Skip to content

Commit 6e375b8

Browse files
committed
add spacing for twa sampling
1 parent 558d2cc commit 6e375b8

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

dicee/weight_averaging.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ class TWA(AbstractCallback):
395395

396396
def __init__(self, twa_start_epoch: int, lr_init: float,
397397
num_samples: int = 5, reg_lambda: float = 0.0,
398-
max_epochs: int = None):
398+
max_epochs: int = None, twa_c_epochs :int = 5):
399399
"""
400400
Parameters
401401
----------
@@ -409,13 +409,16 @@ def __init__(self, twa_start_epoch: int, lr_init: float,
409409
Regularization coefficient for β updates.
410410
max_epochs : int
411411
Total number of training epochs.
412+
twa_c_epochs : int
413+
Spacing (in epochs) between consecutive weight samples for TWA.
412414
"""
413415
super().__init__()
414416
self.twa_start_epoch = twa_start_epoch
415417
self.num_samples = num_samples
416418
self.reg_lambda = reg_lambda
417419
self.max_epochs = max_epochs
418420
self.lr_init = lr_init
421+
self.twa_c_epochs = twa_c_epochs
419422

420423
# State variables
421424
self.current_epoch = -1
@@ -464,7 +467,7 @@ def on_train_epoch_start(self, trainer, model):
464467
def on_train_epoch_end(self, trainer, model):
465468
"""Main TWA logic: build subspace and update in β space."""
466469
# Step 1: collect weight samples before TWA starts
467-
if self.current_epoch < self.twa_start_epoch:
470+
if self.current_epoch < self.twa_start_epoch and self.current_epoch % self.twa_c_epochs == 0 :
468471
self.sample_weights(model) # rolling buffer handled inside
469472
return
470473

0 commit comments

Comments
 (0)