-
Notifications
You must be signed in to change notification settings - Fork 1.3k
BUGFIX: support NDHWC input in sliding_window_inference and DiceMetric #8550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -134,6 +134,12 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Raises: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ValueError: when `y_pred` has fewer than three dimensions. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if isinstance(y_pred, torch.Tensor) and y_pred.ndim == 5 and y_pred.shape[-1] in (1, 3, 4): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
y_pred = y_pred.permute(0, 4, 1, 2, 3).contiguous() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if isinstance(y, torch.Tensor) and y.ndim == 5 and y.shape[-1] in (1, 3, 4): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
y = y.permute(0, 4, 1, 2, 3).contiguous() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+138
to
+142
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Dice layout normalization is unsafe and incomplete; add NHWC support, ambiguity checks, and label-map/one-hot consistency enforcement. Current rule permutes whenever last dim ∈ {1,3,4} for 5D only. This misfires for valid NCDHW with W ∈ {1,3,4}, ignores 4D NHWC, and doesn’t fail fast on ambiguous shapes or label-map vs one-hot mismatches when num_classes is None (per PR goals). Replace with robust normalization using num_classes as an authoritative hint, NHWC/NDHWC support, and explicit errors: - if isinstance(y_pred, torch.Tensor) and y_pred.ndim == 5 and y_pred.shape[-1] in (1, 3, 4):
- y_pred = y_pred.permute(0, 4, 1, 2, 3).contiguous()
- if isinstance(y, torch.Tensor) and y.ndim == 5 and y.shape[-1] in (1, 3, 4):
- y = y.permute(0, 4, 1, 2, 3).contiguous()
+ # Normalize to channel-first; handle NHWC/NDHWC and fail fast on ambiguity.
+ def _norm(t: torch.Tensor, name: str) -> torch.Tensor:
+ if t.ndim not in (4, 5):
+ return t
+ c2, cl = t.shape[1], t.shape[-1]
+ # num_classes is authoritative when provided
+ if self.num_classes is not None:
+ if c2 in (self.num_classes, 1):
+ return t
+ if cl in (self.num_classes, 1):
+ return t.permute(0, t.ndim - 1, *range(1, t.ndim - 1)).contiguous()
+ raise ValueError(
+ f"{name}: cannot infer channel dimension with num_classes={self.num_classes}: "
+ f"dim1={c2}, dim-1={cl}."
+ )
+ # Heuristic: prefer the side where channels > 1 and the other side == 1
+ if c2 > 1 and cl == 1:
+ return t # NCHW[D]
+ if cl > 1 and c2 == 1:
+ return t.permute(0, t.ndim - 1, *range(1, t.ndim - 1)).contiguous() # NHWC/NDHWC
+ # Ambiguous (both >1 or equal small values) -> fail fast
+ if (c2 > 1 and cl > 1) or (c2 == cl and c2 in (1, 2, 3, 4)):
+ raise ValueError(
+ f"{name}: ambiguous channel dimension (dim1={c2}, dim-1={cl}). "
+ "Set num_classes explicitly or reorder the inputs."
+ )
+ return t
+
+ y_pred = _norm(y_pred, "y_pred")
+ y = _norm(y, "y")
+ # Inconsistent forms require num_classes
+ if self.num_classes is None and ((y_pred.shape[1] == 1) ^ (y.shape[1] == 1)):
+ raise ValueError(
+ "Inconsistent inputs: label-map vs one-hot but num_classes is None. "
+ "Provide num_classes to disambiguate."
+ ) Also clarify the DiceMetric docstring to state supported input layouts (NCHW[D], NHWC/NDHWC) and the explicit error conditions above. 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dims = y_pred.ndimension() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if dims < 3: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
raise ValueError(f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
False-positive NDHWC detection: will permute valid NCDHW when W in {1,2,3,4}.
inputs.shape[-1] ∈ {1,3,4} is not sufficient; legitimate channel-first volumes often have a spatial size of 1/3/4 (single-slice, RGB-like W=3, etc.). This silently reorders NCDHW → NCW... and breaks inference. Also missing NHWC (4D) handling and ambiguous-shape fail-fast.
Apply a safer heuristic with ambiguity checks and 2D support:
Also update the docstring (Args/Note) to explicitly state that NHWC/NDHWC inputs are accepted but are normalized to channel-first before processing.
📝 Committable suggestion
🤖 Prompt for AI Agents