Skip to content

Commit 7bd22d1

Browse files
committed
[RFC] Implement basic on disk caching
stack-info: PR: #336, branch: oulgen/stack/26
1 parent bc3438b commit 7bd22d1

File tree

8 files changed

+220
-7
lines changed

8 files changed

+220
-7
lines changed

.github/workflows/lint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ jobs:
4141
4242
- name: Install PyTorch
4343
run: |
44-
pip3 install torch --index-url https://download.pytorch.org/whl/cu128
44+
pip3 install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128
4545
4646
- name: Install lint dependencies
4747
run: ./lint.sh install

helion/_testing.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import collections
4+
import contextlib
45
import importlib
56
import inspect
67
import operator
@@ -15,6 +16,7 @@
1516
import torch
1617
from triton.testing import do_bench
1718

19+
from ._utils import counters
1820
from .runtime.config import Config
1921
from helion._compat import get_tensor_descriptor_fn_name
2022

@@ -289,6 +291,20 @@ def tearDownClass(cls) -> None:
289291
super().tearDownClass()
290292
del cls._expected_journal
291293

294+
def setUp(self) -> None:
295+
super().setUp()
296+
self._test_stack = contextlib.ExitStack()
297+
298+
from torch._inductor.utils import fresh_cache
299+
300+
self._test_stack.enter_context(fresh_cache())
301+
302+
counters.clear()
303+
304+
def tearDown(self) -> None:
305+
super().tearDown()
306+
self._test_stack.close()
307+
292308
def assertExpectedJournal(self, value: str) -> None:
293309
"""
294310
Assert that the given value matches the expected output stored in <testfile>.expected.

helion/_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from __future__ import annotations
2+
3+
import collections
4+
5+
counters: collections.defaultdict[str, collections.Counter[str]] = (
6+
collections.defaultdict(collections.Counter)
7+
)

helion/autotuner/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from .cache import AutotuneCache as AutotuneCache
34
from .config_fragment import BooleanFragment as BooleanFragment
45
from .config_fragment import EnumFragment as EnumFragment
56
from .config_fragment import IntegerFragment as IntegerFragment

helion/autotuner/cache.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from __future__ import annotations
2+
3+
import dataclasses
4+
import functools
5+
import hashlib
6+
import inspect
7+
import logging
8+
import os
9+
from pathlib import Path
10+
import textwrap
11+
from typing import TYPE_CHECKING
12+
from typing import Sequence
13+
14+
import torch
15+
16+
from .._utils import counters
17+
from ..runtime.config import Config
18+
19+
if TYPE_CHECKING:
20+
from ..runtime.kernel import Kernel
21+
22+
log: logging.Logger = logging.getLogger(__name__)
23+
24+
"""
25+
TODO(oulgen)
26+
- Allow user defined cache keys that can be passed on @helion.kernel
27+
- Add import/export for set of configs
28+
"""
29+
30+
31+
@dataclasses.dataclass(frozen=True)
32+
class AutotuneCacheKey:
33+
"""
34+
helion_key: Hash of source code of Helion
35+
torch_key: Hash of source code of PyTorch
36+
system_hash: Hash of system information,
37+
including Triton, current device, cuda/rocm arch version
38+
function_source_hash: Hash of source code of input Helion kernel
39+
input_dtypes: dtypes of input tensors
40+
input_shapes: shapes of input tensors
41+
"""
42+
43+
helion_key: str
44+
torch_key: str
45+
system_hash: str
46+
kernel_source_hash: str
47+
input_dtypes: list[tuple[int, torch.dtype]]
48+
input_shapes: list[tuple[int, torch.Size]]
49+
50+
def stable_hash(self) -> str:
51+
return hashlib.sha256(repr(self).encode("utf-8")).hexdigest()
52+
53+
54+
class AutotuneCache:
55+
def __init__(self, kernel: Kernel, args: Sequence[object]) -> None:
56+
self.key: AutotuneCacheKey = AutotuneCache._generate_key(kernel, args)
57+
58+
@staticmethod
59+
def _generate_key(kernel: Kernel, args: Sequence[object]) -> AutotuneCacheKey:
60+
from torch._inductor.codecache import CacheBase
61+
from torch._inductor.codecache import torch_key
62+
63+
kernel_source = textwrap.dedent(inspect.getsource(kernel.fn))
64+
kernel_source_hash = hashlib.sha256(kernel_source.encode("utf-8")).hexdigest()
65+
66+
input_dtypes = []
67+
input_shapes = []
68+
69+
for idx, a in enumerate(args):
70+
if isinstance(a, torch.Tensor):
71+
input_dtypes.append((idx, a.dtype))
72+
input_shapes.append((idx, a.shape))
73+
74+
return AutotuneCacheKey(
75+
helion_key=helion_key(),
76+
torch_key=torch_key().hex(),
77+
system_hash=CacheBase.get_system()["hash"],
78+
kernel_source_hash=kernel_source_hash,
79+
input_dtypes=input_dtypes,
80+
input_shapes=input_shapes,
81+
)
82+
83+
def _get_cache_key(self) -> str:
84+
return self.key.stable_hash()
85+
86+
def _get_local_cache_path(self) -> Path:
87+
from torch._inductor.runtime.cache_dir_utils import (
88+
cache_dir, # pyright: ignore[reportPrivateImportUsage]
89+
)
90+
91+
return Path(cache_dir()) / "helion" / f"{self._get_cache_key()}.best_config"
92+
93+
def get(self) -> Config | None:
94+
path = self._get_local_cache_path()
95+
try:
96+
config = Config.load(path)
97+
log.debug("Cache hit on config at %s", path)
98+
counters["autotune"]["cache_hit"] += 1
99+
return config
100+
except Exception:
101+
log.debug("Cache miss on config at %s", path)
102+
counters["autotune"]["cache_miss"] += 1
103+
return None
104+
105+
def put(self, config: Config) -> None:
106+
path = self._get_local_cache_path()
107+
config.save(path)
108+
log.debug("Cache write of config at %s", path)
109+
counters["autotune"]["cache_put"] += 1
110+
111+
112+
@functools.cache
113+
def helion_key() -> str:
114+
from torch._inductor.codecache import build_code_hash
115+
116+
here = os.path.abspath(__file__)
117+
helion_path = os.path.dirname(os.path.dirname(here))
118+
119+
combined_hash = hashlib.sha256()
120+
build_code_hash([helion_path], "", combined_hash)
121+
return combined_hash.hexdigest()

helion/runtime/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def from_json(cls, json_str: str) -> Config:
118118

119119
def save(self, path: str | Path) -> None:
120120
"""Save the config to a JSON file."""
121+
Path(path).parent.mkdir(parents=True, exist_ok=True)
121122
Path(path).write_text(self.to_json())
122123

123124
@classmethod

helion/runtime/kernel.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -438,13 +438,21 @@ def autotune(
438438
else:
439439
self.settings.check_autotuning_disabled()
440440

441-
from ..autotuner import DifferentialEvolutionSearch
441+
from ..autotuner import AutotuneCache
442442

443-
config = DifferentialEvolutionSearch(
444-
self,
445-
args,
446-
**kwargs, # pyright: ignore[reportArgumentType]
447-
).autotune()
443+
cache = AutotuneCache(self.kernel, args)
444+
config = cache.get()
445+
446+
if config is None:
447+
from ..autotuner import DifferentialEvolutionSearch
448+
449+
config = DifferentialEvolutionSearch(
450+
self,
451+
args,
452+
**kwargs, # pyright: ignore[reportArgumentType]
453+
).autotune()
454+
455+
cache.put(config)
448456
self.set_config(config)
449457
return config
450458

test/test_cache.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from __future__ import annotations
2+
3+
from pathlib import Path
4+
import unittest
5+
from unittest import mock
6+
7+
import torch
8+
9+
from helion._testing import DEVICE
10+
from helion._testing import TestCase
11+
from helion._testing import import_path
12+
from helion._utils import counters
13+
14+
datadir = Path(__file__).parent / "data"
15+
basic_kernels = import_path(datadir / "basic_kernels.py")
16+
17+
18+
class BasicSearch:
19+
def __init__(self, bound_kernel, *args, **kwargs):
20+
self.bound_kernel = bound_kernel
21+
22+
def autotune(self):
23+
return self.bound_kernel.config_spec.default_config()
24+
25+
26+
class TestCache(TestCase):
27+
@mock.patch("helion.autotuner.DifferentialEvolutionSearch", new=BasicSearch)
28+
def test_basic(self):
29+
a = torch.randn(16, device=DEVICE, dtype=torch.bfloat16)
30+
b = torch.randn(16, device=DEVICE, dtype=torch.float16)
31+
32+
result = basic_kernels.add(a, a)
33+
torch.testing.assert_close(result, a + a)
34+
35+
self.assertEqual(counters["autotune"]["cache_miss"], 1)
36+
self.assertEqual(counters["autotune"]["cache_hit"], 0)
37+
self.assertEqual(counters["autotune"]["cache_put"], 1)
38+
39+
basic_kernels.add.reset()
40+
41+
result = basic_kernels.add(a, a)
42+
torch.testing.assert_close(result, a + a)
43+
44+
self.assertEqual(counters["autotune"]["cache_miss"], 1)
45+
self.assertEqual(counters["autotune"]["cache_hit"], 1)
46+
self.assertEqual(counters["autotune"]["cache_put"], 1)
47+
48+
basic_kernels.add.reset()
49+
50+
result = basic_kernels.add(b, b)
51+
torch.testing.assert_close(result, b + b)
52+
53+
self.assertEqual(counters["autotune"]["cache_miss"], 2)
54+
self.assertEqual(counters["autotune"]["cache_hit"], 1)
55+
self.assertEqual(counters["autotune"]["cache_put"], 2)
56+
57+
58+
if __name__ == "__main__":
59+
unittest.main()

0 commit comments

Comments
 (0)