From d2b7aa89e81d59a00874412c7b6cda9562048ccf Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 10 Sep 2018 16:42:56 -0700 Subject: [PATCH] WIP attempt to support torch.einsum in jit --- test/test_jit_einsum.py | 17 +++++++++++++++++ torch/csrc/jit/tracer.cpp | 2 +- 2 files changed, 18 insertions(+), 1 deletion(-) create mode 100644 test/test_jit_einsum.py diff --git a/test/test_jit_einsum.py b/test/test_jit_einsum.py new file mode 100644 index 0000000000000..4000276f99446 --- /dev/null +++ b/test/test_jit_einsum.py @@ -0,0 +1,17 @@ +import torch +from common import TestCase + + +class TestEinsum(TestCase): + def test_jit(self): + + def fn(x, y): + return torch.einsum('i,j->ij', x, y) + + jit_fn = torch.jit.trace(fn, (torch.ones(2), torch.ones(3)), check_trace=False) + + x = torch.randn(2) + y = torch.randn(3) + expected = fn(x, y) + actual = jit_fn(x, y) + self.assertLess(torch.abs(actual - expected) < 1e-6) diff --git a/torch/csrc/jit/tracer.cpp b/torch/csrc/jit/tracer.cpp index d2c4ef9f0da5a..4132109bfd652 100644 --- a/torch/csrc/jit/tracer.cpp +++ b/torch/csrc/jit/tracer.cpp @@ -39,7 +39,7 @@ void addInputs(Node *n, const char * name, bool value) { detail::g void addInputs(Node *n, const char * name, double value) { detail::genericAddInput(n, value); } void addInputs(Node *n, const char * name, const at::Scalar& value) { detail::genericAddInput(n, value); } void addInputs(Node *n, const char * name, const at::Tensor& value) { n->addInput(getValueTrace(value)); } -void addInputs(Node *n, const char * name, const std::string& value) { detail::badArgType(); } +void addInputs(Node *n, const char * name, const std::string& value) { detail::genericAddInput(n, value); } void addInputs(Node *n, const char * name, const at::SparseTensorRef& value) { detail::badArgType(); } void addInputs(Node *n, const char * name, at::TensorList value) {