Skip to content

Commit f8146e6

Browse files
authored
Normalize operation (#150)
CUDA based normalization operation integrated into cucim. Normalization supports following types. 1- simple range based normalization 2- Atangent based normalization closes #151
1 parent 1634f9d commit f8146e6

File tree

8 files changed

+240
-6
lines changed

8 files changed

+240
-6
lines changed

python/cucim/src/cucim/core/operations/expose/tests/test_expose.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from cucim.core.operations.expose.transform import (color_jitter, image_flip,
22
image_rotate_90,
3+
normalize_data,
34
rand_image_flip,
45
rand_image_rotate_90,
56
rand_zoom,
@@ -11,6 +12,7 @@ def test_exposed_transforms():
1112
assert image_flip is not None
1213
assert image_rotate_90 is not None
1314
assert scale_intensity_range is not None
15+
assert normalize_data is not None
1416
assert zoom is not None
1517
assert rand_zoom is not None
1618
assert rand_image_flip is not None

python/cucim/src/cucim/core/operations/expose/transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
# limitations under the License.
1414

1515
from cucim.core.operations.color import color_jitter
16-
from cucim.core.operations.intensity import (rand_zoom, scale_intensity_range,
17-
zoom)
16+
from cucim.core.operations.intensity import (normalize_data, rand_zoom,
17+
scale_intensity_range, zoom)
1818
from cucim.core.operations.spatial import (image_flip, image_rotate_90,
1919
rand_image_flip,
2020
rand_image_rotate_90)

python/cucim/src/cucim/core/operations/intensity/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from .normalize import normalize_data
12
from .scaling import scale_intensity_range
23
from .zoom import rand_zoom, zoom
34

45
__all__ = [
6+
"normalize_data",
57
"scale_intensity_range",
68
"zoom",
79
"rand_zoom"

python/cucim/src/cucim/core/operations/intensity/kernel/cuda_kernel_source.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,38 @@
1414

1515
cuda_kernel_code = r'''
1616
extern "C" {
17+
__global__ void normalize_data_by_range(float *in, float *out, \
18+
float norm_factor, \
19+
float min_value, \
20+
int total_size)
21+
{
22+
const unsigned int j = blockIdx.x * blockDim.x + threadIdx.x;
23+
24+
if( j < total_size ) {
25+
out[j] = norm_factor * (in[j] - min_value);
26+
}
27+
}
28+
29+
__global__ void normalize_data_by_atan(float *in, float *out, \
30+
float norm_factor, \
31+
float min_value, \
32+
int total_size)
33+
{
34+
const unsigned int j = blockIdx.x * blockDim.x + threadIdx.x;
35+
36+
if( j < total_size ) {
37+
out[j] = norm_factor * atan(in[j] - min_value);
38+
}
39+
}
40+
1741
__global__ void scaleVolume(float* image, float* output, \
1842
float x, float y, float bmin, \
1943
float bmax, int W)
2044
{
21-
const unsigned int j = blockIdx.x * blockDim.x + threadIdx.x;
22-
if(j < W) {
23-
output[j] = fmaxf(fminf(image[j] * x - y, bmax), bmin);
24-
}
45+
const unsigned int j = blockIdx.x * blockDim.x + threadIdx.x;
46+
if(j < W) {
47+
output[j] = fmaxf(fminf(image[j] * x - y, bmax), bmin);
48+
}
2549
}
2650
2751
__global__ void zoom_in_kernel(float *input_tensor, float *output_tensor, \
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright (c) 2021, NVIDIA CORPORATION.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any
16+
17+
import cupy
18+
import numpy as np
19+
20+
from .kernel.cuda_kernel_source import cuda_kernel_code
21+
22+
CUDA_KERNELS = cupy.RawModule(code=cuda_kernel_code)
23+
24+
25+
def normalize_data(
26+
img: Any,
27+
norm_constant: float,
28+
min_value: float,
29+
max_value: float,
30+
type: str = 'range'
31+
) -> Any:
32+
"""
33+
Apply intensity normalization to the input array.
34+
Normalize intensities to the range of [0, norm_constant].
35+
36+
Parameters
37+
----------
38+
img : channel first, cupy.ndarray or numpy.ndarray
39+
Input data of shape (C, H, W). Can also batch process input of shape
40+
(N, C, H, W). Can be a numpy.ndarray or cupy.ndarray.
41+
norm_constant: float
42+
Normalization range of the input data. [0, norm_constant]
43+
min_value : float
44+
Minimum intensity value in input data.
45+
max_value : float
46+
Maximum intensity value in input data.
47+
type : {'range', 'atan'}
48+
Type of normalization.
49+
50+
Returns
51+
-------
52+
out : cupy.ndarray or numpy.ndarray
53+
Output data. Same dimensions and type as input.
54+
55+
Raises
56+
------
57+
TypeError
58+
If input 'img' is not cupy.ndarray or numpy.ndarray
59+
ValueError
60+
If input original intensity min and max are same
61+
ValueError
62+
If incorrect normalization type is invoked
63+
64+
Examples
65+
--------
66+
>>> import cucim.core.operations.intensity as its
67+
>>> # input is channel first 3d array
68+
>>> output_array = its.normalize_data(input_arr,
69+
10, 0 , 255)
70+
"""
71+
try:
72+
if max_value - min_value == 0.0:
73+
raise ValueError("Minimum and Maximum intensity \
74+
same in input data")
75+
76+
if type not in ['range', 'atan']:
77+
raise ValueError("Incorrect normalization type. \
78+
Supported types are: \
79+
range based: 1,\
80+
atangent based: 2")
81+
82+
to_cupy = False
83+
84+
if isinstance(img, np.ndarray):
85+
to_cupy = True
86+
cupy_img = cupy.asarray(img, dtype=cupy.float32, order='C')
87+
elif not isinstance(img, cupy.ndarray):
88+
raise TypeError("img must be a cupy.ndarray or numpy.ndarray")
89+
else:
90+
cupy_img = cupy.ascontiguousarray(img)
91+
92+
if cupy_img.dtype != cupy.float32:
93+
if cupy.can_cast(img.dtype, cupy.float32) is False:
94+
raise ValueError(
95+
"Cannot safely cast type {cupy_img.dtype.name} to \
96+
'float32'"
97+
)
98+
else:
99+
cupy_img = cupy_img.astype(cupy.float32)
100+
101+
normalize = CUDA_KERNELS.get_function("normalize_data_by_range")
102+
103+
if type == 'atan':
104+
normalize = CUDA_KERNELS.get_function("normalize_data_by_atan")
105+
106+
value_range = max_value - min_value
107+
norm_factor = norm_constant / value_range
108+
109+
total_size = int(np.prod(img.shape))
110+
blockx = 128
111+
gridx = int((total_size - 1) / blockx + 1)
112+
113+
result = cupy.empty(img.shape, dtype=cupy_img.dtype)
114+
115+
normalize((gridx, 1, 1), (blockx, 1, 1),
116+
(cupy_img, result, np.float32(norm_factor),
117+
np.float32(min_value),
118+
np.int32(total_size)))
119+
120+
if img.dtype != cupy.float32:
121+
result = result.astype(img.dtype)
122+
123+
if to_cupy is True:
124+
result = cupy.asnumpy(result)
125+
126+
except Exception:
127+
raise
128+
129+
return result
110 KB
Loading
93.4 KB
Loading
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import os
2+
3+
import cupy
4+
import numpy as np
5+
import pytest
6+
import skimage.data
7+
from PIL import Image
8+
9+
import cucim.core.operations.intensity as its
10+
11+
12+
def get_input_arr():
13+
img = skimage.data.astronaut()
14+
arr = np.asarray(img)
15+
arr = np.transpose(arr)
16+
return arr
17+
18+
19+
def get_norm_data():
20+
dirname = os.path.dirname(__file__)
21+
img1 = Image.open(os.path.join(os.path.abspath(dirname),
22+
"normalized.png"))
23+
arr_o = np.asarray(img1)
24+
arr_o = np.transpose(arr_o)
25+
return arr_o
26+
27+
28+
def get_norm_atan_data():
29+
dirname = os.path.dirname(__file__)
30+
img1 = Image.open(os.path.join(os.path.abspath(dirname),
31+
"normalized_atan.png"))
32+
arr_o = np.asarray(img1)
33+
arr_o = np.transpose(arr_o)
34+
return arr_o
35+
36+
37+
def test_norm_param():
38+
arr = get_input_arr()
39+
with pytest.raises(ValueError):
40+
its.normalize_data(arr, 10.0, 255, 255)
41+
with pytest.raises(ValueError):
42+
its.normalize_data(arr, 10.0, 0, 255, 'invalid')
43+
with pytest.raises(TypeError):
44+
img = Image.fromarray(arr.T, 'RGB')
45+
its.normalize_data(img, 10.0, 0, 255)
46+
47+
48+
def test_norm_numpy_input():
49+
arr = get_input_arr()
50+
norm_arr = get_norm_data()
51+
output = its.normalize_data(arr, 10.0, 0, 255)
52+
assert np.allclose(output, norm_arr)
53+
54+
norm_atan_arr = get_norm_atan_data()
55+
output = its.normalize_data(arr, 10000, 0, 255, 'atan')
56+
assert np.allclose(output, norm_atan_arr)
57+
58+
59+
def test_norm_cupy_input():
60+
arr = get_input_arr()
61+
norm_arr = get_norm_data()
62+
cupy_arr = cupy.asarray(arr)
63+
cupy_output = its.normalize_data(cupy_arr, 10.0, 0, 255)
64+
np_output = cupy.asnumpy(cupy_output)
65+
assert np.allclose(np_output, norm_arr)
66+
67+
68+
def test_norm_batchinput():
69+
arr = get_input_arr()
70+
norm_arr = get_norm_data()
71+
arr_batch = np.stack((arr,) * 8, axis=0)
72+
output = its.normalize_data(arr_batch, 10.0, 0, 255)
73+
74+
assert output.shape[0] == 8
75+
76+
for i in range(output.shape[0]):
77+
assert np.allclose(output[i], norm_arr)

0 commit comments

Comments
 (0)