Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [1.3.0a0] - 2025-11-10

### Added

- Added 3D-EDMPrecond wrapper class with 3D version of SongUnet for diffusion models:
- `SongUNet3D`: 3D U-Net diffusion backbone extending DDPM++ and NCSN++ to volumetric data
- New test to cover the SongUnet3D
- `EDMPrecond3D`: 3D preconditioning wrapper for volumetric diffusion models

## [1.3.0a0] - 2025-XX-YY
Comment on lines +9 to 18
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: duplicate version header - there are two ## [1.3.0a0] sections (lines 9 and 18), which should be merged into one

Suggested change
## [1.3.0a0] - 2025-11-10
### Added
- Added 3D-EDMPrecond wrapper class with 3D version of SongUnet for diffusion models:
- `SongUNet3D`: 3D U-Net diffusion backbone extending DDPM++ and NCSN++ to volumetric data
- New test to cover the SongUnet3D
- `EDMPrecond3D`: 3D preconditioning wrapper for volumetric diffusion models
## [1.3.0a0] - 2025-XX-YY
## [1.3.0a0] - 2025-XX-YY
### Added
- Added 3D-EDMPrecond wrapper class with 3D version of SongUnet for diffusion models:
- `SongUNet3D`: 3D U-Net diffusion backbone extending DDPM++ and NCSN++ to volumetric data
- New test to cover the SongUnet3D
- `EDMPrecond3D`: 3D preconditioning wrapper for volumetric diffusion models


### Added
Expand Down
2 changes: 2 additions & 0 deletions physicsnemo/models/diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@
UNetBlock,
)
from .song_unet import SongUNet, SongUNetPosEmbd, SongUNetPosLtEmbd
from .song_unet3d import SongUNet3D
from .dhariwal_unet import DhariwalUNet
from .unet import UNet, StormCastUNet
from .preconditioning import (
EDMPrecond,
EDMPrecond3D,
EDMPrecondSuperResolution,
EDMPrecondSR,
VEPrecond,
Expand Down
105 changes: 105 additions & 0 deletions physicsnemo/models/diffusion/preconditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,111 @@ def round_sigma(sigma: Union[float, List, torch.Tensor]):
return torch.as_tensor(sigma)


@dataclass
class EDMPrecond3DMetaData(ModelMetaData):
"""EDMPrecond meta data"""

name: str = "EDMPrecond3D"
# Optimization
jit: bool = False
cuda_graphs: bool = False
amp_cpu: bool = False
amp_gpu: bool = True
torch_fx: bool = False
# Data type
bf16: bool = False
# Inference
onnx: bool = False
# Physics informed
func_torch: bool = False
auto_grad: bool = False


class EDMPrecond3D(EDMPrecond):
"""
Apply EDM preconditioning to denoise a 3D volumetric input.

Parameters
----------
x : torch.Tensor
Noisy volumetric input of shape (B, C, D, H, W) where B is batch size,
C is channels, and D, H, W are spatial dimensions.
sigma : torch.Tensor
Noise level(s) of shape (B,) or (B, 1).
condition : torch.Tensor, optional
Additional conditioning input to concatenate along channel dimension.
Must have shape (B, C_cond, D, H, W), by default None.
class_labels : torch.Tensor, optional
Class labels for conditional generation of shape (B, label_dim).
If None and label_dim > 0, zero labels are used, by default None.
force_fp32 : bool, optional
Force FP32 precision regardless of `use_fp16` setting, by default False.
**model_kwargs : dict
Additional keyword arguments passed to the underlying model's forward method.

Returns
-------
torch.Tensor
Denoised volumetric output of shape (B, C, D, H, W).

Raises
------
ValueError
If the model output dtype doesn't match the expected dtype (when not
using autocast).
"""

def forward(
self,
x,
sigma,
condition=None,
class_labels=None,
force_fp32=False,
**model_kwargs,
):
x = x.to(torch.float32)
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1, 1)

class_labels = (
None
if self.label_dim == 0
else torch.zeros([1, self.label_dim], device=x.device)
if class_labels is None
else class_labels.to(torch.float32).reshape(-1, self.label_dim)
)
dtype = (
torch.float16
if (self.use_fp16 and not force_fp32 and x.device.type == "cuda")
else torch.float32
)

c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
c_noise = sigma.log() / 4

arg = c_in * x

if condition is not None:
arg = torch.cat([arg, condition], dim=1)

F_x = self.model(
arg.to(dtype),
c_noise.flatten(),
class_labels=class_labels,
**model_kwargs,
)

if (F_x.dtype != dtype) and not torch.is_autocast_enabled():
raise ValueError(
f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead."
)

D_x = c_skip * x + c_out * F_x.to(torch.float32)
return D_x


@dataclass
class EDMPrecondSuperResolutionMetaData(ModelMetaData):
"""EDMPrecondSuperResolution meta data"""
Expand Down
Loading