Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [1.3.0a0] - 2025-11-10

### Added

- Added 3D-EDMPrecond wrapper class with 3D version of SongUnet for diffusion models:
- `SongUNet3D`: 3D U-Net diffusion backbone extending DDPM++ and NCSN++ to volumetric data
- New test to cover the SongUnet3D
- `EDMPrecond3D`: 3D preconditioning wrapper for volumetric diffusion models

## [1.3.0a0] - 2025-XX-YY
Comment on lines +9 to 18
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: duplicate version header - there are two ## [1.3.0a0] sections (lines 9 and 18), which should be merged into one

Suggested change
## [1.3.0a0] - 2025-11-10
### Added
- Added 3D-EDMPrecond wrapper class with 3D version of SongUnet for diffusion models:
- `SongUNet3D`: 3D U-Net diffusion backbone extending DDPM++ and NCSN++ to volumetric data
- New test to cover the SongUnet3D
- `EDMPrecond3D`: 3D preconditioning wrapper for volumetric diffusion models
## [1.3.0a0] - 2025-XX-YY
## [1.3.0a0] - 2025-XX-YY
### Added
- Added 3D-EDMPrecond wrapper class with 3D version of SongUnet for diffusion models:
- `SongUNet3D`: 3D U-Net diffusion backbone extending DDPM++ and NCSN++ to volumetric data
- New test to cover the SongUnet3D
- `EDMPrecond3D`: 3D preconditioning wrapper for volumetric diffusion models


### Added
Expand Down
2 changes: 2 additions & 0 deletions physicsnemo/models/diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@
UNetBlock,
)
from .song_unet import SongUNet, SongUNetPosEmbd, SongUNetPosLtEmbd
from .song_unet3d import SongUNet3D
from .dhariwal_unet import DhariwalUNet
from .unet import UNet, StormCastUNet
from .preconditioning import (
EDMPrecond,
EDMPrecond3D,
EDMPrecondSuperResolution,
EDMPrecondSR,
VEPrecond,
Expand Down
105 changes: 105 additions & 0 deletions physicsnemo/models/diffusion/preconditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,111 @@ def round_sigma(sigma: Union[float, List, torch.Tensor]):
return torch.as_tensor(sigma)


@dataclass
class EDMPrecond3DMetaData(ModelMetaData):
"""EDMPrecond meta data"""

name: str = "EDMPrecond3D"
# Optimization
jit: bool = False
cuda_graphs: bool = False
amp_cpu: bool = False
amp_gpu: bool = True
torch_fx: bool = False
# Data type
bf16: bool = False
# Inference
onnx: bool = False
# Physics informed
func_torch: bool = False
auto_grad: bool = False


class EDMPrecond3D(EDMPrecond):
"""
Apply EDM preconditioning to denoise a 3D volumetric input.

Parameters
----------
x : torch.Tensor
Noisy volumetric input of shape (B, C, D, H, W) where B is batch size,
C is channels, and D, H, W are spatial dimensions.
sigma : torch.Tensor
Noise level(s) of shape (B,) or (B, 1).
condition : torch.Tensor, optional
Additional conditioning input to concatenate along channel dimension.
Must have shape (B, C_cond, D, H, W), by default None.
class_labels : torch.Tensor, optional
Class labels for conditional generation of shape (B, label_dim).
If None and label_dim > 0, zero labels are used, by default None.
force_fp32 : bool, optional
Force FP32 precision regardless of `use_fp16` setting, by default False.
**model_kwargs : dict
Additional keyword arguments passed to the underlying model's forward method.

Returns
-------
torch.Tensor
Denoised volumetric output of shape (B, C, D, H, W).

Raises
------
ValueError
If the model output dtype doesn't match the expected dtype (when not
using autocast).
"""

def forward(
self,
x,
sigma,
condition=None,
class_labels=None,
force_fp32=False,
**model_kwargs,
):
x = x.to(torch.float32)
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1, 1)

class_labels = (
None
if self.label_dim == 0
else torch.zeros([1, self.label_dim], device=x.device)
if class_labels is None
else class_labels.to(torch.float32).reshape(-1, self.label_dim)
)
dtype = (
torch.float16
if (self.use_fp16 and not force_fp32 and x.device.type == "cuda")
else torch.float32
)

c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
c_noise = sigma.log() / 4

arg = c_in * x

if condition is not None:
arg = torch.cat([arg, condition], dim=1)

F_x = self.model(
arg.to(dtype),
c_noise.flatten(),
class_labels=class_labels,
**model_kwargs,
)

if (F_x.dtype != dtype) and not torch.is_autocast_enabled():
raise ValueError(
f"Expected the dtype to be {dtype}, but got {F_x.dtype} instead."
)

D_x = c_skip * x + c_out * F_x.to(torch.float32)
return D_x


@dataclass
class EDMPrecondSuperResolutionMetaData(ModelMetaData):
"""EDMPrecondSuperResolution meta data"""
Expand Down
Loading