@@ -395,7 +395,7 @@ class TWA(AbstractCallback):
395
395
396
396
def __init__ (self , twa_start_epoch : int , lr_init : float ,
397
397
num_samples : int = 5 , reg_lambda : float = 0.0 ,
398
- max_epochs : int = None ):
398
+ max_epochs : int = None , twa_c_epochs : int = 5 ):
399
399
"""
400
400
Parameters
401
401
----------
@@ -409,13 +409,16 @@ def __init__(self, twa_start_epoch: int, lr_init: float,
409
409
Regularization coefficient for β updates.
410
410
max_epochs : int
411
411
Total number of training epochs.
412
+ twa_c_epochs : int
413
+ Spacing (in epochs) between consecutive weight samples for TWA.
412
414
"""
413
415
super ().__init__ ()
414
416
self .twa_start_epoch = twa_start_epoch
415
417
self .num_samples = num_samples
416
418
self .reg_lambda = reg_lambda
417
419
self .max_epochs = max_epochs
418
420
self .lr_init = lr_init
421
+ self .twa_c_epochs = twa_c_epochs
419
422
420
423
# State variables
421
424
self .current_epoch = - 1
@@ -464,7 +467,7 @@ def on_train_epoch_start(self, trainer, model):
464
467
def on_train_epoch_end (self , trainer , model ):
465
468
"""Main TWA logic: build subspace and update in β space."""
466
469
# Step 1: collect weight samples before TWA starts
467
- if self .current_epoch < self .twa_start_epoch :
470
+ if self .current_epoch < self .twa_start_epoch and self . current_epoch % self . twa_c_epochs == 0 :
468
471
self .sample_weights (model ) # rolling buffer handled inside
469
472
return
470
473
0 commit comments