Skip to content

[WIP - do not merge!] Move sparkdl utilities for conversion between numpy arrays and image schema to ImageSchema #90

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/sparkdl/estimators/keras_image_file_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
import pyspark.ml.linalg as spla
from pyspark.ml.param import Param, Params, TypeConverters

from sparkdl.image.imageIO import imageStructToArray
from sparkdl.param import (
keyword_only, CanLoadImage, HasKerasModel, HasKerasOptimizer, HasKerasLoss, HasOutputMode,
HasInputCol, HasInputImageNodeName, HasLabelCol, HasOutputNodeName, HasOutputCol)
from sparkdl.transformers.keras_image import KerasImageFileTransformer
from sparkdl.image.image import ImageSchema
import sparkdl.utils.jvmapi as JVMAPI
import sparkdl.utils.keras_model as kmutil

Expand Down Expand Up @@ -202,7 +202,7 @@ def _getNumpyFeaturesAndLabels(self, dataset):
rows = image_df.collect()
for row in rows:
spimg = row[tmp_image_col]
features = imageStructToArray(spimg)
features = ImageSchema.toNDArray(spimg)
localFeatures.append(features)

if not localFeatures: # NOTE(phi-dbq): pep-8 recommended against testing 0 == len(array)
Expand Down
85 changes: 62 additions & 23 deletions python/sparkdl/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
"""

import numpy as np
from collections import namedtuple

from pyspark import SparkContext
from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string
from pyspark.sql import DataFrame, SparkSession
Expand All @@ -42,9 +44,24 @@ class _ImageSchema(object):
def __init__(self):
self._imageSchema = None
self._ocvTypes = None
self._ocvTypesByName = None
self._ocvTypesByMode = None
self._imageFields = None
self._undefinedImageType = None

_OcvType = namedtuple("OcvType", ["name", "mode", "nChannels", "dataType", "nptype"])

_ocvToNumpyMap = {
"N/A": "N/A",
"8U": np.dtype("uint8"),
"8S": np.dtype("int8"),
"16U": np.dtype('uint16'),
"16S": np.dtype('int16'),
"32S": np.dtype('int32'),
"32F": np.dtype('float32'),
"64F": np.dtype('float64')}
_numpyToOcvMap = {x[1]: x[0] for x in _ocvToNumpyMap.items()}

@property
def imageSchema(self):
"""
Expand All @@ -57,7 +74,7 @@ def imageSchema(self):
"""

if self._imageSchema is None:
ctx = SparkContext._active_spark_context
ctx = SparkContext.getOrCreate()
jschema = ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageSchema()
self._imageSchema = _parse_datatype_json_string(jschema.json())
return self._imageSchema
Expand All @@ -71,11 +88,30 @@ def ocvTypes(self):

.. versionadded:: 2.3.0
"""

if self._ocvTypes is None:
ctx = SparkContext._active_spark_context
self._ocvTypes = dict(ctx._jvm.org.apache.spark.ml.image.ImageSchema.javaOcvTypes())
return self._ocvTypes
ctx = SparkContext.getOrCreate()
self._ocvTypes = [self._OcvType(name=x.name(),
mode=x.mode(),
nChannels=x.nChannels(),
dataType=x.dataType(),
nptype=self._ocvToNumpyMap[x.dataType()])
for x in ctx._jvm.org.apache.spark.ml.image.ImageSchema.javaOcvTypes()]
return self._ocvTypes[:]

def ocvTypeByName(self, name):
if self._ocvTypesByName is None:
self._ocvTypesByName = {x.name: x for x in self.ocvTypes}
if not name in self._ocvTypesByName:
raise ValueError(
"Unsupported image format, can not find matching OpenCvFormat for type = '%s'; currently supported formats = %s" %
(name, str(
self._ocvTypesByName.keys())))
return self._ocvTypesByName[name]

def ocvTypeByMode(self, mode):
if self._ocvTypesByMode is None:
self._ocvTypesByMode = {x.mode: x for x in self.ocvTypes}
return self._ocvTypesByMode[mode]

@property
def imageFields(self):
Expand All @@ -88,7 +124,7 @@ def imageFields(self):
"""

if self._imageFields is None:
ctx = SparkContext._active_spark_context
ctx = SparkContext.getOrCreate()
self._imageFields = list(ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageFields())
return self._imageFields

Expand All @@ -101,7 +137,7 @@ def undefinedImageType(self):
"""

if self._undefinedImageType is None:
ctx = SparkContext._active_spark_context
ctx = SparkContext.getOrCreate()
self._undefinedImageType = \
ctx._jvm.org.apache.spark.ml.image.ImageSchema.undefinedImageType()
return self._undefinedImageType
Expand All @@ -126,15 +162,20 @@ def toNDArray(self, image):
raise ValueError(
"image argument should have attributes specified in "
"ImageSchema.imageSchema [%s]." % ", ".join(self.imageFields))

height = image.height
width = image.width
nChannels = image.nChannels
ocvType = self.ocvTypeByMode(image.mode)
if nChannels != ocvType.nChannels:
raise ValueError(
"Unexpected number of channels, image has %d channels but OcvType '%s' expects %d channels." %
(nChannels, ocvType.name, ocvType.nChannels))
itemSz = ocvType.nptype.itemsize
return np.ndarray(
shape=(height, width, nChannels),
dtype=np.uint8,
dtype=ocvType.nptype,
buffer=image.data,
strides=(width * nChannels, nChannels, 1))
strides=(width * nChannels * itemSz, nChannels * itemSz, itemSz))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will numpy figure out the right strides if we don't pass it explicitly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm yeah I would think so. The original code from ms folks was like this and I did not want to do more changes than necessary.


def toImage(self, array, origin=""):
"""
Expand All @@ -152,29 +193,27 @@ def toImage(self, array, origin=""):
"array argument should be numpy.ndarray; however, it got [%s]." % type(array))

if array.ndim != 3:
raise ValueError("Invalid array shape")
raise ValueError("Invalid array shape %s" % str(array.shape))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to reshape 2d arrays to be shape + (1,)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with their approach. I think it's better to make the caller pass the arguments in expected format rather than trying to auto-convert unless that is completely unambiguous.

So in this case, we say images are always 3 dimensional arrays and it's up to the user to make sure they conform to that. Otherwise they might be passing something else than they think they are passing and we would mask their bug until later.


height, width, nChannels = array.shape
ocvTypes = ImageSchema.ocvTypes
if nChannels == 1:
mode = ocvTypes["CV_8UC1"]
elif nChannels == 3:
mode = ocvTypes["CV_8UC3"]
elif nChannels == 4:
mode = ocvTypes["CV_8UC4"]
else:
raise ValueError("Invalid number of channels")
dtype = array.dtype
if not dtype in self._numpyToOcvMap:
raise ValueError(
"Unexpected/unsupported array data type '%s', currently only supported formats are %s" %
(str(array.dtype), str(self._numpyToOcvMap.keys())))
ocvName = "CV_%sC%d" % (self._numpyToOcvMap[dtype], nChannels)
ocvType = self.ocvTypeByName(ocvName)

# Running `bytearray(numpy.array([1]))` fails in specific Python versions
# with a specific Numpy version, for example in Python 3.6.0 and NumPy 1.13.3.
# Here, it avoids it by converting it to bytes.
data = bytearray(array.astype(dtype=np.uint8).ravel().tobytes())
data = bytearray(array.tobytes())

# Creating new Row with _create_row(), because Row(name = value, ... )
# orders fields by name, which conflicts with expected schema order
# when the new DataFrame is created by UDF
return _create_row(self.imageFields,
[origin, height, width, nChannels, mode, data])
[origin, height, width, nChannels, ocvType.mode, data])

def readImages(self, path, recursive=False, numPartitions=-1,
dropImageFailures=False, sampleRatio=1.0, seed=0):
Expand Down Expand Up @@ -203,7 +242,7 @@ def readImages(self, path, recursive=False, numPartitions=-1,
.. versionadded:: 2.3.0
"""

ctx = SparkContext._active_spark_context
ctx = SparkContext.getOrCreate()
spark = SparkSession(ctx)
image_schema = ctx._jvm.org.apache.spark.ml.image.ImageSchema
jsession = spark._jsparkSession
Expand Down
126 changes: 30 additions & 96 deletions python/sparkdl/image/imageIO.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,83 +21,12 @@
from PIL import Image

# pyspark
from pyspark import Row
from pyspark import SparkContext
from sparkdl.image.image import ImageSchema
from pyspark.sql.functions import udf
from pyspark.sql.types import (
BinaryType, IntegerType, StringType, StructField, StructType)


# ImageType represents supported OpenCV types
# fields:
# name - OpenCvMode
# ord - Ordinal of the corresponding OpenCV mode (stored in mode field of ImageSchema).
# nChannels - number of channels in the image
# dtype - data type of the image's array, sorted as a numpy compatible string.
#
# NOTE: likely to be migrated to Spark ImageSchema code in the near future.
_OcvType = namedtuple("OcvType", ["name", "ord", "nChannels", "dtype"])


_supportedOcvTypes = (
_OcvType(name="CV_8UC1", ord=0, nChannels=1, dtype="uint8"),
_OcvType(name="CV_32FC1", ord=5, nChannels=1, dtype="float32"),
_OcvType(name="CV_8UC3", ord=16, nChannels=3, dtype="uint8"),
_OcvType(name="CV_32FC3", ord=21, nChannels=3, dtype="float32"),
_OcvType(name="CV_8UC4", ord=24, nChannels=4, dtype="uint8"),
_OcvType(name="CV_32FC4", ord=29, nChannels=4, dtype="float32"),
)

# NOTE: likely to be migrated to Spark ImageSchema code in the near future.
_ocvTypesByName = {m.name: m for m in _supportedOcvTypes}
_ocvTypesByOrdinal = {m.ord: m for m in _supportedOcvTypes}


def imageTypeByOrdinal(ord):
if not ord in _ocvTypesByOrdinal:
raise KeyError("unsupported image type with ordinal %d, supported OpenCV types = %s" % (
ord, str(_supportedOcvTypes)))
return _ocvTypesByOrdinal[ord]


def imageTypeByName(name):
if not name in _ocvTypesByName:
raise KeyError("unsupported image type with name '%s', supported supported OpenCV types = %s" % (
name, str(_supportedOcvTypes)))
return _ocvTypesByName[name]


def imageArrayToStruct(imgArray, origin=""):
"""
Create a row representation of an image from an image array.

:param imgArray: ndarray, image data.
:return: Row, image as a DataFrame Row with schema==ImageSchema.
"""
# Sometimes tensors have a leading "batch-size" dimension. Assume to be 1 if it exists.
if len(imgArray.shape) == 4:
if imgArray.shape[0] != 1:
raise ValueError(
"The first dimension of a 4-d image array is expected to be 1.")
imgArray = imgArray.reshape(imgArray.shape[1:])
imageType = _arrayToOcvMode(imgArray)
height, width, nChannels = imgArray.shape
data = bytearray(imgArray.tobytes())
return Row(origin=origin, mode=imageType.ord, height=height,
width=width, nChannels=nChannels, data=data)


def imageStructToArray(imageRow):
"""
Convert an image to a numpy array.

:param imageRow: Row, must use imageSchema.
:return: ndarray, image data.
"""
imType = imageTypeByOrdinal(imageRow.mode)
shape = (imageRow.height, imageRow.width, imageRow.nChannels)
return np.ndarray(shape, imType.dtype, imageRow.data)
from sparkdl.image.image import ImageSchema


def imageStructToPIL(imageRow):
Expand All @@ -107,20 +36,20 @@ def imageStructToPIL(imageRow):
:param imageRow: Row, must have ImageSchema
:return PIL image
"""
imgType = imageTypeByOrdinal(imageRow.mode)
if imgType.dtype != 'uint8':
ary = ImageSchema.toNDArray(imageRow)
if ary.dtype != np.uint8:
raise ValueError("Can not convert image of type " +
imgType.dtype + " to PIL, can only deal with 8U format")
ary = imageStructToArray(imageRow)
ary.dtype + " to PIL, can only deal with 8U format")

# PIL expects RGB order, image schema is BGR
# => we need to flip the order unless there is only one channel
if imgType.nChannels != 1:
if imageRow.nChannels != 1:
ary = _reverseChannels(ary)
if imgType.nChannels == 1:
if imageRow.nChannels == 1:
return Image.fromarray(obj=ary, mode='L')
elif imgType.nChannels == 3:
elif imageRow.nChannels == 3:
return Image.fromarray(obj=ary, mode='RGB')
elif imgType.nChannels == 4:
elif imageRow.nChannels == 4:
return Image.fromarray(obj=ary, mode='RGBA')
else:
raise ValueError("don't know how to convert " +
Expand All @@ -132,19 +61,6 @@ def PIL_to_imageStruct(img):
return _reverseChannels(np.asarray(img))


def _arrayToOcvMode(arr):
assert len(arr.shape) == 3, "Array should have 3 dimensions but has shape {}".format(
arr.shape)
num_channels = arr.shape[2]
if arr.dtype == "uint8":
name = "CV_8UC%d" % num_channels
elif arr.dtype == "float32":
name = "CV_32FC%d" % num_channels
else:
raise ValueError("Unsupported type '%s'" % arr.dtype)
return imageTypeByName(name)


def fixColorChannelOrdering(currentOrder, imgAry):
if currentOrder == 'RGB':
return _reverseChannels(imgAry)
Expand All @@ -160,6 +76,24 @@ def fixColorChannelOrdering(currentOrder, imgAry):
"Unexpected channel order, expected one of L,RGB,BGR but got " + currentChannelOrder)


def _stripBatchSize(imgArray):
"""
Strip batch size (if it's there) from a multi dimensional array.
Assumes batch size is the first coordinate and is equal to 1.
Batch size != 1 will cause an error.

:param imgArray: ndarray, image data.
:return: imgArray without the leading batch size
"""
# Sometimes tensors have a leading "batch-size" dimension. Assume to be 1 if it exists.
if len(imgArray.shape) == 4:
if imgArray.shape[0] != 1:
raise ValueError(
"The first dimension of a 4-d image array is expected to be 1.")
imgArray = imgArray.reshape(imgArray.shape[1:])
return imgArray


def _reverseChannels(ary):
return ary[..., ::-1]

Expand All @@ -183,8 +117,8 @@ def _resizeImageAsRow(imgAsRow):
return imgAsRow
imgAsPil = imageStructToPIL(imgAsRow).resize(sz)
# PIL is RGB based while image schema is BGR based => we need to flip the channels
imgAsArray = _reverseChannels(np.asarray(imgAsPil))
return imageArrayToStruct(imgAsArray, origin=imgAsRow.origin)
imgAsArray = PIL_to_imageStruct(imgAsPil)
return ImageSchema.toImage(imgAsArray, origin=imgAsRow.origin)
return udf(_resizeImageAsRow, ImageSchema.imageSchema['image'].dataType)


Expand Down Expand Up @@ -242,7 +176,7 @@ def readImagesWithCustomFn(path, decode_f, numPartition=None):
def _readImagesWithCustomFn(path, decode_f, numPartition, sc):
def _decode(path, raw_bytes):
try:
return imageArrayToStruct(decode_f(raw_bytes), origin=path)
return ImageSchema.toImage(decode_f(raw_bytes), origin=path)
except BaseException:
return None
decodeImage = udf(_decode, ImageSchema.imageSchema['image'].dataType)
Expand Down
11 changes: 5 additions & 6 deletions python/sparkdl/param/image_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
from sparkdl.image.image import ImageSchema
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.sql.functions import udf
from sparkdl.image.imageIO import imageArrayToStruct
from sparkdl.image.imageIO import _reverseChannels
from sparkdl.image.imageIO import _reverseChannels, _stripBatchSize
from sparkdl.param import SparkDLTypeConverters

OUTPUT_MODES = ["vector", "image"]
Expand Down Expand Up @@ -95,10 +94,10 @@ def loadImagesInternal(self, dataframe, inputCol):
# Load from external resources can fail, so we should allow None to be returned

def load_image_uri_impl(uri):
try:
return imageArrayToStruct(_reverseChannels(loader(uri)))
except BaseException: # pylint: disable=bare-except
return None
# try:
return ImageSchema.toImage(_reverseChannels(_stripBatchSize(loader(uri))))
# except BaseException: # pylint: disable=bare-except
# return None
load_udf = udf(load_image_uri_impl, ImageSchema.imageSchema['image'].dataType)
return dataframe.withColumn(self._loadedImageCol(), load_udf(dataframe[inputCol]))

Expand Down
Loading