diff --git a/examples/slam_aac/README.md b/examples/slam_aac/README.md
new file mode 100644
index 00000000..eb39d158
--- /dev/null
+++ b/examples/slam_aac/README.md
@@ -0,0 +1,88 @@
+# SLAM-AAC
+
+SLAM-AAC is a LLM-based model for Automated Audio Captioning (AAC) task. Inspired by techniques in machine translation and ASR, the model enhances audio captioning by incorporating paraphrasing augmentation and a plug-and-play CLAP-Refine strategy.
+
+
+## Model Architecture
+SLAM-AAC uses EAT as the audio encoder and Vicuna-7B as the LLM decoder. During training, only the Linear Projector and LoRA modules are trainable. For inference, multiple candidates are generated using different beam sizes, which are then refined using the CLAP-Refine strategy.
+
+
+
+## Performance and checkpoints
+We have released the pre-trained checkpoint of SLAM-AAC, as well as the fine-tuned checkpoints for the Clotho and AudioCaps datasets. The provided checkpoints include the model's Linear Projector and LoRA modules. Please note that when using each component, be sure to set up the corresponding environments according to the instructions provided in the respective repositories (e.g., for [EAT](https://github.com/cwx-worst-one/EAT)).
+
+### Pre-training
+SLAM-AAC was pre-trained on a combination of AudioCaps, Clotho, WavCaps, and MACS datasets. For more information on these datasets, you can refer to [this repository](https://github.com/Labbeti/aac-datasets). Additionally, the Clotho dataset was augmented using a back-translation-based paraphrasing technique.
+Audio Encoder | LLM | Checkpoint | Pre-training Dataset|
+|:---:|:---:|:---:|:---:|
+[EAT-base (fine-tuned)](https://drive.google.com/file/d/1aCYiQmoZv_Gh1FxnR-CCWpNAp6DIJzn6/view?usp=sharing) |[vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | [link](https://drive.google.com/drive/folders/10kOjB112AeGYA_0mIUr8f1-i5rSg08_O?usp=sharing) | AudioCaps, Clotho, WavCaps, MACS |
+
+### Fine-tuning
+We fine-tuned the pre-trained model on the Clotho and AudioCaps datasets, respectively. The final evaluation was conducted using audio captions generated with the CLAP-Refine decoding strategy.
+Dataset | Audio Encoder | LLM | Checkpoint | METEOR | CIDEr | SPICE | SPIDEr | SPIDEr-FL | FENSE
+|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
+| Clotho | [EAT-base (fine-tuned)](https://drive.google.com/file/d/1aCYiQmoZv_Gh1FxnR-CCWpNAp6DIJzn6/view?usp=sharing) | [vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | [link](https://drive.google.com/drive/folders/1QX7CM9YAddPi02_NRChI5mzsNmBBtA63?usp=sharing) | 19.7 | 51.5 | 14.8 |33.2 | 33.0 | 54.0 |
+| AudioCaps | [EAT-base (fine-tuned)](https://drive.google.com/file/d/1aCYiQmoZv_Gh1FxnR-CCWpNAp6DIJzn6/view?usp=sharing) | [vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | [link](https://drive.google.com/drive/folders/1GhFPiSVmBE9BvBhYWCEqkFuH-avKl-4g?usp=sharing) | 26.8 | 84.1 | 19.4 | 51.8 | 51.5 | 66.8 |
+
+
+## Data preparation
+Ensure your `jsonl` data follows the structure outlined below:
+```json
+{"key": "Y7fmOlUlwoNg_1", "source": "/root/data/AudioCaps/waveforms/test/Y7fmOlUlwoNg.wav", "target": "Constant rattling noise and sharp vibrations"}
+{"key": "Y6BJ455B1aAs_1", "source": "/root/data/AudioCaps/waveforms/test/Y6BJ455B1aAs.wav", "target": "A rocket flies by followed by a loud explosion and fire crackling as a truck engine runs idle"}
+```
+In addition, you can refer to the [manifest](https://drive.google.com/drive/folders/1NJinoWg3yXKSPm-pRrhqKLvCD9dtDuDG?usp=sharing) file we've provided, which includes the Clotho dataset enhanced with **paraphrasing augmentation** as bonus.
+
+## Model Training
+To pre-train the SLAM-AAC model with pre-training data, you can run the following command:
+```bash
+# Pre-train the model
+bash scripts/pretrain.sh
+```
+
+You can fine-tune the model on the AudioCaps or Clotho datasets using the [provided checkpoint](https://drive.google.com/drive/folders/10kOjB112AeGYA_0mIUr8f1-i5rSg08_O?usp=sharing) or your own pre-trained model by running the following commands:
+
+```bash
+# Fine-tune on AudioCaps
+bash scripts/finetune_audiocaps.sh
+
+# Fine-tune on Clotho
+bash scripts/finetune_clotho.sh
+```
+
+You can also fine-tune the model without loading any pre-trained weights, though this may result in reduced performance.
+
+
+### Note
+- In the current version of SLAM-LLM, the `peft_ckpt` parameter is no longer required. However, if you are using the checkpoint provided by us, which was trained with an earlier version, please keep the `peft_ckpt` parameter in your configuration to ensure compatibility.
+- Due to differences in dependency versions, there may be slight variations in the performance of the SLAM-AAC model.
+
+## Inference
+To perform inference with the trained models, you can use the following commands to decode using the common beam search method:
+```bash
+# Inference on AudioCaps (Beam Search)
+bash scripts/inference_audiocaps_bs.sh
+
+# Inference on Clotho (Beam Search)
+bash scripts/inference_clotho_bs.sh
+```
+
+For improved inference results, you can use the CLAP-Refine strategy, which utilizes multiple beam search decoding. To use this method, you need to download and use our pre-trained [CLAP](https://drive.google.com/drive/folders/1X4NYE08N-kbOy6s_Itb0wBR_3X8oZF56?usp=sharing) model. Note that CLAP-Refine may take longer to run, but it can provide better quality outputs. You can execute the following commands:
+```bash
+# Inference on AudioCaps (CLAP-Refine)
+bash scripts/inference_audiocaps_CLAP_Refine.sh
+
+# Inference on Clotho (CLAP-Refine)
+bash scripts/inference_clotho_CLAP_Refine.sh
+```
+
+If you already have the generated candidates and want to directly refine them using the CLAP-Refine strategy, you can run the following command:
+```bash
+bash scripts/clap_refine.sh
+```
+
+
diff --git a/examples/slam_aac/aac_config.py b/examples/slam_aac/aac_config.py
new file mode 100644
index 00000000..50fca279
--- /dev/null
+++ b/examples/slam_aac/aac_config.py
@@ -0,0 +1,143 @@
+from dataclasses import dataclass, field
+from typing import Optional, List
+@dataclass
+class ModelConfig:
+ file: str = "examples/slam_aac/model/slam_model_aac.py:model_factory"
+ llm_name: str = "vicuna-13b-v1.5"
+ llm_path: str = "PATH/to/LLAMA/7B"
+ llm_type: str = "decoder_only"
+ llm_dim: int = 4096
+ encoder_name: Optional[str] = None
+ encoder_ds_rate: int = 2
+ encoder_path: Optional[str] = None
+ encoder_dim: int = 1280
+ encoder_projector: str = "linear"
+ encoder_projector_ds_rate: int = 5
+ encoder_fairseq_dir: str = "/fairseq/EAT"
+ modal: str = "audio"
+ normalize: Optional[bool] = field(default=False, metadata={
+ "help": "whether inpit is normalized, used for models such as wavlm"
+ })
+ do_sample: bool = False
+ top_p: float = 1.0
+ temperature: float = 1.0
+ num_beams: int = 4
+ num_return_sequences: int = 1
+ length_penalty: float = 1.0
+ repetition_penalty: float = 1.0
+ max_new_tokens: int = 200
+ min_length: int = 1
+
+@dataclass
+class PeftConfig:
+ peft_method: str = "lora" # None , llama_adapter, prefix
+ r: int = 8
+ lora_alpha: int = 32
+ target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj" ])
+ bias: str = "none"
+ task_type: str = "CAUSAL_LM"
+ lora_dropout: float = 0.05
+ inference_mode: bool = False
+
+@dataclass
+class TrainConfig:
+ model_name:str = "PATH/to/LLAMA/7B"
+ enable_ddp:bool = False
+ enable_deepspeed:bool = False
+ enable_fsdp:bool = False
+ low_cpu_fsdp:bool = False
+ run_validation:bool = True
+ batch_size_training:int = 4
+ batching_strategy:str = field(default="packing", metadata={
+ "help":"alternative: padding"
+ }) #
+ context_length:int = 4096
+ gradient_accumulation_steps:int = 1
+ num_epochs:int = 3
+ num_workers_dataloader:int = 1
+ warmup_steps:int = 1000
+ total_steps:int = 100000
+ validation_interval:int = 1000
+ lr:float = 1e-4
+ weight_decay:float = 0.0
+ gamma:float = 0.85
+ seed:int = 42
+ use_fp16:bool = False
+ mixed_precision:bool = True
+ val_batch_size:int = 1
+
+ use_peft:bool = False
+ peft_config:PeftConfig = field(default_factory=PeftConfig)
+ output_dir:str = "PATH/to/save/PEFT/model"
+ freeze_layers:bool = False
+ num_freeze_layers:int = 1
+ quantization:bool = False
+ one_gpu:bool = False
+ save_model:bool = True
+ dist_checkpoint_root_folder:str = "PATH/to/save/FSDP/model" # will be used if using FSDP
+ dist_checkpoint_folder:str = "fine-tuned" # will be used if using FSDP
+ save_optimizer:bool = False # will be used if using FSDP
+ use_fast_kernels:bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
+ run_test_during_validation:bool = False
+ run_test_during_validation_file:str = "test.wav"
+ run_test_during_validation_prompt:str = "<|ASR|>"
+ freeze_llm:bool = field(default=False, metadata={
+ "help": "whether to freeze llm when finetuning, should be true when use peft finetuning"
+ })
+ freeze_encoder:bool = False
+ specaug:bool = False
+ noise_aug:bool = False
+
+@dataclass
+class DataConfig:
+ dataset: str = "audio_dataset"
+ file: str = "src/slam_llm/datasets/audio_dataset.py:get_audio_dataset"
+ train_data_path: Optional[str] = None
+ val_data_path: Optional[str] = None
+ train_split: str = "train"
+ test_split:str = "validation"
+ prompt: Optional[str] = None
+ data_path: Optional[str] = None
+ max_words: Optional[int] = None
+ max_mel: Optional[float] = None
+ fix_length_audio: int = -1
+ inference_mode:bool = False
+ model_name: str = 'eat'
+ fbank_mean: float = -4.268
+ fbank_std: float = 4.569
+ target_length: int = 1024
+ fixed_length: bool = False
+ prompt: str = "Describe the audio you hear."
+ random_crop: bool = False
+ encoder_projector_ds_rate: int = 5
+ input_type: str = field(default="raw", metadata={
+ "help":"Use raw when input is wav, mel when for whisper"
+ })
+ mel_size: int = field(default=80, metadata={
+ "help": "80 for whisper large v1 and v2, 128 for v3"
+ })
+ normalize: Optional[bool] = field(default=False, metadata={
+ "help": "whether inpit is normalized, used for models such as wavlm"
+ })
+
+@dataclass
+class FSDPConfig:
+ mixed_precision: bool = True
+ use_fp16: bool = False
+ # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
+ sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
+ checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
+ fsdp_activation_checkpointing: bool = True
+ fsdp_cpu_offload: bool = False
+ pure_bf16: bool = False
+ optimizer: str = "AdamW"
+
+@dataclass
+class LogConfig:
+ use_wandb: bool = False
+ wandb_dir: str = "/root/test_wandb"
+ wandb_entity_name: str = "project_name"
+ wandb_project_name: str = "project_name"
+ wandb_exp_name: str = "exp_name"
+ log_file: str = "/root/test.log"
+ log_interval: int = 5
diff --git a/examples/slam_aac/conf/ds_config.json b/examples/slam_aac/conf/ds_config.json
new file mode 100644
index 00000000..7ea70e4a
--- /dev/null
+++ b/examples/slam_aac/conf/ds_config.json
@@ -0,0 +1,19 @@
+{
+ "train_micro_batch_size_per_gpu": 4,
+ "gradient_accumulation_steps": 1,
+ "optimizer": {
+ "type": "Adam",
+ "params": {
+ "lr": 1e-4
+ }
+ },
+ "fp16": {
+ "enabled": true
+ },
+ "zero_optimization": {
+ "stage": 3,
+ "offload_optimizer": {
+ "device": "cpu"
+ }
+ }
+}
\ No newline at end of file
diff --git a/examples/slam_aac/conf/prompt.yaml b/examples/slam_aac/conf/prompt.yaml
new file mode 100644
index 00000000..88c7e595
--- /dev/null
+++ b/examples/slam_aac/conf/prompt.yaml
@@ -0,0 +1,3 @@
+dataset_config:
+ # we put prompt here, because the hydra override in shell script only support a small subset of chars
+ prompt: "Describe the audio you hear."
diff --git a/examples/slam_aac/docs/model.png b/examples/slam_aac/docs/model.png
new file mode 100644
index 00000000..4b7dee30
Binary files /dev/null and b/examples/slam_aac/docs/model.png differ
diff --git a/examples/slam_aac/finetune_aac.py b/examples/slam_aac/finetune_aac.py
new file mode 100644
index 00000000..c46a8663
--- /dev/null
+++ b/examples/slam_aac/finetune_aac.py
@@ -0,0 +1,52 @@
+from slam_llm.pipeline.finetune import main as train
+
+import hydra
+import logging
+from typing import Optional
+from dataclasses import dataclass, field
+from omegaconf import DictConfig, ListConfig, OmegaConf
+from aac_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig
+
+@dataclass
+class RunConfig:
+ dataset_config: DataConfig = field(default_factory=DataConfig)
+ model_config: ModelConfig = field(default_factory=ModelConfig)
+ train_config: TrainConfig = field(default_factory=TrainConfig)
+ log_config: LogConfig = field(default_factory=LogConfig)
+ fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
+ debug: bool = field(default=False, metadata={"help": "Use pdb when true"})
+ metric: str = field(default="acc", metadata={"help": "The metric for evaluation"})
+ ckpt_path: Optional[str] = field(
+ default=None, metadata={"help": "The path to projector checkpoint"}
+ )
+ peft_ckpt: Optional[str] = field(
+ default=None, metadata={"help": "The path to peft checkpoint"}
+ )
+
+@hydra.main(config_name=None, version_base=None)
+def main_hydra(cfg: DictConfig):
+ run_config = RunConfig()
+ cfg = OmegaConf.merge(run_config, cfg)
+ def to_plain_list(cfg_item):
+ if isinstance(cfg_item, ListConfig):
+ return OmegaConf.to_container(cfg_item, resolve=True)
+ elif isinstance(cfg_item, DictConfig):
+ return {k: to_plain_list(v) for k, v in cfg_item.items()}
+ else:
+ return cfg_item
+
+ # kwargs = to_plain_list(cfg)
+ kwargs = cfg
+ log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
+
+ logging.basicConfig(level=log_level)
+
+ if kwargs.get("debug", False):
+ import pdb;
+ pdb.set_trace()
+
+ train(kwargs)
+
+
+if __name__ == "__main__":
+ main_hydra()
\ No newline at end of file
diff --git a/examples/slam_aac/inference_aac_batch.py b/examples/slam_aac/inference_aac_batch.py
new file mode 100644
index 00000000..357b43ec
--- /dev/null
+++ b/examples/slam_aac/inference_aac_batch.py
@@ -0,0 +1,53 @@
+from slam_llm.pipeline.inference_batch import main as inference
+
+import hydra
+import logging
+from dataclasses import dataclass, field
+from omegaconf import DictConfig, ListConfig, OmegaConf
+from typing import Optional
+from aac_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig
+
+
+@dataclass
+class RunConfig:
+ dataset_config: DataConfig = field(default_factory=DataConfig)
+ model_config: ModelConfig = field(default_factory=ModelConfig)
+ train_config: TrainConfig = field(default_factory=TrainConfig)
+ log_config: LogConfig = field(default_factory=LogConfig)
+ fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
+ debug: bool = field(default=False, metadata={"help": "Use pdb when true"})
+ metric: str = field(default="acc", metadata={"help": "The metric for evaluation"})
+ decode_log: str = field(
+ default="output/decode_log",
+ metadata={"help": "The prefix for the decode output"},
+ )
+ ckpt_path: str = field(
+ default="output/model.pt", metadata={"help": "The path to projector checkpoint"}
+ )
+ peft_ckpt: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "The path to peft checkpoint, should be a directory including adapter_config.json"
+ },
+ )
+
+
+@hydra.main(config_name=None, version_base=None)
+def main_hydra(cfg: DictConfig):
+ run_config = RunConfig()
+ cfg = OmegaConf.merge(run_config, cfg)
+ # kwargs = to_plain_list(cfg)
+ log_level = getattr(logging, cfg.get("log_level", "INFO").upper())
+
+ logging.basicConfig(level=log_level)
+
+ if cfg.get("debug", False):
+ import pdb
+
+ pdb.set_trace()
+
+ inference(cfg)
+
+
+if __name__ == "__main__":
+ main_hydra()
diff --git a/examples/slam_aac/model/slam_model_aac.py b/examples/slam_aac/model/slam_model_aac.py
new file mode 100644
index 00000000..c4c33d79
--- /dev/null
+++ b/examples/slam_aac/model/slam_model_aac.py
@@ -0,0 +1,349 @@
+import torch
+import os
+import logging
+from slam_llm.models.slam_model import (
+ slam_model,
+ setup_tokenizer,
+ setup_encoder,
+ setup_encoder_projector,
+ setup_llm,
+)
+from slam_llm.utils.train_utils import print_model_size
+from torchaudio.transforms import Resample
+from slam_llm.models.BEATs.BEATs import BEATs
+from slam_llm.models.EAT.EAT import EAT_preprocess
+import torchaudio
+from typing import List, Optional
+from transformers import T5ForConditionalGeneration
+from slam_llm.utils.metric import compute_accuracy
+import numpy as np
+import torch.nn.functional as F
+
+logger = logging.getLogger(__name__)
+
+def model_factory(train_config, model_config, **kwargs):
+ # return necessary components for training
+ tokenizer = setup_tokenizer(train_config, model_config, **kwargs)
+
+ encoder = setup_encoder(train_config, model_config, **kwargs)
+
+ # llm
+ llm = setup_llm(train_config, model_config, **kwargs)
+
+ # projector
+ encoder_projector = setup_encoder_projector(
+ train_config, model_config, **kwargs
+ )
+ model = slam_model_aac(
+ encoder,
+ llm,
+ encoder_projector,
+ tokenizer,
+ train_config,
+ model_config,
+ **kwargs,
+ )
+
+ ckpt_path = kwargs.get(
+ "ckpt_path", None
+ ) # FIX(MZY): load model ckpt(mainly projector, related to model_checkpointing/checkpoint_handler.py: save_model_checkpoint_peft)
+ if ckpt_path is not None:
+ logger.info("loading other parts from: {}".format(ckpt_path))
+ ckpt_dict = torch.load(ckpt_path, map_location="cpu")
+ model.load_state_dict(ckpt_dict, strict=False)
+
+ print_model_size(
+ model,
+ train_config,
+ (
+ int(os.environ["RANK"])
+ if train_config.enable_fsdp or train_config.enable_ddp
+ else 0
+ ),
+ )
+ return model, tokenizer
+
+
+class slam_model_aac(slam_model):
+ def __init__(
+ self,
+ encoder,
+ llm,
+ encoder_projector,
+ tokenizer,
+ train_config,
+ model_config,
+ **kwargs,
+ ):
+ super().__init__(
+ encoder,
+ llm,
+ encoder_projector,
+ tokenizer,
+ train_config,
+ model_config,
+ **kwargs,
+ )
+
+
+ def forward(self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ):
+
+ audio_mel = kwargs.get("audio_mel", None)
+ audio_mel_mask = kwargs.get("audio_mel_mask", None)
+ audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None) # 2x downsample for whisper
+
+ audio = kwargs.get("audio", None)
+ audio_mask = kwargs.get("audio_mask", None)
+ visual = kwargs.get("visual", None)
+
+ # for text encoder
+ instruct_ids = kwargs.get("instruct_ids", None)
+ instruct_mask = kwargs.get("instruct_mask", None)
+
+ modality_mask = kwargs.get("modality_mask", None)
+
+ if audio_mel is not None:
+ audio_mel = audio_mel.unsqueeze(dim=1)
+
+ # noise aug
+ if audio_mel is not None and self.train_config.noise_aug and self.llm.training:
+ audio_mel = audio_mel + torch.rand((audio_mel.shape[2], audio_mel.shape[3]),device="cuda") * np.random.rand() / 10
+
+ # Specaug
+ if audio_mel is not None and self.train_config.specaug and self.llm.training:
+ from torchlibrosa.augmentation import SpecAugmentation
+ spec_augmenter = SpecAugmentation(time_drop_width=64,
+ time_stripes_num=2,
+ freq_drop_width=8,
+ freq_stripes_num=2)
+ audio_mel = spec_augmenter(audio_mel)
+
+ encoder_outs = None
+ if audio_mel is not None or audio is not None:
+ if self.model_config.encoder_name == "whisper":
+ encoder_outs = self.encoder.extract_variable_length_features(audio_mel.permute(0, 2, 1)) # bs*seq*dim
+ if self.model_config.encoder_name == "beats":
+ encoder_outs, audio_mel_post_mask = self.encoder.extract_features(audio_mel.squeeze(dim=1), padding_mask = audio_mel_mask, feature_only = True) # bs*seq*dim
+ if self.model_config.encoder_name == "eat":
+ encoder_outs = self.encoder.model.extract_features(audio_mel, padding_mask = None, mask=False, remove_extra_tokens = False)['x']
+ if self.model_config.encoder_name == "clap":
+ if text is not None:
+ encoder_outs = self.encoder.encode_text(text).unsqueeze(1) # [btz, 1, dim]
+ elif audio is not None:
+ encoder_outs = self.encoder.encode_audio(audio) # with projection-based decoding
+ if self.model_config.encoder_name == "SpatialAST":
+ encoder_outs = self.encoder(audio) # output: [bs, seq_len=3+512, dim=768]
+ if self.model_config.encoder_name == "wavlm":
+ encoder_outs = self.encoder.extract_features(audio, 1 - audio_mask) #(FIX:MZY): 1-audio_mask is needed for wavlm as the padding mask
+ if self.model_config.encoder_name == "hubert":
+ results = self.encoder(source = audio, padding_mask = 1-audio_mask)
+ if self.model_config.encoder_type == "pretrain":
+ encoder_outs, audio_mel_post_mask = results["x"], results["padding_mask"]
+ if self.model_config.encoder_type == "finetune":
+ encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"]
+ encoder_outs = encoder_outs.transpose(0, 1)
+ if self.model_config.encoder_name == "av_hubert":
+ results = self.encoder(source={'video':visual, 'audio':audio}, padding_mask=visual_mask) # bs*seq*dim
+ encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"]
+ encoder_outs = encoder_outs.transpose(0, 1)
+ audio_mel_post_mask = (~audio_mel_post_mask).float()
+ if self.model_config.encoder_name == 'musicfm':
+ encoder_outs = self.encoder.extract_features(audio, padding_mask = None) # MusicFM doesn't support padding mask
+ if self.encoder is None:
+ encoder_outs = audio_mel if audio_mel is not None else audio
+
+ if self.model_config.encoder_projector == "q-former":
+ encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask)
+ if self.model_config.encoder_projector == "linear":
+ encoder_outs = self.encoder_projector(encoder_outs)
+ if self.model_config.encoder_projector == "cov1d-linear":
+ encoder_outs = self.encoder_projector(encoder_outs)
+
+ if instruct_ids is not None:
+ if self.encoder is not None:
+ encoder_outs = self.encoder(input_ids=instruct_ids, attention_mask=instruct_mask).last_hidden_state
+
+ if self.model_config.encoder_projector == "q-former":
+ encoder_outs = self.encoder_projector(encoder_outs, instruct_mask)
+ if self.model_config.encoder_projector == "linear":
+ encoder_outs = self.encoder_projector(encoder_outs)
+
+ if input_ids is not None:
+ input_ids[input_ids == -1] = 0
+ if isinstance(self.llm, T5ForConditionalGeneration):
+ inputs_embeds = self.llm.shared(input_ids)
+ else:
+ if hasattr(self.llm.model, "embed_tokens"):
+ inputs_embeds = self.llm.model.embed_tokens(input_ids)
+ elif hasattr(self.llm.model.model, "embed_tokens"):
+ inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
+ else:
+ inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)
+
+ if modality_mask is not None:
+ modality_mask_start_indices = (modality_mask == True).float().argmax(dim=1)
+ modality_lengths = torch.clamp(modality_mask.sum(dim=1), max=encoder_outs.shape[1]).tolist()
+
+ encoder_outs_pad = torch.zeros_like(inputs_embeds)
+ for i in range(encoder_outs.shape[0]):
+ encoder_outs_pad[
+ i, modality_mask_start_indices[i]:modality_mask_start_indices[i]+modality_lengths[i]
+ ] = encoder_outs[i][:modality_lengths[i]]
+
+ inputs_embeds = encoder_outs_pad + inputs_embeds * (~modality_mask[:, :, None])
+
+ if kwargs.get("inference_mode", False):
+ return inputs_embeds, attention_mask
+
+
+ model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
+ acc = -1
+ if self.metric:
+ with torch.no_grad():
+ preds = torch.argmax(model_outputs.logits, -1)
+ acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=-100)
+
+ return model_outputs, acc
+
+
+ @torch.no_grad()
+ def generate(self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ):
+ kwargs["inference_mode"] = True
+
+ inputs_embeds, attention_mask = self.forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ **kwargs,
+ )
+
+ model_outputs = self.llm.generate(
+ inputs_embeds=inputs_embeds,
+ # max_length=kwargs.get("max_length", 200),
+ max_new_tokens=self.model_config.max_new_tokens,
+ num_beams=self.model_config.num_beams,
+ num_return_sequences=self.model_config.num_return_sequences,
+ do_sample=self.model_config.do_sample,
+ min_length=self.model_config.min_length,
+ top_p=self.model_config.top_p,
+ repetition_penalty=self.model_config.repetition_penalty,
+ length_penalty=self.model_config.length_penalty,
+ temperature=self.model_config.temperature,
+ attention_mask=attention_mask,
+ bos_token_id=self.tokenizer.bos_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ pad_token_id=self.tokenizer.pad_token_id
+ )
+
+ return model_outputs
+
+ @torch.no_grad()
+ def inference(
+ self,
+ wav_path = None,
+ prompt = None,
+ dataset_config = None,
+ generation_config = None,
+ logits_processor = None,
+ stopping_criteria = None,
+ prefix_allowed_tokens_fn = None,
+ synced_gpus = None,
+ assistant_model = None,
+ streamer = None,
+ negative_prompt_ids = None,
+ negative_prompt_attention_mask = None,
+ **kwargs,
+ ):
+ device = kwargs.get("device", "cuda")
+ if os.path.exists(wav_path):
+ try:
+ audio_raw, sample_rate = torchaudio.load(wav_path)
+ if audio_raw.shape[1] == 0:
+ raise ValueError("Empty audio file")
+ resampler = Resample(orig_freq=sample_rate, new_freq=16000)
+ audio_raw = resampler(audio_raw)
+
+ except (FileNotFoundError, ValueError, RuntimeError):
+ audio_raw = torch.zeros(1, 16000)
+
+ if self.model_config.encoder_name == "beats":
+ audio_mel = BEATs.preprocess(audio_raw[0], fbank_mean=dataset_config.fbank_mean, fbank_std=dataset_config.fbank_std)
+ elif self.model_config.encoder_name == "eat":
+ audio_mel = EAT_preprocess(source=audio_raw[0],norm_mean=dataset_config.fbank_mean,norm_std=dataset_config.fbank_std,
+ target_length=dataset_config.target_length,fixed_length=dataset_config.fixed_length,random_crop=dataset_config.random_crop)
+ else:
+ pass
+
+ audio_mel = audio_mel.unsqueeze(dim=0)
+ audio_mel_mask = torch.ones_like(audio_mel)
+ audio_mel = audio_mel.to(device)
+ audio_mel_mask = audio_mel_mask.to(device)
+
+ if self.model_config.encoder_name == "beats":
+ encoder_outs, audio_mel_post_mask = self.encoder.extract_features(audio_mel, padding_mask = audio_mel_mask, feature_only = True)
+ if self.model_config.encoder_name == "eat":
+ encoder_outs = self.encoder.model.extract_features(audio_mel.unsqueeze(dim=1), padding_mask = None, mask=False, remove_extra_tokens = False)['x']
+
+ if self.model_config.encoder_projector == "q-former":
+ audio_mel_post_mask = torch.ones(encoder_outs.size()[:-1], dtype=torch.long).to(encoder_outs.device)
+ encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask)
+ if self.model_config.encoder_projector == "linear":
+ encoder_outs = self.encoder_projector(encoder_outs)
+ else: # Text QA
+ encoder_outs = torch.empty(1, 0, self.llm.model.embed_tokens.embedding_dim).to(device)
+
+ prompt = "USER: {} \n ASSISTANT:".format(prompt)
+ prompt_ids = self.tokenizer.encode(prompt)
+ prompt_length = len(prompt_ids)
+ prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(device)
+
+ if hasattr(self.llm.model, "embed_tokens"):
+ inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
+ elif hasattr(self.llm.model.model, "embed_tokens"):
+ inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids)
+ else:
+ inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
+
+ inputs_embeds = torch.cat((encoder_outs, inputs_embeds[None, :, :]), dim=1) # [audio,prompt]
+
+ attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(inputs_embeds.device)
+
+ # generate
+ model_outputs = self.generate(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ **kwargs
+ )
+
+ return model_outputs
diff --git a/examples/slam_aac/scripts/clap_refine.sh b/examples/slam_aac/scripts/clap_refine.sh
new file mode 100644
index 00000000..e7b76550
--- /dev/null
+++ b/examples/slam_aac/scripts/clap_refine.sh
@@ -0,0 +1,20 @@
+export CUDA_VISIBLE_DEVICES=0
+export HF_ENDPOINT=https://hf-mirror.com
+
+run_dir=/data/wenxi.chen/SLAM-LLM
+cd $run_dir
+code_dir=examples/slam_aac
+
+clap_dir=/data/xiquan.li/models/clap
+inference_data_path=/data/wenxi.chen/data/audiocaps/new_test.jsonl
+output_dir=/data/wenxi.chen/cp/aac_epoch_2_step_182_audiocaps_seed42
+
+echo "Running CLAP-Refine"
+
+# -m debugpy --listen 6666 --wait-for-client
+python ${code_dir}/utils/clap_refine.py \
+ --start_beam 2 --end_beam 8 \
+ --clap_ckpt $clap_dir/best_model.pt \
+ --config $clap_dir/clap_config.yaml \
+ --test_jsonl $inference_data_path \
+ --exp_explorer $output_dir
\ No newline at end of file
diff --git a/examples/slam_aac/scripts/finetune_audiocaps.sh b/examples/slam_aac/scripts/finetune_audiocaps.sh
new file mode 100644
index 00000000..7aae293a
--- /dev/null
+++ b/examples/slam_aac/scripts/finetune_audiocaps.sh
@@ -0,0 +1,100 @@
+#!/bin/bash
+export PYTHONPATH=/root/fairseq:$PYTHONPATH
+export CUDA_VISIBLE_DEVICES=1
+export TOKENIZERS_PARALLELISM=false
+export OMP_NUM_THREADS=7
+
+
+run_dir=/data/wenxi.chen/SLAM-LLM
+cd $run_dir
+code_dir=examples/slam_aac
+
+audio_encoder_path=/data/xiquan.li/models/EAT-base_epoch30_ft.pt
+llm_path=/data/xiquan.li/models/vicuna-7b-v1.5
+
+seed=42
+btz=4
+lr=8e-6
+encoder_projector_ds_rate=5
+
+train_jsonl_path=/data/wenxi.chen/data/audiocaps/train.jsonl
+val_jsonl_path=/data/wenxi.chen/data/audiocaps/val.jsonl
+
+exp_name=slam-aac_AudioCaps_fine-tune
+output_dir=/data/wenxi.chen/exps/AudioCaps/${exp_name}
+
+ckpt_path=/data/wenxi.chen/cp/wavcaps_pt_v7-seed666_btz16_lr1e-4-short_prompt_10w/aac_epoch_4_step_4001/model.pt # path to load the pre-trained model
+peft_ckpt=/data/wenxi.chen/cp/wavcaps_pt_v7-seed666_btz16_lr1e-4-short_prompt_10w/aac_epoch_4_step_4001
+# ↑ This parameter is required for loading the old version of the SLAM-LLM model. Our released checkpoint uses the old version. In the new version, this parameter is no longer needed.
+
+hydra_args="
+hydra.run.dir=$output_dir \
+++model_config.llm_name=vicuna-7b-v1.5 \
+++model_config.llm_path=$llm_path \
+++model_config.llm_dim=4096 \
+++model_config.encoder_name=eat \
+++model_config.encoder_ds_rate=2 \
+++model_config.encoder_projector_ds_rate=$encoder_projector_ds_rate \
+++model_config.encoder_path=$audio_encoder_path \
+++model_config.encoder_dim=768 \
+++model_config.encoder_projector=linear \
+++dataset_config.encoder_projector_ds_rate=${encoder_projector_ds_rate} \
+++dataset_config.dataset=audio_dataset \
+++dataset_config.train_data_path=$train_jsonl_path \
+++dataset_config.val_data_path=$val_jsonl_path \
+++dataset_config.input_type=mel \
+++dataset_config.fbank_mean=-4.268 \
+++dataset_config.fbank_std=4.569 \
+++dataset_config.model_name=eat \
+++dataset_config.fixed_length=true \
+++dataset_config.target_length=1024 \
+++train_config.model_name=aac \
+++train_config.freeze_encoder=true \
+++train_config.freeze_llm=false \
+++train_config.batching_strategy=custom \
+++train_config.warmup_steps=1000 \
+++train_config.total_steps=100000 \
+++train_config.lr=$lr \
+++train_config.validation_interval=500 \
+++train_config.batch_size_training=$btz \
+++train_config.val_batch_size=$btz \
+++train_config.num_workers_dataloader=4 \
+++train_config.use_fp16=true \
+++train_config.output_dir=$output_dir \
+++train_config.seed=${seed} \
+++train_config.use_peft=true \
+++train_config.peft_config.peft_method=lora \
+++train_config.specaug=true \
+++log_config.log_file="${output_dir}/train.log" \
+++log_config.wandb_dir=${output_dir} \
+++log_config.wandb_entity_name=wxc12 \
+++log_config.wandb_project_name=slam-llm \
+++log_config.wandb_exp_name=$exp_name \
+++log_config.use_wandb=true \
+++metric=acc \
+++ckpt_path=$ckpt_path \
+++peft_ckpt=$peft_ckpt \
+"
+
+# note: to train the linear layer only, you could set '++train_config.use_peft=false' and 'train_config.freeze_llm=true'
+# -m debugpy --listen 5678 --wait-for-client
+if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then
+ python $code_dir/finetune_aac.py \
+ --config-path "conf" \
+ --config-name "prompt.yaml" \
+ $hydra_args
+else
+ torchrun \
+ --nnodes 1 \
+ --nproc_per_node 2 \
+ --master_port=29503 \
+ $code_dir/finetune_asr.py \
+ --config-path "conf" \
+ --config-name "prompt.yaml" \
+ ++train_config.enable_fsdp=false \
+ ++train_config.enable_ddp=true \
+ ++train_config.use_fp16=true \
+ $hydra_args
+fi
+
+# bash /data/wenxi.chen/SLAM-LLM/examples/slam_aac/scripts/finetune_audiocaps.sh
\ No newline at end of file
diff --git a/examples/slam_aac/scripts/finetune_clotho.sh b/examples/slam_aac/scripts/finetune_clotho.sh
new file mode 100644
index 00000000..94ed81c6
--- /dev/null
+++ b/examples/slam_aac/scripts/finetune_clotho.sh
@@ -0,0 +1,100 @@
+#!/bin/bash
+export PYTHONPATH=/root/fairseq:$PYTHONPATH
+export CUDA_VISIBLE_DEVICES=0
+export TOKENIZERS_PARALLELISM=false
+export OMP_NUM_THREADS=7
+
+
+run_dir=/data/wenxi.chen/SLAM-LLM
+cd $run_dir
+code_dir=examples/slam_aac
+
+audio_encoder_path=/data/xiquan.li/models/EAT-base_epoch30_ft.pt
+llm_path=/data/xiquan.li/models/vicuna-7b-v1.5
+
+seed=10086
+btz=4
+lr=8e-6
+encoder_projector_ds_rate=5
+
+train_jsonl_path=/data/wenxi.chen/data/clotho/development.jsonl
+val_jsonl_path=/data/wenxi.chen/data/clotho/validation.jsonl
+
+exp_name=slam-aac_Clotho_fine-tune
+output_dir=/data/wenxi.chen/exps/Clotho/${exp_name}
+
+ckpt_path=/data/wenxi.chen/cp/wavcaps_pt_v7-seed666_btz16_lr1e-4-short_prompt_10w/aac_epoch_4_step_4001/model.pt # path to load the pre-trained model
+peft_ckpt=/data/wenxi.chen/cp/wavcaps_pt_v7-seed666_btz16_lr1e-4-short_prompt_10w/aac_epoch_4_step_4001
+# ↑ This parameter is required for loading the old version of the SLAM-LLM model. Our released checkpoint uses the old version. In the new version, this parameter is no longer needed.
+
+hydra_args="
+hydra.run.dir=$output_dir \
+++model_config.llm_name=vicuna-7b-v1.5 \
+++model_config.llm_path=$llm_path \
+++model_config.llm_dim=4096 \
+++model_config.encoder_name=eat \
+++model_config.encoder_ds_rate=2 \
+++model_config.encoder_projector_ds_rate=$encoder_projector_ds_rate \
+++model_config.encoder_path=$audio_encoder_path \
+++model_config.encoder_dim=768 \
+++model_config.encoder_projector=linear \
+++dataset_config.encoder_projector_ds_rate=${encoder_projector_ds_rate} \
+++dataset_config.dataset=audio_dataset \
+++dataset_config.train_data_path=$train_jsonl_path \
+++dataset_config.val_data_path=$val_jsonl_path \
+++dataset_config.input_type=mel \
+++dataset_config.fbank_mean=-4.268 \
+++dataset_config.fbank_std=4.569 \
+++dataset_config.model_name=eat \
+++dataset_config.fixed_length=true \
+++dataset_config.target_length=1024 \
+++train_config.model_name=aac \
+++train_config.freeze_encoder=true \
+++train_config.freeze_llm=false \
+++train_config.batching_strategy=custom \
+++train_config.warmup_steps=1000 \
+++train_config.total_steps=100000 \
+++train_config.lr=$lr \
+++train_config.validation_interval=500 \
+++train_config.batch_size_training=$btz \
+++train_config.val_batch_size=$btz \
+++train_config.num_workers_dataloader=4 \
+++train_config.use_fp16=true \
+++train_config.output_dir=$output_dir \
+++train_config.seed=${seed} \
+++train_config.use_peft=true \
+++train_config.peft_config.peft_method=lora \
+++train_config.specaug=true \
+++log_config.log_file="${output_dir}/train.log" \
+++log_config.wandb_dir=${output_dir} \
+++log_config.wandb_entity_name=wxc12 \
+++log_config.wandb_project_name=slam-llm \
+++log_config.wandb_exp_name=$exp_name \
+++log_config.use_wandb=true \
+++metric=acc \
+++ckpt_path=$ckpt_path \
+++peft_ckpt=$peft_ckpt \
+"
+
+# note: to train the linear layer only, you could set '++train_config.use_peft=false' and 'train_config.freeze_llm=true'
+# -m debugpy --listen 5678 --wait-for-client
+if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then
+ python $code_dir/finetune_aac.py \
+ --config-path "conf" \
+ --config-name "prompt.yaml" \
+ $hydra_args
+else
+ torchrun \
+ --nnodes 1 \
+ --nproc_per_node 2 \
+ --master_port=29503 \
+ $code_dir/finetune_asr.py \
+ --config-path "conf" \
+ --config-name "prompt.yaml" \
+ ++train_config.enable_fsdp=false \
+ ++train_config.enable_ddp=true \
+ ++train_config.use_fp16=true \
+ $hydra_args
+fi
+
+# bash /data/wenxi.chen/SLAM-LLM/examples/slam_aac/scripts/finetune_clotho.sh
\ No newline at end of file
diff --git a/examples/slam_aac/scripts/inference_audiocaps_CLAP_Refine.sh b/examples/slam_aac/scripts/inference_audiocaps_CLAP_Refine.sh
new file mode 100644
index 00000000..aec6df91
--- /dev/null
+++ b/examples/slam_aac/scripts/inference_audiocaps_CLAP_Refine.sh
@@ -0,0 +1,82 @@
+#!/bin/bash
+export CUDA_VISIBLE_DEVICES=1
+export TOKENIZERS_PARALLELISM=false
+
+run_dir=/data/wenxi.chen/SLAM-LLM
+cd $run_dir
+code_dir=examples/slam_aac
+
+audio_encoder_path=/data/xiquan.li/models/EAT-base_epoch30_ft.pt
+llm_path=/data/xiquan.li/models/vicuna-7b-v1.5
+clap_dir=/data/xiquan.li/models/clap
+
+encoder_projector_ds_rate=5
+
+inference_data_path=/data/wenxi.chen/data/audiocaps/new_test.jsonl
+output_dir=/data/wenxi.chen/cp/aac_epoch_2_step_182_audiocaps_seed42
+
+# define the beam size range
+beam_range=(2 3 4 5 6 7 8)
+
+for num_beams in "${beam_range[@]}"; do
+ decode_log=$output_dir/decode_beam${num_beams}
+
+ if [ -f "${decode_log}_pred" ]; then
+ echo "Decode log ${decode_log}_pred already exists, skipping this beam size..."
+ continue
+ fi
+
+ echo "Running inference with num_beams=$num_beams"
+
+ python $code_dir/inference_aac_batch.py \
+ --config-path "conf" \
+ --config-name "prompt.yaml" \
+ hydra.run.dir=$output_dir \
+ ++model_config.llm_name="vicuna-7b-v1.5" \
+ ++model_config.llm_path=$llm_path \
+ ++model_config.llm_dim=4096 \
+ ++model_config.encoder_name=eat \
+ ++model_config.encoder_path=$audio_encoder_path \
+ ++model_config.encoder_dim=768 \
+ ++model_config.encoder_projector=linear \
+ ++model_config.encoder_projector_ds_rate=$encoder_projector_ds_rate \
+ ++model_config.normalize=true \
+ ++dataset_config.encoder_projector_ds_rate=$encoder_projector_ds_rate \
+ ++dataset_config.dataset=audio_dataset \
+ ++dataset_config.val_data_path=$inference_data_path \
+ ++dataset_config.fbank_mean=-4.268 \
+ ++dataset_config.fbank_std=4.569 \
+ ++dataset_config.model_name=eat \
+ ++dataset_config.inference_mode=true \
+ ++dataset_config.normalize=true \
+ ++dataset_config.input_type=mel \
+ ++dataset_config.fixed_length=true \
+ ++dataset_config.target_length=1024 \
+ ++train_config.model_name=aac \
+ ++train_config.batching_strategy=custom \
+ ++train_config.num_epochs=1 \
+ ++train_config.val_batch_size=4 \
+ ++train_config.num_workers_dataloader=0 \
+ ++train_config.output_dir=$output_dir \
+ ++train_config.freeze_encoder=true \
+ ++train_config.freeze_llm=false \
+ ++train_config.use_peft=true \
+ ++ckpt_path=$output_dir/model.pt \
+ ++peft_ckpt=$output_dir \
+ ++decode_log=$decode_log \
+ ++model_config.num_beams=$num_beams
+done
+
+# note: to inference model trained the linear layer only, you could set '++train_config.use_peft=false' and 'train_config.freeze_llm=true'
+
+echo "Running CLAP-Refine"
+
+# -m debugpy --listen 6666 --wait-for-client
+python ${code_dir}/utils/clap_refine.py \
+ --start_beam 2 --end_beam 8 \
+ --clap_ckpt $clap_dir/best_model.pt \
+ --config $clap_dir/clap_config.yaml \
+ --test_jsonl $inference_data_path \
+ --exp_explorer $output_dir
+
+# bash /data/wenxi.chen/SLAM-LLM/examples/slam_aac/scripts/inference_audiocaps_CLAP_Refine.sh
\ No newline at end of file
diff --git a/examples/slam_aac/scripts/inference_audiocaps_bs.sh b/examples/slam_aac/scripts/inference_audiocaps_bs.sh
new file mode 100644
index 00000000..b97e3fff
--- /dev/null
+++ b/examples/slam_aac/scripts/inference_audiocaps_bs.sh
@@ -0,0 +1,60 @@
+#!/bin/bash
+export CUDA_VISIBLE_DEVICES=0
+export TOKENIZERS_PARALLELISM=false
+
+run_dir=/data/wenxi.chen/SLAM-LLM
+cd $run_dir
+code_dir=examples/slam_aac
+
+audio_encoder_path=/data/xiquan.li/models/EAT-base_epoch30_ft.pt
+llm_path=/data/xiquan.li/models/vicuna-7b-v1.5
+
+encoder_projector_ds_rate=5
+num_beams=4
+
+inference_data_path=/data/wenxi.chen/data/audiocaps/new_test.jsonl
+output_dir=/data/wenxi.chen/cp/aac_epoch_2_step_182_audiocaps_seed42
+decode_log=$output_dir/decode_beam${num_beams}
+
+
+# -m debugpy --listen 5678 --wait-for-client
+python $code_dir/inference_aac_batch.py \
+ --config-path "conf" \
+ --config-name "prompt.yaml" \
+ hydra.run.dir=$output_dir \
+ ++model_config.llm_name="vicuna-7b-v1.5" \
+ ++model_config.llm_path=$llm_path \
+ ++model_config.llm_dim=4096 \
+ ++model_config.encoder_name=eat \
+ ++model_config.encoder_path=$audio_encoder_path \
+ ++model_config.encoder_dim=768 \
+ ++model_config.encoder_projector=linear \
+ ++model_config.encoder_projector_ds_rate=$encoder_projector_ds_rate \
+ ++model_config.normalize=true \
+ ++dataset_config.encoder_projector_ds_rate=$encoder_projector_ds_rate \
+ ++dataset_config.dataset=audio_dataset \
+ ++dataset_config.val_data_path=$inference_data_path \
+ ++dataset_config.fbank_mean=-4.268 \
+ ++dataset_config.fbank_std=4.569 \
+ ++dataset_config.model_name=eat \
+ ++dataset_config.inference_mode=true \
+ ++dataset_config.normalize=true \
+ ++dataset_config.input_type=mel \
+ ++dataset_config.fixed_length=true \
+ ++dataset_config.target_length=1024 \
+ ++train_config.model_name=aac \
+ ++train_config.batching_strategy=custom \
+ ++train_config.num_epochs=1 \
+ ++train_config.val_batch_size=4 \
+ ++train_config.num_workers_dataloader=0 \
+ ++train_config.output_dir=$output_dir \
+ ++train_config.freeze_encoder=true \
+ ++train_config.freeze_llm=false \
+ ++train_config.use_peft=true \
+ ++ckpt_path=$output_dir/model.pt \
+ ++peft_ckpt=$output_dir \
+ ++decode_log=$decode_log \
+ ++model_config.num_beams=$num_beams
+
+# note: to inference model trained the linear layer only, you could set '++train_config.use_peft=false' and 'train_config.freeze_llm=true'
+# bash /data/wenxi.chen/SLAM-LLM/examples/slam_aac/scripts/inference_audiocaps_bs.sh
diff --git a/examples/slam_aac/scripts/inference_clotho_CLAP_Refine.sh b/examples/slam_aac/scripts/inference_clotho_CLAP_Refine.sh
new file mode 100644
index 00000000..d14b186b
--- /dev/null
+++ b/examples/slam_aac/scripts/inference_clotho_CLAP_Refine.sh
@@ -0,0 +1,82 @@
+#!/bin/bash
+export CUDA_VISIBLE_DEVICES=2
+export TOKENIZERS_PARALLELISM=false
+
+run_dir=/data/wenxi.chen/SLAM-LLM
+cd $run_dir
+code_dir=examples/slam_aac
+
+audio_encoder_path=/data/xiquan.li/models/EAT-base_epoch30_ft.pt
+llm_path=/data/xiquan.li/models/vicuna-7b-v1.5
+clap_dir=/data/xiquan.li/models/clap
+
+encoder_projector_ds_rate=5
+
+inference_data_path=/data/wenxi.chen/data/clotho/evaluation_single.jsonl
+output_dir=/data/wenxi.chen/cp/wavcaps_pt_v7_epoch4-clotho_ft-seed10086_btz4_lr8e-6-short_prompt_10w/aac_epoch_1_step_4500
+
+# define the beam size range
+beam_range=(2 3 4 5 6 7 8)
+
+for num_beams in "${beam_range[@]}"; do
+ decode_log=$output_dir/decode_beam${num_beams}
+
+ if [ -f "${decode_log}_pred" ]; then
+ echo "Decode log ${decode_log}_pred already exists, skipping this beam size..."
+ continue
+ fi
+
+ echo "Running inference with num_beams=$num_beams"
+
+ python $code_dir/inference_aac_batch.py \
+ --config-path "conf" \
+ --config-name "prompt.yaml" \
+ hydra.run.dir=$output_dir \
+ ++model_config.llm_name="vicuna-7b-v1.5" \
+ ++model_config.llm_path=$llm_path \
+ ++model_config.llm_dim=4096 \
+ ++model_config.encoder_name=eat \
+ ++model_config.encoder_path=$audio_encoder_path \
+ ++model_config.encoder_dim=768 \
+ ++model_config.encoder_projector=linear \
+ ++model_config.encoder_projector_ds_rate=$encoder_projector_ds_rate \
+ ++model_config.normalize=true \
+ ++dataset_config.encoder_projector_ds_rate=$encoder_projector_ds_rate \
+ ++dataset_config.dataset=audio_dataset \
+ ++dataset_config.val_data_path=$inference_data_path \
+ ++dataset_config.fbank_mean=-4.268 \
+ ++dataset_config.fbank_std=4.569 \
+ ++dataset_config.model_name=eat \
+ ++dataset_config.inference_mode=true \
+ ++dataset_config.normalize=true \
+ ++dataset_config.input_type=mel \
+ ++dataset_config.fixed_length=true \
+ ++dataset_config.target_length=1024 \
+ ++train_config.model_name=aac \
+ ++train_config.batching_strategy=custom \
+ ++train_config.num_epochs=1 \
+ ++train_config.val_batch_size=4 \
+ ++train_config.num_workers_dataloader=0 \
+ ++train_config.output_dir=$output_dir \
+ ++train_config.freeze_encoder=true \
+ ++train_config.freeze_llm=false \
+ ++train_config.use_peft=true \
+ ++ckpt_path=$output_dir/model.pt \
+ ++peft_ckpt=$output_dir \
+ ++decode_log=$decode_log \
+ ++model_config.num_beams=$num_beams
+done
+
+# note: to inference model trained the linear layer only, you could set '++train_config.use_peft=false' and 'train_config.freeze_llm=true'
+
+echo "Running CLAP-Refine"
+
+# -m debugpy --listen 6666 --wait-for-client
+python ${code_dir}/utils/clap_refine.py \
+ --start_beam 2 --end_beam 8 \
+ --clap_ckpt $clap_dir/best_model.pt \
+ --config $clap_dir/clap_config.yaml \
+ --test_jsonl $inference_data_path \
+ --exp_explorer $output_dir
+
+# bash /data/wenxi.chen/SLAM-LLM/examples/slam_aac/scripts/inference_clotho_CLAP_Refine.sh
\ No newline at end of file
diff --git a/examples/slam_aac/scripts/inference_clotho_bs.sh b/examples/slam_aac/scripts/inference_clotho_bs.sh
new file mode 100644
index 00000000..6e27784f
--- /dev/null
+++ b/examples/slam_aac/scripts/inference_clotho_bs.sh
@@ -0,0 +1,60 @@
+#!/bin/bash
+export CUDA_VISIBLE_DEVICES=1
+export TOKENIZERS_PARALLELISM=false
+
+run_dir=/data/wenxi.chen/SLAM-LLM
+cd $run_dir
+code_dir=examples/slam_aac
+
+audio_encoder_path=/data/xiquan.li/models/EAT-base_epoch30_ft.pt
+llm_path=/data/xiquan.li/models/vicuna-7b-v1.5
+
+encoder_projector_ds_rate=5
+num_beams=4
+
+inference_data_path=/data/wenxi.chen/data/clotho/evaluation_single.jsonl
+output_dir=/data/wenxi.chen/cp/wavcaps_pt_v7_epoch4-clotho_ft-seed10086_btz4_lr8e-6-short_prompt_10w/aac_epoch_1_step_4500
+decode_log=$output_dir/decode_beam${num_beams}
+
+
+# -m debugpy --listen 5678 --wait-for-client
+python $code_dir/inference_aac_batch.py \
+ --config-path "conf" \
+ --config-name "prompt.yaml" \
+ hydra.run.dir=$output_dir \
+ ++model_config.llm_name="vicuna-7b-v1.5" \
+ ++model_config.llm_path=$llm_path \
+ ++model_config.llm_dim=4096 \
+ ++model_config.encoder_name=eat \
+ ++model_config.encoder_path=$audio_encoder_path \
+ ++model_config.encoder_dim=768 \
+ ++model_config.encoder_projector=linear \
+ ++model_config.encoder_projector_ds_rate=$encoder_projector_ds_rate \
+ ++model_config.normalize=true \
+ ++dataset_config.encoder_projector_ds_rate=$encoder_projector_ds_rate \
+ ++dataset_config.dataset=audio_dataset \
+ ++dataset_config.val_data_path=$inference_data_path \
+ ++dataset_config.fbank_mean=-4.268 \
+ ++dataset_config.fbank_std=4.569 \
+ ++dataset_config.model_name=eat \
+ ++dataset_config.inference_mode=true \
+ ++dataset_config.normalize=true \
+ ++dataset_config.input_type=mel \
+ ++dataset_config.fixed_length=true \
+ ++dataset_config.target_length=1024 \
+ ++train_config.model_name=aac \
+ ++train_config.batching_strategy=custom \
+ ++train_config.num_epochs=1 \
+ ++train_config.val_batch_size=4 \
+ ++train_config.num_workers_dataloader=0 \
+ ++train_config.output_dir=$output_dir \
+ ++train_config.freeze_encoder=true \
+ ++train_config.freeze_llm=false \
+ ++train_config.use_peft=true \
+ ++ckpt_path=$output_dir/model.pt \
+ ++peft_ckpt=$output_dir \
+ ++decode_log=$decode_log \
+ ++model_config.num_beams=$num_beams
+
+# note: to inference model trained the linear layer only, you could set '++train_config.use_peft=false' and 'train_config.freeze_llm=true'
+# bash /data/wenxi.chen/SLAM-LLM/examples/slam_aac/scripts/inference_clotho_bs.sh
diff --git a/examples/slam_aac/scripts/pretrain.sh b/examples/slam_aac/scripts/pretrain.sh
new file mode 100644
index 00000000..1a099f29
--- /dev/null
+++ b/examples/slam_aac/scripts/pretrain.sh
@@ -0,0 +1,95 @@
+#!/bin/bash
+export PYTHONPATH=/root/fairseq:$PYTHONPATH
+export CUDA_VISIBLE_DEVICES=2
+export TOKENIZERS_PARALLELISM=false
+export OMP_NUM_THREADS=7
+
+
+run_dir=/data/wenxi.chen/SLAM-LLM
+cd $run_dir
+code_dir=examples/slam_aac
+
+audio_encoder_path=/data/xiquan.li/models/EAT-base_epoch30_ft.pt
+llm_path=/data/xiquan.li/models/vicuna-7b-v1.5
+
+seed=666
+btz=16
+lr=1e-4
+encoder_projector_ds_rate=5
+
+train_jsonl_path=/data/wenxi.chen/data/pretrain/merged_data_v7.jsonl
+val_jsonl_path=/data/wenxi.chen/data/clotho/validation.jsonl
+
+exp_name=slam-aac_pre-train
+output_dir=/root/exps/test/${exp_name}
+
+hydra_args="
+hydra.run.dir=$output_dir \
+++model_config.llm_name=vicuna-7b-v1.5 \
+++model_config.llm_path=$llm_path \
+++model_config.llm_dim=4096 \
+++model_config.encoder_name=eat \
+++model_config.encoder_ds_rate=2 \
+++model_config.encoder_projector_ds_rate=$encoder_projector_ds_rate \
+++model_config.encoder_path=$audio_encoder_path \
+++model_config.encoder_dim=768 \
+++model_config.encoder_projector=linear \
+++dataset_config.encoder_projector_ds_rate=${encoder_projector_ds_rate} \
+++dataset_config.dataset=audio_dataset \
+++dataset_config.train_data_path=$train_jsonl_path \
+++dataset_config.val_data_path=$val_jsonl_path \
+++dataset_config.input_type=mel \
+++dataset_config.fbank_mean=-4.268 \
+++dataset_config.fbank_std=4.569 \
+++dataset_config.model_name=eat \
+++dataset_config.fixed_length=true \
+++dataset_config.target_length=1024 \
+++train_config.model_name=aac \
+++train_config.freeze_encoder=true \
+++train_config.freeze_llm=false \
+++train_config.batching_strategy=custom \
+++train_config.warmup_steps=1000 \
+++train_config.total_steps=100000 \
+++train_config.lr=$lr \
+++train_config.validation_interval=500 \
+++train_config.batch_size_training=$btz \
+++train_config.val_batch_size=$btz \
+++train_config.num_workers_dataloader=4 \
+++train_config.use_fp16=true \
+++train_config.output_dir=$output_dir \
+++train_config.seed=${seed} \
+++train_config.use_peft=true \
+++train_config.run_validation=false \
+++train_config.peft_config.peft_method=lora \
+++train_config.specaug=true \
+++log_config.log_file="${output_dir}/train.log" \
+++log_config.wandb_dir=${output_dir} \
+++log_config.wandb_entity_name=wxc12 \
+++log_config.wandb_project_name=slam-llm \
+++log_config.wandb_exp_name=$exp_name \
+++log_config.use_wandb=true \
+++metric=acc \
+"
+
+# note: to train the linear layer only, you could set '++train_config.use_peft=false' and 'train_config.freeze_llm=true'
+# -m debugpy --listen 5678 --wait-for-client
+if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then
+ python $code_dir/finetune_aac.py \
+ --config-path "conf" \
+ --config-name "prompt.yaml" \
+ $hydra_args
+else
+ torchrun \
+ --nnodes 1 \
+ --nproc_per_node 2 \
+ --master_port=29503 \
+ $code_dir/finetune_asr.py \
+ --config-path "conf" \
+ --config-name "prompt.yaml" \
+ ++train_config.enable_fsdp=false \
+ ++train_config.enable_ddp=true \
+ ++train_config.use_fp16=true \
+ $hydra_args
+fi
+
+# bash /data/wenxi.chen/SLAM-LLM/examples/slam_aac/scripts/pretrain.sh
\ No newline at end of file
diff --git a/examples/slam_aac/utils/clap_refine.py b/examples/slam_aac/utils/clap_refine.py
new file mode 100644
index 00000000..d352b5a4
--- /dev/null
+++ b/examples/slam_aac/utils/clap_refine.py
@@ -0,0 +1,171 @@
+## CLAP-refine
+from ruamel import yaml
+import torch
+from tqdm import tqdm
+import argparse
+import json
+import torchaudio
+from torchaudio.transforms import Resample
+from slam_llm.models.CLAP.ase_model import ASE
+from torch.utils.data import Dataset, DataLoader
+import torch.nn.functional as F
+
+## we use dataloader to accelerate data I/O
+class caption_dataset(Dataset):
+ def __init__(self, captions, remove_dup=True) -> None:
+ super().__init__()
+ self.captions = captions
+
+ def __getitem__(self, index):
+ return self.captions[index]
+
+ def __len__(self):
+ return len(self.captions)
+
+class audio_dataset(Dataset):
+ def __init__(self, database) -> None:
+ super().__init__()
+ self.input_type = 'raw'
+ self.max_length = 10
+ self.sr = 32000
+ wav_path = []
+ for dataset in database:
+ with open(dataset, 'r') as f:
+ for line in f:
+ data = json.loads(line)
+ wav_path.append(data['source'])
+ self.wav_paths = wav_path
+
+ def __getitem__(self, index):
+ wav_path = self.wav_paths[index]
+ wav_info = torchaudio.info(wav_path)
+ waveform, sr = torchaudio.load(wav_path, num_frames=self.max_length*wav_info.sample_rate) #[1, wav_length]
+ if waveform.shape[-1] < 0.1*self.sr:
+ waveform = torch.zeros(self.max_length*self.sr)
+ else:
+ waveform = waveform[0]
+ if self.input_type == "raw":
+ resampler = Resample(orig_freq=sr, new_freq=32000) # 32k for HTSAT
+ elif self.input_type == "mel":
+ resampler = Resample(orig_freq=sr, new_freq=16000) # 16k for EAT
+ waveform = resampler(waveform)
+ return waveform
+
+ def __len__(self):
+ return len(self.wav_paths)
+
+ def collator(self, samples):
+ audio_list = []
+ max_length = max([i.shape[-1] for i in samples])
+
+ for audio in samples: # audio: raw or mel
+
+ if audio.dim() == 1: # raw
+ if audio.shape[-1] < max_length:
+ pad_length = max_length - audio.shape[-1]
+ audio = F.pad(audio, [0, pad_length], "constant", 0.0)
+ audio_list.append(audio)
+ elif audio.dim() == 2: # mel
+ audio_list.append(audio)
+
+ audios = torch.stack(audio_list, dim=0)
+ return audios
+
+def read_captions(decode_log):
+ audio_ids, captions = [], []
+ with open(decode_log, 'r') as f:
+ for idx, line in enumerate(f):
+ line = line.strip()
+ line_strip = [i for i in line.split('\t') if i]
+ if len(line_strip) == 2:
+ audio_id, caption = line_strip
+ else:
+ audio_id, caption = line_strip, ''
+ print("No caption detected")
+
+ audio_ids.append(audio_id)
+ captions.append(caption)
+
+ return audio_ids, captions
+
+def encode_text(dl):
+ embeds, caps = [], []
+ for i, b in enumerate(tqdm(dl, total=len(dl))):
+ embeds.append(model.encode_text(b).detach_())
+ caps += b
+ return torch.vstack(embeds), caps
+
+def encode_audio(dl):
+ device = torch.device("cuda")
+ embeds = []
+ for i, b in enumerate(tqdm(dl, total=len(dl))):
+ b = b.to(device)
+ embeds.append(model.encode_audio(b).detach_())
+ return torch.vstack(embeds), None
+
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--start_beam", type=int, default=2,
+ help="Start beam for reranking [included]")
+ parser.add_argument("--end_beam", type=int, default=8,
+ help="End beam for reranking [included]")
+ parser.add_argument("--clap_ckpt", type=str, required=True,
+ help="model ckpt for CLAP encoder")
+ parser.add_argument("--config", type=str, required=True,
+ help="model config for CLAP encoder")
+ parser.add_argument("--test_jsonl", type=str, required=True,
+ help="jsonl file for test set to get audio sources")
+ parser.add_argument("--exp_explorer", type=str, required=True,
+ help="dir to load candidate captions")
+ parser.add_argument("--rank", type=int, default=1,
+ help="rank for selecting the candidate caption")
+ args = parser.parse_args()
+
+ clap_ckpt = args.clap_ckpt
+ config = args.config
+ exp_explorer, test_jsonl = args.exp_explorer, args.test_jsonl
+ start_beam, end_beam = args.start_beam, args.end_beam
+ cand_files = [f'{exp_explorer}/decode_beam{i}_pred' for i in range(start_beam, end_beam+1)]
+
+ print(f"--Clap re-ranking for beam {start_beam}~{end_beam}--")
+
+ ## Load captions & audios
+ cand_captions = [read_captions(log)[1] for log in cand_files]
+ audio_ids, _ = read_captions(cand_files[0])
+
+ ## Load model
+ with open(config, 'r') as f:
+ config = yaml.safe_load(f)
+ model = ASE(config)
+ cp_dict = torch.load(clap_ckpt)['model']
+ model.load_state_dict(cp_dict)
+ model.cuda().eval()
+
+ # Encode
+ audio_ds = audio_dataset([test_jsonl])
+ train_loader = DataLoader(audio_ds, batch_size=1, shuffle=False, collate_fn=audio_ds.collator) # NOTE: btz should be 1, if not performance will be harmed due to zero padding
+ audio_embeds, train_caps = encode_audio(train_loader) # [b, dim]
+
+ caption_embeds= []
+ for captions in cand_captions:
+ embeds = encode_text(DataLoader(caption_dataset(captions), batch_size=512, shuffle=False))[0]
+ caption_embeds.append(embeds)
+ caption_embeds = torch.stack(caption_embeds) # [b, dim]
+
+ # Select
+ sim = (audio_embeds.unsqueeze(0) * caption_embeds).sum(-1) # [b, n]
+ sorted, indices = torch.sort(sim, dim=0, descending=True) # [b, n]
+ best_captions = []
+ for i in range(indices.shape[1]):
+ best_captions.append(cand_captions[int(indices[args.rank-1][i])][i])
+
+ # Write
+ output_file = exp_explorer + '/' + f"decode_beam{start_beam}-{end_beam}_pred"
+ with open(output_file, 'w') as f:
+ for i, caption in enumerate(best_captions):
+ audio_id = audio_ids[i]
+ line = f'{audio_id}\t{caption}'
+ f.write(line + '\n')
+ print(f"Clap refine finished, decode file saved at {output_file}")
\ No newline at end of file
diff --git a/examples/st_covost2/README.md b/examples/st_covost2/README.md
index 13d708f8..2864311e 100755
--- a/examples/st_covost2/README.md
+++ b/examples/st_covost2/README.md
@@ -1,5 +1,15 @@
# ST_covost2
+
+## Model Stracture
+
+
+
+## Multitask
+
+
+
+
## Download Model
We only train the q-former projector in this recipe.
Encoder | Projector | LLM
@@ -33,7 +43,7 @@ You can find the test jsonl in "test_st.jsonl"
{"audio": "/userhome/speech/data/common/4/en/clips/common_voice_en_699711.mp3", "prompt": "\"She'll be all right.\"<|zh|>", "gt": "\"She'll be all right.\"<|zh|>她会没事的。", "source": "covost_enenzh"}
```
## Train Stage
-Here, we have designed a four-step training process, where each training session uses the checkpoint obtained from the previous training session.
+Here, we have designed a three-step training process, where each training session uses the checkpoint obtained from the previous training session.
```
#In this step, we perform ASR pretraining to acquire speech recognition capabilities.
bash asr_pretrain.sh
@@ -41,10 +51,8 @@ bash asr_pretrain.sh
#In this phase, we conduct multimodal machine translation training to enhance the final performance.
bash mmt.sh
-#monolingual SRT training.
+#monolingual SRT training and multitask training.
bash srt.sh
-
-#multilingual multitask training.
bash zsrt.sh
```
@@ -53,7 +61,7 @@ bash zsrt.sh
You can try our pre-trained model.
```
-bash infer.sh
+bash infer_enzh.sh
```
## Citation
diff --git a/examples/st_covost2/image/framework.jpg b/examples/st_covost2/image/framework.jpg
new file mode 100644
index 00000000..d0f746e0
Binary files /dev/null and b/examples/st_covost2/image/framework.jpg differ
diff --git a/examples/st_covost2/image/prompt.png b/examples/st_covost2/image/prompt.png
new file mode 100644
index 00000000..48807499
Binary files /dev/null and b/examples/st_covost2/image/prompt.png differ
diff --git a/examples/st_covost2/scripts/infer.sh b/examples/st_covost2/scripts/infer_enzh.sh
similarity index 100%
rename from examples/st_covost2/scripts/infer.sh
rename to examples/st_covost2/scripts/infer_enzh.sh
diff --git a/src/slam_llm/models/CLAP/feature_extractor.py b/src/slam_llm/models/CLAP/feature_extractor.py
index 43febe0b..94681f2a 100644
--- a/src/slam_llm/models/CLAP/feature_extractor.py
+++ b/src/slam_llm/models/CLAP/feature_extractor.py
@@ -27,10 +27,10 @@ def __init__(self, audio_config):
fmin=audio_config["f_min"],
fmax=audio_config["f_max"],
ref=1.0,
- amin=1e-6,
+ amin=audio_config.get("amin", 1e-6),
top_db=None,
freeze_parameters=True)
-
+
def forward(self, input):
# input: waveform [bs, wav_length]
mel_feats = self.mel_trans(input)