spark window function conditional restart - pyspark

I am trying calculate the activity value that is not originated from extra creditation.
Input:
+------+--------+------+
|period|activity|credit|
+------+--------+------+
| 1| 5| 0|
| 2| 0| 3|
| 3| 4| 0|
| 4| 0| 3|
| 5| 1| 0|
| 6| 1| 0|
| 7| 5| 0|
| 8| 0| 1|
| 9| 0| 1|
| 10| 5| 0|
+------+--------+------+
Output:
rdd = sc.parallelize([(5,0,5),(0,3,0),(4,0,1),(0,3,0),(1,0,0),(1,0,0),(5,0,4),(0,1,0),(0,1,0),(5,0,3)])
df = rdd.toDF(["activity","credit","realActivity"])
+------+--------+------+------------+
|period|activity|credit|realActivity|
+------+--------+------+------------+
| 1| 5| 0| 5|
| 2| 0| 3| 0|
| 3| 4| 0| 1|
| 4| 0| 3| 0|
| 5| 1| 0| 0|
| 6| 1| 0| 0|
| 7| 5| 0| 4|
| 8| 0| 1| 0|
| 9| 0| 1| 0|
| 10| 5| 0| 3|
+------+--------+------+------------+
I tried to create a credit balance column that adds and deducts based on the row type, but I could not restart it conditionally (every time it goes below zero) based on itself. It looks like a recursive problem that i am not sure how to transform into pyspark friendly. Obviously, I can't do the following, self referencing the previous value..
w = Window.orderBy("period")
df = df.withColumn("realActivity", lag("realActivity",1,0).over(w) - lag("credit", 1, 0).over(w) - lag("activity",1,0).over(w) )
Update:
As it was pointed out, it is not possible with window calculation. Therefore I would like to do something like the snippet below to calculate creditBalance that would let me calculate the realActivity.
df['creditBalance']=0
for i in range(1, len(df)):
if (df.loc[i-1, 'creditBalance']) > 0:
df.loc[i, 'creditBalance'] = df.loc[i-1, 'creditBalance'] + df.loc[i, 'credit'] - df.loc[i, 'activity']
elif df.loc[i, 'creditamount'] > 0:
df.loc[i, 'creditBalance'] = df.loc[i, 'credit'] - df.loc[i, 'activity']
Now, my only question is: how can I apply this "local" function to each group in a spark dataframe?
write dataframe to files by group and process locally?
custom map and collect the rows for the local execution?
collapse rows to a single row by group and process that ?
anything else?

#pansen,
I've solved the issue with the following code. It may be useful in case you are trying to solve a similar problem.
def creditUsage(rows):
'''
Input:
timestamp, activity, credit
['1;5;0', '2;0;3', '3;4;0', '4;0;3', '5;1;0', '6;1;0', '7;5;0', '8;0;1', '9;0;1', '10;5;0']
Output:
[timestamp; creditUsage]
'''
timestamps = [int(r.split(";")[0]) for r in rows]
rows = [r for _,r in sorted(zip(timestamps,rows))]
print(rows)
timestamp, trActivity, credit = zip(*[(int(ts), float(act), float(rbonus)) for r in rows for [ts, act, rbonus] in [r.split(";")]])
creditBalance,creditUsage = [0.0] * len(credit), [0.0] * len(credit)
for i in range(0, len(trActivity)):
creditBalance[i] = creditBalance[i-1]+credit[i]
""" if bonusBalance greater than activity then actitivity is the usage, if not, than bonusBalance """
creditUsage[i] = creditBalance[i] if creditBalance[i] - trActivity[i] <0 else trActivity[i]
creditBalance[i] += (- creditUsage[i])
output = ["{0};{1:02}".format(t_, r_) for t_, r_ in zip(timestamp, creditUsage)]
return(output)
realBonusUDF = udf(creditUsage,ArrayType(StringType()))
a= df.withColumn('data', concat_ws(';', col('period'), col('activity'), col('credit'))) \
.groupBy('userID').agg(collect_list('data').alias('data')) \
.withColumn('data', realBonusUDF('data')) \
.withColumn("data", explode("data")) \
.withColumn("data", split("data", ";")) \
.withColumn("timestamp", col('data')[0].cast("int")) \
.withColumn("creditUsage", col('data')[1].cast("float")) \
.drop('data')
Output:
+------+---------+-----------+
|userID|timestamp|creditUsage|
+------+---------+-----------+
| 123| 1| 0.0|
| 123| 2| 0.0|
| 123| 3| 3.0|
| 123| 4| 0.0|
| 123| 5| 1.0|
| 123| 6| 1.0|
| 123| 7| 1.0|
| 123| 8| 0.0|
| 123| 9| 0.0|
| 123| 10| 2.0|
+------+---------+-----------+

Related

Window function based on a condition

I have the following DF:
|-----------------------|
|Date | Val | Cond|
|-----------------------|
|2022-01-08 | 2 | 0 |
|2022-01-09 | 4 | 1 |
|2022-01-10 | 6 | 1 |
|2022-01-11 | 8 | 0 |
|2022-01-12 | 2 | 1 |
|2022-01-13 | 5 | 1 |
|2022-01-14 | 7 | 0 |
|2022-01-15 | 9 | 0 |
|-----------------------|
I need to sum the values of two days before where cond = 1 for every date, my expected output is:
|-----------------|
|Date | Sum |
|-----------------|
|2022-01-08 | 0 | Not sum because doesnt exists two dates with cond = 1 before this date
|2022-01-09 | 0 | Not sum because doesnt exists two dates with cond = 1 before this date
|2022-01-10 | 0 | Not sum because doesnt exists two dates with cond = 1 before this date
|2022-01-11 | 10 | (4+6)
|2022-01-12 | 10 | (4+6)
|2022-01-13 | 8 | (2+6)
|2022-01-14 | 7 | (5+2)
|2022-01-15 | 7 | (5+2)
|-----------------|
I've tried to get the output DF using this code:
df = df.where("Cond= 1").withColumn(
"ListView",
f.collect_list("Val").over(windowSpec.rowsBetween(-2, -1))
)
But when I use .where("Cond = 1") I exclude the dates that cond is equal zero.
I found the following answer but didn't help me:
Window.rowsBetween - only consider rows fulfilling a specific condition (e.g. not being null)
How can I achieve my expected output using window functions?
The MVCE:
data_1=[
("2022-01-08",2,0),
("2022-01-09",4,1),
("2022-01-10",6,1),
("2022-01-11",8,0),
("2022-01-12",2,1),
("2022-01-13",5,1),
("2022-01-14",7,0),
("2022-01-15",9,0)
]
schema_1 = StructType([
StructField("Date", DateType(),True),
StructField("Val", IntegerType(),True),
StructField("Cond", IntegerType(),True)
])
df_1 = spark.createDataFrame(data=data_1,schema=schema_1)
The following should do the trick (but I'm sure it can be further optimized).
Setup:
data_1=[
("2022-01-08",2,0),
("2022-01-09",4,1),
("2022-01-10",6,1),
("2022-01-11",8,0),
("2022-01-12",2,1),
("2022-01-13",5,1),
("2022-01-14",7,0),
("2022-01-15",9,0),
("2022-01-16",9,0),
("2022-01-17",9,0)
]
schema_1 = StructType([
StructField("Date", StringType(),True),
StructField("Val", IntegerType(),True),
StructField("Cond", IntegerType(),True)
])
df_1 = spark.createDataFrame(data=data_1,schema=schema_1)
df_1 = df_1.withColumn('Date', to_date("Date", "yyyy-MM-dd"))
+----------+---+----+
| Date|Val|Cond|
+----------+---+----+
|2022-01-08| 2| 0|
|2022-01-09| 4| 1|
|2022-01-10| 6| 1|
|2022-01-11| 8| 0|
|2022-01-12| 2| 1|
|2022-01-13| 5| 1|
|2022-01-14| 7| 0|
|2022-01-15| 9| 0|
|2022-01-16| 9| 0|
|2022-01-17| 9| 0|
+----------+---+----+
Create a new DF only with Cond==1 rows to obtain the sum of two consecutive rows with that condition:
windowSpec = Window.partitionBy("Cond").orderBy("Date")
df_2 = df_1.where(df_1.Cond==1).withColumn(
"Sum",
sum("Val").over(windowSpec.rowsBetween(-1, 0))
).withColumn('date_1', col('date')).drop('date')
+---+----+---+----------+
|Val|Cond|Sum| date_1|
+---+----+---+----------+
| 4| 1| 4|2022-01-09|
| 6| 1| 10|2022-01-10|
| 2| 1| 8|2022-01-12|
| 5| 1| 7|2022-01-13|
+---+----+---+----------+
Do a left join to get the sum into the original data frame, and set the sum to zero for the rows with Cond==0:
df_3 = df_1.join(df_2.select('sum', col('date_1')), df_1.Date == df_2.date_1, "left").drop('date_1').fillna(0)
+----------+---+----+---+
| Date|Val|Cond|sum|
+----------+---+----+---+
|2022-01-08| 2| 0| 0|
|2022-01-09| 4| 1| 4|
|2022-01-10| 6| 1| 10|
|2022-01-11| 8| 0| 0|
|2022-01-12| 2| 1| 8|
|2022-01-13| 5| 1| 7|
|2022-01-14| 7| 0| 0|
|2022-01-15| 9| 0| 0|
|2022-01-16| 9| 0| 0|
|2022-01-17| 9| 0| 0|
+----------+---+----+---+
Do a cumulative sum on the condition column:
df_3=df_3.withColumn('cond_sum', sum('cond').over(Window.orderBy('Date')))
+----------+---+----+---+--------+
| Date|Val|Cond|sum|cond_sum|
+----------+---+----+---+--------+
|2022-01-08| 2| 0| 0| 0|
|2022-01-09| 4| 1| 4| 1|
|2022-01-10| 6| 1| 10| 2|
|2022-01-11| 8| 0| 0| 2|
|2022-01-12| 2| 1| 8| 3|
|2022-01-13| 5| 1| 7| 4|
|2022-01-14| 7| 0| 0| 4|
|2022-01-15| 9| 0| 0| 4|
|2022-01-16| 9| 0| 0| 4|
|2022-01-17| 9| 0| 0| 4|
+----------+---+----+---+--------+
Finally, for each partition where the cond_sum is greater than 1, use the max sum for that partition:
df_3.withColumn('sum', when(df_3.cond_sum > 1, max('sum').over(Window.partitionBy('cond_sum'))).otherwise(0)).show()
+----------+---+----+---+--------+
| Date|Val|Cond|sum|cond_sum|
+----------+---+----+---+--------+
|2022-01-08| 2| 0| 0| 0|
|2022-01-09| 4| 1| 0| 1|
|2022-01-10| 6| 1| 10| 2|
|2022-01-11| 8| 0| 10| 2|
|2022-01-12| 2| 1| 8| 3|
|2022-01-13| 5| 1| 7| 4|
|2022-01-14| 7| 0| 7| 4|
|2022-01-15| 9| 0| 7| 4|
|2022-01-16| 9| 0| 7| 4|
|2022-01-17| 9| 0| 7| 4|
+----------+---+----+---+--------+

How to replace for loop in python with map transformation in pyspark where we want to compare previous row and current row with multiple conditions

just dragged in a road block kind situation while applying map function on pyspark dataframe and need your help in coming out of this.
Though problem is even more complicated but let me simplify it with below example using dictionary and for loop, and need solution in pyspark.
Here example of python code on dummy data, I want same in pyspark map transformation with when, clause using window or any other way.
Problem - I have a pyspark dataframe with column name as keys in below dictionary and want to add/modify section column with similar logic applied in for loop in this example.
record=[
{'id':xyz,'SN':xyz,'miles':xyz,'feet':xyz,'MP':xyz,'section':xyz},
{'id':xyz,'SN':xyz,'miles':xyz,'feet':xyz,'MP':xyz,'section':xyz},
{'id':xyz,'SN':xyz,'miles':xyz,'feet':xyz,'MP':xyz,'section':xyz}
]
last_rec='null'
section=0
for cur_rec in record:
if lastTrack != null:
if (last_rec.id != cur_rec.id | last_rec.SN != cur_rec.SN):
section+=1
elif (last_rec.miles == cur_rec.miles & abs(last_rec.feet- cur_rec.feet) > 1):
section+=1
elif (last_rec.MP== 555 & cur_rec.MP != 555):
section+=1
elif (abs(last_rec.miles- cur_rec.miles) > 1):
section+=1
cur_rec['section']= section
last_rec = cur_rec
Your window function is a cumulative sum of a boolean variable.
Let's start with a sample dataframe:
import numpy as np
record_df = spark.createDataFrame(
[list(x) for x in zip(*[np.random.randint(0, 10, 100).tolist() for _ in range(5)])],
['id', 'SN', 'miles', 'feet', 'MP'])
record_df.show()
+---+---+-----+----+---+
| id| SN|miles|feet| MP|
+---+---+-----+----+---+
| 9| 5| 7| 5| 1|
| 0| 6| 3| 7| 5|
| 8| 2| 7| 3| 5|
| 0| 2| 6| 5| 8|
| 0| 8| 9| 1| 5|
| 8| 5| 1| 6| 0|
| 0| 3| 9| 0| 3|
| 6| 4| 9| 0| 8|
| 5| 8| 8| 1| 0|
| 3| 0| 9| 9| 9|
| 1| 1| 2| 7| 0|
| 1| 3| 7| 7| 6|
| 4| 9| 5| 5| 5|
| 3| 6| 0| 0| 0|
| 5| 5| 5| 9| 3|
| 8| 3| 7| 8| 1|
| 7| 1| 3| 1| 8|
| 3| 1| 5| 2| 5|
| 6| 2| 3| 5| 6|
| 9| 4| 5| 9| 1|
+---+---+-----+----+---+
A cumulative sum is an ordered window function, therefore we'll need to use monotonically_increasing_id to give an order to our rows:
import pyspark.sql.functions as psf
record_df = record_df.withColumn(
'rn',
psf.monotonically_increasing_id())
For the boolean variable we'll need to use lag:
from pyspark.sql import Window
w = Window.orderBy('rn')
record_df = record_df.select(
record_df.columns
+ [psf.lag(c).over(w).alias('prev_' + c) for c in ['id', 'SN', 'miles', 'feet', 'MP']])
Since all the conditions yield the same result on section, it is an orclause:
clause = (psf.col("prev_id") != psf.col("id")) | (psf.col("prev_SN") != psf.col("SN")) \
| ((psf.col("prev_miles") == psf.col("miles")) & (psf.abs(psf.col("prev_feet") - psf.col("feet")) > 1)) \
| ((psf.col("prev_MP") == 555) & (psf.col("MP") != 555)) \
| (psf.abs(psf.col("prev_miles") - psf.col("miles")) > 1)
record_df = record_df.withColumn("tmp", (clause).cast('int'))
And finally for the cumulative sum
record_df = record_df.withColumn("section", psf.sum("tmp").over(w))

How to apply the formula to each Row of Spark DataFrame in Scala?

I have the following DataFrame df:
+------+----+---------+--------+--------+
|nodeId|ni |type |avg_ni |std_ni |
+------+----+---------+--------+--------+
| 1| 1| 0| 0.5| 0.7071|
| 0| 0| 0| 0.5| 0.7071|
| 2| 0| 2| 0.0| 0.0|
| 3| 0| 4| 0.6667| 1.1547|
| 4| 2| 4| 0.6667| 1.1547|
| 5| 0| 4| 0.6667| 1.1547|
+------+----+---------+--------+--------+
I want to apply the formula (ni - avg_ni) / std_ni to each Row.
I tried it this way, but it does not work:
df.map(x => (x("ni")-x("avg_ni")/x("std_ni"))).show()
Just use withColumn or select:
df.select(($"ni" - $ "avg_ni") / $"std_ni")
optionally with conversion
df.select(($"ni" - $ "avg_ni") / $"std_ni").as[Double]

Pyspark - Ranking columns keeping ties

I'm looking for a way to rank columns of a dataframe preserving ties. Specifically for this example, I have a pyspark dataframe as follows where I want to generate ranks for colA & colB (though I want to support being able to rank N number of columns)
+--------+----------+-----+----+
| Entity| id| colA|colB|
+-------------------+-----+----+
| a|8589934652| 21| 50|
| b| 112| 9| 23|
| c|8589934629| 9| 23|
| d|8589934702| 8| 21|
| e| 20| 2| 21|
| f|8589934657| 2| 5|
| g|8589934601| 1| 5|
| h|8589934653| 1| 4|
| i|8589934620| 0| 4|
| j|8589934643| 0| 3|
| k|8589934618| 0| 3|
| l|8589934602| 0| 2|
| m|8589934664| 0| 2|
| n| 25| 0| 1|
| o| 67| 0| 1|
| p|8589934642| 0| 1|
| q|8589934709| 0| 1|
| r|8589934660| 0| 1|
| s| 30| 0| 1|
| t| 55| 0| 1|
+--------+----------+-----+----+
What I'd like is a way to rank this dataframe where tied values receive the same rank such as:
+--------+----------+-----+----+---------+---------+
| Entity| id| colA|colB|colA_rank|colB_rank|
+-------------------+-----+----+---------+---------+
| a|8589934652| 21| 50| 1| 1|
| b| 112| 9| 23| 2| 2|
| c|8589934629| 9| 21| 2| 3|
| d|8589934702| 8| 21| 3| 3|
| e| 20| 2| 21| 4| 3|
| f|8589934657| 2| 5| 4| 4|
| g|8589934601| 1| 5| 5| 4|
| h|8589934653| 1| 4| 5| 5|
| i|8589934620| 0| 4| 6| 5|
| j|8589934643| 0| 3| 6| 6|
| k|8589934618| 0| 3| 6| 6|
| l|8589934602| 0| 2| 6| 7|
| m|8589934664| 0| 2| 6| 7|
| n| 25| 0| 1| 6| 8|
| o| 67| 0| 1| 6| 8|
| p|8589934642| 0| 1| 6| 8|
| q|8589934709| 0| 1| 6| 8|
| r|8589934660| 0| 1| 6| 8|
| s| 30| 0| 1| 6| 8|
| t| 55| 0| 1| 6| 8|
+--------+----------+-----+----+---------+---------+
My current implementation with the first dataframe looks like:
def getRanks(mydf, cols=None, ascending=False):
from pyspark import Row
# This takes a dataframe and a list of columns to rank
# If no list is provided, it ranks *all* columns
# returns a new dataframe
def addRank(ranked_rdd, col, ascending):
# This assumes an RDD of the form (Row(...), list[...])
# it orders the rdd by col, finds the order, then adds that to the
# list
myrdd = ranked_rdd.sortBy(lambda (row, ranks): row[col],
ascending=ascending).zipWithIndex()
return myrdd.map(lambda ((row, ranks), index): (row, ranks +
[index+1]))
myrdd = mydf.rdd
fields = myrdd.first().__fields__
ranked_rdd = myrdd.map(lambda x: (x, []))
if (cols is None):
cols = fields
for col in cols:
ranked_rdd = addRank(ranked_rdd, col, ascending)
rank_names = [x + "_rank" for x in cols]
# Hack to make sure columns come back in the right order
ranked_rdd = ranked_rdd.map(lambda (row, ranks): Row(*row.__fields__ +
rank_names)(*row + tuple(ranks)))
return ranked_rdd.toDF()
which produces:
+--------+----------+-----+----+---------+---------+
| Entity| id| colA|colB|colA_rank|colB_rank|
+-------------------+-----+----+---------+---------+
| a|8589934652| 21| 50| 1| 1|
| b| 112| 9| 23| 2| 2|
| c|8589934629| 9| 23| 3| 3|
| d|8589934702| 8| 21| 4| 4|
| e| 20| 2| 21| 5| 5|
| f|8589934657| 2| 5| 6| 6|
| g|8589934601| 1| 5| 7| 7|
| h|8589934653| 1| 4| 8| 8|
| i|8589934620| 0| 4| 9| 9|
| j|8589934643| 0| 3| 10| 10|
| k|8589934618| 0| 3| 11| 11|
| l|8589934602| 0| 2| 12| 12|
| m|8589934664| 0| 2| 13| 13|
| n| 25| 0| 1| 14| 14|
| o| 67| 0| 1| 15| 15|
| p|8589934642| 0| 1| 16| 16|
| q|8589934709| 0| 1| 17| 17|
| r|8589934660| 0| 1| 18| 18|
| s| 30| 0| 1| 19| 19|
| t| 55| 0| 1| 20| 20|
+--------+----------+-----+----+---------+---------+
As you can see, the function getRanks() takes a dataframe, specifies the columns to be ranked, sorts them, and uses zipWithIndex() to generate an ordering or rank. However, I can't figure out a way to preserve ties.
This stackoverflow post is the closest solution I've found:
rank-users-by-column But it appears to only handle 1 column (I think).
Thanks so much for the help in advance!
EDIT: column 'id' is generated from calling monotonically_increasing_id() and in my implementation is cast to a string.
You're looking for dense_rank
First let's create our dataframe:
df = spark.createDataFrame(sc.parallelize([["a",8589934652,21,50],["b",112,9,23],["c",8589934629,9,23],
["d",8589934702,8,21],["e",20,2,21],["f",8589934657,2,5],
["g",8589934601,1,5],["h",8589934653,1,4],["i",8589934620,0,4],
["j",8589934643,0,3],["k",8589934618,0,3],["l",8589934602,0,2],
["m",8589934664,0,2],["n",25,0,1],["o",67,0,1],["p",8589934642,0,1],
["q",8589934709,0,1],["r",8589934660,0,1],["s",30,0,1],["t",55,0,1]]
), ["Entity","id","colA","colB"])
We'll define two windowSpec:
from pyspark.sql import Window
import pyspark.sql.functions as psf
wA = Window.orderBy(psf.desc("colA"))
wB = Window.orderBy(psf.desc("colB"))
df = df.withColumn(
"colA_rank",
psf.dense_rank().over(wA)
).withColumn(
"colB_rank",
psf.dense_rank().over(wB)
)
+------+----------+----+----+---------+---------+
|Entity| id|colA|colB|colA_rank|colB_rank|
+------+----------+----+----+---------+---------+
| a|8589934652| 21| 50| 1| 1|
| b| 112| 9| 23| 2| 2|
| c|8589934629| 9| 23| 2| 2|
| d|8589934702| 8| 21| 3| 3|
| e| 20| 2| 21| 4| 3|
| f|8589934657| 2| 5| 4| 4|
| g|8589934601| 1| 5| 5| 4|
| h|8589934653| 1| 4| 5| 5|
| i|8589934620| 0| 4| 6| 5|
| j|8589934643| 0| 3| 6| 6|
| k|8589934618| 0| 3| 6| 6|
| l|8589934602| 0| 2| 6| 7|
| m|8589934664| 0| 2| 6| 7|
| n| 25| 0| 1| 6| 8|
| o| 67| 0| 1| 6| 8|
| p|8589934642| 0| 1| 6| 8|
| q|8589934709| 0| 1| 6| 8|
| r|8589934660| 0| 1| 6| 8|
| s| 30| 0| 1| 6| 8|
| t| 55| 0| 1| 6| 8|
+------+----------+----+----+---------+---------+
I'll also pose an alternative:
for cols in data.columns[2:]:
lookup = (data.select(cols)
.distinct()
.orderBy(cols, ascending=False)
.rdd
.zipWithIndex()
.map(lambda x: x[0] + (x[1], ))
.toDF([cols, cols+"_rank_lookup"]))
name = cols + "_ranks"
data = data.join(lookup, [cols]).withColumn(name,col(cols+"_rank_lookup")
+ 1).drop(cols + "_rank_lookup")
Not as elegant as dense_rank() and I'm uncertain as to performance implications.

Filtering rows based on subsequent row values in spark dataframe [duplicate]

I have a dataframe(spark):
id value
3 0
3 1
3 0
4 1
4 0
4 0
I want to create a new dataframe:
3 0
3 1
4 1
Need to remove all the rows after 1(value) for each id.I tried with window functions in spark dateframe(Scala). But couldn't able to find a solution.Seems to be I am going in a wrong direction.
I am looking for a solution in Scala.Thanks
Output using monotonically_increasing_id
scala> val data = Seq((3,0),(3,1),(3,0),(4,1),(4,0),(4,0)).toDF("id", "value")
data: org.apache.spark.sql.DataFrame = [id: int, value: int]
scala> val minIdx = dataWithIndex.filter($"value" === 1).groupBy($"id").agg(min($"idx")).toDF("r_id", "min_idx")
minIdx: org.apache.spark.sql.DataFrame = [r_id: int, min_idx: bigint]
scala> dataWithIndex.join(minIdx,($"r_id" === $"id") && ($"idx" <= $"min_idx")).select($"id", $"value").show
+---+-----+
| id|value|
+---+-----+
| 3| 0|
| 3| 1|
| 4| 1|
+---+-----+
The solution wont work if we did a sorted transformation in the original dataframe. That time the monotonically_increasing_id() is generated based on original DF rather that sorted DF.I have missed that requirement before.
All suggestions are welcome.
One way is to use monotonically_increasing_id() and a self-join:
val data = Seq((3,0),(3,1),(3,0),(4,1),(4,0),(4,0)).toDF("id", "value")
data.show
+---+-----+
| id|value|
+---+-----+
| 3| 0|
| 3| 1|
| 3| 0|
| 4| 1|
| 4| 0|
| 4| 0|
+---+-----+
Now we generate a column named idx with an increasing Long:
val dataWithIndex = data.withColumn("idx", monotonically_increasing_id())
// dataWithIndex.cache()
Now we get the min(idx) for each id where value = 1:
val minIdx = dataWithIndex
.filter($"value" === 1)
.groupBy($"id")
.agg(min($"idx"))
.toDF("r_id", "min_idx")
Now we join the min(idx) back to the original DataFrame:
dataWithIndex.join(
minIdx,
($"r_id" === $"id") && ($"idx" <= $"min_idx")
).select($"id", $"value").show
+---+-----+
| id|value|
+---+-----+
| 3| 0|
| 3| 1|
| 4| 1|
+---+-----+
Note: monotonically_increasing_id() generates its value based on the partition of the row. This value may change each time dataWithIndex is re-evaluated. In my code above, because of lazy evaluation, it's only when I call the final show that monotonically_increasing_id() is evaluated.
If you want to force the value to stay the same, for example so you can use show to evaluate the above step-by-step, uncomment this line above:
// dataWithIndex.cache()
Hi I found the solution using Window and self join.
val data = Seq((3,0,2),(3,1,3),(3,0,1),(4,1,6),(4,0,5),(4,0,4),(1,0,7),(1,1,8),(1,0,9),(2,1,10),(2,0,11),(2,0,12)).toDF("id", "value","sorted")
data.show
scala> data.show
+---+-----+------+
| id|value|sorted|
+---+-----+------+
| 3| 0| 2|
| 3| 1| 3|
| 3| 0| 1|
| 4| 1| 6|
| 4| 0| 5|
| 4| 0| 4|
| 1| 0| 7|
| 1| 1| 8|
| 1| 0| 9|
| 2| 1| 10|
| 2| 0| 11|
| 2| 0| 12|
+---+-----+------+
val sort_df=data.sort($"sorted")
scala> sort_df.show
+---+-----+------+
| id|value|sorted|
+---+-----+------+
| 3| 0| 1|
| 3| 0| 2|
| 3| 1| 3|
| 4| 0| 4|
| 4| 0| 5|
| 4| 1| 6|
| 1| 0| 7|
| 1| 1| 8|
| 1| 0| 9|
| 2| 1| 10|
| 2| 0| 11|
| 2| 0| 12|
+---+-----+------+
var window=Window.partitionBy("id").orderBy("$sorted")
val sort_idx=sort_df.select($"*",rowNumber.over(window).as("count_index"))
val minIdx=sort_idx.filter($"value"===1).groupBy("id").agg(min("count_index")).toDF("idx","min_idx")
val result_id=sort_idx.join(minIdx,($"id"===$"idx") &&($"count_index" <= $"min_idx"))
result_id.show
+---+-----+------+-----------+---+-------+
| id|value|sorted|count_index|idx|min_idx|
+---+-----+------+-----------+---+-------+
| 1| 0| 7| 1| 1| 2|
| 1| 1| 8| 2| 1| 2|
| 2| 1| 10| 1| 2| 1|
| 3| 0| 1| 1| 3| 3|
| 3| 0| 2| 2| 3| 3|
| 3| 1| 3| 3| 3| 3|
| 4| 0| 4| 1| 4| 3|
| 4| 0| 5| 2| 4| 3|
| 4| 1| 6| 3| 4| 3|
+---+-----+------+-----------+---+-------+
Still looking for a more optimized solutions.Thanks
You can simply use groupBy like this
val df2 = df1.groupBy("id","value").count().select("id","value")
Here your df1 is
id value
3 0
3 1
3 0
4 1
4 0
4 0
And resultant dataframe is df2 which is your expected output like this
id value
3 0
3 1
4 1
4 0
use isin method and filter as below:
val data = Seq((3,0,2),(3,1,3),(3,0,1),(4,1,6),(4,0,5),(4,0,4),(1,0,7),(1,1,8),(1,0,9),(2,1,10),(2,0,11),(2,0,12)).toDF("id", "value","sorted")
val idFilter = List(1, 2)
data.filter($"id".isin(idFilter:_*)).show
+---+-----+------+
| id|value|sorted|
+---+-----+------+
| 1| 0| 7|
| 1| 1| 8|
| 1| 0| 9|
| 2| 1| 10|
| 2| 0| 11|
| 2| 0| 12|
+---+-----+------+
Ex: filter based on val
val valFilter = List(0)
data.filter($"value".isin(valFilter:_*)).show
+---+-----+------+
| id|value|sorted|
+---+-----+------+
| 3| 0| 2|
| 3| 0| 1|
| 4| 0| 5|
| 4| 0| 4|
| 1| 0| 7|
| 1| 0| 9|
| 2| 0| 11|
| 2| 0| 12|
+---+-----+------+