Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions examples/aispeech_asr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# AISPEECH_ASR

## Overview

This example is designed for large-scale industrial data training, suitable for datasets on the order of 100,000 hours. Its main features include:
- **Support for multi-task training**: Designed to support tasks such as ASR and ST through a unified data format.
- **Dynamic prompt selection**: Supports random selection from multiple prompts.
- **Iterative dataset**: Uses an iterative dataset format to reduce startup time for large datasets.
- **Deepspeed training**: Supports DeepSpeed training to significantly reduce memory usage.
- **Multi-machine multi-GPU inference**: Supports distributed inference across multiple machines and GPUs to reduce evaluation time.
- **Dynamic frame batching**: Dynamically combines frames based on audio size rather than using a fixed batch size, significantly reducing training and evaluation time (reduces training time by 3/4 for 100,000 hours of data).

This example is modified from `mala_asr_slidespeech`.

## Model Architecture

The model architecture can be dynamically selected within the scope supported by SLAM-LMM. Below are some recommended configurations:
- **Encoder**: WavLM, Whisper
- **Projector**: Linear
- **LLM**: Qwen2.5-7B-Instruct, Vicuna1.5-7B

## Data Preparation

The following two files are required:
- `multitask.jsonl`
- `multiprompt.jsonl`

### multitask.jsonl

The format of this file is as follows, where `path` supports both ark format and wav files:
```json
{"key": "BAC009S0002W0122", "task": "ASR", "target": "而对楼市成交抑制作用最大的限购", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/train/data/data_wav.1.ark:17"}
{"key": "BAC009S0002W0123", "task": "ASR", "target": "也成为地方政府的眼中钉", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/train/data/data_wav.1.ark:191758"}
{"key": "BAC009S0002W0124", "task": "ASR", "target": "自六月底呼和浩特市率先宣布取消限购后", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/train/data/data_wav.1.ark:315339"}
{"key": "BAC009S0764W0238", "task": "hotword", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/test/data/data_wav.1.ark:17343733", "target": "形成一批具有国际竞争力的中国企业", "hotword": "中国"}
```

### multiprompt.jsonl

The format of this file is as follows:
```json
{"task": "ASR", "prompt": "Transcribe speech to text."}
{"task": "ASR", "prompt": "请识别语音."}
{"task": "ZH2EN", "prompt": "请识别语音并翻译为英文:"}
{"task": "EN2ZH", "prompt": "请识别语音并翻译为中文:"}
{"task": "prevtext", "prompt": "Transcribe speech to text, below are the previous historical transcription texts:{}."}
{"task": "hotword", "prompt": "Transcribe speech to text, follow words may occur:{}."}
```

### Notes
- If multiple prompts are provided, one will be selected dynamically.
- For additional information (e.g., hotwords), include the task-named information in `multitask.jsonl` and use `{}` in the prompt to inject this information. Additionally, update the `append_info_tasks` in the `aispeech_config` file:
```python
append_info_tasks: List = field(default_factory=lambda: ["hotword"])
```

## Training a New Model

### Script Preparation

Prepare and modify the following content in `scripts/finetune_deepspeed.sh` or `scripts/finetune_torchrun.sh` (Deepspeed is recommended):
```bash
run_dir= # Directory to save the model
train_scp_file_path= # Path to training data
dev_scp_file_path= # Path to validation data
train_max_frame_length=1500 # Maximum frame length for training
eval_max_frame_length=1000 # Maximum frame length for evaluation
multitask_prompt_path= # Path to multitask.jsonl
prompt_style="\{\}" # Prompt style, e.g., "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" or "USER: {}\n ASSISTANT:"
projector=linear # Type of projector
encoder_name=whisper # Name of the encoder
llm_name=Qwen2.5-7B-Instruct # Name of the LLM
use_peft=false # Whether to use PEFT (for LLM)
use_fp16=true # Whether to use FP16
freeze_encoder=true # Whether to freeze the encoder
pad_or_trim=true # Whether to use pad_or_trim (for Whisper)
deepspeed_config= # Path to DeepSpeed configuration file
```

Typically, we first train the projector and then fine-tune the LoRA. For projector training, set:
```bash
use_peft=false
```

For LoRA training, set (with `ckpt_path` pointing to the model saved in the previous step):
```bash
use_peft=true
if [[ $use_peft == "true" ]]; then
ckpt_path= # For DDP training, provide the path to the saved pt file; for DeepSpeed training, convert mp_rank_00_model_states.pt to model.pt using the `scripts/transcribe_deepspeed_to_pt.py` script
fi
```
### Deepspeed
When using `bf16`/`fp16` for training, deepspeed saves about 20GB of GPU memory compared to `torchrun` when training a 7B model. For 7B models, it's recommended to use `zero-0`/`1`/`2`, while for extremely large models, `zero-3` can be used, though communication may become a bottleneck.

```json
{
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-4
}
},
"fp16": {
"enabled": true
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu"
}
}
}
```

Note that when using `zero-0`/`1`/`2`, the DeepSpeed model is saved in a format that requires a script to convert `mp_rank_00_model_states.pt` to `model.pt`, such as `python scripts/transcribe_deepspeed_to_pt.py mp_rank_00_model_states.pt output_dir`.

```
global_step1000
global_step1000/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
...
global_step1000/mp_rank_00_model_states.pt
latest
zero_to_fp32.py
```

If training with `Zero-3`, the model is saved in a different format and can be converted using `python zero_to_fp32.py global_step50 outputdir`.

```
global_step50
global_step50/zero_pp_rank_0_mp_rank_00_model_states.pt
global_step50/zero_pp_rank_0_mp_rank_00_optim_states.pt
...
latest
zero_to_fp32.py
```
If you use bf16/fp16 training in DeepSpeed and encounter NaN in train/eval loss, check the autocast in `src/slam_llm/utils/deepspeed_utils.py`:

```python
with autocast() # original code
with autocast(dtype=torch.bfloat16) # must work
with autocast(dtype=torch.float16)
```
## Decoding

- **Single-machine single-GPU decoding**: Refer to `scripts/decode.sh`
- **Single-machine multi-GPU decoding**: Refer to `scripts/decode_deepspeed.sh`

## Multi-Machine Multi-GPU Support

Multi-machine multi-GPU training can be supported with minor modifications to the `finetune_deepspeed.sh` or `scripts/decode_deepspeed.sh` scripts. Due to environment-specific requirements, this example does not include dedicated scripts for multi-machine multi-GPU setups.
99 changes: 99 additions & 0 deletions examples/aispeech_asr/README_zh.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# AISPEECH_ASR

## 概述

这是为工业界大规模数据训练准备的示例,适用于10万小时量级的数据训练,主要特点如下:
- **多任务训练支持**:通过设计数据格式,支持包括ASR、ST等多种任务。
- **动态Prompt选择**:支持在多个Prompt中随机选择。
- **迭代式dataset**:采用迭代形式的dataset,减少大数据量时的启动时间。
- **Deepspeed训练**:支持Deepspeed训练,显著减少内存使用。
- **多机多卡推理**:支持多机多卡推理,减少评估时间。
- **动态帧数组合**:根据每个音频大小动态组合合适的帧数进行训练,而非使用固定的batch_size,大大减少了训练和评估时间(在10万小时量级的数据上,训练时间减少了3/4)。

本示例基于`mala_asr_slidespeech`进行修改。

## 模型架构

可以根据需要,在SLAM—LMM支持的范围内动态选择模型架构。以下是一些推荐的模型配置:
- **Encoder**:WavLM, Whisper
- **Projector**:Linear
- **LLM**:Qwen2.5-7B-Instruct, Vicuna1.5-7B

## 数据准备

需要准备以下两个文件:
- `multitask.jsonl`
- `multiprompt.jsonl`

### multitask.jsonl

该文件的内容格式如下,其中`path`支持ark格式和wav文件:
```json
{"key": "BAC009S0002W0122", "task": "ASR", "target": "而对楼市成交抑制作用最大的限购", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/train/data/data_wav.1.ark:17"}
{"key": "BAC009S0002W0123", "task": "ASR", "target": "也成为地方政府的眼中钉", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/train/data/data_wav.1.ark:191758"}
{"key": "BAC009S0002W0124", "task": "ASR", "target": "自六月底呼和浩特市率先宣布取消限购后", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/train/data/data_wav.1.ark:315339"}
{"key": "BAC009S0764W0238", "task": "hotword", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/test/data/data_wav.1.ark:17343733", "target": "形成一批具有国际竞争力的中国企业", "hotword": "中国"}
```

### multiprompt.jsonl

该文件的内容格式如下:
```json
{"task": "ASR", "prompt": "Transcribe speech to text."}
{"task": "ASR", "prompt": "请识别语音."}
{"task": "ZH2EN", "prompt": "请识别语音并翻译为英文:"}
{"task": "EN2ZH", "prompt": "请识别语音并翻译为中文:"}
{"task": "prevtext", "prompt": "Transcribe speech to text, below are the previous historical transcription texts:{}."}
{"task": "hotword", "prompt": "Transcribe speech to text, follow words may occur:{}."}
```

### 注意事项
- 如果有多条Prompt,会动态选择其中一条。
- 如果有额外信息(如热词),请在`multitask.jsonl`中提供与任务同名的信息,并在Prompt中使用`{}`注入该信息。同时,修改`aispeech_config`文件中的`append_info_tasks`:
```python
append_info_tasks: List = field(default_factory=lambda: ["hotword"])
```

## 训练新模型

### 脚本准备

在`scripts/finetune_deepspeed.sh`或`scripts/finetune_torchrun.sh`中准备并修改以下内容(推荐使用Deepspeed):
```bash
run_dir= # 模型保存目录
train_scp_file_path= # 训练数据路径
dev_scp_file_path= # 验证数据路径
train_max_frame_length=1500 # 训练时的最大帧长度
eval_max_frame_length=1000 # 评估时的最大帧长度
multitask_prompt_path= # multitask.jsonl文件路径
prompt_style="\{\}" # Prompt样式,可选格式如"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"或"USER: {}\n ASSISTANT:"
projector=linear # Projector类型
encoder_name=whisper # Encoder名称
llm_name=Qwen2.5-7B-Instruct # LLM名称
use_peft=false # 是否使用PEFT(对于LLM)
use_fp16=true # 是否使用FP16
freeze_encoder=true # 是否冻结Encoder
pad_or_trim=true # 是否使用pad_or_trim(对于Whisper)
deepspeed_config= # DeepSpeed配置文件路径
```

通常,我们首先训练Projector,然后再训练LoRA。训练Projector时,设置如下:
```bash
use_peft=false
```

训练LoRA时,设置如下(`ckpt_path`是上一步训练保存的模型路径):
```bash
use_peft=true
if [[ $use_peft == "true" ]]; then
ckpt_path= # 如果是DDP训练,直接写入保存的pt文件路径;如果是Deepspeed训练,需将mp_rank_00_model_states.pt文件转化为model.pt,可使用`scripts/transcribe_deepspeed_to_pt.py`脚本
fi
```

## 解码

- **单机单卡解码**:参考`scripts/decode.sh`
- **单机多卡解码**:参考`scripts/decode_deepspeed.sh`

## 多机多卡支持
简单修改脚本finetune_deepspeed.sh 或者scripts/decode_deepspeed.sh`后可以支持多机多卡训练,因为环境不同所做的修改也不同,本实例就不放出多机多卡的脚本了
141 changes: 141 additions & 0 deletions examples/aispeech_asr/aispeech_asr_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from dataclasses import dataclass, field
from typing import Optional, List
from torch.distributed.fsdp import ShardingStrategy


@dataclass
class ModelConfig:
file: str = "examples/aispeech_asr/model/aispeech_asr.py:model_factory"
llm_name: str = "vicuna-7b-v1.5"
llm_path: str = "PATH/to/LLAMA/7B"
llm_type: str = "decoder_only"
llm_dim: int = 4096
whisper_decode : Optional[bool] = False
encoder_name: Optional[str] = None
encoder_ds_rate: int = 2
encoder_path: Optional[str] = None
encoder_path_hf: Optional[str] = None
encoder_dim: int = 1280
encoder_projector: str = "linear"
qformer_layers : int = 8
encoder_projector_ds_rate: int = 5
modal: str = "audio"
normalize: Optional[bool] = field(default=False, metadata={
"help": "whether input is normalized, used for models such as wavlm"
})
encoder_type: str = field(default="finetune", metadata={
"help": "whether model is only pretrained or finetuned, used for models such as hubert"
})


@dataclass
class PeftConfig:
peft_method: str = "lora" # None , llama_adapter, prefix
r: int = 64
lora_alpha: int = 16
target_modules: List = field(default_factory=lambda: [ "q_proj","k_proj", "v_proj", "o_proj", "up_proj","gate_proj","down_proj"])
bias: str = "none"
task_type: str = "CAUSAL_LM"
lora_dropout: float = 0.05
inference_mode: bool = False

@dataclass
class TrainConfig:
model_name:str = "PATH/to/LLAMA/7B"
enable_ddp:bool = False
enable_deepspeed:bool = False
enable_fsdp:bool = False
low_cpu_fsdp:bool = False
run_validation:bool = True
batch_size_training: Optional[int] = None
batching_strategy:str = field(default="packing", metadata={
"help":"alternative: padding"
}) #
context_length:int = 4096
gradient_accumulation_steps:int = 1
num_epochs:int = 3
num_workers_dataloader:int = 1
warmup_steps:int = 1000
total_steps:int = 100000
validation_interval:int = 1000
lr:float = 1e-4
weight_decay:float = 0.0
gamma:float = 0.85
seed:int = 42
use_fp16:bool = False
mixed_precision:bool = True
val_batch_size:Optional[int] = None

use_peft:bool = False
peft_config:PeftConfig = field(default_factory=PeftConfig)
output_dir:str = "PATH/to/save/PEFT/model"
freeze_layers:bool = False
num_freeze_layers:int = 1
quantization:bool = False
one_gpu:bool = False
save_model:bool = True
dist_checkpoint_root_folder:str = "PATH/to/save/FSDP/model" # will be used if using FSDP
dist_checkpoint_folder:str = "fine-tuned" # will be used if using FSDP
save_optimizer:bool = False # will be used if using FSDP
use_fast_kernels:bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
run_test_during_validation:bool = False
run_test_during_validation_file:str = "test.wav"
run_test_during_validation_prompt:str = "<|ASR|>"
freeze_llm:bool = field(default=False, metadata={
"help": "whether to freeze llm when finetuning, should be true when use peft finetuning"
})
freeze_encoder:bool = False

@dataclass
class DataConfig:
dataset: str = "multitask_dataset"
train_max_frame_length: int = 1500
eval_max_frame_length: int = 1000
multitask_prompt_path: str = "/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/data/multiprompt.jsonl"
prompt_style: str = "\{\}" #
append_info_tasks : List = field(default_factory=lambda: [ "hotword"])
file: str = "examples/aispeech_asr/slam_llm/datasets/speech_dataset_large.py:get_speech_dataset"
train_scp_file_path: str = ""
dev_scp_file_path: str = ""
test_scp_file_path: str = ""
train_split: str = "train"
dev_split: str = "dev"
test_split:str = "test"
pad_or_trim: bool = True
prompt: Optional[str] = None
use_ocr: bool = True
inference_mode: bool = False
lower: bool = False
fix_length_audio: int = -1
inference_mode:bool = False
input_type: str = field(default="raw", metadata={
"help":"Use raw when input is wav, mel when for whisper"
})
mel_size: int = field(default=80, metadata={
"help": "80 for whisper large v1 and v2, 128 for v3"
})
normalize: Optional[bool] = field(default=False, metadata={
"help": "whether input is normalized, used for models such as wavlm"
})

@dataclass
class FSDPConfig:
mixed_precision: bool = True
use_fp16: bool = False
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: ShardingStrategy = "SHARD_GRAD_OP" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool = True
fsdp_cpu_offload: bool = False
pure_bf16: bool = False
optimizer: str = "AdamW"

@dataclass
class LogConfig:
use_wandb: bool = False
wandb_dir: str = "tmp/test_wandb"
wandb_entity_name: str = "project_name"
wandb_project_name: str = "project_name"
wandb_exp_name: str = "exp_name"
log_file: str = "tmp/test.log"
log_interval: int = 5
Loading