Skip to content

PyTorch reference mode (both eager and torch.compile) #339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

yf225
Copy link
Contributor

@yf225 yf225 commented Jul 20, 2025

Stacked PRs:


PyTorch reference mode (both eager and torch.compile)

Fixes #77.
Please see inline code comments on the PR.

yf225 added a commit that referenced this pull request Jul 20, 2025
stack-info: PR: #339, branch: yf225/stack/34
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 20, 2025
yf225 added a commit that referenced this pull request Jul 20, 2025
Fixes #77.

stack-info: PR: #339, branch: yf225/stack/34
yf225 added a commit that referenced this pull request Jul 20, 2025
Fixes #77.

stack-info: PR: #339, branch: yf225/stack/34
yf225 added a commit that referenced this pull request Jul 20, 2025
Fixes #77.

stack-info: PR: #339, branch: yf225/stack/34
yf225 added a commit that referenced this pull request Jul 20, 2025
Fixes #77.

stack-info: PR: #339, branch: yf225/stack/34
fn = torch.compile(fn, fullgraph=True)
result = fn(*args)
else:
result = fn(*args)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High-level idea: patch the hl.* ops as well as necessary torch.* ops, to be able to run the Helion kernel in:

  • reference eager mode
  • reference torch.compile mode

)

# Step 3: Handle block_size (in ref mode, full dim size is always used as block_size)
block_size_list = [None] * len(end_list)
Copy link
Contributor Author

@yf225 yf225 Jul 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Always use full dim size in ref modes regardless of block_size value

x_part = hl.load(
x, [tile0, tile1], extra_mask=(tile1.index < x.size(1))[None, :]
Copy link
Contributor Author

@yf225 yf225 Jul 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are treating tile as a Python slice object in ref mode, tile.index no longer works and we have to use hl.tile_index().

To make the UX better, in a follow-up PR I am thinking of adding a RefTile class that Dynamo can understand, and support tile APIs like .index / .begin / .end in that class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make tile.index work. Maybe rather than changing the examples we should skip tests in reference mode.

@yf225 yf225 requested review from jansel, oulgen, drisspg and joydddd July 20, 2025 23:50
class TestExamplesRefCompile(test_examples.TestExamples):
"""Run all TestExamples tests in reference torch.compile mode via HELION_REF_COMPILE=1."""

# NOTE: All tests in TestExamples are run in ref torch.compile(fullgraph=True) mode by default in this test file.
Copy link
Contributor Author

@yf225 yf225 Jul 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently all examples in TestExamples pass with ref eager mode and ref compile mode.

Planning to add more ref mode unit tests to cover test_reduce.py, test_associative_scan.py etc. in the next PR.

yf225 added a commit that referenced this pull request Jul 20, 2025
Fixes #77.

stack-info: PR: #339, branch: yf225/stack/34
yf225 added a commit that referenced this pull request Jul 20, 2025
Fixes #77. Please see
inline code comments on the PR.

stack-info: PR: #339, branch: yf225/stack/34
yf225 added a commit that referenced this pull request Jul 20, 2025
Fixes #77.
Please see inline code comments on the PR.

stack-info: PR: #339, branch: yf225/stack/34
yf225 added a commit that referenced this pull request Jul 21, 2025
Fixes #77.
Please see inline code comments on the PR.

stack-info: PR: #339, branch: yf225/stack/34
yf225 added a commit that referenced this pull request Jul 21, 2025
Fixes #77.
Please see inline code comments on the PR.

stack-info: PR: #339, branch: yf225/stack/34
yf225 added a commit that referenced this pull request Jul 21, 2025
Fixes #77.
Please see inline code comments on the PR.

stack-info: PR: #339, branch: yf225/stack/34
yf225 added a commit that referenced this pull request Jul 21, 2025
Fixes #77.
Please see inline code comments on the PR.

stack-info: PR: #339, branch: yf225/stack/34
yf225 added a commit that referenced this pull request Jul 21, 2025
Fixes #77.
Please see inline code comments on the PR.

stack-info: PR: #339, branch: yf225/stack/34
@yf225 yf225 changed the title PyTorch reference mode (both eager and torch.compile) PyTorch reference mode (supports both eager and torch.compile) Jul 21, 2025
yf225 added a commit that referenced this pull request Jul 21, 2025
Fixes #77.
Please see inline code comments on the PR.

stack-info: PR: #339, branch: yf225/stack/34
@yf225 yf225 changed the title PyTorch reference mode (supports both eager and torch.compile) PyTorch reference mode (both eager and torch.compile) Jul 21, 2025
yf225 added a commit that referenced this pull request Jul 21, 2025
Fixes #77.
Please see inline code comments on the PR.

stack-info: PR: #339, branch: yf225/stack/34
yf225 added a commit that referenced this pull request Jul 21, 2025
Fixes #77.
Please see inline code comments on the PR.

stack-info: PR: #339, branch: yf225/stack/34
yf225 added a commit that referenced this pull request Jul 21, 2025
Fixes #77.
Please see inline code comments on the PR.

stack-info: PR: #339, branch: yf225/stack/34
yf225 added a commit that referenced this pull request Jul 21, 2025
Fixes #77.
Please see inline code comments on the PR.

stack-info: PR: #339, branch: yf225/stack/34
@@ -2,3 +2,4 @@ pytest
typing_extensions
pre-commit
filecheck
numpy
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately torch.compile now requires numpy via torch.onnx dependency:

File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 275, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 793, in transform
    tracer.run()
  File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3518, in run
    super().run()
  File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1371, in run
    while self.step():
          ^^^^^^^^^^^
  File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1275, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 851, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2922, in CALL
    self._call(inst)
  File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2916, in _call
    self.call_function(fn, args, kwargs)
  File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1199, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/variables/torch.py", line 1429, in call_function
    special_handler = self._get_handlers().get(self.value)
                      ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/variables/torch.py", line 476, in _get_handlers
    @register(*tracing_state_functions())
               ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/variables/torch.py", line 193, in tracing_state_functions
    torch.onnx.is_in_onnx_export: False,
    ^^^^^^^^^^
  File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/__init__.py", line 2734, in __getattr__
    return importlib.import_module(f".{name}", __name__)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/venv/lib/python3.12/importlib/__init__.py", line 90, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen importlib._bootstrap>", line 1387, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1360, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1331, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 935, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 999, in exec_module
  File "<frozen importlib._bootstrap>", line 488, in _call_with_frames_removed
  File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/onnx/__init__.py", line 51, in <module>
    from ._internal.exporter._onnx_program import ONNXProgram
  File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_onnx_program.py", line 18, in <module>
    import numpy as np
torch._dynamo.exc.InternalTorchDynamoError: ModuleNotFoundError: No module named 'numpy'

We could either try to patch PyTorch to do sth like "don't access torch.onnx if numpy is not available", or add the numpy dependency to Helion.

x_part = hl.load(
x, [tile0, tile1], extra_mask=(tile1.index < x.size(1))[None, :]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make tile.index work. Maybe rather than changing the examples we should skip tests in reference mode.


def zeros(shape: list[int | slice], dtype: torch.dtype = torch.float32) -> Tensor:
processed_shape = _normalize_shape(shape)
return torch.zeros(processed_shape, dtype=dtype, device="cuda")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't hardcode cuda

yield from dim_tiles[0]
else:
# Multi-dimensional - yield tuples of slices
yield from itertools.product(*dim_tiles)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why we would need more than 1 tile?

_e5m2_matmul_available = False


def _patched_addmm(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we needing to patch pytorch ops?

Copy link
Contributor Author

@yf225 yf225 Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • For addmm and baddbmm:
    In Helion examples there is a usage pattern of torch.addmm(acc_fp32, A_bf16, B_bf16) which makes sense for expressing the intent of "doing bf16 GEMM then accumulate into fp32", but would fail eager mode / torch.compile execution (error from PyTorch) due to the dtype mismatch between the three input tensors.
    To work around this, one way is to patch those PyTorch GEMM ops to allow acc to have a different dtype than the other inputs (implemented in this PR).

  • For matmul:
    PyTorch fp8 GEMM api (torch._scaled_mm) doesn't support FP8 E5M2 input, and thus currently there isn't a way to enable PyTorch reference mode for fp8 GEMM kernels that uses FP8 E5M2 input (appears quite common in TritonBench). The workaround here is to override torch.matmul in ref mode to dispatch to a custom Triton e5m2 GEMM kernel, to support this use case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline:

  • For addmm and baddbmm:
    Instead of monkey-patching, use __torch_function__ mode (https://docs.pytorch.org/tutorials/recipes/torch_compile_torch_function_modes.html).

  • For matmul:
    We can completely remove this patching and do the following:

    • Change all examples to use torch._scaled_mm instead of torch.matmul for fp8 GEMM.
    • When people are passing in fp8 inputs to torch.matmul in Helion and in PyTorch eager mode, show an error message pointing them to torch._scaled_mm.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hopefully we can point to the public one soon: pytorch/pytorch#157950

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For bf16bf16->fp32, torch supports this addmm/baddbmm calls with out_dtype arguments

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e5m2 is very rarely used, the absolute majority of training jobs switched to e4m3, it's also the only fp8 format supported for grouped gemms. Instead of trying to mock coverage for e5m2, should we work with triton bench to switch to more realistic testcases? Bad benchmarks are worse than no benchmarks

self.env = CompileEnvironment(_find_device(args), self.kernel.settings)
with self.env:
with self.env: # pyright: ignore[reportOptionalContextManager]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix types throughout.

Comment on lines +63 to +67
ref_eager: bool = False,
ref_compile: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should do in settings, not as new attributes.

What happens if both of these are set? Should this be an enum?

yf225 added a commit that referenced this pull request Jul 21, 2025
Fixes #77.
Please see inline code comments on the PR.

stack-info: PR: #339, branch: yf225/stack/34
yf225 added a commit that referenced this pull request Jul 21, 2025
Fixes #77.
Please see inline code comments on the PR.

stack-info: PR: #339, branch: yf225/stack/34
yf225 added a commit that referenced this pull request Jul 22, 2025
Fixes #77.
Please see inline code comments on the PR.

stack-info: PR: #339, branch: yf225/stack/34
Fixes #77.
Please see inline code comments on the PR.

stack-info: PR: #339, branch: yf225/stack/34

if beta == 0:
return result
return result + (beta * input)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is losing precision compared to matmul producing fp32 result, and higher precision is the main reason to use it at all. If helion is using torch.addmm(fp32, bf16, bf16) should it follow pytorch semantics and add out_dtype argument when needed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Automatic reference implementations
4 participants