diff --git a/setup.sh b/setup.sh index fc3f640e..9beb11d2 100644 --- a/setup.sh +++ b/setup.sh @@ -110,4 +110,4 @@ else fi # Install maxdiffusion -pip3 install -U . || echo "Failed to install maxdiffusion" >&2 +pip3 install -e . || echo "Failed to install maxdiffusion" >&2 diff --git a/src/maxdiffusion/__init__.py b/src/maxdiffusion/__init__.py index 7415ed68..1cfd6293 100644 --- a/src/maxdiffusion/__init__.py +++ b/src/maxdiffusion/__init__.py @@ -373,6 +373,7 @@ _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"] _import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] _import_structure["models.flux.transformers.transformer_flux_flax"] = ["FluxTransformer2DModel"] + _import_structure["models.ltx_video.transformers.transformer3d"] = ["Transformer3DModel"] _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) _import_structure["schedulers"].extend( @@ -453,6 +454,7 @@ from .models.modeling_flax_utils import FlaxModelMixin from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel + from .models.ltx_video.transformers.transformer3d import Transformer3DModel from .models.vae_flax import FlaxAutoencoderKL from .pipelines import FlaxDiffusionPipeline from .schedulers import ( diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index dd78eaa6..7366cf86 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -213,8 +213,13 @@ def load_state_if_possible( max_logging.log(f"restoring from this run's directory latest step {latest_step}") try: if not enable_single_replica_ckpt_restoring: - item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} - return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) + # item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} + # return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) #currently changed to this + if checkpoint_item == " ": + return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state)) + else: + item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} + return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) #currently changed to this def map_to_pspec(data): pspec = data.sharding.spec diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml new file mode 100644 index 00000000..b6052988 --- /dev/null +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -0,0 +1,51 @@ +#hardware +hardware: 'tpu' +skip_jax_distributed_system: False + +jax_cache_dir: '' +weights_dtype: 'bfloat16' +activations_dtype: 'bfloat16' + + +run_name: '' +output_dir: 'ltx-video-output' +save_config_to_gcs: False + +#parallelism +mesh_axes: ['data', 'fsdp', 'tensor'] +logical_axis_rules: [ + ['batch', 'data'], + ['activation_batch', ['data','fsdp']], + ['activation_heads', 'tensor'], + ['activation_kv', 'tensor'], + ['mlp','tensor'], + ['embed','fsdp'], + ['heads', 'tensor'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], + ['conv_out', 'fsdp'], + ] +data_sharding: [['data', 'fsdp', 'tensor']] +dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: -1 +dcn_tensor_parallelism: 1 +ici_data_parallelism: -1 +ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_tensor_parallelism: 1 + + + + +learning_rate_schedule_steps: -1 +max_train_steps: 500 #TODO: change this +pretrained_model_name_or_path: '' +unet_checkpoint: '' +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tf' +cache_latents_text_encoder_outputs: True +per_device_batch_size: 1 +compile_topology_num_slices: -1 +quantization_local_shard_count: -1 +jit_initializers: True +enable_single_replica_ckpt_restoring: False diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py new file mode 100644 index 00000000..8fef5280 --- /dev/null +++ b/src/maxdiffusion/generate_ltx_video.py @@ -0,0 +1,92 @@ +from json import encoder +from absl import app +from typing import Sequence +import jax +from flax import linen as nn +import json +from flax.linen import partitioning as nn_partitioning +from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel +import os +import functools +import jax.numpy as jnp +from maxdiffusion import pyconfig +from maxdiffusion.max_utils import ( + create_device_mesh, + setup_initial_state, + get_memory_allocations, +) +from jax.sharding import Mesh, PartitionSpec as P +import orbax.checkpoint as ocp + + +def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids): + print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) + print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype) + print("latents.shape: ", latents.shape, latents.dtype) + print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype) + print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype) + print("encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype) + +def run(config): + key = jax.random.PRNGKey(0) + + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + base_dir = os.path.dirname(__file__) + + ##load in model config + config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json") + with open(config_path, "r") as f: + model_config = json.load(f) + relative_ckpt_path = model_config["ckpt_path"] + + ignored_keys = ["_class_name", "_diffusers_version", "_name_or_path", "causal_temporal_positioning", "in_channels", "ckpt_path"] + in_channels = model_config["in_channels"] + for name in ignored_keys: + if name in model_config: + del model_config[name] + + + transformer = Transformer3DModel(**model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh) + transformer_param_shapes = transformer.init_weights(in_channels, model_config['caption_channels'], eval_only = True) #use this to test! + + weights_init_fn = functools.partial( + transformer.init_weights, + in_channels, + model_config['caption_channels'], + eval_only = True + ) + + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) + + checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path) + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + checkpoint_manager=checkpoint_manager, + checkpoint_item=" ", + model_params=None, + training=False, + ) + + + + + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + run(pyconfig.config) + + +if __name__ == "__main__": + app.run(main) + + + + + diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index fab895f9..c04aab34 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -402,7 +402,11 @@ def setup_initial_state( config.enable_single_replica_ckpt_restoring, ) if state: - state = state[checkpoint_item] + ###!Edited + if checkpoint_item == " ": + state = state + else: + state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") init_train_state_partial = functools.partial( diff --git a/src/maxdiffusion/models/__init__.py b/src/maxdiffusion/models/__init__.py index 95861e24..f03bc306 100644 --- a/src/maxdiffusion/models/__init__.py +++ b/src/maxdiffusion/models/__init__.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING -from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available +from maxdiffusion.utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available _import_structure = {} @@ -32,7 +32,7 @@ from .vae_flax import FlaxAutoencoderKL from .lora import * from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel - + from .ltx_video.transformers.transformer3d import Transformer3DModel else: import sys diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 2f894605..aeea037f 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -1188,4 +1188,4 @@ def setup(self): def __call__(self, hidden_states, deterministic=True): hidden_states = self.proj(hidden_states) hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) - return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic) + return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic) \ No newline at end of file diff --git a/src/maxdiffusion/models/ltx_video/__init__.py b/src/maxdiffusion/models/ltx_video/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py new file mode 100644 index 00000000..f32cc945 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/gradient_checkpoint.py @@ -0,0 +1,70 @@ +from enum import Enum, auto +from typing import Optional + +import jax +from flax import linen as nn + +SKIP_GRADIENT_CHECKPOINT_KEY = "skip" + + +class GradientCheckpointType(Enum): + """ + Defines the type of the gradient checkpoint we will have + + NONE - means no gradient checkpoint + FULL - means full gradient checkpoint, wherever possible (minimum memory usage) + MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, + except for ones that involve batch dimension - that means that all attention and projection + layers will have gradient checkpoint, but not the backward with respect to the parameters + """ + + NONE = auto() + FULL = auto() + MATMUL_WITHOUT_BATCH = auto() + + @classmethod + def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType": + """ + Constructs the gradient checkpoint type from a string + + Args: + s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None. + + Returns: + GradientCheckpointType: The policy that corresponds to the string + """ + if s is None: + s = "none" + return GradientCheckpointType[s.upper()] + + def to_jax_policy(self): + """ + Converts the gradient checkpoint type to a jax policy + """ + match self: + case GradientCheckpointType.NONE: + return SKIP_GRADIENT_CHECKPOINT_KEY + case GradientCheckpointType.FULL: + return None + case GradientCheckpointType.MATMUL_WITHOUT_BATCH: + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + + def apply(self, module: nn.Module) -> nn.Module: + """ + Applies a gradient checkpoint policy to a module + if no policy is needed, it will return the module as is + + Args: + module (nn.Module): the module to apply the policy to + + Returns: + nn.Module: the module with the policy applied + """ + policy = self.to_jax_policy() + if policy == SKIP_GRADIENT_CHECKPOINT_KEY: + return module + return nn.remat( # pylint: disable=invalid-name + module, + prevent_cse=False, + policy=policy, + ) diff --git a/src/maxdiffusion/models/ltx_video/linear.py b/src/maxdiffusion/models/ltx_video/linear.py new file mode 100644 index 00000000..4fdc1444 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/linear.py @@ -0,0 +1,109 @@ +from typing import Union, Iterable, Tuple, Optional, Callable + +import numpy as np +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax.linen.initializers import lecun_normal + + +Shape = Tuple[int, ...] +Initializer = Callable[[jax.random.PRNGKey, Shape, jax.numpy.dtype], jax.Array] +InitializerAxis = Union[int, Shape] + + +def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: + # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. + return tuple(ax if ax >= 0 else ndim + ax for ax in axes) + + +def _canonicalize_tuple(x): + if isinstance(x, Iterable): + return tuple(x) + else: + return (x,) + + +NdInitializer = Callable[[jax.random.PRNGKey, Shape, jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] +KernelInitializer = Callable[[jax.random.PRNGKey, Shape, jnp.dtype, InitializerAxis, InitializerAxis], jax.Array] + + +class DenseGeneral(nn.Module): + """A linear transformation with flexible axes. + + Adapted from https://github.com/AI-Hypercomputer/maxtext/blob/4bf3beaa5e721745427bfed09938427e369c2aaf/MaxText/layers/linears.py#L86 + + Attributes: + features: tuple with numbers of output features. + axis: tuple with axes to apply the transformation on. + weight_dtype: the dtype of the weights (default: float32). + dtype: the dtype of the computation (default: float32). + kernel_init: initializer function for the weight matrix. + use_bias: whether to add bias in linear transformation. + bias_norm: whether to add normalization before adding bias. + quant: quantization config, defaults to None implying no quantization. + """ + + features: Union[Iterable[int], int] + axis: Union[Iterable[int], int] = -1 + weight_dtype: jnp.dtype = jnp.float32 + dtype: np.dtype = jnp.float32 + kernel_init: KernelInitializer = lecun_normal() + kernel_axes: Tuple[Optional[str], ...] = () + use_bias: bool = False + matmul_precision: str = "default" + + bias_init: Initializer = jax.nn.initializers.constant(0.0) + + @nn.compact + def __call__(self, inputs: jax.Array) -> jax.Array: + """Applies a linear transformation to the inputs along multiple dimensions. + + Args: + inputs: The nd-array to be transformed. + + Returns: + The transformed input. + """ + + def compute_dot_general(inputs, kernel, axis, contract_ind): + """Computes a dot_general operation that may be quantized.""" + dot_general = jax.lax.dot_general + matmul_precision = jax.lax.Precision(self.matmul_precision) + return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=matmul_precision) + + features = _canonicalize_tuple(self.features) + axis = _canonicalize_tuple(self.axis) + + inputs = jnp.asarray(inputs, self.dtype) + axis = _normalize_axes(axis, inputs.ndim) + + kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features + kernel_in_axis = np.arange(len(axis)) + kernel_out_axis = np.arange(len(axis), len(axis) + len(features)) + kernel = self.param( + "kernel", + nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), + kernel_shape, + self.weight_dtype, + ) + kernel = jnp.asarray(kernel, self.dtype) + + contract_ind = tuple(range(0, len(axis))) + output = compute_dot_general(inputs, kernel, axis, contract_ind) + + if self.use_bias: + bias_axes, bias_shape = ( + self.kernel_axes[-len(features) :], + kernel_shape[-len(features) :], + ) + bias = self.param( + "bias", + nn.with_logical_partitioning(self.bias_init, bias_axes), + bias_shape, + self.weight_dtype, + ) + bias = jnp.asarray(bias, self.dtype) + + output += bias + return output diff --git a/src/maxdiffusion/models/ltx_video/main.py b/src/maxdiffusion/models/ltx_video/main.py new file mode 100644 index 00000000..9ed184cd --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/main.py @@ -0,0 +1,22 @@ + +import argparse +import json +from typing import Any, Dict, Optional +import os +import jax +import jax.numpy as jnp +import jax.lib.xla_extension +import flax +from flax.training import train_state +import torch +import optax +import orbax.checkpoint as ocp +from safetensors.torch import load_file + +from maxdiffusion.models.ltx_video.transformers_pytorch.transformer_pt import Transformer3DModel_PT + +base_dir = os.path.dirname(__file__) +config_path = os.path.join(base_dir, "xora_v1.2-13B-balanced-128.json") +with open(config_path, "r") as f: + model_config = json.load(f) +transformer = Transformer3DModel_PT.from_config(model_config) \ No newline at end of file diff --git a/src/maxdiffusion/models/ltx_video/repeatable_layer.py b/src/maxdiffusion/models/ltx_video/repeatable_layer.py new file mode 100644 index 00000000..d723ca00 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/repeatable_layer.py @@ -0,0 +1,102 @@ +from dataclasses import field +from typing import Any, Callable, Dict, List, Tuple, Optional + +import jax +from flax import linen as nn +from flax.linen import partitioning + + +class RepeatableCarryBlock(nn.Module): + """ + Integrates an input module in a jax carry format + + ergo, the module assumes the role of a building block + and returns both input and output across all blocks + """ + + module: Callable[[Any], nn.Module] + module_init_args: List[Any] + module_init_kwargs: Dict[str, Any] + + @nn.compact + def __call__(self, *args) -> Tuple[jax.Array, None]: + """ + jax carry-op format of block + assumes the input contains an input tensor to the block along with kwargs that might be send to the block + kwargs are assumed to have static role, while the input changes between cycles + + Returns: + Tuple[jax.Array, None]: Output tensor from the block + """ + mod = self.module(*self.module_init_args, **self.module_init_kwargs) + output = mod(*args) + return output, None + + +class RepeatableLayer(nn.Module): + """ + RepeatableLayer will assume a similar role to torch.nn.ModuleList + with the condition that each block has the same graph, and only the parameters differ + + The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation + """ + + module: Callable[[Any], nn.Module] + """ + A Callable function for single block construction + """ + + num_layers: int + """ + The amount of blocks to build + """ + + module_init_args: List[Any] = field(default_factory=list) + """ + args passed to RepeatableLayer.module callable, to support block construction + """ + + module_init_kwargs: Dict[str, Any] = field(default_factory=dict) + """ + kwargs passed to RepeatableLayer.module callable, to support block construction + """ + + pspec_name: Optional[str] = None + """ + Partition spec metadata + """ + + param_scan_axis: int = 0 + """ + The axis that the "layers" will be aggragated on + eg: if a kernel is shaped (8, 16) + N layers will be (N, 8, 16) if param_scan_axis=0 + and (8, N, 16) if param_scan_axis=1 + """ + + @nn.compact + def __call__(self, *args): + + scan_kwargs = {} + if self.pspec_name is not None: + scan_kwargs["metadata_params"] = {nn.PARTITION_NAME: self.pspec_name} + + initializing = self.is_mutable_collection("params") + params_spec = self.param_scan_axis if initializing else partitioning.ScanIn(self.param_scan_axis) + scan_fn = nn.scan( + RepeatableCarryBlock, + variable_axes={ + "params": params_spec, + "cache": 0, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, + }, # Separate params per timestep + split_rngs={"params": True}, + in_axes=(nn.broadcast,) * (len(args) - 1), + length=self.num_layers, + **scan_kwargs, + ) + wrapped_function = scan_fn(self.module, self.module_init_args, self.module_init_kwargs) + x, _ = wrapped_function(*args) + return x diff --git a/src/maxdiffusion/models/ltx_video/transformers/__init__.py b/src/maxdiffusion/models/ltx_video/transformers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/maxdiffusion/models/ltx_video/transformers/activations.py b/src/maxdiffusion/models/ltx_video/transformers/activations.py new file mode 100644 index 00000000..2a4f7180 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/activations.py @@ -0,0 +1,173 @@ +from typing import Optional, Tuple + +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax.linen.initializers import lecun_normal + +from diffusers.utils.deprecation_utils import deprecate + +from maxdiffusion.models.ltx_video.linear import DenseGeneral, KernelInitializer + + +ACTIVATION_FUNCTIONS = { + "swish": jax.nn.silu, + "silu": jax.nn.silu, + "mish": lambda x: x * jax.nn.tanh(jax.nn.softplus(x)), # Mish is not in JAX by default + "gelu": jax.nn.gelu, + "relu": jax.nn.relu, +} + + +@jax.jit +def approximate_gelu(x: jax.Array) -> jax.Array: + """ + Computes Gaussian Error Linear Unit (GELU) activation function + + Args: + x (jax.Array): The input tensor + + jax.Array: The output tensor + """ + # The error function (erf) in GELU asymptotically approaches -1 for very large negative inputs + # sometimes it results in jnp.nan in jax on TPU's, this prevents this behavior + if x.dtype in (jax.numpy.float64,): + x = x.clip(-10, None) + return jax.nn.gelu(x, approximate=True) + +def get_activation(act_fn: str): + """Returns the activation function from string.""" + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] + raise ValueError(f"Unsupported activation function: {act_fn}") + + +class GELU(nn.Module): + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + approximate: str = "none" + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def gelu(self, gate: jax.Array) -> jax.Array: + approximate_to_tanh = self.approximate == "tanh" + if approximate_to_tanh: + return approximate_gelu(gate) + else: + return jax.nn.gelu(gate, approximate=False) + + @nn.compact + def __call__(self, hidden_states): + if self.approximate not in ("none", "tanh"): + raise ValueError(f"approximate must be 'none' or 'tanh', got {self.approximate}") + proj = DenseGeneral( + features=self.dim_out, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + hidden_states = proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + r""" + A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states, *args, **kwargs): + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + proj = DenseGeneral( + features=self.dim_out * 2, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + + hidden_states = proj(hidden_states) + hidden_states, gate = jnp.split(hidden_states, 2, axis=-1) + return hidden_states * jax.nn.gelu(gate, approximate=False) + + +class ApproximateGELU(nn.Module): + r""" + The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this + [paper](https://arxiv.org/abs/1606.08415). + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_in: int + dim_out: int + bias: bool = True + + kernel_axes: Tuple[Optional[str], ...] = () + kernel_init: KernelInitializer = lecun_normal() + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, x): + proj = DenseGeneral( + features=self.dim_out, + use_bias=self.bias, + kernel_axes=self.kernel_axes, + kernel_init=self.kernel_init, + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj", + ) + x = proj(x) + return x * jax.nn.sigmoid(1.702 * x) diff --git a/src/maxdiffusion/models/ltx_video/transformers/adaln.py b/src/maxdiffusion/models/ltx_video/transformers/adaln.py new file mode 100644 index 00000000..1ca38f70 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/adaln.py @@ -0,0 +1,194 @@ +from typing import Dict, Optional, Tuple + +import jax +import jax.nn +import jax.numpy as jnp +from flax import linen as nn + +from maxdiffusion.models.ltx_video.transformers.activations import get_activation +from maxdiffusion.models.ltx_video.linear import DenseGeneral + + +def get_timestep_embedding_multidim( + timesteps: jnp.ndarray, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> jnp.ndarray: + """ + Computes sinusoidal timestep embeddings while preserving the original dimensions. + No reshaping to 1D is performed at any stage. + + Args: + timesteps (jnp.ndarray): A Tensor of arbitrary shape containing timestep values. + embedding_dim (int): The dimension of the output. + flip_sin_to_cos (bool): Whether the embedding order should be `cos, sin` (if True) + or `sin, cos` (if False). + downscale_freq_shift (float): Controls the delta between frequencies between dimensions. + scale (float): Scaling factor applied to the embeddings. + max_period (int): Controls the maximum frequency of the embeddings. + + Returns: + jnp.ndarray: A Tensor of shape (*timesteps.shape, embedding_dim) with positional embeddings. + """ + half_dim = embedding_dim // 2 + exponent = -jnp.log(max_period) * jnp.arange(half_dim, dtype=jnp.float32) + exponent = exponent / (half_dim - downscale_freq_shift) + shape = (1,) * timesteps.ndim + (half_dim,) # (1, 1, ..., 1, half_dim) + emb = jnp.exp(exponent).reshape(*shape) # Expand to match timesteps' shape + emb = nn.with_logical_constraint(emb, ("activation_batch", "activation_norm_length", "activation_embed")) + emb = timesteps[..., None] * emb # Broadcasting to match shape (*timesteps.shape, half_dim) + emb = scale * emb + emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1) # Shape (*timesteps.shape, embedding_dim) + if flip_sin_to_cos: + emb = jnp.concatenate([emb[..., half_dim:], emb[..., :half_dim]], axis=-1) + + return emb + + +class TimestepEmbedding(nn.Module): + in_channels: int + time_embed_dim: int + act_fn: str = "silu" + out_dim: Optional[int] = None + sample_proj_bias: bool = True + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + + def setup(self): + """Initialize layers efficiently""" + self.linear_1 = DenseGeneral( + self.time_embed_dim, + use_bias=self.sample_proj_bias, + kernel_axes=(None, "mlp"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_1", + ) + + self.act = get_activation(self.act_fn) + time_embed_dim_out = self.out_dim if self.out_dim is not None else self.time_embed_dim + self.linear_2 = DenseGeneral( + time_embed_dim_out, + use_bias=self.sample_proj_bias, + kernel_axes=("embed", "mlp"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_2", + ) + + def __call__(self, sample, condition=None): + sample = nn.with_logical_constraint(sample, ("activation_batch", "activation_norm_length", "activation_embed")) + sample = self.linear_1(sample) + sample = self.act(sample) + sample = self.linear_2(sample) + return sample + + +class Timesteps(nn.Module): + num_channels: int + flip_sin_to_cos: bool + downscale_freq_shift: float + scale: int = 1 + + def __call__(self, timesteps: jnp.ndarray) -> jnp.ndarray: + t_emb = get_timestep_embedding_multidim( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class AlphaCombinedTimestepSizeEmbeddings(nn.Module): + """ + + """ + + embedding_dim: int + size_emb_dim: int + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize sub-modules.""" + self.outdim = self.size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, + time_embed_dim=self.embedding_dim, + name="timestep_embedder", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def __call__(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype)) + return timesteps_emb + + +class AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in: https://arxiv.org/abs/2310.00426; Section 2.3. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + embedding_dim: int + embedding_coefficient: int = 6 + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + self.emb = AlphaCombinedTimestepSizeEmbeddings( + self.embedding_dim, + size_emb_dim=self.embedding_dim // 3, + name="emb", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + self.silu = jax.nn.silu + self.linear = DenseGeneral( + self.embedding_coefficient * self.embedding_dim, + use_bias=True, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear", + ) + + def __call__( + self, + timestep: jnp.ndarray, + added_cond_kwargs: Optional[Dict[str, jnp.ndarray]] = None, + batch_size: Optional[int] = None, + hidden_dtype: Optional[jnp.dtype] = None, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Compute AdaLayerNorm-Single modulation. + + Returns: + Tuple: + - Processed embedding after SiLU + linear transformation. + - Original embedded timestep. + """ + embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep diff --git a/src/maxdiffusion/models/ltx_video/transformers/attention.py b/src/maxdiffusion/models/ltx_video/transformers/attention.py new file mode 100644 index 00000000..54398139 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/attention.py @@ -0,0 +1,913 @@ +from functools import partial +import functools +import math +from typing import Any, Dict, Optional, Tuple +from enum import Enum, auto +import jax +import jax.nn as jnn +import jax.numpy as jnp +from jax.ad_checkpoint import checkpoint_name +from jax.experimental.shard_map import shard_map +from jax.experimental.pallas.ops.tpu.flash_attention import ( + flash_attention as jax_flash_attention, + SegmentIds, + BlockSizes, +) + +from flax import linen as nn + +from maxdiffusion.models.ltx_video.linear import DenseGeneral, Initializer +from maxdiffusion.models.ltx_video.transformers.activations import ( + GELU, + GEGLU, + ApproximateGELU, +) + + + +class SkipLayerStrategy(Enum): + AttentionSkip = auto() + AttentionValues = auto() + Residual = auto() + TransformerBlock = auto() + + +class Identity(nn.Module): + def __call__(self, x): + return x + + +class BasicTransformerBlock(nn.Module): + dim: int + num_attention_heads: int + attention_head_dim: int + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + attention_bias: bool = False + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + norm_elementwise_affine: bool = True + adaptive_norm: str = "single_scale_shift" + standardization_norm: str = "layer_norm" + norm_eps: float = 1e-5 + qk_norm: str = None + final_dropout: bool = False + attention_type: str = ("default",) # pylint: disable=unused-argument + ff_inner_dim: Optional[int] = None + ff_bias: bool = True + attention_out_bias: bool = True + use_tpu_flash_attention: bool = True + use_rope: bool = False + ffn_dim_mult: Optional[int] = 4 + attention_op: Optional[nn.Module] = None + sharding_mesh: Optional[jax.sharding.Mesh] = None + + dtype: jax.numpy.dtype = jnp.float32 + weight_dtype: jax.numpy.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + assert self.standardization_norm in ["layer_norm", "rms_norm"] + assert self.adaptive_norm in ["single_scale_shift", "single_scale", "none"] + assert self.use_tpu_flash_attention, "Jax version only use tpu_flash attention." + + if self.standardization_norm == "layer_norm": + make_norm_layer = partial( + nn.LayerNorm, + epsilon=self.norm_eps, + param_dtype=self.weight_dtype, + dtype=self.dtype, + ) + else: + make_norm_layer = partial( + RMSNorm, + epsilon=self.norm_eps, + elementwise_affine=self.norm_elementwise_affine, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("norm",), + ) + + # 1. Self-Attn + self.norm1 = make_norm_layer(name="norm1") + self.attn1 = Attention( + query_dim=self.dim, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + cross_attention_dim=self.cross_attention_dim if self.only_cross_attention else None, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn1", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 2. Cross-Attn + if self.cross_attention_dim is not None or self.double_self_attention: + self.attn2 = Attention( + query_dim=self.dim, + cross_attention_dim=self.cross_attention_dim if not self.double_self_attention else None, + heads=self.num_attention_heads, + dim_head=self.attention_head_dim, + dropout=self.dropout, + bias=self.attention_bias, + upcast_attention=self.upcast_attention, + out_bias=self.attention_out_bias, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + attention_op=self.attention_op, + name="attn2", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + ) + if self.adaptive_norm == "none": + self.attn2_norm = make_norm_layer() + else: + self.attn2 = None + self.attn2_norm = None + + self.norm2 = make_norm_layer(name="norm2") + # 3. Feed-forward + self.ff = FeedForward( + self.dim, + dropout=self.dropout, + activation_fn=self.activation_fn, + final_dropout=self.final_dropout, + inner_dim=self.ff_inner_dim, + bias=self.ff_bias, + mult=self.ffn_dim_mult, + name="ff", + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + # 4. Scale-Shift + if self.adaptive_norm != "none": + num_ada_params = 4 if self.adaptive_norm == "single_scale" else 6 + + def ada_initalizer(key): + return jax.random.normal(key, (num_ada_params, self.dim), dtype=self.weight_dtype) / self.dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(ada_initalizer, ("ada", "embed")), + ) + + def __call__( + self, + hidden_states: jnp.ndarray, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + segment_ids: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_segment_ids: Optional[jnp.ndarray] = None, + timestep: Optional[jnp.ndarray] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[jnp.ndarray] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + ) -> jnp.ndarray: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + print("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + + hidden_states = nn.with_logical_constraint( + hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + hidden_states = checkpoint_name(hidden_states, "basic_transformer_block hidden_states") + + batch_size = hidden_states.shape[0] + + # 0. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + + # Adaptive Norm + if self.adaptive_norm in ["single_scale_shift", "single_scale"]: + assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim] + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None].astype(self.weight_dtype) + timestep.reshape( + batch_size, timestep.shape[1], num_ada_params, -1 + ) + # Moving ada values to computation dtype to prevent dtype promotion + ada_values = ada_values.astype(self.dtype) + ada_values = nn.with_logical_constraint( + ada_values, ("activation_batch", "activation_norm_length", "activation_ada", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 6, axis=2) + ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + scale_msa, gate_msa, scale_mlp, gate_mlp = ( + jnp.squeeze(arr, axis=2) for arr in jnp.split(ada_values, 4, axis=2) + ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + elif self.adaptive_norm == "none": + scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + if norm_hidden_states.shape[1] == 1: + norm_hidden_states = jnp.squeeze(norm_hidden_states, axis=1) + + # 1. Self-Attention + attn_output = self.attn1( + norm_hidden_states, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids if self.only_cross_attention else segment_ids, + sharding_mesh=self.sharding_mesh, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **(cross_attention_kwargs or {}), + ) + + attn_output = nn.with_logical_constraint( + attn_output, ("activation_batch", "activation_norm_length", "activation_embed") + ) + + if gate_msa is not None: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + + # 3. Cross-Attention + if self.attn2 is not None: + attn_input = self.attn2_norm(hidden_states) if self.adaptive_norm == "none" else hidden_states + attn_input = nn.with_logical_constraint( + attn_input, ("activation_batch", "activation_norm_length", "activation_embed") + ) + attn_output = self.attn2( + attn_input, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + segment_ids=segment_ids, + kv_attention_segment_ids=encoder_attention_segment_ids, + sharding_mesh=self.sharding_mesh, + **(cross_attention_kwargs or {}), + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-Forward + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = nn.with_logical_constraint( + norm_hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + + if self.adaptive_norm == "single_scale_shift": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + elif self.adaptive_norm == "single_scale": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + elif self.adaptive_norm == "none": + pass + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + ff_output = self.ff(norm_hidden_states) + ff_output = nn.with_logical_constraint( + ff_output, ("activation_batch", "activation_norm_length", "activation_embed") + ) + if gate_mlp is not None: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = jnp.squeeze(hidden_states, axis=1) + hidden_states = nn.with_logical_constraint( + hidden_states, + ("activation_batch", "activation_norm_length", "activation_embed"), + ) + return hidden_states + + +class Attention(nn.Module): + query_dim: int + cross_attention_dim: Optional[int] = None + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + bias: bool = False + upcast_attention: bool = False + upcast_softmax: bool = False + cross_attention_norm: Optional[str] = None + added_kv_proj_dim: Optional[int] = None + out_bias: bool = True + scale_qk: bool = True + qk_norm: Optional[str] = None + only_cross_attention: bool = False + eps: float = 1e-5 + rescale_output_factor: float = 1.0 + residual_connection: bool = False + out_dim: Optional[int] = None + use_tpu_flash_attention: bool = True + use_rope: bool = False + attention_op: Optional[nn.Module] = None + + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + def setup(self): + """Initialize layers in Flax `setup()`.""" + self.inner_dim = self.out_dim if self.out_dim is not None else self.dim_head * self.heads + self.use_bias = self.bias + self.is_cross_attention = self.cross_attention_dim is not None + self.fused_projections = False + out_dim = self.out_dim if self.out_dim is not None else self.query_dim + self.scale = self.dim_head**-0.5 if self.scale_qk else 1.0 + + # Query and Key Normalization + if self.qk_norm is None: + self.q_norm = Identity() + self.k_norm = Identity() + elif self.qk_norm == "rms_norm": + self.q_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + self.k_norm = RMSNorm(epsilon=self.eps, kernel_axes=("norm",)) + elif self.qk_norm == "layer_norm": + self.q_norm = nn.LayerNorm(epsilon=self.eps) + self.k_norm = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError(f"Unsupported qk_norm method: {self.qk_norm}") + + if out_dim is not None: + self.heads_count = out_dim // self.dim_head + + # Validate parameters + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. " + "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if self.cross_attention_norm is None: + self.norm_cross = None + elif self.cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(epsilon=self.eps) + else: + raise ValueError( + f"Unknown cross_attention_norm: {self.cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'." + ) + + # Linear layers for queries, keys, values + self.to_q = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_q", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv"), + axis=-1, + ) + + if not self.only_cross_attention: + self.to_k = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_k", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + self.to_v = DenseGeneral( + features=(self.inner_dim,), + use_bias=self.bias, + name="to_v", + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + kernel_axes=("embed", "kv_head_dim"), + axis=-1, + ) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Dense(self.inner_dim, name="add_k_proj") + self.add_v_proj = nn.Dense(self.inner_dim, name="add_v_proj") + + self.to_out = [ + DenseGeneral( + features=(out_dim,), + use_bias=self.out_bias, + axis=-1, + kernel_axes=("kv", "embed"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name="to_out.0", + matmul_precision=self.matmul_precision, + ), + nn.Dropout(self.dropout), + ] + + if self.attention_op is not None: + self.attention = self.attention_op + else: + _tpu_available = any(device.platform == "tpu" for device in jax.devices()) + self.attention = AttentionOp() if _tpu_available else ExplicitAttention() + if not _tpu_available: + print("Warning: Running with explicit attention since tpu is not available.") + + def __call__( + self, + hidden_states: jnp.ndarray, + freqs_cis: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + segment_ids: Optional[jnp.ndarray] = None, + kv_attention_segment_ids: Optional[jnp.ndarray] = None, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + skip_layer_mask: Optional[jnp.ndarray] = None, + skip_layer_strategy: Optional[str] = None, + temb: Optional[jnp.ndarray] = None, + deterministic: bool = True, + **cross_attention_kwargs, + ) -> jnp.ndarray: + cross_attention_kwargs = { k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters } + assert cross_attention_kwargs.get("scale", None) is None, "Not supported" + + input_axis_names = ("activation_batch", "activation_length", "activation_embed") + hidden_states = nn.with_logical_constraint(hidden_states, input_axis_names) + if encoder_hidden_states is not None: + encoder_hidden_states = nn.with_logical_constraint(encoder_hidden_states, input_axis_names) + + residual = hidden_states + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = jnp.reshape(hidden_states, (batch_size, channel, height * width)) + hidden_states = jnp.swapaxes(hidden_states, 1, 2) + + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + if skip_layer_mask is not None: + skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1)) + + query = self.to_q(hidden_states) + query = self.q_norm(query) + + if encoder_hidden_states is not None: + if self.norm_cross: + encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) + key = self.to_k(encoder_hidden_states) + key = self.k_norm(key) + else: + encoder_hidden_states = hidden_states + key = self.to_k(hidden_states) + key = self.k_norm(key) + if self.use_rope: + key = apply_rotary_emb(key, freqs_cis) + query = apply_rotary_emb(query, freqs_cis) + + value = self.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + + query = jnp.reshape(query, (batch_size, -1, self.heads, head_dim)) + query = jnp.swapaxes(query, 1, 2) + query = nn.with_logical_constraint( + query, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + query = checkpoint_name(query, "attention query") + + key = jnp.reshape(key, (batch_size, -1, self.heads, head_dim)) + key = jnp.swapaxes(key, 1, 2) + key = nn.with_logical_constraint( + key, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + key = checkpoint_name(key, "attention key") + + value = jnp.reshape(value, (batch_size, -1, self.heads, head_dim)) + value = jnp.swapaxes(value, 1, 2) + value = nn.with_logical_constraint( + value, ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + ) + value = checkpoint_name(value, "attention value") + + assert self.use_tpu_flash_attention, "JAX only support `use_tpu_flash_attention`" + + q_segment_ids = segment_ids + if q_segment_ids is not None: + q_segment_ids = q_segment_ids.astype(jnp.float32) + + if kv_attention_segment_ids is not None and q_segment_ids is None: + q_segment_ids = jnp.ones((batch_size, query.shape[2]), dtype=jnp.float32) + + hidden_states_a = self.attention( + query, key, value, q_segment_ids, kv_attention_segment_ids, sharding_mesh, self.dtype + ) + + hidden_states_a: jax.Array = nn.with_logical_constraint( + hidden_states_a, ("activation_kv_batch", "activation_heads", "activation_length", "activation_kv") + ) + + hidden_states_a = jnp.reshape(jnp.swapaxes(hidden_states_a, 1, 2), (batch_size, -1, self.heads * head_dim)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.AttentionSkip: + hidden_states = hidden_states_a * skip_layer_mask + hidden_states * (1.0 - skip_layer_mask) + else: + hidden_states = hidden_states_a + + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states, deterministic=deterministic) # Dropout + + if input_ndim == 4: + hidden_states = jnp.reshape(jnp.swapaxes(hidden_states, -1, -2), (batch_size, channel, height, width)) + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + skip_layer_mask = jnp.reshape(skip_layer_mask, (batch_size, 1, 1, 1)) + + if self.residual_connection: + if skip_layer_mask is not None and skip_layer_strategy == SkipLayerStrategy.Residual: + hidden_states = hidden_states + residual * skip_layer_mask + else: + hidden_states = hidden_states + residual + + if self.rescale_output_factor != 1.0: + hidden_states = hidden_states / self.rescale_output_factor + hidden_states = checkpoint_name(hidden_states, "attention_output") + + return hidden_states + + def prepare_attention_mask( + self, attention_mask: jnp.ndarray, target_length: int, batch_size: int, out_dim: int = 3 + ) -> jnp.ndarray: + head_size = self.heads_count + if attention_mask is None: + return attention_mask + + current_length = attention_mask.shape[-1] + if current_length != target_length: + remaining_length = target_length - current_length + attention_mask = jnp.pad(attention_mask, ((0, 0), (0, remaining_length)), constant_values=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = jnp.repeat(attention_mask, head_size, axis=0) + elif out_dim == 4: + attention_mask = jnp.expand_dims(attention_mask, axis=1) + attention_mask = jnp.repeat(attention_mask, head_size, axis=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: jnp.ndarray) -> jnp.ndarray: + assert self.norm_cross is not None, "self.norm_cross must be defined to call norm_encoder_hidden_states." + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = jnp.swapaxes(encoder_hidden_states, 1, 2) + else: + raise ValueError("Unknown normalization type for cross-attention.") + + return encoder_hidden_states + + +class AttentionOp(nn.Module): + @nn.compact + def __call__( + self, + q: jax.Array, # [batch_size, heads, q_tokens, hidden_dim] + k: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + v: jax.Array, # [batch_size, heads, kv_tokens, hidden_dim] + q_segment_ids: jax.Array, # [batch_size, q_tokens] + kv_segment_ids: jax.Array, # [batch_size, kv_tokens] + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + block_sizes: Optional[BlockSizes] = None, + ): + if block_sizes is None: + block_sizes = self.default_block_sizes(q, k, dtype) + + scale_factor = 1 / math.sqrt(q.shape[-1]) + + + def partial_flash_attention(q, k, v, q_segment_ids, kv_segment_ids): + s = ( + # flash attention expects segment ids to be float32 + SegmentIds(q_segment_ids.astype(jnp.float32), kv_segment_ids.astype(jnp.float32)) + if q_segment_ids is not None and kv_segment_ids is not None + else None + ) + output = jax_flash_attention( + q, + k, + v, + None, + s, + sm_scale=scale_factor, + block_sizes=block_sizes, + ) + return output + + if sharding_mesh is not None: + if q.ndim != 4: + raise ValueError(f"Expected input with 4 dims, got {q.ndim}.") + if q_segment_ids is not None and q_segment_ids.ndim != 2: + raise ValueError(f"Expected mask with 2 dims, got {q_segment_ids.ndim}.") + # Based on: ("activation_kv_batch", "activation_kv_heads", "activation_length", "activation_kv_head_dim") + # Computation of the spec based on the logical constraints can be found in logical_axes_to_spec.py. + # qkvo_sharding_spec = jax.sharding.PartitionSpec( + # ("data", "fsdp", "fsdp_transpose", "expert"), + # ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), + # None, + # None, + # ) + # qkvo_sharding_spec = jax.sharding.PartitionSpec( + # ("data", "fsdp", "fsdp_transpose", "expert"), + # ("tensor", "tensor_transpose", "sequence", "tensor_sequence"), + # None, + # None, + # ) + qkvo_sharding_spec = jax.sharding.PartitionSpec( + None, + None, + None, + None, + ) + #Based on: ("activation_kv_batch", "activation_length") + # qkv_segment_ids_spec = jax.sharding.PartitionSpec(("data", "fsdp", "fsdp_transpose", "expert"), "sequence") + qkv_segment_ids_spec = jax.sharding.PartitionSpec(None, None) + wrapped_flash_attention = shard_map( + partial_flash_attention, + mesh=sharding_mesh, + in_specs=( + qkvo_sharding_spec, + qkvo_sharding_spec, + qkvo_sharding_spec, + qkv_segment_ids_spec, + qkv_segment_ids_spec, + ), + out_specs=qkvo_sharding_spec, + check_rep=False, + ) + else: + wrapped_flash_attention = partial_flash_attention + + return wrapped_flash_attention( + q, + k, + v, + q_segment_ids, + kv_segment_ids, + ) + + def default_block_sizes(self, q: jax.Array, k: jax.Array, dtype: jnp.dtype = jnp.float32) -> BlockSizes: + """ + Default block sizes for Flash Attention. + + TPU kernel ops runs in grids, the bigger the grid - the more data that is loaded on the SRAM + we want to utilize the SRAM the best we can + + too big grids will cuase cache misses and slow down the computation while the faster SRAM retrieves the other block data + from the slower HBRAM + + a certain balance has to be met to get the best performance + imho, that balance must be computed with the combination of the information supplied by q and k (which will supply query sequence and key/value sequence lengths) + along with the SRAM cache size + + ** SRAM cache size for TPU + V5P - 1MB SRAM per core + + Args: + q (jax.Array): Query tensor to be used + k (jax.Array): Key tensor to be used + + Returns: + BlockSizes: Grid block sizes + """ + max_block_size = 1024 if dtype == jnp.bfloat16 else 512 + return BlockSizes( + block_q=min(max_block_size, q.shape[-2]), + block_k_major=min(max_block_size, k.shape[-2]), + block_k=min(max_block_size, k.shape[-2]), + block_b=min(1, q.shape[0]), + block_q_major_dkv=min(max_block_size, q.shape[-2]), + block_k_major_dkv=min(max_block_size, k.shape[-2]), + block_q_dkv=min(max_block_size, q.shape[-2]), + block_k_dkv=min(max_block_size, k.shape[-2]), + block_q_dq=min(max_block_size, q.shape[-2]), + block_k_dq=min(512, k.shape[-2]), + block_k_major_dq=min(max_block_size, k.shape[-2]), + ) + + +class ExplicitAttention(nn.Module): + def __call__( + self, + q: jax.Array, + k: jax.Array, + v: jax.Array, + q_segment_ids: jax.Array, + kv_segment_ids: jax.Array, + sharding_mesh: Optional[jax.sharding.Mesh] = None, + dtype: jnp.dtype = jnp.float32, + ): + assert sharding_mesh is None, "Explicit attention does not support sharding mesh." + attn_mask = None + if kv_segment_ids is not None: + q_segment_ids_expanded = q_segment_ids[:, None, :, None] + kv_segment_ids_expanded = kv_segment_ids[:, None, None, :] + attn_mask = q_segment_ids_expanded == kv_segment_ids_expanded + + scale_factor = 1 / jnp.sqrt(q.shape[-1]) + attn_bias = jnp.zeros((q.shape[-2], k.shape[-2]), dtype=q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == jnp.bool_: + attn_bias = jnp.where(attn_mask, attn_bias, float("-inf")) + else: + attn_bias += attn_mask + + attn_weight = q @ k.swapaxes(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = jnn.softmax(attn_weight, axis=-1) + + return attn_weight @ v + + +class RMSNorm(nn.Module): + """ + RMSNorm is a normalization layer that normalizes the input using the root mean square. + """ + + epsilon: float + dtype: jnp.dtype = jnp.float32 + elementwise_affine: bool = True + weight_dtype: jnp.dtype = jnp.float32 + kernel_axes: Tuple[Optional[str], ...] = () + scale_init: Initializer = nn.initializers.ones + + @nn.compact + def __call__(self, hidden_states: jax.Array) -> jax.Array: + """ + Forward pass of the RMSNorm layer. + + First we compute the variance (mean of the square of the input) + and then normalize the input using the root mean square. + + NOTE: if weight is in mixed precision, the operand should be in the same precision. + Args: + hidden_states (jax.Array): Input data + + Returns: + jax.Array: Normed data + """ + + # dim = (self.dim,) if isinstance(self.dim, numbers.Integral) else self.dim + dim = hidden_states.shape[-1] + if self.elementwise_affine: + scale = self.param( + "scale", + nn.with_logical_partitioning(self.scale_init, self.kernel_axes), + (dim,), + self.weight_dtype, + ) + else: + scale = None + + input_dtype = hidden_states.dtype + variance = jnp.mean(jnp.square(hidden_states.astype(jnp.float32)), axis=-1, keepdims=True) + hidden_states: jax.Array = hidden_states * jax.lax.rsqrt(variance + self.epsilon) + + if self.elementwise_affine: + # convert into half-precision if necessary + hidden_states = (hidden_states.astype(self.dtype) * scale.astype(self.dtype)).astype(input_dtype) + else: + hidden_states = hidden_states.astype(input_dtype) + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + dim_out: Optional[int] = None + mult: int = 4 + dropout: float = 0.0 + activation_fn: str = "gelu" + final_dropout: bool = False + bias: bool = True + inner_dim: Optional[int] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, hidden_states: jax.Array, scale: float = 1.0, deterministic: bool = False) -> jax.Array: + dim = hidden_states.shape[-1] + if self.inner_dim is None: + inner_dim = dim * self.mult + if inner_dim < 256: + raise ValueError("inner_dim must be at least 256") + inner_dim = round(inner_dim / 256) * 256 # round to nearest multiple of 256 + else: + inner_dim = self.inner_dim + + dim_out = self.dim_out if self.dim_out is not None else dim + + act_kwargs = { + "name": "net.0", + "bias": self.bias, + "kernel_axes": ("embed", "mlp"), + "matmul_precision": self.matmul_precision, + "weight_dtype": self.weight_dtype, + "dtype": self.dtype, + } + match self.activation_fn: + case "gelu": + act_fn = GELU(dim, inner_dim, **act_kwargs) + case "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", **act_kwargs) + case "geglu": + act_fn = GEGLU(dim, inner_dim, **act_kwargs) + case "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, **act_kwargs) + case _: + raise ValueError(f"activation function {self.activation_fn} not supported") + + if isinstance(act_fn, GEGLU): + hidden_states = act_fn(hidden_states, scale) + else: + hidden_states = act_fn(hidden_states) + + hidden_states = checkpoint_name(hidden_states, "FFN - activation") + hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) + + hidden_states = DenseGeneral( + dim_out, + use_bias=self.bias, + kernel_axes=("mlp", "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="net.2", + )(hidden_states) + hidden_states = checkpoint_name(hidden_states, "FFN - Reprojection") + if self.final_dropout: + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + hidden_states = nn.Dropout(self.dropout)(hidden_states, deterministic=deterministic) + + return hidden_states + + +def apply_rotary_emb(input_tensor: jax.Array, freqs_cis: Tuple[jax.Array, jax.Array]) -> jax.Array: + """ + Integrates positional information into input tensors using RoPE. + + Args: + input_tensor (jax.Array): Input_tensor (from QKV of attention mechanism) + freqs_cis (Tuple[jax.Array, jax.Array]): The sine and cosine frequencies + + Returns: + jax.Array: Tensor where positional information has been integrated into the original input tensor + """ + if len(freqs_cis) != 2: + raise ValueError("freqs_cis must be a tuple of 2 elements") + + cos_freqs, sin_freqs = freqs_cis + + t_dup = input_tensor.reshape(*input_tensor.shape[:-1], -1, 2) + t1, t2 = jnp.split(t_dup, 2, axis=-1) + t_dup = jnp.concatenate([-t2, t1], axis=-1) + input_tensor_rot = t_dup.reshape(*input_tensor.shape) + + # Apply rotary embeddings + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out diff --git a/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py new file mode 100644 index 00000000..dff8b8c6 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/caption_projection.py @@ -0,0 +1,40 @@ +from flax import linen as nn +import jax.numpy as jnp + +from maxdiffusion.models.ltx_video.linear import DenseGeneral +from maxdiffusion.models.ltx_video.transformers.activations import approximate_gelu + + +class CaptionProjection(nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + """ + + in_features: int + hidden_size: int + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + + @nn.compact + def __call__(self, caption): + hidden_states = DenseGeneral( + self.hidden_size, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_1", + )(caption) + hidden_states = approximate_gelu(hidden_states) + hidden_states = DenseGeneral( + self.hidden_size, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="linear_2", + )(hidden_states) + return hidden_states diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py new file mode 100644 index 00000000..5c087f42 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -0,0 +1,294 @@ +from typing import List, Optional, Tuple + +import jax +import jax.numpy as jnp +from flax import linen as nn + +from maxdiffusion.models.ltx_video.linear import DenseGeneral +from maxdiffusion.models.ltx_video.transformers.adaln import AdaLayerNormSingle +from maxdiffusion.models.ltx_video.transformers.attention import BasicTransformerBlock +from maxdiffusion.models.ltx_video.transformers.caption_projection import CaptionProjection +from maxdiffusion.models.ltx_video.gradient_checkpoint import GradientCheckpointType +from maxdiffusion.models.ltx_video.repeatable_layer import RepeatableLayer + + +class Transformer3DModel(nn.Module): + num_attention_heads: int = 16 + attention_head_dim: int = 88 + out_channels: int = 128 + num_layers: int = 1 + dropout: float = 0.0 + cross_attention_dim: Optional[int] = None + attention_bias: bool = False + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + adaptive_norm: str = "single_scale_shift" # 'single_scale_shift' or 'single_scale' + standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm' + norm_elementwise_affine: bool = True + norm_eps: float = 1e-5 + attention_type: str = "default" + caption_channels: int = None + use_tpu_flash_attention: bool = True # if True uses the TPU attention offload ('flash attention') + qk_norm: Optional[str] = None + positional_embedding_type: str = "rope" + positional_embedding_theta: Optional[float] = None + positional_embedding_max_pos: Optional[List[int]] = None + timestep_scale_multiplier: Optional[float] = None + ffn_dim_mult: Optional[int] = 4 + output_scale: Optional[float] = None + attention_op: Optional[nn.Module] = None + dtype: jnp.dtype = jnp.float32 + weight_dtype: jnp.dtype = jnp.float32 + matmul_precision: str = "default" + sharding_mesh: Optional[jax.sharding.Mesh] = None + param_scan_axis: int = 0 + gradient_checkpointing: Optional[str] = None + + + def setup(self): + assert self.out_channels is not None, "out channels must be specified in model config." + self.inner_dim = self.num_attention_heads * self.attention_head_dim + self.patchify_proj = DenseGeneral( + self.inner_dim, + use_bias=True, + kernel_axes=(None, "embed"), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="patchify_proj", + ) + self.freq_cis_pre_computer = FreqsCisPrecomputer( + self.positional_embedding_max_pos, self.positional_embedding_theta, self.inner_dim + ) + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=4 if self.adaptive_norm == "single_scale" else 6, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + + def scale_shift_table_init(key): + return jax.random.normal(key, (2, self.inner_dim)) / self.inner_dim**0.5 + + self.scale_shift_table = self.param( + "scale_shift_table", # Trainable parameter name + nn.with_logical_partitioning(scale_shift_table_init, ("ada", "embed")), + ) + self.norm_out = nn.LayerNorm(epsilon=1e-6, use_scale=False, use_bias=False) + self.proj_out = DenseGeneral( + self.out_channels, + use_bias=True, + kernel_axes=("embed", None), + matmul_precision=self.matmul_precision, + weight_dtype=self.weight_dtype, + dtype=self.dtype, + name="proj_out", + ) + self.use_rope = self.positional_embedding_type == "rope" + if self.num_layers > 0: + RemattedBasicTransformerBlock = GradientCheckpointType.from_str(self.gradient_checkpointing).apply( + BasicTransformerBlock + ) + + self.transformer_blocks = RepeatableLayer( + RemattedBasicTransformerBlock, + num_layers=self.num_layers, + module_init_kwargs=dict( + dim=self.inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + dropout=self.dropout, + cross_attention_dim=self.cross_attention_dim, + activation_fn=self.activation_fn, + num_embeds_ada_norm=self.num_embeds_ada_norm, + attention_bias=self.attention_bias, + only_cross_attention=self.only_cross_attention, + double_self_attention=self.double_self_attention, + upcast_attention=self.upcast_attention, + adaptive_norm=self.adaptive_norm, + standardization_norm=self.standardization_norm, + norm_elementwise_affine=self.norm_elementwise_affine, + norm_eps=self.norm_eps, + attention_type=self.attention_type, + use_tpu_flash_attention=self.use_tpu_flash_attention, + qk_norm=self.qk_norm, + use_rope=self.use_rope, + ffn_dim_mult=self.ffn_dim_mult, + attention_op=self.attention_op, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + sharding_mesh=self.sharding_mesh, + name="CheckpointBasicTransformerBlock_0", + ), + pspec_name="layers", + param_scan_axis=self.param_scan_axis, + ) + + if self.caption_channels is not None: + self.caption_projection = CaptionProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + matmul_precision=self.matmul_precision, + ) + def init_weights(self, in_channels, caption_channels, eval_only=True): + example_inputs = {} + batch_size, num_tokens = 4, 256 + input_shapes = { + "hidden_states": (batch_size, num_tokens, in_channels), + "indices_grid": (batch_size, 3, num_tokens), + "encoder_hidden_states": (batch_size, 128, caption_channels), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + + if eval_only: + return jax.eval_shape( + self.init, + jax.random.PRNGKey(42), ##need to change? + **example_inputs, + )["params"] + else: + return self.init(jax.random.PRNGKey(42), **example_inputs)['params'] + + def __call__( + self, + hidden_states, + indices_grid, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + segment_ids=None, + encoder_attention_segment_ids=None, + return_dict=True, + ): + hidden_states = self.patchify_proj(hidden_states) + freqs_cis = self.freq_cis_pre_computer(indices_grid) + + if self.timestep_scale_multiplier: + timestep = self.timestep_scale_multiplier * timestep + + batch_size = hidden_states.shape[0] + + timestep, embedded_timestep = self.adaln_single( + timestep, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + + if self.num_layers > 0: + hidden_states = self.transformer_blocks( + hidden_states, + freqs_cis, + segment_ids, + encoder_hidden_states, + encoder_attention_segment_ids, + timestep, + cross_attention_kwargs, + class_labels, + ) + # Output processing + + scale_shift_values = ( + self.scale_shift_table[jnp.newaxis, jnp.newaxis, :, :] + embedded_timestep[:, :, jnp.newaxis] + ) + scale_shift_values = nn.with_logical_constraint( + scale_shift_values, ("activation_batch", "activation_length", "activation_ada", "activation_embed") + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + if self.output_scale: + hidden_states = hidden_states / self.output_scale + + return hidden_states + + +def log_base(x: jax.Array, base: jax.Array) -> jax.Array: + """ + Computes log of x with defined base. + + Args: + x (jax.Array): log value + base (jax.Array): base of the log + + Returns: + jax.Array: log(x)[base] + """ + return jnp.log(x) / jnp.log(base) + + + + + +class FreqsCisPrecomputer(nn.Module): + """ + computes frequency components (cosine and sine embeddings) for positional encodings based on fractional positions. + This is commonly used in rotary embeddings (RoPE) for transformers. + """ + + positional_embedding_max_pos: List[int] + positional_embedding_theta: float + inner_dim: int + + def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array: + fractional_positions = jnp.stack( + [indices_grid[:, i] / self.positional_embedding_max_pos[i] for i in range(3)], + axis=-1, + ) + return fractional_positions + + @nn.compact + def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]: + source_dtype = indices_grid.dtype + dtype = jnp.float32 # We need full precision in the freqs_cis computation. + dim = self.inner_dim + theta = self.positional_embedding_theta + + fractional_positions = self.get_fractional_positions(indices_grid) + + start = 1 + end = theta + indices = jnp.power( + theta, + jnp.linspace( + log_base(start, theta), + log_base(end, theta), + dim // 6, + dtype=dtype, + ), + ) + indices = indices.astype(dtype) + + indices = indices * jnp.pi / 2 + + freqs = (indices * (jnp.expand_dims(fractional_positions, axis=-1) * 2 - 1)).swapaxes(-1, -2) + freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # Flatten along axis 2 + + cos_freq = jnp.cos(freqs).repeat(2, axis=-1) + sin_freq = jnp.sin(freqs).repeat(2, axis=-1) + + if dim % 6 != 0: + cos_padding = jnp.ones_like(cos_freq[:, :, : dim % 6]) + sin_padding = jnp.zeros_like(sin_freq[:, :, : dim % 6]) + + cos_freq = jnp.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = jnp.concatenate([sin_padding, sin_freq], axis=-1) + return cos_freq.astype(source_dtype), sin_freq.astype(source_dtype) diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/__init__.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py new file mode 100644 index 00000000..55c5bf37 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py @@ -0,0 +1,1265 @@ +import inspect +from importlib import import_module +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from diffusers.models.activations import GEGLU, GELU, ApproximateGELU +from diffusers.models.attention import _chunked_feed_forward +from diffusers.models.attention_processor import ( + LoRAAttnAddedKVProcessor, + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + SpatialNorm, +) +from diffusers.models.lora import LoRACompatibleLinear +from diffusers.models.normalization import RMSNorm +from diffusers.utils import deprecate, logging +from diffusers.utils.torch_utils import maybe_allow_in_graph +from einops import rearrange +from torch import nn + +from maxdiffusion.models.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy + +try: + from torch_xla.experimental.custom_kernel import flash_attention +except ImportError: + # workaround for automatic tests. Currently this function is manually patched + # to the torch_xla lib on setup of container + pass + +# code adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py + +logger = logging.get_logger(__name__) + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + qk_norm (`str`, *optional*, defaults to None): + Set to 'layer_norm' or `rms_norm` to perform query and key normalization. + adaptive_norm (`str`, *optional*, defaults to `"single_scale_shift"`): + The type of adaptive norm to use. Can be `"single_scale_shift"`, `"single_scale"` or "none". + standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): + The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, # pylint: disable=unused-argument + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + adaptive_norm: str = "single_scale_shift", # 'single_scale_shift', 'single_scale' or 'none' + standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm' + norm_eps: float = 1e-5, + qk_norm: Optional[str] = None, + final_dropout: bool = False, + attention_type: str = "default", # pylint: disable=unused-argument + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + use_tpu_flash_attention: bool = False, + use_rope: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_tpu_flash_attention = use_tpu_flash_attention + self.adaptive_norm = adaptive_norm + + assert standardization_norm in ["layer_norm", "rms_norm"] + assert adaptive_norm in ["single_scale_shift", "single_scale", "none"] + + make_norm_layer = ( + nn.LayerNorm if standardization_norm == "layer_norm" else RMSNorm + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm1 = make_norm_layer( + dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps + ) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + use_tpu_flash_attention=use_tpu_flash_attention, + qk_norm=qk_norm, + use_rope=use_rope, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=( + cross_attention_dim if not double_self_attention else None + ), + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + use_tpu_flash_attention=use_tpu_flash_attention, + qk_norm=qk_norm, + use_rope=use_rope, + ) # is self-attn if encoder_hidden_states is none + + if adaptive_norm == "none": + self.attn2_norm = make_norm_layer( + dim, norm_eps, norm_elementwise_affine + ) + else: + self.attn2 = None + self.attn2_norm = None + + self.norm2 = make_norm_layer(dim, norm_eps, norm_elementwise_affine) + + # 3. Feed-forward + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 5. Scale-shift for PixArt-Alpha. + if adaptive_norm != "none": + num_ada_params = 4 if adaptive_norm == "single_scale" else 6 + self.scale_shift_table = nn.Parameter( + torch.randn(num_ada_params, dim) / dim**0.5 + ) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_use_tpu_flash_attention(self): + r""" + Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU + attention kernel. + """ + self.use_tpu_flash_attention = True + self.attn1.set_use_tpu_flash_attention() + self.attn2.set_use_tpu_flash_attention() + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + skip_layer_mask: Optional[torch.Tensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + ) -> torch.FloatTensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored." + ) + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + original_hidden_states = hidden_states + + norm_hidden_states = self.norm1(hidden_states) + + # Apply ada_norm_single + if self.adaptive_norm in ["single_scale_shift", "single_scale"]: + assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim] + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None] + timestep.reshape( + batch_size, timestep.shape[1], num_ada_params, -1 + ) + if self.adaptive_norm == "single_scale_shift": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + ada_values.unbind(dim=2) + ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + scale_msa, gate_msa, scale_mlp, gate_mlp = ada_values.unbind(dim=2) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + elif self.adaptive_norm == "none": + scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + norm_hidden_states = norm_hidden_states.squeeze( + 1 + ) # TODO: Check if this is needed + + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = ( + cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + ) + + attn_output = self.attn1( + norm_hidden_states, + freqs_cis=freqs_cis, + encoder_hidden_states=( + encoder_hidden_states if self.only_cross_attention else None + ), + attention_mask=attention_mask, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **cross_attention_kwargs, + ) + if gate_msa is not None: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.adaptive_norm == "none": + attn_input = self.attn2_norm(hidden_states) + else: + attn_input = hidden_states + attn_output = self.attn2( + attn_input, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm2(hidden_states) + if self.adaptive_norm == "single_scale_shift": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + elif self.adaptive_norm == "single_scale": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + elif self.adaptive_norm == "none": + pass + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward( + self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size + ) + else: + ff_output = self.ff(norm_hidden_states) + if gate_mlp is not None: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + if ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.TransformerBlock + ): + skip_layer_mask = skip_layer_mask.view(-1, 1, 1) + hidden_states = hidden_states * skip_layer_mask + original_hidden_states * ( + 1.0 - skip_layer_mask + ) + + return hidden_states + + +@maybe_allow_in_graph +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + qk_norm (`str`, *optional*, defaults to None): + Set to 'layer_norm' or `rms_norm` to perform query and key normalization. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + qk_norm: Optional[str] = None, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + use_tpu_flash_attention: bool = False, + use_rope: bool = False, + ): + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = ( + cross_attention_dim if cross_attention_dim is not None else query_dim + ) + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.use_tpu_flash_attention = use_tpu_flash_attention + self.use_rope = use_rope + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + if qk_norm is None: + self.q_norm = nn.Identity() + self.k_norm = nn.Identity() + elif qk_norm == "rms_norm": + self.q_norm = RMSNorm(dim_head * heads, eps=1e-5) + self.k_norm = RMSNorm(dim_head * heads, eps=1e-5) + elif qk_norm == "layer_norm": + self.q_norm = nn.LayerNorm(dim_head * heads, eps=1e-5) + self.k_norm = nn.LayerNorm(dim_head * heads, eps=1e-5) + else: + raise ValueError(f"Unsupported qk_norm method: {qk_norm}") + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm( + num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True + ) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm( + f_channels=query_dim, zq_channels=spatial_norm_dim + ) + else: + self.spatial_norm = None + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, + num_groups=cross_attention_norm_num_groups, + eps=1e-5, + affine=True, + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + linear_cls = nn.Linear + + self.linear_cls = linear_cls + self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = AttnProcessor2_0() + self.set_processor(processor) + + def set_use_tpu_flash_attention(self): + r""" + Function sets the flag in this object. The flag will enforce the usage of TPU attention kernel. + """ + self.use_tpu_flash_attention = True + + def set_processor(self, processor: "AttnProcessor") -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info( + f"You are removing possibly trained weights of {self.processor} with {processor}" + ) + self._modules.pop("processor") + + self.processor = processor + + def get_processor( + self, return_deprecated_lora: bool = False + ) -> "AttentionProcessor": # noqa: F821 + r""" + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible + # serialization format for LoRA Attention Processors. It should be deleted once the integration + # with PEFT is completed. + is_lora_activated = { + name: module.lora_layer is not None + for name, module in self.named_modules() + if hasattr(module, "lora_layer") + } + + # 1. if no layer has a LoRA activated we can return the processor as usual + if not any(is_lora_activated.values()): + return self.processor + + # If doesn't apply LoRA do `add_k_proj` or `add_v_proj` + is_lora_activated.pop("add_k_proj", None) + is_lora_activated.pop("add_v_proj", None) + # 2. else it is not posssible that only some layers have LoRA activated + if not all(is_lora_activated.values()): + raise ValueError( + f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}" + ) + + # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor + non_lora_processor_cls_name = self.processor.__class__.__name__ + lora_processor_cls = getattr( + import_module(__name__), "LoRA" + non_lora_processor_cls_name + ) + + hidden_size = self.inner_dim + + # now create a LoRA attention processor from the LoRA layers + if lora_processor_cls in [ + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + ]: + kwargs = { + "cross_attention_dim": self.cross_attention_dim, + "rank": self.to_q.lora_layer.rank, + "network_alpha": self.to_q.lora_layer.network_alpha, + "q_rank": self.to_q.lora_layer.rank, + "q_hidden_size": self.to_q.lora_layer.out_features, + "k_rank": self.to_k.lora_layer.rank, + "k_hidden_size": self.to_k.lora_layer.out_features, + "v_rank": self.to_v.lora_layer.rank, + "v_hidden_size": self.to_v.lora_layer.out_features, + "out_rank": self.to_out[0].lora_layer.rank, + "out_hidden_size": self.to_out[0].lora_layer.out_features, + } + + if hasattr(self.processor, "attention_op"): + kwargs["attention_op"] = self.processor.attention_op + + lora_processor = lora_processor_cls(hidden_size, **kwargs) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict( + self.to_out[0].lora_layer.state_dict() + ) + elif lora_processor_cls == LoRAAttnAddedKVProcessor: + lora_processor = lora_processor_cls( + hidden_size, + cross_attention_dim=self.add_k_proj.weight.shape[0], + rank=self.to_q.lora_layer.rank, + network_alpha=self.to_q.lora_layer.network_alpha, + ) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict( + self.to_out[0].lora_layer.state_dict() + ) + + # only save if used + if self.add_k_proj.lora_layer is not None: + lora_processor.add_k_proj_lora.load_state_dict( + self.add_k_proj.lora_layer.state_dict() + ) + lora_processor.add_v_proj_lora.load_state_dict( + self.add_v_proj.lora_layer.state_dict() + ) + else: + lora_processor.add_k_proj_lora = None + lora_processor.add_v_proj_lora = None + else: + raise ValueError(f"{lora_processor_cls} does not exist.") + + return lora_processor + + def forward( + self, + hidden_states: torch.FloatTensor, + freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + skip_layer_mask: Optional[torch.Tensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + skip_layer_mask (`torch.Tensor`, *optional*): + The skip layer mask to use. If `None`, no mask is applied. + skip_layer_strategy (`SkipLayerStrategy`, *optional*, defaults to `None`): + Controls which layers to skip for spatiotemporal guidance. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + + attn_parameters = set( + inspect.signature(self.processor.__call__).parameters.keys() + ) + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by" + f" {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = { + k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters + } + + return self.processor( + self, + hidden_states, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` + is the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape( + batch_size // head_size, seq_len, dim * head_size + ) + return tensor + + def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is + the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is + reshaped to `[batch_size * heads, seq_len, dim // heads]`. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + + head_size = self.heads + if tensor.ndim == 3: + batch_size, seq_len, dim = tensor.shape + extra_dim = 1 + else: + batch_size, extra_dim, seq_len, dim = tensor.shape + tensor = tensor.reshape( + batch_size, seq_len * extra_dim, head_size, dim // head_size + ) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape( + batch_size * head_size, seq_len * extra_dim, dim // head_size + ) + + return tensor + + def get_attention_scores( + self, + query: torch.Tensor, + key: torch.Tensor, + attention_mask: torch.Tensor = None, + ) -> torch.Tensor: + r""" + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], + query.shape[1], + key.shape[1], + dtype=query.dtype, + device=query.device, + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask( + self, + attention_mask: torch.Tensor, + target_length: int, + batch_size: int, + out_dim: int = 3, + ) -> torch.Tensor: + r""" + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = ( + attention_mask.shape[0], + attention_mask.shape[1], + target_length, + ) + padding = torch.zeros( + padding_shape, + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + def norm_encoder_hidden_states( + self, encoder_hidden_states: torch.Tensor + ) -> torch.Tensor: + r""" + Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the + `Attention` class. + + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ + assert ( + self.norm_cross is not None + ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + + @staticmethod + def apply_rotary_emb( + input_tensor: torch.Tensor, + freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + cos_freqs = freqs_cis[0] + sin_freqs = freqs_cis[1] + + t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) + t1, t2 = t_dup.unbind(dim=-1) + t_dup = torch.stack((-t2, t1), dim=-1) + input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") + + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out + + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + pass + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor], + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + skip_layer_mask: Optional[torch.FloatTensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + if skip_layer_mask is not None: + skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1) + + if (attention_mask is not None) and (not attn.use_tpu_flash_attention): + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + query = attn.to_q(hidden_states) + query = attn.q_norm(query) + + if encoder_hidden_states is not None: + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + key = attn.to_k(encoder_hidden_states) + key = attn.k_norm(key) + else: # if no context provided do self-attention + encoder_hidden_states = hidden_states + key = attn.to_k(hidden_states) + key = attn.k_norm(key) + if attn.use_rope: + key = attn.apply_rotary_emb(key, freqs_cis) + query = attn.apply_rotary_emb(query, freqs_cis) + + value = attn.to_v(encoder_hidden_states) + value_for_stg = value + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + + if attn.use_tpu_flash_attention: # use tpu attention offload 'flash attention' + q_segment_indexes = None + if ( + attention_mask is not None + ): # if mask is required need to tune both segmenIds fields + # attention_mask = torch.squeeze(attention_mask).to(torch.float32) + attention_mask = attention_mask.to(torch.float32) + q_segment_indexes = torch.ones( + batch_size, query.shape[2], device=query.device, dtype=torch.float32 + ) + assert ( + attention_mask.shape[1] == key.shape[2] + ), f"ERROR: KEY SHAPE must be same as attention mask [{key.shape[2]}, {attention_mask.shape[1]}]" + + assert ( + query.shape[2] % 128 == 0 + ), f"ERROR: QUERY SHAPE must be divisible by 128 (TPU limitation) [{query.shape[2]}]" + assert ( + key.shape[2] % 128 == 0 + ), f"ERROR: KEY SHAPE must be divisible by 128 (TPU limitation) [{key.shape[2]}]" + + # run the TPU kernel implemented in jax with pallas + hidden_states_a = flash_attention( + q=query, + k=key, + v=value, + q_segment_ids=q_segment_indexes, + kv_segment_ids=attention_mask, + sm_scale=attn.scale, + ) + else: + hidden_states_a = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + ) + + hidden_states_a = hidden_states_a.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states_a = hidden_states_a.to(query.dtype) + + if ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.AttentionSkip + ): + hidden_states = hidden_states_a * skip_layer_mask + hidden_states * ( + 1.0 - skip_layer_mask + ) + elif ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.AttentionValues + ): + hidden_states = hidden_states_a * skip_layer_mask + value_for_stg * ( + 1.0 - skip_layer_mask + ) + else: + hidden_states = hidden_states_a + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + if ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.Residual + ): + skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1, 1) + + if attn.residual_connection: + if ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.Residual + ): + hidden_states = hidden_states + residual * skip_layer_mask + else: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class AttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + query = attn.q_norm(query) + key = attn.k_norm(key) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim=None, + bias: bool = True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + linear_cls = nn.Linear + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + elif activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + else: + raise ValueError(f"Unsupported activation function: {activation_fn}") + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(linear_cls(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + compatible_cls = (GEGLU, LoRACompatibleLinear) + for module in self.net: + if isinstance(module, compatible_cls): + hidden_states = module(hidden_states, scale) + else: + hidden_states = module(hidden_states) + return hidden_states \ No newline at end of file diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/embeddings.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/embeddings.py new file mode 100644 index 00000000..14fc56d6 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/embeddings.py @@ -0,0 +1,129 @@ +# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py +import math + +import numpy as np +import torch +from einops import rearrange +from torch import nn + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def get_3d_sincos_pos_embed(embed_dim, grid, w, h, f): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid = rearrange(grid, "c (f h w) -> c f h w", h=h, w=w) + grid = rearrange(grid, "c f h w -> c h w f", h=h, w=w) + grid = grid.reshape([3, 1, w, h, f]) + pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid) + pos_embed = pos_embed.transpose(1, 0, 2, 3) + return rearrange(pos_embed, "h w f c -> (f h w) c") + + +def get_3d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 3 != 0: + raise ValueError("embed_dim must be divisible by 3") + + # use half of dimensions to encode grid_h + emb_f = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0]) # (H*W*T, D/3) + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1]) # (H*W*T, D/3) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2]) # (H*W*T, D/3) + + emb = np.concatenate([emb_h, emb_w, emb_f], axis=-1) # (H*W*T, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos_shape = pos.shape + + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + out = out.reshape([*pos_shape, -1])[0] + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (M, D) + return emb + + +class SinusoidalPositionalEmbedding(nn.Module): + """Apply positional information to a sequence of embeddings. + + Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to + them + + Args: + embed_dim: (int): Dimension of the positional embedding. + max_seq_length: Maximum sequence length to apply positional embeddings + + """ + + def __init__(self, embed_dim: int, max_seq_length: int = 32): + super().__init__() + position = torch.arange(max_seq_length).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim) + ) + pe = torch.zeros(1, max_seq_length, embed_dim) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + def forward(self, x): + _, seq_length, _ = x.shape + x = x + self.pe[:, :seq_length] + return x \ No newline at end of file diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/symmetric_patchifier.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/symmetric_patchifier.py new file mode 100644 index 00000000..28ad834f --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/symmetric_patchifier.py @@ -0,0 +1,84 @@ +from abc import ABC, abstractmethod +from typing import Tuple + +import torch +from diffusers.configuration_utils import ConfigMixin +from einops import rearrange +from torch import Tensor + + +class Patchifier(ConfigMixin, ABC): + def __init__(self, patch_size: int): + super().__init__() + self._patch_size = (1, patch_size, patch_size) + + @abstractmethod + def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: + raise NotImplementedError("Patchify method not implemented") + + @abstractmethod + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + pass + + @property + def patch_size(self): + return self._patch_size + + def get_latent_coords( + self, latent_num_frames, latent_height, latent_width, batch_size, device + ): + """ + Return a tensor of shape [batch_size, 3, num_patches] containing the + top-left corner latent coordinates of each latent patch. + The tensor is repeated for each batch element. + """ + latent_sample_coords = torch.meshgrid( + torch.arange(0, latent_num_frames, self._patch_size[0], device=device), + torch.arange(0, latent_height, self._patch_size[1], device=device), + torch.arange(0, latent_width, self._patch_size[2], device=device), + ) + latent_sample_coords = torch.stack(latent_sample_coords, dim=0) + latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_coords = rearrange( + latent_coords, "b c f h w -> b c (f h w)", b=batch_size + ) + return latent_coords + + +class SymmetricPatchifier(Patchifier): + def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: + b, _, f, h, w = latents.shape + latent_coords = self.get_latent_coords(f, h, w, b, latents.device) + latents = rearrange( + latents, + "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", + p1=self._patch_size[0], + p2=self._patch_size[1], + p3=self._patch_size[2], + ) + return latents, latent_coords + + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + output_height = output_height // self._patch_size[1] + output_width = output_width // self._patch_size[2] + latents = rearrange( + latents, + "b (f h w) (c p q) -> b c f (h p) (w q)", + h=output_height, + w=output_width, + p=self._patch_size[1], + q=self._patch_size[2], + ) + return latents \ No newline at end of file diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/transformer_pt.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/transformer_pt.py new file mode 100644 index 00000000..75b7b510 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/transformer_pt.py @@ -0,0 +1,507 @@ +# Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union +import os +import json +import glob +from pathlib import Path + +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import PixArtAlphaTextProjection +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormSingle +from diffusers.utils import BaseOutput, is_torch_version +from diffusers.utils import logging +from torch import nn +from safetensors import safe_open + + +from maxdiffusion.models.ltx_video.transformers_pytorch.attention import BasicTransformerBlock +from maxdiffusion.models.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy + +from maxdiffusion.models.ltx_video.utils.diffusers_config_mapping import ( + diffusers_and_ours_config_mapping, + make_hashable_key, + TRANSFORMER_KEYS_RENAME_DICT, +) + + +logger = logging.get_logger(__name__) + + +@dataclass +class Transformer3DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class Transformer3DModel_PT(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + adaptive_norm: str = "single_scale_shift", # 'single_scale_shift' or 'single_scale' + standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm' + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + use_tpu_flash_attention: bool = False, # if True uses the TPU attention offload ('flash attention') + qk_norm: Optional[str] = None, + positional_embedding_type: str = "rope", + positional_embedding_theta: Optional[float] = None, + positional_embedding_max_pos: Optional[List[int]] = None, + timestep_scale_multiplier: Optional[float] = None, + causal_temporal_positioning: bool = False, # For backward compatibility, will be deprecated + ): + super().__init__() + self.use_tpu_flash_attention = ( + use_tpu_flash_attention # FIXME: push config down to the attention modules + ) + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.inner_dim = inner_dim + self.patchify_proj = nn.Linear(in_channels, inner_dim, bias=True) + self.positional_embedding_type = positional_embedding_type + self.positional_embedding_theta = positional_embedding_theta + self.positional_embedding_max_pos = positional_embedding_max_pos + self.use_rope = self.positional_embedding_type == "rope" + self.timestep_scale_multiplier = timestep_scale_multiplier + + if self.positional_embedding_type == "absolute": + raise ValueError("Absolute positional embedding is no longer supported") + elif self.positional_embedding_type == "rope": + if positional_embedding_theta is None: + raise ValueError( + "If `positional_embedding_type` type is rope, `positional_embedding_theta` must also be defined" + ) + if positional_embedding_max_pos is None: + raise ValueError( + "If `positional_embedding_type` type is rope, `positional_embedding_max_pos` must also be defined" + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + adaptive_norm=adaptive_norm, + standardization_norm=standardization_norm, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + use_tpu_flash_attention=use_tpu_flash_attention, + qk_norm=qk_norm, + use_rope=self.use_rope, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter( + torch.randn(2, inner_dim) / inner_dim**0.5 + ) + self.proj_out = nn.Linear(inner_dim, self.out_channels) + + self.adaln_single = AdaLayerNormSingle( + inner_dim, use_additional_conditions=False + ) + if adaptive_norm == "single_scale": + self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True) + + self.caption_projection = None + if caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, hidden_size=inner_dim + ) + + self.gradient_checkpointing = False + + def set_use_tpu_flash_attention(self): + r""" + Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU + attention kernel. + """ + logger.info("ENABLE TPU FLASH ATTENTION -> TRUE") + self.use_tpu_flash_attention = True + # push config down to the attention modules + for block in self.transformer_blocks: + block.set_use_tpu_flash_attention() + + def create_skip_layer_mask( + self, + batch_size: int, + num_conds: int, + ptb_index: int, + skip_block_list: Optional[List[int]] = None, + ): + if skip_block_list is None or len(skip_block_list) == 0: + return None + num_layers = len(self.transformer_blocks) + mask = torch.ones( + (num_layers, batch_size * num_conds), device=self.device, dtype=self.dtype + ) + for block_idx in skip_block_list: + mask[block_idx, ptb_index::num_conds] = 0 + return mask + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def get_fractional_positions(self, indices_grid): + fractional_positions = torch.stack( + [ + indices_grid[:, i] / self.positional_embedding_max_pos[i] + for i in range(3) + ], + dim=-1, + ) + return fractional_positions + + def precompute_freqs_cis(self, indices_grid, spacing="exp"): + dtype = torch.float32 # We need full precision in the freqs_cis computation. + dim = self.inner_dim + theta = self.positional_embedding_theta + + fractional_positions = self.get_fractional_positions(indices_grid) + + start = 1 + end = theta + device = fractional_positions.device + if spacing == "exp": + indices = theta ** ( + torch.linspace( + math.log(start, theta), + math.log(end, theta), + dim // 6, + device=device, + dtype=dtype, + ) + ) + indices = indices.to(dtype=dtype) + elif spacing == "exp_2": + indices = 1.0 / theta ** (torch.arange(0, dim, 6, device=device) / dim) + indices = indices.to(dtype=dtype) + elif spacing == "linear": + indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype) + elif spacing == "sqrt": + indices = torch.linspace( + start**2, end**2, dim // 6, device=device, dtype=dtype + ).sqrt() + + indices = indices * math.pi / 2 + + if spacing == "exp_2": + freqs = ( + (indices * fractional_positions.unsqueeze(-1)) + .transpose(-1, -2) + .flatten(2) + ) + else: + freqs = ( + (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) + .transpose(-1, -2) + .flatten(2) + ) + + cos_freq = freqs.cos().repeat_interleave(2, dim=-1) + sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + if dim % 6 != 0: + cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6]) + sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6]) + cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) + sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) + return cos_freq.to(self.dtype), sin_freq.to(self.dtype) + + def load_state_dict( + self, + state_dict: Dict, + *args, + **kwargs, + ): + if any([key.startswith("model.diffusion_model.") for key in state_dict.keys()]): + state_dict = { + key.replace("model.diffusion_model.", ""): value + for key, value in state_dict.items() + if key.startswith("model.diffusion_model.") + } + super().load_state_dict(state_dict, **kwargs) + + @classmethod + def from_pretrained( + cls, + pretrained_model_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + pretrained_model_path = Path(pretrained_model_path) + if pretrained_model_path.is_dir(): + config_path = pretrained_model_path / "transformer" / "config.json" + with open(config_path, "r") as f: + config = make_hashable_key(json.load(f)) + + assert config in diffusers_and_ours_config_mapping, ( + "Provided diffusers checkpoint config for transformer is not suppported. " + "We only support diffusers configs found in Lightricks/LTX-Video." + ) + + config = diffusers_and_ours_config_mapping[config] + state_dict = {} + ckpt_paths = ( + pretrained_model_path + / "transformer" + / "diffusion_pytorch_model*.safetensors" + ) + dict_list = glob.glob(str(ckpt_paths)) + for dict_path in dict_list: + part_dict = {} + with safe_open(dict_path, framework="pt", device="cpu") as f: + for k in f.keys(): + part_dict[k] = f.get_tensor(k) + state_dict.update(part_dict) + + for key in list(state_dict.keys()): + new_key = key + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + state_dict[new_key] = state_dict.pop(key) + + with torch.device("meta"): + transformer = cls.from_config(config) + transformer.load_state_dict(state_dict, assign=True, strict=True) + elif pretrained_model_path.is_file() and str(pretrained_model_path).endswith( + ".safetensors" + ): + comfy_single_file_state_dict = {} + with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + for k in f.keys(): + comfy_single_file_state_dict[k] = f.get_tensor(k) + configs = json.loads(metadata["config"]) + transformer_config = configs["transformer"] + with torch.device("meta"): + transformer = Transformer3DModel.from_config(transformer_config) + transformer.load_state_dict(comfy_single_file_state_dict, assign=True) + return transformer + + def forward( + self, + hidden_states: torch.Tensor, + indices_grid: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + skip_layer_mask: Optional[torch.Tensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`): + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + skip_layer_mask ( `torch.Tensor`, *optional*): + A mask of shape `(num_layers, batch)` that indicates which layers to skip. `0` at position + `layer, batch_idx` indicates that the layer should be skipped for the corresponding batch index. + skip_layer_strategy ( `SkipLayerStrategy`, *optional*, defaults to `None`): + Controls which layers are skipped when calculating a perturbed latent for spatiotemporal guidance. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # for tpu attention offload 2d token masks are used. No need to transform. + if not self.use_tpu_flash_attention: + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = ( + 1 - encoder_attention_mask.to(hidden_states.dtype) + ) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + hidden_states = self.patchify_proj(hidden_states) + + if self.timestep_scale_multiplier: + timestep = self.timestep_scale_multiplier * timestep + + freqs_cis = self.precompute_freqs_cis(indices_grid) + + batch_size = hidden_states.shape[0] + timestep, embedded_timestep = self.adaln_single( + timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + # Second dimension is 1 or number of tokens (if timestep_per_token) + timestep = timestep.view(batch_size, -1, timestep.shape[-1]) + embedded_timestep = embedded_timestep.view( + batch_size, -1, embedded_timestep.shape[-1] + ) + + # 2. Blocks + if self.caption_projection is not None: + batch_size = hidden_states.shape[0] + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view( + batch_size, -1, hidden_states.shape[-1] + ) + + for block_idx, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + freqs_cis, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + ( + skip_layer_mask[block_idx] + if skip_layer_mask is not None + else None + ), + skip_layer_strategy, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + freqs_cis=freqs_cis, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + skip_layer_mask=( + skip_layer_mask[block_idx] + if skip_layer_mask is not None + else None + ), + skip_layer_strategy=skip_layer_strategy, + ) + + # 3. Output + scale_shift_values = ( + self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + if not return_dict: + return (hidden_states,) + + return Transformer3DModelOutput(sample=hidden_states) \ No newline at end of file diff --git a/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py b/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py new file mode 100644 index 00000000..13c1ebf9 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/convert_torch_weights_to_jax.py @@ -0,0 +1,256 @@ +import argparse +import json +from typing import Any, Dict, Optional + +import jax +import jax.numpy as jnp +from flax.training import train_state +import optax +import orbax.checkpoint as ocp +from safetensors.torch import load_file + +from maxdiffusion.models.ltx_video.transformers_pytorch.transformer_pt import Transformer3DModel_PT +from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel as JaxTranformer3DModel +from maxdiffusion.models.ltx_video.utils.torch_compat import torch_statedict_to_jax + +from huggingface_hub import hf_hub_download +import os + +class Checkpointer: + """ + Checkpointer - to load and store JAX checkpoints + """ + + STATE_DICT_SHAPE_KEY = "shape" + STATE_DICT_DTYPE_KEY = "dtype" + TRAIN_STATE_FILE_NAME = "train_state" + + def __init__( + self, + checkpoint_dir: str, + use_zarr3: bool = False, + save_buffer_size: Optional[int] = None, + restore_buffer_size: Optional[int] = None, + ): + """ + Constructs the checkpointer object + """ + opts = ocp.CheckpointManagerOptions( + enable_async_checkpointing=True, + step_format_fixed_length=8, # to make the format of "00000000" + ) + self.use_zarr3 = use_zarr3 + self.save_buffer_size = save_buffer_size + self.restore_buffer_size = restore_buffer_size + registry = ocp.DefaultCheckpointHandlerRegistry() + self.train_state_handler = ocp.PyTreeCheckpointHandler( + save_concurrent_gb=save_buffer_size, + restore_concurrent_gb=restore_buffer_size, + use_zarr3=use_zarr3, + ) + registry.add( + self.TRAIN_STATE_FILE_NAME, + ocp.args.PyTreeSave, + self.train_state_handler, + ) + self.manager = ocp.CheckpointManager( + directory=checkpoint_dir, + options=opts, + handler_registry=registry, + ) + + @property + def save_buffer_size_bytes(self) -> Optional[int]: + if self.save_buffer_size is None: + return None + return self.save_buffer_size * 2**30 + + @staticmethod + def state_dict_to_structure_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Converts a state dict to a dictionary stating the shape and dtype of the state_dict elements. + With this, we can reconstruct the state_dict structure later on. + """ + return jax.tree_util.tree_map( + lambda t: { + Checkpointer.STATE_DICT_SHAPE_KEY: tuple(t.shape), + Checkpointer.STATE_DICT_DTYPE_KEY: t.dtype.name, + }, + state_dict, + is_leaf=lambda t: isinstance(t, jax.Array), + ) + + def save( + self, + step: int, + state: train_state.TrainState, + config: Dict[str, Any], + ): + """ + Saves the checkpoint asynchronously + + NOTE that state is going to be copied for this operation + + Args: + step (int): The step of the checkpoint + state (TrainStateWithEma): A trainstate containing both the parameters and the optimizer state + config (Dict[str, Any]): A dictionary containing the configuration of the model + """ + self.wait() + args = ocp.args.Composite( + train_state=ocp.args.PyTreeSave( + state, + ocdbt_target_data_file_size=self.save_buffer_size_bytes, + ), + config=ocp.args.JsonSave(config), + meta_params=ocp.args.JsonSave(self.state_dict_to_structure_dict(state.params)), + ) + self.manager.save( + step, + args=args, + ) + + def wait(self): + """ + Waits for the checkpoint save operation to complete + """ + self.manager.wait_until_finished() + + +""" +Convert Torch checkpoints to JAX. + +This script loads a Torch checkpoint (either regular or sharded), converts it to Jax weights, and saved it. +""" +def main(args): + """ + Convert a Torch checkpoint into JAX. + """ + + if args.output_step_num > 1: + print( + "⚠️ Warning: The optimizer state is not converted. Changing the output step may lead to a mismatch between " + "the model parameters and optimizer state. This can affect optimizer moments and may result in a spike in " + "training loss when resuming from the converted checkpoint." + ) + + print("Loading safetensors, flush = True") + weight_file = "ltxv-13b-0.9.7-dev.safetensors" + + #download from huggingface, otherwise load from local + if (args.local_ckpt_path is None): + print("Loading from HF", flush = True) + model_name = "Lightricks/LTX-Video" + local_file_path = hf_hub_download( + repo_id=model_name, + filename=weight_file, + local_dir=args.download_ckpt_path, + local_dir_use_symlinks=False, + ) + else: + base_dir = args.local_ckpt_path + local_file_path = os.path.join(base_dir, weight_file) + torch_state_dict = load_file(local_file_path) + + + + print("Initializing pytorch transformer..", flush=True) + transformer_config = json.loads(open(args.transformer_config_path, "r").read()) + ignored_keys = ["_class_name", "_diffusers_version", "_name_or_path", "causal_temporal_positioning", "ckpt_path"] + for key in ignored_keys: + if key in transformer_config: + del transformer_config[key] + + transformer = Transformer3DModel_PT.from_config(transformer_config) + + + print("Loading torch weights into transformer..", flush=True) + transformer.load_state_dict(torch_state_dict) + torch_state_dict = transformer.state_dict() + + print("Creating jax transformer with params..", flush=True) + transformer_config["use_tpu_flash_attention"] = True + in_channels = transformer_config["in_channels"] + del transformer_config["in_channels"] + jax_transformer3d = JaxTranformer3DModel(**transformer_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch") + example_inputs = {} + batch_size, num_tokens = 2, 256 + input_shapes = { + "hidden_states": (batch_size, num_tokens, in_channels), + "indices_grid": (batch_size, 3, num_tokens), + "encoder_hidden_states": (batch_size, 128, transformer_config["caption_channels"]), + "timestep": (batch_size, 256), + "segment_ids": (batch_size, 256), + "encoder_attention_segment_ids": (batch_size, 128), + } + for name, shape in input_shapes.items(): + example_inputs[name] = jnp.ones( + shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool + ) + params_jax = jax_transformer3d.init(jax.random.PRNGKey(42), **example_inputs) + + + print("Converting torch params to jax..", flush=True) + params_jax = torch_statedict_to_jax(params_jax, torch_state_dict) + + print("Creating checkpointer and jax state for saving..", flush=True) + relative_ckpt_path = args.output_dir + absolute_ckpt_path = os.path.abspath(relative_ckpt_path) + tx = optax.adamw(learning_rate=1e-5) + with jax.default_device("cpu"): + state = train_state.TrainState( + step=args.output_step_num, + apply_fn=jax_transformer3d.apply, + params=params_jax, + tx=tx, + opt_state=tx.init(params_jax), + ) + with ocp.CheckpointManager(absolute_ckpt_path) as mngr: + mngr.save(args.output_step_num, args = ocp.args.StandardSave(state.params)) + print("Done.", flush=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Torch checkpoints to Jax format.") + parser.add_argument( + "--local_ckpt_path", + type=str, + required=False, + help="Local path of the checkpoint to convert. If not provided, will download from huggingface for example '/mnt/ckpt/00536000' or '/opt/dmd-torch-model/ema.pt'", + ) + + parser.add_argument( + "--download_ckpt_path", + type=str, + required=False, + help="Location to download safetensors from huggingface", + ) + + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Path to save the checkpoint to. for example 'gs://lt-research-mm-europe-west4/jax_trainings/converted-from-torch'", + ) + parser.add_argument( + "--output_step_num", + default=1, + type=int, + required=False, + help=( + "The step number to assign to the output checkpoint. The result will be saved using this step value. " + "⚠️ Warning: The optimizer state is not converted. Changing the output step may lead to a mismatch between " + "the model parameters and optimizer state. This can affect optimizer moments and may result in a spike in " + "training loss when resuming from the converted checkpoint." + ), + ) + parser.add_argument( + "--transformer_config_path", + default="/opt/txt2img/txt2img/config/transformer3d/ltxv2B-v1.0.json", + type=str, + required=False, + help="Path to Transformer3D structure config to load the weights based on.", + ) + + args = parser.parse_args() + main(args) diff --git a/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py b/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py new file mode 100644 index 00000000..9ee89d1d --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/diffusers_config_mapping.py @@ -0,0 +1,174 @@ +def make_hashable_key(dict_key): + def convert_value(value): + if isinstance(value, list): + return tuple(value) + elif isinstance(value, dict): + return tuple(sorted((k, convert_value(v)) for k, v in value.items())) + else: + return value + + return tuple(sorted((k, convert_value(v)) for k, v in dict_key.items())) + + +DIFFUSERS_SCHEDULER_CONFIG = { + "_class_name": "FlowMatchEulerDiscreteScheduler", + "_diffusers_version": "0.32.0.dev0", + "base_image_seq_len": 1024, + "base_shift": 0.95, + "invert_sigmas": False, + "max_image_seq_len": 4096, + "max_shift": 2.05, + "num_train_timesteps": 1000, + "shift": 1.0, + "shift_terminal": 0.1, + "use_beta_sigmas": False, + "use_dynamic_shifting": True, + "use_exponential_sigmas": False, + "use_karras_sigmas": False, +} +DIFFUSERS_TRANSFORMER_CONFIG = { + "_class_name": "LTXVideoTransformer3DModel", + "_diffusers_version": "0.32.0.dev0", + "activation_fn": "gelu-approximate", + "attention_bias": True, + "attention_head_dim": 64, + "attention_out_bias": True, + "caption_channels": 4096, + "cross_attention_dim": 2048, + "in_channels": 128, + "norm_elementwise_affine": False, + "norm_eps": 1e-06, + "num_attention_heads": 32, + "num_layers": 28, + "out_channels": 128, + "patch_size": 1, + "patch_size_t": 1, + "qk_norm": "rms_norm_across_heads", +} +DIFFUSERS_VAE_CONFIG = { + "_class_name": "AutoencoderKLLTXVideo", + "_diffusers_version": "0.32.0.dev0", + "block_out_channels": [128, 256, 512, 512], + "decoder_causal": False, + "encoder_causal": True, + "in_channels": 3, + "latent_channels": 128, + "layers_per_block": [4, 3, 3, 3, 4], + "out_channels": 3, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-06, + "scaling_factor": 1.0, + "spatio_temporal_scaling": [True, True, True, False], +} + +OURS_SCHEDULER_CONFIG = { + "_class_name": "RectifiedFlowScheduler", + "_diffusers_version": "0.25.1", + "num_train_timesteps": 1000, + "shifting": "SD3", + "base_resolution": None, + "target_shift_terminal": 0.1, +} + +OURS_TRANSFORMER_CONFIG = { + "_class_name": "Transformer3DModel", + "_diffusers_version": "0.25.1", + "_name_or_path": "PixArt-alpha/PixArt-XL-2-256x256", + "activation_fn": "gelu-approximate", + "attention_bias": True, + "attention_head_dim": 64, + "attention_type": "default", + "caption_channels": 4096, + "cross_attention_dim": 2048, + "double_self_attention": False, + "dropout": 0.0, + "in_channels": 128, + "norm_elementwise_affine": False, + "norm_eps": 1e-06, + "norm_num_groups": 32, + "num_attention_heads": 32, + "num_embeds_ada_norm": 1000, + "num_layers": 28, + "num_vector_embeds": None, + "only_cross_attention": False, + "out_channels": 128, + "project_to_2d_pos": True, + "upcast_attention": False, + "use_linear_projection": False, + "qk_norm": "rms_norm", + "standardization_norm": "rms_norm", + "positional_embedding_type": "rope", + "positional_embedding_theta": 10000.0, + "positional_embedding_max_pos": [20, 2048, 2048], + "timestep_scale_multiplier": 1000, +} +OURS_VAE_CONFIG = { + "_class_name": "CausalVideoAutoencoder", + "dims": 3, + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "blocks": [ + ["res_x", 4], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x", 3], + ["res_x", 4], + ], + "scaling_factor": 1.0, + "norm_layer": "pixel_norm", + "patch_size": 4, + "latent_log_var": "uniform", + "use_quant_conv": False, + "causal_decoder": False, +} + + +diffusers_and_ours_config_mapping = { + make_hashable_key(DIFFUSERS_SCHEDULER_CONFIG): OURS_SCHEDULER_CONFIG, + make_hashable_key(DIFFUSERS_TRANSFORMER_CONFIG): OURS_TRANSFORMER_CONFIG, + make_hashable_key(DIFFUSERS_VAE_CONFIG): OURS_VAE_CONFIG, +} + + +TRANSFORMER_KEYS_RENAME_DICT = { + "proj_in": "patchify_proj", + "time_embed": "adaln_single", + "norm_q": "q_norm", + "norm_k": "k_norm", +} + + +VAE_KEYS_RENAME_DICT = { + "decoder.up_blocks.3.conv_in": "decoder.up_blocks.7", + "decoder.up_blocks.3.upsamplers.0": "decoder.up_blocks.8", + "decoder.up_blocks.3": "decoder.up_blocks.9", + "decoder.up_blocks.2.upsamplers.0": "decoder.up_blocks.5", + "decoder.up_blocks.2.conv_in": "decoder.up_blocks.4", + "decoder.up_blocks.2": "decoder.up_blocks.6", + "decoder.up_blocks.1.upsamplers.0": "decoder.up_blocks.2", + "decoder.up_blocks.1": "decoder.up_blocks.3", + "decoder.up_blocks.0": "decoder.up_blocks.1", + "decoder.mid_block": "decoder.up_blocks.0", + "encoder.down_blocks.3": "encoder.down_blocks.8", + "encoder.down_blocks.2.downsamplers.0": "encoder.down_blocks.7", + "encoder.down_blocks.2": "encoder.down_blocks.6", + "encoder.down_blocks.1.downsamplers.0": "encoder.down_blocks.4", + "encoder.down_blocks.1.conv_out": "encoder.down_blocks.5", + "encoder.down_blocks.1": "encoder.down_blocks.3", + "encoder.down_blocks.0.conv_out": "encoder.down_blocks.2", + "encoder.down_blocks.0.downsamplers.0": "encoder.down_blocks.1", + "encoder.down_blocks.0": "encoder.down_blocks.0", + "encoder.mid_block": "encoder.down_blocks.9", + "conv_shortcut.conv": "conv_shortcut", + "resnets": "res_blocks", + "norm3": "norm3.norm", + "latents_mean": "per_channel_statistics.mean-of-means", + "latents_std": "per_channel_statistics.std-of-means", +} \ No newline at end of file diff --git a/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py b/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py new file mode 100644 index 00000000..c34df51f --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/skip_layer_strategy.py @@ -0,0 +1,8 @@ +from enum import Enum, auto + + +class SkipLayerStrategy(Enum): + AttentionSkip = auto() + AttentionValues = auto() + Residual = auto() + TransformerBlock = auto() \ No newline at end of file diff --git a/src/maxdiffusion/models/ltx_video/utils/torch_compat.py b/src/maxdiffusion/models/ltx_video/utils/torch_compat.py new file mode 100644 index 00000000..450cf9b1 --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/utils/torch_compat.py @@ -0,0 +1,519 @@ +import re +from copy import copy +from dataclasses import dataclass +from functools import partial +from typing import Callable, Dict, List, Optional, Tuple, Union, Any + +import flax +import jax +import torch +import torch.utils._pytree as pytree +from flax.traverse_util import flatten_dict + + + +AnyTensor = Union[jax.Array, torch.Tensor] +StateDict = Dict[str, AnyTensor] + +ScanRepeatableCarryBlock = "ScanRepeatableCarryBlock" + +JaxParams = Dict[str, Union[Dict[str, jax.Array], jax.Array]] + +def unbox_logically_partioned(statedict: JaxParams) -> JaxParams: + return jax.tree_util.tree_map( + lambda t: t.unbox() if isinstance(t, flax.linen.spmd.LogicallyPartitioned) else t, + statedict, + is_leaf=lambda k: isinstance(k, flax.linen.spmd.LogicallyPartitioned), + ) + +def torch_tensor_to_jax_array(data: torch.Tensor) -> jax.Array: + match data.dtype: + case torch.bfloat16: + return jax.numpy.from_dlpack(data) + case _: + return jax.numpy.array(data) + + +def is_stack_or_tensor(param: Any) -> bool: + """ + Returns True if param is of type tensor or list/tuple of tensors (stack of tensors) + + Used for mapping utils + """ + return isinstance(param, (torch.Tensor, list, tuple)) + + +def convert_tensor_stack_to_tensor(param: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor: + """ + Converts a list of torch tensors to a single torch tensor. + Args: + param (Union[List[torch.Tensor], torch.Tensor]): The parameter to convert. + + Returns: + torch.Tensor: The converted tensor. + """ + if isinstance(param, list): + return torch.stack(param) + return param + + +@dataclass +class ConvertAction: + """ + Defines a set of actions to be done on a given parameter. + + The definition must be commutative, i.e. the order of the actions should not matter. + also we should strive for actions to be reversible (so the same action can be used for both directions). + """ + + transpose: Optional[Tuple[int, int]] = None + """ + If defined, transposes the tensor with the given indices. + Example: (1, 0) transposes a (at least 2D tensor) from (..., a, b) to (..., b, a). + """ + + rename: Optional[Dict[str, str]] = None + """ + If defined, renames the parameter according to the given mapping. + Example: {"torch": "weight", "jax": "kernel"} + * renames "torch.weight" to "jax.kernel" when converting from torch to jax. + * renames "jax.kernel" to "torch.weight" when converting from jax to torch. + """ + + split_by: Optional[str] = None + """ + If defined, splits the parameter by the given delimiter. + Example: "ScanRepeatableCarryBlock.k1" assumes the parameter is a concatenation of multiple tensors (shaped: (n, ...)). + and splits them into individual tensors named as "ScanRepeatableCarryBlock.0.k1", "ScanRepeatableCarryBlock.n.k1". + """ + + group_by: Optional[str] = None + """ + If defined, groups the parameter by the given delimiter. + Example: "ScanRepeatableCarryBlock.0.k1", "ScanRepeatableCarryBlock.1.k1", "ScanRepeatableCarryBlock.2.k1" + will be grouped into a single tensor named "ScanRepeatableCarryBlock.k1" shaped (n, ...). + + *** Note: + this is kind of the reverse of split_by, only a different behavior. + it's easy to define "actions" that are reversible in base of context (jax->torch, torch->jax). + but it's very wrong to do so, since it blocks modular behavior and makes the code harder to maintain. + + """ + + jax_groups: Optional[List[str]] = None + """ + Generally used in group_by, this is a list of all possible keys that can be used to group the parameters. + This must be defined if group_by is defined. + + It's due to the un-reversibility nature of the group_by action. + """ + + def apply_transpose(self, mini_statedict: StateDict) -> StateDict: + """ + Applies the transpose action if defined + Args: + mini_statedict (StateDict): Local context of the state dict + + Returns: + StateDict: Output local context of the state dict + """ + + if self.transpose is None: + return mini_statedict + index0, index1 = self.transpose + return {param_name: param.swapaxes(index0, index1) for param_name, param in mini_statedict.items()} + + def apply_rename(self, mini_statedict: StateDict, delim: str) -> StateDict: + """ + Applies the rename action if defined + + Args: + mini_statedict (StateDict): Local context of the state dict + delim (str): delimiter used for parsing (usually "."), kept as parameter for flexibility. + + Returns: + StateDict: Output local context of the state dict + """ + if self.rename is None: + return mini_statedict + + param_names = list(mini_statedict.keys()) + for param_name in param_names: + param = mini_statedict.pop(param_name) + parts = param_name.split(delim) + rename_source = "torch" if isinstance(param, torch.Tensor) else "jax" + rename_target = "jax" if isinstance(param, torch.Tensor) else "torch" + source_name = self.rename[rename_source] + dest_name = self.rename[rename_target] + if source_name == param_name: + new_param_name = dest_name + else: + # There is always ```self.rename[rename_source]``` in parts + index = parts.index(self.rename[rename_source]) + parts[index] = self.rename[rename_target] + new_param_name = delim.join(parts) + mini_statedict[new_param_name] = param + + return mini_statedict + + def apply_split_by(self, mini_statedict: StateDict, new_params: List, delim: str) -> Tuple[StateDict, List[str]]: + """ + Applies the split_by action if defined + + Args: + mini_statedict (StateDict): Local state dict + new_params (List): State containing list of new params that were created during the process (if any) + delim (str): Output local context of the state dict + + Returns: + Tuple[StateDict, List[str]]: Output local context of the state dict and list of new keys to add to the global state dict. + """ + if self.split_by is None: + return mini_statedict, new_params + + param_names = list(mini_statedict.keys()) + for param_name in param_names: + parts = param_name.split(delim) + indices = [i for i, p in enumerate(parts) if self.split_by in p] + if len(indices) != 1: + raise ValueError(f"Expected exactly one split_by in param_name: {param_name}") + index = indices[0] + params = mini_statedict.pop(param_name) + for i, param in enumerate(params): + new_parts = parts[:index] + [f"{i}"] + parts[index + 2 :] + new_param_name = delim.join(new_parts) + mini_statedict[new_param_name] = param + new_params.append(new_param_name) + + return mini_statedict, new_params + + def apply_group_by( + self, mini_statedict: StateDict, new_params: List, full_statedict: StateDict, delim: str + ) -> Tuple[StateDict, List[str]]: + """ + Applies the group_by action if defined + + Args: + mini_statedict (StateDict): Local state dict + new_params (List): State containing list of new params that were created during the process (if any) + full_statedict (StateDict): Global context of the state dict + delim (str): delimiter used for parsing (usually "."), kept as parameter for flexibility. + + Returns: + Tuple[StateDict, List[str]]: Output local context of the state dict and list of new keys to add to the global state dict. + """ + if self.group_by is None: + return mini_statedict, new_params + + param_names = list(mini_statedict.keys()) + for param_name in param_names: + param = mini_statedict.pop(param_name) + jax_keywords = extract_scan_keywords(param_name, self.jax_groups, delim) + block_index = re.findall(r"\.\d+\.", param_name)[0][1:-1] + parts = param_name.split(delim) + index = parts.index(block_index) + prefix = delim.join(parts[:index]) + suffix = delim.join(parts[index + 1 :]) + + new_param_name = f"{prefix}.{delim.join(jax_keywords)}.{suffix}" + + if new_param_name not in full_statedict: + full_statedict[new_param_name] = [param] + else: + full_statedict[new_param_name] = full_statedict[new_param_name] + [param] + + return mini_statedict, new_params + + def __call__( + self, + mini_statedict: StateDict, + new_params: List, + full_statedict: StateDict, + delim: str, + ) -> Tuple[StateDict, List[str]]: + """ + Given a state dict, applies the transformations defined in the ConvertAction. + + Args: + mini_statedict (StateDict): Local context of the state dict + new_params (List): new params that were created during the process (if any) + full_statedict (StateDict): Global context of the state dict + delim (str): delimiter used for parsing (usually "."), kept as parameter for flexibility. + + Returns: + Tuple[StateDict, List[str]]: Updated local state dict and list of new keys to add to the global state dict. + """ + mini_statedict = self.apply_transpose(mini_statedict) + mini_statedict = self.apply_rename(mini_statedict, delim) + mini_statedict, new_params = self.apply_split_by(mini_statedict, new_params, delim) + mini_statedict, new_params = self.apply_group_by(mini_statedict, new_params, full_statedict, delim) + return mini_statedict, new_params + + +def is_kernel_2d(param_name: str, param: AnyTensor) -> bool: + """ + Checks if the parameter is a 2D kernel (weight) or not. + usually applies to linear layers or convolutions. + Args: + param_name (str): Name of the parameter + param (AnyTensor): The parameter itself (could be either jax or torch Tensor) + + Returns: + bool: True if the parameter is a weight for linear/convolutional layer or not. + """ + expected_name = "weight" if isinstance(param, torch.Tensor) else "kernel" + return expected_name in param_name and param.ndim == 2 + + +def is_scan_repeatable(param_name: str, _) -> bool: + """ + Checks if the parameter is a scan repeatable carry block parameter. + + Args: + param_name (str): Parameter name + _ (_type_): Unused, will contain the parameter itself + + Returns: + bool: True if the parameter is a scan repeatable carry block parameter or not. + """ + return ScanRepeatableCarryBlock in param_name + + +def is_scale_shift_table(param_name: str, _) -> bool: + """ + Checks if the parameter is a scale shift table parameter. + + Args: + param_name (str): Parameter name + _ (_type_): Unused, will contain the parameter itself + + Returns: + bool: True if the parameter is a scale shift table parameter or not. + """ + return "scale_shift_table" in param_name + + +def is_affine_scale_param(param_name: str, parameter: AnyTensor, jax_flattened_keys: List[str]) -> bool: + """ + Checks if the parameter is an affine scale parameter. + + Args: + param_name (str): Parameter name + parameter (AnyTensor): The parameter itself + jax_flattened_keys (List[str]): Flattened list of the keys use in jax (for reference and keys search) + + + Returns: + bool: True if the parameter is an affine scale parameter or not. + """ + if isinstance(parameter, torch.Tensor): + return "weight" in param_name and parameter.ndim == 1 and param_name not in jax_flattened_keys + else: + return "scale" in param_name and parameter.ndim == 1 + + +def extract_scan_keywords(param_name: str, jax_flattened_keys: List[str], delim: str) -> Optional[Tuple[str, str]]: + """ + Extracts the keywords from the scan repeatable carry block parameter (if exists) + + If the parameter is a scan repeatable carry block, it will return the keywords that are used to group the parameters. + otherwise it will return None. + + Args: + param_name (str): Name of the parameter + jax_flattened_keys (List[str]): Flattened list of the keys use in jax (for reference and keys search) + delim (str): The delimiter used in the parameter name (in torch) + + Returns: + Optional[Tuple[str, str]]: Tuple of the keywords used to group the parameters (or None if it is not a scan repeatable carry block) + """ + block_indices = re.findall(r"\.\d+\.", param_name) + + if len(block_indices) == 0: + return None + block_indices = [block_indices[0]] + block_index = block_indices[0][1:-1] + parts = param_name.split(delim) + index = parts.index(block_index) + prefix = delim.join(parts[:index]) + suffix = delim.join(parts[index + 1 :]) + + for flat_key in jax_flattened_keys: + if flat_key.startswith(prefix) and flat_key.endswith(suffix): + mid_layer = flat_key[len(prefix) + 1 : -len(suffix) - 1] + mid_parts = mid_layer.split(delim) + if not any(ScanRepeatableCarryBlock in mid_part for mid_part in mid_parts): + continue + return mid_parts + + return None + + +def should_be_scan_repeatable(param_name: str, param: AnyTensor, jax_flattened_keys: List[str], delim: str) -> bool: + """ + Checks if the parameter should be a scan repeatable carry block or not. + Args: + param_name (str): The name of the parameter + param (AnyTensor): the Parameter itself + jax_flattened_keys (List[str]): Flattened list of the keys use in jax (for reference and keys search) + delim (str): The delimiter used in the parameter name (in torch) + + Returns: + bool: True if the paramter should be treated scan repeatable block parameter. + """ + if not isinstance(param, torch.Tensor): + return False + + keywords = extract_scan_keywords(param_name, jax_flattened_keys, delim) + return keywords is not None + + +def jax_statedict_to_torch( + jax_params: JaxParams, rulebook: Optional[Dict[Callable[[str, AnyTensor], bool], ConvertAction]] = None +) -> Dict[str, torch.Tensor]: + """ + Converts a JAX state dict to a torch state dict. + + Args: + jax_params (JaxParams): The current params in JAX format, to ease parsing and conversion. + rulebook (Optional[Dict[Callable[[str, AnyTensor], bool], ConvertAction]], optional): Defines a rulebook stating how to convert state dict from jax to torch. + Defaults to None. + + + Returns: + Dict[str, torch.Tensor]: The converted state dict in torch format (Pytorch state dict). + """ + + affine_scale_search = partial(is_affine_scale_param, jax_flattened_keys=[]) + + if rulebook is None: + rulebook = { + is_scan_repeatable: ConvertAction(split_by=ScanRepeatableCarryBlock), + is_kernel_2d: ConvertAction(transpose=(1, 0), rename=dict(torch="weight", jax="kernel")), + affine_scale_search: ConvertAction(rename=dict(torch="weight", jax="scale")), + } + if "params" not in jax_params: + raise ValueError('Expected "params" key in jax_params, are you sure you are passing the correct object?') + + jax_params = copy(jax_params["params"]) # Non reference copy + jax_params = unbox_logically_partioned(jax_params) + + delim = "." + # Move to flattened dict to match torch state dict convention + flattened_params = flatten_dict(jax_params, sep=delim) + + param_names = list(flattened_params.keys()) + for param_name in param_names: + param = flattened_params.pop(param_name) + mini_statedict = {param_name: param} + new_params = [] + for condition, rule in rulebook.items(): + if condition(param_name, param): + mini_statedict, new_params = rule(mini_statedict, new_params, flattened_params, delim) + if len(mini_statedict) == 1: + param_name = list(mini_statedict.keys())[0] + + flattened_params.update(mini_statedict) + param_names.extend(new_params) + + flattened_params = pytree.tree_map(convert_tensor_stack_to_tensor, flattened_params, is_leaf=is_stack_or_tensor) + + to_cpu = pytree.tree_map(lambda t: jax.device_put(t, jax.devices("cpu")[0]), flattened_params) + to_torch = pytree.tree_map(torch.from_dlpack, to_cpu) + return to_torch + + +def torch_statedict_to_jax( + jax_params: JaxParams, + torch_params: Dict[str, torch.Tensor], +) -> JaxParams: + """ + Converts a torch state dict to a JAX state dict. + + Args: + jax_params (JaxParams): The current params in JAX format, to ease parsing and conversion. + torch_params (Dict[str, torch.Tensor]): The current params in torch format, to load parameters from. + + Returns: + JaxParams: The state dict in JAX format. + """ + with jax.default_device("cpu"): + jax_params = copy(jax_params) + jax_params = unbox_logically_partioned(jax_params) + torch_params = copy(torch_params) + + if "params" not in jax_params: + raise ValueError('Expected "params" key in jax_params, are you sure you are passing the correct object?') + + delim = "." + flattened_keys = list(flatten_dict(jax_params["params"], sep=".").keys()) + scan_repeatable_cond = partial(should_be_scan_repeatable, jax_flattened_keys=flattened_keys, delim=delim) + affine_scale_search = partial(is_affine_scale_param, jax_flattened_keys=flattened_keys) + + rulebook = { + is_kernel_2d: ConvertAction(transpose=(1, 0), rename=dict(torch="weight", jax="kernel")), + affine_scale_search: ConvertAction(rename=dict(torch="weight", jax="scale")), + scan_repeatable_cond: ConvertAction(group_by=ScanRepeatableCarryBlock, jax_groups=flattened_keys), + } + + # First pass - Rulebook + param_names = list(torch_params.keys()) + for param_name in param_names: + param = torch_params.pop(param_name) + mini_statedict = {param_name: param} + new_params = [] + for condition, rule in rulebook.items(): + if condition(param_name, param): + mini_statedict, new_params = rule(mini_statedict, new_params, torch_params, delim=delim) + if len(mini_statedict) == 1: + param_name = list(mini_statedict.keys())[0] + + torch_params.update(mini_statedict) + param_names.extend(new_params) + + # Ensures any list of tensors are converted to a single tensor + # This is due to the fact that the scan repeatable block is a list of tensors + torch_params = pytree.tree_map(convert_tensor_stack_to_tensor, torch_params, is_leaf=is_stack_or_tensor) + + to_jax: Dict = pytree.tree_map(torch_tensor_to_jax_array, torch_params) + + def nested_insert(param_name: str, param: torch.Tensor, nested_dict: Dict): + """ + Inserts a parameter into a nested dictionary. (to fit Jax format) + The keys in torch are split into groups by a delimiter of choice (usually "." to fit torch schema) + and then inserted into a nested dictionary. + + in case the parameter is of the form of "a.b" and "a.b" is a layer type in jax - + the parameter will be inserted as "a.b": {...: param}. this ensures compatibility between jax layers and torch layers. + + Args: + param_name (str): Parameter name + param (torch.Tensor): Parameter itself + nested_dict (Dict): Current nested dict state + """ + if delim not in param_name: + nested_dict[param_name] = param + return + + parts = param_name.split(delim) + if len(parts) == 1: + return nested_insert(parts[0], param, nested_dict) + else: + key = parts[0] + # May be either complex key or nested key + if len(parts) > 2 and re.fullmatch(r"\d+", parts[1]) is not None: + key = delim.join(parts[:2]) + new_param_name = delim.join(parts[2:]) + else: + new_param_name = delim.join(parts[1:]) + new_nested_dict = nested_dict.get(key, {}) + nested_dict[key] = new_nested_dict + return nested_insert(new_param_name, param, new_nested_dict) + + params = {} + for param_name, param in to_jax.items(): + nested_insert(param_name, param, params) + + # Jax state dict is usually held as dict containings "parmas" keys which contains + # dict of dict containing all the params + return {"params": params} diff --git a/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json new file mode 100644 index 00000000..10414eab --- /dev/null +++ b/src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json @@ -0,0 +1,26 @@ +{ + "ckpt_path": "/mnt/disks/diffusionproj/jax_weights", + "activation_fn": "gelu-approximate", + "attention_bias": true, + "attention_head_dim": 128, + "attention_type": "default", + "caption_channels": 4096, + "cross_attention_dim": 4096, + "double_self_attention": false, + "dropout": 0.0, + "norm_elementwise_affine": false, + "norm_eps": 1e-06, + "num_attention_heads": 32, + "num_embeds_ada_norm": 1000, + "num_layers": 48, + "only_cross_attention": false, + "out_channels": 128, + "upcast_attention": false, + "qk_norm": "rms_norm", + "standardization_norm": "rms_norm", + "positional_embedding_type": "rope", + "positional_embedding_theta": 10000.0, + "positional_embedding_max_pos": [20, 2048, 2048], + "timestep_scale_multiplier": 1000, + "in_channels": 128 +} diff --git a/test_lightning.png b/test_lightning.png deleted file mode 100644 index 36e844cc..00000000 Binary files a/test_lightning.png and /dev/null differ