-
Notifications
You must be signed in to change notification settings - Fork 495
[WIP - do not merge!] Move sparkdl utilities for conversion between numpy arrays and image schema to ImageSchema #90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,8 @@ | |
""" | ||
|
||
import numpy as np | ||
from collections import namedtuple | ||
|
||
from pyspark import SparkContext | ||
from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string | ||
from pyspark.sql import DataFrame, SparkSession | ||
|
@@ -42,9 +44,24 @@ class _ImageSchema(object): | |
def __init__(self): | ||
self._imageSchema = None | ||
self._ocvTypes = None | ||
self._ocvTypesByName = None | ||
self._ocvTypesByMode = None | ||
self._imageFields = None | ||
self._undefinedImageType = None | ||
|
||
_OcvType = namedtuple("OcvType", ["name", "mode", "nChannels", "dataType", "nptype"]) | ||
|
||
_ocvToNumpyMap = { | ||
"N/A": "N/A", | ||
"8U": np.dtype("uint8"), | ||
"8S": np.dtype("int8"), | ||
"16U": np.dtype('uint16'), | ||
"16S": np.dtype('int16'), | ||
"32S": np.dtype('int32'), | ||
"32F": np.dtype('float32'), | ||
"64F": np.dtype('float64')} | ||
_numpyToOcvMap = {x[1]: x[0] for x in _ocvToNumpyMap.items()} | ||
|
||
@property | ||
def imageSchema(self): | ||
""" | ||
|
@@ -57,7 +74,7 @@ def imageSchema(self): | |
""" | ||
|
||
if self._imageSchema is None: | ||
ctx = SparkContext._active_spark_context | ||
ctx = SparkContext.getOrCreate() | ||
jschema = ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageSchema() | ||
self._imageSchema = _parse_datatype_json_string(jschema.json()) | ||
return self._imageSchema | ||
|
@@ -71,11 +88,30 @@ def ocvTypes(self): | |
|
||
.. versionadded:: 2.3.0 | ||
""" | ||
|
||
if self._ocvTypes is None: | ||
ctx = SparkContext._active_spark_context | ||
self._ocvTypes = dict(ctx._jvm.org.apache.spark.ml.image.ImageSchema.javaOcvTypes()) | ||
return self._ocvTypes | ||
ctx = SparkContext.getOrCreate() | ||
self._ocvTypes = [self._OcvType(name=x.name(), | ||
mode=x.mode(), | ||
nChannels=x.nChannels(), | ||
dataType=x.dataType(), | ||
nptype=self._ocvToNumpyMap[x.dataType()]) | ||
for x in ctx._jvm.org.apache.spark.ml.image.ImageSchema.javaOcvTypes()] | ||
return self._ocvTypes[:] | ||
|
||
def ocvTypeByName(self, name): | ||
if self._ocvTypesByName is None: | ||
self._ocvTypesByName = {x.name: x for x in self.ocvTypes} | ||
if not name in self._ocvTypesByName: | ||
raise ValueError( | ||
"Unsupported image format, can not find matching OpenCvFormat for type = '%s'; currently supported formats = %s" % | ||
(name, str( | ||
self._ocvTypesByName.keys()))) | ||
return self._ocvTypesByName[name] | ||
|
||
def ocvTypeByMode(self, mode): | ||
if self._ocvTypesByMode is None: | ||
self._ocvTypesByMode = {x.mode: x for x in self.ocvTypes} | ||
return self._ocvTypesByMode[mode] | ||
|
||
@property | ||
def imageFields(self): | ||
|
@@ -88,7 +124,7 @@ def imageFields(self): | |
""" | ||
|
||
if self._imageFields is None: | ||
ctx = SparkContext._active_spark_context | ||
ctx = SparkContext.getOrCreate() | ||
self._imageFields = list(ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageFields()) | ||
return self._imageFields | ||
|
||
|
@@ -101,7 +137,7 @@ def undefinedImageType(self): | |
""" | ||
|
||
if self._undefinedImageType is None: | ||
ctx = SparkContext._active_spark_context | ||
ctx = SparkContext.getOrCreate() | ||
self._undefinedImageType = \ | ||
ctx._jvm.org.apache.spark.ml.image.ImageSchema.undefinedImageType() | ||
return self._undefinedImageType | ||
|
@@ -126,15 +162,20 @@ def toNDArray(self, image): | |
raise ValueError( | ||
"image argument should have attributes specified in " | ||
"ImageSchema.imageSchema [%s]." % ", ".join(self.imageFields)) | ||
|
||
height = image.height | ||
width = image.width | ||
nChannels = image.nChannels | ||
ocvType = self.ocvTypeByMode(image.mode) | ||
if nChannels != ocvType.nChannels: | ||
raise ValueError( | ||
"Unexpected number of channels, image has %d channels but OcvType '%s' expects %d channels." % | ||
(nChannels, ocvType.name, ocvType.nChannels)) | ||
itemSz = ocvType.nptype.itemsize | ||
return np.ndarray( | ||
shape=(height, width, nChannels), | ||
dtype=np.uint8, | ||
dtype=ocvType.nptype, | ||
buffer=image.data, | ||
strides=(width * nChannels, nChannels, 1)) | ||
strides=(width * nChannels * itemSz, nChannels * itemSz, itemSz)) | ||
|
||
def toImage(self, array, origin=""): | ||
""" | ||
|
@@ -152,29 +193,27 @@ def toImage(self, array, origin=""): | |
"array argument should be numpy.ndarray; however, it got [%s]." % type(array)) | ||
|
||
if array.ndim != 3: | ||
raise ValueError("Invalid array shape") | ||
raise ValueError("Invalid array shape %s" % str(array.shape)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want to reshape 2d arrays to be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with their approach. I think it's better to make the caller pass the arguments in expected format rather than trying to auto-convert unless that is completely unambiguous. So in this case, we say images are always 3 dimensional arrays and it's up to the user to make sure they conform to that. Otherwise they might be passing something else than they think they are passing and we would mask their bug until later. |
||
|
||
height, width, nChannels = array.shape | ||
ocvTypes = ImageSchema.ocvTypes | ||
if nChannels == 1: | ||
mode = ocvTypes["CV_8UC1"] | ||
elif nChannels == 3: | ||
mode = ocvTypes["CV_8UC3"] | ||
elif nChannels == 4: | ||
mode = ocvTypes["CV_8UC4"] | ||
else: | ||
raise ValueError("Invalid number of channels") | ||
dtype = array.dtype | ||
if not dtype in self._numpyToOcvMap: | ||
raise ValueError( | ||
"Unexpected/unsupported array data type '%s', currently only supported formats are %s" % | ||
(str(array.dtype), str(self._numpyToOcvMap.keys()))) | ||
ocvName = "CV_%sC%d" % (self._numpyToOcvMap[dtype], nChannels) | ||
ocvType = self.ocvTypeByName(ocvName) | ||
|
||
# Running `bytearray(numpy.array([1]))` fails in specific Python versions | ||
# with a specific Numpy version, for example in Python 3.6.0 and NumPy 1.13.3. | ||
# Here, it avoids it by converting it to bytes. | ||
data = bytearray(array.astype(dtype=np.uint8).ravel().tobytes()) | ||
data = bytearray(array.tobytes()) | ||
|
||
# Creating new Row with _create_row(), because Row(name = value, ... ) | ||
# orders fields by name, which conflicts with expected schema order | ||
# when the new DataFrame is created by UDF | ||
return _create_row(self.imageFields, | ||
[origin, height, width, nChannels, mode, data]) | ||
[origin, height, width, nChannels, ocvType.mode, data]) | ||
|
||
def readImages(self, path, recursive=False, numPartitions=-1, | ||
dropImageFailures=False, sampleRatio=1.0, seed=0): | ||
|
@@ -203,7 +242,7 @@ def readImages(self, path, recursive=False, numPartitions=-1, | |
.. versionadded:: 2.3.0 | ||
""" | ||
|
||
ctx = SparkContext._active_spark_context | ||
ctx = SparkContext.getOrCreate() | ||
spark = SparkSession(ctx) | ||
image_schema = ctx._jvm.org.apache.spark.ml.image.ImageSchema | ||
jsession = spark._jsparkSession | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will numpy figure out the right strides if we don't pass it explicitly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm yeah I would think so. The original code from ms folks was like this and I did not want to do more changes than necessary.