From ab0888765a2d898db10c4ac5e6fad24bff5b676b Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 12 Mar 2018 15:30:44 -0400 Subject: [PATCH] Proof of concept implementation of asof Join --- .../scala/org/apache/spark/Partitioner.scala | 2 +- .../plans/logical/basicLogicalOperators.scala | 13 ++ .../plans/physical/partitioning.scala | 133 ++++++++++++ .../spark/sql/ExperimentalMethods.scala | 82 +++++++- .../spark/sql/execution/AsofJoinExec.scala | 131 ++++++++++++ .../sql/execution/BroadcastAsofJoinExec.scala | 132 ++++++++++++ .../spark/sql/execution/QueryExecution.scala | 3 +- .../exchange/ShuffleExchangeExec.scala | 58 +++++- .../org/apache/spark/sql/functions.scala | 36 +++- .../org/apache/spark/sql/AsofJoinSuite.scala | 197 ++++++++++++++++++ 10 files changed, 775 insertions(+), 12 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/AsofJoinExec.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/BroadcastAsofJoinExec.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/AsofJoinSuite.scala diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index c940cb25d478b..787a2d33fd189 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -156,7 +156,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( private var ordering = implicitly[Ordering[K]] // An array of upper bounds for the first (partitions - 1) partitions - private var rangeBounds: Array[K] = { + var rangeBounds: Array[K] = { if (partitions <= 1) { Array.empty } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index a4fca790dd086..5f6507a67a5f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -275,6 +275,19 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { } } +case class AsofJoin( + left: LogicalPlan, + right: LogicalPlan, + leftOn: Expression, + rightOn: Expression, + leftBy: Seq[Expression], + rightBy: Seq[Expression], + tolerance: Long) + extends BinaryNode { + + override def output: Seq[Attribute] = left.output ++ right.output.map(_.withNullability(true)) +} + case class Join( left: LogicalPlan, right: LogicalPlan, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 4d9a9925fe3ff..f5e9a4c55d908 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,9 +17,16 @@ package org.apache.spark.sql.catalyst.plans.physical +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import com.google.common.collect.{Range => GRange, RangeMap, TreeRangeMap} + +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{DataType, IntegerType} + /** * Specifies how tuples that share common expressions will be distributed when a query is executed * in parallel on many machines. Distribution can be used to refer to two distinct physical @@ -148,6 +155,43 @@ case class BroadcastDistribution(mode: BroadcastMode) extends Distribution { } } +/** + * A object that holds distribution ranges to be shared cross different nodes. + * + * This is not thread safe. + */ +class DelayedRange extends Serializable { + var range: IndexedSeq[GRange[java.lang.Long]] = _ + + def realized(): Boolean = { + range != null + } + + def setRange(range: IndexedSeq[GRange[java.lang.Long]]): Unit = { + this.range = range + } + + def getRange(): IndexedSeq[GRange[java.lang.Long]] = { + if (range == null) { + throw new Exception("DelayedRange is not realized") + } else { + range + } + } +} + +case class DelayedOverlappedRangeDistribution( + key: Expression, + overlap: Long, + core: Boolean // Decides whether it is range-defining +) extends Distribution { + override def requiredNumPartitions: Option[Int] = None + + override def createPartitioning(numPartitions: Int): Partitioning = { + DelayedOverlappedRangePartitioning(key, null, numPartitions, overlap, core) + } +} + /** * Describes how an operator's output is split across partitions. It has 2 major properties: * 1. number of partitions. @@ -261,6 +305,95 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) } } +case class DelayedOverlappedRangePartitioning( + key: Expression, + var delayedRange: DelayedRange, + numPartitions: Int, + overlap: Long, + core: Boolean +) extends Expression with Partitioning with Unevaluable { + override def children: Seq[Expression] = Seq(key) + override def nullable: Boolean = false + override def dataType: DataType = IntegerType + + def setDelayedRange(delayedRange: DelayedRange): Unit = this.delayedRange = delayedRange + + override def satisfies(required: Distribution): Boolean = { + super.satisfies(required) || { + required match { + case DelayedOverlappedRangeDistribution(requiredKey, requiredOverlap, _) => + key.semanticEquals(requiredKey) && overlap == requiredOverlap + case OrderedDistribution(ordering) => + (ordering.length == 1) && + ordering.head.child.semanticEquals(key) && + ordering.head.direction == Ascending && + overlap == 0 + case ClusteredDistribution(clustering, None) => + (clustering.length == 1) && clustering.head.semanticEquals(key) && overlap == 0 + case _ => false + } + } + } + + def realizeDelayedRange(bounds: Array[Long]): Unit = { + if (delayedRange.realized()) { + throw new Exception("ranges are realized already") + } else { + val ranges: IndexedSeq[GRange[java.lang.Long]] = { + (Seq(Long.MinValue) ++ bounds).zip(bounds ++ Seq(Long.MaxValue)).map { + case (lower, upper) => GRange.closedOpen(Long.box(lower), Long.box(upper)) + }.toIndexedSeq + } + + delayedRange.setRange(ranges) + } + } + + def withPartitionIds( + iter: Iterator[InternalRow], + expr: Expression, + output: Seq[Attribute]): Iterator[Product2[Int, InternalRow]] = { + val expandedRanges = delayedRange.getRange().map{ + case range => GRange.closedOpen( + Long.box( + if (range.lowerEndpoint() == Long.MinValue) { + Long.MinValue + } + else { + range.lowerEndpoint() - overlap + } + ), + range.upperEndpoint() + ) + } + + val keyProj = UnsafeProjection.create(Seq(expr), output) + val proj = UnsafeProjection.create(output, output) + + var currentStartIndex = 0 + + iter.flatMap { row => + val key = keyProj(row).getLong(0) + + // Update + while(currentStartIndex < expandedRanges.length && + !expandedRanges(currentStartIndex).contains(key)) { + currentStartIndex += 1 + } + + var i = currentStartIndex + val rows = new ArrayBuffer[Product2[Int, UnsafeRow]]() + + while(i < expandedRanges.length && expandedRanges(i).contains(key)) { + rows.append((i, proj(row))) + i += 1 + } + + rows + } + } +} + /** * A collection of [[Partitioning]]s that can be used to describe the partitioning * scheme of the output of a physical operator. It is usually used for an operator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala index bd8dd6ea3fe0f..4c8b9497fcc92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala @@ -18,8 +18,83 @@ package org.apache.spark.sql import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{AsofJoin, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.physical.{DelayedOverlappedRangePartitioning, DelayedRange} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{AsofJoinExec, BroadcastAsofJoinExec, SparkPlan} +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} +import org.apache.spark.sql.internal.SQLConf + +// These are codes that can be added via experimental methods +// The actual rules don't need to be in this file. Keep them here for now +// for convenience. + +object AsofJoinStrategy extends Strategy { + + private def canBroadcastByHints(left: LogicalPlan, right: LogicalPlan) + : Boolean = { + left.stats.hints.broadcast || right.stats.hints.broadcast + } + + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case AsofJoin(left, right, leftOn, rightOn, leftBy, rightBy, tolerance) + if canBroadcastByHints(left, right) => + val buildSide = if (left.stats.hints.broadcast) { + BuildLeft + } else { + BuildRight + } + BroadcastAsofJoinExec( + buildSide, + leftOn, + rightOn, + leftBy, rightBy, tolerance, planLater(left), planLater(right)) :: Nil + + case AsofJoin(left, right, leftOn, rightOn, leftBy, rightBy, tolerance) => + AsofJoinExec( + leftOn, + rightOn, + leftBy, rightBy, tolerance, planLater(left), planLater(right)) :: Nil + + case _ => Nil + } +} + +/** + * This must run after ensure requirements. This is not great but I don't know another way to + * do this, unless we modify ensure requirements. + * + * Currently this mutate the state of partitioning (by setting the delayed range object) so + * it's not great. We might need to make partitioning immutable and copy nodes with new + * partitioning. + * + */ +object EnsureRange extends Rule[SparkPlan] { + + private def ensureChildrenRange(operator: SparkPlan): SparkPlan = operator match { + case asof: AsofJoinExec => + // This code assumes EnsureRequirement will set the left and right partitioning + // properly + val leftPartitioning = + asof.left.outputPartitioning.asInstanceOf[DelayedOverlappedRangePartitioning] + val rightPartitioning = + asof.right.outputPartitioning.asInstanceOf[DelayedOverlappedRangePartitioning] + + if (leftPartitioning.delayedRange == null) { + val delayedRange = new DelayedRange() + leftPartitioning.setDelayedRange(delayedRange) + rightPartitioning.setDelayedRange(delayedRange) + } else { + rightPartitioning.setDelayedRange(leftPartitioning.delayedRange) + } + asof + } + + override def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case operator: AsofJoinExec => ensureChildrenRange(operator) + } +} + /** * :: Experimental :: @@ -42,14 +117,17 @@ class ExperimentalMethods private[sql]() { * * @since 1.3.0 */ - @volatile var extraStrategies: Seq[Strategy] = Nil + @volatile var extraStrategies: Seq[Strategy] = Seq(AsofJoinStrategy) @volatile var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil + @volatile var extraPreparations: Seq[Rule[SparkPlan]] = Seq(EnsureRange) + override def clone(): ExperimentalMethods = { val result = new ExperimentalMethods result.extraStrategies = extraStrategies result.extraOptimizations = extraOptimizations + result.extraPreparations = extraPreparations result } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AsofJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AsofJoinExec.scala new file mode 100644 index 0000000000000..b762b99dcb4a1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AsofJoinExec.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.{HashMap => JHashMap} + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.sql.catalyst.plans.physical._ + +class NullUnsafeProjection extends UnsafeProjection { + override def apply(row: InternalRow): UnsafeRow = new UnsafeRow() +} + +case class AsofJoinExec( + leftOnExpr: Expression, + rightOnExpr: Expression, + leftByExpr: Seq[Expression], + rightByExpr: Seq[Expression], + tolerance: Long, + left: SparkPlan, + right: SparkPlan) extends BinaryExecNode { + + override val output: Seq[Attribute] = + left.output ++ right.output.map(_.withNullability(true)) + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def producedAttributes: AttributeSet = AttributeSet(output) + + override val requiredChildDistribution: Seq[Distribution] = { + Seq( + DelayedOverlappedRangeDistribution(leftOnExpr, 0, true), + DelayedOverlappedRangeDistribution(rightOnExpr, tolerance, false) + ) + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(leftOnExpr.map(SortOrder(_, Ascending)), rightOnExpr.map(SortOrder(_, Ascending))) + + override def outputOrdering: Seq[SortOrder] = Seq(SortOrder(leftOnExpr, Ascending)) + + /** + * Iterates until we are at the last row without going over current key, + * save the last row. + */ + @annotation.tailrec + private def catchUp( + leftOn: Long, + rightIter: BufferedIterator[InternalRow], + rightOnProj: UnsafeProjection, + rightByProj: UnsafeProjection, + lastSeen: JHashMap[UnsafeRow, InternalRow] + ) { + val nextRight = if (rightIter.hasNext) rightIter.head else null + + if (nextRight != null) { + val rightOn = rightOnProj(nextRight).getLong(0) + val rightBy = rightByProj(nextRight) + if (rightOn <= leftOn) { + val row = rightIter.next + // TODO: Can we avoid the copy? + lastSeen.put(rightBy.copy(), row.copy()) + catchUp(leftOn, rightIter, rightOnProj, rightByProj, lastSeen) + } + } + } + + override protected def doExecute(): RDD[InternalRow] = { + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + + val rightNullRow = new GenericInternalRow(right.output.length) + val joinedRow = new JoinedRow() + + val leftOnProj = UnsafeProjection.create(Seq(leftOnExpr), left.output) + val rightOnProj = UnsafeProjection.create(Seq(rightOnExpr), right.output) + val leftByProj = if (leftByExpr.isEmpty) { + new NullUnsafeProjection() + } else { + UnsafeProjection.create(leftByExpr, left.output) + } + val rightByProj = if (rightByExpr.isEmpty) { + new NullUnsafeProjection() + } else { + UnsafeProjection.create(rightByExpr, right.output) + } + + val resultProj = UnsafeProjection.create(output, output) + val bufferedRightIter = rightIter.buffered + + // TODO: Use unsafe map? + val lastSeen = new JHashMap[UnsafeRow, InternalRow](1024) + + leftIter.map { leftRow => + val leftOn = leftOnProj(leftRow).getLong(0) + catchUp(leftOn, bufferedRightIter, rightOnProj, rightByProj, lastSeen) + + val rightRow = lastSeen.get(leftByProj(leftRow)) + + if (rightRow == null) { + resultProj(joinedRow.withLeft(leftRow).withRight(rightNullRow)) + } else { + val rightOn = rightOnProj(rightRow).getLong(0) + + if (leftOn - tolerance <= rightOn && rightOn <= leftOn) { + resultProj(joinedRow.withLeft(leftRow).withRight(rightRow)) + } else { + resultProj(joinedRow.withLeft(leftRow).withRight(rightNullRow)) + } + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BroadcastAsofJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BroadcastAsofJoinExec.scala new file mode 100644 index 0000000000000..efb63a0b1825c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BroadcastAsofJoinExec.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.{HashMap => JHashMap} + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, GenericInternalRow, JoinedRow, SortOrder, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} + +case class BroadcastAsofJoinExec( + buildSide: BuildSide, + leftOnExpr: Expression, + rightOnExpr: Expression, + leftByExpr: Seq[Expression], + rightByExpr: Seq[Expression], + tolerance: Long, + left: SparkPlan, + right: SparkPlan) extends BinaryExecNode { + + override def requiredChildDistribution: Seq[Distribution] = { + buildSide match { + case BuildLeft => + BroadcastDistribution(IdentityBroadcastMode) :: UnspecifiedDistribution :: Nil + case BuildRight => + UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil + } + } + + override val output: Seq[Attribute] = left.output ++ right.output.map(_.withNullability(true)) + + override def outputPartitioning: Partitioning = buildSide match { + case BuildLeft => right.outputPartitioning + case BuildRight => left.outputPartitioning + } + + override def producedAttributes: AttributeSet = AttributeSet(output) + + override def outputOrdering: Seq[SortOrder] = Seq(SortOrder(leftOnExpr, Ascending)) + + /** + * Iterates until we are at the last row without going over current key, + * save the last row. + */ + @annotation.tailrec + private def catchUp( + leftOn: Long, + rightIter: BufferedIterator[InternalRow], + rightOnProj: UnsafeProjection, + rightByProj: UnsafeProjection, + lastSeen: JHashMap[UnsafeRow, InternalRow] + ) { + val nextRight = if (rightIter.hasNext) rightIter.head else null + + if (nextRight != null) { + val rightOn = rightOnProj(nextRight).getLong(0) + val rightBy = rightByProj(nextRight) + if (rightOn <= leftOn) { + val row = rightIter.next + // TODO: Can we avoid the copy? + lastSeen.put(rightBy.copy(), row.copy()) + catchUp(leftOn, rightIter, rightOnProj, rightByProj, lastSeen) + } + } + } + + override protected def doExecute(): RDD[InternalRow] = { + // left broadcast is hard ... + val broadcastedRelation = right.executeBroadcast[Array[InternalRow]]() + + left.execute().mapPartitionsInternal { leftIter => + val rightIter = broadcastedRelation.value.iterator + val rightNullRow = new GenericInternalRow(right.output.length) + val joinedRow = new JoinedRow() + + val leftOnProj = UnsafeProjection.create(Seq(leftOnExpr), left.output) + val rightOnProj = UnsafeProjection.create(Seq(rightOnExpr), right.output) + val leftByProj = if (leftByExpr.isEmpty) { + new NullUnsafeProjection() + } else { + UnsafeProjection.create(leftByExpr, left.output) + } + val rightByProj = if (rightByExpr.isEmpty) { + new NullUnsafeProjection() + } else { + UnsafeProjection.create(rightByExpr, right.output) + } + + val resultProj = UnsafeProjection.create(output, output) + val bufferedRightIter = rightIter.buffered + + // TODO: Use unsafe map? + val lastSeen = new JHashMap[UnsafeRow, InternalRow](1024) + + leftIter.map { leftRow => + val leftOn = leftOnProj(leftRow).getLong(0) + catchUp(leftOn, bufferedRightIter, rightOnProj, rightByProj, lastSeen) + + val rightRow = lastSeen.get(leftByProj(leftRow)) + + if (rightRow == null) { + resultProj(joinedRow.withLeft(leftRow).withRight(rightNullRow)) + } else { + val rightOn = rightOnProj(rightRow).getLong(0) + + if (leftOn - tolerance <= rightOn && rightOn <= leftOn) { + resultProj(joinedRow.withLeft(leftRow).withRight(rightRow)) + } else { + resultProj(joinedRow.withLeft(leftRow).withRight(rightNullRow)) + } + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 7cae24bf5976c..4278ac18441ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -94,7 +94,8 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { EnsureRequirements(sparkSession.sessionState.conf), CollapseCodegenStages(sparkSession.sessionState.conf), ReuseExchange(sparkSession.sessionState.conf), - ReuseSubquery(sparkSession.sessionState.conf)) + ReuseSubquery(sparkSession.sessionState.conf) + ) ++ sparkSession.experimental.extraPreparations protected def stringOrError[A](f: => A): String = try f.toString catch { case e: AnalysisException => e.toString } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 4d95ee34f30de..0ac1a4701cb4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -26,7 +26,7 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, SortOrder, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ @@ -63,6 +63,13 @@ case class ShuffleExchangeExec( override def outputPartitioning: Partitioning = newPartitioning + // TODO: Verify this + /* override def outputOrdering: Seq[SortOrder] = newPartitioning match { + case OverlappedRangePartitioning(key, _, _, _) => + Seq(SortOrder(key, Ascending)) + case _ => Nil + } */ + private val serializer: Serializer = new UnsafeRowSerializer(child.output.size, longMetric("dataSize")) @@ -205,6 +212,7 @@ object ShuffleExchangeExec { outputAttributes: Seq[Attribute], newPartitioning: Partitioning, serializer: Serializer): ShuffleDependency[Int, InternalRow, InternalRow] = { + val part: Partitioner = newPartitioning match { case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions) case HashPartitioning(_, n) => @@ -232,6 +240,23 @@ object ShuffleExchangeExec { override def numPartitions: Int = 1 override def getPartition(key: Any): Int = 0 } + case partitioning @ DelayedOverlappedRangePartitioning(_, _, n, _, core) if core => + val rddForSampling = rdd.mapPartitionsInternal { iter => + val mutablePair = new MutablePair[InternalRow, Null]() + iter.map(row => mutablePair.update(row.copy(), null)) + } + val sortingExpressions = Seq(SortOrder(partitioning.key, Ascending)) + implicit val ordering = new LazilyGeneratedOrdering(sortingExpressions, outputAttributes) + new RangePartitioner( + n, + rddForSampling, + ascending = true, + samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) + case DelayedOverlappedRangePartitioning(_, _, n, _, core) if !core => + new Partitioner { + override def numPartitions: Int = n + override def getPartition(key: Any): Int = key.asInstanceOf[Int] + } case _ => sys.error(s"Exchange not implemented for $newPartitioning") // TODO: Handle BroadcastPartitioning. } @@ -248,6 +273,7 @@ object ShuffleExchangeExec { val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) row => projection(row).getInt(0) case RangePartitioning(_, _) | SinglePartition => identity + case _: DelayedOverlappedRangePartitioning => identity case _ => sys.error(s"Exchange not implemented for $newPartitioning") } @@ -303,11 +329,31 @@ object ShuffleExchangeExec { val getPartitionKey = getPartitionKeyExtractor() iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) } } - } else { - newRdd.mapPartitionsInternal { iter => - val getPartitionKey = getPartitionKeyExtractor() - val mutablePair = new MutablePair[Int, InternalRow]() - iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) } + } + else { + newPartitioning match { + case partitioning: DelayedOverlappedRangePartitioning if partitioning.core => + + val p = part.asInstanceOf[RangePartitioner[InternalRow, Null]] + val proj = UnsafeProjection.create(Seq(partitioning.key), outputAttributes) + val bounds = p.rangeBounds.map{ row => proj(row).getLong(0)} + partitioning.realizeDelayedRange(bounds) + + newRdd.mapPartitionsInternal { iter => + val getPartitionKey = getPartitionKeyExtractor() + val mutablePair = new MutablePair[Int, InternalRow]() + iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) } + } + + case partitioning: DelayedOverlappedRangePartitioning if !partitioning.core => + newRdd.mapPartitionsInternal(iter => + partitioning.withPartitionIds(iter, partitioning.key, outputAttributes)) + case _ => + newRdd.mapPartitionsInternal { iter => + val getPartitionKey = getPartitionKeyExtractor() + val mutablePair = new MutablePair[Int, InternalRow]() + iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) } + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0d54c02c3d06f..ac79f21205156 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.language.implicitConversions import scala.reflect.runtime.universe.{typeTag, TypeTag} import scala.util.Try @@ -30,8 +31,8 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint} -import org.apache.spark.sql.execution.SparkSqlParser +import org.apache.spark.sql.catalyst.plans.logical.{AsofJoin, HintInfo, ResolvedHint} +import org.apache.spark.sql.execution.{LogicalRDD, SparkSqlParser} import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -3683,4 +3684,35 @@ object functions { def callUDF(udfName: String, cols: Column*): Column = withExpr { UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false) } + + def asofJoin( + left: DataFrame, + right: DataFrame, + on: String, + by: String, + tolerance: Long + ): DataFrame = { + val leftOn = left.logicalPlan.output.find(_.name == on).get + val rightOn = right.logicalPlan.output.find(_.name == on).get + val leftBy: Seq[Expression] = if (by != null) { + Seq(left.logicalPlan.output.find(_.name == by).get) + } else { + Seq.empty + } + val rightBy: Seq[Expression] = if (by != null) { + Seq(right.logicalPlan.output.find(_.name == by).get) + } else { + Seq.empty + } + Dataset.ofRows( + left.sparkSession, + AsofJoin( + left.logicalPlan, + right.logicalPlan, + leftOn, + rightOn, + leftBy, + rightBy, + tolerance)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/AsofJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/AsofJoinSuite.scala new file mode 100644 index 0000000000000..0da8598d9aba9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/AsofJoinSuite.scala @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ListBuffer +import scala.language.existentials +import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} +import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.expressions.{Ascending, PythonUDF, SortOrder} +import org.apache.spark.sql.execution.{BinaryExecNode, SortExec} +import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.functions.asofJoin +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{LongType, StructField, StructType} + +class AsofJoinSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + setupTestData() + + private lazy val df1 = spark.createDataFrame(Seq( + (100, 1, 5.0, 3.0), + (200, 2, 10.0, 4.0), + (300, 1, 15.0, 6.0), + (300, 2, 10.0, 5.0), + (400, 2, 8.0, 7.0) + )).toDF("time", "id", "v1", "w1") + + private lazy val df2 = spark.createDataFrame(Seq( + (90, 1, 20.0, 1.0), + (185, 2, 30.0, 2.0), + (250, 1, 40.0, 3.0), + (250, 2, 40.0, 3.0), + (300, 1, 30.0, 5.0) + )).toDF("time", "id", "v2", "w2") + + private lazy val df3 = spark.createDataFrame(Seq( + (90, 1, 20.0, 1.0), + (185, 2, 30.0, 2.0), + (250, 1, 40.0, 3.0), + (300, 1, 50.0, 2.0), + (400, 1, 60.0, 7.0) + )).toDF("time", "id", "v3", "w3") + + test("Asof Join 2 with flatmap groups in pandas") { + import org.apache.spark.sql.functions.{asofJoin, sum, lit, array, explode} + + val df1 = spark.range(0, 1000, 3).toDF("time") + val df2 = spark.range(1, 1000, 3).toDF("time") + val tolerance = 10 + val context = (0L, 1000L) + + val udf = PythonUDF("foo", null, StructType(Seq(StructField("time", LongType))), + Seq(df1("time").expr), + evalType = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, true) + + val result = asofJoin(df1, df2, "time", null, tolerance) + .groupBy(df1("time")).flatMapGroupsInPandas(udf) + + result.explain(true) + } + + // TODO: Make this work + test("Asof Join 2 unsorted") { + import org.apache.spark.sql.functions.{asofJoin, sum, lit, array, explode, col, desc} + + val df1 = spark.range(0, 1000, 3).toDF("time").repartition(47) + val df2 = spark.range(1, 1000, 3).toDF("time").repartition(47) + + df1.show() + val tolerance = 10 + val context = (0L, 1000L) + + df1.sparkSession.conf.set("spark.sql.codegen.wholeStage", "false") + + val result = asofJoin(df1, df2, "time", null, tolerance) + result.explain(true) + + result.show() + + // println(result.queryExecution.executedPlan.execute().first()) + } + + test("Asof join 2") { + import org.apache.spark.sql.functions.asofJoin + + val df1 = spark.range(0, 1000, 3).toDF("time") + val df2 = spark.range(1, 1000, 3).toDF("time") + val tolerance = 10 + + val result = asofJoin(df1, df2, "time", null, tolerance) + + result.explain(true) + + result.show(1000) + } + + test("Asof join 2 with key") { + val result = asofJoin(df1, df2, "time", "id", tolerance = 100) + + result.explain(true) + + result.show() + } + + test("Asof Join 3") { + import org.apache.spark.sql.functions.{asofJoin, lit, array, explode} + + val df1 = spark.range(0, 1000, 3).toDF("time") + val df2 = spark.range(1, 1000, 3).toDF("time") + val df3 = spark.range(2, 1000, 3).toDF("time").withColumn("v2", lit(2)) + val tolerance = 10 + + val result = asofJoin( + asofJoin(df1, df2, "time", null, tolerance) + .withColumn("v1", lit(1)), + df3, + "time", + null, + tolerance) + + result.explain(true) + + result.show(1000) + } + + test("column prune") { + // This works automatically! + import org.apache.spark.sql.functions.{asofJoin, lit, sum, array, explode} + + val tolerance = 20 + + val result = asofJoin(df1, df2, "time", null, tolerance).agg(sum(df2("v2"))) + // val result2 = df1.join(df2, "time").agg(sum(df2("v2"))) + + result.explain(true) + } + + test("count") { + import org.apache.spark.sql.functions.{asofJoin, lit, sum, array, explode} + + val tolerance = 20 + + val result = asofJoin(df1, df2, "time", null, tolerance) + + result.groupBy().count().explain(true) + } + + test("right broadcast") { + import org.apache.spark.sql.functions.{asofJoin, broadcast} + + val df1 = spark.range(0, 1000, 3).toDF("time") + val df2 = spark.range(1, 1000, 3).toDF("time") + val tolerance = 10 + + val result = asofJoin(df1, broadcast(df2), "time", null, tolerance) + + result.explain(true) + + result.show(1000) + } + + test("left broadcast (not working)") { + import org.apache.spark.sql.functions.{asofJoin, broadcast} + + val df1 = spark.range(0, 1000, 3).toDF("time") + val df2 = spark.range(1, 1000, 3).toDF("time") + val tolerance = 10 + val context = (0L, 1000L) + + val result = asofJoin(broadcast(df1), df2, "time", null, tolerance) + + result.explain(true) + + result.show(1000) + } + +}