Skip to content

Passing in make_fx input into helion #434

@mitkotak

Description

@mitkotak

Hey was wondering whether I can pass in a make_fx()'d graph into helion to leverage its autotuner and spit out triton code.

import torch

import helion
import helion.language as hl
from torch.fx.experimental.proxy_tensor import make_fx

class F(torch.nn.Module):
    def forward(self, x: torch.Tensor):
        return x + 1


# Need to pass device args otherwise it complains
@helion.kernel()
def F_kernel(func_m, device):
    return func_m

x = torch.randn(1000).to(device='cuda')
m = make_fx(F(), tracing_mode="symbolic", _allow_non_fake_inputs=True, _error_on_data_dependent_ops=True)(x)

torch.testing.assert_allclose(m(x), F()(x))

F_kernel(lambda x: m(x), x)
Traceback (most recent call last):
  File "/home/mkotak/atomic_architects/projects/helion_playground/fx_helion.py", line 24, in <module>
    F_kernel(lambda x: F()(x), x)
  File "/home/mkotak/atomic_architects/projects/helion_playground/.venv/lib/python3.11/site-packages/helion/runtime/kernel.py", line 272, in __call__
    return self.bind(args)(*args)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkotak/atomic_architects/projects/helion_playground/.venv/lib/python3.11/site-packages/helion/runtime/kernel.py", line 581, in __call__
    self.autotune(args)
  File "/home/mkotak/atomic_architects/projects/helion_playground/.venv/lib/python3.11/site-packages/helion/runtime/kernel.py", line 473, in autotune
    config = self.settings.autotuner_fn(self, args, **kwargs).autotune()
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkotak/atomic_architects/projects/helion_playground/.venv/lib/python3.11/site-packages/helion/runtime/settings.py", line 68, in default_autotuner_fn
    return LocalAutotuneCache(DifferentialEvolutionSearch(bound_kernel, args, **kwargs))  # pyright: ignore[reportArgumentType]
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkotak/atomic_architects/projects/helion_playground/.venv/lib/python3.11/site-packages/helion/autotuner/differential_evolution.py", line 34, in __init__
    super().__init__(kernel, args)
  File "/home/mkotak/atomic_architects/projects/helion_playground/.venv/lib/python3.11/site-packages/helion/autotuner/base_search.py", line 319, in __init__
    self.config_gen: ConfigGeneration = ConfigGeneration(self.config_spec)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mkotak/atomic_architects/projects/helion_playground/.venv/lib/python3.11/site-packages/helion/autotuner/config_generation.py", line 55, in __init__
    self.min_block_size: int = max(
                               ^^^^
ValueError: max() arg is an empty sequence

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions