Skip to content

Commit 968825a

Browse files
Merge pull request #912 from FlorentinD/udf_aggrDur
Support durations in aggregation-functions
2 parents ae1bcdc + 53018c5 commit 968825a

File tree

4 files changed

+156
-29
lines changed

4 files changed

+156
-29
lines changed

morpheus-spark-cypher/src/main/scala/org/opencypher/morpheus/impl/SparkSQLExprMapper.scala

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.opencypher.morpheus.impl.convert.SparkConversions._
3535
import org.opencypher.morpheus.impl.expressions.AddPrefix._
3636
import org.opencypher.morpheus.impl.expressions.EncodeLong._
3737
import org.opencypher.morpheus.impl.temporal.TemporalConversions._
38-
import org.opencypher.morpheus.impl.temporal.TemporalUdfs
38+
import org.opencypher.morpheus.impl.temporal.{TemporalUdafs, TemporalUdfs}
3939
import org.opencypher.okapi.api.types._
4040
import org.opencypher.okapi.api.value.CypherValue.CypherMap
4141
import org.opencypher.okapi.impl.exception._
@@ -348,11 +348,26 @@ object SparkSQLExprMapper {
348348
else collect_list(child0)
349349

350350
case CountStar => count(ONE_LIT)
351-
case _: Avg => avg(child0)
352-
353-
case _: Max => max(child0)
354-
case _: Min => min(child0)
355-
case _: Sum => sum(child0)
351+
case _: Avg =>
352+
expr.cypherType match {
353+
case CTDuration => TemporalUdafs.durationAvg(child0)
354+
case _ => avg(child0)
355+
}
356+
case _: Max =>
357+
expr.cypherType match {
358+
case CTDuration => TemporalUdafs.durationMax(child0)
359+
case _ => max(child0)
360+
}
361+
case _: Min =>
362+
expr.cypherType match {
363+
case CTDuration => TemporalUdafs.durationMin(child0)
364+
case _ => min(child0)
365+
}
366+
case _: Sum =>
367+
expr.cypherType match {
368+
case CTDuration => TemporalUdafs.durationSum(child0)
369+
case _ => sum(child0)
370+
}
356371

357372

358373
case BigDecimal(_, precision, scale) =>
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/**
2+
* Copyright (c) 2016-2019 "Neo4j Sweden, AB" [https://neo4j.com]
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*
16+
* Attribution Notice under the terms of the Apache License 2.0
17+
*
18+
* This work was created by the collective efforts of the openCypher community.
19+
* Without limiting the terms of Section 6, any Derivative Work that is not
20+
* approved by the public consensus process of the openCypher Implementers Group
21+
* should not be described as “Cypher” (and Cypher® is a registered trademark of
22+
* Neo4j Inc.) or as "openCypher". Extensions by implementers or prototypes or
23+
* proposals for change that have been documented or implemented should only be
24+
* described as "implementation extensions to Cypher" or as "proposed changes to
25+
* Cypher that are not yet approved by the openCypher community".
26+
*/
27+
package org.opencypher.morpheus.impl.temporal
28+
29+
import org.apache.logging.log4j.scala.Logging
30+
import org.apache.spark.sql.Row
31+
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
32+
import org.apache.spark.sql.types.{CalendarIntervalType, DataType, LongType, StructField, StructType}
33+
import org.apache.spark.unsafe.types.CalendarInterval
34+
import org.opencypher.okapi.impl.temporal.TemporalConstants
35+
import org.opencypher.morpheus.impl.temporal.TemporalConversions._
36+
37+
object TemporalUdafs extends Logging {
38+
39+
abstract class SimpleDurationAggregation(aggrName: String) extends UserDefinedAggregateFunction {
40+
override def inputSchema: StructType = StructType(Array(StructField("duration", CalendarIntervalType)))
41+
override def bufferSchema: StructType = StructType(Array(StructField(aggrName, CalendarIntervalType)))
42+
override def dataType: DataType = CalendarIntervalType
43+
override def deterministic: Boolean = true
44+
override def initialize(buffer: MutableAggregationBuffer): Unit = {
45+
buffer(0) = new CalendarInterval(0, 0L)
46+
}
47+
override def evaluate(buffer: Row): Any = buffer.getAs[CalendarInterval](0)
48+
}
49+
50+
class DurationSum extends SimpleDurationAggregation("sum") {
51+
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
52+
buffer(0) = buffer.getAs[CalendarInterval](0).add(input.getAs[CalendarInterval](0))
53+
}
54+
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
55+
buffer1(0) = buffer2.getAs[CalendarInterval](0).add(buffer1.getAs[CalendarInterval](0))
56+
}
57+
}
58+
59+
class DurationMax extends SimpleDurationAggregation("max") {
60+
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
61+
val currMaxInterval = buffer.getAs[CalendarInterval](0)
62+
val inputInterval = input.getAs[CalendarInterval](0)
63+
buffer(0) = if (currMaxInterval.toDuration.compare(inputInterval.toDuration) >= 0) currMaxInterval else inputInterval
64+
}
65+
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
66+
val interval1 = buffer1.getAs[CalendarInterval](0)
67+
val interval2 = buffer2.getAs[CalendarInterval](0)
68+
buffer1(0) = if (interval1.toDuration.compare(interval2.toDuration) >= 0) interval1 else interval2
69+
}
70+
}
71+
72+
class DurationMin extends SimpleDurationAggregation("min") {
73+
override def initialize(buffer: MutableAggregationBuffer): Unit = {
74+
buffer(0) = new CalendarInterval(Integer.MAX_VALUE, Long.MaxValue)
75+
}
76+
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
77+
val currMinInterval = buffer.getAs[CalendarInterval](0)
78+
val inputInterval = input.getAs[CalendarInterval](0)
79+
buffer(0) = if (inputInterval.toDuration.compare(currMinInterval.toDuration) >= 0) currMinInterval else inputInterval
80+
}
81+
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
82+
val interval1 = buffer1.getAs[CalendarInterval](0)
83+
val interval2 = buffer2.getAs[CalendarInterval](0)
84+
buffer1(0) = if (interval2.toDuration.compare(interval1.toDuration) >= 0) interval1 else interval2
85+
}
86+
}
87+
88+
class DurationAvg extends UserDefinedAggregateFunction {
89+
override def inputSchema: StructType = StructType(Array(StructField("duration", CalendarIntervalType)))
90+
override def bufferSchema: StructType = StructType(Array(StructField("sum", CalendarIntervalType), StructField("cnt", LongType)))
91+
override def dataType: DataType = CalendarIntervalType
92+
override def deterministic: Boolean = true
93+
override def initialize(buffer: MutableAggregationBuffer): Unit = {
94+
buffer(0) = new CalendarInterval(0, 0L)
95+
buffer(1) = 0L
96+
}
97+
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
98+
buffer(0) = buffer.getAs[CalendarInterval](0).add(input.getAs[CalendarInterval](0))
99+
buffer(1) = buffer.getLong(1) + 1
100+
}
101+
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
102+
buffer1(0) = buffer2.getAs[CalendarInterval](0).add(buffer1.getAs[CalendarInterval](0))
103+
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
104+
}
105+
override def evaluate(buffer: Row): Any = {
106+
val sumInterval = buffer.getAs[CalendarInterval](0)
107+
val cnt = buffer.getLong(1)
108+
new CalendarInterval((sumInterval.months / cnt).toInt, sumInterval.microseconds / cnt)
109+
}
110+
}
111+
112+
val durationSum = new DurationSum()
113+
val durationAvg = new DurationAvg()
114+
val durationMin = new DurationMin()
115+
val durationMax = new DurationMax()
116+
}

morpheus-spark-cypher/src/main/scala/org/opencypher/morpheus/impl/temporal/TemporalUdfs.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,28 +119,28 @@ object TemporalUdfs extends Logging {
119119
} else {
120120
val days = duration.microseconds / CalendarInterval.MICROS_PER_DAY
121121
// Note: in cypher days (and weeks) make up their own group, thus we have to exclude them for all values < day
122-
val daysInMicros = days * CalendarInterval.MICROS_PER_DAY
122+
val daysInMicros = days * CalendarInterval.MICROS_PER_DAY
123123

124124
val l: Long = accessor match {
125125
case "years" => duration.months / 12
126126
case "quarters" => duration.months / 3
127127
case "months" => duration.months
128128
case "weeks" => duration.microseconds / CalendarInterval.MICROS_PER_DAY / 7
129129
case "days" => duration.microseconds / CalendarInterval.MICROS_PER_DAY
130-
case "hours" => (duration.microseconds - daysInMicros ) / CalendarInterval.MICROS_PER_HOUR
131-
case "minutes" => (duration.microseconds - daysInMicros ) / CalendarInterval.MICROS_PER_MINUTE
132-
case "seconds" => (duration.microseconds - daysInMicros ) / CalendarInterval.MICROS_PER_SECOND
133-
case "milliseconds" => (duration.microseconds - daysInMicros ) / CalendarInterval.MICROS_PER_MILLI
130+
case "hours" => (duration.microseconds - daysInMicros) / CalendarInterval.MICROS_PER_HOUR
131+
case "minutes" => (duration.microseconds - daysInMicros) / CalendarInterval.MICROS_PER_MINUTE
132+
case "seconds" => (duration.microseconds - daysInMicros) / CalendarInterval.MICROS_PER_SECOND
133+
case "milliseconds" => (duration.microseconds - daysInMicros) / CalendarInterval.MICROS_PER_MILLI
134134
case "microseconds" => duration.microseconds - daysInMicros
135135

136136
case "quartersofyear" => (duration.months / 3) % 4
137137
case "monthsofquarter" => duration.months % 3
138138
case "monthsofyear" => duration.months % 12
139139
case "daysofweek" => (duration.microseconds / CalendarInterval.MICROS_PER_DAY) % 7
140-
case "minutesofhour" => ((duration.microseconds - daysInMicros )/ CalendarInterval.MICROS_PER_MINUTE) % 60
141-
case "secondsofminute" => ((duration.microseconds - daysInMicros ) / CalendarInterval.MICROS_PER_SECOND) % 60
142-
case "millisecondsofsecond" => ((duration.microseconds - daysInMicros ) / CalendarInterval.MICROS_PER_MILLI) % 1000
143-
case "microsecondsofsecond" => (duration.microseconds - daysInMicros ) % 1000000
140+
case "minutesofhour" => ((duration.microseconds - daysInMicros) / CalendarInterval.MICROS_PER_MINUTE) % 60
141+
case "secondsofminute" => ((duration.microseconds - daysInMicros) / CalendarInterval.MICROS_PER_SECOND) % 60
142+
case "millisecondsofsecond" => ((duration.microseconds - daysInMicros) / CalendarInterval.MICROS_PER_MILLI) % 1000
143+
case "microsecondsofsecond" => (duration.microseconds - daysInMicros) % 1000000
144144

145145
case other => throw UnsupportedOperationException(s"Unknown Duration accessor: $other")
146146
}

morpheus-testing/src/test/scala/org/opencypher/morpheus/impl/acceptance/AggregationTests.scala

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,11 @@ class AggregationTests extends MorpheusTestSuite with ScanGraphInit {
126126
))
127127
}
128128

129-
//todo: cypher should allow avg on durations, but spark does not support avg on durations (calendarintervals)
130-
ignore("avg on durations") {
131-
val result = morpheus.graphs.empty.cypher("UNWIND [duration('P1DT12H'), duration('P1DT200H')] AS d RETURN AVG(d) as res")
129+
it("avg on durations") {
130+
val result = morpheus.graphs.empty.cypher("UNWIND [duration('P1DT12H'), duration('P1DT20H')] AS d RETURN AVG(d) as res")
132131

133132
result.records.toMaps should equal(Bag(
134-
CypherMap("res" -> Duration(days = 1, hours = 12))
133+
CypherMap("res" -> Duration(days = 1, hours = 16))
135134
))
136135
}
137136
}
@@ -366,8 +365,7 @@ class AggregationTests extends MorpheusTestSuite with ScanGraphInit {
366365
))
367366
}
368367

369-
//todo: spark does not support min on durations (calendarintervals)
370-
ignore("min on durations") {
368+
it("min on durations") {
371369
val result = morpheus.graphs.empty.cypher("UNWIND [duration('P1DT12H'), duration('P1DT200H')] AS d RETURN MIN(d) as res")
372370

373371
result.records.toMaps should equal(Bag(
@@ -472,12 +470,11 @@ class AggregationTests extends MorpheusTestSuite with ScanGraphInit {
472470
))
473471
}
474472

475-
//todo: spark does not support max on durations (calendarintervals)
476-
ignore("max on durations") {
477-
val result = morpheus.graphs.empty.cypher("UNWIND [duration('P1DT12H'), duration('P1DT200H')] AS d RETURN MAX(d) as res")
473+
it("max on durations") {
474+
val result = morpheus.graphs.empty.cypher("UNWIND [duration('P10DT12H'), duration('P1DT24H')] AS d RETURN MAX(d) as res")
478475

479476
result.records.toMaps should equal(Bag(
480-
CypherMap("res" -> Duration(days = 1, hours = 200))
477+
CypherMap("res" -> Duration(days = 10, hours = 12))
481478
))
482479
}
483480

@@ -584,12 +581,11 @@ class AggregationTests extends MorpheusTestSuite with ScanGraphInit {
584581
))
585582
}
586583

587-
//todo: cypher should sum over durations, but spark does not support sum over durations (calendarintervals)
588-
ignore("sum on durations") {
589-
val result = morpheus.graphs.empty.cypher("UNWIND [duration('P1DT12H'), duration('P1DT200H')] AS d RETURN SUM(d) as res")
584+
it("sum on durations") {
585+
val result = morpheus.graphs.empty.cypher("UNWIND [duration('P1DT12H'), duration('P1DT24H')] AS d RETURN SUM(d) as res")
590586

591587
result.records.toMaps should equal(Bag(
592-
CypherMap("res" -> Duration(days = 2, hours = 12))
588+
CypherMap("res" -> Duration(days = 3, hours = 12))
593589
))
594590
}
595591
}

0 commit comments

Comments
 (0)