Set literal value over Window if condition suited Spark Scala - scala

I need to check a condition over a window:
- If the column IND_DEF is 20, then I want to change the value of the column premium for the window to which this register belongs to, and set it to 1.
My initial Dataframe looks like this:
+--------+----+-------+-----+-------+
|policyId|name|premium|state|IND_DEF|
+--------+----+-------+-----+-------+
| 1| BK| null| KT| 40|
| 1| AK| -31| null| 30|
| 1| VZ| null| IL| 20|
| 2| VK| 32| LI| 7|
| 2| CK| 25| YNZ| 10|
| 2| CK| 0| null| 5|
| 2| VK| 30| IL| 25|
+--------+----+-------+-----+-------+
And I want to achieve this:
+--------+----+-------+-----+-------+
|policyId|name|premium|state|IND_DEF|
+--------+----+-------+-----+-------+
| 1| BK| 1| KT| 40|
| 1| AK| 1| null| 30|
| 1| VZ| 1| IL| 20|
| 2| VK| 32| LI| 7|
| 2| CK| 25| YNZ| 10|
| 2| CK| 0| null| 5|
| 2| VK| 30| IL| 25|
+--------+----+-------+-----+-------+
I am trying the following code but does not work...
val df_946 = Seq [(Int, String, Integer, String, Int)]((1,"VZ",null,"IL",20),(1, "AK", -31,null,30),(1,"BK", null,"KT",40),(2,"CK",0,null,5),(2,"CK",25,"YNZ",10),(2,"VK",30,"IL",25),(2,"VK",32,"LI",7)).toDF("policyId", "name", "premium", "state","IND_DEF").orderBy("policyId")
val winSpec = Window.partitionBy("policyId").orderBy("policyId")
val df_947 = df_946.withColumn("premium",when(col("IND_DEF") === 20,lit(1).over(winSpec)).otherwise(col("premium")))

You can generate an array of IND_DEF values via collect_list for each window partition and recreate column premium based on the array_contains condition:
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import spark.implicits._
val df = Seq(
(1, None, 40),
(1, Some(-31), 30),
(1, None, 20),
(2, Some(32), 7),
(2, Some(30), 10)
).toDF("policyId", "premium", "IND_DEF")
val win = Window.partitionBy($"policyId")
df.
withColumn("indList", collect_list($"IND_DEF").over(win)).
withColumn("premium", when(array_contains($"indList", 20), 1).otherwise($"premium")).
drop($"indList").
show
// +--------+-------+-------+
// |policyId|premium|IND_DEF|
// +--------+-------+-------+
// | 1| 1| 40|
// | 1| 1| 30|
// | 1| 1| 20|
// | 2| 32| 7|
// | 2| 30| 10|
// +--------+-------+-------+

Related

Spark Categorize ordered dataframe values by a condition

Let's say I have a dataframe
val userData = spark.createDataFrame(Seq(
(1, 0),
(2, 2),
(3, 3),
(4, 0),
(5, 3),
(6, 4)
)).toDF("order_clause", "some_value")
userData.withColumn("passed", when(col("some_value") <= 1.5,1))
.show()
+------------+----------+------+
|order_clause|some_value|passed|
+------------+----------+------+
| 1| 0| 1|
| 2| 2| null|
| 3| 3| null|
| 4| 0| 1|
| 5| 3| null|
| 6| 4| null|
+------------+----------+------+
That dataframe is ordered by order_clause. When values in some_value become smaller than 1.5 I can say one round is done.
What I want to do is create column round like:
+------------+----------+------+-----+
|order_clause|some_value|passed|round|
+------------+----------+------+-----+
| 1| 0| 1| 1|
| 2| 2| null| 1|
| 3| 3| null| 1|
| 4| 0| 1| 2|
| 5| 3| null| 2|
| 6| 4| null| 2|
+------------+----------+------+-----+
Now I could be able to get subsets of rounds in this dataframe. I searched for hints how to do this but have not found a way to do this.
You're probably looking for a rolling sum of the passed column. You can do it using a sum window function:
import org.apache.spark.sql.expressions.Window
val result = userData.withColumn(
"passed",
when(col("some_value") <= 1.5, 1)
).withColumn(
"round",
sum("passed").over(Window.orderBy("order_clause"))
)
result.show
+------------+----------+------+-----+
|order_clause|some_value|passed|round|
+------------+----------+------+-----+
| 1| 0| 1| 1|
| 2| 2| null| 1|
| 3| 3| null| 1|
| 4| 0| 1| 2|
| 5| 3| null| 2|
| 6| 4| null| 2|
+------------+----------+------+-----+
Or more simply
import org.apache.spark.sql.expressions.Window
val result = userData.withColumn(
"round",
sum(when(col("some_value") <= 1.5, 1)).over(Window.orderBy("order_clause"))
)

Spark window functions: first match in window

I'm trying to extend the results of my previous question, but haven't been able to figure out how to achieve my new goal.
Before, I wanted to key on either a flag match or a string match. Now, I want to create a unique grouping key from a run starting with either a flag being true or the first string match preceding a run of true flag values.
Here's some example data:
val msgList = List("b", "f")
val df = spark.createDataFrame(Seq(("a", false), ("b", false), ("c", false), ("b", false), ("c", true), ("d", false), ("e", true), ("f", true), ("g", false)))
.toDF("message", "flag")
.withColumn("index", monotonically_increasing_id)
df.show
+-------+-----+-----+
|message| flag|index|
+-------+-----+-----+
| a|false| 0|
| b|false| 1|
| c|false| 2|
| b|false| 3|
| c| true| 4|
| d|false| 5|
| e| true| 6|
| f| true| 7|
| g|false| 8|
+-------+-----+-----+
The desired output is something equivalent to either of key1 or key2:
+-------+-----+-----+-----+-----+
|message| flag|index| key1| key2|
+-------+-----+-----+-----+-----+
| a|false| 0| 0| null|
| b|false| 1| 1| 1|
| c|false| 2| 1| 1|
| b|false| 3| 1| 1|
| c| true| 4| 1| 1|
| d|false| 5| 2| null|
| e| true| 6| 3| 2|
| f| true| 7| 3| 2|
| g|false| 8| 4| null|
+-------+-----+-----+-----+-----+
From the answer to my previous question, I already have a precursor:
import org.apache.spark.sql.expressions.Window
val checkMsg = udf { (s: String) => s != null && msgList.exists(s.contains(_)) }
val df2 = df.withColumn("message_match", checkMsg($"message"))
.withColumn("match_or_flag", when($"message_match" || $"flag", 1).otherwise(0))
.withColumn("lead", lead("match_or_flag", -1, 1).over(Window.orderBy("index")))
.withColumn("switched", when($"match_or_flag" =!= $"lead", $"index"))
.withColumn("base_key", last("switched", ignoreNulls = true).over(Window.orderBy("index").rowsBetween(Window.unboundedPreceding, 0)))
df2.show
+-------+-----+-----+-------------+-------------+----+--------+--------+
|message| flag|index|message_match|match_or_flag|lead|switched|base_key|
+-------+-----+-----+-------------+-------------+----+--------+--------+
| a|false| 0| false| 0| 1| 0| 0|
| b|false| 1| true| 1| 0| 1| 1|
| c|false| 2| false| 0| 1| 2| 2|
| b|false| 3| true| 1| 0| 3| 3|
| c| true| 4| false| 1| 1| null| 3|
| d|false| 5| false| 0| 1| 5| 5|
| e| true| 6| false| 1| 0| 6| 6|
| f| true| 7| true| 1| 1| null| 6|
| g|false| 8| false| 0| 1| 8| 8|
+-------+-----+-----+-------------+-------------+----+--------+--------+
base_key here is somewhat close to key1 above, but assigns separate keys to rows 1 and rows 3-4. I want rows 1-4 to get a single key based on the fact that row 1 contains the first msgList match within or preceding a run of flag = true.
Looking at the Spark window function API, it looks like there might be some way to use rangeBetween to accomplish this as of Spark 2.3.0, but the docs are bare enough that I haven't been able to figure out how to make it work.

Creating a unique grouping key from column-wise runs in a Spark DataFrame

I have something analogous to this, where spark is my sparkContext. I've imported implicits._ in my sparkContext so I can use the $ syntax:
val df = spark.createDataFrame(Seq(("a", 0L), ("b", 1L), ("c", 1L), ("d", 1L), ("e", 0L), ("f", 1L)))
.toDF("id", "flag")
.withColumn("index", monotonically_increasing_id)
.withColumn("run_key", when($"flag" === 1, $"index").otherwise(0))
df.show
df: org.apache.spark.sql.DataFrame = [id: string, flag: bigint ... 2 more fields]
+---+----+-----+-------+
| id|flag|index|run_key|
+---+----+-----+-------+
| a| 0| 0| 0|
| b| 1| 1| 1|
| c| 1| 2| 2|
| d| 1| 3| 3|
| e| 0| 4| 0|
| f| 1| 5| 5|
+---+----+-----+-------+
I want to create another column with a unique grouping key for each nonzero chunk of run_key, something equivalent to this:
+---+----+-----+-------+---+
| id|flag|index|run_key|key|
+---+----+-----+-------+---|
| a| 0| 0| 0| 0|
| b| 1| 1| 1| 1|
| c| 1| 2| 2| 1|
| d| 1| 3| 3| 1|
| e| 0| 4| 0| 0|
| f| 1| 5| 5| 2|
+---+----+-----+-------+---+
It could be the first value in each run, average of each run, or some other value -- it doesn't really matter as long as it's guaranteed to be unique so that I can group on it afterward to compare other values between groups.
Edit: BTW, I don't need to retain the rows where flag is 0.
One approach would be to 1) create a column $"lag1" using Window function lag() from $"flag", 2) create another column $"switched" with $"index" value in rows where $"flag" is switched, and finally 3) create the column which copies $"switched" from the last non-null row via last() and rowsBetween().
Note that this solution uses Window function without partitioning hence may not work for large dataset.
val df = Seq(
("a", 0L), ("b", 1L), ("c", 1L), ("d", 1L), ("e", 0L), ("f", 1L),
("g", 1L), ("h", 0L), ("i", 0L), ("j", 1L), ("k", 1L), ("l", 1L)
).toDF("id", "flag").
withColumn("index", monotonically_increasing_id).
withColumn("run_key", when($"flag" === 1, $"index").otherwise(0))
import org.apache.spark.sql.expressions.Window
df.withColumn( "lag1", lag("flag", 1, -1).over(Window.orderBy("index")) ).
withColumn( "switched", when($"flag" =!= $"lag1", $"index") ).
withColumn( "key", last("switched", ignoreNulls = true).over(
Window.orderBy("index").rowsBetween(Window.unboundedPreceding, 0)
) )
// +---+----+-----+-------+----+--------+---+
// | id|flag|index|run_key|lag1|switched|key|
// +---+----+-----+-------+----+--------+---+
// | a| 0| 0| 0| -1| 0| 0|
// | b| 1| 1| 1| 0| 1| 1|
// | c| 1| 2| 2| 1| null| 1|
// | d| 1| 3| 3| 1| null| 1|
// | e| 0| 4| 0| 1| 4| 4|
// | f| 1| 5| 5| 0| 5| 5|
// | g| 1| 6| 6| 1| null| 5|
// | h| 0| 7| 0| 1| 7| 7|
// | i| 0| 8| 0| 0| null| 7|
// | j| 1| 9| 9| 0| 9| 9|
// | k| 1| 10| 10| 1| null| 9|
// | l| 1| 11| 11| 1| null| 9|
// +---+----+-----+-------+----+--------+---+
You can label the "run" with the largest index where flag is 0 smaller than the index of the row in question.
Something like:
flags = df.filter($"flag" === 0)
.select("index")
.withColumnRenamed("index", "flagIndex")
indices = df.select("index").join(flags, df.index > flags.flagIndex)
.groupBy($"index")
.agg(max($"index$).as("groupKey"))
dfWithGroups = df.join(indices, Seq("index"))

Pivot scala dataframe with conditional counting

I would like to aggregate this DataFrame and count the number of observations with a value less than or equal to the "BUCKET" field for each level. For example:
val myDF = Seq(
("foo", 0),
("foo", 0),
("bar", 0),
("foo", 1),
("foo", 1),
("bar", 1),
("foo", 2),
("bar", 2),
("foo", 3),
("bar", 3)).toDF("COL1", "BUCKET")
myDF.show
+----+------+
|COL1|BUCKET|
+----+------+
| foo| 0|
| foo| 0|
| bar| 0|
| foo| 1|
| foo| 1|
| bar| 1|
| foo| 2|
| bar| 2|
| foo| 3|
| bar| 3|
+----+------+
I can count the number of observations matching each bucket value using this code:
myDF.groupBy("COL1").pivot("BUCKET").count.show
+----+---+---+---+---+
|COL1| 0| 1| 2| 3|
+----+---+---+---+---+
| bar| 1| 1| 1| 1|
| foo| 2| 2| 1| 1|
+----+---+---+---+---+
But I want to count the number of rows with a value in the "BUCKET" field which is less than or equal to the final header after pivoting, like this:
+----+---+---+---+---+
|COL1| 0| 1| 2| 3|
+----+---+---+---+---+
| bar| 1| 2| 3| 4|
| foo| 2| 4| 5| 6|
+----+---+---+---+---+
You can achieve this using a window function, as follows:
import org.apache.spark.sql.expressions.Window.partitionBy
import org.apache.spark.sql.functions.first
myDF.
select(
$"COL1",
$"BUCKET",
count($"BUCKET").over(partitionBy($"COL1").orderBy($"BUCKET")).as("ROLLING_COUNT")).
groupBy($"COL1").pivot("BUCKET").agg(first("ROLLING_COUNT")).
show()
+----+---+---+---+---+
|COL1| 0| 1| 2| 3|
+----+---+---+---+---+
| bar| 1| 2| 3| 4|
| foo| 2| 4| 5| 6|
+----+---+---+---+---+
What you are specifying here is that you want to perform a count of your observations, partitioned in windows as determined by a key (COL1 in this case). By specifying an ordering, you are also making the count rolling over the window, thus obtaining the results you want then to be pivoted in your end results.
This is the result of applying the window function:
myDF.
select(
$"COL1",
$"BUCKET",
count($"BUCKET").over(partitionBy($"COL1").orderBy($"BUCKET")).as("ROLLING_COUNT")).
show()
+----+------+-------------+
|COL1|BUCKET|ROLLING_COUNT|
+----+------+-------------+
| bar| 0| 1|
| bar| 1| 2|
| bar| 2| 3|
| bar| 3| 4|
| foo| 0| 2|
| foo| 0| 2|
| foo| 1| 4|
| foo| 1| 4|
| foo| 2| 5|
| foo| 3| 6|
+----+------+-------------+
Finally, by grouping by COL1, pivoting over BUCKET and only getting the first result of the rolling count (anyone would be good as all of them are applied to the whole window), you finally obtain the result you were looking for.
In a way, window functions are very similar to aggregations over groupings, but are more flexible and powerful. This just scratches the surface of window functions and you can dig a little bit deeper by having a look at this introductory reading.
Here's one approach to get the rolling counts by traversing the pivoted BUCKET value columns using foldLeft to aggregate the counts. Note that a tuple of (DataFrame, Int) is used for foldLeft to transform the DataFrame as well as store the count in the previous iteration:
val pivotedDF = myDF.groupBy($"COL1").pivot("BUCKET").count
val buckets = pivotedDF.columns.filter(_ != "COL1")
buckets.drop(1).foldLeft((pivotedDF, buckets.head))( (acc, c) =>
( acc._1.withColumn(c, col(acc._2) + col(c)), c )
)._1.show
// +----+---+---+---+---+
// |COL1| 0| 1| 2| 3|
// +----+---+---+---+---+
// | bar| 1| 2| 3| 4|
// | foo| 2| 4| 5| 6|
// +----+---+---+---+---+

spark sql conditional maximum

I have a tall table which contains up to 10 values per group. How can I transform this table into a wide format i.e. add 2 columns where these resemble the value smaller or equal to a threshold?
I want to find the maximum per group, but it needs to be smaller than a specified value like:
min(max('value1), lit(5)).over(Window.partitionBy('grouping))
However min()will only work for a column and not for the Scala value which is returned from the inner function?
The problem can be described as:
Seq(Seq(1,2,3,4).max,5).min
Where Seq(1,2,3,4) is returned by the window.
How can I formulate this in spark sql?
edit
E.g.
+--------+-----+---------+
|grouping|value|something|
+--------+-----+---------+
| 1| 1| first|
| 1| 2| second|
| 1| 3| third|
| 1| 4| fourth|
| 1| 7| 7|
| 1| 10| 10|
| 21| 1| first|
| 21| 2| second|
| 21| 3| third|
+--------+-----+---------+
created by
case class MyThing(grouping: Int, value:Int, something:String)
val df = Seq(MyThing(1,1, "first"), MyThing(1,2, "second"), MyThing(1,3, "third"),MyThing(1,4, "fourth"),MyThing(1,7, "7"), MyThing(1,10, "10"),
MyThing(21,1, "first"), MyThing(21,2, "second"), MyThing(21,3, "third")).toDS
Where
df
.withColumn("somethingAtLeast5AndMaximum5", max('value).over(Window.partitionBy('grouping)))
.withColumn("somethingAtLeast6OupToThereshold2", max('value).over(Window.partitionBy('grouping)))
.show
returns
+--------+-----+---------+----------------------------+-------------------------+
|grouping|value|something|somethingAtLeast5AndMaximum5| somethingAtLeast6OupToThereshold2 |
+--------+-----+---------+----------------------------+-------------------------+
| 1| 1| first| 10| 10|
| 1| 2| second| 10| 10|
| 1| 3| third| 10| 10|
| 1| 4| fourth| 10| 10|
| 1| 7| 7| 10| 10|
| 1| 10| 10| 10| 10|
| 21| 1| first| 3| 3|
| 21| 2| second| 3| 3|
| 21| 3| third| 3| 3|
+--------+-----+---------+----------------------------+-------------------------+
Instead, I rather would want to formulate:
lit(Seq(max('value).asInstanceOf[java.lang.Integer], new java.lang.Integer(2)).min).over(Window.partitionBy('grouping))
But that does not work as max('value) is not a scalar value.
Expected output should look like
+--------+-----+---------+----------------------------+-------------------------+
|grouping|value|something|somethingAtLeast5AndMaximum5|somethingAtLeast6OupToThereshold2|
+--------+-----+---------+----------------------------+-------------------------+
| 1| 4| fourth| 4| 7|
| 21| 1| first| 3| NULL|
+--------+-----+---------+----------------------------+-------------------------+
edit2
When trying a pivot
df.groupBy("grouping").pivot("value").agg(first('something)).show
+--------+-----+------+-----+------+----+----+
|grouping| 1| 2| 3| 4| 7| 10|
+--------+-----+------+-----+------+----+----+
| 1|first|second|third|fourth| 7| 10|
| 21|first|second|third| null|null|null|
+--------+-----+------+-----+------+----+----+
The second part of the problem remains that some columns might not exist or be null.
When aggregating to arrays:
df.groupBy("grouping").agg(collect_list('value).alias("value"), collect_list('something).alias("something"))
+--------+-------------------+--------------------+
|grouping| value| something|
+--------+-------------------+--------------------+
| 1|[1, 2, 3, 4, 7, 10]|[first, second, t...|
| 21| [1, 2, 3]|[first, second, t...|
+--------+-------------------+--------------------+
The values are already next to each other, but the right values need to be selected. This is probably still more efficient than a join or window function.
Would be easier to do in two separate steps - calculate max over Window, and then use when...otherwise on result to produce min(x, 5):
df.withColumn("tmp", max('value1).over(Window.partitionBy('grouping)))
.withColumn("result", when('tmp > lit(5), 5).otherwise('tmp))
EDIT: some example data to clarify this:
val df = Seq((1, 1),(1, 2),(1, 3),(1, 4),(2, 7),(2, 8))
.toDF("grouping", "value1")
df.withColumn("result", max('value1).over(Window.partitionBy('grouping)))
.withColumn("result", when('result > lit(5), 5).otherwise('result))
.show()
// +--------+------+------+
// |grouping|value1|result|
// +--------+------+------+
// | 1| 1| 4| // 4, because Seq(Seq(1,2,3,4).max,5).min = 4
// | 1| 2| 4|
// | 1| 3| 4|
// | 1| 4| 4|
// | 2| 7| 5| // 5, because Seq(Seq(7,8).max,5).min = 5
// | 2| 8| 5|
// +--------+------+------+