-
Notifications
You must be signed in to change notification settings - Fork 123
[Diffusion] add TorchCompileTransformation #913
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
legitnull
wants to merge
2
commits into
flagos-ai:main
Choose a base branch
from
legitnull:compile
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,129 @@ | ||
| # Copied from https://github.com/vllm-project/vllm/blob/6ac5e06f7c5d4658c9fb119826a92d9910730fb4/vllm/compilation/inductor_pass.py. | ||
| # # Below is the original copyright: | ||
|
|
||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import functools | ||
| import hashlib | ||
| import inspect | ||
| import json | ||
| import types | ||
|
|
||
| from collections.abc import Callable | ||
| from contextlib import contextmanager | ||
| from typing import Any | ||
|
|
||
| import torch | ||
|
|
||
| from torch import fx | ||
| from torch._inductor.custom_graph_pass import CustomGraphPass | ||
| from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily | ||
|
|
||
| _pass_context = None | ||
|
|
||
|
|
||
| class PassContext: | ||
| def __init__(self, runtime_shape: int | None): | ||
| self.runtime_shape = runtime_shape | ||
|
|
||
|
|
||
| def get_pass_context() -> PassContext: | ||
| """Get the current pass context.""" | ||
| assert _pass_context is not None | ||
| return _pass_context | ||
|
|
||
|
|
||
| @contextmanager | ||
| def pass_context(runtime_shape: int | None): | ||
| """A context manager that stores the current pass context, | ||
| usually it is a list of sizes to specialize. | ||
| """ | ||
| global _pass_context | ||
| prev_context = _pass_context | ||
| _pass_context = PassContext(runtime_shape) | ||
| try: | ||
| yield | ||
| finally: | ||
| _pass_context = prev_context | ||
|
|
||
|
|
||
| class InductorPass(CustomGraphPass): | ||
| """ | ||
| A custom graph pass that uses a hash of its source as the UUID. | ||
| This is defined as a convenience and should work in most cases. | ||
| """ | ||
|
|
||
| def uuid(self) -> Any: | ||
| """ | ||
| Provide a unique identifier for the pass, used in Inductor code cache. | ||
| This should depend on the pass implementation, so that changes to the | ||
| pass result in recompilation. | ||
| By default, the object source is hashed. | ||
| """ | ||
| return InductorPass.hash_source(self) | ||
|
|
||
| @staticmethod | ||
| def hash_source(*srcs: str | Any): | ||
| """ | ||
| Utility method to hash the sources of functions or objects. | ||
| :param srcs: strings or objects to add to the hash. | ||
| Objects and functions have their source inspected. | ||
| :return: | ||
| """ | ||
| hasher = hashlib.sha256() | ||
| for src in srcs: | ||
| if isinstance(src, str): | ||
| src_str = src | ||
| elif isinstance(src, (types.FunctionType, type)): | ||
| src_str = inspect.getsource(src) | ||
| else: | ||
| # object instance | ||
| src_str = inspect.getsource(src.__class__) | ||
| hasher.update(src_str.encode("utf-8")) | ||
| return hasher.hexdigest() | ||
|
|
||
| @staticmethod | ||
| def hash_dict(dict_: dict[Any, Any]): | ||
| """ | ||
| Utility method to hash a dictionary, can alternatively be used for uuid. | ||
| :return: A sha256 hash of the json rep of the dictionary. | ||
| """ | ||
| encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") | ||
| return hashlib.sha256(encoded).hexdigest() | ||
|
|
||
| def is_applicable(self, shape: int | None): | ||
| return True | ||
|
|
||
|
|
||
| class CallableInductorPass(InductorPass): | ||
| """ | ||
| This class is a wrapper for a callable that automatically provides an | ||
| implementation of the UUID. | ||
| """ | ||
|
|
||
| def __init__(self, callable: Callable[[fx.Graph], None], uuid: Any | None = None): | ||
| self.callable = callable | ||
| self._uuid = self.hash_source(callable) if uuid is None else uuid | ||
|
|
||
| def __call__(self, graph: torch.fx.Graph): | ||
| self.callable(graph) | ||
|
|
||
| def uuid(self) -> Any: | ||
| return self._uuid | ||
|
|
||
|
|
||
| def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]: | ||
| """ | ||
| Applies a FakeTensorMode context. This is useful when you don't want to | ||
| create or run things with real tensors. | ||
| """ | ||
|
|
||
| @functools.wraps(fn) | ||
| def fn_new(*args, **kwargs) -> Any: | ||
| with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode(): | ||
| result = fn(*args, **kwargs) | ||
|
|
||
| return result | ||
|
|
||
| return fn_new |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
78 changes: 78 additions & 0 deletions
78
flagscale/transformations/diffusion/timestep_embedding_flip_sine_cosine_pass.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,78 @@ | ||
| import torch | ||
|
|
||
| from torch._inductor.pattern_matcher import ( | ||
| Arg, | ||
| CallFunction, | ||
| Match, | ||
| PatternMatcherPass, | ||
| register_graph_pattern, | ||
| ) | ||
|
|
||
| from flagscale.compilation.inductor_pass import InductorPass | ||
| from flagscale.runner.utils import logger | ||
|
|
||
| aten = torch.ops.aten | ||
| _timestep_embedding_flip_sine_cosine_matcher = PatternMatcherPass( | ||
| pass_name="timestep_embedding_flip_sine_cosine_matcher" | ||
| ) | ||
|
|
||
| # Shared inputs we want to capture once and reuse | ||
| base = Arg() # %mul_17 feeding both sin and cos | ||
| split_idx = Arg() # %floordiv reused by both slices | ||
|
|
||
| inner_cat = CallFunction( | ||
| aten.cat.default, | ||
| [CallFunction(aten.sin.default, base), CallFunction(aten.cos.default, base)], | ||
| -1, | ||
| _users=2, # The cat feeds both slices | ||
| ) | ||
| slice_hi = CallFunction(aten.slice.Tensor, inner_cat, 1, split_idx, 9223372036854775807) | ||
| slice_lo = CallFunction(aten.slice.Tensor, inner_cat, 1, 0, split_idx) | ||
| pattern = CallFunction(aten.cat.default, [slice_hi, slice_lo], -1) | ||
|
|
||
|
|
||
| @register_graph_pattern(pattern, pass_dict=_timestep_embedding_flip_sine_cosine_matcher) | ||
| def _rewrite_timestep_embedding_flip_sine_cosine(match: Match, base, split_idx) -> None: | ||
| logger.debug(f"Applying TimestepEmbeddingFlipSineCosinePattern at {match.nodes}") | ||
|
|
||
| sin_node, cos_node, inner_cat, slice_hi, slice_lo, cat_node = match.nodes | ||
|
|
||
| graph = cat_node.graph | ||
|
|
||
| with graph.inserting_before(cat_node): | ||
| new_cat = graph.call_function(torch.ops.aten.cat.default, args=([cos_node, sin_node], -1)) | ||
| new_cat.meta.update(cat_node.meta) | ||
|
|
||
| cat_node.replace_all_uses_with(new_cat) | ||
|
|
||
| for node in (slice_hi, slice_lo): | ||
| if len(node.users) == 0 and node in graph.nodes: | ||
| graph.erase_node(node) | ||
|
|
||
| if len(inner_cat.users) == 0 and inner_cat in graph.nodes: | ||
| graph.erase_node(inner_cat) | ||
|
|
||
| if len(cat_node.users) == 0 and cat_node in graph.nodes: | ||
| graph.erase_node(cat_node) | ||
|
|
||
| graph.eliminate_dead_code() | ||
|
|
||
|
|
||
| class TimestepEmbeddingFlipSineCosinePass(InductorPass): | ||
| """ | ||
| A pass to rewrite the following code snippet from [diffusers](https://github.com/huggingface/diffusers/blob/v0.35.2/src/diffusers/models/embeddings.py#L69): | ||
| ``` | ||
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) | ||
| if flip_sin_to_cos: | ||
| emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) | ||
| ``` | ||
| to the following code snippet: | ||
| ``` | ||
| emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1) | ||
|
|
||
| This pass serves only as a demonstration, no significant performance improvement is expected. | ||
| Also, we could register the pattern to post grad matcher pass, in that case, this class would be unnecessary. | ||
| """ | ||
|
|
||
| def __call__(self, graph: torch.fx.Graph) -> None: | ||
| _timestep_embedding_flip_sine_cosine_matcher.apply(graph) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does 9223372036854775807 mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still don't quite understand why it has to be written this way. Why not just write -1 directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't dig into it. Maybe it has something to do with how aten.slice represent end of the dim