-
Notifications
You must be signed in to change notification settings - Fork 484
Interpolation model example #1149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jleinonen
wants to merge
37
commits into
NVIDIA:1.3.0-rc
Choose a base branch
from
jleinonen:interp-model-example
base: 1.3.0-rc
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 3 commits
Commits
Show all changes
37 commits
Select commit
Hold shift + click to select a range
b0dbf92
Temporal interpolation training recipe
jleinonen af84a7b
Add README
jleinonen d9841bc
Merge branch 'NVIDIA:main' into interp-model-example
jleinonen f56991e
Docs changes based on comments
jleinonen ae6eed1
Update docstrings and README
jleinonen 5430eb4
Add temporal interpolation animation
jleinonen 84289f5
Merge branch 'NVIDIA:main' into interp-model-example
jleinonen 2b2c81e
Add animation link
jleinonen 6f09aa1
Add shape check in loss
jleinonen 811a38a
Updates of configs + trainer
jleinonen cb23fe6
Update config comments
jleinonen c32a01b
Merge branch 'NVIDIA:main' into interp-model-example
jleinonen f642d78
Merge branch 'main' into interp-model-example
CharlelieLrt 8d6ae42
Update README.md
megnvidia e1c202b
Added wandb logging
CharlelieLrt dc2b215
Merge branch 'interp-model-example' of https://github.com/jleinonen/m…
CharlelieLrt a253d6e
Reformated sections in docstring for GeometricL2Loss
CharlelieLrt aad2683
Merge branch 'main' into interp-model-example
CharlelieLrt e25aaeb
Update README and configs
jleinonen 2ea2897
Merge branch 'interp-model-example' of
jleinonen 8475e1d
README changes + type hint fixes
jleinonen 547f10d
Update README.md
jleinonen 41560fc
Merge branch 'NVIDIA:main' into interp-model-example
jleinonen 5912555
Draft of validation script
jleinonen 2a6c06a
Update validation and README
jleinonen 06d6aa3
Merge branch 'NVIDIA:main' into interp-model-example
jleinonen 62b8174
Merge branch '1.3.0-rc' into interp-model-example
CharlelieLrt 551486e
Fixed command in README.md for temporal_interpolation example
CharlelieLrt d0d2214
Removed unused import in datapipe/climate_interp.py
CharlelieLrt 2792ae6
Updated license headers in temporal_interpolation example
CharlelieLrt f600d13
Renamed methods to avoid implicit shadowing in Trainer class
CharlelieLrt f1fd540
Cosmetic changes in train.py and removed unused import in validate.py
CharlelieLrt 2b8de5c
Added clamp in validate.py to make sure step does not go out of bounds
CharlelieLrt d719952
Added the temporal_interpolation example to the docs + updated CHANGE…
CharlelieLrt 390d778
Addressing remaining comments
jleinonen e6477aa
Merged two data source classes in climate_interp.py
CharlelieLrt 04729dd
Merge branch 'interp-model-example' of https://github.com/jleinonen/m…
CharlelieLrt File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| # Earth-2 Temporal Interpolation Model | ||
|
|
||
| The temporal interpolation model is used to increase the temporal resolution of AI-based | ||
| forecast models. These typically have a native temporal resolution of 6 hours; the | ||
| interpolation allows this to be improved to 1 hour. With appropriate training data, even | ||
| higher temporal resolutions should be achievable. | ||
|
|
||
| This PhysicsNeMo example shows how to train a ModAFNO-based temporal interpolation model | ||
| with a custom dataset. For access to the pre-trained model, see the [wrapper in | ||
| Earth2Studio](https://nvidia.github.io/earth2studio/modules/generated/models/px/earth2studio.models.px.InterpModAFNO.html#earth2studio.models.px.InterpModAFNO). | ||
| A technical description of the model can be found in the paper ["Modulated Adaptive | ||
| Fourier Neural Operators for Temporal Interpolation of Weather | ||
| Forecasts"](https://arxiv.org/abs/2410.18904). | ||
|
|
||
CharlelieLrt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ## Requirements | ||
|
|
||
| ### Environment | ||
|
|
||
| You need to have PhysicsNeMo installed on a GPU system. Training useful models in | ||
| practice requires a multi-GPU system; for the original model, 64 H100 GPUs were used. | ||
| Using the [PhysicsNeMo | ||
| container](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/physicsnemo/containers/physicsnemo) | ||
| is recommended. | ||
|
|
||
| This example uses MLFlow to manage data logging. MLFlow is not installed by default; | ||
| install it by running | ||
|
|
||
| ```bash | ||
| pip install mlflow | ||
CharlelieLrt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ``` | ||
|
|
||
| ### Data | ||
CharlelieLrt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| To train a temporal interpolation model, you need the following: | ||
|
|
||
| * A dataset of yearly HDF5 files at 1-hour resolution. For more details, see the section | ||
| ["Data Format and Structure" in the diagnostic model | ||
| example](https://github.com/NVIDIA/physicsnemo/blob/5a64525c40eada2248cd3eacee0a6ac4735ae380/examples/weather/diagnostic/README.md#data-format-and-structure). | ||
| These datasets can be very large: the dataset used to train the original model, with | ||
| 73 variables from 1980 to 2017, is approximately 100 TB in size. The data used to | ||
| train the original model are on the ERA5 0.25 degree grid with shape `(721, 1440)` but | ||
| other resolutions should work too. | ||
| * Statistics files containing the mean and standard deviation of each channel in the | ||
| data files. They should be found in the `stats/global_means.npy` and | ||
| `stats/global_stds.npy` files in your data directory. They should be `.npy` files | ||
| containing a 1D array with length equal to the number of variables in the dataset, | ||
| with each value giving the mean (for `global_means.npy`) or standard deviation (for | ||
| `global_stds.npy`) of the corresponding variable. | ||
| * A JSON file with metadata about the contents of the HDF5 files. See [here](data.json) | ||
CharlelieLrt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| for an example describing the dataset used to train the original model. | ||
| * Optional: NetCDF4 files containing the orography and land-sea mask for the grid | ||
| contained in the data. These should contain a variable of the same shape as the data | ||
|
|
||
| ## Configuration | ||
|
|
||
| The model training is controlled by YAML configuration files managed by | ||
| [Hydra](https://hydra.cc/), found in the `config` directory. The full configuration for | ||
| training the original model is [`train_interp.yaml`](config/train_interp.yaml). | ||
| [`train_interp_lite.yaml`](config/train_interp_lite.yaml) runs a short test run with a | ||
| lightweight model, which is not expected to produce useful checkpoints but can be used | ||
| to test that training runs without errors. | ||
|
|
||
| See the comments in the configuration files for an explanation of each configuration | ||
| parameter. To replicate the model from the paper, you only need to change the file and | ||
| directory paths to correspond to those on your system. If you train it with a custom | ||
| dataset, you may also need to change the `model.in_channels` and `model.out_channels` | ||
| parameters. | ||
|
|
||
| ## Starting training | ||
CharlelieLrt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| Test training by running the `train.py` script using the "lite" configuration file on a | ||
| system with a GPU: | ||
|
|
||
| ```bash | ||
| python train.py --config-name=train_interp_lite.yaml | ||
| ``` | ||
|
|
||
| For a multi-GPU or multi-node training job, launch the training with the | ||
| `train_interp.yaml` configuration file using `torchrun` or MPI. For example, to train on | ||
| 8 nodes with 8 GPUs each for a total of 64 GPUs, start a distributed compute job (e.g. | ||
| using SLURM or Run:ai) and use: | ||
|
|
||
| ```bash | ||
| torchrun --nnodes=8 --nproc-per-node=8 train.py --config-name=train_interp.yaml | ||
| ``` | ||
|
|
||
| or the equivalent `mpirun` command. The code will automatically utilize all GPUs | ||
| available to the job. Remember to set `training.batch_size` in the configuration file to | ||
| the batch size *per process*. | ||
|
|
||
| Configuration parameters can be overridden from the command line using the Hydra syntax. | ||
| For instance, to set the optimizer learning rate to 0.0001 for the current run, you | ||
| could use | ||
|
|
||
| ```bash | ||
| torchrun --nnodes=8 --nproc-per-node=8 train.py --config-name=train_interp.yaml ++training.optimizer_params.lr=0.0001 | ||
| ``` | ||
56 changes: 56 additions & 0 deletions
56
examples/weather/temporal_interpolation/config/train_interp.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. | ||
| # SPDX-FileCopyrightText: All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| model: | ||
| model_type: "modafno" # should always be "modafno" | ||
| model_name: "modafno-cplxscale-smallpatch" # name for the model | ||
| inp_shape: [720, 1440] # should be [720, 1440], must be divisible by patch_size | ||
| in_channels: 155 # number of input channels to the model, use 155 for 73-var ERA5 | ||
CharlelieLrt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| out_channels: 73 # number of output channels from the model, use 73 for 73-var ERA5 | ||
| patch_size: [2,2] # size of AFNO patches | ||
| embed_dim: 512 # embedding dimension | ||
| mlp_ratio: 2.0 # multiplier for MLP hidden layer size | ||
jleinonen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| num_blocks: 12 # number of AFNO blocks | ||
|
|
||
| scale_shift_mode: complex # "real" or "complex" | ||
CharlelieLrt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| embed_model: | ||
| dim: 64 # width of time embedding net | ||
| depth: 1 # depth of time embedding net | ||
| method: sinusoidal # embedding type used in time embedding net | ||
CharlelieLrt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| datapipe: | ||
| data_dir: "/data/era5-73varQ-hourly" # directory where data files are located | ||
| metadata_path: "/data/era5-73varQ-hourly/metadata/data.json" # directory to metadata json file | ||
| geopotential_filename: "/data/era5-wind_gust/invariants/orography.nc" # location of orography file | ||
| lsm_filename: "/data/era5-wind_gust/invariants/land_sea_mask.nc" # location of lsm file | ||
| use_latlon: True # when True, return latitude and longitude from datapipe | ||
| num_samples_per_year_train: 8748 # number of training samples per year (8748 == 365 * 24 - 12) | ||
CharlelieLrt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| num_samples_per_year_valid: 64 # number of validation samples per year | ||
| batch_size_train: 1 # batch size per GPU | ||
|
|
||
| training: | ||
| max_epoch: 120 # number of data "epochs" | ||
| samples_per_epoch: 50000 # number of samples per "epoch" | ||
CharlelieLrt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| checkpoint_dir: "/checkpoints/fcinterp/" # location where checkpoints are saved | ||
CharlelieLrt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| optimizer_params: | ||
| lr: 5e-4 # learning rate | ||
| betas: [0.9, 0.95] # beta parameters for Adam | ||
|
|
||
| logging: | ||
| mlflow: | ||
| use_mlflow: True # when True, produce logs with mlflow | ||
| experiment_name: "Forecast interpolation model" # experiment name, can be set freely | ||
| user_name: "PhysicsNeMo User" # user name, can be set freely | ||
60 changes: 60 additions & 0 deletions
60
examples/weather/temporal_interpolation/config/train_interp_lite.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. | ||
| # SPDX-FileCopyrightText: All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # Config file for testing training. Does a very short run with a small model. | ||
| # Can be used to test that training runs without errors, not expected to | ||
| # produce useful checkpoints. | ||
|
|
||
| model: | ||
| model_type: "modafno" # should always be "modafno" | ||
| model_name: "modafno-test" # name for the model | ||
| inp_shape: [720, 1440] # should be [720, 1440], must be divisible by patch_size | ||
| in_channels: 155 # number of input channels to the model, use 155 for 73-var ERA5 | ||
| out_channels: 73 # number of output channels from the model, use 73 for 73-var ERA5 | ||
| patch_size: [8,8] # size of AFNO patches | ||
| embed_dim: 64 # embedding dimension | ||
| mlp_ratio: 2.0 # multiplier for MLP hidden layer size | ||
| num_blocks: 2 # number of AFNO blocks | ||
|
|
||
| scale_shift_mode: complex # "real" or "complex" | ||
| embed_model: | ||
| dim: 64 # width of time embedding net | ||
| depth: 1 # depth of time embedding net | ||
| method: sinusoidal # embedding type used in time embedding net | ||
|
|
||
| datapipe: | ||
| data_dir: "/data/era5-73varQ-hourly" # directory where data files are located | ||
| metadata_path: "/data/era5-73varQ-hourly/metadata/data.json" # directory to metadata json file | ||
| geopotential_filename: "/data/era5-wind_gust/invariants/orography.nc" # location of orography file | ||
| lsm_filename: "/data/era5-wind_gust/invariants/land_sea_mask.nc" # location of lsm file | ||
| use_latlon: True # when True, return latitude and longitude from datapipe | ||
| num_samples_per_year_train: 8748 # number of training samples per year (8748 == 365 * 24 - 12) | ||
| num_samples_per_year_valid: 64 # number of validation samples per year | ||
| batch_size_train: 1 # batch size per GPU | ||
|
|
||
| training: | ||
| max_epoch: 4 # number of data "epochs" | ||
| samples_per_epoch: 50 # number of samples per "epoch" | ||
| checkpoint_dir: "/checkpoints/fcinterp/" # location where checkpoints are saved | ||
| optimizer_params: | ||
| lr: 5e-4 # learning rate | ||
| betas: [0.9, 0.95] # beta parameters for Adam | ||
|
|
||
| logging: | ||
| mlflow: | ||
| use_mlflow: True # when True, produce logs with mlflow | ||
| experiment_name: "Forecast interpolation model" # experiment name, can be set freely | ||
| user_name: "PhysicsNeMo User" # user name, can be set freely |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.