@@ -206,21 +206,27 @@ static void unpackReturnTuple(Stack &stack) {
206206 stack.insert (stack.end (), tuple->elements ().begin (), tuple->elements ().end ());
207207}
208208
209+ struct DifferentiableGraphOp ;
210+
209211struct DifferentiableGraphBackward : public autograd ::Node {
210212 DifferentiableGraphBackward (
211- GraphExecutor executor ,
213+ const std::shared_ptr<Graph>& unspec_graph ,
212214 size_t input_size,
213- size_t capture_size)
214- : executor(std::move(executor)),
215- captures_ (capture_size),
216- input_instructions_(input_size) {}
215+ size_t capture_size,
216+ c10::optional<GraphExecutor>& grad_executor)
217+ : captures_(capture_size),
218+ input_instructions_ (input_size),
219+ unspecialized_graph_(unspec_graph),
220+ grad_executor_(grad_executor) {}
217221
218222 variable_list apply (variable_list&& inputs) override {
219223 Stack stack;
220- stack.reserve (captures_.size () + inputs.size ());
224+ size_t num_args = captures_.size () + inputs.size ();
225+ stack.reserve (num_args);
221226
222227 input_instructions_.unpack (std::move (inputs), stack);
223228 captures_.unpack (stack, shared_from_this ());
229+ GraphExecutor& executor = getExecutor (stack);
224230 GRAPH_DEBUG (" Running DifferentiableGraphBackward for " , &executor);
225231 executor.run (stack);
226232 unpackReturnTuple (stack);
@@ -259,6 +265,66 @@ struct DifferentiableGraphBackward : public autograd::Node {
259265 captures_.capture (val, is_output);
260266 }
261267
268+
269+ static c10::TensorTypePtr getTensorType (bool defined ) {
270+ auto tensor_type = TensorType::get ();
271+
272+ if (defined ) {
273+ return tensor_type;
274+ }
275+
276+ return tensor_type->withUndefined ();
277+ }
278+
279+ GraphExecutor& getExecutor (Stack& stack) {
280+
281+ // tensor lists are hashed as a single boolean value
282+ // since all tensors will be either defined or undefined
283+
284+ std::vector<bool > hash;
285+
286+ for (IValue& v : stack) {
287+ if (v.isTensorList ()) {
288+ auto list = v.toTensorListRef ();
289+ hash.push_back (list.size () > 0 ? list[0 ].defined () : true );
290+ } else if (v.isTensor ()) {
291+ hash.push_back (v.toTensor ().defined ());
292+ } else {
293+ // assume that every other type is defined
294+ hash.push_back (true );
295+ }
296+ }
297+
298+
299+ TORCH_INTERNAL_ASSERT (unspecialized_graph_->inputs ().size () == hash.size ());
300+ if (grad_executors_.count (hash) == 0 ) {
301+
302+ std::shared_ptr<Graph> spec_copy = unspecialized_graph_->copy ();
303+
304+ for (auto i = 0 ; i < hash.size (); i++) {
305+ auto input_type = spec_copy->inputs ().at (i);
306+ bool defined = hash[i];
307+ if (input_type->type ()->kind () == TensorType::Kind) {
308+ input_type->setType (getTensorType (defined ));
309+ } else if (
310+ input_type->type ()->kind () == ListType::Kind &&
311+ input_type->type ()->expect <ListType>()->getElementType ()->kind () ==
312+ TensorType::Kind) {
313+ input_type->setType (ListType::create (getTensorType (defined )));
314+ }
315+ }
316+ grad_executors_[hash] = GraphExecutor (spec_copy);
317+ }
318+
319+
320+ // set last optimized graph
321+ // make a copy because DifferentiableBackward might disappear
322+ // by the time we get to use diff_op_.grad_executor
323+ grad_executor_ = GraphExecutor (grad_executors_[hash].graph ()) ;
324+ return grad_executors_[hash];
325+
326+ }
327+
262328 void addOutputForTensor (const at::Tensor& tensor) {
263329 auto v = Variable (tensor);
264330 add_next_edge (v.defined () ? v.gradient_edge () : autograd::Edge{});
@@ -319,6 +385,9 @@ struct DifferentiableGraphBackward : public autograd::Node {
319385 GraphExecutor executor;
320386 CaptureList captures_;
321387 UnpackInstructions input_instructions_;
388+ std::unordered_map<std::vector<bool >, GraphExecutor> grad_executors_;
389+ std::shared_ptr<Graph> unspecialized_graph_;
390+ c10::optional<GraphExecutor>& grad_executor_;
322391};
323392
324393// an optimized way of executing the subgraph computed directly on
@@ -330,17 +399,18 @@ struct DifferentiableGraphOp {
330399 DifferentiableGraphOp (Gradient grad)
331400 : f(grad.f),
332401 grad (std::move(grad)),
333- grad_executor(this ->grad.df ),
402+ grad_executor(),
334403 num_inputs(this ->grad.f->inputs ().size()),
335404 num_outputs(this ->grad.f->outputs ().size()) {}
336405
337406 // XXX: keep in mind that stack can be larger than the inputs we need!
338407 int operator ()(Stack& stack) const {
339408 auto grad_fn = std::make_shared<DifferentiableGraphBackward>(
340- grad_executor ,
409+ this -> grad . df ,
341410 grad.df_input_vjps .size (),
342411 grad.df_input_captured_inputs .size () +
343- grad.df_input_captured_outputs .size ());
412+ grad.df_input_captured_outputs .size (),
413+ grad_executor);
344414
345415 {
346416 auto inputs = last (stack, num_inputs);
@@ -378,6 +448,7 @@ struct DifferentiableGraphOp {
378448
379449 private:
380450 friend GraphExecutor* detail::getGradExecutor (Operation& op);
451+ friend struct DifferentiableGraphBackward ;
381452
382453 at::Tensor detach (at::Tensor t) const {
383454 if (!t.defined ()) {
@@ -426,7 +497,7 @@ struct DifferentiableGraphOp {
426497
427498 Code f;
428499 Gradient grad;
429- GraphExecutor grad_executor;
500+ mutable c10::optional< GraphExecutor> grad_executor;
430501
431502 const size_t num_inputs;
432503 const size_t num_outputs;
@@ -459,7 +530,9 @@ namespace detail {
459530
460531GraphExecutor* getGradExecutor (Operation& op) {
461532 if (auto diff_op = op.target <DifferentiableGraphOp>()) {
462- return &diff_op->grad_executor ;
533+
534+ TORCH_INTERNAL_ASSERT (diff_op->grad_executor .has_value ())
535+ return &(*diff_op->grad_executor );
463536 }
464537 return nullptr ;
465538}
0 commit comments