I have the following DataFrame:
name,email,phone,country
------------------------------------------------
[Mike,mike#example.com,+91-9999999999,Italy]
[Alex,alex#example.com,+91-9999999998,France]
[John,john#example.com,+1-1111111111,United States]
[Donald,donald#example.com,+1-2222222222,United States]
[Dan,dan#example.com,+91-9999444999,Poland]
[Scott,scott#example.com,+91-9111999998,Spain]
[Rob,rob#example.com,+91-9114444998,Italy]
exposed as temp table tagged_users:
resultDf.createOrReplaceTempView("tagged_users")
I need to add additional column tag to this DataFrame and assign calculated tags by different SQL conditions, which are described in the following map(key - tag name, value - condition for WHERE clause)
val tags = Map(
"big" -> "country IN (SELECT * FROM big_countries)",
"medium" -> "country IN (SELECT * FROM medium_countries)",
//2000 other different tags and conditions
"sometag" -> "name = 'Donald' AND email = 'donald#example.com' AND phone = '+1-2222222222'"
)
I have the following DataFrames(as data dictionaries) in order to be able to use them in SQL query:
Seq("Italy", "France", "United States", "Spain").toDF("country").createOrReplaceTempView("big_countries")
Seq("Poland", "Hungary", "Spain").toDF("country").createOrReplaceTempView("medium_countries")
I want to test each line in my tagged_users table and assign it appropriate tags. I tried to implement the following logic in order to achieve it:
tags.foreach {
case (tag, tagCondition) => {
resultDf = spark.sql(buildTagQuery(tag, tagCondition, "tagged_users"))
.withColumn("tag", lit(tag).cast(StringType))
}
}
def buildTagQuery(tag: String, tagCondition: String, table: String): String = {
f"SELECT * FROM $table WHERE $tagCondition"
}
but right now I don't know how to accumulate tags and not override them. Right now as the result I have the following DataFrame:
name,email,phone,country,tag
Dan,dan#example.com,+91-9999444999,Poland,medium
Scott,scott#example.com,+91-9111999998,Spain,medium
but I need something like:
name,email,phone,country,tag
Mike,mike#example.com,+91-9999999999,Italy,big
Alex,alex#example.com,+91-9999999998,France,big
John,john#example.com,+1-1111111111,United States,big
Donald,donald#example.com,+1-2222222222,United States,(big|sometag)
Dan,dan#example.com,+91-9999444999,Poland,medium
Scott,scott#example.com,+91-9111999998,Spain,(big|medium)
Rob,rob#example.com,+91-9114444998,Italy,big
Please note that Donal should have 2 tags (big|sometag) and Scott should have 2 tags (big|medium).
Please show how to implement it.
UPDATED
val spark = SparkSession
.builder()
.appName("Java Spark SQL basic example")
.config("spark.master", "local")
.getOrCreate();
import spark.implicits._
import spark.sql
Seq("Italy", "France", "United States", "Spain").toDF("country").createOrReplaceTempView("big_countries")
Seq("Poland", "Hungary", "Spain").toDF("country").createOrReplaceTempView("medium_countries")
val df = Seq(
("Mike", "mike#example.com", "+91-9999999999", "Italy"),
("Alex", "alex#example.com", "+91-9999999998", "France"),
("John", "john#example.com", "+1-1111111111", "United States"),
("Donald", "donald#example.com", "+1-2222222222", "United States"),
("Dan", "dan#example.com", "+91-9999444999", "Poland"),
("Scott", "scott#example.com", "+91-9111999998", "Spain"),
("Rob", "rob#example.com", "+91-9114444998", "Italy")).toDF("name", "email", "phone", "country")
df.collect.foreach(println)
df.createOrReplaceTempView("tagged_users")
val tags = Map(
"big" -> "country IN (SELECT * FROM big_countries)",
"medium" -> "country IN (SELECT * FROM medium_countries)",
"sometag" -> "name = 'Donald' AND email = 'donald#example.com' AND phone = '+1-2222222222'")
val sep_tag = tags.map((x) => { s"when array_contains(" + x._1 + ", country) then '" + x._1 + "' " }).mkString
val combine_sel_tag1 = tags.map((x) => { s" array_contains(" + x._1 + ",country) " }).mkString(" and ")
val combine_sel_tag2 = tags.map((x) => x._1).mkString(" '(", "|", ")' ")
val combine_sel_all = " case when " + combine_sel_tag1 + " then " + combine_sel_tag2 + sep_tag + " end as tags "
val crosqry = tags.map((x) => { s" cross join ( select collect_list(country) as " + x._1 + " from " + x._1 + "_countries) " + x._1 + " " }).mkString
val qry = " select name,email,phone,country, " + combine_sel_all + " from tagged_users " + crosqry
spark.sql(qry).show
spark.stop()
fails with the following exception:
Caused by: org.apache.spark.sql.catalyst.analysis.NoSuchTableException: Table or view 'sometag_countries' not found in database 'default';
at org.apache.spark.sql.catalyst.catalog.ExternalCatalog$class.requireTableExists(ExternalCatalog.scala:48)
at org.apache.spark.sql.catalyst.catalog.InMemoryCatalog.requireTableExists(InMemoryCatalog.scala:45)
at org.apache.spark.sql.catalyst.catalog.InMemoryCatalog.getTable(InMemoryCatalog.scala:326)
at org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener.getTable(ExternalCatalogWithListener.scala:138)
at org.apache.spark.sql.catalyst.catalog.SessionCatalog.lookupRelation(SessionCatalog.scala:701)
at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveRelations$.org$apache$spark$sql$catalyst$analysis$Analyzer$ResolveRelations$$lookupTableFromCatalog(Analyzer.scala:730)
... 74 more
Check out this DF solution:
scala> val df = Seq(("Mike","mike#example.com","+91-9999999999","Italy"),
| ("Alex","alex#example.com","+91-9999999998","France"),
| ("John","john#example.com","+1-1111111111","United States"),
| ("Donald","donald#example.com","+1-2222222222","United States"),
| ("Dan","dan#example.com","+91-9999444999","Poland"),
| ("Scott","scott#example.com","+91-9111999998","Spain"),
| ("Rob","rob#example.com","+91-9114444998","Italy")
| ).toDF("name","email","phone","country")
df: org.apache.spark.sql.DataFrame = [name: string, email: string ... 2 more fields]
scala> val dfbc=Seq("Italy", "France", "United States", "Spain").toDF("country")
dfbc: org.apache.spark.sql.DataFrame = [country: string]
scala> val dfmc=Seq("Poland", "Hungary", "Spain").toDF("country")
dfmc: org.apache.spark.sql.DataFrame = [country: string]
scala> val dfbc2=dfbc.agg(collect_list('country).as("bcountry"))
dfbc2: org.apache.spark.sql.DataFrame = [bcountry: array<string>]
scala> val dfmc2=dfmc.agg(collect_list('country).as("mcountry"))
dfmc2: org.apache.spark.sql.DataFrame = [mcountry: array<string>]
scala> val df2=df.crossJoin(dfbc2).crossJoin(dfmc2)
df2: org.apache.spark.sql.DataFrame = [name: string, email: string ... 4 more fields]
scala> df2.selectExpr("*","case when array_contains(bcountry,country) and array_contains(mcountry,country) then '(big|medium)' when array_contains(bcountry,country) then 'big' when array_contains(mcountry,country) then 'medium' else 'none' end as `tags`").select("name","email","phone","country","tags").show(false)
+------+------------------+--------------+-------------+------------+
|name |email |phone |country |tags |
+------+------------------+--------------+-------------+------------+
|Mike |mike#example.com |+91-9999999999|Italy |big |
|Alex |alex#example.com |+91-9999999998|France |big |
|John |john#example.com |+1-1111111111 |United States|big |
|Donald|donald#example.com|+1-2222222222 |United States|big |
|Dan |dan#example.com |+91-9999444999|Poland |medium |
|Scott |scott#example.com |+91-9111999998|Spain |(big|medium)|
|Rob |rob#example.com |+91-9114444998|Italy |big |
+------+------------------+--------------+-------------+------------+
scala>
SQL approach
scala> Seq(("Mike","mike#example.com","+91-9999999999","Italy"),
| ("Alex","alex#example.com","+91-9999999998","France"),
| ("John","john#example.com","+1-1111111111","United States"),
| ("Donald","donald#example.com","+1-2222222222","United States"),
| ("Dan","dan#example.com","+91-9999444999","Poland"),
| ("Scott","scott#example.com","+91-9111999998","Spain"),
| ("Rob","rob#example.com","+91-9114444998","Italy")
| ).toDF("name","email","phone","country").createOrReplaceTempView("tagged_users")
scala> Seq("Italy", "France", "United States", "Spain").toDF("country").createOrReplaceTempView("big_countries")
scala> Seq("Poland", "Hungary", "Spain").toDF("country").createOrReplaceTempView("medium_countries")
scala> spark.sql(""" select name,email,phone,country,case when array_contains(bc,country) and array_contains(mc,country) then '(big|medium)' when array_contains(bc,country) then 'big' when array_contains(mc,country) then 'medium' else 'none' end as tags from tagged_users cross join ( select collect_list(country) as bc from big_countries ) b cross join ( select collect_list(country) as mc from medium_countries ) c """).show(false)
+------+------------------+--------------+-------------+------------+
|name |email |phone |country |tags |
+------+------------------+--------------+-------------+------------+
|Mike |mike#example.com |+91-9999999999|Italy |big |
|Alex |alex#example.com |+91-9999999998|France |big |
|John |john#example.com |+1-1111111111 |United States|big |
|Donald|donald#example.com|+1-2222222222 |United States|big |
|Dan |dan#example.com |+91-9999444999|Poland |medium |
|Scott |scott#example.com |+91-9111999998|Spain |(big|medium)|
|Rob |rob#example.com |+91-9114444998|Italy |big |
+------+------------------+--------------+-------------+------------+
scala>
Iterating through the tags
scala> val tags = Map(
| "big" -> "country IN (SELECT * FROM big_countries)",
| "medium" -> "country IN (SELECT * FROM medium_countries)")
tags: scala.collection.immutable.Map[String,String] = Map(big -> country IN (SELECT * FROM big_countries), medium -> country IN (SELECT * FROM medium_countries))
scala> val sep_tag = tags.map( (x) => { s"when array_contains("+x._1+", country) then '" + x._1 + "' " } ).mkString
sep_tag: String = "when array_contains(big, country) then 'big' when array_contains(medium, country) then 'medium' "
scala> val combine_sel_tag1 = tags.map( (x) => { s" array_contains("+x._1+",country) " } ).mkString(" and ")
combine_sel_tag1: String = " array_contains(big,country) and array_contains(medium,country) "
scala> val combine_sel_tag2 = tags.map( (x) => x._1 ).mkString(" '(","|", ")' ")
combine_sel_tag2: String = " '(big|medium)' "
scala> val combine_sel_all = " case when " + combine_sel_tag1 + " then " + combine_sel_tag2 + sep_tag + " end as tags "
combine_sel_all: String = " case when array_contains(big,country) and array_contains(medium,country) then '(big|medium)' when array_contains(big, country) then 'big' when array_contains(medium, country) then 'medium' end as tags "
scala> val crosqry = tags.map( (x) => { s" cross join ( select collect_list(country) as "+x._1+" from "+x._1+"_countries) "+ x._1 + " " } ).mkString
crosqry: String = " cross join ( select collect_list(country) as big from big_countries) big cross join ( select collect_list(country) as medium from medium_countries) medium "
scala> val qry = " select name,email,phone,country, " + combine_sel_all + " from tagged_users " + crosqry
qry: String = " select name,email,phone,country, case when array_contains(big,country) and array_contains(medium,country) then '(big|medium)' when array_contains(big, country) then 'big' when array_contains(medium, country) then 'medium' end as tags from tagged_users cross join ( select collect_list(country) as big from big_countries) big cross join ( select collect_list(country) as medium from medium_countries) medium "
scala> spark.sql(qry).show
+------+------------------+--------------+-------------+------------+
| name| email| phone| country| tags|
+------+------------------+--------------+-------------+------------+
| Mike| mike#example.com|+91-9999999999| Italy| big|
| Alex| alex#example.com|+91-9999999998| France| big|
| John| john#example.com| +1-1111111111|United States| big|
|Donald|donald#example.com| +1-2222222222|United States| big|
| Dan| dan#example.com|+91-9999444999| Poland| medium|
| Scott| scott#example.com|+91-9111999998| Spain|(big|medium)|
| Rob| rob#example.com|+91-9114444998| Italy| big|
+------+------------------+--------------+-------------+------------+
scala>
UPDATE2:
scala> Seq(("Mike","mike#example.com","+91-9999999999","Italy"),
| ("Alex","alex#example.com","+91-9999999998","France"),
| ("John","john#example.com","+1-1111111111","United States"),
| ("Donald","donald#example.com","+1-2222222222","United States"),
| ("Dan","dan#example.com","+91-9999444999","Poland"),
| ("Scott","scott#example.com","+91-9111999998","Spain"),
| ("Rob","rob#example.com","+91-9114444998","Italy")
| ).toDF("name","email","phone","country").createOrReplaceTempView("tagged_users")
scala> Seq("Italy", "France", "United States", "Spain").toDF("country").createOrReplaceTempView("big_countries")
scala> Seq("Poland", "Hungary", "Spain").toDF("country").createOrReplaceTempView("medium_countries")
scala> val tags = Map(
| "big" -> "country IN (SELECT * FROM big_countries)",
| "medium" -> "country IN (SELECT * FROM medium_countries)",
| "sometag" -> "name = 'Donald' AND email = 'donald#example.com' AND phone = '+1-2222222222'")
tags: scala.collection.immutable.Map[String,String] = Map(big -> country IN (SELECT * FROM big_countries), medium -> country IN (SELECT * FROM medium_countries), sometag -> name = 'Donald' AND email = 'donald#example.com' AND phone = '+1-2222222222')
scala> val sql_tags = tags.map( x => { val p = x._2.trim.toUpperCase.split(" ");
| val qry = if(p.contains("IN") && p.contains("FROM"))
| s" case when array_contains((select collect_list("+p.head +") from " + p.last.replaceAll("[)]","")+ " ), " +p.head + " ) then '" + x._1 + " ' else '' end " + x._1 + " "
| else
| " case when " + x._2 + " then '" + x._1 + " ' else '' end " + x._1 + " ";
| qry } ).mkString(",")
sql_tags: String = " case when array_contains((select collect_list(COUNTRY) from BIG_COUNTRIES ), COUNTRY ) then 'big ' else '' end big , case when array_contains((select collect_list(COUNTRY) from MEDIUM_COUNTRIES ), COUNTRY ) then 'medium ' else '' end medium , case when name = 'Donald' AND email = 'donald#example.com' AND phone = '+1-2222222222' then 'sometag ' else '' end sometag "
scala> val outer_query = tags.map( x=> x._1).mkString(" regexp_replace(trim(concat(", ",", " )),' ','|') tags ")
outer_query: String = " regexp_replace(trim(concat(big,medium,sometag )),' ','|') tags "
scala> spark.sql(" select name,email, country, " + outer_query + " from ( select name,email, country ," + sql_tags + " from tagged_users ) " ).show
+------+------------------+-------------+-----------+
| name| email| country| tags|
+------+------------------+-------------+-----------+
| Mike| mike#example.com| Italy| big|
| Alex| alex#example.com| France| big|
| John| john#example.com|United States| big|
|Donald|donald#example.com|United States|big|sometag|
| Dan| dan#example.com| Poland| medium|
| Scott| scott#example.com| Spain| big|medium|
| Rob| rob#example.com| Italy| big|
+------+------------------+-------------+-----------+
scala>
If you need to aggregate the results and not just execute each query perhaps use map instead of foreach then union the results
val o = tags.map {
case (tag, tagCondition) => {
val resultDf = spark.sql(buildTagQuery(tag, tagCondition, "tagged_users"))
.withColumn("tag", new Column("blah"))
resultDf
}
}
o.foldLeft(o.head) {
case (acc, df) => acc.union(df)
}
I would define multiple tags tables with columns value, tag.
Then your tags definition would be a collection say Seq[(String, String] where the first tuple element is the column on which the tag is calculated.
Lets say
Seq(
"country" -> "bigCountries", // Columns [country, bigCountry]
"country" -> "mediumCountries", // Columns [country, mediumCountry]
"email" -> "hotmailLosers" // [country, hotmailLoser]
)
Then iterate through this list, left join each table on the relevant column with the associated column.
After joining each table simply select your tags column to be the current value + the joined column if it is not null.