Skip to content

Commit 0dbcb85

Browse files
committed
Orientation transform checks for space of a metatensor
1 parent c3a317d commit 0dbcb85

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

monai/transforms/spatial/array.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
GridSamplePadMode,
6565
InterpolateMode,
6666
NumpyPadMode,
67+
SpaceKeys,
6768
convert_to_cupy,
6869
convert_to_dst_type,
6970
convert_to_numpy,
@@ -560,7 +561,7 @@ def __init__(
560561
self,
561562
axcodes: str | None = None,
562563
as_closest_canonical: bool = False,
563-
labels: Sequence[tuple[str, str]] | None = (("L", "R"), ("P", "A"), ("I", "S")),
564+
labels: Sequence[tuple[str, str]] | None = None,
564565
lazy: bool = False,
565566
) -> None:
566567
"""
@@ -573,7 +574,9 @@ def __init__(
573574
as_closest_canonical: if True, load the image as closest to canonical axis format.
574575
labels: optional, None or sequence of (2,) sequences
575576
(2,) sequences are labels for (beginning, end) of output axis.
576-
Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``.
577+
Defaults to using the ``"space"`` attribute of a metatensor,
578+
where appliable, or (('L', 'R'), ('P', 'A'), ('I', 'S'))``
579+
otherwise (i.e. for plain tensors).
577580
lazy: a flag to indicate whether this transform should execute lazily or not.
578581
Defaults to False
579582
@@ -619,9 +622,15 @@ def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch.
619622
raise ValueError(f"data_array must have at least one spatial dimension, got {spatial_shape}.")
620623
affine_: np.ndarray
621624
affine_np: np.ndarray
625+
labels = self.labels
622626
if isinstance(data_array, MetaTensor):
623627
affine_np, *_ = convert_data_type(data_array.peek_pending_affine(), np.ndarray)
624628
affine_ = to_affine_nd(sr, affine_np)
629+
630+
# Set up "labels" such that LPS tensors are handled correctly by default
631+
if self.labels is None and SpaceKeys(data_array.meta["space"]) == SpaceKeys.LPS:
632+
labels = (("R", "L"), ("A", "P"), ("I", "S")) # value for LPS
633+
625634
else:
626635
warnings.warn("`data_array` is not of type `MetaTensor, assuming affine to be identity.")
627636
# default to identity
@@ -640,7 +649,7 @@ def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch.
640649
f"{self.__class__.__name__}: spatial shape = {spatial_shape}, channels = {data_array.shape[0]},"
641650
"please make sure the input is in the channel-first format."
642651
)
643-
dst = nib.orientations.axcodes2ornt(self.axcodes[:sr], labels=self.labels)
652+
dst = nib.orientations.axcodes2ornt(self.axcodes[:sr], labels=labels)
644653
if len(dst) < sr:
645654
raise ValueError(
646655
f"axcodes must match data_array spatially, got axcodes={len(self.axcodes)}D data_array={sr}D"
@@ -653,8 +662,18 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
653662
transform = self.pop_transform(data)
654663
# Create inverse transform
655664
orig_affine = transform[TraceKeys.EXTRA_INFO]["original_affine"]
656-
orig_axcodes = nib.orientations.aff2axcodes(orig_affine)
657-
inverse_transform = Orientation(axcodes=orig_axcodes, as_closest_canonical=False, labels=self.labels)
665+
labels = self.labels
666+
667+
# Set up "labels" such that LPS tensors are handled correctly by default
668+
if (
669+
isinstance(data, MetaTensor) and
670+
self.labels is None and
671+
SpaceKeys(data.meta["space"]) == SpaceKeys.LPS
672+
):
673+
labels = (("R", "L"), ("A", "P"), ("I", "S")) # value for LPS
674+
675+
orig_axcodes = nib.orientations.aff2axcodes(orig_affine, labels=labels)
676+
inverse_transform = Orientation(axcodes=orig_axcodes, as_closest_canonical=False, labels=labels)
658677
# Apply inverse
659678
with inverse_transform.trace_transform(False):
660679
data = inverse_transform(data)

monai/transforms/spatial/dictionary.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ def __init__(
550550
keys: KeysCollection,
551551
axcodes: str | None = None,
552552
as_closest_canonical: bool = False,
553-
labels: Sequence[tuple[str, str]] | None = (("L", "R"), ("P", "A"), ("I", "S")),
553+
labels: Sequence[tuple[str, str]] | None = None,
554554
allow_missing_keys: bool = False,
555555
lazy: bool = False,
556556
) -> None:
@@ -564,7 +564,9 @@ def __init__(
564564
as_closest_canonical: if True, load the image as closest to canonical axis format.
565565
labels: optional, None or sequence of (2,) sequences
566566
(2,) sequences are labels for (beginning, end) of output axis.
567-
Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``.
567+
Defaults to using the ``"space"`` attribute of a metatensor,
568+
where appliable, or (('L', 'R'), ('P', 'A'), ('I', 'S'))``
569+
otherwise (i.e. for plain tensors).
568570
allow_missing_keys: don't raise exception if key is missing.
569571
lazy: a flag to indicate whether this transform should execute lazily or not.
570572
Defaults to False

0 commit comments

Comments
 (0)