Skip to content

Commit 727df69

Browse files
Merge branch 'main' into saikrishnanc/update-crash
2 parents 4a9a974 + f8fd198 commit 727df69

File tree

7 files changed

+133
-11
lines changed

7 files changed

+133
-11
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
| [**Install Guide**](#installation)
1414
| [**Getting Started**](#getting-started)
1515
| [**Contributing Guidelines**](#contributing-to-physicsnemo)
16-
| [**License**](#license)
16+
| [**Dev blog**](https://nvidia.github.io/physicsnemo/blog/)
1717

1818
## What is PhysicsNeMo?
1919

examples/structural_mechanics/crash/conf/training/default.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
# └───────────────────────────────────────────┘
2020

2121
raw_data_dir: "/code/datasets/gm_crash/train" # TODO change
22+
raw_data_dir_validation: "/code/datasets/gm_crash/validation"
2223
max_workers_preprocessing: 64 # Maximum parallel workers
2324

2425
# ┌───────────────────────────────────────────┐
@@ -27,9 +28,12 @@ max_workers_preprocessing: 64 # Maximum parallel workers
2728

2829
num_time_steps: 14
2930
num_training_samples: 8
31+
num_validation_samples: 8
3032
start_lr: 0.0001
3133
end_lr: 0.0000003
3234
epochs: 10000
35+
validation_freq: 10
36+
save_chckpoint_freq: 10
3337

3438
# ┌───────────────────────────────────────────┐
3539
# │ Performance Optimization │

examples/structural_mechanics/crash/d3plot_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def __call__(
427427
split: str,
428428
logger=None,
429429
):
430-
write_vtp = False if split == "train" else True
430+
write_vtp = False if split in ("train", "validation") else True
431431
return process_d3plot_data(
432432
data_dir=data_dir,
433433
num_samples=num_samples,

examples/structural_mechanics/crash/inference.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,14 @@ def run_on_single_run(self, run_path: str):
150150
os.symlink(run_path, os.path.join(tmpdir, run_name))
151151

152152
# Instantiate a dataset that sees exactly one run
153+
reader = instantiate(self.cfg.reader)
153154
dataset = instantiate(
154155
self.cfg.datapipe,
155156
name="crash_test",
157+
reader=reader,
156158
split="test",
157159
num_steps=self.cfg.training.num_time_steps,
158160
num_samples=1,
159-
write_vtp=True, # ensures it writes ./output_<run_name>/frame_*.vtp
160161
logger=self.logger,
161162
data_dir=tmpdir, # IMPORTANT: dataset reads from the tmpdir with single run
162163
)
@@ -197,12 +198,7 @@ def run_on_single_run(self, run_path: str):
197198
sample = sample.to(self.device)
198199

199200
# Forward rollout: expected to return [T,N,3]
200-
pred_seq = self.model(
201-
node_features=sample.node_features,
202-
edge_index=sample.edge_index,
203-
edge_features=sample.edge_features,
204-
data_stats=data_stats,
205-
)
201+
pred_seq = self.model(sample=sample, data_stats=data_stats)
206202

207203
# Exact sequence (if provided)
208204
exact_seq = None
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
jaxtyping==0.3.3
12
lasso-python==2.0.3
23
torch_geometric==2.6.1
34
torch_scatter>=2.1.2

examples/structural_mechanics/crash/train.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
# Import unified datapipe
3939
from datapipe import SimSample, simsample_collate
40+
from omegaconf import open_dict
4041

4142

4243
class Trainer:
@@ -113,6 +114,58 @@ def __init__(self, cfg: DictConfig, logger0: RankZeroLoggingWrapper):
113114
)
114115
self.sampler = sampler
115116

117+
if cfg.training.num_validation_samples > 0:
118+
self.num_validation_replicas = min(
119+
self.dist.world_size, cfg.training.num_validation_samples
120+
)
121+
self.num_validation_samples = (
122+
cfg.training.num_validation_samples
123+
// self.num_validation_replicas
124+
* self.num_validation_replicas
125+
)
126+
logger0.info(f"Number of validation samples: {self.num_validation_samples}")
127+
128+
# Create a validation dataset
129+
val_cfg = self.cfg.datapipe
130+
with open_dict(val_cfg): # or open_dict(cfg) to open the whole tree
131+
val_cfg.data_dir = self.cfg.training.raw_data_dir_validation
132+
val_cfg.num_samples = self.num_validation_samples
133+
val_dataset = instantiate(
134+
val_cfg,
135+
name="crash_validation",
136+
reader=reader,
137+
split="validation",
138+
logger=logger0,
139+
)
140+
141+
if self.dist.rank < self.num_validation_replicas:
142+
# Sampler
143+
if self.dist.world_size > 1:
144+
sampler = DistributedSampler(
145+
val_dataset,
146+
num_replicas=self.num_validation_replicas,
147+
rank=self.dist.rank,
148+
shuffle=False,
149+
drop_last=True,
150+
)
151+
else:
152+
sampler = None
153+
154+
self.val_dataloader = torch.utils.data.DataLoader(
155+
val_dataset,
156+
batch_size=1, # variable N per sample
157+
shuffle=(sampler is None),
158+
drop_last=True,
159+
pin_memory=True,
160+
num_workers=cfg.training.num_dataloader_workers,
161+
sampler=sampler,
162+
collate_fn=simsample_collate,
163+
)
164+
else:
165+
self.val_dataloader = torch.utils.data.DataLoader(
166+
torch.utils.data.Subset(val_dataset, []), batch_size=1
167+
)
168+
116169
# Model
117170
self.model = instantiate(cfg.model)
118171
logging.getLogger().setLevel(logging.INFO)
@@ -203,6 +256,48 @@ def backward(self, loss):
203256
loss.backward()
204257
self.optimizer.step()
205258

259+
@torch.no_grad()
260+
def validate(self, epoch):
261+
"""Run validation error computation"""
262+
self.model.eval()
263+
264+
MSE = torch.zeros(1, device=self.dist.device)
265+
MSE_w_time = torch.zeros(self.rollout_steps, device=self.dist.device)
266+
for idx, sample in enumerate(self.val_dataloader):
267+
sample = sample[0].to(self.dist.device) # SimSample .to()
268+
T = self.rollout_steps
269+
270+
# Model forward
271+
pred_seq = self.model(sample=sample, data_stats=self.data_stats)
272+
273+
# Exact sequence
274+
N = sample.node_target.size(0)
275+
Fo = 3 # output features per node
276+
assert sample.node_target.size(1) == T * Fo, (
277+
f"target dim {sample.node_target.size(1)} != {T * Fo}"
278+
)
279+
exact_seq = (
280+
sample.node_target.view(N, T, Fo).transpose(0, 1).contiguous()
281+
) # [T,N,Fo]
282+
283+
# Compute and add error
284+
SqError = torch.square(pred_seq - exact_seq)
285+
MSE_w_time += torch.mean(SqError, dim=(1, 2))
286+
MSE += torch.mean(SqError)
287+
288+
# Sum errors across all ranks
289+
if self.dist.world_size > 1:
290+
torch.distributed.all_reduce(MSE, op=torch.distributed.ReduceOp.SUM)
291+
torch.distributed.all_reduce(MSE_w_time, op=torch.distributed.ReduceOp.SUM)
292+
293+
val_stats = {
294+
"MSE_w_time": MSE_w_time / self.num_validation_samples,
295+
"MSE": MSE / self.num_validation_samples,
296+
}
297+
298+
self.model.train() # Switch back to training mode
299+
return val_stats
300+
206301

207302
@hydra.main(version_base="1.3", config_path="conf", config_name="config")
208303
def main(cfg: DictConfig) -> None:
@@ -247,7 +342,8 @@ def main(cfg: DictConfig) -> None:
247342

248343
if dist.world_size > 1:
249344
torch.distributed.barrier()
250-
if dist.rank == 0:
345+
346+
if dist.rank == 0 and (epoch + 1) % cfg.training.save_chckpoint_freq == 0:
251347
save_checkpoint(
252348
cfg.training.ckpt_path,
253349
models=trainer.model,
@@ -258,6 +354,31 @@ def main(cfg: DictConfig) -> None:
258354
)
259355
logger.info(f"Saved model on rank {dist.rank}")
260356

357+
# Validation
358+
if (
359+
cfg.training.num_validation_samples > 0
360+
and (epoch + 1) % cfg.training.validation_freq == 0
361+
):
362+
# logger0.info(f"Validation started...")
363+
val_stats = trainer.validate(epoch)
364+
365+
# Log detailed validation statistics
366+
logger0.info(
367+
f"Validation epoch {epoch + 1}: MSE: {val_stats['MSE'].item():.3e}, "
368+
)
369+
370+
if dist.rank == 0:
371+
# Log to tensorboard
372+
trainer.writer.add_scalar("val/MSE", val_stats["MSE"].item(), epoch)
373+
374+
# Log individual timestep relative errors
375+
for i in range(len(val_stats["MSE_w_time"])):
376+
trainer.writer.add_scalar(
377+
f"val/timestep_{i}_MSE",
378+
val_stats["MSE_w_time"][i].item(),
379+
epoch,
380+
)
381+
261382
logger0.info("Training completed!")
262383
if dist.rank == 0:
263384
trainer.writer.close()

examples/structural_mechanics/crash/vtp_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def __call__(
237237
logger=None,
238238
**kwargs,
239239
):
240-
write_vtp = False if split == "train" else True
240+
write_vtp = False if split in ("train", "validation") else True
241241
return process_vtp_data(
242242
data_dir=data_dir,
243243
num_samples=num_samples,

0 commit comments

Comments
 (0)