Skip to content

Commit 52fab27

Browse files
author
蒄骰
committed
f
1 parent 86dc310 commit 52fab27

File tree

5 files changed

+66
-7
lines changed

5 files changed

+66
-7
lines changed

examples/contextual_asr/README.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,16 @@ words, with the remainder classified as rare words. The biasing list generated f
2828

2929

3030
## Decoding with checkpoints
31-
LLM-based Contextual ASR Inference script, with different biaisng sizes and test sets.
31+
LLM-based ASR Inference script.
32+
```
33+
bash decode_wavlm_libri960_ft_char.sh
34+
```
35+
LLM-based Contextual ASR Inference script, with different biaisng list sizes.
3236
```
3337
bash decode_wavlm_libri960_ft_char_hotwords.sh
3438
```
3539

40+
3641
## Training the model
3742
LLM-based ASR Training script: using CTC fine-tuned Wavlm as encoder and “Transcribe speech to text.” as prompt.
3843
```
@@ -53,6 +58,4 @@ You can refer to the paper for more results.
5358
journal={Proc. SLT},
5459
year={2024}
5560
}
56-
```
57-
58-
61+
```

examples/contextual_asr/contextual_asr_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class TrainConfig:
8282
@dataclass
8383
class DataConfig:
8484
dataset: str = "speech_dataset"
85-
file: str = "examples/contextual_asr/dataset/hotwords_dataset.py:get_speech_dataset"
85+
file: str = "src/slam_llm/datasets/speech_dataset.py:get_speech_dataset"
8686
train_data_path: Optional[str] = None
8787
val_data_path: Optional[str] = None
8888
train_split: str = "train"

examples/contextual_asr/scripts/finetune/finetune_wavlm_libri960_ft_char_hotwords.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ hydra.run.dir=$output_dir \
3232
++dataset_config.val_data_path=$val_data_path \
3333
++dataset_config.input_type=raw \
3434
++dataset_config.dataset=hotwords_dataset \
35-
++dataset_config.file=src/slam_llm/datasets/hotwords_dataset.py:get_speech_dataset \
35+
++dataset_config.file=examples/contextual_asr/dataset/hotwords_dataset.py:get_speech_dataset \
3636
++train_config.model_name=asr \
3737
++train_config.num_epochs=5 \
3838
++train_config.freeze_encoder=true \
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#!/bin/bash
2+
#export PYTHONPATH=/root/whisper:$PYTHONPATH
3+
export CUDA_VISIBLE_DEVICES=2
4+
export TOKENIZERS_PARALLELISM=false
5+
# export CUDA_LAUNCH_BLOCKING=1
6+
7+
run_dir=/nfs/yangguanrou.ygr/codes/SLAM-LLM
8+
cd $run_dir
9+
code_dir=examples/contextual_asr
10+
11+
speech_encoder_path=/nfs/yangguanrou.ygr/ckpts/wavlm_large_ft_libri960_char/wavlm_large_ft_libri960_char.pt
12+
llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5
13+
14+
output_dir=/nfs/yangguanrou.ygr/experiments_librispeech/vicuna-7b-v1.5-WavLM-Large-libri960-ft-char-20240521
15+
ckpt_path=$output_dir/asr_epoch_3_step_9780
16+
N=100
17+
for ref_split in test_clean test_other; do
18+
split=librispeech_${ref_split}
19+
val_data_path=/nfs/maziyang.mzy/data/librispeech/${split}.jsonl
20+
decode_log=$ckpt_path/decode_${split}_beam4_debug
21+
python $code_dir/inference_contextual_asr_batch.py \
22+
--config-path "conf" \
23+
--config-name "prompt.yaml" \
24+
hydra.run.dir=$ckpt_path \
25+
++model_config.llm_name="vicuna-7b-v1.5" \
26+
++model_config.llm_path=$llm_path \
27+
++model_config.llm_dim=4096 \
28+
++model_config.encoder_name=wavlm \
29+
++model_config.normalize=true \
30+
++dataset_config.normalize=true \
31+
++model_config.encoder_projector_ds_rate=5 \
32+
++model_config.encoder_path=$speech_encoder_path \
33+
++model_config.encoder_dim=1024 \
34+
++model_config.encoder_projector=cov1d-linear \
35+
++dataset_config.dataset=speech_dataset \
36+
++dataset_config.val_data_path=$val_data_path \
37+
++dataset_config.input_type=raw \
38+
++dataset_config.inference_mode=true \
39+
++train_config.model_name=asr \
40+
++train_config.freeze_encoder=true \
41+
++train_config.freeze_llm=true \
42+
++train_config.batching_strategy=custom \
43+
++train_config.num_epochs=1 \
44+
++train_config.val_batch_size=1 \
45+
++train_config.num_workers_dataloader=0 \
46+
++train_config.output_dir=$output_dir \
47+
++decode_log=$decode_log \
48+
++ckpt_path=$ckpt_path/model.pt && \
49+
python src/slam_llm/utils/whisper_tn.py ${decode_log}_gt ${decode_log}_gt.proc && \
50+
python src/slam_llm/utils/whisper_tn.py ${decode_log}_pred ${decode_log}_pred.proc && \
51+
python src/slam_llm/utils/compute_wer.py ${decode_log}_gt.proc ${decode_log}_pred.proc ${decode_log}.proc.wer && \
52+
python /nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/my_score.py \
53+
--refs /nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/ref_score/${ref_split}.biasing_${N}.tsv \
54+
--hyps ${decode_log}_pred.proc \
55+
--output_file ${decode_log}.proc.wer
56+
done

examples/contextual_asr/scripts/infer/decode_wavlm_libri960_ft_char_hotwords.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ export TOKENIZERS_PARALLELISM=false
55
export CUDA_LAUNCH_BLOCKING=1
66
export HYDRA_FULL_ERROR=1
77

8-
run_dir=/root/SLAM-LLM
8+
run_dir=/nfs/yangguanrou.ygr/codes/SLAM-LLM
99
cd $run_dir
1010
code_dir=examples/contextual_asr
1111

0 commit comments

Comments
 (0)