Rewrite LogicalPlan to push down udf from aggregate - scala

I have defined an UDF which increases the input value by one, named "inc", this is the code of my udf
spark.udf.register("inc", (x: Long) => x + 1)
this is my test sql
val df = spark.sql("select sum(inc(vals)) from data")
df.explain(true)
df.show()
this is the optimized plan of that sql
== Optimized Logical Plan ==
Aggregate [sum(inc(vals#4L)) AS sum(inc(vals))#7L]
+- LocalRelation [vals#4L]
I want to rewrite the plan, and extract the "inc" from the "sum", just like python udf does.
So, this is the optimized plan which I wanted.
Aggregate [sum(inc_val#6L) AS sum(inc(vals))#7L]
+- Project [inc(vals#4L) AS inc_val#6L]
+- LocalRelation [vals#4L]
I have found that source code file "ExtractPythonUDFs.scala" provides similar function which works on PythonUDF, but it inserts a new node named "ArrowEvalPython", this is the logical plan of pythonudf.
== Optimized Logical Plan ==
Aggregate [sum(pythonUDF0#7L) AS sum(inc(vals))#4L]
+- Project [pythonUDF0#7L]
+- ArrowEvalPython [inc(vals#0L)], [pythonUDF0#7L], 200
+- Repartition 10, true
+- RelationV2[vals#0L] parquet file:/tmp/vals.parquet
What I want to inset is just a "project node", I don't want to define a new node.
this is the test code of my project
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, ScalaUDF}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
object RewritePlanTest {
case class UdfRule(spark: SparkSession) extends Rule[LogicalPlan] {
def collectUDFs(e: Expression): Seq[Expression] = e match {
case udf: ScalaUDF => Seq(udf)
case _ => e.children.flatMap(collectUDFs)
}
override def apply(plan: LogicalPlan): LogicalPlan = plan match {
case agg#Aggregate(g, a, _) if (g.isEmpty && a.length == 1) =>
val udfs = agg.expressions.flatMap(collectUDFs)
println("================")
udfs.foreach(println)
val test = udfs(0).isInstanceOf[NamedExpression]
println(s"cast ScalaUDF to NamedExpression = ${test}")
println("================")
agg
case _ => plan
}
}
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.WARN)
val spark = SparkSession
.builder()
.master("local[*]")
.appName("Rewrite plan test")
.withExtensions(e => e.injectOptimizerRule(UdfRule))
.getOrCreate()
val input = Seq(100L, 200L, 300L)
import spark.implicits._
input.toDF("vals").createOrReplaceTempView("data")
spark.udf.register("inc", (x: Long) => x + 1)
val df = spark.sql("select sum(inc(vals)) from data")
df.explain(true)
df.show()
spark.stop()
}
}
I have extract ScalaUDF from the Aggregate node,
since the arguments needed for Project Node is Seq[NamedExpression]
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)
but it's failed to cast ScalaUDF to NamedExpression,
so I have no idea about how to construct the Project node.
Can someone give me some advices?
Thanks.

OK, finally I find way to so answer this question.
Though ScalaUDF can't cast to NamedExpression, but Alias could.
So, I create Alias from ScalaUDF, then construct Project.
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, ExpectsInputTypes, ExprId, Expression, NamedExpression, ScalaUDF}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation, LogicalPlan, Project, Subquery}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.{AbstractDataType, DataType}
import scala.collection.mutable
object RewritePlanTest {
case class UdfRule(spark: SparkSession) extends Rule[LogicalPlan] {
def collectUDFs(e: Expression): Seq[Expression] = e match {
case udf: ScalaUDF => Seq(udf)
case _ => e.children.flatMap(collectUDFs)
}
override def apply(plan: LogicalPlan): LogicalPlan = plan match {
case agg#Aggregate(g, a, c) if g.isEmpty && a.length == 1 => {
val udfs = agg.expressions.flatMap(collectUDFs)
if (udfs.isEmpty) {
agg
} else {
val alias_udf = for (i <- 0 until udfs.size) yield Alias(udfs(i), s"udf${i}")()
val alias_set = mutable.HashMap[Expression, Attribute]()
val proj = Project(alias_udf, c)
alias_set ++= udfs.zip(proj.output)
val new_agg = agg.withNewChildren(Seq(proj)).transformExpressionsUp {
case udf: ScalaUDF if alias_set.contains(udf) => alias_set(udf)
}
println("====== new agg ======")
println(new_agg)
new_agg
}
}
case _ => plan
}
}
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.WARN)
val spark = SparkSession
.builder()
.master("local[*]")
.appName("Rewrite plan test")
.withExtensions(e => e.injectOptimizerRule(UdfRule))
.getOrCreate()
val input = Seq(100L, 200L, 300L)
import spark.implicits._
input.toDF("vals").createOrReplaceTempView("data")
spark.udf.register("inc", (x: Long) => x + 1)
val df = spark.sql("select sum(inc(vals)) from data where vals > 100")
// val plan = df.queryExecution.analyzed
// println(plan)
df.explain(true)
df.show()
spark.stop()
}
}
This code output the LogicalPlan that I wanted.
====== new agg ======
Aggregate [sum(udf0#9L) AS sum(inc(vals))#7L]
+- Project [inc(vals#4L) AS udf0#9L]
+- LocalRelation [vals#4L]

Related

How to aggregate unknown-typed maps into single map?

Consider the code:
import org.apache.log4j.Logger
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{aggregate, col, map, map_concat}
import org.apache.spark.sql.types.StructType
/**
* A batch application that takes a hard-coded list of strings and counts the words.
*/
object MyBatchApp {
lazy val logger: Logger = Logger.getLogger(this.getClass)
val jobName = "MyBatchApp"
def main(args: Array[String]): Unit = {
try {
val spark = SparkSession.builder().appName(jobName).master("local[*]").getOrCreate()
import spark.implicits._
val strings = Seq(
"""{"batch_id":"111111111","id":"111111111","lab_data":{"categories":{"alkaloids":null,"cannabinoids":{"compounds":[{"limit":"","limitrangehigh":"","limitrangelow":"","lod":"","loq":"0.100","max":"","name":"cbd","regulatornotes":null,"rpd":"","rsd":"","spike":"","stdev":"","value":"0"},{"limit":"","limitrangehigh":"","limitrangelow":"","lod":"","loq":"0.100","max":"","name":"cbg","regulatornotes":null,"rpd":"","rsd":"","spike":"","stdev":"","value":"0.06374552683896600"}]},"dna":null,"flavonoids":null,"foreign_matter":null,"general":null,"homogeneity":null,"metals":{"compounds":[{"limit":"1000","limitrangehigh":"","limitrangelow":"","lod":"166000","loq":"333.000","max":"","name":"lead","regulatornotes":null,"rpd":"","rsd":"","spike":"","stdev":"","value":"24.000"},{"limit":"400.0","limitrangehigh":"","limitrangelow":"","lod":"66000","loq":"133.000","max":"","name":"arsenic","regulatornotes":null,"rpd":"","rsd":"","spike":"","stdev":"","value":"32.000"}]},"microbials":{"compounds":[{"limit":"100","limitrangehigh":"","limitrangelow":"","lod":"","loq":"100","max":"","name":"ecoli","regulatornotes":null,"rpd":"","rsd":"","spike":"","stdev":"","value":"0"},{"limit":"1","limitrangehigh":"","limitrangelow":"","lod":"","loq":"","max":"","name":"salmonella_spp","regulatornotes":null,"rpd":"","rsd":"","spike":"","stdev":"","value":"0"}]}}}}""",
"""{"batch_id":"222222222","id":"222222222","lab_data":{"categories":{"alkaloids":null,"cannabinoids":{"compounds":[{"limit":"","limitrangehigh":"","limitrangelow":"","lod":"","loq":"0.100","max":"","name":"cbd","regulatornotes":null,"rpd":"","rsd":"","spike":"","stdev":"","value":"0"},{"limit":"","limitrangehigh":"","limitrangelow":"","lod":"","loq":"0.100","max":"","name":"cbg","regulatornotes":null,"rpd":"","rsd":"","spike":"","stdev":"","value":"0.06374552683896600"}]},"dna":null,"flavonoids":null,"foreign_matter":null,"general":null,"homogeneity":null,"metals":{"compounds":[{"limit":"1000","limitrangehigh":"","limitrangelow":"","lod":"166000","loq":"333.000","max":"","name":"lead","regulatornotes":null,"rpd":"","rsd":"","spike":"","stdev":"","value":"24.000"},{"limit":"400.0","limitrangehigh":"","limitrangelow":"","lod":"66000","loq":"133.000","max":"","name":"arsenic","regulatornotes":null,"rpd":"","rsd":"","spike":"","stdev":"","value":"32.000"}]},"microbials":{"compounds":[{"limit":"100","limitrangehigh":"","limitrangelow":"","lod":"","loq":"100","max":"","name":"ecoli","regulatornotes":null,"rpd":"","rsd":"","spike":"","stdev":"","value":"0"},{"limit":"1","limitrangehigh":"","limitrangelow":"","lod":"","loq":"","max":"","name":"salmonella_spp","regulatornotes":null,"rpd":"","rsd":"","spike":"","stdev":"","value":"0"}]}}}}"""
)
spark.read.json(strings.toDS).createOrReplaceTempView("sample")
val sourceSchema = spark.sql(s"select lab_data.categories.* from sample").head.schema
val source = spark.sql(s"select * from sample")
val parquet = sourceSchema.fields.filter(f =>
f.dataType.isInstanceOf[StructType] &&
f.dataType.asInstanceOf[StructType].fieldNames.contains("compounds"))
.foldLeft(source)((source, f) => source
.withColumn(s"sample_lab_data_new_categories_${f.name}_compounds",
aggregate(
col(s"sample.lab_data.categories.${f.name}.compounds"),
map(),
// map().cast("map<string,struct<limit:string,limitrangehigh:string,limitrangelow:string,lod:string,loq:string,max:string,regulatornotes:string,rpd:string,rsd:string,spike:string,stdev:string,value:string>>"),
(acc, c) => map_concat(acc, map(c.getField("name"), c.dropFields("name")))))
)
parquet.show()
}
}
}
Failed with error
Exception in thread "main" org.apache.spark.sql.AnalysisException: cannot resolve 'aggregate(sample.`lab_data`.`categories`.`cannabinoids`.`compounds`, map(), lambdafunction(map_concat(namedlambdavariable(), map(namedlambdavariable().`name`, update_fields(namedlambdavariable()))), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable()))' due to data type mismatch: argument 3 requires map<null,null> type, however, 'lambdafunction(map_concat(namedlambdavariable(), map(namedlambdavariable().`name`, update_fields(namedlambdavariable()))), namedlambdavariable(), namedlambdavariable())' is of map<string,struct<limit:string,limitrangehigh:string,limitrangelow:string,lod:string,loq:string,max:string,regulatornotes:string,rpd:string,rsd:string,spike:string,stdev:string,value:string>> type.; Project [batch_id#5, id#6, lab_data#7, aggregate(lab_data#7.categories.cannabinoids.compounds, map(), lambdafunction(map_concat(cast(lambda x_0#45 as map<string,struct<limit:string,limitrangehigh:string,limitrangelow:string,lod:string,loq:string,max:string,regulatornotes:string,rpd:string,rsd:string,spike:string,stdev:string,value:string>>), map(lambda y_1#46.name, update_fields(lambda y_1#46, DropField(name)))), lambda x_0#45, lambda y_1#46, false), lambdafunction(lambda x_2#47, lambda x_2#47, false)) AS sample_lab_data_new_categories_cannabinoids_compounds#43]
+- Project [batch_id#5, id#6, lab_data#7] +- SubqueryAlias sample
+- LogicalRDD [batch_id#5, id#6, lab_data#7], false
But when uncomment map().cast line it works. I wonder - can it possible to use map() without explicit typing? For example in situations when additional fields added.
You can get the schema of the inner structs of compounds arrays then use it to cast the initial map value for aggregate function:
//...
.foldLeft(source)((source, f) => {
val arrayOfStructs = col(s"sample.lab_data.categories.${f.name}.compounds")
val structSchema = source.select(arrayOfStructs(0)).schema.map(_.dataType.simpleString).mkString
source.withColumn(
s"sample_lab_data_new_categories_${f.name}_compounds",
aggregate(
arrayOfStructs,
map().cast(s"map<string,$structSchema>"),
(acc, c) =>
map_concat(acc, map(c.getField("name"), c.dropFields("name")))
)
)
}
)
//...

Assert RDD is not sorted

I have a method called split that accepts an RDD[T] and a splitSize and returns an Array[RDD[T]].
Now, one of the test cases I write for it should verify that this function also randomly shuffles the RDD.
So I create a sorted RDD, and then see the results:
it should "randomize shuffle" in {
val inputRDD = sc.parallelize((0 until 16))
val result = RDDUtils.split(inputRDD, 2)
result.foreach(rdd => {
rdd.collect.foreach(println)
})
// Asset result is not sorted
}
If the results are:
0
1
2
3
..
15
Then it's not working as expected.
A good result can be something like:
11
3
9
14
...
1
6
How can I assert the output Array[RDD[T]]] is not sorted?
You could try something like this
val resultOrder = result.sortBy(....)
assert(!resultOrder.sameElements(result))
or
val resultOrder = result.sortBy(....)
assert(!resultOrder.toList == result.toList)
It's important to note that the key is to know how to sort the Array. For an Integer data type it would be easy, but for a complex data type you could need an implicit Ordering for your data type. e.g:
implicit val ordering: Ordering[T] =
Ordering.fromLessThan[T]((sa: T, sb: T) => sa < sb)
// OR
implicit val ordering: Ordering[MyClass] =
Ordering.fromLessThan[MyClass]((sa: MyClass, sb: MyClass) => sa.field1 < sb.field1)
The exact code would depend of your data type.
As a full example of this
package tests
import org.apache.log4j.{Level, Logger}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
object SortArrayRDD {
val spark = SparkSession
.builder()
.appName("SortArrayRDD")
.master("local[*]")
.config("spark.sql.shuffle.partitions","4") //Change to a more reasonable default number of partitions for our data
.config("spark.app.id","SortArrayRDD") // To silence Metrics warning
.getOrCreate()
val sc = spark.sparkContext
def main(args: Array[String]): Unit = {
try {
Logger.getRootLogger.setLevel(Level.ERROR)
val arrRDD: Array[RDD[Int]] = Array(sc.parallelize(List(2,3)),sc.parallelize(List(10,11)),sc.parallelize(List(6,7)),sc.parallelize(List(8,9)),
sc.parallelize(List(4,5)),sc.parallelize(List(0,1)),sc.parallelize(List(12,13)),sc.parallelize(List(14,15)))
val aux = arrRDD
implicit val ordering: Ordering[RDD[Int]] = Ordering.fromLessThan[RDD[Int]]((sa: RDD[Int], sb: RDD[Int]) => sa.sum() < sb.sum())
aux.sorted.foreach(rdd => println(rdd.collect().mkString(",")))
val resultOrder = aux.sorted
assert(!resultOrder.sameElements(arrRDD))
println("It's unordered")
} finally {
sc.stop()
}
}
}

Spark collect_list and limit resulting list

I have a dataframe of the following format:
name merged
key1 (internalKey1, value1)
key1 (internalKey2, value2)
...
key2 (internalKey3, value3)
...
What I want to do is group the dataframe by the name, collect the list and limit the size of the list.
This is how i group by the name and collect the list:
val res = df.groupBy("name")
.agg(collect_list(col("merged")).as("final"))
The resuling dataframe is something like:
key1 [(internalKey1, value1), (internalKey2, value2),...] // Limit the size of this list
key2 [(internalKey3, value3),...]
What I want to do is limit the size of the produced lists for each key. I' ve tried multiple ways to do that but had no success. I've already seen some posts that suggest 3rd party solutions but I want to avoid that. Is there a way?
So while a UDF does what you need, if you're looking for a more performant way that is also memory sensitive, the way of doing this would be to write a UDAF. Unfortunately the UDAF API is actually not as extensible as the aggregate functions that ship with spark. However you can use their internal APIs to build on the internal functions to do what you need.
Here is an implementation for collect_list_limit that is mostly a copy past of Spark's internal CollectList AggregateFunction. I would just extend it but its a case class. Really all that's needed is to override update and merge methods to respect a passed in limit:
case class CollectListLimit(
child: Expression,
limitExp: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] {
val limit = limitExp.eval( null ).asInstanceOf[Int]
def this(child: Expression, limit: Expression) = this(child, limit, 0, 0)
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty
override def update(buffer: mutable.ArrayBuffer[Any], input: InternalRow): mutable.ArrayBuffer[Any] = {
if( buffer.size < limit ) super.update(buffer, input)
else buffer
}
override def merge(buffer: mutable.ArrayBuffer[Any], other: mutable.ArrayBuffer[Any]): mutable.ArrayBuffer[Any] = {
if( buffer.size >= limit ) buffer
else if( other.size >= limit ) other
else ( buffer ++= other ).take( limit )
}
override def prettyName: String = "collect_list_limit"
}
And to actually register it, we can do it through Spark's internal FunctionRegistry which takes in the name and the builder which is effectively a function that creates a CollectListLimit using the provided expressions:
val collectListBuilder = (args: Seq[Expression]) => CollectListLimit( args( 0 ), args( 1 ) )
FunctionRegistry.builtin.registerFunction( "collect_list_limit", collectListBuilder )
Edit:
Turns out adding it to the builtin only works if you haven't created the SparkContext yet as it makes an immutable clone on startup. If you have an existing context then this should work to add it with reflection:
val field = classOf[SessionCatalog].getFields.find( _.getName.endsWith( "functionRegistry" ) ).get
field.setAccessible( true )
val inUseRegistry = field.get( SparkSession.builder.getOrCreate.sessionState.catalog ).asInstanceOf[FunctionRegistry]
inUseRegistry.registerFunction( "collect_list_limit", collectListBuilder )
You can create a function that limits the size of the aggregated ArrayType column as shown below:
import org.apache.spark.sql.functions._
import org.apache.spark.sql.Column
case class KV(k: String, v: String)
val df = Seq(
("key1", KV("internalKey1", "value1")),
("key1", KV("internalKey2", "value2")),
("key2", KV("internalKey3", "value3")),
("key2", KV("internalKey4", "value4")),
("key2", KV("internalKey5", "value5"))
).toDF("name", "merged")
def limitSize(n: Int, arrCol: Column): Column =
array( (0 until n).map( arrCol.getItem ): _* )
df.
groupBy("name").agg( collect_list(col("merged")).as("final") ).
select( $"name", limitSize(2, $"final").as("final2") ).
show(false)
// +----+----------------------------------------------+
// |name|final2 |
// +----+----------------------------------------------+
// |key1|[[internalKey1,value1], [internalKey2,value2]]|
// |key2|[[internalKey3,value3], [internalKey4,value4]]|
// +----+----------------------------------------------+
You can use a UDF.
Here is a probable example without the necessity of schema and with a meaningful reduction:
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.functions._
import scala.collection.mutable
object TestJob1 {
def main (args: Array[String]): Unit = {
val sparkSession = SparkSession
.builder()
.appName(this.getClass.getName.replace("$", ""))
.master("local")
.getOrCreate()
val sc = sparkSession.sparkContext
import sparkSession.sqlContext.implicits._
val rawDf = Seq(
("key", 1L, "gargamel"),
("key", 4L, "pe_gadol"),
("key", 2L, "zaam"),
("key1", 5L, "naval")
).toDF("group", "quality", "other")
rawDf.show(false)
rawDf.printSchema
val rawSchema = rawDf.schema
val fUdf = udf(reduceByQuality, rawSchema)
val aggDf = rawDf
.groupBy("group")
.agg(
count(struct("*")).as("num_reads"),
max(col("quality")).as("quality"),
collect_list(struct("*")).as("horizontal")
)
.withColumn("short", fUdf($"horizontal"))
.drop("horizontal")
aggDf.printSchema
aggDf.show(false)
}
def reduceByQuality= (x: Any) => {
val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]
val red = d.reduce((r1, r2) => {
val quality1 = r1.getAs[Long]("quality")
val quality2 = r2.getAs[Long]("quality")
val r3 = quality1 match {
case a if a >= quality2 =>
r1
case _ =>
r2
}
r3
})
red
}
}
here is an example with data like yours
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types._
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.functions._
import scala.collection.mutable
object TestJob {
def main (args: Array[String]): Unit = {
val sparkSession = SparkSession
.builder()
.appName(this.getClass.getName.replace("$", ""))
.master("local")
.getOrCreate()
val sc = sparkSession.sparkContext
import sparkSession.sqlContext.implicits._
val df1 = Seq(
("key1", ("internalKey1", "value1")),
("key1", ("internalKey2", "value2")),
("key2", ("internalKey3", "value3")),
("key2", ("internalKey4", "value4")),
("key2", ("internalKey5", "value5"))
)
.toDF("name", "merged")
// df1.printSchema
//
// df1.show(false)
val res = df1
.groupBy("name")
.agg( collect_list(col("merged")).as("final") )
res.printSchema
res.show(false)
def f= (x: Any) => {
val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]
val d1 = d.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]].head
d1.toString
}
val fUdf = udf(f, StringType)
val d2 = res
.withColumn("d", fUdf(col("final")))
.drop("final")
d2.printSchema()
d2
.show(false)
}
}

Scala : Product with Serializable does not take parameters

My objectif is to read Data from a csv file and convert my rdd to dataframe in scala/spark. This is my code :
package xxx.DataScience.CompensationStudy
import org.apache.spark._
import org.apache.log4j._
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.SparkConf
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types._
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
object CompensationAnalysis {
case class GetDF(profil_date:String, profil_pays:String, param_tarif2:String, param_tarif3:String, dt_titre:String, dt_langues:String,
dt_diplomes:String, dt_experience:String, dt_formation:String, dt_outils:String, comp_applications:String,
comp_interventions:String, comp_competence:String)
def main(args: Array[String]) {
Logger.getLogger("org").setLevel(Level.ERROR)
val conf = new SparkConf().setAppName("CompensationAnalysis ")
val sc = new SparkContext(conf)
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
import sqlContext.implicits._
val lines = sc.textFile("C:/Users/../Downloads/CompensationStudy.csv").flatMap { l =>
l.split(",") match {
case field: Array[String] if field.size > 13 => Some(field(0), field(1), field(2), field(3), field(4), field(5), field(6), field(7), field(8), field(9), field(10), field(11), field(12))
case field: Array[String] if field.size == 1 => Some((field(0), "default value"))
case _ => None
}
}
At this stade, I had the error : Product with Serializable does not take parameters
val summary = lines.collect().map(x => GetDF(x("profil_date"), x("profil_pays"), x("param_tarif2"), x("param_tarif3"), x("dt_titre"), x("dt_langues"), x("dt_diplomes"), x("dt_experience"), x("dt_formation"), x("dt_outils"), x("comp_applications"), x("comp_interventions"), x("comp_competence")))
val sum_df = summary.toDF()
df.printSchema
}
}
This is a screenshot :
Help please ?
You have several things you should improve. The most urgent problem, which causes the exception, is, as #CyrilleCorpet points out, " the three different lines in the pattern matching return values of types Some[Tuple13], Some[Tuple2] and None.type. The least-upper-bound is then Option[Product with Serializable] which complies with flatMap's signature (where the result should be an Iterable[T]) modulo some implicit conversion."
Basically, if you had Some[Tuple13], Some[Tuple13], and None or Some[Tuple2], Some[Tuple2], and None, you would be better off.
Also, pattern matching on types is generally a bad idea because of type erasure, and pattern matching isn't even great anyway for your situation.
So you could set default values in your case class:
case class GetDF(profile_date: String,
profile_pays: String = "default",
param_tarif2: String = "default",
...
)
Then in your lambda:
val tokens = l.split
if (l.length > 13) {
Some(GetDf(l(0), l(1), l(2)...))
} else if (l.length == 1) {
Some(GetDf(l(0)))
} else {
None
}
Now in all cases you are returning Option[GetDF]. You can flatMap the RDD to get rid of all the Nones and keep only GetDF instances.

spark map partitions to fill nan values

I want to fill nan values in spark using the last good known observation - see: Spark / Scala: fill nan with last good observation
My current solution used window functions in order to accomplish the task. But this is not great, as all values are mapped into a single partition.
val imputed: RDD[FooBar] = recordsDF.rdd.mapPartitionsWithIndex { case (i, iter) => fill(i, iter) } should work a lot better. But strangely my fill function is not executed. What is wrong with my code?
+----------+--------------------+
| foo| bar|
+----------+--------------------+
|2016-01-01| first|
|2016-01-02| second|
| null| noValidFormat|
|2016-01-04|lastAssumingSameDate|
+----------+--------------------+
Here is the full example code:
import java.sql.Date
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
case class FooBar(foo: Date, bar: String)
object WindowFunctionExample extends App {
Logger.getLogger("org").setLevel(Level.WARN)
val conf: SparkConf = new SparkConf()
.setAppName("foo")
.setMaster("local[*]")
val spark: SparkSession = SparkSession
.builder()
.config(conf)
.enableHiveSupport()
.getOrCreate()
import spark.implicits._
val myDff = Seq(("2016-01-01", "first"), ("2016-01-02", "second"),
("2016-wrongFormat", "noValidFormat"),
("2016-01-04", "lastAssumingSameDate"))
val recordsDF = myDff
.toDF("foo", "bar")
.withColumn("foo", 'foo.cast("Date"))
.as[FooBar]
recordsDF.show
def notMissing(row: FooBar): Boolean = {
row.foo != null
}
val toCarry = recordsDF.rdd.mapPartitionsWithIndex { case (i, iter) => Iterator((i, iter.filter(notMissing(_)).toSeq.lastOption)) }.collectAsMap
println("###################### carry ")
println(toCarry)
println(toCarry.foreach(println))
println("###################### carry ")
val toCarryBd = spark.sparkContext.broadcast(toCarry)
def fill(i: Int, iter: Iterator[FooBar]): Iterator[FooBar] = {
var lastNotNullRow: FooBar = toCarryBd.value(i).get
iter.map(row => {
if (!notMissing(row))1
FooBar(lastNotNullRow.foo, row.bar)
else {
lastNotNullRow = row
row
}
})
}
// The algorithm does not step into the for loop for filling the null values. Strange
val imputed: RDD[FooBar] = recordsDF.rdd.mapPartitionsWithIndex { case (i, iter) => fill(i, iter) }
val imputedDF = imputed.toDS()
println(imputedDF.orderBy($"foo").collect.toList)
imputedDF.show
spark.stop
}
edit
I fixed the code as outlined by the comment. But the toCarryBd contains None values. How can this happen as I did filter explicitly for
def notMissing(row: FooBar): Boolean = {row.foo != null}
iter.filter(notMissing(_)).toSeq.lastOption
non None values.
(2,None)
(5,None)
(4,None)
(7,Some(FooBar(2016-01-04,lastAssumingSameDate)))
(1,Some(FooBar(2016-01-01,first)))
(3,Some(FooBar(2016-01-02,second)))
(6,None)
(0,None)
This leads to NoSuchElementException: None.getwhen trying to access toCarryBd.
Firstly, if your foo field can be null, I would recommend creating the case class as:
case class FooBar(foo: Option[Date], bar: String)
Then, you can rewrite your notMissing function to something like:
def notMissing(row: Option[FooBar]): Boolean = row.isDefined && row.get.foo.isDefined