Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
b0dbf92
Temporal interpolation training recipe
jleinonen Oct 3, 2025
af84a7b
Add README
jleinonen Oct 8, 2025
d9841bc
Merge branch 'NVIDIA:main' into interp-model-example
jleinonen Oct 8, 2025
f56991e
Docs changes based on comments
jleinonen Oct 13, 2025
ae6eed1
Update docstrings and README
jleinonen Oct 14, 2025
5430eb4
Add temporal interpolation animation
jleinonen Oct 14, 2025
84289f5
Merge branch 'NVIDIA:main' into interp-model-example
jleinonen Oct 14, 2025
2b2c81e
Add animation link
jleinonen Oct 14, 2025
6f09aa1
Add shape check in loss
jleinonen Oct 14, 2025
811a38a
Updates of configs + trainer
jleinonen Oct 15, 2025
cb23fe6
Update config comments
jleinonen Oct 15, 2025
c32a01b
Merge branch 'NVIDIA:main' into interp-model-example
jleinonen Oct 20, 2025
f642d78
Merge branch 'main' into interp-model-example
CharlelieLrt Oct 21, 2025
8d6ae42
Update README.md
megnvidia Oct 21, 2025
e1c202b
Added wandb logging
CharlelieLrt Oct 22, 2025
dc2b215
Merge branch 'interp-model-example' of https://github.com/jleinonen/m…
CharlelieLrt Oct 22, 2025
a253d6e
Reformated sections in docstring for GeometricL2Loss
CharlelieLrt Oct 22, 2025
aad2683
Merge branch 'main' into interp-model-example
CharlelieLrt Oct 22, 2025
e25aaeb
Update README and configs
jleinonen Oct 22, 2025
2ea2897
Merge branch 'interp-model-example' of
jleinonen Oct 22, 2025
8475e1d
README changes + type hint fixes
jleinonen Oct 22, 2025
547f10d
Update README.md
jleinonen Oct 22, 2025
41560fc
Merge branch 'NVIDIA:main' into interp-model-example
jleinonen Oct 23, 2025
5912555
Draft of validation script
jleinonen Oct 23, 2025
2a6c06a
Update validation and README
jleinonen Oct 27, 2025
06d6aa3
Merge branch 'NVIDIA:main' into interp-model-example
jleinonen Oct 27, 2025
62b8174
Merge branch '1.3.0-rc' into interp-model-example
CharlelieLrt Oct 30, 2025
551486e
Fixed command in README.md for temporal_interpolation example
CharlelieLrt Oct 30, 2025
d0d2214
Removed unused import in datapipe/climate_interp.py
CharlelieLrt Oct 31, 2025
2792ae6
Updated license headers in temporal_interpolation example
CharlelieLrt Oct 31, 2025
f600d13
Renamed methods to avoid implicit shadowing in Trainer class
CharlelieLrt Oct 31, 2025
f1fd540
Cosmetic changes in train.py and removed unused import in validate.py
CharlelieLrt Oct 31, 2025
2b8de5c
Added clamp in validate.py to make sure step does not go out of bounds
CharlelieLrt Oct 31, 2025
d719952
Added the temporal_interpolation example to the docs + updated CHANGE…
CharlelieLrt Oct 31, 2025
390d778
Addressing remaining comments
jleinonen Nov 13, 2025
e6477aa
Merged two data source classes in climate_interp.py
CharlelieLrt Nov 15, 2025
04729dd
Merge branch 'interp-model-example' of https://github.com/jleinonen/m…
CharlelieLrt Nov 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions examples/weather/temporal_interpolation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Earth-2 Temporal Interpolation Model

The temporal interpolation model is used to increase the temporal resolution of AI-based
forecast models. These typically have a native temporal resolution of 6 hours; the
interpolation allows this to be improved to 1 hour. With appropriate training data, even
higher temporal resolutions should be achievable.

This PhysicsNeMo example shows how to train a ModAFNO-based temporal interpolation model
with a custom dataset. For access to the pre-trained model, see the [wrapper in
Earth2Studio](https://nvidia.github.io/earth2studio/modules/generated/models/px/earth2studio.models.px.InterpModAFNO.html#earth2studio.models.px.InterpModAFNO).
A technical description of the model can be found in the paper ["Modulated Adaptive
Fourier Neural Operators for Temporal Interpolation of Weather
Forecasts"](https://arxiv.org/abs/2410.18904).

## Requirements

### Environment

You need to have PhysicsNeMo installed on a GPU system. Training useful models in
practice requires a multi-GPU system; for the original model, 64 H100 GPUs were used.
Using the [PhysicsNeMo
container](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/physicsnemo/containers/physicsnemo)
is recommended.

This example uses MLFlow to manage data logging. MLFlow is not installed by default;
install it by running

```bash
pip install mlflow
```

### Data

To train a temporal interpolation model, you need the following:

* A dataset of yearly HDF5 files at 1-hour resolution. For more details, see the section
["Data Format and Structure" in the diagnostic model
example](https://github.com/NVIDIA/physicsnemo/blob/5a64525c40eada2248cd3eacee0a6ac4735ae380/examples/weather/diagnostic/README.md#data-format-and-structure).
These datasets can be very large: the dataset used to train the original model, with
73 variables from 1980 to 2017, is approximately 100 TB in size. The data used to
train the original model are on the ERA5 0.25 degree grid with shape `(721, 1440)` but
other resolutions should work too.
* Statistics files containing the mean and standard deviation of each channel in the
data files. They should be found in the `stats/global_means.npy` and
`stats/global_stds.npy` files in your data directory. They should be `.npy` files
containing a 1D array with length equal to the number of variables in the dataset,
with each value giving the mean (for `global_means.npy`) or standard deviation (for
`global_stds.npy`) of the corresponding variable.
* A JSON file with metadata about the contents of the HDF5 files. See [here](data.json)
for an example describing the dataset used to train the original model.
* Optional: NetCDF4 files containing the orography and land-sea mask for the grid
contained in the data. These should contain a variable of the same shape as the data

## Configuration

The model training is controlled by YAML configuration files managed by
[Hydra](https://hydra.cc/), found in the `config` directory. The full configuration for
training the original model is [`train_interp.yaml`](config/train_interp.yaml).
[`train_interp_lite.yaml`](config/train_interp_lite.yaml) runs a short test run with a
lightweight model, which is not expected to produce useful checkpoints but can be used
to test that training runs without errors.

See the comments in the configuration files for an explanation of each configuration
parameter. To replicate the model from the paper, you only need to change the file and
directory paths to correspond to those on your system. If you train it with a custom
dataset, you may also need to change the `model.in_channels` and `model.out_channels`
parameters.

## Starting training

Test training by running the `train.py` script using the "lite" configuration file on a
system with a GPU:

```bash
python train.py --config-name=train_interp_lite.yaml
```

For a multi-GPU or multi-node training job, launch the training with the
`train_interp.yaml` configuration file using `torchrun` or MPI. For example, to train on
8 nodes with 8 GPUs each for a total of 64 GPUs, start a distributed compute job (e.g.
using SLURM or Run:ai) and use:

```bash
torchrun --nnodes=8 --nproc-per-node=8 train.py --config-name=train_interp.yaml
```

or the equivalent `mpirun` command. The code will automatically utilize all GPUs
available to the job. Remember to set `training.batch_size` in the configuration file to
the batch size *per process*.

Configuration parameters can be overridden from the command line using the Hydra syntax.
For instance, to set the optimizer learning rate to 0.0001 for the current run, you
could use

```bash
torchrun --nnodes=8 --nproc-per-node=8 train.py --config-name=train_interp.yaml ++training.optimizer_params.lr=0.0001
```
56 changes: 56 additions & 0 deletions examples/weather/temporal_interpolation/config/train_interp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

model:
model_type: "modafno" # should always be "modafno"
model_name: "modafno-cplxscale-smallpatch" # name for the model
inp_shape: [720, 1440] # should be [720, 1440], must be divisible by patch_size
in_channels: 155 # number of input channels to the model, use 155 for 73-var ERA5
out_channels: 73 # number of output channels from the model, use 73 for 73-var ERA5
patch_size: [2,2] # size of AFNO patches
embed_dim: 512 # embedding dimension
mlp_ratio: 2.0 # multiplier for MLP hidden layer size
num_blocks: 12 # number of AFNO blocks

scale_shift_mode: complex # "real" or "complex"
embed_model:
dim: 64 # width of time embedding net
depth: 1 # depth of time embedding net
method: sinusoidal # embedding type used in time embedding net

datapipe:
data_dir: "/data/era5-73varQ-hourly" # directory where data files are located
metadata_path: "/data/era5-73varQ-hourly/metadata/data.json" # directory to metadata json file
geopotential_filename: "/data/era5-wind_gust/invariants/orography.nc" # location of orography file
lsm_filename: "/data/era5-wind_gust/invariants/land_sea_mask.nc" # location of lsm file
use_latlon: True # when True, return latitude and longitude from datapipe
num_samples_per_year_train: 8748 # number of training samples per year (8748 == 365 * 24 - 12)
num_samples_per_year_valid: 64 # number of validation samples per year
batch_size_train: 1 # batch size per GPU

training:
max_epoch: 120 # number of data "epochs"
samples_per_epoch: 50000 # number of samples per "epoch"
checkpoint_dir: "/checkpoints/fcinterp/" # location where checkpoints are saved
optimizer_params:
lr: 5e-4 # learning rate
betas: [0.9, 0.95] # beta parameters for Adam

logging:
mlflow:
use_mlflow: True # when True, produce logs with mlflow
experiment_name: "Forecast interpolation model" # experiment name, can be set freely
user_name: "PhysicsNeMo User" # user name, can be set freely
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Config file for testing training. Does a very short run with a small model.
# Can be used to test that training runs without errors, not expected to
# produce useful checkpoints.

model:
model_type: "modafno" # should always be "modafno"
model_name: "modafno-test" # name for the model
inp_shape: [720, 1440] # should be [720, 1440], must be divisible by patch_size
in_channels: 155 # number of input channels to the model, use 155 for 73-var ERA5
out_channels: 73 # number of output channels from the model, use 73 for 73-var ERA5
patch_size: [8,8] # size of AFNO patches
embed_dim: 64 # embedding dimension
mlp_ratio: 2.0 # multiplier for MLP hidden layer size
num_blocks: 2 # number of AFNO blocks

scale_shift_mode: complex # "real" or "complex"
embed_model:
dim: 64 # width of time embedding net
depth: 1 # depth of time embedding net
method: sinusoidal # embedding type used in time embedding net

datapipe:
data_dir: "/data/era5-73varQ-hourly" # directory where data files are located
metadata_path: "/data/era5-73varQ-hourly/metadata/data.json" # directory to metadata json file
geopotential_filename: "/data/era5-wind_gust/invariants/orography.nc" # location of orography file
lsm_filename: "/data/era5-wind_gust/invariants/land_sea_mask.nc" # location of lsm file
use_latlon: True # when True, return latitude and longitude from datapipe
num_samples_per_year_train: 8748 # number of training samples per year (8748 == 365 * 24 - 12)
num_samples_per_year_valid: 64 # number of validation samples per year
batch_size_train: 1 # batch size per GPU

training:
max_epoch: 4 # number of data "epochs"
samples_per_epoch: 50 # number of samples per "epoch"
checkpoint_dir: "/checkpoints/fcinterp/" # location where checkpoints are saved
optimizer_params:
lr: 5e-4 # learning rate
betas: [0.9, 0.95] # beta parameters for Adam

logging:
mlflow:
use_mlflow: True # when True, produce logs with mlflow
experiment_name: "Forecast interpolation model" # experiment name, can be set freely
user_name: "PhysicsNeMo User" # user name, can be set freely
Loading