Skip to content

Commit f616462

Browse files
merged with master
1 parent aeff9c9 commit f616462

File tree

16 files changed

+175
-177
lines changed

16 files changed

+175
-177
lines changed

python/sparkdl/estimators/keras_image_file_estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424
import pyspark.ml.linalg as spla
2525
from pyspark.ml.param import Param, Params, TypeConverters
2626

27-
from sparkdl.image.imageIO import imageStructToArray
2827
from sparkdl.param import (
2928
keyword_only, CanLoadImage, HasKerasModel, HasKerasOptimizer, HasKerasLoss, HasOutputMode,
3029
HasInputCol, HasInputImageNodeName, HasLabelCol, HasOutputNodeName, HasOutputCol)
3130
from sparkdl.transformers.keras_image import KerasImageFileTransformer
31+
from sparkdl.image.image import ImageSchema
3232
import sparkdl.utils.jvmapi as JVMAPI
3333
import sparkdl.utils.keras_model as kmutil
3434

@@ -202,7 +202,7 @@ def _getNumpyFeaturesAndLabels(self, dataset):
202202
rows = image_df.collect()
203203
for row in rows:
204204
spimg = row[tmp_image_col]
205-
features = imageStructToArray(spimg)
205+
features = ImageSchema.toNDArray(spimg)
206206
localFeatures.append(features)
207207

208208
if not localFeatures: # NOTE(phi-dbq): pep-8 recommended against testing 0 == len(array)

python/sparkdl/image/image.py

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
"""
2828

2929
import numpy as np
30+
from collections import namedtuple
31+
3032
from pyspark import SparkContext
3133
from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string
3234
from pyspark.sql import DataFrame, SparkSession
@@ -42,9 +44,23 @@ class _ImageSchema(object):
4244
def __init__(self):
4345
self._imageSchema = None
4446
self._ocvTypes = None
47+
self._ocvTypesByName = None
48+
self._ocvTypesByMode = None
4549
self._imageFields = None
4650
self._undefinedImageType = None
4751

52+
_OcvType = namedtuple("OcvType", ["name", "mode", "nChannels", "dataType", "nptype"])
53+
54+
_ocvToNumpyMap = {
55+
"8U": np.dtype("uint8"),
56+
"8S": np.dtype("int8"),
57+
"16U": np.dtype('uint16'),
58+
"16S": np.dtype('int16'),
59+
"32S": np.dtype('int32'),
60+
"32F": np.dtype('float32'),
61+
"64F": np.dtype('float64')}
62+
_numpyToOcvMap = {x[1]: x[0] for x in _ocvToNumpyMap.items()}
63+
4864
@property
4965
def imageSchema(self):
5066
"""
@@ -57,7 +73,7 @@ def imageSchema(self):
5773
"""
5874

5975
if self._imageSchema is None:
60-
ctx = SparkContext._active_spark_context
76+
ctx = SparkContext.getOrCreate()
6177
jschema = ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageSchema()
6278
self._imageSchema = _parse_datatype_json_string(jschema.json())
6379
return self._imageSchema
@@ -71,11 +87,30 @@ def ocvTypes(self):
7187
7288
.. versionadded:: 2.3.0
7389
"""
74-
7590
if self._ocvTypes is None:
76-
ctx = SparkContext._active_spark_context
77-
self._ocvTypes = dict(ctx._jvm.org.apache.spark.ml.image.ImageSchema.javaOcvTypes())
78-
return self._ocvTypes
91+
ctx = SparkContext.getOrCreate()
92+
self._ocvTypes = [self._OcvType(name=x.name(),
93+
mode=x.mode(),
94+
nChannels=x.nChannels(),
95+
dataType=x.dataType(),
96+
nptype=self._ocvToNumpyMap[x.dataType()])
97+
for x in ctx._jvm.org.apache.spark.ml.image.ImageSchema.javaOcvTypes()]
98+
return [x for x in self._ocvTypes]
99+
100+
def ocvTypeByName(self, name):
101+
if self._ocvTypesByName is None:
102+
self._ocvTypesByName = {x.name: x for x in self.ocvTypes}
103+
if not name in self._ocvTypesByName:
104+
raise ValueError(
105+
"Unsupported image format, can not find matching OpenCvFormat for type = '%s'; currently supported formats = %s" %
106+
(name, str(
107+
self._ocvTypesByName.keys())))
108+
return self._ocvTypesByName[name]
109+
110+
def ocvTypeByMode(self, mode):
111+
if self._ocvTypesByMode is None:
112+
self._ocvTypesByMode = {x.mode: x for x in self.ocvTypes}
113+
return self._ocvTypesByMode[mode]
79114

80115
@property
81116
def imageFields(self):
@@ -88,7 +123,7 @@ def imageFields(self):
88123
"""
89124

90125
if self._imageFields is None:
91-
ctx = SparkContext._active_spark_context
126+
ctx = SparkContext.getOrCreate()
92127
self._imageFields = list(ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageFields())
93128
return self._imageFields
94129

@@ -101,7 +136,7 @@ def undefinedImageType(self):
101136
"""
102137

103138
if self._undefinedImageType is None:
104-
ctx = SparkContext._active_spark_context
139+
ctx = SparkContext.getOrCreate()
105140
self._undefinedImageType = \
106141
ctx._jvm.org.apache.spark.ml.image.ImageSchema.undefinedImageType()
107142
return self._undefinedImageType
@@ -126,15 +161,20 @@ def toNDArray(self, image):
126161
raise ValueError(
127162
"image argument should have attributes specified in "
128163
"ImageSchema.imageSchema [%s]." % ", ".join(self.imageFields))
129-
130164
height = image.height
131165
width = image.width
132166
nChannels = image.nChannels
167+
ocvType = self.ocvTypeByMode(image.mode)
168+
if nChannels != ocvType.nChannels:
169+
raise ValueError(
170+
"Unexpected number of channels, image has %d channels but OcvType '%s' expects %d channels." %
171+
(nChannels, ocvType.name, ocvType.nChannels))
172+
itemSz = ocvType.nptype.itemsize
133173
return np.ndarray(
134174
shape=(height, width, nChannels),
135-
dtype=np.uint8,
175+
dtype=ocvType.nptype,
136176
buffer=image.data,
137-
strides=(width * nChannels, nChannels, 1))
177+
strides=(width * nChannels * itemSz, nChannels * itemSz, itemSz))
138178

139179
def toImage(self, array, origin=""):
140180
"""
@@ -152,29 +192,27 @@ def toImage(self, array, origin=""):
152192
"array argument should be numpy.ndarray; however, it got [%s]." % type(array))
153193

154194
if array.ndim != 3:
155-
raise ValueError("Invalid array shape")
195+
raise ValueError("Invalid array shape %s" % str(array.shape))
156196

157197
height, width, nChannels = array.shape
158-
ocvTypes = ImageSchema.ocvTypes
159-
if nChannels == 1:
160-
mode = ocvTypes["CV_8UC1"]
161-
elif nChannels == 3:
162-
mode = ocvTypes["CV_8UC3"]
163-
elif nChannels == 4:
164-
mode = ocvTypes["CV_8UC4"]
165-
else:
166-
raise ValueError("Invalid number of channels")
198+
dtype = array.dtype
199+
if not dtype in self._numpyToOcvMap:
200+
raise ValueError(
201+
"Unexpected/unsupported array data type '%s', currently only supported formats are %s" %
202+
(str(array.dtype), str(self._numpyToOcvMap.keys())))
203+
ocvName = "CV_%sC%d" % (self._numpyToOcvMap[dtype], nChannels)
204+
ocvType = self.ocvTypeByName(ocvName)
167205

168206
# Running `bytearray(numpy.array([1]))` fails in specific Python versions
169207
# with a specific Numpy version, for example in Python 3.6.0 and NumPy 1.13.3.
170208
# Here, it avoids it by converting it to bytes.
171-
data = bytearray(array.astype(dtype=np.uint8).ravel().tobytes())
209+
data = bytearray(array.tobytes())
172210

173211
# Creating new Row with _create_row(), because Row(name = value, ... )
174212
# orders fields by name, which conflicts with expected schema order
175213
# when the new DataFrame is created by UDF
176214
return _create_row(self.imageFields,
177-
[origin, height, width, nChannels, mode, data])
215+
[origin, height, width, nChannels, ocvType.mode, data])
178216

179217
def readImages(self, path, recursive=False, numPartitions=-1,
180218
dropImageFailures=False, sampleRatio=1.0, seed=0):
@@ -203,7 +241,7 @@ def readImages(self, path, recursive=False, numPartitions=-1,
203241
.. versionadded:: 2.3.0
204242
"""
205243

206-
ctx = SparkContext._active_spark_context
244+
ctx = SparkContext.getOrCreate()
207245
spark = SparkSession(ctx)
208246
image_schema = ctx._jvm.org.apache.spark.ml.image.ImageSchema
209247
jsession = spark._jsparkSession

python/sparkdl/image/imageIO.py

Lines changed: 30 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -21,83 +21,12 @@
2121
from PIL import Image
2222

2323
# pyspark
24-
from pyspark import Row
2524
from pyspark import SparkContext
26-
from sparkdl.image.image import ImageSchema
2725
from pyspark.sql.functions import udf
2826
from pyspark.sql.types import (
2927
BinaryType, IntegerType, StringType, StructField, StructType)
3028

31-
32-
# ImageType represents supported OpenCV types
33-
# fields:
34-
# name - OpenCvMode
35-
# ord - Ordinal of the corresponding OpenCV mode (stored in mode field of ImageSchema).
36-
# nChannels - number of channels in the image
37-
# dtype - data type of the image's array, sorted as a numpy compatible string.
38-
#
39-
# NOTE: likely to be migrated to Spark ImageSchema code in the near future.
40-
_OcvType = namedtuple("OcvType", ["name", "ord", "nChannels", "dtype"])
41-
42-
43-
_supportedOcvTypes = (
44-
_OcvType(name="CV_8UC1", ord=0, nChannels=1, dtype="uint8"),
45-
_OcvType(name="CV_32FC1", ord=5, nChannels=1, dtype="float32"),
46-
_OcvType(name="CV_8UC3", ord=16, nChannels=3, dtype="uint8"),
47-
_OcvType(name="CV_32FC3", ord=21, nChannels=3, dtype="float32"),
48-
_OcvType(name="CV_8UC4", ord=24, nChannels=4, dtype="uint8"),
49-
_OcvType(name="CV_32FC4", ord=29, nChannels=4, dtype="float32"),
50-
)
51-
52-
# NOTE: likely to be migrated to Spark ImageSchema code in the near future.
53-
_ocvTypesByName = {m.name: m for m in _supportedOcvTypes}
54-
_ocvTypesByOrdinal = {m.ord: m for m in _supportedOcvTypes}
55-
56-
57-
def imageTypeByOrdinal(ord):
58-
if not ord in _ocvTypesByOrdinal:
59-
raise KeyError("unsupported image type with ordinal %d, supported OpenCV types = %s" % (
60-
ord, str(_supportedOcvTypes)))
61-
return _ocvTypesByOrdinal[ord]
62-
63-
64-
def imageTypeByName(name):
65-
if not name in _ocvTypesByName:
66-
raise KeyError("unsupported image type with name '%s', supported supported OpenCV types = %s" % (
67-
name, str(_supportedOcvTypes)))
68-
return _ocvTypesByName[name]
69-
70-
71-
def imageArrayToStruct(imgArray, origin=""):
72-
"""
73-
Create a row representation of an image from an image array.
74-
75-
:param imgArray: ndarray, image data.
76-
:return: Row, image as a DataFrame Row with schema==ImageSchema.
77-
"""
78-
# Sometimes tensors have a leading "batch-size" dimension. Assume to be 1 if it exists.
79-
if len(imgArray.shape) == 4:
80-
if imgArray.shape[0] != 1:
81-
raise ValueError(
82-
"The first dimension of a 4-d image array is expected to be 1.")
83-
imgArray = imgArray.reshape(imgArray.shape[1:])
84-
imageType = _arrayToOcvMode(imgArray)
85-
height, width, nChannels = imgArray.shape
86-
data = bytearray(imgArray.tobytes())
87-
return Row(origin=origin, mode=imageType.ord, height=height,
88-
width=width, nChannels=nChannels, data=data)
89-
90-
91-
def imageStructToArray(imageRow):
92-
"""
93-
Convert an image to a numpy array.
94-
95-
:param imageRow: Row, must use imageSchema.
96-
:return: ndarray, image data.
97-
"""
98-
imType = imageTypeByOrdinal(imageRow.mode)
99-
shape = (imageRow.height, imageRow.width, imageRow.nChannels)
100-
return np.ndarray(shape, imType.dtype, imageRow.data)
29+
from sparkdl.image.image import ImageSchema
10130

10231

10332
def imageStructToPIL(imageRow):
@@ -107,20 +36,20 @@ def imageStructToPIL(imageRow):
10736
:param imageRow: Row, must have ImageSchema
10837
:return PIL image
10938
"""
110-
imgType = imageTypeByOrdinal(imageRow.mode)
111-
if imgType.dtype != 'uint8':
39+
ary = ImageSchema.toNDArray(imageRow)
40+
if ary.dtype != np.uint8:
11241
raise ValueError("Can not convert image of type " +
113-
imgType.dtype + " to PIL, can only deal with 8U format")
114-
ary = imageStructToArray(imageRow)
42+
ary.dtype + " to PIL, can only deal with 8U format")
43+
11544
# PIL expects RGB order, image schema is BGR
11645
# => we need to flip the order unless there is only one channel
117-
if imgType.nChannels != 1:
46+
if imageRow.nChannels != 1:
11847
ary = _reverseChannels(ary)
119-
if imgType.nChannels == 1:
48+
if imageRow.nChannels == 1:
12049
return Image.fromarray(obj=ary, mode='L')
121-
elif imgType.nChannels == 3:
50+
elif imageRow.nChannels == 3:
12251
return Image.fromarray(obj=ary, mode='RGB')
123-
elif imgType.nChannels == 4:
52+
elif imageRow.nChannels == 4:
12453
return Image.fromarray(obj=ary, mode='RGBA')
12554
else:
12655
raise ValueError("don't know how to convert " +
@@ -132,19 +61,6 @@ def PIL_to_imageStruct(img):
13261
return _reverseChannels(np.asarray(img))
13362

13463

135-
def _arrayToOcvMode(arr):
136-
assert len(arr.shape) == 3, "Array should have 3 dimensions but has shape {}".format(
137-
arr.shape)
138-
num_channels = arr.shape[2]
139-
if arr.dtype == "uint8":
140-
name = "CV_8UC%d" % num_channels
141-
elif arr.dtype == "float32":
142-
name = "CV_32FC%d" % num_channels
143-
else:
144-
raise ValueError("Unsupported type '%s'" % arr.dtype)
145-
return imageTypeByName(name)
146-
147-
14864
def fixColorChannelOrdering(currentOrder, imgAry):
14965
if currentOrder == 'RGB':
15066
return _reverseChannels(imgAry)
@@ -160,6 +76,24 @@ def fixColorChannelOrdering(currentOrder, imgAry):
16076
"Unexpected channel order, expected one of L,RGB,BGR but got " + currentChannelOrder)
16177

16278

79+
def _stripBatchSize(imgArray):
80+
"""
81+
Strip batch size (if it's there) from a multi dimensional array.
82+
Assumes batch size is the first coordinate and is equal to 1.
83+
Batch size != 1 will cause an error.
84+
85+
:param imgArray: ndarray, image data.
86+
:return: imgArray without the leading batch size
87+
"""
88+
# Sometimes tensors have a leading "batch-size" dimension. Assume to be 1 if it exists.
89+
if len(imgArray.shape) == 4:
90+
if imgArray.shape[0] != 1:
91+
raise ValueError(
92+
"The first dimension of a 4-d image array is expected to be 1.")
93+
imgArray = imgArray.reshape(imgArray.shape[1:])
94+
return imgArray
95+
96+
16397
def _reverseChannels(ary):
16498
return ary[..., ::-1]
16599

@@ -183,8 +117,8 @@ def _resizeImageAsRow(imgAsRow):
183117
return imgAsRow
184118
imgAsPil = imageStructToPIL(imgAsRow).resize(sz)
185119
# PIL is RGB based while image schema is BGR based => we need to flip the channels
186-
imgAsArray = _reverseChannels(np.asarray(imgAsPil))
187-
return imageArrayToStruct(imgAsArray, origin=imgAsRow.origin)
120+
imgAsArray = PIL_to_imageStruct(imgAsPil)
121+
return ImageSchema.toImage(imgAsArray, origin=imgAsRow.origin)
188122
return udf(_resizeImageAsRow, ImageSchema.imageSchema['image'].dataType)
189123

190124

@@ -242,7 +176,7 @@ def readImagesWithCustomFn(path, decode_f, numPartition=None):
242176
def _readImagesWithCustomFn(path, decode_f, numPartition, sc):
243177
def _decode(path, raw_bytes):
244178
try:
245-
return imageArrayToStruct(decode_f(raw_bytes), origin=path)
179+
return ImageSchema.toImage(decode_f(raw_bytes), origin=path)
246180
except BaseException:
247181
return None
248182
decodeImage = udf(_decode, ImageSchema.imageSchema['image'].dataType)

python/sparkdl/param/image_params.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222
from sparkdl.image.image import ImageSchema
2323
from pyspark.ml.param import Param, Params, TypeConverters
2424
from pyspark.sql.functions import udf
25-
from sparkdl.image.imageIO import imageArrayToStruct
26-
from sparkdl.image.imageIO import _reverseChannels
25+
from sparkdl.image.imageIO import _reverseChannels, _stripBatchSize
2726
from sparkdl.param import SparkDLTypeConverters
2827

2928
OUTPUT_MODES = ["vector", "image"]
@@ -95,10 +94,10 @@ def loadImagesInternal(self, dataframe, inputCol):
9594
# Load from external resources can fail, so we should allow None to be returned
9695

9796
def load_image_uri_impl(uri):
98-
try:
99-
return imageArrayToStruct(_reverseChannels(loader(uri)))
100-
except BaseException: # pylint: disable=bare-except
101-
return None
97+
# try:
98+
return ImageSchema.toImage(_reverseChannels(_stripBatchSize(loader(uri))))
99+
# except BaseException: # pylint: disable=bare-except
100+
# return None
102101
load_udf = udf(load_image_uri_impl, ImageSchema.imageSchema['image'].dataType)
103102
return dataframe.withColumn(self._loadedImageCol(), load_udf(dataframe[inputCol]))
104103

0 commit comments

Comments
 (0)