Skip to content

Commit 2281c4e

Browse files
Fix image_stats integration test (#3790)
Fix CI after changes in #3753 ### Description Add the old job into CI for testing. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Quick tests passed locally by running `./runtest.sh`. - [ ] In-line docstrings updated. - [ ] Documentation updated.
1 parent e9e2837 commit 2281c4e

File tree

5 files changed

+225
-2
lines changed

5 files changed

+225
-2
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
{
2+
"format_version": 2,
3+
"executors": [
4+
{
5+
"tasks": [
6+
"fed_stats_pre_run", "fed_stats"
7+
],
8+
"executor": {
9+
"id": "Executor",
10+
"path": "nvflare.app_common.executors.statistics.statistics_executor.StatisticsExecutor",
11+
"args": {
12+
"generator_id": "local_hist_generator"
13+
}
14+
}
15+
}
16+
],
17+
"task_result_filters": [
18+
{
19+
"tasks": ["fed_stats"],
20+
"filters":[
21+
{
22+
"path": "nvflare.app_common.filters.statistics_privacy_filter.StatisticsPrivacyFilter",
23+
"args": {
24+
"result_cleanser_ids": [
25+
"min_count_cleanser"
26+
]
27+
}
28+
}
29+
]
30+
}
31+
],
32+
"task_data_filters": [],
33+
"components": [
34+
{
35+
"id": "local_hist_generator",
36+
"path": "image_statistics.ImageStatistics",
37+
"args": {
38+
"data_root": "/tmp/nvflare/image_stats/data"
39+
}
40+
},
41+
{
42+
"id": "min_count_cleanser",
43+
"path": "nvflare.app_common.statistics.min_count_cleanser.MinCountCleanser",
44+
"args": {
45+
"min_count": 10
46+
}
47+
}
48+
]
49+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
{
2+
"format_version": 2,
3+
"workflows": [
4+
{
5+
"id": "fed_stats_controller",
6+
"path": "nvflare.app_common.workflows.statistics_controller.StatisticsController",
7+
"args": {
8+
"min_clients": 4,
9+
"statistic_configs": {
10+
"count": {},
11+
"histogram": {
12+
"*": {
13+
"bins": 255, "range": [0,256]
14+
}
15+
}
16+
},
17+
"writer_id": "stats_writer",
18+
"enable_pre_run_task": true
19+
}
20+
}
21+
],
22+
"components": [
23+
{
24+
"id": "stats_writer",
25+
"path": "nvflare.app_common.statistics.json_stats_file_persistor.JsonStatsFileWriter",
26+
"args": {
27+
"output_path": "statistics/image_statistics.json",
28+
"json_encoder_path": "nvflare.app_common.utils.json_utils.ObjectEncoder"
29+
}
30+
}
31+
]
32+
}
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import glob
16+
import os
17+
from typing import Dict, List, Optional
18+
19+
import numpy as np
20+
from monai.data import ITKReader, load_decathlon_datalist
21+
from monai.transforms import LoadImage
22+
23+
from nvflare.apis.fl_context import FLContext
24+
from nvflare.app_common.abstract.statistics_spec import Bin, DataType, Feature, Histogram, HistogramType, Statistics
25+
from nvflare.security.logging import secure_log_traceback
26+
27+
28+
class ImageStatistics(Statistics):
29+
def __init__(self, data_root: str = "/tmp/nvflare/image_stats/data", data_list_key: str = "data"):
30+
"""local image statistics generator .
31+
Args:
32+
data_root: directory with local image data.
33+
data_list_key: data list key to use.
34+
Returns:
35+
a Shareable with the computed local statistics`
36+
"""
37+
super().__init__()
38+
self.data_list_key = data_list_key
39+
self.data_root = data_root
40+
self.data_list = None
41+
self.client_name = None
42+
43+
self.loader = None
44+
self.failure_images = 0
45+
self.fl_ctx = None
46+
47+
def initialize(self, fl_ctx: FLContext):
48+
self.fl_ctx = fl_ctx
49+
self.client_name = fl_ctx.get_identity_name()
50+
self.loader = LoadImage(image_only=True)
51+
self.loader.register(ITKReader())
52+
self._load_data_list(self.client_name, fl_ctx)
53+
54+
if self.data_list is None:
55+
raise ValueError("data is not loaded. make sure the data is loaded")
56+
57+
def _load_data_list(self, client_name, fl_ctx: FLContext) -> bool:
58+
dataset_json = glob.glob(os.path.join(self.data_root, client_name + "*.json"))
59+
if len(dataset_json) != 1:
60+
self.log_error(
61+
fl_ctx, f"No unique matching dataset list found in {self.data_root} for client {client_name}"
62+
)
63+
return False
64+
dataset_json = dataset_json[0]
65+
self.log_info(fl_ctx, f"Reading data from {dataset_json}")
66+
67+
data_list = load_decathlon_datalist(
68+
data_list_file_path=dataset_json, data_list_key=self.data_list_key, base_dir=self.data_root
69+
)
70+
self.data_list = {"train": data_list}
71+
72+
self.log_info(fl_ctx, f"Client {client_name} has {len(self.data_list)} images")
73+
return True
74+
75+
def pre_run(
76+
self,
77+
statistics: List[str],
78+
num_of_bins: Optional[Dict[str, Optional[int]]],
79+
bin_ranges: Optional[Dict[str, Optional[List[float]]]],
80+
):
81+
return {}
82+
83+
def features(self) -> Dict[str, List[Feature]]:
84+
return {"train": [Feature("intensity", DataType.FLOAT)]}
85+
86+
def count(self, dataset_name: str, feature_name: str) -> int:
87+
image_paths = self.data_list[dataset_name]
88+
return len(image_paths)
89+
90+
def failure_count(self, dataset_name: str, feature_name: str) -> int:
91+
92+
return self.failure_images
93+
94+
def histogram(
95+
self, dataset_name: str, feature_name: str, num_of_bins: int, global_min_value: float, global_max_value: float
96+
) -> Histogram:
97+
histogram_bins: List[Bin] = []
98+
histogram = np.zeros((num_of_bins,), dtype=np.int64)
99+
bin_edges = []
100+
for i, entry in enumerate(self.data_list[dataset_name]):
101+
file = entry.get("image")
102+
try:
103+
img = self.loader(file)
104+
curr_histogram, bin_edges = np.histogram(
105+
img, bins=num_of_bins, range=(global_min_value, global_max_value)
106+
)
107+
histogram += curr_histogram
108+
bin_edges = bin_edges.tolist()
109+
110+
if i % 100 == 0:
111+
self.logger.info(
112+
f"{self.client_name}, adding {i + 1} of {len(self.data_list[dataset_name])}: {file}"
113+
)
114+
except Exception as e:
115+
self.failure_images += 1
116+
self.logger.critical(
117+
f"Failed to load file {file} with exception: {e.__str__()}. " f"Skipping this image..."
118+
)
119+
120+
if num_of_bins + 1 != len(bin_edges):
121+
secure_log_traceback()
122+
raise ValueError(
123+
f"bin_edges size: {len(bin_edges)} is not matching with number of bins + 1: {num_of_bins + 1}"
124+
)
125+
126+
for j in range(num_of_bins):
127+
low_value = bin_edges[j]
128+
high_value = bin_edges[j + 1]
129+
bin_sample_count = histogram[j]
130+
histogram_bins.append(Bin(low_value=low_value, high_value=high_value, sample_count=bin_sample_count))
131+
132+
return Histogram(HistogramType.STANDARD, histogram_bins)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"name": "image_stats",
3+
"resource_spec": {},
4+
"min_clients" : 4,
5+
"deploy_map": {
6+
"app": [
7+
"@ALL"
8+
]
9+
}
10+
}

tests/integration_test/data/test_configs/standalone_job/image_stats.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
additional_python_paths:
22
- ../../examples/advanced/federated-statistics/image_stats
33
cleanup: true
4-
jobs_root_dir: ../../examples/advanced/federated-statistics/image_stats/jobs
4+
jobs_root_dir: ./data/jobs
55
project_yaml: data/projects/4_clients.yml
66
tests:
77
- event_sequence:
@@ -26,6 +26,6 @@ tests:
2626
- cp ../../examples/advanced/federated-statistics/image_stats/requirements.txt ../../examples/advanced/federated-statistics/image_stats/temp_requirements.txt
2727
- sed -i '/nvflare\|jupyter\|notebook/d' ../../examples/advanced/federated-statistics/image_stats/temp_requirements.txt
2828
- pip install -r ../../examples/advanced/federated-statistics/image_stats/temp_requirements.txt
29-
- bash ../../examples/advanced/federated-statistics/image_stats/prepare_data.sh
29+
- python ../../examples/advanced/federated-statistics/image_stats/prepare_data.py --input_dir /tmp/nvflare/image_stats/data --output_dir /tmp/nvflare/image_stats/data
3030
- rm -f ../../examples/advanced/federated-statistics/image_stats/temp_requirements.txt
3131
test_name: Test example image_stats.

0 commit comments

Comments
 (0)