Skip to content

MONAISegInferenceOperator Additional Arguments #547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

bluna301
Copy link
Contributor

When translating the MONAI CT TotalSegmentator MONAI Bundle to a MAP, it was determined that the segmentations produced by the MONAI Bundle and MAP were not exactly aligned (i.e. DICE != 1 for all organs).

This seems to be due to the MONAISegInferenceOperator not accepting all the possible arguments that are accepted by the sliding_window_inference method, specifically mode and padding_mode for this Bundle. Once a custom operator that accepted these input arguments was implemented into the MAP pipeline, the DICE for all organs was ~=0.99 when comparing to the Bundle.

This initial commit includes the logic for accepting and implementing mode and padding_mode arguments. Happy to discuss further if the additional missing sliding_window_inference arguments should be added as well.

@bluna301 bluna301 added the enhancement New feature or request label Jul 28, 2025
@bluna301 bluna301 self-assigned this Jul 28, 2025
@bluna301 bluna301 requested review from vikashg and MMelQin July 28, 2025 22:34
@@ -47,13 +47,29 @@
__all__ = ["MonaiSegInferenceOperator", "InfererType", "InMemImageReader"]


class BlendModeType(StrEnum):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we simply use the enum from MONAI core instead of redefining it and needing to handle the mapping (just in case the core decides to update the str enum)? Also applicable to PytorchPadModeType. I think it is fine to expose some of the innards of this class, and further add in the inline doc that internally it uses the MONAI inference utils functions for sliding window and simple inference, and the params need to be as expected by them.

Similar to the conditional import of the sliding_window_inference, we can do

BlendModeType, _ = optional_import("monai.utils.BlendModeType", name="BlendModeType")
PytorchPadModeType, _ = optional_import("monai.utils.BlendModeType", name="PytorchPadModeType")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially had BlendMode and PytorchPadMode as optional imports from MONAI utils like you suggested, but during type checking, mypy was struggling to deal with type checking of optional imports. I ended up having these as required imports, but these do now rely on the MONAI utils definition.

from monai.utils import BlendMode, PytorchPadMode

.....

# define other StrEnum types
BlendModeType = BlendMode
PytorchPadModeType = PytorchPadMode

Happy to change these back to optional imports if you like, just let me know.

@@ -320,6 +367,8 @@ def compute_impl(self, input_image, context):
roi_size=self._roi_size,
sw_batch_size=self.sw_batch_size,
overlap=self.overlap,
mode=self._mode,
padding_mode=self._padding_mode,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is good to add these couple params as explicitly supported named args to the MONAI inference function. In addition, I had the idea of extending the support to all other params on the MONAI inference functions by passing the (filtered) **kwargs down to the functions. I will provide a static function for the filtering in the general comment section.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filtering functionality added. Tested this out with a few example cases, and the desired behavior was displayed. Here is an example - we can see input and predictor inputs are ignored, and that buffer_steps and buffer_dim parameters are successfully passed to sliding_window_inference to produce a ValueError (values chosen purposefully to produce the error):

# delegates inference and saving output to the built-in operator
# parameters pulled from inference.yaml file of the MONAI bundle
infer_operator = MonaiSegInferenceOperator(
    self.fragment,
    roi_size=(96, 96, 96),
    pre_transforms=pre_transforms,
    post_transforms=post_transforms,
    overlap=0.25,
    app_context=self.app_context,
    model_name="",
    inferer=InfererType.SLIDING_WINDOW,
    sw_batch_size=1,
    mode=BlendModeType.GAUSSIAN,
    padding_mode=PytorchPadModeType.REPLICATE,
    model_path=self.model_path,
    inputs="test",
    predictor="testing",
    buffer_steps=1,
    buffer_dim=4
)
2025-08-01 19:04:15,548] [INFO] (ct_totalseg_operator.CTTotalSegOperator) - TorchScript model detected
[2025-08-01 19:04:15,548] [WARNING] (monai_seg_inference_operator.MonaiSegInferenceOperator) - 'inputs' is already explicity defined or used; ignoring input arg
[2025-08-01 19:04:15,548] [WARNING] (monai_seg_inference_operator.MonaiSegInferenceOperator) - 'predictor' is already explicity defined or used; ignoring input arg

....

[2025-08-01 19:04:18,125] [INFO] (monai_seg_inference_operator.MonaiSegInferenceOperator) - Input of <class 'monai.data.meta_tensor.MetaTensor'> shape: torch.Size([1, 1, 270, 270, 204])
[error] [gxf_wrapper.cpp:118] Exception occurred for operator: 'ct_totalseg_op' - ValueError: buffer_dim must be in [-3, 3], got 4.

At:
  /home/bluna301/miniconda3/envs/ct-totalsegmentator/lib/python3.9/site-packages/monai/inferers/utils.py(142): sliding_window_inference
  /home/bluna301/ct-totalsegmentator-map/my_app/monai_seg_inference_operator.py(435): compute_impl
  /home/bluna301/ct-totalsegmentator-map/my_app/ct_totalseg_operator.py(226): compute

[error] [entity_executor.cpp:596] Failed to tick codelet ct_totalseg_op in entity: ct_totalseg_op code: GXF_FAILURE
[warning] [greedy_scheduler.cpp:243] Error while executing entity 28 named 'ct_totalseg_op': GXF_FAILURE
[info] [greedy_scheduler.cpp:401] Scheduler finished.
[error] [program.cpp:580] wait failed. Deactivating...
[error] [runtime.cpp:1649] Graph wait failed with error: GXF_FAILURE
[warning] [gxf_executor.cpp:2428] GXF call GxfGraphWait(context) in line 2428 of file /workspace/holoscan-sdk/src/core/executors/gxf/gxf_executor.cpp failed with 'GXF_FAILURE' (1)
[info] [gxf_executor.cpp:2438] Graph execution finished.
[error] [gxf_executor.cpp:2446] Graph execution error: GXF_FAILURE
Traceback (most recent call last):
  File "/home/bluna301/miniconda3/envs/ct-totalsegmentator/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/bluna301/miniconda3/envs/ct-totalsegmentator/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/bluna301/ct-totalsegmentator-map/my_app/__main__.py", line 25, in <module>
    CTTotalSegmentatorApp().run()
  File "/home/bluna301/ct-totalsegmentator-map/my_app/app.py", line 61, in run
    super().run(*args, **kwargs)
  File "/home/bluna301/ct-totalsegmentator-map/my_app/ct_totalseg_operator.py", line 226, in compute
    seg_image = infer_operator.compute_impl(input_image, context)
  File "/home/bluna301/ct-totalsegmentator-map/my_app/monai_seg_inference_operator.py", line 391, in compute_impl
    d[self._pred_dataset_key] = sliding_window_inference(
  File "/home/bluna301/miniconda3/envs/ct-totalsegmentator/lib/python3.9/site-packages/monai/inferers/utils.py", line 142, in sliding_window_inference
    raise ValueError(f"buffer_dim must be in [{-num_spatial_dims}, {num_spatial_dims}], got {buffer_dim}.")
ValueError: buffer_dim must be in [-3, 3], got 4.

@MMelQin
Copy link
Collaborator

MMelQin commented Jul 30, 2025

It is good to add these couple parameters explicitly on the __init__ function of this operator. In addition, we can also consider supporting all other params on the MONAI sliding_window_inference function in a implicitly, i.e. by passing the (filtered) kwargs down to the function. The filtering is to ensure no keys in the kwargs are duplicates of the explicitly supported params.

A test example

>>> def test(a: str = "it", b: str = "is Good", c: bool = True, d: dict = None, *args: Any, **kwargs: Any):
...     print(f"{a} {b}")
...     bool_val = "True" if c else "False"
...     print(f"The boolean val is {bool_val}")
...     if d:
...         for key, val in d.items():
...             print(f"{key}: {val}")
... 
>>> mine={'c': True, 'extra': 'unknown_to_fn'}
>>> test('a' = 'Test', 'b' = 1, **mine)
>>> test(a = 'Test', b = 1, **mine)
Test 1
The boolean val is True
>>> mine={'b': 2, 'c': False, 'extra': 'unknown_to_fn'}
>>> test(a = 'Test', b = 2, **mine)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: __main__.test() got multiple values for keyword argument 'b'
>>> mine={'b': 3, 'c': True, 'extra': 'unknown_to_fn'}
>>> test(a = 'Test', **mine)
Test 3
The boolean val is True
>>>

A function to filter out the params in the __init__ from the **kwargs can be implemented as

    @staticmethod
    def filter_sliding_window_params(**kwargs) -> Dict[str, Any]:
        """
        Returns a dictionary of named parameters of the sliding_window_inference function
        that are not in the __init__ of this class, or explicitly used on calling sliding_window_inference.
        """

        # Need to import inspect if not yet
        init_params = inspect.signature(MonaiSegInferenceOperator).parameters

        filtered_params = {}
        for name, val in kwargs.items():
            if name not in init_params and name not in ["inputs", "predictor", "args", "kwargs"]:
                filtered_params[name] = val
        return filtered_params

This function can be called in the init, and the resultant dictionary saved as an attribute, say self._implicit_params. When calling the sliding_window_inference, just need to add **self._implicit_params. The description of this class can also be updated to explain the use of the **kwargs.

@bluna301
Copy link
Contributor Author

@MMelQin many thanks for your review! Completely agree with all of your suggestions, and appreciate the base filtering function you provided for handling the implicit parameters. I will work on integrating your suggestions into the operator.

Copy link

sonarqubecloud bot commented Aug 1, 2025

@bluna301
Copy link
Contributor Author

bluna301 commented Aug 2, 2025

@MMelQin PR has been updated - please see comments for detailed updates.

I did have one remaining question about the device parameter - I see this is hardcoded into the operator, and this is also a sliding_window_inference parameter. Is this something we need to address?

@bluna301 bluna301 requested a review from MMelQin August 2, 2025 02:07
Copy link
Collaborator

@MMelQin MMelQin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for expanding the support.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants