spark aggregating column into Set efficiently - scala

How can I aggregate a column into an Set (Array of unique elements) in spark efficiently?
case class Foo(a:String, b:String, c:Int, d:Array[String])
val df = Seq(Foo("A", "A", 123, Array("A")),
Foo("A", "A", 123, Array("B")),
Foo("B", "B", 123, Array("C", "A")),
Foo("B", "B", 123, Array("C", "E", "A")),
Foo("B", "B", 123, Array("D"))
).toDS()
Will result in
+---+---+---+---------+
| a| b| c| d|
+---+---+---+---------+
| A| A|123| [A]|
| A| A|123| [B]|
| B| B|123| [C, A]|
| B| B|123|[C, E, A]|
| B| B|123| [D]|
+---+---+---+---------+
what I am Looking for is (ordering of d column is not important):
+---+---+---+------------+
| a| b| c| d |
+---+---+---+------------+
| A| A|123| [A, B]. |
| B| B|123|[C, A, E, D]|
+---+---+---+------------+
this may be a bit similar to How to aggregate values into collection after groupBy? or the example from HighPerformanceSpark of https://github.com/high-performance-spark/high-performance-spark-examples/blob/57a6267fb77fae5a90109bfd034ae9c18d2edf22/src/main/scala/com/high-performance-spark-examples/transformations/SmartAggregations.scala#L33-L43
Using the following code:
import org.apache.spark.sql.functions.udf
val flatten = udf((xs: Seq[Seq[String]]) => xs.flatten.distinct)
val d = flatten(collect_list($"d")).alias("d")
df.groupBy($"a", $"b", $"c").agg(d).show
will produce the desired result, but I wonder if there are any possibilities to improve performance using the RDD API as outlined in the book. And would like to know how to formulate it using data set API.
Details about the execution for this minimal sample follow below:
== Optimized Logical Plan ==
GlobalLimit 21
+- LocalLimit 21
+- Aggregate [a#45, b#46, c#47], [a#45, b#46, c#47, UDF(collect_list(d#48, 0, 0)) AS d#82]
+- LocalRelation [a#45, b#46, c#47, d#48]
== Physical Plan ==
CollectLimit 21
+- SortAggregate(key=[a#45, b#46, c#47], functions=[collect_list(d#48, 0, 0)], output=[a#45, b#46, c#47, d#82])
+- *Sort [a#45 ASC NULLS FIRST, b#46 ASC NULLS FIRST, c#47 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(a#45, b#46, c#47, 200)
+- LocalTableScan [a#45, b#46, c#47, d#48]
edit
The problems of this operation are outlined very well https://github.com/awesome-spark/spark-gotchas/blob/master/04_rdd_actions_and_transformations_by_example.md#be-smart-about-groupbykey
edit2
As you can see the DAG for the dataSet query suggested below is more complicated and instead of 0.4 seem to take 2 seconds.

Try this
df.groupByKey(foo => (foo.a, foo.b, foo.c)).
reduceGroups{
(foo1, foo2) =>
foo1.copy(d = (foo1.d ++ foo2.d).distinct )
}.map(_._2)

Related

Pivot and aggregate a PySpark Data Frame with alias

I have a PySpark DataFrame similar to this:
df = sc.parallelize([
("c1", "A", 3.4, 0.4, 3.5),
("c1", "B", 9.6, 0.0, 0.0),
("c1", "A", 2.8, 0.4, 0.3),
("c1", "B", 5.4, 0.2, 0.11),
("c2", "A", 0.0, 9.7, 0.3),
("c2", "B", 9.6, 8.6, 0.1),
("c2", "A", 7.3, 9.1, 7.0),
("c2", "B", 0.7, 6.4, 4.3)
]).toDF(["user_id", "type", "d1", 'd2', 'd3'])
df.show()
which gives:
+-------+----+---+---+----+
|user_id|type| d1| d2| d3|
+-------+----+---+---+----+
| c1| A|3.4|0.4| 3.5|
| c1| B|9.6|0.0| 0.0|
| c1| A|2.8|0.4| 0.3|
| c1| B|5.4|0.2|0.11|
| c2| A|0.0|9.7| 0.3|
| c2| B|9.6|8.6| 0.1|
| c2| A|7.3|9.1| 7.0|
| c2| B|0.7|6.4| 4.3|
+-------+----+---+---+----+
And I've pivoted it by type column aggregating the result with a sum():
data_wide = df.groupBy('user_id')\
.pivot('type').sum()
data_wide.show()
which gives:
+-------+-----------------+------------------+-----------+------------------+-----------+------------------+
|user_id| A_sum(`d1`)| A_sum(`d2`)|A_sum(`d3`)| B_sum(`d1`)|B_sum(`d2`)| B_sum(`d3`)|
+-------+-----------------+------------------+-----------+------------------+-----------+------------------+
| c1|6.199999999999999| 0.8| 3.8| 15.0| 0.2| 0.11|
| c2| 7.3|18.799999999999997| 7.3|10.299999999999999| 15.0|4.3999999999999995|
+-------+-----------------+------------------+-----------+------------------+-----------+------------------+
Now, the resulting column names contains the `(tilde) character, and this is a problem to, for example, introduce this new columns in a Vector Assembler because it returns a syntax error in attribute name. For this reason, I need to rename the column names but to call a withColumnRenamed method inside a loop or inside a reduce(lambda...) function takes a lot of time (actually my df has 11.520 columns).
Is there any way to avoid this character in the pivot+aggregation step or recursively assign an alias that depends on the name of the new pivoted column?
Thank you in advance
You can do the renaming within the aggregation for the pivot using alias:
import pyspark.sql.functions as f
data_wide = df.groupBy('user_id')\
.pivot('type')\
.agg(*[f.sum(x).alias(x) for x in df.columns if x not in {"user_id", "type"}])
data_wide.show()
#+-------+-----------------+------------------+----+------------------+----+------------------+
#|user_id| A_d1| A_d2|A_d3| B_d1|B_d2| B_d3|
#+-------+-----------------+------------------+----+------------------+----+------------------+
#| c1|6.199999999999999| 0.8| 3.8| 15.0| 0.2| 0.11|
#| c2| 7.3|18.799999999999997| 7.3|10.299999999999999|15.0|4.3999999999999995|
#+-------+-----------------+------------------+----+------------------+----+------------------+
However, this is really no different than doing the pivot and renaming afterwards. Here is the execution plan for this method:
#== Physical Plan ==
#HashAggregate(keys=[user_id#0], functions=[pivotfirst(type#1, sum(`d1`) AS `d1`#169, A, B, 0, 0), pivotfirst(type#1, sum(`d2`)
#AS `d2`#170, A, B, 0, 0), pivotfirst(type#1, sum(`d3`) AS `d3`#171, A, B, 0, 0)])
#+- Exchange hashpartitioning(user_id#0, 200)
# +- HashAggregate(keys=[user_id#0], functions=[partial_pivotfirst(type#1, sum(`d1`) AS `d1`#169, A, B, 0, 0), partial_pivotfirst(type#1, sum(`d2`) AS `d2`#170, A, B, 0, 0), partial_pivotfirst(type#1, sum(`d3`) AS `d3`#171, A, B, 0, 0)])
# +- *HashAggregate(keys=[user_id#0, type#1], functions=[sum(d1#2), sum(d2#3), sum(d3#4)])
# +- Exchange hashpartitioning(user_id#0, type#1, 200)
# +- *HashAggregate(keys=[user_id#0, type#1], functions=[partial_sum(d1#2), partial_sum(d2#3), partial_sum(d3#4)])
# +- Scan ExistingRDD[user_id#0,type#1,d1#2,d2#3,d3#4]
Compare this with the method in this answer:
import re
def clean_names(df):
p = re.compile("^(\w+?)_([a-z]+)\((\w+)\)(?:\(\))?")
return df.toDF(*[p.sub(r"\1_\3", c) for c in df.columns])
pivoted = df.groupBy('user_id').pivot('type').sum()
clean_names(pivoted).explain()
#== Physical Plan ==
#HashAggregate(keys=[user_id#0], functions=[pivotfirst(type#1, sum(`d1`)#363, A, B, 0, 0), pivotfirst(type#1, sum(`d2`)#364, A, B, 0, 0), pivotfirst(type#1, sum(`d3`)#365, A, B, 0, 0)])
#+- Exchange hashpartitioning(user_id#0, 200)
# +- HashAggregate(keys=[user_id#0], functions=[partial_pivotfirst(type#1, sum(`d1`)#363, A, B, 0, 0), partial_pivotfirst(type#1, sum(`d2`)#364, A, B, 0, 0), partial_pivotfirst(type#1, sum(`d3`)#365, A, B, 0, 0)])
# +- *HashAggregate(keys=[user_id#0, type#1], functions=[sum(d1#2), sum(d2#3), sum(d3#4)])
# +- Exchange hashpartitioning(user_id#0, type#1, 200)
# +- *HashAggregate(keys=[user_id#0, type#1], functions=[partial_sum(d1#2), partial_sum(d2#3), partial_sum(d3#4)])
# +- Scan ExistingRDD[user_id#0,type#1,d1#2,d2#3,d3#4]
You'll see that the two are practically identical. You'll likely have some minuscule speed up by avoiding the regular expression, but it will be negligible compared to the pivot.
Wrote an easy and fast function to rename PySpark pivot tables. Enjoy! :)
# This function efficiently rename pivot tables' urgly names
def rename_pivot_cols(rename_df, remove_agg):
"""change spark pivot table's default ugly column names at ease.
Option 1: remove_agg = True: `2_sum(sum_amt)` --> `sum_amt_2`.
Option 2: remove_agg = False: `2_sum(sum_amt)` --> `sum_sum_amt_2`
"""
for column in rename_df.columns:
if remove_agg == True:
start_index = column.find('(')
end_index = column.find(')')
if (start_index > 0 and end_index > 0):
rename_df = rename_df.withColumnRenamed(column, column[start_index+1:end_index]+'_'+column[:1])
else:
new_column = column.replace('(','_').replace(')','')
rename_df = rename_df.withColumnRenamed(column, new_column[2:]+'_'+new_column[:1])
return rename_df

Speed up spark dataframe groupBy

I am fairly inexperienced in Spark, and need help with groupBy and aggregate functions on a dataframe. Consider the following dataframe:
val df = (Seq((1, "a", "1"),
(1,"b", "3"),
(1,"c", "6"),
(2, "a", "9"),
(2,"c", "10"),
(1,"b","8" ),
(2, "c", "3"),
(3,"r", "19")).toDF("col1", "col2", "col3"))
df.show()
+----+----+----+
|col1|col2|col3|
+----+----+----+
| 1| a| 1|
| 1| b| 3|
| 1| c| 6|
| 2| a| 9|
| 2| c| 10|
| 1| b| 8|
| 2| c| 3|
| 3| r| 19|
+----+----+----+
I need to group by col1 and col2 and calculate the mean of col3, which I can do using:
val col1df = df.groupBy("col1").agg(round(mean("col3"),2).alias("mean_col1"))
val col2df = df.groupBy("col2").agg(round(mean("col3"),2).alias("mean_col2"))
However, on a large dataframe with a few million rows and tens of thousands of unique elements in the columns to group by, it takes a very long time. Besides, I have many more columns to group by and it takes insanely long, which I am looking to reduce. Is there a better way to do the groupBy followed by the aggregation?
You could use ideas from Multiple Aggregations, it might do everything in one shuffle operations, which is the most expensive operation.
Example:
val df = (Seq((1, "a", "1"),
(1,"b", "3"),
(1,"c", "6"),
(2, "a", "9"),
(2,"c", "10"),
(1,"b","8" ),
(2, "c", "3"),
(3,"r", "19")).toDF("col1", "col2", "col3"))
df.createOrReplaceTempView("data")
val grpRes = spark.sql("""select grouping_id() as gid, col1, col2, round(mean(col3), 2) as res
from data group by col1, col2 grouping sets ((col1), (col2)) """)
grpRes.show(100, false)
Output:
+---+----+----+----+
|gid|col1|col2|res |
+---+----+----+----+
|1 |3 |null|19.0|
|2 |null|b |5.5 |
|2 |null|c |6.33|
|1 |1 |null|4.5 |
|2 |null|a |5.0 |
|1 |2 |null|7.33|
|2 |null|r |19.0|
+---+----+----+----+
gid is a bit funny to use, as it has some binary calculations underneath. But if your grouping columns can not have nulls, than you can use it for selecting the correct groups.
Execution Plan:
scala> grpRes.explain
== Physical Plan ==
*(2) HashAggregate(keys=[col1#111, col2#112, spark_grouping_id#108], functions=[avg(cast(col3#9 as double))])
+- Exchange hashpartitioning(col1#111, col2#112, spark_grouping_id#108, 200)
+- *(1) HashAggregate(keys=[col1#111, col2#112, spark_grouping_id#108], functions=[partial_avg(cast(col3#9 as double))])
+- *(1) Expand [List(col3#9, col1#109, null, 1), List(col3#9, null, col2#110, 2)], [col3#9, col1#111, col2#112, spark_grouping_id#108]
+- LocalTableScan [col3#9, col1#109, col2#110]
As you can see there is single Exchange operation, the expensive shuffle.

Why joining two spark dataframes fail unless I add ".as('alias)" to both?

Assume there are 2 Spark DataFrames we'd like to join, for whatever reason:
val df1 = Seq(("A", 1), ("B", 2), ("C", 3)).toDF("agent", "in_count")
val df2 = Seq(("A", 2), ("C", 2), ("D", 2)).toDF("agent", "out_count")
It can be done with the code like this:
val joinedDf = df1.as('d1).join(df2.as('d2), ($"d1.agent" === $"d2.agent"))
// Result:
val joinedDf.show
+-----+--------+-----+---------+
|agent|in_count|agent|out_count|
+-----+--------+-----+---------+
| A| 1| A| 2|
| C| 3| C| 2|
+-----+--------+-----+---------+
Now, what I don't understand, why does it work only as long as I use aliases df1.as(d1) and df2.as(d2)? I can imagine that there would be name clashes between the columns if I wrote it bluntly like
val joinedDf = df1.join(df2, ($"df1.agent" === $"df2.agent")) // fails
But...I don't understand why can't I use .as(alias) with only one DF of the two:
df1.as('d1).join(df2, ($"d1.agent" === $"df2.agent")).show()
fails with
org.apache.spark.sql.AnalysisException: cannot resolve '`df2.agent`' given input columns: [agent, in_count, agent, out_count];;
'Join Inner, (agent#25 = 'df2.agent)
:- SubqueryAlias d1
: +- Project [_1#22 AS agent#25, _2#23 AS in_count#26]
: +- LocalRelation [_1#22, _2#23]
+- Project [_1#32 AS agent#35, _2#33 AS out_count#36]
+- LocalRelation [_1#32, _2#33]
Why is the last example invalid?
Hello When you use alias DataFrame is converted into org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [agent: string, in_count: int] so you can use $"d1.agent" over there.
If you want to join on the DataFrame, you can do like this:
scala> val joinedDf = df1.join(df2, (df1("agent") === df2("agent")))
joinedDf: org.apache.spark.sql.DataFrame = [agent: string, in_count: int ... 2 more fields]
scala> joinedDf.show
+-----+--------+-----+---------+
|agent|in_count|agent|out_count|
+-----+--------+-----+---------+
| A| 1| A| 2|
| C| 3| C| 2|
+-----+--------+-----+---------+

How to compute the sum of orders over a 12 months period sliding by 1 month per customer in Spark

I am relatively new to spark with Scala. currently I am trying to aggregate order data in spark over a 12 months period that slides monthly.
Below is a simple sample of my data, I tried to format it so you can easily test it
import spark.implicits._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
var sample = Seq(("C1","01/01/2016", 20), ("C1","02/01/2016", 5),
("C1","03/01/2016", 2), ("C1","04/01/2016", 3), ("C1","05/01/2017", 5),
("C1","08/01/2017", 5), ("C1","01/02/2017", 10), ("C1","01/02/2017", 10),
("C1","01/03/2017", 10)).toDF("id","order_date", "orders")
sample = sample.withColumn("order_date",
to_date(unix_timestamp($"order_date", "dd/MM/yyyy").cast("timestamp")))
sample.show
+---+----------+------+
| id|order_date|orders|
+---+----------+------+
| C1|2016-01-01| 20|
| C1|2016-01-02| 5|
| C1|2016-01-03| 2|
| C1|2016-01-04| 3|
| C1|2017-01-05| 5|
| C1|2017-01-08| 5|
| C1|2017-02-01| 10|
| C1|2017-02-01| 10|
| C1|2017-03-01| 10|
+---+----------+------+
the imposed upon me outcome is the following.
id period_start period_end rolling
C1 2015-01-01 2016-01-01 30
C1 2016-01-01 2017-01-01 40
C1 2016-02-01 2017-02-01 30
C1 2016-03-01 2017-03-01 40
what I tried to do so far
I collapsed the dates per costumer to the first day of the month
(e.i. 2016-01-[1..31] >> 2016-01-01 )
import org.joda.time._
val collapse_month = (month:Integer, year:Integer ) => {
var dt = new DateTime().withYear(year)
.withMonthOfYear(month)
.withDayOfMonth(1)
dt.toString("yyyy-MM-dd")
}
val collapse_month_udf = udf(collapse_month)
sample = sample.withColumn("period_end",
collapse_month_udf(
month(col("order_date")),
year(col("order_date"))
).as("date"))
sample.groupBy($"id", $"period_end")
.agg(sum($"orders").as("orders"))
.orderBy("period_end").show
+---+----------+------+
| id|period_end|orders|
+---+----------+------+
| C1|2016-01-01| 30|
| C1|2017-01-01| 10|
| C1|2017-02-01| 20|
| C1|2017-03-01| 10|
+---+----------+------+
I tried the provided window function but I was not able to use 12 months sliding by one option.
I am really not sure what is the best way to proceed from this point, that would not take 5 hours given how much data I have to work with.
Any help would be appreciated.
tried the provided window function but I was not able to use 12 months sliding by one option.
You can still use window with longer intervals, but all parameters have to be expressed in days or weeks:
window($"order_date", "365 days", "28 days")
Unfortunately window this won't respect month or year boundaries, so it won't be that useful for you.
Personally I would aggregate data first:
val byMonth = sample
.groupBy($"id", trunc($"order_date", "month").alias("order_month"))
.agg(sum($"orders").alias("orders"))
+---+-----------+-----------+
| id|order_month|sum(orders)|
+---+-----------+-----------+
| C1| 2017-01-01| 10|
| C1| 2016-01-01| 30|
| C1| 2017-02-01| 20|
| C1| 2017-03-01| 10|
+---+-----------+-----------+
Create reference date range:
import java.time.temporal.ChronoUnit
val Row(start: java.sql.Date, end: java.sql.Date) = byMonth
.select(min($"order_month"), max($"order_month"))
.first
val months = (0L to ChronoUnit.MONTHS.between(
start.toLocalDate, end.toLocalDate))
.map(i => java.sql.Date.valueOf(start.toLocalDate.plusMonths(i)))
.toDF("order_month")
And combine with unique ids:
val ref = byMonth.select($"id").distinct.crossJoin(months)
and join back with the source:
val expanded = ref.join(byMonth, Seq("id", "order_month"), "leftouter")
+---+-----------+------+
| id|order_month|orders|
+---+-----------+------+
| C1| 2016-01-01| 30|
| C1| 2016-02-01| null|
| C1| 2016-03-01| null|
| C1| 2016-04-01| null|
| C1| 2016-05-01| null|
| C1| 2016-06-01| null|
| C1| 2016-07-01| null|
| C1| 2016-08-01| null|
| C1| 2016-09-01| null|
| C1| 2016-10-01| null|
| C1| 2016-11-01| null|
| C1| 2016-12-01| null|
| C1| 2017-01-01| 10|
| C1| 2017-02-01| 20|
| C1| 2017-03-01| 10|
+---+-----------+------+
With data prepared like this you can use window functions:
import org.apache.spark.sql.expressions.Window
val w = Window.partitionBy($"id")
.orderBy($"order_month")
.rowsBetween(-12, Window.currentRow)
expanded.withColumn("rolling", sum("orders").over(w))
.na.drop(Seq("orders"))
.select(
$"order_month" - expr("INTERVAL 12 MONTHS") as "period_start",
$"order_month" as "period_end",
$"rolling")
+------------+----------+-------+
|period_start|period_end|rolling|
+------------+----------+-------+
| 2015-01-01|2016-01-01| 30|
| 2016-01-01|2017-01-01| 40|
| 2016-02-01|2017-02-01| 30|
| 2016-03-01|2017-03-01| 40|
+------------+----------+-------+
Please be advised this is a very expensive operation, requiring at least two shuffles:
== Physical Plan ==
*Project [cast(cast(order_month#104 as timestamp) - interval 1 years as date) AS period_start#1387, order_month#104 AS period_end#1388, rolling#1375L]
+- *Filter AtLeastNNulls(n, orders#55L)
+- Window [sum(orders#55L) windowspecdefinition(id#7, order_month#104 ASC NULLS FIRST, ROWS BETWEEN 12 PRECEDING AND CURRENT ROW) AS rolling#1375L], [id#7], [order_month#104 ASC NULLS FIRST]
+- *Sort [id#7 ASC NULLS FIRST, order_month#104 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(id#7, 200)
+- *Project [id#7, order_month#104, orders#55L]
+- *BroadcastHashJoin [id#7, order_month#104], [id#181, order_month#49], LeftOuter, BuildRight
:- BroadcastNestedLoopJoin BuildRight, Cross
: :- *HashAggregate(keys=[id#7], functions=[])
: : +- Exchange hashpartitioning(id#7, 200)
: : +- *HashAggregate(keys=[id#7], functions=[])
: : +- *HashAggregate(keys=[id#7, trunc(order_date#14, month)#1394], functions=[])
: : +- Exchange hashpartitioning(id#7, trunc(order_date#14, month)#1394, 200)
: : +- *HashAggregate(keys=[id#7, trunc(order_date#14, month) AS trunc(order_date#14, month)#1394], functions=[])
: : +- LocalTableScan [id#7, order_date#14]
: +- BroadcastExchange IdentityBroadcastMode
: +- LocalTableScan [order_month#104]
+- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true], input[1, date, true]))
+- *HashAggregate(keys=[id#181, trunc(order_date#14, month)#1395], functions=[sum(cast(orders#183 as bigint))])
+- Exchange hashpartitioning(id#181, trunc(order_date#14, month)#1395, 200)
+- *HashAggregate(keys=[id#181, trunc(order_date#14, month) AS trunc(order_date#14, month)#1395], functions=[partial_sum(cast(orders#183 as bigint))])
+- LocalTableScan [id#181, order_date#14, orders#183]
It is also possible to express this using rangeBetween frame, but you have to encode data first:
val encoded = byMonth
.withColumn("order_month_offset",
// Choose "zero" date appropriate in your scenario
months_between($"order_month", to_date(lit("1970-01-01"))))
val w = Window.partitionBy($"id")
.orderBy($"order_month_offset")
.rangeBetween(-12, Window.currentRow)
encoded.withColumn("rolling", sum($"orders").over(w))
+---+-----------+------+------------------+-------+
| id|order_month|orders|order_month_offset|rolling|
+---+-----------+------+------------------+-------+
| C1| 2016-01-01| 30| 552.0| 30|
| C1| 2017-01-01| 10| 564.0| 40|
| C1| 2017-02-01| 20| 565.0| 30|
| C1| 2017-03-01| 10| 566.0| 40|
+---+-----------+------+------------------+-------+
This would make the join with reference obsolete and simplify execution plan.

Self-join not working as expected with the DataFrame API

I am trying to get the latest records from a table using self join. It works using spark-sql but not working using spark DataFrame API.
Can anyone help? Is it a bug?
I am using Spark 2.2.0 in local mode
Creating input DataFrame:
scala> val df3 = spark.sparkContext.parallelize(Array((1,"a",1),(1,"aa",2),(2,"b",2),(2,"bb",5))).toDF("id","value","time")
df3: org.apache.spark.sql.DataFrame = [id: int, value: string ... 1 more field]
scala> val df33 = df3
df33: org.apache.spark.sql.DataFrame = [id: int, value: string ... 1 more field]
scala> df3.show
+---+-----+----+
| id|value|time|
+---+-----+----+
| 1| a| 1|
| 1| aa| 2|
| 2| b| 2|
| 2| bb| 5|
+---+-----+----+
scala> df33.show
+---+-----+----+
| id|value|time|
+---+-----+----+
| 1| a| 1|
| 1| aa| 2|
| 2| b| 2|
| 2| bb| 5|
+---+-----+----+
Now performing the join using SQL: works
scala> spark.sql("select df33.* from df3 join df33 on df3.id = df33.id and df3.time < df33.time").show
+---+-----+----+
| id|value|time|
+---+-----+----+
| 1| aa| 2|
| 2| bb| 5|
+---+-----+----+
Now performing the join using dataframe API: doesn't work
scala> df3.join(df33, (df3.col("id") === df33.col("id")) && (df3.col("time") < df33.col("time")) ).select(df33.col("id"),df33.col("value"),df33.col("time")).show
+---+-----+----+
| id|value|time|
+---+-----+----+
+---+-----+----+
The thing to notice is the explain plans: blank for the DataFrame API!!
scala> df3.join(df33, (df3.col("id") === df33.col("id")) && (df3.col("time") < df33.col("time")) ).select(df33.col("id"),df33.col("value"),df33.col("time")).explain
== Physical Plan ==
LocalTableScan <empty>, [id#150, value#151, time#152]
scala> spark.sql("select df33.* from df3 join df33 on df3.id = df33.id and df3.time < df33.time").explain
== Physical Plan ==
*Project [id#1241, value#1242, time#1243]
+- *SortMergeJoin [id#150], [id#1241], Inner, (time#152 < time#1243)
:- *Sort [id#150 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(id#150, 200)
: +- *Project [_1#146 AS id#150, _3#148 AS time#152]
: +- *SerializeFromObject [assertnotnull(input[0, scala.Tuple3, true])._1 AS _1#146, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString,
assertnotnull(input[0, scala.Tuple3, true])._2, true) AS _2#147, assertnotnull(input[0, scala.Tuple3, true])._3 AS _3#148]
: +- Scan ExternalRDDScan[obj#145]
+- *Sort [id#1241 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(id#1241, 200)
+- *Project [_1#146 AS id#1241, _2#147 AS value#1242, _3#148 AS time#1243]
+- *SerializeFromObject [assertnotnull(input[0, scala.Tuple3, true])._1 AS _1#146, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString,
assertnotnull(input[0, scala.Tuple3, true])._2, true) AS _2#147, assertnotnull(input[0, scala.Tuple3, true])._3 AS _3#148]
+- Scan ExternalRDDScan[obj#145]
No that's not a bug, but when you reassign the DataFrame to a new one like what you have done, it actually copies the lineage but it doesn't duplicate the data. Thus you'll be comparing on the same column.
Use spark.sql is slightly different because it's actually working on aliases of your DataFrames
So the correct way to perform a self-join using the API is actually aliasing your DataFrame as followed :
val df1 = Seq((1,"a",1),(1,"aa",2),(2,"b",2),(2,"bb",5)).toDF("id","value","time")
df1.as("df1").join(df1.as("df2"), $"df1.id" === $"df2.id" && $"df1.time" < $"df2.time").select($"df2.*").show
// +---+-----+----+
// | id|value|time|
// +---+-----+----+
// | 1| aa| 2|
// | 2| bb| 5|
// +---+-----+----+
For more information about self-joins, I recommend reading High Performance Spark by Rachel Warren, Holden Karau - Chapter 4.