Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions monai/inferers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def sliding_window_inference(
*args: Any,
**kwargs: Any,
) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]:

"""
Sliding window inference on `inputs` with `predictor`.

Expand Down Expand Up @@ -134,6 +135,14 @@ def sliding_window_inference(
- input must be channel-first and have a batch dim, supports N-D sliding window.

"""

# auto transform (N,D,H,W,C) → (N,C,D,H,W)
if isinstance(inputs, torch.Tensor) and inputs.ndim == 5 and inputs.shape[-1] in (1, 3, 4):
inputs = inputs.permute(0, 4, 1, 2, 3).contiguous()
Comment on lines +139 to +141
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

False-positive NDHWC detection: will permute valid NCDHW when W in {1,2,3,4}.

inputs.shape[-1] ∈ {1,3,4} is not sufficient; legitimate channel-first volumes often have a spatial size of 1/3/4 (single-slice, RGB-like W=3, etc.). This silently reorders NCDHW → NCW... and breaks inference. Also missing NHWC (4D) handling and ambiguous-shape fail-fast.

Apply a safer heuristic with ambiguity checks and 2D support:

-    # auto transform (N,D,H,W,C) → (N,C,D,H,W)
-    if isinstance(inputs, torch.Tensor) and inputs.ndim == 5 and inputs.shape[-1] in (1, 3, 4):
-        inputs = inputs.permute(0, 4, 1, 2, 3).contiguous()
+    # Heuristic channel-last -> channel-first normalization with ambiguity guard.
+    if isinstance(inputs, torch.Tensor):
+        if inputs.ndim == 5:  # NDHWC or NCDHW
+            c2, cl = inputs.shape[1], inputs.shape[-1]
+            if cl in (1, 2, 3, 4) and c2 not in (1, 2, 3, 4):
+                inputs = inputs.permute(0, 4, 1, 2, 3).contiguous()
+            elif cl in (1, 2, 3, 4) and c2 in (1, 2, 3, 4) and cl != c2:
+                raise ValueError(
+                    f"Ambiguous channel dimension: dim=1 ({c2}) vs dim=-1 ({cl}). "
+                    "Please reorder explicitly to channel-first."
+                )
+        elif inputs.ndim == 4:  # NHWC or NCHW
+            c2, cl = inputs.shape[1], inputs.shape[-1]
+            if cl in (1, 2, 3, 4) and c2 not in (1, 2, 3, 4):
+                inputs = inputs.permute(0, 3, 1, 2).contiguous()
+            elif cl in (1, 2, 3, 4) and c2 in (1, 2, 3, 4) and cl != c2:
+                raise ValueError(
+                    f"Ambiguous channel dimension: dim=1 ({c2}) vs dim=-1 ({cl}). "
+                    "Please reorder explicitly to channel-first."
+                )

Also update the docstring (Args/Note) to explicitly state that NHWC/NDHWC inputs are accepted but are normalized to channel-first before processing.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# auto transform (N,D,H,W,C) → (N,C,D,H,W)
if isinstance(inputs, torch.Tensor) and inputs.ndim == 5 and inputs.shape[-1] in (1, 3, 4):
inputs = inputs.permute(0, 4, 1, 2, 3).contiguous()
# Heuristic channel-last -> channel-first normalization with ambiguity guard.
if isinstance(inputs, torch.Tensor):
if inputs.ndim == 5: # NDHWC or NCDHW
c2, cl = inputs.shape[1], inputs.shape[-1]
if cl in (1, 2, 3, 4) and c2 not in (1, 2, 3, 4):
inputs = inputs.permute(0, 4, 1, 2, 3).contiguous()
elif cl in (1, 2, 3, 4) and c2 in (1, 2, 3, 4) and cl != c2:
raise ValueError(
f"Ambiguous channel dimension: dim=1 ({c2}) vs dim=-1 ({cl}). "
"Please reorder explicitly to channel-first."
)
elif inputs.ndim == 4: # NHWC or NCHW
c2, cl = inputs.shape[1], inputs.shape[-1]
if cl in (1, 2, 3, 4) and c2 not in (1, 2, 3, 4):
inputs = inputs.permute(0, 3, 1, 2).contiguous()
elif cl in (1, 2, 3, 4) and c2 in (1, 2, 3, 4) and cl != c2:
raise ValueError(
f"Ambiguous channel dimension: dim=1 ({c2}) vs dim=-1 ({cl}). "
"Please reorder explicitly to channel-first."
)
🤖 Prompt for AI Agents
In monai/inferers/utils.py around lines 139-141, the current NDHWC detection
uses inputs.shape[-1] in (1,3,4) which misidentifies valid NCDHW volumes when a
spatial dimension equals 1/3/4 and silently permutes them; replace this with a
safer heuristic: for 5D tensors, first check if the channel dimension is already
channel-first by testing inputs.shape[1] ∈ {1,3,4} and only permute when the
last dim looks like channels and the second dim does not; for 4D tensors apply
the analogous NHWC→NCHW logic; if both candidate dims look like channels
(ambiguous) raise a clear ValueError asking the caller to provide channel_last
flag or reshape explicitly; update the function docstring Args/Note to state
accepted input formats (NCDHW, NDHWC, NCHW, NHWC), that inputs will be
normalized to channel-first, and mention the ambiguity error and how to resolve
it.





buffered = buffer_steps is not None and buffer_steps > 0
num_spatial_dims = len(inputs.shape) - 2
if buffered:
Expand Down
6 changes: 6 additions & 0 deletions monai/metrics/meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
Raises:
ValueError: when `y_pred` has fewer than three dimensions.
"""

if isinstance(y_pred, torch.Tensor) and y_pred.ndim == 5 and y_pred.shape[-1] in (1, 3, 4):
y_pred = y_pred.permute(0, 4, 1, 2, 3).contiguous()
if isinstance(y, torch.Tensor) and y.ndim == 5 and y.shape[-1] in (1, 3, 4):
y = y.permute(0, 4, 1, 2, 3).contiguous()

Comment on lines +138 to +142
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Dice layout normalization is unsafe and incomplete; add NHWC support, ambiguity checks, and label-map/one-hot consistency enforcement.

Current rule permutes whenever last dim ∈ {1,3,4} for 5D only. This misfires for valid NCDHW with W ∈ {1,3,4}, ignores 4D NHWC, and doesn’t fail fast on ambiguous shapes or label-map vs one-hot mismatches when num_classes is None (per PR goals).

Replace with robust normalization using num_classes as an authoritative hint, NHWC/NDHWC support, and explicit errors:

-        if isinstance(y_pred, torch.Tensor) and y_pred.ndim == 5 and y_pred.shape[-1] in (1, 3, 4):
-            y_pred = y_pred.permute(0, 4, 1, 2, 3).contiguous()
-        if isinstance(y, torch.Tensor) and y.ndim == 5 and y.shape[-1] in (1, 3, 4):
-            y = y.permute(0, 4, 1, 2, 3).contiguous()
+        # Normalize to channel-first; handle NHWC/NDHWC and fail fast on ambiguity.
+        def _norm(t: torch.Tensor, name: str) -> torch.Tensor:
+            if t.ndim not in (4, 5):
+                return t
+            c2, cl = t.shape[1], t.shape[-1]
+            # num_classes is authoritative when provided
+            if self.num_classes is not None:
+                if c2 in (self.num_classes, 1):
+                    return t
+                if cl in (self.num_classes, 1):
+                    return t.permute(0, t.ndim - 1, *range(1, t.ndim - 1)).contiguous()
+                raise ValueError(
+                    f"{name}: cannot infer channel dimension with num_classes={self.num_classes}: "
+                    f"dim1={c2}, dim-1={cl}."
+                )
+            # Heuristic: prefer the side where channels > 1 and the other side == 1
+            if c2 > 1 and cl == 1:
+                return t  # NCHW[D]
+            if cl > 1 and c2 == 1:
+                return t.permute(0, t.ndim - 1, *range(1, t.ndim - 1)).contiguous()  # NHWC/NDHWC
+            # Ambiguous (both >1 or equal small values) -> fail fast
+            if (c2 > 1 and cl > 1) or (c2 == cl and c2 in (1, 2, 3, 4)):
+                raise ValueError(
+                    f"{name}: ambiguous channel dimension (dim1={c2}, dim-1={cl}). "
+                    "Set num_classes explicitly or reorder the inputs."
+                )
+            return t
+
+        y_pred = _norm(y_pred, "y_pred")
+        y = _norm(y, "y")
+        # Inconsistent forms require num_classes
+        if self.num_classes is None and ((y_pred.shape[1] == 1) ^ (y.shape[1] == 1)):
+            raise ValueError(
+                "Inconsistent inputs: label-map vs one-hot but num_classes is None. "
+                "Provide num_classes to disambiguate."
+            )

Also clarify the DiceMetric docstring to state supported input layouts (NCHW[D], NHWC/NDHWC) and the explicit error conditions above.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if isinstance(y_pred, torch.Tensor) and y_pred.ndim == 5 and y_pred.shape[-1] in (1, 3, 4):
y_pred = y_pred.permute(0, 4, 1, 2, 3).contiguous()
if isinstance(y, torch.Tensor) and y.ndim == 5 and y.shape[-1] in (1, 3, 4):
y = y.permute(0, 4, 1, 2, 3).contiguous()
# Normalize to channel-first; handle NHWC/NDHWC and fail fast on ambiguity.
def _norm(t: torch.Tensor, name: str) -> torch.Tensor:
if t.ndim not in (4, 5):
return t
c2, cl = t.shape[1], t.shape[-1]
# num_classes is authoritative when provided
if self.num_classes is not None:
if c2 in (self.num_classes, 1):
return t
if cl in (self.num_classes, 1):
return t.permute(0, t.ndim - 1, *range(1, t.ndim - 1)).contiguous()
raise ValueError(
f"{name}: cannot infer channel dimension with num_classes={self.num_classes}: "
f"dim1={c2}, dim-1={cl}."
)
# Heuristic: prefer the side where channels > 1 and the other side == 1
if c2 > 1 and cl == 1:
return t # NCHW[D]
if cl > 1 and c2 == 1:
return t.permute(0, t.ndim - 1, *range(1, t.ndim - 1)).contiguous() # NHWC/NDHWC
# Ambiguous (both >1 or equal small values) -> fail fast
if (c2 > 1 and cl > 1) or (c2 == cl and c2 in (1, 2, 3, 4)):
raise ValueError(
f"{name}: ambiguous channel dimension (dim1={c2}, dim-1={cl}). "
"Set num_classes explicitly or reorder the inputs."
)
return t
y_pred = _norm(y_pred, "y_pred")
y = _norm(y, "y")
# Inconsistent forms require num_classes
if self.num_classes is None and ((y_pred.shape[1] == 1) ^ (y.shape[1] == 1)):
raise ValueError(
"Inconsistent inputs: label-map vs one-hot but num_classes is None. "
"Provide num_classes to disambiguate."
)
🤖 Prompt for AI Agents
In monai/metrics/meandice.py around lines 138 to 142, the current NHWC-to-NCHW
permutation logic is unsafe: it blindly permutes 5D tensors when the last dim is
1/3/4, ignores 4D NHWC, and doesn't use num_classes to disambiguate label-map vs
one-hot inputs. Replace this block with a robust normalization routine that:
detects and supports NHWC (4D) and NDHWC/NDCHW (5D) layouts, uses num_classes as
an authoritative hint to decide whether the channel-last dimension is classes
(one-hot) or spatial (width), enforces consistency between num_classes and
channel size (raise ValueError on ambiguity), converts any NHWC/NDHWC input to
NCHW/NDCHW via explicit permute only after checks, and raises clear errors for
ambiguous shapes or mismatched label-map vs one-hot inputs; also update the
DiceMetric docstring to list supported layouts (NCHW[D], NHWC/NDHWC) and
enumerate the explicit error conditions.

dims = y_pred.ndimension()
if dims < 3:
raise ValueError(f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}.")
Expand Down
Loading