Skip to content

Commit 910179c

Browse files
authored
fix hello_world tf result printing (#2910)
1 parent 81b57ca commit 910179c

File tree

1 file changed

+36
-16
lines changed

1 file changed

+36
-16
lines changed

examples/hello-world/hello_world.ipynb

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -559,21 +559,30 @@
559559
},
560560
"outputs": [],
561561
"source": [
562-
"from nvflare.fuel.utils import fobs\n",
563-
"from nvflare.app_common.decomposers import common_decomposers\n",
564562
"import pprint\n",
565-
"\n",
566-
"# This example stores numpy arrays in FOBS format. Decomposers for Numpy is not registered automatically.\n",
567-
"common_decomposers.register()\n",
563+
"from nvflare.app_opt.tf.utils import flat_layer_weights_dict\n",
564+
"from tensorflow.keras import layers, models\n",
568565
"\n",
569566
"result = sess.download_job_result(job_id)\n",
570-
"with open(result + \"/workspace/app_server/tf_model.weights.h5\", \"rb\") as f:\n",
571-
" bytes = f.read()\n",
572567
"\n",
573-
"weights = fobs.loads(bytes)\n",
568+
"class Net(models.Sequential):\n",
569+
" def __init__(self, input_shape=(None, 28, 28)):\n",
570+
" super().__init__()\n",
571+
" self._input_shape = input_shape\n",
572+
" self.add(layers.Flatten())\n",
573+
" self.add(layers.Dense(128, activation=\"relu\"))\n",
574+
" self.add(layers.Dropout(0.2))\n",
575+
" self.add(layers.Dense(10))\n",
576+
"\n",
577+
"model = Net()\n",
578+
"model.build(input_shape=(None, 28, 28))\n",
579+
"model.load_weights(result + \"/workspace/app_server/tf_model.weights.h5\")\n",
580+
"model.summary()\n",
581+
"\n",
582+
"layer_weights_dict = flat_layer_weights_dict({layer.name: layer.get_weights() for layer in model.layers})\n",
574583
"\n",
575584
"pp = pprint.PrettyPrinter(indent=4)\n",
576-
"pp.pprint(weights)"
585+
"pp.pprint(layer_weights_dict)"
577586
]
578587
},
579588
{
@@ -872,19 +881,30 @@
872881
},
873882
"outputs": [],
874883
"source": [
875-
"from nvflare.fuel.utils import fobs\n",
876-
"from nvflare.app_common.decomposers import common_decomposers\n",
877884
"import pprint\n",
885+
"from nvflare.app_opt.tf.utils import flat_layer_weights_dict\n",
886+
"from tensorflow.keras import layers, models\n",
878887
"\n",
879-
"common_decomposers.register()\n",
880888
"result = sess.download_job_result(job_id)\n",
881-
"with open(result + \"/workspace/app_server/tf_model.weights.h5\", \"rb\") as f:\n",
882-
" bytes = f.read()\n",
883889
"\n",
884-
"weights = fobs.loads(bytes)\n",
890+
"class Net(models.Sequential):\n",
891+
" def __init__(self, input_shape=(None, 28, 28)):\n",
892+
" super().__init__()\n",
893+
" self._input_shape = input_shape\n",
894+
" self.add(layers.Flatten())\n",
895+
" self.add(layers.Dense(128, activation=\"relu\"))\n",
896+
" self.add(layers.Dropout(0.2))\n",
897+
" self.add(layers.Dense(10))\n",
898+
"\n",
899+
"model = Net()\n",
900+
"model.build(input_shape=(None, 28, 28))\n",
901+
"model.load_weights(result + \"/workspace/app_server/tf_model.weights.h5\")\n",
902+
"model.summary()\n",
903+
"\n",
904+
"layer_weights_dict = flat_layer_weights_dict({layer.name: layer.get_weights() for layer in model.layers})\n",
885905
"\n",
886906
"pp = pprint.PrettyPrinter(indent=4)\n",
887-
"pp.pprint(weights)"
907+
"pp.pprint(layer_weights_dict)"
888908
]
889909
},
890910
{

0 commit comments

Comments
 (0)