Finding most common non-null prefix per group in spark - scala

I need to write a structured query that finds the most common not-null PREFIX (occurences) per UNIQUE_GUEST_ID.
There is input data:
val inputDf = Seq(
(1, "Mr"),
(1, "Mme"),
(1, "Mr"),
(1, null),
(1, null),
(1, null),
(2, "Mr"),
(3, null)).toDF("UNIQUE_GUEST_ID", "PREFIX")
println("Input:")
inputDf.show(false)
My solution was:
inputDf
.groupBy($"UNIQUE_GUEST_ID")
.agg(collect_list($"PREFIX").alias("PREFIX"))
But that is not what i need:
Expected:
+---------------+------+
|UNIQUE_GUEST_ID|PREFIX|
+---------------+------+
|1 |Mr |
|2 |Mr |
|3 |null |
+---------------+------+
Actual:
+---------------+-------------+
|UNIQUE_GUEST_ID|PREFIX |
+---------------+-------------+
|1 |[Mr, Mme, Mr]|
|3 |[] |
|2 |[Mr] |
+---------------+-------------+

Try this-
val inputDf = Seq(
(1, "Mr"),
(1, "Mme"),
(1, "Mr"),
(1, null),
(1, null),
(1, null),
(2, "Mr"),
(3, null)).toDF("UNIQUE_GUEST_ID", "PREFIX")
println("Input:")
inputDf.show(false)
/**
* Input:
* +---------------+------+
* |UNIQUE_GUEST_ID|PREFIX|
* +---------------+------+
* |1 |Mr |
* |1 |Mme |
* |1 |Mr |
* |1 |null |
* |1 |null |
* |1 |null |
* |2 |Mr |
* |3 |null |
* +---------------+------+
*/
inputDf
.groupBy($"UNIQUE_GUEST_ID", $"PREFIX").agg(count($"PREFIX").as("count"))
.groupBy($"UNIQUE_GUEST_ID")
.agg(max( struct( $"count", $"PREFIX")).as("max"))
.selectExpr("UNIQUE_GUEST_ID", "max.PREFIX")
.show(false)
/**
* +---------------+------+
* |UNIQUE_GUEST_ID|PREFIX|
* +---------------+------+
* |2 |Mr |
* |1 |Mr |
* |3 |null |
* +---------------+------+
*/

val df2 = inputDf.groupBy('UNIQUE_GUEST_ID,'PREFIX).agg(count('PREFIX).as("ct"))
val df3 = df2.groupBy('UNIQUE_GUEST_ID).agg(max('ct).as("ct"))
df2.join(df3,Seq("ct","UNIQUE_GUEST_ID")).show()
output:
+---+---------------+------+
| ct|UNIQUE_GUEST_ID|PREFIX|
+---+---------------+------+
| 1| 2| Mr|
| 0| 3| null|
| 2| 1| Mr|
+---+---------------+------+

Related

Spark dataframe join aggregating by ID

I have problem in joining 2 dataframes grouped by ID
val df1 = Seq(
(1, 1,100),
(1, 3,20),
(2, 5,5),
(2, 2,10)).toDF("id", "index","value")
val df2 = Seq(
(1, 0),
(2, 0),
(3, 0),
(4, 0),
(5,0)).toDF("index", "value")
df1 joins with df2 by index column for every id
expected result
id
index
value
1
1
100
1
2
0
1
3
20
1
4
0
1
5
0
2
1
0
2
2
10
2
3
0
2
4
0
2
5
5
please help me on this
First of all, I would replace your df2 table with this:
var df2 = Seq(
(Array(1, 2), Array(1, 2, 3, 4, 5))
).toDF("id", "index")
This allows us to use explode and auto-generate a table which can be of help to us:
df2 = df2
.withColumn("id", explode(col("id")))
.withColumn("index", explode(col("index")))
and it gives:
+---+-----+
|id |index|
+---+-----+
|1 |1 |
|1 |2 |
|1 |3 |
|1 |4 |
|1 |5 |
|2 |1 |
|2 |2 |
|2 |3 |
|2 |4 |
|2 |5 |
+---+-----+
Now, all we need to do, is join with your df1 as below:
df2 = df2
.join(df1, Seq("id", "index"), "left")
.withColumn("value", when(col("value").isNull, 0).otherwise(col("value")))
And we get this final output:
+---+-----+-----+
|id |index|value|
+---+-----+-----+
|1 |1 |100 |
|1 |2 |0 |
|1 |3 |20 |
|1 |4 |0 |
|1 |5 |0 |
|2 |1 |0 |
|2 |2 |10 |
|2 |3 |0 |
|2 |4 |0 |
|2 |5 |5 |
+---+-----+-----+
which should be what you want. Good luck!

Find min value for every 5 hour interval

My df
val df = Seq(
("1", 1),
("1", 1),
("1", 2),
("1", 4),
("1", 5),
("1", 6),
("1", 8),
("1", 12),
("1", 12),
("1", 13),
("1", 14),
("1", 15),
("1", 16)
).toDF("id", "time")
For this case the first interval starts from 1 hour. So every row up to 6 (1 + 5) is part of this interval.
But 8 - 1 > 5, so the second interval starts from 8 and goes up to 13.
Then I see that 14 - 8 > 5, so the third one starts and so on.
The desired result
+---+----+--------+
|id |time|min_time|
+---+----+--------+
|1 |1 |1 |
|1 |1 |1 |
|1 |2 |1 |
|1 |4 |1 |
|1 |5 |1 |
|1 |6 |1 |
|1 |8 |8 |
|1 |12 |8 |
|1 |12 |8 |
|1 |13 |8 |
|1 |14 |14 |
|1 |15 |14 |
|1 |16 |14 |
+---+----+--------+
I'm trying to do it using min function, but don't know how to account for this condition.
val window = Window.partitionBy($"id").orderBy($"time")
df
.select($"id", $"time")
.withColumn("min_time", when(($"time" - min($"time").over(window)) <= 5, min($"time").over(window)).otherwise($"time"))
.show(false)
what I get
+---+----+--------+
|id |time|min_time|
+---+----+--------+
|1 |1 |1 |
|1 |1 |1 |
|1 |2 |1 |
|1 |4 |1 |
|1 |5 |1 |
|1 |6 |1 |
|1 |8 |8 |
|1 |12 |12 |
|1 |12 |12 |
|1 |13 |13 |
|1 |14 |14 |
|1 |15 |15 |
|1 |16 |16 |
+---+----+--------+
You can go with your first idea of using aggregation function on a window. But instead of using some combination of Spark's already defined functions, you can define your own Spark's user-defined aggregate function (UDAF).
Analysis
As you correctly supposed, we should use a kind of min function on a window. On the rows of this window, we want to implement the following rule:
Given rows sorted by time, if the difference between the min_time of the previous row and the time of the current row is greater than 5, then the current row's min_time should be current row's time, else the current row's min_time should be previous row's min_time.
However, with the aggregate functions provided by Spark, we can't access to the previous row's min_time. It exists a lag function, but with this function we can only access to the already present values of previous rows. As the previous row's min_time is not already present, we can't access it.
Thus we have to define our own aggregate function
Solution
Defining a tailor-made aggregate function
To define our aggregate function, we need to create a class that extends the Aggregator abstract class. Below is the complete implementation:
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders}
object MinByInterval extends Aggregator[Integer, Integer, Integer] {
def zero: Integer = null
def reduce(buffer: Integer, time: Integer): Integer = {
if (buffer == null || time - buffer > 5) time else buffer
}
def merge(b1: Integer, b2: Integer): Integer = {
throw new NotImplementedError("should not use as general aggregation")
}
def finish(reduction: Integer): Integer = reduction
def bufferEncoder: Encoder[Integer] = Encoders.INT
def outputEncoder: Encoder[Integer] = Encoders.INT
}
We use Integer for input, buffer and output types. We chose Integer as it is a nullable Int. We could have used Option[Int], however the documentation of Spark advises to not recreate objects in aggregators methods for performance issues, what would happens if we use complex types like Option.
We implement the rule defined in Analysis section in reduce method:
def reduce(buffer: Integer, time: Integer): Integer = {
if (buffer == null || time - buffer > 5) time else buffer
}
Here time is the value in the column time of the current row, and buffer the value previously computed, so corresponding to the column min_time of the previous row. As in our window we sort the rows by time, time is always greater than buffer. The null buffer case only happens when treating first row.
The method merge is not used when using aggregate function over a window, so we don't implement it.
finish method is identity method as we don't need to perform final calculation on our aggregated value and output and buffer encoders are Encoders.INT
Calling user defined aggregate function
Now we can call our user defined aggregate function with the following code:
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, udaf}
val minTime = udaf(MinByInterval)
val window = Window.partitionBy("id").orderBy("time")
df.withColumn("min_time", minTime(col("time")).over(window))
Run
Given the input dataframe in the question, we get:
+---+----+--------+
|id |time|min_time|
+---+----+--------+
|1 |1 |1 |
|1 |1 |1 |
|1 |2 |1 |
|1 |4 |1 |
|1 |5 |1 |
|1 |6 |1 |
|1 |8 |8 |
|1 |12 |8 |
|1 |12 |8 |
|1 |13 |8 |
|1 |14 |14 |
|1 |15 |14 |
|1 |16 |14 |
+---+----+--------+
Input data
val df = Seq(
("1", 1),
("1", 1),
("1", 2),
("1", 4),
("1", 5),
("1", 6),
("1", 8),
("1", 12),
("1", 12),
("1", 13),
("1", 14),
("1", 15),
("1", 16),
("2", 4),
("2", 8),
("2", 10),
("2", 11),
("2", 11),
("2", 12),
("2", 13),
("2", 20)
).toDF("id", "time")
The data must be sorted, otherwise the result will be incorect.
val window = Window.partitionBy($"id").orderBy($"time")
df
.withColumn("min", row_number().over(window))
.as[Row]
.map(_.getMin)
.show(40)
After, I create a case class. var min is used to hold the minimum value and is only updated when the conditions are met.
case class Row(id:String, time:Int, min:Int){
def getMin: Row = {
if(time - Row.min > 5 || Row.min == -99 || min == 1){
Row.min = time
}
Row(id, time, Row.min)
}
}
object Row{
var min: Int = -99
}
Result
+---+----+---+
| id|time|min|
+---+----+---+
| 1| 1| 1|
| 1| 1| 1|
| 1| 2| 1|
| 1| 4| 1|
| 1| 5| 1|
| 1| 6| 1|
| 1| 8| 8|
| 1| 12| 8|
| 1| 12| 8|
| 1| 13| 8|
| 1| 14| 14|
| 1| 15| 14|
| 1| 16| 14|
| 2| 4| 4|
| 2| 8| 4|
| 2| 10| 10|
| 2| 11| 10|
| 2| 11| 10|
| 2| 12| 10|
| 2| 13| 10|
| 2| 20| 20|
+---+----+---+

Spark Scala filter on group of result

I am trying to filter a Data Frame on a group of result.
Sample Dataframe code -
scala> val df = sc.parallelize(Seq(
(1, 1, "m10", "t22"),
(1, 2, "m10", "t22"),
(1, 3, "m11", "t22"),
(1, 4, "m11", "t22"),
(1, 5, "m10", "t22"),
(1, 6, "m10", "t22"),
(1, 7, "m10", "t22"),
(1, 8, "m11", "t22"),
(1, 9, "m10", "t22"),
(1, 10, "m10", "t22"),
(2, 1, "m10", "t22"),
(2, 2, "m11", "t22"),
(2, 3, "m10", "t22"),
(2, 4, "m10", "t22"),
(2, 5, "m10", "t22"),
(2, 9, "m10", "t22"),
(2, 10, "m11", "t22"),
(3, 4, "m10", "t22"),
(3, 5, "m11", "t22"),
(3, 6, "m10", "t22"),
(3, 7, "m10", "t22"),
(3, 8, "m10", "t22"),
(3, 9, "m11", "t22"),
(3, 10, "m10", "t22")
)
).toDF("org_id", "rule_id", "period_id", "base_id")
Data looks like below -
scala> df.show(50, false)
+------+-------+---------+-------+
|org_id|rule_id|period_id|base_id|
+------+-------+---------+-------+
|1 |1 |m10 |t21 |
|1 |2 |m10 |t22 |
|1 |3 |m11 |t22 |
|1 |4 |m11 |t22 |
|1 |5 |m10 |t23 |
|1 |6 |m10 |t22 |
|1 |7 |m10 |t22 |
|1 |8 |m11 |t22 |
|1 |9 |m10 |t22 |
|1 |10 |m10 |t22 |
|2 |1 |m10 |t22 |
|2 |2 |m11 |t22 |
|2 |3 |m10 |t23 |
|2 |4 |m10 |t22 |
|2 |5 |m10 |t22 |
|2 |9 |m10 |t22 |
|2 |10 |m11 |t22 |
|3 |4 |m10 |t22 |
|3 |5 |m11 |t22 |
|3 |6 |m10 |t22 |
|3 |7 |m10 |t22 |
|3 |8 |m10 |t22 |
|3 |9 |m11 |t22 |
|3 |10 |m10 |t23 |
+------+-------+---------+-------+
Based on a properties file, I need to filter the result on group of org_id. properties file looks like -
4=1,2,3
7=1,4,5
9=8,10
.....................
.....................
In the properties file all the values are rule_id.
I will consider the rows contain rule_id 4 only if any group of org_id contain 1, 2 and 3 rule_ids. Otherwise I need to delete the row contains rule_id 4. Similarly for others rule_id values available in the properties file.
Expected Result -
+------+-------+---------+-------+
|org_id|rule_id|period_id|base_id|
+------+-------+---------+-------+
|1 |1 |m10 |t21 |
|1 |2 |m10 |t22 |
|1 |3 |m11 |t22 |
|1 |4 |m11 |t22 |
|1 |5 |m10 |t23 |
|1 |6 |m10 |t22 |
|1 |7 |m10 |t22 |
|1 |8 |m11 |t22 |
|1 |9 |m10 |t22 |
|1 |10 |m10 |t22 |
|2 |1 |m10 |t22 |
|2 |2 |m11 |t22 |
|2 |3 |m10 |t23 |
|2 |4 |m10 |t22 |
|2 |5 |m10 |t22 |
|2 |10 |m11 |t22 |
|3 |5 |m11 |t22 |
|3 |6 |m10 |t22 |
|3 |8 |m10 |t22 |
|3 |9 |m11 |t22 |
|3 |10 |m10 |t23 |
+------+-------+---------+-------+
I am stuck on this and don't know how to proceed on this. Any suggestions would be greatly appreciated.
This approach has multiple joins and aggregations, so hopefully the data is not too big.
Basically, records with sets of rules are created. Then, a join correlates the original records with sub-rules that must exist for that org/rule combination, as well as the rules that are actually exhibited within that org, creating orgsContainingRulesDF. Using this DF, you can filter out rules where not all "sub-rules" were exhibited.
// Assume rule/sub-rule info can be read as either a Map or List of Tuple
val rules = Map(4->Set(1,2,3), 7->Set(1,4,5), 9->Set(8,10))
val rulesDF = rules.toList.toDF("rule", "sub_rules")
// For each org_id, get a set of rules which appear under it
val ruleSetsDF = df.groupBy(col("org_id")).agg(collect_set(col("rule_id")) as "rules")
// For each rule with sub-rules, match with orgs containing that rule
// Also get the full list of rules pertaining to that org
val orgsContainingRulesDF = rulesDF.join(df, $"rule" === $"rule_id", "left").join(ruleSetsDF, Seq("org_id"), "left")
// Create a UDF for determining if all items in first seq are in second seq
val subsetOf = udf((array1: Seq[String], array2: Seq[String]) => {
Set(array1:_*).subsetOf(Set(array2:_*))
})
// Create DF with items to delete
// i.e. org-and-rule-id-pairs where not all sub-rules appear in exhibited rules
val toDeleteDF = orgsContainingRulesDF.filter(!subsetOf($"sub_rules", $"rules"))
// Use a left anti-join (inverse of left join) to only preserve records
// with no corresponding toDeleteDF record
val resultDF = df.join(toDeleteDF, Seq("org_id", "rule_id"), "left_anti").orderBy($"org_id", $"rule_id")
Result is as expected:
resultDF.show(25,false)
+------+-------+---------+-------+
|org_id|rule_id|period_id|base_id|
+------+-------+---------+-------+
|1 |1 |m10 |t22 |
|1 |2 |m10 |t22 |
|1 |3 |m11 |t22 |
|1 |4 |m11 |t22 |
|1 |5 |m10 |t22 |
|1 |6 |m10 |t22 |
|1 |7 |m10 |t22 |
|1 |8 |m11 |t22 |
|1 |9 |m10 |t22 |
|1 |10 |m10 |t22 |
|2 |1 |m10 |t22 |
|2 |2 |m11 |t22 |
|2 |3 |m10 |t22 |
|2 |4 |m10 |t22 |
|2 |5 |m10 |t22 |
|2 |10 |m11 |t22 |
|3 |5 |m11 |t22 |
|3 |6 |m10 |t22 |
|3 |8 |m10 |t22 |
|3 |9 |m11 |t22 |
|3 |10 |m10 |t22 |
+------+-------+---------+-------+
This problem can be solved using a SQL window function.
Let's register your original data and the properties file as temporary views data and rule_filters respectively:
Seq(
(1, 1, "m10", "t22"),
(1, 2, "m10", "t22"),
(1, 3, "m11", "t22"),
(1, 4, "m11", "t22"),
(1, 5, "m10", "t22"),
(1, 6, "m10", "t22"),
(1, 7, "m10", "t22"),
(1, 8, "m11", "t22"),
(1, 9, "m10", "t22"),
(1, 10, "m10", "t22"),
(2, 1, "m10", "t22"),
(2, 2, "m11", "t22"),
(2, 3, "m10", "t22"),
(2, 4, "m10", "t22"),
(2, 5, "m10", "t22"),
(2, 9, "m10", "t22"),
(2, 10, "m11", "t22"),
(3, 4, "m10", "t22"),
(3, 5, "m11", "t22"),
(3, 6, "m10", "t22"),
(3, 7, "m10", "t22"),
(3, 8, "m10", "t22"),
(3, 9, "m11", "t22"),
(3, 10, "m10", "t22")
).toDF(
"org_id",
"rule_id",
"period_id",
"base_id"
).createOrReplaceTempView("data")
Seq(
"4=1,2,3",
"7=1,4,5",
"9=8,10"
).map { line =>
val Array(key, values) = line.split("=")
(key, values.split(",").map(_.toInt).sorted)
}.toDF(
"key",
"rules"
).createOrReplaceTempView("rule_filters")
Then the following SQL query solves the problem:
SELECT
org_id,
rule_id,
period_id,
base_id
FROM
(
SELECT
*,
array_sort(
collect_set(rule_id) OVER (
PARTITION BY org_id ROWS BETWEEN UNBOUNDED PRECEDING
AND UNBOUNDED FOLLOWING
)
) AS rules_in_org
FROM
data
LEFT JOIN rule_filters ON rule_id = key
)
WHERE
rules IS NULL
OR array_intersect(rules_in_org, rules) = rules
ORDER BY
org_id,
rule_id
If you prefer, you may also implement it using the DataFrame API:
table("data")
.join(table("rule_filters"), $"data.rule_id" === $"rule_filters.key", "left")
.select(
$"*",
array_sort(
collect_set($"rule_id").over(
Window
.partitionBy($"org_id")
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
)
) as "rules_within_org"
)
.filter($"rules".isNull || array_intersect($"rules_within_org", $"rules") === $"rules")
.drop("key", "rules", "rules_within_org")
.orderBy($"org_id", $"rule_id")
.show(Int.MaxValue)
+------+-------+---------+-------+
|org_id|rule_id|period_id|base_id|
+------+-------+---------+-------+
| 1| 1| m10| t22|
| 1| 2| m10| t22|
| 1| 3| m11| t22|
| 1| 4| m11| t22|
| 1| 5| m10| t22|
| 1| 6| m10| t22|
| 1| 7| m10| t22|
| 1| 8| m11| t22|
| 1| 9| m10| t22|
| 1| 10| m10| t22|
| 2| 1| m10| t22|
| 2| 2| m11| t22|
| 2| 3| m10| t22|
| 2| 4| m10| t22|
| 2| 5| m10| t22|
| 2| 10| m11| t22|
| 3| 5| m11| t22|
| 3| 6| m10| t22|
| 3| 8| m10| t22|
| 3| 9| m11| t22|
| 3| 10| m10| t22|
+------+-------+---------+-------+

How to create a column expression using a subquery in spark scala

Given any df, I want to calculate another column for the df called "has_duplicates", and then add a column with a boolean value for whether each row is unique. Example input df:
val df = Seq((1, 2), (2, 5), (1, 7), (1, 2), (2, 5)).toDF("A", "B")
Given an input columns: Seq[String], I know how to get the count of each row:
val countsDf = df.withColumn("count", count("*").over(Window.partitionBy(columns.map(col(_)): _*)))
But I'm not sure how to use this to create a column expression for the final column indicating whether each row is unique.
Something like
def getEvaluationExpression(df: DataFrame): Column = {
when("count > 1", lit("fail").otherwise(lit("pass"))
}
but the count needs to be evaluated on the spot using the query above.
Try below code.
scala> df.withColumn("has_duplicates", when(count("*").over(Window.partitionBy(df.columns.map(col(_)): _*)) > 1 , lit("fail")).otherwise("pass")).show(false)
+---+---+--------------+
|A |B |has_duplicates|
+---+---+--------------+
|1 |7 |pass |
|1 |2 |fail |
|1 |2 |fail |
|2 |5 |fail |
|2 |5 |fail |
+---+---+--------------+
Or
scala> df.withColumn("count",count("*").over(Window.partitionBy(df.columns.map(col(_)): _*))).withColumn("has_duplicates", when($"count" > 1 , lit("fail")).otherwise("pass")).show(false)
+---+---+-----+--------------+
|A |B |count|has_duplicates|
+---+---+-----+--------------+
|1 |7 |1 |pass |
|1 |2 |2 |fail |
|1 |2 |2 |fail |
|2 |5 |2 |fail |
|2 |5 |2 |fail |
+---+---+-----+--------------+

spark-scala: Transform the dataframe to generate new column gender and vice versa [closed]

Closed. This question needs to be more focused. It is not currently accepting answers.
Want to improve this question? Update the question so it focuses on one problem only by editing this post.
Closed 2 years ago.
Improve this question
Table1:
class male female
1 2 1
2 0 2
3 2 0
table2:
class gender
1 m
1 f
1 m
2 f
2 f
3 m
3 m
Using spark-scala take the data from table1 and dump into another table in the format of table2 as given.Also please do vice-versa
Please help me in this guys.
Thanks in Advance
You can use udf and explode function like below.
import org.apache.spark.sql.functions._
import spark.implicits._
val df=Seq((1,2,1),(2,0,2),(3,2,0)).toDF("class","male","female")
//Input Df
+-----+----+------+
|class|male|female|
+-----+----+------+
| 1| 2| 1|
| 2| 0| 2|
| 3| 2| 0|
+-----+----+------+
val getGenderUdf=udf((x:Int,y:Int)=>List.fill(x)("m")++List.fill(y)("f"))
val df1=df.withColumn("gender",getGenderUdf(df.col("male"),df.col("female"))).drop("male","female").withColumn("gender",explode($"gender"))
df1.show()
+-----+------+
|class|gender|
+-----+------+
| 1| m|
| 1| m|
| 1| f|
| 2| f|
| 2| f|
| 3| m|
| 3| m|
+-----+------+
Reverse of df1
val df2=df1.groupBy("class").pivot("gender").agg(count("gender")).na.fill(0).withColumnRenamed("m","male").withColumnRenamed("f","female")
df2.show()
//Sample Output:
+-----+------+----+
|class|female|male|
+-----+------+----+
| 1| 1| 2|
| 3| 0| 2|
| 2| 2| 0|
+-----+------+----+
val inDF = Seq((1,2,1),
(2, 0, 2),
(3, 2, 0)).toDF("class", "male", "female")
val testUdf = udf((m: Int, f: Int) => {
val ml = 1.to(m).map(_ => "m")
val fml = 1.to(f).map(_ => "f")
ml ++ fml
})
val df1 = inDF.withColumn("mf", testUdf('male, 'female))
.drop("male", "female")
.select('class, explode('mf).alias("gender"))
Perhaps this is helpful - without UDF
spark>=2.4
Load the test data provided
val data =
"""
|class | male | female
|1 | 2 | 1
|2 | 0 | 2
|3 | 2 | 0
""".stripMargin
val stringDS1 = data.split(System.lineSeparator())
.map(_.split("\\|").map(_.replaceAll("""^[ \t]+|[ \t]+$""", "")).mkString(","))
.toSeq.toDS()
val df1 = spark.read
.option("sep", ",")
.option("inferSchema", "true")
.option("header", "true")
.option("nullValue", "null")
.csv(stringDS1)
df1.show(false)
df1.printSchema()
/**
* +-----+----+------+
* |class|male|female|
* +-----+----+------+
* |1 |2 |1 |
* |2 |0 |2 |
* |3 |2 |0 |
* +-----+----+------+
*
* root
* |-- class: integer (nullable = true)
* |-- male: integer (nullable = true)
* |-- female: integer (nullable = true)
*/
compute the gender array and explode
df1.select($"class",
when($"male" >= 1, sequence(lit(1), col("male"))).otherwise(array()).as("male"),
when($"female" >= 1, sequence(lit(1), col("female"))).otherwise(array()).as("female")
).withColumn("male", expr("TRANSFORM(male, x -> 'm')"))
.withColumn("female", expr("TRANSFORM(female, x -> 'f')"))
.withColumn("gender", explode(concat($"male", $"female")))
.select("class", "gender")
.show(false)
/**
* +-----+------+
* |class|gender|
* +-----+------+
* |1 |m |
* |1 |m |
* |1 |f |
* |2 |f |
* |2 |f |
* |3 |m |
* |3 |m |
* +-----+------+
*/
vice versa
df2.groupBy("class").agg(collect_list("gender").as("gender"))
.withColumn("male", expr("size(FILTER(gender, x -> x='m'))"))
.withColumn("female", expr("size(FILTER(gender, x -> x='f'))"))
.select("class", "male", "female")
.orderBy("class")
.show(false)
/**
* +-----+----+------+
* |class|male|female|
* +-----+----+------+
* |1 |2 |1 |
* |2 |0 |2 |
* |3 |2 |0 |
* +-----+----+------+
*/