Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions examples/openjourney/conf/inference/t2i.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,30 @@ engine:
device: "cuda"
torch_dtype: float16
transformations:
LogIOTransformation:
log_level: "info"
StateScopeTransformation:
{}
TorchCompileTransformation:
options:
mode: "default"
dynamic: true
fullgraph: false
disable: false
passes:
TimestepEmbeddingFlipSineCosinePass:


generate:
prompts: "retro serie of different cars with different colors and shapes, mdjrny-v4 style"
width: 512
height: 512
num_inference_steps: 50
guidance_scale: 7.0
- prompt: "retro serie of different cars with different colors and shapes, mdjrny-v4 style"
width: 512
height: 512
num_inference_steps: 50
guidance_scale: 7.0
- prompt: "retro serie of different cars with different colors and shapes, mdjrny-v4 style"
width: 512
height: 512
num_inference_steps: 50
guidance_scale: 7.0
- prompt: "retro serie of different cars with different colors and shapes, mdjrny-v4 style"
width: 256
height: 256
num_inference_steps: 50
guidance_scale: 7.0

129 changes: 129 additions & 0 deletions flagscale/compilation/inductor_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copied from https://github.com/vllm-project/vllm/blob/6ac5e06f7c5d4658c9fb119826a92d9910730fb4/vllm/compilation/inductor_pass.py.
# # Below is the original copyright:

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import functools
import hashlib
import inspect
import json
import types

from collections.abc import Callable
from contextlib import contextmanager
from typing import Any

import torch

from torch import fx
from torch._inductor.custom_graph_pass import CustomGraphPass
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily

_pass_context = None


class PassContext:
def __init__(self, runtime_shape: int | None):
self.runtime_shape = runtime_shape


def get_pass_context() -> PassContext:
"""Get the current pass context."""
assert _pass_context is not None
return _pass_context


@contextmanager
def pass_context(runtime_shape: int | None):
"""A context manager that stores the current pass context,
usually it is a list of sizes to specialize.
"""
global _pass_context
prev_context = _pass_context
_pass_context = PassContext(runtime_shape)
try:
yield
finally:
_pass_context = prev_context


class InductorPass(CustomGraphPass):
"""
A custom graph pass that uses a hash of its source as the UUID.
This is defined as a convenience and should work in most cases.
"""

def uuid(self) -> Any:
"""
Provide a unique identifier for the pass, used in Inductor code cache.
This should depend on the pass implementation, so that changes to the
pass result in recompilation.
By default, the object source is hashed.
"""
return InductorPass.hash_source(self)

@staticmethod
def hash_source(*srcs: str | Any):
"""
Utility method to hash the sources of functions or objects.
:param srcs: strings or objects to add to the hash.
Objects and functions have their source inspected.
:return:
"""
hasher = hashlib.sha256()
for src in srcs:
if isinstance(src, str):
src_str = src
elif isinstance(src, (types.FunctionType, type)):
src_str = inspect.getsource(src)
else:
# object instance
src_str = inspect.getsource(src.__class__)
hasher.update(src_str.encode("utf-8"))
return hasher.hexdigest()

@staticmethod
def hash_dict(dict_: dict[Any, Any]):
"""
Utility method to hash a dictionary, can alternatively be used for uuid.
:return: A sha256 hash of the json rep of the dictionary.
"""
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).hexdigest()

def is_applicable(self, shape: int | None):
return True


class CallableInductorPass(InductorPass):
"""
This class is a wrapper for a callable that automatically provides an
implementation of the UUID.
"""

def __init__(self, callable: Callable[[fx.Graph], None], uuid: Any | None = None):
self.callable = callable
self._uuid = self.hash_source(callable) if uuid is None else uuid

def __call__(self, graph: torch.fx.Graph):
self.callable(graph)

def uuid(self) -> Any:
return self._uuid


def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]:
"""
Applies a FakeTensorMode context. This is useful when you don't want to
create or run things with real tensors.
"""

@functools.wraps(fn)
def fn_new(*args, **kwargs) -> Any:
with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode():
result = fn(*args, **kwargs)

return result

return fn_new
5 changes: 3 additions & 2 deletions flagscale/inference/diffusion_entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import argparse

from typing import Union
from typing import Any, Union

from omegaconf import DictConfig, ListConfig, OmegaConf
from omegaconf.base import DictKeyType

from flagscale.inference.inference_engine import InferenceEngine
from flagscale.runner.utils import logger
Expand Down Expand Up @@ -61,7 +62,7 @@ def _normalize_runs(gen_cfg):

runs = _normalize_runs(generate_cfg)

for idx, run_cfg in enumerate(runs):
for idx, run_cfg in enumerate[dict[DictKeyType, Any] | dict[str, Any]](runs):
single_cfg = dict(run_cfg)
name_prefix = single_cfg.pop("name", None) or f"sample_{idx}"
outputs = engine.generate(**single_cfg)
Expand Down
7 changes: 6 additions & 1 deletion flagscale/inference/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,16 @@ def apply_transformations(self) -> None:
`Transformation` will be applied in the EXACT order as specified in the config.
"""

# TODO(yupu): run preflight/supports check for each transformation
transforms_cfg = self.vconfig.engine.transformations or {}
transformations = create_transformations_from_config(transforms_cfg)
for t in transformations:
if not t.preflight():
raise ValueError(
f"Transformation {t} is not supported: not met the hardware or python package requirements"
)

for name, mod in t.targets(self.backbone):
logger.debug(f"Applying transformation: {t} on {name}")
success = t.apply(mod)
if not success:
raise ValueError(f"Failed to apply transformation: {t} on {name}")
Expand Down
38 changes: 37 additions & 1 deletion flagscale/transformations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,28 @@
from omegaconf import DictConfig

from .diffusion.taylorseer_transformation import TaylorSeerTransformation
from .diffusion.timestep_embedding_flip_sine_cosine_pass import TimestepEmbeddingFlipSineCosinePass
from .diffusion.timestep_tracker_transformation import TimestepTrackerTransformation
from .log_io_transformation import LogIOTransformation
from .state_scope_transformation import StateScopeTransformation
from .torch_compile_transformation import TorchCompileTransformation
from .transformation import Transformation
from flagscale.compilation.inductor_pass import InductorPass

# Registry of supported Transformation classes by their class names.
_TRANSFORMATION_REGISTRY: Dict[str, Type[Transformation]] = {
"LogIOTransformation": LogIOTransformation,
"StateScopeTransformation": StateScopeTransformation,
"TimestepTrackerTransformation": TimestepTrackerTransformation,
"TaylorSeerTransformation": TaylorSeerTransformation,
"TorchCompileTransformation": TorchCompileTransformation,
}

__all__ = ["create_transformations_from_config"]
_PASS_REGISTRY: Dict[str, Type[InductorPass]] = {
"TimestepEmbeddingFlipSineCosinePass": TimestepEmbeddingFlipSineCosinePass
}

__all__ = ["create_transformations_from_config", "create_passes_from_config"]


def create_transformations_from_config(cfg: DictConfig) -> List[Transformation]:
Expand Down Expand Up @@ -48,3 +56,31 @@ def create_transformations_from_config(cfg: DictConfig) -> List[Transformation]:
instances.append(inst)

return instances


def create_passes_from_config(cfg: DictConfig) -> List[InductorPass]:
"""Instantiate passes from the configuration

Args:
cfg: The configuration

Returns:
A list of instantiated passes
"""
instances: List[InductorPass] = []

for name, kwargs in cfg.items():
cls = _PASS_REGISTRY.get(name)
if cls is None:
raise KeyError(
f"Unknown pass class '{name}'. Available: {sorted(_PASS_REGISTRY.keys())}"
)
try:
if kwargs is None:
kwargs = {}
inst = cls(**kwargs)
except TypeError as e:
raise TypeError(f"Failed to instantiate pass '{name}' with kwargs {kwargs}: {e}") from e
instances.append(inst)

return instances
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torch

from torch._inductor.pattern_matcher import (
Arg,
CallFunction,
Match,
PatternMatcherPass,
register_graph_pattern,
)

from flagscale.compilation.inductor_pass import InductorPass
from flagscale.runner.utils import logger

aten = torch.ops.aten
_timestep_embedding_flip_sine_cosine_matcher = PatternMatcherPass(
pass_name="timestep_embedding_flip_sine_cosine_matcher"
)

# Shared inputs we want to capture once and reuse
base = Arg() # %mul_17 feeding both sin and cos
split_idx = Arg() # %floordiv reused by both slices

inner_cat = CallFunction(
aten.cat.default,
[CallFunction(aten.sin.default, base), CallFunction(aten.cos.default, base)],
-1,
_users=2, # The cat feeds both slices
)
slice_hi = CallFunction(aten.slice.Tensor, inner_cat, 1, split_idx, 9223372036854775807)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does 9223372036854775807 mean?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't quite understand why it has to be written this way. Why not just write -1 directly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't dig into it. Maybe it has something to do with how aten.slice represent end of the dim

slice_lo = CallFunction(aten.slice.Tensor, inner_cat, 1, 0, split_idx)
pattern = CallFunction(aten.cat.default, [slice_hi, slice_lo], -1)


@register_graph_pattern(pattern, pass_dict=_timestep_embedding_flip_sine_cosine_matcher)
def _rewrite_timestep_embedding_flip_sine_cosine(match: Match, base, split_idx) -> None:
logger.debug(f"Applying TimestepEmbeddingFlipSineCosinePattern at {match.nodes}")

sin_node, cos_node, inner_cat, slice_hi, slice_lo, cat_node = match.nodes

graph = cat_node.graph

with graph.inserting_before(cat_node):
new_cat = graph.call_function(torch.ops.aten.cat.default, args=([cos_node, sin_node], -1))
new_cat.meta.update(cat_node.meta)

cat_node.replace_all_uses_with(new_cat)

for node in (slice_hi, slice_lo):
if len(node.users) == 0 and node in graph.nodes:
graph.erase_node(node)

if len(inner_cat.users) == 0 and inner_cat in graph.nodes:
graph.erase_node(inner_cat)

if len(cat_node.users) == 0 and cat_node in graph.nodes:
graph.erase_node(cat_node)

graph.eliminate_dead_code()


class TimestepEmbeddingFlipSineCosinePass(InductorPass):
"""
A pass to rewrite the following code snippet from [diffusers](https://github.com/huggingface/diffusers/blob/v0.35.2/src/diffusers/models/embeddings.py#L69):
```
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
```
to the following code snippet:
```
emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)

This pass serves only as a demonstration, no significant performance improvement is expected.
Also, we could register the pattern to post grad matcher pass, in that case, this class would be unnecessary.
"""

def __call__(self, graph: torch.fx.Graph) -> None:
_timestep_embedding_flip_sine_cosine_matcher.apply(graph)
2 changes: 2 additions & 0 deletions flagscale/transformations/log_io_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def __init__(self, log_level: str = "info") -> None:
self._log_level = log_level

def apply(self, model: nn.Module) -> bool:
logger.debug(f"Applying LogIOTransformation to {model.__class__.__name__}")

reg = ModuleHookRegistry.get_or_create_registry(model)
hook = LogIOHook(log_level=self._log_level)
reg.register_hook(hook, "log_io")
Expand Down
Loading
Loading