How to force DataFrame evaluation in Spark - scala

Sometimes (e.g. for testing and bechmarking) I want force the execution of the transformations defined on a DataFrame. AFAIK calling an action like count does not ensure that all Columns are actually computed, show may only compute a subset of all Rows (see examples below)
My solution is to write the DataFrame to HDFS using df.write.saveAsTable, but this "clutters" my system with tables I don't want to keep any further.
So what is the best way to trigger the evaluation of a DataFrame?
Edit:
Note that there is also a recent discussion on the spark developer list : http://apache-spark-developers-list.1001551.n3.nabble.com/Will-count-always-trigger-an-evaluation-of-each-row-td21018.html
I made a small example which shows that count on DataFrame does not evaluate everything (tested using Spark 1.6.3 and spark-master = local[2]):
val df = sc.parallelize(Seq(1)).toDF("id")
val myUDF = udf((i:Int) => {throw new RuntimeException;i})
df.withColumn("test",myUDF($"id")).count // runs fine
df.withColumn("test",myUDF($"id")).show() // gives Exception
Using the same logic, here an example that show does not evaluate all rows:
val df = sc.parallelize(1 to 10).toDF("id")
val myUDF = udf((i:Int) => {if(i==10) throw new RuntimeException;i})
df.withColumn("test",myUDF($"id")).show(5) // runs fine
df.withColumn("test",myUDF($"id")).show(10) // gives Exception
Edit 2 : For Eliasah: The Exception says this:
org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 6.0 failed 1 times, most recent failure: Lost task 0.0 in stage 6.0 (TID 6, localhost): java.lang.RuntimeException
at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$anonfun$1.apply$mcII$sp(<console>:68)
at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$anonfun$1.apply(<console>:68)
at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$anonfun$1.apply(<console>:68)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source)
at org.apache.spark.sql.execution.Project$$anonfun$1$$anonfun$apply$1.apply(basicOperators.scala:51)
at org.apache.spark.sql.execution.Project$$anonfun$1$$anonfun$apply$1.apply(basicOperators.scala:49)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
.
.
.
.
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1431)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1419)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1418)
at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:47)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1418)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:799)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:799)
at scala.Option.foreach(Option.scala:236)
at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:799)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1640)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1599)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1588)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:620)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:1832)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:1845)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:1858)
at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:212)
at org.apache.spark.sql.execution.Limit.executeCollect(basicOperators.scala:165)
at org.apache.spark.sql.execution.SparkPlan.executeCollectPublic(SparkPlan.scala:174)
at org.apache.spark.sql.DataFrame$$anonfun$org$apache$spark$sql$DataFrame$$execute$1$1.apply(DataFrame.scala:1500)
at org.apache.spark.sql.DataFrame$$anonfun$org$apache$spark$sql$DataFrame$$execute$1$1.apply(DataFrame.scala:1500)
at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:56)
at org.apache.spark.sql.DataFrame.withNewExecutionId(DataFrame.scala:2087)
at org.apache.spark.sql.DataFrame.org$apache$spark$sql$DataFrame$$execute$1(DataFrame.scala:1499)
at org.apache.spark.sql.DataFrame.org$apache$spark$sql$DataFrame$$collect(DataFrame.scala:1506)
at org.apache.spark.sql.DataFrame$$anonfun$head$1.apply(DataFrame.scala:1376)
at org.apache.spark.sql.DataFrame$$anonfun$head$1.apply(DataFrame.scala:1375)
at org.apache.spark.sql.DataFrame.withCallback(DataFrame.scala:2100)
at org.apache.spark.sql.DataFrame.head(DataFrame.scala:1375)
at org.apache.spark.sql.DataFrame.take(DataFrame.scala:1457)
at org.apache.spark.sql.DataFrame.showString(DataFrame.scala:170)
at org.apache.spark.sql.DataFrame.show(DataFrame.scala:350)
at org.apache.spark.sql.DataFrame.show(DataFrame.scala:311)
at org.apache.spark.sql.DataFrame.show(DataFrame.scala:319)
at $iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC$$iwC.<init>(<console>:74)
.
.
.
.

It's a bit late, but here's the fundamental reason: count does not act the same on RDD and DataFrame.
In DataFrames there's an optimization, as in some cases you do not require to load data to actually know the number of elements it has (especially in the case of yours where there's no data shuffling involved). Hence, the DataFrame materialized when count is called will not load any data and will not pass into your exception throwing. You can easily do the experiment by defining your own DefaultSource and Relation and see that calling count on a DataFrame will always end up in the method buildScan with no requiredColumns no matter how many columns you did select (cf. org.apache.spark.sql.sources.interfaces to understand more). It's actually a very efficient optimization ;-)
In RDDs though, there's no such optimizations (that's why one should always try to use DataFrames when possible). Hence the count on RDD executes all the lineage and returns the sum of all sizes of the iterators composing any partitions.
Calling dataframe.count goes into the first explanation, but calling dataframe.rdd.count goes into the second as you did build an RDD out of your DataFrame. Note that calling dataframe.cache().count forces the dataframe to be materialized as you required Spark to cache the results (hence it needs to load all the data and transform it). But it does have the side-effect of caching your data...

I guess simply getting an underlying rdd from DataFrame and triggering an action on it should achieve what you're looking for.
df.withColumn("test",myUDF($"id")).rdd.count // this gives proper exceptions

It appears that df.cache.count is the way to go:
scala> val myUDF = udf((i:Int) => {if(i==1000) throw new RuntimeException;i})
myUDF: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,IntegerType,Some(List(IntegerType)))
scala> val df = sc.parallelize(1 to 1000).toDF("id")
df: org.apache.spark.sql.DataFrame = [id: int]
scala> df.withColumn("test",myUDF($"id")).show(10)
[rdd_51_0]
+---+----+
| id|test|
+---+----+
| 1| 1|
| 2| 2|
| 3| 3|
| 4| 4|
| 5| 5|
| 6| 6|
| 7| 7|
| 8| 8|
| 9| 9|
| 10| 10|
+---+----+
only showing top 10 rows
scala> df.withColumn("test",myUDF($"id")).count
res13: Long = 1000
scala> df.withColumn("test",myUDF($"id")).cache.count
org.apache.spark.SparkException: Failed to execute user defined function($anonfun$1: (int) => int)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
.
.
.
Caused by: java.lang.RuntimeException
Source

I prefer to use df.save.parquet(). This does add disc I/o time that you can estimate and subtract out later, but you are positive that spark performed each step you expected and did not trick you with lazy evaluation.

Related

Spark: Is "count" on Grouped Data a Transformation or an Action?

I know that count called on an RDD or a DataFrame is an action. But while fiddling with the spark shell, I observed the following
scala> val empDF = Seq((1,"James Gordon", 30, "Homicide"),(2,"Harvey Bullock", 35, "Homicide"),(3,"Kristen Kringle", 28, "Records"),(4,"Edward Nygma", 30, "Forensics"),(5,"Leslie Thompkins", 31, "Forensics")).toDF("id", "name", "age", "department")
empDF: org.apache.spark.sql.DataFrame = [id: int, name: string, age: int, department: string]
scala> empDF.show
+---+----------------+---+----------+
| id| name|age|department|
+---+----------------+---+----------+
| 1| James Gordon| 30| Homicide|
| 2| Harvey Bullock| 35| Homicide|
| 3| Kristen Kringle| 28| Records|
| 4| Edward Nygma| 30| Forensics|
| 5|Leslie Thompkins| 31| Forensics|
+---+----------------+---+----------+
scala> empDF.groupBy("department").count //count returned a DataFrame
res1: org.apache.spark.sql.DataFrame = [department: string, count: bigint]
scala> res1.show
+----------+-----+
|department|count|
+----------+-----+
| Homicide| 2|
| Records| 1|
| Forensics| 2|
+----------+-----+
When I called count on GroupedData (empDF.groupBy("department")), I got another DataFrame as the result (res1). This leads me to believe that count in this case was a transformation. It is further supported by the fact that no computations were triggered when I called count, instead, they started when I ran res1.show.
I haven't been able to find any documentation that suggests count could be a transformation as well. Could someone please shed some light on this?
The .count() what you have used in your code is over RelationalGroupedDataset, which creates a new column with count of elements in the grouped dataset. This is a transformation. Refer:
https://spark.apache.org/docs/1.6.0/api/scala/index.html#org.apache.spark.sql.GroupedDataset
The .count() that you use normally over RDD/DataFrame/Dataset is completely different from the above and this .count() is an Action. Refer: https://spark.apache.org/docs/1.6.0/api/scala/index.html#org.apache.spark.rdd.RDD
EDIT:
always use .count() with .agg() while operating on groupedDataSet in order to avoid confusion in future:
empDF.groupBy($"department").agg(count($"department") as "countDepartment").show
Case 1:
You use rdd.count() to count the number of rows. Since it initiates the DAG execution and returns the data to the driver, its an action for RDD.
for ex: rdd.count // it returns a Long value
Case 2:
If you call count on Dataframe, it initiates the DAG execution and returns the data to the driver, its an action for Dataframe.
for ex: df.count // it returns a Long value
Case 3:
In your case you are calling groupBy on dataframe which returns RelationalGroupedDataset object, and you are calling count on grouped Dataset which returns a Dataframe, so its a transformation since it doesn't gets the data to the driver and initiates the DAG execution.
for ex:
df.groupBy("department") // returns RelationalGroupedDataset
.count // returns a Dataframe so a transformation
.count // returns a Long value since called on DF so an action
As you've already figure out - if method returns a distributed object (Dataset or RDD) it can be qualified as a transformations.
However these distinctions are much better suited for RDDs than Datasets. The latter ones features an optimizer, including recently added cost based optimizer, and might be much less lazy the old API, blurring differences between transformation and action in some case.
Here however it is safe to say count is a transformation.

How to flatten a nested field in a Spark Dataset?

I have nested field like below. I want to call flatmap (I think) to produce a flattened row.
My dataset has
A,B,[[x,y,z]],C
I want to convert it to produce output like
A,B,X,Y,Z,C
This is for Spark 2.0+
Thanks!
Apache DataFu has a generic explodeArray method that will do
exactly what you need.
import datafu.spark.DataFrameOps._
val df = sc.parallelize(Seq(("A","B",Array("X","Y","Z"),"C"))).toDF
df.explodeArray(col("_3"), "token").show
This will produce:
+---+---+---------+---+------+------+------+
| _1| _2| _3| _4|token0|token1|token2|
+---+---+---------+---+------+------+------+
| A| B|[X, Y, Z]| C| X| Y| Z|
+---+---+---------+---+------+------+------+
One thing to consider is that this method evaluates the data frame in order to determine how many columns to create - if it's expensive to compute it should be cached.
Full disclosure - I am a member of Apache DataFu.
Try this for RDD:
val rdd = sc.parallelize(Seq(("A","B",Array("X","Y","Z"),"C")))
rdd.flatMap(x => (Option(x._3).map(y => (x._1,x._2,y(0),y(1),y(2),x._4 )))).collect.foreach(println)
Output:
(A,B,X,Y,Z,C)

How to evaluate expressions that are the column values?

I have a big dataframe with millions of rows as follows:
A B C Eqn
12 3 4 A+B
32 8 9 B*C
56 12 2 A+B*C
How to evaluate the expressions in the Eqn column?
You could create a custom UDF that evaluates these arithmetic functions
def evalUDF = udf((a:Int, b:Int, c:Int, eqn:String) => {
val eqnParts = eqn
.replace("A", a.toString)
.replace("B", b.toString)
.replace("C", c.toString)
.split("""\b""")
.toList
val (sum, _) = eqnParts.tail.foldLeft((eqnParts.head.toInt, "")){
case ((runningTotal, "+"), num) => (runningTotal + num.toInt, "")
case ((runningTotal, "-"), num) => (runningTotal - num.toInt, "")
case ((runningTotal, "*"), num) => (runningTotal * num.toInt, "")
case ((runningTotal, _), op) => (runningTotal, op)
}
sum
})
evalDf
.withColumn("eval", evalUDF('A, 'B, 'C, 'Eqn))
.show()
Output:
+---+---+---+-----+----+
| A| B| C| Eqn|eval|
+---+---+---+-----+----+
| 12| 3| 4| A+B| 15|
| 32| 8| 9| B*C| 72|
| 56| 12| 2|A+B*C| 136|
+---+---+---+-----+----+
As you can see this works, but is very fragile (spaces, unknown operators, etc will break the code) and doesn't adhere to order of operations (otherwise the last should have been 92)
So you could write all that yourself or find some library that already does that perhaps (like https://gist.github.com/daixque/1610753)?
Maybe the performance overhead will be very large (especially it you start using recursive parsers), But at least you can perform it on a dataframe instead of collecting it first
I think the only way to execute SQLs that are inside a DataFrame is to select("Eqn").collect first followed by executing the SQLs iteratively on the source Dataset.
Since the SQLs are in a DataFrame that is nothing else but a description of a distributed computation that will be executed on Spark executors there is no way you could submit Spark jobs while processing the SQLs on executors. It is simply too late in the execution pipeline. You should be back on the driver to be able to submit new Spark jobs, say to execute SQLs.
With SQLs on the driver you'd then take the corresponding row per SQL and simply withColumn to execute SQLs (with their rows).
I think it's easier to write it than develop a working Spark application, but that's how I'd go about it.
I am late But incase someone is looking for
Generic Math expression interpreter using variables
Complex/unknown expression that cannot be hardcoded into the UDF (accepted answer)
Then you can use javax.script.ScriptEngineManager
import javax.script.SimpleBindings;
import javax.script.ScriptEngineManager
import java.util.Map
import java.util.HashMap
def calculateFunction = (mathExpression: String, A : Double, B : Double, C : Double ) => {
val vars: Map[String, Object] = new HashMap[String, Object]();
vars.put("A",A.asInstanceOf[Object])
vars.put("B",B.asInstanceOf[Object])
vars.put("C",C.asInstanceOf[Object])
val engine = new ScriptEngineManager().getEngineByExtension("js");
val result = engine.eval(mathExpression, new SimpleBindings(vars));
result.asInstanceOf[Double]
}
val calculateUDF = spark.udf.register("calculateFunction",calculateFunction)
NOTE : This will handle generic expressions and is robust but it give a lot more worse performance than the accepted answer and heavy on memory.

Spark - How to use QuantileDiscretizer with RandomForestClassifier

Is it possible to use QuantileDiscretizer, keeping NaN values, with a RandomForestClassifier?
I have been getting an error like this:
18/03/23 17:38:15 ERROR Executor: Exception in task 3.0 in stage 133.0 (TID 381)
java.lang.IllegalArgumentException: DecisionTree given invalid data: Feature 1 is categorical with values in {0,...,1, but a data point gives it value 2.0.
Bad data point: (1.0,[1.0,2.0])
Example
The idea here is to create a numeric column and discretize it using quantiles, keeping invalid numbers (NaN) in a special bucket.
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler,
QuantileDiscretizer}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.{RandomForestClassifier}
val tseq = Seq((0, "a", 1.0), (1, "b", 0.0), (2, "c", 2.0),
(3, "a", 1.0), (4, "a", 3.0), (5, "c", Double.NaN))
val tdf = SparkInit.ss.createDataFrame(tseq).toDF("id", "category", "class")
val indexer = new StringIndexer()
.setInputCol("category")
.setOutputCol("categoryIndex")
val discr = new QuantileDiscretizer()
.setInputCol("class")
.setOutputCol("quant")
.setNumBuckets(2)
.setHandleInvalid("keep")
val assembler = new VectorAssembler()
.setInputCols(Array("categoryIndex", "quant"))
.setOutputCol("features")
val rf = new RandomForestClassifier()
.setLabelCol("categoryIndex")
.setFeaturesCol("features")
.setNumTrees(3)
new Pipeline()
.setStages(Array(indexer, discr, assembler, rf))
.fit(tdf)
.transform(tdf)
.show()
Without trying to fit the Random Forest, I was getting a DataFrame like this:
+---+--------+-----+-------------+-----+---------+
| id|category|class|categoryIndex|quant| features|
+---+--------+-----+-------------+-----+---------+
| 0| a| 1.0| 0.0| 1.0|[0.0,1.0]|
| 1| b| 0.0| 2.0| 0.0|[2.0,0.0]|
| 2| c| 2.0| 1.0| 1.0|[1.0,1.0]|
| 3| a| 1.0| 0.0| 1.0|[0.0,1.0]|
| 4| a| 3.0| 0.0| 1.0|[0.0,1.0]|
| 5| c| NaN| 1.0| 2.0|[1.0,2.0]|
+---+--------+-----+-------------+-----+---------+
If I try to fit the model, I get the error:
18/03/23 17:54:12 WARN DecisionTreeMetadata: DecisionTree reducing maxBins from 32 to 6 (= number of training instances)
18/03/23 17:54:12 WARN BlockManager: Putting block rdd_490_3 failed due to an exception
18/03/23 17:54:12 WARN BlockManager: Block rdd_490_3 could not be removed as it was not found on disk or in memory
18/03/23 17:54:12 ERROR Executor: Exception in task 3.0 in stage 143.0 (TID 414)
java.lang.IllegalArgumentException: DecisionTree given invalid data: Feature 1 is categorical with values in {0,...,1, but a data point gives it value 2.0.
Bad data point: (1.0,[1.0,2.0])
at org.apache.spark.ml.tree.impl.TreePoint$.findBin(TreePoint.scala:124)
at org.apache.spark.ml.tree.impl.TreePoint$.org$apache$spark$ml$tree$impl$TreePoint$$labeledPointToTreePoint(TreePoint.scala:93)
at org.apache.spark.ml.tree.impl.TreePoint$$anonfun$convertToTreeRDD$2.apply(TreePoint.scala:73)
at org.apache.spark.ml.tree.impl.TreePoint$$anonfun$convertToTreeRDD$2.apply(TreePoint.scala:72)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:410)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:410)
at org.apache.spark.storage.memory.MemoryStore.putIteratorAsValues(MemoryStore.scala:216)
Does QuantileDiscretizer inserts some kind of metadata about the special extra bucket? It's weird that I was able to build a model using columns with the same values before, but without forcing any discretization.
Update
Yes, columns does have attached metadata and it looks like this:
org.apache.spark.sql.types.Metadata = {"ml_attr":
{"ord":true,
"vals":["-Infinity, 5.0","5.0, 10.0","10.0, Infinity"],
"type":"nominal"}
}
The question now might be: how to set correctly the metadata to include values like Double.NaN?
The workaround I used was simply to remove the associated metadata from the discretized columns, letting the decision tree implementation to decide what to do with the data. I think the column would actually become a numerical column ([0, 1, 2, 2, 1], for example), but, if too many categories are created, the column could be discretized again (look for the parameter maxBins).
In my case, the simplest way to remove the metadata was to fill the DataFrame after applying QuantileDiscretizer:
// Nothing is actually filled in my case, since there was no missing
// values before this operation.
df.na.fill(Double.NaN, Array("quant"))
I'm almost sure you could also manually remove the metadata accessing the column object directly.
Update
We can change a column's metadata by creating an alias (reference):
val metadata: Metadata = ...
df.select($"colA".as("colB", metadata))
This answer describes a way to get the column's metadata by getting the respective StructField of a DataFrame's schema.

Spark: Parquet DataFrame operations fail when forcing schema on read

(Spark 2.0.2)
The problem here rises when you have parquet files with different schema and force the schema during read. Even though you can print the schema and run show() ok, you cannot apply any filtering logic on the missing columns.
Here are the two example schemata:
// assuming you are running this code in a spark REPL
import spark.implicits._
case class Foo(i: Int)
case class Bar(i: Int, j: Int)
So Bar includes all the fields of Foo and adds one more (j). In real-life this arises when you start with schema Foo and later decided that you needed more fields and end up with schema Bar.
Let's simulate the two different parquet files.
// assuming you are on a Mac or Linux OS
spark.createDataFrame(Foo(1)::Nil).write.parquet("/tmp/foo")
spark.createDataFrame(Bar(1,2)::Nil).write.parquet("/tmp/bar")
What we want here is to always read data using the more generic schema Bar. That is, rows written on schema Foo should have j to be null.
case 1: We read a mix of both schema
spark.read.option("mergeSchema", "true").parquet("/tmp/foo", "/tmp/bar").show()
+---+----+
| i| j|
+---+----+
| 1| 2|
| 1|null|
+---+----+
spark.read.option("mergeSchema", "true").parquet("/tmp/foo", "/tmp/bar").filter($"j".isNotNull).show()
+---+---+
| i| j|
+---+---+
| 1| 2|
+---+---+
case 2: We only have Bar data
spark.read.parquet("/tmp/bar").show()
+---+---+
| i| j|
+---+---+
| 1| 2|
+---+---+
case 3: We only have Foo data
scala> spark.read.parquet("/tmp/foo").show()
+---+
| i|
+---+
| 1|
+---+
The problematic case is 3, where our resulting schema is of type Foo and not of Bar. Since we migrate to schema Bar, we want to always get schema Bar from our data (old and new).
The suggested solution would be to define the schema programmatically to always be Bar. Let's see how to do this:
val barSchema = org.apache.spark.sql.Encoders.product[Bar].schema
//barSchema: org.apache.spark.sql.types.StructType = StructType(StructField(i,IntegerType,false), StructField(j,IntegerType,false))
Running show() works great:
scala> spark.read.schema(barSchema).parquet("/tmp/foo").show()
+---+----+
| i| j|
+---+----+
| 1|null|
+---+----+
However, if you try to filter on the missing column j, things fail.
scala> spark.read.schema(barSchema).parquet("/tmp/foo").filter($"j".isNotNull).show()
17/09/07 18:13:50 ERROR Executor: Exception in task 0.0 in stage 230.0 (TID 481)
java.lang.IllegalArgumentException: Column [j] was not found in schema!
at org.apache.parquet.Preconditions.checkArgument(Preconditions.java:55)
at org.apache.parquet.filter2.predicate.SchemaCompatibilityValidator.getColumnDescriptor(SchemaCompatibilityValidator.java:181)
at org.apache.parquet.filter2.predicate.SchemaCompatibilityValidator.validateColumn(SchemaCompatibilityValidator.java:169)
at org.apache.parquet.filter2.predicate.SchemaCompatibilityValidator.validateColumnFilterPredicate(SchemaCompatibilityValidator.java:151)
at org.apache.parquet.filter2.predicate.SchemaCompatibilityValidator.visit(SchemaCompatibilityValidator.java:91)
at org.apache.parquet.filter2.predicate.SchemaCompatibilityValidator.visit(SchemaCompatibilityValidator.java:58)
at org.apache.parquet.filter2.predicate.Operators$NotEq.accept(Operators.java:194)
at org.apache.parquet.filter2.predicate.SchemaCompatibilityValidator.validate(SchemaCompatibilityValidator.java:63)
at org.apache.parquet.filter2.compat.RowGroupFilter.visit(RowGroupFilter.java:59)
at org.apache.parquet.filter2.compat.RowGroupFilter.visit(RowGroupFilter.java:40)
at org.apache.parquet.filter2.compat.FilterCompat$FilterPredicateCompat.accept(FilterCompat.java:126)
at org.apache.parquet.filter2.compat.RowGroupFilter.filterRowGroups(RowGroupFilter.java:46)
at org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase.initialize(SpecificParquetRecordReaderBase.java:110)
at org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader.initialize(VectorizedParquetRecordReader.java:109)
at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$$anonfun$buildReader$1.apply(ParquetFileFormat.scala:381)
at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$$anonfun$buildReader$1.apply(ParquetFileFormat.scala:355)
at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.nextIterator(FileScanRDD.scala:168)
at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.hasNext(FileScanRDD.scala:109)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.scan_nextBatch$(Unknown Source)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:377)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:231)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:225)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
at org.apache.spark.scheduler.Task.run(Task.scala:99)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:322)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
at java.lang.Thread.run(Thread.java:745)
Issue is due to parquet filter push down which is not correctly handled in parquet-mr versions < 1.9.0
You can check https://issues.apache.org/jira/browse/PARQUET-389 for more details.
You can either upgrade the parquet-mr version or add a new column and base the filter on the new column.
For eg.
dfNew = df.withColumn("new_j", when($"j".isNotNull, $"j").otherwise(lit(null)))
dfNew.filter($"new_j".isNotNull)
On Spark 1.6 worked fine, schema retrieving was changed, HiveContext was used:
val barSchema = ScalaReflection.schemaFor[Bar].dataType.asInstanceOf[StructType]
println(s"barSchema: $barSchema")
hiveContext.read.schema(barSchema).parquet("tmp/foo").filter($"j".isNotNull).show()
Result is:
barSchema: StructType(StructField(i,IntegerType,false), StructField(j,IntegerType,false))
+---+----+
| i| j|
+---+----+
| 1|null|
+---+----+
What worked for me is to use the createDataFrame API with RDD[Row] and the new schema (which at least the new columns being nullable).
// Make the columns nullable (probably you don't need to make them all nullable)
val barSchemaNullable = org.apache.spark.sql.types.StructType(
barSchema.map(_.copy(nullable = true)).toArray)
// We create the df (but this is not what you want to use, since it still has the same issue)
val df = spark.read.schema(barSchemaNullable).parquet("/tmp/foo")
// Here is the final API that give a working DataFrame
val fixedDf = spark.createDataFrame(df.rdd, barSchemaNullable)
fixedDf.filter($"j".isNotNull).show()
+---+---+
| i| j|
+---+---+
+---+---+