Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions examples/s2s/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,12 @@ ds = load_dataset("DATASET_NAME")
### JSONL
We also support JSONL format for its concise structure. Below is an example:
```jsonl
{"key": "1", "source_wav": "/xxx/1.wav", "source_text": "Can you recommend some Chinese food for me?", "target_wav": "/xxx/1.wav", "target_text": "Sure! I recommend trying dumplings, Peking duck, and mapo tofu for a mix of flavors and textures in Chinese cuisine. These dishes offer a good balance of savory, spicy, and crispy elements."}
{"key": "1", "source_wav": "/xxx/1.wav", "source_text": "Can you recommend some Chinese food for me?", "target_token": [742, 383, 455, 619, 180], "target_text": "Sure! I recommend trying dumplings, Peking duck, and mapo tofu for a mix of flavors and textures in Chinese cuisine. These dishes offer a good balance of savory, spicy, and crispy elements."}
```

🔔**Update**:
We now use `target_token` to replace the `target_wav` field. When using your own data, you need to generate the corresponding audio response tokens yourself (e.g., using [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) tokens in SLAM-Omni).

## Checkpoints
We reproduced the single-stage fine-tuning results of SLAM-Omni with a group size of **3**. The following checkpoints are available for download:
- [Single-Round Dialogue (English)](https://drive.google.com/drive/folders/1ZmM1h5ZTvS-piuN-msmctmZdi51GWLAu?usp=sharing): Trained on VoiceAssistant-400K.
Expand Down Expand Up @@ -144,4 +147,4 @@ Mini-Omni:


## License
Our code is released under MIT License. The Chinese dialogue model is licensed under GPL-3.0 due to its use of Belle data and is intended for research purposes only.
Our code is released under MIT License. The Chinese dialogue model is licensed under GPL-3.0 due to its use of Belle data and is intended for research purposes only.
10 changes: 10 additions & 0 deletions examples/s2s/demo/demo_data/jsonl_demo-en.jsonl

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions examples/s2s/demo/demo_data/jsonl_demo-zh.jsonl

Large diffs are not rendered by default.

6 changes: 0 additions & 6 deletions examples/s2s/demo/demo_data/jsonl_demo.jsonl

This file was deleted.

2 changes: 1 addition & 1 deletion examples/s2s/s2s_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class DataConfig:
"help": "whether input is normalized, used for models such as wavlm"
})
seed: int = 42
manifest_format: str = field(default="datasets", metadata={ "help": "alternative: jsonl" })
manifest_format: str = field(default="parquet", metadata={ "help": "alternative: jsonl" })
split_size: float = 0.1

vocab_config: VocabConfig = field(default_factory=VocabConfig)
Expand Down
5 changes: 3 additions & 2 deletions examples/s2s/scripts/finetune/finetune_s2s.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ num_latency_tokens=0 # number of delay tokens (in front of the ge
do_layershift=false # if false, tokens in each layers use the same codebook, otherwise, use different codebooks

# dataset settings
manifest_format=parquet # parquet or jsonl
train_data_path=worstchan/VoiceAssistant-400K-SLAM-Omni
val_data_path=worstchan/VoiceAssistant-400K-SLAM-Omni
load_from_cache_file=true # set to true if you have already generated the cache file, otherwise set to false
load_from_cache_file=true # set to true if you have already generated the cache file, otherwise set to false

# training settings
batch_size_training=6
Expand Down Expand Up @@ -89,7 +90,7 @@ hydra.run.dir=$output_dir \
++dataset_config.input_type=mel \
++dataset_config.mel_size=$mel_size \
++dataset_config.seed=42 \
++dataset_config.manifest_format=datasets \
++dataset_config.manifest_format=$manifest_format \
++dataset_config.split_size=$split_size \
++dataset_config.load_from_cache_file=$load_from_cache_file \
++dataset_config.task_type=$task_type \
Expand Down
5 changes: 3 additions & 2 deletions examples/s2s/scripts/finetune/finetune_s2s_group.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ num_latency_tokens=0 # number of delay tokens (in front of the ge
do_layershift=false # if false, tokens in each layers use the same codebook, otherwise, use different codebooks

# dataset settings
manifest_format=parquet # parquet or jsonl
train_data_path=worstchan/VoiceAssistant-400K-SLAM-Omni
val_data_path=worstchan/VoiceAssistant-400K-SLAM-Omni
load_from_cache_file=true # set to true if you have already generated the cache file, otherwise set to false
load_from_cache_file=true # set to true if you have already generated the cache file, otherwise set to false

# training settings
batch_size_training=6
Expand Down Expand Up @@ -96,7 +97,7 @@ hydra.run.dir=$output_dir \
++dataset_config.input_type=mel \
++dataset_config.mel_size=$mel_size \
++dataset_config.seed=42 \
++dataset_config.manifest_format=datasets \
++dataset_config.manifest_format=$manifest_format \
++dataset_config.split_size=$split_size \
++dataset_config.load_from_cache_file=$load_from_cache_file \
++dataset_config.task_type=$task_type \
Expand Down
3 changes: 2 additions & 1 deletion examples/s2s/scripts/finetune/mini-omni/finetune_s2s.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mel_size=80 # 80 128 ( only whisper-large-v3 supports 128 )
llm_dim=896 # 896 1536 2048 3584 -> 0.5B 1.5B 3B 7B

# dataset settings
manifest_format=parquet # parquet or jsonl
train_data_path="/valleblob/v-wenxichen/data/s2s/VoiceAssistant-400K"
val_data_path="/valleblob/v-wenxichen/data/s2s/VoiceAssistant-400K"
load_from_cache_file=false # set to true if you have already generated the cache file, otherwise set to false
Expand Down Expand Up @@ -75,7 +76,7 @@ hydra.run.dir=$output_dir \
++dataset_config.input_type=mel \
++dataset_config.mel_size=$mel_size \
++dataset_config.seed=42 \
++dataset_config.manifest_format=datasets \
++dataset_config.manifest_format=$manifest_format \
++dataset_config.split_size=$split_size \
++dataset_config.load_from_cache_file=$load_from_cache_file \
++dataset_config.task_type=$task_type \
Expand Down
2 changes: 1 addition & 1 deletion examples/s2s/scripts/inference/inference_s2s_batch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ ckpt_path=/valleblob/v-wenxichen/exp/s2s/s2s_train_v3-gpu16-btz3-lr5e-4-fp16-epo
# val_data_path=/home/v-wenxichen/SLAM-LLM/examples/s2s/demo/data/${split}.jsonl

# huggingface dataset
manifest_format=datasets
manifest_format=parquet
val_data_path="/valleblob/v-wenxichen/data/s2s/VoiceAssistant-400K-v1/test"
load_from_cache_file=false
dataset_sample_seed=777
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ split=test
# val_data_path=/home/v-wenxichen/SLAM-LLM/examples/s2s/demo/data/${split}.jsonl

# huggingface dataset
manifest_format=datasets
manifest_format=parquet
val_data_path="gpt-omni/VoiceAssistant-400K"
load_from_cache_file=true
dataset_sample_seed=777
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ split=test
# val_data_path=/home/v-wenxichen/SLAM-LLM/examples/s2s/demo/data/${split}.jsonl

# huggingface dataset
manifest_format=datasets
manifest_format=parquet
val_data_path="gpt-omni/VoiceAssistant-400K"
load_from_cache_file=true
dataset_sample_seed=1234
Expand Down
2 changes: 1 addition & 1 deletion examples/s2s/scripts/inference/mini-omni/inference_tts.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ split=test
# val_data_path=/home/v-wenxichen/SLAM-LLM/examples/s2s/demo/data/${split}.jsonl

# huggingface dataset
manifest_format=datasets
manifest_format=parquet
val_data_path="gpt-omni/VoiceAssistant-400K"
load_from_cache_file=true
dataset_sample_seed=1234
Expand Down
3 changes: 2 additions & 1 deletion examples/s2s/scripts/pretrain/pretrain_asr.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ num_latency_tokens=0 # number of delay tokens (in front of the ge
do_layershift=false # if false, tokens in each layers use the same codebook, otherwise, use different codebooks

# dataset settings
manifest_format=parquet # parquet or jsonl
train_data_path=worstchan/VoiceAssistant-400K-SLAM-Omni
val_data_path=worstchan/VoiceAssistant-400K-SLAM-Omni
load_from_cache_file=true # set to true if you have already generated the cache file, otherwise set to false
Expand Down Expand Up @@ -96,7 +97,7 @@ hydra.run.dir=$output_dir \
++dataset_config.input_type=mel \
++dataset_config.mel_size=$mel_size \
++dataset_config.seed=42 \
++dataset_config.manifest_format=datasets \
++dataset_config.manifest_format=$manifest_format \
++dataset_config.split_size=$split_size \
++dataset_config.load_from_cache_file=$load_from_cache_file \
++dataset_config.task_type=$task_type \
Expand Down
3 changes: 2 additions & 1 deletion examples/s2s/scripts/pretrain/pretrain_asr_debug.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ num_latency_tokens=0 # number of latency tokens (in front of the
do_layershift=false # if false, tokens in each layers use the same codebook, otherwise, use different codebooks

# dataset settings
manifest_format=parquet # parquet or jsonl
train_data_path="/valleblob/v-wenxichen/data/s2s/parquet_data_test/en"
val_data_path="/valleblob/v-wenxichen/data/s2s/parquet_data_test/en"
load_from_cache_file=true # set to true if you have already generated the cache file, otherwise set to false
Expand Down Expand Up @@ -97,7 +98,7 @@ hydra.run.dir=$output_dir \
++dataset_config.input_type=mel \
++dataset_config.mel_size=$mel_size \
++dataset_config.seed=42 \
++dataset_config.manifest_format=datasets \
++dataset_config.manifest_format=$manifest_format \
++dataset_config.split_size=$split_size \
++dataset_config.load_from_cache_file=$load_from_cache_file \
++dataset_config.task_type=$task_type \
Expand Down
3 changes: 2 additions & 1 deletion examples/s2s/scripts/pretrain/pretrain_tts.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ num_latency_tokens=0 # number of delay tokens (in front of the ge
do_layershift=false # if false, tokens in each layers use the same codebook, otherwise, use different codebooks

# dataset settings
manifest_format=parquet # parquet or jsonl
train_data_path=worstchan/VoiceAssistant-400K-SLAM-Omni
val_data_path=worstchan/VoiceAssistant-400K-SLAM-Omni
load_from_cache_file=true # set to true if you have already generated the cache file, otherwise set to false
Expand Down Expand Up @@ -90,7 +91,7 @@ hydra.run.dir=$output_dir \
++dataset_config.val_data_path=$val_data_path \
++dataset_config.input_type=mel \
++dataset_config.seed=42 \
++dataset_config.manifest_format=datasets \
++dataset_config.manifest_format=$manifest_format \
++dataset_config.split_size=$split_size \
++dataset_config.load_from_cache_file=$load_from_cache_file \
++dataset_config.task_type=$task_type \
Expand Down
3 changes: 2 additions & 1 deletion examples/s2s/scripts/pretrain/pretrain_tts_debug.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ num_latency_tokens=0 # number of latency tokens (in front of the
do_layershift=false # if false, tokens in each layers use the same codebook, otherwise, use different codebooks

# dataset settings
manifest_format=parquet # parquet or jsonl
train_data_path="/valleblob/v-wenxichen/data/debug/1"
val_data_path="/valleblob/v-wenxichen/data/debug/1"
load_from_cache_file=true # set to true if you have already generated the cache file, otherwise set to false
Expand Down Expand Up @@ -82,7 +83,7 @@ hydra.run.dir=$output_dir \
++dataset_config.val_data_path=$val_data_path \
++dataset_config.input_type=mel \
++dataset_config.seed=42 \
++dataset_config.manifest_format=datasets \
++dataset_config.manifest_format=$manifest_format \
++dataset_config.split_size=$split_size \
++dataset_config.load_from_cache_file=$load_from_cache_file \
++dataset_config.task_type=$task_type \
Expand Down
21 changes: 11 additions & 10 deletions examples/s2s/speech_dataset_s2s.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ def __init__(self,
self.inference_mode = dataset_config.get("inference_mode", False)
self.normalize = dataset_config.get("normalize", False)
self.input_type = dataset_config.get("input_type", None)
self.manifest_format = dataset_config.get("manifest_format", "datasets")
self.manifest_format = dataset_config.get("manifest_format", "parquet")
self.seed = dataset_config.get("seed", 42)
self.split_size = dataset_config.get("split_size", 0.1)
assert self.input_type in ["raw", "mel"], "input_type must be one of [raw, mel]"
assert self.manifest_format in ["datasets", "jsonl"], "manifest_format must be one of [datasets, jsonl]"
assert self.manifest_format in ["parquet", "jsonl"], "manifest_format must be one of [parquet, jsonl]"

# vocab config
self.vocab_config = dataset_config.get("vocab_config", None)
Expand Down Expand Up @@ -88,7 +88,7 @@ def __init__(self,
self.data_list = []

# TODO: design a better way to load data
if self.manifest_format == "datasets":
if self.manifest_format == "parquet":
from datasets import load_dataset, load_from_disk
if dataset_config.load_from_cache_file:
ds = load_dataset(dataset_config.train_data_path) # load_from huggingface datasets
Expand All @@ -99,7 +99,7 @@ def __init__(self,
self.data_list = train_val_split['train']
else:
self.data_list = train_val_split['test']
else:
elif self.manifest_format == "jsonl":
if split == "train":
with open(dataset_config.train_data_path, encoding='utf-8') as fin:
for line in fin:
Expand All @@ -110,6 +110,8 @@ def __init__(self,
for line in fin:
data_dict = json.loads(line.strip())
self.data_list.append(data_dict)
else:
raise ValueError("manifest_format must be one of [parquet, jsonl]")

def get_source_len(self, data_dict):
return data_dict["source_len"]
Expand All @@ -120,16 +122,15 @@ def get_target_len(self, data_dict):
def __len__(self):
return len(self.data_list)

# NOTE: here datasets format is just for VoiceAssistant-400K dataset, and we only support the whisper format
def extract_audio_feature(self, audio_path):
# audio path is a dictionary, resample the audio to 16kHz
if self.manifest_format == "datasets" and isinstance(audio_path, dict):
if self.manifest_format == "parquet" and isinstance(audio_path, dict):
audio_raw = audio_path['array']
audio_raw_sr = audio_path['sampling_rate']
if not isinstance(audio_raw, np.ndarray):
audio_raw = np.array(audio_raw)
audio_raw = librosa.resample(audio_raw, orig_sr=audio_raw_sr, target_sr=16000).astype(np.float32)
elif self.manifest_format == "datasets" and (isinstance(audio_path, str) or isinstance(audio_path, list)):
elif (self.manifest_format == "parquet" and (isinstance(audio_path, str) or isinstance(audio_path, list))) or (self.manifest_format == "jsonl" and isinstance(audio_path, list)):
if self.code_type == "SNAC":
audio_res, audio_length = get_snac_answer_token(audio_path)
elif self.code_type == "CosyVoice":
Expand Down Expand Up @@ -233,7 +234,7 @@ def __getitem__(self, index):
audio_length = 0
target_audio_length = 0

if self.manifest_format == "datasets":
if self.manifest_format == "parquet":
source_audio = data_dict.get("question_audio", None)
if self.code_type == "SNAC":
target_audio = data_dict.get("answer_snac", None)
Expand All @@ -245,12 +246,12 @@ def __getitem__(self, index):
key = source_audio['path']
elif self.manifest_format == "jsonl":
source_audio = data_dict.get("source_wav", None)
target_audio = data_dict.get("target_wav", None)
target_audio = data_dict.get("target_token", None)
source_text = data_dict.get("source_text", None)
target_text = data_dict.get("target_text", None)
key = data_dict.get("key", None)
else:
raise ValueError("manifest_format must be one of [datasets, jsonl]")
raise ValueError("manifest_format must be one of [parquet, jsonl]")

if task_type == "s2s" or task_type == "asr":
audio_mel, audio_length = self.extract_audio_feature(source_audio)
Expand Down
Loading