Skip to content

Commit 38d8c66

Browse files
authored
Merge pull request #150 from X-LANCE/seld
SELD: fix typo; add ckpt link; add inference code; update model performance.
2 parents db55bea + 8c05584 commit 38d8c66

16 files changed

+957
-106
lines changed

examples/seld_spatialsoundqa/README.md

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,70 @@
11
# <img src="assets/bat.png" alt="SELD_SpatialSoundQA" width="25" height="25"> SELD_SpatialSoundQA
22

3-
This repo hosts the code and models of "[BAT: Learning to Reason about Spatial Sounds with Large Language Models](https://arxiv.org/abs/2402.01591)" [ICML 2024 [bib](https://github.com/zszheng147/Spatial-AST#citation)].
3+
This repo hosts the code and models of "[BAT: Learning to Reason about Spatial Sounds with Large Language Models](https://arxiv.org/abs/2402.01591)" [ICML 2024 [bib](https://github.com/X-LANCE/SLAM-LLM/tree/main/examples/seld_spatialsoundqa#citation)].
44

55
Checkout our [demo page](https://zhishengzheng.com/BAT/) and enjoy a QA game with spatial audio.
66

7-
## Performance and checkpoints
8-
Encoder | Projector | PEFT | LLM
9-
|---|---|---|---|
10-
[Spatial-AST](https://huggingface.co/zhisheng01/Bat/blob/main/spatial-ast.pth) | Q-Former | adapter |[llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b)
7+
## Performance evaluation on **SpatialSoundQA**
8+
We use [Spatial-AST](https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/SpatialAST/finetuned.pth) as audio encoder, [llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) as LLM backbone. We finetune the model by adding Q-Former and LoRA. To calculate MAP, you can refer to [calculate_map.py](https://github.com/X-LANCE/SLAM-LLM/blob/main/examples/seld_spatialsoundqa/scripts/calculate_map.py)
9+
<img src="assets/performance.png" alt="xxx">
10+
11+
12+
## Checkpoints
13+
Encoder | Projector | LLM |
14+
|---|---|---|
15+
[Spatial-AST](https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/SpatialAST/finetuned.pth) | [Q-former](https://huggingface.co/datasets/zhisheng01/SpatialAudio/blob/main/BAT/model.pt)(~73.56M) | [llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b) |
16+
17+
## Demo (Spatial Audio Inference)
18+
Try [`inference.ipynb`](https://github.com/X-LANCE/SLAM-LLM/blob/main/examples/seld_spatialsoundqa/inference.ipynb).
19+
1120

1221
## Data preparation
1322
You need to prepare the data jsonl in this format. Below is an example.
14-
You can download the SpatialSoundQA dataset from [huggingface](https://huggingface.co/datasets/zhisheng01/SpatialSoundQA).
15-
```
16-
{"audio_id": "eval/audio/YI-HlrcP6Qg4", "reverb_id": "q9vSo1VnCiC/0.npy", "audio_id2": null, "reverb_id2": null, "question_id": 0, "question_type": "CLASSIFICATION", "question": "Enumerate the sound occurrences in the audio clip.", "answer": "accelerating, revving, vroom; car; vehicle"}
23+
You can download the SpatialSoundQA dataset from [SpatialAudio](https://huggingface.co/datasets/zhisheng01/SpatialAudio).
24+
```json
25+
{
26+
"audio_id": "eval/audio/YI-HlrcP6Qg4",
27+
"reverb_id": "q9vSo1VnCiC/0.npy",
28+
"audio_id2": null,
29+
"reverb_id2": null,
30+
"question_id": 0,
31+
"question_type": "CLASSIFICATION",
32+
"question": "Enumerate the sound occurrences in the audio clip.",
33+
"answer": "accelerating, revving, vroom; car; vehicle"
34+
}
35+
1736
...
18-
{"audio_id": "eval/audio/YZX2fVPmUidA", "reverb_id": "q9vSo1VnCiC/32.npy", "audio_id2": "eval/audio/YjNjUU01quLs", "reverb_id2": "q9vSo1VnCiC/31.npy", "question_id": 58, "question_type": "MIXUP_NONBINARY_DISTANCE", "question": "How far away is the sound of the banjo from the sound of the whack, thwack?", "answer": "2m"}
37+
38+
{
39+
"audio_id": "eval/audio/YZX2fVPmUidA",
40+
"reverb_id": "q9vSo1VnCiC/32.npy",
41+
"audio_id2": "eval/audio/YjNjUU01quLs",
42+
"reverb_id2": "q9vSo1VnCiC/31.npy",
43+
"question_id": 58,
44+
"question_type": "MIXUP_NONBINARY_DISTANCE",
45+
"question": "How far away is the sound of the banjo from the sound of the whack, thwack?",
46+
"answer": "2m"
47+
}
1948
```
2049

2150
## Train a new model
2251
```bash
23-
bash examples/seld_spatialsoundqa/scripts/finetune_spatial-ast_qformer_llama_2_7b.sh
52+
cd examples/seld_spatialsoundqa/
53+
bash scripts/finetune_spatial-ast_qformer_llama_2_7b.sh
2454
```
2555

2656
## Decoding with checkpoints
2757
```bash
28-
bash examples/seld_spatialsoundqa/scripts/decode_spatial-ast_qformer_llama_2_7b.sh
58+
cd examples/seld_spatialsoundqa/
59+
bash scripts/decode_spatial-ast_qformer_llama_2_7b.sh
2960
```
3061

3162

3263
## TODO
3364
- [x] Decode with checkpoints
3465
- [x] Upload SpatialSoundQA dataset
35-
- [ ] Upload pretrained checkpoints
36-
- [ ] Update model performance
66+
- [x] Upload pretrained checkpoints
67+
- [x] Update model performance
3768

3869
## Citation
3970
```
656 KB
Binary file not shown.
592 KB
Binary file not shown.
625 KB
Binary file not shown.
625 KB
Binary file not shown.
509 KB
Loading

examples/seld_spatialsoundqa/dataset/spatial_audio_dataset.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,8 @@ def __init__(
3737
split,
3838
):
3939
super().__init__()
40-
dataset_path = os.path.join(dataset_config['qa_data_root'], dataset_config['stage'], split + '.jsonl')
41-
with open(dataset_path) as f:
42-
self.data = [json.loads(line) for line in f.readlines()]
40+
dataset_path = os.path.join(dataset_config['qa_data_root'], dataset_config['stage'], split + '.json')
41+
self.data = json.load(open(dataset_path))["data"]
4342

4443
self.anechoic_data_root = dataset_config['anechoic_data_root'] # which is AudioSet in this case
4544
self.reverb_data_root = dataset_config['reverb_data_root']

examples/seld_spatialsoundqa/finetune_seld.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import hydra
22
import logging
3+
from typing import Optional
34
from dataclasses import dataclass, field
45
from omegaconf import DictConfig, ListConfig, OmegaConf
56

@@ -16,32 +17,20 @@ class RunConfig:
1617
peft_config: PeftConfig = field(default_factory=PeftConfig)
1718
debug: bool = field(default=False, metadata={"help": "Use pdb when true"})
1819
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"})
19-
ckpt_path: str = field(
20-
default="output/model.pt", metadata={"help": "The path to projector checkpoint"}
20+
ckpt_path: Optional[str] = field(
21+
default=None, metadata={"help": "The path to projector checkpoint"}
2122
)
2223

2324
@hydra.main(config_name=None, version_base=None)
2425
def main_hydra(cfg: DictConfig):
2526
run_config = RunConfig()
2627
cfg = OmegaConf.merge(run_config, cfg)
27-
def to_plain_list(cfg_item):
28-
if isinstance(cfg_item, ListConfig):
29-
return OmegaConf.to_container(cfg_item, resolve=True)
30-
elif isinstance(cfg_item, DictConfig):
31-
return {k: to_plain_list(v) for k, v in cfg_item.items()}
32-
else:
33-
return cfg_item
34-
35-
# kwargs = to_plain_list(cfg)
36-
kwargs = cfg
37-
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
28+
cfg.train_config.peft_config = cfg.peft_config
29+
30+
log_level = getattr(logging, cfg.get("log_level", "INFO").upper())
3831
logging.basicConfig(level=log_level)
39-
40-
if kwargs.get("debug", False):
41-
import pdb;
42-
pdb.set_trace()
4332

44-
train(kwargs)
33+
train(cfg)
4534

4635

4736
if __name__ == "__main__":

examples/seld_spatialsoundqa/inference.ipynb

Lines changed: 786 additions & 0 deletions
Large diffs are not rendered by default.

examples/seld_spatialsoundqa/inference_seld_batch.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,11 @@ class RunConfig:
3636
def main_hydra(cfg: DictConfig):
3737
run_config = RunConfig()
3838
cfg = OmegaConf.merge(run_config, cfg)
39-
# kwargs = to_plain_list(cfg)
40-
log_level = getattr(logging, cfg.get("log_level", "INFO").upper())
39+
cfg.train_config.peft_config = cfg.peft_config
4140

41+
log_level = getattr(logging, cfg.get("log_level", "INFO").upper())
4242
logging.basicConfig(level=log_level)
4343

44-
if cfg.get("debug", False):
45-
import pdb
46-
47-
pdb.set_trace()
48-
4944
inference(cfg)
5045

5146

0 commit comments

Comments
 (0)