Skip to content

Commit bd5e5aa

Browse files
Merge pull request #183 from X-LANCE/main
merge latest main branch
2 parents 85191fb + 87e1449 commit bd5e5aa

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+4874
-109315
lines changed

README.md

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,17 @@ developers to train custom multimodal large language model (MLLM), focusing on <
2020
# Table of Contents
2121
1. [News](#news)
2222
2. [Installation](#installation)
23-
3. [Uasge](#uasge)
23+
3. [Usage](#usage)
2424
- [List of Recipes](#list-of-recipes)
2525
- [Configuration Priority](#configuration-priority)
2626
4. [Features](#features)
2727
5. [Acknowledge](#acknowledge)
2828
6. [Citation](#citation)
2929

3030
# News
31-
- [Update Oct. 12, 2024] Recipes for [SLAM-AAC](examples/slam_aac/README.md) have been supported.
31+
- [Update Nov. 17, 2024] Recipes for [LLM-Based Contextual ASR](examples/contextual_asr/README.md) have been supported.
32+
- [Update Nov. 5, 2024] Recipes for [speech emotion captioning (SEC)](examples/sec_emotioncaps/README.md) with [emotion2vec](https://github.com/ddlBoJack/emotion2vec) as the encoder has been supported.
33+
- [Update Oct. 12, 2024] Recipes for [SLAM-AAC](examples/slam_aac/README.md) with [EAT](https://github.com/cwx-worst-one/EAT) as the encoder have been supported.
3234
- [Update Sep. 28, 2024] Recipes for [CoT-ST](examples/st_covost2/README.md) have been supported.
3335
- [Update Sep. 25, 2024] Recipes for [DRCap](examples/drcap_zeroshot_aac/README.md) have been supported.
3436
- [Update Jun. 12, 2024] Recipes for [MaLa-ASR](examples/mala_asr_slidespeech/README.md) have been supported.
@@ -83,13 +85,15 @@ We provide reference implementations of various LLM-based speech, audio, and mus
8385

8486
- Contextual Automatic Speech Recognition (CASR)
8587
- [ Mala-ASR](examples/mala_asr_slidespeech/README.md)
88+
- [LLM-Based Contextual ASR](examples/contextual_asr/README.md)
8689

8790
- [Visual Speech Recognition (VSR)](examples/vsr_LRS3/README.md)
8891
- Speech-to-Text Translation (S2TT)
8992
- [CoT-ST](examples/st_covost2/README.md)
9093

9194
- Text-to-Speech (TTS)
9295
- [VALL-E-X](examples/vallex/README.md)
96+
- [Speech Emotion Captioning (SEC)](examples/sec_emotioncaps/README.md)
9397

9498
- **Audio Task**
9599
- [Automated Audio Captioning (AAC)](examples/aac_audiocaps/README.md)
@@ -118,7 +122,10 @@ command-line (shell file) > Hydra configuration (yaml file) > dataclass configur
118122
- We borrow code from [Fairseq](https://github.com/facebookresearch/fairseq) for deepspeed configuration.
119123
- We thank the contributors for providing diverse recipes.
120124

121-
## Citation
125+
# Citation
126+
127+
## Speech Task
128+
122129
SLAM-ASR:
123130
```
124131
@article{ma2024embarrassingly,
@@ -128,4 +135,60 @@ SLAM-ASR:
128135
year={2024}
129136
}
130137
```
138+
Mala-ASR:
139+
```
140+
@article{yang2024mala,
141+
title={MaLa-ASR: Multimedia-Assisted LLM-Based ASR},
142+
author={Yang, Guanrou and Ma, Ziyang and Yu, Fan and Gao, Zhifu and Zhang, Shiliang and Chen, Xie},
143+
journal={Proc. INTERSPEECH},
144+
year={2024}
145+
}
146+
```
147+
LLM-Based Contextual ASR:
148+
```
149+
@article{yang2024ctc,
150+
title={CTC-Assisted LLM-Based Contextual ASR},
151+
author={Yang, Guanrou and Ma, Ziyang and Gao, Zhifu and Zhang, Shiliang and Chen, Xie},
152+
journal={Proc. SLT},
153+
year={2024}
154+
}
155+
```
156+
CoT-ST:
157+
```
158+
@article{du2024cot,
159+
title={CoT-ST: Enhancing LLM-based Speech Translation with Multimodal Chain-of-Thought},
160+
author={Du, Yexing and Ma, Ziyang and Yang, Yifan and Deng, Keqi and Chen, Xie and Yang, Bo and Xiang, Yang and Liu, Ming and Qin, Bing},
161+
journal={arXiv preprint arXiv:2409.19510},
162+
year={2024}
163+
}
164+
```
165+
131166

167+
## Audio Task
168+
SLAM-AAC:
169+
```
170+
@article{chen2024slam,
171+
title={SLAM-AAC: Enhancing Audio Captioning with Paraphrasing Augmentation and CLAP-Refine through LLMs},
172+
author={Chen, Wenxi and Ma, Ziyang and Li, Xiquan and Xu, Xuenan and Liang, Yuzhe and Zheng, Zhisheng and Yu, Kai and Chen, Xie},
173+
journal={arXiv preprint arXiv:2410.09503},
174+
year={2024}
175+
}
176+
```
177+
DRCap:
178+
```
179+
@article{li2024drcap,
180+
title={DRCap: Decoding CLAP Latents with Retrieval-augmented Generation for Zero-shot Audio Captioning},
181+
author={Li, Xiquan and Chen, Wenxi and Ma, Ziyang and Xu, Xuenan and Liang, Yuzhe and Zheng, Zhisheng and Kong, Qiuqiang and Chen, Xie},
182+
journal={arXiv preprint arXiv:2410.09472},
183+
year={2024}
184+
}
185+
```
186+
BAT:
187+
```
188+
@article{zheng2024bat,
189+
title={BAT: Learning to Reason about Spatial Sounds with Large Language Models},
190+
author={Zheng, Zhisheng and Peng, Puyuan and Ma, Ziyang and Chen, Xie and Choi, Eunsol and Harwath, David},
191+
journal={Proc. ICML},
192+
year={2024}
193+
}
194+
```

examples/aac_audiocaps/aac_config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from dataclasses import dataclass, field
22
from typing import Optional, List
3+
4+
from torch.distributed.fsdp import ShardingStrategy
5+
6+
37
@dataclass
48
class ModelConfig:
59
file: str = "examples/aac_audiocaps/model/slam_model_aac.py:model_factory"
@@ -114,7 +118,7 @@ class FSDPConfig:
114118
mixed_precision: bool = True
115119
use_fp16: bool = False
116120
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
117-
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
121+
sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
118122
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
119123
fsdp_activation_checkpointing: bool = True
120124
fsdp_cpu_offload: bool = False

examples/asr_librispeech/asr_config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from dataclasses import dataclass, field
22
from typing import Optional, List
3+
4+
from torch.distributed.fsdp import ShardingStrategy
5+
6+
37
@dataclass
48
class ModelConfig:
59
file: str = "examples/asr_librispeech/model/slam_model_asr.py:model_factory"
@@ -108,7 +112,7 @@ class FSDPConfig:
108112
mixed_precision: bool = True
109113
use_fp16: bool = False
110114
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
111-
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
115+
sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
112116
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
113117
fsdp_activation_checkpointing: bool = True
114118
fsdp_cpu_offload: bool = False

examples/contextual_asr/README.md

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# CTC-Assisted LLM-Based Contextual ASR
2+
3+
## Guides
4+
5+
[CTC-Assisted LLM-Based Contextual ASR](https://arxiv.org/abs/2411.06437) is an LLM-based contextual ASR model that first uses CTC decoding results to filter potential relevant hotwords from pre-defined hotwords list and then incorporate them into LLM prompt input to improve recognition of hotwords.
6+
7+
## Model Architecture
8+
9+
We use WavLM-Large model pre-trained on 94, 000 hours of data, and fine-tuned on 960h hours of Librispeech data with CTC loss, as our speech encoder. We use the public Vicuna 7B as our large language model decoder, and a simple-structured linear projector, consisting of a 1-D convolution layer and two linear layers as our adapter. Refer to our [paper](https://arxiv.org/pdf/2411.06437) for more details.
10+
11+
![](docs/model.png)
12+
13+
## Checkpoints
14+
We only train the linear projector in this recipe.
15+
Encoder | Projector | LLM
16+
|---|---|---|
17+
[CTC Fine-tuned WavLM-Large](https://drive.google.com/file/d/12ZmSSbDvx73W0eK1wpUgajapCLhqh5DI/view?usp=drive_link)(~315.45M) | [Linear](https://drive.google.com/file/d/1Zlbsnz1YUWtYtt-yNyoPK5OhR30kwLfS/view?usp=drive_link)(~15.74M) | [vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5)(~6.7B)
18+
19+
## Performance
20+
![](docs/performance.png)
21+
22+
23+
## Data preparation
24+
The artificial biasing list constructed in [Contextualized streaming end-to-end speech recognition with trie-based deep biasing and shallow fusion](https://arxiv.org/pdf/2104.02194) is utilized for contextual ASR testing. Refer to official [Repo](https://github.com/facebookresearch/fbai-speech/tree/main/is21_deep_bias).
25+
They categorize the 5,000 most frequent words in the Librispeech training corpus as common
26+
words, with the remainder classified as rare words. The biasing list generated for the test set consists of two segments: rare words in the transcriptions, and distractors sampled from the 209.2K rare words vocabulary. Biasing lists of varying lengths are generated by incorporating N = {100, 500, 1000, 2000} distractors into the lists.
27+
28+
29+
The viterbi decode results of our CTC Fine-tuned WavLM-Large: [test-clean](https://drive.google.com/file/d/1kMzPx8oRK3aOsxNaMGski3zH8z5Otvek/view?usp=drive_link), [test-other](https://drive.google.com/file/d/12KHaatVg5O0MIBTcf8e_rNjV_i9WLBFR/view?usp=drive_link) (``ctc_file`` in contextual_asr_config.py)
30+
31+
## Decoding with checkpoints
32+
LLM-based ASR Inference script.
33+
```
34+
bash decode_wavlm_libri960_ft_char.sh
35+
```
36+
LLM-based Contextual ASR Inference script, with different biaisng list sizes.
37+
```
38+
bash decode_wavlm_libri960_ft_char_hotwords.sh
39+
```
40+
41+
42+
## Training the model
43+
LLM-based ASR Training script: using CTC fine-tuned Wavlm as encoder and “Transcribe speech to text.” as prompt.
44+
```
45+
bash finetune_wavlm_libri960_ft_char.sh
46+
```
47+
LLM-based Contextual ASR Training script: using CTC fine-tuned Wavlm as encoder and "Transcribe speech to text. Some hotwords might help. The hotwords are {}.” as prompt.
48+
```
49+
bash finetune_wavlm_libri960_ft_char_hotwords.sh
50+
```
51+
52+
53+
## Citation
54+
You can refer to the paper for more results.
55+
```
56+
@article{yang2024ctc,
57+
title={CTC-Assisted LLM-Based Contextual ASR},
58+
author={Yang, Guanrou and Ma, Ziyang and Gao, Zhifu and Zhang, Shiliang and Chen, Xie},
59+
journal={Proc. SLT},
60+
year={2024}
61+
}
62+
```
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
"train_micro_batch_size_per_gpu": 4,
3+
"gradient_accumulation_steps": 1,
4+
"optimizer": {
5+
"type": "Adam",
6+
"params": {
7+
"lr": 1e-4
8+
}
9+
},
10+
"fp16": {
11+
"enabled": true
12+
},
13+
"zero_optimization": {
14+
"stage": 3,
15+
"offload_optimizer": {
16+
"device": "cpu"
17+
}
18+
}
19+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
dataset_config:
2+
# we put prompt here, because the hydra override in shell script only support a small subset of chars
3+
# prompt: "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. "
4+
prompt: "Transcribe speech to text. "
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
from dataclasses import dataclass, field
2+
from typing import Optional, List
3+
@dataclass
4+
class ModelConfig:
5+
file: str = "examples/contextual_asr/model/slam_model_contextual_asr.py:model_factory"
6+
llm_name: str = "vicuna-13b-v1.5"
7+
llm_path: str = "PATH/to/LLAMA/7B"
8+
llm_type: str = "decoder_only"
9+
llm_dim: int = 4096
10+
encoder_name: Optional[str] = None
11+
encoder_ds_rate: int = 2
12+
encoder_path: Optional[str] = None
13+
encoder_dim: int = 1280
14+
encoder_projector: str = "linear"
15+
encoder_projector_ds_rate: int = 5
16+
modal: str = "audio"
17+
normalize: Optional[bool] = field(default=False, metadata={
18+
"help": "whether input is normalized, used for models such as wavlm"
19+
})
20+
encoder_type: str = field(default="finetune", metadata={
21+
"help": "whether model is only pretrained or finetuned, used for models such as hubert"
22+
})
23+
24+
@dataclass
25+
class PeftConfig:
26+
peft_method: str = "lora" # None , llama_adapter, prefix
27+
r: int = 8
28+
lora_alpha: int = 32
29+
# target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj" ])
30+
target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj","k_proj","o_proj" ])
31+
bias: str = "none"
32+
task_type: str = "CAUSAL_LM"
33+
lora_dropout: float = 0.05
34+
inference_mode: bool = False
35+
36+
@dataclass
37+
class TrainConfig:
38+
model_name:str = "PATH/to/LLAMA/7B"
39+
enable_ddp:bool = False
40+
enable_deepspeed:bool = False
41+
enable_fsdp:bool = False
42+
low_cpu_fsdp:bool = False
43+
run_validation:bool = True
44+
batch_size_training:int = 4
45+
batching_strategy:str = field(default="packing", metadata={
46+
"help":"alternative: padding"
47+
})
48+
context_length:int = 4096
49+
gradient_accumulation_steps:int = 1
50+
num_epochs:int = 3
51+
num_workers_dataloader:int = 1
52+
warmup_steps:int = 1000
53+
total_steps:int = 100000
54+
validation_interval:int = 1000
55+
lr:float = 1e-4
56+
weight_decay:float = 0.0
57+
gamma:float = 0.85
58+
seed:int = 42
59+
use_fp16:bool = False
60+
mixed_precision:bool = True
61+
val_batch_size:int = 1
62+
use_peft:bool = False
63+
peft_config:PeftConfig = field(default_factory=PeftConfig)
64+
output_dir:str = "PATH/to/save/PEFT/model"
65+
freeze_layers:bool = False
66+
num_freeze_layers:int = 1
67+
quantization:bool = False
68+
one_gpu:bool = False
69+
save_model:bool = True
70+
dist_checkpoint_root_folder:str = "PATH/to/save/FSDP/model" # will be used if using FSDP
71+
dist_checkpoint_folder:str = "fine-tuned" # will be used if using FSDP
72+
save_optimizer:bool = False # will be used if using FSDP
73+
use_fast_kernels:bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
74+
run_test_during_validation:bool = False
75+
run_test_during_validation_file:str = "test.wav"
76+
run_test_during_validation_prompt:str = "<|ASR|>"
77+
freeze_llm:bool = field(default=False, metadata={
78+
"help": "whether to freeze llm when finetuning, should be true when use peft finetuning"
79+
})
80+
freeze_encoder:bool = False
81+
82+
@dataclass
83+
class DataConfig:
84+
dataset: str = "speech_dataset"
85+
file: str = "src/slam_llm/datasets/speech_dataset.py:get_speech_dataset"
86+
train_data_path: Optional[str] = None
87+
val_data_path: Optional[str] = None
88+
train_split: str = "train"
89+
test_split:str = "validation"
90+
prompt: Optional[str] = None
91+
data_path: Optional[str] = None
92+
max_words: Optional[int] = None
93+
max_mel: Optional[float] = None
94+
fix_length_audio: int = -1
95+
inference_mode:bool = False
96+
input_type: str = field(default="raw", metadata={
97+
"help":"Use raw when input is wav, mel when for whisper"
98+
})
99+
mel_size: int = field(default=80, metadata={
100+
"help": "80 for whisper large v1 and v2, 128 for v3"
101+
})
102+
normalize: Optional[bool] = field(default=False, metadata={
103+
"help": "whether input is normalized, used for models such as wavlm"
104+
})
105+
infer_type: str = "bias"
106+
infer_file: str = "/nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/my_ref/test-clean.biasing_100.tsv"
107+
ctc_file: str = "/nfs/yangguanrou.ygr/data/librispeech_my_infer/wavlm_large_libri_test_other_char.txt"
108+
common_words_5k_dir: str="/nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/words/common_words_5k.txt"
109+
probability_threshold: float = 0.9
110+
word_num: int = 15
111+
filter_infer_sentence: bool = False
112+
filter_infer_sentence_few: bool = False
113+
first: int = 1
114+
115+
@dataclass
116+
class FSDPConfig:
117+
mixed_precision: bool = True
118+
use_fp16: bool = False
119+
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
120+
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
121+
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
122+
fsdp_activation_checkpointing: bool = True
123+
fsdp_cpu_offload: bool = False
124+
pure_bf16: bool = False
125+
optimizer: str = "AdamW"
126+
127+
@dataclass
128+
class LogConfig:
129+
use_wandb: bool = False
130+
wandb_dir: str = "/root/test_wandb"
131+
wandb_entity_name: str = "project_name"
132+
wandb_project_name: str = "project_name"
133+
wandb_exp_name: str = "exp_name"
134+
log_file: str = "/root/test.log"
135+
log_interval: int = 5

0 commit comments

Comments
 (0)