diff --git a/helion/_testing.py b/helion/_testing.py index c323152d..76e29dde 100644 --- a/helion/_testing.py +++ b/helion/_testing.py @@ -1,6 +1,7 @@ from __future__ import annotations import collections +import contextlib import importlib import inspect import operator @@ -15,6 +16,7 @@ import torch from triton.testing import do_bench +from ._utils import counters from .runtime.config import Config from helion._compat import get_tensor_descriptor_fn_name @@ -291,6 +293,20 @@ def tearDownClass(cls) -> None: super().tearDownClass() del cls._expected_journal + def setUp(self) -> None: + super().setUp() + self._test_stack = contextlib.ExitStack() + + from torch._inductor.utils import fresh_cache + + self._test_stack.enter_context(fresh_cache()) + + counters.clear() + + def tearDown(self) -> None: + super().tearDown() + self._test_stack.close() + def assertExpectedJournal(self, value: str) -> None: """ Assert that the given value matches the expected output stored in .expected. diff --git a/helion/_utils.py b/helion/_utils.py new file mode 100644 index 00000000..73824ff3 --- /dev/null +++ b/helion/_utils.py @@ -0,0 +1,7 @@ +from __future__ import annotations + +import collections + +counters: collections.defaultdict[str, collections.Counter[str]] = ( + collections.defaultdict(collections.Counter) +) diff --git a/helion/autotuner/__init__.py b/helion/autotuner/__init__.py index bb38ca46..de09d5ca 100644 --- a/helion/autotuner/__init__.py +++ b/helion/autotuner/__init__.py @@ -9,4 +9,6 @@ DifferentialEvolutionSearch as DifferentialEvolutionSearch, ) from .finite_search import FiniteSearch as FiniteSearch +from .local_cache import LocalAutotuneCache as LocalAutotuneCache +from .local_cache import StrictLocalAutotuneCache as StrictLocalAutotuneCache from .random_search import RandomSearch as RandomSearch diff --git a/helion/autotuner/base_cache.py b/helion/autotuner/base_cache.py new file mode 100644 index 00000000..64e4718c --- /dev/null +++ b/helion/autotuner/base_cache.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import abc +import dataclasses +import functools +import hashlib +import logging +import os +from typing import TYPE_CHECKING +from typing import Hashable +from typing import Sequence + +from torch._inductor.codecache import build_code_hash +from torch._inductor.codecache import torch_key +from torch._inductor.runtime.triton_compat import triton_key + +from .._utils import counters + +if TYPE_CHECKING: + from ..runtime.config import Config + from ..runtime.kernel import BoundKernel + from .base_search import BaseSearch + +log: logging.Logger = logging.getLogger(__name__) + + +@functools.cache +def helion_key() -> str: + here = os.path.abspath(__file__) + helion_path = os.path.dirname(os.path.dirname(here)) + + combined_hash = hashlib.sha256() + build_code_hash([helion_path], "", combined_hash) + return combined_hash.hexdigest() + + +@functools.cache +def torch_key_wrapper() -> str: + return torch_key().hex() + + +@functools.cache +def triton_key_wrapper() -> str: + return triton_key() + + +class CacheKeyBase: + """ + Base class to provide utility functions to all cache key dataclasses + """ + + def stable_hash(self) -> str: + return hashlib.sha256(repr(self).encode("utf-8")).hexdigest() + + +@dataclasses.dataclass(frozen=True) +class BoundKernelInMemoryCacheKey(CacheKeyBase): + """ + Default in memory cache key. + + This key includes: + + specialization_key: Information about all kernel inputs. + For tensors this means their device, shape, size etc. + extra_results: Information regarding `hl.specialize` decisions + """ + + specialization_key: tuple[Hashable, ...] + extra_results: tuple[Hashable, ...] + + +@dataclasses.dataclass(frozen=True) +class LooseAutotuneCacheKey(BoundKernelInMemoryCacheKey): + """ + Autotune Cache key to use for most use cases. + + This key includes (in addition to BoundKernelInMemoryCacheKey): + + kernel_source_hash: Hash of source code of input Helion kernel + hardware: Hardware of the input device + runtime_name: Version of the cuda/rocm arch + """ + + kernel_source_hash: str + hardware: str + runtime_name: str + + def stable_hash(self) -> str: + return hashlib.sha256(repr(self).encode("utf-8")).hexdigest() + + +@dataclasses.dataclass(frozen=True) +class StrictAutotuneCacheKey(LooseAutotuneCacheKey): + """ + Autotune Cache key to use for utmost strictness in terms of re-autotuning + when library source code changes. + + This key includes (in addition to StrictAutotuneCacheKey): + + helion_key: Hash of source code of Helion + torch_key: Hash of source code of PyTorch + triton_key: Hash of source code of Triton + """ + + helion_key: str = dataclasses.field(default_factory=helion_key) + torch_key: str = dataclasses.field(default_factory=torch_key_wrapper) + triton_key: str = dataclasses.field(default_factory=triton_key_wrapper) + + +class AutotuneCacheBase(abc.ABC): + """ + Abstract base class that all autotune caches need to implement. + Any user defined cache will need to extend this class, and + provide implementations for get and put methods. + """ + + def __init__( + self, kernel: BoundKernel, args: Sequence[object], autotuner: BaseSearch + ) -> None: + self.autotuner = autotuner + self.kernel = kernel + self.args = args + + @abc.abstractmethod + def get(self) -> Config | None: + raise NotImplementedError + + @abc.abstractmethod + def put(self, config: Config) -> None: + raise NotImplementedError + + def autotune(self) -> Config: + if (config := self.get()) is not None: + counters["autotune"]["cache_hit"] += 1 + log.debug("cache hit: %s", str(config)) + return config + + counters["autotune"]["cache_miss"] += 1 + log.debug("cache miss") + + config = self.autotuner.autotune() + + self.put(config) + counters["autotune"]["cache_put"] += 1 + log.debug("cache put: %s", str(config)) + + return config diff --git a/helion/autotuner/local_cache.py b/helion/autotuner/local_cache.py new file mode 100644 index 00000000..2277aaa7 --- /dev/null +++ b/helion/autotuner/local_cache.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import hashlib +import inspect +import logging +import os +from pathlib import Path +import textwrap +from typing import TYPE_CHECKING +from typing import Sequence + +import torch +from torch._inductor.runtime.cache_dir_utils import ( + cache_dir, # pyright: ignore[reportPrivateImportUsage] +) + +from ..runtime.config import Config +from .base_cache import AutotuneCacheBase +from .base_cache import LooseAutotuneCacheKey +from .base_cache import StrictAutotuneCacheKey + +if TYPE_CHECKING: + from ..runtime.kernel import BoundKernel + from .base_search import BaseSearch + +log: logging.Logger = logging.getLogger(__name__) + + +class LocalAutotuneCache(AutotuneCacheBase): + """ + This class implements the local autotune cache, storing the + best config artifact on the local file system either by default + on torch's cache directory, or at a user specified HELION_CACHE_DIR + directory. + It uses the LooseAutotuneCacheKey implementation for the cache key + which takes into account device and source code properties, but does + not account for library level code changes such as Triton, Helion or + PyTorch. Use StrictLocalAutotuneCache to consider these properties. + """ + + def __init__( + self, kernel: BoundKernel, args: Sequence[object], autotuner: BaseSearch + ) -> None: + super().__init__(kernel, args, autotuner) + self.key = self._generate_key() + + def _generate_key(self) -> LooseAutotuneCacheKey: + in_memory_cache_key = self.kernel.kernel._create_bound_kernel_cache_key( + self.kernel, + tuple(self.args), + self.kernel.kernel.specialization_key(self.args), + ) + kernel_source = textwrap.dedent(inspect.getsource(self.kernel.kernel.fn)) + kernel_source_hash = hashlib.sha256(kernel_source.encode("utf-8")).hexdigest() + + hardware = None + runtime_name = None + + for arg in self.args: + if isinstance(arg, torch.Tensor): + device_properties = torch.cuda.get_device_properties(arg.device) + if torch.version.cuda is not None: # pyright: ignore[reportAttributeAccessIssue] + hardware = device_properties.name + runtime_name = torch.version.cuda # pyright: ignore[reportAttributeAccessIssue] + else: + hardware = device_properties.gcnArchName + runtime_name = torch.version.hip # pyright: ignore[reportAttributeAccessIssue] + + assert hardware is not None and runtime_name is not None + return LooseAutotuneCacheKey( + specialization_key=in_memory_cache_key.specialization_key, + extra_results=in_memory_cache_key.extra_results, + kernel_source_hash=kernel_source_hash, + hardware=hardware, + runtime_name=runtime_name, + ) + + def _get_local_cache_path(self) -> Path: + if (user_path := os.environ.get("HELION_CACHE_DIR", None)) is not None: + cache_path = Path(user_path) + else: + cache_path = Path(cache_dir()) / "helion" + + return cache_path / f"{self.key.stable_hash()}.best_config" + + def get(self) -> Config | None: + path = self._get_local_cache_path() + try: + return Config.load(path) + except Exception: + return None + + def put(self, config: Config) -> None: + path = self._get_local_cache_path() + config.save(path) + + +class StrictLocalAutotuneCache(LocalAutotuneCache): + """ + Stricter implementation of the local autotune cache, which takes into + account library level code changes such as Triton, Helion or PyTorch. + """ + + def _generate_key(self) -> StrictAutotuneCacheKey: + loose_key = super()._generate_key() + return StrictAutotuneCacheKey(**vars(loose_key)) diff --git a/helion/runtime/config.py b/helion/runtime/config.py index f12b5c3f..ac6ab985 100644 --- a/helion/runtime/config.py +++ b/helion/runtime/config.py @@ -3,9 +3,11 @@ from collections.abc import Iterator from collections.abc import Mapping import json +import os from pathlib import Path from typing import Literal from typing import cast +import uuid from ..autotuner.config_spec import DEFAULT_NUM_STAGES from ..autotuner.config_spec import DEFAULT_NUM_WARPS @@ -118,7 +120,13 @@ def from_json(cls, json_str: str) -> Config: def save(self, path: str | Path) -> None: """Save the config to a JSON file.""" - Path(path).write_text(self.to_json()) + # Write to temp dir and rename to make the operation atomic + # in case we are in a multithreaded environment + Path(path).parent.mkdir(parents=True, exist_ok=True) + + tmp = Path(path).parent / f"tmp.{uuid.uuid4()!s}" + tmp.write_text(self.to_json()) + os.rename(str(tmp), str(path)) @classmethod def load(cls, path: str | Path) -> Config: diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index 52247cd4..d87b45ba 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -45,6 +45,7 @@ from torch._guards import Source from ..autotuner import ConfigSpec + from ..autotuner.base_cache import BoundKernelInMemoryCacheKey ConfigLike = Config | dict[str, object] @@ -53,12 +54,6 @@ CompiledConfig = Callable[..., _R] -@dataclasses.dataclass(frozen=True) -class BoundKernelInMemoryCacheKey: - specialization_key: tuple[Hashable, ...] - extra_results: tuple[Hashable, ...] - - class Kernel(Generic[_R]): def __init__( self, @@ -114,6 +109,8 @@ def __init__( def _get_bound_kernel_cache_key( self, args: tuple[object, ...], signature: tuple[Hashable, ...] ) -> BoundKernelInMemoryCacheKey | None: + from ..autotuner.base_cache import BoundKernelInMemoryCacheKey + extra_fns = self._specialize_extra.get(signature) if extra_fns is not None: extra_results: tuple[Hashable, ...] = tuple([s(args) for s in extra_fns]) @@ -126,6 +123,8 @@ def _create_bound_kernel_cache_key( args: tuple[object, ...], signature: tuple[Hashable, ...], ) -> BoundKernelInMemoryCacheKey: + from ..autotuner.base_cache import BoundKernelInMemoryCacheKey + self._specialize_extra[signature] = extra_fns = bound_kernel._specialize_extra() extra_results: tuple[Hashable, ...] = tuple([s(args) for s in extra_fns]) return BoundKernelInMemoryCacheKey(signature, extra_results) @@ -458,12 +457,18 @@ def autotune( self.settings.check_autotuning_disabled() from ..autotuner import DifferentialEvolutionSearch + from ..autotuner import LocalAutotuneCache - config = DifferentialEvolutionSearch( + config = LocalAutotuneCache( self, args, - **kwargs, # pyright: ignore[reportArgumentType] + DifferentialEvolutionSearch( + self, + args, + **kwargs, # pyright: ignore[reportArgumentType] + ), ).autotune() + self.set_config(config) return config diff --git a/test/test_cache.py b/test/test_cache.py new file mode 100644 index 00000000..616fc120 --- /dev/null +++ b/test/test_cache.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from pathlib import Path +import unittest + +import torch + +from helion._testing import DEVICE +from helion._testing import TestCase +from helion._testing import import_path +from helion._utils import counters +from helion.autotuner import StrictLocalAutotuneCache +from helion.autotuner.base_search import BaseSearch + +datadir = Path(__file__).parent / "data" +basic_kernels = import_path(datadir / "basic_kernels.py") + + +class BasicSearch(BaseSearch): + def autotune(self): + return self.config_spec.default_config() + + +class TestCache(TestCase): + def test_basic(self): + a = torch.randn(16, device=DEVICE, dtype=torch.bfloat16) + args_a = (a, a) + b = torch.randn(16, device=DEVICE, dtype=torch.float16) + args_b = (b, b) + + # TODO(oulgen): Using a custom autotuner is very verbose, requires passing args 3 times etc + bound_kernel = basic_kernels.add.bind(args_a) + config = StrictLocalAutotuneCache( + bound_kernel, args_a, BasicSearch(bound_kernel, args_a) + ).autotune() + bound_kernel.set_config(config) + result = bound_kernel(*args_a) + torch.testing.assert_close(result, a + a) + + self.assertEqual(counters["autotune"]["cache_miss"], 1) + self.assertEqual(counters["autotune"]["cache_hit"], 0) + self.assertEqual(counters["autotune"]["cache_put"], 1) + + basic_kernels.add.reset() + + bound_kernel = basic_kernels.add.bind(args_a) + config = StrictLocalAutotuneCache( + bound_kernel, args_a, BasicSearch(bound_kernel, args_a) + ).autotune() + bound_kernel.set_config(config) + result = bound_kernel(*args_a) + torch.testing.assert_close(result, a + a) + + self.assertEqual(counters["autotune"]["cache_miss"], 1) + self.assertEqual(counters["autotune"]["cache_hit"], 1) + self.assertEqual(counters["autotune"]["cache_put"], 1) + + basic_kernels.add.reset() + + bound_kernel = basic_kernels.add.bind(args_b) + config = StrictLocalAutotuneCache( + bound_kernel, args_b, BasicSearch(bound_kernel, args_b) + ).autotune() + bound_kernel.set_config(config) + result = bound_kernel(*args_b) + torch.testing.assert_close(result, b + b) + + self.assertEqual(counters["autotune"]["cache_miss"], 2) + self.assertEqual(counters["autotune"]["cache_hit"], 1) + self.assertEqual(counters["autotune"]["cache_put"], 2) + + +if __name__ == "__main__": + unittest.main()