Skip to content

Commit 2ab898e

Browse files
authored
Merge pull request #147 from X-LANCE/cwx_slam_aac
SLAM-AAC open-source
2 parents d599ce4 + b8dbc12 commit 2ab898e

18 files changed

+1479
-2
lines changed

examples/slam_aac/README.md

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# SLAM-AAC
2+
3+
SLAM-AAC is a LLM-based model for Automated Audio Captioning (AAC) task. Inspired by techniques in machine translation and ASR, the model enhances audio captioning by incorporating paraphrasing augmentation and a plug-and-play CLAP-Refine strategy.
4+
<!-- For more details, please refer to the [paper](). -->
5+
6+
## Model Architecture
7+
SLAM-AAC uses EAT as the audio encoder and Vicuna-7B as the LLM decoder. During training, only the Linear Projector and LoRA modules are trainable. For inference, multiple candidates are generated using different beam sizes, which are then refined using the CLAP-Refine strategy.
8+
9+
![](./docs/model.png)
10+
11+
## Performance and checkpoints
12+
We have released the pre-trained checkpoint of SLAM-AAC, as well as the fine-tuned checkpoints for the Clotho and AudioCaps datasets. The provided checkpoints include the model's Linear Projector and LoRA modules. Please note that when using each component, be sure to set up the corresponding environments according to the instructions provided in the respective repositories (e.g., for [EAT](https://github.com/cwx-worst-one/EAT)).
13+
14+
### Pre-training
15+
SLAM-AAC was pre-trained on a combination of AudioCaps, Clotho, WavCaps, and MACS datasets. For more information on these datasets, you can refer to [this repository](https://github.com/Labbeti/aac-datasets). Additionally, the Clotho dataset was augmented using a back-translation-based paraphrasing technique.
16+
Audio Encoder | LLM | Checkpoint | Pre-training Dataset|
17+
|:---:|:---:|:---:|:---:|
18+
[EAT-base (fine-tuned)](https://drive.google.com/file/d/1aCYiQmoZv_Gh1FxnR-CCWpNAp6DIJzn6/view?usp=sharing) |[vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | [link](https://drive.google.com/drive/folders/10kOjB112AeGYA_0mIUr8f1-i5rSg08_O?usp=sharing) | AudioCaps, Clotho, WavCaps, MACS |
19+
20+
### Fine-tuning
21+
We fine-tuned the pre-trained model on the Clotho and AudioCaps datasets, respectively. The final evaluation was conducted using audio captions generated with the CLAP-Refine decoding strategy.
22+
Dataset | Audio Encoder | LLM | Checkpoint | METEOR | CIDEr | SPICE | SPIDEr | SPIDEr-FL | FENSE
23+
|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
24+
| Clotho | [EAT-base (fine-tuned)](https://drive.google.com/file/d/1aCYiQmoZv_Gh1FxnR-CCWpNAp6DIJzn6/view?usp=sharing) | [vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | [link](https://drive.google.com/drive/folders/1QX7CM9YAddPi02_NRChI5mzsNmBBtA63?usp=sharing) | 19.7 | 51.5 | 14.8 |33.2 | 33.0 | 54.0 |
25+
| AudioCaps | [EAT-base (fine-tuned)](https://drive.google.com/file/d/1aCYiQmoZv_Gh1FxnR-CCWpNAp6DIJzn6/view?usp=sharing) | [vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | [link](https://drive.google.com/drive/folders/1GhFPiSVmBE9BvBhYWCEqkFuH-avKl-4g?usp=sharing) | 26.8 | 84.1 | 19.4 | 51.8 | 51.5 | 66.8 |
26+
27+
28+
## Data preparation
29+
Ensure your `jsonl` data follows the structure outlined below:
30+
```json
31+
{"key": "Y7fmOlUlwoNg_1", "source": "/root/data/AudioCaps/waveforms/test/Y7fmOlUlwoNg.wav", "target": "Constant rattling noise and sharp vibrations"}
32+
{"key": "Y6BJ455B1aAs_1", "source": "/root/data/AudioCaps/waveforms/test/Y6BJ455B1aAs.wav", "target": "A rocket flies by followed by a loud explosion and fire crackling as a truck engine runs idle"}
33+
```
34+
In addition, you can refer to the [manifest](https://drive.google.com/drive/folders/1NJinoWg3yXKSPm-pRrhqKLvCD9dtDuDG?usp=sharing) file we've provided, which includes the Clotho dataset enhanced with **paraphrasing augmentation** as bonus.
35+
36+
## Model Training
37+
To pre-train the SLAM-AAC model with pre-training data, you can run the following command:
38+
```bash
39+
# Pre-train the model
40+
bash scripts/pretrain.sh
41+
```
42+
43+
You can fine-tune the model on the AudioCaps or Clotho datasets using the [provided checkpoint](https://drive.google.com/drive/folders/10kOjB112AeGYA_0mIUr8f1-i5rSg08_O?usp=sharing) or your own pre-trained model by running the following commands:
44+
45+
```bash
46+
# Fine-tune on AudioCaps
47+
bash scripts/finetune_audiocaps.sh
48+
49+
# Fine-tune on Clotho
50+
bash scripts/finetune_clotho.sh
51+
```
52+
53+
You can also fine-tune the model without loading any pre-trained weights, though this may result in reduced performance.
54+
55+
56+
### Note
57+
- In the current version of SLAM-LLM, the `peft_ckpt` parameter is no longer required. However, if you are using the checkpoint provided by us, which was trained with an earlier version, please keep the `peft_ckpt` parameter in your configuration to ensure compatibility.
58+
- Due to differences in dependency versions, there may be slight variations in the performance of the SLAM-AAC model.
59+
60+
## Inference
61+
To perform inference with the trained models, you can use the following commands to decode using the common beam search method:
62+
```bash
63+
# Inference on AudioCaps (Beam Search)
64+
bash scripts/inference_audiocaps_bs.sh
65+
66+
# Inference on Clotho (Beam Search)
67+
bash scripts/inference_clotho_bs.sh
68+
```
69+
70+
For improved inference results, you can use the CLAP-Refine strategy, which utilizes multiple beam search decoding. To use this method, you need to download and use our pre-trained [CLAP](https://drive.google.com/drive/folders/1X4NYE08N-kbOy6s_Itb0wBR_3X8oZF56?usp=sharing) model. Note that CLAP-Refine may take longer to run, but it can provide better quality outputs. You can execute the following commands:
71+
```bash
72+
# Inference on AudioCaps (CLAP-Refine)
73+
bash scripts/inference_audiocaps_CLAP_Refine.sh
74+
75+
# Inference on Clotho (CLAP-Refine)
76+
bash scripts/inference_clotho_CLAP_Refine.sh
77+
```
78+
79+
If you already have the generated candidates and want to directly refine them using the CLAP-Refine strategy, you can run the following command:
80+
```bash
81+
bash scripts/clap_refine.sh
82+
```
83+
84+
<!-- ## Citation
85+
You can refer to the paper for more results.
86+
```
87+
88+
``` -->

examples/slam_aac/aac_config.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
from dataclasses import dataclass, field
2+
from typing import Optional, List
3+
@dataclass
4+
class ModelConfig:
5+
file: str = "examples/slam_aac/model/slam_model_aac.py:model_factory"
6+
llm_name: str = "vicuna-13b-v1.5"
7+
llm_path: str = "PATH/to/LLAMA/7B"
8+
llm_type: str = "decoder_only"
9+
llm_dim: int = 4096
10+
encoder_name: Optional[str] = None
11+
encoder_ds_rate: int = 2
12+
encoder_path: Optional[str] = None
13+
encoder_dim: int = 1280
14+
encoder_projector: str = "linear"
15+
encoder_projector_ds_rate: int = 5
16+
encoder_fairseq_dir: str = "/fairseq/EAT"
17+
modal: str = "audio"
18+
normalize: Optional[bool] = field(default=False, metadata={
19+
"help": "whether inpit is normalized, used for models such as wavlm"
20+
})
21+
do_sample: bool = False
22+
top_p: float = 1.0
23+
temperature: float = 1.0
24+
num_beams: int = 4
25+
num_return_sequences: int = 1
26+
length_penalty: float = 1.0
27+
repetition_penalty: float = 1.0
28+
max_new_tokens: int = 200
29+
min_length: int = 1
30+
31+
@dataclass
32+
class PeftConfig:
33+
peft_method: str = "lora" # None , llama_adapter, prefix
34+
r: int = 8
35+
lora_alpha: int = 32
36+
target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj" ])
37+
bias: str = "none"
38+
task_type: str = "CAUSAL_LM"
39+
lora_dropout: float = 0.05
40+
inference_mode: bool = False
41+
42+
@dataclass
43+
class TrainConfig:
44+
model_name:str = "PATH/to/LLAMA/7B"
45+
enable_ddp:bool = False
46+
enable_deepspeed:bool = False
47+
enable_fsdp:bool = False
48+
low_cpu_fsdp:bool = False
49+
run_validation:bool = True
50+
batch_size_training:int = 4
51+
batching_strategy:str = field(default="packing", metadata={
52+
"help":"alternative: padding"
53+
}) #
54+
context_length:int = 4096
55+
gradient_accumulation_steps:int = 1
56+
num_epochs:int = 3
57+
num_workers_dataloader:int = 1
58+
warmup_steps:int = 1000
59+
total_steps:int = 100000
60+
validation_interval:int = 1000
61+
lr:float = 1e-4
62+
weight_decay:float = 0.0
63+
gamma:float = 0.85
64+
seed:int = 42
65+
use_fp16:bool = False
66+
mixed_precision:bool = True
67+
val_batch_size:int = 1
68+
69+
use_peft:bool = False
70+
peft_config:PeftConfig = field(default_factory=PeftConfig)
71+
output_dir:str = "PATH/to/save/PEFT/model"
72+
freeze_layers:bool = False
73+
num_freeze_layers:int = 1
74+
quantization:bool = False
75+
one_gpu:bool = False
76+
save_model:bool = True
77+
dist_checkpoint_root_folder:str = "PATH/to/save/FSDP/model" # will be used if using FSDP
78+
dist_checkpoint_folder:str = "fine-tuned" # will be used if using FSDP
79+
save_optimizer:bool = False # will be used if using FSDP
80+
use_fast_kernels:bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
81+
run_test_during_validation:bool = False
82+
run_test_during_validation_file:str = "test.wav"
83+
run_test_during_validation_prompt:str = "<|ASR|>"
84+
freeze_llm:bool = field(default=False, metadata={
85+
"help": "whether to freeze llm when finetuning, should be true when use peft finetuning"
86+
})
87+
freeze_encoder:bool = False
88+
specaug:bool = False
89+
noise_aug:bool = False
90+
91+
@dataclass
92+
class DataConfig:
93+
dataset: str = "audio_dataset"
94+
file: str = "src/slam_llm/datasets/audio_dataset.py:get_audio_dataset"
95+
train_data_path: Optional[str] = None
96+
val_data_path: Optional[str] = None
97+
train_split: str = "train"
98+
test_split:str = "validation"
99+
prompt: Optional[str] = None
100+
data_path: Optional[str] = None
101+
max_words: Optional[int] = None
102+
max_mel: Optional[float] = None
103+
fix_length_audio: int = -1
104+
inference_mode:bool = False
105+
model_name: str = 'eat'
106+
fbank_mean: float = -4.268
107+
fbank_std: float = 4.569
108+
target_length: int = 1024
109+
fixed_length: bool = False
110+
prompt: str = "Describe the audio you hear."
111+
random_crop: bool = False
112+
encoder_projector_ds_rate: int = 5
113+
input_type: str = field(default="raw", metadata={
114+
"help":"Use raw when input is wav, mel when for whisper"
115+
})
116+
mel_size: int = field(default=80, metadata={
117+
"help": "80 for whisper large v1 and v2, 128 for v3"
118+
})
119+
normalize: Optional[bool] = field(default=False, metadata={
120+
"help": "whether inpit is normalized, used for models such as wavlm"
121+
})
122+
123+
@dataclass
124+
class FSDPConfig:
125+
mixed_precision: bool = True
126+
use_fp16: bool = False
127+
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
128+
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
129+
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
130+
fsdp_activation_checkpointing: bool = True
131+
fsdp_cpu_offload: bool = False
132+
pure_bf16: bool = False
133+
optimizer: str = "AdamW"
134+
135+
@dataclass
136+
class LogConfig:
137+
use_wandb: bool = False
138+
wandb_dir: str = "/root/test_wandb"
139+
wandb_entity_name: str = "project_name"
140+
wandb_project_name: str = "project_name"
141+
wandb_exp_name: str = "exp_name"
142+
log_file: str = "/root/test.log"
143+
log_interval: int = 5
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
"train_micro_batch_size_per_gpu": 4,
3+
"gradient_accumulation_steps": 1,
4+
"optimizer": {
5+
"type": "Adam",
6+
"params": {
7+
"lr": 1e-4
8+
}
9+
},
10+
"fp16": {
11+
"enabled": true
12+
},
13+
"zero_optimization": {
14+
"stage": 3,
15+
"offload_optimizer": {
16+
"device": "cpu"
17+
}
18+
}
19+
}

examples/slam_aac/conf/prompt.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
dataset_config:
2+
# we put prompt here, because the hydra override in shell script only support a small subset of chars
3+
prompt: "Describe the audio you hear."

examples/slam_aac/docs/model.png

132 KB
Loading

examples/slam_aac/finetune_aac.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from slam_llm.pipeline.finetune import main as train
2+
3+
import hydra
4+
import logging
5+
from typing import Optional
6+
from dataclasses import dataclass, field
7+
from omegaconf import DictConfig, ListConfig, OmegaConf
8+
from aac_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig
9+
10+
@dataclass
11+
class RunConfig:
12+
dataset_config: DataConfig = field(default_factory=DataConfig)
13+
model_config: ModelConfig = field(default_factory=ModelConfig)
14+
train_config: TrainConfig = field(default_factory=TrainConfig)
15+
log_config: LogConfig = field(default_factory=LogConfig)
16+
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
17+
debug: bool = field(default=False, metadata={"help": "Use pdb when true"})
18+
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"})
19+
ckpt_path: Optional[str] = field(
20+
default=None, metadata={"help": "The path to projector checkpoint"}
21+
)
22+
peft_ckpt: Optional[str] = field(
23+
default=None, metadata={"help": "The path to peft checkpoint"}
24+
)
25+
26+
@hydra.main(config_name=None, version_base=None)
27+
def main_hydra(cfg: DictConfig):
28+
run_config = RunConfig()
29+
cfg = OmegaConf.merge(run_config, cfg)
30+
def to_plain_list(cfg_item):
31+
if isinstance(cfg_item, ListConfig):
32+
return OmegaConf.to_container(cfg_item, resolve=True)
33+
elif isinstance(cfg_item, DictConfig):
34+
return {k: to_plain_list(v) for k, v in cfg_item.items()}
35+
else:
36+
return cfg_item
37+
38+
# kwargs = to_plain_list(cfg)
39+
kwargs = cfg
40+
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
41+
42+
logging.basicConfig(level=log_level)
43+
44+
if kwargs.get("debug", False):
45+
import pdb;
46+
pdb.set_trace()
47+
48+
train(kwargs)
49+
50+
51+
if __name__ == "__main__":
52+
main_hydra()
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from slam_llm.pipeline.inference_batch import main as inference
2+
3+
import hydra
4+
import logging
5+
from dataclasses import dataclass, field
6+
from omegaconf import DictConfig, ListConfig, OmegaConf
7+
from typing import Optional
8+
from aac_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig
9+
10+
11+
@dataclass
12+
class RunConfig:
13+
dataset_config: DataConfig = field(default_factory=DataConfig)
14+
model_config: ModelConfig = field(default_factory=ModelConfig)
15+
train_config: TrainConfig = field(default_factory=TrainConfig)
16+
log_config: LogConfig = field(default_factory=LogConfig)
17+
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
18+
debug: bool = field(default=False, metadata={"help": "Use pdb when true"})
19+
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"})
20+
decode_log: str = field(
21+
default="output/decode_log",
22+
metadata={"help": "The prefix for the decode output"},
23+
)
24+
ckpt_path: str = field(
25+
default="output/model.pt", metadata={"help": "The path to projector checkpoint"}
26+
)
27+
peft_ckpt: Optional[str] = field(
28+
default=None,
29+
metadata={
30+
"help": "The path to peft checkpoint, should be a directory including adapter_config.json"
31+
},
32+
)
33+
34+
35+
@hydra.main(config_name=None, version_base=None)
36+
def main_hydra(cfg: DictConfig):
37+
run_config = RunConfig()
38+
cfg = OmegaConf.merge(run_config, cfg)
39+
# kwargs = to_plain_list(cfg)
40+
log_level = getattr(logging, cfg.get("log_level", "INFO").upper())
41+
42+
logging.basicConfig(level=log_level)
43+
44+
if cfg.get("debug", False):
45+
import pdb
46+
47+
pdb.set_trace()
48+
49+
inference(cfg)
50+
51+
52+
if __name__ == "__main__":
53+
main_hydra()

0 commit comments

Comments
 (0)