64
64
GridSamplePadMode ,
65
65
InterpolateMode ,
66
66
NumpyPadMode ,
67
+ SpaceKeys ,
67
68
convert_to_cupy ,
68
69
convert_to_dst_type ,
69
70
convert_to_numpy ,
@@ -560,7 +561,7 @@ def __init__(
560
561
self ,
561
562
axcodes : str | None = None ,
562
563
as_closest_canonical : bool = False ,
563
- labels : Sequence [tuple [str , str ]] | None = (( "L" , "R" ), ( "P" , "A" ), ( "I" , "S" )) ,
564
+ labels : Sequence [tuple [str , str ]] | None = None ,
564
565
lazy : bool = False ,
565
566
) -> None :
566
567
"""
@@ -573,7 +574,9 @@ def __init__(
573
574
as_closest_canonical: if True, load the image as closest to canonical axis format.
574
575
labels: optional, None or sequence of (2,) sequences
575
576
(2,) sequences are labels for (beginning, end) of output axis.
576
- Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``.
577
+ Defaults to using the ``"space"`` attribute of a metatensor,
578
+ where appliable, or (('L', 'R'), ('P', 'A'), ('I', 'S'))``
579
+ otherwise (i.e. for plain tensors).
577
580
lazy: a flag to indicate whether this transform should execute lazily or not.
578
581
Defaults to False
579
582
@@ -619,9 +622,15 @@ def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch.
619
622
raise ValueError (f"data_array must have at least one spatial dimension, got { spatial_shape } ." )
620
623
affine_ : np .ndarray
621
624
affine_np : np .ndarray
625
+ labels = self .labels
622
626
if isinstance (data_array , MetaTensor ):
623
627
affine_np , * _ = convert_data_type (data_array .peek_pending_affine (), np .ndarray )
624
628
affine_ = to_affine_nd (sr , affine_np )
629
+
630
+ # Set up "labels" such that LPS tensors are handled correctly by default
631
+ if self .labels is None and SpaceKeys (data_array .meta ["space" ]) == SpaceKeys .LPS :
632
+ labels = (("R" , "L" ), ("A" , "P" ), ("I" , "S" )) # value for LPS
633
+
625
634
else :
626
635
warnings .warn ("`data_array` is not of type `MetaTensor, assuming affine to be identity." )
627
636
# default to identity
@@ -640,7 +649,7 @@ def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch.
640
649
f"{ self .__class__ .__name__ } : spatial shape = { spatial_shape } , channels = { data_array .shape [0 ]} ,"
641
650
"please make sure the input is in the channel-first format."
642
651
)
643
- dst = nib .orientations .axcodes2ornt (self .axcodes [:sr ], labels = self . labels )
652
+ dst = nib .orientations .axcodes2ornt (self .axcodes [:sr ], labels = labels )
644
653
if len (dst ) < sr :
645
654
raise ValueError (
646
655
f"axcodes must match data_array spatially, got axcodes={ len (self .axcodes )} D data_array={ sr } D"
@@ -653,8 +662,18 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
653
662
transform = self .pop_transform (data )
654
663
# Create inverse transform
655
664
orig_affine = transform [TraceKeys .EXTRA_INFO ]["original_affine" ]
656
- orig_axcodes = nib .orientations .aff2axcodes (orig_affine )
657
- inverse_transform = Orientation (axcodes = orig_axcodes , as_closest_canonical = False , labels = self .labels )
665
+ labels = self .labels
666
+
667
+ # Set up "labels" such that LPS tensors are handled correctly by default
668
+ if (
669
+ isinstance (data , MetaTensor ) and
670
+ self .labels is None and
671
+ SpaceKeys (data .meta ["space" ]) == SpaceKeys .LPS
672
+ ):
673
+ labels = (("R" , "L" ), ("A" , "P" ), ("I" , "S" )) # value for LPS
674
+
675
+ orig_axcodes = nib .orientations .aff2axcodes (orig_affine , labels = labels )
676
+ inverse_transform = Orientation (axcodes = orig_axcodes , as_closest_canonical = False , labels = labels )
658
677
# Apply inverse
659
678
with inverse_transform .trace_transform (False ):
660
679
data = inverse_transform (data )
0 commit comments