diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index ad153963..45096edf 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -231,6 +231,11 @@ def to_fake(self, obj: object, origin: Origin) -> object: fn = extract_helper_function(obj) return lift_closures(fn, origin) + # Handle GraphModule - treat it like a function + if isinstance(obj, torch.fx.GraphModule): + # GraphModule can be treated like a callable function + # We return it as-is since it will be called during execution + return obj if isinstance(obj, ConstExpr): return obj.value if isinstance(obj, str): diff --git a/helion/exc.py b/helion/exc.py index e742cb25..4030c575 100644 --- a/helion/exc.py +++ b/helion/exc.py @@ -353,3 +353,7 @@ class InvalidSequenceSubscription(BaseError): class InvalidAPIUsage(BaseError): message = "Invalid usage of Helion API: {0}" + + +class GraphModuleUnsupportedOps(BaseError): + message = "GraphModule contains unsupported operations: {0}. Only pure computation graphs are supported (no load_attr or call_module ops)." diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index 57782fe0..7a0fc370 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -4,6 +4,7 @@ import dataclasses import functools import inspect +import itertools import logging import operator import re @@ -22,7 +23,9 @@ from torch._dynamo.source import TensorProperty from torch._dynamo.source import TensorPropertySource from torch._inductor.codecache import PyCodeCache +from torch._inductor.codecache import compiled_fx_graph_hash from torch._subclasses import FakeTensor +from torch.utils.weak import WeakIdKeyDictionary from .. import exc from .._compiler.ast_extension import unparse @@ -55,6 +58,9 @@ _R = TypeVar("_R") CompiledConfig = Callable[..., _R] +# Cache for GraphModule hashes +_graph_module_hash_cache: WeakIdKeyDictionary = WeakIdKeyDictionary() + class Kernel(Generic[_R]): def __init__( @@ -203,7 +209,10 @@ def _specialization_key(self, obj: object) -> Hashable: try: extractor = _specialization_extractors[type(obj)] except KeyError: - if isinstance(obj, tuple) and hasattr(obj, "_fields"): + if isinstance(obj, torch.fx.GraphModule): + # GraphModule subclasses need special handling + extractor = _specialization_extractors[torch.fx.GraphModule] + elif isinstance(obj, tuple) and hasattr(obj, "_fields"): # this is a namedtuple extractor = _specialization_extractors["namedtuple"] elif dataclasses.is_dataclass(obj): @@ -696,6 +705,27 @@ def _function_key(fn: Kernel, obj: types.FunctionType) -> object: return obj.__code__ +def _graph_module_key(fn: Kernel, obj: torch.fx.GraphModule) -> Hashable: + """Generate a specialization key for GraphModule arguments.""" + # Check if already cached + if obj in _graph_module_hash_cache: + return _graph_module_hash_cache[obj] + + # Check for unsupported operations + unsupported_ops = { + node.op + for node in itertools.chain( + obj.graph.find_nodes(op="call_module"), + obj.graph.find_nodes(op="get_attr"), + ) + } + if unsupported_ops: + raise exc.GraphModuleUnsupportedOps(", ".join(sorted(unsupported_ops))) + + _graph_module_hash_cache[obj] = rv = str(compiled_fx_graph_hash(obj, [], {}, [])) + return rv + + _specialization_extractors: dict[ type[object] | str, Callable[[Kernel, object], Hashable] ] = { # pyright: ignore[reportAssignmentType] @@ -715,6 +745,7 @@ def _function_key(fn: Kernel, obj: types.FunctionType) -> object: "dataclass": lambda fn, x: _mapping_key(fn, dataclasses.asdict(x), type(x)), # pyright: ignore[reportArgumentType] types.FunctionType: _function_key, types.BuiltinFunctionType: lambda fn, x: x, + torch.fx.GraphModule: _graph_module_key, ConstExpr: lambda fn, x: x.value, # pyright: ignore[reportAttributeAccessIssue] } diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index 8fea25e1..b8bc238a 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -95,7 +95,6 @@ class _Settings: RefMode.EAGER if os.environ.get("HELION_INTERPRET", "") == "1" else RefMode.OFF ) autotuner_fn: AutotunerFunction = default_autotuner_fn - set_triton_allocator: bool = True class Settings(_Settings): @@ -118,7 +117,6 @@ class Settings(_Settings): "allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.", "ref_mode": "Reference mode for kernel execution. Can be RefMode.OFF or RefMode.EAGER.", "autotuner_fn": "Function to create an autotuner", - "set_triton_allocator": "If True, insert helion.runtime.set_triton_allocator() call in generated code. Default is True.", } assert __slots__.keys() == {field.name for field in dataclasses.fields(_Settings)} diff --git a/test/test_graph_module.expected b/test/test_graph_module.expected new file mode 100644 index 00000000..ae50ad98 --- /dev/null +++ b/test/test_graph_module.expected @@ -0,0 +1,63 @@ +This file is automatically generated by assertExpectedJournal calls in test_graph_module.py. +Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. + +--- assertExpectedJournal(TestGraphModule.test_graph_module_arg) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from torch._inductor.runtime.triton_helpers import math as tl_math +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_apply_graph_module(x, out, x_size_0, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < x_size_0 + load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + v_0 = 1.0 + v_1 = load + v_0 + v_2 = tl_math.sin(v_1) + tl.store(out + indices_0 * out_stride_0, v_2, mask_0) + +def apply_graph_module(func_m, x, *, _launcher=_default_launcher): + """Kernel that applies a GraphModule function to tensor elements.""" + out = torch.empty_like(x) + _BLOCK_SIZE_0 = 1024 + _launcher(_helion_apply_graph_module, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestGraphModule.test_graph_module_with_multiple_ops) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from torch._inductor.runtime import triton_helpers +from torch._inductor.runtime.triton_helpers import math as tl_math +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_apply_graph_module(x, out, x_size_0, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < x_size_0 + load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + v_0 = 2.0 + v_1 = load * v_0 + v_2 = tl.full([], 0, tl.int32) + v_3 = triton_helpers.maximum(v_2, v_1) + v_4 = 1.0 + v_5 = v_3 + v_4 + v_6 = tl_math.cos(v_5) + tl.store(out + indices_0 * out_stride_0, v_6, mask_0) + +def apply_graph_module(func_m, x, *, _launcher=_default_launcher): + """Kernel that applies a GraphModule function to tensor elements.""" + out = torch.empty_like(x) + _BLOCK_SIZE_0 = 512 + _launcher(_helion_apply_graph_module, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return out diff --git a/test/test_graph_module.py b/test/test_graph_module.py new file mode 100644 index 00000000..1bc43dcd --- /dev/null +++ b/test/test_graph_module.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import unittest + +import torch + +import helion +from helion._testing import DEVICE +from helion._testing import RefEagerTestBase +from helion._testing import TestCase +from helion._testing import code_and_output +from helion._testing import skipIfRefEager +import helion.language as hl + + +@helion.kernel(use_default_config=True) +def apply_graph_module(func_m, x): + """Kernel that applies a GraphModule function to tensor elements.""" + out = torch.empty_like(x) + for tile in hl.tile(out.size()): + out[tile] = func_m(x[tile]) + return out + + +class TestGraphModule(RefEagerTestBase, TestCase): + def test_graph_module_arg(self): + """Test that GraphModule arguments work in kernels.""" + x = torch.randn(1000, device=DEVICE) + + # Create a GraphModule with a simple computation + gm = torch.fx.symbolic_trace(lambda x: torch.sin(x + 1)) + + # This should work - GraphModule is treated like a function call + code, result = code_and_output(apply_graph_module, (gm, x)) + expected = torch.sin(x + 1) + + torch.testing.assert_close(result, expected) + self.assertExpectedJournal(code) + + def test_graph_module_with_multiple_ops(self): + """Test GraphModule with multiple operations.""" + x = torch.randn(512, device=DEVICE) + + # Create a more complex GraphModule + def complex_func(x): + return torch.cos(torch.relu(x * 2) + 1) + + gm = torch.fx.symbolic_trace(complex_func) + + code, result = code_and_output(apply_graph_module, (gm, x)) + expected = complex_func(x) + + torch.testing.assert_close(result, expected) + self.assertExpectedJournal(code) + + def test_graph_module_specialization(self): + """Test that different GraphModules get specialized separately.""" + x = torch.randn(256, device=DEVICE) + + # Create two different GraphModules + gm1 = torch.fx.symbolic_trace(lambda x: torch.sin(x)) + gm2 = torch.fx.symbolic_trace(lambda x: torch.cos(x)) + + # Each should get its own specialization + code1, result1 = code_and_output(apply_graph_module, (gm1, x)) + code2, result2 = code_and_output(apply_graph_module, (gm2, x)) + + torch.testing.assert_close(result1, torch.sin(x)) + torch.testing.assert_close(result2, torch.cos(x)) + + @skipIfRefEager("doesn't make required call") + def test_graph_module_with_unsupported_ops(self): + """Test that GraphModules with unsupported ops raise an error.""" + x = torch.randn(128, device=DEVICE) + + # Create a module with call_module (unsupported) + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(128, 128) + + def forward(self, x): + return self.linear(x) + + module = MyModule() + gm = torch.fx.symbolic_trace(module) + + # This should raise an error due to call_module op + with self.assertRaises(helion.exc.GraphModuleUnsupportedOps) as cm: + apply_graph_module(gm, x) + + self.assertIn("call_module", str(cm.exception)) + + def test_graph_module_caching(self): + """Test that GraphModule hash caching works correctly.""" + x = torch.randn(256, device=DEVICE) + + # Create a GraphModule + gm = torch.fx.symbolic_trace(lambda x: torch.sin(x)) + + # Call the kernel twice with the same GraphModule + # Should use cached hash the second time + code1, result1 = code_and_output(apply_graph_module, (gm, x)) + code2, result2 = code_and_output(apply_graph_module, (gm, x)) + + torch.testing.assert_close(result1, torch.sin(x)) + torch.testing.assert_close(result2, torch.sin(x)) + + # Same GraphModule should produce same code + self.assertEqual(code1, code2) + + +if __name__ == "__main__": + unittest.main()