Skip to content

Commit 297297e

Browse files
LostnEkkoTao Getge25
authored
GEFS-HRRR Corrdiff Devs (#703)
* merge gefs-hrrr model * update gefs-hrrr * Update gefs_hrrr.py * Update gefs_hrrr.yaml edit stats path * Update train.py * Update gefs_hrrr.py * Add unit test for GEFS Corrdiff regression loss and lead-time aware songunet * Format so init prob_channel scalar factor by channel length * Add docstrings and license for dataloader, formating in general * Delete examples/generative/corrdiff/stats.json * Update loss.py * Update song_unet.py * Update utils.py * Update README.md * Update README.md * Update CHANGELOG.md * Fixing generalization to unit test case for corrdiff utils * Update unit test for fixing diffusion step signature * Reformat with black --------- Co-authored-by: Tao Ge <[email protected]> Co-authored-by: Tao Ge <[email protected]>
1 parent f46e25f commit 297297e

30 files changed

+1993
-112
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1818
- Bistride Multiscale MeshGraphNet example.
1919
- FIGConvUNet model and example.
2020
- The Transolver model.
21+
- Incoporated CorrDiff-GEFS-HRRR model into CorrDiff, with lead-time aware SongUNet and
22+
cross entropy loss.
2123

2224
### Changed
2325

examples/generative/corrdiff/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,11 @@ CorrDiff training is handled by `train.py` and controlled by YAML configuration
3030
### Training the regression model
3131
To train the CorrDiff-Mini regression model, we use the main configuration file [config_training_mini_regression.yaml](conf/config_training_mini_regression.yaml). This includes the following components:
3232
* The HRRR-Mini dataset: [conf/dataset/hrrrmini.yaml](conf/dataset/hrrrmini.yaml)
33+
* The GEFS-HRRR dataset: [conf/dataset/hrrrmini.yaml](conf/dataset/gefs_hrrr.yaml)
3334
* The CorrDiff-Mini regression model: [conf/model/corrdiff_regression_mini.yaml](conf/model/corrdiff_regression_mini.yaml)
3435
* The CorrDiff-Mini regression training options: [conf/training/corrdiff_regression_mini.yaml](conf/training/corrdiff_regression_mini.yaml)
36+
* The CorrDiff-GEFS-HRRR regression training options: [conf/model/corrdiff_regression_mini.yaml](conf/training/config_training_gefs_regression.yaml)
37+
3538
To start the training, run:
3639
```bash
3740
python train.py --config-name=config_training_mini_regression.yaml ++dataset.data_path=</path/to/dataset>/hrrr_mini_train.nc ++dataset.stats_path=</path/to/dataset>/stats.json
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-FileCopyrightText: All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
hydra:
18+
job:
19+
chdir: true
20+
name: gefs_hrrr_generation
21+
run:
22+
dir: output/${hydra:job.name}
23+
24+
# Get defaults
25+
defaults:
26+
27+
# Dataset
28+
- dataset/gefs_hrrr
29+
30+
# Sampler
31+
- sampler/stochastic
32+
#- sampler/deterministic
33+
34+
# Generation
35+
- generation/patched_based_gefs_hrrr
36+
#- generation/patched_based
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-FileCopyrightText: All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
hydra:
18+
job:
19+
chdir: true
20+
name: gefs_hrrr_diffusion
21+
run:
22+
dir: ./outputs/${hydra:job.name}
23+
24+
# Get defaults
25+
defaults:
26+
27+
# Dataset
28+
- dataset/gefs_hrrr
29+
30+
# Model
31+
- model/corrdiff_patched_diffusion_gefs_hrrr
32+
33+
# Training
34+
- training/corrdiff_patched_diffusion_gefs_hrrr
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-FileCopyrightText: All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
hydra:
18+
job:
19+
chdir: true
20+
name: gefs_hrrr_regression
21+
run:
22+
dir: ./outputs/${hydra:job.name}
23+
24+
# Get defaults
25+
defaults:
26+
27+
# Dataset
28+
- dataset/gefs_hrrr
29+
30+
# Model
31+
- model/corrdiff_regression_gefs_hrrr
32+
33+
# Training
34+
- training/corrdiff_regression_gefs_hrrr

examples/generative/corrdiff/conf/dataset/cwb_train.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,4 @@ ds_factor: 4
2626
min_path: null
2727
max_path: null
2828
global_means_path: null
29-
global_stds_path: null
29+
global_stds_path: null
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-FileCopyrightText: All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
type: gefs_hrrr
18+
data_path: /data
19+
stats_path: modulus/examples/generative/corrdiff/stats.json
20+
output_variables: ["u10m", "v10m", "t2m", "precip", "cat_snow", "cat_ice", "cat_freez", "cat_rain", "cat_none"]
21+
prob_variables: ["cat_snow", "cat_ice", "cat_freez", "cat_rain"]
22+
input_surface_variables: ["u10m", "v10m", "t2m", "q2m", "sp", "msl", "precipitable_water"]
23+
input_isobaric_variables: ['u1000', 'u925', 'u850', 'u700', 'u500', 'u250', 'v1000', 'v925', 'v850', 'v700', 'v500', 'v250', 'z1000', 'z925', 'z850', 'z700', 'z500', 'z200', 't1000', 't925', 't850', 't700', 't500', 't100', 'r1000', 'r925', 'r850', 'r700', 'r500', 'r100']
24+
ds_factor: 4
25+
train: False
26+
hrrr_window: [[1,1057], [4,1796]] # need dims to be divisible by 16 [[0,1024], [0,1024]]

examples/generative/corrdiff/conf/dataset/hrrrmini.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717
type: hrrr_mini
1818
data_path: /data/corrdiff-mini/hrrr_mini_train.nc
1919
stats_path: /data/corrdiff-mini/stats.json
20-
output_variables: ['10u', '10v']
20+
output_variables: ['10u', '10v']
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-FileCopyrightText: All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
num_ensembles: 1
18+
# Number of ensembles to generate per input
19+
seed_batch_size: 1
20+
# Size of the batched inference
21+
inference_mode: all
22+
# Choose between "all" (regression + diffusion), "regression" or "diffusion"
23+
patch_size: 448
24+
patch_shape_x: 448
25+
patch_shape_y: 448
26+
# Patch size. Patch-based sampling will be utilized if these dimensions differ from
27+
# img_shape_x and img_shape_y
28+
overlap_pixels: 4
29+
# Number of overlapping pixels between adjacent patches
30+
boundary_pixels: 2
31+
# Number of boundary pixels to be cropped out. 2 is recommanded to address the boundary
32+
# artifact.
33+
hr_mean_conditioning: true
34+
gridtype: learnable
35+
N_grid_channels: 100
36+
sample_res: full
37+
# Sampling resolution
38+
times_range: null
39+
times:
40+
- "2024011212f00"
41+
- "2024011212f03"
42+
- "2024011212f06"
43+
- "2024011212f09"
44+
- "2024011212f12"
45+
- "2024011212f15"
46+
- "2024011212f18"
47+
- "2024011212f21"
48+
- "2024011212f24"
49+
50+
has_lead_time: true
51+
52+
perf:
53+
force_fp16: false
54+
# Whether to force fp16 precision for the model. If false, it'll use the precision
55+
# specified upon training.
56+
use_torch_compile: false
57+
# whether to use torch.compile on the diffusion model
58+
# this will make the first time stamp generation very slow due to compilation overheads
59+
# but will significantly speed up subsequent inference runs
60+
num_writer_workers: 1
61+
# number of workers to use for writing file
62+
# To support multiple workers a threadsafe version of the netCDF library must be used
63+
64+
io:
65+
res_ckpt_filename: EDMPrecondSRV2_updated.0.5821440.mdlus
66+
# Checkpoint filename for the diffusion model
67+
reg_ckpt_filename: UNet_updated.0.1960960.mdlus
68+
# Checkpoint filename for the mean predictor model
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-FileCopyrightText: All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
name: lt_aware_patched_diffusion
18+
# Name of the preconditioner
19+
hr_mean_conditioning: True
20+
# High-res mean (regression's output) as additional condition
21+
scale_cond_input: True
22+
# If true, also scales the input conditioning
23+
# For backward compatibility, this is true by default
24+
# We recommend setting this to false for new training runs

0 commit comments

Comments
 (0)