From 73e4c94ce1f8d5fade0ae96ea35ecfb3da876445 Mon Sep 17 00:00:00 2001 From: tomasatdatabricks Date: Thu, 28 Dec 2017 15:17:00 -0800 Subject: [PATCH 1/3] merged with master --- .../estimators/keras_image_file_estimator.py | 4 +- python/sparkdl/image/image.py | 84 ++++++++---- python/sparkdl/image/imageIO.py | 126 +++++------------- python/sparkdl/param/image_params.py | 11 +- python/sparkdl/transformers/tf_image.py | 7 +- python/sparkdl/udf/keras_image_model.py | 4 +- python/tests/graph/test_pieces.py | 10 +- python/tests/image/test_imageIO.py | 14 +- python/tests/transformers/keras_image_test.py | 4 +- python/tests/transformers/named_image_test.py | 9 +- python/tests/transformers/tf_image_test.py | 4 +- python/tests/udf/keras_sql_udf_test.py | 5 +- .../com/databricks/sparkdl/ImageUtils.scala | 6 +- .../apache/spark/ml/image/ImageSchema.scala | 52 ++++++-- .../sparkdl/DeepImageFeaturizerSuite.scala | 8 +- .../databricks/sparkdl/ImageUtilsSuite.scala | 4 +- 16 files changed, 175 insertions(+), 177 deletions(-) diff --git a/python/sparkdl/estimators/keras_image_file_estimator.py b/python/sparkdl/estimators/keras_image_file_estimator.py index 1d67ed6b..da303d96 100644 --- a/python/sparkdl/estimators/keras_image_file_estimator.py +++ b/python/sparkdl/estimators/keras_image_file_estimator.py @@ -24,11 +24,11 @@ import pyspark.ml.linalg as spla from pyspark.ml.param import Param, Params, TypeConverters -from sparkdl.image.imageIO import imageStructToArray from sparkdl.param import ( keyword_only, CanLoadImage, HasKerasModel, HasKerasOptimizer, HasKerasLoss, HasOutputMode, HasInputCol, HasInputImageNodeName, HasLabelCol, HasOutputNodeName, HasOutputCol) from sparkdl.transformers.keras_image import KerasImageFileTransformer +from sparkdl.image.image import ImageSchema import sparkdl.utils.jvmapi as JVMAPI import sparkdl.utils.keras_model as kmutil @@ -202,7 +202,7 @@ def _getNumpyFeaturesAndLabels(self, dataset): rows = image_df.collect() for row in rows: spimg = row[tmp_image_col] - features = imageStructToArray(spimg) + features = ImageSchema.toNDArray(spimg) localFeatures.append(features) if not localFeatures: # NOTE(phi-dbq): pep-8 recommended against testing 0 == len(array) diff --git a/python/sparkdl/image/image.py b/python/sparkdl/image/image.py index 2569abdb..0049ff6c 100644 --- a/python/sparkdl/image/image.py +++ b/python/sparkdl/image/image.py @@ -27,6 +27,8 @@ """ import numpy as np +from collections import namedtuple + from pyspark import SparkContext from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string from pyspark.sql import DataFrame, SparkSession @@ -42,9 +44,23 @@ class _ImageSchema(object): def __init__(self): self._imageSchema = None self._ocvTypes = None + self._ocvTypesByName = None + self._ocvTypesByMode = None self._imageFields = None self._undefinedImageType = None + _OcvType = namedtuple("OcvType", ["name", "mode", "nChannels", "dataType", "nptype"]) + + _ocvToNumpyMap = { + "8U": np.dtype("uint8"), + "8S": np.dtype("int8"), + "16U": np.dtype('uint16'), + "16S": np.dtype('int16'), + "32S": np.dtype('int32'), + "32F": np.dtype('float32'), + "64F": np.dtype('float64')} + _numpyToOcvMap = {x[1]: x[0] for x in _ocvToNumpyMap.items()} + @property def imageSchema(self): """ @@ -57,7 +73,7 @@ def imageSchema(self): """ if self._imageSchema is None: - ctx = SparkContext._active_spark_context + ctx = SparkContext.getOrCreate() jschema = ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageSchema() self._imageSchema = _parse_datatype_json_string(jschema.json()) return self._imageSchema @@ -71,11 +87,30 @@ def ocvTypes(self): .. versionadded:: 2.3.0 """ - if self._ocvTypes is None: - ctx = SparkContext._active_spark_context - self._ocvTypes = dict(ctx._jvm.org.apache.spark.ml.image.ImageSchema.javaOcvTypes()) - return self._ocvTypes + ctx = SparkContext.getOrCreate() + self._ocvTypes = [self._OcvType(name=x.name(), + mode=x.mode(), + nChannels=x.nChannels(), + dataType=x.dataType(), + nptype=self._ocvToNumpyMap[x.dataType()]) + for x in ctx._jvm.org.apache.spark.ml.image.ImageSchema.javaOcvTypes()] + return self._ocvTypes[:] + + def ocvTypeByName(self, name): + if self._ocvTypesByName is None: + self._ocvTypesByName = {x.name: x for x in self.ocvTypes} + if not name in self._ocvTypesByName: + raise ValueError( + "Unsupported image format, can not find matching OpenCvFormat for type = '%s'; currently supported formats = %s" % + (name, str( + self._ocvTypesByName.keys()))) + return self._ocvTypesByName[name] + + def ocvTypeByMode(self, mode): + if self._ocvTypesByMode is None: + self._ocvTypesByMode = {x.mode: x for x in self.ocvTypes} + return self._ocvTypesByMode[mode] @property def imageFields(self): @@ -88,7 +123,7 @@ def imageFields(self): """ if self._imageFields is None: - ctx = SparkContext._active_spark_context + ctx = SparkContext.getOrCreate() self._imageFields = list(ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageFields()) return self._imageFields @@ -101,7 +136,7 @@ def undefinedImageType(self): """ if self._undefinedImageType is None: - ctx = SparkContext._active_spark_context + ctx = SparkContext.getOrCreate() self._undefinedImageType = \ ctx._jvm.org.apache.spark.ml.image.ImageSchema.undefinedImageType() return self._undefinedImageType @@ -126,15 +161,20 @@ def toNDArray(self, image): raise ValueError( "image argument should have attributes specified in " "ImageSchema.imageSchema [%s]." % ", ".join(self.imageFields)) - height = image.height width = image.width nChannels = image.nChannels + ocvType = self.ocvTypeByMode(image.mode) + if nChannels != ocvType.nChannels: + raise ValueError( + "Unexpected number of channels, image has %d channels but OcvType '%s' expects %d channels." % + (nChannels, ocvType.name, ocvType.nChannels)) + itemSz = ocvType.nptype.itemsize return np.ndarray( shape=(height, width, nChannels), - dtype=np.uint8, + dtype=ocvType.nptype, buffer=image.data, - strides=(width * nChannels, nChannels, 1)) + strides=(width * nChannels * itemSz, nChannels * itemSz, itemSz)) def toImage(self, array, origin=""): """ @@ -152,29 +192,27 @@ def toImage(self, array, origin=""): "array argument should be numpy.ndarray; however, it got [%s]." % type(array)) if array.ndim != 3: - raise ValueError("Invalid array shape") + raise ValueError("Invalid array shape %s" % str(array.shape)) height, width, nChannels = array.shape - ocvTypes = ImageSchema.ocvTypes - if nChannels == 1: - mode = ocvTypes["CV_8UC1"] - elif nChannels == 3: - mode = ocvTypes["CV_8UC3"] - elif nChannels == 4: - mode = ocvTypes["CV_8UC4"] - else: - raise ValueError("Invalid number of channels") + dtype = array.dtype + if not dtype in self._numpyToOcvMap: + raise ValueError( + "Unexpected/unsupported array data type '%s', currently only supported formats are %s" % + (str(array.dtype), str(self._numpyToOcvMap.keys()))) + ocvName = "CV_%sC%d" % (self._numpyToOcvMap[dtype], nChannels) + ocvType = self.ocvTypeByName(ocvName) # Running `bytearray(numpy.array([1]))` fails in specific Python versions # with a specific Numpy version, for example in Python 3.6.0 and NumPy 1.13.3. # Here, it avoids it by converting it to bytes. - data = bytearray(array.astype(dtype=np.uint8).ravel().tobytes()) + data = bytearray(array.tobytes()) # Creating new Row with _create_row(), because Row(name = value, ... ) # orders fields by name, which conflicts with expected schema order # when the new DataFrame is created by UDF return _create_row(self.imageFields, - [origin, height, width, nChannels, mode, data]) + [origin, height, width, nChannels, ocvType.mode, data]) def readImages(self, path, recursive=False, numPartitions=-1, dropImageFailures=False, sampleRatio=1.0, seed=0): @@ -203,7 +241,7 @@ def readImages(self, path, recursive=False, numPartitions=-1, .. versionadded:: 2.3.0 """ - ctx = SparkContext._active_spark_context + ctx = SparkContext.getOrCreate() spark = SparkSession(ctx) image_schema = ctx._jvm.org.apache.spark.ml.image.ImageSchema jsession = spark._jsparkSession diff --git a/python/sparkdl/image/imageIO.py b/python/sparkdl/image/imageIO.py index a2a60145..d56f9da5 100644 --- a/python/sparkdl/image/imageIO.py +++ b/python/sparkdl/image/imageIO.py @@ -21,83 +21,12 @@ from PIL import Image # pyspark -from pyspark import Row from pyspark import SparkContext -from sparkdl.image.image import ImageSchema from pyspark.sql.functions import udf from pyspark.sql.types import ( BinaryType, IntegerType, StringType, StructField, StructType) - -# ImageType represents supported OpenCV types -# fields: -# name - OpenCvMode -# ord - Ordinal of the corresponding OpenCV mode (stored in mode field of ImageSchema). -# nChannels - number of channels in the image -# dtype - data type of the image's array, sorted as a numpy compatible string. -# -# NOTE: likely to be migrated to Spark ImageSchema code in the near future. -_OcvType = namedtuple("OcvType", ["name", "ord", "nChannels", "dtype"]) - - -_supportedOcvTypes = ( - _OcvType(name="CV_8UC1", ord=0, nChannels=1, dtype="uint8"), - _OcvType(name="CV_32FC1", ord=5, nChannels=1, dtype="float32"), - _OcvType(name="CV_8UC3", ord=16, nChannels=3, dtype="uint8"), - _OcvType(name="CV_32FC3", ord=21, nChannels=3, dtype="float32"), - _OcvType(name="CV_8UC4", ord=24, nChannels=4, dtype="uint8"), - _OcvType(name="CV_32FC4", ord=29, nChannels=4, dtype="float32"), -) - -# NOTE: likely to be migrated to Spark ImageSchema code in the near future. -_ocvTypesByName = {m.name: m for m in _supportedOcvTypes} -_ocvTypesByOrdinal = {m.ord: m for m in _supportedOcvTypes} - - -def imageTypeByOrdinal(ord): - if not ord in _ocvTypesByOrdinal: - raise KeyError("unsupported image type with ordinal %d, supported OpenCV types = %s" % ( - ord, str(_supportedOcvTypes))) - return _ocvTypesByOrdinal[ord] - - -def imageTypeByName(name): - if not name in _ocvTypesByName: - raise KeyError("unsupported image type with name '%s', supported supported OpenCV types = %s" % ( - name, str(_supportedOcvTypes))) - return _ocvTypesByName[name] - - -def imageArrayToStruct(imgArray, origin=""): - """ - Create a row representation of an image from an image array. - - :param imgArray: ndarray, image data. - :return: Row, image as a DataFrame Row with schema==ImageSchema. - """ - # Sometimes tensors have a leading "batch-size" dimension. Assume to be 1 if it exists. - if len(imgArray.shape) == 4: - if imgArray.shape[0] != 1: - raise ValueError( - "The first dimension of a 4-d image array is expected to be 1.") - imgArray = imgArray.reshape(imgArray.shape[1:]) - imageType = _arrayToOcvMode(imgArray) - height, width, nChannels = imgArray.shape - data = bytearray(imgArray.tobytes()) - return Row(origin=origin, mode=imageType.ord, height=height, - width=width, nChannels=nChannels, data=data) - - -def imageStructToArray(imageRow): - """ - Convert an image to a numpy array. - - :param imageRow: Row, must use imageSchema. - :return: ndarray, image data. - """ - imType = imageTypeByOrdinal(imageRow.mode) - shape = (imageRow.height, imageRow.width, imageRow.nChannels) - return np.ndarray(shape, imType.dtype, imageRow.data) +from sparkdl.image.image import ImageSchema def imageStructToPIL(imageRow): @@ -107,20 +36,20 @@ def imageStructToPIL(imageRow): :param imageRow: Row, must have ImageSchema :return PIL image """ - imgType = imageTypeByOrdinal(imageRow.mode) - if imgType.dtype != 'uint8': + ary = ImageSchema.toNDArray(imageRow) + if ary.dtype != np.uint8: raise ValueError("Can not convert image of type " + - imgType.dtype + " to PIL, can only deal with 8U format") - ary = imageStructToArray(imageRow) + ary.dtype + " to PIL, can only deal with 8U format") + # PIL expects RGB order, image schema is BGR # => we need to flip the order unless there is only one channel - if imgType.nChannels != 1: + if imageRow.nChannels != 1: ary = _reverseChannels(ary) - if imgType.nChannels == 1: + if imageRow.nChannels == 1: return Image.fromarray(obj=ary, mode='L') - elif imgType.nChannels == 3: + elif imageRow.nChannels == 3: return Image.fromarray(obj=ary, mode='RGB') - elif imgType.nChannels == 4: + elif imageRow.nChannels == 4: return Image.fromarray(obj=ary, mode='RGBA') else: raise ValueError("don't know how to convert " + @@ -132,19 +61,6 @@ def PIL_to_imageStruct(img): return _reverseChannels(np.asarray(img)) -def _arrayToOcvMode(arr): - assert len(arr.shape) == 3, "Array should have 3 dimensions but has shape {}".format( - arr.shape) - num_channels = arr.shape[2] - if arr.dtype == "uint8": - name = "CV_8UC%d" % num_channels - elif arr.dtype == "float32": - name = "CV_32FC%d" % num_channels - else: - raise ValueError("Unsupported type '%s'" % arr.dtype) - return imageTypeByName(name) - - def fixColorChannelOrdering(currentOrder, imgAry): if currentOrder == 'RGB': return _reverseChannels(imgAry) @@ -160,6 +76,24 @@ def fixColorChannelOrdering(currentOrder, imgAry): "Unexpected channel order, expected one of L,RGB,BGR but got " + currentChannelOrder) +def _stripBatchSize(imgArray): + """ + Strip batch size (if it's there) from a multi dimensional array. + Assumes batch size is the first coordinate and is equal to 1. + Batch size != 1 will cause an error. + + :param imgArray: ndarray, image data. + :return: imgArray without the leading batch size + """ + # Sometimes tensors have a leading "batch-size" dimension. Assume to be 1 if it exists. + if len(imgArray.shape) == 4: + if imgArray.shape[0] != 1: + raise ValueError( + "The first dimension of a 4-d image array is expected to be 1.") + imgArray = imgArray.reshape(imgArray.shape[1:]) + return imgArray + + def _reverseChannels(ary): return ary[..., ::-1] @@ -183,8 +117,8 @@ def _resizeImageAsRow(imgAsRow): return imgAsRow imgAsPil = imageStructToPIL(imgAsRow).resize(sz) # PIL is RGB based while image schema is BGR based => we need to flip the channels - imgAsArray = _reverseChannels(np.asarray(imgAsPil)) - return imageArrayToStruct(imgAsArray, origin=imgAsRow.origin) + imgAsArray = PIL_to_imageStruct(imgAsPil) + return ImageSchema.toImage(imgAsArray, origin=imgAsRow.origin) return udf(_resizeImageAsRow, ImageSchema.imageSchema['image'].dataType) @@ -242,7 +176,7 @@ def readImagesWithCustomFn(path, decode_f, numPartition=None): def _readImagesWithCustomFn(path, decode_f, numPartition, sc): def _decode(path, raw_bytes): try: - return imageArrayToStruct(decode_f(raw_bytes), origin=path) + return ImageSchema.toImage(decode_f(raw_bytes), origin=path) except BaseException: return None decodeImage = udf(_decode, ImageSchema.imageSchema['image'].dataType) diff --git a/python/sparkdl/param/image_params.py b/python/sparkdl/param/image_params.py index 0e7dcdbc..65824df1 100644 --- a/python/sparkdl/param/image_params.py +++ b/python/sparkdl/param/image_params.py @@ -22,8 +22,7 @@ from sparkdl.image.image import ImageSchema from pyspark.ml.param import Param, Params, TypeConverters from pyspark.sql.functions import udf -from sparkdl.image.imageIO import imageArrayToStruct -from sparkdl.image.imageIO import _reverseChannels +from sparkdl.image.imageIO import _reverseChannels, _stripBatchSize from sparkdl.param import SparkDLTypeConverters OUTPUT_MODES = ["vector", "image"] @@ -95,10 +94,10 @@ def loadImagesInternal(self, dataframe, inputCol): # Load from external resources can fail, so we should allow None to be returned def load_image_uri_impl(uri): - try: - return imageArrayToStruct(_reverseChannels(loader(uri))) - except BaseException: # pylint: disable=bare-except - return None + # try: + return ImageSchema.toImage(_reverseChannels(_stripBatchSize(loader(uri)))) + # except BaseException: # pylint: disable=bare-except + # return None load_udf = udf(load_image_uri_impl, ImageSchema.imageSchema['image'].dataType) return dataframe.withColumn(self._loadedImageCol(), load_udf(dataframe[inputCol])) diff --git a/python/sparkdl/transformers/tf_image.py b/python/sparkdl/transformers/tf_image.py index 3796288e..520e472b 100644 --- a/python/sparkdl/transformers/tf_image.py +++ b/python/sparkdl/transformers/tf_image.py @@ -162,8 +162,7 @@ def _getImageDtype(self, dataset): # Assumes that the dtype for all images is the same in the given dataframe. pdf = dataset.select(self.getInputCol()).take(1) img = pdf[0][self.getInputCol()] - img_type = imageIO.imageTypeByOrdinal(img.mode) - return img_type.dtype + return ImageSchema.ocvTypeByMode(img.mode).nptype # TODO: duplicate code, same functionality as sparkdl.graph.pieces.py::builSpImageConverter # TODO: It should be extracted as a util function and shared @@ -229,10 +228,10 @@ def _convertOutputToImage(self, df, tfs_output_col, output_shape): def to_image(orig_image, numeric_data): # Assume the returned image has float pixels but same #channels as input - mode = imageIO.imageTypeByName('CV_32FC%d' % orig_image.nChannels) + ocvType = ImageSchema.ocvTypeByName('CV_32FC%d' % orig_image.nChannels) data = bytearray(np.array(numeric_data).astype(np.float32).tobytes()) nChannels = orig_image.nChannels - return Row(origin="", mode=mode.ord, height=height, + return Row(origin="", mode=ocvType.mode, height=height, width=width, nChannels=nChannels, data=data) to_image_udf = udf(to_image, ImageSchema.imageSchema['image'].dataType) resDf = df.withColumn(self.getOutputCol(), to_image_udf( diff --git a/python/sparkdl/udf/keras_image_model.py b/python/sparkdl/udf/keras_image_model.py index f5ba9db0..8ba6602b 100644 --- a/python/sparkdl/udf/keras_image_model.py +++ b/python/sparkdl/udf/keras_image_model.py @@ -137,8 +137,6 @@ def _serialize_and_reload_with(preprocessor): def udf_impl(spimg): import numpy as np from tempfile import NamedTemporaryFile - from sparkdl.image.imageIO import imageArrayToStruct - img = imageIO.imageStructToPIL(spimg) # Warning: must use lossless format to guarantee consistency temp_fp = NamedTemporaryFile(suffix='.png') @@ -150,6 +148,6 @@ def udf_impl(spimg): # Keras works in RGB order, need to fix the order img_arr_reloaded = imageIO.fixColorChannelOrdering( currentOrder='RGB', imgAry=img_arr_reloaded) - return imageArrayToStruct(img_arr_reloaded) + return ImageSchema.toImage(img_arr_reloaded) return udf_impl diff --git a/python/tests/graph/test_pieces.py b/python/tests/graph/test_pieces.py index 1395bf41..5ac1bed0 100644 --- a/python/tests/graph/test_pieces.py +++ b/python/tests/graph/test_pieces.py @@ -40,8 +40,8 @@ from sparkdl.graph.builder import IsolatedSession, GraphFunction import sparkdl.graph.pieces as gfac import sparkdl.graph.utils as tfx -from sparkdl.image.imageIO import imageArrayToStruct -from sparkdl.image.imageIO import imageTypeByOrdinal +from sparkdl.image.image import ImageSchema +from sparkdl.image import imageIO from ..tests import SparkDLTestCase @@ -64,10 +64,10 @@ def exec_gfn_spimg_decode(spimg_dict, img_dtype): return img_out def check_image_round_trip(img_arr): - spimg_dict = imageArrayToStruct(img_arr).asDict() + spimg_dict = ImageSchema.toImage(img_arr).asDict() spimg_dict['data'] = bytes(spimg_dict['data']) img_arr_out = exec_gfn_spimg_decode( - spimg_dict, imageTypeByOrdinal(spimg_dict['mode']).dtype) + spimg_dict, ImageSchema.ocvTypeByMode(spimg_dict['mode']).nptype) self.assertTrue(np.all(img_arr_out == img_arr)) for fp in img_fpaths: @@ -159,7 +159,7 @@ def test_pipeline(self): img_input = xcpt.preprocess_input(img_arr) preds_ref = xcpt_model.predict(img_input) - spimg_input_dict = imageArrayToStruct(img_input).asDict() + spimg_input_dict = ImageSchema.toImage(imageIO._stripBatchSize(img_input)).asDict() spimg_input_dict['data'] = bytes(spimg_input_dict['data']) with IsolatedSession() as issn: # Need blank import scope name so that spimg fields match the input names diff --git a/python/tests/image/test_imageIO.py b/python/tests/image/test_imageIO.py index 2c172603..5ec7cc78 100644 --- a/python/tests/image/test_imageIO.py +++ b/python/tests/image/test_imageIO.py @@ -81,7 +81,7 @@ def test_resize(self): self.assertRaises(ValueError, imageIO.createResizeImageUDF, [1, 2, 3]) make_smaller = imageIO.createResizeImageUDF([4, 5]).func - imgAsRow = imageIO.imageArrayToStruct(array) + imgAsRow = ImageSchema.toImage(array) smallerImg = make_smaller(imgAsRow) self.assertEqual(smallerImg.height, 4) self.assertEqual(smallerImg.width, 5) @@ -89,7 +89,7 @@ def test_resize(self): # Compare to PIL resizing imgAsPIL = PIL.Image.fromarray(obj=imageIO._reverseChannels(array)).resize((5, 4)) smallerAry = imageIO._reverseChannels(np.asarray(imgAsPIL)) - np.testing.assert_array_equal(smallerAry, imageIO.imageStructToArray(smallerImg)) + np.testing.assert_array_equal(smallerAry, ImageSchema.toNDArray(smallerImg)) # Test that resize with the same size is a no-op sameImage = imageIO.createResizeImageUDF((imgAsRow.height, imgAsRow.width)).func(imgAsRow) self.assertEqual(imgAsRow, sameImage) @@ -103,11 +103,11 @@ def test_imageConversions(self): """ def _test(array): height, width, chan = array.shape - imgAsStruct = imageIO.imageArrayToStruct(array) + imgAsStruct = ImageSchema.toImage(array) self.assertEqual(imgAsStruct.height, height) self.assertEqual(imgAsStruct.width, width) self.assertEqual(imgAsStruct.data, array.tobytes()) - imgReconstructed = imageIO.imageStructToArray(imgAsStruct) + imgReconstructed = ImageSchema.toNDArray(imgAsStruct) np.testing.assert_array_equal(array, imgReconstructed) for nChannels in (1, 3, 4): # unsigned bytes @@ -129,7 +129,7 @@ def test_readImages(self): img = validImages.first().image self.assertEqual(img.height, array.shape[0]) self.assertEqual(img.width, array.shape[1]) - self.assertEqual(imageIO.imageTypeByOrdinal(img.mode).nChannels, array.shape[2]) + self.assertEqual(ImageSchema.ocvTypeByMode(img.mode).nChannels, array.shape[2]) # array comes out of PIL and is in RGB order self.assertEqual(img.data, array.tobytes()) @@ -137,8 +137,8 @@ def test_udf_schema(self): # Test that utility functions can be used to create a udf that accepts and return # imageSchema def do_nothing(imgRow): - array = imageIO.imageStructToArray(imgRow) - return imageIO.imageArrayToStruct(array) + array = ImageSchema.toNDArray(imgRow) + return ImageSchema.toImage(array) do_nothing_udf = udf(do_nothing, ImageSchema.imageSchema['image'].dataType) df = imageIO._readImagesWithCustomFn( diff --git a/python/tests/transformers/keras_image_test.py b/python/tests/transformers/keras_image_test.py index d215d991..8d8d7f2f 100644 --- a/python/tests/transformers/keras_image_test.py +++ b/python/tests/transformers/keras_image_test.py @@ -13,7 +13,7 @@ # limitations under the License. # -from sparkdl.image.imageIO import imageStructToArray +from sparkdl.image.image import ImageSchema from sparkdl.transformers.keras_image import KerasImageFileTransformer from sparkdl.transformers.utils import InceptionV3Constants from ..tests import SparkDLTestCase @@ -39,7 +39,7 @@ def test_loadImages(self): img_col = transformer._loadedImageCol() expected_shape = InceptionV3Constants.INPUT_SHAPE + (3,) for row in image_df.collect(): - arr = imageStructToArray(row[img_col]) + arr = ImageSchema.toNDArray(row[img_col]) self.assertEqual(arr.shape, expected_shape) diff --git a/python/tests/transformers/named_image_test.py b/python/tests/transformers/named_image_test.py index 110be520..4ace8934 100644 --- a/python/tests/transformers/named_image_test.py +++ b/python/tests/transformers/named_image_test.py @@ -126,6 +126,12 @@ def test_buildtfgraphforname(self): self.assertEqual(kerasPredict.shape, tfPredict.shape) np.testing.assert_array_almost_equal(kerasPredict, tfPredict) + def _rowWithImage(self, img): + row = ImageSchema.toImage(img.astype('uint8')) + # re-order row to avoid pyspark bug + return [[getattr(row, field.name) + for field in ImageSchema.imageSchema['image'].dataType]] + def test_DeepImagePredictorNoReshape(self): """ Run sparkDL predictor on manually-resized images and compare result to the @@ -135,8 +141,7 @@ def test_DeepImagePredictorNoReshape(self): kerasPredict = self.kerasPredict def rowWithImage(img): - # return [imageIO.imageArrayToStruct(img.astype('uint8'), imageType.sparkMode)] - row = imageIO.imageArrayToStruct(img.astype('uint8')) + row = ImageSchema.toImage(img.astype('uint8')) # re-order row to avoid pyspark bug return [[getattr(row, field.name) for field in ImageSchema.imageSchema['image'].dataType]] diff --git a/python/tests/transformers/tf_image_test.py b/python/tests/transformers/tf_image_test.py index f1a346f2..74ea46f1 100644 --- a/python/tests/transformers/tf_image_test.py +++ b/python/tests/transformers/tf_image_test.py @@ -21,7 +21,7 @@ import tensorflow as tf import sparkdl.graph.utils as tfx -from sparkdl.image.imageIO import imageStructToArray +from sparkdl.image.image import ImageSchema from sparkdl.image import imageIO from sparkdl.transformers.keras_utils import KSessionWrap from sparkdl.transformers.tf_image import TFImageTransformer @@ -134,7 +134,7 @@ def _executeTensorflow(self, graph, input_tensor_name, output_tensor_name, values = {} topK = {} for img_row in image_collected: - image = np.expand_dims(imageStructToArray(img_row[input_col]), axis=0) + image = np.expand_dims(ImageSchema.toNDArray(img_row[input_col]), axis=0) uri = img_row['image']['origin'] output = sess.run([output_tensor], feed_dict={ diff --git a/python/tests/udf/keras_sql_udf_test.py b/python/tests/udf/keras_sql_udf_test.py index 9743e940..462be9fc 100644 --- a/python/tests/udf/keras_sql_udf_test.py +++ b/python/tests/udf/keras_sql_udf_test.py @@ -34,7 +34,6 @@ import sparkdl.graph.utils as tfx from sparkdl.udf.keras_image_model import registerKerasImageUDF from sparkdl.utils import jvmapi as JVMAPI -from sparkdl.image.imageIO import imageArrayToStruct from sparkdl.image.imageIO import _reverseChannels from ..tests import SparkDLTestCase from ..transformers.image_utils import getSampleImagePathsDF @@ -103,11 +102,11 @@ def pil_load_spimg(fpath): import numpy as np img_arr = np.array(Image.open(fpath), dtype=np.uint8) # PIL is RGB, image schema is BGR => need to flip the channels - return imageArrayToStruct(_reverseChannels(img_arr)) + return ImageSchema.toImage(_reverseChannels(img_arr)) def keras_load_spimg(fpath): # Keras loads image in RGB order, ImageSchema expects BGR => need to flip - return imageArrayToStruct(_reverseChannels(keras_load_img(fpath))) + return ImageSchema.toImage(_reverseChannels(keras_load_img(fpath))) # Load image with Keras and store it in our image schema JVMAPI.registerUDF('keras_load_spimg', keras_load_spimg, diff --git a/src/main/scala/com/databricks/sparkdl/ImageUtils.scala b/src/main/scala/com/databricks/sparkdl/ImageUtils.scala index 2e660ef6..194a0df6 100644 --- a/src/main/scala/com/databricks/sparkdl/ImageUtils.scala +++ b/src/main/scala/com/databricks/sparkdl/ImageUtils.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions.udf -private[sparkdl] object ImageUtils { +object ImageUtils { /** * Takes a Row image (spImage) and returns a Java BufferedImage. Currently supports 1 & 3 @@ -100,7 +100,7 @@ private[sparkdl] object ImageUtils { } h += 1 } - Row(origin, height, width, channels, ImageSchema.ocvTypes("CV_8UC3"), decoded) + Row(origin, height, width, channels, ImageSchema.OpenCvType.get("CV_8UC3").mode, decoded) } /** @@ -115,7 +115,7 @@ private[sparkdl] object ImageUtils { * @param spImage image to resize. * @return resized image, if the input was BGR or 1 channel, the output will be BGR. */ - private[sparkdl] def resizeImage( + def resizeImage( tgtHeight: Int, tgtWidth: Int, tgtChannels: Int, diff --git a/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala b/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala index 9ebce2ad..3ff04852 100644 --- a/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala +++ b/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala @@ -42,18 +42,44 @@ object ImageSchema { val undefinedImageType = "Undefined" - /** - * (Scala-specific) OpenCV type mapping supported - */ - val ocvTypes: Map[String, Int] = Map( - undefinedImageType -> -1, - "CV_8U" -> 0, "CV_8UC1" -> 0, "CV_8UC3" -> 16, "CV_8UC4" -> 24 - ) + + case class OpenCvType(mode: Int, dataType: String, nChannels: Int) { + def name: String = "CV_" + dataType + "C" + nChannels + override def toString: String = "OpenCvType(mode = " + mode + ", name = " + name + ")" + } + + object OpenCvType { + def get(name: String): OpenCvType = { + ocvTypes.find(x => x.name == name).getOrElse(throw new IllegalArgumentException("Unknown open cv type " + name)) + } + def get(mode: Int): OpenCvType = { + ocvTypes.find(x => x.mode == mode).getOrElse(throw new IllegalArgumentException("Unknown open cv mode " + mode)) + } + val undefinedType = OpenCvType(-1, "N/A", -1) + } /** - * (Java-specific) OpenCV type mapping supported + * A Mapping of Type to Numbers in OpenCV + * + * C1 C2 C3 C4 + * CV_8U 0 8 16 24 + * CV_8S 1 9 17 25 + * CV_16U 2 10 18 26 + * CV_16S 3 11 19 27 + * CV_32S 4 12 20 28 + * CV_32F 5 13 21 29 + * CV_64F 6 14 22 30 */ - val javaOcvTypes: java.util.Map[String, Int] = ocvTypes.asJava + val ocvTypes = { + val types = + for (nc <- Array(1, 2, 3, 4); + dt <- Array("8U", "8S", "16U", "16S", "32S", "32F", "64F")) + yield (dt, nc) + val ordinals = for (i <- 0 to 3; j <- 0 to 6) yield ( i * 8 + j) + (ordinals zip types).map(x => OpenCvType(x._1, x._2._1, x._2._2)) + } + + val javaOcvTypes = ocvTypes.asJava /** * Schema for the image column: Row(String, Int, Int, Int, Int, Array[Byte]) @@ -124,7 +150,7 @@ object ImageSchema { * @return Row with the default values */ private[spark] def invalidImageRow(origin: String): Row = - Row(Row(origin, -1, -1, -1, ocvTypes(undefinedImageType), Array.ofDim[Byte](0))) + Row(Row(origin, -1, -1, -1, OpenCvType.undefinedType, Array.ofDim[Byte](0))) /** * Convert the compressed image (jpeg, png, etc.) into OpenCV @@ -147,11 +173,11 @@ object ImageSchema { val height = img.getHeight val width = img.getWidth val (nChannels, mode) = if (isGray) { - (1, ocvTypes("CV_8UC1")) + (1, OpenCvType.get("CV_8UC1").mode) } else if (hasAlpha) { - (4, ocvTypes("CV_8UC4")) + (4, OpenCvType.get("CV_8UC4").mode) } else { - (3, ocvTypes("CV_8UC3")) + (3, OpenCvType.get("CV_8UC3").mode) } val imageSize = height * width * nChannels diff --git a/src/test/scala/com/databricks/sparkdl/DeepImageFeaturizerSuite.scala b/src/test/scala/com/databricks/sparkdl/DeepImageFeaturizerSuite.scala index 6182723b..5b3a25f8 100644 --- a/src/test/scala/com/databricks/sparkdl/DeepImageFeaturizerSuite.scala +++ b/src/test/scala/com/databricks/sparkdl/DeepImageFeaturizerSuite.scala @@ -17,10 +17,9 @@ package com.databricks.sparkdl import org.scalatest.FunSuite - import org.apache.spark.ml.image.ImageSchema import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.functions.{col, lit, udf} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.{StructField, StructType} @@ -48,8 +47,9 @@ class DeepImageFeaturizerSuite extends FunSuite with TestSparkContext with Defau assert(featurizer.transformSchema(myData.schema) === transformed.schema) // check that we can materialize a row, and the type is Vector. - val result = transformed.select(col(outputColName)).collect() - assert(result.forall { r: Row => r.getAs[Vector](0).size == 24 }) + val result = transformed.select(col(outputColName)).collect().map((r: Row) => r.getAs[Vector](0)) + assert(result.forall { v:Vector => v.size == 24 }) + result.forall { v:Vector => {println(v.toArray.mkString("[",",","]"));true} } } test ("Test schema validation.") { diff --git a/src/test/scala/com/databricks/sparkdl/ImageUtilsSuite.scala b/src/test/scala/com/databricks/sparkdl/ImageUtilsSuite.scala index 8ec69522..aef9cb1a 100644 --- a/src/test/scala/com/databricks/sparkdl/ImageUtilsSuite.scala +++ b/src/test/scala/com/databricks/sparkdl/ImageUtilsSuite.scala @@ -66,7 +66,7 @@ class ImageUtilsSuite extends FunSuite { val rand = new Random(971) val imageData = Array.ofDim[Byte](height * width * channels) rand.nextBytes(imageData) - val spImage = Row(null, height, width, channels, ImageSchema.ocvTypes("CV_8UC3"), imageData) + val spImage = Row(null, height, width, channels, ImageSchema.OpenCvType.get("CV_8UC3").mode, imageData) val bufferedImage = ImageUtils.spImageToBufferedImage(spImage) val testImage = ImageUtils.spImageFromBufferedImage(bufferedImage) assert(spImage === testImage, "Image changed during conversion.") @@ -84,7 +84,7 @@ class ImageUtilsSuite extends FunSuite { (0 until width).flatMap { j => Seq(x + j + 1, x + j + 4, x + j + 7) } }.map(_.toByte).toArray - val spImage = Row(null, height, width, 3, ImageSchema.ocvTypes("CV_8UC3"), rawData) + val spImage = Row(null, height, width, 3, ImageSchema.OpenCvType.get("CV_8UC3").mode, rawData) val bufferedImage = ImageUtils.spImageToBufferedImage(spImage) for (h <- 0 until height) { From c8c90e013731e5fbf03277895ec800f37a33d939 Mon Sep 17 00:00:00 2001 From: tomasatdatabricks Date: Fri, 29 Dec 2017 13:52:31 -0800 Subject: [PATCH 2/3] Added undefined ocv type to the list of types --- python/sparkdl/image/image.py | 1 + src/main/scala/org/apache/spark/ml/image/ImageSchema.scala | 7 ++----- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/python/sparkdl/image/image.py b/python/sparkdl/image/image.py index 0049ff6c..83f9cf1f 100644 --- a/python/sparkdl/image/image.py +++ b/python/sparkdl/image/image.py @@ -52,6 +52,7 @@ def __init__(self): _OcvType = namedtuple("OcvType", ["name", "mode", "nChannels", "dataType", "nptype"]) _ocvToNumpyMap = { + "N/A": "N/A", "8U": np.dtype("uint8"), "8S": np.dtype("int8"), "16U": np.dtype('uint16'), diff --git a/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala b/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala index 3ff04852..00915d77 100644 --- a/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala +++ b/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala @@ -40,11 +40,8 @@ import org.apache.spark.sql.types._ @Since("2.3.0") object ImageSchema { - val undefinedImageType = "Undefined" - - case class OpenCvType(mode: Int, dataType: String, nChannels: Int) { - def name: String = "CV_" + dataType + "C" + nChannels + def name: String = if (mode == -1) "Undefined" else "CV_" + dataType + "C" + nChannels override def toString: String = "OpenCvType(mode = " + mode + ", name = " + name + ")" } @@ -76,7 +73,7 @@ object ImageSchema { dt <- Array("8U", "8S", "16U", "16S", "32S", "32F", "64F")) yield (dt, nc) val ordinals = for (i <- 0 to 3; j <- 0 to 6) yield ( i * 8 + j) - (ordinals zip types).map(x => OpenCvType(x._1, x._2._1, x._2._2)) + OpenCvType.undefinedType +: (ordinals zip types).map(x => OpenCvType(x._1, x._2._1, x._2._2)) } val javaOcvTypes = ocvTypes.asJava From ce17c433861eccf5bbaeb2af8b14b1e672c282b6 Mon Sep 17 00:00:00 2001 From: tomasatdatabricks Date: Wed, 3 Jan 2018 11:12:39 -0800 Subject: [PATCH 3/3] added conversion test for all ocv types --- python/tests/image/test_imageIO.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/python/tests/image/test_imageIO.py b/python/tests/image/test_imageIO.py index 5ec7cc78..d7b257a8 100644 --- a/python/tests/image/test_imageIO.py +++ b/python/tests/image/test_imageIO.py @@ -18,6 +18,7 @@ # 3rd party import numpy as np import PIL.Image +import random # pyspark from pyspark.sql.functions import col, udf @@ -156,5 +157,25 @@ def test_filesTODF(self): self.assertTrue(hasattr(first, "filePath")) self.assertEqual(type(first.fileData), bytearray) + def test_ocv_types(self): + ocvList = ImageSchema.ocvTypes + self.assertEqual("Undefined", ocvList[0].name) + self.assertEqual(-1, ocvList[0].mode) + self.assertEqual("N/A", ocvList[0].dataType) + for x in ocvList: + self.assertEqual(x, ImageSchema.ocvTypeByName(x.name)) + self.assertEqual(x, ImageSchema.ocvTypeByMode(x.mode)) + + def test_conversions(self): + ary_src = [[[1e7*random.random() for z in range(4)] for y in range(10)] for x in range(20)] + for ocvType in ImageSchema.ocvTypes: + if ocvType.name == 'Undefined': + continue + x = [[ary_src[i][j][0:ocvType.nChannels] for j in range(len(ary_src[0]))] for i in range(len(ary_src))] + npary0 = np.array(x).astype(ocvType.nptype) + img = ImageSchema.toImage(npary0) + self.assertEqual(ocvType,ImageSchema.ocvTypeByMode(img.mode)) + npary1 = ImageSchema.toNDArray(img) + np.testing.assert_array_equal(npary0, npary1) # TODO: make unit tests for arrayToImageRow on arrays of varying shapes, channels, dtypes.