diff --git a/src/main/scala/com/redislabs/provider/redis/redisFunctions.scala b/src/main/scala/com/redislabs/provider/redis/redisFunctions.scala index b8c17706..b5736064 100644 --- a/src/main/scala/com/redislabs/provider/redis/redisFunctions.scala +++ b/src/main/scala/com/redislabs/provider/redis/redisFunctions.scala @@ -5,6 +5,8 @@ import com.redislabs.provider.redis.util.ConnectionUtils.withConnection import com.redislabs.provider.redis.util.PipelineUtils._ import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD +import redis.clients.jedis.StreamEntryID + import scala.collection.JavaConversions.mapAsJavaMap /** @@ -382,6 +384,18 @@ class RedisContext(@transient val sc: SparkContext) extends Serializable { readWriteConfig: ReadWriteConfig = ReadWriteConfig.fromSparkConf(sc.getConf)) { vs.foreachPartition(partition => setFixedList(listName, listSize, partition, redisConfig, readWriteConfig)) } + + /** + * Write RDD of (stream name, hash KVs) + * + * @param kvs RDD of tuples (hash name, Map(hash field name, hash field value)) + */ + def toRedisSTREAMs(kvs: RDD[(String, StreamEntryID, Map[String, String])]) + (implicit + redisConfig: RedisConfig = RedisConfig.fromSparkConf(sc.getConf), + readWriteConfig: ReadWriteConfig = ReadWriteConfig.fromSparkConf(sc.getConf)) { + kvs.foreachPartition(partition => setStream(partition, redisConfig, readWriteConfig)) + } } @@ -611,6 +625,31 @@ object RedisContext extends Serializable { pipeline.sync() conn.close() } + + /** + * @param streams streamName: hashes to be saved in the target host + */ + def setStream(streams: Iterator[(String, StreamEntryID, Map[String, String])], + redisConfig: RedisConfig, + readWriteConfig: ReadWriteConfig) { + implicit val rwConf: ReadWriteConfig = readWriteConfig + + streams + .map { case (key, entryId, hash) => + (redisConfig.getHost(key), (key, entryId, hash)) + } + .toArray + .groupBy(_._1) + .foreach { case (node, arr) => + withConnection(node.endpoint.connect()) { conn => + foreachWithPipeline(conn, arr) { (pipeline, a) => + val (key, entryId, hash) = a._2 + pipeline.xadd(key, entryId, hash) + } + } + } + } + } trait RedisFunctions { diff --git a/src/test/scala/com/redislabs/provider/redis/rdd/RedisRddExtraSuite.scala b/src/test/scala/com/redislabs/provider/redis/rdd/RedisRddExtraSuite.scala index 17102052..ba7e3776 100644 --- a/src/test/scala/com/redislabs/provider/redis/rdd/RedisRddExtraSuite.scala +++ b/src/test/scala/com/redislabs/provider/redis/rdd/RedisRddExtraSuite.scala @@ -4,6 +4,7 @@ import com.redislabs.provider.redis.util.ConnectionUtils.withConnection import org.scalatest.Matchers import com.redislabs.provider.redis._ import com.redislabs.provider.redis.util.TestUtils +import redis.clients.jedis.StreamEntryID import redis.clients.jedis.exceptions.JedisConnectionException import scala.collection.JavaConverters._ @@ -73,6 +74,20 @@ trait RedisRddExtraSuite extends SparkRedisSuite with Keys with Matchers { verifyHash("hash2", map2) } + test("toRedisSTREAMs") { + val map1 = Map("k1" -> "v1", "k2" -> "v2") + val map2 = Map("k3" -> "v3", "k4" -> "v4") + val hashes = Seq( + ("stream1", null.asInstanceOf[StreamEntryID], map1), + ("stream2", null.asInstanceOf[StreamEntryID], map2) + ) + val rdd = sc.parallelize(hashes) + sc.toRedisSTREAMs(rdd) + + verifyStreamLastEntry("stream1", map1) + verifyStreamLastEntry("stream2", map2) + } + test("connection fails with incorrect user/pass") { assertThrows[JedisConnectionException] { new RedisConfig(RedisEndpoint( @@ -112,4 +127,11 @@ trait RedisRddExtraSuite extends SparkRedisSuite with Keys with Matchers { } } + def verifyStreamLastEntry(stream: String, vals: Map[String, String]): Unit = { + withConnection(redisConfig.getHost(stream).endpoint.connect()) { conn => + conn.xrevrange(stream, null, null, 1).get(0) should be(vals) + // TODO: breaking in Jedis 4 + } + } + }