Skip to content

Commit 40a8a50

Browse files
pilleyefacebook-github-bot
authored andcommitted
Add static_runtime::fused_equally_split (#2)
Summary: Pull Request resolved: pytorch/pytorch-canary#2 Pull Request resolved: pytorch#66881 Adds `static_runtime::fused_equally_split` operator and removes `is_fused` logic from original operator. Modifies `FuseUnpackListV2` to map `fb::equally_split` to this new operator. Test Plan: ``` adityapillai@5960 /data/sandcastle/boxes/fbsource/fbcode 1m 13s ❯ buck test //caffe2/benchmarks/static_runtime/fb:test_fb_operators ``` and sandcastle strange_what_could_go_wrong Reviewed By: mikeiovine Differential Revision: D31742293 fbshipit-source-id: 60b35589c8817719b005d49811f575b6590d1c39
1 parent 391eb1d commit 40a8a50

File tree

3 files changed

+21
-16
lines changed

3 files changed

+21
-16
lines changed

benchmarks/static_runtime/test_utils.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,18 @@ void testStaticRuntime(
277277
compareTensorLists(args_tensors, args_copy, use_allclose, use_equalnan);
278278
}
279279

280+
bool hasProcessedNodeWithName(
281+
torch::jit::StaticModule& smodule,
282+
const char* name) {
283+
for (torch::jit::ProcessedNode& pnode : smodule.runtime().nodes()) {
284+
auto op_name = pnode.node()->kind().toQualString();
285+
if (strcmp(op_name, name) == 0) {
286+
return true;
287+
}
288+
}
289+
return false;
290+
}
291+
280292
} // namespace test
281293
} // namespace jit
282294
} // namespace torch

benchmarks/static_runtime/test_utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <vector>
77

88
#include <torch/csrc/jit/ir/ir.h>
9+
#include <torch/csrc/jit/runtime/static/impl.h>
910

1011
namespace c10 {
1112
struct IValue;
@@ -30,6 +31,8 @@ void testStaticRuntime(
3031

3132
std::shared_ptr<Graph> getGraphFromIR(const std::string& ir);
3233

34+
bool hasProcessedNodeWithName(torch::jit::StaticModule& smodule, const char *name);
35+
3336
} // namespace test
3437
} // namespace jit
3538
} // namespace torch

torch/csrc/jit/runtime/static/passes.cpp

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,9 @@ TORCH_LIBRARY_FRAGMENT(static_runtime, m) {
347347
m.def(torch::schema(
348348
"static_runtime::VarTupleUnpack(...) -> ...",
349349
c10::AliasAnalysisKind::CONSERVATIVE));
350+
m.def(torch::schema(
351+
"static_runtime::fused_equally_split(Tensor input, int num_split, int dim) -> ...",
352+
c10::AliasAnalysisKind::PURE_FUNCTION));
350353
}
351354

352355
void FuseSignLog1P(std::shared_ptr<torch::jit::Graph>& graph) {
@@ -529,13 +532,11 @@ void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
529532
const std::vector<Value*> graph_outputs(
530533
graph->outputs().begin(), graph->outputs().end());
531534
auto nodes = graph->nodes();
532-
std::vector<Node*> equally_splits_to_remove;
533535
for (auto it = nodes.begin(); it != nodes.end(); ++it) {
534536
Node* node = *it;
535537
const std::string node_qual_string = node->kind().toQualString();
536538
if (node_qual_string == "fb::sigrid_transforms" ||
537539
node_qual_string == "fb::sigrid_transforms_torch_bind" ||
538-
node_qual_string == "fb::equally_split" ||
539540
node_qual_string == "fb::gather_ranges_to_dense" ||
540541
node_qual_string == "fb::gather_ranges_to_dense_v2" ||
541542
node_qual_string == "fb::variadic_sigrid_transforms_torch_bind") {
@@ -577,22 +578,9 @@ void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
577578
it_next.destroyCurrent(); // remove list_unpack
578579

579580
node->eraseOutput(0);
580-
581-
if (node_qual_string == "fb::equally_split" &&
582-
node->outputs().size() == 1) {
583-
// This captures a case of `y = fb::equally_split(x, 1, _)` where y
584-
// becomes just an alias of x.
585-
// If this case is found, replace y with x to avoid executing this op.
586-
equally_splits_to_remove.push_back(node);
587-
}
588581
}
589582
}
590583

591-
for (Node* node : equally_splits_to_remove) {
592-
node->output(0)->replaceAllUsesWith(node->input(0));
593-
node->destroy();
594-
}
595-
596584
#ifndef NDEBUG
597585
graph->lint();
598586
AliasDb db2(graph);
@@ -601,7 +589,9 @@ void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
601589
}
602590

603591
void FuseListUnpackV2(std::shared_ptr<torch::jit::Graph>& graph) {
604-
const FastMap<c10::Symbol, c10::Symbol> unfused_to_fused = {};
592+
const FastMap<c10::Symbol, c10::Symbol> unfused_to_fused = {
593+
{c10::Symbol::fromQualString("fb::equally_split"),
594+
c10::Symbol::fromQualString("static_runtime::fused_equally_split")}};
605595

606596
AliasDb alias_db(
607597
graph, /*isFrozen=*/false, /*enablePreciseTupleContainerAnalysis=*/true);

0 commit comments

Comments
 (0)