Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions docs/source/BestPractices/GPTQ量化模型GRPO训练.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# 采用Colocate模式进行GPTQ量化模型的GRPO训练

## 1. 问题和可能的解决方法

已知:采用vLLM加速时目前代码会合Lora再更新vllm服务的模型的参数,但是GPTQ量化模型无法合lora。

实际:采用VLLM加速,量化模型在move model to llm时会出错。报错:AttributeError: 'GPTQLoraLinear' object has no attribute 'get_delta_weight',同https://github.com/modelscope/ms-swift/issues/3949。

现在的框架只能在不采用VLLM推理加速的情况下训练,速度非常慢。(不考虑此方案)

针对这个问题有两种解决方法:

- 方案1:修改ms-swift,在move_model_to_vllm中改为每步暂存Lora参数到本地,调用LLM engine时通过Adapter-request参数传递lora参数

- 方案2:反量化GPTQ-int4模型,在此基础上进行训练,保存lora,最后基模采用量化版本的。

## 2. 方案2

针对方案2,优先测试了ms-swift能否支持非量化的32B模型的Lora模式的GRPO。发现:
- server模式下的VLLM不支持。在更新VLLM服务的模型的参数时会出错,报通信超时错误,同https://github.com/modelscope/ms-swift/issues/4797。
- colocate模式下可以。

目前还没写出无误的GPTQ反量化代码,所以方案2暂时进行到这里。

## 3. 方案1

针对方案1,按想法修改了ms-swift的代码,并且通过了测试,完成了实验。

### 3.1 示例脚本

```bash
MASTER_PORT=29502 \
CUDA_VISIBLE_DEVICES=0,1,2,3 \
NPROC_PER_NODE=4 \
swift rlhf \
--rlhf_type grpo \
--model /xxx/deepseek-r1-distill-qwen-32b-gptq-int4 \
--external_plugins examples/train/grpo/plugin/plugin.py \
--reward_funcs external_xxx_accuracy external_xxx_format external_xxx_len \
--reward_weights 1.0 1.0 1.0 \
--vllm_mode colocate \
--use_vllm true \
--vllm_gpu_memory_utilization 0.4 \
--vllm_tensor_parallel_size 4 \
--torch_dtype bfloat16 \
--dataset 'xxx/xxx.json' \
--max_completion_length 5120 \
--num_train_epochs 5 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--learning_rate 1e-6 \
--gradient_accumulation_steps 4 \
--eval_steps 50 \
--save_steps 50 \
--save_total_limit 10 \
--logging_steps 5 \
--max_length 16000 \
--train_type lora \
--lora_rank 8 \
--lora_alpha 16 \
--target_modules all-linear \
--resume_only_model \
--resume_from_checkpoint /xxx/checkpoint-xxx \
--output_dir /xxx/xxx \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--num_generations 16 \
--temperature 0.7 \
--top_p 1.0 \
--top_k 80 \
--log_completions true \
--report_to tensorboard \
--model_type deepseek_r1_distill \
--async_generate false \
--deepspeed zero3 \
--sleep_level 1 \
--max_step 1500 \
--vllm_max_model_len 30000 \
--local_adapter_path /xxx/tmp_path_for_lora \

```
### 3.2 注意事项

- 需要注意,此时不能用move_model_batches这个参数,也就是不将lora参数分batch传给vllm,否则会报错[rank0]: IndexError: too many indices for tensor of dimension 1。

- 如果是继续训练,比如先前基于sft训练了lora,想在此lora上继续训练,采用GRPO方式。那么如果先前采用的deepspeed阶段是zero3, 那么此时需要采用同样的zero3。不能采用建议的zero3_offload 、offload_optimizer true 、offload_model true 策略,否则会报错[rank0]: KeyError: 'bias_correction'

- 如果遇到爆显存问题,可调低vllm_gpu_memory_utilization,vllm_max_model_len等值。
1 change: 1 addition & 0 deletions swift/llm/argument/rlhf_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class RLHFArguments(TeacherModelArguments, GRPOArguments, PPOArguments, RewardMo
undesirable_weight: float = 1.0
# PPO/GRPO/GKD
temperature: float = 0.9
local_adapter_path: str = None
# RM
center_rewards_coefficient: Optional[float] = None
# GKD
Expand Down
1 change: 1 addition & 0 deletions swift/trainers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def get_vllm_engine_kwargs(self):

@dataclass
class GRPOArgumentsMixin(VllmArguments):
local_adapter_path: str = None
epsilon: float = 0.2
epsilon_high: Optional[float] = None
delta: Optional[float] = None
Expand Down
87 changes: 58 additions & 29 deletions swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union

import shutil
from swift.llm.infer.infer_engine.utils import AdapterRequest

import json
import torch
import torch.nn as nn
Expand Down Expand Up @@ -107,6 +110,8 @@ def __init__(self,
from swift.trainers.rlhf_arguments import GRPOConfig
args: GRPOConfig = kwargs['args']
self.args = args
self.local_adapter_path = getattr(args, 'local_adapter_path', None)
self.enable_lora = True if self.local_adapter_path else False
self.ref_adapter_name = getattr(args, 'ref_adapter_name', None)
self.model_adapter_name = None
# for async generate
Expand Down Expand Up @@ -529,6 +534,7 @@ def prepare_vllm(self, model):
max_model_len=self.args.vllm_max_model_len,
seed=self.accelerator.process_index // self.vllm_tensor_parallel_size,
template=self.template,
enable_lora = self.enable_lora,
distributed_executor_backend='external_launcher',
)
return engine
Expand Down Expand Up @@ -568,34 +574,47 @@ def _move_model_to_vllm(self, skip_async_check=False):
parameter for name, parameter in self.model.named_parameters()
if not parameter_group or name in parameter_group
]
with gather_if_zero3(parameters), patch_lora_merge(self.model, parameter_group):
self.model.merge_adapter()
state_dict = self.model.state_dict()
state_dict = {
k.removeprefix('base_model.model.').replace('.base_layer', ''): v
for k, v in state_dict.items()
}
state_dict = {k: v for k, v in state_dict.items() if self.model.prefix not in k}
# When module to save, remove its prefix and discard the original module
state_dict = {
k.replace('modules_to_save.default.', ''): v
for k, v in state_dict.items() if 'original_module' not in k
}
if parameter_group_no_lora:
parameter_group_no_lora = [n.replace('base_model.model.', '') for n in parameter_group_no_lora]
state_dict = {k: v for k, v in state_dict.items() if k in parameter_group_no_lora}
assert len(state_dict) > 0 and all(
[state.shape != torch.Size([0]) for state in state_dict.values()])

if self.vllm_mode == 'server' and self.accelerator.is_main_process:
for name, param in state_dict.items():
self.vllm_client.update_named_param(name, param)
elif self.vllm_mode == 'colocate':
llm_model = self.engine.inner_model
llm_model.load_weights(state_dict.items())
with patch_lora_unmerge(self.model):
self.model.unmerge_adapter()
del state_dict
# TODO save lora in local adapter path
if self.local_adapter_path:
with gather_if_zero3(parameters):
if self.accelerator.is_main_process:
if os.path.exists(self.local_adapter_path):
# delete existing files
shutil.rmtree(self.local_adapter_path)
logger.info(f"step:{self.state.global_step},deleted previous lora")

os.makedirs(self.local_adapter_path)
self.model.save_pretrained(self.local_adapter_path,peft_format=True)
logger.info(f"step:{self.state.global_step},save newest lora in local adapter path")
else:
with gather_if_zero3(parameters), patch_lora_merge(self.model, parameter_group):
self.model.merge_adapter()
state_dict = self.model.state_dict()
state_dict = {
k.removeprefix('base_model.model.').replace('.base_layer', ''): v
for k, v in state_dict.items()
}
state_dict = {k: v for k, v in state_dict.items() if self.model.prefix not in k}
# When module to save, remove its prefix and discard the original module
state_dict = {
k.replace('modules_to_save.default.', ''): v
for k, v in state_dict.items() if 'original_module' not in k
}
if parameter_group_no_lora:
parameter_group_no_lora = [n.replace('base_model.model.', '') for n in parameter_group_no_lora]
state_dict = {k: v for k, v in state_dict.items() if k in parameter_group_no_lora}
assert len(state_dict) > 0 and all(
[state.shape != torch.Size([0]) for state in state_dict.values()])

if self.vllm_mode == 'server' and self.accelerator.is_main_process:
for name, param in state_dict.items():
self.vllm_client.update_named_param(name, param)
elif self.vllm_mode == 'colocate':
llm_model = self.engine.inner_model
llm_model.load_weights(state_dict.items())
with patch_lora_unmerge(self.model):
self.model.unmerge_adapter()
del state_dict
else:
for name, param in self.model.named_parameters():
with gather_if_zero3([param]):
Expand Down Expand Up @@ -1949,7 +1968,17 @@ def _engine_infer(
asdict(request_config),
use_tqdm=use_tqdm)
else:
res = self.engine.infer(infer_requests, request_config, use_tqdm=use_tqdm)
# use adapter_request path lora to vllm engine
if self.local_adapter_path:
if not os.path.exists(self.local_adapter_path):
raise FileNotFoundError(f'fpath: {self.local_adapter_path}')
tmp_name = "lora_"+str(self.state.global_step)
adapter_request = AdapterRequest(tmp_name, self.local_adapter_path)
if self.accelerator.is_main_process :
logger.info(f"adapter_request info:{adapter_request}")
res = self.engine.infer(infer_requests, request_config, use_tqdm=use_tqdm, adapter_request=adapter_request)
else:
res = self.engine.infer(infer_requests, request_config, use_tqdm=use_tqdm)
if all(isinstance(r, RolloutOutput) for r in res):
return res
else:
Expand Down