diff --git a/opensora/train/train_t2v_diffusers.py b/opensora/train/train_t2v_diffusers.py index 26bb1abab..a0e5ec9d5 100644 --- a/opensora/train/train_t2v_diffusers.py +++ b/opensora/train/train_t2v_diffusers.py @@ -47,7 +47,7 @@ from transformers.utils import ContextManagers from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.utils import DistributedType, ProjectConfiguration, set_seed +from accelerate.utils import DistributedType, ProjectConfiguration, set_seed, DeepSpeedPlugin from accelerate.state import AcceleratorState from packaging import version from tqdm.auto import tqdm @@ -56,7 +56,7 @@ import diffusers from diffusers import DDPMScheduler, PNDMScheduler, DPMSolverMultistepScheduler, CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler from diffusers.optimization import get_scheduler -from diffusers.training_utils import EMAModel, compute_snr +from diffusers.training_utils import compute_snr from diffusers.utils import check_min_version, is_wandb_available from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3 @@ -69,6 +69,7 @@ from opensora.models.diffusion import Diffusion_models, Diffusion_models_class from opensora.utils.dataset_utils import Collate, LengthGroupedSampler from opensora.utils.utils import explicit_uniform_sampling +from opensora.utils.ema import EMAModel from opensora.sample.pipeline_opensora import OpenSoraPipeline from opensora.models.causalvideovae import ae_stride_config, ae_wrapper @@ -78,6 +79,8 @@ check_min_version("0.24.0") logger = get_logger(__name__) +GB = 1024 * 1024 * 1024 + @torch.inference_mode() def log_validation(args, model, vae, text_encoder, tokenizer, accelerator, weight_dtype, global_step, ema=False): positive_prompt = "(masterpiece), (best quality), (ultra-detailed), {}. emotional, harmonious, vignette, 4k epic detailed, shot on kodak, 35mm photo, sharp focus, high budget, cinemascope, moody, epic, gorgeous" @@ -170,11 +173,26 @@ def main(args): npu_config.seed_everything(args.seed) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + # https://huggingface.co/docs/accelerate/v1.0.0/en/usage_guides/deepspeed_multiple_model#using-multiple-models-with-deepspeed + train_plugin = DeepSpeedPlugin(hf_ds_config=args.train_deepspeed_config_file) + text_encoder_plugin = DeepSpeedPlugin(hf_ds_config=args.eval_deepspeed_config_file) + text_encoder_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size + + # Use FP32 for EMA model + ema_plugin = DeepSpeedPlugin(hf_ds_config=args.eval_deepspeed_config_file) + ema_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size + ema_plugin.deepspeed_config["bf16"]["enabled"] = False + accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, + deepspeed_plugins={ + "train": train_plugin, + "text_encoder": text_encoder_plugin, + "ema": ema_plugin, + }, ) if args.num_frames != 1: @@ -219,44 +237,37 @@ def main(args): elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - # Create model: - - # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. - # For this to work properly all models must be run through `accelerate.prepare`. But accelerate - # will try to assign the same optimizer with the same weights to all models during - # `deepspeed.initialize`, which of course doesn't work. - # - # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 - # frozen models from being partitioned during `zero.Init` which gets called during - # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding - # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. - def deepspeed_zero_init_disabled_context_manager(): - """ - returns either a context list that includes one that will disable zero.Init or an empty context list - """ - deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None - if deepspeed_plugin is None: - return [] - - return [deepspeed_plugin.zero3_init_context_manager(enable=False)] - - with ContextManagers(deepspeed_zero_init_disabled_context_manager()): - kwargs = {} - ae = ae_wrapper[args.ae](args.ae_path, cache_dir=args.cache_dir, **kwargs).eval() - - if args.enable_tiling: - ae.vae.enable_tiling() - - kwargs = { - 'torch_dtype': weight_dtype, - 'low_cpu_mem_usage': False - } - text_enc_1 = get_text_warpper(args.text_encoder_name_1)(args, **kwargs).eval() + checkpoint_path, global_step = None, 0 + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + checkpoint_path = args.resume_from_checkpoint + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + checkpoint_path = os.path.join(args.output_dir, dirs[-1]) if len(dirs) > 0 else None - text_enc_2 = None - if args.text_encoder_name_2 is not None: - text_enc_2 = get_text_warpper(args.text_encoder_name_2)(args, **kwargs).eval() + if checkpoint_path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + else: + accelerator.print(f"Resuming from checkpoint {checkpoint_path}") + global_step = int(checkpoint_path.split("-")[-1]) + + # STEP 1: Prepare VAE model + kwargs = {} + ae = ae_wrapper[args.ae](args.ae_path, cache_dir=args.cache_dir, **kwargs).eval() + if args.enable_tiling: + ae.vae.enable_tiling() + ae.vae.requires_grad_(False) + ae.vae.to(accelerator.device, dtype=weight_dtype) + logger.info(f"Load VAE model finish, memory_allocated: {torch.cuda.memory_allocated()/GB:.2f} GB", main_process_only=True) + # STEP 2: Prepare diffusion model (Trained model must be prepared first!) + accelerator.state.select_deepspeed_plugin("train") ae_stride_t, ae_stride_h, ae_stride_w = ae_stride_config[args.ae] ae.vae_scale_factor = (ae_stride_t, ae_stride_h, ae_stride_w) assert ae_stride_h == ae_stride_w, f"Support only ae_stride_h == ae_stride_w now, but found ae_stride_h ({ae_stride_h}), ae_stride_w ({ae_stride_w})" @@ -317,13 +328,6 @@ def deepspeed_zero_init_disabled_context_manager(): print(f'Successfully load {len(model_state_dict) - len(missing_keys)}/{len(model_state_dict)} keys from {args.pretrained}!') model.gradient_checkpointing = args.gradient_checkpointing - # Freeze vae and text encoders. - ae.vae.requires_grad_(False) - text_enc_1.requires_grad_(False) - if text_enc_2 is not None: - text_enc_2.requires_grad_(False) - # Set model as trainable. - model.train() kwargs = dict( prediction_type=args.prediction_type, @@ -341,62 +345,15 @@ def deepspeed_zero_init_disabled_context_manager(): noise_scheduler_copy = copy.deepcopy(noise_scheduler) else: noise_scheduler = DDPMScheduler(**kwargs) - # Move unet, vae and text_encoder to device and cast to weight_dtype - # The VAE is in float32 to avoid NaN losses. - if not args.extra_save_mem: - ae.vae.to(accelerator.device, dtype=torch.float32 if args.vae_fp32 else weight_dtype) - text_enc_1.to(accelerator.device, dtype=weight_dtype) - if text_enc_2 is not None: - text_enc_2.to(accelerator.device, dtype=weight_dtype) # Create EMA for the unet. if args.use_ema: - ema_model = deepcopy(model) - ema_model = EMAModel(ema_model.parameters(), decay=args.ema_decay, update_after_step=args.ema_start_step, - model_cls=Diffusion_models_class[args.model], model_config=ema_model.config, - foreach=args.foreach_ema) - - # `accelerate` 0.16.0 will have better support for customized saving - if version.parse(accelerate.__version__) >= version.parse("0.16.0"): - # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format - def save_model_hook(models, weights, output_dir): - if accelerator.is_main_process: - if args.use_ema: - ema_model.save_pretrained(os.path.join(output_dir, "model_ema")) - - for i, model in enumerate(models): - model.save_pretrained(os.path.join(output_dir, "model")) - if weights: # Don't pop if empty - # make sure to pop weight so that corresponding model is not saved again - weights.pop() - - def load_model_hook(models, input_dir): - if args.use_ema: - load_model = EMAModel.from_pretrained( - os.path.join(input_dir, "model_ema"), - Diffusion_models_class[args.model], - foreach=args.foreach_ema, - ) - ema_model.load_state_dict(load_model.state_dict()) - if args.offload_ema: - ema_model.pin_memory() - else: - ema_model.to(accelerator.device) - del load_model - - for i in range(len(models)): - # pop models so that they are not loaded again - model = models.pop() - - # load diffusers style into model - load_model = Diffusion_models_class[args.model].from_pretrained(input_dir, subfolder="model") - model.register_to_config(**load_model.config) - - model.load_state_dict(load_model.state_dict()) - del load_model - - accelerator.register_save_state_pre_hook(save_model_hook) - accelerator.register_load_state_pre_hook(load_model_hook) + if checkpoint_path: + ema_model = EMAModel.from_pretrained(os.path.join(checkpoint_path, "model_ema"), + model_cls=Diffusion_models_class[args.model]) + else: + ema_model = EMAModel(deepcopy(model), decay=args.ema_decay, update_after_step=args.ema_start_step, + model_cls=Diffusion_models_class[args.model], model_config=model.config) # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices @@ -523,12 +480,34 @@ def load_model_hook(models, input_dir): model, optimizer, train_dataloader, lr_scheduler ) logger.info(f'after accelerator.prepare') - + if checkpoint_path: + accelerator.load_state(checkpoint_path) + logger.info(f"Load diffusion model finish, memory_allocated: {torch.cuda.memory_allocated()/GB:.2f} GB", main_process_only=True) + + # STEP 3: Prepare text encoder model + accelerator.state.select_deepspeed_plugin("text_encoder") + text_enc_1 = get_text_warpper(args.text_encoder_name_1)(args) + text_enc_1 = accelerator.prepare(text_enc_1) + text_enc_1.eval() + logger.info(f"Load text encoder model 1 finish, memory_allocated: {torch.cuda.memory_allocated()/GB:.2f} GB", main_process_only=True) + + text_enc_2 = None + if args.text_encoder_name_2 is not None: + text_enc_2 = get_text_warpper(args.text_encoder_name_2)(args) + text_enc_2 = accelerator.prepare(text_enc_2) + text_enc_2.eval() + logger.info(f"Load text encoder model 2 finish, memory_allocated: {torch.cuda.memory_allocated()/GB:.2f} GB", main_process_only=True) + + # STEP 4: Prepare EMA model + accelerator.state.select_deepspeed_plugin("ema") if args.use_ema: - if args.offload_ema: - ema_model.pin_memory() - else: - ema_model.to(accelerator.device) + ema_model.model = accelerator.prepare(ema_model.model) + ema_model.model.eval() + logger.info(f"Load EMA model finish, memory_allocated: {torch.cuda.memory_allocated()/GB:.2f} GB", main_process_only=True) + + # STEP 5: All models have been prepared, start training + accelerator.state.select_deepspeed_plugin("train") + model.train() # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -562,40 +541,12 @@ def load_model_hook(models, input_dir): logger.info(f" Text_enc_1 = {args.text_encoder_name_1}; Dtype = {weight_dtype}; Parameters = {sum(p.numel() for p in text_enc_1.parameters()) / 1e9} B") if args.text_encoder_name_2 is not None: logger.info(f" Text_enc_2 = {args.text_encoder_name_2}; Dtype = {weight_dtype}; Parameters = {sum(p.numel() for p in text_enc_2.parameters()) / 1e9} B") - - global_step = 0 - first_epoch = 0 - # Potentially load in the weights and states from a previous save - if args.resume_from_checkpoint: - if args.resume_from_checkpoint != "latest": - path = os.path.basename(args.resume_from_checkpoint) - else: - # Get the most recent checkpoint - dirs = os.listdir(args.output_dir) - dirs = [d for d in dirs if d.startswith("checkpoint")] - dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) - path = dirs[-1] if len(dirs) > 0 else None - - if path is None: - accelerator.print( - f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." - ) - args.resume_from_checkpoint = None - initial_global_step = 0 - else: - accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) - global_step = int(path.split("-")[1]) - - initial_global_step = global_step - first_epoch = global_step // num_update_steps_per_epoch - - else: - initial_global_step = 0 + if args.use_ema: + logger.info(f" EMA model = {type(ema_model.model)}; Dtype = {ema_model.model.dtype}; Parameters = {sum(p.numel() for p in ema_model.model.parameters()) / 1e9} B") progress_bar = tqdm( range(0, args.max_train_steps), - initial=initial_global_step, + initial=global_step, desc="Steps", # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, @@ -659,7 +610,13 @@ def sync_gradients_info(loss): shutil.rmtree(removing_checkpoint) save_path = os.path.join(args.output_dir, f"checkpoint-{progress_info.global_step}") + # FIXME: https://github.com/huggingface/accelerate/issues/3140 + acc_models = accelerator._models + accelerator._models = [model for model in acc_models if model.checkpoint_engine is not None] accelerator.save_state(save_path) + accelerator._models = acc_models + if args.use_ema: + ema_model.save_pretrained(os.path.join(save_path, "model_ema")) logger.info(f"Saved state to {save_path}") logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} @@ -841,12 +798,6 @@ def train_one_step(step_, data_item_, prof_=None): x, attn_mask, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2 = data_item_ # print(f'step: {step_}, rank: {accelerator.process_index}, x: {x.shape}, dtype: {x.dtype}') # assert not torch.any(torch.isnan(x)), 'torch.any(torch.isnan(x))' - if args.extra_save_mem: - torch.cuda.empty_cache() - ae.vae.to(accelerator.device, dtype=torch.float32 if args.vae_fp32 else weight_dtype) - text_enc_1.to(accelerator.device, dtype=weight_dtype) - if text_enc_2 is not None: - text_enc_2.to(accelerator.device, dtype=weight_dtype) x = x.to(accelerator.device, dtype=ae.vae.dtype) # B C T H W # x = x.to(accelerator.device, dtype=torch.float32) # B C T H W @@ -891,12 +842,6 @@ def train_one_step(step_, data_item_, prof_=None): # import sys;sys.exit() # print("rank {} | step {} | after encode".format(accelerator.process_index, step_)) - if args.extra_save_mem: - ae.vae.to('cpu') - text_enc_1.to('cpu') - if text_enc_2 is not None: - text_enc_2.to('cpu') - torch.cuda.empty_cache() current_step_frame = x.shape[2] current_step_sp_state = get_sequence_parallel_state() @@ -1009,6 +954,10 @@ def train_one_epoch(prof_=None): if __name__ == "__main__": parser = argparse.ArgumentParser() + # deepspeed + parser.add_argument("--train_deepspeed_config_file", type=str, required=True, help="deepspeed config file for training, e.g diffusion model") + parser.add_argument("--eval_deepspeed_config_file", type=str, required=True, help="deepspeed config file for evaluation, e.g text encoder, EMA model") + # dataset & dataloader parser.add_argument("--dataset", type=str, required=True) parser.add_argument("--data", type=str, required='') @@ -1035,7 +984,6 @@ def train_one_epoch(prof_=None): # text encoder & vae & diffusion model parser.add_argument('--vae_fp32', action='store_true') - parser.add_argument('--extra_save_mem', action='store_true') parser.add_argument("--model", type=str, choices=list(Diffusion_models.keys()), default="Latte-XL/122") parser.add_argument('--enable_tiling', action='store_true') parser.add_argument('--interpolation_scale_h', type=float, default=1.0) diff --git a/opensora/utils/ema.py b/opensora/utils/ema.py index a906efee9..91df866c0 100644 --- a/opensora/utils/ema.py +++ b/opensora/utils/ema.py @@ -1,6 +1,6 @@ -import contextlib +import os import copy -import random +import math from typing import Any, Dict, Iterable, List, Optional, Union from diffusers.utils import ( @@ -17,6 +17,12 @@ import numpy as np import torch +import deepspeed +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + + +def _z3_params_to_fetch(param_list): + return [p for p in param_list if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE] # Adapted from diffusers-style ema https://github.com/huggingface/diffusers/blob/main/src/diffusers/training_utils.py#L263 @@ -27,7 +33,7 @@ class EMAModel: def __init__( self, - parameters: Iterable[torch.nn.Parameter], + model: torch.nn.Module, decay: float = 0.9999, min_decay: float = 0.0, update_after_step: int = 0, @@ -57,22 +63,8 @@ def __init__( gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at 215.4k steps). """ - - if isinstance(parameters, torch.nn.Module): - deprecation_message = ( - "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. " - "Please pass the parameters of the module instead." - ) - deprecate( - "passing a `torch.nn.Module` to `ExponentialMovingAverage`", - "1.0.0", - deprecation_message, - standard_warn=False, - ) - parameters = parameters.parameters() - - # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility - use_ema_warmup = True + + self.model = model if kwargs.get("max_value", None) is not None: deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead." @@ -84,9 +76,6 @@ def __init__( deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False) min_decay = kwargs["min_value"] - parameters = list(parameters) - self.shadow_params = [p.clone().detach() for p in parameters] - if kwargs.get("device", None) is not None: deprecation_message = "The `device` argument is deprecated. Please use `to` instead." deprecate("device", "1.0.0", deprecation_message, standard_warn=False) @@ -131,7 +120,7 @@ def from_pretrained(cls, path, model_cls) -> "EMAModel": ema_kwargs = cls.extract_ema_kwargs(config) model = model_cls.from_pretrained(path) - ema_model = cls(model.parameters(), model_cls=model_cls, model_config=config) + ema_model = cls(model, model_cls=model_cls, model_config=config) ema_model.load_state_dict(ema_kwargs) return ema_model @@ -143,13 +132,24 @@ def save_pretrained(self, path): if self.model_config is None: raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.") - model = self.model_cls.from_config(self.model_config) + rank = int(os.getenv("RANK", "0")) state_dict = self.state_dict() - state_dict.pop("shadow_params", None) - - model.register_to_config(**state_dict) - self.copy_to(model.parameters()) - model.save_pretrained(path) + state_dict.pop("model") + + model_to_save = self.model.module if hasattr(self.model, "module") else self.model + model_state_dict = {} + for k, v in model_to_save.named_parameters(): + # only gather z3 params + params_to_fetch = _z3_params_to_fetch([v]) + with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=len(params_to_fetch) > 0): + vv = v.data.cpu() + if rank == 0: + model_state_dict[k] = vv + + if rank == 0: + self.model.register_to_config(**state_dict) + self.model.save_config(path) + torch.save(model_state_dict, os.path.join(path, "diffusion_pytorch_model.bin")) def get_decay(self, optimization_step: int) -> float: """ @@ -194,19 +194,38 @@ def step(self, parameters: Iterable[torch.nn.Parameter]): self.cur_decay_value = decay one_minus_decay = 1 - decay - context_manager = contextlib.nullcontext - if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): - import deepspeed - - for s_param, param in zip(self.shadow_params, parameters): - if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): - context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) - - with context_manager(): - if param.requires_grad: - s_param.sub_(one_minus_decay * (s_param - param)) - else: - s_param.copy_(param) + # https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/partition_parameters.py#L1543 + for s_param, param in zip(self.model.parameters(), parameters): + s_tensor, tensor = None, None + if hasattr(s_param, "ds_tensor"): # EMA ZeRO-3 + s_tensor = s_param.ds_tensor + if hasattr(param, "ds_tensor"): # DiT ZeRO-3 + tensor = param.ds_tensor + else: # DiT ZeRO-2 + rank, world_size = int(os.getenv("RANK")), int(os.getenv("WORLD_SIZE")) + partition_size = math.ceil(param.numel()/world_size) + start = partition_size * rank + end = start + partition_size + + one_dim_param = param.data.contiguous().view(-1) + if start < param.numel() and end <= param.numel(): + tensor = one_dim_param.narrow(0, start, partition_size) + elif start < param.numel(): + elems_to_copy = param.numel() - start + s_tensor = s_param.ds_tensor.narrow(0, 0, elems_to_copy) + tensor = one_dim_param.narrow(0, start, elems_to_copy) + else: + continue + else: # DiT/EMA ZeRO-2 + s_tensor = s_param.data + tensor = param.data + + assert s_tensor.shape == tensor.shape, f"mismatch shape, s_tensor: {s_tensor.shape}, tensor: {tensor.shape}" + + if param.requires_grad: + s_tensor.sub_(one_minus_decay * (s_tensor - tensor)) + else: + s_tensor.copy_(tensor) def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: """ @@ -218,7 +237,7 @@ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: `ExponentialMovingAverage` was initialized will be used. """ parameters = list(parameters) - for s_param, param in zip(self.shadow_params, parameters): + for s_param, param in zip(self.model.parameters(), parameters): param.data.copy_(s_param.to(param.device).data) @@ -229,10 +248,7 @@ def to(self, device=None, dtype=None) -> None: device: like `device` argument to `torch.Tensor.to` """ # .to() on the tensors handles None correctly - self.shadow_params = [ - p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) - for p in self.shadow_params - ] + self.model = self.model.to(device=device, dtype=dtype) def state_dict(self) -> dict: r""" @@ -250,7 +266,7 @@ def state_dict(self) -> dict: "use_ema_warmup": self.use_ema_warmup, "inv_gamma": self.inv_gamma, "power": self.power, - "shadow_params": self.shadow_params, + "model": self.model.state_dict(), } def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: @@ -319,10 +335,19 @@ def load_state_dict(self, state_dict: dict) -> None: if not isinstance(self.power, (float, int)): raise ValueError("Invalid power") - shadow_params = state_dict.get("shadow_params", None) - if shadow_params is not None: - self.shadow_params = shadow_params - if not isinstance(self.shadow_params, list): - raise ValueError("shadow_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): - raise ValueError("shadow_params must all be Tensors") + model_state_dict = state_dict.get("model", None) + if model_state_dict is not None: + self.model.load_state_dict(model_state_dict) + + +if __name__ == "__main__": + import ipdb + from opensora.models.diffusion.opensora_v1_3.modeling_opensora import OpenSoraT2V_v1_3 + + model_path = "" + ema_model = EMAModel.from_pretrained(model_path, OpenSoraT2V_v1_3) + ipdb.set_trace() + + save_path = "" + ema_model.save_pretrained(save_path) + ema_model2 = EMAModel.from_pretrained(save_path, OpenSoraT2V_v1_3) diff --git a/pyproject.toml b/pyproject.toml index 846449483..95a562c6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "timm==0.9.16", "torchdiffeq==0.2.3", "torchmetrics==1.3.2", "tqdm==4.66.5", "urllib3==2.2.2", "uvicorn==0.27.1", "scikit-video==1.1.11", "imageio-ffmpeg==0.4.9", "sentencepiece==0.1.99", "beautifulsoup4==4.12.3", "ftfy==6.1.3", "moviepy==1.0.3", "wandb==0.16.3", "tensorboard==2.14.0", "pydantic==2.6.4", "gradio==4.0.0", - "torch==2.1.0", "torchvision==0.16.0", "xformers==0.0.22.post7", "accelerate==0.34.0", "diffusers==0.30.2", "deepspeed==0.12.6" + "torch==2.1.0", "torchvision==0.16.0", "xformers==0.0.22.post7", "accelerate==1.0.0", "diffusers==0.30.2", "deepspeed==0.12.6" ] [project.optional-dependencies] diff --git a/scripts/accelerate_configs/zero3.json b/scripts/accelerate_configs/zero3.json index ef467e4d5..c1b9d640c 100644 --- a/scripts/accelerate_configs/zero3.json +++ b/scripts/accelerate_configs/zero3.json @@ -22,7 +22,7 @@ "sub_group_size": 1e9, "reduce_bucket_size": 5e8, "stage3_prefetch_bucket_size": 5e8, - "stage3_param_persistence_threshold": "auto", + "stage3_param_persistence_threshold": 1e5, "stage3_max_live_parameters": 1e9, "stage3_max_reuse_distance": 1e9, "stage3_gather_16bit_weights_on_model_save": true diff --git a/scripts/text_condition/npu/train_t2v_v1_3.sh b/scripts/text_condition/npu/train_t2v_v1_3.sh index f1b7427d9..59844c745 100644 --- a/scripts/text_condition/npu/train_t2v_v1_3.sh +++ b/scripts/text_condition/npu/train_t2v_v1_3.sh @@ -17,6 +17,8 @@ export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True accelerate launch \ --config_file scripts/accelerate_configs/deepspeed_zero2_config.yaml \ opensora/train/train_t2v_diffusers.py \ + --train_deepspeed_config_file scripts/accelerate_configs/zero2.json \ + --eval_deepspeed_config_file scripts/accelerate_configs/zero3.json \ --model OpenSoraT2V_v1_3-2B/122 \ --text_encoder_name_1 google/mt5-xxl \ --cache_dir "../../cache_dir/" \