I already post a question similar but someone gave me a trick to avoid using the "if condition".
Here I am in a similar position and I do not find any trick to avoid it....
I have a dataframe.
var df = sc.parallelize(Array(
(1, "2017-06-29 10:53:53.0","2017-06-25 14:60:53.0","boulanger.fr"),
(2, "2017-07-05 10:48:57.0","2017-09-05 08:60:53.0","patissier.fr"),
(3, "2017-06-28 10:31:42.0","2017-02-28 20:31:42.0","boulanger.fr"),
(4, "2017-08-21 17:31:12.0","2017-10-21 10:29:12.0","patissier.fr"),
(5, "2017-07-28 11:22:42.0","2017-05-28 11:22:42.0","boulanger.fr"),
(6, "2017-08-23 17:03:43.0","2017-07-23 09:03:43.0","patissier.fr"),
(7, "2017-08-24 16:08:07.0","2017-08-22 16:08:07.0","boulanger.fr"),
(8, "2017-08-31 17:20:43.0","2017-05-22 17:05:43.0","patissier.fr"),
(9, "2017-09-04 14:35:38.0","2017-07-04 07:30:25.0","boulanger.fr"),
(10, "2017-09-07 15:10:34.0","2017-07-29 12:10:34.0","patissier.fr"))).toDF("id", "date1","date2", "mail")
df = df.withColumn("date1", (unix_timestamp($"date1", "yyyy-MM-dd HH:mm:ss").cast("timestamp")))
df = df.withColumn("date2", (unix_timestamp($"date2", "yyyy-MM-dd HH:mm:ss").cast("timestamp")))
df = df.orderBy("date1", "date2")
It looks like:
+---+---------------------+---------------------+------------+
|id |date1 |date2 |mail |
+---+---------------------+---------------------+------------+
|3 |2017-06-28 10:31:42.0|2017-02-28 20:31:42.0|boulanger.fr|
|1 |2017-06-29 10:53:53.0|2017-06-25 15:00:53.0|boulanger.fr|
|2 |2017-07-05 10:48:57.0|2017-09-05 09:00:53.0|patissier.fr|
|5 |2017-07-28 11:22:42.0|2017-05-28 11:22:42.0|boulanger.fr|
|4 |2017-08-21 17:31:12.0|2017-10-21 10:29:12.0|patissier.fr|
|6 |2017-08-23 17:03:43.0|2017-07-23 09:03:43.0|patissier.fr|
|7 |2017-08-24 16:08:07.0|2017-08-22 16:08:07.0|boulanger.fr|
|8 |2017-08-31 17:20:43.0|2017-05-22 17:05:43.0|patissier.fr|
|9 |2017-09-04 14:35:38.0|2017-07-04 07:30:25.0|boulanger.fr|
|10 |2017-09-07 15:10:34.0|2017-07-29 12:10:34.0|patissier.fr|
+---+---------------------+---------------------+------------+
For each id I want to count among all other line the number of lines with:
a date1 in [my_current_date1-60 day, my_current_date1-1 day]
a date2 < my_current_date1
the same mail than my current_mail
If I look at the line 5 I want to return the number of line with:
date1 in [2017-05-29 11:22:42.0, 2017-07-27 11:22:42.0]
date2 < 2017-07-28 11:22:42.0
mail = boulanger.fr
--> The result would be 2 (corresponding to id 1 and id 3)
So I would like to do something like:
val w = Window.partitionBy("mail").orderBy(col("date1").cast("long")).rangeBetween(-60*24*60*60,-1*24*60*60)
var df= df.withColumn("all_previous", count("mail") over w)
But this will respond to condition 1 and condition 3 but not the second one... i have to add something to includ this second condition comparing date2 to my_date1...
Using a generalized Window spec with last(date1) being the current date1 per Window partition and a sum over 0's and 1's as conditional count, here's how I would incorporate your condition #2 into the counting criteria:
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
def days(n: Long): Long = n * 24 * 60 * 60
val w = Window.partitionBy("mail").orderBy($"date1".cast("long"))
val w1 = w.rangeBetween(days(-60), days(0))
val w2 = w.rangeBetween(days(-60), days(-1))
df.withColumn("all_previous", sum(
when($"date2".cast("long") < last($"date1").over(w1).cast("long"), 1).
otherwise(0)
).over(w2)
).na.fill(0).
show
// +---+-------------------+-------------------+------------+------------+
// | id| date1| date2| mail|all_previous|
// +---+-------------------+-------------------+------------+------------+
// | 3|2017-06-28 10:31:42|2017-02-28 20:31:42|boulanger.fr| 0|
// | 1|2017-06-29 10:53:53|2017-06-25 15:00:53|boulanger.fr| 1|
// | 5|2017-07-28 11:22:42|2017-05-28 11:22:42|boulanger.fr| 2|
// | 7|2017-08-24 16:08:07|2017-08-22 16:08:07|boulanger.fr| 3|
// | 9|2017-09-04 14:35:38|2017-07-04 07:30:25|boulanger.fr| 2|
// | 2|2017-07-05 10:48:57|2017-09-05 09:00:53|patissier.fr| 0|
// | 4|2017-08-21 17:31:12|2017-10-21 10:29:12|patissier.fr| 0|
// | 6|2017-08-23 17:03:43|2017-07-23 09:03:43|patissier.fr| 0|
// | 8|2017-08-31 17:20:43|2017-05-22 17:05:43|patissier.fr| 1|
// | 10|2017-09-07 15:10:34|2017-07-29 12:10:34|patissier.fr| 2|
// +---+-------------------+-------------------+------------+------------+
[UPDATE]
This solution is incorrect, even though the result appears to be correct with the sample dataset. In particular, last($"date1").over(w1) did not work the way intended. The answer is being kept to hopefully serve as a lead for a working solution.
Related
I have a dataframe called rawEDI that looks something like this;
Line_number
Segment
1
ST
2
BPT
3
SE
4
ST
5
BPT
6
N1
7
SE
8
ST
9
PTD
10
SE
Each row represents a line in a file. Each line is called a segment and is denoted by something called a segment identifier; a short string. Segments are grouped together in chunks that start with an ST segment identifier and end with an SE segment segment identifier. There can be any number of ST chunks in a given file and the size of each any ST chunk is not fixed.
I want to create a new column on the dataframe that represents numerically what ST group a given segment belongs to. This will allow me to use groupBy to perform aggregate operations across all ST segments without having to loop over each individual ST segment, which is too slow.
The final DataFrame would look like this;
Line_number
Segment
ST_Group
1
ST
1
2
BPT
1
3
SE
1
4
ST
2
5
BPT
2
6
N1
2
7
SE
2
8
ST
3
9
PTD
3
10
SE
3
In short, I want to create and populate a DataFrame column with a number that increments by one whenever the value "ST" appears in the Segment column.
I am using spark 2.3.2 and scala 2.11.8
My initial thought was to use iteration. I collected another DataFrame, df, that contained the starting and ending line_number for each segment, looking like this;
Start
End
1
3
4
7
8
10
Then iterate over the rows of the dataframe and use them to populate the new column like this;
var st = 1
for (row <- df.collect()) {
val start = row(0)
val end = row(1)
var labelSTs = rawEDI.filter("line_number > = ${start}").filter("line_number <= ${end}").withColumn("ST_Group", lit(st))
st = st + 1
However, this yields an empty DataFrame. Additionally, the use of a for loop is time-prohibitive, taking over 20s on my machine for this. Achieving this result without the use of a loop would be huge, but a solution with a loop may also be acceptable if performant.
I have a hunch this can be accomplished using a udf or a Window, but I'm not certain how to attack that.
This
val func = udf((s:String) => if(s == "ST") 1 else 0)
var labelSTs = rawEDI.withColumn("ST_Group", func((col("segment")))
Only populates the column with 1 at each ST segment start.
And this
val w = Window.partitionBy("Segment").orderBy("line_number")
val labelSTs = rawEDI.withColumn("ST_Group", row_number().over(w)
Returns a nonsense dataframe.
One way is to create an intermediate dataframe of "groups" that would tell you on which line each group starts and ends (sort of what you've already done), and then join it to the original table using greater-than/less-than conditions.
Sample data
scala> val input = Seq((1,"ST"),(2,"BPT"),(3,"SE"),(4,"ST"),(5,"BPT"),
(6,"N1"),(7,"SE"),(8,"ST"),(9,"PTD"),(10,"SE"))
.toDF("linenumber","segment")
scala> input.show(false)
+----------+-------+
|linenumber|segment|
+----------+-------+
|1 |ST |
|2 |BPT |
|3 |SE |
|4 |ST |
|5 |BPT |
|6 |N1 |
|7 |SE |
|8 |ST |
|9 |PTD |
|10 |SE |
+----------+-------+
Create a dataframe for groups, using Window just as your hunch was telling you:
scala> val groups = input.where("segment='ST'")
.withColumn("endline",lead("linenumber",1) over Window.orderBy("linenumber"))
.withColumn("groupnumber",row_number() over Window.orderBy("linenumber"))
.withColumnRenamed("linenumber","startline")
.drop("segment")
scala> groups.show(false)
+---------+-----------+-------+
|startline|groupnumber|endline|
+---------+-----------+-------+
|1 |1 |4 |
|4 |2 |8 |
|8 |3 |null |
+---------+-----------+-------+
Join both to get the result
scala> input.join(groups,
input("linenumber") >= groups("startline") &&
(input("linenumber") < groups("endline") || groups("endline").isNull))
.select("linenumber","segment","groupnumber")
.show(false)
+----------+-------+-----------+
|linenumber|segment|groupnumber|
+----------+-------+-----------+
|1 |ST |1 |
|2 |BPT |1 |
|3 |SE |1 |
|4 |ST |2 |
|5 |BPT |2 |
|6 |N1 |2 |
|7 |SE |2 |
|8 |ST |3 |
|9 |PTD |3 |
|10 |SE |3 |
+----------+-------+-----------+
The only problem with this is Window.orderBy() on an unpartitioned dataframe, which would collect all data to a single partition and thus could be a killer.
if you want just to add column with a number that increments by one whenever the value "ST" appears in the Segment column, you can filter lines with the ST segment in a separate dataframe,
var labelSTs = rawEDI.filter("segement == 'ST'");
// then group by ST and collect to list the linenumbers
var groupedDf = labelSTs.groupBy("Segment").agg(collect_list("Line_number").alias("Line_numbers"))
// now you need to flat back the data frame and log the line number index
var flattedDf = groupedDf.select($"Segment", explode($"Line_numbers").as("Line_number"))
// log the line_number index in your target column ST_Group
val withIndexDF = flattenedDF.withColumn("ST_Group", row_number().over(Window.partitionBy($"Segment").orderBy($"Line_number")))
and you have this as result:
+-------+----------+----------------+
|Segment|Line_number|ST_Group |
+-------+----------+----------------+
| ST| 1| 1|
| ST| 4| 2|
| ST| 8| 3|
+-------|----------|----------------|
then you concat this with other Segement in the initial dataframe.
Found a more simpler way, add a column which will have 1 when the segment column value is ST, otherwise it will have 0. Then using Window function find the cummulative sum of that new column. This will give you the desired results.
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
val rawEDI=Seq((1,"ST"),(2,"BPT"),(3,"SE"),(4,"ST"),(5,"BPT"),(6,"N1"),(7,"SE"),(8,"ST"),(9,"PTD"),(10,"SE")).toDF("line_number","segment")
val newDf=rawEDI.withColumn("ST_Group", ($"segment" === "ST").cast("bigint"))
val windowSpec = Window.orderBy("line_number")
newDf.withColumn("ST_Group", sum("ST_Group").over(windowSpec))
.show
+-----------+-------+--------+
|line_number|segment|ST_Group|
+-----------+-------+--------+
| 1| ST| 1|
| 2| BPT| 1|
| 3| SE| 1|
| 4| ST| 2|
| 5| BPT| 2|
| 6| N1| 2|
| 7| SE| 2|
| 8| ST| 3|
| 9| PTD| 3|
| 10| SE| 3|
+-----------+-------+--------+
I'm new to SparkSQL, and I want to calculate the percentage in my data with every status.
Here is my data like below:
A B
11 1
11 3
12 1
13 3
12 2
13 1
11 1
12 2
So,I can do it in SQL like this:
select (C.oneTotal / C.total) as onePercentage,
(C.twoTotal / C.total) as twotPercentage,
(C.threeTotal / C.total) as threPercentage
from (select count(*) as total,
A,
sum(case when B = '1' then 1 else 0 end) as oneTotal,
sum(case when B = '2' then 1 else 0 end) as twoTotal,
sum(case when B = '3' then 1 else 0 end) as threeTotal
from test
group by A) as C;
But in the SparkSQL DataFrame, first I calculate totalCount in every status like below:
// wrong code
val cc = transDF.select("transData.*").groupBy("A")
.agg(count("transData.*").alias("total"),
sum(when(col("B") === "1", 1)).otherwise(0)).alias("oneTotal")
sum(when(col("B") === "2", 1).otherwise(0)).alias("twoTotal")
The problem is the sum(when)'s result is zero.
Do I have wrong use with it?
How to implement it in SparkSQL just like my above SQL? And then calculate the portion of every status?
Thank you for your help. In the end, I solve it with sum(when). below is my current code.
val cc = transDF.select("transData.*").groupBy("A")
.agg(count("transData.*").alias("total"),
sum(when(col("B") === "1", 1).otherwise(0)).alias("oneTotal"),
sum(when(col("B") === "2", 1).otherwise(0)).alias("twoTotal"))
.select(col("total"),
col("A"),
col("oneTotal") / col("total").alias("oneRate"),
col("twoTotal") / col("total").alias("twoRate"))
Thanks again.
you can use sum(when(... or also count(when.., the second option being shorter to write:
val df = Seq(
(11, 1),
(11, 3),
(12, 1),
(13, 3),
(12, 2),
(13, 1),
(11, 1),
(12, 2)
).toDF("A", "B")
df
.groupBy($"A")
.agg(
count("*").as("total"),
count(when($"B"==="1",$"A")).as("oneTotal"),
count(when($"B"==="2",$"A")).as("twoTotal"),
count(when($"B"==="3",$"A")).as("threeTotal")
)
.select(
$"A",
($"oneTotal"/$"total").as("onePercentage"),
($"twoTotal"/$"total").as("twoPercentage"),
($"threeTotal"/$"total").as("threePercentage")
)
.show()
gives
+---+------------------+------------------+------------------+
| A| onePercentage| twoPercentage| threePercentage|
+---+------------------+------------------+------------------+
| 12|0.3333333333333333|0.6666666666666666| 0.0|
| 13| 0.5| 0.0| 0.5|
| 11|0.6666666666666666| 0.0|0.3333333333333333|
+---+------------------+------------------+------------------+
alternatively, you could produce a "long" table with window-functions:
df
.groupBy($"A",$"B").count()
.withColumn("total",sum($"count").over(Window.partitionBy($"A")))
.select(
$"A",
$"B",
($"count"/$"total").as("percentage")
).orderBy($"A",$"B")
.show()
+---+---+------------------+
| A| B| percentage|
+---+---+------------------+
| 11| 1|0.6666666666666666|
| 11| 3|0.3333333333333333|
| 12| 1|0.3333333333333333|
| 12| 2|0.6666666666666666|
| 13| 1| 0.5|
| 13| 3| 0.5|
+---+---+------------------+
As far as I understood you want to implement the logic like above sql showed in the question.
one way is like below example
package examples
import org.apache.log4j.Level
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
object AggTest extends App {
val logger = org.apache.log4j.Logger.getLogger("org")
logger.setLevel(Level.WARN)
val spark = SparkSession.builder.appName(getClass.getName)
.master("local[*]").getOrCreate
import spark.implicits._
val df = Seq(
(11, 1),
(11, 3),
(12, 1),
(13, 3),
(12, 2),
(13, 1),
(11, 1),
(12, 2)
).toDF("A", "B")
df.show(false)
df.createOrReplaceTempView("test")
spark.sql(
"""
|select (C.oneTotal / C.total) as onePercentage,
| (C.twoTotal / C.total) as twotPercentage,
| (C.threeTotal / C.total) as threPercentage
|from (select count(*) as total,
| A,
| sum(case when B = '1' then 1 else 0 end) as oneTotal,
| sum(case when B = '2' then 1 else 0 end) as twoTotal,
| sum(case when B = '3' then 1 else 0 end) as threeTotal
| from test
| group by A) as C
""".stripMargin).show
}
Result :
+---+---+
|A |B |
+---+---+
|11 |1 |
|11 |3 |
|12 |1 |
|13 |3 |
|12 |2 |
|13 |1 |
|11 |1 |
|12 |2 |
+---+---+
+------------------+------------------+------------------+
| onePercentage| twotPercentage| threPercentage|
+------------------+------------------+------------------+
|0.3333333333333333|0.6666666666666666| 0.0|
| 0.5| 0.0| 0.5|
|0.6666666666666666| 0.0|0.3333333333333333|
+------------------+------------------+------------------+
I am using spark 2.3 in my scala application. I have a dataframe which create from spark sql that name is sqlDF in the sample code which I shared. I have a string list that has the items below
List[] stringList items
-9,-8,-7,-6
I want to replace all values that match with this lists item in all columns in dataframe to 0.
Initial dataframe
column1 | column2 | column3
1 |1 |1
2 |-5 |1
6 |-6 |1
-7 |-8 |-7
It must return to
column1 | column2 | column3
1 |1 |1
2 |-5 |1
6 |0 |1
0 |0 |0
For this I am itarating the query below for all columns (more than 500) in sqlDF.
sqlDF = sqlDF.withColumn(currColumnName, when(col(currColumnName).isin(stringList:_*), 0).otherwise(col(currColumnName)))
But getting the error below, by the way if I choose only one column for iterating it works, but if I run the code above for 500 columns iteration it fails
Exception in thread "streaming-job-executor-0"
java.lang.StackOverflowError at
scala.collection.generic.GenTraversableFactory$GenericCanBuildFrom.apply(GenTraversableFactory.scala:57)
at
scala.collection.generic.GenTraversableFactory$GenericCanBuildFrom.apply(GenTraversableFactory.scala:52)
at
scala.collection.TraversableLike$class.builder$1(TraversableLike.scala:229)
at
scala.collection.TraversableLike$class.map(TraversableLike.scala:233)
at scala.collection.immutable.List.map(List.scala:285) at
org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:333)
at
org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:187)
What is the thing that I am missing?
Here is a different approach applying left anti join between columnX and X where X is your list of items transferred into a dataframe. The left anti join will return all the items not present in X, the results we concatenate them all together through an outer join (which can be replaced with left join for better performance, this though will exclude records with all zeros i.e id == 3) based on the id assigned with monotonically_increasing_id:
import org.apache.spark.sql.functions.{monotonically_increasing_id, col}
val df = Seq(
(1, 1, 1),
(2, -5, 1),
(6, -6, 1),
(-7, -8, -7))
.toDF("c1", "c2", "c3")
.withColumn("id", monotonically_increasing_id())
val exdf = Seq(-9, -8, -7, -6).toDF("x")
df.columns.map{ c =>
df.select("id", c).join(exdf, col(c) === $"x", "left_anti")
}
.reduce((df1, df2) => df1.join(df2, Seq("id"), "outer"))
.na.fill(0)
.show
Output:
+---+---+---+---+
| id| c1| c2| c3|
+---+---+---+---+
| 0| 1| 1| 1|
| 1| 2| -5| 1|
| 3| 0| 0| 0|
| 2| 6| 0| 1|
+---+---+---+---+
foldLeft works perfect for your case here as below
val df = spark.sparkContext.parallelize(Seq(
(1, 1, 1),
(2, -5, 1),
(6, -6, 1),
(-7, -8, -7)
)).toDF("a", "b", "c")
val list = Seq(-7, -8, -9)
val resultDF = df.columns.foldLeft(df) { (acc, name) => {
acc.withColumn(name, when(col(name).isin(list: _*), 0).otherwise(col(name)))
}
}
Output:
+---+---+---+
|a |b |c |
+---+---+---+
|1 |1 |1 |
|2 |-5 |1 |
|6 |-6 |1 |
|0 |0 |0 |
+---+---+---+
I would suggest you to broadcast the list of String :
val stringList=sc.broadcast(<Your List of List[String]>)
After that use this :
sqlDF = sqlDF.withColumn(currColumnName, when(col(currColumnName).isin(stringList.value:_*), 0).otherwise(col(currColumnName)))
Make sure your currColumnName also is in String Format. Comparison should be String to String
I am trying to join 2 data frame, in the 1st DF i need to pass a dynamic number of column and join that with another DF. The complexity I am facing here i have a case statement with the output of 1st DF. I am able get the desired output by creating the temp view. But not able to achieve the same output through spark.
Below is the snippet, i have tried and works as expected.
// Sample DF1
val studentDF = Seq(
(1, "Peter","M",15,"Tution Received"),
(2, "Merry","F",14,null),
(3, "Sam","M",16,"Tution Received"),
(4, "Kat","O",16,null),
(5, "Keivn","M",18,null)
).toDF("Enrollment", "Name","Gender","Age","Notes")
//Sample DF2
val studentFees = Seq((1,"$500","Deposit"),(2, "$800","Deposit"),(3,"$200","Deposit"),(4,"$100","Deposit")).toDF("Enrollment","Fees","Notes")
studentDF.createOrReplaceTempView("STUDENT")
studentFees.createOrReplaceTempView("FEES")
val displayColumns = List("Enrollment","Name","Gender").map("a."+_).reduce(_+","+_)
val queryStr = spark.sql(s"select $displayColumns, case when a.Notes is null then b.Notes else a.Notes end as Notes, b.Fees from STUDENT a join FEES b on a.Enrollment=b.Enrollment")
queryStr.show()
---------+-----+------+---------------+----+
|Enrollment| Name|Gender| Notes|Fees|
+----------+-----+------+---------------+----+
| 1|Peter| M|Tution Received|$500|
| 2|Merry| F| Deposit|$800|
| 3| Sam| M|Tution Received|$200|
| 4| Kat| O| Deposit|$100|
+----------+-----+------+---------------+----+
// Below is not giving the desired output
val displayColumns = List("Enrollment","Name","Gender","Notes")
val queryStr = studentDF.select(displayColumns.head, displayColumns.tail: _*).alias("a").join(studentFees.as("b"),Seq("Enrollment"),"inner").withColumn("Notes",when($"a.Notes".isNull,$"b.Notes").otherwise($"a.Notes"))
queryStr.show()
Enrollment| Name|Gender| Notes|Fees| Notes|
+----------+-----+------+---------------+----+---------------+
| 1|Peter| M|Tution Received|$500|Tution Received|
| 2|Merry| F| Deposit|$800| Deposit|
| 3| Sam| M|Tution Received|$200|Tution Received|
| 4| Kat| O| Deposit|$100| Deposit|
+----------+-----+------+---------------+----+---------------+
// Expecting the output like below.
---------+-----+------+---------------+----+
|Enrollment| Name|Gender| Notes|Fees|
+----------+-----+------+---------------+----+
| 1|Peter| M|Tution Received|$500|
| 2|Merry| F| Deposit|$800|
| 3| Sam| M|Tution Received|$200|
| 4| Kat| O| Deposit|$100|
+----------+-----+------+---------------+----+
Is there a better way to handle such scenarios instead of crating temp table/views?
Thank You all whoever read my post!!
I was able to find the solution for my problem.
val displayColumns = List("Enrollment","Name","Gender","Notes")
val queryStr = studentDF.select(displayColumns.head, displayColumns.tail: _*).alias("a").join(studentFees.as("b"),Seq("Enrollment"),"inner").select($"a.*",when($"a.Notes".isNull,$"b.Notes").otherwise($"a.Notes").as("Notes"),$"b.Fees").drop($"a.Notes")
I am new to UDF in spark. I have also read the answer here
Problem statement: I'm trying to find pattern matching from a dataframe col.
Ex: Dataframe
val df = Seq((1, Some("z")), (2, Some("abs,abc,dfg")),
(3,Some("a,b,c,d,e,f,abs,abc,dfg"))).toDF("id", "text")
df.show()
+---+--------------------+
| id| text|
+---+--------------------+
| 1| z|
| 2| abs,abc,dfg|
| 3|a,b,c,d,e,f,abs,a...|
+---+--------------------+
df.filter($"text".contains("abs,abc,dfg")).count()
//returns 2 as abs exits in 2nd row and 3rd row
Now I want to do this pattern matching for every row in column $text and add new column called count.
Result:
+---+--------------------+-----+
| id| text|count|
+---+--------------------+-----+
| 1| z| 1|
| 2| abs,abc,dfg| 2|
| 3|a,b,c,d,e,f,abs,a...| 1|
+---+--------------------+-----+
I tried to define a udf passing $text column as Array[Seq[String]. But I am not able to get what I intended.
What I tried so far:
val txt = df.select("text").collect.map(_.toSeq.map(_.toString)) //convert column to Array[Seq[String]
val valsum = udf((txt:Array[Seq[String],pattern:String)=> {txt.count(_ == pattern) } )
df.withColumn("newCol", valsum( lit(txt) ,df(text)) )).show()
Any help would be appreciated
You will have to know all the elements of text column which can be done using collect_list by grouping all the rows of your dataframe as one. Then just check if element in text column in the collected array and count them as in the following code.
import sqlContext.implicits._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions._
val df = Seq((1, Some("z")), (2, Some("abs,abc,dfg")),(3,Some("a,b,c,d,e,f,abs,abc,dfg"))).toDF("id", "text")
val valsum = udf((txt: String, array : mutable.WrappedArray[String])=> array.filter(element => element.contains(txt)).size)
df.withColumn("grouping", lit("g"))
.withColumn("array", collect_list("text").over(Window.partitionBy("grouping")))
.withColumn("count", valsum($"text", $"array"))
.drop("grouping", "array")
.show(false)
You should have following output
+---+-----------------------+-----+
|id |text |count|
+---+-----------------------+-----+
|1 |z |1 |
|2 |abs,abc,dfg |2 |
|3 |a,b,c,d,e,f,abs,abc,dfg|1 |
+---+-----------------------+-----+
I hope this is helpful.