21
21
from PIL import Image
22
22
23
23
# pyspark
24
- from pyspark import Row
25
24
from pyspark import SparkContext
26
- from sparkdl .image .image import ImageSchema
27
25
from pyspark .sql .functions import udf
28
26
from pyspark .sql .types import (
29
27
BinaryType , IntegerType , StringType , StructField , StructType )
30
28
31
-
32
- # ImageType represents supported OpenCV types
33
- # fields:
34
- # name - OpenCvMode
35
- # ord - Ordinal of the corresponding OpenCV mode (stored in mode field of ImageSchema).
36
- # nChannels - number of channels in the image
37
- # dtype - data type of the image's array, sorted as a numpy compatible string.
38
- #
39
- # NOTE: likely to be migrated to Spark ImageSchema code in the near future.
40
- _OcvType = namedtuple ("OcvType" , ["name" , "ord" , "nChannels" , "dtype" ])
41
-
42
-
43
- _supportedOcvTypes = (
44
- _OcvType (name = "CV_8UC1" , ord = 0 , nChannels = 1 , dtype = "uint8" ),
45
- _OcvType (name = "CV_32FC1" , ord = 5 , nChannels = 1 , dtype = "float32" ),
46
- _OcvType (name = "CV_8UC3" , ord = 16 , nChannels = 3 , dtype = "uint8" ),
47
- _OcvType (name = "CV_32FC3" , ord = 21 , nChannels = 3 , dtype = "float32" ),
48
- _OcvType (name = "CV_8UC4" , ord = 24 , nChannels = 4 , dtype = "uint8" ),
49
- _OcvType (name = "CV_32FC4" , ord = 29 , nChannels = 4 , dtype = "float32" ),
50
- )
51
-
52
- # NOTE: likely to be migrated to Spark ImageSchema code in the near future.
53
- _ocvTypesByName = {m .name : m for m in _supportedOcvTypes }
54
- _ocvTypesByOrdinal = {m .ord : m for m in _supportedOcvTypes }
55
-
56
-
57
- def imageTypeByOrdinal (ord ):
58
- if not ord in _ocvTypesByOrdinal :
59
- raise KeyError ("unsupported image type with ordinal %d, supported OpenCV types = %s" % (
60
- ord , str (_supportedOcvTypes )))
61
- return _ocvTypesByOrdinal [ord ]
62
-
63
-
64
- def imageTypeByName (name ):
65
- if not name in _ocvTypesByName :
66
- raise KeyError ("unsupported image type with name '%s', supported supported OpenCV types = %s" % (
67
- name , str (_supportedOcvTypes )))
68
- return _ocvTypesByName [name ]
69
-
70
-
71
- def imageArrayToStruct (imgArray , origin = "" ):
72
- """
73
- Create a row representation of an image from an image array.
74
-
75
- :param imgArray: ndarray, image data.
76
- :return: Row, image as a DataFrame Row with schema==ImageSchema.
77
- """
78
- # Sometimes tensors have a leading "batch-size" dimension. Assume to be 1 if it exists.
79
- if len (imgArray .shape ) == 4 :
80
- if imgArray .shape [0 ] != 1 :
81
- raise ValueError (
82
- "The first dimension of a 4-d image array is expected to be 1." )
83
- imgArray = imgArray .reshape (imgArray .shape [1 :])
84
- imageType = _arrayToOcvMode (imgArray )
85
- height , width , nChannels = imgArray .shape
86
- data = bytearray (imgArray .tobytes ())
87
- return Row (origin = origin , mode = imageType .ord , height = height ,
88
- width = width , nChannels = nChannels , data = data )
89
-
90
-
91
- def imageStructToArray (imageRow ):
92
- """
93
- Convert an image to a numpy array.
94
-
95
- :param imageRow: Row, must use imageSchema.
96
- :return: ndarray, image data.
97
- """
98
- imType = imageTypeByOrdinal (imageRow .mode )
99
- shape = (imageRow .height , imageRow .width , imageRow .nChannels )
100
- return np .ndarray (shape , imType .dtype , imageRow .data )
29
+ from sparkdl .image .image import ImageSchema
101
30
102
31
103
32
def imageStructToPIL (imageRow ):
@@ -107,20 +36,20 @@ def imageStructToPIL(imageRow):
107
36
:param imageRow: Row, must have ImageSchema
108
37
:return PIL image
109
38
"""
110
- imgType = imageTypeByOrdinal (imageRow . mode )
111
- if imgType .dtype != ' uint8' :
39
+ ary = ImageSchema . toNDArray (imageRow )
40
+ if ary .dtype != np . uint8 :
112
41
raise ValueError ("Can not convert image of type " +
113
- imgType .dtype + " to PIL, can only deal with 8U format" )
114
- ary = imageStructToArray ( imageRow )
42
+ ary .dtype + " to PIL, can only deal with 8U format" )
43
+
115
44
# PIL expects RGB order, image schema is BGR
116
45
# => we need to flip the order unless there is only one channel
117
- if imgType .nChannels != 1 :
46
+ if imageRow .nChannels != 1 :
118
47
ary = _reverseChannels (ary )
119
- if imgType .nChannels == 1 :
48
+ if imageRow .nChannels == 1 :
120
49
return Image .fromarray (obj = ary , mode = 'L' )
121
- elif imgType .nChannels == 3 :
50
+ elif imageRow .nChannels == 3 :
122
51
return Image .fromarray (obj = ary , mode = 'RGB' )
123
- elif imgType .nChannels == 4 :
52
+ elif imageRow .nChannels == 4 :
124
53
return Image .fromarray (obj = ary , mode = 'RGBA' )
125
54
else :
126
55
raise ValueError ("don't know how to convert " +
@@ -132,19 +61,6 @@ def PIL_to_imageStruct(img):
132
61
return _reverseChannels (np .asarray (img ))
133
62
134
63
135
- def _arrayToOcvMode (arr ):
136
- assert len (arr .shape ) == 3 , "Array should have 3 dimensions but has shape {}" .format (
137
- arr .shape )
138
- num_channels = arr .shape [2 ]
139
- if arr .dtype == "uint8" :
140
- name = "CV_8UC%d" % num_channels
141
- elif arr .dtype == "float32" :
142
- name = "CV_32FC%d" % num_channels
143
- else :
144
- raise ValueError ("Unsupported type '%s'" % arr .dtype )
145
- return imageTypeByName (name )
146
-
147
-
148
64
def fixColorChannelOrdering (currentOrder , imgAry ):
149
65
if currentOrder == 'RGB' :
150
66
return _reverseChannels (imgAry )
@@ -160,6 +76,24 @@ def fixColorChannelOrdering(currentOrder, imgAry):
160
76
"Unexpected channel order, expected one of L,RGB,BGR but got " + currentChannelOrder )
161
77
162
78
79
+ def _stripBatchSize (imgArray ):
80
+ """
81
+ Strip batch size (if it's there) from a multi dimensional array.
82
+ Assumes batch size is the first coordinate and is equal to 1.
83
+ Batch size != 1 will cause an error.
84
+
85
+ :param imgArray: ndarray, image data.
86
+ :return: imgArray without the leading batch size
87
+ """
88
+ # Sometimes tensors have a leading "batch-size" dimension. Assume to be 1 if it exists.
89
+ if len (imgArray .shape ) == 4 :
90
+ if imgArray .shape [0 ] != 1 :
91
+ raise ValueError (
92
+ "The first dimension of a 4-d image array is expected to be 1." )
93
+ imgArray = imgArray .reshape (imgArray .shape [1 :])
94
+ return imgArray
95
+
96
+
163
97
def _reverseChannels (ary ):
164
98
return ary [..., ::- 1 ]
165
99
@@ -183,8 +117,8 @@ def _resizeImageAsRow(imgAsRow):
183
117
return imgAsRow
184
118
imgAsPil = imageStructToPIL (imgAsRow ).resize (sz )
185
119
# PIL is RGB based while image schema is BGR based => we need to flip the channels
186
- imgAsArray = _reverseChannels ( np . asarray ( imgAsPil ) )
187
- return imageArrayToStruct (imgAsArray , origin = imgAsRow .origin )
120
+ imgAsArray = PIL_to_imageStruct ( imgAsPil )
121
+ return ImageSchema . toImage (imgAsArray , origin = imgAsRow .origin )
188
122
return udf (_resizeImageAsRow , ImageSchema .imageSchema ['image' ].dataType )
189
123
190
124
@@ -242,7 +176,7 @@ def readImagesWithCustomFn(path, decode_f, numPartition=None):
242
176
def _readImagesWithCustomFn (path , decode_f , numPartition , sc ):
243
177
def _decode (path , raw_bytes ):
244
178
try :
245
- return imageArrayToStruct (decode_f (raw_bytes ), origin = path )
179
+ return ImageSchema . toImage (decode_f (raw_bytes ), origin = path )
246
180
except BaseException :
247
181
return None
248
182
decodeImage = udf (_decode , ImageSchema .imageSchema ['image' ].dataType )
0 commit comments