Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 4 additions & 13 deletions applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,31 +268,22 @@ def _init_reward(self, critic_model_name_or_path):
# If critic is ZeRO-3 then we use it for everything, otherwise assume we have enough memory
zero_stage = 0

ds_config = get_eval_ds_config(offload=self.args.offload,
ds_config = get_eval_ds_config(offload=self.args.offload_reward_model,
dtype=self.args.dtype,
stage=zero_stage)
ds_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
ds_config[
'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size(
) * self.args.gradient_accumulation_steps

ds_eval_config = get_eval_ds_config(offload=False,
dtype=self.args.dtype,
stage=zero_stage)

# We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine.
ds_eval_config[
ds_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
ds_eval_config[
ds_config[
'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size(
) * self.args.gradient_accumulation_steps

# Model
reward_model = create_critic_model(
model_name_or_path=critic_model_name_or_path,
tokenizer=self.tokenizer,
ds_config=ds_eval_config,
ds_config=ds_config,
num_padding_at_beginning=self.args.num_padding_at_beginning,
rlhf_training=True,
dropout=self.args.critic_dropout,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ def parse_args():
'--offload_reference_model',
action='store_true',
help='Enable ZeRO Offload techniques for reference model')
parser.add_argument('--offload_reward_model',
action='store_true',
help='Enable ZeRO Offload techniques for reward model')
parser.add_argument(
'--actor_zero_stage',
type=int,
Expand Down
Loading