Skip to content

Commit d97bdd5

Browse files
committed
Enables automatic transform group tracking for inversion
Addresses an issue where `Invertd` fails when postprocessing contains invertible transforms before `Invertd` is called. The solution uses automatic group tracking: `Compose` assigns its ID to child transforms, allowing `Invertd` to filter and select only the relevant transforms for inversion. This ensures correct inversion when multiple transform pipelines are used or when post-processing steps include invertible transforms. `TraceableTransform` now stores group information. `Invertd` now filters transforms by group, falling back to the original behavior if no group information is present (for backward compatibility). Adds tests to verify the fix and group isolation.
1 parent e267705 commit d97bdd5

File tree

5 files changed

+317
-2
lines changed

5 files changed

+317
-2
lines changed

monai/transforms/compose.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,48 @@ def __init__(
262262
self.set_random_state(seed=get_seed())
263263
self.overrides = overrides
264264

265+
# Automatically assign group ID to child transforms for inversion tracking
266+
self._set_transform_groups()
267+
268+
def _set_transform_groups(self):
269+
"""
270+
Automatically set group IDs on child transforms for inversion tracking.
271+
This allows Invertd to identify which transforms belong to this Compose instance.
272+
Recursively sets groups on wrapped transforms (e.g., array transforms inside dictionary transforms).
273+
"""
274+
from monai.transforms.inverse import TraceableTransform
275+
276+
group_id = str(id(self))
277+
visited = set() # Track visited objects to avoid infinite recursion
278+
279+
def set_group_recursive(obj, gid):
280+
"""Recursively set group on transform and its wrapped transforms."""
281+
# Avoid infinite recursion
282+
obj_id = id(obj)
283+
if obj_id in visited:
284+
return
285+
visited.add(obj_id)
286+
287+
if isinstance(obj, TraceableTransform):
288+
obj._group = gid
289+
290+
# Handle wrapped transforms in dictionary transforms
291+
# Check common attribute patterns for wrapped transforms
292+
for attr_name in dir(obj):
293+
# Skip magic methods and common non-transform attributes
294+
if attr_name.startswith('__') or attr_name in ('transforms', 'transform'):
295+
continue
296+
try:
297+
attr = getattr(obj, attr_name, None)
298+
if attr is not None and isinstance(attr, TraceableTransform) and not isinstance(attr, Compose):
299+
# Recursively set group on nested transforms
300+
set_group_recursive(attr, gid)
301+
except Exception:
302+
pass
303+
304+
for transform in self.transforms:
305+
set_group_recursive(transform, group_id)
306+
265307
@LazyTransform.lazy.setter # type: ignore
266308
def lazy(self, val: bool):
267309
self._lazy = val

monai/transforms/inverse.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,13 @@ def get_transform_info(self) -> dict:
125125
self.tracing,
126126
self._do_transform if hasattr(self, "_do_transform") else True,
127127
)
128-
return dict(zip(self.transform_info_keys(), vals))
128+
info = dict(zip(self.transform_info_keys(), vals))
129+
130+
# Add group if set (automatically set by Compose)
131+
if hasattr(self, "_group") and self._group is not None:
132+
info[TraceKeys.GROUP] = self._group
133+
134+
return info
129135

130136
def push_transform(self, data, *args, **kwargs):
131137
"""

monai/transforms/post/dictionary.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,29 @@ def __init__(
859859
self.post_func = ensure_tuple_rep(post_func, len(self.keys))
860860
self._totensor = ToTensor()
861861

862+
def _filter_transforms_by_group(self, all_transforms: list[dict]) -> list[dict]:
863+
"""
864+
Filter applied_operations to only include transforms from the target Compose instance.
865+
Uses automatic group tracking where Compose assigns its ID to child transforms.
866+
"""
867+
from monai.utils import TraceKeys
868+
869+
# Get the group ID of the transform (Compose instance)
870+
target_group = str(id(self.transform))
871+
872+
# Filter transforms that match the target group
873+
filtered = []
874+
for xform in all_transforms:
875+
xform_group = xform.get(TraceKeys.GROUP)
876+
if xform_group == target_group:
877+
filtered.append(xform)
878+
879+
# If no transforms match (backward compatibility), return all transforms
880+
if not filtered:
881+
return all_transforms
882+
883+
return filtered
884+
862885
def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
863886
d = dict(data)
864887
for (
@@ -894,8 +917,11 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
894917

895918
orig_meta_key = orig_meta_key or f"{orig_key}_{meta_key_postfix}"
896919
if orig_key in d and isinstance(d[orig_key], MetaTensor):
897-
transform_info = d[orig_key].applied_operations
920+
all_transforms = d[orig_key].applied_operations
898921
meta_info = d[orig_key].meta
922+
923+
# Automatically filter by Compose instance group ID
924+
transform_info = self._filter_transforms_by_group(all_transforms)
899925
else:
900926
transform_info = d[InvertibleTransform.trace_key(orig_key)]
901927
meta_info = d.get(orig_meta_key, {})

monai/utils/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ class TraceKeys(StrEnum):
334334
TRACING: str = "tracing"
335335
STATUSES: str = "statuses"
336336
LAZY: str = "lazy"
337+
GROUP: str = "group"
337338

338339

339340
class TraceStatusKeys(StrEnum):

tests/transforms/inverse/test_invertd.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,246 @@ def test_invert(self):
137137

138138
set_determinism(seed=None)
139139

140+
def test_invertd_with_postprocessing_transforms(self):
141+
"""Test that Invertd ignores postprocessing transforms using automatic group tracking.
142+
143+
This is a regression test for the issue where Invertd would fail when
144+
postprocessing contains invertible transforms before Invertd is called.
145+
The fix uses automatic group tracking where Compose assigns its ID to child transforms.
146+
"""
147+
from monai.data import MetaTensor, create_test_image_2d
148+
from monai.transforms.utility.dictionary import Lambdad
149+
150+
img, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2)
151+
img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]})
152+
key = "image"
153+
154+
# Preprocessing pipeline
155+
preprocessing = Compose([
156+
EnsureChannelFirstd(key),
157+
Spacingd(key, pixdim=[2.0, 2.0]),
158+
])
159+
160+
# Postprocessing with Lambdad before Invertd
161+
# Previously this would raise RuntimeError about transform ID mismatch
162+
postprocessing = Compose([
163+
Lambdad(key, func=lambda x: x), # Should be ignored during inversion
164+
Invertd(key, transform=preprocessing, orig_keys=key)
165+
])
166+
167+
# Apply transforms
168+
item = {key: img}
169+
pre = preprocessing(item)
170+
171+
# This should NOT raise an error (was failing before the fix)
172+
try:
173+
post = postprocessing(pre)
174+
# If we get here, the bug is fixed
175+
self.assertIsNotNone(post)
176+
self.assertIn(key, post)
177+
print(f"SUCCESS! Automatic group tracking fixed the bug.")
178+
print(f" Preprocessing group ID: {id(preprocessing)}")
179+
print(f" Postprocessing group ID: {id(postprocessing)}")
180+
except RuntimeError as e:
181+
if "getting the most recently applied invertible transform" in str(e):
182+
self.fail(f"Invertd still has the postprocessing transform bug: {e}")
183+
184+
def test_invertd_multiple_pipelines(self):
185+
"""Test that Invertd correctly handles multiple independent preprocessing pipelines."""
186+
from monai.data import MetaTensor, create_test_image_2d
187+
from monai.transforms.utility.dictionary import Lambdad
188+
189+
img1, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2)
190+
img1 = MetaTensor(img1, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]})
191+
img2, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2)
192+
img2 = MetaTensor(img2, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]})
193+
194+
# Two different preprocessing pipelines
195+
preprocessing1 = Compose([
196+
EnsureChannelFirstd("image1"),
197+
Spacingd("image1", pixdim=[2.0, 2.0]),
198+
])
199+
200+
preprocessing2 = Compose([
201+
EnsureChannelFirstd("image2"),
202+
Spacingd("image2", pixdim=[1.5, 1.5]),
203+
])
204+
205+
# Postprocessing that inverts both
206+
postprocessing = Compose([
207+
Lambdad(["image1", "image2"], func=lambda x: x),
208+
Invertd("image1", transform=preprocessing1, orig_keys="image1"),
209+
Invertd("image2", transform=preprocessing2, orig_keys="image2"),
210+
])
211+
212+
# Apply transforms
213+
item = {"image1": img1, "image2": img2}
214+
pre1 = preprocessing1(item)
215+
pre2 = preprocessing2(pre1)
216+
217+
# Should not raise error - each Invertd should only invert its own pipeline
218+
post = postprocessing(pre2)
219+
self.assertIn("image1", post)
220+
self.assertIn("image2", post)
221+
222+
def test_invertd_multiple_postprocessing_transforms(self):
223+
"""Test Invertd with multiple invertible transforms in postprocessing before Invertd."""
224+
from monai.data import MetaTensor, create_test_image_2d
225+
from monai.transforms.utility.dictionary import Lambdad
226+
227+
img, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2)
228+
img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]})
229+
key = "image"
230+
231+
preprocessing = Compose([
232+
EnsureChannelFirstd(key),
233+
Spacingd(key, pixdim=[2.0, 2.0]),
234+
])
235+
236+
# Multiple transforms in postprocessing before Invertd
237+
postprocessing = Compose([
238+
Lambdad(key, func=lambda x: x * 2),
239+
Lambdad(key, func=lambda x: x + 1),
240+
Lambdad(key, func=lambda x: x - 1),
241+
Invertd(key, transform=preprocessing, orig_keys=key)
242+
])
243+
244+
item = {key: img}
245+
pre = preprocessing(item)
246+
post = postprocessing(pre)
247+
248+
self.assertIsNotNone(post)
249+
self.assertIn(key, post)
250+
251+
def test_invertd_group_isolation(self):
252+
"""Test that groups correctly isolate transforms from different Compose instances."""
253+
from monai.data import MetaTensor, create_test_image_2d
254+
255+
img, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2)
256+
img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]})
257+
key = "image"
258+
259+
# First preprocessing
260+
preprocessing1 = Compose([
261+
EnsureChannelFirstd(key),
262+
Spacingd(key, pixdim=[2.0, 2.0]),
263+
])
264+
265+
# Second preprocessing (different pipeline)
266+
preprocessing2 = Compose([
267+
Spacingd(key, pixdim=[1.5, 1.5]),
268+
])
269+
270+
item = {key: img}
271+
pre1 = preprocessing1(item)
272+
273+
# Verify group IDs are in applied_operations
274+
self.assertTrue(len(pre1[key].applied_operations) > 0)
275+
group1 = pre1[key].applied_operations[0].get("group")
276+
self.assertIsNotNone(group1)
277+
self.assertEqual(group1, str(id(preprocessing1)))
278+
279+
# Apply second preprocessing
280+
pre2 = preprocessing2(pre1)
281+
282+
# Should have operations from both pipelines with different groups
283+
groups = [op.get("group") for op in pre2[key].applied_operations]
284+
self.assertIn(str(id(preprocessing1)), groups)
285+
self.assertIn(str(id(preprocessing2)), groups)
286+
287+
# Inverting preprocessing1 should only invert its transforms
288+
inverter = Invertd(key, transform=preprocessing1, orig_keys=key)
289+
inverted = inverter(pre2)
290+
self.assertIsNotNone(inverted)
291+
292+
def test_compose_inverse_with_groups(self):
293+
"""Test that Compose.inverse() works correctly with automatic group tracking."""
294+
from monai.data import MetaTensor, create_test_image_2d
295+
296+
img, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2)
297+
img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]})
298+
key = "image"
299+
300+
# Create a preprocessing pipeline
301+
preprocessing = Compose([
302+
EnsureChannelFirstd(key),
303+
Spacingd(key, pixdim=[2.0, 2.0]),
304+
])
305+
306+
# Apply preprocessing
307+
item = {key: img}
308+
pre = preprocessing(item)
309+
310+
# Call inverse() directly on the Compose object
311+
inverted = preprocessing.inverse(pre)
312+
313+
# Should successfully invert
314+
self.assertIsNotNone(inverted)
315+
self.assertIn(key, inverted)
316+
# Shape should be restored after inversion
317+
self.assertEqual(inverted[key].shape[1:], img.shape)
318+
319+
def test_compose_inverse_with_postprocessing_groups(self):
320+
"""Test Compose.inverse() when data has been through multiple pipelines with different groups."""
321+
from monai.data import MetaTensor, create_test_image_2d
322+
from monai.transforms.utility.dictionary import Lambdad
323+
324+
img, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2)
325+
img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]})
326+
key = "image"
327+
328+
# Preprocessing pipeline
329+
preprocessing = Compose([
330+
EnsureChannelFirstd(key),
331+
Spacingd(key, pixdim=[2.0, 2.0]),
332+
])
333+
334+
# Postprocessing pipeline (different group)
335+
postprocessing = Compose([
336+
Lambdad(key, func=lambda x: x * 2),
337+
])
338+
339+
# Apply both pipelines
340+
item = {key: img}
341+
pre = preprocessing(item)
342+
post = postprocessing(pre)
343+
344+
# Now call inverse() directly on preprocessing
345+
# This tests that inverse() can handle data that has transforms from multiple groups
346+
# This WILL fail because applied_operations contains postprocessing transforms
347+
# and inverse() doesn't do group filtering (only Invertd does)
348+
with self.assertRaises(RuntimeError):
349+
inverted = preprocessing.inverse(post)
350+
351+
def test_mixed_invertd_and_compose_inverse(self):
352+
"""Test mixing Invertd (with group filtering) and Compose.inverse() (without filtering)."""
353+
from monai.data import MetaTensor, create_test_image_2d
354+
355+
img, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2)
356+
img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]})
357+
key = "image"
358+
359+
# First pipeline
360+
pipeline1 = Compose([
361+
EnsureChannelFirstd(key),
362+
Spacingd(key, pixdim=[2.0, 2.0]),
363+
])
364+
365+
# Apply first pipeline
366+
item = {key: img}
367+
result1 = pipeline1(item)
368+
369+
# Use Compose.inverse() directly - should work fine
370+
inverted1 = pipeline1.inverse(result1)
371+
self.assertIsNotNone(inverted1)
372+
self.assertEqual(inverted1[key].shape[1:], img.shape)
373+
374+
# Now apply pipeline again and use Invertd
375+
result2 = pipeline1(item)
376+
inverter = Invertd(key, transform=pipeline1, orig_keys=key)
377+
inverted2 = inverter(result2)
378+
self.assertIsNotNone(inverted2)
379+
140380

141381
if __name__ == "__main__":
142382
unittest.main()

0 commit comments

Comments
 (0)