diff --git a/examples/aispeech_asr/README.md b/examples/aispeech_asr/README.md new file mode 100644 index 00000000..4568c335 --- /dev/null +++ b/examples/aispeech_asr/README.md @@ -0,0 +1,152 @@ +# AISPEECH_ASR + +## Overview + +This example is designed for large-scale industrial data training, suitable for datasets on the order of 100,000 hours. Its main features include: +- **Support for multi-task training**: Designed to support tasks such as ASR and ST through a unified data format. +- **Dynamic prompt selection**: Supports random selection from multiple prompts. +- **Iterative dataset**: Uses an iterative dataset format to reduce startup time for large datasets. +- **Deepspeed training**: Supports DeepSpeed training to significantly reduce memory usage. +- **Multi-machine multi-GPU inference**: Supports distributed inference across multiple machines and GPUs to reduce evaluation time. +- **Dynamic frame batching**: Dynamically combines frames based on audio size rather than using a fixed batch size, significantly reducing training and evaluation time (reduces training time by 3/4 for 100,000 hours of data). + +This example is modified from `mala_asr_slidespeech`. + +## Model Architecture + +The model architecture can be dynamically selected within the scope supported by SLAM-LMM. Below are some recommended configurations: +- **Encoder**: WavLM, Whisper +- **Projector**: Linear +- **LLM**: Qwen2.5-7B-Instruct, Vicuna1.5-7B + +## Data Preparation + +The following two files are required: +- `multitask.jsonl` +- `multiprompt.jsonl` + +### multitask.jsonl + +The format of this file is as follows, where `path` supports both ark format and wav files: +```json +{"key": "BAC009S0002W0122", "task": "ASR", "target": "而对楼市成交抑制作用最大的限购", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/train/data/data_wav.1.ark:17"} +{"key": "BAC009S0002W0123", "task": "ASR", "target": "也成为地方政府的眼中钉", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/train/data/data_wav.1.ark:191758"} +{"key": "BAC009S0002W0124", "task": "ASR", "target": "自六月底呼和浩特市率先宣布取消限购后", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/train/data/data_wav.1.ark:315339"} +{"key": "BAC009S0764W0238", "task": "hotword", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/test/data/data_wav.1.ark:17343733", "target": "形成一批具有国际竞争力的中国企业", "hotword": "中国"} +``` + +### multiprompt.jsonl + +The format of this file is as follows: +```json +{"task": "ASR", "prompt": "Transcribe speech to text."} +{"task": "ASR", "prompt": "请识别语音."} +{"task": "ZH2EN", "prompt": "请识别语音并翻译为英文:"} +{"task": "EN2ZH", "prompt": "请识别语音并翻译为中文:"} +{"task": "prevtext", "prompt": "Transcribe speech to text, below are the previous historical transcription texts:{}."} +{"task": "hotword", "prompt": "Transcribe speech to text, follow words may occur:{}."} +``` + +### Notes +- If multiple prompts are provided, one will be selected dynamically. +- For additional information (e.g., hotwords), include the task-named information in `multitask.jsonl` and use `{}` in the prompt to inject this information. Additionally, update the `append_info_tasks` in the `aispeech_config` file: + ```python + append_info_tasks: List = field(default_factory=lambda: ["hotword"]) + ``` + +## Training a New Model + +### Script Preparation + +Prepare and modify the following content in `scripts/finetune_deepspeed.sh` or `scripts/finetune_torchrun.sh` (Deepspeed is recommended): +```bash +run_dir= # Directory to save the model +train_scp_file_path= # Path to training data +dev_scp_file_path= # Path to validation data +train_max_frame_length=1500 # Maximum frame length for training +eval_max_frame_length=1000 # Maximum frame length for evaluation +multitask_prompt_path= # Path to multitask.jsonl +prompt_style="\{\}" # Prompt style, e.g., "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" or "USER: {}\n ASSISTANT:" +projector=linear # Type of projector +encoder_name=whisper # Name of the encoder +llm_name=Qwen2.5-7B-Instruct # Name of the LLM +use_peft=false # Whether to use PEFT (for LLM) +use_fp16=true # Whether to use FP16 +freeze_encoder=true # Whether to freeze the encoder +pad_or_trim=true # Whether to use pad_or_trim (for Whisper) +deepspeed_config= # Path to DeepSpeed configuration file +``` + +Typically, we first train the projector and then fine-tune the LoRA. For projector training, set: +```bash +use_peft=false +``` + +For LoRA training, set (with `ckpt_path` pointing to the model saved in the previous step): +```bash +use_peft=true +if [[ $use_peft == "true" ]]; then + ckpt_path= # For DDP training, provide the path to the saved pt file; for DeepSpeed training, convert mp_rank_00_model_states.pt to model.pt using the `scripts/transcribe_deepspeed_to_pt.py` script +fi +``` +### Deepspeed +When using `bf16`/`fp16` for training, deepspeed saves about 20GB of GPU memory compared to `torchrun` when training a 7B model. For 7B models, it's recommended to use `zero-0`/`1`/`2`, while for extremely large models, `zero-3` can be used, though communication may become a bottleneck. + +```json +{ + "train_micro_batch_size_per_gpu": 4, + "gradient_accumulation_steps": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, + "fp16": { + "enabled": true + }, + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu" + } + } +} +``` + +Note that when using `zero-0`/`1`/`2`, the DeepSpeed model is saved in a format that requires a script to convert `mp_rank_00_model_states.pt` to `model.pt`, such as `python scripts/transcribe_deepspeed_to_pt.py mp_rank_00_model_states.pt output_dir`. + +``` +global_step1000 +global_step1000/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt +... +global_step1000/mp_rank_00_model_states.pt +latest +zero_to_fp32.py +``` + +If training with `Zero-3`, the model is saved in a different format and can be converted using `python zero_to_fp32.py global_step50 outputdir`. + +``` +global_step50 +global_step50/zero_pp_rank_0_mp_rank_00_model_states.pt +global_step50/zero_pp_rank_0_mp_rank_00_optim_states.pt +... +latest +zero_to_fp32.py +``` +If you use bf16/fp16 training in DeepSpeed and encounter NaN in train/eval loss, check the autocast in `src/slam_llm/utils/deepspeed_utils.py`: + +```python +with autocast() # original code +with autocast(dtype=torch.bfloat16) # must work +with autocast(dtype=torch.float16) +``` +## Decoding + +- **Single-machine single-GPU decoding**: Refer to `scripts/decode.sh` +- **Single-machine multi-GPU decoding**: Refer to `scripts/decode_deepspeed.sh` + +## Multi-Machine Multi-GPU Support + +Multi-machine multi-GPU training can be supported with minor modifications to the `finetune_deepspeed.sh` or `scripts/decode_deepspeed.sh` scripts. Due to environment-specific requirements, this example does not include dedicated scripts for multi-machine multi-GPU setups. diff --git a/examples/aispeech_asr/README_zh.md b/examples/aispeech_asr/README_zh.md new file mode 100644 index 00000000..f23f5d37 --- /dev/null +++ b/examples/aispeech_asr/README_zh.md @@ -0,0 +1,99 @@ +# AISPEECH_ASR + +## 概述 + +这是为工业界大规模数据训练准备的示例,适用于10万小时量级的数据训练,主要特点如下: +- **多任务训练支持**:通过设计数据格式,支持包括ASR、ST等多种任务。 +- **动态Prompt选择**:支持在多个Prompt中随机选择。 +- **迭代式dataset**:采用迭代形式的dataset,减少大数据量时的启动时间。 +- **Deepspeed训练**:支持Deepspeed训练,显著减少内存使用。 +- **多机多卡推理**:支持多机多卡推理,减少评估时间。 +- **动态帧数组合**:根据每个音频大小动态组合合适的帧数进行训练,而非使用固定的batch_size,大大减少了训练和评估时间(在10万小时量级的数据上,训练时间减少了3/4)。 + +本示例基于`mala_asr_slidespeech`进行修改。 + +## 模型架构 + +可以根据需要,在SLAM—LMM支持的范围内动态选择模型架构。以下是一些推荐的模型配置: +- **Encoder**:WavLM, Whisper +- **Projector**:Linear +- **LLM**:Qwen2.5-7B-Instruct, Vicuna1.5-7B + +## 数据准备 + +需要准备以下两个文件: +- `multitask.jsonl` +- `multiprompt.jsonl` + +### multitask.jsonl + +该文件的内容格式如下,其中`path`支持ark格式和wav文件: +```json +{"key": "BAC009S0002W0122", "task": "ASR", "target": "而对楼市成交抑制作用最大的限购", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/train/data/data_wav.1.ark:17"} +{"key": "BAC009S0002W0123", "task": "ASR", "target": "也成为地方政府的眼中钉", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/train/data/data_wav.1.ark:191758"} +{"key": "BAC009S0002W0124", "task": "ASR", "target": "自六月底呼和浩特市率先宣布取消限购后", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/train/data/data_wav.1.ark:315339"} +{"key": "BAC009S0764W0238", "task": "hotword", "path": "/aistor/aispeech/hpc_stor01/group/asr/mandarin/aishell-1/asr/test/data/data_wav.1.ark:17343733", "target": "形成一批具有国际竞争力的中国企业", "hotword": "中国"} +``` + +### multiprompt.jsonl + +该文件的内容格式如下: +```json +{"task": "ASR", "prompt": "Transcribe speech to text."} +{"task": "ASR", "prompt": "请识别语音."} +{"task": "ZH2EN", "prompt": "请识别语音并翻译为英文:"} +{"task": "EN2ZH", "prompt": "请识别语音并翻译为中文:"} +{"task": "prevtext", "prompt": "Transcribe speech to text, below are the previous historical transcription texts:{}."} +{"task": "hotword", "prompt": "Transcribe speech to text, follow words may occur:{}."} +``` + +### 注意事项 +- 如果有多条Prompt,会动态选择其中一条。 +- 如果有额外信息(如热词),请在`multitask.jsonl`中提供与任务同名的信息,并在Prompt中使用`{}`注入该信息。同时,修改`aispeech_config`文件中的`append_info_tasks`: + ```python + append_info_tasks: List = field(default_factory=lambda: ["hotword"]) + ``` + +## 训练新模型 + +### 脚本准备 + +在`scripts/finetune_deepspeed.sh`或`scripts/finetune_torchrun.sh`中准备并修改以下内容(推荐使用Deepspeed): +```bash +run_dir= # 模型保存目录 +train_scp_file_path= # 训练数据路径 +dev_scp_file_path= # 验证数据路径 +train_max_frame_length=1500 # 训练时的最大帧长度 +eval_max_frame_length=1000 # 评估时的最大帧长度 +multitask_prompt_path= # multitask.jsonl文件路径 +prompt_style="\{\}" # Prompt样式,可选格式如"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"或"USER: {}\n ASSISTANT:" +projector=linear # Projector类型 +encoder_name=whisper # Encoder名称 +llm_name=Qwen2.5-7B-Instruct # LLM名称 +use_peft=false # 是否使用PEFT(对于LLM) +use_fp16=true # 是否使用FP16 +freeze_encoder=true # 是否冻结Encoder +pad_or_trim=true # 是否使用pad_or_trim(对于Whisper) +deepspeed_config= # DeepSpeed配置文件路径 +``` + +通常,我们首先训练Projector,然后再训练LoRA。训练Projector时,设置如下: +```bash +use_peft=false +``` + +训练LoRA时,设置如下(`ckpt_path`是上一步训练保存的模型路径): +```bash +use_peft=true +if [[ $use_peft == "true" ]]; then + ckpt_path= # 如果是DDP训练,直接写入保存的pt文件路径;如果是Deepspeed训练,需将mp_rank_00_model_states.pt文件转化为model.pt,可使用`scripts/transcribe_deepspeed_to_pt.py`脚本 +fi +``` + +## 解码 + +- **单机单卡解码**:参考`scripts/decode.sh` +- **单机多卡解码**:参考`scripts/decode_deepspeed.sh` + +## 多机多卡支持 +简单修改脚本finetune_deepspeed.sh 或者scripts/decode_deepspeed.sh`后可以支持多机多卡训练,因为环境不同所做的修改也不同,本实例就不放出多机多卡的脚本了 diff --git a/examples/aispeech_asr/aispeech_asr_config.py b/examples/aispeech_asr/aispeech_asr_config.py new file mode 100644 index 00000000..cf1bdab5 --- /dev/null +++ b/examples/aispeech_asr/aispeech_asr_config.py @@ -0,0 +1,141 @@ +from dataclasses import dataclass, field +from typing import Optional, List +from torch.distributed.fsdp import ShardingStrategy + + +@dataclass +class ModelConfig: + file: str = "examples/aispeech_asr/model/aispeech_asr.py:model_factory" + llm_name: str = "vicuna-7b-v1.5" + llm_path: str = "PATH/to/LLAMA/7B" + llm_type: str = "decoder_only" + llm_dim: int = 4096 + whisper_decode : Optional[bool] = False + encoder_name: Optional[str] = None + encoder_ds_rate: int = 2 + encoder_path: Optional[str] = None + encoder_path_hf: Optional[str] = None + encoder_dim: int = 1280 + encoder_projector: str = "linear" + qformer_layers : int = 8 + encoder_projector_ds_rate: int = 5 + modal: str = "audio" + normalize: Optional[bool] = field(default=False, metadata={ + "help": "whether input is normalized, used for models such as wavlm" + }) + encoder_type: str = field(default="finetune", metadata={ + "help": "whether model is only pretrained or finetuned, used for models such as hubert" + }) + + +@dataclass +class PeftConfig: + peft_method: str = "lora" # None , llama_adapter, prefix + r: int = 64 + lora_alpha: int = 16 + target_modules: List = field(default_factory=lambda: [ "q_proj","k_proj", "v_proj", "o_proj", "up_proj","gate_proj","down_proj"]) + bias: str = "none" + task_type: str = "CAUSAL_LM" + lora_dropout: float = 0.05 + inference_mode: bool = False + +@dataclass +class TrainConfig: + model_name:str = "PATH/to/LLAMA/7B" + enable_ddp:bool = False + enable_deepspeed:bool = False + enable_fsdp:bool = False + low_cpu_fsdp:bool = False + run_validation:bool = True + batch_size_training: Optional[int] = None + batching_strategy:str = field(default="packing", metadata={ + "help":"alternative: padding" + }) # + context_length:int = 4096 + gradient_accumulation_steps:int = 1 + num_epochs:int = 3 + num_workers_dataloader:int = 1 + warmup_steps:int = 1000 + total_steps:int = 100000 + validation_interval:int = 1000 + lr:float = 1e-4 + weight_decay:float = 0.0 + gamma:float = 0.85 + seed:int = 42 + use_fp16:bool = False + mixed_precision:bool = True + val_batch_size:Optional[int] = None + + use_peft:bool = False + peft_config:PeftConfig = field(default_factory=PeftConfig) + output_dir:str = "PATH/to/save/PEFT/model" + freeze_layers:bool = False + num_freeze_layers:int = 1 + quantization:bool = False + one_gpu:bool = False + save_model:bool = True + dist_checkpoint_root_folder:str = "PATH/to/save/FSDP/model" # will be used if using FSDP + dist_checkpoint_folder:str = "fine-tuned" # will be used if using FSDP + save_optimizer:bool = False # will be used if using FSDP + use_fast_kernels:bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels + run_test_during_validation:bool = False + run_test_during_validation_file:str = "test.wav" + run_test_during_validation_prompt:str = "<|ASR|>" + freeze_llm:bool = field(default=False, metadata={ + "help": "whether to freeze llm when finetuning, should be true when use peft finetuning" + }) + freeze_encoder:bool = False + +@dataclass +class DataConfig: + dataset: str = "multitask_dataset" + train_max_frame_length: int = 1500 + eval_max_frame_length: int = 1000 + multitask_prompt_path: str = "/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/data/multiprompt.jsonl" + prompt_style: str = "\{\}" # + append_info_tasks : List = field(default_factory=lambda: [ "hotword"]) + file: str = "examples/aispeech_asr/slam_llm/datasets/speech_dataset_large.py:get_speech_dataset" + train_scp_file_path: str = "" + dev_scp_file_path: str = "" + test_scp_file_path: str = "" + train_split: str = "train" + dev_split: str = "dev" + test_split:str = "test" + pad_or_trim: bool = True + prompt: Optional[str] = None + use_ocr: bool = True + inference_mode: bool = False + lower: bool = False + fix_length_audio: int = -1 + inference_mode:bool = False + input_type: str = field(default="raw", metadata={ + "help":"Use raw when input is wav, mel when for whisper" + }) + mel_size: int = field(default=80, metadata={ + "help": "80 for whisper large v1 and v2, 128 for v3" + }) + normalize: Optional[bool] = field(default=False, metadata={ + "help": "whether input is normalized, used for models such as wavlm" + }) + +@dataclass +class FSDPConfig: + mixed_precision: bool = True + use_fp16: bool = False + # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD + sharding_strategy: ShardingStrategy = "SHARD_GRAD_OP" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP + checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. + fsdp_activation_checkpointing: bool = True + fsdp_cpu_offload: bool = False + pure_bf16: bool = False + optimizer: str = "AdamW" + +@dataclass +class LogConfig: + use_wandb: bool = False + wandb_dir: str = "tmp/test_wandb" + wandb_entity_name: str = "project_name" + wandb_project_name: str = "project_name" + wandb_exp_name: str = "exp_name" + log_file: str = "tmp/test.log" + log_interval: int = 5 diff --git a/examples/aispeech_asr/conf/ds_config.json b/examples/aispeech_asr/conf/ds_config.json new file mode 100644 index 00000000..e6726609 --- /dev/null +++ b/examples/aispeech_asr/conf/ds_config.json @@ -0,0 +1,37 @@ +{ + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 5e-5, + "betas": [0.9, 0.999], + "eps": 1e-06 + } + }, + "bf16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 100, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 0.01 + }, + "zero_optimization": { + "stage": 2, + "allgather_partitions": true, + "overlap_comm": true, + "reduce_scatter": true, + "contiguous_gradients": true + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0.00, + "warmup_max_lr": 0.00005, + "warmup_num_steps": 1000 + } + }, + "checkpoint_activations": false + +} \ No newline at end of file diff --git a/examples/aispeech_asr/conf/prompt.yaml b/examples/aispeech_asr/conf/prompt.yaml new file mode 100644 index 00000000..32cf2374 --- /dev/null +++ b/examples/aispeech_asr/conf/prompt.yaml @@ -0,0 +1,14 @@ +dataset_config: + # we put prompt here, because the hydra override in shell script only support a small subset of chars + # prompt: "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " + # prompt: "<|im_start|>user\n请将语音转写为汉字<|im_end|>\n<|im_start|>assistant\n" + # prompt: "识别语音" + # prompt : "将上面的语音转写为英文" + # prompt: "Transcribe speech to English." + # prompt: "Transcribe speech to text.And then translate the text to spanish." + # prompt: "Transcribe speech to text." + # prompt: "Tell me what is the language of the text." + prompt: "Transcribe speech to text." + # prompt: "Transcribe speech to text.Follow words may occur in audio:{}." + # prompt: "" + # prompt: "请问上面有几个句子,有多少个字,给字编号然后输出文本" diff --git a/examples/aispeech_asr/finetune_deepspeed.py b/examples/aispeech_asr/finetune_deepspeed.py new file mode 100644 index 00000000..89d39a0e --- /dev/null +++ b/examples/aispeech_asr/finetune_deepspeed.py @@ -0,0 +1,58 @@ +from slam_llm.pipeline.finetune_deepspeed import main as train +from typing import Optional +import argparse +import hydra +import logging +from dataclasses import dataclass, field +from omegaconf import DictConfig, ListConfig, OmegaConf +from aispeech_asr_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig +from slam_llm.utils.deepspeed_utils import deepspeed_main_wrapper + +import sys +@dataclass +class RunConfig: + dataset_config: DataConfig = field(default_factory=DataConfig) + model_config: ModelConfig = field(default_factory=ModelConfig) + train_config: TrainConfig = field(default_factory=TrainConfig) + log_config: LogConfig = field(default_factory=LogConfig) + fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) + debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) + metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) + ckpt_path: Optional[str] = field( + default=None, metadata={"help": "The path to projector checkpoint"} + ) + deepspeed_config : str ="examples/aispeech_asr/conf/ds_config.json" + deepspeed_ckpt_path: Optional[str] = field( + default=None, metadata={"help": "The path to projector checkpoint"} + ) + deepspeed_ckpt_id: Optional[str] = field( + default=None, metadata={"help": "The id to projector checkpoint"} + ) + +@deepspeed_main_wrapper(config_name=None, version_base=None) +def main_hydra(cfg: DictConfig): + run_config = RunConfig() + cfg = OmegaConf.merge(run_config, cfg) + def to_plain_list(cfg_item): + if isinstance(cfg_item, ListConfig): + return OmegaConf.to_container(cfg_item, resolve=True) + elif isinstance(cfg_item, DictConfig): + return {k: to_plain_list(v) for k, v in cfg_item.items()} + else: + return cfg_item + + # kwargs = to_plain_list(cfg) + kwargs = cfg + log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) + + logging.basicConfig(level=log_level) + + if kwargs.get("debug", False): + import pdb; + pdb.set_trace() + + train(kwargs) + + +if __name__ == "__main__": + main_hydra() \ No newline at end of file diff --git a/examples/aispeech_asr/finetune_torchrun.py b/examples/aispeech_asr/finetune_torchrun.py new file mode 100644 index 00000000..b9c1e86e --- /dev/null +++ b/examples/aispeech_asr/finetune_torchrun.py @@ -0,0 +1,49 @@ +from slam_llm.pipeline.finetune import main as train +from typing import Optional + +import hydra +import logging +from dataclasses import dataclass, field +from omegaconf import DictConfig, ListConfig, OmegaConf +from aispeech_asr_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig + +@dataclass +class RunConfig: + dataset_config: DataConfig = field(default_factory=DataConfig) + model_config: ModelConfig = field(default_factory=ModelConfig) + train_config: TrainConfig = field(default_factory=TrainConfig) + log_config: LogConfig = field(default_factory=LogConfig) + fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) + debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) + metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) + ckpt_path: Optional[str] = field( + default=None, metadata={"help": "The path to projector checkpoint"} + ) + +@hydra.main(config_name=None, version_base=None) +def main_hydra(cfg: DictConfig): + run_config = RunConfig() + cfg = OmegaConf.merge(run_config, cfg) + def to_plain_list(cfg_item): + if isinstance(cfg_item, ListConfig): + return OmegaConf.to_container(cfg_item, resolve=True) + elif isinstance(cfg_item, DictConfig): + return {k: to_plain_list(v) for k, v in cfg_item.items()} + else: + return cfg_item + + # kwargs = to_plain_list(cfg) + kwargs = cfg + log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) + + logging.basicConfig(level=log_level) + + if kwargs.get("debug", False): + import pdb; + pdb.set_trace() + + train(kwargs) + + +if __name__ == "__main__": + main_hydra() \ No newline at end of file diff --git a/examples/aispeech_asr/inference_batch.py b/examples/aispeech_asr/inference_batch.py new file mode 100644 index 00000000..44ce05a7 --- /dev/null +++ b/examples/aispeech_asr/inference_batch.py @@ -0,0 +1,53 @@ +from slam_llm.pipeline.inference_batch import main as inference + +import hydra +import logging +from dataclasses import dataclass, field +from omegaconf import DictConfig, ListConfig, OmegaConf +from typing import Optional +from aispeech_asr_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig + + +@dataclass +class RunConfig: + dataset_config: DataConfig = field(default_factory=DataConfig) + model_config: ModelConfig = field(default_factory=ModelConfig) + train_config: TrainConfig = field(default_factory=TrainConfig) + log_config: LogConfig = field(default_factory=LogConfig) + fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) + debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) + metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) + decode_log: str = field( + default="output/decode_log", + metadata={"help": "The prefix for the decode output"}, + ) + ckpt_path: str = field( + default="output/model.pt", metadata={"help": "The path to projector checkpoint"} + ) + peft_ckpt: Optional[str] = field( + default=None, + metadata={ + "help": "The path to peft checkpoint, should be a directory including adapter_config.json" + }, + ) + + +@hydra.main(config_name=None, version_base=None) +def main_hydra(cfg: DictConfig): + run_config = RunConfig() + cfg = OmegaConf.merge(run_config, cfg) + # kwargs = to_plain_list(cfg) + log_level = getattr(logging, cfg.get("log_level", "INFO").upper()) + + logging.basicConfig(level=log_level) + + if cfg.get("debug", False): + import pdb + + pdb.set_trace() + + inference(cfg) + + +if __name__ == "__main__": + main_hydra() diff --git a/examples/aispeech_asr/inference_batch_deepspeed.py b/examples/aispeech_asr/inference_batch_deepspeed.py new file mode 100644 index 00000000..19264772 --- /dev/null +++ b/examples/aispeech_asr/inference_batch_deepspeed.py @@ -0,0 +1,60 @@ +from slam_llm.pipeline.inference_batch_deepspeed import main as inference +from typing import Optional +import argparse +import hydra +import logging +from dataclasses import dataclass, field +from omegaconf import DictConfig, ListConfig, OmegaConf +from aispeech_asr_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig +import sys +from slam_llm.utils.deepspeed_utils import deepspeed_main_wrapper +@dataclass +class RunConfig: + dataset_config: DataConfig = field(default_factory=DataConfig) + model_config: ModelConfig = field(default_factory=ModelConfig) + train_config: TrainConfig = field(default_factory=TrainConfig) + log_config: LogConfig = field(default_factory=LogConfig) + fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) + debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) + metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) + decode_log: str = field( + default="output/decode_log", + metadata={"help": "The prefix for the decode output"}, + ) + ckpt_path: str = field( + default="output/model.pt", metadata={"help": "The path to projector checkpoint"} + ) + peft_ckpt: Optional[str] = field( + default=None, + metadata={ + "help": "The path to peft checkpoint, should be a directory including adapter_config.json" + }, + ) + +@deepspeed_main_wrapper(config_name=None, version_base=None) +def main_hydra(cfg: DictConfig): + run_config = RunConfig() + cfg = OmegaConf.merge(run_config, cfg) + def to_plain_list(cfg_item): + if isinstance(cfg_item, ListConfig): + return OmegaConf.to_container(cfg_item, resolve=True) + elif isinstance(cfg_item, DictConfig): + return {k: to_plain_list(v) for k, v in cfg_item.items()} + else: + return cfg_item + + # kwargs = to_plain_list(cfg) + kwargs = cfg + log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) + + logging.basicConfig(level=log_level) + + if kwargs.get("debug", False): + import pdb; + pdb.set_trace() + + inference(kwargs) + + +if __name__ == "__main__": + main_hydra() \ No newline at end of file diff --git a/examples/aispeech_asr/model/aispeech_asr.py b/examples/aispeech_asr/model/aispeech_asr.py new file mode 100644 index 00000000..830b49fa --- /dev/null +++ b/examples/aispeech_asr/model/aispeech_asr.py @@ -0,0 +1,156 @@ +import torch +import os +import logging +from slam_llm.models.slam_model import ( + slam_model, + setup_tokenizer, + setup_encoder, + setup_encoder_projector, + setup_llm, +) +from slam_llm.utils.train_utils import print_model_size + +logger = logging.getLogger(__name__) + +def model_factory(train_config, model_config, **kwargs): + # return necessary components for training + tokenizer = setup_tokenizer(train_config, model_config, **kwargs) + + encoder = setup_encoder(train_config, model_config, **kwargs) + + # llm + llm = setup_llm(train_config, model_config, **kwargs) + + # projector + encoder_projector = setup_encoder_projector( + train_config, model_config, **kwargs + ) + model = slam_model_asr( + encoder, + llm, + encoder_projector, + tokenizer, + train_config, + model_config, + **kwargs, + ) + + ckpt_path = kwargs.get( + "ckpt_path", None + ) # FIX(MZY): load model ckpt(mainly projector, related to model_checkpointing/checkpoint_handler.py: save_model_checkpoint_peft) + if ckpt_path is not None: + logger.info("loading other parts from: {}".format(ckpt_path)) + ckpt_dict = torch.load(ckpt_path, map_location="cpu") + model.load_state_dict(ckpt_dict, strict=False) + + + print_model_size( + model, + train_config, + ( + int(os.environ["RANK"]) + if train_config.enable_fsdp or train_config.enable_ddp + else 0 + ), + ) + return model, tokenizer + + +class slam_model_asr(slam_model): + def __init__( + self, + encoder, + llm, + encoder_projector, + tokenizer, + train_config, + model_config, + **kwargs, + ): + super().__init__( + encoder, + llm, + encoder_projector, + tokenizer, + train_config, + model_config, + **kwargs, + ) + + + @torch.no_grad() + def inference( + self, + wav_path=None, + prompt=None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=None, + assistant_model=None, + streamer=None, + negative_prompt_ids=None, + negative_prompt_attention_mask=None, + **kwargs, + ): + # inference for asr model + + device = kwargs.get("device", "npu") + if os.path.exists(wav_path): # Audio-Text QA + import whisper + + audio_raw = whisper.load_audio(wav_path) + audio_raw = whisper.pad_or_trim(audio_raw) + + mel_size = getattr( + self.dataset_config, "mel_size", 80 + ) # 80 for large v1 and v2, 128 for large v3 + audio_mel = ( + whisper.log_mel_spectrogram(audio_raw, n_mels=mel_size) + .permute(1, 0)[None, :, :] + .to(device) + ) + + encoder_outs = self.encoder.extract_variable_length_features( + audio_mel.permute(0, 2, 1) + ) + + if self.model_config.encoder_projector == "q-former": + audio_mel_post_mask = torch.ones( + encoder_outs.size()[:-1], dtype=torch.long + ).to(encoder_outs.device) + encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask) + if self.model_config.encoder_projector == "linear": + encoder_outs = self.encoder_projector(encoder_outs) + else: # Text QA + encoder_outs = torch.empty( + 1, 0, self.llm.model.embed_tokens.embedding_dim + ).to(device) + + prompt = "USER: {}\n ASSISTANT:".format(prompt) + prompt_ids = self.tokenizer.encode(prompt) + prompt_length = len(prompt_ids) + prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(device) + + if hasattr(self.llm.model, "embed_tokens"): + inputs_embeds = self.llm.model.embed_tokens(prompt_ids) + elif hasattr(self.llm.model.model, "embed_tokens"): + inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids) + else: + inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids) + + inputs_embeds = torch.cat( + (encoder_outs, inputs_embeds[None, :, :]), dim=1 + ) # [audio,prompt] + + attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to( + inputs_embeds.device + ) + + # generate + model_outputs = self.generate( + inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs + ) + + return model_outputs diff --git a/examples/aispeech_asr/scripts/decode.sh b/examples/aispeech_asr/scripts/decode.sh new file mode 100644 index 00000000..30790213 --- /dev/null +++ b/examples/aispeech_asr/scripts/decode.sh @@ -0,0 +1,91 @@ +#!/bin/bash +set -e +run_dir=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/github/SLAM-LLM-NPU +cd $run_dir +code_dir=examples/aispeech_asr + +prompt_style="\{\}" # "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" | "USER: {}\n ASSISTANT:" +projector=linear +encoder_name=whisper +llm_name=Qwen2.5-7B-Instruct +use_peft=false +use_fp16=false +pad_or_trim=true +encoder_projector_ds_rate=5 +eval_max_frame_length=1000 +ckpt_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/project/aispeech_asr/exp/librispeech/20250322/whisper_linear_Qwen2.5-7B-Instruct_lorafalse_padtrue_normal_asr_speedfalse_specaugfalse-1121/mala_asr_epoch_2_step_25000_best +test_scp_file_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/data/aishell-1/asr/test + + +# Choose Encoder +if [[ $encoder_name == "whisper" ]] +then + speech_encoder_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/model/whisper/large-v3.pt + mel_size=128 + encoder_dim=1280 + input_type=mel +elif [[ $encoder_name == "wavlm" ]] +then + speech_encoder_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/model/wavlm/WavLM-Large.pt + encoder_dim=1024 + input_type=raw + mel_size=128 +else + exit 1 +fi + +# Choose LLM +if [[ $llm_name == "vicuna-7b-v1.5" ]] +then + llm_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/model/vicuna-7b-v1.5 + llm_dim=4096 +elif [[ $llm_name == "Qwen2.5-7B-Instruct" ]] +then + llm_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/model/Qwen2.5-7B-Instruct + llm_dim=3584 +elif [[ $llm_name == "Qwen2-7B" ]] +then + llm_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/model/Qwen2-7B + llm_dim=3584 +elif [[ $llm_name == "Qwen2.5-7B" ]] +then + llm_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/model/Qwen2.5-7B + llm_dim=3584 +else + exit 1 +fi + +decode_log=$ckpt_path/decode +python \ + $code_dir/inference_batch.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + hydra.run.dir=$ckpt_path \ + ++model_config.llm_path=$llm_path \ + ++model_config.llm_dim=$llm_dim \ + ++model_config.encoder_name=$encoder_name \ + ++model_config.normalize=true \ + ++model_config.encoder_projector_ds_rate=5 \ + ++model_config.encoder_path=$speech_encoder_path \ + ++model_config.encoder_dim=$encoder_dim \ + ++model_config.encoder_projector=$projector \ + ++dataset_config.prompt_style=$prompt_style \ + ++dataset_config.dataset=$dataset \ + ++dataset_config.pad_or_trim=$pad_or_trim \ + ++dataset_config.test_scp_file_path=$test_scp_file_path \ + ++dataset_config.input_type=$input_type \ + ++dataset_config.mel_size=$mel_size \ + ++dataset_config.inference_mode=true \ + ++train_config.model_name=aispeech_asr \ + ++train_config.freeze_encoder=true \ + ++train_config.freeze_llm=true \ + ++train_config.use_peft=$use_peft \ + ++train_config.batching_strategy=dynamic \ + ++train_config.num_epochs=1 \ + ++train_config.num_workers_dataloader=0 \ + ++train_config.output_dir=$output_dir \ + ++decode_log=$decode_log \ + ++ckpt_path=$ckpt_path/model.pt + + + diff --git a/examples/aispeech_asr/scripts/decode_deepspeed.sh b/examples/aispeech_asr/scripts/decode_deepspeed.sh new file mode 100644 index 00000000..846874ba --- /dev/null +++ b/examples/aispeech_asr/scripts/decode_deepspeed.sh @@ -0,0 +1,93 @@ +#!/bin/bash +set -e +run_dir=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/github/SLAM-LLM-NPU +cd $run_dir +code_dir=examples/aispeech_asr + +prompt_style="\{\}" # "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" | "USER: {}\n ASSISTANT:" +projector=linear +encoder_name=whisper +llm_name=Qwen2.5-7B-Instruct +use_peft=false +use_fp16=false +pad_or_trim=true +encoder_projector_ds_rate=5 +eval_max_frame_length=1000 +ckpt_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/project/aispeech_asr/exp/librispeech/20250322/whisper_linear_Qwen2.5-7B-Instruct_lorafalse_padtrue_normal_asr_speedfalse_specaugfalse-1121/mala_asr_epoch_2_step_25000_best +test_scp_file_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/data/aishell-1/asr/test + + +# Choose Encoder +if [[ $encoder_name == "whisper" ]] +then + speech_encoder_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/model/whisper/large-v3.pt + mel_size=128 + encoder_dim=1280 + input_type=mel +elif [[ $encoder_name == "wavlm" ]] +then + speech_encoder_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/model/wavlm/WavLM-Large.pt + encoder_dim=1024 + input_type=raw + mel_size=128 +else + exit 1 +fi + +# Choose LLM +if [[ $llm_name == "vicuna-7b-v1.5" ]] +then + llm_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/model/vicuna-7b-v1.5 + llm_dim=4096 +elif [[ $llm_name == "Qwen2.5-7B-Instruct" ]] +then + llm_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/model/Qwen2.5-7B-Instruct + llm_dim=3584 +elif [[ $llm_name == "Qwen2-7B" ]] +then + llm_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/model/Qwen2-7B + llm_dim=3584 +elif [[ $llm_name == "Qwen2.5-7B" ]] +then + llm_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/model/Qwen2.5-7B + llm_dim=3584 +else + exit 1 +fi + +decode_log=$ckpt_path/decode +deepspeed \ + --num_nodes 1 \ + --num_gpus 8 \ + $code_dir/inference_batch_deepspeed.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + hydra.run.dir=$ckpt_path \ + ++model_config.llm_path=$llm_path \ + ++model_config.llm_dim=$llm_dim \ + ++model_config.encoder_name=$encoder_name \ + ++model_config.normalize=true \ + ++model_config.encoder_projector_ds_rate=5 \ + ++model_config.encoder_path=$speech_encoder_path \ + ++model_config.encoder_dim=$encoder_dim \ + ++model_config.encoder_projector=$projector \ + ++dataset_config.prompt_style=$prompt_style \ + ++dataset_config.dataset=$dataset \ + ++dataset_config.pad_or_trim=$pad_or_trim \ + ++dataset_config.test_scp_file_path=$test_scp_file_path \ + ++dataset_config.input_type=$input_type \ + ++dataset_config.mel_size=$mel_size \ + ++dataset_config.inference_mode=true \ + ++train_config.model_name=aispeech_asr \ + ++train_config.freeze_encoder=true \ + ++train_config.freeze_llm=true \ + ++train_config.use_peft=$use_peft \ + ++train_config.batching_strategy=dynamic \ + ++train_config.num_epochs=1 \ + ++train_config.num_workers_dataloader=0 \ + ++train_config.output_dir=$output_dir \ + ++decode_log=$decode_log \ + ++ckpt_path=$ckpt_path/model.pt + + + diff --git a/examples/aispeech_asr/scripts/finetune_deepspeed.sh b/examples/aispeech_asr/scripts/finetune_deepspeed.sh new file mode 100644 index 00000000..52071137 --- /dev/null +++ b/examples/aispeech_asr/scripts/finetune_deepspeed.sh @@ -0,0 +1,124 @@ +#!/bin/bash +# export PYTHONPATH=/root/fairseq:$PYTHONPATH +# export ASCEND_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export TOKENIZERS_PARALLELISM=false +export HCCL_CONNECT_TIMEOUT=7200 +# export CUDA_LAUNCH_BLOCKING=1 +export HYDRA_FULL_ERROR=1 +export OMP_NUM_THREADS=1 + + + +run_dir=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/github/SLAM-LLM-NPU +cd $run_dir +code_dir=examples/aispeech_asr + +train_scp_file_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/data/aishell-1/asr/test +dev_scp_file_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/data/aishell-1/asr/test +train_max_frame_length=500 +eval_max_frame_length=500 +multitask_prompt_path="/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/data/multiprompt.jsonl" +prompt_style="\{\}" # "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" | "USER: {}\n ASSISTANT:" +projector=linear +encoder_name=whisper +llm_name=Qwen2.5-7B-Instruct +use_peft=false # For llm +use_fp16=true +freeze_encoder=true +pad_or_trim=true # For whisper + +deepspeed_config=examples/aispeech_asr/conf/ds_config.json + +if [[ $use_peft == "true" || $freeze_encoder == false ]];then + ckpt_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/project/aispeech_asr/exp/slidespeech/20250414/whisper_linear_Qwen2.5-7B-Instruct_lorafalse_padtrue_normal_asr_speedfalse_specaugfalse-1515_slidespeech_text/mala_asr_epoch_2_step_7000 +fi + +# Choose Encoder +if [[ $encoder_name == "whisper" ]] +then + speech_encoder_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/model/whisper/large-v3.pt + mel_size=128 + encoder_dim=1280 + input_type=mel + +elif [[ $encoder_name == "wavlm" ]] +then + speech_encoder_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/model/wavlm/WavLM-Large.pt + encoder_dim=1024 + input_type=raw + mel_size=128 +else + exit 1 +fi + +# Choose LLM +if [[ $llm_name == "vicuna-7b-v1.5" ]] +then + llm_path= + llm_dim=4096 +elif [[ $llm_name == "Qwen2.5-7B-Instruct" ]] +then + llm_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/model/Qwen2.5-7B-Instruct + llm_dim=3584 +elif [[ $llm_name == "Qwen2-7B" ]] +then + llm_path= + llm_dim=3584 +elif [[ $llm_name == "Qwen2.5-1.5B-Instruct" ]] +then + llm_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/model/Qwen2.5-1.5B-Instruct + llm_dim=3584 +else + exit 1 +fi + + + + + + +output_dir=${code_dir}/exp/$(date +"%Y%m%d-%H%M") +hydra_args=" +hydra.run.dir=$output_dir \ +++model_config.llm_path=$llm_path \ +++model_config.llm_dim=$llm_dim \ +++model_config.encoder_name=$encoder_name \ +++model_config.encoder_path=$speech_encoder_path \ +++model_config.encoder_dim=$encoder_dim \ +++model_config.encoder_projector=$projector \ +++dataset_config.prompt_style=$prompt_style \ +++dataset_config.train_max_frame_length=$train_max_frame_length \ +++dataset_config.eval_max_frame_length=$eval_max_frame_length \ +++dataset_config.multitask_prompt_path=$multitask_prompt_path \ +++dataset_config.input_type=$input_type \ +++dataset_config.mel_size=$mel_size \ +++dataset_config.pad_or_trim=$pad_or_trim \ +++dataset_config.train_scp_file_path=$train_scp_file_path \ +++dataset_config.dev_scp_file_path=$dev_scp_file_path \ +++train_config.model_name=aispeech_asr \ +++train_config.num_epochs=5 \ +++train_config.freeze_encoder=$freeze_encoder \ +++train_config.freeze_llm=true \ +++train_config.use_peft=$use_peft \ +++train_config.batching_strategy=dynamic \ +++train_config.validation_interval=1000 \ +++train_config.num_workers_dataloader=8 \ +++train_config.output_dir=$output_dir \ +++metric=acc \ +" +if [[ $use_peft == "true" || $freeze_encoder == false ]];then + hydra_args+="++ckpt_path=$ckpt_path/model.pt" +fi + + +deepspeed \ + --num_nodes 1 \ + --num_gpus 8 \ + $code_dir/finetune_deepspeed.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + ++train_config.enable_fsdp=false \ + ++train_config.enable_ddp=true \ + ++train_config.use_fp16=$use_fp16 \ + ++deepspeed_config=$deepspeed_config \ + ${hydra_args} diff --git a/examples/aispeech_asr/scripts/finetune_torchrun.sh b/examples/aispeech_asr/scripts/finetune_torchrun.sh new file mode 100644 index 00000000..abcd3e5e --- /dev/null +++ b/examples/aispeech_asr/scripts/finetune_torchrun.sh @@ -0,0 +1,126 @@ +#!/bin/bash +# export PYTHONPATH=/root/fairseq:$PYTHONPATH +# export ASCEND_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export TOKENIZERS_PARALLELISM=false +export HCCL_CONNECT_TIMEOUT=7200 +# export CUDA_LAUNCH_BLOCKING=1 +export HYDRA_FULL_ERROR=1 +export OMP_NUM_THREADS=1 + + + +run_dir=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/github/SLAM-LLM-NPU +cd $run_dir +code_dir=examples/aispeech_asr + +train_scp_file_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/data/aishell-1/asr/test +dev_scp_file_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/data/aishell-1/asr/test +train_max_frame_length=1500 +eval_max_frame_length=3000 +multitask_prompt_path="/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/data/multiprompt.jsonl" +prompt_style="\{\}" # "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" | "USER: {}\n ASSISTANT:" +projector=linear +encoder_name=whisper +llm_name=Qwen2.5-1.5B-Instruct +use_peft=false # For llm +use_fp16=true +freeze_encoder=true +pad_or_trim=true # For whisper + + +if [[ $use_peft == "true" || $freeze_encoder == false ]];then + ckpt_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/project/aispeech_asr/exp/slidespeech/20250414/whisper_linear_Qwen2.5-7B-Instruct_lorafalse_padtrue_normal_asr_speedfalse_specaugfalse-1515_slidespeech_text/mala_asr_epoch_2_step_7000 +fi + +# Choose Encoder +if [[ $encoder_name == "whisper" ]] +then + speech_encoder_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/model/whisper/large-v3.pt + mel_size=128 + encoder_dim=1280 + input_type=mel + +elif [[ $encoder_name == "wavlm" ]] +then + speech_encoder_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/model/wavlm/WavLM-Large.pt + encoder_dim=1024 + input_type=raw + mel_size=128 +else + exit 1 +fi + +# Choose LLM +if [[ $llm_name == "vicuna-7b-v1.5" ]] +then + llm_path= + llm_dim=4096 +elif [[ $llm_name == "Qwen2.5-7B-Instruct" ]] +then + llm_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/model/Qwen2.5-7B-Instruct + llm_dim=3584 +elif [[ $llm_name == "Qwen2-7B" ]] +then + llm_path= + llm_dim=3584 +elif [[ $llm_name == "Qwen2.5-1.5B-Instruct" ]] +then + llm_path=/aistor/aispeech/hpc_stor01/home/fangyangui/workingspace/model/Qwen2.5-1.5B-Instruct + llm_dim=1536 +elif [[ $llm_name == "Qwen2.5-7B" ]] +then + llm_path= + llm_dim=3584 +else + exit 1 +fi + + + + + + +output_dir=${code_dir}/exp/$(date +"%Y%m%d-%H%M") +hydra_args=" +hydra.run.dir=$output_dir \ +++model_config.llm_path=$llm_path \ +++model_config.llm_dim=$llm_dim \ +++model_config.encoder_name=$encoder_name \ +++model_config.encoder_path=$speech_encoder_path \ +++model_config.encoder_dim=$encoder_dim \ +++model_config.encoder_projector=$projector \ +++dataset_config.prompt_style=$prompt_style \ +++dataset_config.train_max_frame_length=$train_max_frame_length \ +++dataset_config.eval_max_frame_length=$eval_max_frame_length \ +++dataset_config.multitask_prompt_path=$multitask_prompt_path \ +++dataset_config.input_type=$input_type \ +++dataset_config.mel_size=$mel_size \ +++dataset_config.pad_or_trim=$pad_or_trim \ +++dataset_config.train_scp_file_path=$train_scp_file_path \ +++dataset_config.dev_scp_file_path=$dev_scp_file_path \ +++train_config.model_name=aispeech_asr \ +++train_config.num_epochs=5 \ +++train_config.freeze_encoder=$freeze_encoder \ +++train_config.freeze_llm=true \ +++train_config.use_peft=$use_peft \ +++train_config.batching_strategy=dynamic \ +++train_config.validation_interval=10 \ +++train_config.num_workers_dataloader=8 \ +++train_config.output_dir=$output_dir \ +++metric=acc \ +" +if [[ $use_peft == "true" || $freeze_encoder == false ]];then + hydra_args+="++ckpt_path=$ckpt_path/model.pt" +fi + +torchrun \ + --nnodes 1 \ + --nproc_per_node 2 \ + --master_port=29505 \ + $code_dir/finetune_torchrun.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + ++train_config.enable_fsdp=false \ + ++train_config.enable_ddp=true \ + ++train_config.use_fp16=true \ + ${hydra_args} diff --git a/examples/aispeech_asr/scripts/transcribe_deepspeed_to_pt.py b/examples/aispeech_asr/scripts/transcribe_deepspeed_to_pt.py new file mode 100644 index 00000000..e2a02862 --- /dev/null +++ b/examples/aispeech_asr/scripts/transcribe_deepspeed_to_pt.py @@ -0,0 +1,9 @@ +import argparse +import torch +import torch_npu +import sys +in_path = sys.argv[1] +out_path = sys.argv[2] +weight_dict = torch.load(in_path)["module"] +torch.save(weight_dict, f"{out_path}/model.pt") +print("[Finish]") \ No newline at end of file diff --git a/examples/asr_librispeech/README.md b/examples/asr_librispeech/README.md index 8502c95a..fb663512 100644 --- a/examples/asr_librispeech/README.md +++ b/examples/asr_librispeech/README.md @@ -53,6 +53,60 @@ Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memor If you are interested in running full parameter fine-tuning on the 70B model, you can enable `low_cpu_fsdp` mode as the following command. This option will load model on rank0 only before moving model to devices to construct FSDP. This can dramatically save cpu memory when loading large models like 70B (on a 8-gpu node, this reduces cpu memory from 2+T to 280G for 70B model). This has been tested with `BF16` on 16xA100, 80GB GPUs. +### Fine-tuning using Deepspeed + +If you're interested in training with DeepSpeed, refer to the script `finetune_whisper_large_linear_vicuna_7b_deepspeed.sh`. The training configuration is shown in `conf/ds_config.json`. When using `bf16`/`fp16` for training, it saves about 20GB of GPU memory compared to `torchrun` when training a 7B model. For 7B models, it's recommended to use `zero-0`/`1`/`2`, while for extremely large models, `zero-3` can be used, though communication may become a bottleneck. + +```json +{ + "train_micro_batch_size_per_gpu": 4, + "gradient_accumulation_steps": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, + "fp16": { + "enabled": true + }, + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu" + } + } +} +``` + +Note that when using `zero-0`/`1`/`2`, the DeepSpeed model is saved in a format that requires a script to convert `mp_rank_00_model_states.pt` to `model.pt`, such as `python transcribe_deepspeed_to_pt.py mp_rank_00_model_states.pt output_dir`. + +``` +global_step1000 +global_step1000/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt +... +global_step1000/mp_rank_00_model_states.pt +latest +zero_to_fp32.py +``` + +If training with `Zero-3`, the model is saved in a different format and can be converted using `python zero_to_fp32.py global_step50 outputdir`. + +``` +global_step50 +global_step50/zero_pp_rank_0_mp_rank_00_model_states.pt +global_step50/zero_pp_rank_0_mp_rank_00_optim_states.pt +... +latest +zero_to_fp32.py +``` +If you use bf16/fp16 training in DeepSpeed and encounter NaN in train/eval loss, check the autocast in `src/slam_llm/utils/deepspeed_utils.py`: + +```python +with autocast() # original code +with autocast(dtype=torch.bfloat16) +with autocast(dtype=torch.float16) +``` ## Citation You can refer to the paper for more results. ``` diff --git a/examples/asr_librispeech/scripts/transcribe_deepspeed_to_pt.py b/examples/asr_librispeech/scripts/transcribe_deepspeed_to_pt.py new file mode 100644 index 00000000..e2a02862 --- /dev/null +++ b/examples/asr_librispeech/scripts/transcribe_deepspeed_to_pt.py @@ -0,0 +1,9 @@ +import argparse +import torch +import torch_npu +import sys +in_path = sys.argv[1] +out_path = sys.argv[2] +weight_dict = torch.load(in_path)["module"] +torch.save(weight_dict, f"{out_path}/model.pt") +print("[Finish]") \ No newline at end of file diff --git a/src/slam_llm/datasets/speech_dataset_large.py b/src/slam_llm/datasets/speech_dataset_large.py new file mode 100644 index 00000000..2ba299ac --- /dev/null +++ b/src/slam_llm/datasets/speech_dataset_large.py @@ -0,0 +1,273 @@ +import torch +from torch.utils.data import Dataset,IterableDataset +import whisper +import kaldiio +import types +from functools import partial +# import pyroomacoustics as pra +import torch.distributed as dist +import string +import copy +import numpy as np +import copy +from tqdm import tqdm +import os +import json +import random +import torchaudio +import random +import logging +import subprocess + + +class MultiTaskDataset(IterableDataset): + def __init__(self, dataset_config, tokenizer=None, split='train'): + super().__init__() + self.multitask_prompt_list = {} + self.append_info_tasks = dataset_config.append_info_tasks + with open(dataset_config.multitask_prompt_path) as f_prompt: + for line in f_prompt: + item = json.loads(line.strip()) + if item["task"] in self.multitask_prompt_list: + self.multitask_prompt_list[item["task"]].append(item["prompt"]) + else: + self.multitask_prompt_list[item["task"]] = [item["prompt"]] + print(f"[Prompt] {self.multitask_prompt_list}") + if split == "train": + self.data_path = dataset_config.train_scp_file_path + elif split == "val": + self.data_path = dataset_config.dev_scp_file_path + elif split == "test": + self.data_path = dataset_config.test_scp_file_path + else: + raise ValueError("split must be train val test") + + self.llm_name = dataset_config.get("llm_name", None) + self.prompt_template1 = dataset_config.get("prompt_style", "{}") + self.answer_template = "{}" + self.dataset_config = dataset_config + self.tokenizer = tokenizer + self.split = split + self.pad_or_trim = dataset_config.get("pad_or_trim", False) + self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss + self.mel_size = dataset_config.get("mel_size", 80) # 80 for whisper large v1 and v2, 128 for large v3 + self.fix_length_audio = dataset_config.get("fix_length_audio", -1) + self.inference_mode = dataset_config.get("inference_mode", False) + self.normalize = dataset_config.get("normalize", False) + self.input_type = dataset_config.get("input_type", None) + assert self.input_type in ["raw", "mel"], "input_type must be one of [raw, mel]" + + def __iter__(self): + multitask_task_path = os.path.join(self.data_path,"multitask.jsonl") + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: # Not in the multi-processing environment of DataLoader. + num_workers = 1 + worker_id = 0 + else: + num_workers = worker_info.num_workers + worker_id = worker_info.id + + # Obtain the process information in the distributed environment. + if dist.is_available() and dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + else: + world_size = 1 + rank = 0 + + # Calculate the data range that each worker and each process should handle. + total_num_workers = num_workers * world_size + worker_rank = rank * num_workers + worker_id + data_index = 0 + with open(multitask_task_path) as f_task: + for line in f_task: + if (data_index % total_num_workers) == worker_rank: + item = json.loads(line.strip()) + ark_path = item["path"] + numpy_array = kaldiio.load_mat(ark_path) + audio_raw = numpy_array[1].astype(np.float32) / 32768 + if len(audio_raw) / 16000 > 30: + continue + key = item["key"] + target = item["target"] + if self.input_type == "raw": + audio_raw = torch.from_numpy(audio_raw).float() + if self.normalize: + audio_raw = torch.nn.functional.layer_norm(audio_raw, audio_raw.shape) + audio_length = len(audio_raw) // 320 # ad-hoc for fairseq 320x downsample + audio_length = audio_length // 5 # ad-hoc for 5x fc downsample + elif self.input_type == "mel": + if self.pad_or_trim == True: + audio_raw = whisper.pad_or_trim(audio_raw) + audio_mel = whisper.log_mel_spectrogram(audio_raw, n_mels=self.mel_size).permute(1, 0) + audio_length = (audio_mel.shape[0] + 1) // 2 # ad-hoc for whisper for 2x downsample from mel to feats + audio_length = audio_length // 5 # ad-hoc for 5x fc downsample + if self.fix_length_audio > 0: + audio_length = self.fix_length_audio + audio_pseudo = torch.full((audio_length,), -1) # placeholder + + prompt = random.choice(self.multitask_prompt_list[item["task"]]) + prompt = self.prompt_template1.format(prompt) + if item["task"] in self.append_info_tasks: + prompt = prompt.format(item[item["task"]]) + prompt_ids = self.tokenizer.encode(prompt) + prompt_length = len(prompt_ids) + + if self.inference_mode: + prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64) + example_ids = torch.cat((audio_pseudo, prompt_ids)) # [audio,prompt] + example_mask = example_ids.ge(-1) # [True,True] + yield { + "input_ids": example_ids, + "attention_mask": example_mask, + "audio": audio_raw if self.input_type == "raw" else None, + "audio_mel": audio_mel if self.input_type == "mel" else None, + 'audio_length': audio_length, + 'key': key, + 'target': target, + } + else: + answer = self.answer_template.format(target) + example = prompt + answer # FIX(MZY): avoid putting a bos token before answer. + example_ids = self.tokenizer.encode(example) # [prompt,answer] + example_ids.append(self.tokenizer.eos_token_id) # [prompt,answer,eos] + example_ids = torch.tensor( + example_ids, dtype=torch.int64 + ) + example_ids = torch.cat((audio_pseudo, example_ids)) # [audio,prompt,answer,eos] + + labels_ids = copy.deepcopy(example_ids) # [audio,prompt,answer,eos] + labels_ids[:audio_length + prompt_length] = -1 # [-1,-1,answer,eos]; + example_mask = example_ids.ge(-1) # FIX(GZF): [True,True,True,True] + + label_mask = labels_ids.ge(0) # [False,False,True,True] + example_ids[~example_mask] = 0 # [audio,prompt,answer,eos] + labels_ids[~label_mask] = self.IGNORE_INDEX # [-100,-100,answer,eos] + yield { + "input_ids": example_ids, + "labels": labels_ids, + "attention_mask": example_mask, + "audio": audio_raw if self.input_type == "raw" else None, + "audio_mel": audio_mel if self.input_type == "mel" else None, + 'audio_length': audio_length, + } + data_index += 1 + + def pad(self, sequence, max_length, padding_idx=0): + if isinstance(sequence, (int, list, tuple)): + if len(sequence) < max_length: + sequence = sequence + [padding_idx] * (max_length - len(sequence)) + else: + sequence = sequence[:max_length] + elif isinstance(sequence, torch.Tensor): + if len(sequence) < max_length: + sequence = torch.cat( + (sequence, torch.full(([max_length - len(sequence)] + list(sequence.size())[1:]), padding_idx))) + else: + sequence = sequence[:max_length] + elif isinstance(sequence, np.ndarray): + if len(sequence) < max_length: + sequence = np.concatenate( + (sequence, np.full((max_length - len(sequence),) + sequence.shape[1:], padding_idx))) + else: + sequence = sequence[:max_length] + else: + raise Exception("Type mismatch during padding!") + return sequence + + def collator(self, samples): + assert samples is not None + input_ids_max_length = max([s['input_ids'].shape[0] for s in samples]) + input_ids = torch.stack([self.pad(s['input_ids'], input_ids_max_length, self.tokenizer.pad_token_id) + for s in samples]) + attention_mask = torch.stack([self.pad(s['attention_mask'], input_ids_max_length, False) + for s in samples]) + if self.input_type == "raw": + audio_raw_max_length = max([s['audio'].shape[0] for s in samples]) + audio_raw = torch.stack([self.pad(s['audio'], audio_raw_max_length, 0) + for s in samples]) + audio_mask = torch.zeros(len(samples), audio_raw_max_length) + for line, sample in enumerate(samples): + audio_mask[line, :sample['audio'].shape[0]] = 1 + elif self.input_type == "mel": + audio_mel_max_length = max([s['audio_mel'].shape[0] for s in samples]) + audio_mel = torch.stack([self.pad(s['audio_mel'], audio_mel_max_length, 0) + for s in samples]) + audio_mel_post_mask = torch.zeros(len(samples), (audio_mel_max_length + 1) // 2) # ad-hoc for whisper for 2x downsample from mel to feats + for line, sample in enumerate(samples): + audio_mel_post_mask[line, :(sample['audio_mel'].shape[0] + 1) // 2] = 1 + + modality_mask = torch.zeros_like(attention_mask) + for line, sample in enumerate(samples): + modality_mask[line, :sample['audio_length']] = 1 + + if self.inference_mode: + keys = [s['key'] for s in samples] + targets = [s['target'] for s in samples] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "audio": audio_raw if self.input_type == "raw" else None, + "audio_mask": audio_mask if self.input_type == "raw" else None, + "audio_mel": audio_mel if self.input_type == "mel" else None, + "audio_mel_post_mask": audio_mel_post_mask if self.input_type == "mel" else None, + "modality_mask": modality_mask, + "keys": keys, + "targets": targets + } + + labels = torch.stack([self.pad(s['labels'], input_ids_max_length, self.IGNORE_INDEX) + for s in samples]) + return { + "input_ids": input_ids, + "labels": labels, + "attention_mask": attention_mask, + "audio": audio_raw if self.input_type == "raw" else None, + "audio_mask": audio_mask if self.input_type == "raw" else None, + "audio_mel": audio_mel if self.input_type == "mel" else None, + "audio_mel_post_mask": audio_mel_post_mask if self.input_type == "mel" else None, + "modality_mask": modality_mask + } + +class MultiTaskDynamicBatchDataset(IterableDataset): + def __init__(self, dataset: IterableDataset, window_class) -> None: + super().__init__() + self.dp = dataset + + assert window_class is not None + self.window_class = window_class + self.collator = self.dp.collator + self._buffer = [] + def __iter__(self): + for elem in self.dp: + if not self.window_class(elem, self._buffer): + self._buffer.append(elem) + else: + if len(self._buffer) > 0: + yield self._buffer + del self._buffer + self._buffer = [elem] + if len(self._buffer) > 0: + yield self._buffer + del self._buffer + self._buffer = [] + + +def window_class(elem,buffer,max_frame_length): + if len(buffer) == 0: + return True + max_frame = max(len(elem["input_ids"]),max([ len(_["input_ids"]) for _ in buffer])) + return (len(buffer) + 1) * max_frame > max_frame_length + +def get_speech_dataset(dataset_config, tokenizer, split): + dataset = MultiTaskDataset(dataset_config, tokenizer, split) + if split == "train": + dataset = MultiTaskDynamicBatchDataset(dataset,partial(window_class,max_frame_length = dataset_config.train_max_frame_length)) + else: + dataset = MultiTaskDynamicBatchDataset(dataset,partial(window_class,max_frame_length = dataset_config.eval_max_frame_length)) + return dataset + + + + diff --git a/src/slam_llm/pipeline/finetune.py b/src/slam_llm/pipeline/finetune.py index 4ced3c51..dce517a8 100644 --- a/src/slam_llm/pipeline/finetune.py +++ b/src/slam_llm/pipeline/finetune.py @@ -197,14 +197,14 @@ def main(kwargs: DictConfig): dataset_config, split="train", ) - if not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0: + if (not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0) and train_config.batching_strategy != "dynamic": logger.info(f"--> Training Set Length = {len(dataset_train)}") dataset_val = get_preprocessed_dataset( tokenizer, dataset_config, split="val", ) - if not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0: + if not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0 and train_config.batching_strategy != "dynamic": logger.info(f"--> Validation Set Length = {len(dataset_val)}") if train_config.batching_strategy == "packing": dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length) diff --git a/src/slam_llm/pipeline/finetune_deepspeed.py b/src/slam_llm/pipeline/finetune_deepspeed.py index 8f275faf..bb1b35c7 100644 --- a/src/slam_llm/pipeline/finetune_deepspeed.py +++ b/src/slam_llm/pipeline/finetune_deepspeed.py @@ -189,14 +189,14 @@ def main(kwargs: DictConfig): dataset_config, split="train", ) - if not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0: + if (not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0) and train_config.batching_strategy != "dynamic": logger.info(f"--> Training Set Length = {len(dataset_train)}") dataset_val = get_preprocessed_dataset( tokenizer, dataset_config, split="val", ) - if not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0: + if not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0 and train_config.batching_strategy != "dynamic": logger.info(f"--> Validation Set Length = {len(dataset_val)}") if train_config.batching_strategy == "packing": dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length) diff --git a/src/slam_llm/pipeline/inference_batch.py b/src/slam_llm/pipeline/inference_batch.py index 1dc03940..a405d638 100644 --- a/src/slam_llm/pipeline/inference_batch.py +++ b/src/slam_llm/pipeline/inference_batch.py @@ -109,7 +109,7 @@ def main(kwargs: DictConfig): dataset_config, split="test", ) - if not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0: + if (not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0) and train_config.batching_strategy != "dynamic": logger.info(f"--> Training Set Length = {len(dataset_test)}") test_dataloader = torch.utils.data.DataLoader( @@ -127,7 +127,7 @@ def main(kwargs: DictConfig): pred_path = kwargs.get('decode_log') + "_pred" gt_path = kwargs.get('decode_log') + "_gt" with open(pred_path, "w") as pred, open(gt_path, "w") as gt: - for step, batch in tqdm(enumerate(test_dataloader), total=len(test_dataloader)): + for step, batch in tqdm(enumerate(test_dataloader), total=len(test_dataloader) if train_config.batching_strategy != "dynamic" else ""): for key in batch.keys(): batch[key] = batch[key].to(device) if isinstance(batch[key], torch.Tensor) else batch[key] model_outputs = model.generate(**batch) diff --git a/src/slam_llm/pipeline/inference_batch_deepspeed.py b/src/slam_llm/pipeline/inference_batch_deepspeed.py new file mode 100644 index 00000000..18bcd2ad --- /dev/null +++ b/src/slam_llm/pipeline/inference_batch_deepspeed.py @@ -0,0 +1,192 @@ +# os +import os +import fire +import deepspeed +import random +import importlib +from tqdm import tqdm +# nn +import torch +import torch_npu +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from torch.utils.data import DistributedSampler +# opt +import torch.optim as optim +from torch.optim.lr_scheduler import StepLR +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, +) +from torch.nn.parallel import DistributedDataParallel as DDP + +from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload +from slam_llm.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing +import torch.distributed as dist +# config +# from llama_recipes.configs import fsdp_config as FSDP_CONFIG +# from llama_recipes.configs import train_config as TRAIN_CONFIG +# from llama_recipes.configs import model_config as MODEL_CONFIG +# from llama_recipes.configs import log_config as LOG_CONFIG +from slam_llm.data.concatenator import ConcatDataset + +# util +from slam_llm.utils import fsdp_auto_wrap_policy +from slam_llm.utils.config_utils import get_dataloader_kwargs + +from slam_llm.utils.dataset_utils import get_preprocessed_dataset, load_module_from_py_file +from slam_llm.utils.model_utils import get_custom_model_factory +from slam_llm.utils.deepspeed_utils import ( + train, + freeze_transformer_layers, + setup, + setup_environ_flags, + clear_gpu_cache, +) + +import sys +import logging +import wandb + +import hydra +from omegaconf import DictConfig, ListConfig, OmegaConf +from pathlib import Path + +@hydra.main(config_name=None, version_base=None) # strict=False 允许忽略未知参数) +def main_hydra(cfg: DictConfig): + def to_plain_list(cfg_item): + if isinstance(cfg_item, ListConfig): + return OmegaConf.to_container(cfg_item, resolve=True) + elif isinstance(cfg_item, DictConfig): + return {k: to_plain_list(v) for k, v in cfg_item.items()} + else: + return cfg_item + + # kwargs = to_plain_list(cfg) + kwargs = cfg + log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) + + logging.basicConfig(level=log_level) + + if kwargs.get("debug", False): + import pdb; + pdb.set_trace() + + main(kwargs) + + +def main(kwargs: DictConfig): + # Update the configuration for the training and sharding process + # train_config, fsdp_config, model_config, log_config = TRAIN_CONFIG(), FSDP_CONFIG(), MODEL_CONFIG(), LOG_CONFIG() + # update_config((train_config, fsdp_config, model_config, log_config), **kwargs) + + train_config, model_config, log_config, dataset_config = kwargs.train_config, \ + kwargs.model_config, \ + kwargs.log_config, \ + kwargs.dataset_config + del kwargs.train_config + del kwargs.model_config + del kwargs.log_config + del kwargs.dataset_config + + # Set log + if not os.path.exists(os.path.dirname(log_config.log_file)): + os.makedirs(os.path.dirname(log_config.log_file), exist_ok=True) + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + filemode='w' + ) + + logger = logging.getLogger() + logger.setLevel(logging.INFO) + + file_handler = logging.FileHandler(filename=log_config.log_file, mode='w') + file_handler.setLevel(logging.INFO) + file_formatter = logging.Formatter('[%(asctime)s][%(name)s][%(levelname)s] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') + file_handler.setFormatter(file_formatter) + + logger.handlers[0].setLevel(logging.INFO) + console_formatter = logging.Formatter('[%(asctime)s][%(name)s][%(levelname)s] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') + logger.handlers[0].setFormatter(console_formatter) + + logger.addHandler(file_handler) + + + # Set the seeds for reproducibility + torch_npu.npu.manual_seed(train_config.seed) + torch.manual_seed(train_config.seed) + random.seed(train_config.seed) + + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + logger.info(f"local_rank: {local_rank}, rank: {rank}, world_size: {world_size}") + + deepspeed.init_distributed( + dist_backend='nccl', # 使用NCCL后端(GPU场景) + ) + + if rank == 0: + logger.info("train_config: {}".format(train_config)) + logger.info("model_config: {}".format(model_config)) + logger.info("log_config: {}".format(log_config)) + + # Set wandb + if rank == 0: + if log_config.use_wandb: + if not os.path.exists(log_config.wandb_dir): + os.makedirs(log_config.wandb_dir, exist_ok=True) + wandb_config={"train_config": train_config, "model_config": model_config, "log_config": log_config} + wandb.init(dir=log_config.wandb_dir, entity=log_config.wandb_entity_name, project=log_config.wandb_project_name,name=log_config.wandb_exp_name ,config=wandb_config) + + + model_factory = get_custom_model_factory(model_config, logger) + model, tokenizer = model_factory(train_config, model_config, **kwargs) + device = torch.device(f"npu:{local_rank}" if torch.npu.is_available() else "cpu") # FIX(MZY): put the whole model to device. + model.to(device) + model.eval() + logger.info("dataset_config: {}".format(dataset_config)) + dataset_test = get_preprocessed_dataset( + tokenizer, + dataset_config, + split="test", + ) + # sampler = DistributedSampler( + # dataset_test, + # rank=dist.get_rank(), + # num_replicas=dist.get_world_size(), + # ) + test_dataloader = torch.utils.data.DataLoader( + dataset_test, + num_workers=train_config.num_workers_dataloader, + pin_memory=True, + shuffle=False, + batch_size=train_config.val_batch_size, + drop_last=False, + collate_fn=dataset_test.collator, + # sampler=sampler + # multiprocessing_context=mp.get_context("spawn") + ) + + logger.info("=====================================") + pred_path = kwargs.get('decode_log') + f"_pred" + gt_path = kwargs.get('decode_log') + f"_gt" + pred_result = "" + gt_result = "" + with torch.no_grad(): + for step, batch in tqdm(enumerate(test_dataloader)): + for key in batch.keys(): + batch[key] = batch[key].to(device) if isinstance(batch[key], torch.Tensor) else batch[key] + model_outputs = model.generate(**batch) + if hasattr(model, 'tokenizer'): + output_text = model.tokenizer.batch_decode(model_outputs, add_special_tokens=False, skip_special_tokens=True) + else: + output_text = tokenizer.batch_decode(model_outputs, skip_special_tokens=True) + for key, text, target in zip(batch["keys"], output_text, batch["targets"]): + pred_result += key + " " + text.strip() + "\n" + gt_result += key + " " + target + "\n" + with open(pred_path, "a+") as pred, open(gt_path, "a+") as gt: + pred.write(pred_result) + gt.write(gt_result) +if __name__ == "__main__": + main_hydra() \ No newline at end of file diff --git a/src/slam_llm/utils/config_utils.py b/src/slam_llm/utils/config_utils.py index b0aadf7d..adbce533 100644 --- a/src/slam_llm/utils/config_utils.py +++ b/src/slam_llm/utils/config_utils.py @@ -91,6 +91,12 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode): kwargs["batch_size"] = batch_size kwargs["drop_last"] = True kwargs["collate_fn"] = default_data_collator + elif train_config.batching_strategy == "dynamic": + kwargs["sampler"] = None + kwargs["batch_size"] = None + kwargs["drop_last"] = False + kwargs["collate_fn"] = dataset.collator + logger.info(f"Using batching strategy: {train_config.batching_strategy}") else: # raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}") if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed: diff --git a/src/slam_llm/utils/deepspeed_utils.py b/src/slam_llm/utils/deepspeed_utils.py index 6ac5608d..75903fc4 100644 --- a/src/slam_llm/utils/deepspeed_utils.py +++ b/src/slam_llm/utils/deepspeed_utils.py @@ -7,7 +7,7 @@ from contextlib import nullcontext from pathlib import Path from pkg_resources import packaging - +import datetime import functools import hydra @@ -107,7 +107,28 @@ def decorated_main(cfg_passthrough: Optional[DictConfig] = None) -> Any: return main_decorator - +def deepspeed_join(group_join): + """ + Copy from wenet:https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/executor.py#L64 + """ + try: + # NOTE(xcsong): Why we need a new group? + # Because Deepspeed has its own group where all the relevant communication + # operations are executed. If we add a communication operation that is not + # managed by Deepspeed in this group, it's highly likely to cause + # communication chaos, resulting in hard-to-troubleshoot hangs. + dist.monitored_barrier(group=group_join, + timeout=group_join.options._timeout) + except RuntimeError as e: + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + logging.info("Detected uneven workload distribution: {}\n".format(e) + + "Break current worker to manually join all workers, " + + "world_size {}, current rank {}, current local_rank {}\n". + format(world_size, rank, local_rank)) + return True + return False def set_tokenizer_params(tokenizer: LlamaTokenizer): tokenizer.pad_token_id = 0 @@ -169,19 +190,22 @@ def train( best_val_loss = float("inf") best_val_acc = 0.0 for epoch in range(train_config.num_epochs): + dist.barrier() + group_join = dist.new_group( + backend="gloo", timeout=datetime.timedelta(seconds=3)) epoch_start_time = time.perf_counter() with MemoryTrace() as memtrace: # track the memory usage model.train() total_loss = 0.0 total_acc = 0.0 - total_length = len(train_dataloader) // gradient_accumulation_steps - pbar = tqdm( - colour="blue", - desc=f"Training Epoch: {epoch+1}", - total=total_length, - dynamic_ncols=True, - ) + if train_config.batching_strategy != "dynamic": + total_length = len(train_dataloader)//gradient_accumulation_steps + pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True) + else: + pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", dynamic_ncols=True) for step, batch in enumerate(train_dataloader): + if train_config.batching_strategy == "dynamic" and deepspeed_join(group_join): + break for key in batch.keys(): batch[key] = ( batch[key].to(local_rank).half() @@ -193,8 +217,8 @@ def train( else batch[key] ) ) - # with autocast(): - outputs, *rest = model(**batch) + with autocast(): + outputs, *rest = model(**batch) acc = rest[0] if rest else -1 loss = outputs.loss @@ -209,7 +233,7 @@ def train( "train_inner/train_inner_loss": loss, "train_inner/train_inner_accuracy": acc, }, - step=(epoch * total_length + step), + step=(epoch * total_length + step) if train_config.batching_strategy != "dynamic" else step + 1, ) else: wandb.log( @@ -217,7 +241,7 @@ def train( "train_inner/train_inner_loss": loss, "train_inner/train_inner_accuracy": acc, }, - step=(epoch * total_length + step), + step=(epoch * total_length + step) if train_config.batching_strategy != "dynamic" else step + 1, ) total_loss += loss.detach().float() @@ -227,17 +251,17 @@ def train( model.backward(loss) model.step() - if (step + 1) % gradient_accumulation_steps == 0 or step == len( + if (step + 1) % gradient_accumulation_steps == 0 or ( train_config.batching_strategy != "dynamic" and step == len( train_dataloader - ) - 1: + ) - 1): pbar.update(1) pbar.set_description( - f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()}, acc: {acc})" + f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader) if train_config.batching_strategy != 'dynamic' else ''} completed (loss: {loss.detach().float()}, acc: {acc})" ) if ( - (epoch * total_length + step + 1) % train_config.validation_interval + (epoch * total_length + step + 1 if train_config.batching_strategy != "dynamic" else step + 1) % train_config.validation_interval == 0 and train_config.run_validation ): @@ -302,6 +326,7 @@ def train( logger.info("=====================================") dist.barrier() pbar.close() + dist.destroy_process_group(group_join) epoch_end_time = time.perf_counter() - epoch_start_time epoch_times.append(epoch_end_time) @@ -311,8 +336,8 @@ def train( ): dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) dist.all_reduce(total_acc, op=dist.ReduceOp.SUM) - train_epoch_loss = total_loss / len(train_dataloader) - train_epoch_acc = total_acc / len(train_dataloader) + train_epoch_loss = total_loss / (len(train_dataloader) if train_config.batching_strategy != "dynamic" else (step + 1)* train_config.num_epochs) + train_epoch_acc = total_acc / (len(train_dataloader) if train_config.batching_strategy != "dynamic" else (step + 1)* train_config.num_epochs) if train_config.enable_fsdp or train_config.enable_ddp: train_epoch_loss = train_epoch_loss / world_size train_epoch_acc = train_epoch_acc / world_size @@ -411,13 +436,11 @@ def evaluation(model, train_config, eval_dataloader, local_rank, tokenizer): ) # (Fix:MZY): fix expected scalar type mismatch in norm with MemoryTrace() as memtrace: - total_length = len(eval_dataloader) - pbar = tqdm( - colour="green", - desc=f"Evaluating Epoch", - total=total_length, - dynamic_ncols=True, - ) + if train_config.batching_strategy != "dynamic": + total_length = len(eval_dataloader) + pbar = tqdm(colour="green", desc=f"Evaluating Epoch", total=total_length, dynamic_ncols=True) + else: + pbar = tqdm(colour="green", desc=f"Evaluating Epoch", dynamic_ncols=True) for step, batch in enumerate(eval_dataloader): for key in batch.keys(): batch[key] = ( @@ -446,7 +469,7 @@ def evaluation(model, train_config, eval_dataloader, local_rank, tokenizer): ) pbar.update(1) pbar.set_description( - f"step: {step+1}/{total_length}, eval_loss: {eval_loss/(step+1):.4f}, eval_acc: {eval_acc/(step+1):.4f}" + f"step: {step+1}/{total_length if train_config.batching_strategy != 'dynamic' else '' }, eval_loss: {eval_loss/(step+1):.4f}, eval_acc: {eval_acc/(step+1):.4f}" ) # If there's more than one CUDA device, reduce evaluation loss across all devices @@ -457,8 +480,8 @@ def evaluation(model, train_config, eval_dataloader, local_rank, tokenizer): dist.all_reduce(eval_acc, op=dist.ReduceOp.SUM) # Compute average loss and perplexity - eval_epoch_loss = eval_loss / len(eval_dataloader) - eval_epoch_acc = eval_acc / len(eval_dataloader) + eval_epoch_loss = eval_loss / (len(eval_dataloader) if train_config.batching_strategy != "dynamic" else step + 1) + eval_epoch_acc = eval_acc / (len(eval_dataloader) if train_config.batching_strategy != "dynamic" else step + 1) eval_epoch_loss = eval_epoch_loss / world_size eval_epoch_acc = eval_epoch_acc / world_size eval_ppl = torch.exp(eval_epoch_loss) diff --git a/src/slam_llm/utils/train_utils.py b/src/slam_llm/utils/train_utils.py index 8f5c34e6..621306b5 100644 --- a/src/slam_llm/utils/train_utils.py +++ b/src/slam_llm/utils/train_utils.py @@ -7,7 +7,7 @@ from contextlib import nullcontext from pathlib import Path from pkg_resources import packaging - +from torch.distributed.algorithms.join import Join import torch import torch.cuda.nccl as nccl @@ -88,12 +88,15 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche best_val_acc = 0.0 for epoch in range(train_config.num_epochs): epoch_start_time = time.perf_counter() - with MemoryTrace() as memtrace: # track the memory usage + with MemoryTrace() as memtrace,Join([model,optimizer]): # track the memory usage model.train() total_loss = 0.0 total_acc = 0.0 - total_length = len(train_dataloader)//gradient_accumulation_steps - pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True) + if train_config.batching_strategy != "dynamic": + total_length = len(train_dataloader)//gradient_accumulation_steps + pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True) + else: + pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", dynamic_ncols=True) for step, batch in enumerate(train_dataloader): for key in batch.keys(): if train_config.enable_fsdp or train_config.enable_ddp: @@ -117,16 +120,16 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if log_config.use_wandb and step % log_config.log_interval == 0: if train_config.enable_fsdp or train_config.enable_ddp: if rank==0: - wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_accuracy":acc}, step=(epoch * total_length + step)) + wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_accuracy":acc}, step=(epoch * total_length + step) if train_config.batching_strategy != "dynamic" else step + 1) else: - wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_accuracy":acc}, step=(epoch * total_length + step)) - + wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_accuracy":acc}, step=(epoch * total_length + step) if train_config.batching_strategy != "dynamic" else step + 1) total_loss += loss.detach().float() total_acc += acc if train_config.use_fp16: # if fp16 is enabled, use gradient scaler to handle gradient update scaler.scale(loss).backward() - if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + + if (step + 1) % gradient_accumulation_steps == 0 or (train_config.batching_strategy != "dynamic" and step == len(train_dataloader) - 1): scaler.step(optimizer) scaler.update() if lr_scheduler is not None: @@ -139,15 +142,15 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if log_config.use_wandb and step % log_config.log_interval == 0: if train_config.enable_fsdp or train_config.enable_ddp: if rank==0: - wandb.log({"train_inner/lr":current_lr}, step=(epoch * total_length + step)) + wandb.log({"train_inner/lr":current_lr}, step=(epoch * total_length + step) if train_config.batching_strategy != "dynamic" else step + 1) else: - wandb.log({"train_inner/lr":current_lr}, step=(epoch * total_length + step)) + wandb.log({"train_inner/lr":current_lr}, step=(epoch * total_length + step) if train_config.batching_strategy != "dynamic" else step + 1) optimizer.zero_grad() pbar.update(1) else: # regular backpropagation when fp16 is not used loss.backward() - if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + if (step + 1) % gradient_accumulation_steps == 0 or ( train_config.batching_strategy != "dynamic" and step == len(train_dataloader) - 1): optimizer.step() if lr_scheduler is not None: lr_scheduler.step() @@ -159,15 +162,15 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if log_config.use_wandb and step % log_config.log_interval == 0: if train_config.enable_fsdp or train_config.enable_ddp: if rank==0: - wandb.log({"train_inner/lr":current_lr}, step=(epoch * total_length + step)) + wandb.log({"train_inner/lr":current_lr}, step=(epoch * total_length + step) if train_config.batching_strategy != "dynamic" else step + 1) else: - wandb.log({"train_inner/lr":current_lr}, step=(epoch * total_length + step)) + wandb.log({"train_inner/lr":current_lr}, step=(epoch * total_length + step) if train_config.batching_strategy != "dynamic" else step + 1) optimizer.zero_grad() pbar.update(1) - pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()}, acc: {acc})") + pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader) if train_config.batching_strategy != 'dynamic' else ''} completed (loss: {loss.detach().float()}, acc: {acc})") - if (epoch * total_length + step + 1) % train_config.validation_interval == 0 and train_config.run_validation: + if (epoch * total_length + step + 1 if train_config.batching_strategy != "dynamic" else step + 1) % train_config.validation_interval == 0 and train_config.run_validation: eval_ppl, eval_epoch_loss, *rest = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer) eval_epoch_acc = rest[0] if rest else -1 checkpoint_start_time = time.perf_counter() @@ -323,8 +326,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if torch.cuda.device_count() > 1 and (train_config.enable_fsdp or train_config.enable_ddp): dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) dist.all_reduce(total_acc, op=dist.ReduceOp.SUM) - train_epoch_loss = total_loss / len(train_dataloader) - train_epoch_acc = total_acc / len(train_dataloader) + train_epoch_loss = total_loss / (len(train_dataloader) if train_config.batching_strategy != "dynamic" else (step + 1) *train_config.num_epochs) + train_epoch_acc = total_acc / (len(train_dataloader) if train_config.batching_strategy != "dynamic" else (step + 1) *train_config.num_epochs) if train_config.enable_fsdp or train_config.enable_ddp: train_epoch_loss = train_epoch_loss/world_size train_epoch_acc = train_epoch_acc/world_size @@ -411,8 +414,11 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext # (Fix:MZY): fix expected scalar type mismatch in norm with MemoryTrace() as memtrace: - total_length = len(eval_dataloader) - pbar = tqdm(colour="green", desc=f"Evaluating Epoch", total=total_length, dynamic_ncols=True) + if train_config.batching_strategy != "dynamic": + total_length = len(eval_dataloader) + pbar = tqdm(colour="green", desc=f"Evaluating Epoch", total=total_length, dynamic_ncols=True) + else: + pbar = tqdm(colour="green", desc=f"Evaluating Epoch", dynamic_ncols=True) for step, batch in enumerate(eval_dataloader): for key in batch.keys(): if train_config.enable_fsdp or train_config.enable_ddp: @@ -438,7 +444,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): except Exception: pass # vallex does not need to show it's result (we can't view any thing from abstract acoustic token) pbar.update(1) - pbar.set_description(f"step: {step+1}/{total_length}, eval_loss: {eval_loss/(step+1):.4f}, eval_acc: {eval_acc/(step+1):.4f}") + pbar.set_description(f"step: {step+1}/{total_length if train_config.batching_strategy != 'dynamic' else '' }, eval_loss: {eval_loss/(step+1):.4f}, eval_acc: {eval_acc/(step+1):.4f}") # If there's more than one CUDA device, reduce evaluation loss across all devices if torch.cuda.device_count() > 1 and train_config.enable_fsdp or train_config.enable_ddp: @@ -446,8 +452,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): dist.all_reduce(eval_acc, op=dist.ReduceOp.SUM) # Compute average loss and perplexity - eval_epoch_loss = eval_loss / len(eval_dataloader) - eval_epoch_acc = eval_acc / len(eval_dataloader) + eval_epoch_loss = eval_loss / (len(eval_dataloader) if train_config.batching_strategy != "dynamic" else step + 1) + eval_epoch_acc = eval_acc / (len(eval_dataloader) if train_config.batching_strategy != "dynamic" else step + 1) if train_config.enable_fsdp or train_config.enable_ddp: eval_epoch_loss = eval_epoch_loss/world_size eval_epoch_acc = eval_epoch_acc/world_size