How to calculate whether transactions fall into specific periods without loops - pyspark

I have a Pyspark dataframe of transactions by customer which feeds into a dashboard. For each rolling 12 month time period, I want to calculate whether a customer is 'New' (never before purchased), 'Retained' (made a purchase in the 12 months before the start of the current time period and purchased in the current time period), or 'Reactivated' (made a purchase prior to the previous 12 months, didn't purchase in the previous 12 months, and purchased in the current month).
Clarification of 'current time period':
If current period is the Rolling 12 Months to the end of September 2022, any purchase from October 2021 to September 2022 falls into the 'current' time period. Purchases from October 2020 to September 2021 fall into the 'previous 12 months', and purchases from September 2020 and earlier are 'prior to the previous 12 months'.
input:
customer_id
transaction_id
transaction_date
1
1
2019-JAN-10
1
2
2019-DEC-15
1
3
2022-SEP-07
intermediate:
customer_id
txn_id
txn_date
period
txn_current
txn_prev_12m
txn_prior_prev_12m
1
1
2019-JAN-10
SEP 2022
0
0
1
1
2
2019-DEC-15
SEP 2022
0
0
1
1
3
2022-SEP-07
SEP 2022
1
0
0
final:
customer_id
txn_period
txn_current
txn_prev_12m
txn_prior_prev_12m
status
1
SEP 2022
1
0
2
Reactivated
My current solution loops through each required evaluation period (Jan 2022, Feb 2022, Mar 2022, etc.), classifying the customer status for that period. This step, however, takes hours to process because it has to loop through dozens of different time periods over a dataframe with millions of rows.
I feel like I'm missing something obvious, but how can I calculate this without looping through each time period and checking whether each individual transaction falls within the bounds of that time period?

You can use Lag function to extract the previous order by user, Then calculate the datediff of date and previous_date.
from pyspark.sql import functions as F
from pyspark.sql import Window as W
window = W.partitionBy('customer_id').orderBy('transaction_date')
(
df
.withColumn('prev_order_date', F.lag('transaction_date').over(window))
.withColumn('datediff', F.datediff(F.col('transaction_date'), F.col('prev_order_date')))
).show()

Here's my implementation.
generate a table for all the time periods for the current year
and then cross join that table on all the transactions and perform calculations. the final output is assuming you want the status per customer for each period.
from pyspark.sql import functions as F
from pyspark.sql.types import *
from pyspark.sql import Window
from datetime import *
from dateutil.relativedelta import relativedelta
year_start = date(datetime.today().year, 1, 1)
n_months = 12
date_list = [year_start + relativedelta(months=i) for i in range(n_months)]
dates_formatted = [(d.strftime("%b %Y"),) for d in date_list]
period_df = spark.createDataFrame(dates_formatted, ["period"])
period_df.show()
current_time_period = "SEP 2022"
df = spark.createDataFrame(
[
("1", "1", "2019-JAN-10"),
("1", "2", "2019-DEC-15"),
("1", "3", "2022-SEP-07"),
],
["customer_id", "transaction_id", "transaction_date"],
)
cols = [
"customer_id",
"txn_id",
"txn_date",
"period",
"txn_current",
"txn_prev_12m",
"txn_prior_prev_12m"
]
txn_period_df = df.crossJoin(period_df)
txn_period_df.show(n=100,truncate=False)
#intermediate:
df = (
txn_period_df
.withColumnRenamed("transaction_id", "txn_id")
.withColumnRenamed("transaction_date", "txn_date")
.withColumn("txn_date_formatted", F.to_date(F.col("txn_date"),"yyyy-MMM-dd"))
# add one month to include tranctions after the 1st
.withColumn("current_date", F.add_months( F.to_date(F.col("period"),"MMM yyyy") , 1) )
#any purchase from October 2021 to September 2022 falls into the 'current' time period.
.withColumn('txn_current_date', F.add_months(F.col("current_date"), -12))
.withColumn(
"txn_current",
F.when(
(F.col("txn_date_formatted") >= F.col("txn_current_date"))
& (F.col("txn_date_formatted") < F.col("current_date")) , F.lit(1)
).otherwise(F.lit(0)),
)
#Purchases from October 2020 to September 2021 fall into the 'previous 12 months'
.withColumn('txn_prev_12m_date', F.add_months(F.col("txn_current_date"), -12))
.withColumn(
"txn_prev_12m",
F.when(
(F.col("txn_date_formatted") >= F.col("txn_prev_12m_date"))
& (F.col("txn_date_formatted") < F.col("txn_current_date")) , F.lit(1)
).otherwise(F.lit(0)),
)
#and purchases from September 2020 and earlier are 'prior to the previous 12 months'
.withColumn(
"txn_prior_prev_12m",
F.when(
(F.col("txn_date_formatted") < F.col("txn_prev_12m_date")) , F.lit(1)
).otherwise(F.lit(0)),
)
.select(cols)
)
df.show()
cols = [
"customer_id",
"txn_period",
"txn_current",
"txn_prev_12m",
"txn_prior_prev_12m",
"status",
]
txn_agg_window = Window.partitionBy(
"customer_id",
"txn_period",
).orderBy(F.col("customer_id"))
#final:
final_df = (
df
.withColumnRenamed("period", "txn_period")
.withColumn(
"txn_current",
F.sum("txn_current").over(txn_agg_window),
)
.withColumn(
"txn_prev_12m",
F.sum("txn_prev_12m").over(txn_agg_window),
)
.withColumn(
"txn_prior_prev_12m",
F.sum("txn_prior_prev_12m").over(txn_agg_window),
)
.withColumn(
"row_num",
F.row_number().over(txn_agg_window),
)
.filter(F.col("row_num") == 1)
.drop("row_num")
.withColumn(
"status",
F.when(
(F.col("txn_prior_prev_12m") > 0)
& (F.col("txn_prev_12m") == 0)
& (F.col("txn_current") > 0), F.lit("Reactivated")
)
.when(
(F.col("txn_prev_12m") > 0)
& (F.col("txn_current") > 0), F.lit("Retained")
)
.when(
(F.col("txn_prior_prev_12m") == 0)
& (F.col("txn_prev_12m") == 0)
& (F.col("txn_current") == 0), F.lit("New")
)
.otherwise(F.lit("NA").cast(StringType())),
)
.select(cols)
.orderBy(F.to_date(F.col("txn_period"),"MMM yyyy").asc())
)
final_df.show()
periods:
+--------+
| period|
+--------+
|Jan 2022|
|Feb 2022|
|Mar 2022|
|Apr 2022|
|May 2022|
|Jun 2022|
|Jul 2022|
|Aug 2022|
|Sep 2022|
|Oct 2022|
|Nov 2022|
|Dec 2022|
+--------+
intermediate:
+-----------+------+-----------+--------+-----------+------------+------------------+
|customer_id|txn_id| txn_date| period|txn_current|txn_prev_12m|txn_prior_prev_12m|
+-----------+------+-----------+--------+-----------+------------+------------------+
| 1| 1|2019-JAN-10|Jan 2022| 0| 0| 1|
| 1| 1|2019-JAN-10|Feb 2022| 0| 0| 1|
| 1| 1|2019-JAN-10|Mar 2022| 0| 0| 1|
| 1| 1|2019-JAN-10|Apr 2022| 0| 0| 1|
| 1| 1|2019-JAN-10|May 2022| 0| 0| 1|
| 1| 1|2019-JAN-10|Jun 2022| 0| 0| 1|
| 1| 1|2019-JAN-10|Jul 2022| 0| 0| 1|
| 1| 1|2019-JAN-10|Aug 2022| 0| 0| 1|
| 1| 1|2019-JAN-10|Sep 2022| 0| 0| 1|
| 1| 1|2019-JAN-10|Oct 2022| 0| 0| 1|
| 1| 1|2019-JAN-10|Nov 2022| 0| 0| 1|
| 1| 1|2019-JAN-10|Dec 2022| 0| 0| 1|
| 1| 2|2019-DEC-15|Jan 2022| 0| 0| 1|
| 1| 2|2019-DEC-15|Feb 2022| 0| 0| 1|
| 1| 2|2019-DEC-15|Mar 2022| 0| 0| 1|
| 1| 2|2019-DEC-15|Apr 2022| 0| 0| 1|
| 1| 2|2019-DEC-15|May 2022| 0| 0| 1|
| 1| 2|2019-DEC-15|Jun 2022| 0| 0| 1|
| 1| 2|2019-DEC-15|Jul 2022| 0| 0| 1|
| 1| 2|2019-DEC-15|Aug 2022| 0| 0| 1|
| 1| 2|2019-DEC-15|Sep 2022| 0| 0| 1|
| 1| 2|2019-DEC-15|Oct 2022| 0| 0| 1|
| 1| 2|2019-DEC-15|Nov 2022| 0| 0| 1|
| 1| 2|2019-DEC-15|Dec 2022| 0| 0| 1|
| 1| 3|2022-SEP-07|Jan 2022| 0| 0| 0|
| 1| 3|2022-SEP-07|Feb 2022| 0| 0| 0|
| 1| 3|2022-SEP-07|Mar 2022| 0| 0| 0|
| 1| 3|2022-SEP-07|Apr 2022| 0| 0| 0|
| 1| 3|2022-SEP-07|May 2022| 0| 0| 0|
| 1| 3|2022-SEP-07|Jun 2022| 0| 0| 0|
| 1| 3|2022-SEP-07|Jul 2022| 0| 0| 0|
| 1| 3|2022-SEP-07|Aug 2022| 0| 0| 0|
| 1| 3|2022-SEP-07|Sep 2022| 1| 0| 0|
| 1| 3|2022-SEP-07|Oct 2022| 1| 0| 0|
| 1| 3|2022-SEP-07|Nov 2022| 1| 0| 0|
| 1| 3|2022-SEP-07|Dec 2022| 1| 0| 0|
+-----------+------+-----------+--------+-----------+------------+------------------+
final:
+-----------+----------+-----------+------------+------------------+-----------+
|customer_id|txn_period|txn_current|txn_prev_12m|txn_prior_prev_12m| status|
+-----------+----------+-----------+------------+------------------+-----------+
| 1| Jan 2022| 0| 0| 2| NA|
| 1| Feb 2022| 0| 0| 2| NA|
| 1| Mar 2022| 0| 0| 2| NA|
| 1| Apr 2022| 0| 0| 2| NA|
| 1| May 2022| 0| 0| 2| NA|
| 1| Jun 2022| 0| 0| 2| NA|
| 1| Jul 2022| 0| 0| 2| NA|
| 1| Aug 2022| 0| 0| 2| NA|
| 1| Sep 2022| 1| 0| 2|Reactivated|
| 1| Oct 2022| 1| 0| 2|Reactivated|
| 1| Nov 2022| 1| 0| 2|Reactivated|
| 1| Dec 2022| 1| 0| 2|Reactivated|
+-----------+----------+-----------+------------+------------------+-----------+

Related

Window function based on a condition

I have the following DF:
|-----------------------|
|Date | Val | Cond|
|-----------------------|
|2022-01-08 | 2 | 0 |
|2022-01-09 | 4 | 1 |
|2022-01-10 | 6 | 1 |
|2022-01-11 | 8 | 0 |
|2022-01-12 | 2 | 1 |
|2022-01-13 | 5 | 1 |
|2022-01-14 | 7 | 0 |
|2022-01-15 | 9 | 0 |
|-----------------------|
I need to sum the values of two days before where cond = 1 for every date, my expected output is:
|-----------------|
|Date | Sum |
|-----------------|
|2022-01-08 | 0 | Not sum because doesnt exists two dates with cond = 1 before this date
|2022-01-09 | 0 | Not sum because doesnt exists two dates with cond = 1 before this date
|2022-01-10 | 0 | Not sum because doesnt exists two dates with cond = 1 before this date
|2022-01-11 | 10 | (4+6)
|2022-01-12 | 10 | (4+6)
|2022-01-13 | 8 | (2+6)
|2022-01-14 | 7 | (5+2)
|2022-01-15 | 7 | (5+2)
|-----------------|
I've tried to get the output DF using this code:
df = df.where("Cond= 1").withColumn(
"ListView",
f.collect_list("Val").over(windowSpec.rowsBetween(-2, -1))
)
But when I use .where("Cond = 1") I exclude the dates that cond is equal zero.
I found the following answer but didn't help me:
Window.rowsBetween - only consider rows fulfilling a specific condition (e.g. not being null)
How can I achieve my expected output using window functions?
The MVCE:
data_1=[
("2022-01-08",2,0),
("2022-01-09",4,1),
("2022-01-10",6,1),
("2022-01-11",8,0),
("2022-01-12",2,1),
("2022-01-13",5,1),
("2022-01-14",7,0),
("2022-01-15",9,0)
]
schema_1 = StructType([
StructField("Date", DateType(),True),
StructField("Val", IntegerType(),True),
StructField("Cond", IntegerType(),True)
])
df_1 = spark.createDataFrame(data=data_1,schema=schema_1)
The following should do the trick (but I'm sure it can be further optimized).
Setup:
data_1=[
("2022-01-08",2,0),
("2022-01-09",4,1),
("2022-01-10",6,1),
("2022-01-11",8,0),
("2022-01-12",2,1),
("2022-01-13",5,1),
("2022-01-14",7,0),
("2022-01-15",9,0),
("2022-01-16",9,0),
("2022-01-17",9,0)
]
schema_1 = StructType([
StructField("Date", StringType(),True),
StructField("Val", IntegerType(),True),
StructField("Cond", IntegerType(),True)
])
df_1 = spark.createDataFrame(data=data_1,schema=schema_1)
df_1 = df_1.withColumn('Date', to_date("Date", "yyyy-MM-dd"))
+----------+---+----+
| Date|Val|Cond|
+----------+---+----+
|2022-01-08| 2| 0|
|2022-01-09| 4| 1|
|2022-01-10| 6| 1|
|2022-01-11| 8| 0|
|2022-01-12| 2| 1|
|2022-01-13| 5| 1|
|2022-01-14| 7| 0|
|2022-01-15| 9| 0|
|2022-01-16| 9| 0|
|2022-01-17| 9| 0|
+----------+---+----+
Create a new DF only with Cond==1 rows to obtain the sum of two consecutive rows with that condition:
windowSpec = Window.partitionBy("Cond").orderBy("Date")
df_2 = df_1.where(df_1.Cond==1).withColumn(
"Sum",
sum("Val").over(windowSpec.rowsBetween(-1, 0))
).withColumn('date_1', col('date')).drop('date')
+---+----+---+----------+
|Val|Cond|Sum| date_1|
+---+----+---+----------+
| 4| 1| 4|2022-01-09|
| 6| 1| 10|2022-01-10|
| 2| 1| 8|2022-01-12|
| 5| 1| 7|2022-01-13|
+---+----+---+----------+
Do a left join to get the sum into the original data frame, and set the sum to zero for the rows with Cond==0:
df_3 = df_1.join(df_2.select('sum', col('date_1')), df_1.Date == df_2.date_1, "left").drop('date_1').fillna(0)
+----------+---+----+---+
| Date|Val|Cond|sum|
+----------+---+----+---+
|2022-01-08| 2| 0| 0|
|2022-01-09| 4| 1| 4|
|2022-01-10| 6| 1| 10|
|2022-01-11| 8| 0| 0|
|2022-01-12| 2| 1| 8|
|2022-01-13| 5| 1| 7|
|2022-01-14| 7| 0| 0|
|2022-01-15| 9| 0| 0|
|2022-01-16| 9| 0| 0|
|2022-01-17| 9| 0| 0|
+----------+---+----+---+
Do a cumulative sum on the condition column:
df_3=df_3.withColumn('cond_sum', sum('cond').over(Window.orderBy('Date')))
+----------+---+----+---+--------+
| Date|Val|Cond|sum|cond_sum|
+----------+---+----+---+--------+
|2022-01-08| 2| 0| 0| 0|
|2022-01-09| 4| 1| 4| 1|
|2022-01-10| 6| 1| 10| 2|
|2022-01-11| 8| 0| 0| 2|
|2022-01-12| 2| 1| 8| 3|
|2022-01-13| 5| 1| 7| 4|
|2022-01-14| 7| 0| 0| 4|
|2022-01-15| 9| 0| 0| 4|
|2022-01-16| 9| 0| 0| 4|
|2022-01-17| 9| 0| 0| 4|
+----------+---+----+---+--------+
Finally, for each partition where the cond_sum is greater than 1, use the max sum for that partition:
df_3.withColumn('sum', when(df_3.cond_sum > 1, max('sum').over(Window.partitionBy('cond_sum'))).otherwise(0)).show()
+----------+---+----+---+--------+
| Date|Val|Cond|sum|cond_sum|
+----------+---+----+---+--------+
|2022-01-08| 2| 0| 0| 0|
|2022-01-09| 4| 1| 0| 1|
|2022-01-10| 6| 1| 10| 2|
|2022-01-11| 8| 0| 10| 2|
|2022-01-12| 2| 1| 8| 3|
|2022-01-13| 5| 1| 7| 4|
|2022-01-14| 7| 0| 7| 4|
|2022-01-15| 9| 0| 7| 4|
|2022-01-16| 9| 0| 7| 4|
|2022-01-17| 9| 0| 7| 4|
+----------+---+----+---+--------+

how to join dataframes with some similar values and multiple keys / scala

I have problems to get following table. The first two tables are my source tables which i would like to join. the third table is how i would like to have it.
I tried it with an outer join and used the keys "ID" and "date" but the result is not the same like in this example. The problem is, that some def_ values in each table have the same date and i would like to get them in the same row.
I used following join:
val df_result = df_1.join(df_2, Seq("ID", "date"), "outer")
df
+----+-----+-----------+
|ID |def_a| date |
+----+-----+-----------+
| 01| 1| 2019-01-31|
| 02| 1| 2019-12-31|
| 03| 1| 2019-11-30|
| 01| 1| 2019-10-31|
df
+----+-----+-----+-----------+
|ID |def_b|def_c|date |
+----+-----+-----+-----------+
| 01| 1| 0| 2017-01-31|
| 02| 1| 1| 2019-12-31|
| 03| 1| 1| 2018-11-30|
| 03| 0| 1| 2019-11-30|
| 01| 1| 1| 2018-09-30|
| 02| 1| 1| 2018-08-31|
| 01| 1| 1| 2018-07-31|
result
+----+-----+-----+-----+-----------+
|ID |def_a|def_b|deb_c|date |
+----+-----+-----+-----+-----------+
| 01| 1| 0| 0| 2019-01-31|
| 02| 1| 1| 1| 2019-12-31|
| 03| 1| 0| 1| 2019-11-30|
| 01| 1| 0| 0| 2019-10-31|
| 01| 0| 1| 0| 2017-01-31|
| 03| 0| 1| 1| 2018-11-30|
| 01| 0| 1| 1| 2018-09-30|
| 02| 0| 1| 1| 2018-08-31|
| 01| 0| 1| 1| 2018-07-31|
I would be grateful for any help.
Hope the following code would be helpful —
df_result
.groupBy("ID", "date")
.agg(
max("a"),
max("b"),
max("c")
)

Filtering on multiple columns in Spark dataframes

Suppose I have a dataframe in Spark as shown below -
val df = Seq(
(0,0,0,0.0),
(1,0,0,0.1),
(0,1,0,0.11),
(0,0,1,0.12),
(1,1,0,0.24),
(1,0,1,0.27),
(0,1,1,0.30),
(1,1,1,0.40)
).toDF("A","B","C","rate")
Here is how it looks like -
scala> df.show()
+---+---+---+----+
| A| B| C|rate|
+---+---+---+----+
| 0| 0| 0| 0.0|
| 1| 0| 0| 0.1|
| 0| 1| 0|0.11|
| 0| 0| 1|0.12|
| 1| 1| 0|0.24|
| 1| 0| 1|0.27|
| 0| 1| 1| 0.3|
| 1| 1| 1| 0.4|
+---+---+---+----+
A,B and C are the advertising channels in this case. 0 and 1 represent absence and presence of channels respectively. 2^3 shows 8 combinations in the data-frame.
I want to filter records from this data-frame that shows presence of 2 channels at a time( AB, AC, BC) . Here is how I want my output to be -
+---+---+---+----+
| A| B| C|rate|
+---+---+---+----+
| 1| 1| 0|0.24|
| 1| 0| 1|0.27|
| 0| 1| 1| 0.3|
+---+---+---+----+
I can write 3 statements to get the output by doing -
scala> df.filter($"A" === 1 && $"B" === 1 && $"C" === 0).show()
+---+---+---+----+
| A| B| C|rate|
+---+---+---+----+
| 1| 1| 0|0.24|
+---+---+---+----+
scala> df.filter($"A" === 1 && $"B" === 0 && $"C" === 1).show()
+---+---+---+----+
| A| B| C|rate|
+---+---+---+----+
| 1| 0| 1|0.27|
+---+---+---+----+
scala> df.filter($"A" === 0 && $"B" === 1 && $"C" === 1).show()
+---+---+---+----+
| A| B| C|rate|
+---+---+---+----+
| 0| 1| 1| 0.3|
+---+---+---+----+
However, I want to achieve this using either a single statement that does my job or a function that helps me get the output.
I was thinking of using a case statement to match the values. However in general my dataframe might consist of more than 3 channels -
scala> df.show()
+---+---+---+---+----+
| A| B| C| D|rate|
+---+---+---+---+----+
| 0| 0| 0| 0| 0.0|
| 0| 0| 0| 1| 0.1|
| 0| 0| 1| 0| 0.1|
| 0| 0| 1| 1|0.59|
| 0| 1| 0| 0| 0.1|
| 0| 1| 0| 1|0.89|
| 0| 1| 1| 0|0.39|
| 0| 1| 1| 1| 0.4|
| 1| 0| 0| 0| 0.0|
| 1| 0| 0| 1|0.99|
| 1| 0| 1| 0|0.49|
| 1| 0| 1| 1| 0.1|
| 1| 1| 0| 0|0.79|
| 1| 1| 0| 1| 0.1|
| 1| 1| 1| 0| 0.1|
| 1| 1| 1| 1| 0.1|
+---+---+---+---+----+
In this scenario I would want my output as -
scala> df.show()
+---+---+---+---+----+
| A| B| C| D|rate|
+---+---+---+---+----+
| 0| 0| 1| 1|0.59|
| 0| 1| 0| 1|0.89|
| 0| 1| 1| 0|0.39|
| 1| 0| 0| 1|0.99|
| 1| 0| 1| 0|0.49|
| 1| 1| 0| 0|0.79|
+---+---+---+---+----+
which shows rates for paired presence of channels => (AB, AC, AD, BC, BD, CD).
Kindly help.
One way could be to sum the columns and then filter only when the result of the sum is 2.
import org.apache.spark.sql.functions._
df.withColumn("res", $"A" + $"B" + $"C").filter($"res" === lit(2)).drop("res").show
The output is:
+---+---+---+----+
| A| B| C|rate|
+---+---+---+----+
| 1| 1| 0|0.24|
| 1| 0| 1|0.27|
| 0| 1| 1| 0.3|
+---+---+---+----+

How to enumerate the rows of a dataframe? Spark Scala

I have a dataframe (renderDF) like this:
+------+---+-------+
| uid|sid|renders|
+------+---+-------+
| david| 0| 0|
|rachel| 1| 0|
|rachel| 3| 0|
|rachel| 2| 0|
| pep| 2| 0|
| pep| 0| 1|
| pep| 1| 1|
|rachel| 0| 1|
| rick| 1| 1|
| ross| 0| 3|
| rick| 0| 3|
+------+---+-------+
I want to use a window function to achieve this result
+------+---+-------+-----------+
| uid|sid|renders|row_number |
+------+---+-------+-----------+
| david| 0| 0| 1 |
|rachel| 1| 0| 2 |
|rachel| 3| 0| 3 |
|rachel| 2| 0| 4 |
| pep| 2| 0| 5 |
| pep| 0| 1| 6 |
| pep| 1| 1| 7 |
|rachel| 0| 1| 8 |
| rick| 1| 1| 9 |
| ross| 0| 3| 10 |
| rick| 0| 3| 11 |
+------+---+-------+-----------+
I try:
val windowRender = Window.partitionBy('sid).orderBy('Renders)
renderDF.withColumn("row_number", row_number() over windowRender)
But it doesn't do what I need.
Is the partition my problem?
try this:
val dfWithRownumber = renderDF.withColumn("row_number", row_number.over(Window.partitionBy(lit(1)).orderBy("renders")))

how to filter out a null value from spark dataframe

I created a dataframe in spark with the following schema:
root
|-- user_id: long (nullable = false)
|-- event_id: long (nullable = false)
|-- invited: integer (nullable = false)
|-- day_diff: long (nullable = true)
|-- interested: integer (nullable = false)
|-- event_owner: long (nullable = false)
|-- friend_id: long (nullable = false)
And the data is shown below:
+----------+----------+-------+--------+----------+-----------+---------+
| user_id| event_id|invited|day_diff|interested|event_owner|friend_id|
+----------+----------+-------+--------+----------+-----------+---------+
| 4236494| 110357109| 0| -1| 0| 937597069| null|
| 78065188| 498404626| 0| 0| 0| 2904922087| null|
| 282487230|2520855981| 0| 28| 0| 3749735525| null|
| 335269852|1641491432| 0| 2| 0| 1490350911| null|
| 437050836|1238456614| 0| 2| 0| 991277599| null|
| 447244169|2095085551| 0| -1| 0| 1579858878| null|
| 516353916|1076364848| 0| 3| 1| 3597645735| null|
| 528218683|1151525474| 0| 1| 0| 3433080956| null|
| 531967718|3632072502| 0| 1| 0| 3863085861| null|
| 627948360|2823119321| 0| 0| 0| 4092665803| null|
| 811791433|3513954032| 0| 2| 0| 415464198| null|
| 830686203| 99027353| 0| 0| 0| 3549822604| null|
|1008893291|1115453150| 0| 2| 0| 2245155244| null|
|1239364869|2824096896| 0| 2| 1| 2579294650| null|
|1287950172|1076364848| 0| 0| 0| 3597645735| null|
|1345896548|2658555390| 0| 1| 0| 2025118823| null|
|1354205322|2564682277| 0| 3| 0| 2563033185| null|
|1408344828|1255629030| 0| -1| 1| 804901063| null|
|1452633375|1334001859| 0| 4| 0| 1488588320| null|
|1625052108|3297535757| 0| 3| 0| 1972598895| null|
+----------+----------+-------+--------+----------+-----------+---------+
I want to filter out the rows have null values in the field of "friend_id".
scala> val aaa = test.filter("friend_id is null")
scala> aaa.count
I got :res52: Long = 0 which is obvious not right. What is the right way to get it?
One more question, I want to replace the values in the friend_id field. I want to replace null with 0 and 1 for any other value except null. The code I can figure out is:
val aaa = train_friend_join.select($"user_id", $"event_id", $"invited", $"day_diff", $"interested", $"event_owner", ($"friend_id" != null)?1:0)
This code also doesn't work. Can anyone tell me how can I fix it? Thanks
Let's say you have this data setup (so that results are reproducible):
// declaring data types
case class Company(cName: String, cId: String, details: String)
case class Employee(name: String, id: String, email: String, company: Company)
// setting up example data
val e1 = Employee("n1", null, "n1#c1.com", Company("c1", "1", "d1"))
val e2 = Employee("n2", "2", "n2#c1.com", Company("c1", "1", "d1"))
val e3 = Employee("n3", "3", "n3#c1.com", Company("c1", "1", "d1"))
val e4 = Employee("n4", "4", "n4#c2.com", Company("c2", "2", "d2"))
val e5 = Employee("n5", null, "n5#c2.com", Company("c2", "2", "d2"))
val e6 = Employee("n6", "6", "n6#c2.com", Company("c2", "2", "d2"))
val e7 = Employee("n7", "7", "n7#c3.com", Company("c3", "3", "d3"))
val e8 = Employee("n8", "8", "n8#c3.com", Company("c3", "3", "d3"))
val employees = Seq(e1, e2, e3, e4, e5, e6, e7, e8)
val df = sc.parallelize(employees).toDF
Data is:
+----+----+---------+---------+
|name| id| email| company|
+----+----+---------+---------+
| n1|null|n1#c1.com|[c1,1,d1]|
| n2| 2|n2#c1.com|[c1,1,d1]|
| n3| 3|n3#c1.com|[c1,1,d1]|
| n4| 4|n4#c2.com|[c2,2,d2]|
| n5|null|n5#c2.com|[c2,2,d2]|
| n6| 6|n6#c2.com|[c2,2,d2]|
| n7| 7|n7#c3.com|[c3,3,d3]|
| n8| 8|n8#c3.com|[c3,3,d3]|
+----+----+---------+---------+
Now to filter employees with null ids, you will do --
df.filter("id is null").show
which will correctly show you following:
+----+----+---------+---------+
|name| id| email| company|
+----+----+---------+---------+
| n1|null|n1#c1.com|[c1,1,d1]|
| n5|null|n5#c2.com|[c2,2,d2]|
+----+----+---------+---------+
Coming to the second part of your question, you can replace the null ids with 0 and other values with 1 with this --
df.withColumn("id", when($"id".isNull, 0).otherwise(1)).show
This results in:
+----+---+---------+---------+
|name| id| email| company|
+----+---+---------+---------+
| n1| 0|n1#c1.com|[c1,1,d1]|
| n2| 1|n2#c1.com|[c1,1,d1]|
| n3| 1|n3#c1.com|[c1,1,d1]|
| n4| 1|n4#c2.com|[c2,2,d2]|
| n5| 0|n5#c2.com|[c2,2,d2]|
| n6| 1|n6#c2.com|[c2,2,d2]|
| n7| 1|n7#c3.com|[c3,3,d3]|
| n8| 1|n8#c3.com|[c3,3,d3]|
+----+---+---------+---------+
Or like df.filter($"friend_id".isNotNull)
df.where(df.col("friend_id").isNull)
There are two ways to do it: creating filter condition 1) Manually 2) Dynamically.
Sample DataFrame:
val df = spark.createDataFrame(Seq(
(0, "a1", "b1", "c1", "d1"),
(1, "a2", "b2", "c2", "d2"),
(2, "a3", "b3", null, "d3"),
(3, "a4", null, "c4", "d4"),
(4, null, "b5", "c5", "d5")
)).toDF("id", "col1", "col2", "col3", "col4")
+---+----+----+----+----+
| id|col1|col2|col3|col4|
+---+----+----+----+----+
| 0| a1| b1| c1| d1|
| 1| a2| b2| c2| d2|
| 2| a3| b3|null| d3|
| 3| a4|null| c4| d4|
| 4|null| b5| c5| d5|
+---+----+----+----+----+
1) Creating filter condition manually i.e. using DataFrame where or filter function
df.filter(col("col1").isNotNull && col("col2").isNotNull).show
or
df.where("col1 is not null and col2 is not null").show
Result:
+---+----+----+----+----+
| id|col1|col2|col3|col4|
+---+----+----+----+----+
| 0| a1| b1| c1| d1|
| 1| a2| b2| c2| d2|
| 2| a3| b3|null| d3|
+---+----+----+----+----+
2) Creating filter condition dynamically: This is useful when we don't want any column to have null value and there are large number of columns, which is mostly the case.
To create the filter condition manually in these cases will waste a lot of time. In below code we are including all columns dynamically using map and reduce function on DataFrame columns:
val filterCond = df.columns.map(x=>col(x).isNotNull).reduce(_ && _)
How filterCond looks:
filterCond: org.apache.spark.sql.Column = (((((id IS NOT NULL) AND (col1 IS NOT NULL)) AND (col2 IS NOT NULL)) AND (col3 IS NOT NULL)) AND (col4 IS NOT NULL))
Filtering:
val filteredDf = df.filter(filterCond)
Result:
+---+----+----+----+----+
| id|col1|col2|col3|col4|
+---+----+----+----+----+
| 0| a1| b1| c1| d1|
| 1| a2| b2| c2| d2|
+---+----+----+----+----+
A good solution for me was to drop the rows with any null values:
Dataset<Row> filtered = df.filter(row => !row.anyNull);
In case one is interested in the other case, just call row.anyNull.
(Spark 2.1.0 using Java API)
The following lines work well:
test.filter("friend_id is not null")
From the hint from Michael Kopaniov, below works
df.where(df("id").isNotNull).show
Here is a solution for spark in Java. To select data rows containing nulls. When you have Dataset data, you do:
Dataset<Row> containingNulls = data.where(data.col("COLUMN_NAME").isNull())
To filter out data without nulls you do:
Dataset<Row> withoutNulls = data.where(data.col("COLUMN_NAME").isNotNull())
Often dataframes contain columns of type String where instead of nulls we have empty strings like "". To filter out such data as well we do:
Dataset<Row> withoutNullsAndEmpty = data.where(data.col("COLUMN_NAME").isNotNull().and(data.col("COLUMN_NAME").notEqual("")))
for the first question, it is correct you are filtering out nulls and hence count is zero.
for the second replacing: use like below:
val options = Map("path" -> "...\\ex.csv", "header" -> "true")
val dfNull = spark.sqlContext.load("com.databricks.spark.csv", options)
scala> dfNull.show
+----------+----------+-------+--------+----------+-----------+---------+
| user_id| event_id|invited|day_diff|interested|event_owner|friend_id|
+----------+----------+-------+--------+----------+-----------+---------+
| 4236494| 110357109| 0| -1| 0| 937597069| null|
| 78065188| 498404626| 0| 0| 0| 2904922087| null|
| 282487230|2520855981| 0| 28| 0| 3749735525| null|
| 335269852|1641491432| 0| 2| 0| 1490350911| null|
| 437050836|1238456614| 0| 2| 0| 991277599| null|
| 447244169|2095085551| 0| -1| 0| 1579858878| a|
| 516353916|1076364848| 0| 3| 1| 3597645735| b|
| 528218683|1151525474| 0| 1| 0| 3433080956| c|
| 531967718|3632072502| 0| 1| 0| 3863085861| null|
| 627948360|2823119321| 0| 0| 0| 4092665803| null|
| 811791433|3513954032| 0| 2| 0| 415464198| null|
| 830686203| 99027353| 0| 0| 0| 3549822604| null|
|1008893291|1115453150| 0| 2| 0| 2245155244| null|
|1239364869|2824096896| 0| 2| 1| 2579294650| d|
|1287950172|1076364848| 0| 0| 0| 3597645735| null|
|1345896548|2658555390| 0| 1| 0| 2025118823| null|
|1354205322|2564682277| 0| 3| 0| 2563033185| null|
|1408344828|1255629030| 0| -1| 1| 804901063| null|
|1452633375|1334001859| 0| 4| 0| 1488588320| null|
|1625052108|3297535757| 0| 3| 0| 1972598895| null|
+----------+----------+-------+--------+----------+-----------+---------+
dfNull.withColumn("friend_idTmp", when($"friend_id".isNull, "1").otherwise("0")).drop($"friend_id").withColumnRenamed("friend_idTmp", "friend_id").show
+----------+----------+-------+--------+----------+-----------+---------+
| user_id| event_id|invited|day_diff|interested|event_owner|friend_id|
+----------+----------+-------+--------+----------+-----------+---------+
| 4236494| 110357109| 0| -1| 0| 937597069| 1|
| 78065188| 498404626| 0| 0| 0| 2904922087| 1|
| 282487230|2520855981| 0| 28| 0| 3749735525| 1|
| 335269852|1641491432| 0| 2| 0| 1490350911| 1|
| 437050836|1238456614| 0| 2| 0| 991277599| 1|
| 447244169|2095085551| 0| -1| 0| 1579858878| 0|
| 516353916|1076364848| 0| 3| 1| 3597645735| 0|
| 528218683|1151525474| 0| 1| 0| 3433080956| 0|
| 531967718|3632072502| 0| 1| 0| 3863085861| 1|
| 627948360|2823119321| 0| 0| 0| 4092665803| 1|
| 811791433|3513954032| 0| 2| 0| 415464198| 1|
| 830686203| 99027353| 0| 0| 0| 3549822604| 1|
|1008893291|1115453150| 0| 2| 0| 2245155244| 1|
|1239364869|2824096896| 0| 2| 1| 2579294650| 0|
|1287950172|1076364848| 0| 0| 0| 3597645735| 1|
|1345896548|2658555390| 0| 1| 0| 2025118823| 1|
|1354205322|2564682277| 0| 3| 0| 2563033185| 1|
|1408344828|1255629030| 0| -1| 1| 804901063| 1|
|1452633375|1334001859| 0| 4| 0| 1488588320| 1|
|1625052108|3297535757| 0| 3| 0| 1972598895| 1|
+----------+----------+-------+--------+----------+-----------+---------+
val df = Seq(
("1001", "1007"),
("1002", null),
("1003", "1005"),
(null, "1006")
).toDF("user_id", "friend_id")
Data is:
+-------+---------+
|user_id|friend_id|
+-------+---------+
| 1001| 1007|
| 1002| null|
| 1003| 1005|
| null| 1006|
+-------+---------+
Drop rows containing any null or NaN values in the specified columns of the Seq:
df.na.drop(Seq("friend_id"))
.show()
Output:
+-------+---------+
|user_id|friend_id|
+-------+---------+
| 1001| 1007|
| 1003| 1005|
| null| 1006|
+-------+---------+
If do not specify columns, drop row as long as any column of a row contains null or NaN values:
df.na.drop()
.show()
Output:
+-------+---------+
|user_id|friend_id|
+-------+---------+
| 1001| 1007|
| 1003| 1005|
+-------+---------+
Another easy way to filter out null values from multiple columns in spark dataframe. Please pay attention there is AND between columns.
df.filter(" COALESCE(col1, col2, col3, col4, col5, col6) IS NOT NULL")
If you need to filter out rows that contain any null (OR connected) please use
df.na.drop()
I use the following code to solve my question. It works. But as we all know, I work around a country's mile to solve it. So, is there a short cut for that? Thanks
def filter_null(field : Any) : Int = field match {
case null => 0
case _ => 1
}
val test = train_event_join.join(
user_friends_pair,
train_event_join("user_id") === user_friends_pair("user_id") &&
train_event_join("event_owner") === user_friends_pair("friend_id"),
"left"
).select(
train_event_join("user_id"),
train_event_join("event_id"),
train_event_join("invited"),
train_event_join("day_diff"),
train_event_join("interested"),
train_event_join("event_owner"),
user_friends_pair("friend_id")
).rdd.map{
line => (
line(0).toString.toLong,
line(1).toString.toLong,
line(2).toString.toLong,
line(3).toString.toLong,
line(4).toString.toLong,
line(5).toString.toLong,
filter_null(line(6))
)
}.toDF("user_id", "event_id", "invited", "day_diff", "interested", "event_owner", "creator_is_friend")