Skip to content

Commit 28078f7

Browse files
author
blublinsky
committed
Add serializers/desirializers for models
1 parent 7f6875e commit 28078f7

File tree

6 files changed

+111
-7
lines changed

6 files changed

+111
-7
lines changed

src/main/scala/com/lightbend/modelServer/ModelServingFlatJob.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import java.util.Properties
44

55
import com.lightbend.kafka.ModelServingConfiguration
66
import com.lightbend.model.winerecord.WineRecord
7+
import com.lightbend.modelServer.model.{PMMLModelSerializerKryo, TensorFlowModel, TensorFlowModelSerializerKryo}
78
import com.lightbend.modelServer.typeschema.ByteArraySchema
89
import org.apache.flink.api.scala._
910
import org.apache.flink.configuration.{ConfigConstants, Configuration, QueryableStateOptions}
@@ -80,6 +81,10 @@ object ModelServingFlatJob {
8081
def buildGraph(env : StreamExecutionEnvironment) : Unit = {
8182
env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
8283
env.enableCheckpointing(5000)
84+
// Add custom serializers
85+
env.getConfig.addDefaultKryoSerializer(TensorFlowModel.getClass, classOf[TensorFlowModelSerializerKryo])
86+
env.getConfig.addDefaultKryoSerializer(TensorFlowModel.getClass, classOf[PMMLModelSerializerKryo])
87+
8388

8489
// configure Kafka consumer
8590
// Data

src/main/scala/com/lightbend/modelServer/ModelServingKeyedJob.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@ import java.util.Properties
44

55
import com.lightbend.kafka.ModelServingConfiguration
66
import com.lightbend.model.winerecord.WineRecord
7+
import com.lightbend.modelServer.model.{PMMLModelSerializerKryo, TensorFlowModel, TensorFlowModelSerializerKryo}
78
import org.apache.flink.api.scala._
89
import com.lightbend.modelServer.typeschema.ByteArraySchema
910
import org.apache.flink.streaming.api.TimeCharacteristic
1011
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
1112
import org.apache.flink.streaming.connectors.kafka.FlinkKafkaConsumer010
12-
1313
import org.apache.flink.configuration.Configuration
1414
import org.apache.flink.configuration.ConfigConstants
1515
import org.apache.flink.configuration.QueryableStateOptions
@@ -84,6 +84,9 @@ object ModelServingKeyedJob {
8484
def buildGraph(env : StreamExecutionEnvironment) : Unit = {
8585
env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
8686
env.enableCheckpointing(5000)
87+
// Add custom serializer
88+
env.getConfig.addDefaultKryoSerializer(TensorFlowModel.getClass, classOf[TensorFlowModelSerializerKryo])
89+
env.getConfig.addDefaultKryoSerializer(TensorFlowModel.getClass, classOf[PMMLModelSerializerKryo])
8790

8891
// configure Kafka consumer
8992
// Data

src/main/scala/com/lightbend/modelServer/model/PMMLModel.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ package com.lightbend.modelServer.model
99
import org.jpmml.evaluator.{FieldValue, ModelEvaluatorFactory, TargetField}
1010
import org.jpmml.evaluator.visitors._
1111
import org.jpmml.model.PMMLUtil
12-
import java.io.{ByteArrayInputStream, InputStream}
12+
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream}
1313

1414
import org.jpmml.evaluator.Computable
1515
import com.lightbend.model.winerecord.WineRecord
@@ -67,6 +67,12 @@ class PMMLModel(inputStream: Array[Byte]) extends Model {
6767
case _ => .0
6868
}
6969

70+
def toBytes : Array[Byte] = {
71+
var stream = new ByteArrayOutputStream()
72+
PMMLUtil.marshal(pmml, stream)
73+
stream.toByteArray
74+
}
75+
7076
}
7177

7278
object PMMLModel{
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package com.lightbend.modelServer.model
2+
3+
/**
4+
* Created by boris on 6/2/17.
5+
*/
6+
import com.esotericsoftware.kryo.io.{Input, Output}
7+
import com.esotericsoftware.kryo.{Kryo, Serializer}
8+
9+
10+
class PMMLModelSerializerKryo extends Serializer[PMMLModel]{
11+
12+
super.setAcceptsNull(false)
13+
super.setImmutable(true)
14+
15+
/** Reads bytes and returns a new object of the specified concrete type.
16+
* <p>
17+
* Before Kryo can be used to read child objects, {@link Kryo#reference(Object)} must be called with the parent object to
18+
* ensure it can be referenced by the child objects. Any serializer that uses {@link Kryo} to read a child object may need to
19+
* be reentrant.
20+
* <p>
21+
* This method should not be called directly, instead this serializer can be passed to {@link Kryo} read methods that accept a
22+
* serialier.
23+
*
24+
* @return May be null if { @link #getAcceptsNull()} is true. */
25+
26+
override def read(kryo: Kryo, input: Input, `type`: Class[PMMLModel]): PMMLModel = {
27+
val bytes = Stream.continually(input.readByte()).takeWhile(_ != -1).toArray
28+
PMMLModel(bytes).get
29+
}
30+
31+
/** Writes the bytes for the object to the output.
32+
* <p>
33+
* This method should not be called directly, instead this serializer can be passed to {@link Kryo} write methods that accept a
34+
* serialier.
35+
*
36+
* @param value May be null if { @link #getAcceptsNull()} is true. */
37+
38+
override def write(kryo: Kryo, output: Output, value: PMMLModel): Unit = {
39+
output.write(value.toBytes)
40+
}
41+
}

src/main/scala/com/lightbend/modelServer/model/TensorFlowModel.scala

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package com.lightbend.modelServer.model
22

3-
import java.io.InputStream
43

54
import com.lightbend.model.winerecord.WineRecord
65
import org.tensorflow.{Graph, Session, Tensor}
@@ -9,11 +8,12 @@ import org.tensorflow.{Graph, Session, Tensor}
98
* Created by boris on 5/26/17.
109
* Implementation of tensorflow model
1110
*/
12-
class TensorFlowModel(inputStream: Array[Byte]) extends Model {
11+
12+
class TensorFlowModel(inputStream : Array[Byte]) extends Model {
1313

1414
val graph = new Graph
1515
graph.importGraphDef(inputStream)
16-
val session = new Session (graph)
16+
val session = new Session(graph)
1717

1818
override def score(input: AnyVal): AnyVal = {
1919

@@ -45,8 +45,16 @@ class TensorFlowModel(inputStream: Array[Byte]) extends Model {
4545
}
4646

4747
override def cleanup(): Unit = {
48-
session.close
49-
graph.close
48+
try{
49+
session.close
50+
}catch {
51+
case t: Throwable => // Swallow
52+
}
53+
try{
54+
graph.close
55+
}catch {
56+
case t: Throwable => // Swallow
57+
}
5058
}
5159
}
5260

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package com.lightbend.modelServer.model
2+
3+
/**
4+
* Created by boris on 6/2/17.
5+
*/
6+
import com.esotericsoftware.kryo.{Kryo, Serializer}
7+
import com.esotericsoftware.kryo.io.{Input, Output}
8+
9+
10+
class TensorFlowModelSerializerKryo extends Serializer[TensorFlowModel]{
11+
12+
super.setAcceptsNull(false)
13+
super.setImmutable(true)
14+
15+
/** Reads bytes and returns a new object of the specified concrete type.
16+
* <p>
17+
* Before Kryo can be used to read child objects, {@link Kryo#reference(Object)} must be called with the parent object to
18+
* ensure it can be referenced by the child objects. Any serializer that uses {@link Kryo} to read a child object may need to
19+
* be reentrant.
20+
* <p>
21+
* This method should not be called directly, instead this serializer can be passed to {@link Kryo} read methods that accept a
22+
* serialier.
23+
*
24+
* @return May be null if { @link #getAcceptsNull()} is true. */
25+
26+
override def read(kryo: Kryo, input: Input, `type`: Class[TensorFlowModel]): TensorFlowModel = {
27+
val bytes = Stream.continually(input.readByte()).takeWhile(_ != -1).toArray
28+
TensorFlowModel(bytes).get
29+
}
30+
31+
/** Writes the bytes for the object to the output.
32+
* <p>
33+
* This method should not be called directly, instead this serializer can be passed to {@link Kryo} write methods that accept a
34+
* serialier.
35+
*
36+
* @param value May be null if { @link #getAcceptsNull()} is true. */
37+
38+
override def write(kryo: Kryo, output: Output, value: TensorFlowModel): Unit = {
39+
output.write(value.graph.toGraphDef)
40+
}
41+
}

0 commit comments

Comments
 (0)