Skip to content

Commit 43d9e54

Browse files
committed
Add:An example, aispeech_asr, and a large dataset, speech_dataset_large, have been added, supporting multi-machine multi-GPU decoding.
1 parent aa9ac13 commit 43d9e54

24 files changed

+1765
-54
lines changed

examples/aispeech_asr/README.md

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# AISPEECH_ASR
2+
3+
## Overview
4+
5+
This example is designed for large-scale industrial data training, suitable for datasets on the order of 100,000 hours. Its main features include:
6+
- **Support for multi-task training**: Designed to support tasks such as ASR and ST through a unified data format.
7+
- **Dynamic prompt selection**: Supports random selection from multiple prompts.
8+
- **Iterative dataset**: Uses an iterative dataset format to reduce startup time for large datasets.
9+
- **Deepspeed training**: Supports DeepSpeed training to significantly reduce memory usage.
10+
- **Multi-machine multi-GPU inference**: Supports distributed inference across multiple machines and GPUs to reduce evaluation time.
11+
- **Dynamic frame batching**: Dynamically combines frames based on audio size rather than using a fixed batch size, significantly reducing training and evaluation time (reduces training time by 3/4 for 100,000 hours of data).
12+
- **Ascend NPU compatibility**: Optimized for compatibility with Ascend NPU.
13+
14+
This example is modified from `mala_asr_slidespeech`.
15+
16+
## Model Architecture
17+
18+
The model architecture can be dynamically selected within the scope supported by SLAM-LMM. Below are some recommended configurations:
19+
- **Encoder**: WavLM, Whisper
20+
- **Projector**: Linear
21+
- **LLM**: Qwen2.5-7B-Instruct, Vicuna1.5-7B
22+
23+
## Data Preparation
24+
25+
The following two files are required:
26+
- `multitask.jsonl`
27+
- `multiprompt.jsonl`
28+
29+
### multitask.jsonl
30+
31+
The format of this file is as follows, where `path` supports both ark format and wav files:
32+
```json
33+
{"key": "BAC009S0002W0122", "task": "ASR", "target": "而对楼市成交抑制作用最大的限购", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/train/data/data_wav.1.ark:17"}
34+
{"key": "BAC009S0002W0123", "task": "ASR", "target": "也成为地方政府的眼中钉", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/train/data/data_wav.1.ark:191758"}
35+
{"key": "BAC009S0002W0124", "task": "ASR", "target": "自六月底呼和浩特市率先宣布取消限购后", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/train/data/data_wav.1.ark:315339"}
36+
{"key": "BAC009S0764W0238", "task": "hotword", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/test/data/data_wav.1.ark:17343733", "target": "形成一批具有国际竞争力的中国企业", "hotword": "中国"}
37+
```
38+
39+
### multiprompt.jsonl
40+
41+
The format of this file is as follows:
42+
```json
43+
{"task": "ASR", "prompt": "Transcribe speech to text."}
44+
{"task": "ASR", "prompt": "请识别语音."}
45+
{"task": "ZH2EN", "prompt": "请识别语音并翻译为英文:"}
46+
{"task": "EN2ZH", "prompt": "请识别语音并翻译为中文:"}
47+
{"task": "prevtext", "prompt": "Transcribe speech to text, below are the previous historical transcription texts:{}."}
48+
{"task": "hotword", "prompt": "Transcribe speech to text, follow words may occur:{}."}
49+
```
50+
51+
### Notes
52+
- If multiple prompts are provided, one will be selected dynamically.
53+
- For additional information (e.g., hotwords), include the task-named information in `multitask.jsonl` and use `{}` in the prompt to inject this information. Additionally, update the `append_info_tasks` in the `aispeech_config` file:
54+
```python
55+
append_info_tasks: List = field(default_factory=lambda: ["hotword"])
56+
```
57+
58+
## Training a New Model
59+
60+
### Script Preparation
61+
62+
Prepare and modify the following content in `scripts/finetune_deepspeed.sh` or `scripts/finetune_torchrun.sh` (Deepspeed is recommended):
63+
```bash
64+
run_dir= # Directory to save the model
65+
train_scp_file_path= # Path to training data
66+
dev_scp_file_path= # Path to validation data
67+
train_max_frame_length=1500 # Maximum frame length for training
68+
eval_max_frame_length=1000 # Maximum frame length for evaluation
69+
multitask_prompt_path= # Path to multitask.jsonl
70+
prompt_style="\{\}" # Prompt style, e.g., "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" or "USER: {}\n ASSISTANT:"
71+
projector=linear # Type of projector
72+
encoder_name=whisper # Name of the encoder
73+
llm_name=Qwen2.5-7B-Instruct # Name of the LLM
74+
use_peft=false # Whether to use PEFT (for LLM)
75+
use_fp16=true # Whether to use FP16
76+
freeze_encoder=true # Whether to freeze the encoder
77+
pad_or_trim=true # Whether to use pad_or_trim (for Whisper)
78+
deepspeed_config= # Path to DeepSpeed configuration file
79+
```
80+
81+
Typically, we first train the projector and then fine-tune the LoRA. For projector training, set:
82+
```bash
83+
use_peft=false
84+
```
85+
86+
For LoRA training, set (with `ckpt_path` pointing to the model saved in the previous step):
87+
```bash
88+
use_peft=true
89+
if [[ $use_peft == "true" ]]; then
90+
ckpt_path= # For DDP training, provide the path to the saved pt file; for DeepSpeed training, convert mp_rank_00_model_states.pt to model.pt using the `scripts/transcribe_deepspeed_to_pt.py` script
91+
fi
92+
```
93+
94+
## Decoding
95+
96+
- **Single-machine single-GPU decoding**: Refer to `scripts/decode.sh`
97+
- **Single-machine multi-GPU decoding**: Refer to `scripts/decode_deepspeed.sh`
98+
99+
## Multi-Machine Multi-GPU Support
100+
101+
Multi-machine multi-GPU training can be supported with minor modifications to the `finetune_deepspeed.sh` or `scripts/decode_deepspeed.sh` scripts. Due to environment-specific requirements, this example does not include dedicated scripts for multi-machine multi-GPU setups.

examples/aispeech_asr/README_zh.md

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# AISPEECH_ASR
2+
3+
## 概述
4+
5+
这是为工业界大规模数据训练准备的示例,适用于10万小时量级的数据训练,主要特点如下:
6+
- **多任务训练支持**:通过设计数据格式,支持包括ASR、ST等多种任务。
7+
- **动态Prompt选择**:支持在多个Prompt中随机选择。
8+
- **迭代式dataset**:采用迭代形式的dataset,减少大数据量时的启动时间。
9+
- **Deepspeed训练**:支持Deepspeed训练,显著减少内存使用。
10+
- **多机多卡推理**:支持多机多卡推理,减少评估时间。
11+
- **动态帧数组合**:根据每个音频大小动态组合合适的帧数进行训练,而非使用固定的batch_size,大大减少了训练和评估时间(在10万小时量级的数据上,训练时间减少了3/4)。
12+
- **昇腾NPU适配**:适配支持昇腾NPU。
13+
14+
本示例基于`mala_asr_slidespeech`进行修改。
15+
16+
## 模型架构
17+
18+
可以根据需要,在SLAM—LMM支持的范围内动态选择模型架构。以下是一些推荐的模型配置:
19+
- **Encoder**:WavLM, Whisper
20+
- **Projector**:Linear
21+
- **LLM**:Qwen2.5-7B-Instruct, Vicuna1.5-7B
22+
23+
## 数据准备
24+
25+
需要准备以下两个文件:
26+
- `multitask.jsonl`
27+
- `multiprompt.jsonl`
28+
29+
### multitask.jsonl
30+
31+
该文件的内容格式如下,其中`path`支持ark格式和wav文件:
32+
```json
33+
{"key": "BAC009S0002W0122", "task": "ASR", "target": "而对楼市成交抑制作用最大的限购", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/train/data/data_wav.1.ark:17"}
34+
{"key": "BAC009S0002W0123", "task": "ASR", "target": "也成为地方政府的眼中钉", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/train/data/data_wav.1.ark:191758"}
35+
{"key": "BAC009S0002W0124", "task": "ASR", "target": "自六月底呼和浩特市率先宣布取消限购后", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/train/data/data_wav.1.ark:315339"}
36+
{"key": "BAC009S0764W0238", "task": "hotword", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/test/data/data_wav.1.ark:17343733", "target": "形成一批具有国际竞争力的中国企业", "hotword": "中国"}
37+
```
38+
39+
### multiprompt.jsonl
40+
41+
该文件的内容格式如下:
42+
```json
43+
{"task": "ASR", "prompt": "Transcribe speech to text."}
44+
{"task": "ASR", "prompt": "请识别语音."}
45+
{"task": "ZH2EN", "prompt": "请识别语音并翻译为英文:"}
46+
{"task": "EN2ZH", "prompt": "请识别语音并翻译为中文:"}
47+
{"task": "prevtext", "prompt": "Transcribe speech to text, below are the previous historical transcription texts:{}."}
48+
{"task": "hotword", "prompt": "Transcribe speech to text, follow words may occur:{}."}
49+
```
50+
51+
### 注意事项
52+
- 如果有多条Prompt,会动态选择其中一条。
53+
- 如果有额外信息(如热词),请在`multitask.jsonl`中提供与任务同名的信息,并在Prompt中使用`{}`注入该信息。同时,修改`aispeech_config`文件中的`append_info_tasks`
54+
```python
55+
append_info_tasks: List = field(default_factory=lambda: ["hotword"])
56+
```
57+
58+
## 训练新模型
59+
60+
### 脚本准备
61+
62+
`scripts/finetune_deepspeed.sh``scripts/finetune_torchrun.sh`中准备并修改以下内容(推荐使用Deepspeed):
63+
```bash
64+
run_dir= # 模型保存目录
65+
train_scp_file_path= # 训练数据路径
66+
dev_scp_file_path= # 验证数据路径
67+
train_max_frame_length=1500 # 训练时的最大帧长度
68+
eval_max_frame_length=1000 # 评估时的最大帧长度
69+
multitask_prompt_path= # multitask.jsonl文件路径
70+
prompt_style="\{\}" # Prompt样式,可选格式如"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"或"USER: {}\n ASSISTANT:"
71+
projector=linear # Projector类型
72+
encoder_name=whisper # Encoder名称
73+
llm_name=Qwen2.5-7B-Instruct # LLM名称
74+
use_peft=false # 是否使用PEFT(对于LLM)
75+
use_fp16=true # 是否使用FP16
76+
freeze_encoder=true # 是否冻结Encoder
77+
pad_or_trim=true # 是否使用pad_or_trim(对于Whisper)
78+
deepspeed_config= # DeepSpeed配置文件路径
79+
```
80+
81+
通常,我们首先训练Projector,然后再训练LoRA。训练Projector时,设置如下:
82+
```bash
83+
use_peft=false
84+
```
85+
86+
训练LoRA时,设置如下(`ckpt_path`是上一步训练保存的模型路径):
87+
```bash
88+
use_peft=true
89+
if [[ $use_peft == "true" ]]; then
90+
ckpt_path= # 如果是DDP训练,直接写入保存的pt文件路径;如果是Deepspeed训练,需将mp_rank_00_model_states.pt文件转化为model.pt,可使用`scripts/transcribe_deepspeed_to_pt.py`脚本
91+
fi
92+
```
93+
94+
## 解码
95+
96+
- **单机单卡解码**:参考`scripts/decode.sh`
97+
- **单机多卡解码**:参考`scripts/decode_deepspeed.sh`
98+
99+
## 多机多卡支持
100+
简单修改脚本finetune_deepspeed.sh 或者scripts/decode_deepspeed.sh`后可以支持多机多卡训练,因为环境不同所做的修改也不同,本实例就不放出多机多卡的脚本了
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
from dataclasses import dataclass, field
2+
from typing import Optional, List
3+
from torch.distributed.fsdp import ShardingStrategy
4+
5+
6+
@dataclass
7+
class ModelConfig:
8+
file: str = "examples/aispeech_asr/model/aispeech_asr.py:model_factory"
9+
llm_name: str = "vicuna-7b-v1.5"
10+
llm_path: str = "PATH/to/LLAMA/7B"
11+
llm_type: str = "decoder_only"
12+
llm_dim: int = 4096
13+
whisper_decode : Optional[bool] = False
14+
encoder_name: Optional[str] = None
15+
encoder_ds_rate: int = 2
16+
encoder_path: Optional[str] = None
17+
encoder_path_hf: Optional[str] = None
18+
encoder_dim: int = 1280
19+
encoder_projector: str = "linear"
20+
qformer_layers : int = 8
21+
encoder_projector_ds_rate: int = 5
22+
modal: str = "audio"
23+
normalize: Optional[bool] = field(default=False, metadata={
24+
"help": "whether input is normalized, used for models such as wavlm"
25+
})
26+
encoder_type: str = field(default="finetune", metadata={
27+
"help": "whether model is only pretrained or finetuned, used for models such as hubert"
28+
})
29+
30+
31+
@dataclass
32+
class PeftConfig:
33+
peft_method: str = "lora" # None , llama_adapter, prefix
34+
r: int = 64
35+
lora_alpha: int = 16
36+
target_modules: List = field(default_factory=lambda: [ "q_proj","k_proj", "v_proj", "o_proj", "up_proj","gate_proj","down_proj"])
37+
bias: str = "none"
38+
task_type: str = "CAUSAL_LM"
39+
lora_dropout: float = 0.05
40+
inference_mode: bool = False
41+
42+
@dataclass
43+
class TrainConfig:
44+
model_name:str = "PATH/to/LLAMA/7B"
45+
enable_ddp:bool = False
46+
enable_deepspeed:bool = False
47+
enable_fsdp:bool = False
48+
low_cpu_fsdp:bool = False
49+
run_validation:bool = True
50+
batch_size_training: Optional[int] = None
51+
batching_strategy:str = field(default="packing", metadata={
52+
"help":"alternative: padding"
53+
}) #
54+
context_length:int = 4096
55+
gradient_accumulation_steps:int = 1
56+
num_epochs:int = 3
57+
num_workers_dataloader:int = 1
58+
warmup_steps:int = 1000
59+
total_steps:int = 100000
60+
validation_interval:int = 1000
61+
lr:float = 1e-4
62+
weight_decay:float = 0.0
63+
gamma:float = 0.85
64+
seed:int = 42
65+
use_fp16:bool = False
66+
mixed_precision:bool = True
67+
val_batch_size:Optional[int] = None
68+
69+
use_peft:bool = False
70+
peft_config:PeftConfig = field(default_factory=PeftConfig)
71+
output_dir:str = "PATH/to/save/PEFT/model"
72+
freeze_layers:bool = False
73+
num_freeze_layers:int = 1
74+
quantization:bool = False
75+
one_gpu:bool = False
76+
save_model:bool = True
77+
dist_checkpoint_root_folder:str = "PATH/to/save/FSDP/model" # will be used if using FSDP
78+
dist_checkpoint_folder:str = "fine-tuned" # will be used if using FSDP
79+
save_optimizer:bool = False # will be used if using FSDP
80+
use_fast_kernels:bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
81+
run_test_during_validation:bool = False
82+
run_test_during_validation_file:str = "test.wav"
83+
run_test_during_validation_prompt:str = "<|ASR|>"
84+
freeze_llm:bool = field(default=False, metadata={
85+
"help": "whether to freeze llm when finetuning, should be true when use peft finetuning"
86+
})
87+
freeze_encoder:bool = False
88+
89+
@dataclass
90+
class DataConfig:
91+
dataset: str = "multitask_dataset"
92+
train_max_frame_length: int = 1500
93+
eval_max_frame_length: int = 1000
94+
multitask_prompt_path: str = "/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/data/multiprompt.jsonl"
95+
prompt_style: str = "\{\}" #
96+
append_info_tasks : List = field(default_factory=lambda: [ "hotword"])
97+
file: str = "examples/aispeech_asr/slam_llm/datasets/speech_dataset_large.py:get_speech_dataset"
98+
train_scp_file_path: str = ""
99+
dev_scp_file_path: str = ""
100+
test_scp_file_path: str = ""
101+
train_split: str = "train"
102+
dev_split: str = "dev"
103+
test_split:str = "test"
104+
pad_or_trim: bool = True
105+
prompt: Optional[str] = None
106+
use_ocr: bool = True
107+
inference_mode: bool = False
108+
lower: bool = False
109+
fix_length_audio: int = -1
110+
inference_mode:bool = False
111+
input_type: str = field(default="raw", metadata={
112+
"help":"Use raw when input is wav, mel when for whisper"
113+
})
114+
mel_size: int = field(default=80, metadata={
115+
"help": "80 for whisper large v1 and v2, 128 for v3"
116+
})
117+
normalize: Optional[bool] = field(default=False, metadata={
118+
"help": "whether input is normalized, used for models such as wavlm"
119+
})
120+
121+
@dataclass
122+
class FSDPConfig:
123+
mixed_precision: bool = True
124+
use_fp16: bool = False
125+
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
126+
sharding_strategy: ShardingStrategy = "SHARD_GRAD_OP" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
127+
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
128+
fsdp_activation_checkpointing: bool = True
129+
fsdp_cpu_offload: bool = False
130+
pure_bf16: bool = False
131+
optimizer: str = "AdamW"
132+
133+
@dataclass
134+
class LogConfig:
135+
use_wandb: bool = False
136+
wandb_dir: str = "tmp/test_wandb"
137+
wandb_entity_name: str = "project_name"
138+
wandb_project_name: str = "project_name"
139+
wandb_exp_name: str = "exp_name"
140+
log_file: str = "tmp/test.log"
141+
log_interval: int = 5
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
{
2+
"train_micro_batch_size_per_gpu": 1,
3+
"gradient_accumulation_steps": 1,
4+
"optimizer": {
5+
"type": "Adam",
6+
"params": {
7+
"lr": 5e-5,
8+
"betas": [0.9, 0.999],
9+
"eps": 1e-06
10+
}
11+
},
12+
"bf16": {
13+
"enabled": true,
14+
"loss_scale": 0,
15+
"loss_scale_window": 100,
16+
"initial_scale_power": 16,
17+
"hysteresis": 2,
18+
"min_loss_scale": 0.01
19+
},
20+
"zero_optimization": {
21+
"stage": 2,
22+
"allgather_partitions": true,
23+
"overlap_comm": true,
24+
"reduce_scatter": true,
25+
"contiguous_gradients": true
26+
},
27+
"scheduler": {
28+
"type": "WarmupLR",
29+
"params": {
30+
"warmup_min_lr": 0.00,
31+
"warmup_max_lr": 0.00005,
32+
"warmup_num_steps": 1000
33+
}
34+
},
35+
"checkpoint_activations": false
36+
37+
}

0 commit comments

Comments
 (0)