-
Notifications
You must be signed in to change notification settings - Fork 25.1k
[dynamo] propagate tensor metadata on Tensor.__setitem__(tensor) #161036
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161036
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 33ecc00 with merge base 95e456f ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
test/dynamo/test_repros.py
Outdated
torch.compile(fn, backend="eager")(x2, y2).sum().backward() | ||
self.assertTrue(x2.requires_grad) | ||
|
||
self.assertEqual(x1.grad, x2.grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried running your test locally without your changes, and it looks like this test passes without the changes in this PR - it turns out that x
is a non-leaf tensor after the mutation (x.grad_fn
is CopySlices), so
x.gradis
None` in both examples.
You probably need to assert on self.assertEqual(y1.grad, y2.grad)
? since the "wrong gradients" that we are computing today should propagate into y2.grad
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the catch, let me just add the original repro's permutations...
target_cls, tx, example_value, infer_subclass_type(example_value) | ||
) | ||
for k, v in specialized_props.items(): | ||
setattr(self, k, v) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might need to audit a few other places in dynamo too. For example, this will also change the requires_gradness of x
:
x = torch.randn(4)
y = torch.randn(4, requires_grad=True)
# x.requires_grad is now True
x.add_(y)
so any tensor data mutation ops will need the same treatment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Confirmed locally - i tried tweaking your input_fn to use add_ and hit the same correctness issue when asserting on y.grad:
def fn(x, y):
x.add_(y)
return MyFn.apply(x)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm does this PR also fix the problem for add_
? it's not clear to me where (if not no worries, but can you file another issue?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments - nice find!
hf_Bart fails on base commit |
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 1 checks: inductor / linux-jammy-cpu-py3.9-gcc11-inductor / test (cpu_inductor_torchbench, 1, 2, linux.8xlarge.amx) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Fixes silent incorrectness for autograd function tracing, where we rely on FakeTensor metadata (requires_grad) to determine whether to HOP or not:
pytorch/torch/_dynamo/variables/misc.py
Line 671 in 5ee464d
Stared at this with @anijain2305 yesterday,
Tensor.__setitem__
can update tensor metadata, and we can just run the fake prop and extract the output metadata from the updated FakeTensor.FIXES #160901
It should also be the root cause behind the issue in pytorch/torchtitan#1604 @bdhirsh @ruisizhang123
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos