|
559 | 559 | }, |
560 | 560 | "outputs": [], |
561 | 561 | "source": [ |
562 | | - "from nvflare.fuel.utils import fobs\n", |
563 | | - "from nvflare.app_common.decomposers import common_decomposers\n", |
564 | 562 | "import pprint\n", |
565 | | - "\n", |
566 | | - "# This example stores numpy arrays in FOBS format. Decomposers for Numpy is not registered automatically.\n", |
567 | | - "common_decomposers.register()\n", |
| 563 | + "from nvflare.app_opt.tf.utils import flat_layer_weights_dict\n", |
| 564 | + "from tensorflow.keras import layers, models\n", |
568 | 565 | "\n", |
569 | 566 | "result = sess.download_job_result(job_id)\n", |
570 | | - "with open(result + \"/workspace/app_server/tf_model.weights.h5\", \"rb\") as f:\n", |
571 | | - " bytes = f.read()\n", |
572 | 567 | "\n", |
573 | | - "weights = fobs.loads(bytes)\n", |
| 568 | + "class Net(models.Sequential):\n", |
| 569 | + " def __init__(self, input_shape=(None, 28, 28)):\n", |
| 570 | + " super().__init__()\n", |
| 571 | + " self._input_shape = input_shape\n", |
| 572 | + " self.add(layers.Flatten())\n", |
| 573 | + " self.add(layers.Dense(128, activation=\"relu\"))\n", |
| 574 | + " self.add(layers.Dropout(0.2))\n", |
| 575 | + " self.add(layers.Dense(10))\n", |
| 576 | + "\n", |
| 577 | + "model = Net()\n", |
| 578 | + "model.build(input_shape=(None, 28, 28))\n", |
| 579 | + "model.load_weights(result + \"/workspace/app_server/tf_model.weights.h5\")\n", |
| 580 | + "model.summary()\n", |
| 581 | + "\n", |
| 582 | + "layer_weights_dict = flat_layer_weights_dict({layer.name: layer.get_weights() for layer in model.layers})\n", |
574 | 583 | "\n", |
575 | 584 | "pp = pprint.PrettyPrinter(indent=4)\n", |
576 | | - "pp.pprint(weights)" |
| 585 | + "pp.pprint(layer_weights_dict)" |
577 | 586 | ] |
578 | 587 | }, |
579 | 588 | { |
|
872 | 881 | }, |
873 | 882 | "outputs": [], |
874 | 883 | "source": [ |
875 | | - "from nvflare.fuel.utils import fobs\n", |
876 | | - "from nvflare.app_common.decomposers import common_decomposers\n", |
877 | 884 | "import pprint\n", |
| 885 | + "from nvflare.app_opt.tf.utils import flat_layer_weights_dict\n", |
| 886 | + "from tensorflow.keras import layers, models\n", |
878 | 887 | "\n", |
879 | | - "common_decomposers.register()\n", |
880 | 888 | "result = sess.download_job_result(job_id)\n", |
881 | | - "with open(result + \"/workspace/app_server/tf_model.weights.h5\", \"rb\") as f:\n", |
882 | | - " bytes = f.read()\n", |
883 | 889 | "\n", |
884 | | - "weights = fobs.loads(bytes)\n", |
| 890 | + "class Net(models.Sequential):\n", |
| 891 | + " def __init__(self, input_shape=(None, 28, 28)):\n", |
| 892 | + " super().__init__()\n", |
| 893 | + " self._input_shape = input_shape\n", |
| 894 | + " self.add(layers.Flatten())\n", |
| 895 | + " self.add(layers.Dense(128, activation=\"relu\"))\n", |
| 896 | + " self.add(layers.Dropout(0.2))\n", |
| 897 | + " self.add(layers.Dense(10))\n", |
| 898 | + "\n", |
| 899 | + "model = Net()\n", |
| 900 | + "model.build(input_shape=(None, 28, 28))\n", |
| 901 | + "model.load_weights(result + \"/workspace/app_server/tf_model.weights.h5\")\n", |
| 902 | + "model.summary()\n", |
| 903 | + "\n", |
| 904 | + "layer_weights_dict = flat_layer_weights_dict({layer.name: layer.get_weights() for layer in model.layers})\n", |
885 | 905 | "\n", |
886 | 906 | "pp = pprint.PrettyPrinter(indent=4)\n", |
887 | | - "pp.pprint(weights)" |
| 907 | + "pp.pprint(layer_weights_dict)" |
888 | 908 | ] |
889 | 909 | }, |
890 | 910 | { |
|
0 commit comments