Pyspark Rolling Average starting at first row - pyspark

I am trying to calculate a rolling average in Pyspark. I have it working but it seems to have different behavior than what I expected. The rolling average starts at the first row.
For example:
columns = ['month', 'day', 'value']
data = [('JAN', '01', '20000'), ('JAN', '02', '40000'), ('JAN', '03', '30000'), ('JAN', '04', '25000'), ('JAN', '05', '5000'), ('JAN', '06', '15000'),
('FEB', '01', '10000'), ('FEB', '02', '50000'), ('FEB', '03', '100000'), ('FEB', '04', '60000'), ('FEB', '05', '1000'), ('FEB', '06', '10000'),]
df_test = sc.createDataFrame(data).toDF(*columns)
win = Window.partitionBy('month').orderBy('day').rowsBetween(-2,0)
df_test.withColumn('rolling_average', f.avg('value').over(win)).show()
+-----+---+------+------------------+
|month|day| value| rolling_average|
+-----+---+------+------------------+
| JAN| 01| 20000| 20000.0|
| JAN| 02| 40000| 30000.0|
| JAN| 03| 30000| 30000.0|
| JAN| 04| 25000|31666.666666666668|
| JAN| 05| 5000| 20000.0|
| JAN| 06| 15000| 15000.0|
| FEB| 01| 10000| 10000.0|
| FEB| 02| 50000| 30000.0|
| FEB| 03|100000|53333.333333333336|
| FEB| 04| 60000| 70000.0|
| FEB| 05| 1000|53666.666666666664|
| FEB| 06| 10000|23666.666666666668|
+-----+---+------+------------------+
This would be more in line with what I expect. Is there way to get this behavior?
+-----+---+------+------------------+
|month|day| value| rolling_average|
+-----+---+------+------------------+
| JAN| 01| 20000| null|
| JAN| 02| 40000| null|
| JAN| 03| 30000| 30000.0|
| JAN| 04| 25000|31666.666666666668|
| JAN| 05| 5000| 20000.0|
| JAN| 06| 15000| 15000.0|
| FEB| 01| 10000| null|
| FEB| 02| 50000| null|
| FEB| 03|100000|53333.333333333336|
| FEB| 04| 60000| 70000.0|
| FEB| 05| 1000|53666.666666666664|
| FEB| 06| 10000|23666.666666666668|
+-----+---+------+------------------+
The issue with the default behavior is that I need another column to keep track of where the lag should start from.

Try with row_number() window function then use when+otherwise statement to replace null.
To change the lag start then change when statement col("rn") <= <value> value.
Example:
columns = ['month', 'day', 'value']
data = [('JAN', '01', '20000'), ('JAN', '02', '40000'), ('JAN', '03', '30000'), ('JAN', '04', '25000'), ('JAN', '05', '5000'), ('JAN', '06', '15000'),
('FEB', '01', '10000'), ('FEB', '02', '50000'), ('FEB', '03', '100000'), ('FEB', '04', '60000'), ('FEB', '05', '1000'), ('FEB', '06', '10000'),]
df_test = sc.createDataFrame(data).toDF(*columns)
win = Window.partitionBy('month').orderBy('day').rowsBetween(-2,0)
win1 = Window.partitionBy('month').orderBy('day')
df_test.withColumn('rolling_average', f.avg('value').over(win)).\
withColumn("rn",row_number().over(win1)).\
withColumn("rolling_average",when(col("rn") <= 2 ,lit(None)).\
otherwise(col("rolling_average"))).\
drop("rn").\
show()
#+-----+---+------+------------------+
#|month|day| value| rolling_average|
#+-----+---+------+------------------+
#| FEB| 01| 10000| null|
#| FEB| 02| 50000| null|
#| FEB| 03|100000|53333.333333333336|
#| FEB| 04| 60000| 70000.0|
#| FEB| 05| 1000|53666.666666666664|
#| FEB| 06| 10000|23666.666666666668|
#| JAN| 01| 20000| null|
#| JAN| 02| 40000| null|
#| JAN| 03| 30000| 30000.0|
#| JAN| 04| 25000|31666.666666666668|
#| JAN| 05| 5000| 20000.0|
#| JAN| 06| 15000| 15000.0|
#+-----+---+------+------------------+

More reduced version of #484.
import pyspark.sql.functions as f
from pyspark.sql import Window
w1 = Window.partitionBy('month').orderBy('day')
w2 = Window.partitionBy('month').orderBy('day').rowsBetween(-2, 0)
df.withColumn("rolling_average", f.when(f.row_number().over(w1) > f.lit(2), f.avg('value').over(w2))).show(10, False)
p.s. Please do not mark this as an answer :)

Related

Spark collect_set from a column using window function approach

I have a sample dataset with salaries. I want to distribute that salary into 3 buckets and then find the lower of the salary in each bucket and then convert that into an array and attach it to the original set. I am trying to use window function to do that. And it seems to do it in a progressive fashion.
Here is the code that I have written
val spark = sparkSession
import spark.implicits._
val simpleData = Seq(("James", "Sales", 3000),
("Michael", "Sales", 3100),
("Robert", "Sales", 3200),
("Maria", "Finance", 3300),
("James", "Sales", 3400),
("Scott", "Finance", 3500),
("Jen", "Finance", 3600),
("Jeff", "Marketing", 3700),
("Kumar", "Marketing", 3800),
("Saif", "Sales", 3900)
)
val df = simpleData.toDF("employee_name", "department", "salary")
val windowSpec = Window.orderBy("salary")
val ntileFrame = df.withColumn("ntile", ntile(3).over(windowSpec))
val lowWindowSpec = Window.partitionBy("ntile")
val ntileMinDf = ntileFrame.withColumn("lower_bound", min("salary").over(lowWindowSpec))
var rangeDf = ntileMinDf.withColumn("range", collect_set("lower_bound").over(windowSpec))
rangeDf.show()
I am getting the dataset like this
+-------------+----------+------+-----+-----------+------------------+
|employee_name|department|salary|ntile|lower_bound| range|
+-------------+----------+------+-----+-----------+------------------+
| James| Sales| 3000| 1| 3000| [3000]|
| Michael| Sales| 3100| 1| 3000| [3000]|
| Robert| Sales| 3200| 1| 3000| [3000]|
| Maria| Finance| 3300| 1| 3000| [3000]|
| James| Sales| 3400| 2| 3400| [3000, 3400]|
| Scott| Finance| 3500| 2| 3400| [3000, 3400]|
| Jen| Finance| 3600| 2| 3400| [3000, 3400]|
| Jeff| Marketing| 3700| 3| 3700|[3000, 3700, 3400]|
| Kumar| Marketing| 3800| 3| 3700|[3000, 3700, 3400]|
| Saif| Sales| 3900| 3| 3700|[3000, 3700, 3400]|
+-------------+----------+------+-----+-----------+------------------+
I am expecting the dataset to look like this
+-------------+----------+------+-----+-----------+------------------+
|employee_name|department|salary|ntile|lower_bound| range|
+-------------+----------+------+-----+-----------+------------------+
| James| Sales| 3000| 1| 3000|[3000, 3700, 3400]|
| Michael| Sales| 3100| 1| 3000|[3000, 3700, 3400]|
| Robert| Sales| 3200| 1| 3000|[3000, 3700, 3400]|
| Maria| Finance| 3300| 1| 3000|[3000, 3700, 3400]|
| James| Sales| 3400| 2| 3400|[3000, 3700, 3400]|
| Scott| Finance| 3500| 2| 3400|[3000, 3700, 3400]|
| Jen| Finance| 3600| 2| 3400|[3000, 3700, 3400]|
| Jeff| Marketing| 3700| 3| 3700|[3000, 3700, 3400]|
| Kumar| Marketing| 3800| 3| 3700|[3000, 3700, 3400]|
| Saif| Sales| 3900| 3| 3700|[3000, 3700, 3400]|
+-------------+----------+------+-----+-----------+------------------+
To ensure that your windows take into account all rows and not only rows before current row, you can use rowsBetween method with Window.unboundedPreceding and Window.unboundedFollowing as argument. Your last line thus become:
var rangeDf = ntileMinDf.withColumn(
"range",
collect_set("lower_bound")
.over(Window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))
)
and you get the following rangeDf dataframe:
+-------------+----------+------+-----+-----------+------------------+
|employee_name|department|salary|ntile|lower_bound| range|
+-------------+----------+------+-----+-----------+------------------+
| James| Sales| 3000| 1| 3000|[3000, 3700, 3400]|
| Michael| Sales| 3100| 1| 3000|[3000, 3700, 3400]|
| Robert| Sales| 3200| 1| 3000|[3000, 3700, 3400]|
| Maria| Finance| 3300| 1| 3000|[3000, 3700, 3400]|
| James| Sales| 3400| 2| 3400|[3000, 3700, 3400]|
| Scott| Finance| 3500| 2| 3400|[3000, 3700, 3400]|
| Jen| Finance| 3600| 2| 3400|[3000, 3700, 3400]|
| Jeff| Marketing| 3700| 3| 3700|[3000, 3700, 3400]|
| Kumar| Marketing| 3800| 3| 3700|[3000, 3700, 3400]|
| Saif| Sales| 3900| 3| 3700|[3000, 3700, 3400]|
+-------------+----------+------+-----+-----------+------------------+

PySpark - Getting the latest date less than another given date

I need some help. I have two dataframes, one has a few dates and the other has my significant data, catalogued by date.
It goes something like this:
First df, with the relevant data
+------+----------+---------------+
| id| test_date| score|
+------+----------+---------------+
| 1|2021-03-31| 94|
| 1|2021-01-31| 93|
| 1|2020-12-31| 100|
| 1|2020-06-30| 95|
| 1|2019-10-31| 58|
| 1|2017-10-31| 78|
| 2|2020-01-31| 79|
| 2|2018-03-31| 66|
| 2|2016-05-31| 77|
| 3|2021-05-31| 97|
| 3|2020-07-31| 100|
| 3|2019-07-31| 99|
| 3|2019-06-30| 98|
| 3|2018-07-31| 91|
| 3|2018-02-28| 86|
| 3|2017-11-30| 82|
+------+----------+---------------+
Second df, with the dates
+--------------+--------------+--------------+
| eval_date_1| eval_date_2| eval_date_3|
+--------------+--------------+--------------+
| 2021-01-31| 2020-10-31| 2019-06-30|
+--------------+--------------+--------------+
Needed DF
+------+--------------+---------+--------------+---------+--------------+---------+
| id| eval_date_1| score_1 | eval_date_2| score_2 | eval_date_3| score_3 |
+------+--------------+---------+--------------+---------+--------------+---------+
| 1| 2021-01-31| 93| 2020-10-31| 95| 2019-06-30| 78|
| 2| 2021-01-31| 79| 2020-10-31| 79| 2019-06-30| 66|
| 3| 2021-01-31| 100| 2020-10-31| 100| 2019-06-30| 98|
+------+--------------+---------+--------------+---------+--------------+---------+
So, for instance, for the first id, the needed df takes the scores from the second, fourth and sixth rows from the first df. Those are the most updated dates that stay equal to or below the eval_date on the second df.
Assuming df is your main dataframe and df_date is the one which contains only dates.
from functools import reduce
from pyspark.sql import functions as F, Window as W
df_final = reduce(
lambda a, b: a.join(b, on="id"),
(
df.join(
F.broadcast(df_date.select(f"eval_date_{i}")),
on=F.col(f"eval_date_{i}") >= F.col("test_date"),
)
.withColumn(
"rnk",
F.row_number().over(W.partitionBy("id").orderBy(F.col("test_date").desc())),
)
.where("rnk=1")
.select("id", f"eval_date_{i}", "score")
for i in range(1, 4)
),
)
df_final.show()
+---+-----------+-----+-----------+-----+-----------+-----+
| id|eval_date_1|score|eval_date_2|score|eval_date_3|score|
+---+-----------+-----+-----------+-----+-----------+-----+
| 1| 2021-01-31| 93| 2020-10-31| 95| 2019-06-30| 78|
| 3| 2021-01-31| 100| 2020-10-31| 100| 2019-06-30| 98|
| 2| 2021-01-31| 79| 2020-10-31| 79| 2019-06-30| 66|
+---+-----------+-----+-----------+-----+-----------+-----+

How can we combine values from two columns of a dataframe of same data_type and get the count of each element?

val data = Seq(
("India","Pakistan","India"),
("Australia","India","India"),
("New Zealand","Zimbabwe","New Zealand"),
("West Indies", "Bangladesh","Bangladesh"),
("Sri Lanka","Bangladesh","Bangladesh"),
("Sri Lanka","Bangladesh","Bangladesh"),
("Sri Lanka","Bangladesh","Bangladesh")
)
val df = data.toDF("Team_1","Team_2","Winner")
I have this dataframe. I want to get the count how many matches has each team played ?
There are 3 approaches discussed above answers, I tried to evaluate (just for educational/awareness ) in terms of time taken/elapsed with respect to performance....
import org.apache.log4j.Level
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions._
object Katu_37 extends App {
val logger = org.apache.log4j.Logger.getLogger("org")
logger.setLevel(Level.WARN)
val spark = SparkSession.builder.appName(getClass.getName)
.master("local[*]").getOrCreate
import spark.implicits._
val data = Seq(
("India", "Pakistan", "India"),
("Australia", "India", "India"),
("New Zealand", "Zimbabwe", "New Zealand"),
("West Indies", "Bangladesh", "Bangladesh"),
("Sri Lanka", "Bangladesh", "Bangladesh"),
("Sri Lanka", "Bangladesh", "Bangladesh"),
("Sri Lanka", "Bangladesh", "Bangladesh")
)
val df = data.toDF("Team_1", "Team_2", "Winner")
df.show
exec {
println( "METHOD 1 ")
df.select("Team_1").union(df.select("Team_2")).groupBy("Team_1").agg(count("Team_1")).show()
}
exec {
println( "METHOD 2 ")
df.select(array($"Team_1", $"Team_2").as("Team")).select("Team").withColumn("Team", explode($"Team")).groupBy("Team").agg(count("Team")).show()
}
exec {
println( "METHOD 3 ")
val matchesCount = df.selectExpr("Team_1 as Teams").union(df.selectExpr("Team_2 as Teams"))
matchesCount.groupBy("Teams").count().withColumnRenamed("count","MatchesPlayed").show()
}
/**
*
* #param f
* #tparam T
* #return
*/
def exec[T](f: => T) = {
val starttime = System.nanoTime()
println("t = " + f)
val endtime = System.nanoTime()
val elapsedTime = (endtime - starttime )
// import java.util.concurrent.TimeUnit
// val convertToSeconds = TimeUnit.MINUTES.convert(elapsedTime, TimeUnit.NANOSECONDS)
println("time Elapsed " + elapsedTime )
}
}
Result :
+-----------+----------+-----------+
| Team_1| Team_2| Winner|
+-----------+----------+-----------+
| India| Pakistan| India|
| Australia| India| India|
|New Zealand| Zimbabwe|New Zealand|
|West Indies|Bangladesh| Bangladesh|
| Sri Lanka|Bangladesh| Bangladesh|
| Sri Lanka|Bangladesh| Bangladesh|
| Sri Lanka|Bangladesh| Bangladesh|
+-----------+----------+-----------+
METHOD 1
+-----------+-------------+
| Team_1|count(Team_1)|
+-----------+-------------+
| Sri Lanka| 3|
| India| 2|
|West Indies| 1|
| Bangladesh| 4|
| Zimbabwe| 1|
|New Zealand| 1|
| Australia| 1|
| Pakistan| 1|
+-----------+-------------+
t = ()
time Elapsed 2729302088
METHOD 2
+-----------+-----------+
| Team|count(Team)|
+-----------+-----------+
| Sri Lanka| 3|
| India| 2|
|West Indies| 1|
| Bangladesh| 4|
| Zimbabwe| 1|
|New Zealand| 1|
| Australia| 1|
| Pakistan| 1|
+-----------+-----------+
t = ()
time Elapsed 646513918
METHOD 3
+-----------+-------------+
| Teams|MatchesPlayed|
+-----------+-------------+
| Sri Lanka| 3|
| India| 2|
|West Indies| 1|
| Bangladesh| 4|
| Zimbabwe| 1|
|New Zealand| 1|
| Australia| 1|
| Pakistan| 1|
+-----------+-------------+
t = ()
time Elapsed 988510662
I observed that org.apache.spark.sql.functions.array approach is taking (646513918 nano seconds) less time than union approach...
val matchesCount = df.selectExpr("Team_1 as Teams").union(df.selectExpr("Team_2 as Teams"))
matchesCount.groupBy("Teams").count().withColumnRenamed("count","MatchesPlayed").show()
+-----------+--------------+
| Teams|MatchesPlayed|
+-----------+--------------+
| Sri Lanka| 3|
| India| 2|
|West Indies| 1|
| Bangladesh| 4|
| Zimbabwe| 1|
|New Zealand| 1|
| Australia| 1|
| Pakistan| 1|
+-----------+--------------+
You can either use a union with select statement or use array from org.apache.spark.sql.functions.array
// METHOD 1
df.select("Team_1").union(df.select("Team_2")).groupBy("Team_1").agg(count("Team_1")).show()
// METHOD 2
df.select(array($"Team_1", $"Team_2").as("Team")).select("Team").withColumn("Team",explode($"Team")).groupBy("Team").agg(count("Team")).show()
Using select statement and union :
+-----------+-------------+
| Team_1|count(Team_1)|
+-----------+-------------+
| Sri Lanka| 3|
| India| 2|
|West Indies| 1|
| Bangladesh| 4|
| Zimbabwe| 1|
|New Zealand| 1|
| Australia| 1|
| Pakistan| 1|
+-----------+-------------+
Time Elapsed : 1588835600
Using array :
+-----------+-----------+
| Team|count(Team)|
+-----------+-----------+
| Sri Lanka| 3|
| India| 2|
|West Indies| 1|
| Bangladesh| 4|
| Zimbabwe| 1|
|New Zealand| 1|
| Australia| 1|
| Pakistan| 1|
+-----------+-----------+
Time Elapsed : 342103600
Performance wise using org.apache.spark.sql.functions.array is better.

unpivoting the dataframe in spark and scala

I have a dataframe like:
+----------+-----+------+------+-----+---+
| product|china|france|german|india|usa|
+----------+-----+------+------+-----+---+
| beans| 496| 200| 210| 234|119|
| banana| null| 345| 234| 123|122|
|starwberry| 340| 430| 246| 111|321|
| mango| null| 345| 456| 110|223|
| chiku| 765| 455| 666| 122|222|
| apple| 109| 766| 544| 444|333|
+----------+-----+------+------+-----+---+
I want to unpivot it by keeping fixed as mutiple columns like
import spark.implicits._
val unPivotDF = testData.select($"product",$"german", expr("stack(4, 'china', china, 'usa', usa, 'france', france,'india',india) " +
"as (Country,Total)"))
unPivotDF.show()
which gives below o/p:
+----------+------+-------+-----+
| product|german|Country|Total|
+----------+------+-------+-----+
| beans| 210| china| 496|
| beans| 210| usa| 119|
| beans| 210| france| 200|
| beans| 210| india| 234|
| banana| 234| china| null|
| banana| 234| usa| 122|
| banana| 234| france| 345|
| banana| 234| india| 123|
|starwberry| 246| china| 340|
|starwberry| 246| usa| 321|
|starwberry| 246| france| 430|
|starwberry| 246| india| 111|
which is perfect but this fixed column like product and german are runtime information so directly i cannot use the col names in select statement
So what i was doing
val fixedCol= List[String]()
fixedCol= "german" :: fixedCol
fixedCol= "product" :: fixedCol
val col= df.select(fixedCol:_*,expr("stack(.......)") //it gives error as first argument of select is fixed and second arg is varargs
I know it can be done by using but i cannot use sql:
val ss= spark.createOrReplaceTempView(df)
spark.sql("select.......")
Is there any other way to make it dynamic
Convert all column names and exp to List[Column]
val fixedCol : List[Column] = List(col("german") , col("product") , expr("stack(.......)"))
df.select(fixedCol:_*)

PySpark: Simulate SQL's UPDATE

I have two Spark DataFrames:
trg
+---+-----+---------+
|key|value| flag|
+---+-----+---------+
| 1| 0.1|unchanged|
| 2| 0.2|unchanged|
| 3| 0.3|unchanged|
+---+-----+---------+
src
+---+-----+-------+-----+
|key|value| flag|merge|
+---+-----+-------+-----+
| 1| 0.11|changed| 0|
| 2| 0.22|changed| 1|
| 3| 0.33|changed| 0|
+---+-----+-------+-----+
I need to "update" trg.value and trg.flag based on src.merge as described by the following SQL logic:
UPDATE trg
INNER JOIN src ON trg.key = src.key
SET trg.value = src.value,
trg.flag = src.flag
WHERE src.merge = 1;
Expected new trg:
+---+-----+---------+
|key|value| flag|
+---+-----+---------+
| 1| 0.1 |unchanged|
| 2| 0.22| changed|
| 3| 0.3 |unchanged|
+---+-----+---------+
I have tried using when(). It works for the flag field (since it can have only two values), but not for the value field, because I don't know how to pick the value from the corresponding row:
from pyspark.sql.functions import when
trg = spark.createDataFrame(data=[('1', '0.1', 'unchanged'),
('2', '0.2', 'unchanged'),
('3', '0.3', 'unchanged')],
schema=['key', 'value', 'flag'])
src = spark.createDataFrame(data=[('1', '0.11', 'changed', '0'),
('2', '0.22', 'changed', '1'),
('3', '0.33', 'changed', '0')],
schema=['key', 'value', 'flag', 'merge'])
new_trg = (trg.alias('trg').join(src.alias('src'), on=['key'], how='inner')
.select(
'trg.*',
when(src.merge == 1, 'changed').otherwise('unchanged').alias('flag'),
when(src.merge == 1, ???).otherwise(???).alias('value')))
Is there any other, preferably idiomatic, way to translate that SQL logic to PySpark?
newdf = (trg.join(src, on=['key'], how='inner')
.select(trg.key,
when( src.merge==1, src.value)
.otherwise(trg.value).alias('value'),
when( src.merge==1, src.flag)
.otherwise(trg.flag).alias('flag')))
newdf.show()
+---+-----+---------+
|key|value| flag|
+---+-----+---------+
| 1| 0.1|unchanged|
| 2| 0.22| changed|
| 3| 0.3|unchanged|
+---+-----+---------+
Imports and Create datasets
import pyspark.sql.functions as f
l1 = [(1, 0.1, 'unchanged'), (2, 0.2, 'unchanged'), (3, 0.3, 'unchanged')]
dfl1 = spark.createDataFrame(l1).toDF('key', 'value', 'flag')
dfl1.show()
+---+-----+---------+
|key|value| flag|
+---+-----+---------+
| 1| 0.1|unchanged|
| 2| 0.2|unchanged|
| 3| 0.3|unchanged|
+---+-----+---------+
l2 = [(1, 0.11, 'changed', 0), (2, 0.22, 'changed', 1), (3, 0.33, 'changed', 0)]
dfl2 = spark.createDataFrame(l2).toDF('key', 'value', 'flag', 'merge')
dfl2.show()
+---+-----+-------+-----+
|key|value| flag|merge|
+---+-----+-------+-----+
| 1| 0.11|changed| 0|
| 2| 0.22|changed| 1|
| 3| 0.33|changed| 0|
+---+-----+-------+-----+
# filtering upfront for better performance in next join
# dfl2 = dfl2.where(dfl2['merge'] == 1)
Join datasets
join_cond = [dfl1['key'] == dfl2['key'], dfl2['merge'] == 1]
dfl12 = dfl1.join(dfl2, join_cond, 'left_outer')
dfl12.show()
+---+-----+---------+----+-----+-------+-----+
|key|value| flag| key|value| flag|merge|
+---+-----+---------+----+-----+-------+-----+
| 1| 0.1|unchanged|null| null| null| null|
| 3| 0.3|unchanged|null| null| null| null|
| 2| 0.2|unchanged| 2| 0.22|changed| 1|
+---+-----+---------+----+-----+-------+-----+
Use when function. If its null then use the original value or use new value
df = dfl12.withColumn('new_value', f.when(dfl2['value'].isNotNull(), dfl2['value']).otherwise(dfl1['value'])).\
withColumn('new_flag', f.when(dfl2['flag'].isNotNull(), dfl2['flag']).otherwise(dfl1['flag']))
df.show()
+---+-----+---------+----+-----+-------+-----+---------+---------+
|key|value| flag| key|value| flag|merge|new_value| new_flag|
+---+-----+---------+----+-----+-------+-----+---------+---------+
| 1| 0.1|unchanged|null| null| null| null| 0.1|unchanged|
| 3| 0.3|unchanged|null| null| null| null| 0.3|unchanged|
| 2| 0.2|unchanged| 2| 0.22|changed| 1| 0.22| changed|
+---+-----+---------+----+-----+-------+-----+---------+---------+
df.select(dfl1['key'], df['new_value'], df['new_flag']).show()
+---+---------+---------+
|key|new_value| new_flag|
+---+---------+---------+
| 1| 0.1|unchanged|
| 3| 0.3|unchanged|
| 2| 0.22| changed|
+---+---------+---------+
This is for understanding, you can combine couple of steps into one.
import findspark
findspark.init()
from pyspark.sql import SparkSession
from pyspark.sql.functions import when
spark = SparkSession.builder.appName("test").getOrCreate()
data1 = [(1, 0.1, 'unchanged'), (2, 0.2,'unchanged'), (3, 0.3, 'unchanged')]
schema = ['key', 'value', 'flag']
df1 = spark.createDataFrame(data1, schema=schema)
df1.show()
+---+-----+---------+
|key|value| flag|
+---+-----+---------+
| 1| 0.1|unchanged|
| 2| 0.2|unchanged|
| 3| 0.3|unchanged|
+---+-----+---------+
data2 = [(1, 0.11, 'changed',0), (2, 0.22,'changed',1), (3, 0.33, 'changed',0)]
schema2 = ['key', 'value', 'flag', 'merge']
df2 = spark.createDataFrame(data2, schema=schema2)
df2.show()
+---+-----+-------+-----+
|key|value| flag|merge|
+---+-----+-------+-----+
| 1| 0.11|changed| 0|
| 2| 0.22|changed| 1|
| 3| 0.33|changed| 0|
+---+-----+-------+-----+
df2 = df2.withColumnRenamed("value", "value1").withColumnRenamed("flag", 'flag1')
mer = df1.join(df2, ['key'], 'inner')
mer = mer.withColumn("temp", when(mer.merge == 1, mer.value1).otherwise(mer.value))
mer = mer.withColumn("temp1", when(mer.merge == 1, 'changed').otherwise('unchanged'))
output = mer.select(mer.key, mer.temp.alias('value'), mer.temp1.alias('flag'))
output.orderBy(output.value.asc()).show()
+---+-----+---------+
|key|value| flag|
+---+-----+---------+
| 1| 0.1|unchanged|
| 2| 0.22| changed|
| 3| 0.3|unchanged|
+---+-----+---------+