Pyspark : How to Map the value without join or using rdd? - pyspark

I have a dataframe df looks like this
from pyspark.sql.functions import lit, col, create_map
df = spark.createDataFrame(
[
("1","A","B","2020-01-01", 6),
("2","A","B","2020-01-01", 6),
("3","A","C","2020-01-01", 6),
("4","A","C","2020-01-01", 6) ,
("5","B","D","2020-01-01", 10),
("6","B","D","2020-01-01",10),
],
["id","map1","map2","date",'var']
)
+---+----+----+----------+---+
| id|map1|map2| date|var|
+---+----+----+----------+---+
| 1| A| B|2020-01-01| 6|
| 2| A| B|2020-01-01| 6|
| 3| A| C|2020-01-01| 6|
| 4| A| C|2020-01-01| 6|
| 5| B| D|2020-01-01| 10|
| 6| B| D|2020-01-01| 10|
+---+----+----+----------+---+
Now I would like to map using map1 and map2 column such that ... shown in the screenshot below.
note that for all different map1 values , (A,B) the var values are same (6,10) and map1 can not be null but map2 can be null.
I want to do this without using join/rdd/udf as much as possible, just depends on pure pyspark functions for the performance.
first, I create a column dictionary key : value
df = df.withColumn("mapp", create_map('map1', 'var'))
I tried using something like but this obviously does not work dynamically.
df = df.withColumn('var_mapped', df["mapp"].getItem(df['map1']))
what are some solutions/functions to use in this case? any help would be appreciated.

to get all the key-value combinations of map across rows, you can use window functions. In this case, you can use collect_set of struct of column map1 and var over an all-rows window, then create a map using map_from_entries. It would be something like this
from pyspark.sql.functions import map_from_entries, collect_set, struct, col
from pyspark.sql.window import Window
df = df.withColumn("mapp", map_from_entries(collect_set(struct(col('map1'), col('var'))).over(Window.partitionBy())))
df.show()
+---+----+----+----------+---+-----------------+
| id|map1|map2| date|var| mapp|
+---+----+----+----------+---+-----------------+
| 1| A| B|2020-01-01| 6|[B -> 10, A -> 6]|
| 2| A| B|2020-01-01| 6|[B -> 10, A -> 6]|
| 3| A| C|2020-01-01| 6|[B -> 10, A -> 6]|
| 4| A| C|2020-01-01| 6|[B -> 10, A -> 6]|
| 5| B| D|2020-01-01| 10|[B -> 10, A -> 6]|
| 6| B| D|2020-01-01| 10|[B -> 10, A -> 6]|
+---+----+----+----------+---+-----------------+
After that, you can map column map2 using .getItem().
df = df.withColumn('res', col('mapp').getItem(col('map2'))).fillna(0)
df.show()
+---+----+----+----------+---+-----------------+---+
| id|map1|map2| date|var| mapp|res|
+---+----+----+----------+---+-----------------+---+
| 1| A| B|2020-01-01| 6|[B -> 10, A -> 6]| 10|
| 2| A| B|2020-01-01| 6|[B -> 10, A -> 6]| 10|
| 3| A| C|2020-01-01| 6|[B -> 10, A -> 6]| 0|
| 4| A| C|2020-01-01| 6|[B -> 10, A -> 6]| 0|
| 5| B| D|2020-01-01| 10|[B -> 10, A -> 6]| 0|
| 6| B| D|2020-01-01| 10|[B -> 10, A -> 6]| 0|
+---+----+----+----------+---+-----------------+---+

Related

Dataframe transform column for each day into rows [duplicate]

Is there an equivalent of Pandas Melt function in Apache Spark in PySpark or at least in Scala?
I was running a sample dataset till now in Python and now I want to use Spark for the entire dataset.
Spark >= 3.4
In Spark 3.4 or later you can use built-in melt method
(sdf
.melt(
ids=['A'], values=['B', 'C'],
variableColumnName="variable",
valueColumnName="value")
.show())
+---+--------+-----+
| A|variable|value|
+---+--------+-----+
| a| B| 1|
| a| C| 2|
| b| B| 3|
| b| C| 4|
| c| B| 5|
| c| C| 6|
+---+--------+-----+
This method is available across all APIs so could be used in Scala
sdf.melt(Array($"A"), Array($"B", $"C"), "variable", "value")
or SQL
SELECT * FROM sdf UNPIVOT (val FOR col in (col_1, col_2))
Spark 3.2 (Python only, requires Pandas and pyarrow)
(sdf
.to_koalas()
.melt(id_vars=['A'], value_vars=['B', 'C'])
.to_spark()
.show())
+---+--------+-----+
| A|variable|value|
+---+--------+-----+
| a| B| 1|
| a| C| 2|
| b| B| 3|
| b| C| 4|
| c| B| 5|
| c| C| 6|
+---+--------+-----+
Spark < 3.2
There is no built-in function (if you work with SQL and Hive support enabled you can use stack function, but it is not exposed in Spark and has no native implementation) but it is trivial to roll your own. Required imports:
from pyspark.sql.functions import array, col, explode, lit, struct
from pyspark.sql import DataFrame
from typing import Iterable
Example implementation:
def melt(
df: DataFrame,
id_vars: Iterable[str], value_vars: Iterable[str],
var_name: str="variable", value_name: str="value") -> DataFrame:
"""Convert :class:`DataFrame` from wide to long format."""
# Create array<struct<variable: str, value: ...>>
_vars_and_vals = array(*(
struct(lit(c).alias(var_name), col(c).alias(value_name))
for c in value_vars))
# Add to the DataFrame and explode
_tmp = df.withColumn("_vars_and_vals", explode(_vars_and_vals))
cols = id_vars + [
col("_vars_and_vals")[x].alias(x) for x in [var_name, value_name]]
return _tmp.select(*cols)
And some tests (based on Pandas doctests):
import pandas as pd
pdf = pd.DataFrame({'A': {0: 'a', 1: 'b', 2: 'c'},
'B': {0: 1, 1: 3, 2: 5},
'C': {0: 2, 1: 4, 2: 6}})
pd.melt(pdf, id_vars=['A'], value_vars=['B', 'C'])
A variable value
0 a B 1
1 b B 3
2 c B 5
3 a C 2
4 b C 4
5 c C 6
sdf = spark.createDataFrame(pdf)
melt(sdf, id_vars=['A'], value_vars=['B', 'C']).show()
+---+--------+-----+
| A|variable|value|
+---+--------+-----+
| a| B| 1|
| a| C| 2|
| b| B| 3|
| b| C| 4|
| c| B| 5|
| c| C| 6|
+---+--------+-----+
Note: For use with legacy Python versions remove type annotations.
Related:
R SparkR - equivalent to melt function
Gather in sparklyr
Came across this question in my search for an implementation of melt in Spark for Scala.
Posting my Scala port in case someone also stumbles upon this.
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame}
/** Extends the [[org.apache.spark.sql.DataFrame]] class
*
* #param df the data frame to melt
*/
implicit class DataFrameFunctions(df: DataFrame) {
/** Convert [[org.apache.spark.sql.DataFrame]] from wide to long format.
*
* melt is (kind of) the inverse of pivot
* melt is currently (02/2017) not implemented in spark
*
* #see reshape packe in R (https://cran.r-project.org/web/packages/reshape/index.html)
* #see this is a scala adaptation of http://stackoverflow.com/questions/41670103/pandas-melt-function-in-apache-spark
*
* #todo method overloading for simple calling
*
* #param id_vars the columns to preserve
* #param value_vars the columns to melt
* #param var_name the name for the column holding the melted columns names
* #param value_name the name for the column holding the values of the melted columns
*
*/
def melt(
id_vars: Seq[String], value_vars: Seq[String],
var_name: String = "variable", value_name: String = "value") : DataFrame = {
// Create array<struct<variable: str, value: ...>>
val _vars_and_vals = array((for (c <- value_vars) yield { struct(lit(c).alias(var_name), col(c).alias(value_name)) }): _*)
// Add to the DataFrame and explode
val _tmp = df.withColumn("_vars_and_vals", explode(_vars_and_vals))
val cols = id_vars.map(col _) ++ { for (x <- List(var_name, value_name)) yield { col("_vars_and_vals")(x).alias(x) }}
return _tmp.select(cols: _*)
}
}
Since I'm am not that advanced considering Scala, I'm sure there is room for improvement.
Any comments are welcome.
Voted for user6910411's answer. It works as expected, however, it cannot handle None values well. thus I refactored his melt function to the following:
from pyspark.sql.functions import array, col, explode, lit
from pyspark.sql.functions import create_map
from pyspark.sql import DataFrame
from typing import Iterable
from itertools import chain
def melt(
df: DataFrame,
id_vars: Iterable[str], value_vars: Iterable[str],
var_name: str="variable", value_name: str="value") -> DataFrame:
"""Convert :class:`DataFrame` from wide to long format."""
# Create map<key: value>
_vars_and_vals = create_map(
list(chain.from_iterable([
[lit(c), col(c)] for c in value_vars]
))
)
_tmp = df.select(*id_vars, explode(_vars_and_vals)) \
.withColumnRenamed('key', var_name) \
.withColumnRenamed('value', value_name)
return _tmp
Test is with the following dataframe:
import pandas as pd
pdf = pd.DataFrame({'A': {0: 'a', 1: 'b', 2: 'c'},
'B': {0: 1, 1: 3, 2: 5},
'C': {0: 2, 1: 4, 2: 6},
'D': {1: 7, 2: 9}})
pd.melt(pdf, id_vars=['A'], value_vars=['B', 'C', 'D'])
A variable value
0 a B 1.0
1 b B 3.0
2 c B 5.0
3 a C 2.0
4 b C 4.0
5 c C 6.0
6 a D NaN
7 b D 7.0
8 c D 9.0
sdf = spark.createDataFrame(pdf)
melt(sdf, id_vars=['A'], value_vars=['B', 'C', 'D']).show()
+---+--------+-----+
| A|variable|value|
+---+--------+-----+
| a| B| 1.0|
| a| C| 2.0|
| a| D| NaN|
| b| B| 3.0|
| b| C| 4.0|
| b| D| 7.0|
| c| B| 5.0|
| c| C| 6.0|
| c| D| 9.0|
+---+--------+-----+
UPD
Finally i've found most effective implementation for me. It uses all resources for cluster in my yarn configuration.
from pyspark.sql.functions import explode
def melt(df):
sp = df.columns[1:]
return (df
.rdd
.map(lambda x: [str(x[0]), [(str(i[0]),
float(i[1] if i[1] else 0)) for i in zip(sp, x[1:])]],
preservesPartitioning = True)
.toDF()
.withColumn('_2', explode('_2'))
.rdd.map(lambda x: [str(x[0]),
str(x[1][0]),
float(x[1][1] if x[1][1] else 0)],
preservesPartitioning = True)
.toDF()
)
For very wide dataframe I've got performance decreasing at _vars_and_vals generation from user6910411 answer.
It was useful to implement melting via selectExpr
columns=['a', 'b', 'c', 'd', 'e', 'f']
pd_df = pd.DataFrame([[1,2,3,4,5,6], [4,5,6,7,9,8], [7,8,9,1,2,4], [8,3,9,8,7,4]], columns=columns)
df = spark.createDataFrame(pd_df)
+---+---+---+---+---+---+
| a| b| c| d| e| f|
+---+---+---+---+---+---+
| 1| 2| 3| 4| 5| 6|
| 4| 5| 6| 7| 9| 8|
| 7| 8| 9| 1| 2| 4|
| 8| 3| 9| 8| 7| 4|
+---+---+---+---+---+---+
cols = df.columns[1:]
df.selectExpr('a', "stack({}, {})".format(len(cols), ', '.join(("'{}', {}".format(i, i) for i in cols))))
+---+----+----+
| a|col0|col1|
+---+----+----+
| 1| b| 2|
| 1| c| 3|
| 1| d| 4|
| 1| e| 5|
| 1| f| 6|
| 4| b| 5|
| 4| c| 6|
| 4| d| 7|
| 4| e| 9|
| 4| f| 8|
| 7| b| 8|
| 7| c| 9|
...
Use list comprehension to create struct column of column names and col values and explode the new column using the magic inline. Code below;
melted_df=(df.withColumn(
#Create struct of column names and corresponding values
'tab',F.array(*[F.struct(lit(x).alias('var'),F.col(x).alias('val'))for x in df.columns if x!='A'] ))
#Explode the column
.selectExpr('A',"inline(tab)")
)
melted_df.show()
+---+---+---+
| A|var|val|
+---+---+---+
| a| B| 1|
| a| C| 2|
| b| B| 3|
| b| C| 4|
| c| B| 5|
| c| C| 6|
+---+---+---+
1) Copy & paste
2) Change the first 2 variables
to_melt = {'latin', 'greek', 'chinese'}
new_names = ['lang', 'letter']
melt_str = ','.join([f"'{c}', `{c}`" for c in to_melt])
df = df.select(
*(set(df.columns) - to_melt),
F.expr(f"stack({len(to_melt)}, {melt_str}) ({','.join(new_names)})")
)
null is created if some values contain null. To remove it, add this:
.filter(f"!{new_names[1]} is null")
Full test:
from pyspark.sql import functions as F
df = spark.createDataFrame([(101, "A", "Σ", "西"), (102, "B", "Ω", "诶")], ['ID', 'latin', 'greek', 'chinese'])
df.show()
# +---+-----+-----+-------+
# | ID|latin|greek|chinese|
# +---+-----+-----+-------+
# |101| A| Σ| 西|
# |102| B| Ω| 诶|
# +---+-----+-----+-------+
to_melt = {'latin', 'greek', 'chinese'}
new_names = ['lang', 'letter']
melt_str = ','.join([f"'{c}', `{c}`" for c in to_melt])
df = df.select(
*(set(df.columns) - to_melt),
F.expr(f"stack({len(to_melt)}, {melt_str}) ({','.join(new_names)})")
)
df.show()
# +---+-------+------+
# | ID| lang|letter|
# +---+-------+------+
# |101| latin| A|
# |101| greek| Σ|
# |101|chinese| 西|
# |102| latin| B|
# |102| greek| Ω|
# |102|chinese| 诶|
# +---+-------+------+

What are the Scala Spark performance of groupBy vs pivot?

I am facing an issue that I have to pivot a Spark Dataframe with different aggregation functions, based on the column value I decide to pivot. I am using this other question on SO as my starting point.
Let's take the following as starting point:
scala> val data = Seq((1, "k1", "measureA", 2), (1, "k1", "measureA", 4), (1, "k1", "measureB", 5), (1, "k1", "measureB", 7), (1, "k1", "measureC", 7), (1, "k1", "measureC", 1), (2, "k1", "measureB", 8), (2, "k1", "measureC", 9), (2, "k2", "measureA", 5), (2, "k2", "measureC", 5), (2, "k2", "measureC", 8))
data: Seq[(Int, String, String, Int)] = List((1,k1,measureA,2), (1,k1,measureA,4), (1,k1,measureB,5), (1,k1,measureB,7), (1,k1,measureC,7), (1,k1,measureC,1), (2,k1,measureB,8), (2,k1,measureC,9), (2,k2,measureA,5), (2,k2,measureC,5), (2,k2,measureC,8))
scala> val df = data.toDF("ts","key","measure_type","value")
df: org.apache.spark.sql.DataFrame = [ts: int, key: string ... 2 more fields]
scala> df.show
+---+---+------------+-----+
| ts|key|measure_type|value|
+---+---+------------+-----+
| 1| k1| measureA| 2|
| 1| k1| measureA| 4|
| 1| k1| measureB| 5|
| 1| k1| measureB| 7|
| 1| k1| measureC| 7|
| 1| k1| measureC| 1|
| 2| k1| measureB| 8|
| 2| k1| measureC| 9|
| 2| k2| measureA| 5|
| 2| k2| measureC| 5|
| 2| k2| measureC| 8|
+---+---+------------+-----+
What does perform better? A groupBy + agg:
val ddf = df.groupBy("ts", "key").agg(
sum(when(col("measure_type") === "measureA",col("value"))).as("measureA"),
avg(when(col("measure_type") === "measureB",col("value"))).as("measureB"),
max(when(col("measure_type") === "measureC",col("value"))).as("measureC"))
ddf.show
+---+---+--------+--------+--------+
| ts|key|measureA|measureB|measureC|
+---+---+--------+--------+--------+
| 1| k1| 6| 6.0| 7|
| 2| k1| null| 8.0| 9|
| 2| k2| 5| null| 8|
+---+---+--------+--------+--------+
Or pivot + agg:
val listA = Seq("measureA")
val listB = Seq("measureB")
val listC = Seq("measureC")
val ddf = df.groupBy("ts", "key").pivot(col("measure_type"), Seq("measureA", "measureB", "measureC")).agg(
sum(when(col("measure_type").isInCollection(listA),col("value"))).as("measureA"),
avg(when(col("measure_type").isInCollection(listB),col("value"))).as("measureB"),
max(when(col("measure_type").isInCollection(listC),col("value"))).as("measureC"))
ddf.show()
+---+---+-----------------+-----------------+-----------------+-----------------+-----------------+-----------------+-----------------+-----------------+-----------------+
| ts|key|measureA_measureA|measureA_measureB|measureA_measureC|measureB_measureA|measureB_measureB|measureB_measureC|measureC_measureA|measureC_measureB|measureC_measureC|
+---+---+-----------------+-----------------+-----------------+-----------------+-----------------+-----------------+-----------------+-----------------+-----------------+
| 2| k2| 5| null| null| null| null| null| null| null| 8|
| 2| k1| null| null| null| null| 8.0| null| null| null| 9|
| 1| k1| 6| null| null| null| 6.0| null| null| null| 7|
+---+---+-----------------+-----------------+-----------------+-----------------+-----------------+-----------------+-----------------+-----------------+-----------------+
I am aware that the second DataFrame that I got as output is different because it contains all the combinations from the Seq() of column names I passed to the pivot method and the number of the different aggregation functions (SQL CASEs, since there is when()) I decided to choose. So 3 * 3 = 9. But if you filter this second DataFrame removing the columns with only null then the result is the same.
Also I am wondering if I am doing something wrong in the second approach or there is a way to better rename the columns in order to avoid from the start the null columns.
val ddf = df.groupBy("ts", "key").pivot(col("measure_type"), Seq("measureA", "measureB", "measureC")).agg(
sum(when(col("measure_type").isInCollection(listA),col("value"))),
avg(when(col("measure_type").isInCollection(listB),col("value"))),
max(when(col("measure_type").isInCollection(listC),col("value"))))
ddf:org.apache.spark.sql.DataFrame
ts:integer
key:string
measureA_sum(CASE WHEN (measure_type IN (measureA)) THEN value END):long
measureA_avg(CASE WHEN (measure_type IN (measureB)) THEN value END):double
measureA_max(CASE WHEN (measure_type IN (measureC)) THEN value END):integer
measureB_sum(CASE WHEN (measure_type IN (measureA)) THEN value END):long
measureB_avg(CASE WHEN (measure_type IN (measureB)) THEN value END):double
measureB_max(CASE WHEN (measure_type IN (measureC)) THEN value END):integer
measureC_sum(CASE WHEN (measure_type IN (measureA)) THEN value END):long
measureC_avg(CASE WHEN (measure_type IN (measureB)) THEN value END):double
measureC_max(CASE WHEN (measure_type IN (measureC)) THEN value END):integer
I have decided not to post the ddf.show because of the very verbose headers. The result is the same as the pivot + agg example, just with the headers above listed.

How to compare value of one row with all the other rows in PySpark on grouped values

Problem statement
Consider the following data (see code generation at the bottom)
+-----+-----+-------+--------+
|index|group|low_num|high_num|
+-----+-----+-------+--------+
| 0| 1| 1| 1|
| 1| 1| 2| 2|
| 2| 1| 3| 3|
| 3| 2| 1| 3|
+-----+-----+-------+--------+
Then for a given index, I want to count how many times that one indexes high_num is greater than low_num for all low_num in the group.
For instance, consider the second row with index: 1. Index: 1 is in group: 1 and the high_num is 2. high_num on index 1 is greater than the high_num on index 0, equal to low_num, and smaller than the one on index 2. So the high_num of index: 1 is greater than low_num across the group once, so then I want the value in the answer column to say 1.
Dataset with desired output
+-----+-----+-------+--------+-------+
|index|group|low_num|high_num|desired|
+-----+-----+-------+--------+-------+
| 0| 1| 1| 1| 0|
| 1| 1| 2| 2| 1|
| 2| 1| 3| 3| 2|
| 3| 2| 1| 3| 1|
+-----+-----+-------+--------+-------+
Dataset generation code
from pyspark.sql import SparkSession
spark = (
SparkSession
.builder
.getOrCreate()
)
## Example df
## Note the inclusion of "desired" which is the desired output.
df = spark.createDataFrame(
[
(0, 1, 1, 1, 0),
(1, 1, 2, 2, 1),
(2, 1, 3, 3, 2),
(3, 2, 1, 3, 1)
],
schema=["index", "group", "low_num", "high_num", "desired"]
)
Pseudocode that might have solved the problem
A pseusocode might look like this:
import pyspark.sql.functions as F
from pyspark.sql.window import Window
w_spec = Window.partitionBy("group").rowsBetween(
Window.unboundedPreceding, Window.unboundedFollowing)
## F.collect_list_when does not exist
## F.current_col does not exist
## Probably wouldn't work like this anyways
ddf = df.withColumn("Counts",
F.size(F.collect_list_when(
F.current_col("high_number") > F.col("low_number"), 1
).otherwise(None).over(w_spec))
)
You can do a filter on the collect_list, and check its size:
import pyspark.sql.functions as F
df2 = df.withColumn(
'desired',
F.expr('size(filter(collect_list(low_num) over (partition by group), x -> x < high_num))')
)
df2.show()
+-----+-----+-------+--------+-------+
|index|group|low_num|high_num|desired|
+-----+-----+-------+--------+-------+
| 0| 1| 1| 1| 0|
| 1| 1| 2| 2| 1|
| 2| 1| 3| 3| 2|
| 3| 2| 1| 3| 1|
+-----+-----+-------+--------+-------+

How to automatically drop constant columns in pyspark?

I have a spark dataframe in pyspark and I need to drop all constant columns from my dataframe. Since I don't know which columns are constant I cannot manually unselect the constant columns, i.e. I need an automatic procedure. I am surprised I was not able to find a simple solution on stackoverflow.
Example:
import pandas as pd
import pyspark
from pyspark.sql.session import SparkSession
spark = SparkSession.builder.appName("test").getOrCreate()
d = {'col1': [1, 2, 3, 4, 5],
'col2': [1, 2, 3, 4, 5],
'col3': [0, 0, 0, 0, 0],
'col4': [0, 0, 0, 0, 0]}
df_panda = pd.DataFrame(data=d)
df_spark = spark.createDataFrame(df_panda)
df_spark.show()
Output:
+----+----+----+----+
|col1|col2|col3|col4|
+----+----+----+----+
| 1| 1| 0| 0|
| 2| 2| 0| 0|
| 3| 3| 0| 0|
| 4| 4| 0| 0|
| 5| 5| 0| 0|
+----+----+----+----+
Desired output:
+----+----+
|col1|col2|
+----+----+
| 1| 1|
| 2| 2|
| 3| 3|
| 4| 4|
| 5| 5|
+----+----+
What is the best way to automatically drop constant columns in pyspark?
Count distinct values in each column first and then drop columns that contain only one distinct value:
import pyspark.sql.functions as f
cnt = df_spark.agg(*(f.countDistinct(c).alias(c) for c in df_spark.columns)).first()
cnt
# Row(col1=5, col2=5, col3=1, col4=1)
df_spark.drop(*[c for c in cnt.asDict() if cnt[c] == 1]).show()
+----+----+
|col1|col2|
+----+----+
| 1| 1|
| 2| 2|
| 3| 3|
| 4| 4|
| 5| 5|
+----+----+

Apache Spark - Scala API - Aggregate on sequentially increasing key

I have a data frame that looks something like this:
val df = sc.parallelize(Seq(
(3,1,"A"),(3,2,"B"),(3,3,"C"),
(2,1,"D"),(2,2,"E"),
(3,1,"F"),(3,2,"G"),(3,3,"G"),
(2,1,"X"),(2,2,"X")
)).toDF("TotalN", "N", "String")
+------+---+------+
|TotalN| N|String|
+------+---+------+
| 3| 1| A|
| 3| 2| B|
| 3| 3| C|
| 2| 1| D|
| 2| 2| E|
| 3| 1| F|
| 3| 2| G|
| 3| 3| G|
| 2| 1| X|
| 2| 2| X|
+------+---+------+
I need to aggregate the strings by concatenating them together based on the TotalN and the sequentially increasing ID (N). The problem is there is not a unique ID for each aggregation I can group by. So, I need to do something like "for each row look at the TotalN, loop through the next N rows and concatenate, then reset".
+------+------+
|TotalN|String|
+------+------+
| 3| ABC|
| 2| DE|
| 3| FGG|
| 2| XX|
+------+------+
Any pointers much appreciated.
Using Spark 2.3.1 and the Scala Api.
Try this:
val df = spark.sparkContext.parallelize(Seq(
(3, 1, "A"), (3, 2, "B"), (3, 3, "C"),
(2, 1, "D"), (2, 2, "E"),
(3, 1, "F"), (3, 2, "G"), (3, 3, "G"),
(2, 1, "X"), (2, 2, "X")
)).toDF("TotalN", "N", "String")
df.createOrReplaceTempView("data")
val sqlDF = spark.sql(
"""
| SELECT TotalN d, N, String, ROW_NUMBER() over (order by TotalN) as rowNum
| FROM data
""".stripMargin)
sqlDF.withColumn("key", $"N" - $"rowNum")
.groupBy("key").agg(collect_list('String).as("texts")).show()
Solution is to calculate a grouping variable using the row_number function which can be used in later groupBy.
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.row_number
var w = Window.orderBy("TotalN")
df.withColumn("GeneratedID", $"N" - row_number.over(w)).show
+------+---+------+-----------+
|TotalN| N|String|GeneratedID|
+------+---+------+-----------+
| 2| 1| D| 0|
| 2| 2| E| 0|
| 2| 1| X| -2|
| 2| 2| X| -2|
| 3| 1| A| -4|
| 3| 2| B| -4|
| 3| 3| C| -4|
| 3| 1| F| -7|
| 3| 2| G| -7|
| 3| 3| G| -7|
+------+---+------+-----------+