Skip to content

Conversation

xmfan
Copy link
Member

@xmfan xmfan commented Aug 20, 2025

Stack from ghstack (oldest at bottom):

Fixes silent incorrectness for autograd function tracing, where we rely on FakeTensor metadata (requires_grad) to determine whether to HOP or not:

if requires_grad and torch.is_grad_enabled():

Stared at this with @anijain2305 yesterday, Tensor.__setitem__ can update tensor metadata, and we can just run the fake prop and extract the output metadata from the updated FakeTensor.

FIXES #160901

It should also be the root cause behind the issue in pytorch/torchtitan#1604 @bdhirsh @ruisizhang123

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @mlazos

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Aug 20, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161036

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 33ecc00 with merge base 95e456f (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

xmfan added a commit that referenced this pull request Aug 20, 2025
[ghstack-poisoned]
[ghstack-poisoned]
xmfan added a commit that referenced this pull request Aug 20, 2025
@xmfan xmfan marked this pull request as ready for review August 20, 2025 14:41
@xmfan xmfan requested review from anijain2305 and bdhirsh August 20, 2025 14:44
torch.compile(fn, backend="eager")(x2, y2).sum().backward()
self.assertTrue(x2.requires_grad)

self.assertEqual(x1.grad, x2.grad)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried running your test locally without your changes, and it looks like this test passes without the changes in this PR - it turns out that x is a non-leaf tensor after the mutation (x.grad_fn is CopySlices), so x.gradisNone` in both examples.

You probably need to assert on self.assertEqual(y1.grad, y2.grad)? since the "wrong gradients" that we are computing today should propagate into y2.grad

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch, let me just add the original repro's permutations...

target_cls, tx, example_value, infer_subclass_type(example_value)
)
for k, v in specialized_props.items():
setattr(self, k, v)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to audit a few other places in dynamo too. For example, this will also change the requires_gradness of x:

x = torch.randn(4)
y = torch.randn(4, requires_grad=True)
# x.requires_grad is now True
x.add_(y)

so any tensor data mutation ops will need the same treatment

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed locally - i tried tweaking your input_fn to use add_ and hit the same correctness issue when asserting on y.grad:

        def fn(x, y):
            x.add_(y)
            return MyFn.apply(x)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm does this PR also fix the problem for add_? it's not clear to me where (if not no worries, but can you file another issue?)

Copy link
Contributor

@bdhirsh bdhirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments - nice find!

[ghstack-poisoned]
xmfan added a commit that referenced this pull request Aug 20, 2025
[ghstack-poisoned]
xmfan added a commit that referenced this pull request Aug 21, 2025
@xmfan xmfan requested a review from bdhirsh August 21, 2025 04:00
@xmfan
Copy link
Member Author

xmfan commented Aug 21, 2025

hf_Bart fails on base commit

@xmfan
Copy link
Member Author

xmfan commented Aug 22, 2025

@pytorchbot merge -i

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 22, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: inductor / linux-jammy-cpu-py3.9-gcc11-inductor / test (cpu_inductor_torchbench, 1, 2, linux.8xlarge.amx)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants