From 967b19d42f74f1593774016cca2094ef6a628afc Mon Sep 17 00:00:00 2001 From: teamtee <2624071330@qq.com> Date: Fri, 16 May 2025 21:40:30 +0800 Subject: [PATCH 1/2] Fixed the join bug caused by Deepspeed adaptation and improved the model saving of Deepspeed --- .gitignore | 2 +- examples/aispeech_asr/README.md | 26 ++----------------- examples/aispeech_asr/README_zh.md | 4 +-- examples/aispeech_asr/aispeech_asr_config.py | 4 ++- examples/aispeech_asr/scripts/decode.sh | 5 ++-- .../aispeech_asr/scripts/decode_deepspeed.sh | 4 +-- .../scripts/finetune_deepspeed.sh | 13 +++++----- .../aispeech_asr/scripts/finetune_torchrun.sh | 17 ++++++------ .../scripts/transcribe_deepspeed_to_pt.py | 9 ------- examples/asr_librispeech/README.md | 22 +--------------- .../scripts/transcribe_deepspeed_to_pt.py | 9 ------- src/slam_llm/datasets/speech_dataset_large.py | 4 ++- src/slam_llm/utils/checkpoint_handler.py | 14 +++++++--- src/slam_llm/utils/deepspeed_utils.py | 2 +- src/slam_llm/utils/train_utils.py | 6 ++--- 15 files changed, 45 insertions(+), 96 deletions(-) delete mode 100644 examples/aispeech_asr/scripts/transcribe_deepspeed_to_pt.py delete mode 100644 examples/asr_librispeech/scripts/transcribe_deepspeed_to_pt.py diff --git a/.gitignore b/.gitignore index 65408424..eddcb9b8 100644 --- a/.gitignore +++ b/.gitignore @@ -15,7 +15,7 @@ data/ jobs/ debug/ audio/ - +exp/ examples/s2s/scripts/debug examples/vsr_LRS3/scripts/decode_avhubert_vo_vicuna_7b_noself.sh examples/asr_librispeech/scripts/decode_hubert_xtralarge_linear_vicuna_7b_copy.sh diff --git a/examples/aispeech_asr/README.md b/examples/aispeech_asr/README.md index 4568c335..55918731 100644 --- a/examples/aispeech_asr/README.md +++ b/examples/aispeech_asr/README.md @@ -66,7 +66,6 @@ dev_scp_file_path= # Path to validation data train_max_frame_length=1500 # Maximum frame length for training eval_max_frame_length=1000 # Maximum frame length for evaluation multitask_prompt_path= # Path to multitask.jsonl -prompt_style="\{\}" # Prompt style, e.g., "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" or "USER: {}\n ASSISTANT:" projector=linear # Type of projector encoder_name=whisper # Name of the encoder llm_name=Qwen2.5-7B-Instruct # Name of the LLM @@ -86,7 +85,7 @@ For LoRA training, set (with `ckpt_path` pointing to the model saved in the prev ```bash use_peft=true if [[ $use_peft == "true" ]]; then - ckpt_path= # For DDP training, provide the path to the saved pt file; for DeepSpeed training, convert mp_rank_00_model_states.pt to model.pt using the `scripts/transcribe_deepspeed_to_pt.py` script + ckpt_path= fi ``` ### Deepspeed @@ -113,28 +112,7 @@ When using `bf16`/`fp16` for training, deepspeed saves about 20GB of GPU memory } } ``` - -Note that when using `zero-0`/`1`/`2`, the DeepSpeed model is saved in a format that requires a script to convert `mp_rank_00_model_states.pt` to `model.pt`, such as `python scripts/transcribe_deepspeed_to_pt.py mp_rank_00_model_states.pt output_dir`. - -``` -global_step1000 -global_step1000/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt -... -global_step1000/mp_rank_00_model_states.pt -latest -zero_to_fp32.py -``` - -If training with `Zero-3`, the model is saved in a different format and can be converted using `python zero_to_fp32.py global_step50 outputdir`. - -``` -global_step50 -global_step50/zero_pp_rank_0_mp_rank_00_model_states.pt -global_step50/zero_pp_rank_0_mp_rank_00_optim_states.pt -... -latest -zero_to_fp32.py -``` +Note that when using `zero-0`/`1`/`2`/`3`, the DeepSpeed model is saved as `pytorch_model.bin` If you use bf16/fp16 training in DeepSpeed and encounter NaN in train/eval loss, check the autocast in `src/slam_llm/utils/deepspeed_utils.py`: ```python diff --git a/examples/aispeech_asr/README_zh.md b/examples/aispeech_asr/README_zh.md index f23f5d37..2807f2d0 100644 --- a/examples/aispeech_asr/README_zh.md +++ b/examples/aispeech_asr/README_zh.md @@ -82,11 +82,11 @@ deepspeed_config= # DeepSpeed配置文件路径 use_peft=false ``` -训练LoRA时,设置如下(`ckpt_path`是上一步训练保存的模型路径): +训练LoRA时,设置如下(`ckpt_path`是上一步训练保存的模型路径`pytorch_model.bin/model.pt`): ```bash use_peft=true if [[ $use_peft == "true" ]]; then - ckpt_path= # 如果是DDP训练,直接写入保存的pt文件路径;如果是Deepspeed训练,需将mp_rank_00_model_states.pt文件转化为model.pt,可使用`scripts/transcribe_deepspeed_to_pt.py`脚本 + ckpt_path= fi ``` diff --git a/examples/aispeech_asr/aispeech_asr_config.py b/examples/aispeech_asr/aispeech_asr_config.py index cf1bdab5..242a4d49 100644 --- a/examples/aispeech_asr/aispeech_asr_config.py +++ b/examples/aispeech_asr/aispeech_asr_config.py @@ -91,8 +91,10 @@ class DataConfig: dataset: str = "multitask_dataset" train_max_frame_length: int = 1500 eval_max_frame_length: int = 1000 + audio_sample_rate: int = 16000 + max_audio_length: int = 30 multitask_prompt_path: str = "/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/data/multiprompt.jsonl" - prompt_style: str = "\{\}" # + prompt_style: str = "{}" # "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" | "USER: {}\n ASSISTANT:" Comment:Changed it in aispeech_asr_config.py append_info_tasks : List = field(default_factory=lambda: [ "hotword"]) file: str = "examples/aispeech_asr/slam_llm/datasets/speech_dataset_large.py:get_speech_dataset" train_scp_file_path: str = "" diff --git a/examples/aispeech_asr/scripts/decode.sh b/examples/aispeech_asr/scripts/decode.sh index 30790213..33bb3fe3 100644 --- a/examples/aispeech_asr/scripts/decode.sh +++ b/examples/aispeech_asr/scripts/decode.sh @@ -1,10 +1,9 @@ #!/bin/bash set -e -run_dir=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/github/SLAM-LLM-NPU +run_dir=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/github/SLAM-LLM cd $run_dir code_dir=examples/aispeech_asr -prompt_style="\{\}" # "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" | "USER: {}\n ASSISTANT:" projector=linear encoder_name=whisper llm_name=Qwen2.5-7B-Instruct @@ -15,6 +14,7 @@ encoder_projector_ds_rate=5 eval_max_frame_length=1000 ckpt_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/project/aispeech_asr/exp/librispeech/20250322/whisper_linear_Qwen2.5-7B-Instruct_lorafalse_padtrue_normal_asr_speedfalse_specaugfalse-1121/mala_asr_epoch_2_step_25000_best test_scp_file_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/data/aishell-1/asr/test +# prompt_style="\{\}" # "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" | "USER: {}\n ASSISTANT:" Comment:Changed it in aispeech_asr_config.py # Choose Encoder @@ -69,7 +69,6 @@ python \ ++model_config.encoder_path=$speech_encoder_path \ ++model_config.encoder_dim=$encoder_dim \ ++model_config.encoder_projector=$projector \ - ++dataset_config.prompt_style=$prompt_style \ ++dataset_config.dataset=$dataset \ ++dataset_config.pad_or_trim=$pad_or_trim \ ++dataset_config.test_scp_file_path=$test_scp_file_path \ diff --git a/examples/aispeech_asr/scripts/decode_deepspeed.sh b/examples/aispeech_asr/scripts/decode_deepspeed.sh index 846874ba..2b384672 100644 --- a/examples/aispeech_asr/scripts/decode_deepspeed.sh +++ b/examples/aispeech_asr/scripts/decode_deepspeed.sh @@ -1,10 +1,9 @@ #!/bin/bash set -e -run_dir=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/github/SLAM-LLM-NPU +run_dir=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/github/SLAM-LLM cd $run_dir code_dir=examples/aispeech_asr -prompt_style="\{\}" # "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" | "USER: {}\n ASSISTANT:" projector=linear encoder_name=whisper llm_name=Qwen2.5-7B-Instruct @@ -15,6 +14,7 @@ encoder_projector_ds_rate=5 eval_max_frame_length=1000 ckpt_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/project/aispeech_asr/exp/librispeech/20250322/whisper_linear_Qwen2.5-7B-Instruct_lorafalse_padtrue_normal_asr_speedfalse_specaugfalse-1121/mala_asr_epoch_2_step_25000_best test_scp_file_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/data/aishell-1/asr/test +# prompt_style="\{\}" # "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" | "USER: {}\n ASSISTANT:" Comment:Changed it in aispeech_asr_config.py # Choose Encoder diff --git a/examples/aispeech_asr/scripts/finetune_deepspeed.sh b/examples/aispeech_asr/scripts/finetune_deepspeed.sh index 52071137..1e86e00f 100644 --- a/examples/aispeech_asr/scripts/finetune_deepspeed.sh +++ b/examples/aispeech_asr/scripts/finetune_deepspeed.sh @@ -9,16 +9,16 @@ export OMP_NUM_THREADS=1 -run_dir=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/github/SLAM-LLM-NPU +run_dir=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/github/SLAM-LLM cd $run_dir code_dir=examples/aispeech_asr train_scp_file_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/data/aishell-1/asr/test dev_scp_file_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/data/aishell-1/asr/test -train_max_frame_length=500 -eval_max_frame_length=500 +train_max_frame_length=2000 +eval_max_frame_length=2500 multitask_prompt_path="/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/data/multiprompt.jsonl" -prompt_style="\{\}" # "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" | "USER: {}\n ASSISTANT:" +# prompt_style="\{\}" # "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" | "USER: {}\n ASSISTANT:" Comment:Changed it in aispeech_asr_config.py projector=linear encoder_name=whisper llm_name=Qwen2.5-7B-Instruct @@ -30,7 +30,7 @@ pad_or_trim=true # For whisper deepspeed_config=examples/aispeech_asr/conf/ds_config.json if [[ $use_peft == "true" || $freeze_encoder == false ]];then - ckpt_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/project/aispeech_asr/exp/slidespeech/20250414/whisper_linear_Qwen2.5-7B-Instruct_lorafalse_padtrue_normal_asr_speedfalse_specaugfalse-1515_slidespeech_text/mala_asr_epoch_2_step_7000 + ckpt_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/project/aispeech_asr/exp/slidespeech/20250414/whisper_linear_Qwen2.5-7B-Instruct_lorafalse_padtrue_normal_asr_speedfalse_specaugfalse-1515_slidespeech_text/mala_asr_epoch_2_step_7000/pytorch_model.bin fi # Choose Encoder @@ -86,7 +86,6 @@ hydra.run.dir=$output_dir \ ++model_config.encoder_path=$speech_encoder_path \ ++model_config.encoder_dim=$encoder_dim \ ++model_config.encoder_projector=$projector \ -++dataset_config.prompt_style=$prompt_style \ ++dataset_config.train_max_frame_length=$train_max_frame_length \ ++dataset_config.eval_max_frame_length=$eval_max_frame_length \ ++dataset_config.multitask_prompt_path=$multitask_prompt_path \ @@ -107,7 +106,7 @@ hydra.run.dir=$output_dir \ ++metric=acc \ " if [[ $use_peft == "true" || $freeze_encoder == false ]];then - hydra_args+="++ckpt_path=$ckpt_path/model.pt" + hydra_args+="++ckpt_path=$ckpt_path" fi diff --git a/examples/aispeech_asr/scripts/finetune_torchrun.sh b/examples/aispeech_asr/scripts/finetune_torchrun.sh index abcd3e5e..f08b1602 100644 --- a/examples/aispeech_asr/scripts/finetune_torchrun.sh +++ b/examples/aispeech_asr/scripts/finetune_torchrun.sh @@ -9,16 +9,16 @@ export OMP_NUM_THREADS=1 -run_dir=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/github/SLAM-LLM-NPU +run_dir=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/github/SLAM-LLM cd $run_dir code_dir=examples/aispeech_asr train_scp_file_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/data/aishell-1/asr/test dev_scp_file_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/data/aishell-1/asr/test -train_max_frame_length=1500 -eval_max_frame_length=3000 +train_max_frame_length=1400 +eval_max_frame_length=2000 multitask_prompt_path="/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/data/multiprompt.jsonl" -prompt_style="\{\}" # "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" | "USER: {}\n ASSISTANT:" +# prompt_style="\{\}" # "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" | "USER: {}\n ASSISTANT:" Comment:Changed it in aispeech_asr_config.py projector=linear encoder_name=whisper llm_name=Qwen2.5-1.5B-Instruct @@ -29,7 +29,7 @@ pad_or_trim=true # For whisper if [[ $use_peft == "true" || $freeze_encoder == false ]];then - ckpt_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/project/aispeech_asr/exp/slidespeech/20250414/whisper_linear_Qwen2.5-7B-Instruct_lorafalse_padtrue_normal_asr_speedfalse_specaugfalse-1515_slidespeech_text/mala_asr_epoch_2_step_7000 + ckpt_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/project/aispeech_asr/exp/slidespeech/20250414/whisper_linear_Qwen2.5-7B-Instruct_lorafalse_padtrue_normal_asr_speedfalse_specaugfalse-1515_slidespeech_text/mala_asr_epoch_2_step_7000/model.pt fi # Choose Encoder @@ -89,7 +89,6 @@ hydra.run.dir=$output_dir \ ++model_config.encoder_path=$speech_encoder_path \ ++model_config.encoder_dim=$encoder_dim \ ++model_config.encoder_projector=$projector \ -++dataset_config.prompt_style=$prompt_style \ ++dataset_config.train_max_frame_length=$train_max_frame_length \ ++dataset_config.eval_max_frame_length=$eval_max_frame_length \ ++dataset_config.multitask_prompt_path=$multitask_prompt_path \ @@ -104,18 +103,18 @@ hydra.run.dir=$output_dir \ ++train_config.freeze_llm=true \ ++train_config.use_peft=$use_peft \ ++train_config.batching_strategy=dynamic \ -++train_config.validation_interval=10 \ +++train_config.validation_interval=1000 \ ++train_config.num_workers_dataloader=8 \ ++train_config.output_dir=$output_dir \ ++metric=acc \ " if [[ $use_peft == "true" || $freeze_encoder == false ]];then - hydra_args+="++ckpt_path=$ckpt_path/model.pt" + hydra_args+="++ckpt_path=$ckpt_path" fi torchrun \ --nnodes 1 \ - --nproc_per_node 2 \ + --nproc_per_node 8 \ --master_port=29505 \ $code_dir/finetune_torchrun.py \ --config-path "conf" \ diff --git a/examples/aispeech_asr/scripts/transcribe_deepspeed_to_pt.py b/examples/aispeech_asr/scripts/transcribe_deepspeed_to_pt.py deleted file mode 100644 index e2a02862..00000000 --- a/examples/aispeech_asr/scripts/transcribe_deepspeed_to_pt.py +++ /dev/null @@ -1,9 +0,0 @@ -import argparse -import torch -import torch_npu -import sys -in_path = sys.argv[1] -out_path = sys.argv[2] -weight_dict = torch.load(in_path)["module"] -torch.save(weight_dict, f"{out_path}/model.pt") -print("[Finish]") \ No newline at end of file diff --git a/examples/asr_librispeech/README.md b/examples/asr_librispeech/README.md index fb663512..e53c1673 100644 --- a/examples/asr_librispeech/README.md +++ b/examples/asr_librispeech/README.md @@ -79,27 +79,7 @@ If you're interested in training with DeepSpeed, refer to the script `finetune_w } ``` -Note that when using `zero-0`/`1`/`2`, the DeepSpeed model is saved in a format that requires a script to convert `mp_rank_00_model_states.pt` to `model.pt`, such as `python transcribe_deepspeed_to_pt.py mp_rank_00_model_states.pt output_dir`. - -``` -global_step1000 -global_step1000/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt -... -global_step1000/mp_rank_00_model_states.pt -latest -zero_to_fp32.py -``` - -If training with `Zero-3`, the model is saved in a different format and can be converted using `python zero_to_fp32.py global_step50 outputdir`. - -``` -global_step50 -global_step50/zero_pp_rank_0_mp_rank_00_model_states.pt -global_step50/zero_pp_rank_0_mp_rank_00_optim_states.pt -... -latest -zero_to_fp32.py -``` +Note that when using `zero-0`/`1`/`2`/`3`, the DeepSpeed model is saved as `pytorch_model.bin` If you use bf16/fp16 training in DeepSpeed and encounter NaN in train/eval loss, check the autocast in `src/slam_llm/utils/deepspeed_utils.py`: ```python diff --git a/examples/asr_librispeech/scripts/transcribe_deepspeed_to_pt.py b/examples/asr_librispeech/scripts/transcribe_deepspeed_to_pt.py deleted file mode 100644 index e2a02862..00000000 --- a/examples/asr_librispeech/scripts/transcribe_deepspeed_to_pt.py +++ /dev/null @@ -1,9 +0,0 @@ -import argparse -import torch -import torch_npu -import sys -in_path = sys.argv[1] -out_path = sys.argv[2] -weight_dict = torch.load(in_path)["module"] -torch.save(weight_dict, f"{out_path}/model.pt") -print("[Finish]") \ No newline at end of file diff --git a/src/slam_llm/datasets/speech_dataset_large.py b/src/slam_llm/datasets/speech_dataset_large.py index 2ba299ac..311f0f9f 100644 --- a/src/slam_llm/datasets/speech_dataset_large.py +++ b/src/slam_llm/datasets/speech_dataset_large.py @@ -55,6 +55,8 @@ def __init__(self, dataset_config, tokenizer=None, split='train'): self.inference_mode = dataset_config.get("inference_mode", False) self.normalize = dataset_config.get("normalize", False) self.input_type = dataset_config.get("input_type", None) + self.max_audio_length = dataset_config.get("max_audio_length", 30) + self.audio_sample_rate = dataset_config.get("audio_sample_rate", 16000) assert self.input_type in ["raw", "mel"], "input_type must be one of [raw, mel]" def __iter__(self): @@ -86,7 +88,7 @@ def __iter__(self): ark_path = item["path"] numpy_array = kaldiio.load_mat(ark_path) audio_raw = numpy_array[1].astype(np.float32) / 32768 - if len(audio_raw) / 16000 > 30: + if len(audio_raw) / self.audio_sample_rate > self.max_audio_length: continue key = item["key"] target = item["target"] diff --git a/src/slam_llm/utils/checkpoint_handler.py b/src/slam_llm/utils/checkpoint_handler.py index afaabef0..24341534 100644 --- a/src/slam_llm/utils/checkpoint_handler.py +++ b/src/slam_llm/utils/checkpoint_handler.py @@ -6,7 +6,8 @@ import torch import time from collections import OrderedDict - +from deepspeed.utils.zero_to_fp32 import ( + convert_zero_checkpoint_to_fp32_state_dict) from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, StateDictType, @@ -168,11 +169,18 @@ def save_model_checkpoint( def save_model_checkpoint_deepspeed(model, cfg, checkpoint_name="checkpoint"): logger.info(f"--> saving model ...") save_dir = os.path.join(cfg.output_dir, checkpoint_name) - os.makedirs(save_dir, exist_ok=True) + dist.barrier() + if os.environ["RANK"] == "0": + os.makedirs(save_dir, exist_ok=True) + dist.barrier() # save_full_path = os.path.join(save_dir, "model.pt") save_full_path = save_dir model.save_checkpoint(save_dir=save_full_path, exclude_frozen_parameters=True) - logger.info(f"encoder saved at {save_full_path}") + dist.barrier() + if os.environ["RANK"] == "0": + convert_zero_checkpoint_to_fp32_state_dict(save_full_path,save_full_path) + dist.barrier() + logger.info(f"encoder saved at {save_full_path}_model") def save_model_checkpoint_peft(model, optimizer, rank, cfg, checkpoint_name="checkpoint", save_trainable_only=True): logger.info(f"--> saving model ...") diff --git a/src/slam_llm/utils/deepspeed_utils.py b/src/slam_llm/utils/deepspeed_utils.py index 75903fc4..21b60f93 100644 --- a/src/slam_llm/utils/deepspeed_utils.py +++ b/src/slam_llm/utils/deepspeed_utils.py @@ -123,7 +123,7 @@ def deepspeed_join(group_join): local_rank = int(os.environ["LOCAL_RANK"]) rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) - logging.info("Detected uneven workload distribution: {}\n".format(e) + + logging.info("Detected uneven workload distribution. " + "Break current worker to manually join all workers, " + "world_size {}, current rank {}, current local_rank {}\n". format(world_size, rank, local_rank)) diff --git a/src/slam_llm/utils/train_utils.py b/src/slam_llm/utils/train_utils.py index 621306b5..4205a6e7 100644 --- a/src/slam_llm/utils/train_utils.py +++ b/src/slam_llm/utils/train_utils.py @@ -88,7 +88,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche best_val_acc = 0.0 for epoch in range(train_config.num_epochs): epoch_start_time = time.perf_counter() - with MemoryTrace() as memtrace,Join([model,optimizer]): # track the memory usage + with MemoryTrace() as memtrace,Join([model]): # track the memory usage model.train() total_loss = 0.0 total_acc = 0.0 @@ -326,8 +326,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if torch.cuda.device_count() > 1 and (train_config.enable_fsdp or train_config.enable_ddp): dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) dist.all_reduce(total_acc, op=dist.ReduceOp.SUM) - train_epoch_loss = total_loss / (len(train_dataloader) if train_config.batching_strategy != "dynamic" else (step + 1) *train_config.num_epochs) - train_epoch_acc = total_acc / (len(train_dataloader) if train_config.batching_strategy != "dynamic" else (step + 1) *train_config.num_epochs) + train_epoch_loss = total_loss / (len(train_dataloader) if train_config.batching_strategy != "dynamic" else (step + 1) ) + train_epoch_acc = total_acc / (len(train_dataloader) if train_config.batching_strategy != "dynamic" else (step + 1) ) if train_config.enable_fsdp or train_config.enable_ddp: train_epoch_loss = train_epoch_loss/world_size train_epoch_acc = train_epoch_acc/world_size From 856f430eb719aa5fc707bf5ded932bdbab448a7c Mon Sep 17 00:00:00 2001 From: teamtee <87838510+teamtee@users.noreply.github.com> Date: Mon, 16 Jun 2025 13:47:06 +0800 Subject: [PATCH 2/2] Update README.md --- examples/asr_librispeech/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/asr_librispeech/README.md b/examples/asr_librispeech/README.md index e53c1673..a8a9d053 100644 --- a/examples/asr_librispeech/README.md +++ b/examples/asr_librispeech/README.md @@ -79,7 +79,7 @@ If you're interested in training with DeepSpeed, refer to the script `finetune_w } ``` -Note that when using `zero-0`/`1`/`2`/`3`, the DeepSpeed model is saved as `pytorch_model.bin` +Note that when using `zero-0`/`1`/`2`/`3`, the DeepSpeed model is saved as `pytorch_model.bin`, and you should change "++ckpt_path=$ckpt_path/model.pt" to " ++ckpt_path=$ckpt_path/pytorch_model.bin" in the script to use the model during inference. If you use bf16/fp16 training in DeepSpeed and encounter NaN in train/eval loss, check the autocast in `src/slam_llm/utils/deepspeed_utils.py`: ```python @@ -96,4 +96,4 @@ You can refer to the paper for more results. journal={arXiv preprint arXiv:2402.08846}, year={2024} } -``` \ No newline at end of file +```