3737
3838# Import unified datapipe
3939from datapipe import SimSample , simsample_collate
40+ from omegaconf import open_dict
4041
4142
4243class Trainer :
@@ -113,6 +114,58 @@ def __init__(self, cfg: DictConfig, logger0: RankZeroLoggingWrapper):
113114 )
114115 self .sampler = sampler
115116
117+ if cfg .training .num_validation_samples > 0 :
118+ self .num_validation_replicas = min (
119+ self .dist .world_size , cfg .training .num_validation_samples
120+ )
121+ self .num_validation_samples = (
122+ cfg .training .num_validation_samples
123+ // self .num_validation_replicas
124+ * self .num_validation_replicas
125+ )
126+ logger0 .info (f"Number of validation samples: { self .num_validation_samples } " )
127+
128+ # Create a validation dataset
129+ val_cfg = self .cfg .datapipe
130+ with open_dict (val_cfg ): # or open_dict(cfg) to open the whole tree
131+ val_cfg .data_dir = self .cfg .training .raw_data_dir_validation
132+ val_cfg .num_samples = self .num_validation_samples
133+ val_dataset = instantiate (
134+ val_cfg ,
135+ name = "crash_validation" ,
136+ reader = reader ,
137+ split = "validation" ,
138+ logger = logger0 ,
139+ )
140+
141+ if self .dist .rank < self .num_validation_replicas :
142+ # Sampler
143+ if self .dist .world_size > 1 :
144+ sampler = DistributedSampler (
145+ val_dataset ,
146+ num_replicas = self .num_validation_replicas ,
147+ rank = self .dist .rank ,
148+ shuffle = False ,
149+ drop_last = True ,
150+ )
151+ else :
152+ sampler = None
153+
154+ self .val_dataloader = torch .utils .data .DataLoader (
155+ val_dataset ,
156+ batch_size = 1 , # variable N per sample
157+ shuffle = (sampler is None ),
158+ drop_last = True ,
159+ pin_memory = True ,
160+ num_workers = cfg .training .num_dataloader_workers ,
161+ sampler = sampler ,
162+ collate_fn = simsample_collate ,
163+ )
164+ else :
165+ self .val_dataloader = torch .utils .data .DataLoader (
166+ torch .utils .data .Subset (val_dataset , []), batch_size = 1
167+ )
168+
116169 # Model
117170 self .model = instantiate (cfg .model )
118171 logging .getLogger ().setLevel (logging .INFO )
@@ -203,6 +256,48 @@ def backward(self, loss):
203256 loss .backward ()
204257 self .optimizer .step ()
205258
259+ @torch .no_grad ()
260+ def validate (self , epoch ):
261+ """Run validation error computation"""
262+ self .model .eval ()
263+
264+ MSE = torch .zeros (1 , device = self .dist .device )
265+ MSE_w_time = torch .zeros (self .rollout_steps , device = self .dist .device )
266+ for idx , sample in enumerate (self .val_dataloader ):
267+ sample = sample [0 ].to (self .dist .device ) # SimSample .to()
268+ T = self .rollout_steps
269+
270+ # Model forward
271+ pred_seq = self .model (sample = sample , data_stats = self .data_stats )
272+
273+ # Exact sequence
274+ N = sample .node_target .size (0 )
275+ Fo = 3 # output features per node
276+ assert sample .node_target .size (1 ) == T * Fo , (
277+ f"target dim { sample .node_target .size (1 )} != { T * Fo } "
278+ )
279+ exact_seq = (
280+ sample .node_target .view (N , T , Fo ).transpose (0 , 1 ).contiguous ()
281+ ) # [T,N,Fo]
282+
283+ # Compute and add error
284+ SqError = torch .square (pred_seq - exact_seq )
285+ MSE_w_time += torch .mean (SqError , dim = (1 , 2 ))
286+ MSE += torch .mean (SqError )
287+
288+ # Sum errors across all ranks
289+ if self .dist .world_size > 1 :
290+ torch .distributed .all_reduce (MSE , op = torch .distributed .ReduceOp .SUM )
291+ torch .distributed .all_reduce (MSE_w_time , op = torch .distributed .ReduceOp .SUM )
292+
293+ val_stats = {
294+ "MSE_w_time" : MSE_w_time / self .num_validation_samples ,
295+ "MSE" : MSE / self .num_validation_samples ,
296+ }
297+
298+ self .model .train () # Switch back to training mode
299+ return val_stats
300+
206301
207302@hydra .main (version_base = "1.3" , config_path = "conf" , config_name = "config" )
208303def main (cfg : DictConfig ) -> None :
@@ -247,7 +342,8 @@ def main(cfg: DictConfig) -> None:
247342
248343 if dist .world_size > 1 :
249344 torch .distributed .barrier ()
250- if dist .rank == 0 :
345+
346+ if dist .rank == 0 and (epoch + 1 ) % cfg .training .save_chckpoint_freq == 0 :
251347 save_checkpoint (
252348 cfg .training .ckpt_path ,
253349 models = trainer .model ,
@@ -258,6 +354,31 @@ def main(cfg: DictConfig) -> None:
258354 )
259355 logger .info (f"Saved model on rank { dist .rank } " )
260356
357+ # Validation
358+ if (
359+ cfg .training .num_validation_samples > 0
360+ and (epoch + 1 ) % cfg .training .validation_freq == 0
361+ ):
362+ # logger0.info(f"Validation started...")
363+ val_stats = trainer .validate (epoch )
364+
365+ # Log detailed validation statistics
366+ logger0 .info (
367+ f"Validation epoch { epoch + 1 } : MSE: { val_stats ['MSE' ].item ():.3e} , "
368+ )
369+
370+ if dist .rank == 0 :
371+ # Log to tensorboard
372+ trainer .writer .add_scalar ("val/MSE" , val_stats ["MSE" ].item (), epoch )
373+
374+ # Log individual timestep relative errors
375+ for i in range (len (val_stats ["MSE_w_time" ])):
376+ trainer .writer .add_scalar (
377+ f"val/timestep_{ i } _MSE" ,
378+ val_stats ["MSE_w_time" ][i ].item (),
379+ epoch ,
380+ )
381+
261382 logger0 .info ("Training completed!" )
262383 if dist .rank == 0 :
263384 trainer .writer .close ()
0 commit comments