Skip to content

Commit 5ab1e81

Browse files
committed
asr_librispeech support deepspeed and update aispeech_asr
1 parent f07183e commit 5ab1e81

File tree

8 files changed

+143
-53
lines changed

8 files changed

+143
-53
lines changed

examples/aispeech_asr/README.md

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,59 @@ if [[ $use_peft == "true" ]]; then
8989
ckpt_path= # For DDP training, provide the path to the saved pt file; for DeepSpeed training, convert mp_rank_00_model_states.pt to model.pt using the `scripts/transcribe_deepspeed_to_pt.py` script
9090
fi
9191
```
92+
### Deepspeed
93+
When using `bf16`/`fp16` for training, deepspeed saves about 20GB of GPU memory compared to `torchrun` when training a 7B model. For 7B models, it's recommended to use `zero-0`/`1`/`2`, while for extremely large models, `zero-3` can be used, though communication may become a bottleneck.
9294

95+
```json
96+
{
97+
"train_micro_batch_size_per_gpu": 4,
98+
"gradient_accumulation_steps": 1,
99+
"optimizer": {
100+
"type": "Adam",
101+
"params": {
102+
"lr": 1e-4
103+
}
104+
},
105+
"fp16": {
106+
"enabled": true
107+
},
108+
"zero_optimization": {
109+
"stage": 2,
110+
"offload_optimizer": {
111+
"device": "cpu"
112+
}
113+
}
114+
}
115+
```
116+
117+
Note that when using `zero-0`/`1`/`2`, the DeepSpeed model is saved in a format that requires a script to convert `mp_rank_00_model_states.pt` to `model.pt`, such as `python scripts/transcribe_deepspeed_to_pt.py mp_rank_00_model_states.pt output_dir`.
118+
119+
```
120+
global_step1000
121+
global_step1000/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
122+
...
123+
global_step1000/mp_rank_00_model_states.pt
124+
latest
125+
zero_to_fp32.py
126+
```
127+
128+
If training with `Zero-3`, the model is saved in a different format and can be converted using `python zero_to_fp32.py global_step50 outputdir`.
129+
130+
```
131+
global_step50
132+
global_step50/zero_pp_rank_0_mp_rank_00_model_states.pt
133+
global_step50/zero_pp_rank_0_mp_rank_00_optim_states.pt
134+
...
135+
latest
136+
zero_to_fp32.py
137+
```
138+
If you use bf16/fp16 training in DeepSpeed and encounter NaN in train/eval loss, check the autocast in `src/slam_llm/utils/deepspeed_utils.py`:
139+
140+
```python
141+
with autocast() # original code
142+
with autocast(dtype=torch.bfloat16) # must work
143+
with autocast(dtype=torch.float16)
144+
```
93145
## Decoding
94146

95147
- **Single-machine single-GPU decoding**: Refer to `scripts/decode.sh`

examples/aispeech_asr/finetune_deepspeed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class RunConfig:
2121
ckpt_path: Optional[str] = field(
2222
default=None, metadata={"help": "The path to projector checkpoint"}
2323
)
24-
deepspeed_config : str =""
24+
deepspeed_config : str ="examples/aispeech_asr/conf/ds_config.json"
2525
deepspeed_ckpt_path: Optional[str] = field(
2626
default=None, metadata={"help": "The path to projector checkpoint"}
2727
)

examples/aispeech_asr/scripts/finetune_deepspeed.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use_fp16=true
2727
freeze_encoder=true
2828
pad_or_trim=true # For whisper
2929

30-
deepspeed_config=code_dir=examples/aispeech_asr/conf/ds_config.json
30+
deepspeed_config=examples/aispeech_asr/conf/ds_config.json
3131

3232
if [[ $use_peft == "true" || $freeze_encoder == false ]];then
3333
ckpt_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/project/aispeech_asr/exp/slidespeech/20250414/whisper_linear_Qwen2.5-7B-Instruct_lorafalse_padtrue_normal_asr_speedfalse_specaugfalse-1515_slidespeech_text/mala_asr_epoch_2_step_7000

examples/aispeech_asr/slam_llm

Lines changed: 0 additions & 1 deletion
This file was deleted.

examples/asr_librispeech/README.md

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,60 @@ Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memor
5353

5454
If you are interested in running full parameter fine-tuning on the 70B model, you can enable `low_cpu_fsdp` mode as the following command. This option will load model on rank0 only before moving model to devices to construct FSDP. This can dramatically save cpu memory when loading large models like 70B (on a 8-gpu node, this reduces cpu memory from 2+T to 280G for 70B model). This has been tested with `BF16` on 16xA100, 80GB GPUs.
5555

56+
### Fine-tuning using Deepspeed
57+
58+
If you're interested in training with DeepSpeed, refer to the script `finetune_whisper_large_linear_vicuna_7b_deepspeed.sh`. The training configuration is shown in `conf/ds_config.json`. When using `bf16`/`fp16` for training, it saves about 20GB of GPU memory compared to `torchrun` when training a 7B model. For 7B models, it's recommended to use `zero-0`/`1`/`2`, while for extremely large models, `zero-3` can be used, though communication may become a bottleneck.
59+
60+
```json
61+
{
62+
"train_micro_batch_size_per_gpu": 4,
63+
"gradient_accumulation_steps": 1,
64+
"optimizer": {
65+
"type": "Adam",
66+
"params": {
67+
"lr": 1e-4
68+
}
69+
},
70+
"fp16": {
71+
"enabled": true
72+
},
73+
"zero_optimization": {
74+
"stage": 2,
75+
"offload_optimizer": {
76+
"device": "cpu"
77+
}
78+
}
79+
}
80+
```
81+
82+
Note that when using `zero-0`/`1`/`2`, the DeepSpeed model is saved in a format that requires a script to convert `mp_rank_00_model_states.pt` to `model.pt`, such as `python transcribe_deepspeed_to_pt.py mp_rank_00_model_states.pt output_dir`.
83+
84+
```
85+
global_step1000
86+
global_step1000/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
87+
...
88+
global_step1000/mp_rank_00_model_states.pt
89+
latest
90+
zero_to_fp32.py
91+
```
92+
93+
If training with `Zero-3`, the model is saved in a different format and can be converted using `python zero_to_fp32.py global_step50 outputdir`.
94+
95+
```
96+
global_step50
97+
global_step50/zero_pp_rank_0_mp_rank_00_model_states.pt
98+
global_step50/zero_pp_rank_0_mp_rank_00_optim_states.pt
99+
...
100+
latest
101+
zero_to_fp32.py
102+
```
103+
If you use bf16/fp16 training in DeepSpeed and encounter NaN in train/eval loss, check the autocast in `src/slam_llm/utils/deepspeed_utils.py`:
104+
105+
```python
106+
with autocast() # original code
107+
with autocast(dtype=torch.bfloat16)
108+
with autocast(dtype=torch.float16)
109+
```
56110
## Citation
57111
You can refer to the paper for more results.
58112
```
Lines changed: 23 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#!/bin/bash
22
# export PYTHONPATH=/root/whisper:$PYTHONPATH
33
export PYTHONPATH=/root/fairseq:$PYTHONPATH
4-
# export CUDA_VISIBLE_DEVICES=6,7
54
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
65
export TOKENIZERS_PARALLELISM=false
76
# export CUDA_LAUNCH_BLOCKING=1
@@ -12,81 +11,58 @@ export OMP_NUM_THREADS=1
1211
# export NCCL_DEBUG_SUBSYS=ALL
1312
# export TORCH_DISTRIBUTED_DEBUG=INFO
1413

15-
run_dir=/work/SLAM-LLM
14+
run_dir=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/github/SLAM-LLM-NPU
1615
cd $run_dir
1716
code_dir=examples/asr_librispeech
1817

19-
speech_encoder_path=/cxgroup/model/whisper/large-v3.pt
18+
speech_encoder_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/model/wavlm/WavLM-Large.pt
19+
llm_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/model/vicuna-7b-v1.5
20+
train_data_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/data/multitask_wav.jsonl
21+
val_data_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/data/multitask_wav.jsonl
2022

21-
llm_path=/cxgroup/model/vicuna-7b-v1.5
22-
# llm_path=/nfs/maziyang.mzy/models/vicuna-13b-v1.5
23-
24-
output_dir=/work/exps/vicuna-7b-v1.5-librispeech-linear-steplrwarmupkeep1e-4-whisper-largev3-$(date +"%Y%m%d")-deepspeed
23+
output_dir=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/github/SLAM-LLM-NPU/examples/asr_librispeech/exp-$(date +"%Y%m%d")
2524

2625
hydra_args="
2726
hydra.run.dir=$output_dir \
2827
++model_config.llm_name=vicuna-7b-v1.5 \
2928
++model_config.llm_path=$llm_path \
3029
++model_config.llm_dim=4096 \
31-
++model_config.encoder_name=whisper \
30+
++model_config.encoder_name=wavlm \
31+
++model_config.normalize=true \
32+
++dataset_config.normalize=true \
3233
++model_config.encoder_projector_ds_rate=5 \
3334
++model_config.encoder_path=$speech_encoder_path \
34-
++model_config.encoder_dim=1280 \
35+
++model_config.encoder_dim=1024 \
3536
++model_config.encoder_projector=linear \
3637
++dataset_config.dataset=speech_dataset \
37-
++dataset_config.train_data_path=data/librispeech/train960.jsonl \
38-
++dataset_config.val_data_path=data/librispeech/dev.jsonl \
39-
++dataset_config.input_type=mel \
40-
++dataset_config.mel_size=128 \
38+
++dataset_config.train_data_path=$train_data_path \
39+
++dataset_config.val_data_path=$val_data_path \
40+
++dataset_config.input_type=raw \
4141
++train_config.model_name=asr \
42-
++train_config.num_epochs=6 \
43-
++train_config.enable_deepspeed=true \
42+
++train_config.num_epochs=3 \
4443
++train_config.freeze_encoder=true \
4544
++train_config.freeze_llm=true \
4645
++train_config.batching_strategy=custom \
4746
++train_config.warmup_steps=1000 \
4847
++train_config.total_steps=100000 \
4948
++train_config.lr=1e-4 \
50-
++train_config.validation_interval=1000 \
51-
++train_config.batch_size_training=4 \
49+
++train_config.validation_interval=50 \
50+
++train_config.batch_size_training=1 \
5251
++train_config.val_batch_size=4 \
53-
++train_config.num_workers_dataloader=4 \
52+
++train_config.num_workers_dataloader=2 \
5453
++train_config.output_dir=$output_dir \
5554
++metric=acc \
5655
"
57-
# ++train_config.use_peft=true \
58-
# ++train_config.peft_config.r=32 \
59-
# ++model_config.encoder_projector=linear \
60-
# ++model_config.encoder_projector_ds_rate=5 \
61-
# ++train_config.peft_config.peft_method=lora \
62-
# --peft_ckpt "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4" \
63-
# --ckpt_path "/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-5-whisper-prompt-padding30-20231228/asr/4/model.pt" \
64-
#++log_config.log_file=/$output_dir/train.log \
65-
#++log_config.use_wandb=true \
66-
#++log_config.wandb_dir=$output_dir \
67-
#++log_config.wandb_entity_name=zym22 \
68-
#++log_config.wandb_project_name=slam-llm \
69-
#++log_config.wandb_exp_name=${0##*/%.*} \
70-
#++log_config.log_interval 5 \
56+
57+
58+
7159

7260
deepspeed \
73-
--include localhost:4,5 \
74-
--master_port=29502 \
61+
--num_gpus=8 \
62+
--num_nodes=1 \
7563
$code_dir/deepspeed_finetune_asr.py \
7664
$hydra_args
7765
# --num_gpus=2 \
7866
# --num_nodes=1 \
67+
# --master_port=29502 \
7968

80-
# -m debugpy --listen 5678 --wait-for-client
81-
# if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then
82-
# python -m debugpy --listen 5678 --wait-for-client finetune_asr.py \
83-
# $hydra_args
84-
# else
85-
# deepspeed \
86-
# --num_nodes=1 \
87-
# --include localhost:6,7 \
88-
# --master_port=29502 \
89-
# $code_dir/deepspeed_finetune_asr.py \
90-
# $hydra_args
91-
# # --num_gpus=2 \
92-
# fi
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import argparse
2+
import torch
3+
import torch_npu
4+
import sys
5+
in_path = sys.argv[1]
6+
out_path = sys.argv[2]
7+
weight_dict = torch.load(in_path)["module"]
8+
torch.save(weight_dict, f"{out_path}/model.pt")
9+
print("[Finish]")

src/slam_llm/datasets/speech_dataset_large.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,22 +60,22 @@ def __init__(self, dataset_config, tokenizer=None, split='train'):
6060
def __iter__(self):
6161
multitask_task_path = os.path.join(self.data_path,"multitask.jsonl")
6262
worker_info = torch.utils.data.get_worker_info()
63-
if worker_info is None: # 不在 DataLoader 的多进程环境中
63+
if worker_info is None: # Not in the multi-processing environment of DataLoader.
6464
num_workers = 1
6565
worker_id = 0
6666
else:
6767
num_workers = worker_info.num_workers
6868
worker_id = worker_info.id
6969

70-
# 获取分布式环境中的进程信息
70+
# Obtain the process information in the distributed environment.
7171
if dist.is_available() and dist.is_initialized():
7272
world_size = dist.get_world_size()
7373
rank = dist.get_rank()
7474
else:
7575
world_size = 1
7676
rank = 0
7777

78-
# 计算每个 worker 和每个进程应该处理的数据范围
78+
# Calculate the data range that each worker and each process should handle.
7979
total_num_workers = num_workers * world_size
8080
worker_rank = rank * num_workers + worker_id
8181
data_index = 0

0 commit comments

Comments
 (0)