|
38 | 38 | remove_small_objects, |
39 | 39 | ) |
40 | 40 | from monai.transforms.utils_pytorch_numpy_unification import unravel_index |
41 | | -from monai.utils import TransformBackends, convert_data_type, convert_to_tensor, ensure_tuple, look_up_option |
| 41 | +from monai.utils import ( |
| 42 | + TransformBackends, |
| 43 | + convert_data_type, |
| 44 | + convert_to_tensor, |
| 45 | + ensure_tuple, |
| 46 | + get_equivalent_dtype, |
| 47 | + look_up_option, |
| 48 | +) |
42 | 49 | from monai.utils.type_conversion import convert_to_dst_type |
43 | 50 |
|
44 | 51 | __all__ = [ |
|
54 | 61 | "SobelGradients", |
55 | 62 | "VoteEnsemble", |
56 | 63 | "Invert", |
| 64 | + "GenerateHeatmap", |
57 | 65 | "DistanceTransformEDT", |
58 | 66 | ] |
59 | 67 |
|
@@ -742,6 +750,154 @@ def __call__(self, img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayO |
742 | 750 | return self.post_convert(out_pt, img) |
743 | 751 |
|
744 | 752 |
|
| 753 | +class GenerateHeatmap(Transform): |
| 754 | + """ |
| 755 | + Generate per-landmark Gaussian heatmaps for 2D or 3D coordinates. |
| 756 | +
|
| 757 | + Notes: |
| 758 | + - Coordinates are interpreted in voxel units and expected in (Y, X) for 2D or (Z, Y, X) for 3D. |
| 759 | + - Target spatial_shape is (Y, X) for 2D and (Z, Y, X) for 3D. |
| 760 | + - Output layout uses channel-first convention with one channel per landmark. |
| 761 | + - Input points shape: (N, D) where N is number of landmarks, D is spatial dimensions (2 or 3). |
| 762 | + - Output heatmap shape: (N, Y, X) for 2D or (N, Z, Y, X) for 3D. |
| 763 | + - Each channel index corresponds to one landmark. |
| 764 | +
|
| 765 | + Args: |
| 766 | + sigma: gaussian standard deviation. A single value is broadcast across all spatial dimensions. |
| 767 | + spatial_shape: optional fallback spatial shape. If ``None`` it must be provided when calling the transform. |
| 768 | + truncated: extent, in multiples of ``sigma``, used to crop the gaussian support window. |
| 769 | + normalize: normalize every heatmap channel to ``[0, 1]`` when ``True``. |
| 770 | + dtype: target dtype for the generated heatmaps (accepts numpy or torch dtypes). |
| 771 | +
|
| 772 | + Raises: |
| 773 | + ValueError: when ``sigma`` is non-positive or ``spatial_shape`` cannot be resolved. |
| 774 | +
|
| 775 | + """ |
| 776 | + |
| 777 | + backend = [TransformBackends.NUMPY, TransformBackends.TORCH] |
| 778 | + |
| 779 | + def __init__( |
| 780 | + self, |
| 781 | + sigma: Sequence[float] | float = 5.0, |
| 782 | + spatial_shape: Sequence[int] | None = None, |
| 783 | + truncated: float = 4.0, |
| 784 | + normalize: bool = True, |
| 785 | + dtype: np.dtype | torch.dtype | type = np.float32, |
| 786 | + ) -> None: |
| 787 | + if isinstance(sigma, Sequence) and not isinstance(sigma, (str, bytes)): |
| 788 | + if any(s <= 0 for s in sigma): |
| 789 | + raise ValueError("Argument `sigma` values must be positive.") |
| 790 | + self._sigma = tuple(float(s) for s in sigma) |
| 791 | + else: |
| 792 | + if float(sigma) <= 0: |
| 793 | + raise ValueError("Argument `sigma` must be positive.") |
| 794 | + self._sigma = (float(sigma),) |
| 795 | + if truncated <= 0: |
| 796 | + raise ValueError("Argument `truncated` must be positive.") |
| 797 | + self.truncated = float(truncated) |
| 798 | + self.normalize = normalize |
| 799 | + self.torch_dtype = get_equivalent_dtype(dtype, torch.Tensor) |
| 800 | + self.numpy_dtype = get_equivalent_dtype(dtype, np.ndarray) |
| 801 | + # Validate that dtype is floating-point for meaningful Gaussian values |
| 802 | + if not self.torch_dtype.is_floating_point: |
| 803 | + raise ValueError(f"Argument `dtype` must be a floating-point type, got {self.torch_dtype}") |
| 804 | + self.spatial_shape = None if spatial_shape is None else tuple(int(s) for s in spatial_shape) |
| 805 | + |
| 806 | + def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None = None) -> NdarrayOrTensor: |
| 807 | + """ |
| 808 | + Args: |
| 809 | + points: landmark coordinates as ndarray/Tensor with shape (N, D), |
| 810 | + ordered as (Y, X) for 2D or (Z, Y, X) for 3D, where N is the number |
| 811 | + of landmarks and D is the spatial dimensionality. |
| 812 | + spatial_shape: spatial size as a sequence. If None, uses the value provided at construction. |
| 813 | +
|
| 814 | + Returns: |
| 815 | + Heatmaps with shape (N, *spatial), one channel per landmark. |
| 816 | +
|
| 817 | + Raises: |
| 818 | + ValueError: if points shape/dimension or spatial_shape is invalid. |
| 819 | + """ |
| 820 | + original_points = points |
| 821 | + points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False) |
| 822 | + |
| 823 | + if points_t.ndim != 2: |
| 824 | + raise ValueError( |
| 825 | + f"Argument `points` must be a 2D array with shape (num_points, spatial_dims), got shape {points_t.shape}." |
| 826 | + ) |
| 827 | + |
| 828 | + if points_t.shape[-1] not in (2, 3): |
| 829 | + raise ValueError("GenerateHeatmap only supports 2D or 3D landmarks.") |
| 830 | + |
| 831 | + device = points_t.device |
| 832 | + num_points, spatial_dims = points_t.shape |
| 833 | + |
| 834 | + target_shape = self._resolve_spatial_shape(spatial_shape, spatial_dims) |
| 835 | + sigma = self._resolve_sigma(spatial_dims) |
| 836 | + |
| 837 | + # Create sparse image with impulses at landmark locations |
| 838 | + heatmap = torch.zeros((num_points, *target_shape), dtype=self.torch_dtype, device=device) |
| 839 | + bounds_t = torch.as_tensor(target_shape, device=device, dtype=points_t.dtype) |
| 840 | + |
| 841 | + for idx, center in enumerate(points_t): |
| 842 | + if not torch.isfinite(center).all(): |
| 843 | + continue |
| 844 | + if not ((center >= 0).all() and (center < bounds_t).all()): |
| 845 | + continue |
| 846 | + # Round to nearest integer for impulse placement, then clamp to valid index range |
| 847 | + center_int = center.round().long() |
| 848 | + # Clamp indices to [0, size-1] to avoid out-of-bounds (e.g., 9.7 rounds to 10 in size-10 array) |
| 849 | + bounds_max = (bounds_t - 1).long() |
| 850 | + center_int = torch.minimum(torch.maximum(center_int, torch.zeros_like(center_int)), bounds_max) |
| 851 | + # Place impulse (use maximum in case of overlapping landmarks) |
| 852 | + current_val = heatmap[idx][tuple(center_int)] |
| 853 | + heatmap[idx][tuple(center_int)] = torch.maximum( |
| 854 | + current_val, torch.tensor(1.0, dtype=self.torch_dtype, device=device) |
| 855 | + ) |
| 856 | + |
| 857 | + # Apply Gaussian blur using GaussianFilter |
| 858 | + # Reshape to (num_points, 1, *spatial) for per-channel filtering |
| 859 | + heatmap_input = heatmap.unsqueeze(1) # Add channel dimension |
| 860 | + |
| 861 | + gaussian_filter = GaussianFilter( |
| 862 | + spatial_dims=spatial_dims, sigma=sigma, truncated=self.truncated, approx="erf", requires_grad=False |
| 863 | + ).to(device=device, dtype=self.torch_dtype) |
| 864 | + |
| 865 | + heatmap_blurred = gaussian_filter(heatmap_input) |
| 866 | + heatmap = heatmap_blurred.squeeze(1) # Remove channel dimension |
| 867 | + |
| 868 | + # Normalize per channel if requested |
| 869 | + if self.normalize: |
| 870 | + for idx in range(num_points): |
| 871 | + peak = heatmap[idx].amax() |
| 872 | + if peak > 0: |
| 873 | + heatmap[idx].div_(peak) |
| 874 | + |
| 875 | + target_dtype = self.torch_dtype if isinstance(original_points, (torch.Tensor, MetaTensor)) else self.numpy_dtype |
| 876 | + converted, _, _ = convert_to_dst_type(heatmap, original_points, dtype=target_dtype) |
| 877 | + return converted |
| 878 | + |
| 879 | + def _resolve_spatial_shape(self, call_shape: Sequence[int] | None, spatial_dims: int) -> tuple[int, ...]: |
| 880 | + shape = call_shape if call_shape is not None else self.spatial_shape |
| 881 | + if shape is None: |
| 882 | + raise ValueError("Argument `spatial_shape` must be provided either at construction time or call time.") |
| 883 | + shape_tuple = ensure_tuple(shape) |
| 884 | + if len(shape_tuple) != spatial_dims: |
| 885 | + if len(shape_tuple) == 1: |
| 886 | + shape_tuple = shape_tuple * spatial_dims # type: ignore |
| 887 | + else: |
| 888 | + raise ValueError( |
| 889 | + "Argument `spatial_shape` length must match the landmarks' spatial dims (or pass a single int to broadcast)." |
| 890 | + ) |
| 891 | + return tuple(int(s) for s in shape_tuple) |
| 892 | + |
| 893 | + def _resolve_sigma(self, spatial_dims: int) -> tuple[float, ...]: |
| 894 | + if len(self._sigma) == spatial_dims: |
| 895 | + return self._sigma |
| 896 | + if len(self._sigma) == 1: |
| 897 | + return self._sigma * spatial_dims |
| 898 | + raise ValueError("Argument `sigma` sequence length must equal the number of spatial dimensions.") |
| 899 | + |
| 900 | + |
745 | 901 | class ProbNMS(Transform): |
746 | 902 | """ |
747 | 903 | Performs probability based non-maximum suppression (NMS) on the probabilities map via |
|
0 commit comments