Skip to content

Commit 80cc33f

Browse files
authored
Merge pull request #173 from X-LANCE/ygr_pr2
for ctc-assisted llm-basd CASR codes pr
2 parents f32b8a2 + 52fab27 commit 80cc33f

15 files changed

+1385
-0
lines changed

examples/contextual_asr/README.md

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# CTC-Assisted LLM-Based Contextual ASR
2+
3+
## Guides
4+
5+
[CTC-Assisted LLM-Based Contextual ASR](https://arxiv.org/abs/2411.06437) is an LLM-based contextual ASR model that first uses CTC decoding results to filter potential relevant hotwords from pre-defined hotwords list and then incorporate them into LLM prompt input to improve recognition of hotwords.
6+
7+
## Model Architecture
8+
9+
We use WavLM-Large model pre-trained on 94, 000 hours of data, and fine-tuned on 960h hours of Librispeech data with CTC loss, as our speech encoder. We use the public Vicuna 7B as our large language model decoder, and a simple-structured linear projector, consisting of a 1-D convolution layer and two linear layers as our adapter. Refer to our [paper](https://arxiv.org/pdf/2411.06437) for more details.
10+
11+
![](docs/model.png)
12+
13+
## Checkpoints
14+
We only train the linear projector in this recipe.
15+
Encoder | Projector | LLM
16+
|---|---|---|
17+
[CTC Fine-tuned WavLM-Large](https://drive.google.com/file/d/12ZmSSbDvx73W0eK1wpUgajapCLhqh5DI/view?usp=drive_link)(~315.45M) | [Linear](https://drive.google.com/file/d/1Zlbsnz1YUWtYtt-yNyoPK5OhR30kwLfS/view?usp=drive_link)(~15.74M) | [vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5)(~6.7B)
18+
19+
## Performance
20+
![](docs/performance.png)
21+
22+
23+
## Data preparation
24+
The artificial biasing list constructed in [Contextualized streaming end-to-end speech recognition with trie-based deep biasing and shallow fusion](https://arxiv.org/pdf/2104.02194) is utilized for contextual ASR testing. Refer to official [Repo](https://github.com/facebookresearch/fbai-speech/tree/main/is21_deep_bias).
25+
They categorize the 5,000 most frequent words in the Librispeech training corpus as common
26+
words, with the remainder classified as rare words. The biasing list generated for the test set consists of two segments: rare words in the transcriptions, and distractors sampled from the 209.2K rare words vocabulary. Biasing lists of varying lengths are generated by incorporating N = {100, 500, 1000, 2000} distractors into the lists.
27+
28+
29+
30+
## Decoding with checkpoints
31+
LLM-based ASR Inference script.
32+
```
33+
bash decode_wavlm_libri960_ft_char.sh
34+
```
35+
LLM-based Contextual ASR Inference script, with different biaisng list sizes.
36+
```
37+
bash decode_wavlm_libri960_ft_char_hotwords.sh
38+
```
39+
40+
41+
## Training the model
42+
LLM-based ASR Training script: using CTC fine-tuned Wavlm as encoder and “Transcribe speech to text.” as prompt.
43+
```
44+
bash finetune_wavlm_libri960_ft_char.sh
45+
```
46+
LLM-based Contextual ASR Training script: using CTC fine-tuned Wavlm as encoder and "Transcribe speech to text. Some hotwords might help. The hotwords are {}.” as prompt.
47+
```
48+
bash finetune_wavlm_libri960_ft_char_hotwords.sh
49+
```
50+
51+
52+
## Citation
53+
You can refer to the paper for more results.
54+
```
55+
@article{yang2024ctc,
56+
title={CTC-Assisted LLM-Based Contextual ASR},
57+
author={Yang, Guanrou and Ma, Ziyang and Gao, Zhifu and Zhang, Shiliang and Chen, Xie},
58+
journal={Proc. SLT},
59+
year={2024}
60+
}
61+
```
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
"train_micro_batch_size_per_gpu": 4,
3+
"gradient_accumulation_steps": 1,
4+
"optimizer": {
5+
"type": "Adam",
6+
"params": {
7+
"lr": 1e-4
8+
}
9+
},
10+
"fp16": {
11+
"enabled": true
12+
},
13+
"zero_optimization": {
14+
"stage": 3,
15+
"offload_optimizer": {
16+
"device": "cpu"
17+
}
18+
}
19+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
dataset_config:
2+
# we put prompt here, because the hydra override in shell script only support a small subset of chars
3+
# prompt: "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. "
4+
prompt: "Transcribe speech to text. "
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
from dataclasses import dataclass, field
2+
from typing import Optional, List
3+
@dataclass
4+
class ModelConfig:
5+
file: str = "examples/contextual_asr/model/slam_model_contextual_asr.py:model_factory"
6+
llm_name: str = "vicuna-13b-v1.5"
7+
llm_path: str = "PATH/to/LLAMA/7B"
8+
llm_type: str = "decoder_only"
9+
llm_dim: int = 4096
10+
encoder_name: Optional[str] = None
11+
encoder_ds_rate: int = 2
12+
encoder_path: Optional[str] = None
13+
encoder_dim: int = 1280
14+
encoder_projector: str = "linear"
15+
encoder_projector_ds_rate: int = 5
16+
modal: str = "audio"
17+
normalize: Optional[bool] = field(default=False, metadata={
18+
"help": "whether input is normalized, used for models such as wavlm"
19+
})
20+
encoder_type: str = field(default="finetune", metadata={
21+
"help": "whether model is only pretrained or finetuned, used for models such as hubert"
22+
})
23+
24+
@dataclass
25+
class PeftConfig:
26+
peft_method: str = "lora" # None , llama_adapter, prefix
27+
r: int = 8
28+
lora_alpha: int = 32
29+
# target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj" ])
30+
target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj","k_proj","o_proj" ])
31+
bias: str = "none"
32+
task_type: str = "CAUSAL_LM"
33+
lora_dropout: float = 0.05
34+
inference_mode: bool = False
35+
36+
@dataclass
37+
class TrainConfig:
38+
model_name:str = "PATH/to/LLAMA/7B"
39+
enable_ddp:bool = False
40+
enable_deepspeed:bool = False
41+
enable_fsdp:bool = False
42+
low_cpu_fsdp:bool = False
43+
run_validation:bool = True
44+
batch_size_training:int = 4
45+
batching_strategy:str = field(default="packing", metadata={
46+
"help":"alternative: padding"
47+
})
48+
context_length:int = 4096
49+
gradient_accumulation_steps:int = 1
50+
num_epochs:int = 3
51+
num_workers_dataloader:int = 1
52+
warmup_steps:int = 1000
53+
total_steps:int = 100000
54+
validation_interval:int = 1000
55+
lr:float = 1e-4
56+
weight_decay:float = 0.0
57+
gamma:float = 0.85
58+
seed:int = 42
59+
use_fp16:bool = False
60+
mixed_precision:bool = True
61+
val_batch_size:int = 1
62+
use_peft:bool = False
63+
peft_config:PeftConfig = field(default_factory=PeftConfig)
64+
output_dir:str = "PATH/to/save/PEFT/model"
65+
freeze_layers:bool = False
66+
num_freeze_layers:int = 1
67+
quantization:bool = False
68+
one_gpu:bool = False
69+
save_model:bool = True
70+
dist_checkpoint_root_folder:str = "PATH/to/save/FSDP/model" # will be used if using FSDP
71+
dist_checkpoint_folder:str = "fine-tuned" # will be used if using FSDP
72+
save_optimizer:bool = False # will be used if using FSDP
73+
use_fast_kernels:bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
74+
run_test_during_validation:bool = False
75+
run_test_during_validation_file:str = "test.wav"
76+
run_test_during_validation_prompt:str = "<|ASR|>"
77+
freeze_llm:bool = field(default=False, metadata={
78+
"help": "whether to freeze llm when finetuning, should be true when use peft finetuning"
79+
})
80+
freeze_encoder:bool = False
81+
82+
@dataclass
83+
class DataConfig:
84+
dataset: str = "speech_dataset"
85+
file: str = "src/slam_llm/datasets/speech_dataset.py:get_speech_dataset"
86+
train_data_path: Optional[str] = None
87+
val_data_path: Optional[str] = None
88+
train_split: str = "train"
89+
test_split:str = "validation"
90+
prompt: Optional[str] = None
91+
data_path: Optional[str] = None
92+
max_words: Optional[int] = None
93+
max_mel: Optional[float] = None
94+
fix_length_audio: int = -1
95+
inference_mode:bool = False
96+
input_type: str = field(default="raw", metadata={
97+
"help":"Use raw when input is wav, mel when for whisper"
98+
})
99+
mel_size: int = field(default=80, metadata={
100+
"help": "80 for whisper large v1 and v2, 128 for v3"
101+
})
102+
normalize: Optional[bool] = field(default=False, metadata={
103+
"help": "whether input is normalized, used for models such as wavlm"
104+
})
105+
infer_type: str = "bias"
106+
infer_file: str = "/nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/my_ref/test-clean.biasing_100.tsv"
107+
ctc_file: str = "/nfs/yangguanrou.ygr/data/librispeech_my_infer/wavlm_large_libri_test_other_char.txt"
108+
filter_type: str = "char"
109+
phn_to_name_dict: str = "/nfs/yangguanrou.ygr/data/librispeech_my_infer/wavlm_ft_libri960_${ref_split}_phn.json"
110+
common_words_5k_dir: str="/nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/words/common_words_5k.txt"
111+
probability_threshold: float = 0.9
112+
word_num: int = 15
113+
filter_infer_sentence: bool = False
114+
filter_infer_sentence_few: bool = False
115+
first: int = 1
116+
117+
@dataclass
118+
class FSDPConfig:
119+
mixed_precision: bool = True
120+
use_fp16: bool = False
121+
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
122+
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
123+
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
124+
fsdp_activation_checkpointing: bool = True
125+
fsdp_cpu_offload: bool = False
126+
pure_bf16: bool = False
127+
optimizer: str = "AdamW"
128+
129+
@dataclass
130+
class LogConfig:
131+
use_wandb: bool = False
132+
wandb_dir: str = "/root/test_wandb"
133+
wandb_entity_name: str = "project_name"
134+
wandb_project_name: str = "project_name"
135+
wandb_exp_name: str = "exp_name"
136+
log_file: str = "/root/test.log"
137+
log_interval: int = 5

0 commit comments

Comments
 (0)