Skip to content

Commit cf36f83

Browse files
committed
specialize undefinedness
1 parent 5719fb3 commit cf36f83

File tree

3 files changed

+88
-14
lines changed

3 files changed

+88
-14
lines changed

test/jit_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import tempfile
3434
import textwrap
3535

36-
IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR = False
36+
IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR = True
3737

3838
class ProfilingMode(Enum):
3939
OFF = 1
@@ -44,7 +44,8 @@ class ProfilingMode(Enum):
4444
def enable_profiling_mode(flag):
4545
if IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR:
4646
old_prof_exec_state = torch._C._jit_set_profiling_executor(flag != ProfilingMode.OFF)
47-
old_prof_mode_state = torch._C._jit_set_profiling_mode(flag == ProfilingMode.FULL)
47+
#old_prof_mode_state = torch._C._jit_set_profiling_mode(flag == ProfilingMode.FULL)
48+
old_prof_mode_state = torch._C._jit_set_profiling_mode(False)
4849
try:
4950
yield
5051
finally:

test/test_jit_fuser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
if IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR:
2222
torch._C._jit_set_profiling_executor(True)
23-
torch._C._jit_set_profiling_mode(True)
23+
torch._C._jit_set_profiling_mode(False)
2424

2525

2626
def strip_profiling_nodes(nodes):

torch/csrc/jit/graph_executor.cpp

Lines changed: 84 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -206,21 +206,27 @@ static void unpackReturnTuple(Stack &stack) {
206206
stack.insert(stack.end(), tuple->elements().begin(), tuple->elements().end());
207207
}
208208

209+
struct DifferentiableGraphOp;
210+
209211
struct DifferentiableGraphBackward : public autograd::Node {
210212
DifferentiableGraphBackward(
211-
GraphExecutor executor,
213+
const std::shared_ptr<Graph>& unspec_graph,
212214
size_t input_size,
213-
size_t capture_size)
214-
: executor(std::move(executor)),
215-
captures_(capture_size),
216-
input_instructions_(input_size) {}
215+
size_t capture_size,
216+
c10::optional<GraphExecutor>& grad_executor)
217+
: captures_(capture_size),
218+
input_instructions_(input_size),
219+
unspecialized_graph_(unspec_graph),
220+
grad_executor_(grad_executor) {}
217221

218222
variable_list apply(variable_list&& inputs) override {
219223
Stack stack;
220-
stack.reserve(captures_.size() + inputs.size());
224+
size_t num_args = captures_.size() + inputs.size();
225+
stack.reserve(num_args);
221226

222227
input_instructions_.unpack(std::move(inputs), stack);
223228
captures_.unpack(stack, shared_from_this());
229+
GraphExecutor& executor = getExecutor(stack);
224230
GRAPH_DEBUG("Running DifferentiableGraphBackward for ", &executor);
225231
executor.run(stack);
226232
unpackReturnTuple(stack);
@@ -259,6 +265,66 @@ struct DifferentiableGraphBackward : public autograd::Node {
259265
captures_.capture(val, is_output);
260266
}
261267

268+
269+
static c10::TensorTypePtr getTensorType(bool defined) {
270+
auto tensor_type = TensorType::get();
271+
272+
if (defined) {
273+
return tensor_type;
274+
}
275+
276+
return tensor_type->withUndefined();
277+
}
278+
279+
GraphExecutor& getExecutor(Stack& stack) {
280+
281+
// tensor lists are hashed as a single boolean value
282+
// since all tensors will be either defined or undefined
283+
284+
std::vector<bool> hash;
285+
286+
for (IValue& v : stack) {
287+
if (v.isTensorList()) {
288+
auto list = v.toTensorListRef();
289+
hash.push_back(list.size() > 0 ? list[0].defined() : true);
290+
} else if (v.isTensor()) {
291+
hash.push_back(v.toTensor().defined());
292+
} else {
293+
// assume that every other type is defined
294+
hash.push_back(true);
295+
}
296+
}
297+
298+
299+
TORCH_INTERNAL_ASSERT(unspecialized_graph_->inputs().size() == hash.size());
300+
if (grad_executors_.count(hash) == 0) {
301+
302+
std::shared_ptr<Graph> spec_copy = unspecialized_graph_->copy();
303+
304+
for (auto i = 0; i < hash.size(); i++) {
305+
auto input_type = spec_copy->inputs().at(i);
306+
bool defined = hash[i];
307+
if (input_type->type()->kind() == TensorType::Kind) {
308+
input_type->setType(getTensorType(defined));
309+
} else if (
310+
input_type->type()->kind() == ListType::Kind &&
311+
input_type->type()->expect<ListType>()->getElementType()->kind() ==
312+
TensorType::Kind) {
313+
input_type->setType(ListType::create(getTensorType(defined)));
314+
}
315+
}
316+
grad_executors_[hash] = GraphExecutor(spec_copy);
317+
}
318+
319+
320+
// set last optimized graph
321+
// make a copy because DifferentiableBackward might disappear
322+
// by the time we get to use diff_op_.grad_executor
323+
grad_executor_ = GraphExecutor(grad_executors_[hash].graph()) ;
324+
return grad_executors_[hash];
325+
326+
}
327+
262328
void addOutputForTensor(const at::Tensor& tensor) {
263329
auto v = Variable(tensor);
264330
add_next_edge(v.defined() ? v.gradient_edge() : autograd::Edge{});
@@ -319,6 +385,9 @@ struct DifferentiableGraphBackward : public autograd::Node {
319385
GraphExecutor executor;
320386
CaptureList captures_;
321387
UnpackInstructions input_instructions_;
388+
std::unordered_map<std::vector<bool>, GraphExecutor> grad_executors_;
389+
std::shared_ptr<Graph> unspecialized_graph_;
390+
c10::optional<GraphExecutor>& grad_executor_;
322391
};
323392

324393
// an optimized way of executing the subgraph computed directly on
@@ -330,17 +399,18 @@ struct DifferentiableGraphOp {
330399
DifferentiableGraphOp(Gradient grad)
331400
: f(grad.f),
332401
grad(std::move(grad)),
333-
grad_executor(this->grad.df),
402+
grad_executor(),
334403
num_inputs(this->grad.f->inputs().size()),
335404
num_outputs(this->grad.f->outputs().size()) {}
336405

337406
// XXX: keep in mind that stack can be larger than the inputs we need!
338407
int operator()(Stack& stack) const {
339408
auto grad_fn = std::make_shared<DifferentiableGraphBackward>(
340-
grad_executor,
409+
this->grad.df,
341410
grad.df_input_vjps.size(),
342411
grad.df_input_captured_inputs.size() +
343-
grad.df_input_captured_outputs.size());
412+
grad.df_input_captured_outputs.size(),
413+
grad_executor);
344414

345415
{
346416
auto inputs = last(stack, num_inputs);
@@ -378,6 +448,7 @@ struct DifferentiableGraphOp {
378448

379449
private:
380450
friend GraphExecutor* detail::getGradExecutor(Operation& op);
451+
friend struct DifferentiableGraphBackward;
381452

382453
at::Tensor detach(at::Tensor t) const {
383454
if (!t.defined()) {
@@ -426,7 +497,7 @@ struct DifferentiableGraphOp {
426497

427498
Code f;
428499
Gradient grad;
429-
GraphExecutor grad_executor;
500+
mutable c10::optional<GraphExecutor> grad_executor;
430501

431502
const size_t num_inputs;
432503
const size_t num_outputs;
@@ -459,7 +530,9 @@ namespace detail {
459530

460531
GraphExecutor* getGradExecutor(Operation& op) {
461532
if (auto diff_op = op.target<DifferentiableGraphOp>()) {
462-
return &diff_op->grad_executor;
533+
534+
TORCH_INTERNAL_ASSERT(diff_op->grad_executor.has_value())
535+
return &(*diff_op->grad_executor);
463536
}
464537
return nullptr;
465538
}

0 commit comments

Comments
 (0)