Spark CASE WHEN/When Otherwise that doesn't stop evaluating and accumulates - pyspark

Say I have following table/dataframe:
Id
Col1
Col2
Col3
1
100
aaa
xxx
2
200
aaa
yyy
3
300
ccc
zzz
I need to calculate an extra column CalculatedValue which could have one or multiple values based on other columns' values.
I have tried with a regular CASE WHEN statement like:
df_out = (df_source
.withColumn('CalculatedValue',
expr("CASE WHEN Col1 = 100 THEN 'AAA111'
WHEN Col2 = 'aaa' then 'BBB222'
WHEN Col3 = 'zzz' then 'CCC333'
END")
)
Note I'm doing it with expr() because the actual CASE WHEN statement is a very long string built dynamically.
This results in a table/dataframe like this:
Id
Col1
Col2
Col3
CalculatedValue
1
100
aaa
xxx
AAA111
2
200
aaa
yyy
BBB222
3
300
ccc
zzz
CCC333
However what I need looks more like this, where the CASE WHEN statement didn't stop evaluating after the first match, and instead evaluated all conditions and accumulated all matches into, say, an array
Id
Col1
Col2
Col3
CalculatedValue
1
100
aaa
xxx
[AAA111, BBB222]
2
200
aaa
yyy
BBB222
3
300
ccc
zzz
CCC333
Any ideas?
Thanks

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, when}
import spark.implicits._
val df = Seq(
(1, 100, "aaa", "xxx"),
(2, 200, "aaa", "yyy"),
(3, 300, "ccc", "zzz")
).toDF("Id", "Col1", "Col2", "Col3")
val resDF = df
.withColumn(
"CalculatedValue",
when(
col("Col1") === 100 && col("Col2") === "aaa" && col("Col3") === "zzz",
Array("AAA111", "BBB222", "CCC333")
).when(
col("Col1") === 100 && col("Col2") === "aaa" && col(
"Col3"
) =!= "zzz",
Array("AAA111", "BBB222")
)
.when(
col("Col1") === 100 && col("Col2") =!= "aaa" && col(
"Col3"
) === "zzz",
Array("AAA111", "CCC333")
)
.when(
col("Col1") =!= 100 && col("Col2") === "aaa" && col(
"Col3"
) === "zzz",
Array("BBB222", "CCC333")
)
.when(
col("Col1") =!= 100 && col("Col2") =!= "aaa" && col(
"Col3"
) === "zzz",
Array("CCC333")
)
.when(
col("Col1") === 100 && col("Col2") =!= "aaa" && col(
"Col3"
) =!= "zzz",
Array("AAA111")
)
.when(
col("Col1") =!= 100 && col("Col2") === "aaa" && col(
"Col3"
) =!= "zzz",
Array("BBB222")
)
.otherwise(Array("unknown"))
)
resDF.show(false)
/*
+---+----+----+----+----------------+
|Id |Col1|Col2|Col3|CalculatedValue |
+---+----+----+----+----------------+
|1 |100 |aaa |xxx |[AAA111, BBB222]|
|2 |200 |aaa |yyy |[BBB222] |
|3 |300 |ccc |zzz |[CCC333] |
+---+----+----+----+----------------+
*/

Related

spark scala conditional join replace null values

I have two dataframes. I want to replace values in col1 of df1 where values are null using the values from col1 of df2. Please keep in mind df1 can have > 10^6 rows similarly to df2 and that df1 have some additional columns which are different from some addtional columns of df2.
I know how to do join but I do not know how to do some kind of conditional join here in Spark with Scala.
df1
name | col1 | col2 | col3
----------------------------
foo | 0.1 | ...
bar | null |
hello | 0.6 |
foobar | null |
df2
name | col1 | col7
--------------------
lorem | 0.1 |
bar | 0.52 |
foobar | 0.47 |
EDIT:
This is my current solution:
df1.select("name", "col2", "col3").join(df2, (df1("name") === df2("name")), "left").select(df1("name"), col("col1"))
EDIT2:
val df1 = Seq(
("foo", Seq(0.1), 10, "a"),
("bar", Seq(), 20, "b"),
("hello", Seq(0.1), 30, "c"),
("foobar", Seq(), 40, "d")
).toDF("name", "col1", "col2", "col3")
val df2 = Seq(
("lorem", Seq(0.1), "x"),
("bar", Seq(0.52), "y"),
("foobar", Seq(0.47), "z")
).toDF("name", "col1", "col7")
display(df1.
join(df2, Seq("name"), "left_outer").
select(df1("name"), coalesce(df1("col1"), df2("col1")).as("col1")))
returns:
name | col1
bar | []
foo | [0.1]
foobar | []
hello | [0.1]
Consider using coalesce on col1 after performing the left join. To handle both nulls and empty arrays (in the case of ArrayType) as per revised requirement in the comments section, a when/otherwise clause is used, as shown below:
val df1 = Seq(
("foo", Some(Seq(0.1)), 10, "a"),
("bar", None, 20, "b"),
("hello", Some(Seq(0.1)), 30, "c"),
("foobar", Some(Seq()), 40, "d")
).toDF("name", "col1", "col2", "col3")
val df2 = Seq(
("lorem", Seq(0.1), "x"),
("bar", Seq(0.52), "y"),
("foobar", Seq(0.47), "z")
).toDF("name", "col1", "col7")
df1.
join(df2, Seq("name"), "left_outer").
select(
df1("name"),
coalesce(
when(lit(df1.schema("col1").dataType.typeName) === "array" && size(df1("col1")) === 0, df2("col1")).otherwise(df1("col1")),
df2("col1")
).as("col1")
).
show
/*
+------+------+
| name| col1|
+------+------+
| foo| [0.1]|
| bar|[0.52]|
| hello| [0.1]|
|foobar|[0.47]|
+------+------+
*/
UPDATE:
It appears that Spark, surprisingly, does not handle conditionA && conditionB the way most other languages do -- even when conditionA is false conditionB will still be evaluated, and replacing && with nested when/otherwise still would not resolve the issue. It might be due to limitations in how the internally translated case/when/else SQL is executed.
As a result, the above when/otherwise data-type check via array-specific function size() fails when col1 is non-ArrayType. Given that, I would forgo the dynamic column type check and perform different queries based on whether col1 is ArrayType or not, assuming it's known upfront:
df1.
join(df2, Seq("name"), "left_outer").
select(
df1("name"),
coalesce(
when(size(df1("col1")) === 0, df2("col1")).otherwise(df1("col1")), // <-- if col1 is an array
// df1("col1"), // <-- if col1 is not an array
df2("col1")
).as("col1")
).
show

Spark - Drop null values from map column

I'm using Spark to read a CSV file and then gather all the fields to create a map. Some of the fields are empty and I'd like to remove them from the map.
So for a CSV that looks like this:
"animal", "colour", "age"
"cat" , "black" ,
"dog" , , "3"
I'd like to get a dataset with the following maps:
Map("animal" -> "cat", "colour" -> "black")
Map("animal" -> "dog", "age" -> "3")
This is what I have so far:
val csv_cols_n_vals: Array[Column] = csv.columns.flatMap { c => Array(lit(c), col(c)) }
sparkSession.read
.option("header", "true")
.csv(csvLocation)
.withColumn("allFieldsMap", map(csv_cols_n_vals: _*))
I've tried a few variations, but I can't seem to find the correct solution.
There is most certainly a better and more efficient way using the Dataframe API, but here is a map/flatmap solution:
val df = Seq(("cat", "black", null), ("dog", null, "3")).toDF("animal", "colour", "age")
val cols = df.columns
df.map(r => {
cols.flatMap( c => {
val v = r.getAs[String](c)
if (v != null) {
Some(Map(c -> v))
} else {
None
}
}).reduce(_ ++ _)
}).toDF("map").show(false)
Which produces:
+--------------------------------+
|map |
+--------------------------------+
|[animal -> cat, colour -> black]|
|[animal -> dog, age -> 3] |
+--------------------------------+
scala> df.show(false)
+------+------+----+
|animal|colour|age |
+------+------+----+
|cat |black |null|
|dog |null |3 |
+------+------+----+
Building Expressions
val colExpr = df
.columns // getting list of columns from dataframe.
.map{ columnName =>
when(
col(columnName).isNotNull, // checking if column is not null
map(
lit(columnName),
col(columnName)
) // Adding column name and its value inside map
)
.otherwise(map())
}
.reduce(map_concat(_,_))
// finally using map_concat function to concat map values.
Above code will create below expressions.
map_concat(
map_concat(
CASE WHEN (animal IS NOT NULL) THEN map(animal, animal) ELSE map() END,
CASE WHEN (colour IS NOT NULL) THEN map(colour, colour) ELSE map() END
),
CASE WHEN (age IS NOT NULL) THEN map(age, age) ELSE map() END
)
Applying colExpr on DataFrame.
scala>
df
.withColumn("allFieldsMap",colExpr)
.show(false)
+------+------+----+--------------------------------+
|animal|colour|age |allFieldsMap |
+------+------+----+--------------------------------+
|cat |black |null|[animal -> cat, colour -> black]|
|dog |null |3 |[animal -> dog, age -> 3] |
+------+------+----+--------------------------------+
Spark-sql solution:
val df = Seq(("cat", "black", null), ("dog", null, "3")).toDF("animal", "colour", "age")
df.show(false)
+------+------+----+
|animal|colour|age |
+------+------+----+
|cat |black |null|
|dog |null |3 |
+------+------+----+
df.createOrReplaceTempView("a_vw")
val cols_str = df.columns.flatMap( x => Array("\"".concat(x).concat("\""),x)).mkString(",")
spark.sql(s"""
select collect_list(m2) res from (
select id, key, value, map(key,value) m2 from (
select id, explode(m) as (key,value) from
( select monotonically_increasing_id() id, map(${cols_str}) m from a_vw )
)
where value is not null
) group by id
""")
.show(false)
+------------------------------------+
|res |
+------------------------------------+
|[[animal -> cat], [colour -> black]]|
|[[animal -> dog], [age -> 3]] |
+------------------------------------+
Or much shorter
spark.sql(s"""
select collect_list(case when value is not null then map(key,value) end ) res from (
select id, explode(m) as (key,value) from
( select monotonically_increasing_id() id, map(${cols_str}) m from a_vw )
) group by id
""")
.show(false)
+------------------------------------+
|res |
+------------------------------------+
|[[animal -> cat], [colour -> black]]|
|[[animal -> dog], [age -> 3]] |
+------------------------------------+

Combine two datasets based on value

I have following two datasets:
val dfA = Seq(
("001", "10", "Cat"),
("001", "20", "Dog"),
("001", "30", "Bear"),
("002", "10", "Mouse"),
("002", "20", "Squirrel"),
("002", "30", "Turtle"),
).toDF("Package", "LineItem", "Animal")
val dfB = Seq(
("001", "", "X", "A"),
("001", "", "Y", "B"),
("002", "", "X", "C"),
("002", "", "Y", "D"),
("002", "20", "X" ,"E")
).toDF("Package", "LineItem", "Flag", "Category")
I need to join them with specific conditions:
a) There is always a row in dfB with the X flag and empty LineItem which should be the default Category for the Package from dfA
b) When there is a LineItem specified in dfB the default Category should be overwritten with the Category associated to this LineItem
Expected output:
+---------+----------+----------+----------+
| Package | LineItem | Animal | Category |
+---------+----------+----------+----------+
| 001 | 10 | Cat | A |
+---------+----------+----------+----------+
| 001 | 20 | Dog | A |
+---------+----------+----------+----------+
| 001 | 30 | Bear | A |
+---------+----------+----------+----------+
| 002 | 10 | Mouse | C |
+---------+----------+----------+----------+
| 002 | 20 | Squirrel | E |
+---------+----------+----------+----------+
| 002 | 30 | Turtle | C |
+---------+----------+----------+----------+
I spend some time on it today, but I don't have an idea how it could be accomplished. I appreciate your assistance.
Thanks!
You can use two join + when clause:
val dfC = dfA
.join(dfB, dfB.col("Flag") === "X" && dfA.col("LineItem") === dfB.col("LineItem") && dfA.col("Package") === dfB.col("Package"))
.select(dfA.col("Package").as("priorPackage"), dfA.col("LineItem").as("priorLineItem"), dfB.col("Category").as("priorCategory"))
.as("dfC")
val dfD = dfA
.join(dfB, dfB.col("LineItem") === "" && dfB.col("Flag") === "X" && dfA.col("Package") === dfB.col("Package"), "left_outer")
.join(dfC, dfA.col("LineItem") === dfC.col("priorLineItem") && dfA.col("Package") === dfC.col("priorPackage"), "left_outer")
.select(
dfA.col("package"),
dfA.col("LineItem"),
dfA.col("Animal"),
when(dfC.col("priorCategory").isNotNull, dfC.col("priorCategory")).otherwise(dfB.col("Category")).as("Category")
)
dfD.show()
This should work for you:
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions._
val dfA = Seq(
("001", "10", "Cat"),
("001", "20", "Dog"),
("001", "30", "Bear"),
("002", "10", "Mouse"),
("002", "20", "Squirrel"),
("002", "30", "Turtle")
).toDF("Package", "LineItem", "Animal")
val dfB = Seq(
("001", "", "X", "A"),
("001", "", "Y", "B"),
("002", "", "X", "C"),
("002", "", "Y", "D"),
("002", "20", "X" ,"E")
).toDF("Package", "LineItem", "Flag", "Category")
val result = {
dfA.as("a")
.join(dfB.where('Flag === "X").as("b"), $"a.Package" === $"b.Package" and ($"a.LineItem" === $"b.LineItem" or $"b.LineItem" === ""), "left")
.withColumn("anyRowsInGroupWithBLineItemDefined", first(when($"b.LineItem" =!= "", lit(true)), ignoreNulls = true).over(Window.partitionBy($"a.Package", $"a.LineItem")).isNotNull)
.where(!$"anyRowsInGroupWithBLineItemDefined" or ($"anyRowsInGroupWithBLineItemDefined" and $"b.LineItem" =!= ""))
.select($"a.Package", $"a.LineItem", $"a.Animal", $"b.Category")
}
result.orderBy($"a.Package", $"a.LineItem").show(false)
// +-------+--------+--------+--------+
// |Package|LineItem|Animal |Category|
// +-------+--------+--------+--------+
// |001 |10 |Cat |A |
// |001 |20 |Dog |A |
// |001 |30 |Bear |A |
// |002 |10 |Mouse |C |
// |002 |20 |Squirrel|E |
// |002 |30 |Turtle |C |
// +-------+--------+--------+--------+
The "tricky" part is calculating whether or not there are any rows with LineItem defined in dfB for a given Package, LineItem in dfA. You can see how I perform this calculation in anyRowsInGroupWithBLineItemDefined which involves the use of a window function. Other than that, it's just a normal SQL programming exercise.
Also want to note that this code should be more efficient than the other solution as here we only shuffle the data twice (during join and during window function) and only read in each dataset once.

How do I compare each column in a table using DataFrame by Scala

There are two tables; one is ID Table 1 and the other is Attribute Table 2.
Table 1
Table 2
If the IDs the same row in Table 1 has same attribrte, then we get number 1, else we get 0. Finally, we get the result Table 3.
Table 3
For example, id1 and id2 have different color and size, so the id1 and id2 row(2nd row in Table 3) has "id1 id2 0 0";
id1 and id3 have same color and different size, so the id1 and id3 row(3nd row in Table 3) has "id1 id3 1 0";
Same attribute---1
Different attribute---0
How can I get the result Table 3 using Scala dataframe?
This should do the trick
import spark.implicits._
val t1 = List(
("id1","id2"),
("id1","id3"),
("id2","id3")
).toDF("id_x", "id_y")
val t2 = List(
("id1","blue","m"),
("id2","red","s"),
("id3","blue","s")
).toDF("id", "color", "size")
t1
.join(t2.as("x"), $"id_x" === $"x.id", "inner")
.join(t2.as("y"), $"id_y" === $"y.id", "inner")
.select(
'id_x,
'id_y,
when($"x.color" === $"y.color",1).otherwise(0).alias("color").cast(IntegerType),
when($"x.size" === $"y.size",1).otherwise(0).alias("size").cast(IntegerType)
)
.show()
Resulting in:
+----+----+-----+----+
|id_x|id_y|color|size|
+----+----+-----+----+
| id1| id2| 0| 0|
| id1| id3| 1| 0|
| id2| id3| 0| 1|
+----+----+-----+----+
Here is how you can do it using UDF which helps you to understand, how ever the repetition of code and be minimized to increase the performance
import spark.implicits._
val df1 = spark.sparkContext.parallelize(Seq(
("id1", "id2"),
("id1","id3"),
("id2","id3")
)).toDF("idA", "idB")
val df2 = spark.sparkContext.parallelize(Seq(
("id1", "blue", "m"),
("id2", "red", "s"),
("id3", "blue", "s")
)).toDF("id", "color", "size")
val firstJoin = df1.join(df2, df1("idA") === df2("id"), "inner")
.withColumnRenamed("color", "colorA")
.withColumnRenamed("size", "sizeA")
.withColumnRenamed("id", "idx")
val secondJoin = firstJoin.join(df2, firstJoin("idB") === df2("id"), "inner")
val check = udf((v1: String, v2:String ) => {
if (v1.equalsIgnoreCase(v2)) 1 else 0
})
val result = secondJoin
.withColumn("color", check(col("colorA"), col("color")))
.withColumn("size", check(col("sizeA"), col("size")))
val finalResult = result.select("idA", "idB", "color", "size")
Hope this helps!

Dataframe filter issue, how to do?

Env: Spark 1.6, Scala
My dataframe is like bellow
DF=
DT col1 col2
----------|---|----
2017011011| AA| BB
2017011011| CC| DD
2017011015| PP| BB
2017011015| QQ| DD
2017011016| AA| BB
2017011016| CC| DD
2017011017| PP| BB
2017011017| QQ| DD
How can I filter to get result like SQL - select * from DF where dt> (select distinct dt from DF order by dt desc limit 3)
output have last 3 dates
2017011015 |PP |BB
2017011015 |QQ |DD
2017011016 |AA |BB
2017011016 |CC |DD
2017011017 |PP |BB
2017011017 |QQ |DD
Thanks
Hossain
Tested on Spark 1.6.1
import sqlContext.implicit._
val df = sqlContext.createDataFrame(Seq(
(2017011011, "AA", "BB"),
(2017011011, "CC", "DD"),
(2017011015, "PP", "BB"),
(2017011015, "QQ", "DD"),
(2017011016, "AA", "BB"),
(2017011016, "CC", "DD"),
(2017011017, "PP", "BB"),
(2017011017, "QQ", "DD")
)).select(
$"_1".as("DT"),
$"_2".as("col1"),
$"_3".as("col2")
)
val dates = df.select($"DT")
.distinct()
.orderBy(-$"DT")
.map(_.getInt(0))
.take(3)
val result = df.filter(dates.map($"DT" === _).reduce(_ || _))
result.show()