Skip to content

Commit dbfcfca

Browse files
authored
Merge pull request #156 from xrysamuel/main
Add recipe "sec_emotioncaps"
2 parents 6fb784b + c16e3a9 commit dbfcfca

File tree

12 files changed

+608
-0
lines changed

12 files changed

+608
-0
lines changed

examples/sec_emotioncaps/README.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Speech Emotion Caption
2+
3+
## Model Architecture
4+
5+
This recipe generates high-quality, human-like speech emotion descriptions. The model is based on the **q-former projector** and the **vicuna-7b-v1.5 LLM**. The model is trained on **an unpublished datasets** dataset, which is a large-scale dataset for speech emotion captioning.
6+
7+
![](docs/model.png)
8+
9+
## Performance and checkpoints
10+
11+
We only train the q-former projector in this recipe.
12+
13+
Encoder | Projector | LLM | Similarity Score
14+
---|---|---|---
15+
[emotion2vec_base](https://huggingface.co/emotion2vec/emotion2vec_base) | [Q-Former](to_do)| [vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | 71.10
16+
17+
> **Note**: The baseline model [SECap](https://github.com/thuhcsi/SECap) was tested in our environment and achieved a similarity score of 71.52. Our model's score is slightly lower.
18+
19+
## Data preparation
20+
You need to prepare the data jsonl in this format.
21+
22+
```
23+
{"key": "key_name", "source": "path_to_wav_file", "target": "corresponding_caption"}
24+
...
25+
```
26+
27+
28+
## Decode with checkpoints
29+
30+
```
31+
bash decode_emotion2vec_qformer_vicuna_7b.sh
32+
```
33+
34+
Modify the path including `speech_encoder_path`, `llm_path`, `output_dir`, `ckpt_path`, `val_data_path` and `decode_log` in the script when you run the shell script.
35+
36+
## Train a new model
37+
38+
If you do have sufficient relevant data, you can train the model yourself.
39+
40+
```
41+
bash finetune_emotion2vec_qformer_vicuna_7b.sh
42+
```
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
"train_micro_batch_size_per_gpu": 4,
3+
"gradient_accumulation_steps": 1,
4+
"optimizer": {
5+
"type": "Adam",
6+
"params": {
7+
"lr": 1e-4
8+
}
9+
},
10+
"fp16": {
11+
"enabled": true
12+
},
13+
"zero_optimization": {
14+
"stage": 3,
15+
"offload_optimizer": {
16+
"device": "cpu"
17+
}
18+
}
19+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
dataset_config:
2+
# we put prompt here, because the hydra override in shell script only support a small subset of chars
3+
prompt: "请用中文用一句话描述上面给出的音频中说话人的情感。"
111 KB
Loading
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from slam_llm.pipeline.finetune import main as train
2+
3+
import hydra
4+
import logging
5+
from typing import Optional
6+
from dataclasses import dataclass, field
7+
from omegaconf import DictConfig, ListConfig, OmegaConf
8+
from sec_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig
9+
10+
@dataclass
11+
class RunConfig:
12+
dataset_config: DataConfig = field(default_factory=DataConfig)
13+
model_config: ModelConfig = field(default_factory=ModelConfig)
14+
train_config: TrainConfig = field(default_factory=TrainConfig)
15+
log_config: LogConfig = field(default_factory=LogConfig)
16+
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
17+
debug: bool = field(default=False, metadata={"help": "Use pdb when true"})
18+
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"})
19+
ckpt_path: Optional[str] = field(
20+
default=None, metadata={"help": "The path to projector checkpoint"}
21+
)
22+
23+
@hydra.main(config_name=None, version_base=None)
24+
def main_hydra(cfg: DictConfig):
25+
run_config = RunConfig()
26+
cfg = OmegaConf.merge(run_config, cfg)
27+
def to_plain_list(cfg_item):
28+
if isinstance(cfg_item, ListConfig):
29+
return OmegaConf.to_container(cfg_item, resolve=True)
30+
elif isinstance(cfg_item, DictConfig):
31+
return {k: to_plain_list(v) for k, v in cfg_item.items()}
32+
else:
33+
return cfg_item
34+
35+
# kwargs = to_plain_list(cfg)
36+
kwargs = cfg
37+
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
38+
39+
logging.basicConfig(level=log_level)
40+
41+
if kwargs.get("debug", False):
42+
import pdb;
43+
pdb.set_trace()
44+
45+
train(kwargs)
46+
47+
48+
if __name__ == "__main__":
49+
main_hydra()
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from slam_llm.pipeline.inference_batch import main as inference
2+
3+
import hydra
4+
import logging
5+
from dataclasses import dataclass, field
6+
from omegaconf import DictConfig, ListConfig, OmegaConf
7+
from typing import Optional
8+
from sec_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig
9+
10+
11+
@dataclass
12+
class RunConfig:
13+
dataset_config: DataConfig = field(default_factory=DataConfig)
14+
model_config: ModelConfig = field(default_factory=ModelConfig)
15+
train_config: TrainConfig = field(default_factory=TrainConfig)
16+
log_config: LogConfig = field(default_factory=LogConfig)
17+
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
18+
debug: bool = field(default=False, metadata={"help": "Use pdb when true"})
19+
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"})
20+
decode_log: str = field(
21+
default="output/decode_log",
22+
metadata={"help": "The prefix for the decode output"},
23+
)
24+
ckpt_path: str = field(
25+
default="output/model.pt", metadata={"help": "The path to projector checkpoint"}
26+
)
27+
peft_ckpt: Optional[str] = field(
28+
default=None,
29+
metadata={
30+
"help": "The path to peft checkpoint, should be a directory including adapter_config.json"
31+
},
32+
)
33+
34+
35+
@hydra.main(config_name=None, version_base=None)
36+
def main_hydra(cfg: DictConfig):
37+
run_config = RunConfig()
38+
cfg = OmegaConf.merge(run_config, cfg)
39+
# kwargs = to_plain_list(cfg)
40+
log_level = getattr(logging, cfg.get("log_level", "INFO").upper())
41+
42+
logging.basicConfig(level=log_level)
43+
44+
if cfg.get("debug", False):
45+
import pdb
46+
47+
pdb.set_trace()
48+
49+
inference(cfg)
50+
51+
52+
if __name__ == "__main__":
53+
main_hydra()
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import torch
2+
import os
3+
import logging
4+
from slam_llm.models.slam_model import (
5+
slam_model,
6+
setup_tokenizer,
7+
setup_encoder,
8+
setup_encoder_projector,
9+
setup_llm,
10+
)
11+
from slam_llm.utils.train_utils import print_model_size
12+
13+
logger = logging.getLogger(__name__)
14+
15+
def model_factory(train_config, model_config, **kwargs):
16+
# return necessary components for training
17+
tokenizer = setup_tokenizer(train_config, model_config, **kwargs)
18+
19+
encoder = setup_encoder(train_config, model_config, **kwargs)
20+
21+
# llm
22+
llm = setup_llm(train_config, model_config, **kwargs)
23+
24+
# projector
25+
encoder_projector = setup_encoder_projector(
26+
train_config, model_config, **kwargs
27+
)
28+
model = slam_model_sec(
29+
encoder,
30+
llm,
31+
encoder_projector,
32+
tokenizer,
33+
train_config,
34+
model_config,
35+
**kwargs,
36+
)
37+
38+
ckpt_path = kwargs.get(
39+
"ckpt_path", None
40+
) # FIX(MZY): load model ckpt(mainly projector, related to model_checkpointing/checkpoint_handler.py: save_model_checkpoint_peft)
41+
if ckpt_path is not None:
42+
logger.info("loading other parts from: {}".format(ckpt_path))
43+
ckpt_dict = torch.load(ckpt_path, map_location="cpu")
44+
model.load_state_dict(ckpt_dict, strict=False)
45+
46+
print_model_size(
47+
model,
48+
train_config,
49+
(
50+
int(os.environ["RANK"])
51+
if train_config.enable_fsdp or train_config.enable_ddp
52+
else 0
53+
),
54+
)
55+
return model, tokenizer
56+
57+
58+
class slam_model_sec(slam_model):
59+
def __init__(
60+
self,
61+
encoder,
62+
llm,
63+
encoder_projector,
64+
tokenizer,
65+
train_config,
66+
model_config,
67+
**kwargs,
68+
):
69+
super().__init__(
70+
encoder,
71+
llm,
72+
encoder_projector,
73+
tokenizer,
74+
train_config,
75+
model_config,
76+
**kwargs,
77+
)
78+
79+
80+
@torch.no_grad()
81+
def inference(
82+
self,
83+
wav_path=None,
84+
prompt=None,
85+
generation_config=None,
86+
logits_processor=None,
87+
stopping_criteria=None,
88+
prefix_allowed_tokens_fn=None,
89+
synced_gpus=None,
90+
assistant_model=None,
91+
streamer=None,
92+
negative_prompt_ids=None,
93+
negative_prompt_attention_mask=None,
94+
**kwargs,
95+
):
96+
# inference for asr model
97+
98+
device = kwargs.get("device", "cuda")
99+
if os.path.exists(wav_path): # Audio-Text QA
100+
import whisper
101+
102+
audio_raw = whisper.load_audio(wav_path)
103+
audio_raw = whisper.pad_or_trim(audio_raw)
104+
105+
mel_size = getattr(
106+
self.dataset_config, "mel_size", 80
107+
) # 80 for large v1 and v2, 128 for large v3
108+
audio_mel = (
109+
whisper.log_mel_spectrogram(audio_raw, n_mels=mel_size)
110+
.permute(1, 0)[None, :, :]
111+
.to(device)
112+
)
113+
114+
encoder_outs = self.encoder.extract_variable_length_features(
115+
audio_mel.permute(0, 2, 1)
116+
)
117+
118+
if self.model_config.encoder_projector == "q-former":
119+
audio_mel_post_mask = torch.ones(
120+
encoder_outs.size()[:-1], dtype=torch.long
121+
).to(encoder_outs.device)
122+
encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask)
123+
if self.model_config.encoder_projector == "linear":
124+
encoder_outs = self.encoder_projector(encoder_outs)
125+
else: # Text QA
126+
encoder_outs = torch.empty(
127+
1, 0, self.llm.model.embed_tokens.embedding_dim
128+
).to(device)
129+
130+
prompt = "USER: {}\n ASSISTANT:".format(prompt)
131+
prompt_ids = self.tokenizer.encode(prompt)
132+
prompt_length = len(prompt_ids)
133+
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(device)
134+
135+
if hasattr(self.llm.model, "embed_tokens"):
136+
inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
137+
elif hasattr(self.llm.model.model, "embed_tokens"):
138+
inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids)
139+
else:
140+
inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
141+
142+
inputs_embeds = torch.cat(
143+
(encoder_outs, inputs_embeds[None, :, :]), dim=1
144+
) # [audio,prompt]
145+
146+
attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(
147+
inputs_embeds.device
148+
)
149+
150+
# generate
151+
model_outputs = self.generate(
152+
inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs
153+
)
154+
155+
return model_outputs
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#!/bin/bash
2+
# export PYTHONPATH=/root/whisper:$PYTHONPATH
3+
# export PYTHONPATH=/root/fairseq:$PYTHONPATH
4+
export CUDA_VISIBLE_DEVICES=1
5+
export TOKENIZERS_PARALLELISM=false
6+
# export CUDA_LAUNCH_BLOCKING=1
7+
export OMP_NUM_THREADS=1
8+
9+
# debug setting for multiple gpus
10+
# export NCCL_DEBUG=INFO
11+
# export NCCL_DEBUG_SUBSYS=ALL
12+
# export TORCH_DISTRIBUTED_DEBUG=INFO
13+
14+
run_dir=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/SLAM-LLM
15+
cd $run_dir
16+
code_dir=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/SLAM-LLM/examples/sec_emotioncaps
17+
18+
speech_encoder_path=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/ckpt/emotion2vec_base.pt
19+
llm_path=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/ckpt/vicuna-7b-v1.5
20+
val_data_path=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/data/valid.jsonl
21+
22+
encoder_fairseq_dir=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/deps/emotion2vec/upstream
23+
24+
output_dir=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/out/sec-decode-$(date +"%Y%m%d-%s")
25+
26+
ckpt_path=/hpc_stor03/sjtu_home/ruiyang.xu/SLAM/out/sec-finetune-20241001-1727786623/sec_epoch_1_step_3000/model.pt
27+
28+
decode_log=$output_dir/decode_log
29+
30+
hydra_args="
31+
hydra.run.dir=$output_dir \
32+
++model_config.llm_name=vicuna-7b-v1.5 \
33+
++model_config.llm_path=$llm_path \
34+
++model_config.llm_dim=4096 \
35+
++model_config.encoder_name=emotion2vec \
36+
++model_config.encoder_projector_ds_rate=5 \
37+
++model_config.encoder_path=$speech_encoder_path \
38+
++model_config.encoder_fairseq_dir=$encoder_fairseq_dir \
39+
++model_config.encoder_dim=768 \
40+
++model_config.encoder_projector=q-former \
41+
++dataset_config.dataset=speech_dataset \
42+
++dataset_config.val_data_path=$val_data_path \
43+
++dataset_config.data_path=$val_data_path \
44+
++dataset_config.inference_mode=true \
45+
++dataset_config.input_type=raw \
46+
++train_config.model_name=sec \
47+
++train_config.num_epochs=1 \
48+
++train_config.freeze_encoder=true \
49+
++train_config.freeze_llm=true \
50+
++train_config.batching_strategy=custom \
51+
++train_config.val_batch_size=4 \
52+
++train_config.num_workers_dataloader=2 \
53+
++train_config.output_dir=$output_dir \
54+
++log_config.log_file=$output_dir/train.log \
55+
++ckpt_path=$ckpt_path \
56+
++decode_log=$decode_log
57+
"
58+
59+
# -m debugpy --listen 5678 --wait-for-client
60+
python $code_dir/inference_sec_batch.py \
61+
--config-path "conf" \
62+
--config-name "prompt.yaml" \
63+
$hydra_args

0 commit comments

Comments
 (0)