Skip to content

Commit 04a89ee

Browse files
eclipse0922pre-commit-ci[bot]coderabbitai[bot]ericspodKumoLiu
authored
Generate heatmap transforms (#8579)
Fixes #3328 . ### Description A few sentences describing the changes proposed in this pull request. This pull request introduces `GenerateHeatmap` and `GenerateHeatmapd` transforms for creating Gaussian heatmaps from landmark coordinates. The input points are currently expected in ZYX order, but this can be changed to support XYZ if preferred. The transforms ~support both batched (B, N, D) and~ only non-batched (N, D) inputs. Example notebooks are included for demonstration and will be removed before the PR is merged. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: sewon.jeon <[email protected]> Signed-off-by: sewon jeon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <[email protected]> Co-authored-by: YunLiu <[email protected]>
1 parent f493ecd commit 04a89ee

File tree

5 files changed

+854
-1
lines changed

5 files changed

+854
-1
lines changed

monai/transforms/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@
293293
AsDiscrete,
294294
DistanceTransformEDT,
295295
FillHoles,
296+
GenerateHeatmap,
296297
Invert,
297298
KeepLargestConnectedComponent,
298299
LabelFilter,
@@ -319,6 +320,9 @@
319320
FillHolesD,
320321
FillHolesd,
321322
FillHolesDict,
323+
GenerateHeatmapd,
324+
GenerateHeatmapD,
325+
GenerateHeatmapDict,
322326
InvertD,
323327
Invertd,
324328
InvertDict,

monai/transforms/post/array.py

Lines changed: 157 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,14 @@
3838
remove_small_objects,
3939
)
4040
from monai.transforms.utils_pytorch_numpy_unification import unravel_index
41-
from monai.utils import TransformBackends, convert_data_type, convert_to_tensor, ensure_tuple, look_up_option
41+
from monai.utils import (
42+
TransformBackends,
43+
convert_data_type,
44+
convert_to_tensor,
45+
ensure_tuple,
46+
get_equivalent_dtype,
47+
look_up_option,
48+
)
4249
from monai.utils.type_conversion import convert_to_dst_type
4350

4451
__all__ = [
@@ -54,6 +61,7 @@
5461
"SobelGradients",
5562
"VoteEnsemble",
5663
"Invert",
64+
"GenerateHeatmap",
5765
"DistanceTransformEDT",
5866
]
5967

@@ -742,6 +750,154 @@ def __call__(self, img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayO
742750
return self.post_convert(out_pt, img)
743751

744752

753+
class GenerateHeatmap(Transform):
754+
"""
755+
Generate per-landmark Gaussian heatmaps for 2D or 3D coordinates.
756+
757+
Notes:
758+
- Coordinates are interpreted in voxel units and expected in (Y, X) for 2D or (Z, Y, X) for 3D.
759+
- Target spatial_shape is (Y, X) for 2D and (Z, Y, X) for 3D.
760+
- Output layout uses channel-first convention with one channel per landmark.
761+
- Input points shape: (N, D) where N is number of landmarks, D is spatial dimensions (2 or 3).
762+
- Output heatmap shape: (N, Y, X) for 2D or (N, Z, Y, X) for 3D.
763+
- Each channel index corresponds to one landmark.
764+
765+
Args:
766+
sigma: gaussian standard deviation. A single value is broadcast across all spatial dimensions.
767+
spatial_shape: optional fallback spatial shape. If ``None`` it must be provided when calling the transform.
768+
truncated: extent, in multiples of ``sigma``, used to crop the gaussian support window.
769+
normalize: normalize every heatmap channel to ``[0, 1]`` when ``True``.
770+
dtype: target dtype for the generated heatmaps (accepts numpy or torch dtypes).
771+
772+
Raises:
773+
ValueError: when ``sigma`` is non-positive or ``spatial_shape`` cannot be resolved.
774+
775+
"""
776+
777+
backend = [TransformBackends.NUMPY, TransformBackends.TORCH]
778+
779+
def __init__(
780+
self,
781+
sigma: Sequence[float] | float = 5.0,
782+
spatial_shape: Sequence[int] | None = None,
783+
truncated: float = 4.0,
784+
normalize: bool = True,
785+
dtype: np.dtype | torch.dtype | type = np.float32,
786+
) -> None:
787+
if isinstance(sigma, Sequence) and not isinstance(sigma, (str, bytes)):
788+
if any(s <= 0 for s in sigma):
789+
raise ValueError("Argument `sigma` values must be positive.")
790+
self._sigma = tuple(float(s) for s in sigma)
791+
else:
792+
if float(sigma) <= 0:
793+
raise ValueError("Argument `sigma` must be positive.")
794+
self._sigma = (float(sigma),)
795+
if truncated <= 0:
796+
raise ValueError("Argument `truncated` must be positive.")
797+
self.truncated = float(truncated)
798+
self.normalize = normalize
799+
self.torch_dtype = get_equivalent_dtype(dtype, torch.Tensor)
800+
self.numpy_dtype = get_equivalent_dtype(dtype, np.ndarray)
801+
# Validate that dtype is floating-point for meaningful Gaussian values
802+
if not self.torch_dtype.is_floating_point:
803+
raise ValueError(f"Argument `dtype` must be a floating-point type, got {self.torch_dtype}")
804+
self.spatial_shape = None if spatial_shape is None else tuple(int(s) for s in spatial_shape)
805+
806+
def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None = None) -> NdarrayOrTensor:
807+
"""
808+
Args:
809+
points: landmark coordinates as ndarray/Tensor with shape (N, D),
810+
ordered as (Y, X) for 2D or (Z, Y, X) for 3D, where N is the number
811+
of landmarks and D is the spatial dimensionality.
812+
spatial_shape: spatial size as a sequence. If None, uses the value provided at construction.
813+
814+
Returns:
815+
Heatmaps with shape (N, *spatial), one channel per landmark.
816+
817+
Raises:
818+
ValueError: if points shape/dimension or spatial_shape is invalid.
819+
"""
820+
original_points = points
821+
points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False)
822+
823+
if points_t.ndim != 2:
824+
raise ValueError(
825+
f"Argument `points` must be a 2D array with shape (num_points, spatial_dims), got shape {points_t.shape}."
826+
)
827+
828+
if points_t.shape[-1] not in (2, 3):
829+
raise ValueError("GenerateHeatmap only supports 2D or 3D landmarks.")
830+
831+
device = points_t.device
832+
num_points, spatial_dims = points_t.shape
833+
834+
target_shape = self._resolve_spatial_shape(spatial_shape, spatial_dims)
835+
sigma = self._resolve_sigma(spatial_dims)
836+
837+
# Create sparse image with impulses at landmark locations
838+
heatmap = torch.zeros((num_points, *target_shape), dtype=self.torch_dtype, device=device)
839+
bounds_t = torch.as_tensor(target_shape, device=device, dtype=points_t.dtype)
840+
841+
for idx, center in enumerate(points_t):
842+
if not torch.isfinite(center).all():
843+
continue
844+
if not ((center >= 0).all() and (center < bounds_t).all()):
845+
continue
846+
# Round to nearest integer for impulse placement, then clamp to valid index range
847+
center_int = center.round().long()
848+
# Clamp indices to [0, size-1] to avoid out-of-bounds (e.g., 9.7 rounds to 10 in size-10 array)
849+
bounds_max = (bounds_t - 1).long()
850+
center_int = torch.minimum(torch.maximum(center_int, torch.zeros_like(center_int)), bounds_max)
851+
# Place impulse (use maximum in case of overlapping landmarks)
852+
current_val = heatmap[idx][tuple(center_int)]
853+
heatmap[idx][tuple(center_int)] = torch.maximum(
854+
current_val, torch.tensor(1.0, dtype=self.torch_dtype, device=device)
855+
)
856+
857+
# Apply Gaussian blur using GaussianFilter
858+
# Reshape to (num_points, 1, *spatial) for per-channel filtering
859+
heatmap_input = heatmap.unsqueeze(1) # Add channel dimension
860+
861+
gaussian_filter = GaussianFilter(
862+
spatial_dims=spatial_dims, sigma=sigma, truncated=self.truncated, approx="erf", requires_grad=False
863+
).to(device=device, dtype=self.torch_dtype)
864+
865+
heatmap_blurred = gaussian_filter(heatmap_input)
866+
heatmap = heatmap_blurred.squeeze(1) # Remove channel dimension
867+
868+
# Normalize per channel if requested
869+
if self.normalize:
870+
for idx in range(num_points):
871+
peak = heatmap[idx].amax()
872+
if peak > 0:
873+
heatmap[idx].div_(peak)
874+
875+
target_dtype = self.torch_dtype if isinstance(original_points, (torch.Tensor, MetaTensor)) else self.numpy_dtype
876+
converted, _, _ = convert_to_dst_type(heatmap, original_points, dtype=target_dtype)
877+
return converted
878+
879+
def _resolve_spatial_shape(self, call_shape: Sequence[int] | None, spatial_dims: int) -> tuple[int, ...]:
880+
shape = call_shape if call_shape is not None else self.spatial_shape
881+
if shape is None:
882+
raise ValueError("Argument `spatial_shape` must be provided either at construction time or call time.")
883+
shape_tuple = ensure_tuple(shape)
884+
if len(shape_tuple) != spatial_dims:
885+
if len(shape_tuple) == 1:
886+
shape_tuple = shape_tuple * spatial_dims # type: ignore
887+
else:
888+
raise ValueError(
889+
"Argument `spatial_shape` length must match the landmarks' spatial dims (or pass a single int to broadcast)."
890+
)
891+
return tuple(int(s) for s in shape_tuple)
892+
893+
def _resolve_sigma(self, spatial_dims: int) -> tuple[float, ...]:
894+
if len(self._sigma) == spatial_dims:
895+
return self._sigma
896+
if len(self._sigma) == 1:
897+
return self._sigma * spatial_dims
898+
raise ValueError("Argument `sigma` sequence length must equal the number of spatial dimensions.")
899+
900+
745901
class ProbNMS(Transform):
746902
"""
747903
Performs probability based non-maximum suppression (NMS) on the probabilities map via

0 commit comments

Comments
 (0)