I need to create a 'rolling count' column which takes the previous count and adds the new count for each day and company. I have already organized and sorted the dataframe into groups of ascending dates per company with the corresponding count. I also added a 'ix' column which indexes each grouping, like so:
+--------------------+--------------------+-----+---+
| Normalized_Date| company|count| ix|
+--------------------+--------------------+-----+---+
|09/25/2018 00:00:...|[5c40c8510fb7c017...| 7| 1|
|09/25/2018 00:00:...|[5bdb2b543951bf07...| 9| 1|
|11/28/2017 00:00:...|[593b0d9f3f21f9dd...| 7| 1|
|11/29/2017 00:00:...|[593b0d9f3f21f9dd...| 60| 2|
|01/09/2018 00:00:...|[593b0d9f3f21f9dd...| 1| 3|
|04/27/2018 00:00:...|[593b0d9f3f21f9dd...| 9| 4|
|09/25/2018 00:00:...|[593b0d9f3f21f9dd...| 29| 5|
|11/20/2018 00:00:...|[593b0d9f3f21f9dd...| 42| 6|
|12/11/2018 00:00:...|[593b0d9f3f21f9dd...| 317| 7|
|01/04/2019 00:00:...|[593b0d9f3f21f9dd...| 3| 8|
|02/13/2019 00:00:...|[593b0d9f3f21f9dd...| 15| 9|
|04/01/2019 00:00:...|[593b0d9f3f21f9dd...| 1| 10|
+--------------------+--------------------+-----+---+
The output I need would simply add up all the counts up to that date for each company. Like so:
+--------------------+--------------------+-----+---+------------+
| Normalized_Date| company|count| ix|RollingCount|
+--------------------+--------------------+-----+---+------------+
|09/25/2018 00:00:...|[5c40c8510fb7c017...| 7| 1| 7|
|09/25/2018 00:00:...|[5bdb2b543951bf07...| 9| 1| 9|
|11/28/2017 00:00:...|[593b0d9f3f21f9dd...| 7| 1| 7|
|11/29/2017 00:00:...|[593b0d9f3f21f9dd...| 60| 2| 67|
|01/09/2018 00:00:...|[593b0d9f3f21f9dd...| 1| 3| 68|
|04/27/2018 00:00:...|[593b0d9f3f21f9dd...| 9| 4| 77|
|09/25/2018 00:00:...|[593b0d9f3f21f9dd...| 29| 5| 106|
|11/20/2018 00:00:...|[593b0d9f3f21f9dd...| 42| 6| 148|
|12/11/2018 00:00:...|[593b0d9f3f21f9dd...| 317| 7| 465|
|01/04/2019 00:00:...|[593b0d9f3f21f9dd...| 3| 8| 468|
|02/13/2019 00:00:...|[593b0d9f3f21f9dd...| 15| 9| 483|
|04/01/2019 00:00:...|[593b0d9f3f21f9dd...| 1| 10| 484|
+--------------------+--------------------+-----+---+------------+
I figured the lag function would be of use, and I was able to get each row of rollingcount with ix > 1 to add the count directly above it with the following code:
w = Window.partitionBy('company').orderBy(F.unix_timestamp('Normalized_Dat e','MM/dd/yyyy HH:mm:ss aaa').cast('timestamp'))
refined_DF = solutionDF.withColumn("rn", F.row_number().over(w))
solutionDF = refined_DF.withColumn('RollingCount',F.when(refined_DF['rn'] > 1, refined_DF['count'] + F.lag(refined_DF['count'],count= 1 ).over(w)).otherwise(refined_DF['count']))
which yields the following df:
+--------------------+--------------------+-----+---+------------+
| Normalized_Date| company|count| ix|RollingCount|
+--------------------+--------------------+-----+---+------------+
|09/25/2018 00:00:...|[5c40c8510fb7c017...| 7| 1| 7|
|09/25/2018 00:00:...|[5bdb2b543951bf07...| 9| 1| 9|
|11/28/2017 00:00:...|[593b0d9f3f21f9dd...| 7| 1| 7|
|11/29/2017 00:00:...|[593b0d9f3f21f9dd...| 60| 2| 67|
|01/09/2018 00:00:...|[593b0d9f3f21f9dd...| 1| 3| 61|
|04/27/2018 00:00:...|[593b0d9f3f21f9dd...| 9| 4| 10|
|09/25/2018 00:00:...|[593b0d9f3f21f9dd...| 29| 5| 38|
|11/20/2018 00:00:...|[593b0d9f3f21f9dd...| 42| 6| 71|
|12/11/2018 00:00:...|[593b0d9f3f21f9dd...| 317| 7| 359|
|01/04/2019 00:00:...|[593b0d9f3f21f9dd...| 3| 8| 320|
|02/13/2019 00:00:...|[593b0d9f3f21f9dd...| 15| 9| 18|
|04/01/2019 00:00:...|[593b0d9f3f21f9dd...| 1| 10| 16|
+--------------------+--------------------+-----+---+------------+
I just need it to sum all of the counts ix rows above it. I have tried using a udf to figure out the 'count' input into the lag function, but I keep getting a "'Column' object is not callable" error, plus it doesn't do the sum of all of the rows. I have also tried using a loop but that seems impossible because it will make a new dataframe each time through, plus I would need to join them all afterwards. There must be an easier and simpler way to do this. Perhaps a different function than lag?
The lag returns you a certain single row before your current value, but you need a range to calculate the cummulative sum. Therefore you have to use the window function rangeBetween (rowsBetween). Have a look at the example below:
import pyspark.sql.functions as F
from pyspark.sql import Window
l = [
('09/25/2018', '5c40c8510fb7c017', 7, 1),
('09/25/2018', '5bdb2b543951bf07', 9, 1),
('11/28/2017', '593b0d9f3f21f9dd', 7, 1),
('11/29/2017', '593b0d9f3f21f9dd', 60, 2),
('01/09/2018', '593b0d9f3f21f9dd', 1, 3),
('04/27/2018', '593b0d9f3f21f9dd', 9, 4),
('09/25/2018', '593b0d9f3f21f9dd', 29, 5),
('11/20/2018', '593b0d9f3f21f9dd', 42, 6),
('12/11/2018', '593b0d9f3f21f9dd', 317, 7),
('01/04/2019', '593b0d9f3f21f9dd', 3, 8),
('02/13/2019', '593b0d9f3f21f9dd', 15, 9),
('04/01/2019', '593b0d9f3f21f9dd', 1, 10)
]
columns = ['Normalized_Date', 'company','count', 'ix']
df=spark.createDataFrame(l, columns)
df = df.withColumn('Normalized_Date', F.to_date(df.Normalized_Date, 'MM/dd/yyyy'))
w = Window.partitionBy('company').orderBy('Normalized_Date').rangeBetween(Window.unboundedPreceding, 0)
df = df.withColumn('Rolling_count', F.sum('count').over(w))
df.show()
Output:
+---------------+----------------+-----+---+-------------+
|Normalized_Date| company|count| ix|Rolling_count|
+---------------+----------------+-----+---+-------------+
| 2018-09-25|5c40c8510fb7c017| 7| 1| 7|
| 2018-09-25|5bdb2b543951bf07| 9| 1| 9|
| 2017-11-28|593b0d9f3f21f9dd| 7| 1| 7|
| 2017-11-29|593b0d9f3f21f9dd| 60| 2| 67|
| 2018-01-09|593b0d9f3f21f9dd| 1| 3| 68|
| 2018-04-27|593b0d9f3f21f9dd| 9| 4| 77|
| 2018-09-25|593b0d9f3f21f9dd| 29| 5| 106|
| 2018-11-20|593b0d9f3f21f9dd| 42| 6| 148|
| 2018-12-11|593b0d9f3f21f9dd| 317| 7| 465|
| 2019-01-04|593b0d9f3f21f9dd| 3| 8| 468|
| 2019-02-13|593b0d9f3f21f9dd| 15| 9| 483|
| 2019-04-01|593b0d9f3f21f9dd| 1| 10| 484|
+---------------+----------------+-----+---+-------------+
try this.
You need the sum of all preceding rows to current row in the window frame.
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.expressions.WindowSpec
import org.apache.spark.sql.functions._
val df = Seq(
("5c40c8510fb7c017", 7, 1),
("5bdb2b543951bf07", 9, 1),
("593b0d9f3f21f9dd", 7, 1),
("593b0d9f3f21f9dd", 60, 2),
("593b0d9f3f21f9dd", 1, 3),
("593b0d9f3f21f9dd", 9, 4),
("593b0d9f3f21f9dd", 29, 5),
("593b0d9f3f21f9dd", 42, 6),
("593b0d9f3f21f9dd", 317, 7),
("593b0d9f3f21f9dd", 3, 8),
("593b0d9f3f21f9dd", 15, 9),
("593b0d9f3f21f9dd", 1, 10)
).toDF("company", "count", "ix")
scala> df.show(false)
+----------------+-----+---+
|company |count|ix |
+----------------+-----+---+
|5c40c8510fb7c017|7 |1 |
|5bdb2b543951bf07|9 |1 |
|593b0d9f3f21f9dd|7 |1 |
|593b0d9f3f21f9dd|60 |2 |
|593b0d9f3f21f9dd|1 |3 |
|593b0d9f3f21f9dd|9 |4 |
|593b0d9f3f21f9dd|29 |5 |
|593b0d9f3f21f9dd|42 |6 |
|593b0d9f3f21f9dd|317 |7 |
|593b0d9f3f21f9dd|3 |8 |
|593b0d9f3f21f9dd|15 |9 |
|593b0d9f3f21f9dd|1 |10 |
+----------------+-----+---+
scala> val overColumns = Window.partitionBy("company").orderBy("ix").rowsBetween(Window.unboundedPreceding, Window.currentRow)
overColumns: org.apache.spark.sql.expressions.WindowSpec = org.apache.spark.sql.expressions.WindowSpec#3ed5e17c
scala> val outputDF = df.withColumn("RollingCount", sum("count").over(overColumns))
outputDF: org.apache.spark.sql.DataFrame = [company: string, count: int ... 2 more fields]
scala> outputDF.show(false)
+----------------+-----+---+------------+
|company |count|ix |RollingCount|
+----------------+-----+---+------------+
|5c40c8510fb7c017|7 |1 |7 |
|5bdb2b543951bf07|9 |1 |9 |
|593b0d9f3f21f9dd|7 |1 |7 |
|593b0d9f3f21f9dd|60 |2 |67 |
|593b0d9f3f21f9dd|1 |3 |68 |
|593b0d9f3f21f9dd|9 |4 |77 |
|593b0d9f3f21f9dd|29 |5 |106 |
|593b0d9f3f21f9dd|42 |6 |148 |
|593b0d9f3f21f9dd|317 |7 |465 |
|593b0d9f3f21f9dd|3 |8 |468 |
|593b0d9f3f21f9dd|15 |9 |483 |
|593b0d9f3f21f9dd|1 |10 |484 |
+----------------+-----+---+------------+
Related
I'm struggling in selecting the rows of my dataframe. The selection is depedening on the data inside the same dataframe.
My dataset looks something like this:
from pyspark.sql.session import SparkSession
sc = SparkSession.builder.getOrCreate()
columns = ['Id', 'ActorId', 'EventId', 'Time']
vals = [(3, 3, 'START', '2020-06-22'),
(4, 3, 'END', '2020-06-24'),
(5, 3, 'OTHER', '2019-01-15'),
(6, 3, 'OTHER', '2020-07-24'),
(7, 3, 'OTHER', '2020-06-23'),
(8, 4, 'START', '2018-01-15'),
(9, 4, 'END', '2019-01-14'),
(10, 4, 'OTHER', '2018-11-14')]
events = sc.createDataFrame(vals,columns)
events.show()
Which results in:
+---+-------+-------+----------+
| Id|ActorId|EventId| Time|
+---+-------+-------+----------+
| 3| 3| START|2020-06-22|
| 4| 3| END|2020-06-24|
| 5| 3| OTHER|2019-01-15|
| 6| 3| OTHER|2020-07-24|
| 7| 3| OTHER|2020-06-23|
| 8| 4| START|2018-01-15|
| 9| 4| END|2019-01-14|
| 10| 4| OTHER|2018-11-14|
+---+-------+-------+----------+
(Bear in mind, that this is just an example -> an extract of the data)
I want to find all rows with EventId==OTHER, where time is not between the START and END Events of the same ActorId.
The result should look like:
+---+-------+-------+----------+
| Id|ActorId|EventID| Time|
+---+-------+-------+----------+
| 5| 3| OTHER|2019-01-15|
| 6| 3| OTHER|2020-07-24|
+---+-------+-------+----------+
Thank you for your help!!!
This will solve your problem - There is only 1 assumption in the below code that START and END in the eventId colum will always appear in the 1st and 2nd line in each group.
_w = W.partitionBy('ActorId').orderBy('ActorId')
events = events.withColumn('start_date', F.first('Time').over(_w))
events = events.withColumn('row_num', F.row_number().over(_w))
events = events.withColumn('end_date', F.when(F.col('row_num') == F.lit('2'), F.col('Time')))
events = events.withColumn('end_date', F.coalesce(F.when(F.col('row_num') == F.lit('2'), F.col('Time')), F.min('end_date').over(_w)))
events = events.withColumn('passed_col', F.when(
(
((F.col('Time').cast(T.TimestampType()) > F.col('start_date').cast(T.TimestampType())) & (F.col('Time').cast(T.TimestampType()) > F.col('end_date').cast(T.TimestampType()))) |
(
(F.col('Time').cast(T.TimestampType()) < F.col('start_date').cast(T.TimestampType()))
& (F.col('Time').cast(T.TimestampType()) < F.col('end_date').cast(T.TimestampType())))),F.lit("Passed")))
events = events.select('Id', 'ActorId', 'EventId', 'Time', 'passed_col')
events.show()
+---+-------+-------+----------+----------+
| Id|ActorId|EventId| Time|passed_col|
+---+-------+-------+----------+----------+
| 3| 3| START|2020-06-22| null|
| 4| 3| END|2020-06-24| null|
| 5| 3| OTHER|2019-01-15| Passed|
| 6| 3| OTHER|2020-07-24| Passed|
| 7| 3| OTHER|2020-06-23| null|
| 8| 4| START|2018-01-15| null|
| 9| 4| END|2019-01-14| null|
| 10| 4| OTHER|2018-11-14| null|
+---+-------+-------+----------+----------+
Final Answer post filtering ---
events = events.filter(F.col('passed_col') == F.lit('Passed')).select('Id', 'ActorId', 'EventId', 'Time', 'passed_col')
events.show()
+---+-------+-------+----------+----------+
| Id|ActorId|EventId| Time|passed_col|
+---+-------+-------+----------+----------+
| 5| 3| OTHER|2019-01-15| Passed|
| 6| 3| OTHER|2020-07-24| Passed|
+---+-------+-------+----------+----------+
val res = vals
.filter('EventId.equalTo("OTHER"))
.filter('ActorId.equalTo(3))
.filter(!'Time.between("2020-06-01","2020-06-25"))
res.show(false)
// +---+-------+-------+----------+
// |Id |ActorId|EventId|Time |
// +---+-------+-------+----------+
// |5 |3 |OTHER |2019-01-15|
// |6 |3 |OTHER |2020-07-24|
// +---+-------+-------+----------+
or
val res = vals
.filter('EventId.equalTo("OTHER"))
.filter(!'Time.between("2018-01-01","2018-12-31"))
.filter(!'Time.between("2020-06-01","2020-06-25"))
I need to check a condition over a window:
- If the column IND_DEF is 20, then I want to change the value of the column premium for the window to which this register belongs to, and set it to 1.
My initial Dataframe looks like this:
+--------+----+-------+-----+-------+
|policyId|name|premium|state|IND_DEF|
+--------+----+-------+-----+-------+
| 1| BK| null| KT| 40|
| 1| AK| -31| null| 30|
| 1| VZ| null| IL| 20|
| 2| VK| 32| LI| 7|
| 2| CK| 25| YNZ| 10|
| 2| CK| 0| null| 5|
| 2| VK| 30| IL| 25|
+--------+----+-------+-----+-------+
And I want to achieve this:
+--------+----+-------+-----+-------+
|policyId|name|premium|state|IND_DEF|
+--------+----+-------+-----+-------+
| 1| BK| 1| KT| 40|
| 1| AK| 1| null| 30|
| 1| VZ| 1| IL| 20|
| 2| VK| 32| LI| 7|
| 2| CK| 25| YNZ| 10|
| 2| CK| 0| null| 5|
| 2| VK| 30| IL| 25|
+--------+----+-------+-----+-------+
I am trying the following code but does not work...
val df_946 = Seq [(Int, String, Integer, String, Int)]((1,"VZ",null,"IL",20),(1, "AK", -31,null,30),(1,"BK", null,"KT",40),(2,"CK",0,null,5),(2,"CK",25,"YNZ",10),(2,"VK",30,"IL",25),(2,"VK",32,"LI",7)).toDF("policyId", "name", "premium", "state","IND_DEF").orderBy("policyId")
val winSpec = Window.partitionBy("policyId").orderBy("policyId")
val df_947 = df_946.withColumn("premium",when(col("IND_DEF") === 20,lit(1).over(winSpec)).otherwise(col("premium")))
You can generate an array of IND_DEF values via collect_list for each window partition and recreate column premium based on the array_contains condition:
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import spark.implicits._
val df = Seq(
(1, None, 40),
(1, Some(-31), 30),
(1, None, 20),
(2, Some(32), 7),
(2, Some(30), 10)
).toDF("policyId", "premium", "IND_DEF")
val win = Window.partitionBy($"policyId")
df.
withColumn("indList", collect_list($"IND_DEF").over(win)).
withColumn("premium", when(array_contains($"indList", 20), 1).otherwise($"premium")).
drop($"indList").
show
// +--------+-------+-------+
// |policyId|premium|IND_DEF|
// +--------+-------+-------+
// | 1| 1| 40|
// | 1| 1| 30|
// | 1| 1| 20|
// | 2| 32| 7|
// | 2| 30| 10|
// +--------+-------+-------+
I have a data frame like below
data = [
(1, None,7,10,11,19),
(1, 4,None,10,43,58),
(None, 4,7,67,88,91),
(1, None,7,78,96,32)
]
df = spark.createDataFrame(data, ["A_min", "B_min","C_min","A_max", "B_max","C_max"])
df.show()
and I would want the columns which show name as 'min' to be replaced by their equivalent max column.
Example null values of A_min column should be replaced by A_max column
It should be like the data frame below.
+-----+-----+-----+-----+-----+-----+
|A_min|B_min|C_min|A_max|B_max|C_max|
+-----+-----+-----+-----+-----+-----+
| 1| 11| 7| 10| 11| 19|
| 1| 4| 58| 10| 43| 58|
| 67| 4| 7| 67| 88| 91|
| 1| 96| 7| 78| 96| 32|
+-----+-----+-----+-----+-----+-----+
I have tried the code below by defining the columns but clearly this does not work. Really appreciate any help.
min_cols = ["A_min", "B_min","C_min"]
max_cols = ["A_max", "B_max","C_max"]
for i in min_cols
df = df.withColumn(i,when(f.col(i)=='',max_cols.otherwise(col(i))))
display(df)
Assuming you have the same number of max and min columns, you can use coalesce along with python's list comprehension to obtain your solution
from pyspark.sql.functions import coalesce
min_cols = ["A_min", "B_min","C_min"]
max_cols = ["A_max", "B_max","C_max"]
df.select(*[coalesce(df[val], df[max_cols[pos]]).alias(val) for pos, val in enumerate(min_cols)], *max_cols).show()
Output:
+-----+-----+-----+-----+-----+-----+
|A_min|B_min|C_min|A_max|B_max|C_max|
+-----+-----+-----+-----+-----+-----+
| 1| 11| 7| 10| 11| 19|
| 1| 4| 58| 10| 43| 58|
| 67| 4| 7| 67| 88| 91|
| 1| 96| 7| 78| 96| 32|
+-----+-----+-----+-----+-----+-----+
Is it possible to get first value of the corresponding column within subgroup.
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.{Window, WindowSpec}
object tmp {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local").getOrCreate()
import spark.implicits._
val input = Seq(
(1235, 1, 1101, 0),
(1235, 2, 1102, 0),
(1235, 3, 1103, 1),
(1235, 4, 1104, 1),
(1235, 5, 1105, 0),
(1235, 6, 1106, 0),
(1235, 7, 1107, 1),
(1235, 8, 1108, 1),
(1235, 9, 1109, 1),
(1235, 10, 1110, 0),
(1235, 11, 1111, 0)
).toDF("SERVICE_ID", "COUNTER", "EVENT_ID", "FLAG")
lazy val window: WindowSpec = Window.partitionBy("SERVICE_ID").orderBy("COUNTER")
val firsts = input.withColumn("first_value", first("EVENT_ID", ignoreNulls = true).over(window.rangeBetween(Long.MinValue, Long.MaxValue)))
firsts.orderBy("SERVICE_ID", "COUNTER").show()
}
}
Output I want.
First (or Previous) value of column EVENT_ID based on FLAG = 1
And
Last (or Next ) value of column EVENT_ID based on FLAG = 1
partition by SERVICE_ID sorted by counter
+----------+-------+--------+----+-----------+-----------+
|SERVICE_ID|COUNTER|EVENT_ID|FLAG|first_value|last_value|
+----------+-------+--------+----+-----------+-----------+
| 1235| 1| 1101| 0| 0| 1103|
| 1235| 2| 1102| 0| 0| 1103|
| 1235| 3| 1103| 1| 0| 1106|
| 1235| 4| 1104| 0| 1103| 1106|
| 1235| 5| 1105| 0| 1103| 1106|
| 1235| 6| 1106| 1| 0| 1108|
| 1235| 7| 1107| 0| 1106| 1108|
| 1235| 8| 1108| 1| 0| 1109|
| 1235| 9| 1109| 1| 0| 1110|
| 1235| 10| 1110| 1| 0| 0|
| 1235| 11| 1111| 0| 1110| 0|
| 1235| 12| 1112| 0| 1110| 0|
+----------+-------+--------+----+-----------+-----------+
First the dataframe need to be formed into groups. A new group starts at each time the "TIME" column equals 1. To do this, first add a column "ID" to the dataframe:
lazy val window: WindowSpec = Window.partitionBy("SERVICE_ID").orderBy("COUNTER")
val df_flag = input.filter($"FLAG" === 1)
.withColumn("ID", row_number().over(window))
val df_other = input.filter($"FLAG" =!= 1)
.withColumn("ID", lit(0))
// Create a group for each flag event
val df = df_flag.union(df_other)
.withColumn("ID", max("ID").over(window.rowsBetween(Long.MinValue, 0)))
.cache()
df.show() gives:
+----------+-------+--------+----+---+
|SERVICE_ID|COUNTER|EVENT_ID|FLAG| ID|
+----------+-------+--------+----+---+
| 1235| 1| 1111| 1| 1|
| 1235| 2| 1112| 0| 1|
| 1235| 3| 1114| 0| 1|
| 1235| 4| 2221| 1| 2|
| 1235| 5| 2225| 0| 2|
| 1235| 6| 2226| 0| 2|
| 1235| 7| 2227| 1| 3|
+----------+-------+--------+----+---+
Now that we have a column separating the events, we need to add the correct "EVENT_ID" (renamed "first_value") to each event. In addition to the "first_value", calculate and add a second column "last_value", which is the id of the next flagged event.
val df_event = df.filter($"FLAG" === 1)
.select("EVENT_ID", "ID", "SERVICE_ID", "COUNTER")
.withColumnRenamed("EVENT_ID", "first_value")
.withColumn("last_value", lead($"first_value",1,0).over(window))
.drop("COUNTER")
val df_final = df.join(df_event, Seq("ID", "SERVICE_ID"))
.drop("ID")
.withColumn("first_value", when($"FLAG" === 1, lit(0)).otherwise($"first_value"))
df_final.show() gives us:
+----------+-------+--------+----+-----------+----------+
|SERVICE_ID|COUNTER|EVENT_ID|FLAG|first_value|last_value|
+----------+-------+--------+----+-----------+----------+
| 1235| 1| 1111| 1| 0| 2221|
| 1235| 2| 1112| 0| 1111| 2221|
| 1235| 3| 1114| 0| 1111| 2221|
| 1235| 4| 2221| 1| 0| 2227|
| 1235| 5| 2225| 0| 2221| 2227|
| 1235| 6| 2226| 0| 2221| 2227|
| 1235| 7| 2227| 1| 0| 0|
+----------+-------+--------+----+-----------+----------+
Can be solved in two steps:
get events with "FLAG" == 1 and valid range for this event;
join 1. with input, by range.
Some column renaming included for visibility, can be shortened:
val window = Window.partitionBy("SERVICE_ID").orderBy("COUNTER").rowsBetween(Window.currentRow, 1)
val eventRangeDF = input.where($"FLAG" === 1)
.withColumn("RANGE_END", max($"COUNTER").over(window))
.withColumnRenamed("COUNTER", "RANGE_START")
.select("SERVICE_ID", "EVENT_ID", "RANGE_START", "RANGE_END")
eventRangeDF.show(false)
val result = input.where($"FLAG" === 0).as("i").join(eventRangeDF.as("e"),
expr("e.SERVICE_ID=i.SERVICE_ID And i.COUNTER>e.RANGE_START and i.COUNTER<e.RANGE_END"))
.select($"i.SERVICE_ID", $"i.COUNTER", $"i.EVENT_ID", $"i.FLAG", $"e.EVENT_ID".alias("first_value"))
// include FLAG=1
.union(input.where($"FLAG" === 1).select($"SERVICE_ID", $"COUNTER", $"EVENT_ID", $"FLAG", lit(0).alias("first_value")))
result.sort("COUNTER").show(false)
Output:
+----------+--------+-----------+---------+
|SERVICE_ID|EVENT_ID|RANGE_START|RANGE_END|
+----------+--------+-----------+---------+
|1235 |1111 |1 |4 |
|1235 |2221 |4 |7 |
|1235 |2227 |7 |7 |
+----------+--------+-----------+---------+
+----------+-------+--------+----+-----------+
|SERVICE_ID|COUNTER|EVENT_ID|FLAG|first_value|
+----------+-------+--------+----+-----------+
|1235 |1 |1111 |1 |0 |
|1235 |2 |1112 |0 |1111 |
|1235 |3 |1114 |0 |1111 |
|1235 |4 |2221 |1 |0 |
|1235 |5 |2225 |0 |2221 |
|1235 |6 |2226 |0 |2221 |
|1235 |7 |2227 |1 |0 |
+----------+-------+--------+----+-----------+
I have a tall table which contains up to 10 values per group. How can I transform this table into a wide format i.e. add 2 columns where these resemble the value smaller or equal to a threshold?
I want to find the maximum per group, but it needs to be smaller than a specified value like:
min(max('value1), lit(5)).over(Window.partitionBy('grouping))
However min()will only work for a column and not for the Scala value which is returned from the inner function?
The problem can be described as:
Seq(Seq(1,2,3,4).max,5).min
Where Seq(1,2,3,4) is returned by the window.
How can I formulate this in spark sql?
edit
E.g.
+--------+-----+---------+
|grouping|value|something|
+--------+-----+---------+
| 1| 1| first|
| 1| 2| second|
| 1| 3| third|
| 1| 4| fourth|
| 1| 7| 7|
| 1| 10| 10|
| 21| 1| first|
| 21| 2| second|
| 21| 3| third|
+--------+-----+---------+
created by
case class MyThing(grouping: Int, value:Int, something:String)
val df = Seq(MyThing(1,1, "first"), MyThing(1,2, "second"), MyThing(1,3, "third"),MyThing(1,4, "fourth"),MyThing(1,7, "7"), MyThing(1,10, "10"),
MyThing(21,1, "first"), MyThing(21,2, "second"), MyThing(21,3, "third")).toDS
Where
df
.withColumn("somethingAtLeast5AndMaximum5", max('value).over(Window.partitionBy('grouping)))
.withColumn("somethingAtLeast6OupToThereshold2", max('value).over(Window.partitionBy('grouping)))
.show
returns
+--------+-----+---------+----------------------------+-------------------------+
|grouping|value|something|somethingAtLeast5AndMaximum5| somethingAtLeast6OupToThereshold2 |
+--------+-----+---------+----------------------------+-------------------------+
| 1| 1| first| 10| 10|
| 1| 2| second| 10| 10|
| 1| 3| third| 10| 10|
| 1| 4| fourth| 10| 10|
| 1| 7| 7| 10| 10|
| 1| 10| 10| 10| 10|
| 21| 1| first| 3| 3|
| 21| 2| second| 3| 3|
| 21| 3| third| 3| 3|
+--------+-----+---------+----------------------------+-------------------------+
Instead, I rather would want to formulate:
lit(Seq(max('value).asInstanceOf[java.lang.Integer], new java.lang.Integer(2)).min).over(Window.partitionBy('grouping))
But that does not work as max('value) is not a scalar value.
Expected output should look like
+--------+-----+---------+----------------------------+-------------------------+
|grouping|value|something|somethingAtLeast5AndMaximum5|somethingAtLeast6OupToThereshold2|
+--------+-----+---------+----------------------------+-------------------------+
| 1| 4| fourth| 4| 7|
| 21| 1| first| 3| NULL|
+--------+-----+---------+----------------------------+-------------------------+
edit2
When trying a pivot
df.groupBy("grouping").pivot("value").agg(first('something)).show
+--------+-----+------+-----+------+----+----+
|grouping| 1| 2| 3| 4| 7| 10|
+--------+-----+------+-----+------+----+----+
| 1|first|second|third|fourth| 7| 10|
| 21|first|second|third| null|null|null|
+--------+-----+------+-----+------+----+----+
The second part of the problem remains that some columns might not exist or be null.
When aggregating to arrays:
df.groupBy("grouping").agg(collect_list('value).alias("value"), collect_list('something).alias("something"))
+--------+-------------------+--------------------+
|grouping| value| something|
+--------+-------------------+--------------------+
| 1|[1, 2, 3, 4, 7, 10]|[first, second, t...|
| 21| [1, 2, 3]|[first, second, t...|
+--------+-------------------+--------------------+
The values are already next to each other, but the right values need to be selected. This is probably still more efficient than a join or window function.
Would be easier to do in two separate steps - calculate max over Window, and then use when...otherwise on result to produce min(x, 5):
df.withColumn("tmp", max('value1).over(Window.partitionBy('grouping)))
.withColumn("result", when('tmp > lit(5), 5).otherwise('tmp))
EDIT: some example data to clarify this:
val df = Seq((1, 1),(1, 2),(1, 3),(1, 4),(2, 7),(2, 8))
.toDF("grouping", "value1")
df.withColumn("result", max('value1).over(Window.partitionBy('grouping)))
.withColumn("result", when('result > lit(5), 5).otherwise('result))
.show()
// +--------+------+------+
// |grouping|value1|result|
// +--------+------+------+
// | 1| 1| 4| // 4, because Seq(Seq(1,2,3,4).max,5).min = 4
// | 1| 2| 4|
// | 1| 3| 4|
// | 1| 4| 4|
// | 2| 7| 5| // 5, because Seq(Seq(7,8).max,5).min = 5
// | 2| 8| 5|
// +--------+------+------+