From f44213469d1a607da4d9f44880322465d43dc4ef Mon Sep 17 00:00:00 2001 From: Deepak Akhare Date: Thu, 30 Oct 2025 16:24:04 -0700 Subject: [PATCH 1/9] validation added: works for multi-node job. --- .../crash/conf/training/default.yaml | 1 + examples/structural_mechanics/crash/train.py | 110 +++++++++++++++++- 2 files changed, 110 insertions(+), 1 deletion(-) diff --git a/examples/structural_mechanics/crash/conf/training/default.yaml b/examples/structural_mechanics/crash/conf/training/default.yaml index 2d92f6456c..53269986c7 100644 --- a/examples/structural_mechanics/crash/conf/training/default.yaml +++ b/examples/structural_mechanics/crash/conf/training/default.yaml @@ -27,6 +27,7 @@ max_workers_preprocessing: 64 # Maximum parallel workers num_time_steps: 14 num_training_samples: 8 +num_validation_samples: 8 start_lr: 0.0001 end_lr: 0.0000003 epochs: 10000 diff --git a/examples/structural_mechanics/crash/train.py b/examples/structural_mechanics/crash/train.py index 31caeffb33..f554dd2dde 100644 --- a/examples/structural_mechanics/crash/train.py +++ b/examples/structural_mechanics/crash/train.py @@ -38,12 +38,13 @@ # Import unified datapipe from datapipe import SimSample, simsample_collate +from omegaconf import open_dict class Trainer: """Trainer for crash simulation models with unified SimSample input.""" - def __init__(self, cfg: DictConfig, logger0: RankZeroLoggingWrapper): + def __init__(self, cfg: DictConfig, logger0: RankZeroLoggingWrapper, validation: bool = True): assert DistributedManager.is_initialized() self.dist = DistributedManager() self.cfg = cfg @@ -104,6 +105,49 @@ def __init__(self, cfg: DictConfig, logger0: RankZeroLoggingWrapper): ) self.sampler = sampler + if validation: + self.num_validation_replicas = min(self.dist.world_size, cfg.training.num_validation_samples) + self.num_validation_samples = cfg.training.num_validation_samples // self.num_validation_replicas * self.num_validation_replicas + logger0.info(f'Number of validation samples: {self.num_validation_samples}') + + # Create a validation dataset + val_cfg = self.cfg.datapipe + with open_dict(val_cfg): # or open_dict(cfg) to open the whole tree + val_cfg.data_dir = self.cfg.training.raw_data_dir_test + val_cfg.num_samples = self.num_validation_samples + val_dataset = instantiate( + val_cfg, + name="crash_test", + split="test", + logger=logger0, + ) + + if self.dist.rank < self.num_validation_replicas: + # Sampler + if self.dist.world_size > 1: + sampler = DistributedSampler( + val_dataset, + num_replicas=self.num_validation_replicas, + rank=self.dist.rank, + shuffle=False, + drop_last=True, + ) + else: + sampler = None + + self.val_dataloader = torch.utils.data.DataLoader( + val_dataset, + batch_size=1, # variable N per sample + shuffle=(sampler is None), + drop_last=True, + pin_memory=True, + num_workers=cfg.training.num_dataloader_workers, + sampler=sampler, + collate_fn=simsample_collate, + ) + else: + self.val_dataloader = torch.utils.data.DataLoader(torch.utils.data.Subset(val_dataset, []), batch_size=1) + # Model self.model = instantiate(cfg.model) logging.getLogger().setLevel(logging.INFO) @@ -194,6 +238,49 @@ def backward(self, loss): loss.backward() self.optimizer.step() + @torch.no_grad() + def validate_with_rollout(self, epoch): + """Run validation using the rollout approach with relative error computation""" + self.model.eval() + + MSE = torch.zeros(1, device=self.dist.device) + MSE_w_time = torch.zeros(self.rollout_steps, device=self.dist.device) + for idx, sample in enumerate(self.val_dataloader): + sample = sample[0].to(self.dist.device) # SimSample .to() + T = self.rollout_steps + Fo = 3 + # Forward rollout: expected to return [T,N,3] + pred_seq = self.model(sample=sample, data_stats=self.data_stats) + + # Exact sequence (if provided) + exact_seq = None + if sample.node_target is not None: + N = sample.node_target.size(0) + assert sample.node_target.size(1) == T * Fo + exact_seq = ( + sample.node_target.view(N, T, Fo) + .transpose(0, 1) + .contiguous() + ) + + # Compute detailed relative error losses + SqError = torch.square(pred_seq - exact_seq) + MSE_w_time += torch.mean(SqError, dim=(1,2)) + MSE += torch.mean(SqError) + + if self.dist.world_size > 1: + torch.distributed.all_reduce(MSE, op=torch.distributed.ReduceOp.SUM) + torch.distributed.all_reduce(MSE_w_time, op=torch.distributed.ReduceOp.SUM) + + # Combine all statistics + val_stats = { + 'MSE_w_time': MSE_w_time / self.num_validation_samples, + 'MSE': MSE / self.num_validation_samples, + } + + self.model.train() # Switch back to training mode + return val_stats + @hydra.main(version_base="1.3", config_path="conf", config_name="config") def main(cfg: DictConfig) -> None: @@ -249,6 +336,27 @@ def main(cfg: DictConfig) -> None: ) logger.info(f"Saved model on rank {dist.rank}") + # Validation + #TODO: Add validation frequency to config + val_freq = 10 + if (epoch + 1) % val_freq == 0: + # logger0.info(f"Validation started...") + val_stats = trainer.validate_with_rollout(epoch) + + # Log detailed validation statistics + logger0.info( + f"Validation epoch {epoch+1}: " + f"MSE: {val_stats['MSE'].item():.3e}, " + ) + + if dist.rank == 0: + # Log to tensorboard + trainer.writer.add_scalar("val/MSE", val_stats['MSE'].item(), epoch) + + # Log individual timestep relative errors + for i in range(len(val_stats['MSE_w_time'])): + trainer.writer.add_scalar(f"val/timestep_{i}_MSE", val_stats['MSE_w_time'][i].item(), epoch) + logger0.info("Training completed!") if dist.rank == 0: trainer.writer.close() From 87ad1606208117065a87bd730753628eea16472d Mon Sep 17 00:00:00 2001 From: Deepak Akhare Date: Thu, 30 Oct 2025 16:58:09 -0700 Subject: [PATCH 2/9] rename and rearrange validation function --- examples/structural_mechanics/crash/train.py | 23 ++++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/examples/structural_mechanics/crash/train.py b/examples/structural_mechanics/crash/train.py index f554dd2dde..a21bdf7877 100644 --- a/examples/structural_mechanics/crash/train.py +++ b/examples/structural_mechanics/crash/train.py @@ -239,8 +239,8 @@ def backward(self, loss): self.optimizer.step() @torch.no_grad() - def validate_with_rollout(self, epoch): - """Run validation using the rollout approach with relative error computation""" + def validate(self, epoch): + """Run validation error computation""" self.model.eval() MSE = torch.zeros(1, device=self.dist.device) @@ -248,31 +248,30 @@ def validate_with_rollout(self, epoch): for idx, sample in enumerate(self.val_dataloader): sample = sample[0].to(self.dist.device) # SimSample .to() T = self.rollout_steps - Fo = 3 - # Forward rollout: expected to return [T,N,3] + + # Model forward pred_seq = self.model(sample=sample, data_stats=self.data_stats) # Exact sequence (if provided) exact_seq = None if sample.node_target is not None: N = sample.node_target.size(0) - assert sample.node_target.size(1) == T * Fo - exact_seq = ( - sample.node_target.view(N, T, Fo) - .transpose(0, 1) - .contiguous() + Fo = 3 # output features per node + assert sample.node_target.size(1) == T * Fo, ( + f"target dim {sample.node_target.size(1)} != {T * Fo}" ) + exact_seq = sample.node_target.view(N, T, Fo).transpose(0, 1).contiguous() # [T,N,Fo] - # Compute detailed relative error losses + # Compute and add error SqError = torch.square(pred_seq - exact_seq) MSE_w_time += torch.mean(SqError, dim=(1,2)) MSE += torch.mean(SqError) + # Sum errors across all ranks if self.dist.world_size > 1: torch.distributed.all_reduce(MSE, op=torch.distributed.ReduceOp.SUM) torch.distributed.all_reduce(MSE_w_time, op=torch.distributed.ReduceOp.SUM) - # Combine all statistics val_stats = { 'MSE_w_time': MSE_w_time / self.num_validation_samples, 'MSE': MSE / self.num_validation_samples, @@ -341,7 +340,7 @@ def main(cfg: DictConfig) -> None: val_freq = 10 if (epoch + 1) % val_freq == 0: # logger0.info(f"Validation started...") - val_stats = trainer.validate_with_rollout(epoch) + val_stats = trainer.validate(epoch) # Log detailed validation statistics logger0.info( From f8334029f83dd35b7b11f221a0529119ce3658ed Mon Sep 17 00:00:00 2001 From: Deepak Akhare Date: Thu, 30 Oct 2025 17:39:35 -0700 Subject: [PATCH 3/9] validate_every_n_epochs, save_ckpt_every_n_epochs added in config --- .../crash/conf/training/default.yaml | 2 ++ examples/structural_mechanics/crash/train.py | 11 +++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/structural_mechanics/crash/conf/training/default.yaml b/examples/structural_mechanics/crash/conf/training/default.yaml index 53269986c7..8ac24b988f 100644 --- a/examples/structural_mechanics/crash/conf/training/default.yaml +++ b/examples/structural_mechanics/crash/conf/training/default.yaml @@ -31,6 +31,8 @@ num_validation_samples: 8 start_lr: 0.0001 end_lr: 0.0000003 epochs: 10000 +validate_every_n_epochs: 10 +save_ckpt_every_n_epochs: 10 # ┌───────────────────────────────────────────┐ # │ Performance Optimization │ diff --git a/examples/structural_mechanics/crash/train.py b/examples/structural_mechanics/crash/train.py index a21bdf7877..4e857366d6 100644 --- a/examples/structural_mechanics/crash/train.py +++ b/examples/structural_mechanics/crash/train.py @@ -44,7 +44,7 @@ class Trainer: """Trainer for crash simulation models with unified SimSample input.""" - def __init__(self, cfg: DictConfig, logger0: RankZeroLoggingWrapper, validation: bool = True): + def __init__(self, cfg: DictConfig, logger0: RankZeroLoggingWrapper): assert DistributedManager.is_initialized() self.dist = DistributedManager() self.cfg = cfg @@ -105,7 +105,7 @@ def __init__(self, cfg: DictConfig, logger0: RankZeroLoggingWrapper, validation: ) self.sampler = sampler - if validation: + if cfg.training.num_validation_samples > 0: self.num_validation_replicas = min(self.dist.world_size, cfg.training.num_validation_samples) self.num_validation_samples = cfg.training.num_validation_samples // self.num_validation_replicas * self.num_validation_replicas logger0.info(f'Number of validation samples: {self.num_validation_samples}') @@ -324,7 +324,8 @@ def main(cfg: DictConfig) -> None: if dist.world_size > 1: torch.distributed.barrier() - if dist.rank == 0: + + if dist.rank == 0 and (epoch + 1) % cfg.training.save_ckpt_every_n_epochs == 0: save_checkpoint( cfg.training.ckpt_path, models=trainer.model, @@ -336,9 +337,7 @@ def main(cfg: DictConfig) -> None: logger.info(f"Saved model on rank {dist.rank}") # Validation - #TODO: Add validation frequency to config - val_freq = 10 - if (epoch + 1) % val_freq == 0: + if cfg.training.num_validation_samples > 0 and (epoch + 1) % cfg.training.validate_every_n_epochs == 0: # logger0.info(f"Validation started...") val_stats = trainer.validate(epoch) From 50aecf30879c23fccd3cf2c550f8835d59afdfe3 Mon Sep 17 00:00:00 2001 From: Deepak Akhare Date: Fri, 31 Oct 2025 16:30:19 -0700 Subject: [PATCH 4/9] corrected bug (args of model) in inference --- examples/structural_mechanics/crash/inference.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/examples/structural_mechanics/crash/inference.py b/examples/structural_mechanics/crash/inference.py index 983a61aac5..cb252f03c1 100644 --- a/examples/structural_mechanics/crash/inference.py +++ b/examples/structural_mechanics/crash/inference.py @@ -197,12 +197,7 @@ def run_on_single_run(self, run_path: str): sample = sample.to(self.device) # Forward rollout: expected to return [T,N,3] - pred_seq = self.model( - node_features=sample.node_features, - edge_index=sample.edge_index, - edge_features=sample.edge_features, - data_stats=data_stats, - ) + pred_seq = self.model(sample=sample, data_stats=data_stats) # Exact sequence (if provided) exact_seq = None From cc2add32ca68890a02b988919d1ec6f97bbe74b2 Mon Sep 17 00:00:00 2001 From: Deepak Akhare Date: Mon, 3 Nov 2025 11:33:25 -0800 Subject: [PATCH 5/9] args in validation code updated --- examples/structural_mechanics/crash/train.py | 69 ++++++++++++-------- 1 file changed, 43 insertions(+), 26 deletions(-) diff --git a/examples/structural_mechanics/crash/train.py b/examples/structural_mechanics/crash/train.py index 2d2fa57beb..a60a643a04 100644 --- a/examples/structural_mechanics/crash/train.py +++ b/examples/structural_mechanics/crash/train.py @@ -111,22 +111,29 @@ def __init__(self, cfg: DictConfig, logger0: RankZeroLoggingWrapper): self.sampler = sampler if cfg.training.num_validation_samples > 0: - self.num_validation_replicas = min(self.dist.world_size, cfg.training.num_validation_samples) - self.num_validation_samples = cfg.training.num_validation_samples // self.num_validation_replicas * self.num_validation_replicas - logger0.info(f'Number of validation samples: {self.num_validation_samples}') - + self.num_validation_replicas = min( + self.dist.world_size, cfg.training.num_validation_samples + ) + self.num_validation_samples = ( + cfg.training.num_validation_samples + // self.num_validation_replicas + * self.num_validation_replicas + ) + logger0.info(f"Number of validation samples: {self.num_validation_samples}") + # Create a validation dataset val_cfg = self.cfg.datapipe - with open_dict(val_cfg): # or open_dict(cfg) to open the whole tree - val_cfg.data_dir = self.cfg.training.raw_data_dir_test + with open_dict(val_cfg): # or open_dict(cfg) to open the whole tree + val_cfg.data_dir = self.cfg.inference.raw_data_dir_test val_cfg.num_samples = self.num_validation_samples val_dataset = instantiate( val_cfg, name="crash_test", + reader=reader, split="test", logger=logger0, ) - + if self.dist.rank < self.num_validation_replicas: # Sampler if self.dist.world_size > 1: @@ -151,7 +158,9 @@ def __init__(self, cfg: DictConfig, logger0: RankZeroLoggingWrapper): collate_fn=simsample_collate, ) else: - self.val_dataloader = torch.utils.data.DataLoader(torch.utils.data.Subset(val_dataset, []), batch_size=1) + self.val_dataloader = torch.utils.data.DataLoader( + torch.utils.data.Subset(val_dataset, []), batch_size=1 + ) # Model self.model = instantiate(cfg.model) @@ -247,13 +256,13 @@ def backward(self, loss): def validate(self, epoch): """Run validation error computation""" self.model.eval() - + MSE = torch.zeros(1, device=self.dist.device) MSE_w_time = torch.zeros(self.rollout_steps, device=self.dist.device) for idx, sample in enumerate(self.val_dataloader): - sample = sample[0].to(self.dist.device) # SimSample .to() + sample = sample[0].to(self.dist.device) # SimSample .to() T = self.rollout_steps - + # Model forward pred_seq = self.model(sample=sample, data_stats=self.data_stats) @@ -265,23 +274,25 @@ def validate(self, epoch): assert sample.node_target.size(1) == T * Fo, ( f"target dim {sample.node_target.size(1)} != {T * Fo}" ) - exact_seq = sample.node_target.view(N, T, Fo).transpose(0, 1).contiguous() # [T,N,Fo] + exact_seq = ( + sample.node_target.view(N, T, Fo).transpose(0, 1).contiguous() + ) # [T,N,Fo] # Compute and add error SqError = torch.square(pred_seq - exact_seq) - MSE_w_time += torch.mean(SqError, dim=(1,2)) + MSE_w_time += torch.mean(SqError, dim=(1, 2)) MSE += torch.mean(SqError) # Sum errors across all ranks if self.dist.world_size > 1: torch.distributed.all_reduce(MSE, op=torch.distributed.ReduceOp.SUM) torch.distributed.all_reduce(MSE_w_time, op=torch.distributed.ReduceOp.SUM) - + val_stats = { - 'MSE_w_time': MSE_w_time / self.num_validation_samples, - 'MSE': MSE / self.num_validation_samples, + "MSE_w_time": MSE_w_time / self.num_validation_samples, + "MSE": MSE / self.num_validation_samples, } - + self.model.train() # Switch back to training mode return val_stats @@ -342,23 +353,29 @@ def main(cfg: DictConfig) -> None: logger.info(f"Saved model on rank {dist.rank}") # Validation - if cfg.training.num_validation_samples > 0 and (epoch + 1) % cfg.training.validate_every_n_epochs == 0: + if ( + cfg.training.num_validation_samples > 0 + and (epoch + 1) % cfg.training.validate_every_n_epochs == 0 + ): # logger0.info(f"Validation started...") val_stats = trainer.validate(epoch) - + # Log detailed validation statistics logger0.info( - f"Validation epoch {epoch+1}: " - f"MSE: {val_stats['MSE'].item():.3e}, " + f"Validation epoch {epoch + 1}: MSE: {val_stats['MSE'].item():.3e}, " ) - + if dist.rank == 0: # Log to tensorboard - trainer.writer.add_scalar("val/MSE", val_stats['MSE'].item(), epoch) - + trainer.writer.add_scalar("val/MSE", val_stats["MSE"].item(), epoch) + # Log individual timestep relative errors - for i in range(len(val_stats['MSE_w_time'])): - trainer.writer.add_scalar(f"val/timestep_{i}_MSE", val_stats['MSE_w_time'][i].item(), epoch) + for i in range(len(val_stats["MSE_w_time"])): + trainer.writer.add_scalar( + f"val/timestep_{i}_MSE", + val_stats["MSE_w_time"][i].item(), + epoch, + ) logger0.info("Training completed!") if dist.rank == 0: From 887c71487f0cfd49a2a660541dec6c19d1a010a9 Mon Sep 17 00:00:00 2001 From: Deepak Akhare Date: Wed, 5 Nov 2025 12:18:43 -0800 Subject: [PATCH 6/9] val path added and args name changed --- .../crash/conf/training/default.yaml | 5 ++-- examples/structural_mechanics/crash/train.py | 28 +++++++++---------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/examples/structural_mechanics/crash/conf/training/default.yaml b/examples/structural_mechanics/crash/conf/training/default.yaml index 8ac24b988f..27203d2d91 100644 --- a/examples/structural_mechanics/crash/conf/training/default.yaml +++ b/examples/structural_mechanics/crash/conf/training/default.yaml @@ -19,6 +19,7 @@ # └───────────────────────────────────────────┘ raw_data_dir: "/code/datasets/gm_crash/train" # TODO change +raw_data_dir_validation: "/code/datasets/gm_crash/validation" max_workers_preprocessing: 64 # Maximum parallel workers # ┌───────────────────────────────────────────┐ @@ -31,8 +32,8 @@ num_validation_samples: 8 start_lr: 0.0001 end_lr: 0.0000003 epochs: 10000 -validate_every_n_epochs: 10 -save_ckpt_every_n_epochs: 10 +validation_freq: 10 +save_chckpoint_freq: 10 # ┌───────────────────────────────────────────┐ # │ Performance Optimization │ diff --git a/examples/structural_mechanics/crash/train.py b/examples/structural_mechanics/crash/train.py index a60a643a04..1f377fc13f 100644 --- a/examples/structural_mechanics/crash/train.py +++ b/examples/structural_mechanics/crash/train.py @@ -124,11 +124,11 @@ def __init__(self, cfg: DictConfig, logger0: RankZeroLoggingWrapper): # Create a validation dataset val_cfg = self.cfg.datapipe with open_dict(val_cfg): # or open_dict(cfg) to open the whole tree - val_cfg.data_dir = self.cfg.inference.raw_data_dir_test + val_cfg.data_dir = self.cfg.training.raw_data_dir_validation val_cfg.num_samples = self.num_validation_samples val_dataset = instantiate( val_cfg, - name="crash_test", + name="crash_validation", reader=reader, split="test", logger=logger0, @@ -266,17 +266,15 @@ def validate(self, epoch): # Model forward pred_seq = self.model(sample=sample, data_stats=self.data_stats) - # Exact sequence (if provided) - exact_seq = None - if sample.node_target is not None: - N = sample.node_target.size(0) - Fo = 3 # output features per node - assert sample.node_target.size(1) == T * Fo, ( - f"target dim {sample.node_target.size(1)} != {T * Fo}" - ) - exact_seq = ( - sample.node_target.view(N, T, Fo).transpose(0, 1).contiguous() - ) # [T,N,Fo] + # Exact sequence + N = sample.node_target.size(0) + Fo = 3 # output features per node + assert sample.node_target.size(1) == T * Fo, ( + f"target dim {sample.node_target.size(1)} != {T * Fo}" + ) + exact_seq = ( + sample.node_target.view(N, T, Fo).transpose(0, 1).contiguous() + ) # [T,N,Fo] # Compute and add error SqError = torch.square(pred_seq - exact_seq) @@ -341,7 +339,7 @@ def main(cfg: DictConfig) -> None: if dist.world_size > 1: torch.distributed.barrier() - if dist.rank == 0 and (epoch + 1) % cfg.training.save_ckpt_every_n_epochs == 0: + if dist.rank == 0 and (epoch + 1) % cfg.training.save_chckpoint_freq == 0: save_checkpoint( cfg.training.ckpt_path, models=trainer.model, @@ -355,7 +353,7 @@ def main(cfg: DictConfig) -> None: # Validation if ( cfg.training.num_validation_samples > 0 - and (epoch + 1) % cfg.training.validate_every_n_epochs == 0 + and (epoch + 1) % cfg.training.validation_freq == 0 ): # logger0.info(f"Validation started...") val_stats = trainer.validate(epoch) From aa3c3d452ff210f604d94328b12f11e2d121f4db Mon Sep 17 00:00:00 2001 From: Deepak Akhare Date: Mon, 10 Nov 2025 12:07:11 -0800 Subject: [PATCH 7/9] validation split added -> write_vtp=False --- examples/structural_mechanics/crash/d3plot_reader.py | 2 +- examples/structural_mechanics/crash/train.py | 2 +- examples/structural_mechanics/crash/vtp_reader.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/structural_mechanics/crash/d3plot_reader.py b/examples/structural_mechanics/crash/d3plot_reader.py index 4f8b0edc9f..ba3ae6f025 100644 --- a/examples/structural_mechanics/crash/d3plot_reader.py +++ b/examples/structural_mechanics/crash/d3plot_reader.py @@ -427,7 +427,7 @@ def __call__( split: str, logger=None, ): - write_vtp = False if split == "train" else True + write_vtp = False if split == ("train" or "validation") else True return process_d3plot_data( data_dir=data_dir, num_samples=num_samples, diff --git a/examples/structural_mechanics/crash/train.py b/examples/structural_mechanics/crash/train.py index 48084a70e4..1e583e3905 100644 --- a/examples/structural_mechanics/crash/train.py +++ b/examples/structural_mechanics/crash/train.py @@ -134,7 +134,7 @@ def __init__(self, cfg: DictConfig, logger0: RankZeroLoggingWrapper): val_cfg, name="crash_validation", reader=reader, - split="test", + split="validation", logger=logger0, ) diff --git a/examples/structural_mechanics/crash/vtp_reader.py b/examples/structural_mechanics/crash/vtp_reader.py index c87d41ab09..5e70e35c48 100644 --- a/examples/structural_mechanics/crash/vtp_reader.py +++ b/examples/structural_mechanics/crash/vtp_reader.py @@ -225,7 +225,7 @@ def __call__( logger=None, **kwargs, ): - write_vtp = False if split == "train" else True + write_vtp = False if split == ("train" or "validation") else True return process_vtp_data( data_dir=data_dir, num_samples=num_samples, From 4f570ee2a97f2f6c869e8c55f5837cafcac368af Mon Sep 17 00:00:00 2001 From: Deepak Akhare Date: Mon, 10 Nov 2025 12:17:11 -0800 Subject: [PATCH 8/9] fixed inference bug --- examples/structural_mechanics/crash/inference.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/structural_mechanics/crash/inference.py b/examples/structural_mechanics/crash/inference.py index cb252f03c1..8e54d67a8a 100644 --- a/examples/structural_mechanics/crash/inference.py +++ b/examples/structural_mechanics/crash/inference.py @@ -150,13 +150,14 @@ def run_on_single_run(self, run_path: str): os.symlink(run_path, os.path.join(tmpdir, run_name)) # Instantiate a dataset that sees exactly one run + reader = instantiate(self.cfg.reader) dataset = instantiate( self.cfg.datapipe, name="crash_test", + reader=reader, split="test", num_steps=self.cfg.training.num_time_steps, num_samples=1, - write_vtp=True, # ensures it writes ./output_/frame_*.vtp logger=self.logger, data_dir=tmpdir, # IMPORTANT: dataset reads from the tmpdir with single run ) From 3a195f6eba0bf102e785dd42dbb024f89424d8c8 Mon Sep 17 00:00:00 2001 From: Deepak Akhare Date: Mon, 10 Nov 2025 12:58:54 -0800 Subject: [PATCH 9/9] bug fix: write_vtp --- examples/structural_mechanics/crash/d3plot_reader.py | 2 +- examples/structural_mechanics/crash/vtp_reader.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/structural_mechanics/crash/d3plot_reader.py b/examples/structural_mechanics/crash/d3plot_reader.py index ba3ae6f025..86c4d27592 100644 --- a/examples/structural_mechanics/crash/d3plot_reader.py +++ b/examples/structural_mechanics/crash/d3plot_reader.py @@ -427,7 +427,7 @@ def __call__( split: str, logger=None, ): - write_vtp = False if split == ("train" or "validation") else True + write_vtp = False if split in ("train", "validation") else True return process_d3plot_data( data_dir=data_dir, num_samples=num_samples, diff --git a/examples/structural_mechanics/crash/vtp_reader.py b/examples/structural_mechanics/crash/vtp_reader.py index 5e70e35c48..3cb3e88a06 100644 --- a/examples/structural_mechanics/crash/vtp_reader.py +++ b/examples/structural_mechanics/crash/vtp_reader.py @@ -225,7 +225,7 @@ def __call__( logger=None, **kwargs, ): - write_vtp = False if split == ("train" or "validation") else True + write_vtp = False if split in ("train", "validation") else True return process_vtp_data( data_dir=data_dir, num_samples=num_samples,