diff --git a/examples/weather/corrdiff/README_SOLAR.md b/examples/weather/corrdiff/README_SOLAR.md new file mode 100644 index 0000000000..34156ef874 --- /dev/null +++ b/examples/weather/corrdiff/README_SOLAR.md @@ -0,0 +1,182 @@ + +### **CorrDiffSolar** + +CorrDiffSolar is a generative model for high-resolution solar irradiance downscaling. It utilizes a two-stage approach to upscale low-resolution ERA5 reanalysis data to high-resolution Shortwave Downward Radiation (SWDR) fields. + +The process involves: +1. A **Regression Model** to generate a base downscaled prediction. +2. A **Diffusion Model** to learn and add high-frequency details by correcting the residual of the initial prediction. + +This guide covers the entire workflow, including data preparation, environment setup, model training, and inference. + +--- + +### **1. Project Overview** + +The project consists of the following core components: + +* **Data Preprocessing Scripts**: Convert raw ERA5 and Himawari-8 data into a model-ready format. +* **SolarDataset Class**: A custom PyTorch Dataset for efficiently loading and processing paired high- and low-resolution data. +* **Regression Model**: A UNet-based model that performs the initial coarse downscaling. +* **Diffusion Model**: A conditional diffusion model that refines the regression output, adding realistic, high-resolution details. +* **MultiDiffusion Sampling**: A sliding-window inference strategy that enables the generation of large, high-resolution images while overcoming GPU memory limitations. + +--- + +### **2. Environment Setup** + +For a seamless setup, it is recommended to use the official PhysicsNeMo Docker container, which includes all necessary dependencies. Please ensure you have pulled the latest image and started the container before proceeding. + +--- + +### **3. Data Preparation** + +This model requires two primary data sources: low-resolution ERA5 data and high-resolution Himawari-8 SWDR data. Please download the raw data before following the instructions below. + +#### **Raw Data Sources** + +* **Low-Resolution Input Data (ERA5):** + * **Source**: ERA5 Reanalysis (0.25-degree, 1-hour resolution) + * **Time Range**: 2016-2020 + * **Path**: `path/to/ERA5/rawdata` + * **Variables**: + * `ssrd` (Surface Solar Radiation Downwards), `t2m` (2-meter Temperature), `tcwv` (Total Column Water Vapour) + * `t`, `q`, `z` (Temperature, Specific Humidity, Geopotential) at pressure levels: 50, 100, 300, 500, 925, 1000 hPa + +* **High-Resolution Target Data (Himawari-8):** + * **Source**: Himawari-8 (H08) L3 Hourly Surface Solar Radiation Product for China and Surrounding Areas (0.05-degree, 10-minute resolution). Available [here](https://data.tpdc.ac.cn/en/data/4c4fbfa7-b165-48db-8525-0d1c165b39c4). + * **Time Range**: 2016-2020 + * **Path**: `path/to/H08/rawdata` + * **Variable**: `SWDR` + +#### **3.1. Preparing High-Resolution (HR) Data** + +1. **Organize Raw Data**: Place the downloaded 10-minute interval `_SWDR.nc` files into a directory structure of `YEAR/YYYYMM/DD/`. +2. **Run Preprocessing Script**: + * Modify the `data_directory` variable in `prepare_solar_data/prepare_h08_hourly.py` to point to your raw H08 data root. + * Execute the script. It will automatically merge the six 10-minute files for each hour, crop them to the target domain, filter out missing and nighttime data, and save the results into the `HRdata/H08_YYYY_hourly/` directory. + +#### **3.2. Preparing Low-Resolution (LR) Data** + +1. **Organize Raw Data**: Store the annual ERA5 data in separate `.nc` files for each variable (e.g., `2016_2m_temperature.nc`, `2016_q_500.nc`). +2. **Run Preprocessing Script**: + * Open `prepare_solar_data/prepare_era5.py`. + * Modify the `path` variable to point to the directory containing your annual ERA5 `.nc` files. + * Modify the `output_filename` variable to specify the output path for the Zarr archives (`LRdata/`). + * Execute the script to generate an `era5_YYYY_opt.zarr` file for each year. + +#### **3.3. Preparing Static Files** + +1. **dem.nc**: Go to `prepare_solar_data/get_dem` directory. Run the `1-download.sh` and `2-outputDEM` script to download, interpolate, crop, save, and visualize the DEM file over interested region. +2. **stats.json**: This file contains normalization statistics (mean, std, etc.) for all variables. Run `prepare_solar_data/get_stats.py` to compute these statistics from your training data. + + +After completing the above preprocessing steps, your final data directory should have the following structure: + +``` +/path/to/your/data/ +├── HRdata/ +│ ├── H08_2016_hourly/ +│ │ ├── H08_20160101_0000_hourly.nc +│ │ └── ... +│ ├── H08_2017_hourly/ +│ └── ... +├── LRdata/ +│ ├── era5_2016_opt.zarr/ +│ ├── era5_2017_opt.zarr/ +│ └── ... +├── dem.nc +└── stats.json +``` +--- + +### **4. Model Training** + +The training is a two-stage process. Use Hydra to launch training from the command line. During training, we obtain the window data with fixed size 320x320 via randomly sampling. + +#### **Stage 1: Train the Regression Model** + +This stage trains the base UNet for initial downscaling. + +1. **Configuration**: `config_training_multidiffsolar_regression.yaml` +2. **Modify Config**: + * Set `dataset.data_path` to your final data root directory (`/path/to/your/data/`). + * Ensure `dataset.stats_path` points to your `stats.json` file. + * Adjust `training.hp.total_batch_size` and `training.hp.batch_size_per_gpu` based on your hardware. +3. **Launch Training**: + * For single-GPU training: + ```bash + python train.py --config-name=config_training_multidiffsolar_regression.yaml + ``` + * For multi-GPU training: + ```bash + torchrun --standalone --nnodes=1 --nproc_per_node=8 train.py --config-name=config_training_multidiffsolar_regression.yaml + ``` + Checkpoints will be saved in the directory specified by Hydra (e.g., `outputs/solar_regression/`). + +#### **Stage 2: Train the Diffusion Model** + +This stage trains the residual diffusion model, conditioned on the pre-trained regression model. + +1. **Configuration**: `config_training_multidiffsolar_diffusion.yaml` +2. **Modify Config**: + * Verify `data_path` and `stats_path` as in Stage 1. + * **Crucial Step**: In the `training.io` section, set `regression_checkpoint_path` to the path of a regression model checkpoint (`.pt`) from Stage 1. +3. **Launch Training**: + * For single-GPU or multi-GPU training, use the same commands as above but with the new config file: + ```bash + torchrun --standalone --nnodes=1 --nproc_per_node=8 train.py --config-name=config_training_multidiffsolar_diffusion.yaml + ``` + Checkpoints will be saved in the corresponding output directory. + +--- + +### **5. Model Inference** + +Inference can be performed with or without the diffusion model refinement. + +#### **Mode 1: Regression-Only Inference** + +This mode is faster and provides a good baseline result. + +1. **Configuration**: `config_generate_multidiffsolar.yaml` +2. **Modify Config**: + * Ensure `data_path` and `stats_path` are correct. + * In the `inference` section, set `regression_checkpoint_path` to point to your chosen regression model checkpoint. +3. **Launch Inference**: + ```bash + python test.py --config-name=config_generate_multidiffsolar.yaml + ``` + +#### **Mode 2: Full CorrDiffSolar Inference (Regression + Diffusion)** + +This mode produces the highest quality results by applying the full two-stage model. + +1. **Configuration**: `config_generate_multidiffsolar_wDiff.yaml` +2. **Modify Config**: + * Check `data_path` and `stats_path`. + * **Crucial Step**: Set both `regression_checkpoint_path` and `diffusion_checkpoint_path` in the `inference` section. +3. **Launch Inference**: + ```bash + python test.py --config-name=config_generate_multidiffsolar_wDiff.yaml + ``` + Generated predictions will be saved in the output directory specified by Hydra. + +--- + +### **6. Code Structure Overview** + +* **`SolarDataset` Class**: The core data loader. Its `_get_files_stats` method scans data directories, while `__getitem__` fetches, interpolates, normalizes, computes the solar zenith angles, and crops data samples. +* **`MultiDiffusion` Class**: The sampler used during inference, located in `helpers/generate_helpers.py`. It implements a sliding-window (patch-based) approach to apply the diffusion model across large images, seamlessly stitching the results by averaging overlapping regions. +* **`generate_solar` function**: The function located in `helpers/generate_helpers.py` performs the inference. + +--- + +### **7. Visualization Example** + +Use the provided `plot.py` script to visualize the generated outputs against the ground truth. + +![image-20251020164826810](./images/swdr_reg_1028_s80_.png) + + +![image-20251020164826810](./images/swdr_wDiff_1028_s80_.png) \ No newline at end of file diff --git a/examples/weather/corrdiff/conf/config_generate_multidiffsolar.yaml b/examples/weather/corrdiff/conf/config_generate_multidiffsolar.yaml new file mode 100644 index 0000000000..d89f386aae --- /dev/null +++ b/examples/weather/corrdiff/conf/config_generate_multidiffsolar.yaml @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: false + name: generate_solar # Change `my_job_name` + run: + dir: ./output/${hydra:job.name} # Change `my_output_dir` + searchpath: + - pkg://conf/base # Do not modify + +# Base parameters for dataset, model, and generation +defaults: + + - dataset: solar + # The dataset type for training. + # Accepted values: + # `gefs_hrrr`: full GEFS-HRRR dataset for continental US. + # `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments. + # `cwb`: full CWB dataset for Taiwan. + # `custom`: user-defined dataset. Parameters need to be specified below. + + - generation: non_patched + # The base generation parameters. + # Accepted values: + # `patched`: base parameters for a patch-based model + # `non_patched`: base parameters for a non-patched model + + +# Dataset parameters. Used for `custom` dataset type. +# Modify or add below parameters that should be passed as argument to the +# user-defined dataset class. +dataset: + type: datasets/solar_dataset.py::SolarDataset + # Path to the user-defined dataset class. The user-defined dataset class is + # automatically loaded from the path. The user-defined class "DatasetClass" + # must be defined in the path "path/to/dataset.py". + data_path: data + # Path to .nc data file + stats_path: data/stats.json + input_variables: ["ssrd","t2m","tcwv","t50","t100","t300","t500","t925","t1000","q50","q100","q300","q500","q925","q1000","z50","z100","z300","z500","z925","z1000"] + generating: true + train: false + valid_years: [2020] + stride_gen: 160 + # Path to json stats file + +# Generation parameters to specialize +generation: + num_ensembles: 1 + # int, number of ensembles to generate per input + seed_batch_size: 1 + # int, size of the batched inference + times: ["2020-07-01T00:00:00"] + # List[str], time stamps in ISO 8601 format. Replace and list desired target + # time stamps. + has_lead_time: True + inference_mode: regression + io: + #res_ckpt_filename: ./all_checkpoints_regression/UNet.0.1500160.mdlus + # Path to checkpoint file for the diffusion model + reg_ckpt_filename: ./checkpoints_regression/UNet.0.4640000.mdlus + # Path to checkpoint file for the regression model + output_filename: "regression_test_output.nc" + # Path to checkpoint filename for the mean predictor model + +# Parameters for wandb logging +wandb: + mode: offline + # Configure whether to use wandb: "offline", "online", "disabled" diff --git a/examples/weather/corrdiff/conf/config_generate_multidiffsolar_wDiff.yaml b/examples/weather/corrdiff/conf/config_generate_multidiffsolar_wDiff.yaml new file mode 100644 index 0000000000..d7de3d9e80 --- /dev/null +++ b/examples/weather/corrdiff/conf/config_generate_multidiffsolar_wDiff.yaml @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: false + name: generate_solar # Change `my_job_name` + run: + dir: ./output/${hydra:job.name} # Change `my_output_dir` + searchpath: + - pkg://conf/base # Do not modify + +# Base parameters for dataset, model, and generation +defaults: + + - dataset: solar + # The dataset type for training. + # Accepted values: + # `gefs_hrrr`: full GEFS-HRRR dataset for continental US. + # `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments. + # `cwb`: full CWB dataset for Taiwan. + # `custom`: user-defined dataset. Parameters need to be specified below. + + - generation: non_patched + # The base generation parameters. + # Accepted values: + # `patched`: base parameters for a patch-based model + # `non_patched`: base parameters for a non-patched model + + +# Dataset parameters. Used for `custom` dataset type. +# Modify or add below parameters that should be passed as argument to the +# user-defined dataset class. +dataset: + type: datasets/solar_dataset.py::SolarDataset + # Path to the user-defined dataset class. The user-defined dataset class is + # automatically loaded from the path. The user-defined class "DatasetClass" + # must be defined in the path "path/to/dataset.py". + data_path: data + # Path to .nc data file + stats_path: data/stats.json + input_variables: ["ssrd","t2m","tcwv","t50","t100","t300","t500","t925","t1000","q50","q100","q300","q500","q925","q1000","z50","z100","z300","z500","z925","z1000"] + generating: true + train: false + valid_years: [2020] + stride_gen: 160 + # Path to json stats file + +# Generation parameters to specialize +generation: + num_ensembles: 1 + # int, number of ensembles to generate per input + seed_batch_size: 1 + # int, size of the batched inference + times: ["2020-07-01T00:00:00"] + # List[str], time stamps in ISO 8601 format. Replace and list desired target + # time stamps. + has_lead_time: True + inference_mode: all + io: + res_ckpt_filename: ./checkpoints_diffusion/EDMPrecondSuperResolution.0.680064.mdlus + # Path to checkpoint file for the diffusion model + reg_ckpt_filename: ./checkpoints_regression/UNet.0.4640000.mdlus + # Path to checkpoint file for the regression model + output_filename: "wDiff_test_output_stride320.nc" + # Path to checkpoint filename for the mean predictor model + +# Parameters for wandb logging +wandb: + mode: offline + # Configure whether to use wandb: "offline", "online", "disabled" diff --git a/examples/weather/corrdiff/conf/config_training_multidiffsolar_diffusion.yaml b/examples/weather/corrdiff/conf/config_training_multidiffsolar_diffusion.yaml new file mode 100644 index 0000000000..a7aeb05e37 --- /dev/null +++ b/examples/weather/corrdiff/conf/config_training_multidiffsolar_diffusion.yaml @@ -0,0 +1,112 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: false + name: solar_diffusion # Change `my_job_name` + run: + dir: ./output/${hydra:job.name} # Change `my_output_dir` + searchpath: + - pkg://conf/base # Do not modify + +# Base parameters for dataset, model, training, and validation +defaults: + + - dataset: solar #keep same with the conf/base/dataset/{}.yaml + # The dataset type for training. + # Accepted values: + # `gefs_hrrr`: full GEFS-HRRR dataset for continental US. + # `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments. + # `cwb`: full CWB dataset for Taiwan. + # `custom`: user-defined dataset. Parameters need to be specified below. + + - model: diffusion + # The model type. + # Accepted values: + # `regression`: a regression UNet for deterministic predictions + # `lt_aware_ce_regression`: similar to `regression` but with lead time + # conditioning + # `diffusion`: a diffusion UNet for residual predictions + # `patched_diffusion`: a more memory-efficient diffusion model + # `lt_aware_patched_diffusion`: similar to `patched_diffusion` but + # with lead time conditioning + + - model_size: normal + # The model size configuration. + # Accepted values: + # `normal`: normal model size + # `mini`: smaller model size for fast experiments + + - training: ${model} + # The base training parameters. Determined by the model type. + + +# Dataset parameters. Used for `custom` dataset type. +# Modify or add below parameters that should be passed as argument to the +# user-defined dataset class. +dataset: + type: datasets/solar_dataset.py::SolarDataset + # Path to the user-defined dataset class. The user-defined dataset class is + # automatically loaded from the path. The user-defined class "DatasetClass" + # must be defined in the path "path/to/dataset.py". + data_path: data + # Path to .nc data file + stats_path: data/stats.json + input_variables: ["t2m","tcwv","t50","t100","t300","t500","t925","t1000","q50","q100","q300","q500","q925","q1000","z50","z100","z300","z500","z925","z1000"] + stride_train: 40 + window_size: 320 + +# Training parameters +training: + hp: + training_duration: 10000000 + total_batch_size: 128 + batch_size_per_gpu: 4 + lr: 0.0001 + + io: + regression_checkpoint_path: ./checkpoints_regression/UNet.0.4640000.mdlus + # Where to load the regression checkpoint. Should be overridden. + print_progress_freq: 1024 + # How often to print progress + save_checkpoint_freq: 40000 + # How often to save the checkpoints, measured in number of processed samples + save_n_recent_checkpoints: -1 + # Set to a positive integer to only keep the most recent n checkpoints + validation_freq: 10000 + # how often to record the validation loss, measured in number of processed samples + validation_steps: 80 + # how many loss evaluations are used to compute the validation loss per checkpoint + + +validation: + # Reuse the same dataset class as the training dataset + type: ${dataset.type} + train: false + # Reuse the same data and stats paths + data_path: ${dataset.data_path} + stats_path: ${dataset.stats_path} + + +# Parameters for wandb logging +wandb: + mode: offline + # Configure whether to use wandb: "offline", "online", "disabled" + results_dir: "./wandb" + # Directory to store wandb results + watch_model: false + # If true, wandb will track model parameters and gradients \ No newline at end of file diff --git a/examples/weather/corrdiff/conf/config_training_multidiffsolar_regression.yaml b/examples/weather/corrdiff/conf/config_training_multidiffsolar_regression.yaml new file mode 100644 index 0000000000..24f224a50c --- /dev/null +++ b/examples/weather/corrdiff/conf/config_training_multidiffsolar_regression.yaml @@ -0,0 +1,119 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: false + name: solar_regression # Change `my_job_name` + run: + dir: ./output/${hydra:job.name} # Change `my_output_dir` + searchpath: + - pkg://conf/base # Do not modify + +# Base parameters for dataset, model, training, and validation +defaults: + + - dataset: solar #keep same with the conf/base/dataset/{}.yaml + # The dataset type for training. + # Accepted values: + # `gefs_hrrr`: full GEFS-HRRR dataset for continental US. + # `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments. + # `cwb`: full CWB dataset for Taiwan. + # `custom`: user-defined dataset. Parameters need to be specified below. + + - model: regression + # The model type. + # Accepted values: + # `regression`: a regression UNet for deterministic predictions + # `lt_aware_ce_regression`: similar to `regression` but with lead time + # conditioning + # `diffusion`: a diffusion UNet for residual predictions + # `patched_diffusion`: a more memory-efficient diffusion model + # `lt_aware_patched_diffusion`: similar to `patched_diffusion` but + # with lead time conditioning + + - model_size: normal + # The model size configuration. + # Accepted values: + # `normal`: normal model size + # `mini`: smaller model size for fast experiments + + - training: ${model} + # The base training parameters. Determined by the model type. + + +# Dataset parameters. Used for `custom` dataset type. +# Modify or add below parameters that should be passed as argument to the +# user-defined dataset class. +dataset: + type: datasets/solar_dataset.py::SolarDataset + # Path to the user-defined dataset class. The user-defined dataset class is + # automatically loaded from the path. The user-defined class "DatasetClass" + # must be defined in the path "path/to/dataset.py". + data_path: data + # Path to .nc data file + stats_path: data/stats.json + input_variables: ["t2m","tcwv","t50","t100","t300","t500","t925","t1000","q50","q100","q300","q500","q925","q1000","z50","z100","z300","z500","z925","z1000"] + stride_train: 40 + window_size: 320 + +# Training parameters +training: + hp: + training_duration: 5500000 + total_batch_size: 128 + batch_size_per_gpu: 4 + lr: 0.0002 + lr_decay: 1 + # LR decay rate + lr_rampup: 550000 + # Rampup for learning rate, in number of samples + lr_decay_rate: 5e5 + # Learning rate decay threshold in number of samples, applied every lr_decay_rate samples. + final_lr: 0.000002 + + io: + regression_checkpoint_path: null + # Where to load the regression checkpoint. Should be overridden. + print_progress_freq: 1024 + # How often to print progress + save_checkpoint_freq: 40000 + # How often to save the checkpoints, measured in number of processed samples + save_n_recent_checkpoints: -1 + # Set to a positive integer to only keep the most recent n checkpoints + validation_freq: 10000 + # how often to record the validation loss, measured in number of processed samples + validation_steps: 80 + # how many loss evaluations are used to compute the validation loss per checkpoint + + +validation: + # Reuse the same dataset class as the training dataset + type: ${dataset.type} + train: false + # Reuse the same data and stats paths + data_path: ${dataset.data_path} + stats_path: ${dataset.stats_path} + + +# Parameters for wandb logging +wandb: + mode: offline + # Configure whether to use wandb: "offline", "online", "disabled" + results_dir: "./wandb" + # Directory to store wandb results + watch_model: false + # If true, wandb will track model parameters and gradients \ No newline at end of file diff --git a/examples/weather/corrdiff/datasets/solar_dataset.py b/examples/weather/corrdiff/datasets/solar_dataset.py new file mode 100644 index 0000000000..4558557cd3 --- /dev/null +++ b/examples/weather/corrdiff/datasets/solar_dataset.py @@ -0,0 +1,525 @@ +# Data loader for TWC MVP: GEFS and HRRR forecasts +# adapted from https://gitlab-master.nvidia.com/earth-2/corrdiff-internal/-/blob/dpruitt/hrrr/explore/dpruitt/hrrr/datasets/hrrr.py + +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime, timedelta +import glob +import logging +import os +from typing import Iterable, Tuple, Union, List +import copy + +import cftime +import dask +import json +import numpy as np +import torch +import xarray as xr +import cv2 + +from physicsnemo.distributed import DistributedManager + +from datasets.base import ChannelMetadata, DownscalingDataset + +from earth2studio.utils import ( + handshake_coords, + handshake_dim, + interp, +) +import pandas as pd +from physicsnemo.utils.zenith_angle import cos_zenith_angle_from_timestamp, cos_zenith_angle +import random +def convert_datetime_to_cftime( + time: datetime, cls=cftime.DatetimeGregorian +) -> cftime.DatetimeGregorian: + """Convert a Python datetime object to a cftime DatetimeGregorian object.""" + return cls(time.year, time.month, time.day, time.hour, time.minute, time.second) + + +def time_range( + start_time: datetime, + end_time: datetime, + step: timedelta, + inclusive: bool = False, +): + """Like the Python `range` iterator, but with datetimes.""" + t = start_time + while (t <= end_time) if inclusive else (t < end_time): + yield t + t += step + + +class SolarDataset(DownscalingDataset): + """ + Paired dataset object serving time-synchronized pairs of ERA5 and Envision-wind samples + Expects data to be stored under directory specified by 'location' + ERA5 under /ERA5/ + Solar under /H08/ + Within ERA5 directory, there should be one zarr file per year containing the data of interest. + Within H08 directory, there should be many nc file per hour containing the data of interest. + """ + + def __init__( + self, + *, + data_path: str, + stats_path: str, + input_variables: Union[List[str], None] = None, + output_variables: Union[List[str], None] = None, + invariant_variables: Union[List[str], None] = ("dem"), + train: bool = True, + normalize: bool = True, + train_years: Iterable[int] = (2022,), + valid_years: Iterable[int] = (2021,), + sample_shape: Tuple[int, int] = [-1, -1], + ds_factor: int = 1, + shard: bool = False, + overfit: bool = False, + use_all: bool = False, + normal_way: str = "min_max", + generating: bool = False, + stride_train: int = 80, + stride_gen: int = 160, + window_size: int = 320, + solar: bool = True + ): + dask.config.set( + scheduler="synchronous" + ) # for threadsafe multiworker dataloaders + self.data_path = data_path + self.train = train + self.normalize = normalize + self.output_variables = output_variables + self.input_variables = input_variables + self.invariant_variables = invariant_variables + self.train_years = list(train_years) + self.valid_years = list(valid_years) + self.normal_way = normal_way + self.solar = solar + + self.sample_shape = sample_shape + self.ds_factor = ds_factor + self.shard = shard + self.use_all = use_all + self.output_variables_load = copy.deepcopy(output_variables) + + self._get_files_stats() + self.overfit = overfit + + self.window_size = window_size + self.generating = generating + if not self.generating: + self.stride_train = stride_train + self.windows = self.get_windows(stride=self.stride_train) + logging.info(f"The num of training windows is: {len(self.windows)}") + else: + self.stride_gen = stride_gen + self.era5_input_variables = self.input_variables + + + with open(stats_path, "r") as f: + stats = json.load(f) + + (self.input_center, self.input_scale) = _load_stats( + stats, self.input_variables, "era5", self.normal_way + ) + (self.output_center, self.output_scale) = _load_stats( + stats, self.output_variables, "h08", self.normal_way + ) + if self.invariant_variables is not None: + (self.inv_center, self.inv_scale) = _load_stats( + stats, self.invariant_variables, "inv", self.normal_way + ) + self.invs = self._get_inv() + + lon_input_grid,lat_input_grid = np.meshgrid(self.era5_lon, self.era5_lat) + self.lon_output_grid,self.lat_output_grid = np.meshgrid(self.H08_lon, self.H08_lat) + self._interpolator = interp.LatLonInterpolation( + lat_input_grid, + lon_input_grid, + self.lat_output_grid, + self.lon_output_grid, + ) + + def _apply_window(self,x,window): + ((y_start, y_end), (x_start, x_end)) = window + """ + x: [,,, hight,width] + Crop the data to the H08-window size + """ + if len(x.shape)==2: + return x[y_start:y_end,x_start:x_end] + elif len(x.shape)==3: + return x[:,y_start:y_end,x_start:x_end] + + def _get_files_stats(self): + """ + Scan directories and extract metadata for H08 and ERA5 + + We assume: + - ERA5 files are at self.data_path/LRdata/ with name era5_YYYY_opt.zarr + - HR files are at self.data_path/HRdata/ with folder name H08_YYYY_hourly + """ + + # training or validating, different files will be read + years_to_use = self.train_years if self.train else self.valid_years + logging.info(f"years_to_use: {years_to_use}") + # ERA5 parsing + self.ds_era5 = {} + LR_paths_all = glob.glob( + os.path.join(self.data_path, "LRdata", "era5_*_opt.zarr") + ) + logging.info(f"LR_paths_all: {LR_paths_all}") + # Get years from paths. e.g. '.../era5_2021_opt.zarr' -> '2021' + era5_years = [os.path.basename(p).split('.')[0].split('_')[1] for p in LR_paths_all] + self.era5_paths = dict(zip(era5_years, LR_paths_all)) + logging.info(f"era5_years: {era5_years}") + # Only keep the years to be used + self.era5_paths = { + year: path + for (year, path) in self.era5_paths.items() + if int(year) in years_to_use + } + + # Use the first year to load metadata + first_era5_key = sorted(self.era5_paths.keys())[0] + with xr.open_zarr(self.era5_paths[first_era5_key], consolidated=True) as ds: + + self.era5_lat = ds['latitude'].values + self.era5_lon = ds['longitude'].values + + # H08 parsing + self.ds_H08 = {} + HR_paths_all = glob.glob( + os.path.join(self.data_path, "HRdata", "H08_*_hourly") + ) + + H08_years = [os.path.basename(p).split('.')[0].split('_')[1] for p in HR_paths_all] + self.H08_paths = dict(zip(H08_years, HR_paths_all)) + self.H08_paths = { + year: path + for (year, path) in self.H08_paths.items() + if int(year) in years_to_use + } + + first_H08_key = sorted(self.H08_paths.keys())[0] #folds + #the first path self.H08_paths[first_H08_key] + nc_file = glob.glob(os.path.join(self.H08_paths[first_H08_key], '*.nc'))[0] + logging.info(f"We achieve the lat/lon from {nc_file}") + with xr.open_dataset(nc_file) as ds: + self.H08_lat = ds['latitude'].values + self.H08_lon = ds['longitude'].values + + # Get all years + self.years = set([int(key) for key in self.H08_paths.keys()]) + self.n_samples_total = self.compute_total_samples() + + def __len__(self): + return len(self.valid_samples)-1 + + + def compute_total_samples(self): + + # count the total number of samples from valid_time of H08 files + all_datetimes = set() + # Loop self.H08_paths.values() from _get_files_stats + for year, path in self.H08_paths.items(): + logging.info(f"Reading {year} H08 files: {path}") + nc_files = glob.glob(os.path.join(path, 'H08_*.nc')) + for file_path in nc_files: + filename = os.path.basename(file_path) + datetime_str = '_'.join(filename.split('_')[1:3]) + datetime_obj = np.datetime64(datetime.strptime(datetime_str, '%Y%m%d_%H%M')) + all_datetimes.add(datetime_obj) + self.valid_samples = sorted(list(all_datetimes)) + + logging.info( + "Scan done. We have {} samlpes".format(len(self.valid_samples)) + ) + logging.info(f"The first time: {self.valid_samples[0]}") + logging.info(f"The last time: {self.valid_samples[-1]}") + + # prepare data for distributed training + if self.shard: + dist_manager = DistributedManager() + self.valid_samples = np.array_split( + self.valid_samples, dist_manager.world_size + )[dist_manager.rank] + logging.info( + f"(Rank {dist_manager.rank}) " + f"has {len(self.valid_samples)} samples" + ) + + return len(self.valid_samples) + + def normalize_input(self, x): + x = x.astype(np.float32) + if self.normalize: + x -= self.input_center + x /= self.input_scale + return x + + def denormalize_input(self, x): + x = x.astype(np.float32) + if self.normalize: + x *= self.input_scale + x += self.input_center + return x + + def normalize_output(self, x): + x = x.astype(np.float32) + if self.normalize: + x -= self.output_center + x /= self.output_scale + return x + + def denormalize_output(self, x): + x = x.astype(np.float32) + + if self.normalize: + x *= self.output_scale + x += self.output_center + return x + + + def _interp(self,x): + + x = torch.from_numpy(x).unsqueeze(0) + x = self._interpolator(x.float()).squeeze(0).numpy() + return x + def _get_inv(self): + file_path = os.path.join(self.data_path, "dem.nc") + + ds = xr.open_dataset(file_path) + invs = [] + for inv in self.invariant_variables: + invs.append(ds[inv].values) + invs = np.stack(invs) + invs = (invs - self.inv_center)/self.inv_scale + + return invs + def _get_era5(self, ts): + """ + Retrieve ERA5 samples from zarr files given valid_time + """ + year = ts.astype('datetime64[Y]').astype(int) + 1970 + year_str = str(year) + + #cache the handle + if year_str not in self.ds_era5: + era5_path = self.era5_paths[year_str] + self.ds_era5[year_str] = xr.open_zarr(era5_path, consolidated=True) + #get the handle + era5_handle = self.ds_era5[year_str] + + era5_field = [] + for var in self.input_variables: + era5_field.append(era5_handle[var].sel(valid_time=ts,method='nearest').values) + era5_field = np.stack(era5_field) + + if len(era5_field.shape) == 4: + era5_field = era5_field[0] + + era5_field = self._interp(era5_field) + + era5_field = self.normalize_input(era5_field) + + return era5_field + + def _get_H08(self, ts): + """ + Retrieve H08 samples from nc files given valid_time + """ + + year = ts.astype('datetime64[Y]').astype(int) + 1970 + year_str = str(year) + #ts --> nc_filename + H08_path = self.H08_paths[year_str] + ts_pd = pd.to_datetime(ts) + datetime_str = ts_pd.strftime('%Y%m%d_%H%M') + filename = f"H08_{datetime_str}_hourly.nc" + file_path = os.path.join(H08_path, filename) + H08_handle = xr.open_dataset(file_path) + H08_field = [] + for var in self.output_variables: + H08_field.append(H08_handle[var].values) + H08_field = np.stack(H08_field) + + if len(H08_field.shape) == 4: + H08_field = H08_field[0] + H08_field = self.normalize_output(H08_field) + return H08_field + + def image_shape(self) -> Tuple[int, int]: + """Get the (height, width) of the data (same for input and output).""" + + return (self.window_size, self.window_size) + + def compute_sza(self, ts): + """Compute solar zenith angle for given coordinates. + """ + grid = np.meshgrid(self.H08_lon, self.H08_lat) + lon, lat = grid[0].reshape(-1), grid[1].reshape(-1) + + pd_ts = pd.to_datetime(ts) + yy, mm, dd, hh = pd_ts.year, pd_ts.month, pd_ts.day, pd_ts.hour + + zeith_arr = [] + for miint in range(6): + ztime = datetime(yy, mm, dd, hh, miint * 10, 0) + zeith = cos_zenith_angle(ztime, lon, lat).reshape((len(self.H08_lat),len(self.H08_lon))) + + zeith_arr.append(zeith) + + zeith = np.stack(zeith_arr) + + return zeith + + def get_windows(self, stride=8): + window_size = 320 #self.image_shape[0] + height, width = len(self.H08_lat),len(self.H08_lon) + + if window_size > height or window_size > width: + raise ValueError("window_size cannot be larger than the panorama dimensions") + h_starts = list(range(0, height - window_size, stride)) + w_starts = list(range(0, width - window_size, stride)) + + if (height - window_size) not in h_starts: + h_starts.append(height - window_size) + + if (width - window_size) not in w_starts: + w_starts.append(width - window_size) + + windows = [] + for h_s in h_starts: + for w_s in w_starts: + h_e = h_s + window_size + w_e = w_s + window_size + windows.append((h_s, h_e, w_s, w_e)) + + return windows + + def __getitem__(self, global_idx): + """Return a tuple of: + - H08_field: High-resolution H08 output data + - era5_field: Low-resolution ERA5 input data (interpolated) + - lead_time_label: Lead time + """ + time_index = self._global_idx_to_datetime(global_idx) + + H08_sample = self._get_H08(time_index) + era5_sample_T = self._get_era5(time_index) + time_index_1 = self._global_idx_to_datetime(global_idx+1) + era5_sample_T_1 = self._get_era5(time_index_1) + era5_sample = np.stack([era5_sample_T, era5_sample_T_1], axis=1) + C,H,W = era5_sample_T.shape + era5_sample = era5_sample.reshape(C*2, H, W) + zeith = self.compute_sza(time_index) + + era5_sample = np.concatenate([era5_sample,zeith],axis=0) + + if np.isnan(era5_sample).any() or np.isnan(H08_sample).any(): + logging.info(f"We find nan in sample at {time_index}") + torch.cuda.nvtx.range_pop() + if self.invariant_variables is not None: + img_lr = np.concatenate([era5_sample,self.invs],axis=0) + else: + img_lr = era5_sample + + if not self.generating: + #when training, we randomly select a window + idx = random.randint(0,len(self.windows)-1) + window = ((self.windows[idx][0], self.windows[idx][1]),(self.windows[idx][2], self.windows[idx][3])) + + img_lr = self._apply_window(img_lr,window) + H08_sample = self._apply_window(H08_sample,window) + else: + #when generating, we return all sliding windows + window = self.get_windows(stride=160) + logging.info(f"window0:{window[0]}") + logging.info(f"windows:{window}") + + pd_ts = pd.to_datetime(time_index) + yy, mm, dd, hh = pd_ts.year, pd_ts.month, pd_ts.day, pd_ts.hour + + return H08_sample, img_lr, window, (yy, mm, dd, hh) + + def _global_idx_to_datetime(self, global_idx): + """ + Parse a global sample index and return the input/target timstamps as datetimes + """ + return self.valid_samples[global_idx] + + @staticmethod + def _create_lowres_(x, factor=4): + # downsample the high res imag + x = x.transpose(1, 2, 0) + x = x[::factor, ::factor, :] # 8x8x3 #subsample + # upsample with bicubic interpolation to bring the image to the nominal size + x = cv2.resize( + x, (x.shape[1] * factor, x.shape[0] * factor), interpolation=cv2.INTER_CUBIC + ) # 32x32x3 + x = x.transpose(2, 0, 1) # 3x32x32 + return x + + def latitude(self): + return self.H08_lat #if self.train else self.crop_to_fit(self.H08_lat) + + def longitude(self): + return self.H08_lon #if self.train else self.crop_to_fit(self.H08_lon) + + def input_channels(self): + era5_variables = self.input_variables + self.era5_input = era5_variables + era5_variables_1 = [s + '_1' for s in self.input_variables] + variables = era5_variables + era5_variables_1 + ['zeith0','zeith1','zeith2','zeith3','zeith4','zeith5'] + #[ var for pair in zip(era5_variables, era5_variables_1) for var in pair] + ['zeith0','zeith1','zeith2','zeith3','zeith4','zeith5'] + if self.invariant_variables is not None: + variables += self.invariant_variables + return [ChannelMetadata(name=n) for n in variables] + else: + return [ChannelMetadata(name=n) for n in variables] + + def output_channels(self): + variables = self.output_variables + [s + '_1' for s in self.output_variables]+ [s + '_2' for s in self.output_variables]+ [s + '_3' for s in self.output_variables]+ [s + '_4' for s in self.output_variables]+ [s + '_5' for s in self.output_variables] + return [ChannelMetadata(name=n) for n in variables] + + def time(self): + return self.valid_samples + + + + +def _load_stats(stats, variables, group, normal_way = "min_max"): + + if normal_way == "min_max": + center = np.array([stats[group][v]["center"] for v in variables])[:, None, None].astype( + np.float32 + ) + scale = np.array([stats[group][v]["scale"] for v in variables])[:, None, None].astype( + np.float32 + ) + elif normal_way == "mean_std": + center = np.array([stats[group][v]["mean"] for v in variables])[:, None, None].astype( + np.float32 + ) + scale = np.array([stats[group][v]["std"] for v in variables])[:, None, None].astype( + np.float32 + ) + + return (center, scale) diff --git a/examples/weather/corrdiff/generate.py b/examples/weather/corrdiff/generate.py index 5e1e699427..5392245a76 100644 --- a/examples/weather/corrdiff/generate.py +++ b/examples/weather/corrdiff/generate.py @@ -88,7 +88,7 @@ def main(cfg: DictConfig) -> None: # Create dataset object dataset_cfg = OmegaConf.to_container(cfg.dataset) - + # Register dataset (if custom dataset) register_dataset(cfg.dataset.type) logger0.info(f"Using dataset: {cfg.dataset.type}") @@ -100,6 +100,7 @@ def main(cfg: DictConfig) -> None: dataset, sampler = get_dataset_and_sampler( dataset_cfg=dataset_cfg, times=times, has_lead_time=has_lead_time ) + solar = getattr(dataset, 'solar', False) #If we are doing solar downscaling img_shape = dataset.image_shape() img_out_channels = len(dataset.output_channels()) @@ -381,7 +382,8 @@ def elapsed_time(self, _): start = end = DummyEvent() times = dataset.time() - for dataset_index, (image_tar, image_lr, *lead_time_label) in zip( + + for dataset_index, batch in zip( sampler, iter(data_loader), ): @@ -392,7 +394,11 @@ def elapsed_time(self, _): if time_index == warmup_steps: start.record() - # continue + if solar: + image_tar, image_lr, windows, *lead_time_label = batch + else: + image_tar, image_lr, *lead_time_label = batch + if lead_time_label: lead_time_label = lead_time_label[0].to(dist.device).contiguous() else: @@ -403,7 +409,19 @@ def elapsed_time(self, _): .to(memory_format=torch.channels_last) ) image_tar = image_tar.to(device=device).to(torch.float32) - image_out = generate_fn() + if solar:#We need perform multi-diffusion generation + image_out = generate_solar(image_lr_full=image_lr, + image_tar_full=image_tar, + windows=windows, + net_reg=net_reg, + logger0 = logger0 + img_out_channels = img_out_channels, + net_res = net_res, + seeds = seeds + ) + else: + image_out = generate_fn() + if dist.rank == 0: batch_size = image_out.shape[0] if cfg.generation.perf.io_synchronous: @@ -419,6 +437,7 @@ def elapsed_time(self, _): image_lr.cpu(), time_index, dataset_index, + solar ) ) else: @@ -431,6 +450,7 @@ def elapsed_time(self, _): image_lr.cpu(), time_index, dataset_index, + solar ) end.record() end.synchronize() diff --git a/examples/weather/corrdiff/helpers/generate_helpers.py b/examples/weather/corrdiff/helpers/generate_helpers.py index abfa4fee87..71f40c91c3 100644 --- a/examples/weather/corrdiff/helpers/generate_helpers.py +++ b/examples/weather/corrdiff/helpers/generate_helpers.py @@ -21,6 +21,15 @@ from datasets.dataset import init_dataset_from_config from datasets.base import DownscalingDataset +import torch +import torch.nn as nn +from torch import Tensor +from typing import Optional, Callable +from tqdm.auto import tqdm +import nvtx +from physicsnemo.utils.corrdiff import regression_step +from physicsnemo.utils.generative import StackedRandomGenerator + def get_dataset_and_sampler(dataset_cfg, times, has_lead_time=False): """ @@ -52,6 +61,7 @@ def save_images( image_lr, time_index, dataset_index, + solar ): """ Saves inferencing result along with the baseline @@ -75,7 +85,17 @@ def save_images( # weather sub-plot image_lr2 = image_lr[0].unsqueeze(0) image_lr2 = image_lr2.cpu().numpy() - image_lr2 = dataset.denormalize_input(image_lr2) + if solar: #In solar downscaling, we input the ear5 variables at two time (T, T+1) + len_era5 = len(dataset.era5_input) + era5_sample = image_lr2[:,0:2*len_era5,:,:] + + era5_sample_T = era5_sample[:,0::2, :, :] + era5_sample_T_1 = era5_sample[:,1::2, :, :] + + image_lr2[:,0:len_era5,:,:] = dataset.denormalize_input(era5_sample_T) + image_lr2[:,len_era5:2*len_era5,:,:] = dataset.denormalize_input(era5_sample_T_1) + else: + image_lr2 = dataset.denormalize_input(image_lr2) image_tar2 = image_tar[0].unsqueeze(0) image_tar2 = image_tar2.cpu().numpy() @@ -113,3 +133,206 @@ def save_images( writer.write_input(channel_name, time_index, image_lr2[0, channel_idx]) if channel_idx == image_lr2.shape[1] - 1: break + + + +class MultiDiffusion(nn.Module): + def __init__(self, device): + super().__init__() + self.device = device + + @torch.no_grad() + def __call__( + self, + net: torch.nn.Module, + img_lr: Tensor, + regression_output: Tensor, + class_labels: Optional[Tensor] = None, + randn_like: Callable[[Tensor], Tensor] = torch.randn_like, + windows: Optional[Tensor] = None, + lead_time_label: Optional[Tensor] = None, + num_steps: int = 18, + sigma_min: float = 0.002, + sigma_max: float = 800, + rho: float = 7, + S_churn: float = 0, + S_min: float = 0, + S_max: float = float("inf"), + S_noise: float = 1, + ) -> Tensor: + """ + + Args: + net (torch.nn.Module): the diffusion model + regression_output (Tensor): output from regression model (B, C_cond, H, W)。 + randn_like (Callable): gaussian sampler + windows : All windows + stride (int): the stride between windows + + Returns: + Tensor: (B, C_out, H, W) + """ + + sigma_min = max(sigma_min, net.sigma_min) + sigma_max = min(sigma_max, net.sigma_max) + batch_size, _, height, width = regression_output.shape + x_lr = torch.cat((regression_output,img_lr), dim=1) + latents = randn_like(regression_output) + + step_indices = torch.arange(num_steps, device=self.device) + t_steps = ( + sigma_max ** (1 / rho) + + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) + ) ** rho + t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) + + views = windows + + value = torch.zeros_like(latents) + count = torch.zeros_like(latents) + + optional_args = {} + if lead_time_label is not None: + optional_args["lead_time_label"] = lead_time_label + + x_next = latents * t_steps[0] + + for i, (t_cur, t_next) in enumerate(tqdm(zip(t_steps[:-1], t_steps[1:]), total=num_steps)): + x_cur = x_next.clone() + + value.zero_() + count.zero_() + #print(f"x_cur:{x_cur.shape}") + for view in views: + h_start, h_end, w_start, w_end = int(view[0]),int(view[1]),int(view[2]),int(view[3]) + x_cur_view = x_cur[:, :, h_start:h_end, w_start:w_end] + x_lr_view = x_lr[:, :, h_start:h_end, w_start:w_end] + + #(Churning) + gamma = S_churn / num_steps if S_min <= t_cur <= S_max else 0 + t_hat = net.round_sigma(t_cur + gamma * t_cur) + x_hat_view = x_cur_view + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur_view) + #(Euler step part 1) + denoised_view = net( + x_hat_view, + x_lr_view, + t_hat, + class_labels, + **optional_args, + ) + + d_cur_view = (x_hat_view - denoised_view) / t_hat + + x_next_view_first_order = x_hat_view + (t_next - t_hat) * d_cur_view + + if i < num_steps - 1: + denoised_prime_view = net( + x_next_view_first_order, + x_lr_view, + t_next, + class_labels, + **optional_args, + ) + d_prime_view = (x_next_view_first_order - denoised_prime_view) / t_next + x_next_view = x_hat_view + (t_next - t_hat) * (0.5 * d_cur_view + 0.5 * d_prime_view) + else: + x_next_view = x_next_view_first_order + + value[:, :, h_start:h_end, w_start:w_end] += x_next_view + count[:, :, h_start:h_end, w_start:w_end] += 1 + #print(f"One window finished.") + + x_next = torch.where(count > 0, value / count, value) + + return x_next + + + +def generate_solar( + image_lr_full: torch.Tensor, + image_tar_full: torch.Tensor, + windows: list, + net_reg: torch.nn.Module, + logger0, + img_out_channels: int = None, + net_res: torch.nn.Module = None, + seeds: list = None, + lead_time_label: torch.Tensor = None, +): + """ + A full generation function for solar downscaling + Regression and selectabel Resiual in Multi-Diffusion way + + This function first run regression model on multi-windows to get the high-resolution output. + if net_res is provided, Multi-Diffusion is performed for fine details + Args: + image_lr_full (torch.Tensor): Low-resolution input tensor + image_tar_full (torch.Tensor): The target high-resolution tensor + windows (list): The pre-defined windows + net_reg (torch.nn.Module): Regression network + logger0 + net_res (torch.nn.Module, optional): The resiual Diffusion network + + Returns: + torch.Tensor or None: (B, C, H, W)。 + """ + with nvtx.annotate("generate_solar", color="blue"): + device = image_lr_full.device + + image_reg_full = torch.zeros_like(image_tar_full).to(device=device).to(torch.float32) + counts = torch.zeros_like(image_tar_full) + logger0.info(f"Input LR shape: {image_lr_full.shape}") + logger0.info(f"Target HR shape for stitching: {image_reg_full.shape}") + + with nvtx.annotate("solar_regression_stitching", color="green"): + for window in windows: + y_start, y_end, x_start, x_end = map(int, [window[0].item(), window[1].item(), window[2].item(), window[3].item()]) + + image_lr_patch = image_lr_full[:, :, y_start:y_end, x_start:x_end] + + _, _, h, w = image_lr_patch.shape + latents_shape = (image_lr_patch.shape[0], img_out_channels, h, w) + + with nvtx.annotate("regression_model_step", color="yellow"): + image_reg_patch = regression_step( + net=net_reg, + img_lr=image_lr_patch.to(memory_format=torch.channels_last), + latents_shape=latents_shape, + lead_time_label=lead_time_label, + ) + + image_reg_full[:, :, y_start:y_end, x_start:x_end] += image_reg_patch + counts[:, :, y_start:y_end, x_start:x_end] += 1 + + counts = torch.where(counts == 0, torch.ones_like(counts), counts) + image_reg_full = image_reg_full / counts + logger0.info(f"Stitched regression image shape: {image_reg_full.shape}") + + final_output = image_reg_full + + if net_res: + mdiff = MultiDiffusion(image_reg_full.device) + with nvtx.annotate("solar_multidiffusion", color="purple"): + logger0.info("Performing Multi-Diffusion step...") + regression_output = image_reg_full + ensemble_outputs = [] + + for i in seeds: + rnd = StackedRandomGenerator(regression_output.device, [i]) + + image_res_out_full = mdiff( + net=net_res, + img_lr=image_lr_full, + regression_output=regression_output, + windows=windows, + randn_like=rnd.randn_like, + ) + ensemble_outputs.append(regression_output + image_res_out_full) + + final_output = torch.cat(ensemble_outputs, dim=0) + logger0.info(f"Final ensemble output shape: {final_output.shape}") + else: + logger0.info("Skipping diffusion step. Output is from the regression model.") + + + return final_output \ No newline at end of file diff --git a/examples/weather/corrdiff/images/swdr_reg_1028_s80_.png b/examples/weather/corrdiff/images/swdr_reg_1028_s80_.png new file mode 100644 index 0000000000..ee7d494fb5 Binary files /dev/null and b/examples/weather/corrdiff/images/swdr_reg_1028_s80_.png differ diff --git a/examples/weather/corrdiff/images/swdr_wDiff_1028_s80_.png b/examples/weather/corrdiff/images/swdr_wDiff_1028_s80_.png new file mode 100644 index 0000000000..4b08be6ab1 Binary files /dev/null and b/examples/weather/corrdiff/images/swdr_wDiff_1028_s80_.png differ diff --git a/examples/weather/corrdiff/prepare_solar_data/get_dem/1-download.sh b/examples/weather/corrdiff/prepare_solar_data/get_dem/1-download.sh new file mode 100755 index 0000000000..fe73dfde3e --- /dev/null +++ b/examples/weather/corrdiff/prepare_solar_data/get_dem/1-download.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +wget https://srtm.csi.cgiar.org/wp-content/uploads/files/srtm_30x30/TIFF/N00E060.zip +wget https://srtm.csi.cgiar.org/wp-content/uploads/files/srtm_30x30/TIFF/N00E090.zip +wget https://srtm.csi.cgiar.org/wp-content/uploads/files/srtm_30x30/TIFF/N30E060.zip +wget https://srtm.csi.cgiar.org/wp-content/uploads/files/srtm_30x30/TIFF/N30E090.zip +wget https://srtm.csi.cgiar.org/wp-content/uploads/files/srtm_30x30/TIFF/N30E120.zip +wget https://srtm.csi.cgiar.org/wp-content/uploads/files/srtm_30x30/TIFF/N00E120.zip + +mkdir -p tif + +unzip -d tif N00E060.zip +unzip -d tif N00E090.zip +unzip -d tif N30E060.zip +unzip -d tif N30E090.zip +unzip -d tif N30E120.zip +unzip -d tif N00E120.zip diff --git a/examples/weather/corrdiff/prepare_solar_data/get_dem/2-outputDEM.py b/examples/weather/corrdiff/prepare_solar_data/get_dem/2-outputDEM.py new file mode 100644 index 0000000000..f1785eb33c --- /dev/null +++ b/examples/weather/corrdiff/prepare_solar_data/get_dem/2-outputDEM.py @@ -0,0 +1,122 @@ +import subprocess +import sys +import os + +import xarray as xr +try: + import rioxarray +except ImportError: + try: + subprocess.check_call([sys.executable, "-m", "pip", "install", "rioxarray"]) + print("rioxarray installed successfully.") + except subprocess.CalledProcessError as e: + print(f"Failed to install rioxarray. Error: {e}") + sys.exit(1) + +import numpy as np + +import glob +import re + +import cartopy.crs as ccrs +import cartopy.feature as cfeature +import matplotlib.pyplot as plt + + +tiff_folder_path = './tif' + +tiff_files = sorted(glob.glob(os.path.join(tiff_folder_path, 'cut_*.tif'))) + +if not tiff_files: + print(f"no files under '{tiff_folder_path}'") +else: + lat_keys = sorted(list(set(re.search(r'(n\d+|s\d+)', f).group(1) for f in tiff_files)), reverse=True) + lon_keys = sorted(list(set(re.search(r'(e\d+|w\d+)', f).group(1) for f in tiff_files))) + + nested_files = [] + for lat_key in lat_keys: + row = [] + for lon_key in lon_keys: + + matching_file = next((f for f in tiff_files if lat_key in f and lon_key in f), None) + if matching_file: + row.append(matching_file) + if row: + nested_files.append(row) + + + for r in nested_files: + print([os.path.basename(p) for p in r]) + + try: + merged_dataset = xr.open_mfdataset( + nested_files, + engine="rasterio", + combine='nested', + concat_dim=['y', 'x'], + chunks={} + ) + + print("\nSucess") + print(merged_dataset) + + except Exception as e: + print(f"\nFail:{e}") + + +data_subset = merged_dataset['band_data'].squeeze('band', drop=True).rename({'y': 'latitude', 'x': 'longitude'}) + +print(data_subset) + +lon_num_points = int(round((135-80) / 0.05)) + 1 +lat_num_points = int(round((55-15) / 0.05)) + 1 +target_lon = np.linspace(80, 135, lon_num_points) +target_lat = np.linspace(15, 55, lat_num_points) + + +interpolated_ds = data_subset.interp( + latitude=target_lat, + longitude=target_lon, + method="nearest", + kwargs={"fill_value": None} +) +interpolated_ds = interpolated_ds.fillna(0) + +print('Interpolated data structure') +interpolated_ds = interpolated_ds.to_dataset(name='dem') +print(interpolated_ds) + + +print("\nTake a view...") + +data_to_plot_dask = interpolated_ds['dem'] + +fig = plt.figure(figsize=(12, 10)) +ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree()) + +im = ax.pcolormesh( + data_to_plot_dask['longitude'], + data_to_plot_dask['latitude'], + data_to_plot_dask.compute(), + transform=ccrs.PlateCarree(), + cmap='terrain', +) + +ax.coastlines() +ax.add_feature(cfeature.BORDERS, linestyle=':') +ax.add_feature(cfeature.OCEAN, zorder=100, edgecolor='k') +ax.add_feature(cfeature.LAND) +ax.add_feature(cfeature.RIVERS) +ax.add_feature(cfeature.LAKES) + +gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=True, + linewidth=1, color='gray', alpha=0.5, linestyle='--') +gl.top_labels = False +gl.right_labels = False + +plt.colorbar(im, ax=ax, shrink=0.7, label='Elevation (m)') +plt.title('Interpolated Terrain Data (0.05-degree Resolution)') +plt.savefig('dem_005degree.png', dpi=150) +OUTPUT_NC_PATH = 'dem.nc' +print(f"\nStoring to: {OUTPUT_NC_PATH}...") +interpolated_ds.to_netcdf(OUTPUT_NC_PATH) diff --git a/examples/weather/corrdiff/prepare_solar_data/prepare_era5.py b/examples/weather/corrdiff/prepare_solar_data/prepare_era5.py new file mode 100644 index 0000000000..3abdddc04d --- /dev/null +++ b/examples/weather/corrdiff/prepare_solar_data/prepare_era5.py @@ -0,0 +1,89 @@ + + +import netCDF4 as nc +import numpy as np +import xarray as xr +import os +from datetime import datetime +from tqdm import tqdm + + +def rename_variable_on_load(ds, filename, filename_to_varname_map): + + base_filename = os.path.basename(filename) + + target_var_name = filename_to_varname_map.get(base_filename) + + if 'time' in ds.coords and 'valid_time' not in ds.coords: + ds = ds.rename({'time': 'valid_time'}) + if 'pressure_level' in ds.dims and ds.dims['pressure_level'] == 1: + ds = ds.squeeze('pressure_level', drop=True) + + if target_var_name: + original_var_name = list(ds.data_vars)[0] + if original_var_name != target_var_name: + return ds.rename({original_var_name: target_var_name}) + return ds + + +def main(year): + path = "path/to/ERA5/data/{}".format(year) + + filename_to_varname_map = { + f'{year}_2m_temperature.nc': 't2m', + f'{year}_surface_pressure.nc': 'sp', + f'{year}_ssrd.nc': 'ssrd', + f'{year}_total_column_water_vapour.nc': 'tcwv' + } + + for level in [1000, 925, 500, 300, 100, 50]: + filename_to_varname_map[f'{year}_q_{level}.nc'] = f'q{level}' + filename_to_varname_map[f'{year}_t_{level}.nc'] = f't{level}' + filename_to_varname_map[f'{year}_z_{level}.nc'] = f'z{level}' + + input_files = [os.path.join(path, fname) for fname in filename_to_varname_map.keys()] + + print(input_files) + + from functools import partial + preprocess_func = partial(rename_variable_on_load, filename_to_varname_map=filename_to_varname_map) + + ds = xr.open_mfdataset( + input_files, + preprocess=lambda ds: preprocess_func(ds, ds.encoding["source"]), + combine='by_coords' + ) + + + lat_bounds = [15, 55] + lon_bounds = [80, 135] + ds_cropped = ds.sel( + latitude=slice(max(lat_bounds), min(lat_bounds)), + longitude=slice(min(lon_bounds), max(lon_bounds)) + ) + ds_sorted = ds_cropped.sortby('latitude') + + order = [ + "sp", "t2m", "tcwv", "ssrd", "q1000", "q925", "q500", "q300", "q100", "q50", + "t1000", "t925", "t500", "t300", "t100", "t50", + "z1000", "z925", "z500", "z300", "z100", "z50" + ] + ds_sorted = ds_sorted[order] + + print(ds['valid_time'].values[0],ds['valid_time'].values[-1]) + output_filename = "output/path/era5_{}_opt.zarr".format(year) + ds_final = ds_sorted.chunk({ + 'valid_time': 1, + 'latitude': "auto", + 'longitude': "auto", + }) + print(ds_final) + print(ds_final.data_vars) + ds_final = ds_final.fillna(0) + ds_final.to_zarr(output_filename, mode='w', consolidated=True) + + +if __name__ == "__main__": + years = [2016,2017,2018,2019,2020] + for year in years: + main(year) \ No newline at end of file diff --git a/examples/weather/corrdiff/prepare_solar_data/prepare_h08_hourly.py b/examples/weather/corrdiff/prepare_solar_data/prepare_h08_hourly.py new file mode 100644 index 0000000000..08f1be6966 --- /dev/null +++ b/examples/weather/corrdiff/prepare_solar_data/prepare_h08_hourly.py @@ -0,0 +1,79 @@ +import xarray as xr +import pandas as pd +import os +from pathlib import Path + +def process_year_data(base_path, year): + """ + Args: + base_path (str or Path): the path to data + year (int) + """ + base_path = Path(base_path) + output_dir = base_path / f"H08_{year}_hourly" + output_dir.mkdir(exist_ok=True) + + #creat the times range, with 10mins interval + time_range = pd.date_range(start=f'{year}-01-01 00:00', end=f'{year}-12-31 23:50', freq='10min') + + for group_name, group_df in time_range.to_frame().resample('H'): + if len(group_df) == 6: # Make sure there are 6 files in an hour + file_paths = [] + all_files_exist = True + for dt in group_df.index: + #your original SWDR nc files + file_path = base_path / str(dt.year) / f"{dt.year}{dt.month:02d}" / f"{dt.day:02d}" / f"H08_{dt.strftime('%Y%m%d_%H%M')}_SWDR.nc" + + if file_path.exists(): + file_paths.append(file_path) + else: + print(f"Files Missing: {group_name.strftime('%Y-%m-%d %H')}") + all_files_exist = False + break + + if all_files_exist: + try: + datasets = [xr.open_dataset(fp) for fp in file_paths] + + lat_slice = slice(55, 15) + lon_slice = slice(80, 135) + + datasets = [ds.sel(lat=lat_slice, lon=lon_slice) for ds in datasets] + + is_night_data = False + for ds in datasets: + if ds['SWDR'].max() == 0: + print(f"This is night (SWDR全为0). Skip:{group_name.strftime('%Y-%m-%d %H')}") + is_night_data = True + break + + if not is_night_data: + valid_times = [pd.to_datetime(ds.encoding['source'].split('_')[1] + ds.encoding['source'].split('_')[2], format='%Y%m%d%H%M') for ds in datasets] + + combined_ds = xr.concat( + [ds['SWDR'] for ds in datasets], + dim=pd.Index(valid_times, name='valid_time') + ) + + combined_ds = combined_ds.fillna(0) + + combined_ds = combined_ds.astype('float32') + + combined_ds = combined_ds.sortby('lat') + combined_ds = combined_ds.sortby('valid_time') + + combined_ds = combined_ds.rename({'lat': 'latitude', 'lon': 'longitude'}) + + output_filename = output_dir / f"H08_{group_name.strftime('%Y%m%d_%H%M')}_hourly.nc" + + combined_ds.to_netcdf(output_filename) + print(f"Saved H08_{group_name.strftime('%Y%m%d_%H%M')}_hourly") + except Exception as e: + print(f"Errors happen at {group_name.strftime('%Y-%m-%d %H')}: {e}") + + +data_directory = "path/to/original/data" + +year_to_process = [2016,2017,2018,2019] +for year in year_to_process: + process_year_data(data_directory, year) \ No newline at end of file diff --git a/examples/weather/corrdiff/train.py b/examples/weather/corrdiff/train.py index d56c367cc1..08bdf6eb38 100644 --- a/examples/weather/corrdiff/train.py +++ b/examples/weather/corrdiff/train.py @@ -131,7 +131,7 @@ def main(cfg: DictConfig) -> None: # Resolve and parse configs OmegaConf.resolve(cfg) dataset_cfg = OmegaConf.to_container(cfg.dataset) # TODO needs better handling - + # Register custom dataset if specified in config register_dataset(cfg.dataset.type) logger0.info(f"Using dataset: {cfg.dataset.type}") @@ -191,6 +191,7 @@ def main(cfg: DictConfig) -> None: # Parse image configuration & update model args dataset_channels = len(dataset.input_channels()) + solar = getattr(dataset, 'solar', False) #If we are doing solar downscaling img_in_channels = dataset_channels img_shape = dataset.image_shape() img_out_channels = len(dataset.output_channels()) @@ -541,9 +542,13 @@ def main(cfg: DictConfig) -> None: f"accumulation round {n_i}", color="Magenta" ): with nvtx.annotate("loading data", color="green"): - img_clean, img_lr, *lead_time_label = next( + batch = next( dataset_iterator ) + if solar: + img_clean, img_lr, windows, *lead_time_label = batch + else: + img_clean, img_lr, *lead_time_label = batch if use_apex_gn: img_clean = img_clean.to( dist.device,