PySpark count common occurrences - pyspark

After doing a market basket analysis, extracting the rules,... I also want to count the common occurrence of items - as tuples - to visualize them in Tableau. Below you find the items for each ID / members of a basket.
df = sqlContext.createDataFrame([("ID_1", "Butter"),
("ID_1", "Toast"),
("ID_1","Ham"),
("ID_2", "Ham"),
("ID_2", "Toast"),
("ID_2","Egg"),],
["ID","VAL"])
df.show()
+----+------+
| ID| VAL|
+----+------+
|ID_1|Butter|
|ID_1| Toast|
|ID_1| Ham|
|ID_2| Ham|
|ID_2| Toast|
|ID_2| Egg|
+----+------+
This is the result I want to achieve:
res = sqlContext.createDataFrame([("Butter", "Butter", 0),
("Butter", "Toast", 1),
("Butter", "Ham", 1),
("Butter", "Egg", 0),
("Toast", "Toast", 0),
("Toast", "Ham", 2),
("Toast", "Egg", 1),
("Ham", "Ham", 0),
("Ham", "Egg", 0),
("Egg", "Egg", 0),],
["VAL_1","VAL_2", "COUNT"])
res.show()
+------+------+-----+
| VAL_1| VAL_2|COUNT|
+------+------+-----+
|Butter|Butter| 0|
|Butter| Toast| 1|
|Butter| Ham| 1|
|Butter| Egg| 0|
| Toast| Toast| 0|
| Toast| Ham| 2|
| Toast| Egg| 1|
| Ham| Ham| 0|
| Ham| Egg| 0|
| Egg| Egg| 0|
+------+------+-----+

Try below, you'll might also want to use withColumnRenamed to rename the calculated column
df.groupBy(['ID','VAL']).count().show()

Related

Repeated values in pyspark

I have a dataframe in pyspark where i have three columns
df1 = spark.createDataFrame([
('a', 3, 4.2),
('a', 7, 4.2),
('b', 7, 2.6),
('c', 7, 7.21),
('c', 11, 7.21),
('c', 18, 7.21),
('d', 15, 9.0),
], ['model', 'number', 'price'])
df1.show()
+-----+------+-----+
|model|number|price|
+-----+------+-----+
| a| 3| 4.2|
| a| 7| 4.2|
| b| 7| 2.6|
| c| 7| 7.21|
| c| 11| 7.21|
| c| 18| 7.21|
| d| 15| 9.0|
+-----+------+-----+
Is there a way in pyspark to display only the values that are repeated in the column 'price'?
like in df2 :
df2 = spark.createDataFrame([
('a', 3, 4.2),
('a', 7, 4.2),
('c', 7, 7.21),
('c', 11, 7.21),
('c', 18, 7.21),
], ['model', 'number', 'price'])
df2.show()
+-----+------+-----+
|model|number|price|
+-----+------+-----+
| a| 3| 4.2|
| a| 7| 4.2|
| c| 7| 7.21|
| c| 11| 7.21|
| c| 18| 7.21|
+-----+------+-----+
I tried to do this, but didn't work
df = df1.groupBy("model","price").count().filter("count > 1")
df2 = df1.where((df.model == df1.model) & (df.price == df1.price))
df2.show()
it included the values that are not repeated too
+-----+------+-----+
|model|number|price|
+-----+------+-----+
| a| 3| 4.2|
| a| 7| 4.2|
| b| 7| 2.6|
| c| 7| 7.21|
| c| 11| 7.21|
| c| 18| 7.21|
| d| 15| 9.0|
+-----+------+-----+
You can do so with a window function. We partition by price, take a count and filter count > 1.
from pyspark.sql import Window
from pyspark.sql import functions as f
w = Window().partitionBy('price')
df1.withColumn('_c', f.count('price').over(w)).filter('_c > 1').drop('_c').show()
+-----+------+-----+
|model|number|price|
+-----+------+-----+
| a| 3| 4.2|
| a| 7| 4.2|
| c| 7| 7.21|
| c| 11| 7.21|
| c| 18| 7.21|
+-----+------+-----+

Pyspark: Add column with average of groupby

I have a dataframe like this:
test = spark.createDataFrame(
[
(1, 0, 100),
(2, 0, 200),
(3, 1, 150),
(4, 1, 250),
],
['id', 'flag', 'col1']
)
I would like to create another column and input the average of the groupby of the flag
test.groupBy(f.col('flag')).agg(f.avg(f.col("col1"))).show()
+----+---------+
|flag|avg(col1)|
+----+---------+
| 0| 150.0|
| 1| 200.0|
+----+---------+
End product:
+---+----+----+---+
| id|flag|col1|avg|
+---+----+----+---+
| 1| 0| 100|150|
| 2| 0| 200|150|
| 3| 1| 150|200|
| 4| 1| 250|200|
+---+----+----+---+
You can use the window function:
from pyspark.sql.window import Window
from pyspark.sql import functions as F
w = Window.partitionBy('flag')
test.withColumn("avg", F.avg("col1").over(w)).show()
+---+----+----+-----+
| id|flag|col1| avg|
+---+----+----+-----+
| 1| 0| 100|150.0|
| 2| 0| 200|150.0|
| 3| 1| 150|200.0|
| 4| 1| 250|200.0|
+---+----+----+-----+

PySpark: Simulate SQL's UPDATE

I have two Spark DataFrames:
trg
+---+-----+---------+
|key|value| flag|
+---+-----+---------+
| 1| 0.1|unchanged|
| 2| 0.2|unchanged|
| 3| 0.3|unchanged|
+---+-----+---------+
src
+---+-----+-------+-----+
|key|value| flag|merge|
+---+-----+-------+-----+
| 1| 0.11|changed| 0|
| 2| 0.22|changed| 1|
| 3| 0.33|changed| 0|
+---+-----+-------+-----+
I need to "update" trg.value and trg.flag based on src.merge as described by the following SQL logic:
UPDATE trg
INNER JOIN src ON trg.key = src.key
SET trg.value = src.value,
trg.flag = src.flag
WHERE src.merge = 1;
Expected new trg:
+---+-----+---------+
|key|value| flag|
+---+-----+---------+
| 1| 0.1 |unchanged|
| 2| 0.22| changed|
| 3| 0.3 |unchanged|
+---+-----+---------+
I have tried using when(). It works for the flag field (since it can have only two values), but not for the value field, because I don't know how to pick the value from the corresponding row:
from pyspark.sql.functions import when
trg = spark.createDataFrame(data=[('1', '0.1', 'unchanged'),
('2', '0.2', 'unchanged'),
('3', '0.3', 'unchanged')],
schema=['key', 'value', 'flag'])
src = spark.createDataFrame(data=[('1', '0.11', 'changed', '0'),
('2', '0.22', 'changed', '1'),
('3', '0.33', 'changed', '0')],
schema=['key', 'value', 'flag', 'merge'])
new_trg = (trg.alias('trg').join(src.alias('src'), on=['key'], how='inner')
.select(
'trg.*',
when(src.merge == 1, 'changed').otherwise('unchanged').alias('flag'),
when(src.merge == 1, ???).otherwise(???).alias('value')))
Is there any other, preferably idiomatic, way to translate that SQL logic to PySpark?
newdf = (trg.join(src, on=['key'], how='inner')
.select(trg.key,
when( src.merge==1, src.value)
.otherwise(trg.value).alias('value'),
when( src.merge==1, src.flag)
.otherwise(trg.flag).alias('flag')))
newdf.show()
+---+-----+---------+
|key|value| flag|
+---+-----+---------+
| 1| 0.1|unchanged|
| 2| 0.22| changed|
| 3| 0.3|unchanged|
+---+-----+---------+
Imports and Create datasets
import pyspark.sql.functions as f
l1 = [(1, 0.1, 'unchanged'), (2, 0.2, 'unchanged'), (3, 0.3, 'unchanged')]
dfl1 = spark.createDataFrame(l1).toDF('key', 'value', 'flag')
dfl1.show()
+---+-----+---------+
|key|value| flag|
+---+-----+---------+
| 1| 0.1|unchanged|
| 2| 0.2|unchanged|
| 3| 0.3|unchanged|
+---+-----+---------+
l2 = [(1, 0.11, 'changed', 0), (2, 0.22, 'changed', 1), (3, 0.33, 'changed', 0)]
dfl2 = spark.createDataFrame(l2).toDF('key', 'value', 'flag', 'merge')
dfl2.show()
+---+-----+-------+-----+
|key|value| flag|merge|
+---+-----+-------+-----+
| 1| 0.11|changed| 0|
| 2| 0.22|changed| 1|
| 3| 0.33|changed| 0|
+---+-----+-------+-----+
# filtering upfront for better performance in next join
# dfl2 = dfl2.where(dfl2['merge'] == 1)
Join datasets
join_cond = [dfl1['key'] == dfl2['key'], dfl2['merge'] == 1]
dfl12 = dfl1.join(dfl2, join_cond, 'left_outer')
dfl12.show()
+---+-----+---------+----+-----+-------+-----+
|key|value| flag| key|value| flag|merge|
+---+-----+---------+----+-----+-------+-----+
| 1| 0.1|unchanged|null| null| null| null|
| 3| 0.3|unchanged|null| null| null| null|
| 2| 0.2|unchanged| 2| 0.22|changed| 1|
+---+-----+---------+----+-----+-------+-----+
Use when function. If its null then use the original value or use new value
df = dfl12.withColumn('new_value', f.when(dfl2['value'].isNotNull(), dfl2['value']).otherwise(dfl1['value'])).\
withColumn('new_flag', f.when(dfl2['flag'].isNotNull(), dfl2['flag']).otherwise(dfl1['flag']))
df.show()
+---+-----+---------+----+-----+-------+-----+---------+---------+
|key|value| flag| key|value| flag|merge|new_value| new_flag|
+---+-----+---------+----+-----+-------+-----+---------+---------+
| 1| 0.1|unchanged|null| null| null| null| 0.1|unchanged|
| 3| 0.3|unchanged|null| null| null| null| 0.3|unchanged|
| 2| 0.2|unchanged| 2| 0.22|changed| 1| 0.22| changed|
+---+-----+---------+----+-----+-------+-----+---------+---------+
df.select(dfl1['key'], df['new_value'], df['new_flag']).show()
+---+---------+---------+
|key|new_value| new_flag|
+---+---------+---------+
| 1| 0.1|unchanged|
| 3| 0.3|unchanged|
| 2| 0.22| changed|
+---+---------+---------+
This is for understanding, you can combine couple of steps into one.
import findspark
findspark.init()
from pyspark.sql import SparkSession
from pyspark.sql.functions import when
spark = SparkSession.builder.appName("test").getOrCreate()
data1 = [(1, 0.1, 'unchanged'), (2, 0.2,'unchanged'), (3, 0.3, 'unchanged')]
schema = ['key', 'value', 'flag']
df1 = spark.createDataFrame(data1, schema=schema)
df1.show()
+---+-----+---------+
|key|value| flag|
+---+-----+---------+
| 1| 0.1|unchanged|
| 2| 0.2|unchanged|
| 3| 0.3|unchanged|
+---+-----+---------+
data2 = [(1, 0.11, 'changed',0), (2, 0.22,'changed',1), (3, 0.33, 'changed',0)]
schema2 = ['key', 'value', 'flag', 'merge']
df2 = spark.createDataFrame(data2, schema=schema2)
df2.show()
+---+-----+-------+-----+
|key|value| flag|merge|
+---+-----+-------+-----+
| 1| 0.11|changed| 0|
| 2| 0.22|changed| 1|
| 3| 0.33|changed| 0|
+---+-----+-------+-----+
df2 = df2.withColumnRenamed("value", "value1").withColumnRenamed("flag", 'flag1')
mer = df1.join(df2, ['key'], 'inner')
mer = mer.withColumn("temp", when(mer.merge == 1, mer.value1).otherwise(mer.value))
mer = mer.withColumn("temp1", when(mer.merge == 1, 'changed').otherwise('unchanged'))
output = mer.select(mer.key, mer.temp.alias('value'), mer.temp1.alias('flag'))
output.orderBy(output.value.asc()).show()
+---+-----+---------+
|key|value| flag|
+---+-----+---------+
| 1| 0.1|unchanged|
| 2| 0.22| changed|
| 3| 0.3|unchanged|
+---+-----+---------+

Adding column with sum of all rows above in same grouping

I need to create a 'rolling count' column which takes the previous count and adds the new count for each day and company. I have already organized and sorted the dataframe into groups of ascending dates per company with the corresponding count. I also added a 'ix' column which indexes each grouping, like so:
+--------------------+--------------------+-----+---+
| Normalized_Date| company|count| ix|
+--------------------+--------------------+-----+---+
|09/25/2018 00:00:...|[5c40c8510fb7c017...| 7| 1|
|09/25/2018 00:00:...|[5bdb2b543951bf07...| 9| 1|
|11/28/2017 00:00:...|[593b0d9f3f21f9dd...| 7| 1|
|11/29/2017 00:00:...|[593b0d9f3f21f9dd...| 60| 2|
|01/09/2018 00:00:...|[593b0d9f3f21f9dd...| 1| 3|
|04/27/2018 00:00:...|[593b0d9f3f21f9dd...| 9| 4|
|09/25/2018 00:00:...|[593b0d9f3f21f9dd...| 29| 5|
|11/20/2018 00:00:...|[593b0d9f3f21f9dd...| 42| 6|
|12/11/2018 00:00:...|[593b0d9f3f21f9dd...| 317| 7|
|01/04/2019 00:00:...|[593b0d9f3f21f9dd...| 3| 8|
|02/13/2019 00:00:...|[593b0d9f3f21f9dd...| 15| 9|
|04/01/2019 00:00:...|[593b0d9f3f21f9dd...| 1| 10|
+--------------------+--------------------+-----+---+
The output I need would simply add up all the counts up to that date for each company. Like so:
+--------------------+--------------------+-----+---+------------+
| Normalized_Date| company|count| ix|RollingCount|
+--------------------+--------------------+-----+---+------------+
|09/25/2018 00:00:...|[5c40c8510fb7c017...| 7| 1| 7|
|09/25/2018 00:00:...|[5bdb2b543951bf07...| 9| 1| 9|
|11/28/2017 00:00:...|[593b0d9f3f21f9dd...| 7| 1| 7|
|11/29/2017 00:00:...|[593b0d9f3f21f9dd...| 60| 2| 67|
|01/09/2018 00:00:...|[593b0d9f3f21f9dd...| 1| 3| 68|
|04/27/2018 00:00:...|[593b0d9f3f21f9dd...| 9| 4| 77|
|09/25/2018 00:00:...|[593b0d9f3f21f9dd...| 29| 5| 106|
|11/20/2018 00:00:...|[593b0d9f3f21f9dd...| 42| 6| 148|
|12/11/2018 00:00:...|[593b0d9f3f21f9dd...| 317| 7| 465|
|01/04/2019 00:00:...|[593b0d9f3f21f9dd...| 3| 8| 468|
|02/13/2019 00:00:...|[593b0d9f3f21f9dd...| 15| 9| 483|
|04/01/2019 00:00:...|[593b0d9f3f21f9dd...| 1| 10| 484|
+--------------------+--------------------+-----+---+------------+
I figured the lag function would be of use, and I was able to get each row of rollingcount with ix > 1 to add the count directly above it with the following code:
w = Window.partitionBy('company').orderBy(F.unix_timestamp('Normalized_Dat e','MM/dd/yyyy HH:mm:ss aaa').cast('timestamp'))
refined_DF = solutionDF.withColumn("rn", F.row_number().over(w))
solutionDF = refined_DF.withColumn('RollingCount',F.when(refined_DF['rn'] > 1, refined_DF['count'] + F.lag(refined_DF['count'],count= 1 ).over(w)).otherwise(refined_DF['count']))
which yields the following df:
+--------------------+--------------------+-----+---+------------+
| Normalized_Date| company|count| ix|RollingCount|
+--------------------+--------------------+-----+---+------------+
|09/25/2018 00:00:...|[5c40c8510fb7c017...| 7| 1| 7|
|09/25/2018 00:00:...|[5bdb2b543951bf07...| 9| 1| 9|
|11/28/2017 00:00:...|[593b0d9f3f21f9dd...| 7| 1| 7|
|11/29/2017 00:00:...|[593b0d9f3f21f9dd...| 60| 2| 67|
|01/09/2018 00:00:...|[593b0d9f3f21f9dd...| 1| 3| 61|
|04/27/2018 00:00:...|[593b0d9f3f21f9dd...| 9| 4| 10|
|09/25/2018 00:00:...|[593b0d9f3f21f9dd...| 29| 5| 38|
|11/20/2018 00:00:...|[593b0d9f3f21f9dd...| 42| 6| 71|
|12/11/2018 00:00:...|[593b0d9f3f21f9dd...| 317| 7| 359|
|01/04/2019 00:00:...|[593b0d9f3f21f9dd...| 3| 8| 320|
|02/13/2019 00:00:...|[593b0d9f3f21f9dd...| 15| 9| 18|
|04/01/2019 00:00:...|[593b0d9f3f21f9dd...| 1| 10| 16|
+--------------------+--------------------+-----+---+------------+
I just need it to sum all of the counts ix rows above it. I have tried using a udf to figure out the 'count' input into the lag function, but I keep getting a "'Column' object is not callable" error, plus it doesn't do the sum of all of the rows. I have also tried using a loop but that seems impossible because it will make a new dataframe each time through, plus I would need to join them all afterwards. There must be an easier and simpler way to do this. Perhaps a different function than lag?
The lag returns you a certain single row before your current value, but you need a range to calculate the cummulative sum. Therefore you have to use the window function rangeBetween (rowsBetween). Have a look at the example below:
import pyspark.sql.functions as F
from pyspark.sql import Window
l = [
('09/25/2018', '5c40c8510fb7c017', 7, 1),
('09/25/2018', '5bdb2b543951bf07', 9, 1),
('11/28/2017', '593b0d9f3f21f9dd', 7, 1),
('11/29/2017', '593b0d9f3f21f9dd', 60, 2),
('01/09/2018', '593b0d9f3f21f9dd', 1, 3),
('04/27/2018', '593b0d9f3f21f9dd', 9, 4),
('09/25/2018', '593b0d9f3f21f9dd', 29, 5),
('11/20/2018', '593b0d9f3f21f9dd', 42, 6),
('12/11/2018', '593b0d9f3f21f9dd', 317, 7),
('01/04/2019', '593b0d9f3f21f9dd', 3, 8),
('02/13/2019', '593b0d9f3f21f9dd', 15, 9),
('04/01/2019', '593b0d9f3f21f9dd', 1, 10)
]
columns = ['Normalized_Date', 'company','count', 'ix']
df=spark.createDataFrame(l, columns)
df = df.withColumn('Normalized_Date', F.to_date(df.Normalized_Date, 'MM/dd/yyyy'))
w = Window.partitionBy('company').orderBy('Normalized_Date').rangeBetween(Window.unboundedPreceding, 0)
df = df.withColumn('Rolling_count', F.sum('count').over(w))
df.show()
Output:
+---------------+----------------+-----+---+-------------+
|Normalized_Date| company|count| ix|Rolling_count|
+---------------+----------------+-----+---+-------------+
| 2018-09-25|5c40c8510fb7c017| 7| 1| 7|
| 2018-09-25|5bdb2b543951bf07| 9| 1| 9|
| 2017-11-28|593b0d9f3f21f9dd| 7| 1| 7|
| 2017-11-29|593b0d9f3f21f9dd| 60| 2| 67|
| 2018-01-09|593b0d9f3f21f9dd| 1| 3| 68|
| 2018-04-27|593b0d9f3f21f9dd| 9| 4| 77|
| 2018-09-25|593b0d9f3f21f9dd| 29| 5| 106|
| 2018-11-20|593b0d9f3f21f9dd| 42| 6| 148|
| 2018-12-11|593b0d9f3f21f9dd| 317| 7| 465|
| 2019-01-04|593b0d9f3f21f9dd| 3| 8| 468|
| 2019-02-13|593b0d9f3f21f9dd| 15| 9| 483|
| 2019-04-01|593b0d9f3f21f9dd| 1| 10| 484|
+---------------+----------------+-----+---+-------------+
try this.
You need the sum of all preceding rows to current row in the window frame.
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.expressions.WindowSpec
import org.apache.spark.sql.functions._
val df = Seq(
("5c40c8510fb7c017", 7, 1),
("5bdb2b543951bf07", 9, 1),
("593b0d9f3f21f9dd", 7, 1),
("593b0d9f3f21f9dd", 60, 2),
("593b0d9f3f21f9dd", 1, 3),
("593b0d9f3f21f9dd", 9, 4),
("593b0d9f3f21f9dd", 29, 5),
("593b0d9f3f21f9dd", 42, 6),
("593b0d9f3f21f9dd", 317, 7),
("593b0d9f3f21f9dd", 3, 8),
("593b0d9f3f21f9dd", 15, 9),
("593b0d9f3f21f9dd", 1, 10)
).toDF("company", "count", "ix")
scala> df.show(false)
+----------------+-----+---+
|company |count|ix |
+----------------+-----+---+
|5c40c8510fb7c017|7 |1 |
|5bdb2b543951bf07|9 |1 |
|593b0d9f3f21f9dd|7 |1 |
|593b0d9f3f21f9dd|60 |2 |
|593b0d9f3f21f9dd|1 |3 |
|593b0d9f3f21f9dd|9 |4 |
|593b0d9f3f21f9dd|29 |5 |
|593b0d9f3f21f9dd|42 |6 |
|593b0d9f3f21f9dd|317 |7 |
|593b0d9f3f21f9dd|3 |8 |
|593b0d9f3f21f9dd|15 |9 |
|593b0d9f3f21f9dd|1 |10 |
+----------------+-----+---+
scala> val overColumns = Window.partitionBy("company").orderBy("ix").rowsBetween(Window.unboundedPreceding, Window.currentRow)
overColumns: org.apache.spark.sql.expressions.WindowSpec = org.apache.spark.sql.expressions.WindowSpec#3ed5e17c
scala> val outputDF = df.withColumn("RollingCount", sum("count").over(overColumns))
outputDF: org.apache.spark.sql.DataFrame = [company: string, count: int ... 2 more fields]
scala> outputDF.show(false)
+----------------+-----+---+------------+
|company |count|ix |RollingCount|
+----------------+-----+---+------------+
|5c40c8510fb7c017|7 |1 |7 |
|5bdb2b543951bf07|9 |1 |9 |
|593b0d9f3f21f9dd|7 |1 |7 |
|593b0d9f3f21f9dd|60 |2 |67 |
|593b0d9f3f21f9dd|1 |3 |68 |
|593b0d9f3f21f9dd|9 |4 |77 |
|593b0d9f3f21f9dd|29 |5 |106 |
|593b0d9f3f21f9dd|42 |6 |148 |
|593b0d9f3f21f9dd|317 |7 |465 |
|593b0d9f3f21f9dd|3 |8 |468 |
|593b0d9f3f21f9dd|15 |9 |483 |
|593b0d9f3f21f9dd|1 |10 |484 |
+----------------+-----+---+------------+

Dataframe get first and last value of corresponding column

Is it possible to get first value of the corresponding column within subgroup.
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.{Window, WindowSpec}
object tmp {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local").getOrCreate()
import spark.implicits._
val input = Seq(
(1235, 1, 1101, 0),
(1235, 2, 1102, 0),
(1235, 3, 1103, 1),
(1235, 4, 1104, 1),
(1235, 5, 1105, 0),
(1235, 6, 1106, 0),
(1235, 7, 1107, 1),
(1235, 8, 1108, 1),
(1235, 9, 1109, 1),
(1235, 10, 1110, 0),
(1235, 11, 1111, 0)
).toDF("SERVICE_ID", "COUNTER", "EVENT_ID", "FLAG")
lazy val window: WindowSpec = Window.partitionBy("SERVICE_ID").orderBy("COUNTER")
val firsts = input.withColumn("first_value", first("EVENT_ID", ignoreNulls = true).over(window.rangeBetween(Long.MinValue, Long.MaxValue)))
firsts.orderBy("SERVICE_ID", "COUNTER").show()
}
}
Output I want.
First (or Previous) value of column EVENT_ID based on FLAG = 1
And
Last (or Next ) value of column EVENT_ID based on FLAG = 1
partition by SERVICE_ID sorted by counter
+----------+-------+--------+----+-----------+-----------+
|SERVICE_ID|COUNTER|EVENT_ID|FLAG|first_value|last_value|
+----------+-------+--------+----+-----------+-----------+
| 1235| 1| 1101| 0| 0| 1103|
| 1235| 2| 1102| 0| 0| 1103|
| 1235| 3| 1103| 1| 0| 1106|
| 1235| 4| 1104| 0| 1103| 1106|
| 1235| 5| 1105| 0| 1103| 1106|
| 1235| 6| 1106| 1| 0| 1108|
| 1235| 7| 1107| 0| 1106| 1108|
| 1235| 8| 1108| 1| 0| 1109|
| 1235| 9| 1109| 1| 0| 1110|
| 1235| 10| 1110| 1| 0| 0|
| 1235| 11| 1111| 0| 1110| 0|
| 1235| 12| 1112| 0| 1110| 0|
+----------+-------+--------+----+-----------+-----------+
First the dataframe need to be formed into groups. A new group starts at each time the "TIME" column equals 1. To do this, first add a column "ID" to the dataframe:
lazy val window: WindowSpec = Window.partitionBy("SERVICE_ID").orderBy("COUNTER")
val df_flag = input.filter($"FLAG" === 1)
.withColumn("ID", row_number().over(window))
val df_other = input.filter($"FLAG" =!= 1)
.withColumn("ID", lit(0))
// Create a group for each flag event
val df = df_flag.union(df_other)
.withColumn("ID", max("ID").over(window.rowsBetween(Long.MinValue, 0)))
.cache()
df.show() gives:
+----------+-------+--------+----+---+
|SERVICE_ID|COUNTER|EVENT_ID|FLAG| ID|
+----------+-------+--------+----+---+
| 1235| 1| 1111| 1| 1|
| 1235| 2| 1112| 0| 1|
| 1235| 3| 1114| 0| 1|
| 1235| 4| 2221| 1| 2|
| 1235| 5| 2225| 0| 2|
| 1235| 6| 2226| 0| 2|
| 1235| 7| 2227| 1| 3|
+----------+-------+--------+----+---+
Now that we have a column separating the events, we need to add the correct "EVENT_ID" (renamed "first_value") to each event. In addition to the "first_value", calculate and add a second column "last_value", which is the id of the next flagged event.
val df_event = df.filter($"FLAG" === 1)
.select("EVENT_ID", "ID", "SERVICE_ID", "COUNTER")
.withColumnRenamed("EVENT_ID", "first_value")
.withColumn("last_value", lead($"first_value",1,0).over(window))
.drop("COUNTER")
val df_final = df.join(df_event, Seq("ID", "SERVICE_ID"))
.drop("ID")
.withColumn("first_value", when($"FLAG" === 1, lit(0)).otherwise($"first_value"))
df_final.show() gives us:
+----------+-------+--------+----+-----------+----------+
|SERVICE_ID|COUNTER|EVENT_ID|FLAG|first_value|last_value|
+----------+-------+--------+----+-----------+----------+
| 1235| 1| 1111| 1| 0| 2221|
| 1235| 2| 1112| 0| 1111| 2221|
| 1235| 3| 1114| 0| 1111| 2221|
| 1235| 4| 2221| 1| 0| 2227|
| 1235| 5| 2225| 0| 2221| 2227|
| 1235| 6| 2226| 0| 2221| 2227|
| 1235| 7| 2227| 1| 0| 0|
+----------+-------+--------+----+-----------+----------+
Can be solved in two steps:
get events with "FLAG" == 1 and valid range for this event;
join 1. with input, by range.
Some column renaming included for visibility, can be shortened:
val window = Window.partitionBy("SERVICE_ID").orderBy("COUNTER").rowsBetween(Window.currentRow, 1)
val eventRangeDF = input.where($"FLAG" === 1)
.withColumn("RANGE_END", max($"COUNTER").over(window))
.withColumnRenamed("COUNTER", "RANGE_START")
.select("SERVICE_ID", "EVENT_ID", "RANGE_START", "RANGE_END")
eventRangeDF.show(false)
val result = input.where($"FLAG" === 0).as("i").join(eventRangeDF.as("e"),
expr("e.SERVICE_ID=i.SERVICE_ID And i.COUNTER>e.RANGE_START and i.COUNTER<e.RANGE_END"))
.select($"i.SERVICE_ID", $"i.COUNTER", $"i.EVENT_ID", $"i.FLAG", $"e.EVENT_ID".alias("first_value"))
// include FLAG=1
.union(input.where($"FLAG" === 1).select($"SERVICE_ID", $"COUNTER", $"EVENT_ID", $"FLAG", lit(0).alias("first_value")))
result.sort("COUNTER").show(false)
Output:
+----------+--------+-----------+---------+
|SERVICE_ID|EVENT_ID|RANGE_START|RANGE_END|
+----------+--------+-----------+---------+
|1235 |1111 |1 |4 |
|1235 |2221 |4 |7 |
|1235 |2227 |7 |7 |
+----------+--------+-----------+---------+
+----------+-------+--------+----+-----------+
|SERVICE_ID|COUNTER|EVENT_ID|FLAG|first_value|
+----------+-------+--------+----+-----------+
|1235 |1 |1111 |1 |0 |
|1235 |2 |1112 |0 |1111 |
|1235 |3 |1114 |0 |1111 |
|1235 |4 |2221 |1 |0 |
|1235 |5 |2225 |0 |2221 |
|1235 |6 |2226 |0 |2221 |
|1235 |7 |2227 |1 |0 |
+----------+-------+--------+----+-----------+