Skip to content

[RFC] Implement basic on disk caching #336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: oulgen/stack/25
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
- name: Install PyTorch
run: |
pip3 install torch --index-url https://download.pytorch.org/whl/cu128
pip3 install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128
- name: Install lint dependencies
run: ./lint.sh install
Expand Down
16 changes: 16 additions & 0 deletions helion/_testing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import collections
import contextlib
import importlib
import inspect
import operator
Expand All @@ -15,6 +16,7 @@
import torch
from triton.testing import do_bench

from ._utils import counters
from .runtime.config import Config
from helion._compat import get_tensor_descriptor_fn_name

Expand Down Expand Up @@ -289,6 +291,20 @@ def tearDownClass(cls) -> None:
super().tearDownClass()
del cls._expected_journal

def setUp(self) -> None:
super().setUp()
self._test_stack = contextlib.ExitStack()

from torch._inductor.utils import fresh_cache

self._test_stack.enter_context(fresh_cache())

counters.clear()

def tearDown(self) -> None:
super().tearDown()
self._test_stack.close()

def assertExpectedJournal(self, value: str) -> None:
"""
Assert that the given value matches the expected output stored in <testfile>.expected.
Expand Down
7 changes: 7 additions & 0 deletions helion/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from __future__ import annotations

import collections

counters: collections.defaultdict[str, collections.Counter[str]] = (
collections.defaultdict(collections.Counter)
)
1 change: 1 addition & 0 deletions helion/autotuner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from .cache import AutotuneCache as AutotuneCache
from .config_fragment import BooleanFragment as BooleanFragment
from .config_fragment import EnumFragment as EnumFragment
from .config_fragment import IntegerFragment as IntegerFragment
Expand Down
121 changes: 121 additions & 0 deletions helion/autotuner/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from __future__ import annotations

import dataclasses
import functools
import hashlib
import inspect
import logging
import os
from pathlib import Path
import textwrap
from typing import TYPE_CHECKING
from typing import Sequence

import torch

from .._utils import counters
from ..runtime.config import Config

if TYPE_CHECKING:
from ..runtime.kernel import Kernel

log: logging.Logger = logging.getLogger(__name__)

"""
TODO(oulgen)
- Allow user defined cache keys that can be passed on @helion.kernel
- Add import/export for set of configs
"""


@dataclasses.dataclass(frozen=True)
class AutotuneCacheKey:
"""
helion_key: Hash of source code of Helion
torch_key: Hash of source code of PyTorch
system_hash: Hash of system information,
including Triton, current device, cuda/rocm arch version
Comment on lines +34 to +37
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to include these by default?

I was thinking for this use case, where autotuning is very expensive and caching is off by default, it may make sense to default to a more minimal key.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My take is correctness is utmost important, and by default we should have as strict key as possible because if we are writing to the file system, people update their helion/triton/torch version, we shouldnt give them bad autotuning results.

Having said this, on L26, one of the TODOs is giving the users customization over the cache key. With that customization, we can allow users to choose to omit one or more of these by-default-included-strict requirements.

What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think baseline here is users copy and pasting the config into their code. This makes the config a BC surface where we need to make sure old configs work with new versions of Helion/PyTorch.

I'm a lot more worried about surprise re-autotuning.

function_source_hash: Hash of source code of input Helion kernel
input_dtypes: dtypes of input tensors
input_shapes: shapes of input tensors
"""

helion_key: str
torch_key: str
system_hash: str
kernel_source_hash: str
input_dtypes: list[tuple[int, torch.dtype]]
input_shapes: list[tuple[int, torch.Size]]

def stable_hash(self) -> str:
return hashlib.sha256(repr(self).encode("utf-8")).hexdigest()


class AutotuneCache:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring

def __init__(self, kernel: Kernel, args: Sequence[object]) -> None:
self.key: AutotuneCacheKey = AutotuneCache._generate_key(kernel, args)

@staticmethod
def _generate_key(kernel: Kernel, args: Sequence[object]) -> AutotuneCacheKey:
from torch._inductor.codecache import CacheBase
from torch._inductor.codecache import torch_key

kernel_source = textwrap.dedent(inspect.getsource(kernel.fn))
kernel_source_hash = hashlib.sha256(kernel_source.encode("utf-8")).hexdigest()

input_dtypes = []
input_shapes = []

for idx, a in enumerate(args):
if isinstance(a, torch.Tensor):
input_dtypes.append((idx, a.dtype))
input_shapes.append((idx, a.shape))
Comment on lines +69 to +72
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should mimic/reuse some of the key generation infra in kernel.py (for the in-memory cache). IMO the memory key should be a subset of the disk key. I think right now the memory key includes some extra stuff.


return AutotuneCacheKey(
helion_key=helion_key(),
torch_key=torch_key().hex(),
system_hash=CacheBase.get_system()["hash"],
kernel_source_hash=kernel_source_hash,
input_dtypes=input_dtypes,
input_shapes=input_shapes,
)

def _get_cache_key(self) -> str:
return self.key.stable_hash()

def _get_local_cache_path(self) -> Path:
from torch._inductor.runtime.cache_dir_utils import (
cache_dir, # pyright: ignore[reportPrivateImportUsage]
)

return Path(cache_dir()) / "helion" / f"{self._get_cache_key()}.best_config"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be useful to let users specify a filename for the cache so they can more easily distribute it to production environments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO on L27 is about importing and exporting a file with all the autotune settings. Do you think that would be good enough for should I add an env option to control cache directory?

FWIW we can just do both


def get(self) -> Config | None:
path = self._get_local_cache_path()
try:
config = Config.load(path)
log.debug("Cache hit on config at %s", path)
counters["autotune"]["cache_hit"] += 1
return config
except Exception:
log.debug("Cache miss on config at %s", path)
counters["autotune"]["cache_miss"] += 1
return None

def put(self, config: Config) -> None:
path = self._get_local_cache_path()
config.save(path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do the save to a temp file and rename trick here so the change is atomic. I thinking about multi-process races.

log.debug("Cache write of config at %s", path)
counters["autotune"]["cache_put"] += 1


@functools.cache
def helion_key() -> str:
from torch._inductor.codecache import build_code_hash

here = os.path.abspath(__file__)
helion_path = os.path.dirname(os.path.dirname(here))

combined_hash = hashlib.sha256()
build_code_hash([helion_path], "", combined_hash)
return combined_hash.hexdigest()
1 change: 1 addition & 0 deletions helion/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def from_json(cls, json_str: str) -> Config:

def save(self, path: str | Path) -> None:
"""Save the config to a JSON file."""
Path(path).parent.mkdir(parents=True, exist_ok=True)
Path(path).write_text(self.to_json())

@classmethod
Expand Down
20 changes: 14 additions & 6 deletions helion/runtime/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,13 +438,21 @@ def autotune(
else:
self.settings.check_autotuning_disabled()

from ..autotuner import DifferentialEvolutionSearch
from ..autotuner import AutotuneCache

config = DifferentialEvolutionSearch(
self,
args,
**kwargs, # pyright: ignore[reportArgumentType]
).autotune()
cache = AutotuneCache(self.kernel, args)
config = cache.get()

if config is None:
from ..autotuner import DifferentialEvolutionSearch

config = DifferentialEvolutionSearch(
self,
args,
**kwargs, # pyright: ignore[reportArgumentType]
).autotune()

cache.put(config)
self.set_config(config)
return config

Expand Down
59 changes: 59 additions & 0 deletions test/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations

from pathlib import Path
import unittest
from unittest import mock

import torch

from helion._testing import DEVICE
from helion._testing import TestCase
from helion._testing import import_path
from helion._utils import counters

datadir = Path(__file__).parent / "data"
basic_kernels = import_path(datadir / "basic_kernels.py")


class BasicSearch:
def __init__(self, bound_kernel, *args, **kwargs):
self.bound_kernel = bound_kernel

def autotune(self):
return self.bound_kernel.config_spec.default_config()


class TestCache(TestCase):
@mock.patch("helion.autotuner.DifferentialEvolutionSearch", new=BasicSearch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a cleaner way for people to set the autotuner to use?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, lets define an API that people can set. Out of the scope of this PR but I can look into it. Filed #337

def test_basic(self):
a = torch.randn(16, device=DEVICE, dtype=torch.bfloat16)
b = torch.randn(16, device=DEVICE, dtype=torch.float16)

result = basic_kernels.add(a, a)
torch.testing.assert_close(result, a + a)

self.assertEqual(counters["autotune"]["cache_miss"], 1)
self.assertEqual(counters["autotune"]["cache_hit"], 0)
self.assertEqual(counters["autotune"]["cache_put"], 1)

basic_kernels.add.reset()

result = basic_kernels.add(a, a)
torch.testing.assert_close(result, a + a)

self.assertEqual(counters["autotune"]["cache_miss"], 1)
self.assertEqual(counters["autotune"]["cache_hit"], 1)
self.assertEqual(counters["autotune"]["cache_put"], 1)

basic_kernels.add.reset()

result = basic_kernels.add(b, b)
torch.testing.assert_close(result, b + b)

self.assertEqual(counters["autotune"]["cache_miss"], 2)
self.assertEqual(counters["autotune"]["cache_hit"], 1)
self.assertEqual(counters["autotune"]["cache_put"], 2)


if __name__ == "__main__":
unittest.main()
Loading