Suppose I have these case class
case class Employee(id: Long, proj_id: Long, office_id: Long, salary: Long)
case class Renumeration(id: Long, amount: Long)
I'd like to update a collection of Employee based on Renumeration using Spark
val right: Dataset[Renumeration] = ???
val left: Dataset[Employee] = ???
left.joinWith(broadcast(right),left("proj_id") === right("id"),"leftouter")
.map { case(left,right) => updateProj(left,right) }
.joinWith(broadcast(right),left("office_id") === right("id"),"leftouter")
.map { case(left,right) => updateOffice(left,right) }
def updateProj(emp: Employee; ren: Renumeration): Employee = //business logic
def updateOffice(emp: Employee; ren: Renumeration): Employee = //business logic
The first join and map works, however when I introduce the second join Spark failed to resolve the id column and showed these instead.
org.apache.spark.sql.AnalysisException: Resolved attribute(s) office_id#42L missing from id#114L,salary#117L,id#34L,amount#35L,proj_id#115L,office_id#116L in operator !Join LeftOuter, (office_id#42L = id#34L). Attribute(s) with the same name appear in the operation: office_id. Please check if the right attribute(s) are used.;;
!Join LeftOuter, (office_id#42L = id#34L)
:- SerializeFromObject [assertnotnull(assertnotnull(input[0, Employee, true])).id AS id#114L, assertnotnull(assertnotnull(input[0, Employee, true])).proj_id AS proj_id#115L, assertnotnull(assertnotnull(input[0, Employee, true])).office_id AS office_id#116L, assertnotnull(assertnotnull(input[0, Employee, true])).salary AS salary#117L]
: +- MapElements <function1>, class scala.Tuple2, [StructField(_1,StructType(StructField(id,LongType,false), StructField(proj_id,LongType,false), StructField(office_id,LongType,false), StructField(salary,LongType,false)),true), StructField(_2,StructType(StructField(id,LongType,false), StructField(amount,LongType,false)),true)], obj#113: Employee
: +- DeserializeToObject newInstance(class scala.Tuple2), obj#112: scala.Tuple2
: +- Join LeftOuter, (_1#103.proj_id = _2#104.id)
: :- Project [named_struct(id, id#40L, proj_id, proj_id#41L, office_id, office_id#42L, salary, salary#43L) AS _1#103]
: : +- LocalRelation <empty>, [id#40L, proj_id#41L, office_id#42L, salary#43L]
: +- Project [named_struct(id, id#34L, amount, amount#35L) AS _2#104]
: +- ResolvedHint (broadcast)
: +- LocalRelation <empty>, [id#34L, amount#35L]
+- ResolvedHint (broadcast)
+- LocalRelation <empty>, [id#34L, amount#35L]
Any idea why Spark could not resolve the column even though I already used the typed Dataset? Also what should I do to make this work if possible?
The error is being caused because the reference returned by left("office_id") no longer exists in the new projected dataset(i.e. the dataset resulting from the first join and map operation).
If you look closer at the execution plan in the nested relation
: +- LocalRelation <empty>, [id#40L, proj_id#41L, office_id#42L, salary#43L]
you can observe that the reference to office_id in the left dataset is office_id#42L. However, if you look at the later execution, you will notice that this reference no longer exists in the projection
SerializeFromObject [assertnotnull(assertnotnull(input[0, Employee, true])).id AS id#114L, assertnotnull(assertnotnull(input[0, Employee, true])).proj_id AS proj_id#115L, assertnotnull(assertnotnull(input[0, Employee, true])).office_id AS office_id#116L, assertnotnull(assertnotnull(input[0, Employee, true])).salary AS salary#117L]
as the office_id reference available is office_id#116L.
In order to resolve this, you could use intermediary/temporary variables eg:
val right: Dataset[Renumeration] = ???
val left: Dataset[Employee] = ???
val leftTemp = left.joinWith(broadcast(right),left("proj_id") === right("id"),"leftouter")
.map { case(left,right) => updateProj(left,right) }
val leftFinal = leftTemp.joinWith(broadcast(right),leftTemp("office_id") === right("id"),"leftouter")
.map { case(left,right) => updateOffice(left,right) }
or you could try using the following shorthand $"office_id" === right("id") in your join eg
left.joinWith(broadcast(right),left("proj_id") === right("id"),"leftouter")
.map { case(left,right) => updateProj(left,right) }
.joinWith(broadcast(right),$"office_id" === right("id"),"leftouter")
.map { case(left,right) => updateOffice(left,right) }
Let me know if this works for you.
Related
Consider the code:
import org.apache.log4j.Logger
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{aggregate, col, map, map_concat}
import org.apache.spark.sql.types.StructType
/**
* A batch application that takes a hard-coded list of strings and counts the words.
*/
object MyBatchApp {
lazy val logger: Logger = Logger.getLogger(this.getClass)
val jobName = "MyBatchApp"
def main(args: Array[String]): Unit = {
try {
val spark = SparkSession.builder().appName(jobName).master("local[*]").getOrCreate()
import spark.implicits._
val strings = Seq(
"""{"batch_id":"111111111","id":"111111111","lab_data":{"categories":{"alkaloids":null,"cannabinoids":{"compounds":[{"limit":"","limitrangehigh":"","limitrangelow":"","lod":"","loq":"0.100","max":"","name":"cbd","regulatornotes":null,"rpd":"","rsd":"","spike":"","stdev":"","value":"0"},{"limit":"","limitrangehigh":"","limitrangelow":"","lod":"","loq":"0.100","max":"","name":"cbg","regulatornotes":null,"rpd":"","rsd":"","spike":"","stdev":"","value":"0.06374552683896600"}]},"dna":null,"flavonoids":null,"foreign_matter":null,"general":null,"homogeneity":null,"metals":{"compounds":[{"limit":"1000","limitrangehigh":"","limitrangelow":"","lod":"166000","loq":"333.000","max":"","name":"lead","regulatornotes":null,"rpd":"","rsd":"","spike":"","stdev":"","value":"24.000"},{"limit":"400.0","limitrangehigh":"","limitrangelow":"","lod":"66000","loq":"133.000","max":"","name":"arsenic","regulatornotes":null,"rpd":"","rsd":"","spike":"","stdev":"","value":"32.000"}]},"microbials":{"compounds":[{"limit":"100","limitrangehigh":"","limitrangelow":"","lod":"","loq":"100","max":"","name":"ecoli","regulatornotes":null,"rpd":"","rsd":"","spike":"","stdev":"","value":"0"},{"limit":"1","limitrangehigh":"","limitrangelow":"","lod":"","loq":"","max":"","name":"salmonella_spp","regulatornotes":null,"rpd":"","rsd":"","spike":"","stdev":"","value":"0"}]}}}}""",
"""{"batch_id":"222222222","id":"222222222","lab_data":{"categories":{"alkaloids":null,"cannabinoids":{"compounds":[{"limit":"","limitrangehigh":"","limitrangelow":"","lod":"","loq":"0.100","max":"","name":"cbd","regulatornotes":null,"rpd":"","rsd":"","spike":"","stdev":"","value":"0"},{"limit":"","limitrangehigh":"","limitrangelow":"","lod":"","loq":"0.100","max":"","name":"cbg","regulatornotes":null,"rpd":"","rsd":"","spike":"","stdev":"","value":"0.06374552683896600"}]},"dna":null,"flavonoids":null,"foreign_matter":null,"general":null,"homogeneity":null,"metals":{"compounds":[{"limit":"1000","limitrangehigh":"","limitrangelow":"","lod":"166000","loq":"333.000","max":"","name":"lead","regulatornotes":null,"rpd":"","rsd":"","spike":"","stdev":"","value":"24.000"},{"limit":"400.0","limitrangehigh":"","limitrangelow":"","lod":"66000","loq":"133.000","max":"","name":"arsenic","regulatornotes":null,"rpd":"","rsd":"","spike":"","stdev":"","value":"32.000"}]},"microbials":{"compounds":[{"limit":"100","limitrangehigh":"","limitrangelow":"","lod":"","loq":"100","max":"","name":"ecoli","regulatornotes":null,"rpd":"","rsd":"","spike":"","stdev":"","value":"0"},{"limit":"1","limitrangehigh":"","limitrangelow":"","lod":"","loq":"","max":"","name":"salmonella_spp","regulatornotes":null,"rpd":"","rsd":"","spike":"","stdev":"","value":"0"}]}}}}"""
)
spark.read.json(strings.toDS).createOrReplaceTempView("sample")
val sourceSchema = spark.sql(s"select lab_data.categories.* from sample").head.schema
val source = spark.sql(s"select * from sample")
val parquet = sourceSchema.fields.filter(f =>
f.dataType.isInstanceOf[StructType] &&
f.dataType.asInstanceOf[StructType].fieldNames.contains("compounds"))
.foldLeft(source)((source, f) => source
.withColumn(s"sample_lab_data_new_categories_${f.name}_compounds",
aggregate(
col(s"sample.lab_data.categories.${f.name}.compounds"),
map(),
// map().cast("map<string,struct<limit:string,limitrangehigh:string,limitrangelow:string,lod:string,loq:string,max:string,regulatornotes:string,rpd:string,rsd:string,spike:string,stdev:string,value:string>>"),
(acc, c) => map_concat(acc, map(c.getField("name"), c.dropFields("name")))))
)
parquet.show()
}
}
}
Failed with error
Exception in thread "main" org.apache.spark.sql.AnalysisException: cannot resolve 'aggregate(sample.`lab_data`.`categories`.`cannabinoids`.`compounds`, map(), lambdafunction(map_concat(namedlambdavariable(), map(namedlambdavariable().`name`, update_fields(namedlambdavariable()))), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable()))' due to data type mismatch: argument 3 requires map<null,null> type, however, 'lambdafunction(map_concat(namedlambdavariable(), map(namedlambdavariable().`name`, update_fields(namedlambdavariable()))), namedlambdavariable(), namedlambdavariable())' is of map<string,struct<limit:string,limitrangehigh:string,limitrangelow:string,lod:string,loq:string,max:string,regulatornotes:string,rpd:string,rsd:string,spike:string,stdev:string,value:string>> type.; Project [batch_id#5, id#6, lab_data#7, aggregate(lab_data#7.categories.cannabinoids.compounds, map(), lambdafunction(map_concat(cast(lambda x_0#45 as map<string,struct<limit:string,limitrangehigh:string,limitrangelow:string,lod:string,loq:string,max:string,regulatornotes:string,rpd:string,rsd:string,spike:string,stdev:string,value:string>>), map(lambda y_1#46.name, update_fields(lambda y_1#46, DropField(name)))), lambda x_0#45, lambda y_1#46, false), lambdafunction(lambda x_2#47, lambda x_2#47, false)) AS sample_lab_data_new_categories_cannabinoids_compounds#43]
+- Project [batch_id#5, id#6, lab_data#7] +- SubqueryAlias sample
+- LogicalRDD [batch_id#5, id#6, lab_data#7], false
But when uncomment map().cast line it works. I wonder - can it possible to use map() without explicit typing? For example in situations when additional fields added.
You can get the schema of the inner structs of compounds arrays then use it to cast the initial map value for aggregate function:
//...
.foldLeft(source)((source, f) => {
val arrayOfStructs = col(s"sample.lab_data.categories.${f.name}.compounds")
val structSchema = source.select(arrayOfStructs(0)).schema.map(_.dataType.simpleString).mkString
source.withColumn(
s"sample_lab_data_new_categories_${f.name}_compounds",
aggregate(
arrayOfStructs,
map().cast(s"map<string,$structSchema>"),
(acc, c) =>
map_concat(acc, map(c.getField("name"), c.dropFields("name")))
)
)
}
)
//...
I am trying to create a dataframe from RDD in order to be able to write to a json with following format
A sample json is as shown below(expected output)
"1234":[
{
loc:'abc',
cost1:1.234,
cost2:2.3445
},
{
loc:'www',
cost1:1.534,
cost2:6.3445
}
]
I am able to generate the json with cost1 and cost2 in String format. But I want cost1 and cost2 to be double.
I am getting error while creating data frame from rdd using schema defined.
Somehow the data is being considered as String instead of double.
Can someone help me to get this right?
Below is my scala code of my sample implementation
object csv2json {
def f[T](v: T) = v match {
case _: Int => "Int"
case _: String => "String"
case _: Float => "Float"
case _: Double => "Double"
case _:BigDecimal => "BigDecimal"
case _ => "Unknown"
}
def main(args: Array[String]): Unit = {
import org.apache.spark.sql.SparkSession
val spark = SparkSession.builder().master("local").getOrCreate()
import spark.implicits._
val input_df = Seq(("12345", "111","1.34","2.34"),("123456", "112","1.343","2.344"),("1234", "113","1.353","2.354"),("1231", "114","5.343","6.344")).toDF("item_id","loc","cost1","cost2")
input_df.show()
val inputRDD = input_df.rdd.map(data => {
val nodeObj = scala.collection.immutable.Map("nodeId" -> data(1).toString()
,"soc" -> data(2).toString().toDouble
,"mdc" -> data(3).toString().toDouble)
(data(0).toString(),nodeObj)
})
val inputRDDAgg = inputRDD.aggregateByKey(scala.collection.mutable.ListBuffer.empty[Any])((nodeAAggreg,costValue) => nodeAAggreg += costValue , (nodeAAggreg,costValue) => nodeAAggreg ++ costValue)
val inputRDDAggRow = inputRDDAgg.map(data => {
println(data._1 + "and------ " + f(data._1))
println(data._2 + "and------ " + f(data._2))
val skuObj = Row(
data._1,
data._2)
skuObj
}
)
val innerSchema = ArrayType(MapType(StringType, DoubleType, true))
val schema:StructType = StructType(Seq(StructField(name="skuId", dataType=StringType),StructField(name="nodes", innerSchema)))
val finalJsonDF = spark.createDataFrame(inputRDDAggRow, schema)
finalJsonDF.show()
}
}
Below is the exception stacktrace:
java.lang.RuntimeException: Error while encoding: java.lang.ClassCastException: java.lang.String cannot be cast to java.lang.Double
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 0, skuId), StringType), true, false) AS skuId#32
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else mapobjects(MapObjects_loopValue0, MapObjects_loopIsNull0, ObjectType(class java.lang.Object), if (isnull(validateexternaltype(lambdavariable(MapObjects_loopValue0, MapObjects_loopIsNull0, ObjectType(class java.lang.Object), true), MapType(StringType,DoubleType,true)))) null else newInstance(class org.apache.spark.sql.catalyst.util.ArrayBasedMapData), validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 1, nodes), ArrayType(MapType(StringType,DoubleType,true),true)), None) AS nodes#33
at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.toRow(ExpressionEncoder.scala:291)
at org.apache.spark.sql.SparkSession$$anonfun$4.apply(SparkSession.scala:589)
at org.apache.spark.sql.SparkSession$$anonfun$4.apply(SparkSession.scala:589)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:410)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:410)
I would suggest you to stay with dataset or dataframe by using inbult functions as they are optimized version of rdds.
So you can do the following to achieve your requirement
import org.apache.spark.sql.functions._
val finalJsonDF = input_df
.groupBy("item_id")
.agg(
collect_list(
struct(col("loc"), col("cost1").cast("double"), col("cost2").cast("double")))
.as("jsonData"))
where collect_list and struct are inbuilt functions
which should give you
+-------+-------------------+
|item_id|jsonData |
+-------+-------------------+
|123456 |[[112,1.343,2.344]]|
|1234 |[[113,1.353,2.354]]|
|1231 |[[114,5.343,6.344]]|
|12345 |[[111,1.34,2.34]] |
+-------+-------------------+
and saving the jsonData to json file as your requirement
finalJsonDF.coalesce(1).write.json("path to output file")
should give you
{"item_id":"123456","jsonData":[{"loc":"112","col2":1.343,"col3":2.344}]}
{"item_id":"1234","jsonData":[{"loc":"113","col2":1.353,"col3":2.354}]}
{"item_id":"1231","jsonData":[{"loc":"114","col2":5.343,"col3":6.344}]}
{"item_id":"12345","jsonData":[{"loc":"111","col2":1.34,"col3":2.34}]}
I see schema mismatch in your code. I made simple fix as an workaround
I converted data(1).toString to data(1).toString.toDouble. In your ArrayType(MapType(StringType, DoubleType, true)), you have mentioned all the values are Double where as one of your value is String. I believe that is the issue.
val inputRDD = input_df.rdd.map(data => {
val nodeObj = scala.collection.immutable.Map("nodeId" -> data(1).toString.toDouble
,"soc" -> data(2).toString().toDouble
,"mdc" -> data(3).toString().toDouble)
(data(0).toString(),nodeObj)
})
Output
+------+--------------------------------------------------+
|skuId |nodes |
+------+--------------------------------------------------+
|1231 |[Map(nodeId -> 114.0, soc -> 5.343, mdc -> 6.344)]|
|12345 |[Map(nodeId -> 111.0, soc -> 1.34, mdc -> 2.34)] |
|123456|[Map(nodeId -> 112.0, soc -> 1.343, mdc -> 2.344)]|
|1234 |[Map(nodeId -> 113.0, soc -> 1.353, mdc -> 2.354)]|
+------+--------------------------------------------------+
Hope this helps!
SPARK_VERSION = 2.2.0
I ran into an interesting issue when trying to do a filter on a dataframe that has columns that were added using a UDF. I am able to replicate the problem with a smaller set of data.
Given the dummy case classes:
case class Info(number: Int, color: String)
case class Record(name: String, infos: Seq[Info])
and the following data:
val blue = Info(1, "blue")
val black = Info(2, "black")
val yellow = Info(3, "yellow")
val orange = Info(4, "orange")
val white = Info(5, "white")
val a = Record("a", Seq(blue, black, white))
val a2 = Record("a", Seq(yellow, white, orange))
val b = Record("b", Seq(blue, black))
val c = Record("c", Seq(white, orange))
val d = Record("d", Seq(orange, black))
do the following...
Create two dataframes (we will call them left and right)
val left = Seq(a, b).toDF
val right = Seq(a2, c, d).toDF
Join those dataframes using a full_outer join, and take only what is on the right side
val rightOnlyInfos = left.alias("l")
.join(right.alias("r"), Seq("name"), "full_outer")
.filter("l.infos is null")
.select($"name", $"r.infos".as("r_infos"))
This results in the following:
rightOnlyInfos.show(false)
+----+-----------------------+
|name|r_infos |
+----+-----------------------+
|c |[[5,white], [4,orange]]|
|d |[[4,orange], [2,black]]|
+----+-----------------------+
Using the following udf, add a new column that is a boolean and represents whether or not one of the r_infos contains the color black
def hasBlack = (s: Seq[Row]) => {
s.exists{ case Row(num: Int, color: String) =>
color == "black"
}
}
val joinedBreakdown = rightOnlyInfos.withColumn("has_black", udf(hasBlack).apply($"r_infos"))
This is where I am seeing problems now. If I do the following, I get no errors:
joinedBreakdown.show(false)
and it results (like expected) in:
+----+-----------------------+---------+
|name|r_infos |has_black|
+----+-----------------------+---------+
|c |[[5,white], [4,orange]]|false |
|d |[[4,orange], [2,black]]|true |
+----+-----------------------+---------+
and the schema
joinedBreakdown.printSchema
shows
root
|-- name: string (nullable = true)
|-- r_infos: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- number: integer (nullable = false)
| | |-- color: string (nullable = true)
|-- has_black: boolean (nullable = true)
However, when I try to filter by that results, I get an error:
joinedBreakdown.filter("has_black == true").show(false)
With the following error:
org.apache.spark.SparkException: Failed to execute user defined function($anonfun$hasBlack$1: (array<struct<number:int,color:string>>) => boolean)
at org.apache.spark.sql.catalyst.expressions.ScalaUDF.eval(ScalaUDF.scala:1075)
at org.apache.spark.sql.catalyst.expressions.BinaryExpression.eval(Expression.scala:411)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin.org$apache$spark$sql$catalyst$optimizer$EliminateOuterJoin$$canFilterOutNull(joins.scala:127)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin$$anonfun$rightHasNonNullPredicate$lzycompute$1$1.apply(joins.scala:138)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin$$anonfun$rightHasNonNullPredicate$lzycompute$1$1.apply(joins.scala:138)
at scala.collection.LinearSeqOptimized$class.exists(LinearSeqOptimized.scala:93)
at scala.collection.immutable.List.exists(List.scala:84)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin.rightHasNonNullPredicate$lzycompute$1(joins.scala:138)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin.rightHasNonNullPredicate$1(joins.scala:138)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin.org$apache$spark$sql$catalyst$optimizer$EliminateOuterJoin$$buildNewJoinType(joins.scala:145)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin$$anonfun$apply$2.applyOrElse(joins.scala:152)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin$$anonfun$apply$2.applyOrElse(joins.scala:150)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$2.apply(TreeNode.scala:267)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$2.apply(TreeNode.scala:267)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:266)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:306)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:187)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:304)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:306)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:187)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:304)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:306)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:187)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:304)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:272)
at org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:256)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin.apply(joins.scala:150)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin.apply(joins.scala:116)
at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1$$anonfun$apply$1.apply(RuleExecutor.scala:85)
at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1$$anonfun$apply$1.apply(RuleExecutor.scala:82)
at scala.collection.LinearSeqOptimized$class.foldLeft(LinearSeqOptimized.scala:124)
at scala.collection.immutable.List.foldLeft(List.scala:84)
at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1.apply(RuleExecutor.scala:82)
at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1.apply(RuleExecutor.scala:74)
at scala.collection.immutable.List.foreach(List.scala:381)
at org.apache.spark.sql.catalyst.rules.RuleExecutor.execute(RuleExecutor.scala:74)
at org.apache.spark.sql.execution.QueryExecution.optimizedPlan$lzycompute(QueryExecution.scala:78)
at org.apache.spark.sql.execution.QueryExecution.optimizedPlan(QueryExecution.scala:78)
at org.apache.spark.sql.execution.QueryExecution.sparkPlan$lzycompute(QueryExecution.scala:84)
at org.apache.spark.sql.execution.QueryExecution.sparkPlan(QueryExecution.scala:80)
at org.apache.spark.sql.execution.QueryExecution.executedPlan$lzycompute(QueryExecution.scala:89)
at org.apache.spark.sql.execution.QueryExecution.executedPlan(QueryExecution.scala:89)
at org.apache.spark.sql.Dataset.withAction(Dataset.scala:2832)
at org.apache.spark.sql.Dataset.head(Dataset.scala:2153)
at org.apache.spark.sql.Dataset.take(Dataset.scala:2366)
at org.apache.spark.sql.Dataset.showString(Dataset.scala:245)
at org.apache.spark.sql.Dataset.show(Dataset.scala:646)
at org.apache.spark.sql.Dataset.show(Dataset.scala:623)
... 58 elided
Caused by: java.lang.NullPointerException
at $anonfun$hasBlack$1.apply(<console>:41)
at $anonfun$hasBlack$1.apply(<console>:40)
at org.apache.spark.sql.catalyst.expressions.ScalaUDF$$anonfun$2.apply(ScalaUDF.scala:92)
at org.apache.spark.sql.catalyst.expressions.ScalaUDF$$anonfun$2.apply(ScalaUDF.scala:91)
at org.apache.spark.sql.catalyst.expressions.ScalaUDF.eval(ScalaUDF.scala:1072)
... 114 more
EDIT: opened up a jira issue. Pasting here for tracking purposes:
https://issues.apache.org/jira/browse/SPARK-22942
This answer doesn't address why the issue exists but solutions that I have found to work around it.
I have run into a problem exactly like this. I'm not sure of the cause but I have two workarounds that work for me. Someone much smarter than me will probably be able to explain it all to you but here are my solutions to the problem.
FIRST SOLUTION
Spark is acting like the column doesn't exist yet. Probably because of some kind of filter push-down. Force Spark to cache the result prior to filtering. This makes the column "exist".
val joinedBreakdown = rightOnlyInfos.withColumn("has_black", hasBlack($"r_infos")).cache()
println(joinedBreakdown.count()) //This will force cache the results from after the UDF has been applied.
joinedBreakdown.filter("has_black == true").show(false)
joinedBreakdown.filter("has_black == true").explain
OUTPUT
2
+----+-----------------------+---------+
|name|r_infos |has_black|
+----+-----------------------+---------+
|d |[[4,orange], [2,black]]|true |
+----+-----------------------+---------+
== Physical Plan ==
*Filter (has_black#112632 = true)
+- InMemoryTableScan [name#112622, r_infos#112628, has_black#112632], [(has_black#112632 = true)]
+- InMemoryRelation [name#112622, r_infos#112628, has_black#112632], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)
+- *Project [coalesce(name#112606, name#112614) AS name#112622, infos#112615 AS r_infos#112628, UDF(infos#112615) AS has_black#112632]
+- *Filter isnull(infos#112607)
+- SortMergeJoin [name#112606], [name#112614], FullOuter
:- *Sort [name#112606 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(name#112606, 200)
: +- LocalTableScan [name#112606, infos#112607]
+- *Sort [name#112614 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(name#112614, 200)
+- LocalTableScan [name#112614, infos#112615]
SECOND SOLUTION
No idea why this one works, but do the same thing that you did except put a try/catch in the UDF. And before I get yelled at for it, please know that using try/catch for control flow is an anit-pattern. To learn more I recommend this question and answer. NOTE: I edited your UDF slightly to make it look like something that I am more familiar with.
def hasBlack = udf((s: Seq[Row]) => {
try{
s.exists{ case Row(num: Int, color: String) =>
color == "black"
}
} catch {
case ex: Exception => false
}
})
val joinedBreakdown = rightOnlyInfos.withColumn("has_black", hasBlack($"r_infos"))
joinedBreakdown.filter("has_black == true").explain
joinedBreakdown.filter("has_black == true").show(false)
OUTPUT
== Physical Plan ==
*Project [coalesce(name#112565, name#112573) AS name#112581, infos#112574 AS r_infos#112587, UDF(infos#112574) AS has_black#112591]
+- *Filter isnull(infos#112566)
+- *BroadcastHashJoin [name#112565], [name#112573], RightOuter, BuildLeft, false
:- BroadcastExchange HashedRelationBroadcastMode(ArrayBuffer(input[0, string, false]))
: +- *Filter isnotnull(name#112565)
: +- LocalTableScan [name#112565, infos#112566]
+- *Filter (UDF(infos#112574) = true)
+- LocalTableScan [name#112573, infos#112574]
+----+-----------------------+---------+
|name|r_infos |has_black|
+----+-----------------------+---------+
|d |[[4,orange], [2,black]]|true |
+----+-----------------------+---------+
You can see that the query plans are different due to the fact that I am forcing the application of the UDF prior to the filter.
How to express COUNT(DISTINCT ...) in Slick?
I want to build an equivalent of this query:
sql"""select formatdatetime("timestamp",'yyyy.MM.dd'), count(distinct "order_id")
from "sales" group by
formatdatetime("timestamp",'yyyy.MM.dd')""".as[(String,Option[Int])]
I tried this:
val values = sales groupBy { entry =>
formatDatetime(entry.timestamp, datetimeFormat)
} map { case(formattedDatetime, group) =>
(formattedDatetime, group.distinctOn(_.orderId).length.?)
}
Which throws runtime exception:
[info] slick.SlickTreeException: Cannot convert node to SQL Comprehension
[info] | Path s9._2 : Vector[t2<{s3: Int', s4: java.sql.Timestamp', s5: scala.math.BigDecimal', s6: java.sql.Timestamp', s7: String', s8: String'}>]
(I use H2)
What definitely works / My best shot so far:
val countDistinctOrderId = SimpleExpression.nullary[Int] { queryBuilder =>
import slick.util.MacroSupport._
import queryBuilder._
b"""count(distinct "order_id")"""
}
val values = sales groupBy { entry =>
formatDatetime(entry.timestamp, datetimeFormat)
} map { case(formattedDatetime, group) =>
(formattedDatetime, countDistinctOrderId.?)
}
I have a number of filters that i need to apply to a data frame in Spark, but it is first at runtime i know which filters to user. Currently i am adding them in individual filter functions, but that fails if one of the filtes is not defined
myDataFrame
.filter(_filter1)
.filter(_filter2)
.filter(_filter3)...
I can't really find out how to dynamically at runtime exclude fx _filter2 if that is not needed?
Should i do it by creating one big filter:
var filter = _filter1
if (_filter2 != null)
filter = filter.and(_filter2)
...
Or is there a good pattern for this in Spark that i haven't found?
One possible solution is to default all filters to lit(true):
import org.apache.spark.sql.functions._
val df = Seq(1, 2, 3).toDF("x")
val filter_1 = lit(true)
val filter_2 = col("x") > 1
val filter_3 = lit(true)
val filtered = df.filter(filter_1).filter(filter_2).filter(filter_3)
This will keep null out of your code and trivially true predicates will be pruned from the execution plan:
filtered.explain
== Physical Plan ==
*Project [value#1 AS x#3]
+- *Filter (value#1 > 1)
+- LocalTableScan [value#1]
You can of course make it even simpler and a sequence of predicates:
import org.apache.spark.sql.Column
val preds: Seq[Column] = Seq(lit(true), col("x") > 1, lit(true))
df.where(preds.foldLeft(lit(true))(_ and _))
and, if implemented right, skip placeholders completely.
At first I would get rid of null filters:
val filters:List[A => Boolean] = nullableFilters.filter(_!=null)
Then define function to chain filters:
def chainFilters[A](filters:List[A => Boolean])(v:A) = filters.forall(f => f(v))
Now you can simply apply filters to your df:
df.filter(chainFilters(nullableFilters.filter(_!=null))
Why not:
var df = // load
if (_filter2 != null) {
df = df.filter(_filter2)
}
etc
Alternatively, create a list of filters:
var df = // load
val filters = Seq (filter1, filter2, filter3, ...)
filters.filter(_ != null).foreach (x => df = df.filter(x))
// Sorry if there is some mistake in code, it's more an idea - currently I can't test code