Skip to content

Commit 1a30a0b

Browse files
Add JSON generation widgets for nnUNet and update requirements
1 parent 568d25a commit 1a30a0b

File tree

9 files changed

+2894
-114
lines changed

9 files changed

+2894
-114
lines changed

monai/apps/nnunet/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
from __future__ import annotations
1313

1414
from .nnunet_bundle import (
15-
ModelnnUNetWrapper,
15+
convert_monai_bundle_to_nnunet,
1616
convert_nnunet_to_monai_bundle,
17+
get_network_from_nnunet_plans,
1718
get_nnunet_monai_predictor,
1819
get_nnunet_trainer,
20+
nnUNetMONAIModelWrapper,
1921
)
2022
from .nnunetv2_runner import nnUNetV2Runner
2123
from .utils import NNUNETMode, analyze_data, create_new_data_copy, create_new_dataset_json

monai/apps/nnunet/nnunet_bundle.py

Lines changed: 259 additions & 113 deletions
Large diffs are not rendered by default.

monai/nvflare/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.

monai/nvflare/json_generator.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
from __future__ import annotations
12+
13+
import json
14+
import os.path
15+
16+
from nvflare.apis.event_type import EventType
17+
from nvflare.apis.fl_context import FLContext
18+
from nvflare.widgets.widget import Widget
19+
20+
21+
class PrepareJsonGenerator(Widget):
22+
"""
23+
A widget class to prepare and generate a JSON file containing data preparation configurations.
24+
25+
Parameters
26+
----------
27+
results_dir : str, optional
28+
The directory where the results will be stored (default is "prepare").
29+
json_file_name : str, optional
30+
The name of the JSON file to be generated (default is "data_dict.json").
31+
32+
Methods
33+
-------
34+
handle_event(event_type: str, fl_ctx: FLContext)
35+
Handles events during the federated learning process. Clears the data preparation configuration
36+
at the start of a run and saves the configuration to a JSON file at the end of a run.
37+
"""
38+
39+
def __init__(self, results_dir="prepare", json_file_name="data_dict.json"):
40+
super(PrepareJsonGenerator, self).__init__()
41+
42+
self._results_dir = results_dir
43+
self._data_prepare_config = {}
44+
self._json_file_name = json_file_name
45+
46+
def handle_event(self, event_type: str, fl_ctx: FLContext):
47+
if event_type == EventType.START_RUN:
48+
self._data_prepare_config.clear()
49+
elif event_type == EventType.END_RUN:
50+
self._data_prepare_config = fl_ctx.get_prop("client_data_dict", None)
51+
run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id())
52+
data_prepare_res_dir = os.path.join(run_dir, self._results_dir)
53+
if not os.path.exists(data_prepare_res_dir):
54+
os.makedirs(data_prepare_res_dir)
55+
56+
res_file_path = os.path.join(data_prepare_res_dir, self._json_file_name)
57+
with open(res_file_path, "w") as f:
58+
json.dump(self._data_prepare_config, f)
59+
60+
61+
class nnUNetPackageReportJsonGenerator(Widget):
62+
"""
63+
A class to generate JSON reports for nnUNet package.
64+
65+
Parameters
66+
----------
67+
results_dir : str, optional
68+
Directory where the report will be saved (default is "package_report").
69+
json_file_name : str, optional
70+
Name of the JSON file to save the report (default is "package_report.json").
71+
72+
Methods
73+
-------
74+
handle_event(event_type: str, fl_ctx: FLContext)
75+
Handles events to clear the report at the start of a run and save the report at the end of a run.
76+
"""
77+
78+
def __init__(self, results_dir="package_report", json_file_name="package_report.json"):
79+
super(nnUNetPackageReportJsonGenerator, self).__init__()
80+
81+
self._results_dir = results_dir
82+
self._report = {}
83+
self._json_file_name = json_file_name
84+
85+
def handle_event(self, event_type: str, fl_ctx: FLContext):
86+
if event_type == EventType.START_RUN:
87+
self._report.clear()
88+
elif event_type == EventType.END_RUN:
89+
datasets = fl_ctx.get_prop("package_report", None)
90+
run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id())
91+
cross_val_res_dir = os.path.join(run_dir, self._results_dir)
92+
if not os.path.exists(cross_val_res_dir):
93+
os.makedirs(cross_val_res_dir)
94+
95+
res_file_path = os.path.join(cross_val_res_dir, self._json_file_name)
96+
with open(res_file_path, "w") as f:
97+
json.dump(datasets, f)
98+
99+
100+
class nnUNetPlansJsonGenerator(Widget):
101+
"""
102+
A class to generate JSON files for nnUNet plans.
103+
104+
Parameters
105+
----------
106+
results_dir : str, optional
107+
Directory where the preprocessing results will be stored (default is "nnUNet_preprocessing").
108+
json_file_name : str, optional
109+
Name of the JSON file to be generated (default is "nnUNetPlans.json").
110+
111+
Methods
112+
-------
113+
handle_event(event_type: str, fl_ctx: FLContext)
114+
Handles events during the federated learning process. Clears the nnUNet plans at the start of a run and saves
115+
the plans to a JSON file at the end of a run.
116+
"""
117+
118+
def __init__(self, results_dir="nnUNet_preprocessing", json_file_name="nnUNetPlans.json"):
119+
120+
super(nnUNetPlansJsonGenerator, self).__init__()
121+
122+
self._results_dir = results_dir
123+
self._nnUNetPlans = {}
124+
self._json_file_name = json_file_name
125+
126+
def handle_event(self, event_type: str, fl_ctx: FLContext):
127+
if event_type == EventType.START_RUN:
128+
self._nnUNetPlans.clear()
129+
elif event_type == EventType.END_RUN:
130+
datasets = fl_ctx.get_prop("nnunet_plans", None)
131+
run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id())
132+
cross_val_res_dir = os.path.join(run_dir, self._results_dir)
133+
if not os.path.exists(cross_val_res_dir):
134+
os.makedirs(cross_val_res_dir)
135+
136+
res_file_path = os.path.join(cross_val_res_dir, self._json_file_name)
137+
with open(res_file_path, "w") as f:
138+
json.dump(datasets, f)
139+
140+
141+
class nnUNetValSummaryJsonGenerator(Widget):
142+
"""
143+
A widget to generate a JSON summary for nnUNet validation results.
144+
145+
Parameters
146+
----------
147+
results_dir : str, optional
148+
Directory where the nnUNet training results are stored (default is "nnUNet_train").
149+
json_file_name : str, optional
150+
Name of the JSON file to save the validation summary (default is "val_summary.json").
151+
152+
Methods
153+
-------
154+
handle_event(event_type: str, fl_ctx: FLContext)
155+
Handles events during the federated learning process. Clears the nnUNet plans at the start of a run and saves
156+
the validation summary to a JSON file at the end of a run.
157+
"""
158+
159+
def __init__(self, results_dir="nnUNet_train", json_file_name="val_summary.json"):
160+
161+
super(nnUNetValSummaryJsonGenerator, self).__init__()
162+
163+
self._results_dir = results_dir
164+
self._nnUNetPlans = {}
165+
self._json_file_name = json_file_name
166+
167+
def handle_event(self, event_type: str, fl_ctx: FLContext):
168+
if event_type == EventType.START_RUN:
169+
self._nnUNetPlans.clear()
170+
elif event_type == EventType.END_RUN:
171+
datasets = fl_ctx.get_prop("val_summary_dict", None)
172+
run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id())
173+
cross_val_res_dir = os.path.join(run_dir, self._results_dir)
174+
if not os.path.exists(cross_val_res_dir):
175+
os.makedirs(cross_val_res_dir)
176+
177+
res_file_path = os.path.join(cross_val_res_dir, self._json_file_name)
178+
with open(res_file_path, "w") as f:
179+
json.dump(datasets, f)

0 commit comments

Comments
 (0)