-
Notifications
You must be signed in to change notification settings - Fork 15
[RFC] Implement basic on disk caching #336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
oulgen
wants to merge
1
commit into
main
Choose a base branch
from
oulgen/stack/26
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from __future__ import annotations | ||
|
||
import collections | ||
|
||
counters: collections.defaultdict[str, collections.Counter[str]] = ( | ||
collections.defaultdict(collections.Counter) | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
from __future__ import annotations | ||
|
||
import abc | ||
import dataclasses | ||
import functools | ||
import hashlib | ||
import logging | ||
import os | ||
from typing import TYPE_CHECKING | ||
from typing import Hashable | ||
from typing import Sequence | ||
|
||
from .._utils import counters | ||
|
||
if TYPE_CHECKING: | ||
from ..runtime.config import Config | ||
from ..runtime.kernel import BoundKernel | ||
from .base_search import BaseSearch | ||
|
||
log: logging.Logger = logging.getLogger(__name__) | ||
|
||
|
||
@functools.cache | ||
def helion_key() -> str: | ||
from torch._inductor.codecache import build_code_hash | ||
|
||
here = os.path.abspath(__file__) | ||
helion_path = os.path.dirname(os.path.dirname(here)) | ||
|
||
combined_hash = hashlib.sha256() | ||
build_code_hash([helion_path], "", combined_hash) | ||
return combined_hash.hexdigest() | ||
|
||
|
||
@functools.cache | ||
def torch_key_wrapper() -> str: | ||
from torch._inductor.codecache import torch_key | ||
|
||
return torch_key().hex() | ||
|
||
|
||
@functools.cache | ||
def triton_key_wrapper() -> str: | ||
from torch._inductor.runtime.triton_compat import triton_key | ||
|
||
return triton_key() | ||
|
||
|
||
class CacheKeyBase: | ||
""" | ||
Base class to provide utility functions to all cache key dataclasses | ||
""" | ||
|
||
def stable_hash(self) -> str: | ||
return hashlib.sha256(repr(self).encode("utf-8")).hexdigest() | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class BoundKernelInMemoryCacheKey(CacheKeyBase): | ||
""" | ||
Default in memory cache key. | ||
|
||
This key includes: | ||
|
||
specialization_key: Information about all kernel inputs. | ||
For tensors this means their device, shape, size etc. | ||
extra_results: Information regarding `hl.specialize` decisions | ||
""" | ||
|
||
specialization_key: tuple[Hashable, ...] | ||
extra_results: tuple[Hashable, ...] | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class LooseAutotuneCacheKey(BoundKernelInMemoryCacheKey): | ||
""" | ||
Autotune Cache key to use for most use cases. | ||
|
||
This key includes (in addition to BoundKernelInMemoryCacheKey): | ||
|
||
kernel_source_hash: Hash of source code of input Helion kernel | ||
hardware: Hardware of the input device | ||
runtime_name: Version of the cuda/rocm arch | ||
""" | ||
|
||
kernel_source_hash: str | ||
hardware: str | ||
runtime_name: str | ||
|
||
def stable_hash(self) -> str: | ||
return hashlib.sha256(repr(self).encode("utf-8")).hexdigest() | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class StrictAutotuneCacheKey(LooseAutotuneCacheKey): | ||
""" | ||
Autotune Cache key to use for utmost strictness in terms of re-autotuning | ||
when library source code changes. | ||
|
||
This key includes (in addition to StrictAutotuneCacheKey): | ||
|
||
helion_key: Hash of source code of Helion | ||
torch_key: Hash of source code of PyTorch | ||
triton_key: Hash of source code of Triton | ||
""" | ||
|
||
helion_key: str = dataclasses.field(default_factory=helion_key) | ||
torch_key: str = dataclasses.field(default_factory=torch_key_wrapper) | ||
triton_key: str = dataclasses.field(default_factory=triton_key_wrapper) | ||
|
||
|
||
class AutotuneCacheBase(abc.ABC): | ||
""" | ||
Abstract base class that all autotune caches need to implement. | ||
Any user defined cache will need to extend this class, and | ||
provide implementations for get and put methods. | ||
""" | ||
|
||
def __init__( | ||
self, kernel: BoundKernel, args: Sequence[object], autotuner: BaseSearch | ||
) -> None: | ||
self.autotuner = autotuner | ||
self.kernel = kernel | ||
self.args = args | ||
|
||
@abc.abstractmethod | ||
def get(self) -> Config | None: | ||
raise NotImplementedError | ||
|
||
@abc.abstractmethod | ||
def put(self, config: Config) -> None: | ||
raise NotImplementedError | ||
|
||
def autotune(self) -> Config: | ||
if (config := self.get()) is not None: | ||
counters["autotune"]["cache_hit"] += 1 | ||
log.debug("cache hit: %s", str(config)) | ||
return config | ||
|
||
counters["autotune"]["cache_miss"] += 1 | ||
log.debug("cache miss") | ||
|
||
config = self.autotuner.autotune() | ||
|
||
self.put(config) | ||
counters["autotune"]["cache_put"] += 1 | ||
log.debug("cache put: %s", str(config)) | ||
|
||
return config |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from __future__ import annotations | ||
|
||
import hashlib | ||
import inspect | ||
import logging | ||
import os | ||
from pathlib import Path | ||
import textwrap | ||
from typing import TYPE_CHECKING | ||
from typing import Sequence | ||
|
||
import torch | ||
|
||
from ..runtime.config import Config | ||
from .base_cache import AutotuneCacheBase | ||
from .base_cache import LooseAutotuneCacheKey | ||
from .base_cache import StrictAutotuneCacheKey | ||
|
||
if TYPE_CHECKING: | ||
from ..runtime.kernel import BoundKernel | ||
from .base_search import BaseSearch | ||
|
||
log: logging.Logger = logging.getLogger(__name__) | ||
|
||
|
||
class LocalAutotuneCache(AutotuneCacheBase): | ||
""" | ||
This class implements the local autotune cache, storing the | ||
best config artifact on the local file system either by default | ||
on torch's cache directory, or at a user specified HELION_CACHE_DIR | ||
directory. | ||
It uses the LooseAutotuneCacheKey implementation for the cache key | ||
which takes into account device and source code properties, but does | ||
not account for library level code changes such as Triton, Helion or | ||
PyTorch. Use StrictLocalAutotuneCache to consider these properties. | ||
""" | ||
|
||
def __init__( | ||
self, kernel: BoundKernel, args: Sequence[object], autotuner: BaseSearch | ||
) -> None: | ||
super().__init__(kernel, args, autotuner) | ||
self.key = self._generate_key() | ||
|
||
def _generate_key(self) -> LooseAutotuneCacheKey: | ||
in_memory_cache_key = self.kernel.kernel._create_bound_kernel_cache_key( | ||
self.kernel, | ||
tuple(self.args), | ||
self.kernel.kernel.specialization_key(self.args), | ||
) | ||
kernel_source = textwrap.dedent(inspect.getsource(self.kernel.kernel.fn)) | ||
kernel_source_hash = hashlib.sha256(kernel_source.encode("utf-8")).hexdigest() | ||
|
||
hardware = None | ||
runtime_name = None | ||
|
||
for arg in self.args: | ||
if isinstance(arg, torch.Tensor): | ||
device_properties = torch.cuda.get_device_properties(arg.device) | ||
if torch.version.cuda is not None: # pyright: ignore[reportAttributeAccessIssue] | ||
hardware = device_properties.name | ||
runtime_name = torch.version.cuda # pyright: ignore[reportAttributeAccessIssue] | ||
else: | ||
hardware = device_properties.gcnArchName | ||
runtime_name = torch.version.hip # pyright: ignore[reportAttributeAccessIssue] | ||
|
||
assert hardware is not None and runtime_name is not None | ||
return LooseAutotuneCacheKey( | ||
specialization_key=in_memory_cache_key.specialization_key, | ||
extra_results=in_memory_cache_key.extra_results, | ||
kernel_source_hash=kernel_source_hash, | ||
hardware=hardware, | ||
runtime_name=runtime_name, | ||
) | ||
|
||
def _get_local_cache_path(self) -> Path: | ||
if (user_path := os.environ.get("HELION_CACHE_DIR", None)) is not None: | ||
cache_path = Path(user_path) | ||
else: | ||
from torch._inductor.runtime.cache_dir_utils import ( | ||
cache_dir, # pyright: ignore[reportPrivateImportUsage] | ||
) | ||
|
||
cache_path = Path(cache_dir()) / "helion" | ||
|
||
return cache_path / f"{self.key.stable_hash()}.best_config" | ||
|
||
def get(self) -> Config | None: | ||
path = self._get_local_cache_path() | ||
try: | ||
return Config.load(path) | ||
except Exception: | ||
return None | ||
|
||
def put(self, config: Config) -> None: | ||
path = self._get_local_cache_path() | ||
config.save(path) | ||
|
||
|
||
class StrictLocalAutotuneCache(LocalAutotuneCache): | ||
""" | ||
Stricter implementation of the local autotune cache, which takes into | ||
account library level code changes such as Triton, Helion or PyTorch. | ||
""" | ||
|
||
def _generate_key(self) -> StrictAutotuneCacheKey: | ||
loose_key = super()._generate_key() | ||
return StrictAutotuneCacheKey(**vars(loose_key)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,6 +45,7 @@ | |
from torch._guards import Source | ||
|
||
from ..autotuner import ConfigSpec | ||
from ..autotuner.base_cache import BoundKernelInMemoryCacheKey | ||
|
||
ConfigLike = Config | dict[str, object] | ||
|
||
|
@@ -53,12 +54,6 @@ | |
CompiledConfig = Callable[..., _R] | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class BoundKernelInMemoryCacheKey: | ||
specialization_key: tuple[Hashable, ...] | ||
extra_results: tuple[Hashable, ...] | ||
|
||
|
||
class Kernel(Generic[_R]): | ||
def __init__( | ||
self, | ||
|
@@ -114,6 +109,8 @@ def __init__( | |
def _get_bound_kernel_cache_key( | ||
self, args: tuple[object, ...], signature: tuple[Hashable, ...] | ||
) -> BoundKernelInMemoryCacheKey | None: | ||
from ..autotuner.base_cache import BoundKernelInMemoryCacheKey | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. import at top of file |
||
|
||
extra_fns = self._specialize_extra.get(signature) | ||
if extra_fns is not None: | ||
extra_results: tuple[Hashable, ...] = tuple([s(args) for s in extra_fns]) | ||
|
@@ -126,6 +123,8 @@ def _create_bound_kernel_cache_key( | |
args: tuple[object, ...], | ||
signature: tuple[Hashable, ...], | ||
) -> BoundKernelInMemoryCacheKey: | ||
from ..autotuner.base_cache import BoundKernelInMemoryCacheKey | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. import at top of file |
||
|
||
self._specialize_extra[signature] = extra_fns = bound_kernel._specialize_extra() | ||
extra_results: tuple[Hashable, ...] = tuple([s(args) for s in extra_fns]) | ||
return BoundKernelInMemoryCacheKey(signature, extra_results) | ||
|
@@ -458,12 +457,18 @@ def autotune( | |
self.settings.check_autotuning_disabled() | ||
|
||
from ..autotuner import DifferentialEvolutionSearch | ||
from ..autotuner import LocalAutotuneCache | ||
|
||
config = DifferentialEvolutionSearch( | ||
config = LocalAutotuneCache( | ||
self, | ||
args, | ||
**kwargs, # pyright: ignore[reportArgumentType] | ||
DifferentialEvolutionSearch( | ||
self, | ||
args, | ||
**kwargs, # pyright: ignore[reportArgumentType] | ||
), | ||
).autotune() | ||
|
||
self.set_config(config) | ||
return config | ||
|
||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inductor imports at top of file throughout PR