Adding column with sum of all rows above in same grouping - scala

I need to create a 'rolling count' column which takes the previous count and adds the new count for each day and company. I have already organized and sorted the dataframe into groups of ascending dates per company with the corresponding count. I also added a 'ix' column which indexes each grouping, like so:
+--------------------+--------------------+-----+---+
| Normalized_Date| company|count| ix|
+--------------------+--------------------+-----+---+
|09/25/2018 00:00:...|[5c40c8510fb7c017...| 7| 1|
|09/25/2018 00:00:...|[5bdb2b543951bf07...| 9| 1|
|11/28/2017 00:00:...|[593b0d9f3f21f9dd...| 7| 1|
|11/29/2017 00:00:...|[593b0d9f3f21f9dd...| 60| 2|
|01/09/2018 00:00:...|[593b0d9f3f21f9dd...| 1| 3|
|04/27/2018 00:00:...|[593b0d9f3f21f9dd...| 9| 4|
|09/25/2018 00:00:...|[593b0d9f3f21f9dd...| 29| 5|
|11/20/2018 00:00:...|[593b0d9f3f21f9dd...| 42| 6|
|12/11/2018 00:00:...|[593b0d9f3f21f9dd...| 317| 7|
|01/04/2019 00:00:...|[593b0d9f3f21f9dd...| 3| 8|
|02/13/2019 00:00:...|[593b0d9f3f21f9dd...| 15| 9|
|04/01/2019 00:00:...|[593b0d9f3f21f9dd...| 1| 10|
+--------------------+--------------------+-----+---+
The output I need would simply add up all the counts up to that date for each company. Like so:
+--------------------+--------------------+-----+---+------------+
| Normalized_Date| company|count| ix|RollingCount|
+--------------------+--------------------+-----+---+------------+
|09/25/2018 00:00:...|[5c40c8510fb7c017...| 7| 1| 7|
|09/25/2018 00:00:...|[5bdb2b543951bf07...| 9| 1| 9|
|11/28/2017 00:00:...|[593b0d9f3f21f9dd...| 7| 1| 7|
|11/29/2017 00:00:...|[593b0d9f3f21f9dd...| 60| 2| 67|
|01/09/2018 00:00:...|[593b0d9f3f21f9dd...| 1| 3| 68|
|04/27/2018 00:00:...|[593b0d9f3f21f9dd...| 9| 4| 77|
|09/25/2018 00:00:...|[593b0d9f3f21f9dd...| 29| 5| 106|
|11/20/2018 00:00:...|[593b0d9f3f21f9dd...| 42| 6| 148|
|12/11/2018 00:00:...|[593b0d9f3f21f9dd...| 317| 7| 465|
|01/04/2019 00:00:...|[593b0d9f3f21f9dd...| 3| 8| 468|
|02/13/2019 00:00:...|[593b0d9f3f21f9dd...| 15| 9| 483|
|04/01/2019 00:00:...|[593b0d9f3f21f9dd...| 1| 10| 484|
+--------------------+--------------------+-----+---+------------+
I figured the lag function would be of use, and I was able to get each row of rollingcount with ix > 1 to add the count directly above it with the following code:
w = Window.partitionBy('company').orderBy(F.unix_timestamp('Normalized_Dat e','MM/dd/yyyy HH:mm:ss aaa').cast('timestamp'))
refined_DF = solutionDF.withColumn("rn", F.row_number().over(w))
solutionDF = refined_DF.withColumn('RollingCount',F.when(refined_DF['rn'] > 1, refined_DF['count'] + F.lag(refined_DF['count'],count= 1 ).over(w)).otherwise(refined_DF['count']))
which yields the following df:
+--------------------+--------------------+-----+---+------------+
| Normalized_Date| company|count| ix|RollingCount|
+--------------------+--------------------+-----+---+------------+
|09/25/2018 00:00:...|[5c40c8510fb7c017...| 7| 1| 7|
|09/25/2018 00:00:...|[5bdb2b543951bf07...| 9| 1| 9|
|11/28/2017 00:00:...|[593b0d9f3f21f9dd...| 7| 1| 7|
|11/29/2017 00:00:...|[593b0d9f3f21f9dd...| 60| 2| 67|
|01/09/2018 00:00:...|[593b0d9f3f21f9dd...| 1| 3| 61|
|04/27/2018 00:00:...|[593b0d9f3f21f9dd...| 9| 4| 10|
|09/25/2018 00:00:...|[593b0d9f3f21f9dd...| 29| 5| 38|
|11/20/2018 00:00:...|[593b0d9f3f21f9dd...| 42| 6| 71|
|12/11/2018 00:00:...|[593b0d9f3f21f9dd...| 317| 7| 359|
|01/04/2019 00:00:...|[593b0d9f3f21f9dd...| 3| 8| 320|
|02/13/2019 00:00:...|[593b0d9f3f21f9dd...| 15| 9| 18|
|04/01/2019 00:00:...|[593b0d9f3f21f9dd...| 1| 10| 16|
+--------------------+--------------------+-----+---+------------+
I just need it to sum all of the counts ix rows above it. I have tried using a udf to figure out the 'count' input into the lag function, but I keep getting a "'Column' object is not callable" error, plus it doesn't do the sum of all of the rows. I have also tried using a loop but that seems impossible because it will make a new dataframe each time through, plus I would need to join them all afterwards. There must be an easier and simpler way to do this. Perhaps a different function than lag?

The lag returns you a certain single row before your current value, but you need a range to calculate the cummulative sum. Therefore you have to use the window function rangeBetween (rowsBetween). Have a look at the example below:
import pyspark.sql.functions as F
from pyspark.sql import Window
l = [
('09/25/2018', '5c40c8510fb7c017', 7, 1),
('09/25/2018', '5bdb2b543951bf07', 9, 1),
('11/28/2017', '593b0d9f3f21f9dd', 7, 1),
('11/29/2017', '593b0d9f3f21f9dd', 60, 2),
('01/09/2018', '593b0d9f3f21f9dd', 1, 3),
('04/27/2018', '593b0d9f3f21f9dd', 9, 4),
('09/25/2018', '593b0d9f3f21f9dd', 29, 5),
('11/20/2018', '593b0d9f3f21f9dd', 42, 6),
('12/11/2018', '593b0d9f3f21f9dd', 317, 7),
('01/04/2019', '593b0d9f3f21f9dd', 3, 8),
('02/13/2019', '593b0d9f3f21f9dd', 15, 9),
('04/01/2019', '593b0d9f3f21f9dd', 1, 10)
]
columns = ['Normalized_Date', 'company','count', 'ix']
df=spark.createDataFrame(l, columns)
df = df.withColumn('Normalized_Date', F.to_date(df.Normalized_Date, 'MM/dd/yyyy'))
w = Window.partitionBy('company').orderBy('Normalized_Date').rangeBetween(Window.unboundedPreceding, 0)
df = df.withColumn('Rolling_count', F.sum('count').over(w))
df.show()
Output:
+---------------+----------------+-----+---+-------------+
|Normalized_Date| company|count| ix|Rolling_count|
+---------------+----------------+-----+---+-------------+
| 2018-09-25|5c40c8510fb7c017| 7| 1| 7|
| 2018-09-25|5bdb2b543951bf07| 9| 1| 9|
| 2017-11-28|593b0d9f3f21f9dd| 7| 1| 7|
| 2017-11-29|593b0d9f3f21f9dd| 60| 2| 67|
| 2018-01-09|593b0d9f3f21f9dd| 1| 3| 68|
| 2018-04-27|593b0d9f3f21f9dd| 9| 4| 77|
| 2018-09-25|593b0d9f3f21f9dd| 29| 5| 106|
| 2018-11-20|593b0d9f3f21f9dd| 42| 6| 148|
| 2018-12-11|593b0d9f3f21f9dd| 317| 7| 465|
| 2019-01-04|593b0d9f3f21f9dd| 3| 8| 468|
| 2019-02-13|593b0d9f3f21f9dd| 15| 9| 483|
| 2019-04-01|593b0d9f3f21f9dd| 1| 10| 484|
+---------------+----------------+-----+---+-------------+

try this.
You need the sum of all preceding rows to current row in the window frame.
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.expressions.WindowSpec
import org.apache.spark.sql.functions._
val df = Seq(
("5c40c8510fb7c017", 7, 1),
("5bdb2b543951bf07", 9, 1),
("593b0d9f3f21f9dd", 7, 1),
("593b0d9f3f21f9dd", 60, 2),
("593b0d9f3f21f9dd", 1, 3),
("593b0d9f3f21f9dd", 9, 4),
("593b0d9f3f21f9dd", 29, 5),
("593b0d9f3f21f9dd", 42, 6),
("593b0d9f3f21f9dd", 317, 7),
("593b0d9f3f21f9dd", 3, 8),
("593b0d9f3f21f9dd", 15, 9),
("593b0d9f3f21f9dd", 1, 10)
).toDF("company", "count", "ix")
scala> df.show(false)
+----------------+-----+---+
|company |count|ix |
+----------------+-----+---+
|5c40c8510fb7c017|7 |1 |
|5bdb2b543951bf07|9 |1 |
|593b0d9f3f21f9dd|7 |1 |
|593b0d9f3f21f9dd|60 |2 |
|593b0d9f3f21f9dd|1 |3 |
|593b0d9f3f21f9dd|9 |4 |
|593b0d9f3f21f9dd|29 |5 |
|593b0d9f3f21f9dd|42 |6 |
|593b0d9f3f21f9dd|317 |7 |
|593b0d9f3f21f9dd|3 |8 |
|593b0d9f3f21f9dd|15 |9 |
|593b0d9f3f21f9dd|1 |10 |
+----------------+-----+---+
scala> val overColumns = Window.partitionBy("company").orderBy("ix").rowsBetween(Window.unboundedPreceding, Window.currentRow)
overColumns: org.apache.spark.sql.expressions.WindowSpec = org.apache.spark.sql.expressions.WindowSpec#3ed5e17c
scala> val outputDF = df.withColumn("RollingCount", sum("count").over(overColumns))
outputDF: org.apache.spark.sql.DataFrame = [company: string, count: int ... 2 more fields]
scala> outputDF.show(false)
+----------------+-----+---+------------+
|company |count|ix |RollingCount|
+----------------+-----+---+------------+
|5c40c8510fb7c017|7 |1 |7 |
|5bdb2b543951bf07|9 |1 |9 |
|593b0d9f3f21f9dd|7 |1 |7 |
|593b0d9f3f21f9dd|60 |2 |67 |
|593b0d9f3f21f9dd|1 |3 |68 |
|593b0d9f3f21f9dd|9 |4 |77 |
|593b0d9f3f21f9dd|29 |5 |106 |
|593b0d9f3f21f9dd|42 |6 |148 |
|593b0d9f3f21f9dd|317 |7 |465 |
|593b0d9f3f21f9dd|3 |8 |468 |
|593b0d9f3f21f9dd|15 |9 |483 |
|593b0d9f3f21f9dd|1 |10 |484 |
+----------------+-----+---+------------+

Related

Selecting rows by data corresponding to other rows of the same dataframe

I'm struggling in selecting the rows of my dataframe. The selection is depedening on the data inside the same dataframe.
My dataset looks something like this:
from pyspark.sql.session import SparkSession
sc = SparkSession.builder.getOrCreate()
columns = ['Id', 'ActorId', 'EventId', 'Time']
vals = [(3, 3, 'START', '2020-06-22'),
(4, 3, 'END', '2020-06-24'),
(5, 3, 'OTHER', '2019-01-15'),
(6, 3, 'OTHER', '2020-07-24'),
(7, 3, 'OTHER', '2020-06-23'),
(8, 4, 'START', '2018-01-15'),
(9, 4, 'END', '2019-01-14'),
(10, 4, 'OTHER', '2018-11-14')]
events = sc.createDataFrame(vals,columns)
events.show()
Which results in:
+---+-------+-------+----------+
| Id|ActorId|EventId| Time|
+---+-------+-------+----------+
| 3| 3| START|2020-06-22|
| 4| 3| END|2020-06-24|
| 5| 3| OTHER|2019-01-15|
| 6| 3| OTHER|2020-07-24|
| 7| 3| OTHER|2020-06-23|
| 8| 4| START|2018-01-15|
| 9| 4| END|2019-01-14|
| 10| 4| OTHER|2018-11-14|
+---+-------+-------+----------+
(Bear in mind, that this is just an example -> an extract of the data)
I want to find all rows with EventId==OTHER, where time is not between the START and END Events of the same ActorId.
The result should look like:
+---+-------+-------+----------+
| Id|ActorId|EventID| Time|
+---+-------+-------+----------+
| 5| 3| OTHER|2019-01-15|
| 6| 3| OTHER|2020-07-24|
+---+-------+-------+----------+
Thank you for your help!!!
This will solve your problem - There is only 1 assumption in the below code that START and END in the eventId colum will always appear in the 1st and 2nd line in each group.
_w = W.partitionBy('ActorId').orderBy('ActorId')
events = events.withColumn('start_date', F.first('Time').over(_w))
events = events.withColumn('row_num', F.row_number().over(_w))
events = events.withColumn('end_date', F.when(F.col('row_num') == F.lit('2'), F.col('Time')))
events = events.withColumn('end_date', F.coalesce(F.when(F.col('row_num') == F.lit('2'), F.col('Time')), F.min('end_date').over(_w)))
events = events.withColumn('passed_col', F.when(
(
((F.col('Time').cast(T.TimestampType()) > F.col('start_date').cast(T.TimestampType())) & (F.col('Time').cast(T.TimestampType()) > F.col('end_date').cast(T.TimestampType()))) |
(
(F.col('Time').cast(T.TimestampType()) < F.col('start_date').cast(T.TimestampType()))
& (F.col('Time').cast(T.TimestampType()) < F.col('end_date').cast(T.TimestampType())))),F.lit("Passed")))
events = events.select('Id', 'ActorId', 'EventId', 'Time', 'passed_col')
events.show()
+---+-------+-------+----------+----------+
| Id|ActorId|EventId| Time|passed_col|
+---+-------+-------+----------+----------+
| 3| 3| START|2020-06-22| null|
| 4| 3| END|2020-06-24| null|
| 5| 3| OTHER|2019-01-15| Passed|
| 6| 3| OTHER|2020-07-24| Passed|
| 7| 3| OTHER|2020-06-23| null|
| 8| 4| START|2018-01-15| null|
| 9| 4| END|2019-01-14| null|
| 10| 4| OTHER|2018-11-14| null|
+---+-------+-------+----------+----------+
Final Answer post filtering ---
events = events.filter(F.col('passed_col') == F.lit('Passed')).select('Id', 'ActorId', 'EventId', 'Time', 'passed_col')
events.show()
+---+-------+-------+----------+----------+
| Id|ActorId|EventId| Time|passed_col|
+---+-------+-------+----------+----------+
| 5| 3| OTHER|2019-01-15| Passed|
| 6| 3| OTHER|2020-07-24| Passed|
+---+-------+-------+----------+----------+
val res = vals
.filter('EventId.equalTo("OTHER"))
.filter('ActorId.equalTo(3))
.filter(!'Time.between("2020-06-01","2020-06-25"))
res.show(false)
// +---+-------+-------+----------+
// |Id |ActorId|EventId|Time |
// +---+-------+-------+----------+
// |5 |3 |OTHER |2019-01-15|
// |6 |3 |OTHER |2020-07-24|
// +---+-------+-------+----------+
or
val res = vals
.filter('EventId.equalTo("OTHER"))
.filter(!'Time.between("2018-01-01","2018-12-31"))
.filter(!'Time.between("2020-06-01","2020-06-25"))

Set literal value over Window if condition suited Spark Scala

I need to check a condition over a window:
- If the column IND_DEF is 20, then I want to change the value of the column premium for the window to which this register belongs to, and set it to 1.
My initial Dataframe looks like this:
+--------+----+-------+-----+-------+
|policyId|name|premium|state|IND_DEF|
+--------+----+-------+-----+-------+
| 1| BK| null| KT| 40|
| 1| AK| -31| null| 30|
| 1| VZ| null| IL| 20|
| 2| VK| 32| LI| 7|
| 2| CK| 25| YNZ| 10|
| 2| CK| 0| null| 5|
| 2| VK| 30| IL| 25|
+--------+----+-------+-----+-------+
And I want to achieve this:
+--------+----+-------+-----+-------+
|policyId|name|premium|state|IND_DEF|
+--------+----+-------+-----+-------+
| 1| BK| 1| KT| 40|
| 1| AK| 1| null| 30|
| 1| VZ| 1| IL| 20|
| 2| VK| 32| LI| 7|
| 2| CK| 25| YNZ| 10|
| 2| CK| 0| null| 5|
| 2| VK| 30| IL| 25|
+--------+----+-------+-----+-------+
I am trying the following code but does not work...
val df_946 = Seq [(Int, String, Integer, String, Int)]((1,"VZ",null,"IL",20),(1, "AK", -31,null,30),(1,"BK", null,"KT",40),(2,"CK",0,null,5),(2,"CK",25,"YNZ",10),(2,"VK",30,"IL",25),(2,"VK",32,"LI",7)).toDF("policyId", "name", "premium", "state","IND_DEF").orderBy("policyId")
val winSpec = Window.partitionBy("policyId").orderBy("policyId")
val df_947 = df_946.withColumn("premium",when(col("IND_DEF") === 20,lit(1).over(winSpec)).otherwise(col("premium")))
You can generate an array of IND_DEF values via collect_list for each window partition and recreate column premium based on the array_contains condition:
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import spark.implicits._
val df = Seq(
(1, None, 40),
(1, Some(-31), 30),
(1, None, 20),
(2, Some(32), 7),
(2, Some(30), 10)
).toDF("policyId", "premium", "IND_DEF")
val win = Window.partitionBy($"policyId")
df.
withColumn("indList", collect_list($"IND_DEF").over(win)).
withColumn("premium", when(array_contains($"indList", 20), 1).otherwise($"premium")).
drop($"indList").
show
// +--------+-------+-------+
// |policyId|premium|IND_DEF|
// +--------+-------+-------+
// | 1| 1| 40|
// | 1| 1| 30|
// | 1| 1| 20|
// | 2| 32| 7|
// | 2| 30| 10|
// +--------+-------+-------+

How do I replace null values of multiple columns with values from multiple different columns

I have a data frame like below
data = [
(1, None,7,10,11,19),
(1, 4,None,10,43,58),
(None, 4,7,67,88,91),
(1, None,7,78,96,32)
]
df = spark.createDataFrame(data, ["A_min", "B_min","C_min","A_max", "B_max","C_max"])
df.show()
and I would want the columns which show name as 'min' to be replaced by their equivalent max column.
Example null values of A_min column should be replaced by A_max column
It should be like the data frame below.
+-----+-----+-----+-----+-----+-----+
|A_min|B_min|C_min|A_max|B_max|C_max|
+-----+-----+-----+-----+-----+-----+
| 1| 11| 7| 10| 11| 19|
| 1| 4| 58| 10| 43| 58|
| 67| 4| 7| 67| 88| 91|
| 1| 96| 7| 78| 96| 32|
+-----+-----+-----+-----+-----+-----+
I have tried the code below by defining the columns but clearly this does not work. Really appreciate any help.
min_cols = ["A_min", "B_min","C_min"]
max_cols = ["A_max", "B_max","C_max"]
for i in min_cols
df = df.withColumn(i,when(f.col(i)=='',max_cols.otherwise(col(i))))
display(df)
Assuming you have the same number of max and min columns, you can use coalesce along with python's list comprehension to obtain your solution
from pyspark.sql.functions import coalesce
min_cols = ["A_min", "B_min","C_min"]
max_cols = ["A_max", "B_max","C_max"]
df.select(*[coalesce(df[val], df[max_cols[pos]]).alias(val) for pos, val in enumerate(min_cols)], *max_cols).show()
Output:
+-----+-----+-----+-----+-----+-----+
|A_min|B_min|C_min|A_max|B_max|C_max|
+-----+-----+-----+-----+-----+-----+
| 1| 11| 7| 10| 11| 19|
| 1| 4| 58| 10| 43| 58|
| 67| 4| 7| 67| 88| 91|
| 1| 96| 7| 78| 96| 32|
+-----+-----+-----+-----+-----+-----+

Dataframe get first and last value of corresponding column

Is it possible to get first value of the corresponding column within subgroup.
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.{Window, WindowSpec}
object tmp {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local").getOrCreate()
import spark.implicits._
val input = Seq(
(1235, 1, 1101, 0),
(1235, 2, 1102, 0),
(1235, 3, 1103, 1),
(1235, 4, 1104, 1),
(1235, 5, 1105, 0),
(1235, 6, 1106, 0),
(1235, 7, 1107, 1),
(1235, 8, 1108, 1),
(1235, 9, 1109, 1),
(1235, 10, 1110, 0),
(1235, 11, 1111, 0)
).toDF("SERVICE_ID", "COUNTER", "EVENT_ID", "FLAG")
lazy val window: WindowSpec = Window.partitionBy("SERVICE_ID").orderBy("COUNTER")
val firsts = input.withColumn("first_value", first("EVENT_ID", ignoreNulls = true).over(window.rangeBetween(Long.MinValue, Long.MaxValue)))
firsts.orderBy("SERVICE_ID", "COUNTER").show()
}
}
Output I want.
First (or Previous) value of column EVENT_ID based on FLAG = 1
And
Last (or Next ) value of column EVENT_ID based on FLAG = 1
partition by SERVICE_ID sorted by counter
+----------+-------+--------+----+-----------+-----------+
|SERVICE_ID|COUNTER|EVENT_ID|FLAG|first_value|last_value|
+----------+-------+--------+----+-----------+-----------+
| 1235| 1| 1101| 0| 0| 1103|
| 1235| 2| 1102| 0| 0| 1103|
| 1235| 3| 1103| 1| 0| 1106|
| 1235| 4| 1104| 0| 1103| 1106|
| 1235| 5| 1105| 0| 1103| 1106|
| 1235| 6| 1106| 1| 0| 1108|
| 1235| 7| 1107| 0| 1106| 1108|
| 1235| 8| 1108| 1| 0| 1109|
| 1235| 9| 1109| 1| 0| 1110|
| 1235| 10| 1110| 1| 0| 0|
| 1235| 11| 1111| 0| 1110| 0|
| 1235| 12| 1112| 0| 1110| 0|
+----------+-------+--------+----+-----------+-----------+
First the dataframe need to be formed into groups. A new group starts at each time the "TIME" column equals 1. To do this, first add a column "ID" to the dataframe:
lazy val window: WindowSpec = Window.partitionBy("SERVICE_ID").orderBy("COUNTER")
val df_flag = input.filter($"FLAG" === 1)
.withColumn("ID", row_number().over(window))
val df_other = input.filter($"FLAG" =!= 1)
.withColumn("ID", lit(0))
// Create a group for each flag event
val df = df_flag.union(df_other)
.withColumn("ID", max("ID").over(window.rowsBetween(Long.MinValue, 0)))
.cache()
df.show() gives:
+----------+-------+--------+----+---+
|SERVICE_ID|COUNTER|EVENT_ID|FLAG| ID|
+----------+-------+--------+----+---+
| 1235| 1| 1111| 1| 1|
| 1235| 2| 1112| 0| 1|
| 1235| 3| 1114| 0| 1|
| 1235| 4| 2221| 1| 2|
| 1235| 5| 2225| 0| 2|
| 1235| 6| 2226| 0| 2|
| 1235| 7| 2227| 1| 3|
+----------+-------+--------+----+---+
Now that we have a column separating the events, we need to add the correct "EVENT_ID" (renamed "first_value") to each event. In addition to the "first_value", calculate and add a second column "last_value", which is the id of the next flagged event.
val df_event = df.filter($"FLAG" === 1)
.select("EVENT_ID", "ID", "SERVICE_ID", "COUNTER")
.withColumnRenamed("EVENT_ID", "first_value")
.withColumn("last_value", lead($"first_value",1,0).over(window))
.drop("COUNTER")
val df_final = df.join(df_event, Seq("ID", "SERVICE_ID"))
.drop("ID")
.withColumn("first_value", when($"FLAG" === 1, lit(0)).otherwise($"first_value"))
df_final.show() gives us:
+----------+-------+--------+----+-----------+----------+
|SERVICE_ID|COUNTER|EVENT_ID|FLAG|first_value|last_value|
+----------+-------+--------+----+-----------+----------+
| 1235| 1| 1111| 1| 0| 2221|
| 1235| 2| 1112| 0| 1111| 2221|
| 1235| 3| 1114| 0| 1111| 2221|
| 1235| 4| 2221| 1| 0| 2227|
| 1235| 5| 2225| 0| 2221| 2227|
| 1235| 6| 2226| 0| 2221| 2227|
| 1235| 7| 2227| 1| 0| 0|
+----------+-------+--------+----+-----------+----------+
Can be solved in two steps:
get events with "FLAG" == 1 and valid range for this event;
join 1. with input, by range.
Some column renaming included for visibility, can be shortened:
val window = Window.partitionBy("SERVICE_ID").orderBy("COUNTER").rowsBetween(Window.currentRow, 1)
val eventRangeDF = input.where($"FLAG" === 1)
.withColumn("RANGE_END", max($"COUNTER").over(window))
.withColumnRenamed("COUNTER", "RANGE_START")
.select("SERVICE_ID", "EVENT_ID", "RANGE_START", "RANGE_END")
eventRangeDF.show(false)
val result = input.where($"FLAG" === 0).as("i").join(eventRangeDF.as("e"),
expr("e.SERVICE_ID=i.SERVICE_ID And i.COUNTER>e.RANGE_START and i.COUNTER<e.RANGE_END"))
.select($"i.SERVICE_ID", $"i.COUNTER", $"i.EVENT_ID", $"i.FLAG", $"e.EVENT_ID".alias("first_value"))
// include FLAG=1
.union(input.where($"FLAG" === 1).select($"SERVICE_ID", $"COUNTER", $"EVENT_ID", $"FLAG", lit(0).alias("first_value")))
result.sort("COUNTER").show(false)
Output:
+----------+--------+-----------+---------+
|SERVICE_ID|EVENT_ID|RANGE_START|RANGE_END|
+----------+--------+-----------+---------+
|1235 |1111 |1 |4 |
|1235 |2221 |4 |7 |
|1235 |2227 |7 |7 |
+----------+--------+-----------+---------+
+----------+-------+--------+----+-----------+
|SERVICE_ID|COUNTER|EVENT_ID|FLAG|first_value|
+----------+-------+--------+----+-----------+
|1235 |1 |1111 |1 |0 |
|1235 |2 |1112 |0 |1111 |
|1235 |3 |1114 |0 |1111 |
|1235 |4 |2221 |1 |0 |
|1235 |5 |2225 |0 |2221 |
|1235 |6 |2226 |0 |2221 |
|1235 |7 |2227 |1 |0 |
+----------+-------+--------+----+-----------+

spark sql conditional maximum

I have a tall table which contains up to 10 values per group. How can I transform this table into a wide format i.e. add 2 columns where these resemble the value smaller or equal to a threshold?
I want to find the maximum per group, but it needs to be smaller than a specified value like:
min(max('value1), lit(5)).over(Window.partitionBy('grouping))
However min()will only work for a column and not for the Scala value which is returned from the inner function?
The problem can be described as:
Seq(Seq(1,2,3,4).max,5).min
Where Seq(1,2,3,4) is returned by the window.
How can I formulate this in spark sql?
edit
E.g.
+--------+-----+---------+
|grouping|value|something|
+--------+-----+---------+
| 1| 1| first|
| 1| 2| second|
| 1| 3| third|
| 1| 4| fourth|
| 1| 7| 7|
| 1| 10| 10|
| 21| 1| first|
| 21| 2| second|
| 21| 3| third|
+--------+-----+---------+
created by
case class MyThing(grouping: Int, value:Int, something:String)
val df = Seq(MyThing(1,1, "first"), MyThing(1,2, "second"), MyThing(1,3, "third"),MyThing(1,4, "fourth"),MyThing(1,7, "7"), MyThing(1,10, "10"),
MyThing(21,1, "first"), MyThing(21,2, "second"), MyThing(21,3, "third")).toDS
Where
df
.withColumn("somethingAtLeast5AndMaximum5", max('value).over(Window.partitionBy('grouping)))
.withColumn("somethingAtLeast6OupToThereshold2", max('value).over(Window.partitionBy('grouping)))
.show
returns
+--------+-----+---------+----------------------------+-------------------------+
|grouping|value|something|somethingAtLeast5AndMaximum5| somethingAtLeast6OupToThereshold2 |
+--------+-----+---------+----------------------------+-------------------------+
| 1| 1| first| 10| 10|
| 1| 2| second| 10| 10|
| 1| 3| third| 10| 10|
| 1| 4| fourth| 10| 10|
| 1| 7| 7| 10| 10|
| 1| 10| 10| 10| 10|
| 21| 1| first| 3| 3|
| 21| 2| second| 3| 3|
| 21| 3| third| 3| 3|
+--------+-----+---------+----------------------------+-------------------------+
Instead, I rather would want to formulate:
lit(Seq(max('value).asInstanceOf[java.lang.Integer], new java.lang.Integer(2)).min).over(Window.partitionBy('grouping))
But that does not work as max('value) is not a scalar value.
Expected output should look like
+--------+-----+---------+----------------------------+-------------------------+
|grouping|value|something|somethingAtLeast5AndMaximum5|somethingAtLeast6OupToThereshold2|
+--------+-----+---------+----------------------------+-------------------------+
| 1| 4| fourth| 4| 7|
| 21| 1| first| 3| NULL|
+--------+-----+---------+----------------------------+-------------------------+
edit2
When trying a pivot
df.groupBy("grouping").pivot("value").agg(first('something)).show
+--------+-----+------+-----+------+----+----+
|grouping| 1| 2| 3| 4| 7| 10|
+--------+-----+------+-----+------+----+----+
| 1|first|second|third|fourth| 7| 10|
| 21|first|second|third| null|null|null|
+--------+-----+------+-----+------+----+----+
The second part of the problem remains that some columns might not exist or be null.
When aggregating to arrays:
df.groupBy("grouping").agg(collect_list('value).alias("value"), collect_list('something).alias("something"))
+--------+-------------------+--------------------+
|grouping| value| something|
+--------+-------------------+--------------------+
| 1|[1, 2, 3, 4, 7, 10]|[first, second, t...|
| 21| [1, 2, 3]|[first, second, t...|
+--------+-------------------+--------------------+
The values are already next to each other, but the right values need to be selected. This is probably still more efficient than a join or window function.
Would be easier to do in two separate steps - calculate max over Window, and then use when...otherwise on result to produce min(x, 5):
df.withColumn("tmp", max('value1).over(Window.partitionBy('grouping)))
.withColumn("result", when('tmp > lit(5), 5).otherwise('tmp))
EDIT: some example data to clarify this:
val df = Seq((1, 1),(1, 2),(1, 3),(1, 4),(2, 7),(2, 8))
.toDF("grouping", "value1")
df.withColumn("result", max('value1).over(Window.partitionBy('grouping)))
.withColumn("result", when('result > lit(5), 5).otherwise('result))
.show()
// +--------+------+------+
// |grouping|value1|result|
// +--------+------+------+
// | 1| 1| 4| // 4, because Seq(Seq(1,2,3,4).max,5).min = 4
// | 1| 2| 4|
// | 1| 3| 4|
// | 1| 4| 4|
// | 2| 7| 5| // 5, because Seq(Seq(7,8).max,5).min = 5
// | 2| 8| 5|
// +--------+------+------+