-
Notifications
You must be signed in to change notification settings - Fork 485
feat: Integrate CorrDiffSolar model with MultiDiffusion to PhysicsNem… #1222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Greptile OverviewGreptile SummaryThis PR adds a comprehensive solar irradiance downscaling example to PhysicsNeMo, implementing a two-stage generative model (Regression + Diffusion) with MultiDiffusion for large-scale inference. Key Changes:
Critical Issue:
Architecture: Confidence Score: 2/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User
participant TrainScript as train.py
participant GenScript as generate.py
participant Dataset as SolarDataset
participant RegNet as Regression Network
participant DiffNet as Diffusion Network
participant MultiDiff as MultiDiffusion
Note over User,MultiDiff: Training Phase (Stage 1: Regression)
User->>TrainScript: torchrun train.py --config regression
TrainScript->>Dataset: Initialize SolarDataset
Dataset-->>TrainScript: Return dataset with windows
loop Training Loop
TrainScript->>Dataset: Get batch (img_clean, img_lr, windows)
Dataset-->>TrainScript: Return training data
TrainScript->>RegNet: Forward pass
RegNet-->>TrainScript: Predictions
TrainScript->>TrainScript: Compute loss & update weights
end
TrainScript-->>User: Save regression checkpoint
Note over User,MultiDiff: Training Phase (Stage 2: Diffusion)
User->>TrainScript: torchrun train.py --config diffusion
TrainScript->>RegNet: Load regression checkpoint
TrainScript->>Dataset: Initialize SolarDataset
loop Training Loop
TrainScript->>Dataset: Get batch
Dataset-->>TrainScript: Return data
TrainScript->>RegNet: Get regression baseline
RegNet-->>TrainScript: Base prediction
TrainScript->>DiffNet: Forward pass (learn residual)
DiffNet-->>TrainScript: Residual prediction
TrainScript->>TrainScript: Compute residual loss
end
TrainScript-->>User: Save diffusion checkpoint
Note over User,MultiDiff: Inference Phase
User->>GenScript: python generate.py --config inference
GenScript->>Dataset: Initialize SolarDataset (generating=True)
Dataset-->>GenScript: Return dataset with sliding windows
GenScript->>RegNet: Load regression checkpoint
GenScript->>DiffNet: Load diffusion checkpoint
loop For each time step
GenScript->>Dataset: Get batch (img_tar, img_lr, windows)
Dataset-->>GenScript: Return full-size data with windows
GenScript->>MultiDiff: generate_solar(img_lr, windows, nets)
loop For each window (Regression)
MultiDiff->>RegNet: regression_step(img_lr_patch)
RegNet-->>MultiDiff: Patch prediction
MultiDiff->>MultiDiff: Stitch patches together
end
MultiDiff->>MultiDiff: Average overlapping regions
alt Diffusion enabled
loop For each ensemble seed
loop For each diffusion step
loop For each window
MultiDiff->>DiffNet: Denoise patch
DiffNet-->>MultiDiff: Denoised patch
MultiDiff->>MultiDiff: Accumulate in value/count
end
MultiDiff->>MultiDiff: Average overlapping regions
end
MultiDiff->>MultiDiff: residual = diffusion_output
MultiDiff->>MultiDiff: final = regression + residual
end
end
MultiDiff-->>GenScript: Final high-res output
GenScript->>GenScript: Save to NetCDF
end
GenScript-->>User: Output: corrdiff_output.nc
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
13 files reviewed, 1 comment
| image_tar_full=image_tar, | ||
| windows=windows, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: missing comma after logger0 = logger0, causing syntax error
| image_tar_full=image_tar, | |
| windows=windows, | |
| logger0 = logger0, | |
| img_out_channels = img_out_channels, |
**Pull Request **
feat(examples): Add CorrDiffSolar for high-resolution solar downscaling
Pull Request 正文 (描述)
Hi PhysicsNeMo Team,
This Pull Request introduces a new, comprehensive example for high-resolution solar irradiance downscaling using a conditional diffusion model, named
CorrDiffSolar.This end-to-end example demonstrates a real-world climate science application, showcasing how to prepare complex datasets, train a two-stage generative model, and perform large-scale inference using techniques like
MultiDiffusion. It serves as a valuable use case for researchers interested in AI-based downscaling.Key Contributions
Regression + Diffusionpipeline to upscale low-resolution (0.25°) ERA5 data to high-resolution (0.05°) solar radiation fields.prepare_solar_data/to process raw ERA5 and Himawari-8 satellite data into a model-ready format, including DEM file generation.SolarDataset: Includes a specialized PyTorch Dataset (SolarDataset) for efficiently handling the paired low-res/high-res spatiotemporal data required for training.MultiDiffusion: The inference logic utilizes a sliding-window approach (MultiDiffusion) to generate predictions for large domains that do not fit in GPU memory, making the model scalable.Proposed Code Structure
All new code is self-contained within a new directory in the
examples/folder to ensure no disruption to the core library:Getting Started: A Quick Workflow Overview
A complete guide is available in
examples/solar_downscaling/README.md. The high-level steps are:Environment Setup:
Data Preparation:
prepare_solar_data/to process the data, generate the DEM, and compute statistics. This will create theHRdata/,LRdata/,dem.nc, andstats.jsonfiles required for training.Model Training (Two Stages):
Model Inference:
We believe this example will be a great addition to the PhysicsNeMo repository. Please let us know if you have any questions or feedback