Skip to content
Merged

sloth #159

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,18 @@ developers to train custom multimodal large language model (MLLM), focusing on <
6. [Citation](#citation)

# News
- [Update Jun. 12, 2024] Recipes for [MaLa-ASR](examples/mala_asr_slidespeech/README.md) has been supported.
- [Update Oct. 12, 2024] Recipes for [SLAM-AAC](examples/slam_aac/README.md) have been supported.
- [Update Sep. 28, 2024] Recipes for [CoT-ST](examples/st_covost2/README.md) have been supported.
- [Update Sep. 25, 2024] Recipes for [DRCap](examples/drcap_zeroshot_aac/README.md) have been supported.
- [Update Jun. 12, 2024] Recipes for [MaLa-ASR](examples/mala_asr_slidespeech/README.md) have been supported.
- **[CALL FOR EXAMPLE]** We sincerely invite developers and researchers to develop new applications, conduct academic research based on SLAM-LLM, and pull request your examples! We also acknowledge engineering PR (such as improving and speeding up multi-node training).
- [Update May. 22, 2024] Please join [slack](https://join.slack.com/t/slam-llm/shared_invite/zt-2mc0pkhhs-5jjOi8Cwc8R1Xc8IQmykDA) or [WeChat group](./docs/Wechat.jpg). We will sync our updates and Q&A here.
- [Update May. 21, 2024] Recipes for [Spatial Audio Understanding](examples/seld_spatialsoundqa/README.md) has been supported.
- [Update May. 20, 2024] Recipes for [music caption (MC)](examples/mc_musiccaps/README.md) has been supported.
- [Update May. 8, 2024] Recipes for [visual speech recognition (VSR)](examples/vsr_LRS3/README.md) has been supported.
- [Update May. 4, 2024] Recipes for [zero-shot text-to-speech (TTS)](examples/vallex/README.md) has been supported.
- [Update Apr. 28, 2024] Recipes for [automated audio captioning (AAC)](examples/aac_audiocaps/README.md) has been supported.
- [Update Mar. 31, 2024] Recipes for [automatic speech recognition (ASR)](examples/asr_librispeech/README.md) has been supported.
- [Update May. 21, 2024] Recipes for [Spatial Audio Understanding](examples/seld_spatialsoundqa/README.md) have been supported.
- [Update May. 20, 2024] Recipes for [music caption (MC)](examples/mc_musiccaps/README.md) have been supported.
- [Update May. 8, 2024] Recipes for [visual speech recognition (VSR)](examples/vsr_LRS3/README.md) have been supported.
- [Update May. 4, 2024] Recipes for [zero-shot text-to-speech (TTS)](examples/vallex/README.md) have been supported.
- [Update Apr. 28, 2024] Recipes for [automated audio captioning (AAC)](examples/aac_audiocaps/README.md) have been supported.
- [Update Mar. 31, 2024] Recipes for [automatic speech recognition (ASR)](examples/asr_librispeech/README.md) have been supported.

# Installation
```bash
Expand Down Expand Up @@ -75,12 +78,25 @@ docker run -it --gpus all --name slam --shm-size=256g slam-llm:latest /bin/bash
## List of Recipes
We provide reference implementations of various LLM-based speech, audio, and music tasks:
- **Speech Task**
- [Automatic Speech Recognition (ASR)](examples/asr_librispeech/README.md)
- [Text-to-Speech (TTS)](examples/vallex/README.md)
- [Visual Speech Recognition (VSR)](examples/vsr_LRS3/README.md)
- Automatic Speech Recognition (ASR)
- [SLAM-ASR](examples/asr_librispeech/README.md)

- Contextual Automatic Speech Recognition (CASR)
- [ Mala-ASR](examples/mala_asr_slidespeech/README.md)

- [Visual Speech Recognition (VSR)](examples/vsr_LRS3/README.md)
- Speech-to-Text Translation (S2TT)
- [CoT-ST](examples/st_covost2/README.md)

- Text-to-Speech (TTS)
- [VALL-E-X](examples/vallex/README.md)

- **Audio Task**
- [Automated Audio Captioning (AAC)](examples/aac_audiocaps/README.md)
- [Spatial Audio Understanding](examples/seld_spatialsoundqa/README.md)
- [SLAM-AAC](examples/slam_aac/README.md)
- [DRCap](examples/drcap_zeroshot_aac/README.md)
- Spatial Audio Understanding
- [BAT](examples/seld_spatialsoundqa/README.md)
- **Music Task**
- [Music Caption (MC)](examples/mc_musiccaps/README.md)

Expand All @@ -103,7 +119,7 @@ command-line (shell file) > Hydra configuration (yaml file) > dataclass configur
- We thank the contributors for providing diverse recipes.

## Citation

SLAM-ASR:
```
@article{ma2024embarrassingly,
title={An Embarrassingly Simple Approach for LLM with Strong ASR Capacity},
Expand Down
6 changes: 5 additions & 1 deletion examples/aac_audiocaps/aac_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from dataclasses import dataclass, field
from typing import Optional, List

from torch.distributed.fsdp import ShardingStrategy


@dataclass
class ModelConfig:
file: str = "examples/aac_audiocaps/model/slam_model_aac.py:model_factory"
Expand Down Expand Up @@ -114,7 +118,7 @@ class FSDPConfig:
mixed_precision: bool = True
use_fp16: bool = False
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool = True
fsdp_cpu_offload: bool = False
Expand Down
6 changes: 5 additions & 1 deletion examples/asr_librispeech/asr_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from dataclasses import dataclass, field
from typing import Optional, List

from torch.distributed.fsdp import ShardingStrategy


@dataclass
class ModelConfig:
file: str = "examples/asr_librispeech/model/slam_model_asr.py:model_factory"
Expand Down Expand Up @@ -108,7 +112,7 @@ class FSDPConfig:
mixed_precision: bool = True
use_fp16: bool = False
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool = True
fsdp_cpu_offload: bool = False
Expand Down
6 changes: 5 additions & 1 deletion examples/drcap_zeroshot_aac/drcap_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from dataclasses import dataclass, field
from typing import Optional, List

from torch.distributed.fsdp import ShardingStrategy


@dataclass
class ModelConfig:
file: str = "examples/drcap_zeroshot_aac/model/slam_model_drcap.py:model_factory"
Expand Down Expand Up @@ -113,7 +117,7 @@ class FSDPConfig:
mixed_precision: bool = True
use_fp16: bool = False
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool = True
fsdp_cpu_offload: bool = False
Expand Down
5 changes: 4 additions & 1 deletion examples/mala_asr_slidespeech/mala_asr_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from dataclasses import dataclass, field
from typing import Optional, List
from torch.distributed.fsdp import ShardingStrategy


@dataclass
class ModelConfig:
file: str = "examples/mala_asr_slidespeech/model/slam_model_mala_asr.py:model_factory"
Expand Down Expand Up @@ -109,7 +112,7 @@ class FSDPConfig:
mixed_precision: bool = True
use_fp16: bool = False
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool = True
fsdp_cpu_offload: bool = False
Expand Down
5 changes: 4 additions & 1 deletion examples/mc_musiccaps/mir_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from dataclasses import dataclass, field
from typing import Optional, List

from torch.distributed.fsdp import ShardingStrategy

@dataclass
class ModelConfig:
file: str = "examples/mc_musiccaps/model/slam_model_mir.py:model_factory"
Expand Down Expand Up @@ -112,7 +115,7 @@ class FSDPConfig:
mixed_precision: bool = True
use_fp16: bool = False
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
sharding_strategy: ShardingStrategy = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool = True
fsdp_cpu_offload: bool = False
Expand Down
57 changes: 44 additions & 13 deletions examples/seld_spatialsoundqa/README.md
Original file line number Diff line number Diff line change
@@ -1,39 +1,70 @@
# <img src="assets/bat.png" alt="SELD_SpatialSoundQA" width="25" height="25"> SELD_SpatialSoundQA

This repo hosts the code and models of "[BAT: Learning to Reason about Spatial Sounds with Large Language Models](https://arxiv.org/abs/2402.01591)" [ICML 2024 [bib](https://github.com/zszheng147/Spatial-AST#citation)].
This repo hosts the code and models of "[BAT: Learning to Reason about Spatial Sounds with Large Language Models](https://arxiv.org/abs/2402.01591)" [ICML 2024 [bib](https://github.com/X-LANCE/SLAM-LLM/tree/main/examples/seld_spatialsoundqa#citation)].

Checkout our [demo page](https://zhishengzheng.com/BAT/) and enjoy a QA game with spatial audio.

## Performance and checkpoints
Encoder | Projector | PEFT | LLM
|---|---|---|---|
[Spatial-AST](https://huggingface.co/zhisheng01/Bat/blob/main/spatial-ast.pth) | Q-Former | adapter |[llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b)
## Performance evaluation on **SpatialSoundQA**
We use [Spatial-AST](https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/SpatialAST/finetuned.pth) as audio encoder, [llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) as LLM backbone. We finetune the model by adding Q-Former and LoRA. To calculate MAP, you can refer to [calculate_map.py](https://github.com/X-LANCE/SLAM-LLM/blob/main/examples/seld_spatialsoundqa/scripts/calculate_map.py)
<img src="assets/performance.png" alt="xxx">


## Checkpoints
Encoder | Projector | LLM |
|---|---|---|
[Spatial-AST](https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/SpatialAST/finetuned.pth) | [Q-former](https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/BAT/model.pt)(~73.56M) | [llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b) |

## Demo (Spatial Audio Inference)
Try [`inference.ipynb`](https://github.com/X-LANCE/SLAM-LLM/blob/main/examples/seld_spatialsoundqa/inference.ipynb).


## Data preparation
You need to prepare the data jsonl in this format. Below is an example.
You can download the SpatialSoundQA dataset from [huggingface](https://huggingface.co/datasets/zhisheng01/SpatialSoundQA).
```
{"audio_id": "eval/audio/YI-HlrcP6Qg4", "reverb_id": "q9vSo1VnCiC/0.npy", "audio_id2": null, "reverb_id2": null, "question_id": 0, "question_type": "CLASSIFICATION", "question": "Enumerate the sound occurrences in the audio clip.", "answer": "accelerating, revving, vroom; car; vehicle"}
You can download the SpatialSoundQA dataset from [SpatialAudio](https://huggingface.co/datasets/zhisheng01/SpatialAudio).
```json
{
"audio_id": "eval/audio/YI-HlrcP6Qg4",
"reverb_id": "q9vSo1VnCiC/0.npy",
"audio_id2": null,
"reverb_id2": null,
"question_id": 0,
"question_type": "CLASSIFICATION",
"question": "Enumerate the sound occurrences in the audio clip.",
"answer": "accelerating, revving, vroom; car; vehicle"
}

...
{"audio_id": "eval/audio/YZX2fVPmUidA", "reverb_id": "q9vSo1VnCiC/32.npy", "audio_id2": "eval/audio/YjNjUU01quLs", "reverb_id2": "q9vSo1VnCiC/31.npy", "question_id": 58, "question_type": "MIXUP_NONBINARY_DISTANCE", "question": "How far away is the sound of the banjo from the sound of the whack, thwack?", "answer": "2m"}

{
"audio_id": "eval/audio/YZX2fVPmUidA",
"reverb_id": "q9vSo1VnCiC/32.npy",
"audio_id2": "eval/audio/YjNjUU01quLs",
"reverb_id2": "q9vSo1VnCiC/31.npy",
"question_id": 58,
"question_type": "MIXUP_NONBINARY_DISTANCE",
"question": "How far away is the sound of the banjo from the sound of the whack, thwack?",
"answer": "2m"
}
```

## Train a new model
```bash
bash examples/seld_spatialsoundqa/scripts/finetune_spatial-ast_qformer_llama_2_7b.sh
cd examples/seld_spatialsoundqa/
bash scripts/finetune_spatial-ast_qformer_llama_2_7b.sh
```

## Decoding with checkpoints
```bash
bash examples/seld_spatialsoundqa/scripts/decode_spatial-ast_qformer_llama_2_7b.sh
cd examples/seld_spatialsoundqa/
bash scripts/decode_spatial-ast_qformer_llama_2_7b.sh
```


## TODO
- [x] Decode with checkpoints
- [x] Upload SpatialSoundQA dataset
- [ ] Upload pretrained checkpoints
- [ ] Update model performance
- [x] Upload pretrained checkpoints
- [x] Update model performance

## Citation
```
Expand Down
Binary file added examples/seld_spatialsoundqa/assets/74.npy
Binary file not shown.
Binary file added examples/seld_spatialsoundqa/assets/75.npy
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 2 additions & 3 deletions examples/seld_spatialsoundqa/dataset/spatial_audio_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ def __init__(
split,
):
super().__init__()
dataset_path = os.path.join(dataset_config['qa_data_root'], dataset_config['stage'], split + '.jsonl')
with open(dataset_path) as f:
self.data = [json.loads(line) for line in f.readlines()]
dataset_path = os.path.join(dataset_config['qa_data_root'], dataset_config['stage'], split + '.json')
self.data = json.load(open(dataset_path))["data"]

self.anechoic_data_root = dataset_config['anechoic_data_root'] # which is AudioSet in this case
self.reverb_data_root = dataset_config['reverb_data_root']
Expand Down
25 changes: 7 additions & 18 deletions examples/seld_spatialsoundqa/finetune_seld.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import hydra
import logging
from typing import Optional
from dataclasses import dataclass, field
from omegaconf import DictConfig, ListConfig, OmegaConf

Expand All @@ -16,32 +17,20 @@ class RunConfig:
peft_config: PeftConfig = field(default_factory=PeftConfig)
debug: bool = field(default=False, metadata={"help": "Use pdb when true"})
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"})
ckpt_path: str = field(
default="output/model.pt", metadata={"help": "The path to projector checkpoint"}
ckpt_path: Optional[str] = field(
default=None, metadata={"help": "The path to projector checkpoint"}
)

@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
run_config = RunConfig()
cfg = OmegaConf.merge(run_config, cfg)
def to_plain_list(cfg_item):
if isinstance(cfg_item, ListConfig):
return OmegaConf.to_container(cfg_item, resolve=True)
elif isinstance(cfg_item, DictConfig):
return {k: to_plain_list(v) for k, v in cfg_item.items()}
else:
return cfg_item

# kwargs = to_plain_list(cfg)
kwargs = cfg
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
cfg.train_config.peft_config = cfg.peft_config

log_level = getattr(logging, cfg.get("log_level", "INFO").upper())
logging.basicConfig(level=log_level)

if kwargs.get("debug", False):
import pdb;
pdb.set_trace()

train(kwargs)
train(cfg)


if __name__ == "__main__":
Expand Down
786 changes: 786 additions & 0 deletions examples/seld_spatialsoundqa/inference.ipynb

Large diffs are not rendered by default.

9 changes: 2 additions & 7 deletions examples/seld_spatialsoundqa/inference_seld_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,11 @@ class RunConfig:
def main_hydra(cfg: DictConfig):
run_config = RunConfig()
cfg = OmegaConf.merge(run_config, cfg)
# kwargs = to_plain_list(cfg)
log_level = getattr(logging, cfg.get("log_level", "INFO").upper())
cfg.train_config.peft_config = cfg.peft_config

log_level = getattr(logging, cfg.get("log_level", "INFO").upper())
logging.basicConfig(level=log_level)

if cfg.get("debug", False):
import pdb

pdb.set_trace()

inference(cfg)


Expand Down
25 changes: 2 additions & 23 deletions examples/seld_spatialsoundqa/model/slam_model_seld.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,26 +74,5 @@ def __init__(
tokenizer,
train_config,
model_config,
**kwargs,
)

@torch.no_grad()
def inference(
self,
wav_path=None,
reverb_path=None,
prompt=None,
generation_config=None,
logits_processor=None,
stopping_criteria=None,
prefix_allowed_tokens_fn=None,
synced_gpus=None,
assistant_model=None,
streamer=None,
negative_prompt_ids=None,
negative_prompt_attention_mask=None,
**kwargs,
):
#!TODO:
# inference for SELD model
pass
**kwargs
)
Loading
Loading