Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
dfab661
Add examples/sec_emotioncaps
xrysamuel Oct 1, 2024
d2d5fe7
Merge branch 'X-LANCE:main' into main
xrysamuel Oct 2, 2024
832bf02
fix #92 for fsdp training
Oct 14, 2024
0045773
Merge pull request #153 from nuaalixu/main
ddlBoJack Oct 15, 2024
84bf800
Merge branch 'X-LANCE:main' into main
xrysamuel Oct 16, 2024
f37fe73
Add README to examples/sec_emotioncaps
xrysamuel Oct 16, 2024
902f33e
test
cwx-worst-one Oct 18, 2024
a1d8052
Merge pull request #159 from X-LANCE/main
cwx-worst-one Oct 18, 2024
1738c01
sloth
cwx-worst-one Oct 18, 2024
9ea7716
sloth
cwx-worst-one Oct 18, 2024
f5505fd
test
cwx-worst-one Oct 18, 2024
8373500
test
cwx-worst-one Oct 18, 2024
6fb784b
Merge pull request #160 from X-LANCE/cwx_slam_aac
ddlBoJack Oct 27, 2024
c16e3a9
fix model name, README and fsdp config for examples/sec_emotioncaps
xrysamuel Nov 5, 2024
dbfcfca
Merge pull request #156 from xrysamuel/main
ddlBoJack Nov 5, 2024
378fb87
improve instruction of data preparation for Mala-asr
Nov 8, 2024
3a729b9
Merge pull request #168 from X-LANCE/ygr_pr2
ddlBoJack Nov 8, 2024
41d6185
update README
ddlBoJack Nov 8, 2024
8ab989d
Revert "update README"
ddlBoJack Nov 8, 2024
68c5fb9
update README
ddlBoJack Nov 8, 2024
dc5d254
add example data and modify readme for DRCap
xiquan-li Nov 8, 2024
f605258
Modify data path
xiquan-li Nov 8, 2024
09a4bf7
update readme
xiquan-li Nov 8, 2024
135f69b
minor changes
xiquan-li Nov 8, 2024
819ce1f
add some files
Nov 9, 2024
f32b8a2
Merge pull request #170 from Andreas-Xi/lxq-drcap
ddlBoJack Nov 9, 2024
90bf0ec
Merge branch 'main' of github.com:ddlBoJack/SLAM-LLM into ygr_pr2
Nov 17, 2024
7d2c2d7
for ctc-assisted llm-basd CASR pr
Nov 17, 2024
a9bd1fe
fix pr
Nov 17, 2024
7cb95ce
fix
Nov 17, 2024
86dc310
save
Nov 17, 2024
52fab27
f
Nov 17, 2024
80cc33f
Merge pull request #173 from X-LANCE/ygr_pr2
ddlBoJack Nov 17, 2024
f42716d
update README
ddlBoJack Nov 17, 2024
63fa976
test
yxduir Nov 24, 2024
33b84ed
test
yxduir Nov 24, 2024
6c26585
Merge pull request #176 from X-LANCE/yxdu
ddlBoJack Nov 24, 2024
7090a46
test
yxduir Nov 27, 2024
781c131
test
yxduir Nov 27, 2024
f887fc7
test
yxduir Nov 27, 2024
8d2dc88
Merge pull request #179 from X-LANCE/yxdu
ddlBoJack Nov 27, 2024
85d4b0b
upload ctc_file and remove irrelavant codes
Nov 29, 2024
12539e4
fix
yanghaha0908 Nov 30, 2024
87e1449
Merge pull request #181 from X-LANCE/ygr_pr2
ddlBoJack Nov 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 66 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@ developers to train custom multimodal large language model (MLLM), focusing on <
# Table of Contents
1. [News](#news)
2. [Installation](#installation)
3. [Uasge](#uasge)
3. [Usage](#usage)
- [List of Recipes](#list-of-recipes)
- [Configuration Priority](#configuration-priority)
4. [Features](#features)
5. [Acknowledge](#acknowledge)
6. [Citation](#citation)

# News
- [Update Oct. 12, 2024] Recipes for [SLAM-AAC](examples/slam_aac/README.md) have been supported.
- [Update Nov. 17, 2024] Recipes for [LLM-Based Contextual ASR](examples/contextual_asr/README.md) have been supported.
- [Update Nov. 5, 2024] Recipes for [speech emotion captioning (SEC)](examples/sec_emotioncaps/README.md) with [emotion2vec](https://github.com/ddlBoJack/emotion2vec) as the encoder has been supported.
- [Update Oct. 12, 2024] Recipes for [SLAM-AAC](examples/slam_aac/README.md) with [EAT](https://github.com/cwx-worst-one/EAT) as the encoder have been supported.
- [Update Sep. 28, 2024] Recipes for [CoT-ST](examples/st_covost2/README.md) have been supported.
- [Update Sep. 25, 2024] Recipes for [DRCap](examples/drcap_zeroshot_aac/README.md) have been supported.
- [Update Jun. 12, 2024] Recipes for [MaLa-ASR](examples/mala_asr_slidespeech/README.md) have been supported.
Expand Down Expand Up @@ -83,13 +85,15 @@ We provide reference implementations of various LLM-based speech, audio, and mus

- Contextual Automatic Speech Recognition (CASR)
- [ Mala-ASR](examples/mala_asr_slidespeech/README.md)
- [LLM-Based Contextual ASR](examples/contextual_asr/README.md)

- [Visual Speech Recognition (VSR)](examples/vsr_LRS3/README.md)
- Speech-to-Text Translation (S2TT)
- [CoT-ST](examples/st_covost2/README.md)

- Text-to-Speech (TTS)
- [VALL-E-X](examples/vallex/README.md)
- [Speech Emotion Captioning (SEC)](examples/sec_emotioncaps/README.md)

- **Audio Task**
- [Automated Audio Captioning (AAC)](examples/aac_audiocaps/README.md)
Expand Down Expand Up @@ -118,7 +122,10 @@ command-line (shell file) > Hydra configuration (yaml file) > dataclass configur
- We borrow code from [Fairseq](https://github.com/facebookresearch/fairseq) for deepspeed configuration.
- We thank the contributors for providing diverse recipes.

## Citation
# Citation

## Speech Task

SLAM-ASR:
```
@article{ma2024embarrassingly,
Expand All @@ -128,4 +135,60 @@ SLAM-ASR:
year={2024}
}
```
Mala-ASR:
```
@article{yang2024mala,
title={MaLa-ASR: Multimedia-Assisted LLM-Based ASR},
author={Yang, Guanrou and Ma, Ziyang and Yu, Fan and Gao, Zhifu and Zhang, Shiliang and Chen, Xie},
journal={Proc. INTERSPEECH},
year={2024}
}
```
LLM-Based Contextual ASR:
```
@article{yang2024ctc,
title={CTC-Assisted LLM-Based Contextual ASR},
author={Yang, Guanrou and Ma, Ziyang and Gao, Zhifu and Zhang, Shiliang and Chen, Xie},
journal={Proc. SLT},
year={2024}
}
```
CoT-ST:
```
@article{du2024cot,
title={CoT-ST: Enhancing LLM-based Speech Translation with Multimodal Chain-of-Thought},
author={Du, Yexing and Ma, Ziyang and Yang, Yifan and Deng, Keqi and Chen, Xie and Yang, Bo and Xiang, Yang and Liu, Ming and Qin, Bing},
journal={arXiv preprint arXiv:2409.19510},
year={2024}
}
```


## Audio Task
SLAM-AAC:
```
@article{chen2024slam,
title={SLAM-AAC: Enhancing Audio Captioning with Paraphrasing Augmentation and CLAP-Refine through LLMs},
author={Chen, Wenxi and Ma, Ziyang and Li, Xiquan and Xu, Xuenan and Liang, Yuzhe and Zheng, Zhisheng and Yu, Kai and Chen, Xie},
journal={arXiv preprint arXiv:2410.09503},
year={2024}
}
```
DRCap:
```
@article{li2024drcap,
title={DRCap: Decoding CLAP Latents with Retrieval-augmented Generation for Zero-shot Audio Captioning},
author={Li, Xiquan and Chen, Wenxi and Ma, Ziyang and Xu, Xuenan and Liang, Yuzhe and Zheng, Zhisheng and Kong, Qiuqiang and Chen, Xie},
journal={arXiv preprint arXiv:2410.09472},
year={2024}
}
```
BAT:
```
@article{zheng2024bat,
title={BAT: Learning to Reason about Spatial Sounds with Large Language Models},
author={Zheng, Zhisheng and Peng, Puyuan and Ma, Ziyang and Chen, Xie and Choi, Eunsol and Harwath, David},
journal={Proc. ICML},
year={2024}
}
```
6 changes: 5 additions & 1 deletion examples/aac_audiocaps/aac_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from dataclasses import dataclass, field
from typing import Optional, List

from torch.distributed.fsdp import ShardingStrategy


@dataclass
class ModelConfig:
file: str = "examples/aac_audiocaps/model/slam_model_aac.py:model_factory"
Expand Down Expand Up @@ -114,7 +118,7 @@ class FSDPConfig:
mixed_precision: bool = True
use_fp16: bool = False
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool = True
fsdp_cpu_offload: bool = False
Expand Down
6 changes: 5 additions & 1 deletion examples/asr_librispeech/asr_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from dataclasses import dataclass, field
from typing import Optional, List

from torch.distributed.fsdp import ShardingStrategy


@dataclass
class ModelConfig:
file: str = "examples/asr_librispeech/model/slam_model_asr.py:model_factory"
Expand Down Expand Up @@ -108,7 +112,7 @@ class FSDPConfig:
mixed_precision: bool = True
use_fp16: bool = False
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool = True
fsdp_cpu_offload: bool = False
Expand Down
62 changes: 62 additions & 0 deletions examples/contextual_asr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# CTC-Assisted LLM-Based Contextual ASR

## Guides

[CTC-Assisted LLM-Based Contextual ASR](https://arxiv.org/abs/2411.06437) is an LLM-based contextual ASR model that first uses CTC decoding results to filter potential relevant hotwords from pre-defined hotwords list and then incorporate them into LLM prompt input to improve recognition of hotwords.

## Model Architecture

We use WavLM-Large model pre-trained on 94, 000 hours of data, and fine-tuned on 960h hours of Librispeech data with CTC loss, as our speech encoder. We use the public Vicuna 7B as our large language model decoder, and a simple-structured linear projector, consisting of a 1-D convolution layer and two linear layers as our adapter. Refer to our [paper](https://arxiv.org/pdf/2411.06437) for more details.

![](docs/model.png)

## Checkpoints
We only train the linear projector in this recipe.
Encoder | Projector | LLM
|---|---|---|
[CTC Fine-tuned WavLM-Large](https://drive.google.com/file/d/12ZmSSbDvx73W0eK1wpUgajapCLhqh5DI/view?usp=drive_link)(~315.45M) | [Linear](https://drive.google.com/file/d/1Zlbsnz1YUWtYtt-yNyoPK5OhR30kwLfS/view?usp=drive_link)(~15.74M) | [vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5)(~6.7B)

## Performance
![](docs/performance.png)


## Data preparation
The artificial biasing list constructed in [Contextualized streaming end-to-end speech recognition with trie-based deep biasing and shallow fusion](https://arxiv.org/pdf/2104.02194) is utilized for contextual ASR testing. Refer to official [Repo](https://github.com/facebookresearch/fbai-speech/tree/main/is21_deep_bias).
They categorize the 5,000 most frequent words in the Librispeech training corpus as common
words, with the remainder classified as rare words. The biasing list generated for the test set consists of two segments: rare words in the transcriptions, and distractors sampled from the 209.2K rare words vocabulary. Biasing lists of varying lengths are generated by incorporating N = {100, 500, 1000, 2000} distractors into the lists.


The viterbi decode results of our CTC Fine-tuned WavLM-Large: [test-clean](https://drive.google.com/file/d/1kMzPx8oRK3aOsxNaMGski3zH8z5Otvek/view?usp=drive_link), [test-other](https://drive.google.com/file/d/12KHaatVg5O0MIBTcf8e_rNjV_i9WLBFR/view?usp=drive_link) (``ctc_file`` in contextual_asr_config.py)

## Decoding with checkpoints
LLM-based ASR Inference script.
```
bash decode_wavlm_libri960_ft_char.sh
```
LLM-based Contextual ASR Inference script, with different biaisng list sizes.
```
bash decode_wavlm_libri960_ft_char_hotwords.sh
```


## Training the model
LLM-based ASR Training script: using CTC fine-tuned Wavlm as encoder and “Transcribe speech to text.” as prompt.
```
bash finetune_wavlm_libri960_ft_char.sh
```
LLM-based Contextual ASR Training script: using CTC fine-tuned Wavlm as encoder and "Transcribe speech to text. Some hotwords might help. The hotwords are {}.” as prompt.
```
bash finetune_wavlm_libri960_ft_char_hotwords.sh
```


## Citation
You can refer to the paper for more results.
```
@article{yang2024ctc,
title={CTC-Assisted LLM-Based Contextual ASR},
author={Yang, Guanrou and Ma, Ziyang and Gao, Zhifu and Zhang, Shiliang and Chen, Xie},
journal={Proc. SLT},
year={2024}
}
```
19 changes: 19 additions & 0 deletions examples/contextual_asr/conf/ds_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-4
}
},
"fp16": {
"enabled": true
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu"
}
}
}
4 changes: 4 additions & 0 deletions examples/contextual_asr/conf/prompt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
dataset_config:
# we put prompt here, because the hydra override in shell script only support a small subset of chars
# prompt: "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. "
prompt: "Transcribe speech to text. "
135 changes: 135 additions & 0 deletions examples/contextual_asr/contextual_asr_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from dataclasses import dataclass, field
from typing import Optional, List
@dataclass
class ModelConfig:
file: str = "examples/contextual_asr/model/slam_model_contextual_asr.py:model_factory"
llm_name: str = "vicuna-13b-v1.5"
llm_path: str = "PATH/to/LLAMA/7B"
llm_type: str = "decoder_only"
llm_dim: int = 4096
encoder_name: Optional[str] = None
encoder_ds_rate: int = 2
encoder_path: Optional[str] = None
encoder_dim: int = 1280
encoder_projector: str = "linear"
encoder_projector_ds_rate: int = 5
modal: str = "audio"
normalize: Optional[bool] = field(default=False, metadata={
"help": "whether input is normalized, used for models such as wavlm"
})
encoder_type: str = field(default="finetune", metadata={
"help": "whether model is only pretrained or finetuned, used for models such as hubert"
})

@dataclass
class PeftConfig:
peft_method: str = "lora" # None , llama_adapter, prefix
r: int = 8
lora_alpha: int = 32
# target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj" ])
target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj","k_proj","o_proj" ])
bias: str = "none"
task_type: str = "CAUSAL_LM"
lora_dropout: float = 0.05
inference_mode: bool = False

@dataclass
class TrainConfig:
model_name:str = "PATH/to/LLAMA/7B"
enable_ddp:bool = False
enable_deepspeed:bool = False
enable_fsdp:bool = False
low_cpu_fsdp:bool = False
run_validation:bool = True
batch_size_training:int = 4
batching_strategy:str = field(default="packing", metadata={
"help":"alternative: padding"
})
context_length:int = 4096
gradient_accumulation_steps:int = 1
num_epochs:int = 3
num_workers_dataloader:int = 1
warmup_steps:int = 1000
total_steps:int = 100000
validation_interval:int = 1000
lr:float = 1e-4
weight_decay:float = 0.0
gamma:float = 0.85
seed:int = 42
use_fp16:bool = False
mixed_precision:bool = True
val_batch_size:int = 1
use_peft:bool = False
peft_config:PeftConfig = field(default_factory=PeftConfig)
output_dir:str = "PATH/to/save/PEFT/model"
freeze_layers:bool = False
num_freeze_layers:int = 1
quantization:bool = False
one_gpu:bool = False
save_model:bool = True
dist_checkpoint_root_folder:str = "PATH/to/save/FSDP/model" # will be used if using FSDP
dist_checkpoint_folder:str = "fine-tuned" # will be used if using FSDP
save_optimizer:bool = False # will be used if using FSDP
use_fast_kernels:bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
run_test_during_validation:bool = False
run_test_during_validation_file:str = "test.wav"
run_test_during_validation_prompt:str = "<|ASR|>"
freeze_llm:bool = field(default=False, metadata={
"help": "whether to freeze llm when finetuning, should be true when use peft finetuning"
})
freeze_encoder:bool = False

@dataclass
class DataConfig:
dataset: str = "speech_dataset"
file: str = "src/slam_llm/datasets/speech_dataset.py:get_speech_dataset"
train_data_path: Optional[str] = None
val_data_path: Optional[str] = None
train_split: str = "train"
test_split:str = "validation"
prompt: Optional[str] = None
data_path: Optional[str] = None
max_words: Optional[int] = None
max_mel: Optional[float] = None
fix_length_audio: int = -1
inference_mode:bool = False
input_type: str = field(default="raw", metadata={
"help":"Use raw when input is wav, mel when for whisper"
})
mel_size: int = field(default=80, metadata={
"help": "80 for whisper large v1 and v2, 128 for v3"
})
normalize: Optional[bool] = field(default=False, metadata={
"help": "whether input is normalized, used for models such as wavlm"
})
infer_type: str = "bias"
infer_file: str = "/nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/my_ref/test-clean.biasing_100.tsv"
ctc_file: str = "/nfs/yangguanrou.ygr/data/librispeech_my_infer/wavlm_large_libri_test_other_char.txt"
common_words_5k_dir: str="/nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/words/common_words_5k.txt"
probability_threshold: float = 0.9
word_num: int = 15
filter_infer_sentence: bool = False
filter_infer_sentence_few: bool = False
first: int = 1

@dataclass
class FSDPConfig:
mixed_precision: bool = True
use_fp16: bool = False
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool = True
fsdp_cpu_offload: bool = False
pure_bf16: bool = False
optimizer: str = "AdamW"

@dataclass
class LogConfig:
use_wandb: bool = False
wandb_dir: str = "/root/test_wandb"
wandb_entity_name: str = "project_name"
wandb_project_name: str = "project_name"
wandb_exp_name: str = "exp_name"
log_file: str = "/root/test.log"
log_interval: int = 5
Loading
Loading