I have a dataframe with data like this
I want to compare each record with next year's record and see if that id is there. If it's there in the year and the next year it's 'Existing', If it's there in the year and not there in the next year, it's 'Left'. If it's not there in the year but there in the next year, it's 'New'. I want output like this below. The columns 2017-18, 2018-19 etc. should be created dynamically.
How do I achieve this?
After getting the data in the above format, I need to aggregate the sales for each year band like this as below. For example for 2017-2018,
New_sales = sum of all sales of 2018 (which is the later year in 2017-2018) where it's marked as 'New' which is 25 here.
Left_sales = sum of all sales of 2017 (the earlier year in 2017-2018) where it's marked as 'Left' which is 100 here.
Existing_sales = sum of sales of 2017 where it's marked as 'Existing' subtract sum of sales of 2018 where it's marked as 'Existing'
Existing_sales = 50+75 (sales of 2017, 'Existing') - (20+50) (sales of 2018, 'Existing') = 125-70 = 55
How do I achieve this?
As your date is a string, I think you can:
df = spark.createDataFrame(
data=[
(1, '31/12/2017'),
(2, '31/12/2017'),
(3, '31/12/2017'),
(1, '31/12/2018'),
(3, '31/12/2018'),
(5, '31/12/2018'),
],
schema=['id', 'date']
)
First you can get the year first:
df2 = df.withColumn('year', func.split(func.col('Date'), '/').getItem(2))
df2.show(10, False)
+---+----------+----+
|id |date |year|
+---+----------+----+
|1 |31/12/2017|2017|
|2 |31/12/2017|2017|
|3 |31/12/2017|2017|
|1 |31/12/2018|2018|
|3 |31/12/2018|2018|
|5 |31/12/2018|2018|
+---+----------+----+
Then you can collect the list of year by id as a reference:
df3 = df2.groupby('id')\
.agg(func.collect_set('year').alias('year_lst'))
df3.show(3, False)
+---+------------+
|id |year_lst |
+---+------------+
|1 |[2017, 2018]|
|2 |[2017] |
|3 |[2017, 2018]|
+---+------------+
Then you can join the reference back to the data:
df4 = df2.join(df3, on='id', how='left')
df4.show(10, False)
+---+----------+----+------------+
|id |date |year|year_lst |
+---+----------+----+------------+
|1 |31/12/2017|2017|[2017, 2018]|
|2 |31/12/2017|2017|[2017] |
|3 |31/12/2017|2017|[2017, 2018]|
|1 |31/12/2018|2018|[2017, 2018]|
|3 |31/12/2018|2018|[2017, 2018]|
|5 |31/12/2018|2018|[2018] |
+---+----------+----+------------+
The last step is to create the column dynamically. I think you can use a for loop:
year_loop = ['2017', '2018', '2019', '2020', '2021']
for idx in range(len(year_loop)-1):
this_year = year_loop[idx]
next_year = year_loop[idx+1]
column_name = f"{this_year}-{next_year}"
new_condition = (~func.array_contains(func.col('year_lst'), this_year)) & (func.array_contains(func.col('year_lst'), next_year))
exist_condition = (func.array_contains(func.col('year_lst'), this_year)) & (func.array_contains(func.col('year_lst'), next_year))
left_condition = (func.array_contains(func.col('year_lst'), this_year)) & (~func.array_contains(func.col('year_lst'), next_year))
df4 = df4.withColumn(column_name, func.when(new_condition, func.lit('New'))
.when(exist_condition, func.lit('Existing'))
.when(left_condition, func.lit('Left')))
df4.show(10, False)
+---+----------+----+------------+---------+---------+---------+---------+
|id |date |year|year_lst |2017-2018|2018-2019|2019-2020|2020-2021|
+---+----------+----+------------+---------+---------+---------+---------+
|1 |31/12/2017|2017|[2017, 2018]|Existing |Left |null |null |
|2 |31/12/2017|2017|[2017] |Left |null |null |null |
|3 |31/12/2017|2017|[2017, 2018]|Existing |Left |null |null |
|1 |31/12/2018|2018|[2017, 2018]|Existing |Left |null |null |
|3 |31/12/2018|2018|[2017, 2018]|Existing |Left |null |null |
|5 |31/12/2018|2018|[2018] |New |Left |null |null |
+---+----------+----+------------+---------+---------+---------+---------+
Related
I have problem in joining 2 dataframes grouped by ID
val df1 = Seq(
(1, 1,100),
(1, 3,20),
(2, 5,5),
(2, 2,10)).toDF("id", "index","value")
val df2 = Seq(
(1, 0),
(2, 0),
(3, 0),
(4, 0),
(5,0)).toDF("index", "value")
df1 joins with df2 by index column for every id
expected result
id
index
value
1
1
100
1
2
0
1
3
20
1
4
0
1
5
0
2
1
0
2
2
10
2
3
0
2
4
0
2
5
5
please help me on this
First of all, I would replace your df2 table with this:
var df2 = Seq(
(Array(1, 2), Array(1, 2, 3, 4, 5))
).toDF("id", "index")
This allows us to use explode and auto-generate a table which can be of help to us:
df2 = df2
.withColumn("id", explode(col("id")))
.withColumn("index", explode(col("index")))
and it gives:
+---+-----+
|id |index|
+---+-----+
|1 |1 |
|1 |2 |
|1 |3 |
|1 |4 |
|1 |5 |
|2 |1 |
|2 |2 |
|2 |3 |
|2 |4 |
|2 |5 |
+---+-----+
Now, all we need to do, is join with your df1 as below:
df2 = df2
.join(df1, Seq("id", "index"), "left")
.withColumn("value", when(col("value").isNull, 0).otherwise(col("value")))
And we get this final output:
+---+-----+-----+
|id |index|value|
+---+-----+-----+
|1 |1 |100 |
|1 |2 |0 |
|1 |3 |20 |
|1 |4 |0 |
|1 |5 |0 |
|2 |1 |0 |
|2 |2 |10 |
|2 |3 |0 |
|2 |4 |0 |
|2 |5 |5 |
+---+-----+-----+
which should be what you want. Good luck!
df_hrrchy
|lefId |Lineage |
|-------|--------------------------------------|
|36326 |["36326","36465","36976","36091","82"]|
|36121 |["36121","36908","36976","36091","82"]|
|36380 |["36380","36465","36976","36091","82"]|
|36448 |["36448","36465","36976","36091","82"]|
|36683 |["36683","36465","36976","36091","82"]|
|36949 |["36949","36908","36976","36091","82"]|
|37349 |["37349","36908","36976","36091","82"]|
|37026 |["37026","36908","36976","36091","82"]|
|36879 |["36879","36465","36976","36091","82"]|
df_trans
|tranID | T_Id |
|-----------|-------------------------------------------------------------------------|
|1000540 |["36121","36326","37349","36949","36380","37026","36448","36683","36879"]|
df_creds
|T_Id |T_val |T_Goal |Parent_T_Id |Parent_Val |parent_Goal|
|-------|-------|-------|---------------|----------------|-----------|
|36448 |100 |1 |36465 |200 |1 |
|36465 |200 |1 |36976 |300 |2 |
|36326 |90 |1 |36465 |200 |1 |
|36091 |500 |19 |82 |600 |4 |
|36121 |90 |1 |36908 |200 |1 |
|36683 |90 |1 |36465 |200 |1 |
|36908 |200 |1 |36976 |300 |2 |
|36949 |90 |1 |36908 |200 |1 |
|36976 |300 |2 |36091 |500 |19 |
|37026 |90 |1 |36908 |200 |1 |
|37349 |100 |1 |36908 |200 |1 |
|36879 |90 |1 |36465 |200 |1 |
|36380 |90 |1 |36465 |200 |1 |
Desired Result
T_id
children
T_Val
T_Goal
parent_T_id
parent_Goal
trans_id
36091
["36976"]
500
19
82
4
1000540
36465
["36448","36326","36683","36879","36380"]
200
1
36976
2
1000540
36908
["36121","36949","37026","37349"]
200
1
36976
2
1000540
36976
["36465","36908"]
300
2
36091
19
1000540
36683
null
90
1
36465
1
1000540
37026
null
90
1
36908
1
1000540
36448
null
100
1
36465
1
1000540
36949
null
90
1
36908
1
1000540
36326
null
90
1
36465
1
1000540
36380
null
90
1
36465
1
1000540
36879
null
90
1
36465
1
1000540
36121
null
90
1
36908
1
1000540
37349
null
100
1
36908
1
1000540
Code Tried
from pyspark.sql import functions as F
from pyspark.sql import DataFrame
from pyspark.sql.functions import explode, collect_set, expr, col, collect_list,array_contains, lit
from functools import reduce
for row in df_transactions.rdd.toLocalIterator():
# def find_nodemap(row):
dfs = []
df_hy_set = (df_hrrchy.filter(df_hrrchy. lefId.isin(row["T_ds"]))
.select(explode("Lineage").alias("Terrs"))
.agg(collect_set(col("Terrs")).alias("hierarchy_list"))
.select(F.lit(row["trans_id"]).alias("trans_id "),"hierarchy_list")
)
df_childrens = (df_creds.join(df_ hy _set, expr("array_contains(hierarchy_list, T_id)"))
.select("T_id", "T_Val","T_Goal","parent_T_id", "parent_Goal", "trans _id" )
.groupBy("parent_T_id").agg(collect_list("T_id").alias("children"))
)
df_filter_creds = (df_creds.join(df_ hy _set, expr("array_contains(hierarchy_list, T_id)"))
.select ("T_id", "T_val","T_Goal","parent_T_id", "parent_Goal”, "trans_id")
)
df_nodemap = (df_filter_ creds.alias("A").join(df_childrens.alias("B"), col("A.T_id") == col("B.parent_T_id"), "left")
.select("A.T_id","B.children", "A.T_val","A.terr_Goal","A.parent_T_id", "A.parent_Goal", "A.trans_ id")
)
display(df_nodemap)
# dfs.append(df_nodemap)
# df = reduce(DataFrame.union, dfs)
# display(df)
# # display(df)
My problem - Its a bad design. df_trans is having millions of data and looping through dataframe , its taking forever. Without looping can I do it. I tried couple of other methods, not able to get the desired result.
You certainly need to process entire DataFrame in batch, not iterate row by row.
Key points are to "reverse" df_hrrchy, ie. from parent lineage obtain list of children for every T_Id:
val df_children = df_hrrchy.withColumn("children", slice($"Lineage", lit(1), size($"Lineage") - 1))
.withColumn("parents", slice($"Lineage", 2, 999999))
.select(explode(arrays_zip($"children", $"parents")).as("rels"))
.distinct
.groupBy($"rels.parents".as("T_Id"))
.agg(collect_set($"rels.children").as("children"))
df_children.show(false)
+-----+-----------------------------------+
|T_Id |children |
+-----+-----------------------------------+
|36091|[36976] |
|36465|[36448, 36380, 36326, 36879, 36683]|
|36976|[36465, 36908] |
|82 |[36091] |
|36908|[36949, 37349, 36121, 37026] |
+-----+-----------------------------------+
then expand list of T_Ids in df_trans and also include all T_Ids from the hierarchy:
val df_trans_map = df_trans.withColumn("T_Id", explode($"T_Id"))
.join(df_hrrchy, array_contains($"Lineage", $"T_Id"))
.select($"tranID", explode($"Lineage").as("T_Id"))
.distinct
df_trans_map.show(false)
+-------+-----+
|tranID |T_Id |
+-------+-----+
|1000540|36976|
|1000540|82 |
|1000540|36091|
|1000540|36465|
|1000540|36326|
|1000540|36121|
|1000540|36908|
|1000540|36380|
|1000540|36448|
|1000540|36683|
|1000540|36949|
|1000540|37349|
|1000540|37026|
|1000540|36879|
+-------+-----+
With this it is just a simple join to obtain final result:
df_trans_map.join(df_creds, Seq("T_Id"))
.join(df_children, Seq("T_Id"), "left_outer")
.show(false)
+-----+-------+-----+------+-----------+----------+-----------+-----------------------------------+
|T_Id |tranID |T_val|T_Goal|Parent_T_Id|Parent_Val|parent_Goal|children |
+-----+-------+-----+------+-----------+----------+-----------+-----------------------------------+
|36976|1000540|300 |2 |36091 |500 |19 |[36465, 36908] |
|36091|1000540|500 |19 |82 |600 |4 |[36976] |
|36465|1000540|200 |1 |36976 |300 |2 |[36448, 36380, 36326, 36879, 36683]|
|36326|1000540|90 |1 |36465 |200 |1 |null |
|36121|1000540|90 |1 |36908 |200 |1 |null |
|36908|1000540|200 |1 |36976 |300 |2 |[36949, 37349, 36121, 37026] |
|36380|1000540|90 |1 |36465 |200 |1 |null |
|36448|1000540|100 |1 |36465 |200 |1 |null |
|36683|1000540|90 |1 |36465 |200 |1 |null |
|36949|1000540|90 |1 |36908 |200 |1 |null |
|37349|1000540|100 |1 |36908 |200 |1 |null |
|37026|1000540|90 |1 |36908 |200 |1 |null |
|36879|1000540|90 |1 |36465 |200 |1 |null |
+-----+-------+-----+------+-----------+----------+-----------+-----------------------------------+
You need to re-write this to use the full cluster, using a localIterator means that you aren't fully utilizing the cluster for shared work.
Below code was not run as you didn't provide a workable data set to test. If you do I'll run the code to make sure it's sound.
from pyspark.sql import functions as F
from pyspark.sql import DataFrame
from pyspark.sql.functions import explode, collect_set, expr, col, collect_list,array_contains, lit
from functools import reduce
#uses explode I know this will create a lot of short lived records but the flip side is it will use the entire cluster to complete the work instead of the driver.
df_trans_expld = df_trans.select( df_trans.tranID, explode(df_trans.T_Id).alias("T_Id") )
#uses explode
df_hrrchy_expld = df_hrrchy.select( df_hrrchy.leftId, explode( df_hrrchy.Lineage ).alias("Lineage") )
#uses exploded data to join which is the same as a filter.
df_hy_set = df_trans_expld.join( df_hrrchy_expld, df_hrrchy_expld.lefId === df_trans_expld.T_id, "left").select( "trans_id" ).agg(collect_set(col("Lineage")).alias("hierarchy_list"))
.select(F.lit(col("trans_id")).alias("trans_id "),"hierarchy_list")
#logic unchanged from here down
df_childrens = (df_creds.join(df_hy_set, expr("array_contains(hierarchy_list, T_id)"))
.select("T_id", "T_Val","T_Goal","parent_T_id", "parent_Goal", "trans _id" )
.groupBy("parent_T_id").agg(collect_list("T_id").alias("children"))
)
df_filter_creds = (df_creds.join(ddf_hy_set, expr("array_contains(hierarchy_list, T_id)"))
.select ("T_id", "T_val","T_Goal","parent_T_id", "parent_Goal”, "trans_id")
)
df_nodemap = (df_filter_creds.alias("A").join(df_childrens.alias("B"), col("A.T_id") == col("B.parent_T_id"), "left")
.select("A.T_id","B.children", "A.T_val","A.terr_Goal","A.parent_T_id", "A.parent_Goal", "A.trans_ id")
)
# no need to append/union data as it's now just one dataframe df_nodemap
I'd have to look into this more but I'm pretty sure you are pulling all the data through the driver(with your existing code), which will really slow things down, this will make use of all executors to complete the work.
There may be another optimization to get rid of the array_contains (and use a join instead). I'd have to look at the explain to see if you could get even more performance out of it. Don't remember off the top of my head, you are avoiding a shuffle so it may be better as is.
I have a pyspark dataframe-
df1 = spark.createDataFrame([
("s1", "i1", 0),
("s1", "i2", 1),
("s1", "i3", 2),
("s1", None, 3),
("s1", "i5", 4),
],
["session_id", "item_id", "pos"])
df1.show(truncate=False)
pos is the position or rank of the item in the session.
Now I want to create new sessions without any null values in them. I want to do this by starting a new session after every null item. Basically I want to break existing sessions into multiple sessions, removing the null item_id in the process.
The expected output would like something like-
+----------+-------+---+--------------+
|session_id|item_id|pos|new_session_id|
+----------+-------+---+--------------+
|s1 |i1 |0 | s1_0|
|s1 |i2 |1 | s1_0|
|s1 |i3 |2 | s1_0|
|s1 |null |3 | None|
|s1 |i5 |4 | s1_4|
+----------+-------+---+--------------+
How do I achieve this?
Not sure about the configs of your spark job, but to prevent to use
collect action to build the reference of your "new" session in Python built-in data structure, I would use built-in spark sql function to build the new session reference. Based on your example, assuming you have already sorted the data frame:
from pyspark.sql import SparkSession
from pyspark.sql import functions as func
from pyspark.sql.window import Window
from pyspark.sql.types import *
df = spark.createDataFrame(
[("s1", "i1", 0), ("s1", "i2", 1), ("s1", "i3", 2), ("s1", None, 3), ("s1", None, 4), ("s1", "i6", 5), ("s2", "i7", 6), ("s2", None, 7), ("s2", "i9", 8), ("s2", "i10", 9), ("s2", "i11", 10)],
["session_id", "item_id", "pos"]
)
df.show(20, False)
+----------+-------+---+
|session_id|item_id|pos|
+----------+-------+---+
|s1 |i1 |0 |
|s1 |i2 |1 |
|s1 |i3 |2 |
|s1 |null |3 |
|s1 |null |4 |
|s1 |i6 |5 |
|s2 |i7 |6 |
|s2 |null |7 |
|s2 |i9 |8 |
|s2 |i10 |9 |
|s2 |i11 |10 |
+----------+-------+---+
Step 1: As the data is already sorted, we can use a lag function to shift the data to the next record:
df2 = df\
.withColumn('lag_item', func.lag('item_id', 1).over(Window.partitionBy('session_id').orderBy('pos')))
df2.show(20, False)
+----------+-------+---+--------+
|session_id|item_id|pos|lag_item|
+----------+-------+---+--------+
|s1 |i1 |0 |null |
|s1 |i2 |1 |i1 |
|s1 |i3 |2 |i2 |
|s1 |null |3 |i3 |
|s1 |null |4 |null |
|s1 |i6 |5 |null |
|s2 |i7 |6 |null |
|s2 |null |7 |i7 |
|s2 |i9 |8 |null |
|s2 |i10 |9 |i9 |
|s2 |i11 |10 |i10 |
+----------+-------+---+--------+
Step 2: After using the lag function we can see if the item_id in previous record is NULL or not. Therefore , we can know the boundaries of each new session by doing the filtering and build the reference:
reference = df2\
.filter((func.col('item_id').isNotNull())&(func.col('lag_item').isNull()))\
.groupby('session_id')\
.agg(func.collect_set('pos').alias('session_id_set'))
reference.show(100, False)
+----------+--------------+
|session_id|session_id_set|
+----------+--------------+
|s1 |[0, 5] |
|s2 |[6, 8] |
+----------+--------------+
Step 3: Join the reference back to the data and write a simple UDF to find which new session should be in:
#func.udf(returnType=IntegerType())
def udf_find_session(item_id, pos, session_id_set):
r_val = None
if item_id != None:
for item in session_id_set:
if pos >= item:
r_val = item
else:
break
return r_val
df3 = df2.select('session_id', 'item_id', 'pos')\
.join(reference, on='session_id', how='inner')
df4 = df3.withColumn('new_session_id', udf_find_session(func.col('item_id'), func.col('pos'), func.col('session_id_set')))
df4.show(20, False)
+----------+-------+---+--------------+
|session_id|item_id|pos|new_session_id|
+----------+-------+---+--------------+
|s1 |i1 |0 |0 |
|s1 |i2 |1 |0 |
|s1 |i3 |2 |0 |
|s1 |null |3 |null |
|s1 |null |4 |null |
|s1 |i6 |5 |5 |
|s2 |i7 |6 |6 |
|s2 |null |7 |null |
|s2 |i9 |8 |8 |
|s2 |i10 |9 |8 |
|s2 |i11 |10 |8 |
+----------+-------+---+--------------+
The last step just concat the string you want to show in new session id.
I'm trying to find the max of a column grouped by spark partition id. I'm getting the wrong value when applying the max function though. Here is the code:
val partitionCol = uuid()
val localRankCol = "test"
df = df.withColumn(partitionCol, spark_partition_id)
val windowSpec = WindowSpec.partitionBy(partitionCol).orderBy(sortExprs:_*)
val rankDF = df.withColumn(localRankCol, dense_rank().over(windowSpec))
val rankRangeDF = rankDF.agg(max(localRankCol))
rankRangeDF.show(false)
sortExprs is applying an ascending sort on sales.
And the result with some dummy data is (partitionCol is 5th column):
+--------------+------+-----+---------------------------------+--------------------------------+----+
|title |region|sales|r6bea781150fa46e3a0ed761758a50dea|5683151561af407282380e6cf25f87b5|test|
+--------------+------+-----+---------------------------------+--------------------------------+----+
|Die Hard |US |100.0|1 |0 |1 |
|Rambo |US |100.0|1 |0 |1 |
|Die Hard |AU |200.0|1 |0 |2 |
|House of Cards|EU |400.0|1 |0 |3 |
|Summer Break |US |400.0|1 |0 |3 |
|Rambo |EU |100.0|1 |1 |1 |
|Summer Break |APAC |200.0|1 |1 |2 |
|Rambo |APAC |300.0|1 |1 |3 |
|House of Cards|US |500.0|1 |1 |4 |
+--------------+------+-----+---------------------------------+--------------------------------+----+
+---------+
|max(test)|
+---------+
|5 |
+---------+
"test" column has a max value of 4 but 5 is being returned.
I have the input data set like:
id operation value
1 null 1
1 discard 0
2 null 1
2 null 2
2 max 0
3 null 1
3 null 1
3 list 0
I want to group the input and produce rows according to "operation" column.
for group 1, operation="discard", then the output is null,
for group 2, operation="max", the output is:
2 null 2
for group 3, operation="list", the output is:
3 null 1
3 null 1
So finally the output is like:
id operation value
2 null 2
3 null 1
3 null 1
Is there a solution for this?
I know there is a similar question how-to-iterate-grouped-data-in-spark
But the differences compared to that are:
I want to produce more than one row for each grouped data. Possible
and how?
I want my logic to be easily extended for more operation to be added in future. So User-defined aggregate functions (aka UDAF) is
the only possible solution?
Update 1:
Thank stack0114106, then more details according to his answer, e.g. for id=1, operation="max", I want to iterate all the item with id=2, and find the max value, rather than assign a hard-coded value, that's why I want to iterate the rows in each group. Below is a updated example:
The input:
scala> val df = Seq((0,null,1),(0,"discard",0),(1,null,1),(1,null,2),(1,"max",0),(2,null,1),(2,null,3),(2,"max",0),(3,null,1),(3,null,1),(3,"list",0)).toDF("id"
,"operation","value")
df: org.apache.spark.sql.DataFrame = [id: int, operation: string ... 1 more field]
scala> df.show(false)
+---+---------+-----+
|id |operation|value|
+---+---------+-----+
|0 |null |1 |
|0 |discard |0 |
|1 |null |1 |
|1 |null |2 |
|1 |max |0 |
|2 |null |1 |
|2 |null |3 |
|2 |max |0 |
|3 |null |1 |
|3 |null |1 |
|3 |list |0 |
+---+---------+-----+
The expected output:
+---+---------+-----+
|id |operation|value|
+---+---------+-----+
|1 |null |2 |
|2 |null |3 |
|3 |null |1 |
|3 |null |1 |
+---+---------+-----+
group everything collecting the values, then write logic for each operation :
import org.apache.spark.sql.functions._
val grouped=df.groupBy($"id").agg(max($"operation").as("op"),collect_list($"value").as("vals"))
val maxs=grouped.filter($"op"==="max").withColumn("val",explode($"vals")).groupBy($"id").agg(max("val").as("value"))
val lists=grouped.filter($"op"==="list").withColumn("value",explode($"vals")).filter($"value"!==0).select($"id",$"value")
//we don't collect the "discard"
//and we can add additional subsets for new "operations"
val result=maxs.union(lists)
//if you need the null in "operation" column add it with withColumn
You can use flatMap operation on the dataframe and generate required rows based on the conditions that you mentioned. Check this out
scala> val df = Seq((1,null,1),(1,"discard",0),(2,null,1),(2,null,2),(2,"max",0),(3,null,1),(3,null,1),(3,"list",0)).toDF("id","operation","value")
df: org.apache.spark.sql.DataFrame = [id: int, operation: string ... 1 more field]
scala> df.show(false)
+---+---------+-----+
|id |operation|value|
+---+---------+-----+
|1 |null |1 |
|1 |discard |0 |
|2 |null |1 |
|2 |null |2 |
|2 |max |0 |
|3 |null |1 |
|3 |null |1 |
|3 |list |0 |
+---+---------+-----+
scala> df.filter("operation is not null").flatMap( r=> { val x=r.getString(1); val s = x match { case "discard" => (0,0) case "max" => (1,2) case "list" => (2,1) } ; (0
until s._1).map( i => (r.getInt(0),null,s._2) ) }).show(false)
+---+----+---+
|_1 |_2 |_3 |
+---+----+---+
|2 |null|2 |
|3 |null|1 |
|3 |null|1 |
+---+----+---+
Spark assigns _1,_2 etc.. so you can map them to actual names by assigning them as below
scala> val df2 = df.filter("operation is not null").flatMap( r=> { val x=r.getString(1); val s = x match { case "discard" => (0,0) case "max" => (1,2) case "list" => (2,1) } ; (0 until s._1).map( i => (r.getInt(0),null,s._2) ) }).toDF("id","operation","value")
df2: org.apache.spark.sql.DataFrame = [id: int, operation: null ... 1 more field]
scala> df2.show(false)
+---+---------+-----+
|id |operation|value|
+---+---------+-----+
|2 |null |2 |
|3 |null |1 |
|3 |null |1 |
+---+---------+-----+
scala>
EDIT1:
Since you need the max(value) for each id, you can use window functions and get the max value in a new column, then use the same technique and get the results. Check this out
scala> val df = Seq((0,null,1),(0,"discard",0),(1,null,1),(1,null,2),(1,"max",0),(2,null,1),(2,null,3),(2,"max",0),(3,null,1),(3,null,1),(3,"list",0)).toDF("id","operation","value")
df: org.apache.spark.sql.DataFrame = [id: int, operation: string ... 1 more field]
scala> df.createOrReplaceTempView("michael")
scala> val df2 = spark.sql(""" select *, max(value) over(partition by id) mx from michael """)
df2: org.apache.spark.sql.DataFrame = [id: int, operation: string ... 2 more fields]
scala> df2.show(false)
+---+---------+-----+---+
|id |operation|value|mx |
+---+---------+-----+---+
|1 |null |1 |2 |
|1 |null |2 |2 |
|1 |max |0 |2 |
|3 |null |1 |1 |
|3 |null |1 |1 |
|3 |list |0 |1 |
|2 |null |1 |3 |
|2 |null |3 |3 |
|2 |max |0 |3 |
|0 |null |1 |1 |
|0 |discard |0 |1 |
+---+---------+-----+---+
scala> val df3 = df2.filter("operation is not null").flatMap( r=> { val x=r.getString(1); val s = x match { case "discard" => 0 case "max" => 1 case "list" => 2 } ; (0 until s).map( i => (r.getInt(0),null,r.getInt(3) )) }).toDF("id","operation","value")
df3: org.apache.spark.sql.DataFrame = [id: int, operation: null ... 1 more field]
scala> df3.show(false)
+---+---------+-----+
|id |operation|value|
+---+---------+-----+
|1 |null |2 |
|3 |null |1 |
|3 |null |1 |
|2 |null |3 |
+---+---------+-----+
scala>