@@ -1298,7 +1298,14 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]:
1298
1298
# Process labels and masks
1299
1299
labels = batch_encoded_inputs .pop ('labels' )
1300
1300
logits_to_keep = (labels .shape [- 1 ] - (torch .ne (labels , - 100 ).int ().argmax (- 1 ))).max ().item ()
1301
-
1301
+ extra_kwargs = {
1302
+ 'completion_mask' :
1303
+ labels [:, - logits_to_keep :] != - 100 ,
1304
+ 'truncated_mask' :
1305
+ torch .tensor ([b ['is_truncated' ] for b in batch ], dtype = torch .bool , device = self .accelerator .device ),
1306
+ 'logits_to_keep' :
1307
+ logits_to_keep ,
1308
+ }
1302
1309
if self .template .padding_free :
1303
1310
position_ids = batch_encoded_inputs .get ('text_position_ids' )
1304
1311
if position_ids is None :
@@ -1308,21 +1315,16 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]:
1308
1315
lengths = torch .diff (
1309
1316
torch .cat ([(position_ids == 0 ).nonzero (as_tuple = True )[0 ],
1310
1317
torch .tensor ([len (position_ids )]).to (position_ids .device )]))
1318
+ total_lengths = lengths .sum ()
1319
+ # The first sentence has its prompt portion removed due to logits_to_keep
1320
+ lengths [0 ] = lengths [0 ] - (total_lengths - logits_to_keep )
1321
+ extra_kwargs .update ({'seq_lengths' : lengths })
1311
1322
advantages_stacked = torch .stack ([data ['advantages' ] for data in batch ])
1312
- all_advandages = torch .repeat_interleave (advantages_stacked , lengths )
1323
+ all_advantages = torch .repeat_interleave (advantages_stacked , lengths )
1313
1324
else :
1314
- all_advandages = torch .stack ([data ['advantages' ] for data in batch ])
1315
-
1316
- batch_encoded_inputs .update ({
1317
- 'completion_mask' :
1318
- labels [:, - logits_to_keep :] != - 100 ,
1319
- 'truncated_mask' :
1320
- torch .tensor ([b ['is_truncated' ] for b in batch ], dtype = torch .bool ),
1321
- 'logits_to_keep' :
1322
- logits_to_keep ,
1323
- 'advantages' :
1324
- all_advandages
1325
- })
1325
+ all_advantages = torch .stack ([data ['advantages' ] for data in batch ])
1326
+ extra_kwargs .update ({'advantages' : all_advantages })
1327
+ batch_encoded_inputs .update (extra_kwargs )
1326
1328
1327
1329
with torch .no_grad ():
1328
1330
batch_encoded_inputs ['old_per_token_logps' ] = (
@@ -1344,7 +1346,10 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]:
1344
1346
# --- log completion lengths ---
1345
1347
mode = 'train' if self .model .training else 'eval'
1346
1348
device = self .accelerator .device
1347
- local_lengths = [inp ['completion_mask' ].sum (1 ).tolist () for inp in ga_batch_encoded_inputs ]
1349
+ if self .template .padding_free :
1350
+ local_lengths = [inp ['seq_lengths' ].tolist () for inp in ga_batch_encoded_inputs ]
1351
+ else :
1352
+ local_lengths = [inp ['completion_mask' ].sum (1 ).tolist () for inp in ga_batch_encoded_inputs ]
1348
1353
total_lengths = self ._gather_and_flatten (local_lengths , dtype = torch .float32 , device = device , flatten_level = 1 )
1349
1354
1350
1355
self ._metrics [mode ]['completions/mean_length' ].append (total_lengths .mean ().item ())
@@ -1405,7 +1410,7 @@ def _compute_loss(self, model, inputs):
1405
1410
mode = 'train' if self .model .training else 'eval'
1406
1411
1407
1412
# Check batch size and decide processing strategy
1408
- batch_size = inputs ['input_ids ' ].shape [0 ] if 'input_ids' in inputs else len ( inputs . get ( 'completion_mask' , []))
1413
+ batch_size = inputs ['seq_lengths ' ].shape [0 ] if self . template . padding_free else inputs [ 'input_ids' ]. shape [ 0 ]
1409
1414
expected_bs = self .args .per_device_train_batch_size if mode == 'train' else self .args .per_device_eval_batch_size
1410
1415
1411
1416
should_chunk = self .dynamic_num_samples and any (gather_object ([batch_size > expected_bs ]))
@@ -1427,7 +1432,8 @@ def _compute_loss_and_metrics(self, model, inputs):
1427
1432
1428
1433
completion_mask = inputs ['completion_mask' ]
1429
1434
truncated_mask = inputs ['truncated_mask' ]
1430
-
1435
+ if self .template .padding_free :
1436
+ lengths = inputs ['seq_lengths' ]
1431
1437
per_token_logps , entropies = self ._get_per_token_logps_and_entropies (
1432
1438
model , inputs , compute_entropy = self .compute_entropy )
1433
1439
@@ -1438,7 +1444,11 @@ def _compute_loss_and_metrics(self, model, inputs):
1438
1444
# fill the padded token with NaN
1439
1445
entropies = entropies .masked_fill (completion_mask == 0 , float ('nan' ))
1440
1446
if self .args .log_entropy :
1441
- per_completion_entropies_mean = torch .nanmean (entropies , dim = 1 )
1447
+ if self .template .padding_free :
1448
+ entropy_list = torch .split (entropies , lengths .tolist ())
1449
+ per_completion_entropies_mean = torch .stack ([torch .nanmean (e ) for e in entropy_list ])
1450
+ else :
1451
+ per_completion_entropies_mean = torch .nanmean (entropies , dim = 1 )
1442
1452
global_per_completion_entropies_mean = gather (per_completion_entropies_mean )
1443
1453
entropy_metrics = {
1444
1454
'entropy_logs' : global_per_completion_entropies_mean .tolist (),
@@ -1458,7 +1468,11 @@ def _compute_loss_and_metrics(self, model, inputs):
1458
1468
if all (truncated_mask ):
1459
1469
logger .info ('All completions are overlong and truncated, '
1460
1470
'resulting in NaN some values for some metrics (e.g., KL)' )
1461
- truncated_mask = truncated_mask .unsqueeze (- 1 ).expand_as (completion_mask ).to (completion_mask .device )
1471
+ if self .template .padding_free :
1472
+ truncated_mask = torch .repeat_interleave (truncated_mask , lengths ).unsqueeze (0 )
1473
+ assert truncated_mask .shape == completion_mask .shape
1474
+ else :
1475
+ truncated_mask = truncated_mask .unsqueeze (- 1 ).expand_as (completion_mask )
1462
1476
completion_mask = completion_mask & (~ truncated_mask )
1463
1477
1464
1478
# Compute the KL divergence between the model and the reference model
@@ -1477,14 +1491,29 @@ def _compute_loss_and_metrics(self, model, inputs):
1477
1491
log_ratio = per_token_logps - old_per_token_logps
1478
1492
if self .importance_sampling_level == 'token' :
1479
1493
log_importance_weights = log_ratio
1480
- elif self .importance_sampling_level == 'sequence' :
1481
- log_importance_weights = (log_ratio * completion_mask ).sum (- 1 ) / completion_mask .sum (- 1 ).clamp (min = 1.0 )
1482
- log_importance_weights = log_importance_weights .unsqueeze (- 1 )
1483
- elif self .importance_sampling_level == 'sequence_token' :
1484
- # GSPO-token: sg[si(θ)] * πθ(yi,t)/sg[πθ(yi,t)]
1485
- seq_level_log_weight = (log_ratio * completion_mask ).sum (- 1 ) / completion_mask .sum (- 1 ).clamp (min = 1.0 )
1486
- seq_level_log_weight = seq_level_log_weight .detach ().unsqueeze (- 1 ) # Stop gradient
1487
- log_importance_weights = per_token_logps - per_token_logps .detach () + seq_level_log_weight
1494
+ elif self .importance_sampling_level in ['sequence' , 'sequence_token' ]:
1495
+ if self .template .padding_free :
1496
+ # split to batch, compute seq-level normalization
1497
+ log_ratio_list = torch .split (log_ratio .squeeze (0 ), lengths .tolist ())
1498
+ mask_list = torch .split (completion_mask .squeeze (0 ), lengths .tolist ())
1499
+ seq_weights = [(lr * m ).sum () / m .sum ().clamp (min = 1.0 ) for lr , m in zip (log_ratio_list , mask_list )]
1500
+ seq_level_log_weights = torch .stack (seq_weights ).to (log_ratio .dtype ).unsqueeze (- 1 )
1501
+ if self .importance_sampling_level == 'sequence' :
1502
+ log_importance_weights = seq_level_log_weights
1503
+ else :
1504
+ seq_level_log_weight = seq_level_log_weights .detach ()
1505
+ seq_level_log_weight = torch .repeat_interleave (seq_level_log_weight , lengths ).unsqueeze (0 )
1506
+ log_importance_weights = per_token_logps - per_token_logps .detach () + seq_level_log_weight
1507
+ else :
1508
+ seq_level_log_weights = (log_ratio * completion_mask ).sum (- 1 ) / completion_mask .sum (- 1 ).clamp (
1509
+ min = 1.0 ).unsqueeze (- 1 )
1510
+ if self .importance_sampling_level == 'sequence' :
1511
+ log_importance_weights = seq_level_log_weights
1512
+ else :
1513
+ # GSPO-token: sg[si(θ)] * πθ(yi,t)/sg[πθ(yi,t)]
1514
+ seq_level_log_weight = seq_level_log_weights .detach ()
1515
+ log_importance_weights = per_token_logps - per_token_logps .detach () + seq_level_log_weight
1516
+
1488
1517
else :
1489
1518
raise ValueError (
1490
1519
f"Unknown importance sampling level: { self .importance_sampling_level } . Possible values are 'token' "
@@ -1509,17 +1538,26 @@ def _compute_loss_and_metrics(self, model, inputs):
1509
1538
per_token_loss = per_token_loss + self .beta * per_token_kl
1510
1539
1511
1540
if self .loss_type == 'grpo' :
1512
- loss = ((per_token_loss * completion_mask ).sum (- 1 ) / completion_mask .sum (- 1 ).clamp (min = 1.0 )).mean ()
1541
+ if self .template .padding_free :
1542
+ loss_list = torch .split (per_token_loss .squeeze (0 ), lengths .tolist ())
1543
+ mask_list = torch .split (completion_mask .squeeze (0 ), lengths .tolist ())
1544
+ sample_loss = [(loss * mask ).sum () / mask .sum ().clamp (min = 1.0 )
1545
+ for loss , mask in zip (loss_list , mask_list )]
1546
+ loss = torch .stack (sample_loss ).mean ()
1547
+ else :
1548
+ loss = ((per_token_loss * completion_mask ).sum (- 1 ) / completion_mask .sum (- 1 ).clamp (min = 1.0 )).mean ()
1513
1549
elif self .loss_type == 'bnpo' :
1514
1550
loss = (per_token_loss * completion_mask ).sum () / completion_mask .sum ().clamp (min = 1.0 )
1515
1551
elif self .loss_type == 'dr_grpo' :
1516
- loss = (per_token_loss * completion_mask ).sum () / (per_token_loss .size (0 ) * self .max_completion_length )
1552
+ batch_size = lengths .shape [0 ] if self .template .padding_free else inputs ['input_ids' ].shape [0 ]
1553
+ loss = (per_token_loss * completion_mask ).sum () / (batch_size * self .max_completion_length )
1517
1554
else :
1518
1555
raise ValueError (f'Unknown loss type: { self .loss_type } ' )
1519
1556
1520
1557
completion_token_count = completion_mask .sum ().clamp (min = 1.0 )
1521
1558
1522
1559
def masked_batch_mean (x ):
1560
+ # compute for token-level average
1523
1561
if x .shape [1 ] == 1 : # when importance_sampling_level == "sequence"
1524
1562
return x .mean ()
1525
1563
else :
@@ -1531,7 +1569,6 @@ def masked_batch_mean(x):
1531
1569
'entropy' : entropy_metrics ,
1532
1570
'completion_mask' : completion_mask ,
1533
1571
'completion_token_count' : completion_token_count ,
1534
- 'masked_batch_mean_fn' : masked_batch_mean
1535
1572
}
1536
1573
1537
1574
if self .beta != 0.0 :
@@ -1601,7 +1638,7 @@ def _compute_loss_chunked(self, model, inputs: DataType):
1601
1638
"""
1602
1639
mode = 'train' if self .model .training else 'eval'
1603
1640
chunk_size = self .args .per_device_train_batch_size if mode == 'train' else self .args .per_device_eval_batch_size
1604
- batch_size = inputs ['input_ids ' ].shape [0 ] if 'input_ids' in inputs else len ( inputs . get ( 'completion_mask' , []))
1641
+ batch_size = inputs ['seq_lengths ' ].shape [0 ] if self . template . padding_free else inputs [ 'input_ids' ]. shape [ 0 ]
1605
1642
1606
1643
# Decide how many chunks every rank must run
1607
1644
batch_sizes = gather_object ([batch_size ])
@@ -1777,7 +1814,7 @@ def _get_per_token_logps_and_entropies(self,
1777
1814
When rollout count is larger than expected, we process in smaller batches
1778
1815
to control memory usage.
1779
1816
"""
1780
- batch_size = inputs ['input_ids' ].shape [0 ]
1817
+ batch_size = inputs ['seq_lengths' ]. shape [ 0 ] if self . template . padding_free else inputs [ ' input_ids' ].shape [0 ]
1781
1818
mode = 'train' if self .model .training else 'eval'
1782
1819
expected_bs = self .args .per_device_train_batch_size if mode == 'train' else self .args .per_device_eval_batch_size # noqa
1783
1820
should_chunk = self .dynamic_num_samples and any (gather_object ([batch_size > expected_bs ]))
@@ -1816,7 +1853,7 @@ def _get_per_token_logps_and_entropies_single(self,
1816
1853
k : v
1817
1854
for k , v in inputs .items () if k not in [
1818
1855
'logits_to_keep' , 'completion_mask' , 'ref_per_token_logps' , 'advantages' , 'old_per_token_logps' ,
1819
- 'truncated_mask'
1856
+ 'truncated_mask' , 'seq_lengths'
1820
1857
]
1821
1858
}
1822
1859
if 'logits_to_keep' in self .model_kwarg_keys :
@@ -1862,8 +1899,7 @@ def _get_per_token_logps_and_entropies_chunked(self,
1862
1899
Concatenated per-token entropies, or ``None`` if ``compute_entropy`` is
1863
1900
``False``.
1864
1901
"""
1865
-
1866
- batch_size = inputs ['input_ids' ].shape [0 ]
1902
+ batch_size = inputs ['seq_lengths' ].shape [0 ] if self .template .padding_free else inputs ['input_ids' ].shape [0 ]
1867
1903
mode = 'train' if self .model .training else 'eval'
1868
1904
chunk_size = self .args .per_device_train_batch_size if mode == 'train' else self .args .per_device_eval_batch_size
1869
1905
@@ -1926,6 +1962,7 @@ def _get_last_hidden_state(self, unwrapped_model, inputs, logits_to_keep):
1926
1962
1927
1963
def compute_liger_loss (self , unwrapped_model , inputs ):
1928
1964
# Compute the per-token log probabilities for the model
1965
+ assert not self .template .padding_free
1929
1966
input_ids = inputs ['input_ids' ]
1930
1967
logits_to_keep = inputs ['logits_to_keep' ]
1931
1968
completion_ids = input_ids [:, - logits_to_keep :]
@@ -2359,6 +2396,8 @@ def _server_rollout(self, inputs: DataType, request_config: RequestConfig,
2359
2396
'With --dynamic_sample enabled, only the last valid sample of each '
2360
2397
f'{ self .args .generation_batch_size } -sized batch will be kept; '
2361
2398
'some requests may therefore be dropped.' )
2399
+ if self .template .padding_free :
2400
+ raise NotImplementedError ('Padding free mode is not supported for dynamic sample' )
2362
2401
# Initialize empty outputs for non-main processes
2363
2402
if not self .accelerator .is_main_process :
2364
2403
all_outputs = [None ] * outputs_count
0 commit comments