Skip to content

Commit 5912555

Browse files
committed
Draft of validation script
1 parent 41560fc commit 5912555

File tree

8 files changed

+435
-100
lines changed

8 files changed

+435
-100
lines changed

examples/weather/temporal_interpolation/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ To train a temporal interpolation model, ensure you have the following:
6969
containing a 1D array with length equal to the number of variables in the dataset,
7070
with each value giving the mean (for `global_means.npy`) or standard deviation (for
7171
`global_stds.npy`) of the corresponding variable.
72-
* A JSON file with metadata about the contents of the HDF5 files. Refer to [data sample](https://github.com/NVIDIA/physicsnemo/blob/main/examples/weather/temporal_interpolation/data.json)
72+
* A JSON file with metadata about the contents of the HDF5 files. Refer to the [data
73+
sample](https://github.com/NVIDIA/physicsnemo/blob/main/examples/weather/temporal_interpolation/data/data.json)
7374
for an example describing the dataset used to train the original model.
7475
* Optional: NetCDF4 files containing the orography and land-sea mask for the grid
7576
contained in the data. These should contain a variable of the same shape as the data.

examples/weather/temporal_interpolation/config/train_interp.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,10 @@ training:
4646
samples_per_epoch: 50000 # number of samples per "epoch"
4747
load_epoch: "latest" # int, null or "latest"; "latest" loads the most recent checkpoint in checkpoint_dir
4848
checkpoint_dir: "/checkpoints/fcinterp/" # location where checkpoints are saved
49-
optimizer_params:
50-
lr: 5e-4 # learning rate
51-
betas: [0.9, 0.95] # beta parameters for Adam
49+
50+
optimizer_params:
51+
lr: 5e-4 # learning rate
52+
betas: [0.9, 0.95] # beta parameters for Adam
5253

5354
logging:
5455
mlflow:

examples/weather/temporal_interpolation/config/train_interp_lite.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,10 @@ training:
5050
samples_per_epoch: 50 # number of samples per "epoch"
5151
load_epoch: "latest" # int, null or "latest"; "latest" loads the most recent checkpoint in checkpoint_dir
5252
checkpoint_dir: "/checkpoints/fcinterp/" # location where checkpoints are saved
53-
optimizer_params:
54-
lr: 5e-4 # learning rate
55-
betas: [0.9, 0.95] # beta parameters for Adam
53+
54+
optimizer_params:
55+
lr: 5e-4 # learning rate
56+
betas: [0.9, 0.95] # beta parameters for Adam
5657

5758
logging:
5859
mlflow:

examples/weather/temporal_interpolation/datapipe/climate_interp.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,6 @@ def __call__(
7070

7171
# Shuffle before the next epoch starts
7272
if self.shuffle and sample_info.epoch_idx != self.last_epoch:
73-
# All workers use the same rng seed so the resulting
74-
# indices are the same across workers
75-
# np.random.default_rng(seed=sample_info.epoch_idx).shuffle(self.indices)
7673
print("Shuffling indices")
7774
np.random.shuffle(self.indices)
7875
self.last_epoch = sample_info.epoch_idx

examples/weather/temporal_interpolation/train.py

Lines changed: 88 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import os
1818
import datetime
19+
from typing import Any
20+
import warnings
1921

2022
import hydra
2123
from omegaconf import OmegaConf
@@ -36,6 +38,11 @@
3638
from utils import distribute, loss
3739
from utils.trainer import Trainer
3840

41+
try:
42+
from apex.optimizers import FusedAdam
43+
except ImportError:
44+
warnings.warn("Apex is not installed, defaulting to PyTorch optimizers.")
45+
3946

4047
def setup_datapipes(
4148
*,
@@ -182,6 +189,10 @@ def setup_model(
182189
183190
Parameters
184191
----------
192+
num_variables : int
193+
Number of atmospheric variables in the model.
194+
num_auxiliaries : int
195+
Number of auxiliary input channels.
185196
model_cfg : dict or None, optional
186197
Model configuration dict.
187198
@@ -213,17 +224,70 @@ def setup_model(
213224
return model
214225

215226

227+
def setup_optimizer(
228+
model: torch.nn.Module,
229+
max_epoch: int,
230+
opt_cls: type[torch.optim.Optimizer] | None = None,
231+
opt_params: dict | None = None,
232+
scheduler_cls: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
233+
scheduler_params: dict[str, Any] | None = None,
234+
) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]:
235+
"""Setup optimizer.
236+
237+
Parameters
238+
----------
239+
model : torch.nn.Module
240+
Model that optimizer is applied to.
241+
max_epoch : int
242+
Maximum number of training epochs (used for scheduler setup).
243+
opt_cls : type[torch.optim.Optimizer] or None, optional
244+
Optimizer class. When None, will setup apex.optimizers.FusedAdam
245+
if available, otherwise PyTorch Adam.
246+
opt_params : dict or None, optional
247+
Dict of parameters (e.g. learning rate) to pass to optimizer.
248+
scheduler_cls : type[torch.optim.lr_scheduler.LRScheduler] or None, optional
249+
Scheduler class. When None, will setup CosineAnnealingLR.
250+
scheduler_params : dict[str, Any] or None, optional
251+
Dict of parameters to pass to scheduler.
252+
253+
Returns
254+
-------
255+
tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]
256+
The initialized optimizer and learning rate scheduler.
257+
"""
258+
259+
opt_kwargs = {"lr": 0.0005}
260+
if opt_params is not None:
261+
opt_kwargs.update(opt_params)
262+
if opt_cls is None:
263+
try:
264+
opt_cls = FusedAdam
265+
except NameError: # in case we don't have apex
266+
opt_cls = torch.optim.Adam
267+
268+
scheduler_kwargs = {}
269+
if scheduler_cls is None:
270+
scheduler_cls = torch.optim.lr_scheduler.CosineAnnealingLR
271+
scheduler_kwargs["T_max"] = max_epoch
272+
if scheduler_params is not None:
273+
scheduler_kwargs.update(scheduler_params)
274+
275+
optimizer = opt_cls(model.parameters(), **opt_kwargs)
276+
scheduler = scheduler_cls(optimizer, **scheduler_kwargs)
277+
return (optimizer, scheduler)
278+
279+
216280
@torch.no_grad()
217281
def input_output_from_batch_data(
218-
batch: dict[str, torch.Tensor], time_scale: float = 6 * 3600.0
282+
batch: list[dict[str, torch.Tensor]], time_scale: float = 6 * 3600.0
219283
) -> tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
220284
"""
221285
Convert the datapipe output dict to model input and output batches.
222286
223287
Parameters
224288
----------
225-
batch : dict[str, torch.Tensor]
226-
The data dict returned by the datapipe.
289+
batch : list[dict[str, torch.Tensor]]
290+
The list data dicts returned by the datapipe.
227291
time_scale : float, optional
228292
Number of seconds between the interpolation endpoints (default 6 hours).
229293
@@ -235,16 +299,17 @@ def input_output_from_batch_data(
235299
batch = batch[0]
236300
# Concatenate all input variables to a single tensor
237301
atmos_vars = batch["state_seq-atmos"]
238-
cos_zenith = batch["cos_zenith-atmos"].squeeze(dim=2)
239302

240-
sincos_latlon = batch["latlon"]
241-
geop = batch["geopotential"]
242-
lsm = batch["land_sea_mask"]
243-
244-
atmos_vars_in = torch.cat(
245-
[atmos_vars[:, 0], atmos_vars[:, 1], cos_zenith, sincos_latlon, geop, lsm],
246-
dim=1,
247-
)
303+
atmos_vars_in = [atmos_vars[:, 0], atmos_vars[:, 1]]
304+
if "cos_zenith-atmos" in batch:
305+
atmos_vars_in = atmos_vars_in + [batch["cos_zenith-atmos"].squeeze(dim=2)]
306+
if "latlon" in batch:
307+
atmos_vars_in = atmos_vars_in + [batch["latlon"]]
308+
if "geopotential" in batch:
309+
atmos_vars_in = atmos_vars_in + [batch["geopotential"]]
310+
if "land_sea_mask" in batch:
311+
atmos_vars_in = atmos_vars_in + [batch["land_sea_mask"]]
312+
atmos_vars_in = torch.cat(atmos_vars_in, dim=1)
248313

249314
atmos_vars_out = atmos_vars[:, 2]
250315

@@ -286,6 +351,15 @@ def setup_trainer(**cfg: dict) -> Trainer:
286351
)
287352
(model, dist_manager) = distribute.distribute_model(model)
288353

354+
# Setup optimizer and learning rate scheduler
355+
(optimizer, scheduler) = setup_optimizer(
356+
model,
357+
cfg["training"].get("max_epoch", 1),
358+
opt_params=cfg.get("optimizer_params", {}),
359+
scheduler_params=cfg.get("scheduler_params", {}),
360+
)
361+
362+
# Initialize mlflow
289363
mlflow_cfg = cfg.get("logging", {}).get("mlflow", {})
290364
if mlflow_cfg.pop("use_mlflow", False):
291365
initialize_mlflow(**mlflow_cfg)
@@ -334,6 +408,8 @@ def setup_trainer(**cfg: dict) -> Trainer:
334408
train_datapipe=train_datapipe,
335409
valid_datapipe=valid_datapipe,
336410
input_output_from_batch_data=input_output_from_batch_data,
411+
optimizer=optimizer,
412+
scheduler=scheduler,
337413
use_wandb=use_wandb,
338414
**cfg["training"],
339415
)

examples/weather/temporal_interpolation/utils/trainer.py

Lines changed: 9 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from collections.abc import Callable, Sequence
1818
from typing import Any, Literal
19-
import warnings
2019
import time
2120

2221
import torch
@@ -29,11 +28,6 @@
2928
from physicsnemo.launch.logging import LaunchLogger, PythonLogger
3029
from physicsnemo.launch.utils import load_checkpoint, save_checkpoint
3130

32-
try:
33-
from apex.optimizers import FusedAdam
34-
except ImportError:
35-
warnings.warn("Apex is not installed, defaulting to PyTorch optimizers.")
36-
3731

3832
class Trainer:
3933
"""Training loop.
@@ -52,18 +46,13 @@ class Trainer:
5246
ClimateDatapipe providing validation data.
5347
samples_per_epoch : int
5448
Number of samples to draw from the datapipe per 'epoch'.
49+
optimizer : torch.optim.Optimizer
50+
Optimizer used for training.
51+
scheduler : torch.optim.lr_scheduler.LRScheduler
52+
Learning rate scheduler.
5553
input_output_from_batch_data : Callable, optional
5654
Function that converts datapipe outputs to training batches.
5755
If not provided, will try to use outputs as-is.
58-
optimizer : type[torch.optim.Optimizer] or None, optional
59-
Optimizer class used for training. When None, will setup
60-
apex.optimizers.FusedAdam if available, otherwise PyTorch Adam.
61-
optimizer_params : dict[str, Any] or None, optional
62-
Dict of parameters (e.g. learning rate) to pass to optimizer.
63-
scheduler : type[torch.optim.lr_scheduler.LRScheduler] or None, optional
64-
Learning rate scheduler class. When None, will setup CosineAnnealingLR.
65-
scheduler_params : dict[str, Any] or None, optional
66-
Dict of parameters to pass to LR scheduler.
6756
max_epoch : int, optional
6857
The last training epoch.
6958
load_epoch : int, "latest", or None, optional
@@ -90,11 +79,9 @@ def __init__(
9079
train_datapipe: ClimateDatapipe,
9180
valid_datapipe: ClimateDatapipe,
9281
samples_per_epoch: int,
82+
optimizer: torch.optim.Optimizer,
83+
scheduler: torch.optim.lr_scheduler.LRScheduler,
9384
input_output_from_batch_data: Callable = lambda x: x,
94-
optimizer: type[torch.optim.Optimizer] | None = None,
95-
optimizer_params: dict[str, Any] | None = None,
96-
scheduler: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
97-
scheduler_params: dict[str, Any] | None = None,
9885
max_epoch: int = 1,
9986
load_epoch: int | Literal["latest"] | None = "latest",
10087
checkpoint_every: int = 1,
@@ -110,13 +97,8 @@ def __init__(
11097
self.valid_datapipe = valid_datapipe
11198
self.max_epoch = max_epoch
11299
self.input_output_from_batch_data = input_output_from_batch_data
113-
self.optimizer, self.lr_scheduler = self.setup_optimizer(
114-
model,
115-
opt_cls=optimizer,
116-
opt_params=optimizer_params,
117-
scheduler_cls=scheduler,
118-
scheduler_params=scheduler_params,
119-
)
100+
self.optimizer = optimizer
101+
self.lr_scheduler = scheduler
120102
self.validation_callbacks = validation_callbacks
121103
self.device = self.dist_manager.device
122104
self.logger = PythonLogger()
@@ -309,57 +291,6 @@ def validate_on_epoch(self) -> torch.Tensor:
309291
model.train()
310292
return loss_epoch / num_examples
311293

312-
def setup_optimizer(
313-
self,
314-
model: torch.nn.Module,
315-
opt_cls: type[torch.optim.Optimizer] | None = None,
316-
opt_params: dict | None = None,
317-
scheduler_cls: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
318-
scheduler_params: dict[str, Any] | None = None,
319-
) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]:
320-
"""Setup optimizer.
321-
322-
Parameters
323-
----------
324-
model : torch.nn.Module
325-
Model that optimizer is applied to.
326-
opt_cls : type[torch.optim.Optimizer] or None, optional
327-
Optimizer class. When None, will setup apex.optimizers.FusedAdam
328-
if available, otherwise PyTorch Adam.
329-
opt_params : dict or None, optional
330-
Dict of parameters (e.g. learning rate) to pass to optimizer.
331-
scheduler_cls : type[torch.optim.lr_scheduler.LRScheduler] or None, optional
332-
Scheduler class. When None, will setup CosineAnnealingLR.
333-
scheduler_params : dict[str, Any] or None, optional
334-
Dict of parameters to pass to scheduler.
335-
336-
Returns
337-
-------
338-
tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]
339-
The initialized optimizer and learning rate scheduler.
340-
"""
341-
342-
opt_kwargs = {"lr": 0.0005}
343-
if opt_params is not None:
344-
opt_kwargs.update(opt_params)
345-
346-
if opt_cls is None:
347-
try:
348-
opt_cls = FusedAdam
349-
except NameError: # in case we don't have apex
350-
opt_cls = torch.optim.Adam
351-
352-
scheduler_kwargs = {}
353-
if scheduler_cls is None:
354-
scheduler_cls = torch.optim.lr_scheduler.CosineAnnealingLR
355-
scheduler_kwargs["T_max"] = self.max_epoch
356-
if scheduler_params is not None:
357-
scheduler_kwargs.update(scheduler_params)
358-
359-
optimizer = opt_cls(model.parameters(), **opt_kwargs)
360-
scheduler = scheduler_cls(optimizer, **scheduler_kwargs)
361-
return (optimizer, scheduler)
362-
363294
def load_checkpoint(self, epoch: int | None = None) -> int:
364295
"""Try to load model state from a checkpoint.
365296
@@ -377,7 +308,7 @@ def load_checkpoint(self, epoch: int | None = None) -> int:
377308
"""
378309
if self.checkpoint_dir is None:
379310
raise ValueError("checkpoint_dir must be set in order to load checkpoints.")
380-
metadata = {"total_samples_trained": self.total_samples_trained}
311+
metadata = {}
381312
self.epoch = load_checkpoint(
382313
self.checkpoint_dir,
383314
models=self.model,

0 commit comments

Comments
 (0)