Skip to content

Commit aeff9c9

Browse files
tomasatdatabrickssueann
authored andcommitted
Replace sparkdl's ImageSchema with Spark2.3's version (#85)
Use Spark 2.3's ImageSchema as image interface. The biggest change is using opposite ordering of color channels - BGR instead of RGB, requires extra reordering in various places. The change affects mostly the transformers and the udf creation functionality. Some noteworthy decisions: - For DeepImageFeaturizer & DeepImagePredictor, we preserved ability to read and resize images in python using PIL to match Keras. Those image read & resize utilities are not recommended for external use as it's likely to cause confusion. - For KerasImageFileTransformer and the keras udf creator, we assume the preprocessing function & model inputs work on RGB images since Keras works with RGB images. - For TFImageTransformer, we added a param to specify the channel ordering expected by the tf.graph’s input layer. Having this param explicitly raises awareness that you could be doing the wrong thing, and makes the code easier to reason about. Also needed a few tweaks to run with spark 2.3 - notably UDFs are now referenced by SQL identifier and can not have dash as part of the name. [TODO] - In order to run on spark < 2.3, the image schema files have been copied here and need to be removed in the future once Spark 2.3 is released. - During this work we discovered that ImageSchema-related utilities in Spark 2.3 should support float32 types and a bit more info about the modes. Once that is done we can remove some code from this PR and use the Spark 2.3 version instead.
1 parent 94452d6 commit aeff9c9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+1162
-400
lines changed

README.md

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,23 @@ To try running the examples below, check out the Databricks notebook [DeepLearni
8080

8181
### Working with images in Spark
8282

83-
The first step to applying deep learning on images is the ability to load the images. Deep Learning Pipelines includes utility functions that can load millions of images into a Spark DataFrame and decode them automatically in a distributed fashion, allowing manipulation at scale.
83+
The first step to applying deep learning on images is the ability to load the images. Spark and Deep Learning Pipelines include utility functions that can load millions of images into a Spark DataFrame and decode them automatically in a distributed fashion, allowing manipulation at scale.
84+
85+
Using Spark's ImageSchema
86+
87+
```python
88+
from sparkdl.image.image import ImageSchema
89+
image_df = ImageSchema.readImages("/data/myimages")
90+
```
91+
92+
or if custom image library is needed:
8493

8594
```python
86-
from sparkdl import readImages
87-
image_df = readImages("/data/myimages")
95+
from sparkdl.image import imageIO as imageIO
96+
image_df = imageIO.readImagesWithCustomFn("/data/myimages",decode_f=<your image library, see imageIO.PIL_decode>)
8897
```
8998

90-
The resulting DataFrame contains a string column named "filePath" containing the path to each image file, and a image struct ("`SpImage`") column named "image" containing the decoded image data.
99+
The resulting DataFrame contains a string column named "image" containing an image struct with schema == ImageSchema.
91100

92101
```python
93102
image_df.show()
@@ -109,7 +118,7 @@ featurizer = DeepImageFeaturizer(inputCol="image", outputCol="features", modelNa
109118
lr = LogisticRegression(maxIter=20, regParam=0.05, elasticNetParam=0.3, labelCol="label")
110119
p = Pipeline(stages=[featurizer, lr])
111120

112-
model = p.fit(train_images_df) # train_images_df is a dataset of images (SpImage) and labels
121+
model = p.fit(train_images_df) # train_images_df is a dataset of images and labels
113122

114123
# Inspect training error
115124
df = model.transform(train_images_df.limit(10)).select("image", "probability", "uri", "label")
@@ -127,11 +136,13 @@ Spark DataFrames are a natural construct for applying deep learning models to a
127136
There are many well-known deep learning models for images. If the task at hand is very similar to what the models provide (e.g. object recognition with ImageNet classes), or for pure exploration, one can use the Transformer `DeepImagePredictor` by simply specifying the model name.
128137

129138
```python
130-
from sparkdl import readImages, DeepImagePredictor
139+
from sparkdl.image.image import ImageSchema
140+
141+
from sparkdl import DeepImagePredictor
131142

132143
predictor = DeepImagePredictor(inputCol="image", outputCol="predicted_labels",
133144
modelName="InceptionV3", decodePredictions=True, topK=10)
134-
image_df = readImages("/data/myimages")
145+
image_df = ImageSchema.readImages("/data/myimages")
135146
predictions_df = predictor.transform(image_df)
136147
```
137148

@@ -140,7 +151,8 @@ Spark DataFrames are a natural construct for applying deep learning models to a
140151
Deep Learning Pipelines provides a Transformer that will apply the given TensorFlow Graph to a DataFrame containing a column of images (e.g. loaded using the utilities described in the previous section). Here is a very simple example of how a TensorFlow Graph can be used with the Transformer. In practice, the TensorFlow Graph will likely be restored from files before calling `TFImageTransformer`.
141152

142153
```python
143-
from sparkdl import readImages, TFImageTransformer
154+
from sparkdl.image.image import ImageSchema
155+
from sparkdl import TFImageTransformer
144156
import sparkdl.graph.utils as tfx
145157
from sparkdl.transformers import utils
146158
import tensorflow as tf
@@ -155,7 +167,7 @@ Spark DataFrames are a natural construct for applying deep learning models to a
155167
transformer = TFImageTransformer(inputCol="image", outputCol="predictions", graph=frozen_graph,
156168
inputTensor=image_arr, outputTensor=resized_images,
157169
outputMode="image")
158-
image_df = readImages("/data/myimages")
170+
image_df = ImageSchema.readImages("/data/myimages")
159171
processed_image_df = transformer.transform(image_df)
160172
```
161173

build.sbt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ sparkComponents ++= Seq("mllib-local", "mllib", "sql")
3535
// add any Spark Package dependencies using spDependencies.
3636
// e.g. spDependencies += "databricks/spark-avro:0.1"
3737
spDependencies += s"databricks/tensorframes:0.2.9-s_${scalaMajorVersion}"
38-
spDependencies += "Microsoft/spark-images:0.1"
38+
3939

4040
// These versions are ancient, but they cross-compile around scala 2.10 and 2.11.
4141
// Update them when dropping support for scala 2.10

project/plugins.sbt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
// You may use this file to add plugin dependencies for sbt.
22
resolvers += "Spark Packages repo" at "https://dl.bintray.com/spark-packages/maven/"
3-
43
addSbtPlugin("org.spark-packages" %% "sbt-spark-package" % "0.2.5")
5-
64
// scalacOptions in (Compile,doc) := Seq("-groups", "-implicits")
7-
85
addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.5.0")

python/sparkdl/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
#
15-
1615
from .graph.input import TFInputGraph
17-
from .image.imageIO import imageSchema, imageType, readImages
1816
from .transformers.keras_image import KerasImageFileTransformer
1917
from .transformers.named_image import DeepImagePredictor, DeepImageFeaturizer
2018
from .transformers.tf_image import TFImageTransformer

python/sparkdl/estimators/keras_image_file_estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
logger = logging.getLogger('sparkdl')
3838

39+
3940
class KerasImageFileEstimator(Estimator, HasInputCol, HasInputImageNodeName,
4041
HasOutputCol, HasOutputNodeName, HasLabelCol,
4142
HasKerasModel, HasKerasOptimizer, HasKerasLoss,

python/sparkdl/graph/builder.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
logger = logging.getLogger('sparkdl')
2929

30+
3031
class IsolatedSession(object):
3132
"""
3233
Provide an isolated session to work with mixed Keras and TensorFlow
@@ -43,6 +44,7 @@ class IsolatedSession(object):
4344
In this case, all Keras models loaded in this session will be accessible
4445
as a subgraph of of `graph`
4546
"""
47+
4648
def __init__(self, graph=None, using_keras=False):
4749
self.graph = graph or tf.Graph()
4850
self.sess = tf.Session(graph=self.graph)
@@ -166,7 +168,7 @@ def _fromKerasModelFile(cls, file_path):
166168
'Keras model must be specified as HDF5 file'
167169

168170
with IsolatedSession(using_keras=True) as issn:
169-
K.set_learning_phase(0) # Testing phase
171+
K.set_learning_phase(0) # Testing phase
170172
model = load_model(file_path)
171173
gfn = issn.asGraphFunction(model.inputs, model.outputs)
172174

@@ -223,7 +225,8 @@ def fromList(cls, functions):
223225
# We currently only support single input/output for intermediary stages
224226
# The functions could still take multi-dimensional tensor, but only one
225227
if len(gfn_out.input_names) != 1:
226-
raise NotImplementedError("Only support single input/output for intermediary layers")
228+
raise NotImplementedError(
229+
"Only support single input/output for intermediary layers")
227230

228231
# Acquire initial placeholders' properties
229232
# We want the input names of the merged function are not under scoped

python/sparkdl/graph/input.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
# pylint: disable=invalid-name,wrong-spelling-in-comment,wrong-spelling-in-docstring
2525

26+
2627
class TFInputGraph(object):
2728
"""
2829
An opaque object containing TensorFlow graph.
@@ -84,7 +85,6 @@ class TFInputGraph(object):
8485
Please see the example above.
8586
"""
8687

87-
8888
def __init__(self, graph_def, input_tensor_name_from_signature,
8989
output_tensor_name_from_signature):
9090
self.graph_def = graph_def
@@ -281,6 +281,7 @@ def _from_checkpoint_impl(checkpoint_dir, signature_def_key, feed_names, fetch_n
281281
return _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names,
282282
fetch_names=fetch_names)
283283

284+
284285
def _from_saved_model_impl(saved_model_dir, tag_set, signature_def_key, feed_names, fetch_names):
285286
"""
286287
Construct a TFInputGraph from a SavedModel.

python/sparkdl/graph/pieces.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import tensorflow as tf
1919

2020
from sparkdl.graph.builder import IsolatedSession
21-
from sparkdl.image.imageIO import SparkMode
21+
from sparkdl.image import imageIO
2222

2323
logger = logging.getLogger('sparkdl')
2424

@@ -29,7 +29,8 @@
2929
Deserializing ProtocolBuffer bytes is in general faster than directly loading Keras models.
3030
"""
3131

32-
def buildSpImageConverter(img_dtype):
32+
33+
def buildSpImageConverter(channelOrder, img_dtype):
3334
"""
3435
Convert a imageIO byte encoded image into a image tensor suitable as input to ConvNets
3536
The name of the input must be a subset of those specified in `image.imageIO.imageSchema`.
@@ -48,23 +49,25 @@ def buildSpImageConverter(img_dtype):
4849
# This is the default behavior of Python Image Library
4950
shape = tf.reshape(tf.stack([height, width, num_channels], axis=0),
5051
shape=(3,), name='shape')
51-
if img_dtype == SparkMode.RGB:
52+
if img_dtype == 'uint8':
5253
image_uint8 = tf.decode_raw(image_buffer, tf.uint8, name="decode_raw")
5354
image_float = tf.to_float(image_uint8)
54-
else:
55-
assert img_dtype == SparkMode.RGB_FLOAT32, \
56-
"Unsupported dtype for image: {}".format(img_dtype)
55+
elif img_dtype == 'float32':
5756
image_float = tf.decode_raw(image_buffer, tf.float32, name="decode_raw")
58-
57+
else:
58+
raise ValueError(
59+
'unsupported image data type "%s", currently only know how to handle uint8 and float32' % img_dtype)
5960
image_reshaped = tf.reshape(image_float, shape, name="reshaped")
61+
image_reshaped = imageIO.fixColorChannelOrdering(channelOrder, image_reshaped)
6062
image_input = tf.expand_dims(image_reshaped, 0, name="image_input")
6163
gfn = issn.asGraphFunction([height, width, image_buffer, num_channels], [image_input])
6264

6365
return gfn
6466

67+
6568
def buildFlattener():
66-
"""
67-
Build a flattening layer to remove the extra leading tensor dimension.
69+
"""
70+
Build a flattening layer to remove the extra leading tensor dimension.
6871
e.g. a tensor of shape [1, W, H, C] will have a shape [W, H, C] after applying this.
6972
"""
7073
with IsolatedSession() as issn:

python/sparkdl/graph/tensorframes_udf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
logger = logging.getLogger('sparkdl')
2525

26+
2627
def makeGraphUDF(graph, udf_name, fetches, feeds_to_fields_map=None, blocked=False, register=True):
2728
"""
2829
Create a Spark SQL UserDefinedFunction from a given TensorFlow Graph

python/sparkdl/graph/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
one of the four target variants.
3232
"""
3333

34+
3435
def validated_graph(graph):
3536
"""
3637
Check if the input is a valid :py:class:`tf.Graph` and return it.
@@ -41,6 +42,7 @@ def validated_graph(graph):
4142
assert isinstance(graph, tf.Graph), 'must provide tf.Graph, but get {}'.format(type(graph))
4243
return graph
4344

45+
4446
def get_shape(tfobj_or_name, graph):
4547
"""
4648
Return the shape of the tensor as a list
@@ -52,6 +54,7 @@ def get_shape(tfobj_or_name, graph):
5254
_shape = get_tensor(tfobj_or_name, graph).get_shape().as_list()
5355
return [-1 if x is None else x for x in _shape]
5456

57+
5558
def get_op(tfobj_or_name, graph):
5659
"""
5760
Get a :py:class:`tf.Operation` object.
@@ -76,6 +79,7 @@ def get_op(tfobj_or_name, graph):
7679
assert isinstance(op, tf.Operation), err_msg.format(_op_name, type(op), op)
7780
return op
7881

82+
7983
def get_tensor(tfobj_or_name, graph):
8084
"""
8185
Get a :py:class:`tf.Tensor` object
@@ -100,6 +104,7 @@ def get_tensor(tfobj_or_name, graph):
100104
assert isinstance(tnsr, tf.Tensor), err_msg.format(_tensor_name, type(tnsr), tnsr)
101105
return tnsr
102106

107+
103108
def tensor_name(tfobj_or_name, graph=None):
104109
"""
105110
Derive the :py:class:`tf.Tensor` name from a :py:class:`tf.Operation` or :py:class:`tf.Tensor`
@@ -130,6 +135,7 @@ def tensor_name(tfobj_or_name, graph=None):
130135
else:
131136
raise TypeError('invalid tf.Tensor name query type {}'.format(type(tfobj_or_name)))
132137

138+
133139
def op_name(tfobj_or_name, graph=None):
134140
"""
135141
Derive the :py:class:`tf.Operation` name from a :py:class:`tf.Operation` or
@@ -158,9 +164,11 @@ def op_name(tfobj_or_name, graph=None):
158164
else:
159165
raise TypeError('invalid tf.Operation name query type {}'.format(type(tfobj_or_name)))
160166

167+
161168
def add_scope_to_name(scope, name):
162169
""" Prepends the provided scope to the passed-in op or tensor name. """
163-
return "%s/%s"%(scope, name)
170+
return "%s/%s" % (scope, name)
171+
164172

165173
def validated_output(tfobj_or_name, graph):
166174
"""
@@ -172,6 +180,7 @@ def validated_output(tfobj_or_name, graph):
172180
graph = validated_graph(graph)
173181
return op_name(tfobj_or_name, graph)
174182

183+
175184
def validated_input(tfobj_or_name, graph):
176185
"""
177186
Validate and return the input names useable GraphFunction
@@ -186,6 +195,7 @@ def validated_input(tfobj_or_name, graph):
186195
('input must be Placeholder, but get', op.type)
187196
return name
188197

198+
189199
def strip_and_freeze_until(fetches, graph, sess=None, return_graph=False):
190200
"""
191201
Create a static view of the graph by

0 commit comments

Comments
 (0)