In the example below the code produce a computation that is applied systematically to the same set of the original records.
Instead, the code must use the previously computed value to produce the subsequent quantity.
package playground
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{KeyValueGroupedDataset, SparkSession}
object basic2 extends App {
Logger.getLogger("org").setLevel(Level.OFF)
Logger.getLogger("akka").setLevel(Level.OFF)
val spark = SparkSession
.builder()
.appName("Sample app")
.master("local")
.getOrCreate()
import spark.implicits._
final case class Owner(car: String, pcode: String, qtty: Double)
final case class Invoice(car: String, pcode: String, qtty: Double)
val data = Seq(
Owner("A", "666", 80),
Owner("B", "555", 20),
Owner("A", "444", 50),
Owner("A", "222", 20),
Owner("C", "444", 20),
Owner("C", "666", 80),
Owner("C", "555", 120),
Owner("A", "888", 100)
)
val fleet = Seq(Invoice("A", "666", 15), Invoice("A", "888", 12))
val owners = spark.createDataset(data)
val invoices = spark.createDataset(fleet)
val gb: KeyValueGroupedDataset[Invoice, (Owner, Invoice)] = owners
.joinWith(invoices, invoices("car") === owners("car"), "inner")
.groupByKey(_._2)
gb.flatMapGroups {
case (fleet, group) ⇒
val subOwner: Vector[Owner] = group.toVector.map(_._1)
val calculatedRes = subOwner.filter(_.car == fleet.car)
calculatedRes.map(c => c.copy(qtty = .3 * c.qtty + fleet.qtty))
}
.show()
}
/**
* +---+-----+----+
* |car|pcode|qtty|
* +---+-----+----+
* | A| 666|39.0|
* | A| 444|30.0|
* | A| 222|21.0|
* | A| 888|45.0|
* | A| 666|36.0|
* | A| 444|27.0|
* | A| 222|18.0|
* | A| 888|42.0|
* +---+-----+----+
*
* +---+-----+----+
* |car|pcode|qtty|
* +---+-----+----+
* | A| 666|0.3 * 39.0 + 12|
* | A| 444|0.3 * 30.0 + 12|
* | A| 222|0.3 * 21.0 + 12|
* | A| 888|0.3 * 45.0 + 12|
* +---+-----+----+
*/
The second table above is showing the expected output. The first table is what the code of this question produces.
How to produce the expected output in an iterative way?
Notice that the order of computation doesn't matter, the results will be different but it is still a valid answer.
Check below code.
val getQtty = udf((invoicesQtty:Seq[Double],ownersQtty:Double) => {
invoicesQtty.tail.foldLeft((0.3 * ownersQtty + invoicesQtty.head))(
(totalIQ,nextInvoiceQtty) => 0.3 * totalIQ + nextInvoiceQtty
)
})
val getQttyStr = udf((invoicesQtty:Seq[Double],ownersQtty:Double) => {
val totalIQ = (0.3 * ownersQtty + invoicesQtty.head)
invoicesQtty.tail.foldLeft("")(
(data,nextInvoiceQtty) => {
s"0.3 * ${if(data.isEmpty) totalIQ else s"(${data})"} + ${nextInvoiceQtty}"
}
)
})
owners
.join(invoices, invoices("car") === owners("car"), "inner")
.orderBy(invoices("qtty").desc)
.groupBy(owners("car"),owners("pcode"))
.agg(
collect_list(invoices("qtty")).as("invoices_qtty"),
first(owners("qtty")).as("owners_qtty")
)
.withColumn("qtty",getQtty($"invoices_qtty",$"owners_qtty"))
.withColumn("qtty_str",getQttyStr($"invoices_qtty",$"owners_qtty"))
.show(false)
Result
+---+-----+-------------+-----------+----+-----------------+
|car|pcode|invoices_qtty|owners_qtty|qtty|qtty_str |
+---+-----+-------------+-----------+----+-----------------+
|A |666 |[15.0, 12.0] |80.0 |23.7|0.3 * 39.0 + 12.0|
|A |888 |[15.0, 12.0] |100.0 |25.5|0.3 * 45.0 + 12.0|
|A |444 |[15.0, 12.0] |50.0 |21.0|0.3 * 30.0 + 12.0|
|A |222 |[15.0, 12.0] |20.0 |18.3|0.3 * 21.0 + 12.0|
+---+-----+-------------+-----------+----+-----------------+
Related
I have the below dataframe with me.
val df1=Seq(
("1_2_3","5_10"),
("4_5_6","15_20")
)toDF("c1","c2")
+-----+-----+
| c1| c2|
+-----+-----+
|1_2_3| 5_10|
|4_5_6|15_20|
+-----+-----+
How to get the sum in a separate column based on the condition -
-Omit third value after delimiter - '_' in the first column.
-adding first value of each column ie, omitting '_3' and '_6' in 1_2_3 and 4_5_6
and then adding 1,5 and 2,10. Also adding 15+4 and 20+5.
Expected output -
+-----+-----+-----+
| c1| c2| res|
+-----+-----+-----+
|1_2_3| 5_10| 6_12|
|4_5_6|15_20|19_25|
+-----+-----+-----+
Try this-
zip_with + split
val df1=Seq(
("1_2_3","5_10"),
("4_5_6","15_20")
)toDF("c1","c2")
df1.show(false)
df1.withColumn("res",
expr("concat_ws('_', zip_with(split(c1, '_'), split(c2, '_'), (x, y) -> cast(x+y as int)))"))
.show(false)
/**
* +-----+-----+-----+
* |c1 |c2 |res |
* +-----+-----+-----+
* |1_2_3|5_10 |6_12 |
* |4_5_6|15_20|19_25|
* +-----+-----+-----+
*/
update dynamically for 50 columns
val end = 51 // 50 cols
val df = spark.sql("select '1_2_3' as c1")
val new_df = Range(2, end).foldLeft(df){(df, i) => df.withColumn(s"c$i", $"c1")}
new_df.show(false)
/**
* +-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+
* |c1 |c2 |c3 |c4 |c5 |c6 |c7 |c8 |c9 |c10 |c11 |c12 |c13 |c14 |c15 |c16 |c17 |c18 |c19 |c20 |c21 |c22 |c23 |c24 |c25 |c26 |c27 |c28 |c29 |c30 |c31 |c32 |c33 |c34 |c35 |c36 |c37 |c38 |c39 |c40 |c41 |c42 |c43 |c44 |c45 |c46 |c47 |c48 |c49 |c50 |
* +-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+
* |1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|
* +-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+
*/
val res = new_df.withColumn("res", $"c1")
Range(2, end).foldLeft(res){(df4, i) =>
df4.withColumn("res",
expr(s"concat_ws('_', zip_with(split(res, '_'), split(${s"c$i"}, '_'), (x, y) -> cast(x+y as int)))"))
}
.show(false)
/**
* +-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+----------+
* |c1 |c2 |c3 |c4 |c5 |c6 |c7 |c8 |c9 |c10 |c11 |c12 |c13 |c14 |c15 |c16 |c17 |c18 |c19 |c20 |c21 |c22 |c23 |c24 |c25 |c26 |c27 |c28 |c29 |c30 |c31 |c32 |c33 |c34 |c35 |c36 |c37 |c38 |c39 |c40 |c41 |c42 |c43 |c44 |c45 |c46 |c47 |c48 |c49 |c50 |res |
* +-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+----------+
* |1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|1_2_3|50_100_150|
* +-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+----------+
*/
I have been told that EXCEPT is a very costly operation and one should always try to avoid using EXCEPT.
My Use Case -
val myFilter = "rollNo='11' AND class='10'"
val rawDataDf = spark.table(<table_name>)
val myFilteredDataframe = rawDataDf.where(myFilter)
val allOthersDataframe = rawDataDf.except(myFilteredDataframe)
But I am confused, in such use case , what are my alternatives ?
Use left anti join as below-
val df = spark.range(2).withColumn("name", lit("foo"))
df.show(false)
df.printSchema()
/**
* +---+----+
* |id |name|
* +---+----+
* |0 |foo |
* |1 |foo |
* +---+----+
*
* root
* |-- id: long (nullable = false)
* |-- name: string (nullable = false)
*/
val df2 = df.filter("id=0")
df.join(df2, df.columns.toSeq, "leftanti")
.show(false)
/**
* +---+----+
* |id |name|
* +---+----+
* |1 |foo |
* +---+----+
*/
I am very new to Spark, i have to perform string manipulation operations and create new column in spark dataframe. I have created UDF functions for string manipulation and due to performance i want to do this without UDF. Following is my code and output. Could please help me to create this in better way?
object Demo2 extends Context {
import org.apache.spark.sql.functions.udf
def main(args: Array[String]): Unit = {
import sparkSession.sqlContext.implicits._
val data = Seq(
("bankInfo.SBI.C_1.Kothrud.Pune.displayInfo"),
("bankInfo.ICICI.C_2.TilakRoad.Pune.displayInfo"),
("bankInfo.Axis.C_3.Santacruz.Mumbai.displayInfo"),
("bankInfo.HDFC.C_4.Deccan.Pune.displayInfo")
)
val df = data.toDF("Key")
println("Input Dataframe")
df.show(false)
//get local_address
val get_local_address = udf((key: String) => {
val first_index = key.indexOf(".")
val tmp_key = key.substring(first_index + 1)
val last_index = tmp_key.lastIndexOf(".")
val local_address = tmp_key.substring(0, last_index)
local_address
})
//get address
val get_address = udf((key: String) => {
val first_index = key.indexOf(".")
val tmp_key = key.substring(first_index + 1)
val last_index1 = tmp_key.lastIndexOf(".")
val tmp_key1 = tmp_key.substring(0, last_index1)
val last_index2 = tmp_key1.lastIndexOf(".");
val first_index1 = tmp_key1.lastIndexOf(".", last_index2 - 1);
val address = tmp_key1.substring(0, first_index1) + tmp_key1.substring(last_index2)
address
})
val df2 = df
.withColumn("Local Address", get_local_address(df("Key")))
.withColumn("Address", get_address(df("Key")))
println("Output Dataframe")
df2.show(false)
}
}
Input Dataframe
+----------------------------------------------+
|Key |
+----------------------------------------------+
|bankInfo.SBI.C_1.Kothrud.Pune.displayInfo |
|bankInfo.ICICI.C_2.TilakRoad.Pune.displayInfo |
|bankInfo.Axis.C_3.Santacruz.Mumbai.displayInfo|
|bankInfo.HDFC.C_4.Deccan.Pune.displayInfo |
+----------------------------------------------+
Output Dataframe
+----------------------------------------------+-------------------------+---------------+
|Key |Local Address |Address |
+----------------------------------------------+-------------------------+---------------+
|bankInfo.SBI.C_1.Kothrud.Pune.displayInfo |SBI.C_1.Kothrud.Pune |SBI.C_1.Pune |
|bankInfo.ICICI.C_2.TilakRoad.Pune.displayInfo |ICICI.C_2.TilakRoad.Pune |ICICI.C_2.Pune |
|bankInfo.Axis.C_3.Santacruz.Mumbai.displayInfo|Axis.C_3.Santacruz.Mumbai|Axis.C_3.Mumbai|
|bankInfo.HDFC.C_4.Deccan.Pune.displayInfo |HDFC.C_4.Deccan.Pune |HDFC.C_4.Pune |
+----------------------------------------------+-------------------------+---------------+
Since you have fixed sized array, you can structurize them and then concat as required-
Load the test data provided
val data =
"""
|Key
|bankInfo.SBI.C_1.Kothrud.Pune.displayInfo
|bankInfo.ICICI.C_2.TilakRoad.Pune.displayInfo
|bankInfo.Axis.C_3.Santacruz.Mumbai.displayInfo
|bankInfo.HDFC.C_4.Deccan.Pune.displayInfo
""".stripMargin
val stringDS1 = data.split(System.lineSeparator())
.map(_.split("\\|").map(_.replaceAll("""^[ \t]+|[ \t]+$""", "")).mkString(","))
.toSeq.toDS()
val df1 = spark.read
.option("sep", ",")
.option("inferSchema", "true")
.option("header", "true")
.option("nullValue", "null")
.csv(stringDS1)
df1.show(false)
df1.printSchema()
/**
* +----------------------------------------------+
* |Key |
* +----------------------------------------------+
* |bankInfo.SBI.C_1.Kothrud.Pune.displayInfo |
* |bankInfo.ICICI.C_2.TilakRoad.Pune.displayInfo |
* |bankInfo.Axis.C_3.Santacruz.Mumbai.displayInfo|
* |bankInfo.HDFC.C_4.Deccan.Pune.displayInfo |
* +----------------------------------------------+
*
* root
* |-- Key: string (nullable = true)
*/
Derive the columns from the fixed format string column
df1.select($"key", split($"key", "\\.").as("x"))
.withColumn("bankInfo",
expr(
"""
|named_struct('name', element_at(x, 2), 'cust_id', element_at(x, 3),
| 'branch', element_at(x, 4), 'dist', element_at(x, 5))
""".stripMargin))
.select($"key",
concat_ws(".", $"bankInfo.name", $"bankInfo.cust_id", $"bankInfo.branch", $"bankInfo.dist")
.as("Local_Address"),
concat_ws(".", $"bankInfo.name", $"bankInfo.cust_id", $"bankInfo.dist")
.as("Address"))
.show(false)
/**
* +----------------------------------------------+-------------------------+---------------+
* |key |Local_Address |Address |
* +----------------------------------------------+-------------------------+---------------+
* |bankInfo.SBI.C_1.Kothrud.Pune.displayInfo |SBI.C_1.Kothrud.Pune |SBI.C_1.Pune |
* |bankInfo.ICICI.C_2.TilakRoad.Pune.displayInfo |ICICI.C_2.TilakRoad.Pune |ICICI.C_2.Pune |
* |bankInfo.Axis.C_3.Santacruz.Mumbai.displayInfo|Axis.C_3.Santacruz.Mumbai|Axis.C_3.Mumbai|
* |bankInfo.HDFC.C_4.Deccan.Pune.displayInfo |HDFC.C_4.Deccan.Pune |HDFC.C_4.Pune |
* +----------------------------------------------+-------------------------+---------------+
*/
df1.select($"key", split($"key", "\\.").as("x"))
.withColumn("bankInfo",
expr("named_struct('name', x[1], 'cust_id', x[2], 'branch', x[3], 'dist', x[4])"))
.select($"key",
concat_ws(".", $"bankInfo.name", $"bankInfo.cust_id", $"bankInfo.branch", $"bankInfo.dist")
.as("Local_Address"),
concat_ws(".", $"bankInfo.name", $"bankInfo.cust_id", $"bankInfo.dist")
.as("Address"))
.show(false)
/**
* +----------------------------------------------+-------------------------+---------------+
* |key |Local_Address |Address |
* +----------------------------------------------+-------------------------+---------------+
* |bankInfo.SBI.C_1.Kothrud.Pune.displayInfo |SBI.C_1.Kothrud.Pune |SBI.C_1.Pune |
* |bankInfo.ICICI.C_2.TilakRoad.Pune.displayInfo |ICICI.C_2.TilakRoad.Pune |ICICI.C_2.Pune |
* |bankInfo.Axis.C_3.Santacruz.Mumbai.displayInfo|Axis.C_3.Santacruz.Mumbai|Axis.C_3.Mumbai|
* |bankInfo.HDFC.C_4.Deccan.Pune.displayInfo |HDFC.C_4.Deccan.Pune |HDFC.C_4.Pune |
* +----------------------------------------------+-------------------------+---------------+
*/
Check below code.
scala> df.show(false)
+----------------------------------------------+
|Key |
+----------------------------------------------+
|bankInfo.SBI.C_1.Kothrud.Pune.displayInfo |
|bankInfo.ICICI.C_2.TilakRoad.Pune.displayInfo |
|bankInfo.Axis.C_3.Santacruz.Mumbai.displayInfo|
|bankInfo.HDFC.C_4.Deccan.Pune.displayInfo |
+----------------------------------------------+
scala> val maxLength = df.select(split($"key","\\.").as("keys")).withColumn("length",size($"keys")).select(max($"length").as("length")).map(_.getAs[Int](0)).collect.head
maxLength: Int = 6
scala> val address_except = Seq(0,3,maxLength-1)
address_except: Seq[Int] = List(0, 3, 5)
scala> val local_address_except = Seq(0,maxLength-1)
local_address_except: Seq[Int] = List(0, 5)
scala> def parse(column: Column,indexes:Seq[Int]) = (0 to maxLength).filter(i => !indexes.contains(i)).map(i => column(i)).reduce(concat_ws(".",_,_))
parse: (column: org.apache.spark.sql.Column, indexes: Seq[Int])org.apache.spark.sql.Column
scala> df.select(split($"key","\\.").as("keys")).withColumn("local_address",parse($"keys",local_address_except)).withColumn("address",parse($"keys",address_except)).show(false)
+-----------------------------------------------------+-------------------------+---------------+
|keys |local_address |address |
+-----------------------------------------------------+-------------------------+---------------+
|[bankInfo, SBI, C_1, Kothrud, Pune, displayInfo] |SBI.C_1.Kothrud.Pune |SBI.C_1.Pune |
|[bankInfo, ICICI, C_2, TilakRoad, Pune, displayInfo] |ICICI.C_2.TilakRoad.Pune |ICICI.C_2.Pune |
|[bankInfo, Axis, C_3, Santacruz, Mumbai, displayInfo]|Axis.C_3.Santacruz.Mumbai|Axis.C_3.Mumbai|
|[bankInfo, HDFC, C_4, Deccan, Pune, displayInfo] |HDFC.C_4.Deccan.Pune |HDFC.C_4.Pune |
+-----------------------------------------------------+-------------------------+---------------+
I have a two DataFrames:
scala> df1.show()
+----+----+----+---+----+
|col1|col2|col3| |colN|
+----+----+----+ +----+
| 2|null| 3|...| 4|
| 4| 3| 3| | 1|
| 5| 2| 8| | 1|
+----+----+----+---+----+
scala> df2.show() // has one row only (avg())
+----+----+----+---+----+
|col1|col2|col3| |colN|
+----+----+----+ +----+
| 3.6|null| 4.6|...| 2|
+----+----+----+---+----+
and a constant val c : Double = 0.1.
Desired output is a df3: Dataframe that is given by
,
with n=numberOfRow and m=numberOfColumn.
I already looked through the list of sql.functions and failed implementing it myself with some nested map operations (fearing performance issues). One idea I had was:
val cBc = spark.sparkContext.broadcast(c)
val df2Bc = spark.sparkContext.broadcast(averageObservation)
df1.rdd.map(row => {
for (colIdx <- 0 until row.length) {
val correspondingDf2value = df2Bc.value.head().getDouble(colIdx)
row.getDouble(colIdx) * (1 - cBc.value) + correspondingDf2value * cBc.value
}
})
Thank you in advance!
(cross)join combined with select is more than enough and will be much more efficient than mapping. Required imports:
import org.apache.spark.sql.functions.{broadcast, col, lit}
and expression:
val exprs = df1.columns.map { x => (df1(x) * (1 - c) + df2(x) * c).alias(x) }
join and select:
df1.crossJoin(broadcast(df2)).select(exprs: _*)
I have two large dataframes [a] one which has all events identified by an id [b] a list of ids. I want to filter [a] based on the ids in [b] using the stat.bloomFilter implementation in spark 2.0.0
However I don't see any operations in the dataset API to join the bloom filter to the data frame [a]
val in1 = spark.sparkContext.parallelize(List(0, 1, 2, 3, 4, 5))
val df1 = in1.map(x => (x, x+1, x+2)).toDF("c1", "c2", "c3")
val in2 = spark.sparkContext.parallelize(List(0, 1, 2))
val df2 = in2.map(x => (x)).toDF("c1")
val expectedNumItems: Long = 1000
val fpp: Double = 0.005
val sbf = df.stat.bloomFilter($"c1", expectedNumItems, fpp)
val sbf2 = df2.stat.bloomFilter($"c1", expectedNumItems, fpp)
What is the best way to filter 'df1' based on values in df2?
Thanks!
You can use an UDF:
def might_contain(f: org.apache.spark.util.sketch.BloomFilter) = udf((x: Int) =>
if(x != null) f.mightContain(x) else false)
df1.where(might_contain(sbf2)($"C1"))
I think I found the correct way to do this, but would still like pointers to see if there are better ways to manage this.
Here's my solution -
val in1 = spark.sparkContext.parallelize(List(0, 1, 2, 3, 4, 5))
val d1 = in1.map(x => (x, x+1, x+2)).toDF("c1", "c2", "c3")
val in2 = spark.sparkContext.parallelize(List(0, 1, 2))
val d2 = in2.map(x => (x)).toDF("c1")
val s2 = d2.stat.bloomFilter($"c1", expectedNumItems, fpp)
val a = spark.sparkContext.broadcast(s2)
val x = d1.rdd.filter(x => a.value.mightContain(x(0)))
case class newType(c1: Int, c2: Int, c3: Int) extends Serializable
val xDF = x.map(y => newType(y(0).toString.toInt, y(1).toString.toInt, y(2).toString.toInt)).toDF()
scala> d1.show(10)
+---+---+---+
| c1| c2| c3|
+---+---+---+
| 0| 1| 2|
| 1| 2| 3|
| 2| 3| 4|
| 3| 4| 5|
| 4| 5| 6|
| 5| 6| 7|
+---+---+---+
scala> d2.show(10)
+---+
| c1|
+---+
| 0|
| 1|
| 2|
+---+
scala> xDF.show(10)
+---+---+---+
| c1| c2| c3|
+---+---+---+
| 0| 1| 2|
| 1| 2| 3|
| 2| 3| 4|
+---+---+---+
I built an implicit class that wraps https://stackoverflow.com/a/41989703/6723616
Comments welcome!
/**
* Copyright 2017 Yahoo, Inc.
* Zlib license: https://www.zlib.net/zlib_license.html
*/
package me.klotz.spark.utils
import org.apache.spark.sql.functions._
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.Row
import org.apache.spark.util.sketch.BloomFilter
import org.apache.spark.SparkContext
object BloomFilterEnhancedJoin {
// not parameterized for field typel; assumes string
/**
* Like .join(bigDF, smallDF, but accelerated with a Bloom filter.
* You pass in a size estimate of the bigDF, and a ratio of acceptable false positives out of the expected result set size.
* ratio=1 is a good start; that will result in about 50% false positives in the big-small join, so the filter accepts
* about as many as it passes, rather than rejecting almost all. Pass in a size estimate of the big dataframe
* to avoid enumerating it. The small DataFrame gets enumerated anyway.
*
* Example use:
* <code>
* import me.klotz.spark.utils.BloomFilterEnhancedJoin._
* val (dups_joined, bloomFilterBroadcast) = df_big.joinBloom(1024L*1024L*1024L, dups, 10.0, "id")
* dups_joined.write.format("orc").save("dups")
* bloomFilterBroadcast.unpersist
* <code>
*/
implicit class BloomFilterEnhancedJoiner(bigdf:Dataset[Row]) {
/**
* You should call bloomFilterBroadcast.unpersist after
*/
def joinBloom(bigDFCountEstimate:Long, smallDF: Dataset[Row], ratio:Double, field:String) = {
val sc = smallDF.sparkSession.sparkContext
val smallDFCount = smallDF.count
val fpr = smallDFCount.toDouble / bigDFCountEstimate.toDouble / ratio
println(s"fpr=${fpr} = smallDFCount=${smallDFCount} / bigDFCountEstimate=${bigDFCountEstimate} / ratio=${ratio}")
val bloomFilterBroadcast = sc.broadcast((smallDF.stat.bloomFilter(field, smallDFCount, fpr)))
val mightContain = udf((x: String) => if (x != null) bloomFilterBroadcast.value.mightContainString(x) else false)
(bigdf.filter(mightContain(col(field))).join(smallDF, field), bloomFilterBroadcast)
}
}
}