Related
As far as I understand, this is the most efficient way to calculate average in Spark: Spark : Average of values instead of sum in reduceByKey using Scala.
My question is: if I use the high-level dataset with a groupby followed by Spark functions' avg(), will I get the same RDD under the hood? Can I trust Catalyst or I should use the low-level RDD? I mean, will writing low-level code yield better results than a dataset?
Example code:
employees
.groupBy($"employee")
.agg(
avg($"salary").as("avg_salary")
)
Versus:
employees
.mapValues(employee => (employee.salary, 1)) // map entry with a count of 1
.reduceByKey {
case ((sumL, countL), (sumR, countR)) =>
(sumL + sumR, countL + countR)
}
.mapValues {
case (sum , count) => sum / count
}
I don't see it a black-and-white question. In general, if you have a RDD, especially if it's a PairRDD, and need a result in RDD, it would make sense to settle with reduceByKey. On the other hand, given a DataFrame I would recommend going with groupBy/agg(avg).
A couple of things to consider:
Built-in optimization
While reduceByKey is relatively efficient compared to functions like groupByKey, it does induce stage boundaries since the operation requires repartitioning the data by keys. Depending on the RDD's partitions, the number of tasks in the derived stage may end up to be too small to take advantage of the available cpu cores, potentially resulting in a performance bottleneck. Such performance issue could be addressed, for instance, by manually assigning numPartition in reduceByKey:
def reduceByKey(func: (V, V) => V, numPartitions: Int): RDD[(K, V)]
But the point is that, to fully optimize RDD operations, one might need to put in some manual tweaking effort. In contrary, most operations for DataFrames are automatically optimized by the built-in Catalyst query optimizer.
Memory usage efficiency
Perhaps the other more significant factor to be looked at is related to memory usage for large dataset. When a RDD needs to be distributed across nodes or written to disk, Spark will serialize every row of data into objects, subject to costly Garbage Collection overhead. On the other hand, with knowledge of a DataFrame's schema, Spark doesn't need to serialize the data into objects. The Tungsten execution engine can leverage off-heap memory to store data in binary format for transformations, resulting in more efficient use of memory.
In conclusion, while there may be more knobs for tweaking using low-level code, that does not necessarily result in more performant code due to inadequate optimization, additional cost for serialization, etc.
We can conclude this from the plan generated by Spark.
This is the plan for DataFrame syntax-
val employees = spark.createDataFrame(Seq(("E1",100.0), ("E2",200.0),("E3",300.0))).toDF("employee","salary")
employees
.groupBy($"employee")
.agg(
avg($"salary").as("avg_salary")
).explain(true)
Plan -
== Parsed Logical Plan ==
'Aggregate ['employee], [unresolvedalias('employee, None), avg('salary) AS avg_salary#11]
+- Project [_1#0 AS employee#4, _2#1 AS salary#5]
+- LocalRelation [_1#0, _2#1]
== Analyzed Logical Plan ==
employee: string, avg_salary: double
Aggregate [employee#4], [employee#4, avg(salary#5) AS avg_salary#11]
+- Project [_1#0 AS employee#4, _2#1 AS salary#5]
+- LocalRelation [_1#0, _2#1]
== Optimized Logical Plan ==
Aggregate [employee#4], [employee#4, avg(salary#5) AS avg_salary#11]
+- LocalRelation [employee#4, salary#5]
== Physical Plan ==
*(2) HashAggregate(keys=[employee#4], functions=[avg(salary#5)], output=[employee#4, avg_salary#11])
+- Exchange hashpartitioning(employee#4, 10)
+- *(1) HashAggregate(keys=[employee#4], functions=[partial_avg(salary#5)], output=[employee#4, sum#17, count#18L])
+- LocalTableScan [employee#4, salary#5]
As the plan suggests first "HashAggregate" happened with partial average then "exchange hashpartitioning" happened for full average. The conclusion is that catalyst optimized the DataFrame operation as if we programmed with "reduceByKey" syntax. So we needn't take the burden of writing low level code.
Here is how RDD code and plan looks like.
employees
.map(employee => ("key",(employee.getAs[Double]("salary"), 1))) // map entry with a count of 1
.rdd.reduceByKey {
case ((sumL, countL), (sumR, countR)) =>
(sumL + sumR, countL + countR)
}
.mapValues {
case (sum , count) => sum / count
}.toDF().explain(true)
Plan -
== Parsed Logical Plan ==
SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1, true, false) AS _1#30, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2 AS _2#31]
+- ExternalRDD [obj#29]
== Analyzed Logical Plan ==
_1: string, _2: double
SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1, true, false) AS _1#30, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2 AS _2#31]
+- ExternalRDD [obj#29]
== Optimized Logical Plan ==
SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._1, true, false) AS _1#30, assertnotnull(input[0, scala.Tuple2, true])._2 AS _2#31]
+- ExternalRDD [obj#29]
== Physical Plan ==
*(1) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._1, true, false) AS _1#30, assertnotnull(input[0, scala.Tuple2, true])._2 AS _2#31]
+- Scan[obj#29]
The plan is optimized and also involves serialization of data into objects which means extra pressure of memory.
Conclusion
I would use daraframe syntax for its simplicity and possibly better performance.
Debug at println("done") , Go to http://localhost:4040/stages/ ,You will get the result.
val spark = SparkSession
.builder()
.master("local[*]")
.appName("example")
.getOrCreate()
val employees = spark.createDataFrame(Seq(("employee1",1000),("employee2",2000),("employee3",1500))).toDF("employee","salary")
import spark.implicits._
import org.apache.spark.sql.functions._
// Spark functions
employees
.groupBy("employee")
.agg(
avg($"salary").as("avg_salary")
).show()
// your low-level code
println("done")
I'm using com.datastax.spark:spark-cassandra-connector_2.11:2.4.0 when run zeppelin notebooks and don't understand the difference between two operations in spark. One operation takes a lot of time for computation, the second one executes immediately. Could someone explain to me the differences between two operations:
import com.datastax.spark.connector._
import org.apache.spark.sql.cassandra._
import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
import spark.implicits._
case class SomeClass(val someField:String)
val timelineItems = spark.read.format("org.apache.spark.sql.cassandra").options(scala.collection.immutable.Map("spark.cassandra.connection.host" -> "127.0.0.1", "table" -> "timeline_items", "keyspace" -> "timeline" )).load()
//some simplified code:
val timelineRow = timelineItems
.map(x => {SomeClass("test")})
.filter(x => x != null)
.toDF()
.limit(4)
//first operation (takes a lot of time. It seems spark iterates through all items in Cassandra and doesn't use laziness with limit 4)
println(timelineRow.count()) //return: 4
//second operation (executes immediately); 300 - just random number which doesn't affect the result
println(timelineRow.take(300).length) //return: 4
What is you see is a difference between implementation of Limit (an transformation-like operation) and CollectLimit (an action-like operation). However the difference in timings is highly misleading, and not something you can expect in general case.
First let's create a MCVE
spark.conf.set("spark.sql.files.maxPartitionBytes", 500)
val ds = spark.read
.text("README.md")
.as[String]
.map{ x => {
Thread.sleep(1000)
x
}}
val dsLimit4 = ds.limit(4)
make sure we start with clean slate:
spark.sparkContext.statusTracker.getJobIdsForGroup(null).isEmpty
Boolean = true
invoke count:
dsLimit4.count()
and take a look at the execution plan (from Spark UI):
== Parsed Logical Plan ==
Aggregate [count(1) AS count#12L]
+- GlobalLimit 4
+- LocalLimit 4
+- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#7]
+- MapElements <function1>, class java.lang.String, [StructField(value,StringType,true)], obj#6: java.lang.String
+- DeserializeToObject cast(value#0 as string).toString, obj#5: java.lang.String
+- Relation[value#0] text
== Analyzed Logical Plan ==
count: bigint
Aggregate [count(1) AS count#12L]
+- GlobalLimit 4
+- LocalLimit 4
+- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#7]
+- MapElements <function1>, class java.lang.String, [StructField(value,StringType,true)], obj#6: java.lang.String
+- DeserializeToObject cast(value#0 as string).toString, obj#5: java.lang.String
+- Relation[value#0] text
== Optimized Logical Plan ==
Aggregate [count(1) AS count#12L]
+- GlobalLimit 4
+- LocalLimit 4
+- Project
+- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#7]
+- MapElements <function1>, class java.lang.String, [StructField(value,StringType,true)], obj#6: java.lang.String
+- DeserializeToObject value#0.toString, obj#5: java.lang.String
+- Relation[value#0] text
== Physical Plan ==
*(2) HashAggregate(keys=[], functions=[count(1)], output=[count#12L])
+- *(2) HashAggregate(keys=[], functions=[partial_count(1)], output=[count#15L])
+- *(2) GlobalLimit 4
+- Exchange SinglePartition
+- *(1) LocalLimit 4
+- *(1) Project
+- *(1) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#7]
+- *(1) MapElements <function1>, obj#6: java.lang.String
+- *(1) DeserializeToObject value#0.toString, obj#5: java.lang.String
+- *(1) FileScan text [value#0] Batched: false, Format: Text, Location: InMemoryFileIndex[file:/path/to/README.md], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<value:string>
The core component is
+- *(2) GlobalLimit 4
+- Exchange SinglePartition
+- *(1) LocalLimit 4
which indicates that we can expect a wide operation with multiple stages. We can see a single job
spark.sparkContext.statusTracker.getJobIdsForGroup(null)
Array[Int] = Array(0)
with two stages
spark.sparkContext.statusTracker.getJobInfo(0).get.stageIds
Array[Int] = Array(0, 1)
with eight
spark.sparkContext.statusTracker.getStageInfo(0).get.numTasks
Int = 8
and one
spark.sparkContext.statusTracker.getStageInfo(1).get.numTasks
Int = 1
task respectively.
Now let's compare it to
dsLimit4.take(300).size
which generates following
== Parsed Logical Plan ==
GlobalLimit 300
+- LocalLimit 300
+- GlobalLimit 4
+- LocalLimit 4
+- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#7]
+- MapElements <function1>, class java.lang.String, [StructField(value,StringType,true)], obj#6: java.lang.String
+- DeserializeToObject cast(value#0 as string).toString, obj#5: java.lang.String
+- Relation[value#0] text
== Analyzed Logical Plan ==
value: string
GlobalLimit 300
+- LocalLimit 300
+- GlobalLimit 4
+- LocalLimit 4
+- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#7]
+- MapElements <function1>, class java.lang.String, [StructField(value,StringType,true)], obj#6: java.lang.String
+- DeserializeToObject cast(value#0 as string).toString, obj#5: java.lang.String
+- Relation[value#0] text
== Optimized Logical Plan ==
GlobalLimit 4
+- LocalLimit 4
+- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#7]
+- MapElements <function1>, class java.lang.String, [StructField(value,StringType,true)], obj#6: java.lang.String
+- DeserializeToObject value#0.toString, obj#5: java.lang.String
+- Relation[value#0] text
== Physical Plan ==
CollectLimit 4
+- *(1) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#7]
+- *(1) MapElements <function1>, obj#6: java.lang.String
+- *(1) DeserializeToObject value#0.toString, obj#5: java.lang.String
+- *(1) FileScan text [value#0] Batched: false, Format: Text, Location: InMemoryFileIndex[file:/path/to/README.md], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<value:string>
While both global and local limits still occur, there is no exchange in the middle. Therefore we can expect a single stage operation. Please note that planner narrowed down limit to more restrictive value.
As expected we see a single new job:
spark.sparkContext.statusTracker.getJobIdsForGroup(null)
Array[Int] = Array(1, 0)
which generated only one stage:
spark.sparkContext.statusTracker.getJobInfo(1).get.stageIds
Array[Int] = Array(2)
with only one task
spark.sparkContext.statusTracker.getStageInfo(2).get.numTasks
Int = 1
What does it mean for us?
In the count case Spark used wide transformation and actually applies LocalLimit on each partition and shuffles partial results to perform GlobalLimit.
In the take case Spark used narrow transformation and evaluated LocalLimit only on the first partition.
Obviously the latter approach won't work with number of values in the first partition is lower than the requested limit.
val dsLimit105 = ds.limit(105) // There are 105 lines
In such case the first count will use exactly the same logic as before (I encourage you to confirm that empirically), but take will take rather different path. So far we triggered only two jobs:
spark.sparkContext.statusTracker.getJobIdsForGroup(null)
Array[Int] = Array(1, 0)
Now if we execute
dsLimit105.take(300).size
you'll see that it required 3 more jobs:
spark.sparkContext.statusTracker.getJobIdsForGroup(null)
Array[Int] = Array(4, 3, 2, 1, 0)
So what's going on here? As noted before evaluating a single partition is not enough to satisfy limit in general case. In such case Spark iteratively evaluates LocalLimit on partitions, until GlobalLimit is satisfied, increasing number of partitions taken in each iteration.
Such strategy can have significant performance implications. Starting Spark jobs alone is not cheap and in cases, when upstream object is a result of wide transformation things can get quite ugly (in the best case scenario you can read shuffle files, but if these are lost for some reason, Spark might be forced to re-execute all the dependencies).
To summarize:
take is an action, and can short circuit in specific cases where upstream process is narrow, and LocalLimits can be satisfy GlobalLimits using the first few partitions.
limit is a transformation, and always evaluates all LocalLimits, as there is no iterative escape hatch.
While one can behave better than the other in specific cases, there not exchangeable and neither guarantees better performance in general.
I know that there exist several transformations which preserve parent partitioning (if it was set before - e.g. mapValues) and some which do not preserve it (e.g. map).
I use Dataset API of Spark 2.2. My question is - does dropDuplicates transformation preserve partitioning? Imagine this code:
case class Item(one: Int, two: Int, three: Int)
import session.implicits._
val ds = session.createDataset(List(Item(1,2,3), Item(1,2,3)))
val repart = ds.repartition('one, 'two).cache()
repart.dropDuplicates(List("one", "two")) // will be partitioning preserved?
generally, dropDuplicates does a shuffle (and thus not preserve partitioning), but in your special case it does NOT do an additional shuffle because you have already partitioned the dataset in a suitable form which is taken into account by the optimizer:
repart.dropDuplicates(List("one","two")).explain()
== Physical Plan ==
*HashAggregate(keys=[one#3, two#4, three#5], functions=[])
+- *HashAggregate(keys=[one#3, two#4, three#5], functions=[])
+- InMemoryTableScan [one#3, two#4, three#5]
+- InMemoryRelation [one#3, two#4, three#5], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)
+- Exchange hashpartitioning(one#3, two#4, 200)
+- LocalTableScan [one#3, two#4, three#5]
the keyword to look for here is : Exchange
But consider the following code where you first repartition the dataset using plain repartition():
val repart = ds.repartition(200).cache()
repart.dropDuplicates(List("one","two")).explain()
This will indeed trigger an additional shuffle ( now you have 2 Exchange steps):
== Physical Plan ==
*HashAggregate(keys=[one#3, two#4], functions=[first(three#5, false)])
+- Exchange hashpartitioning(one#3, two#4, 200)
+- *HashAggregate(keys=[one#3, two#4], functions=[partial_first(three#5, false)])
+- InMemoryTableScan [one#3, two#4, three#5]
+- InMemoryRelation [one#3, two#4, three#5], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)
+- Exchange RoundRobinPartitioning(200)
+- LocalTableScan [one#3, two#4, three#5]
NOTE: I checked that with Spark 2.1, it may be different in Spark 2.2 because the optimizer changed in Spark 2.2 (Cost-Based Optimizer)
No, dropDuplicates doesn't preserve partitions since it has a shuffle boundary, which doesn't guarantee order.
dropDuplicates is approximately:
ds.groupBy(columnId).agg(/* take first column from any available partition */)
Spark SQL doesn't seem to be able to push down LIMIT operators on inner joined tables. This is an issue when joining large tables to extract a small subset of rows. I'm testing on Spark 2.2.1 (most recent release).
Below is a contrived example, which runs in the spark-shell (Scala):
First, set up the tables:
case class Customer(id: Long, name: String, email: String, zip: String)
case class Order(id: Long, customer: Long, date: String, amount: Long)
val customers = Seq(
Customer(0, "George Washington", "gwashington#usa.gov", "22121"),
Customer(1, "John Adams", "gwashington#usa.gov", "02169"),
Customer(2, "Thomas Jefferson", "gwashington#usa.gov", "22902"),
Customer(3, "James Madison", "gwashington#usa.gov", "22960"),
Customer(4, "James Monroe", "gwashington#usa.gov", "22902")
)
val orders = Seq(
Order(1, 1, "07/04/1776", 23456),
Order(2, 3, "03/14/1760", 7850),
Order(3, 2, "05/23/1784", 12400),
Order(4, 3, "09/03/1790", 6550),
Order(5, 4, "07/21/1795", 2550),
Order(6, 0, "11/27/1787", 1440)
)
import spark.implicits._
val customerTable = spark.sparkContext.parallelize(customers).toDS()
customerTable.createOrReplaceTempView("customer")
val orderTable = spark.sparkContext.parallelize(orders).toDS()
orderTable.createOrReplaceTempView("order")
Now run the following join query, with a LIMIT and an arbitrary filter for each joined table:
scala> val join = spark.sql("SELECT c.* FROM customer c JOIN order o ON c.id = o.customer WHERE c.id > 1 AND o.amount > 5000 LIMIT 1")
Then print the corresponding optimized execution plan:
scala> println(join.queryExecution.sparkPlan.toString)
CollectLimit 1
+- Project [id#5L, name#6, email#7, zip#8]
+- SortMergeJoin [id#5L], [customer#17L], Inner
:- Filter (id#5L > 1)
: +- SerializeFromObject [assertnotnull(input[0, $line14.$read$$iw$$iw$Customer, true]).id AS id#5L, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line14.$read$$iw$$iw$Customer, true]).name, true) AS name#6, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line14.$read$$iw$$iw$Customer, true]).email, true) AS email#7, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line14.$read$$iw$$iw$Customer, true]).zip, true) AS zip#8]
: +- Scan ExternalRDDScan[obj#4]
+- Project [customer#17L]
+- Filter ((amount#19L > 5000) && (customer#17L > 1))
+- SerializeFromObject [assertnotnull(input[0, $line15.$read$$iw$$iw$Order, true]).id AS id#16L, assertnotnull(input[0, $line15.$read$$iw$$iw$Order, true]).customer AS customer#17L, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line15.$read$$iw$$iw$Order, true]).date, true) AS date#18, assertnotnull(input[0, $line15.$read$$iw$$iw$Order, true]).amount AS amount#19L]
+- Scan ExternalRDDScan[obj#15]
and you can immediately see that both tables are sorted in their entirety before being joined (although for these small example tables the explicit Sort step is not shown before the SortMergeJoin), and only afterward is the LIMIT applied.
If one of the databases contains billions of rows, this query becomes extremely slow and resource intensive, regardless of the LIMIT size.
Is Spark capable of optimizing such a query? Or can I work around the issue without mangling my SQL beyond recognition?
Spark 3.0 adds LimitPushDown so from that version it will be able to do that.
Is Spark capable of optimizing such a query
In short, it is not.
Using old nomenclature, join is a wide a transformation, where each output partition depends on every upstream partition. As a result, both parent datasets have to be fully scanned, to compute even a single partition of the child.
It is not impossible that some optimizations will be included in the future, however if you goal is to:
extract a small subset of rows.
then you should consider using database, not Apache Spark.
Not sure whether this will be helpful in your use case, but did you consider adding the filter and limit clause in order table itself.
val orderTableLimited = orderTable.filter($"customer" >
1).filter($"amount" > 5000).limit(1)
I have a Dataset that i created from a RDD and try to join it with another Dataset which is created from my Phoenix Table:
val dfToJoin = sparkSession.createDataset(rddToJoin)
val tableDf = sparkSession
.read
.option("table", "table")
.option("zkURL", "localhost")
.format("org.apache.phoenix.spark")
.load()
val joinedDf = dfToJoin.join(tableDf, "columnToJoinOn")
When i execute it, it seems that the whole database table is loaded to do the join.
Is there a way to do such a join so that the filtering is done on the database instead of in spark?
Also: dfToJoin is smaller than the table, i do not know if this is important.
Edit: Basically i want to join my Phoenix table with an Dataset created through spark, without fetching the whole table into the executor.
Edit2: Here is the physical plan:
*Project [FEATURE#21, SEQUENCE_IDENTIFIER#22, TAX_NUMBER#23,
WINDOW_NUMBER#24, uniqueIdentifier#5, readLength#6]
+- *SortMergeJoin [FEATURE#21], [feature#4], Inner
:- *Sort [FEATURE#21 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(FEATURE#21, 200)
: +- *Filter isnotnull(FEATURE#21)
: +- *Scan PhoenixRelation(FEATURES,localhost,false)
[FEATURE#21,SEQUENCE_IDENTIFIER#22,TAX_NUMBER#23,WINDOW_NUMBER#24]
PushedFilters: [IsNotNull(FEATURE)], ReadSchema:
struct<FEATURE:int,SEQUENCE_IDENTIFIER:string,TAX_NUMBER:int,
WINDOW_NUMBER:int>
+- *Sort [feature#4 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(feature#4, 200)
+- *Filter isnotnull(feature#4)
+- *SerializeFromObject [assertnotnull(input[0, utils.CaseClasses$QueryFeature, true], top level Product input object).feature AS feature#4, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, utils.CaseClasses$QueryFeature, true], top level Product input object).uniqueIdentifier, true) AS uniqueIdentifier#5, assertnotnull(input[0, utils.CaseClasses$QueryFeature, true], top level Product input object).readLength AS readLength#6]
+- Scan ExternalRDDScan[obj#3]
As you can see the equals-filter is not contained in the pushed-filters list, so it is obvious that no predicate pushdown is happening.
Spark will fetch the Phoenix table records to appropriate executors(not the entire table to one executor)
As the is no direct filter on Phoenix table df, we see only *Filter isnotnull(FEATURE#21) in physical plan.
As you are mentioning Phoenix table data is less when you apply filter on it. You push the filter to phoenix table on feature column by finding feature_ids in other dataset.
//This spread across workers - fully distributed
val dfToJoin = sparkSession.createDataset(rddToJoin)
//This sits in driver - not distributed
val list_of_feature_ids = dfToJoin.dropDuplicates("feature")
.select("feature")
.map(r => r.getString(0))
.collect
.toList
//This spread across workers - fully distributed
val tableDf = sparkSession
.read
.option("table", "table")
.option("zkURL", "localhost")
.format("org.apache.phoenix.spark")
.load()
.filter($"FEATURE".isin(list_of_feature_ids:_*)) //added filter
//This spread across workers - fully distributed
val joinedDf = dfToJoin.join(tableDf, "columnToJoinOn")
joinedDf.explain()