Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit 995671a

Browse files
authored
[v0.1.0] Multi-output fprop_cache tentative fix (#657)
Contains multiple fixes to GetOutputElement, BatchNorm, autodiff, fprop_cache to integrate multi-output batchnorm and fprop_cache
1 parent feeaed5 commit 995671a

File tree

5 files changed

+53
-42
lines changed

5 files changed

+53
-42
lines changed

src/ngraph/autodiff/adjoints.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ std::shared_ptr<Node> autodiff::Adjoints::get(const std::shared_ptr<Node>& x)
125125
void autodiff::Adjoints::add_delta(const std::shared_ptr<Node>& x,
126126
const std::shared_ptr<Node>& delta)
127127
{
128-
if (!x->has_same_type(delta))
128+
if (!x->has_same_type(delta) && delta->get_shape() != x->get_outputs().at(0).get_shape())
129129
{
130130
throw ngraph_error("Autodiff internal error: Mismatch on backprop and op in add_delta.");
131131
}

src/ngraph/ops/batch_norm.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,23 @@ void ngraph::op::BatchNorm::generate_adjoints(autodiff::Adjoints& adjoints,
148148
auto gamma = get_input_op(0);
149149
auto beta = get_input_op(1);
150150
auto input = get_input_op(2);
151-
auto mean = std::make_shared<op::GetOutputElement>(shared_from_this(), 1);
152-
auto var = std::make_shared<op::GetOutputElement>(shared_from_this(), 2);
151+
152+
//Extract mean and variance outputs from BatchNorm
153+
//as these are used by BatchNormBackprop.
154+
//The users of the outputs (GetOutputElements' Inputs) aren't sorted
155+
//and get_n() is used to sort the inputs in the same order as Batchnorm's outputs
156+
//Next, Mean and Variance (`at(1)` and `at(2)`) are extracted
157+
//Please see `add_output` in `BatchNorm::BatchNorm` for more details
158+
std::vector<std::shared_ptr<Node>> goes(get_outputs().size());
159+
160+
for (auto _input : get_output_inputs(0))
161+
{
162+
auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(_input->get_node());
163+
goes.at(goe->get_n()) = _input->get_node();
164+
}
165+
166+
auto mean = goes.at(1);
167+
auto var = goes.at(2);
153168
auto bbn = std::make_shared<op::BatchNormBackprop>(
154169
get_eps_value(), gamma, beta, input, mean, var, delta);
155170
auto dinput = std::make_shared<op::GetOutputElement>(bbn, 0);

src/ngraph/ops/get_output_element.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,17 @@ namespace ngraph
6868
}
6969

7070
protected:
71+
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
72+
const std::shared_ptr<Node>& delta) override
73+
{
74+
//Filter out updates(deltas) from mean and variance (for batchnorm)
75+
//as dinput is the only update required.
76+
//This logic needs to be generalized as new multi-output ops are introduced
77+
if (get_n() == 0)
78+
{
79+
adjoints.add_delta(get_inputs().at(0).get_output().get_node(), delta);
80+
}
81+
}
7182
size_t m_n;
7283
};
7384
}

src/ngraph/util.cpp

Lines changed: 22 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -189,56 +189,39 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
189189
{
190190
using namespace ngraph;
191191

192-
// Traverse fprop to make a map that stores parameters with the same
193-
// shape and element type as the nodes in fprop
194-
NodeMap node_param_map;
195-
ngraph::traverse_nodes(fprop, [&node_param_map](std::shared_ptr<Node> node) {
196-
node_param_map.add(
197-
node, std::make_shared<op::Parameter>(node->get_element_type(), node->get_shape()));
198-
});
199-
200192
// Traverse bprop to find all of the nodes in the graph
201193
std::unordered_set<std::shared_ptr<Node>> in_bprop;
202194
ngraph::traverse_nodes(bprop, [&in_bprop](std::shared_ptr<Node> node) {
203-
if (in_bprop.count(node) == 0)
195+
196+
if (node->get_outputs().size() == 1)
204197
{
205-
in_bprop.insert(node);
198+
if (in_bprop.count(node) == 0)
199+
{
200+
in_bprop.insert(node);
201+
}
206202
}
203+
207204
});
208205

209-
// Get the input paramters of fprop
210-
std::unordered_set<std::shared_ptr<Node>> fprop_params;
211-
for (auto node : fprop->get_parameters())
212-
{
213-
if (fprop_params.count(node) == 0)
206+
// Traverse fprop to make a map that stores parameters with the same
207+
// shape and element type as the nodes in fprop
208+
FpropCache fprop_cache;
209+
fprop_cache.node_param_map = std::make_shared<NodeMap>();
210+
ngraph::traverse_nodes(fprop, [&fprop_cache, &in_bprop](std::shared_ptr<Node> node) {
211+
if (in_bprop.count(node) != 0)
214212
{
215-
fprop_params.insert(node);
213+
fprop_cache.node_param_map->add(
214+
node, std::make_shared<op::Parameter>(node->get_element_type(), node->get_shape()));
216215
}
217-
}
216+
});
218217

219218
// Find all of the nodes that are intermediate values of fprop and used in
220219
// bprop
221220
// and store those nodes that aren't needed in bprop
222-
FpropCache fprop_cache;
223221
std::vector<std::shared_ptr<Node>> unused_nodes;
224-
for (auto kv : node_param_map.get_node_map())
225-
{
226-
// if it's not in bprop, mark it unused
227-
if (in_bprop.count(kv.first) == 0)
228-
{
229-
unused_nodes.push_back(kv.first);
230-
}
231-
// otherwise save in in the ouputs
232-
else
233-
{
234-
fprop_cache.fprop_output_nodes.push_back(kv.first);
235-
}
236-
}
237-
238-
// erase all unused nodes form the map
239-
for (auto node : unused_nodes)
222+
for (auto kv : fprop_cache.node_param_map->get_node_map())
240223
{
241-
node_param_map.get_node_map().erase(node);
224+
fprop_cache.fprop_output_nodes.push_back(kv.first);
242225
}
243226

244227
// create the new outputs for fprop and the new fprop function
@@ -262,13 +245,13 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
262245

263246
// clone the nodes in bprop, replacing fprop-related nodes with the
264247
// intermediate parameters
265-
ngraph::clone_nodes(bprop->get_ops(), node_param_map);
248+
ngraph::clone_nodes(bprop->get_ops(), *(fprop_cache.node_param_map));
266249

267250
// get cloned bprop results
268251
ResultVector cloned_results;
269252
for (auto node : bprop->get_results())
270253
{
271-
auto result = std::dynamic_pointer_cast<op::Result>(node_param_map.get(node));
254+
auto result = std::dynamic_pointer_cast<op::Result>(fprop_cache.node_param_map->get(node));
272255
if (!result)
273256
{
274257
throw ngraph_error("Expected op::Result values for op::Result keys in node_param_map");
@@ -281,14 +264,14 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
281264
for (auto param : adjoints)
282265
{
283266
bprop_input_params.push_back(
284-
std::dynamic_pointer_cast<op::Parameter>(node_param_map.get(param)));
267+
std::dynamic_pointer_cast<op::Parameter>(fprop_cache.node_param_map->get(param)));
285268
}
286269

287270
// add the cached fprop nodes as inputs to bprop
288271
for (auto x : fprop_cache.fprop_output_nodes)
289272
{
290273
bprop_input_params.push_back(
291-
std::dynamic_pointer_cast<op::Parameter>(node_param_map.get(x)));
274+
std::dynamic_pointer_cast<op::Parameter>(fprop_cache.node_param_map->get(x)));
292275
}
293276

294277
// create the new bprop function

src/ngraph/util.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ namespace ngraph
2828
{
2929
class Node;
3030
class Function;
31+
class NodeMap;
3132
class stopwatch;
3233

3334
namespace runtime
@@ -229,6 +230,7 @@ namespace ngraph
229230
std::shared_ptr<Function> fprop;
230231
std::shared_ptr<Function> bprop;
231232
std::vector<std::shared_ptr<Node>> fprop_output_nodes;
233+
std::shared_ptr<NodeMap> node_param_map;
232234
};
233235

234236
/**

0 commit comments

Comments
 (0)