@@ -189,56 +189,39 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
189189{
190190 using namespace ngraph ;
191191
192- // Traverse fprop to make a map that stores parameters with the same
193- // shape and element type as the nodes in fprop
194- NodeMap node_param_map;
195- ngraph::traverse_nodes (fprop, [&node_param_map](std::shared_ptr<Node> node) {
196- node_param_map.add (
197- node, std::make_shared<op::Parameter>(node->get_element_type (), node->get_shape ()));
198- });
199-
200192 // Traverse bprop to find all of the nodes in the graph
201193 std::unordered_set<std::shared_ptr<Node>> in_bprop;
202194 ngraph::traverse_nodes (bprop, [&in_bprop](std::shared_ptr<Node> node) {
203- if (in_bprop.count (node) == 0 )
195+
196+ if (node->get_outputs ().size () == 1 )
204197 {
205- in_bprop.insert (node);
198+ if (in_bprop.count (node) == 0 )
199+ {
200+ in_bprop.insert (node);
201+ }
206202 }
203+
207204 });
208205
209- // Get the input paramters of fprop
210- std::unordered_set<std::shared_ptr<Node>> fprop_params;
211- for (auto node : fprop->get_parameters ())
212- {
213- if (fprop_params.count (node) == 0 )
206+ // Traverse fprop to make a map that stores parameters with the same
207+ // shape and element type as the nodes in fprop
208+ FpropCache fprop_cache;
209+ fprop_cache.node_param_map = std::make_shared<NodeMap>();
210+ ngraph::traverse_nodes (fprop, [&fprop_cache, &in_bprop](std::shared_ptr<Node> node) {
211+ if (in_bprop.count (node) != 0 )
214212 {
215- fprop_params.insert (node);
213+ fprop_cache.node_param_map ->add (
214+ node, std::make_shared<op::Parameter>(node->get_element_type (), node->get_shape ()));
216215 }
217- }
216+ });
218217
219218 // Find all of the nodes that are intermediate values of fprop and used in
220219 // bprop
221220 // and store those nodes that aren't needed in bprop
222- FpropCache fprop_cache;
223221 std::vector<std::shared_ptr<Node>> unused_nodes;
224- for (auto kv : node_param_map.get_node_map ())
225- {
226- // if it's not in bprop, mark it unused
227- if (in_bprop.count (kv.first ) == 0 )
228- {
229- unused_nodes.push_back (kv.first );
230- }
231- // otherwise save in in the ouputs
232- else
233- {
234- fprop_cache.fprop_output_nodes .push_back (kv.first );
235- }
236- }
237-
238- // erase all unused nodes form the map
239- for (auto node : unused_nodes)
222+ for (auto kv : fprop_cache.node_param_map ->get_node_map ())
240223 {
241- node_param_map. get_node_map (). erase (node );
224+ fprop_cache. fprop_output_nodes . push_back (kv. first );
242225 }
243226
244227 // create the new outputs for fprop and the new fprop function
@@ -262,13 +245,13 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
262245
263246 // clone the nodes in bprop, replacing fprop-related nodes with the
264247 // intermediate parameters
265- ngraph::clone_nodes (bprop->get_ops (), node_param_map);
248+ ngraph::clone_nodes (bprop->get_ops (), *(fprop_cache. node_param_map ) );
266249
267250 // get cloned bprop results
268251 ResultVector cloned_results;
269252 for (auto node : bprop->get_results ())
270253 {
271- auto result = std::dynamic_pointer_cast<op::Result>(node_param_map. get (node));
254+ auto result = std::dynamic_pointer_cast<op::Result>(fprop_cache. node_param_map -> get (node));
272255 if (!result)
273256 {
274257 throw ngraph_error (" Expected op::Result values for op::Result keys in node_param_map" );
@@ -281,14 +264,14 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
281264 for (auto param : adjoints)
282265 {
283266 bprop_input_params.push_back (
284- std::dynamic_pointer_cast<op::Parameter>(node_param_map. get (param)));
267+ std::dynamic_pointer_cast<op::Parameter>(fprop_cache. node_param_map -> get (param)));
285268 }
286269
287270 // add the cached fprop nodes as inputs to bprop
288271 for (auto x : fprop_cache.fprop_output_nodes )
289272 {
290273 bprop_input_params.push_back (
291- std::dynamic_pointer_cast<op::Parameter>(node_param_map. get (x)));
274+ std::dynamic_pointer_cast<op::Parameter>(fprop_cache. node_param_map -> get (x)));
292275 }
293276
294277 // create the new bprop function
0 commit comments