Skip to content

Commit b68324c

Browse files
authored
[bugfix] fix grpo padding_free (#5965)
* fix padding_free & overlong filter * fix * fix importance_sampling_level * fix loss_type grpo * revert non-padding free for importance_sampling_level * fix length logging * fix logging metrics
1 parent 197a845 commit b68324c

File tree

1 file changed

+74
-35
lines changed

1 file changed

+74
-35
lines changed

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 74 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,7 +1298,14 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]:
12981298
# Process labels and masks
12991299
labels = batch_encoded_inputs.pop('labels')
13001300
logits_to_keep = (labels.shape[-1] - (torch.ne(labels, -100).int().argmax(-1))).max().item()
1301-
1301+
extra_kwargs = {
1302+
'completion_mask':
1303+
labels[:, -logits_to_keep:] != -100,
1304+
'truncated_mask':
1305+
torch.tensor([b['is_truncated'] for b in batch], dtype=torch.bool, device=self.accelerator.device),
1306+
'logits_to_keep':
1307+
logits_to_keep,
1308+
}
13021309
if self.template.padding_free:
13031310
position_ids = batch_encoded_inputs.get('text_position_ids')
13041311
if position_ids is None:
@@ -1308,21 +1315,16 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]:
13081315
lengths = torch.diff(
13091316
torch.cat([(position_ids == 0).nonzero(as_tuple=True)[0],
13101317
torch.tensor([len(position_ids)]).to(position_ids.device)]))
1318+
total_lengths = lengths.sum()
1319+
# The first sentence has its prompt portion removed due to logits_to_keep
1320+
lengths[0] = lengths[0] - (total_lengths - logits_to_keep)
1321+
extra_kwargs.update({'seq_lengths': lengths})
13111322
advantages_stacked = torch.stack([data['advantages'] for data in batch])
1312-
all_advandages = torch.repeat_interleave(advantages_stacked, lengths)
1323+
all_advantages = torch.repeat_interleave(advantages_stacked, lengths)
13131324
else:
1314-
all_advandages = torch.stack([data['advantages'] for data in batch])
1315-
1316-
batch_encoded_inputs.update({
1317-
'completion_mask':
1318-
labels[:, -logits_to_keep:] != -100,
1319-
'truncated_mask':
1320-
torch.tensor([b['is_truncated'] for b in batch], dtype=torch.bool),
1321-
'logits_to_keep':
1322-
logits_to_keep,
1323-
'advantages':
1324-
all_advandages
1325-
})
1325+
all_advantages = torch.stack([data['advantages'] for data in batch])
1326+
extra_kwargs.update({'advantages': all_advantages})
1327+
batch_encoded_inputs.update(extra_kwargs)
13261328

13271329
with torch.no_grad():
13281330
batch_encoded_inputs['old_per_token_logps'] = (
@@ -1344,7 +1346,10 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]:
13441346
# --- log completion lengths ---
13451347
mode = 'train' if self.model.training else 'eval'
13461348
device = self.accelerator.device
1347-
local_lengths = [inp['completion_mask'].sum(1).tolist() for inp in ga_batch_encoded_inputs]
1349+
if self.template.padding_free:
1350+
local_lengths = [inp['seq_lengths'].tolist() for inp in ga_batch_encoded_inputs]
1351+
else:
1352+
local_lengths = [inp['completion_mask'].sum(1).tolist() for inp in ga_batch_encoded_inputs]
13481353
total_lengths = self._gather_and_flatten(local_lengths, dtype=torch.float32, device=device, flatten_level=1)
13491354

13501355
self._metrics[mode]['completions/mean_length'].append(total_lengths.mean().item())
@@ -1405,7 +1410,7 @@ def _compute_loss(self, model, inputs):
14051410
mode = 'train' if self.model.training else 'eval'
14061411

14071412
# Check batch size and decide processing strategy
1408-
batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs else len(inputs.get('completion_mask', []))
1413+
batch_size = inputs['seq_lengths'].shape[0] if self.template.padding_free else inputs['input_ids'].shape[0]
14091414
expected_bs = self.args.per_device_train_batch_size if mode == 'train' else self.args.per_device_eval_batch_size
14101415

14111416
should_chunk = self.dynamic_num_samples and any(gather_object([batch_size > expected_bs]))
@@ -1427,7 +1432,8 @@ def _compute_loss_and_metrics(self, model, inputs):
14271432

14281433
completion_mask = inputs['completion_mask']
14291434
truncated_mask = inputs['truncated_mask']
1430-
1435+
if self.template.padding_free:
1436+
lengths = inputs['seq_lengths']
14311437
per_token_logps, entropies = self._get_per_token_logps_and_entropies(
14321438
model, inputs, compute_entropy=self.compute_entropy)
14331439

@@ -1438,7 +1444,11 @@ def _compute_loss_and_metrics(self, model, inputs):
14381444
# fill the padded token with NaN
14391445
entropies = entropies.masked_fill(completion_mask == 0, float('nan'))
14401446
if self.args.log_entropy:
1441-
per_completion_entropies_mean = torch.nanmean(entropies, dim=1)
1447+
if self.template.padding_free:
1448+
entropy_list = torch.split(entropies, lengths.tolist())
1449+
per_completion_entropies_mean = torch.stack([torch.nanmean(e) for e in entropy_list])
1450+
else:
1451+
per_completion_entropies_mean = torch.nanmean(entropies, dim=1)
14421452
global_per_completion_entropies_mean = gather(per_completion_entropies_mean)
14431453
entropy_metrics = {
14441454
'entropy_logs': global_per_completion_entropies_mean.tolist(),
@@ -1458,7 +1468,11 @@ def _compute_loss_and_metrics(self, model, inputs):
14581468
if all(truncated_mask):
14591469
logger.info('All completions are overlong and truncated, '
14601470
'resulting in NaN some values for some metrics (e.g., KL)')
1461-
truncated_mask = truncated_mask.unsqueeze(-1).expand_as(completion_mask).to(completion_mask.device)
1471+
if self.template.padding_free:
1472+
truncated_mask = torch.repeat_interleave(truncated_mask, lengths).unsqueeze(0)
1473+
assert truncated_mask.shape == completion_mask.shape
1474+
else:
1475+
truncated_mask = truncated_mask.unsqueeze(-1).expand_as(completion_mask)
14621476
completion_mask = completion_mask & (~truncated_mask)
14631477

14641478
# Compute the KL divergence between the model and the reference model
@@ -1477,14 +1491,29 @@ def _compute_loss_and_metrics(self, model, inputs):
14771491
log_ratio = per_token_logps - old_per_token_logps
14781492
if self.importance_sampling_level == 'token':
14791493
log_importance_weights = log_ratio
1480-
elif self.importance_sampling_level == 'sequence':
1481-
log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
1482-
log_importance_weights = log_importance_weights.unsqueeze(-1)
1483-
elif self.importance_sampling_level == 'sequence_token':
1484-
# GSPO-token: sg[si(θ)] * πθ(yi,t)/sg[πθ(yi,t)]
1485-
seq_level_log_weight = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
1486-
seq_level_log_weight = seq_level_log_weight.detach().unsqueeze(-1) # Stop gradient
1487-
log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight
1494+
elif self.importance_sampling_level in ['sequence', 'sequence_token']:
1495+
if self.template.padding_free:
1496+
# split to batch, compute seq-level normalization
1497+
log_ratio_list = torch.split(log_ratio.squeeze(0), lengths.tolist())
1498+
mask_list = torch.split(completion_mask.squeeze(0), lengths.tolist())
1499+
seq_weights = [(lr * m).sum() / m.sum().clamp(min=1.0) for lr, m in zip(log_ratio_list, mask_list)]
1500+
seq_level_log_weights = torch.stack(seq_weights).to(log_ratio.dtype).unsqueeze(-1)
1501+
if self.importance_sampling_level == 'sequence':
1502+
log_importance_weights = seq_level_log_weights
1503+
else:
1504+
seq_level_log_weight = seq_level_log_weights.detach()
1505+
seq_level_log_weight = torch.repeat_interleave(seq_level_log_weight, lengths).unsqueeze(0)
1506+
log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight
1507+
else:
1508+
seq_level_log_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(
1509+
min=1.0).unsqueeze(-1)
1510+
if self.importance_sampling_level == 'sequence':
1511+
log_importance_weights = seq_level_log_weights
1512+
else:
1513+
# GSPO-token: sg[si(θ)] * πθ(yi,t)/sg[πθ(yi,t)]
1514+
seq_level_log_weight = seq_level_log_weights.detach()
1515+
log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight
1516+
14881517
else:
14891518
raise ValueError(
14901519
f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' "
@@ -1509,17 +1538,26 @@ def _compute_loss_and_metrics(self, model, inputs):
15091538
per_token_loss = per_token_loss + self.beta * per_token_kl
15101539

15111540
if self.loss_type == 'grpo':
1512-
loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean()
1541+
if self.template.padding_free:
1542+
loss_list = torch.split(per_token_loss.squeeze(0), lengths.tolist())
1543+
mask_list = torch.split(completion_mask.squeeze(0), lengths.tolist())
1544+
sample_loss = [(loss * mask).sum() / mask.sum().clamp(min=1.0)
1545+
for loss, mask in zip(loss_list, mask_list)]
1546+
loss = torch.stack(sample_loss).mean()
1547+
else:
1548+
loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean()
15131549
elif self.loss_type == 'bnpo':
15141550
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
15151551
elif self.loss_type == 'dr_grpo':
1516-
loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
1552+
batch_size = lengths.shape[0] if self.template.padding_free else inputs['input_ids'].shape[0]
1553+
loss = (per_token_loss * completion_mask).sum() / (batch_size * self.max_completion_length)
15171554
else:
15181555
raise ValueError(f'Unknown loss type: {self.loss_type}')
15191556

15201557
completion_token_count = completion_mask.sum().clamp(min=1.0)
15211558

15221559
def masked_batch_mean(x):
1560+
# compute for token-level average
15231561
if x.shape[1] == 1: # when importance_sampling_level == "sequence"
15241562
return x.mean()
15251563
else:
@@ -1531,7 +1569,6 @@ def masked_batch_mean(x):
15311569
'entropy': entropy_metrics,
15321570
'completion_mask': completion_mask,
15331571
'completion_token_count': completion_token_count,
1534-
'masked_batch_mean_fn': masked_batch_mean
15351572
}
15361573

15371574
if self.beta != 0.0:
@@ -1601,7 +1638,7 @@ def _compute_loss_chunked(self, model, inputs: DataType):
16011638
"""
16021639
mode = 'train' if self.model.training else 'eval'
16031640
chunk_size = self.args.per_device_train_batch_size if mode == 'train' else self.args.per_device_eval_batch_size
1604-
batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs else len(inputs.get('completion_mask', []))
1641+
batch_size = inputs['seq_lengths'].shape[0] if self.template.padding_free else inputs['input_ids'].shape[0]
16051642

16061643
# Decide how many chunks every rank must run
16071644
batch_sizes = gather_object([batch_size])
@@ -1777,7 +1814,7 @@ def _get_per_token_logps_and_entropies(self,
17771814
When rollout count is larger than expected, we process in smaller batches
17781815
to control memory usage.
17791816
"""
1780-
batch_size = inputs['input_ids'].shape[0]
1817+
batch_size = inputs['seq_lengths'].shape[0] if self.template.padding_free else inputs['input_ids'].shape[0]
17811818
mode = 'train' if self.model.training else 'eval'
17821819
expected_bs = self.args.per_device_train_batch_size if mode == 'train' else self.args.per_device_eval_batch_size # noqa
17831820
should_chunk = self.dynamic_num_samples and any(gather_object([batch_size > expected_bs]))
@@ -1816,7 +1853,7 @@ def _get_per_token_logps_and_entropies_single(self,
18161853
k: v
18171854
for k, v in inputs.items() if k not in [
18181855
'logits_to_keep', 'completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps',
1819-
'truncated_mask'
1856+
'truncated_mask', 'seq_lengths'
18201857
]
18211858
}
18221859
if 'logits_to_keep' in self.model_kwarg_keys:
@@ -1862,8 +1899,7 @@ def _get_per_token_logps_and_entropies_chunked(self,
18621899
Concatenated per-token entropies, or ``None`` if ``compute_entropy`` is
18631900
``False``.
18641901
"""
1865-
1866-
batch_size = inputs['input_ids'].shape[0]
1902+
batch_size = inputs['seq_lengths'].shape[0] if self.template.padding_free else inputs['input_ids'].shape[0]
18671903
mode = 'train' if self.model.training else 'eval'
18681904
chunk_size = self.args.per_device_train_batch_size if mode == 'train' else self.args.per_device_eval_batch_size
18691905

@@ -1926,6 +1962,7 @@ def _get_last_hidden_state(self, unwrapped_model, inputs, logits_to_keep):
19261962

19271963
def compute_liger_loss(self, unwrapped_model, inputs):
19281964
# Compute the per-token log probabilities for the model
1965+
assert not self.template.padding_free
19291966
input_ids = inputs['input_ids']
19301967
logits_to_keep = inputs['logits_to_keep']
19311968
completion_ids = input_ids[:, -logits_to_keep:]
@@ -2359,6 +2396,8 @@ def _server_rollout(self, inputs: DataType, request_config: RequestConfig,
23592396
'With --dynamic_sample enabled, only the last valid sample of each '
23602397
f'{self.args.generation_batch_size}-sized batch will be kept; '
23612398
'some requests may therefore be dropped.')
2399+
if self.template.padding_free:
2400+
raise NotImplementedError('Padding free mode is not supported for dynamic sample')
23622401
# Initialize empty outputs for non-main processes
23632402
if not self.accelerator.is_main_process:
23642403
all_outputs = [None] * outputs_count

0 commit comments

Comments
 (0)