Skip to content

Support GraphModule inputs #488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,11 @@ def to_fake(self, obj: object, origin: Origin) -> object:

fn = extract_helper_function(obj)
return lift_closures(fn, origin)
# Handle GraphModule - treat it like a function
if isinstance(obj, torch.fx.GraphModule):
# GraphModule can be treated like a callable function
# We return it as-is since it will be called during execution
return obj
if isinstance(obj, ConstExpr):
return obj.value
if isinstance(obj, str):
Expand Down
4 changes: 4 additions & 0 deletions helion/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,3 +353,7 @@ class InvalidSequenceSubscription(BaseError):

class InvalidAPIUsage(BaseError):
message = "Invalid usage of Helion API: {0}"


class GraphModuleUnsupportedOps(BaseError):
message = "GraphModule contains unsupported operations: {0}. Only pure computation graphs are supported (no load_attr or call_module ops)."
33 changes: 32 additions & 1 deletion helion/runtime/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import dataclasses
import functools
import inspect
import itertools
import logging
import operator
import re
Expand All @@ -22,7 +23,9 @@
from torch._dynamo.source import TensorProperty
from torch._dynamo.source import TensorPropertySource
from torch._inductor.codecache import PyCodeCache
from torch._inductor.codecache import compiled_fx_graph_hash
from torch._subclasses import FakeTensor
from torch.utils.weak import WeakIdKeyDictionary

from .. import exc
from .._compiler.ast_extension import unparse
Expand Down Expand Up @@ -55,6 +58,9 @@
_R = TypeVar("_R")
CompiledConfig = Callable[..., _R]

# Cache for GraphModule hashes
_graph_module_hash_cache: WeakIdKeyDictionary = WeakIdKeyDictionary()


class Kernel(Generic[_R]):
def __init__(
Expand Down Expand Up @@ -203,7 +209,10 @@ def _specialization_key(self, obj: object) -> Hashable:
try:
extractor = _specialization_extractors[type(obj)]
except KeyError:
if isinstance(obj, tuple) and hasattr(obj, "_fields"):
if isinstance(obj, torch.fx.GraphModule):
# GraphModule subclasses need special handling
extractor = _specialization_extractors[torch.fx.GraphModule]
elif isinstance(obj, tuple) and hasattr(obj, "_fields"):
# this is a namedtuple
extractor = _specialization_extractors["namedtuple"]
elif dataclasses.is_dataclass(obj):
Expand Down Expand Up @@ -696,6 +705,27 @@ def _function_key(fn: Kernel, obj: types.FunctionType) -> object:
return obj.__code__


def _graph_module_key(fn: Kernel, obj: torch.fx.GraphModule) -> Hashable:
"""Generate a specialization key for GraphModule arguments."""
# Check if already cached
if obj in _graph_module_hash_cache:
return _graph_module_hash_cache[obj]

# Check for unsupported operations
unsupported_ops = {
node.op
for node in itertools.chain(
obj.graph.find_nodes(op="call_module"),
obj.graph.find_nodes(op="get_attr"),
)
}
if unsupported_ops:
raise exc.GraphModuleUnsupportedOps(", ".join(sorted(unsupported_ops)))

_graph_module_hash_cache[obj] = rv = str(compiled_fx_graph_hash(obj, [], {}, []))
return rv


_specialization_extractors: dict[
type[object] | str, Callable[[Kernel, object], Hashable]
] = { # pyright: ignore[reportAssignmentType]
Expand All @@ -715,6 +745,7 @@ def _function_key(fn: Kernel, obj: types.FunctionType) -> object:
"dataclass": lambda fn, x: _mapping_key(fn, dataclasses.asdict(x), type(x)), # pyright: ignore[reportArgumentType]
types.FunctionType: _function_key,
types.BuiltinFunctionType: lambda fn, x: x,
torch.fx.GraphModule: _graph_module_key,
ConstExpr: lambda fn, x: x.value, # pyright: ignore[reportAttributeAccessIssue]
}

Expand Down
2 changes: 0 additions & 2 deletions helion/runtime/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ class _Settings:
RefMode.EAGER if os.environ.get("HELION_INTERPRET", "") == "1" else RefMode.OFF
)
autotuner_fn: AutotunerFunction = default_autotuner_fn
set_triton_allocator: bool = True


class Settings(_Settings):
Expand All @@ -118,7 +117,6 @@ class Settings(_Settings):
"allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.",
"ref_mode": "Reference mode for kernel execution. Can be RefMode.OFF or RefMode.EAGER.",
"autotuner_fn": "Function to create an autotuner",
"set_triton_allocator": "If True, insert helion.runtime.set_triton_allocator() call in generated code. Default is True.",
}
assert __slots__.keys() == {field.name for field in dataclasses.fields(_Settings)}

Expand Down
63 changes: 63 additions & 0 deletions test/test_graph_module.expected
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
This file is automatically generated by assertExpectedJournal calls in test_graph_module.py.
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.

--- assertExpectedJournal(TestGraphModule.test_graph_module_arg)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from torch._inductor.runtime.triton_helpers import math as tl_math
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_apply_graph_module(x, out, x_size_0, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < x_size_0
load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
v_0 = 1.0
v_1 = load + v_0
v_2 = tl_math.sin(v_1)
tl.store(out + indices_0 * out_stride_0, v_2, mask_0)

def apply_graph_module(func_m, x, *, _launcher=_default_launcher):
"""Kernel that applies a GraphModule function to tensor elements."""
out = torch.empty_like(x)
_BLOCK_SIZE_0 = 1024
_launcher(_helion_apply_graph_module, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestGraphModule.test_graph_module_with_multiple_ops)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers
from torch._inductor.runtime.triton_helpers import math as tl_math
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_apply_graph_module(x, out, x_size_0, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < x_size_0
load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
v_0 = 2.0
v_1 = load * v_0
v_2 = tl.full([], 0, tl.int32)
v_3 = triton_helpers.maximum(v_2, v_1)
v_4 = 1.0
v_5 = v_3 + v_4
v_6 = tl_math.cos(v_5)
tl.store(out + indices_0 * out_stride_0, v_6, mask_0)

def apply_graph_module(func_m, x, *, _launcher=_default_launcher):
"""Kernel that applies a GraphModule function to tensor elements."""
out = torch.empty_like(x)
_BLOCK_SIZE_0 = 512
_launcher(_helion_apply_graph_module, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return out
114 changes: 114 additions & 0 deletions test/test_graph_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from __future__ import annotations

import unittest

import torch

import helion
from helion._testing import DEVICE
from helion._testing import RefEagerTestBase
from helion._testing import TestCase
from helion._testing import code_and_output
from helion._testing import skipIfRefEager
import helion.language as hl


@helion.kernel(use_default_config=True)
def apply_graph_module(func_m, x):
"""Kernel that applies a GraphModule function to tensor elements."""
out = torch.empty_like(x)
for tile in hl.tile(out.size()):
out[tile] = func_m(x[tile])
return out


class TestGraphModule(RefEagerTestBase, TestCase):
def test_graph_module_arg(self):
"""Test that GraphModule arguments work in kernels."""
x = torch.randn(1000, device=DEVICE)

# Create a GraphModule with a simple computation
gm = torch.fx.symbolic_trace(lambda x: torch.sin(x + 1))

# This should work - GraphModule is treated like a function call
code, result = code_and_output(apply_graph_module, (gm, x))
expected = torch.sin(x + 1)

torch.testing.assert_close(result, expected)
self.assertExpectedJournal(code)

def test_graph_module_with_multiple_ops(self):
"""Test GraphModule with multiple operations."""
x = torch.randn(512, device=DEVICE)

# Create a more complex GraphModule
def complex_func(x):
return torch.cos(torch.relu(x * 2) + 1)

gm = torch.fx.symbolic_trace(complex_func)

code, result = code_and_output(apply_graph_module, (gm, x))
expected = complex_func(x)

torch.testing.assert_close(result, expected)
self.assertExpectedJournal(code)

def test_graph_module_specialization(self):
"""Test that different GraphModules get specialized separately."""
x = torch.randn(256, device=DEVICE)

# Create two different GraphModules
gm1 = torch.fx.symbolic_trace(lambda x: torch.sin(x))
gm2 = torch.fx.symbolic_trace(lambda x: torch.cos(x))

# Each should get its own specialization
code1, result1 = code_and_output(apply_graph_module, (gm1, x))
code2, result2 = code_and_output(apply_graph_module, (gm2, x))

torch.testing.assert_close(result1, torch.sin(x))
torch.testing.assert_close(result2, torch.cos(x))

@skipIfRefEager("doesn't make required call")
def test_graph_module_with_unsupported_ops(self):
"""Test that GraphModules with unsupported ops raise an error."""
x = torch.randn(128, device=DEVICE)

# Create a module with call_module (unsupported)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(128, 128)

def forward(self, x):
return self.linear(x)

module = MyModule()
gm = torch.fx.symbolic_trace(module)

# This should raise an error due to call_module op
with self.assertRaises(helion.exc.GraphModuleUnsupportedOps) as cm:
apply_graph_module(gm, x)

self.assertIn("call_module", str(cm.exception))

def test_graph_module_caching(self):
"""Test that GraphModule hash caching works correctly."""
x = torch.randn(256, device=DEVICE)

# Create a GraphModule
gm = torch.fx.symbolic_trace(lambda x: torch.sin(x))

# Call the kernel twice with the same GraphModule
# Should use cached hash the second time
code1, result1 = code_and_output(apply_graph_module, (gm, x))
code2, result2 = code_and_output(apply_graph_module, (gm, x))

torch.testing.assert_close(result1, torch.sin(x))
torch.testing.assert_close(result2, torch.sin(x))

# Same GraphModule should produce same code
self.assertEqual(code1, code2)


if __name__ == "__main__":
unittest.main()
Loading