Skip to content

model setup [WIP: do not merge] #181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,4 @@ else
fi

# Install maxdiffusion
pip3 install -U . || echo "Failed to install maxdiffusion" >&2
pip3 install -e . || echo "Failed to install maxdiffusion" >&2
2 changes: 2 additions & 0 deletions src/maxdiffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
_import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["models.flux.transformers.transformer_flux_flax"] = ["FluxTransformer2DModel"]
_import_structure["models.ltx_video.transformers.transformer3d"] = ["Transformer3DModel"]
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
_import_structure["schedulers"].extend(
Expand Down Expand Up @@ -453,6 +454,7 @@
from .models.modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
from .models.ltx_video.transformers.transformer3d import Transformer3DModel
from .models.vae_flax import FlaxAutoencoderKL
from .pipelines import FlaxDiffusionPipeline
from .schedulers import (
Expand Down
9 changes: 7 additions & 2 deletions src/maxdiffusion/checkpointing/checkpointing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,13 @@ def load_state_if_possible(
max_logging.log(f"restoring from this run's directory latest step {latest_step}")
try:
if not enable_single_replica_ckpt_restoring:
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))
# item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
# return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) #currently changed to this
if checkpoint_item == " ":
return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state))
else:
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) #currently changed to this

def map_to_pspec(data):
pspec = data.sharding.spec
Expand Down
51 changes: 51 additions & 0 deletions src/maxdiffusion/configs/ltx_video.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#hardware
hardware: 'tpu'
skip_jax_distributed_system: False

jax_cache_dir: ''
weights_dtype: 'bfloat16'
activations_dtype: 'bfloat16'


run_name: ''
output_dir: 'ltx-video-output'
save_config_to_gcs: False

#parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
logical_axis_rules: [
['batch', 'data'],
['activation_batch', ['data','fsdp']],
['activation_heads', 'tensor'],
['activation_kv', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
['heads', 'tensor'],
['conv_batch', ['data','fsdp']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1
dcn_tensor_parallelism: 1
ici_data_parallelism: -1
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1




learning_rate_schedule_steps: -1
max_train_steps: 500 #TODO: change this
pretrained_model_name_or_path: ''
unet_checkpoint: ''
dataset_name: 'diffusers/pokemon-gpt4-captions'
train_split: 'train'
dataset_type: 'tf'
cache_latents_text_encoder_outputs: True
per_device_batch_size: 1
compile_topology_num_slices: -1
quantization_local_shard_count: -1
jit_initializers: True
enable_single_replica_ckpt_restoring: False
92 changes: 92 additions & 0 deletions src/maxdiffusion/generate_ltx_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from json import encoder
from absl import app
from typing import Sequence
import jax
from flax import linen as nn
import json
from flax.linen import partitioning as nn_partitioning
from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel
import os
import functools
import jax.numpy as jnp
from maxdiffusion import pyconfig
from maxdiffusion.max_utils import (
create_device_mesh,
setup_initial_state,
get_memory_allocations,
)
from jax.sharding import Mesh, PartitionSpec as P
import orbax.checkpoint as ocp


def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond, segment_ids, encoder_attention_segment_ids):
print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype)
print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype)
print("latents.shape: ", latents.shape, latents.dtype)
print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype)
print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype)
print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype)
print("encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype)

def run(config):
key = jax.random.PRNGKey(0)

devices_array = create_device_mesh(config)
mesh = Mesh(devices_array, config.mesh_axes)

base_dir = os.path.dirname(__file__)

##load in model config
config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json")
with open(config_path, "r") as f:
model_config = json.load(f)
relative_ckpt_path = model_config["ckpt_path"]

ignored_keys = ["_class_name", "_diffusers_version", "_name_or_path", "causal_temporal_positioning", "in_channels", "ckpt_path"]
in_channels = model_config["in_channels"]
for name in ignored_keys:
if name in model_config:
del model_config[name]


transformer = Transformer3DModel(**model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh)
transformer_param_shapes = transformer.init_weights(in_channels, model_config['caption_channels'], eval_only = True) #use this to test!

weights_init_fn = functools.partial(
transformer.init_weights,
in_channels,
model_config['caption_channels'],
eval_only = True
)

absolute_ckpt_path = os.path.abspath(relative_ckpt_path)

checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path)
transformer_state, transformer_state_shardings = setup_initial_state(
model=transformer,
tx=None,
config=config,
mesh=mesh,
weights_init_fn=weights_init_fn,
checkpoint_manager=checkpoint_manager,
checkpoint_item=" ",
model_params=None,
training=False,
)





def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
run(pyconfig.config)


if __name__ == "__main__":
app.run(main)





6 changes: 5 additions & 1 deletion src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,11 @@ def setup_initial_state(
config.enable_single_replica_ckpt_restoring,
)
if state:
state = state[checkpoint_item]
###!Edited
if checkpoint_item == " ":
state = state
else:
state = state[checkpoint_item]
if not state:
max_logging.log(f"Could not find the item in orbax, creating state...")
init_train_state_partial = functools.partial(
Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import TYPE_CHECKING

from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available
from maxdiffusion.utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available


_import_structure = {}
Expand All @@ -32,7 +32,7 @@
from .vae_flax import FlaxAutoencoderKL
from .lora import *
from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel

from .ltx_video.transformers.transformer3d import Transformer3DModel
else:
import sys

Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,4 +1188,4 @@ def setup(self):
def __call__(self, hidden_states, deterministic=True):
hidden_states = self.proj(hidden_states)
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
Empty file.
70 changes: 70 additions & 0 deletions src/maxdiffusion/models/ltx_video/gradient_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from enum import Enum, auto
from typing import Optional

import jax
from flax import linen as nn

SKIP_GRADIENT_CHECKPOINT_KEY = "skip"


class GradientCheckpointType(Enum):
"""
Defines the type of the gradient checkpoint we will have

NONE - means no gradient checkpoint
FULL - means full gradient checkpoint, wherever possible (minimum memory usage)
MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation,
except for ones that involve batch dimension - that means that all attention and projection
layers will have gradient checkpoint, but not the backward with respect to the parameters
"""

NONE = auto()
FULL = auto()
MATMUL_WITHOUT_BATCH = auto()

@classmethod
def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType":
"""
Constructs the gradient checkpoint type from a string

Args:
s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None.

Returns:
GradientCheckpointType: The policy that corresponds to the string
"""
if s is None:
s = "none"
return GradientCheckpointType[s.upper()]

def to_jax_policy(self):
"""
Converts the gradient checkpoint type to a jax policy
"""
match self:
case GradientCheckpointType.NONE:
return SKIP_GRADIENT_CHECKPOINT_KEY
case GradientCheckpointType.FULL:
return None
case GradientCheckpointType.MATMUL_WITHOUT_BATCH:
return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims

def apply(self, module: nn.Module) -> nn.Module:
"""
Applies a gradient checkpoint policy to a module
if no policy is needed, it will return the module as is

Args:
module (nn.Module): the module to apply the policy to

Returns:
nn.Module: the module with the policy applied
"""
policy = self.to_jax_policy()
if policy == SKIP_GRADIENT_CHECKPOINT_KEY:
return module
return nn.remat( # pylint: disable=invalid-name
module,
prevent_cse=False,
policy=policy,
)
Loading
Loading