-
Notifications
You must be signed in to change notification settings - Fork 15
[RFC] Implement basic on disk caching #336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: oulgen/stack/25
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from __future__ import annotations | ||
|
||
import collections | ||
|
||
counters: collections.defaultdict[str, collections.Counter[str]] = ( | ||
collections.defaultdict(collections.Counter) | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
from __future__ import annotations | ||
|
||
import dataclasses | ||
import functools | ||
import hashlib | ||
import inspect | ||
import logging | ||
import os | ||
from pathlib import Path | ||
import textwrap | ||
from typing import TYPE_CHECKING | ||
from typing import Sequence | ||
|
||
import torch | ||
|
||
from .._utils import counters | ||
from ..runtime.config import Config | ||
|
||
if TYPE_CHECKING: | ||
from ..runtime.kernel import Kernel | ||
|
||
log: logging.Logger = logging.getLogger(__name__) | ||
|
||
""" | ||
TODO(oulgen) | ||
- Allow user defined cache keys that can be passed on @helion.kernel | ||
- Add import/export for set of configs | ||
""" | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class AutotuneCacheKey: | ||
""" | ||
helion_key: Hash of source code of Helion | ||
torch_key: Hash of source code of PyTorch | ||
system_hash: Hash of system information, | ||
including Triton, current device, cuda/rocm arch version | ||
function_source_hash: Hash of source code of input Helion kernel | ||
input_dtypes: dtypes of input tensors | ||
input_shapes: shapes of input tensors | ||
""" | ||
|
||
helion_key: str | ||
torch_key: str | ||
system_hash: str | ||
kernel_source_hash: str | ||
input_dtypes: list[tuple[int, torch.dtype]] | ||
input_shapes: list[tuple[int, torch.Size]] | ||
|
||
def stable_hash(self) -> str: | ||
return hashlib.sha256(repr(self).encode("utf-8")).hexdigest() | ||
|
||
|
||
class AutotuneCache: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. docstring |
||
def __init__(self, kernel: Kernel, args: Sequence[object]) -> None: | ||
self.key: AutotuneCacheKey = AutotuneCache._generate_key(kernel, args) | ||
|
||
@staticmethod | ||
def _generate_key(kernel: Kernel, args: Sequence[object]) -> AutotuneCacheKey: | ||
from torch._inductor.codecache import CacheBase | ||
from torch._inductor.codecache import torch_key | ||
|
||
kernel_source = textwrap.dedent(inspect.getsource(kernel.fn)) | ||
kernel_source_hash = hashlib.sha256(kernel_source.encode("utf-8")).hexdigest() | ||
|
||
input_dtypes = [] | ||
input_shapes = [] | ||
|
||
for idx, a in enumerate(args): | ||
if isinstance(a, torch.Tensor): | ||
input_dtypes.append((idx, a.dtype)) | ||
input_shapes.append((idx, a.shape)) | ||
Comment on lines
+69
to
+72
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should mimic/reuse some of the key generation infra in kernel.py (for the in-memory cache). IMO the memory key should be a subset of the disk key. I think right now the memory key includes some extra stuff. |
||
|
||
return AutotuneCacheKey( | ||
helion_key=helion_key(), | ||
torch_key=torch_key().hex(), | ||
system_hash=CacheBase.get_system()["hash"], | ||
kernel_source_hash=kernel_source_hash, | ||
input_dtypes=input_dtypes, | ||
input_shapes=input_shapes, | ||
) | ||
|
||
def _get_cache_key(self) -> str: | ||
return self.key.stable_hash() | ||
|
||
def _get_local_cache_path(self) -> Path: | ||
from torch._inductor.runtime.cache_dir_utils import ( | ||
cache_dir, # pyright: ignore[reportPrivateImportUsage] | ||
) | ||
|
||
return Path(cache_dir()) / "helion" / f"{self._get_cache_key()}.best_config" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be useful to let users specify a filename for the cache so they can more easily distribute it to production environments. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO on L27 is about importing and exporting a file with all the autotune settings. Do you think that would be good enough for should I add an env option to control cache directory? FWIW we can just do both |
||
|
||
def get(self) -> Config | None: | ||
path = self._get_local_cache_path() | ||
try: | ||
config = Config.load(path) | ||
log.debug("Cache hit on config at %s", path) | ||
counters["autotune"]["cache_hit"] += 1 | ||
return config | ||
except Exception: | ||
log.debug("Cache miss on config at %s", path) | ||
counters["autotune"]["cache_miss"] += 1 | ||
return None | ||
|
||
def put(self, config: Config) -> None: | ||
path = self._get_local_cache_path() | ||
config.save(path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we do the save to a temp file and rename trick here so the change is atomic. I thinking about multi-process races. |
||
log.debug("Cache write of config at %s", path) | ||
counters["autotune"]["cache_put"] += 1 | ||
|
||
|
||
@functools.cache | ||
def helion_key() -> str: | ||
from torch._inductor.codecache import build_code_hash | ||
|
||
here = os.path.abspath(__file__) | ||
helion_path = os.path.dirname(os.path.dirname(here)) | ||
|
||
combined_hash = hashlib.sha256() | ||
build_code_hash([helion_path], "", combined_hash) | ||
return combined_hash.hexdigest() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from __future__ import annotations | ||
|
||
from pathlib import Path | ||
import unittest | ||
from unittest import mock | ||
|
||
import torch | ||
|
||
from helion._testing import DEVICE | ||
from helion._testing import TestCase | ||
from helion._testing import import_path | ||
from helion._utils import counters | ||
|
||
datadir = Path(__file__).parent / "data" | ||
basic_kernels = import_path(datadir / "basic_kernels.py") | ||
|
||
|
||
class BasicSearch: | ||
def __init__(self, bound_kernel, *args, **kwargs): | ||
self.bound_kernel = bound_kernel | ||
|
||
def autotune(self): | ||
return self.bound_kernel.config_spec.default_config() | ||
|
||
|
||
class TestCache(TestCase): | ||
@mock.patch("helion.autotuner.DifferentialEvolutionSearch", new=BasicSearch) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add a cleaner way for people to set the autotuner to use? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea, lets define an API that people can set. Out of the scope of this PR but I can look into it. Filed #337 |
||
def test_basic(self): | ||
a = torch.randn(16, device=DEVICE, dtype=torch.bfloat16) | ||
b = torch.randn(16, device=DEVICE, dtype=torch.float16) | ||
|
||
result = basic_kernels.add(a, a) | ||
torch.testing.assert_close(result, a + a) | ||
|
||
self.assertEqual(counters["autotune"]["cache_miss"], 1) | ||
self.assertEqual(counters["autotune"]["cache_hit"], 0) | ||
self.assertEqual(counters["autotune"]["cache_put"], 1) | ||
|
||
basic_kernels.add.reset() | ||
|
||
result = basic_kernels.add(a, a) | ||
torch.testing.assert_close(result, a + a) | ||
|
||
self.assertEqual(counters["autotune"]["cache_miss"], 1) | ||
self.assertEqual(counters["autotune"]["cache_hit"], 1) | ||
self.assertEqual(counters["autotune"]["cache_put"], 1) | ||
|
||
basic_kernels.add.reset() | ||
|
||
result = basic_kernels.add(b, b) | ||
torch.testing.assert_close(result, b + b) | ||
|
||
self.assertEqual(counters["autotune"]["cache_miss"], 2) | ||
self.assertEqual(counters["autotune"]["cache_hit"], 1) | ||
self.assertEqual(counters["autotune"]["cache_put"], 2) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to include these by default?
I was thinking for this use case, where autotuning is very expensive and caching is off by default, it may make sense to default to a more minimal key.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My take is correctness is utmost important, and by default we should have as strict key as possible because if we are writing to the file system, people update their helion/triton/torch version, we shouldnt give them bad autotuning results.
Having said this, on L26, one of the TODOs is giving the users customization over the cache key. With that customization, we can allow users to choose to omit one or more of these by-default-included-strict requirements.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think baseline here is users copy and pasting the config into their code. This makes the config a BC surface where we need to make sure old configs work with new versions of Helion/PyTorch.
I'm a lot more worried about surprise re-autotuning.