Skip to content

[Ref Mode] PyTorch reference mode (eager only) #339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions helion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
from .runtime import Kernel
from .runtime import kernel
from .runtime import kernel as jit # alias
from .runtime.settings import RefMode
from .runtime.settings import Settings
from .runtime.settings import set_default_settings

__all__ = [
"Config",
"Kernel",
"RefMode",
"Settings",
"cdiv",
"exc",
Expand Down
11 changes: 11 additions & 0 deletions helion/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ._utils import counters
from .runtime.config import Config
from helion._compat import get_tensor_descriptor_fn_name
from helion.runtime.ref_mode import is_ref_mode_enabled

if TYPE_CHECKING:
import types
Expand Down Expand Up @@ -47,6 +48,16 @@ def code_and_output(
args: tuple[object, ...],
**kwargs: object,
) -> tuple[str, object]:
bound = fn.bind(args)
if is_ref_mode_enabled(bound.kernel.settings):
if kwargs:
config = Config(**kwargs) # pyright: ignore[reportArgumentType]
bound._config = config
result = fn(*args)
# Return the original kernel source code
code = inspect.getsource(fn.fn)
return code, result

if kwargs:
config = Config(
**kwargs # pyright: ignore[reportArgumentType]
Expand Down
27 changes: 27 additions & 0 deletions helion/language/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class APIFunc(Protocol):
_to_device_ir: Callable[..., object] | None
_allow_host_tensor: bool
_signature: inspect.Signature
_ref_fn: Callable[..., object] | None

def __call__(self, *args: object, **kwargs: object) -> object: ...

Expand Down Expand Up @@ -133,8 +134,17 @@ def api(
def _impl(fn: _C) -> _C:
@functools.wraps(fn)
def wrapper(*args: object, **kwargs: object) -> object:
from ..runtime.ref_mode import is_in_ref_mode_context

bound = api._signature.bind(*args, **kwargs)
bound.apply_defaults()

if is_in_ref_mode_context():
assert api._ref_fn is not None, (
f"{fn.__qualname__} must be decorated with @helion.ref() to be used in ref mode"
)
return api._ref_fn(*bound.arguments.values())

flat_args = api._prepare_args(*bound.arguments.values())
del args, kwargs

Expand Down Expand Up @@ -187,6 +197,7 @@ def wrapper(*args: object, **kwargs: object) -> object:
api._signature = signature or inspect.signature(
cast("Callable[..., object]", fn)
)
api._ref_fn = None
return wrapper # pyright: ignore[reportReturnType]

return _impl
Expand Down Expand Up @@ -289,6 +300,22 @@ def _impl(to_device_ir_fn: Callable[..., object]) -> Callable[..., Never]:
return _impl # pyright: ignore[reportReturnType]


def ref(
original_fn: Callable[..., object],
) -> _NoReturnDecorator[object]:
def _impl(ref_fn: Callable[..., object]) -> Callable[..., Never]:
assert is_api_func(original_fn), (
f"{ref.__qualname__} can only be used on API functions"
)
assert original_fn._ref_fn is None, (
"ref mode implementation can only be registered once per function"
)
original_fn._ref_fn = ref_fn
return _no_call

return _impl # pyright: ignore[reportReturnType]


def _default_type_function(
fake_fn: Callable[..., object], tiles_as_sizes: bool
) -> Callable[..., TypeInfo]:
Expand Down
5 changes: 5 additions & 0 deletions helion/language/constexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,8 @@ def _(state: CodegenState) -> ast.AST:
value = value.__int__()
assert isinstance(value, int)
return expr_from_string(repr(value))


@_decorators.ref(specialize)
def _(value: int | torch.SymInt) -> int:
return int(value)
17 changes: 17 additions & 0 deletions helion/language/creation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .._compiler.compile_environment import CompileEnvironment
from ..exc import NotInsideKernel
from . import _decorators
from .ref_tile import RefTile

if TYPE_CHECKING:
import ast
Expand Down Expand Up @@ -144,6 +145,22 @@ def _(
return None


@_decorators.ref(full)
def _(
shape: list[int | RefTile],
value: float,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
processed_shape = []
for s in shape:
if isinstance(s, RefTile):
processed_shape.append(s.stop - s.start)
else:
processed_shape.append(s)
env = CompileEnvironment.current()
return torch.full(processed_shape, value, dtype=dtype, device=env.device)


def arange(
*args: int,
dtype: torch.dtype | None = None,
Expand Down
Loading
Loading