Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 95 additions & 147 deletions opensora/train/train_t2v_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from transformers.utils import ContextManagers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
from accelerate.utils import DistributedType, ProjectConfiguration, set_seed, DeepSpeedPlugin
from accelerate.state import AcceleratorState
from packaging import version
from tqdm.auto import tqdm
Expand All @@ -56,7 +56,7 @@
import diffusers
from diffusers import DDPMScheduler, PNDMScheduler, DPMSolverMultistepScheduler, CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel, compute_snr
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3

Expand All @@ -69,6 +69,7 @@
from opensora.models.diffusion import Diffusion_models, Diffusion_models_class
from opensora.utils.dataset_utils import Collate, LengthGroupedSampler
from opensora.utils.utils import explicit_uniform_sampling
from opensora.utils.ema import EMAModel
from opensora.sample.pipeline_opensora import OpenSoraPipeline
from opensora.models.causalvideovae import ae_stride_config, ae_wrapper

Expand All @@ -78,6 +79,8 @@
check_min_version("0.24.0")
logger = get_logger(__name__)

GB = 1024 * 1024 * 1024

@torch.inference_mode()
def log_validation(args, model, vae, text_encoder, tokenizer, accelerator, weight_dtype, global_step, ema=False):
positive_prompt = "(masterpiece), (best quality), (ultra-detailed), {}. emotional, harmonious, vignette, 4k epic detailed, shot on kodak, 35mm photo, sharp focus, high budget, cinemascope, moody, epic, gorgeous"
Expand Down Expand Up @@ -170,11 +173,26 @@ def main(args):
npu_config.seed_everything(args.seed)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)

# https://huggingface.co/docs/accelerate/v1.0.0/en/usage_guides/deepspeed_multiple_model#using-multiple-models-with-deepspeed
train_plugin = DeepSpeedPlugin(hf_ds_config=args.train_deepspeed_config_file)
text_encoder_plugin = DeepSpeedPlugin(hf_ds_config=args.eval_deepspeed_config_file)
text_encoder_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size

# Use FP32 for EMA model
ema_plugin = DeepSpeedPlugin(hf_ds_config=args.eval_deepspeed_config_file)
ema_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
ema_plugin.deepspeed_config["bf16"]["enabled"] = False

accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
deepspeed_plugins={
"train": train_plugin,
"text_encoder": text_encoder_plugin,
"ema": ema_plugin,
},
)

if args.num_frames != 1:
Expand Down Expand Up @@ -219,44 +237,37 @@ def main(args):
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16

# Create model:

# Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
# For this to work properly all models must be run through `accelerate.prepare`. But accelerate
# will try to assign the same optimizer with the same weights to all models during
# `deepspeed.initialize`, which of course doesn't work.
#
# For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
# frozen models from being partitioned during `zero.Init` which gets called during
# `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
# across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
def deepspeed_zero_init_disabled_context_manager():
"""
returns either a context list that includes one that will disable zero.Init or an empty context list
"""
deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
if deepspeed_plugin is None:
return []

return [deepspeed_plugin.zero3_init_context_manager(enable=False)]

with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
kwargs = {}
ae = ae_wrapper[args.ae](args.ae_path, cache_dir=args.cache_dir, **kwargs).eval()

if args.enable_tiling:
ae.vae.enable_tiling()

kwargs = {
'torch_dtype': weight_dtype,
'low_cpu_mem_usage': False
}
text_enc_1 = get_text_warpper(args.text_encoder_name_1)(args, **kwargs).eval()
checkpoint_path, global_step = None, 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
checkpoint_path = args.resume_from_checkpoint
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
checkpoint_path = os.path.join(args.output_dir, dirs[-1]) if len(dirs) > 0 else None

text_enc_2 = None
if args.text_encoder_name_2 is not None:
text_enc_2 = get_text_warpper(args.text_encoder_name_2)(args, **kwargs).eval()
if checkpoint_path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
else:
accelerator.print(f"Resuming from checkpoint {checkpoint_path}")
global_step = int(checkpoint_path.split("-")[-1])

# STEP 1: Prepare VAE model
kwargs = {}
ae = ae_wrapper[args.ae](args.ae_path, cache_dir=args.cache_dir, **kwargs).eval()
if args.enable_tiling:
ae.vae.enable_tiling()
ae.vae.requires_grad_(False)
ae.vae.to(accelerator.device, dtype=weight_dtype)
logger.info(f"Load VAE model finish, memory_allocated: {torch.cuda.memory_allocated()/GB:.2f} GB", main_process_only=True)

# STEP 2: Prepare diffusion model (Trained model must be prepared first!)
accelerator.state.select_deepspeed_plugin("train")
ae_stride_t, ae_stride_h, ae_stride_w = ae_stride_config[args.ae]
ae.vae_scale_factor = (ae_stride_t, ae_stride_h, ae_stride_w)
assert ae_stride_h == ae_stride_w, f"Support only ae_stride_h == ae_stride_w now, but found ae_stride_h ({ae_stride_h}), ae_stride_w ({ae_stride_w})"
Expand Down Expand Up @@ -317,13 +328,6 @@ def deepspeed_zero_init_disabled_context_manager():
print(f'Successfully load {len(model_state_dict) - len(missing_keys)}/{len(model_state_dict)} keys from {args.pretrained}!')

model.gradient_checkpointing = args.gradient_checkpointing
# Freeze vae and text encoders.
ae.vae.requires_grad_(False)
text_enc_1.requires_grad_(False)
if text_enc_2 is not None:
text_enc_2.requires_grad_(False)
# Set model as trainable.
model.train()

kwargs = dict(
prediction_type=args.prediction_type,
Expand All @@ -341,62 +345,15 @@ def deepspeed_zero_init_disabled_context_manager():
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
else:
noise_scheduler = DDPMScheduler(**kwargs)
# Move unet, vae and text_encoder to device and cast to weight_dtype
# The VAE is in float32 to avoid NaN losses.
if not args.extra_save_mem:
ae.vae.to(accelerator.device, dtype=torch.float32 if args.vae_fp32 else weight_dtype)
text_enc_1.to(accelerator.device, dtype=weight_dtype)
if text_enc_2 is not None:
text_enc_2.to(accelerator.device, dtype=weight_dtype)

# Create EMA for the unet.
if args.use_ema:
ema_model = deepcopy(model)
ema_model = EMAModel(ema_model.parameters(), decay=args.ema_decay, update_after_step=args.ema_start_step,
model_cls=Diffusion_models_class[args.model], model_config=ema_model.config,
foreach=args.foreach_ema)

# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
if args.use_ema:
ema_model.save_pretrained(os.path.join(output_dir, "model_ema"))

for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "model"))
if weights: # Don't pop if empty
# make sure to pop weight so that corresponding model is not saved again
weights.pop()

def load_model_hook(models, input_dir):
if args.use_ema:
load_model = EMAModel.from_pretrained(
os.path.join(input_dir, "model_ema"),
Diffusion_models_class[args.model],
foreach=args.foreach_ema,
)
ema_model.load_state_dict(load_model.state_dict())
if args.offload_ema:
ema_model.pin_memory()
else:
ema_model.to(accelerator.device)
del load_model

for i in range(len(models)):
# pop models so that they are not loaded again
model = models.pop()

# load diffusers style into model
load_model = Diffusion_models_class[args.model].from_pretrained(input_dir, subfolder="model")
model.register_to_config(**load_model.config)

model.load_state_dict(load_model.state_dict())
del load_model

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
if checkpoint_path:
ema_model = EMAModel.from_pretrained(os.path.join(checkpoint_path, "model_ema"),
model_cls=Diffusion_models_class[args.model])
else:
ema_model = EMAModel(deepcopy(model), decay=args.ema_decay, update_after_step=args.ema_start_step,
model_cls=Diffusion_models_class[args.model], model_config=model.config)

# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
Expand Down Expand Up @@ -523,12 +480,34 @@ def load_model_hook(models, input_dir):
model, optimizer, train_dataloader, lr_scheduler
)
logger.info(f'after accelerator.prepare')

if checkpoint_path:
accelerator.load_state(checkpoint_path)
logger.info(f"Load diffusion model finish, memory_allocated: {torch.cuda.memory_allocated()/GB:.2f} GB", main_process_only=True)

# STEP 3: Prepare text encoder model
accelerator.state.select_deepspeed_plugin("text_encoder")
text_enc_1 = get_text_warpper(args.text_encoder_name_1)(args)
text_enc_1 = accelerator.prepare(text_enc_1)
text_enc_1.eval()
logger.info(f"Load text encoder model 1 finish, memory_allocated: {torch.cuda.memory_allocated()/GB:.2f} GB", main_process_only=True)

text_enc_2 = None
if args.text_encoder_name_2 is not None:
text_enc_2 = get_text_warpper(args.text_encoder_name_2)(args)
text_enc_2 = accelerator.prepare(text_enc_2)
text_enc_2.eval()
logger.info(f"Load text encoder model 2 finish, memory_allocated: {torch.cuda.memory_allocated()/GB:.2f} GB", main_process_only=True)

# STEP 4: Prepare EMA model
accelerator.state.select_deepspeed_plugin("ema")
if args.use_ema:
if args.offload_ema:
ema_model.pin_memory()
else:
ema_model.to(accelerator.device)
ema_model.model = accelerator.prepare(ema_model.model)
ema_model.model.eval()
logger.info(f"Load EMA model finish, memory_allocated: {torch.cuda.memory_allocated()/GB:.2f} GB", main_process_only=True)

# STEP 5: All models have been prepared, start training
accelerator.state.select_deepspeed_plugin("train")
model.train()

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
Expand Down Expand Up @@ -562,40 +541,12 @@ def load_model_hook(models, input_dir):
logger.info(f" Text_enc_1 = {args.text_encoder_name_1}; Dtype = {weight_dtype}; Parameters = {sum(p.numel() for p in text_enc_1.parameters()) / 1e9} B")
if args.text_encoder_name_2 is not None:
logger.info(f" Text_enc_2 = {args.text_encoder_name_2}; Dtype = {weight_dtype}; Parameters = {sum(p.numel() for p in text_enc_2.parameters()) / 1e9} B")

global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None

if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])

initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch

else:
initial_global_step = 0
if args.use_ema:
logger.info(f" EMA model = {type(ema_model.model)}; Dtype = {ema_model.model.dtype}; Parameters = {sum(p.numel() for p in ema_model.model.parameters()) / 1e9} B")

progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
initial=global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
Expand Down Expand Up @@ -659,7 +610,13 @@ def sync_gradients_info(loss):
shutil.rmtree(removing_checkpoint)

save_path = os.path.join(args.output_dir, f"checkpoint-{progress_info.global_step}")
# FIXME: https://github.com/huggingface/accelerate/issues/3140
acc_models = accelerator._models
accelerator._models = [model for model in acc_models if model.checkpoint_engine is not None]
accelerator.save_state(save_path)
accelerator._models = acc_models
if args.use_ema:
ema_model.save_pretrained(os.path.join(save_path, "model_ema"))
logger.info(f"Saved state to {save_path}")

logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
Expand Down Expand Up @@ -841,12 +798,6 @@ def train_one_step(step_, data_item_, prof_=None):
x, attn_mask, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2 = data_item_
# print(f'step: {step_}, rank: {accelerator.process_index}, x: {x.shape}, dtype: {x.dtype}')
# assert not torch.any(torch.isnan(x)), 'torch.any(torch.isnan(x))'
if args.extra_save_mem:
torch.cuda.empty_cache()
ae.vae.to(accelerator.device, dtype=torch.float32 if args.vae_fp32 else weight_dtype)
text_enc_1.to(accelerator.device, dtype=weight_dtype)
if text_enc_2 is not None:
text_enc_2.to(accelerator.device, dtype=weight_dtype)

x = x.to(accelerator.device, dtype=ae.vae.dtype) # B C T H W
# x = x.to(accelerator.device, dtype=torch.float32) # B C T H W
Expand Down Expand Up @@ -891,12 +842,6 @@ def train_one_step(step_, data_item_, prof_=None):
# import sys;sys.exit()

# print("rank {} | step {} | after encode".format(accelerator.process_index, step_))
if args.extra_save_mem:
ae.vae.to('cpu')
text_enc_1.to('cpu')
if text_enc_2 is not None:
text_enc_2.to('cpu')
torch.cuda.empty_cache()

current_step_frame = x.shape[2]
current_step_sp_state = get_sequence_parallel_state()
Expand Down Expand Up @@ -1009,6 +954,10 @@ def train_one_epoch(prof_=None):
if __name__ == "__main__":
parser = argparse.ArgumentParser()

# deepspeed
parser.add_argument("--train_deepspeed_config_file", type=str, required=True, help="deepspeed config file for training, e.g diffusion model")
parser.add_argument("--eval_deepspeed_config_file", type=str, required=True, help="deepspeed config file for evaluation, e.g text encoder, EMA model")

# dataset & dataloader
parser.add_argument("--dataset", type=str, required=True)
parser.add_argument("--data", type=str, required='')
Expand All @@ -1035,7 +984,6 @@ def train_one_epoch(prof_=None):

# text encoder & vae & diffusion model
parser.add_argument('--vae_fp32', action='store_true')
parser.add_argument('--extra_save_mem', action='store_true')
parser.add_argument("--model", type=str, choices=list(Diffusion_models.keys()), default="Latte-XL/122")
parser.add_argument('--enable_tiling', action='store_true')
parser.add_argument('--interpolation_scale_h', type=float, default=1.0)
Expand Down
Loading