Skip to content

[RFC] Implement basic on disk caching #336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions helion/_testing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import collections
import contextlib
import importlib
import inspect
import operator
Expand All @@ -15,6 +16,7 @@
import torch
from triton.testing import do_bench

from ._utils import counters
from .runtime.config import Config
from helion._compat import get_tensor_descriptor_fn_name

Expand Down Expand Up @@ -291,6 +293,20 @@ def tearDownClass(cls) -> None:
super().tearDownClass()
del cls._expected_journal

def setUp(self) -> None:
super().setUp()
self._test_stack = contextlib.ExitStack()

from torch._inductor.utils import fresh_cache

self._test_stack.enter_context(fresh_cache())

counters.clear()

def tearDown(self) -> None:
super().tearDown()
self._test_stack.close()

def assertExpectedJournal(self, value: str) -> None:
"""
Assert that the given value matches the expected output stored in <testfile>.expected.
Expand Down
7 changes: 7 additions & 0 deletions helion/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from __future__ import annotations

import collections

counters: collections.defaultdict[str, collections.Counter[str]] = (
collections.defaultdict(collections.Counter)
)
2 changes: 2 additions & 0 deletions helion/autotuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@
DifferentialEvolutionSearch as DifferentialEvolutionSearch,
)
from .finite_search import FiniteSearch as FiniteSearch
from .local_cache import LocalAutotuneCache as LocalAutotuneCache
from .local_cache import StrictLocalAutotuneCache as StrictLocalAutotuneCache
from .random_search import RandomSearch as RandomSearch
149 changes: 149 additions & 0 deletions helion/autotuner/base_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from __future__ import annotations

import abc
import dataclasses
import functools
import hashlib
import logging
import os
from typing import TYPE_CHECKING
from typing import Hashable
from typing import Sequence

from .._utils import counters

if TYPE_CHECKING:
from ..runtime.config import Config
from ..runtime.kernel import BoundKernel
from .base_search import BaseSearch

log: logging.Logger = logging.getLogger(__name__)


@functools.cache
def helion_key() -> str:
from torch._inductor.codecache import build_code_hash

here = os.path.abspath(__file__)
helion_path = os.path.dirname(os.path.dirname(here))

combined_hash = hashlib.sha256()
build_code_hash([helion_path], "", combined_hash)
return combined_hash.hexdigest()


@functools.cache
def torch_key_wrapper() -> str:
from torch._inductor.codecache import torch_key

return torch_key().hex()


@functools.cache
def triton_key_wrapper() -> str:
from torch._inductor.runtime.triton_compat import triton_key

return triton_key()


class CacheKeyBase:
"""
Base class to provide utility functions to all cache key dataclasses
"""

def stable_hash(self) -> str:
return hashlib.sha256(repr(self).encode("utf-8")).hexdigest()


@dataclasses.dataclass(frozen=True)
class BoundKernelInMemoryCacheKey(CacheKeyBase):
"""
Default in memory cache key.

This key includes:

specialization_key: Information about all kernel inputs.
For tensors this means their device, shape, size etc.
extra_results: Information regarding `hl.specialize` decisions
"""

specialization_key: tuple[Hashable, ...]
extra_results: tuple[Hashable, ...]


@dataclasses.dataclass(frozen=True)
class LooseAutotuneCacheKey(BoundKernelInMemoryCacheKey):
"""
Autotune Cache key to use for most use cases.

This key includes (in addition to BoundKernelInMemoryCacheKey):

kernel_source_hash: Hash of source code of input Helion kernel
hardware: Hardware of the input device
runtime_name: Version of the cuda/rocm arch
"""

kernel_source_hash: str
hardware: str
runtime_name: str

def stable_hash(self) -> str:
return hashlib.sha256(repr(self).encode("utf-8")).hexdigest()


@dataclasses.dataclass(frozen=True)
class StrictAutotuneCacheKey(LooseAutotuneCacheKey):
"""
Autotune Cache key to use for utmost strictness in terms of re-autotuning
when library source code changes.

This key includes (in addition to StrictAutotuneCacheKey):

helion_key: Hash of source code of Helion
torch_key: Hash of source code of PyTorch
triton_key: Hash of source code of Triton
"""

helion_key: str = dataclasses.field(default_factory=helion_key)
torch_key: str = dataclasses.field(default_factory=torch_key_wrapper)
triton_key: str = dataclasses.field(default_factory=triton_key_wrapper)


class AutotuneCacheBase(abc.ABC):
"""
Abstract base class that all autotune caches need to implement.
Any user defined cache will need to extend this class, and
provide implementations for get and put methods.
"""

def __init__(
self, kernel: BoundKernel, args: Sequence[object], autotuner: BaseSearch
) -> None:
self.autotuner = autotuner
self.kernel = kernel
self.args = args

@abc.abstractmethod
def get(self) -> Config | None:
raise NotImplementedError

@abc.abstractmethod
def put(self, config: Config) -> None:
raise NotImplementedError

def autotune(self) -> Config:
if (config := self.get()) is not None:
counters["autotune"]["cache_hit"] += 1
log.debug("cache hit: %s", str(config))
return config

counters["autotune"]["cache_miss"] += 1
log.debug("cache miss")

config = self.autotuner.autotune()

self.put(config)
counters["autotune"]["cache_put"] += 1
log.debug("cache put: %s", str(config))

return config
107 changes: 107 additions & 0 deletions helion/autotuner/local_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from __future__ import annotations

import hashlib
import inspect
import logging
import os
from pathlib import Path
import textwrap
from typing import TYPE_CHECKING
from typing import Sequence

import torch

from ..runtime.config import Config
from .base_cache import AutotuneCacheBase
from .base_cache import LooseAutotuneCacheKey
from .base_cache import StrictAutotuneCacheKey

if TYPE_CHECKING:
from ..runtime.kernel import BoundKernel
from .base_search import BaseSearch

log: logging.Logger = logging.getLogger(__name__)


class LocalAutotuneCache(AutotuneCacheBase):
"""
This class implements the local autotune cache, storing the
best config artifact on the local file system either by default
on torch's cache directory, or at a user specified HELION_CACHE_DIR
directory.
It uses the LooseAutotuneCacheKey implementation for the cache key
which takes into account device and source code properties, but does
not account for library level code changes such as Triton, Helion or
PyTorch. Use StrictLocalAutotuneCache to consider these properties.
"""

def __init__(
self, kernel: BoundKernel, args: Sequence[object], autotuner: BaseSearch
) -> None:
super().__init__(kernel, args, autotuner)
self.key = self._generate_key()

def _generate_key(self) -> LooseAutotuneCacheKey:
in_memory_cache_key = self.kernel.kernel._create_bound_kernel_cache_key(
self.kernel,
tuple(self.args),
self.kernel.kernel.specialization_key(self.args),
)
kernel_source = textwrap.dedent(inspect.getsource(self.kernel.kernel.fn))
kernel_source_hash = hashlib.sha256(kernel_source.encode("utf-8")).hexdigest()

hardware = None
runtime_name = None

for arg in self.args:
if isinstance(arg, torch.Tensor):
device_properties = torch.cuda.get_device_properties(arg.device)
if torch.version.cuda is not None: # pyright: ignore[reportAttributeAccessIssue]
hardware = device_properties.name
runtime_name = torch.version.cuda # pyright: ignore[reportAttributeAccessIssue]
else:
hardware = device_properties.gcnArchName
runtime_name = torch.version.hip # pyright: ignore[reportAttributeAccessIssue]

assert hardware is not None and runtime_name is not None
return LooseAutotuneCacheKey(
specialization_key=in_memory_cache_key.specialization_key,
extra_results=in_memory_cache_key.extra_results,
kernel_source_hash=kernel_source_hash,
hardware=hardware,
runtime_name=runtime_name,
)

def _get_local_cache_path(self) -> Path:
if (user_path := os.environ.get("HELION_CACHE_DIR", None)) is not None:
cache_path = Path(user_path)
else:
from torch._inductor.runtime.cache_dir_utils import (
cache_dir, # pyright: ignore[reportPrivateImportUsage]
)

cache_path = Path(cache_dir()) / "helion"

return cache_path / f"{self.key.stable_hash()}.best_config"

def get(self) -> Config | None:
path = self._get_local_cache_path()
try:
return Config.load(path)
except Exception:
return None

def put(self, config: Config) -> None:
path = self._get_local_cache_path()
config.save(path)


class StrictLocalAutotuneCache(LocalAutotuneCache):
"""
Stricter implementation of the local autotune cache, which takes into
account library level code changes such as Triton, Helion or PyTorch.
"""

def _generate_key(self) -> StrictAutotuneCacheKey:
loose_key = super()._generate_key()
return StrictAutotuneCacheKey(**vars(loose_key))
10 changes: 9 additions & 1 deletion helion/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from collections.abc import Iterator
from collections.abc import Mapping
import json
import os
from pathlib import Path
from typing import Literal
from typing import cast
import uuid

from ..autotuner.config_spec import DEFAULT_NUM_STAGES
from ..autotuner.config_spec import DEFAULT_NUM_WARPS
Expand Down Expand Up @@ -118,7 +120,13 @@ def from_json(cls, json_str: str) -> Config:

def save(self, path: str | Path) -> None:
"""Save the config to a JSON file."""
Path(path).write_text(self.to_json())
# Write to temp dir and rename to make the operation atomic
# in case we are in a multithreaded environment
Path(path).parent.mkdir(parents=True, exist_ok=True)

tmp = Path(path).parent / f"tmp.{uuid.uuid4()!s}"
tmp.write_text(self.to_json())
os.rename(str(tmp), str(path))

@classmethod
def load(cls, path: str | Path) -> Config:
Expand Down
21 changes: 13 additions & 8 deletions helion/runtime/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from torch._guards import Source

from ..autotuner import ConfigSpec
from ..autotuner.base_cache import BoundKernelInMemoryCacheKey

ConfigLike = Config | dict[str, object]

Expand All @@ -53,12 +54,6 @@
CompiledConfig = Callable[..., _R]


@dataclasses.dataclass(frozen=True)
class BoundKernelInMemoryCacheKey:
specialization_key: tuple[Hashable, ...]
extra_results: tuple[Hashable, ...]


class Kernel(Generic[_R]):
def __init__(
self,
Expand Down Expand Up @@ -114,6 +109,8 @@ def __init__(
def _get_bound_kernel_cache_key(
self, args: tuple[object, ...], signature: tuple[Hashable, ...]
) -> BoundKernelInMemoryCacheKey | None:
from ..autotuner.base_cache import BoundKernelInMemoryCacheKey

extra_fns = self._specialize_extra.get(signature)
if extra_fns is not None:
extra_results: tuple[Hashable, ...] = tuple([s(args) for s in extra_fns])
Expand All @@ -126,6 +123,8 @@ def _create_bound_kernel_cache_key(
args: tuple[object, ...],
signature: tuple[Hashable, ...],
) -> BoundKernelInMemoryCacheKey:
from ..autotuner.base_cache import BoundKernelInMemoryCacheKey

self._specialize_extra[signature] = extra_fns = bound_kernel._specialize_extra()
extra_results: tuple[Hashable, ...] = tuple([s(args) for s in extra_fns])
return BoundKernelInMemoryCacheKey(signature, extra_results)
Expand Down Expand Up @@ -458,12 +457,18 @@ def autotune(
self.settings.check_autotuning_disabled()

from ..autotuner import DifferentialEvolutionSearch
from ..autotuner import LocalAutotuneCache

config = DifferentialEvolutionSearch(
config = LocalAutotuneCache(
self,
args,
**kwargs, # pyright: ignore[reportArgumentType]
DifferentialEvolutionSearch(
self,
args,
**kwargs, # pyright: ignore[reportArgumentType]
),
).autotune()

self.set_config(config)
return config

Expand Down
Loading
Loading