Convert Spark2.2's UDAF to 3.0 Aggregator - scala

I have a already written UDAF in scala using Spark2.4. Since our Databricks cluster was in 6.4 runtime whose support is no more there, we need to move to 7.3 LTS which have the long term support and uses Spark3. UDAF is deprecated in Spark3 and will be removed in future(most likely). So I am trying to convert a UDAF into Aggregator function
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{IntegerType,StringType, StructField, StructType, DataType}
object MaxCampaignIdAggregator extends UserDefinedAggregateFunction with java.io.Serializable{
override def inputSchema: StructType = new StructType()
.add("id", IntegerType, true)
.add("name", StringType, true)
def bufferSchema: StructType = new StructType()
.add("id", IntegerType, true)
.add("name", StringType, true)
// Returned Data Type .
def dataType: DataType = new StructType()
.add("id", IntegerType, true)
.add("name", StringType, true)
// Self-explaining
def deterministic: Boolean = true
// This function is called whenever key changes
def initialize(buffer: MutableAggregationBuffer) = {
buffer(0) = null
buffer(1) = null
}
// Iterate over each entry of a group
def update(buffer: MutableAggregationBuffer, inputRow: Row): Unit ={
val inputId = inputRow.getAs[Int](0)
val actualInputId = inputRow.get(0)
val inputName = inputRow.getString(1)
val bufferId = buffer.getAs[Int](0)
val actualBufferId = buffer.get(0)
val bufferName = buffer.getString(1)
if(actualBufferId == null){
buffer(0) = actualInputId
buffer(1) = inputName
}else if(actualInputId != null) {
if(inputId > bufferId){
buffer(0) = inputId
buffer(1) = inputName
}
}
}
// Merge two partial aggregates
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
val buffer1Id = buffer1.getAs[Int](0)
val actualbuffer1Id = buffer1.get(0)
val buffer1Name = buffer1.getString(1)
val buffer2Id = buffer2.getAs[Int](0)
val actualbuffer2Id = buffer2.get(0)
val buffer2Name = buffer2.getString(1)
if(actualbuffer1Id == null){
buffer1(0) = actualbuffer2Id
buffer1(1) = buffer2Name
}else if(actualbuffer2Id != null){
if(buffer2Id > buffer1Id){
buffer1(0) = buffer2Id
buffer1(1) = buffer2Name
}
}
}
// Called after all the entries are exhausted.
def evaluate(buffer: Row): Any = {
Row(buffer.get(0), buffer.getString(1))
}
}
After usage this give output as :
{"id": 1282, "name": "McCormick Christmas"}
{"id": 1305, "name": "McCormick Perfect Pinch"}
{"id": 1677, "name": "Viking Cruises Viking Cruises"}

Related

How to port UDAF to Aggregator?

I have a DF looking like this:
time,channel,value
0,foo,5
0,bar,23
100,foo,42
...
I want a DF like this:
time,foo,bar
0,5,23
100,42,...
In Spark 2, I did it with a UDAF like this:
case class ColumnBuilderUDAF(channels: Seq[String]) extends UserDefinedAggregateFunction {
#transient lazy val inputSchema: StructType = StructType {
StructField("channel", StringType, nullable = false) ::
StructField("value", DoubleType, nullable = false) ::
Nil
}
#transient lazy val bufferSchema: StructType = StructType {
channels
.toList
.indices
.map(i => StructField("c%d".format(i), DoubleType, nullable = false))
}
#transient lazy val dataType: DataType = bufferSchema
#transient lazy val deterministic: Boolean = false
def initialize(buffer: MutableAggregationBuffer): Unit = channels.indices.foreach(buffer(_) = NaN)
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val channel = input.getAs[String](0)
val p = channels.indexOf(channel)
if (p >= 0 && p < channels.length) {
val v = input.getAs[Double](1)
if (!v.isNaN) {
buffer(p) = v
}
}
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit =
channels
.indices
.foreach { i =>
val v2 = buffer2.getAs[Double](i)
if ((!v2.isNaN) && buffer1.getAs[Double](i).isNaN) {
buffer1(i) = v2
}
}
def evaluate(buffer: Row): Any =
new GenericRowWithSchema(channels.indices.map(buffer.getAs[Double]).toArray, dataType.asInstanceOf[StructType])
}
which I use like this:
val cb = ColumnBuilderUDAF(Seq("foo", "bar"))
val dfColumnar = df.groupBy($"time").agg(cb($"channel", $"value") as "c")
and then, I rename c.c0, c.c1 etc. to foo, bar etc.
In Spark 3, UDAF is deprecated and Aggregator should be used instead. So I began to port it like this:
case class ColumnBuilder(channels: Seq[String]) extends Aggregator[(String, Double), Array[Double], Row] {
lazy val bufferEncoder: Encoder[Array[Double]] = Encoders.javaSerialization[Array[Double]]
lazy val zero: Array[Double] = channels.map(_ => Double.NaN).toArray
def reduce(b: Array[Double], a: (String, Double)): Array[Double] = {
val index = channels.indexOf(a._1)
if (index >= 0 && !a._2.isNaN) b(index) = a._2
b
}
def merge(b1: Array[Double], b2: Array[Double]): Array[Double] = {
(0 until b1.length.min(b2.length)).foreach(i => if (b1(i).isNaN) b1(i) = b2(i))
b1
}
def finish(reduction: Array[Double]): Row =
new GenericRowWithSchema(reduction.map(x => x: Any), outputEncoder.schema)
def outputEncoder: Encoder[Row] = ??? // what goes here?
}
I don't know how to implement the Encoder[Row] as Spark does not have a pre-defined one. If I simply do a straightforward approach like this:
val outputEncoder: Encoder[Row] = new Encoder[Row] {
val schema: StructType = StructType(channels.map(StructField(_, DoubleType, nullable = false)))
val clsTag: ClassTag[Row] = classTag[Row]
}
I get a ClassCastException because outputEncoder actually has to be ExpressionEncoder.
So, how do I implement this correctly? Or do I still have to use the deprecated UDAF?
You can do it with the use of groupBy and pivot
import spark.implicits._
import org.apache.spark.sql.functions._
val df = Seq(
(0, "foo", 5),
(0, "bar", 23),
(100, "foo", 42)
).toDF("time", "channel", "value")
df.groupBy("time")
.pivot("channel")
.agg(first("value"))
.show(false)
Output:
+----+----+---+
|time|bar |foo|
+----+----+---+
|100 |null|42 |
|0 |23 |5 |
+----+----+---+

Spark sessionization using data frames

I want to do clickstream sessionization on the spark data frame. Let's I have loaded the data frame which has events from multiple sessions with the following schema -
And I want to aggregate(stitch) the sessions, like this -
I have explored UDAF and Window functions but could not understand how I can use them for this specific use case. I know that partitioning the data by session id puts entire session data in a single partition but how do I aggregate them?
The idea is to aggregate all the events specific to each session as a single output record.
You can use collect_set:
def process(implicit spark: SparkSession) = {
import spark._
import org.apache.spark.sql.functions.{ concat, col, collect_set }
val seq = Seq(Row(1, 1, "startTime=1549270909"), Row(1, 1, "endTime=1549270913"))
val rdd = spark.sparkContext.parallelize(seq)
val df1 = spark.createDataFrame(rdd, StructType(List(StructField("sessionId", IntegerType, false), StructField("userId", IntegerType, false), StructField("session", StringType, false))))
df1.groupBy("sessionId").agg(collect_set("session"))
}
}
That gives you:
+---------+------------------------------------------+
|sessionId|collect_set(session) |
+---------+------------------------------------------+
|1 |[startTime=1549270909, endTime=1549270913]|
+---------+------------------------------------------+
as output.
If you need a more complex logic, it could be included in the following UDAF:
class YourComplexLogicStrings extends UserDefinedAggregateFunction {
override def inputSchema: StructType = StructType(StructField("input", StringType) :: Nil)
override def bufferSchema: StructType = StructType(StructField("pair", StringType) :: Nil)
override def dataType: DataType = StringType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = buffer(0) = ""
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val b = buffer.getAs[String](0)
val i = input.getAs[String](0)
buffer(0) = { if(b.isEmpty) b + i else b + " + " + i }
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val b1 = buffer1.getAs[String](0)
val b2 = buffer2.getAs[String](0)
if(!b1.isEmpty)
buffer1(0) = (b1) ++ "," ++ (b2)
else
buffer1(0) = b2
}
override def evaluate(buffer: Row): Any = {
val yourString = buffer.getAs[String](0)
// Compute your logic and return another String
yourString
}
}
def process0(implicit spark: SparkSession) = {
import org.apache.spark.sql.functions.{ concat, col, collect_set }
val agg0 = new YourComplexLogicStrings()
val seq = Seq(Row(1, 1, "startTime=1549270909"), Row(1, 1, "endTime=1549270913"))
val rdd = spark.sparkContext.parallelize(seq)
val df1 = spark.createDataFrame(rdd, StructType(List(StructField("sessionId", IntegerType, false), StructField("userId", IntegerType, false), StructField("session", StringType, false))))
df1.groupBy("sessionId").agg(agg0(col("session")))
}
It gives:
+---------+---------------------------------------+
|sessionId|yourcomplexlogicstrings(session) |
+---------+---------------------------------------+
|1 |startTime=1549270909,endTime=1549270913|
+---------+---------------------------------------+
Note that you could include very complex logic using spark sql functions directly if you want to avoid UDAFs.

Computing Quartiles over Windowed Dataframe

I have some data, for the sake of discussion take it to be given by:
val schema = Seq("id", "day", "value")
val data = Seq(
(1, 1, 1),
(1, 2, 11),
(1, 3, 1),
(1, 4, 11),
(1, 5, 1),
(1, 6, 11),
(2, 1, 1),
(2, 2, 11),
(2, 3, 1),
(2, 4, 11),
(2, 5, 1),
(2, 6, 11)
)
val df = sc.parallelize(data).toDF(schema: _*)
I would like to compute quartiles for each ID over a moving window of days. Something like
val w = Window.partitionBy("id").orderBy("day").rangeBetween(-2, 0)
df.select(col("id"),col("day"),collect_list(col("value")).over(w),quartiles(col("value")).over(w).as("Quartiles"))
Of course there isn't a quartiles function for this so I need to write a UserDefinedAggregateFunction. The following is a simple (albeit non-scalable) solution (based on this) CollectionFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
class QuartilesFunction extends UserDefinedAggregateFunction {
def inputSchema: StructType =
StructType(StructField("value", DoubleType, false) :: Nil)
def bufferSchema: StructType = StructType(StructField("lower", ArrayType(DoubleType, true), true) :: StructField("upper", ArrayType(DoubleType, true), true) :: Nil)
override def dataType: DataType = ArrayType(DoubleType, true)
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = IndexedSeq[Double]()
buffer(1) = IndexedSeq[Double]()
}
def rebalance(lower : IndexedSeq[Double], upper : IndexedSeq[Double]) = {
(lower++upper).splitAt((lower.length+upper.length)/2)
}
def sorted_median(x : IndexedSeq[Double]) : Option[Double] = {
if(x.length == 0) {
None
}
val N = x.length
val (lower, upper) = x.splitAt(N/2)
Some(
if(N%2==0) {
(lower.last+upper.head)/2.0
} else {
upper.head
}
)
}
// this is how to update the buffer given an input
def update(buffer : MutableAggregationBuffer, input : Row) : Unit = {
val lower = buffer(0).asInstanceOf[IndexedSeq[Double]]
val upper = buffer(1).asInstanceOf[IndexedSeq[Double]]
val value = input.getAs[Double](0)
if(lower.length == 0) {
buffer(0) = Array(value)
} else {
if(value >= lower.last) {
buffer(1) = (value +: upper).sortWith(_<_)
} else {
buffer(0) = (lower :+ value).sortWith(_<_)
}
}
val (result0,result1) = rebalance(buffer(0).asInstanceOf[IndexedSeq[Double]],buffer(1).asInstanceOf[IndexedSeq[Double]])
buffer(0) = result0
buffer(1) = result1
}
// this is how to merge two objects with the buffer schema type
def merge(buffer1 : MutableAggregationBuffer, buffer2 : Row) : Unit = {
buffer1(0) = buffer1(0).asInstanceOf[IndexedSeq[Double]] ++ buffer2(0).asInstanceOf[IndexedSeq[Double]]
buffer1(1) = buffer1(1).asInstanceOf[IndexedSeq[Double]] ++ buffer2(1).asInstanceOf[IndexedSeq[Double]]
val (result0,result1) = rebalance(buffer1(0).asInstanceOf[IndexedSeq[Double]],buffer1(1).asInstanceOf[IndexedSeq[Double]])
buffer1(0) = result0
buffer1(1) = result1
}
def evaluate(buffer: Row): Array[Option[Double]] = {
val lower =
if (buffer(0) == null) {
IndexedSeq[Double]()
} else {
buffer(0).asInstanceOf[IndexedSeq[Double]]
}
val upper =
if (buffer(1) == null) {
IndexedSeq[Double]()
} else {
buffer(1).asInstanceOf[IndexedSeq[Double]]
}
val Q1 = sorted_median(lower)
val Q2 = if(upper.length==0) { None } else { Some(upper.head) }
val Q3 = sorted_median(upper)
Array(Q1,Q2,Q3)
}
}
However, executing the following produces an error:
val quartiles = new QuartilesFunction
df.select('*).show
val w = org.apache.spark.sql.expressions.Window.partitionBy("id").orderBy("day").rangeBetween(-2, 0)
val x = df.select(col("id"),col("day"),collect_list(col("value")).over(w),quartiles(col("value")).over(w).as("Quantiles"))
x.show
The error is:
org.apache.spark.SparkException: Task not serializable
The offending function seems to be sorted_median. If I replace the code with:
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
class QuartilesFunction extends UserDefinedAggregateFunction {
def inputSchema: StructType =
StructType(StructField("value", DoubleType, false) :: Nil)
def bufferSchema: StructType = StructType(StructField("lower", ArrayType(DoubleType, true), true) :: StructField("upper", ArrayType(DoubleType, true), true) :: Nil)
override def dataType: DataType = ArrayType(DoubleType, true)
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = IndexedSeq[Double]()
buffer(1) = IndexedSeq[Double]()
}
def rebalance(lower : IndexedSeq[Double], upper : IndexedSeq[Double]) = {
(lower++upper).splitAt((lower.length+upper.length)/2)
}
/*
def sorted_median(x : IndexedSeq[Double]) : Option[Double] = {
if(x.length == 0) {
None
}
val N = x.length
val (lower, upper) = x.splitAt(N/2)
Some(
if(N%2==0) {
(lower.last+upper.head)/2.0
} else {
upper.head
}
)
}
*/
// this is how to update the buffer given an input
def update(buffer : MutableAggregationBuffer, input : Row) : Unit = {
val lower = buffer(0).asInstanceOf[IndexedSeq[Double]]
val upper = buffer(1).asInstanceOf[IndexedSeq[Double]]
val value = input.getAs[Double](0)
if(lower.length == 0) {
buffer(0) = Array(value)
} else {
if(value >= lower.last) {
buffer(1) = (value +: upper).sortWith(_<_)
} else {
buffer(0) = (lower :+ value).sortWith(_<_)
}
}
val (result0,result1) = rebalance(buffer(0).asInstanceOf[IndexedSeq[Double]],buffer(1).asInstanceOf[IndexedSeq[Double]])
buffer(0) = result0
buffer(1) = result1
}
// this is how to merge two objects with the buffer schema type
def merge(buffer1 : MutableAggregationBuffer, buffer2 : Row) : Unit = {
buffer1(0) = buffer1(0).asInstanceOf[IndexedSeq[Double]] ++ buffer2(0).asInstanceOf[IndexedSeq[Double]]
buffer1(1) = buffer1(1).asInstanceOf[IndexedSeq[Double]] ++ buffer2(1).asInstanceOf[IndexedSeq[Double]]
val (result0,result1) = rebalance(buffer1(0).asInstanceOf[IndexedSeq[Double]],buffer1(1).asInstanceOf[IndexedSeq[Double]])
buffer1(0) = result0
buffer1(1) = result1
}
def evaluate(buffer: Row): Array[Option[Double]] = {
val lower =
if (buffer(0) == null) {
IndexedSeq[Double]()
} else {
buffer(0).asInstanceOf[IndexedSeq[Double]]
}
val upper =
if (buffer(1) == null) {
IndexedSeq[Double]()
} else {
buffer(1).asInstanceOf[IndexedSeq[Double]]
}
val Q1 = Some(1.0)//sorted_median(lower)
val Q2 = Some(2.0)//if(upper.length==0) { None } else { Some(upper.head) }
val Q3 = Some(3.0)//sorted_median(upper)
Array(Q1,Q2,Q3)
}
}
Then everything works, except that it doesn't compute quartiles (obviously). I don't understand the error and the rest of the stacktrace isn't any more illuminating. Could someone help me to understand what the problem is and/or how to compute these quartiles?
If you have a hive-context (or hiveSupportEnabled) you can use percentile UDAF as follows:
val dfQuartiles = df.select(
col("id"),
col("day"),
collect_list(col("value")).over(w).as("values"),
callUDF("percentile", col("value"), lit(0.25)).over(w).as("Q1"),
callUDF("percentile", col("value"), lit(0.50)).over(w).as("Q2"),
callUDF("percentile", col("value"), lit(0.75)).over(w).as("Q3"),
callUDF("percentile", col("value"), lit(1.0)).over(w).as("Q4")
)
Alternatively you can use an UDF to calculate the quartiles from values ( as you have this array anyway):
val calcPercentile = udf((xs:Seq[Int], percentile:Double) => {
val ss = xs.toSeq.sorted
val index = ((ss.size-1)*percentile).toInt
ss(index)
}
)
val dfQuartiles = df.select(
col("id"),
col("day"),
collect_list(col("value")).over(w).as("values")
)
.withColumn("Q1",calcPercentile($"values",lit(0.25)))
.withColumn("Q2",calcPercentile($"values",lit(0.50)))
.withColumn("Q3",calcPercentile($"values",lit(0.75)))
.withColumn("Q4",calcPercentile($"values",lit(1.00)))

How to add/mutate a Map object in MutableAggregationBuffer in UDAF?

I use Spark 2.0.1 and Scala 2.11.
This is a question related to user-defined aggregate function (UDAF) in Spark. I'm using the example answer provided here to ask my question:
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.{Row, Column}
object DummyUDAF extends UserDefinedAggregateFunction {
def inputSchema = new StructType().add("x", StringType)
def bufferSchema = new StructType()
.add("buff", ArrayType(LongType))
.add("buff2", ArrayType(DoubleType))
def dataType = new StructType()
.add("xs", ArrayType(LongType))
.add("ys", ArrayType(DoubleType))
def deterministic = true
def initialize(buffer: MutableAggregationBuffer) = {}
def update(buffer: MutableAggregationBuffer, input: Row) = {}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {}
def evaluate(buffer: Row) = (Array(1L, 2L, 3L), Array(1.0, 2.0, 3.0))
}
I'm able to return multiple Maps instead of an Array easily, but not able to mutate the map in the update method.
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.{Row, Column}
import scala.collection.mutable.Map
object DummyUDAF extends UserDefinedAggregateFunction {
def inputSchema = new StructType().add("x", DoubleType).add("y", IntegerType)
def bufferSchema = new StructType()
.add("buff", MapType(DoubleType, IntegerType))
.add("buff2", MapType(DoubleType, IntegerType))
def dataType = new StructType()
.add("xs", MapType(DoubleType, IntegerType))
.add("ys", MapType(DoubleType, IntegerType))
def deterministic = true
def initialize(buffer: MutableAggregationBuffer) = {
buffer(0) = scala.collection.mutable.Map[Double,Int]()
buffer(1) = scala.collection.mutable.Map[Double,Int]()
}
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0).asInstanceOf[Map[Double,Int]](input.getDouble(0)) = input.getInt(1)
buffer(1).asInstanceOf[Map[Double,Int]](input.getDouble(0)*10) = input.getInt(1)*10
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
buffer1(0).asInstanceOf[Map[Double,Int]] ++= buffer2(0).asInstanceOf[Map[Double,Int]]
buffer1(1).asInstanceOf[Map[Double,Int]] ++= buffer2(1).asInstanceOf[Map[Double,Int]]
}
//def evaluate(buffer: Row) = (Map(1.0->10,2.0->20), Map(10.0->100,11.0->110))
def evaluate(buffer: Row) = (buffer(0).asInstanceOf[Map[Double,Int]], buffer(1).asInstanceOf[Map[Double,Int]])
}
This compiles fine, but gives a runtime error:
val df = Seq((1.0, 1), (2.0, 2)).toDF("k", "v")
df.select(DummyUDAF($"k", $"v")).show(1, false)
org.apache.spark.SparkException: Job aborted due to stage failure: Task 1 in stage 70.0 failed 4 times, most recent failure: Lost task 1.3 in stage 70.0 (TID 204, 10.91.252.25): java.lang.ClassCastException: scala.collection.immutable.Map$EmptyMap$ cannot be cast to scala.collection.mutable.Map
Another solution discussed here indicates that this could be a problem due to MapType StructType. However, when I try out the solution mentioned I still get the same error.
val distudaf = new DistinctValues
val df = Seq(("a", "a1"), ("a", "a1"), ("a", "a2"), ("b", "b1"), ("b", "b2"), ("b", "b3"), ("b", "b1"), ("b", "b1")).toDF("col1", "col2")
df.groupBy("col1").agg(distudaf($"col2").as("DV")).show
org.apache.spark.SparkException: Job aborted due to stage failure: Task 1 in stage 22.0 failed 4 times, most recent failure: Lost task 1.3 in stage 22.0 (TID 100, 10.91.252.25): java.lang.ClassCastException: scala.collection.immutable.Map$EmptyMap$ cannot be cast to scala.collection.mutable.Map
My preference would be to mutate the Map, given that I expect the Map to be huge, and making a copy and re-assigning may be lead to performance/memory bottlenecks)
My limited understanding of UDAF is that you should only set what you want to be (semantically) updated, i.e. take what is set already in MutableAggregationBuffer, combine with what you want to add and...= it (which will call update(i: Int, value: Any): Unit under the covers)
Your code could look as follows:
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val newBuffer0 = buffer(0).asInstanceOf[Map[Double, Int]]
buffer(0) = newBuffer0 + (input.getDouble(0) -> input.getInt(1))
val newBuffer1 = buffer(1).asInstanceOf[Map[Double, Int]]
buffer(1) = newBuffer1 + (input.getDouble(0) * 10 -> input.getInt(1) * 10)
}
The complete DummyUDAF could be as follows:
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.{Row, Column}
object DummyUDAF extends UserDefinedAggregateFunction {
def inputSchema = new StructType().add("x", DoubleType).add("y", IntegerType)
def bufferSchema = new StructType()
.add("buff", MapType(DoubleType, IntegerType))
.add("buff2", MapType(DoubleType, IntegerType))
def dataType = new StructType()
.add("xs", MapType(DoubleType, IntegerType))
.add("ys", MapType(DoubleType, IntegerType))
def deterministic = true
def initialize(buffer: MutableAggregationBuffer) = {
buffer(0) = Map[Double,Int]()
buffer(1) = Map[Double,Int]()
}
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val newBuffer0 = buffer(0).asInstanceOf[Map[Double, Int]]
buffer(0) = newBuffer0 + (input.getDouble(0) -> input.getInt(1))
val newBuffer1 = buffer(1).asInstanceOf[Map[Double, Int]]
buffer(1) = newBuffer1 + (input.getDouble(0) * 10 -> input.getInt(1) * 10)
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
buffer1(0) = buffer1(0).asInstanceOf[Map[Double,Int]] ++ buffer2(0).asInstanceOf[Map[Double,Int]]
buffer1(1) = buffer1(1).asInstanceOf[Map[Double,Int]] ++ buffer2(1).asInstanceOf[Map[Double,Int]]
}
//def evaluate(buffer: Row) = (Map(1.0->10,2.0->20), Map(10.0->100,11.0->110))
def evaluate(buffer: Row) = (buffer(0).asInstanceOf[Map[Double,Int]], buffer(1).asInstanceOf[Map[Double,Int]])
}
Late for the party. I just discovered that one can use
override def bufferSchema: StructType = StructType(List(
StructField("map", ObjectType(classOf[mutable.Map[String, Long]]))
))
to use mutable.Map in a buffer.

how to properly save spark rdd result to a mysql database

I currently do as follows to save a spark RDD result into a mysql database.
import anorm._
import java.sql.Connection
import org.apache.spark.rdd.RDD
val wordCounts: RDD[(String, Int)] = ...
def getDbConnection(dbUrl: String): Connection = {
Class.forName("com.mysql.jdbc.Driver").newInstance()
java.sql.DriverManager.getConnection(dbUrl)
}
def using[X <: {def close()}, A](resource : X)(f : X => A): A =
try { f(resource)
} finally { resource.close() }
wordCounts.map.foreachPartition(iter => {
using(getDbConnection(dbUrl)) { implicit conn =>
iter.foreach { case (word, count) =>
SQL"insert into WordCount VALUES(word, count)".executeUpdate()
}
}
})
Is there a better way?
I tried as follows, but it is so slow, compared to the first approach:
val sqlContext = new SQLContext(sc)
val wordCountSchema = StructType(List(StructField("word", StringType, nullable = false), StructField("count", IntegerType, nullable = false)))
val wordCountRowRDD = wordCounts.map(p => org.apache.spark.sql.Row(p._1,p._2))
val wordCountDF = sqlContext.createDataFrame(wordCountRowRDD, wordCountSchema)
wordCountDF.registerTempTable("WordCount")
wordCountDF.write.mode("overwrite").jdbc(dbUrl, "WordCount", new java.util.Properties())