-
Notifications
You must be signed in to change notification settings - Fork 15
PyTorch reference mode (both eager and torch.compile) #339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
stack-info: PR: #339, branch: yf225/stack/34
fn = torch.compile(fn, fullgraph=True) | ||
result = fn(*args) | ||
else: | ||
result = fn(*args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
High-level idea: patch the hl.*
ops as well as necessary torch.*
ops, to be able to run the Helion kernel in:
- reference eager mode
- reference
torch.compile
mode
helion/ref/hl_patch.py
Outdated
) | ||
|
||
# Step 3: Handle block_size (in ref mode, full dim size is always used as block_size) | ||
block_size_list = [None] * len(end_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Always use full dim size in ref modes regardless of block_size
value
x_part = hl.load( | ||
x, [tile0, tile1], extra_mask=(tile1.index < x.size(1))[None, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are treating tile
as a Python slice object in ref mode, tile.index
no longer works and we have to use hl.tile_index()
.
To make the UX better, in a follow-up PR I am thinking of adding a RefTile
class that Dynamo can understand, and support tile APIs like .index
/ .begin
/ .end
in that class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should make tile.index work. Maybe rather than changing the examples we should skip tests in reference mode.
class TestExamplesRefCompile(test_examples.TestExamples): | ||
"""Run all TestExamples tests in reference torch.compile mode via HELION_REF_COMPILE=1.""" | ||
|
||
# NOTE: All tests in TestExamples are run in ref torch.compile(fullgraph=True) mode by default in this test file. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently all examples in TestExamples
pass with ref eager mode and ref compile mode.
Planning to add more ref mode unit tests to cover test_reduce.py
, test_associative_scan.py
etc. in the next PR.
@@ -2,3 +2,4 @@ pytest | |||
typing_extensions | |||
pre-commit | |||
filecheck | |||
numpy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately torch.compile
now requires numpy
via torch.onnx
dependency:
File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 275, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 793, in transform
tracer.run()
File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3518, in run
super().run()
File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1371, in run
while self.step():
^^^^^^^^^^^
File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1275, in step
self.dispatch_table[inst.opcode](self, inst)
File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 851, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2922, in CALL
self._call(inst)
File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2916, in _call
self.call_function(fn, args, kwargs)
File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1199, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/variables/torch.py", line 1429, in call_function
special_handler = self._get_handlers().get(self.value)
^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/variables/torch.py", line 476, in _get_handlers
@register(*tracing_state_functions())
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/_dynamo/variables/torch.py", line 193, in tracing_state_functions
torch.onnx.is_in_onnx_export: False,
^^^^^^^^^^
File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/__init__.py", line 2734, in __getattr__
return importlib.import_module(f".{name}", __name__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/venv/lib/python3.12/importlib/__init__.py", line 90, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<frozen importlib._bootstrap>", line 1387, in _gcd_import
File "<frozen importlib._bootstrap>", line 1360, in _find_and_load
File "<frozen importlib._bootstrap>", line 1331, in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 935, in _load_unlocked
File "<frozen importlib._bootstrap_external>", line 999, in exec_module
File "<frozen importlib._bootstrap>", line 488, in _call_with_frames_removed
File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/onnx/__init__.py", line 51, in <module>
from ._internal.exporter._onnx_program import ONNXProgram
File "/opt/conda/envs/venv/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_onnx_program.py", line 18, in <module>
import numpy as np
torch._dynamo.exc.InternalTorchDynamoError: ModuleNotFoundError: No module named 'numpy'
We could either try to patch PyTorch to do sth like "don't access torch.onnx if numpy is not available", or add the numpy
dependency to Helion.
x_part = hl.load( | ||
x, [tile0, tile1], extra_mask=(tile1.index < x.size(1))[None, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should make tile.index work. Maybe rather than changing the examples we should skip tests in reference mode.
helion/ref/hl_patch.py
Outdated
|
||
def zeros(shape: list[int | slice], dtype: torch.dtype = torch.float32) -> Tensor: | ||
processed_shape = _normalize_shape(shape) | ||
return torch.zeros(processed_shape, dtype=dtype, device="cuda") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't hardcode cuda
helion/ref/hl_patch.py
Outdated
yield from dim_tiles[0] | ||
else: | ||
# Multi-dimensional - yield tuples of slices | ||
yield from itertools.product(*dim_tiles) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain why we would need more than 1 tile?
helion/ref/torch_patch.py
Outdated
_e5m2_matmul_available = False | ||
|
||
|
||
def _patched_addmm( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we needing to patch pytorch ops?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
For
addmm
andbaddbmm
:
In Helion examples there is a usage pattern oftorch.addmm(acc_fp32, A_bf16, B_bf16)
which makes sense for expressing the intent of "doing bf16 GEMM then accumulate into fp32", but would fail eager mode / torch.compile execution (error from PyTorch) due to the dtype mismatch between the three input tensors.
To work around this, one way is to patch those PyTorch GEMM ops to allowacc
to have a different dtype than the other inputs (implemented in this PR). -
For
matmul
:
PyTorch fp8 GEMM api (torch._scaled_mm
) doesn't support FP8 E5M2 input, and thus currently there isn't a way to enable PyTorch reference mode for fp8 GEMM kernels that uses FP8 E5M2 input (appears quite common in TritonBench). The workaround here is to overridetorch.matmul
in ref mode to dispatch to a custom Triton e5m2 GEMM kernel, to support this use case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed offline:
-
For
addmm
andbaddbmm
:
Instead of monkey-patching, use__torch_function__
mode (https://docs.pytorch.org/tutorials/recipes/torch_compile_torch_function_modes.html). -
For
matmul
:
We can completely remove this patching and do the following:- Change all examples to use
torch._scaled_mm
instead oftorch.matmul
for fp8 GEMM. - When people are passing in fp8 inputs to
torch.matmul
in Helion and in PyTorch eager mode, show an error message pointing them totorch._scaled_mm
.
- Change all examples to use
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hopefully we can point to the public one soon: pytorch/pytorch#157950
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For bf16bf16->fp32, torch supports this addmm/baddbmm calls with out_dtype
arguments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
e5m2 is very rarely used, the absolute majority of training jobs switched to e4m3, it's also the only fp8 format supported for grouped gemms. Instead of trying to mock coverage for e5m2, should we work with triton bench to switch to more realistic testcases? Bad benchmarks are worse than no benchmarks
self.env = CompileEnvironment(_find_device(args), self.kernel.settings) | ||
with self.env: | ||
with self.env: # pyright: ignore[reportOptionalContextManager] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix types throughout.
ref_eager: bool = False, | ||
ref_compile: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These should do in settings, not as new attributes.
What happens if both of these are set? Should this be an enum?
|
||
if beta == 0: | ||
return result | ||
return result + (beta * input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is losing precision compared to matmul producing fp32 result, and higher precision is the main reason to use it at all. If helion is using torch.addmm(fp32, bf16, bf16)
should it follow pytorch semantics and add out_dtype argument when needed?
Stacked PRs:
PyTorch reference mode (both eager and torch.compile)
Fixes #77.
Please see inline code comments on the PR.