diff --git a/README.md b/README.md index 23d36e1..627ab81 100644 --- a/README.md +++ b/README.md @@ -506,7 +506,8 @@ We evaluate video generation on **a single RTX 5090 GPU**. The E2E Time refers t In this repo, we provide training code based on Wan2.1 and its synthetic data. The training builds on the rCM codebase (https://github.com/NVlabs/rcm), with infrastructure support including FSDP2, Ulysses CP, and selective activation checkpointing (SAC). For rCM training instructions, please refer to the original rCM repository; [SLA (Sparse-Linear Attention)](https://github.com/thu-ml/SLA) training guidance is provided here. -#### Additional Installation +### Additional Installation + For rCM/SLA training, additionally run: ```bash @@ -514,7 +515,8 @@ pip install megatron-core hydra-core wandb webdataset pip install --no-build-isolation transformer_engine[pytorch] ``` -#### Checkpoints Downloading +### Checkpoints Downloading + Download the Wan2.1 pretrained checkpoints in `.pth` format and VAE/text encoder to `assets/checkpoints`: ```bash @@ -530,7 +532,7 @@ python -m torch.distributed.checkpoint.format_utils torch_to_dcp assets/checkpoi After training, the saved `.dcp` checkpoints can be converted to `.pth` using the script `scripts/dcp_to_pth.py`. -#### Dataset Downloading +### Dataset Downloading We provide Wan2.1-14B-synthesized datasets. Download to `assets/datasets` using: @@ -539,7 +541,8 @@ We provide Wan2.1-14B-synthesized datasets. Download to `assets/datasets` using: git clone https://huggingface.co/datasets/worstcoder/Wan_datasets assets/datasets ``` -#### Start Training +### Start Training + We implement white-box SLA training by aligning the predictions of the SLA-enabled model with those of the full-attention pretrained model. Unlike black-box training in the original paper, which tunes the pretrained model using diffusion loss, white-box training mitigates distribution shift and is less sensitive to the training data. Single-node training example: @@ -572,7 +575,7 @@ torchrun --nproc_per_node=8 \ Please refer to `turbodiffusion/rcm/configs/experiments/sla/wan2pt1_t2v.py` for the 14B config or perform modifications as needed. -#### Model Merging +### Model Merging The parameter updates from SLA training can be merged into rCM checkpoints using `turbodiffusion/scripts/merge_models.py`, enabling rCM to perform sparse attention inference. Specify `--base` as the rCM model, `--diff_base` as the pretrained model, and `--diff_target` as the SLA-tuned model. diff --git a/turbodiffusion/imaginaire/trainer.py b/turbodiffusion/imaginaire/trainer.py index 5cc9fd0..b4d80c8 100644 --- a/turbodiffusion/imaginaire/trainer.py +++ b/turbodiffusion/imaginaire/trainer.py @@ -237,6 +237,7 @@ def train( self.callbacks.on_train_end(model, iteration=iteration) self.checkpointer.finalize() distributed.barrier() + distributed.destroy_process_group() self.callbacks.on_app_end() def training_step( diff --git a/turbodiffusion/imaginaire/utils/distributed.py b/turbodiffusion/imaginaire/utils/distributed.py index e42ca42..2c4303f 100644 --- a/turbodiffusion/imaginaire/utils/distributed.py +++ b/turbodiffusion/imaginaire/utils/distributed.py @@ -156,6 +156,12 @@ def barrier() -> None: dist.barrier() +def destroy_process_group() -> None: + """Destroy the distributed process group.""" + if dist.is_available() and dist.is_initialized(): + dist.destroy_process_group() + + def rank0_first(func: Callable) -> Callable: """run the function on rank 0 first, then on other ranks."""