-
Notifications
You must be signed in to change notification settings - Fork 74
Open
Description
https://gist.github.com/Chillee/97e530fe23897d4f730dbb0f89a98d1e
The idea here is that after we've written a helion kernel, if you actually want to use it you can't completely specialize on one shape in practice. And of course, helion autotuning is prohibitively expensive to use on every shape.
You can write an own custom wrapper choosing specific configs for specific settings, but this is kind of painful, and loses you a lot of the benefit of autotuning (what do you do for different hardware?).
So, how do we get the benefits of leveraging autotuning while still being able to use the kernel?
The main goals of this wrapper are:
- I want to primarily only define the @helion.kernel and not need to write significant wrapper code.
- I want the configs to be easily shareable (and ideally understandable) across versions/users.
- I want to be able to easily retune the kernel for different hardware.
- I want the kernel to "work" for most shapes/dtypes, whereas being able to prioritize which shapes I care most about perf for.
For example, a rmsnorm kernel might look like this.
@helion_aot_autotune(
"rms_norm_fwd",
kernel_key=rms_norm_fwd_key,
primary_inputs=partial(rms_norm_fwd_inputs, sizes=[512, 1024, 2048, 4096, 6144, 8192]),
secondary_inputs=partial(
rms_norm_fwd_inputs, sizes=list(range(128, 8192, 128))
),
)
@helion.kernel(static_shapes=False, ignore_warnings=[helion.exc.TensorOperationInWrapper])
def helion_rms_norm_fwd(
x: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-5,
) -> tuple[torch.Tensor, torch.Tensor]:
And it gets perf like this.

Metadata
Metadata
Assignees
Labels
No labels