Window function acts not as expected when I use Order By (PySpark) - pyspark

So I have read this comprehensive material yet I don't understand why Window function acts this way.
Here's a little example:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.window import Window
spark = SparkSession.builder.getOrCreate()
columns = ["CATEGORY", "REVENUE"]
data = [("Cell Phone", "6000"),
("Tablet", "1500"),
("Tablet", "5500"),
("Cell Phone", "5000"),
("Cell Phone", "6000"),
("Tablet", "2500"),
("Cell Phone", "3000"),
("Cell Phone", "3000"),
("Tablet", "3000"),
("Tablet", "4500"),
("Tablet", "6500")]
df = spark.createDataFrame(data=data, schema=columns)
window_spec = Window.partitionBy(df['CATEGORY']).orderBy(df['REVENUE'])
revenue_difference = F.max(df['REVENUE']).over(window_spec)
df.select(
df['CATEGORY'],
df['REVENUE'],
revenue_difference.alias("revenue_difference")).show()
So when I write orderBy(df['REVENUE']), I get this:
+----------+-------+------------------+
| CATEGORY|REVENUE|revenue_difference|
+----------+-------+------------------+
|Cell Phone| 3000| 3000|
|Cell Phone| 3000| 3000|
|Cell Phone| 5000| 5000|
|Cell Phone| 6000| 6000|
|Cell Phone| 6000| 6000|
| Tablet| 1500| 1500|
| Tablet| 2500| 2500|
| Tablet| 3000| 3000|
| Tablet| 4500| 4500|
| Tablet| 5500| 5500|
| Tablet| 6500| 6500|
+----------+-------+------------------+
But when I write orderBy(df['REVENUE']).desc(), I get this:
+----------+-------+------------------+
| CATEGORY|REVENUE|revenue_difference|
+----------+-------+------------------+
|Cell Phone| 6000| 6000|
|Cell Phone| 6000| 6000|
|Cell Phone| 5000| 6000|
|Cell Phone| 3000| 6000|
|Cell Phone| 3000| 6000|
| Tablet| 6500| 6500|
| Tablet| 5500| 6500|
| Tablet| 4500| 6500|
| Tablet| 3000| 6500|
| Tablet| 2500| 6500|
| Tablet| 1500| 6500|
+----------+-------+------------------+
I don't understand because the way I see it, the MAX value in each window stays the same no matter what order is. So can someone please explain me what I am not gettin here??
Thank you!

The simple reason is that the default window range/row spec is Window.UnboundedPreceding to Window.CurrentRow, which means that the max is taken from the first row in that partition to the current row, NOT the last row of the partition.
This is a common gotcha. (you can replace .max() with sum() and see what output you get. It also changes depending on how you order the partition.)
To solve this, you can specify that you want the max of each partition to always be calculated using the full window partition, like so:
window_spec = Window.partitionBy(df['CATEGORY']).orderBy(df['REVENUE']).rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
revenue_difference = F.max(df['REVENUE']).over(window_spec)
df.select(
df['CATEGORY'],
df['REVENUE'],
revenue_difference.alias("revenue_difference")).show()
+----------+-------+------------------+
| CATEGORY|REVENUE|revenue_difference|
+----------+-------+------------------+
| Tablet| 6500| 6500|
| Tablet| 5500| 6500|
| Tablet| 4500| 6500|
| Tablet| 3000| 6500|
| Tablet| 2500| 6500|
| Tablet| 1500| 6500|
|Cell Phone| 6000| 6000|
|Cell Phone| 6000| 6000|
|Cell Phone| 5000| 6000|
|Cell Phone| 3000| 6000|
|Cell Phone| 3000| 6000|
+----------+-------+------------------+

Related

Spark collect_set from a column using window function approach

I have a sample dataset with salaries. I want to distribute that salary into 3 buckets and then find the lower of the salary in each bucket and then convert that into an array and attach it to the original set. I am trying to use window function to do that. And it seems to do it in a progressive fashion.
Here is the code that I have written
val spark = sparkSession
import spark.implicits._
val simpleData = Seq(("James", "Sales", 3000),
("Michael", "Sales", 3100),
("Robert", "Sales", 3200),
("Maria", "Finance", 3300),
("James", "Sales", 3400),
("Scott", "Finance", 3500),
("Jen", "Finance", 3600),
("Jeff", "Marketing", 3700),
("Kumar", "Marketing", 3800),
("Saif", "Sales", 3900)
)
val df = simpleData.toDF("employee_name", "department", "salary")
val windowSpec = Window.orderBy("salary")
val ntileFrame = df.withColumn("ntile", ntile(3).over(windowSpec))
val lowWindowSpec = Window.partitionBy("ntile")
val ntileMinDf = ntileFrame.withColumn("lower_bound", min("salary").over(lowWindowSpec))
var rangeDf = ntileMinDf.withColumn("range", collect_set("lower_bound").over(windowSpec))
rangeDf.show()
I am getting the dataset like this
+-------------+----------+------+-----+-----------+------------------+
|employee_name|department|salary|ntile|lower_bound| range|
+-------------+----------+------+-----+-----------+------------------+
| James| Sales| 3000| 1| 3000| [3000]|
| Michael| Sales| 3100| 1| 3000| [3000]|
| Robert| Sales| 3200| 1| 3000| [3000]|
| Maria| Finance| 3300| 1| 3000| [3000]|
| James| Sales| 3400| 2| 3400| [3000, 3400]|
| Scott| Finance| 3500| 2| 3400| [3000, 3400]|
| Jen| Finance| 3600| 2| 3400| [3000, 3400]|
| Jeff| Marketing| 3700| 3| 3700|[3000, 3700, 3400]|
| Kumar| Marketing| 3800| 3| 3700|[3000, 3700, 3400]|
| Saif| Sales| 3900| 3| 3700|[3000, 3700, 3400]|
+-------------+----------+------+-----+-----------+------------------+
I am expecting the dataset to look like this
+-------------+----------+------+-----+-----------+------------------+
|employee_name|department|salary|ntile|lower_bound| range|
+-------------+----------+------+-----+-----------+------------------+
| James| Sales| 3000| 1| 3000|[3000, 3700, 3400]|
| Michael| Sales| 3100| 1| 3000|[3000, 3700, 3400]|
| Robert| Sales| 3200| 1| 3000|[3000, 3700, 3400]|
| Maria| Finance| 3300| 1| 3000|[3000, 3700, 3400]|
| James| Sales| 3400| 2| 3400|[3000, 3700, 3400]|
| Scott| Finance| 3500| 2| 3400|[3000, 3700, 3400]|
| Jen| Finance| 3600| 2| 3400|[3000, 3700, 3400]|
| Jeff| Marketing| 3700| 3| 3700|[3000, 3700, 3400]|
| Kumar| Marketing| 3800| 3| 3700|[3000, 3700, 3400]|
| Saif| Sales| 3900| 3| 3700|[3000, 3700, 3400]|
+-------------+----------+------+-----+-----------+------------------+
To ensure that your windows take into account all rows and not only rows before current row, you can use rowsBetween method with Window.unboundedPreceding and Window.unboundedFollowing as argument. Your last line thus become:
var rangeDf = ntileMinDf.withColumn(
"range",
collect_set("lower_bound")
.over(Window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))
)
and you get the following rangeDf dataframe:
+-------------+----------+------+-----+-----------+------------------+
|employee_name|department|salary|ntile|lower_bound| range|
+-------------+----------+------+-----+-----------+------------------+
| James| Sales| 3000| 1| 3000|[3000, 3700, 3400]|
| Michael| Sales| 3100| 1| 3000|[3000, 3700, 3400]|
| Robert| Sales| 3200| 1| 3000|[3000, 3700, 3400]|
| Maria| Finance| 3300| 1| 3000|[3000, 3700, 3400]|
| James| Sales| 3400| 2| 3400|[3000, 3700, 3400]|
| Scott| Finance| 3500| 2| 3400|[3000, 3700, 3400]|
| Jen| Finance| 3600| 2| 3400|[3000, 3700, 3400]|
| Jeff| Marketing| 3700| 3| 3700|[3000, 3700, 3400]|
| Kumar| Marketing| 3800| 3| 3700|[3000, 3700, 3400]|
| Saif| Sales| 3900| 3| 3700|[3000, 3700, 3400]|
+-------------+----------+------+-----+-----------+------------------+

Pyspark remove duplicates base 2 columns

I have the next df in pyspark:
+---------+----------+--------+-----+----------+------+
|firstname|middlename|lastname| ncf| date|salary|
+---------+----------+--------+-----+----------+------+
| James| | V|36636|2021-09-03| 3000| remove
| Michael| Rose| |40288|2021-09-10| 4000|
| Robert| |Williams|42114|2021-08-03| 4000|
| Maria| Anne| Jones|39192|2021-05-13| 4000|
| Jen| Mary| Brown| |2020-09-03| -1|
| James| | Smith|36636|2021-09-03| 3000| remove
| James| | Smith|36636|2021-09-04| 3000|
+---------+----------+--------+-----+----------+------+
I need remove rows where ncf and date were equal. The df result will be:
+---------+----------+--------+-----+----------+------+
|firstname|middlename|lastname| ncf| date|salary|
+---------+----------+--------+-----+----------+------+
| Michael| Rose| |40288|2021-09-10| 4000|
| Robert| |Williams|42114|2021-08-03| 4000|
| Maria| Anne| Jones|39192|2021-05-13| 4000|
| Jen| Mary| Brown| |2020-09-03| -1|
| James| | Smith|36636|2021-09-04| 3000|
+---------+----------+--------+-----+----------+------+
dropDuplicates method helps with removing duplicates with in a subset of columns.
df.dropDuplicates(['ncf', 'date'])
You can use window functions to count if there are two or more rows with your conditions
from pyspark.sql import functions as F
from pyspark.sql import Window as W
df.withColumn('duplicated', F.count('*').over(W.partitionBy('ncf', 'date').orderBy(F.lit(1))) > 1)
# +---------+----------+--------+-----+----------+------+----------+
# |firstname|middlename|lastname| ncf| date|salary|duplicated|
# +---------+----------+--------+-----+----------+------+----------+
# | Jen| Mary| Brown| |2020-09-03| -1| false|
# | James| | V|36636|2021-09-03| 3000| true|
# | James| | Smith|36636|2021-09-03| 3000| true|
# | Michael| Rose| |40288|2021-09-10| 4000| false|
# | Robert| |Williams|42114|2021-08-03| 4000| false|
# | James| | Smith|36636|2021-09-04| 3000| false|
# | Maria| Anne| Jones|39192|2021-05-13| 4000| false|
# +---------+----------+--------+-----+----------+------+----------+
You now can use duplicated to filter rows as desired.

PySpark - Getting the latest date less than another given date

I need some help. I have two dataframes, one has a few dates and the other has my significant data, catalogued by date.
It goes something like this:
First df, with the relevant data
+------+----------+---------------+
| id| test_date| score|
+------+----------+---------------+
| 1|2021-03-31| 94|
| 1|2021-01-31| 93|
| 1|2020-12-31| 100|
| 1|2020-06-30| 95|
| 1|2019-10-31| 58|
| 1|2017-10-31| 78|
| 2|2020-01-31| 79|
| 2|2018-03-31| 66|
| 2|2016-05-31| 77|
| 3|2021-05-31| 97|
| 3|2020-07-31| 100|
| 3|2019-07-31| 99|
| 3|2019-06-30| 98|
| 3|2018-07-31| 91|
| 3|2018-02-28| 86|
| 3|2017-11-30| 82|
+------+----------+---------------+
Second df, with the dates
+--------------+--------------+--------------+
| eval_date_1| eval_date_2| eval_date_3|
+--------------+--------------+--------------+
| 2021-01-31| 2020-10-31| 2019-06-30|
+--------------+--------------+--------------+
Needed DF
+------+--------------+---------+--------------+---------+--------------+---------+
| id| eval_date_1| score_1 | eval_date_2| score_2 | eval_date_3| score_3 |
+------+--------------+---------+--------------+---------+--------------+---------+
| 1| 2021-01-31| 93| 2020-10-31| 95| 2019-06-30| 78|
| 2| 2021-01-31| 79| 2020-10-31| 79| 2019-06-30| 66|
| 3| 2021-01-31| 100| 2020-10-31| 100| 2019-06-30| 98|
+------+--------------+---------+--------------+---------+--------------+---------+
So, for instance, for the first id, the needed df takes the scores from the second, fourth and sixth rows from the first df. Those are the most updated dates that stay equal to or below the eval_date on the second df.
Assuming df is your main dataframe and df_date is the one which contains only dates.
from functools import reduce
from pyspark.sql import functions as F, Window as W
df_final = reduce(
lambda a, b: a.join(b, on="id"),
(
df.join(
F.broadcast(df_date.select(f"eval_date_{i}")),
on=F.col(f"eval_date_{i}") >= F.col("test_date"),
)
.withColumn(
"rnk",
F.row_number().over(W.partitionBy("id").orderBy(F.col("test_date").desc())),
)
.where("rnk=1")
.select("id", f"eval_date_{i}", "score")
for i in range(1, 4)
),
)
df_final.show()
+---+-----------+-----+-----------+-----+-----------+-----+
| id|eval_date_1|score|eval_date_2|score|eval_date_3|score|
+---+-----------+-----+-----------+-----+-----------+-----+
| 1| 2021-01-31| 93| 2020-10-31| 95| 2019-06-30| 78|
| 3| 2021-01-31| 100| 2020-10-31| 100| 2019-06-30| 98|
| 2| 2021-01-31| 79| 2020-10-31| 79| 2019-06-30| 66|
+---+-----------+-----+-----------+-----+-----------+-----+

Spark Dataframe Scala: add new columns by some conditions

I revised my question so that it is easier to understand.
original df looks like this:
+---+----------+-------+----+------+
| id|tim |price | qty|qtyChg|
+---+----------+-------+----+------+
| 1| 31951.509| 0.370| 1| 1|
| 2| 31951.515|145.380| 100| 100|
| 3| 31951.519|149.370| 100| 100|
| 4| 31951.520|144.370| 100| 100|
| 5| 31951.520|119.370| 5| 5|
| 6| 31951.520|149.370| 300| 200|
| 7| 31951.521|149.370| 400| 100|
| 8| 31951.522|149.370| 410| 10|
| 9| 31951.522|149.870| 50| 50|
| 10| 31951.522|109.370| 50| 50|
| 11| 31951.522|144.370| 400| 300|
| 12| 31951.524|149.370| 610| 200|
| 13| 31951.526|135.130| 22| 22|
| 14| 31951.527|149.370| 750| 140|
| 15| 31951.528| 89.370| 100| 100|
| 16| 31951.528|145.870| 50| 50|
| 17| 31951.528|139.370| 100| 100|
| 18| 31951.531|144.370| 410| 10|
| 19| 31951.531|149.370| 769| 19|
| 20| 31951.538|149.370| 869| 100|
| 21| 31951.538|144.880| 200| 200|
| 22| 31951.541|139.370| 221| 121|
| 23| 31951.542|149.370|1199| 330|
| 24| 31951.542|139.370| 236| 15|
| 25| 31951.542|144.370| 510| 100|
| 26| 31951.543|146.250| 50| 50|
| 27| 31951.543|143.820| 100| 100|
| 28| 31951.543|139.370| 381| 145|
| 29| 31951.544|149.370|1266| 67|
| 30| 31951.544|150.000| 50| 50|
| 31| 31951.544|137.870| 300| 300|
| 32| 31951.544|140.470| 10| 10|
| 33| 31951.545|150.000| 53| 3|
| 34| 31951.545|140.000| 25| 25|
| 35| 31951.545|148.310| 8| 8|
| 36| 31951.547|149.000| 20| 20|
| 37| 31951.549|143.820| 102| 2|
| 38| 31951.549|150.110| 75| 75|
+---+----------+-------+----+------+
then I run the code
val ww = Window.partitionBy().orderBy($"tim")
val step1 = df.withColumn("sequence",sort_array(collect_set(col("price")).over(ww),asc=false))
.withColumn("top1price",col("sequence").getItem(0))
.withColumn("top2price",col("sequence").getItem(1))
.drop("sequence")
The new dataframe looks like this:
+---+---------+-------+----+------+---------+---------+
| id| tim| price| qty|qtyChg|top1price|top2price|
+---+---------+-------+----+------+---------+---------+
| 1|31951.509| 0.370| 1| 1| 0.370| null|
| 2|31951.515|145.380| 100| 100| 145.380| 0.370|
| 3|31951.519|149.370| 100| 100| 149.370| 145.380|
| 4|31951.520|149.370| 300| 200| 149.370| 145.380|
| 5|31951.520|144.370| 100| 100| 149.370| 145.380|
| 6|31951.520|119.370| 5| 5| 149.370| 145.380|
| 7|31951.521|149.370| 400| 100| 149.370| 145.380|
| 8|31951.522|109.370| 50| 50| 149.870| 149.370|
| 9|31951.522|144.370| 400| 300| 149.870| 149.370|
| 10|31951.522|149.870| 50| 50| 149.870| 149.370|
| 11|31951.522|149.370| 410| 10| 149.870| 149.370|
| 12|31951.524|149.370| 610| 200| 149.870| 149.370|
| 13|31951.526|135.130| 22| 22| 149.870| 149.370|
| 14|31951.527|149.370| 750| 140| 149.870| 149.370|
| 15|31951.528| 89.370| 100| 100| 149.870| 149.370|
| 16|31951.528|139.370| 100| 100| 149.870| 149.370|
| 17|31951.528|145.870| 50| 50| 149.870| 149.370|
| 18|31951.531|144.370| 410| 10| 149.870| 149.370|
| 19|31951.531|149.370| 769| 19| 149.870| 149.370|
| 20|31951.538|144.880| 200| 200| 149.870| 149.370|
| 21|31951.538|149.370| 869| 100| 149.870| 149.370|
| 22|31951.541|139.370| 221| 121| 149.870| 149.370|
| 23|31951.542|144.370| 510| 100| 149.870| 149.370|
| 24|31951.542|139.370| 236| 15| 149.870| 149.370|
| 25|31951.542|149.370|1199| 330| 149.870| 149.370|
| 26|31951.543|139.370| 381| 145| 149.870| 149.370|
| 27|31951.543|143.820| 100| 100| 149.870| 149.370|
| 28|31951.543|146.250| 50| 50| 149.870| 149.370|
| 29|31951.544|140.470| 10| 10| 150.000| 149.870|
| 30|31951.544|137.870| 300| 300| 150.000| 149.870|
| 31|31951.544|150.000| 50| 50| 150.000| 149.870|
| 32|31951.544|149.370|1266| 67| 150.000| 149.870|
| 33|31951.545|140.000| 25| 25| 150.000| 149.870|
| 34|31951.545|150.000| 53| 3| 150.000| 149.870|
| 35|31951.545|148.310| 8| 8| 150.000| 149.870|
| 36|31951.547|149.000| 20| 20| 150.000| 149.870|
| 37|31951.549|150.110| 75| 75| 150.110| 150.000|
| 38|31951.549|143.820| 102| 2| 150.110| 150.000|
+---+---------+-------+----+------+---------+---------+
I am hoping to get two new columns top1priceQty, top2priceQty which store the most updated corresponding qty of top1price and top2price.
For example, in row 6, top1price= 149.370, based on this value, I want to get its corresponding qty which is 400(not 100 or 300). in row 33, when top1price=150.00000000, I want to get its corresponding qty which is 53 that comes from row 32, not 50 from row 28. same rule apply to top2price
Thank you all in advance!
You were very close to the answer by yourself. Instead of collecting set of just one column, collect array of 'LMTPRICE' and it's corresponding 'qty'. Then use getItem(0).getItem(0) for top1price and getItem(0).getItem(1) for top1priceQty. To keep the order by INTEREST_TIME for getting correct qty, use INTEREST_TIME also after 'LMTPRICE' and before 'qty'.
df.withColumn("sequence",sort_array(collect_set(array("LMTPRICE","INTEREST_TIME","qty")).over(ww),asc=false)).withColumn("top1price",col("sequence").getItem(0).getItem(0)).withColumn("top1priceQty",col("sequence").getItem(0).getItem(2).cast("int")).drop("sequence").show(false)
+-----+-------------+--------+---+------+---------+------------+
|index|INTEREST_TIME|LMTPRICE|qty|qtyChg|top1price|top1priceQty|
+-----+-------------+--------+---+------+---------+------------+
|0 |31951.509 |0.37 |1 |1 |0.37 |1 |
|1 |31951.515 |145.38 |100|100 |145.38 |100 |
|2 |31951.519 |149.37 |100|100 |149.37 |100 |
|3 |31951.52 |119.37 |5 |5 |149.37 |300 |
|4 |31951.52 |144.37 |100|100 |149.37 |300 |
|5 |31951.52 |149.37 |300|200 |149.37 |300 |
|6 |31951.521 |149.37 |400|100 |149.37 |400 |
|7 |31951.522 |149.87 |50 |50 |149.87 |50 |
|8 |31951.522 |149.37 |410|10 |149.87 |50 |
|9 |31951.522 |109.37 |50 |50 |149.87 |50 |
|10 |31951.522 |144.37 |400|300 |149.87 |50 |
|11 |31951.524 |149.87 |610|200 |149.87 |610 |
|12 |31951.526 |135.13 |22 |22 |149.87 |610 |
|13 |31951.527 |149.37 |750|140 |149.87 |610 |
|14 |31951.528 |139.37 |100|100 |149.87 |610 |
|15 |31951.528 |145.87 |50 |50 |149.87 |610 |
|16 |31951.528 |89.37 |100|100 |149.87 |610 |
|17 |31951.531 |144.37 |410|10 |149.87 |610 |
|18 |31951.531 |149.37 |769|19 |149.87 |610 |
|19 |31951.538 |149.37 |869|100 |149.87 |610 |
+-----+-------------+--------+---+------+---------+------------+

unpivoting the dataframe in spark and scala

I have a dataframe like:
+----------+-----+------+------+-----+---+
| product|china|france|german|india|usa|
+----------+-----+------+------+-----+---+
| beans| 496| 200| 210| 234|119|
| banana| null| 345| 234| 123|122|
|starwberry| 340| 430| 246| 111|321|
| mango| null| 345| 456| 110|223|
| chiku| 765| 455| 666| 122|222|
| apple| 109| 766| 544| 444|333|
+----------+-----+------+------+-----+---+
I want to unpivot it by keeping fixed as mutiple columns like
import spark.implicits._
val unPivotDF = testData.select($"product",$"german", expr("stack(4, 'china', china, 'usa', usa, 'france', france,'india',india) " +
"as (Country,Total)"))
unPivotDF.show()
which gives below o/p:
+----------+------+-------+-----+
| product|german|Country|Total|
+----------+------+-------+-----+
| beans| 210| china| 496|
| beans| 210| usa| 119|
| beans| 210| france| 200|
| beans| 210| india| 234|
| banana| 234| china| null|
| banana| 234| usa| 122|
| banana| 234| france| 345|
| banana| 234| india| 123|
|starwberry| 246| china| 340|
|starwberry| 246| usa| 321|
|starwberry| 246| france| 430|
|starwberry| 246| india| 111|
which is perfect but this fixed column like product and german are runtime information so directly i cannot use the col names in select statement
So what i was doing
val fixedCol= List[String]()
fixedCol= "german" :: fixedCol
fixedCol= "product" :: fixedCol
val col= df.select(fixedCol:_*,expr("stack(.......)") //it gives error as first argument of select is fixed and second arg is varargs
I know it can be done by using but i cannot use sql:
val ss= spark.createOrReplaceTempView(df)
spark.sql("select.......")
Is there any other way to make it dynamic
Convert all column names and exp to List[Column]
val fixedCol : List[Column] = List(col("german") , col("product") , expr("stack(.......)"))
df.select(fixedCol:_*)