How do I detect if a Spark DataFrame has a column - scala

When I create a DataFrame from a JSON file in Spark SQL, how can I tell if a given column exists before calling .select
Example JSON schema:
"a": {
"b": 1,
"c": 2
This is what I want to do:
potential_columns = Seq("b", "c", "d")
df = => if(df.hasColumn(column))"a.$column"))
but I can't find a good function for hasColumn. The closest I've gotten is to test if the column is in this somewhat awkward array:
res17: Array[String] = Array(b, c)

Just assume it exists and let it fail with Try. Plain and simple and supports an arbitrary nesting:
import scala.util.Try
import org.apache.spark.sql.DataFrame
def hasColumn(df: DataFrame, path: String) = Try(df(path)).isSuccess
val df =
"""{"foo": [{"bar": {"foobar": 3}}]}""" :: Nil))
hasColumn(df, "foobar")
// Boolean = false
hasColumn(df, "foo")
// Boolean = true
hasColumn(df, "")
// Boolean = true
hasColumn(df, "")
// Boolean = true
hasColumn(df, "")
// Boolean = false
Or even simpler:
val columns = Seq(
"foobar", "foo", "", "", "")
columns.flatMap(c => Try(df(c)).toOption)
// Seq[org.apache.spark.sql.Column] = List(
// foo, AS bar#12, AS foobar#13)
Python equivalent:
from pyspark.sql.utils import AnalysisException
from pyspark.sql import Row
def has_column(df, col):
return True
except AnalysisException:
return False
df = sc.parallelize([Row(foo=[Row(bar=Row(foobar=3))])]).toDF()
has_column(df, "foobar")
## False
has_column(df, "foo")
## True
has_column(df, "")
## True
has_column(df, "")
## True
has_column(df, "")
## False

Another option which I normally use is
This returns a boolean

Actually you don't even need to call select in order to use columns, you can just call it on the dataframe itself
// define test data
case class Test(a: Int, b: Int)
val testList = List(Test(1,2), Test(3,4))
val testDF = sqlContext.createDataFrame(testList)
// define the hasColumn function
def hasColumn(df: org.apache.spark.sql.DataFrame, colName: String) = df.columns.contains(colName)
// then you can just use it on the DF with a given column name
hasColumn(testDF, "a") // <-- true
hasColumn(testDF, "c") // <-- false
Alternatively you can define an implicit class using the pimp my library pattern so that the hasColumn method is available on your dataframes directly
implicit class DataFrameImprovements(df: org.apache.spark.sql.DataFrame) {
def hasColumn(colName: String) = df.columns.contains(colName)
Then you can use it as:
testDF.hasColumn("a") // <-- true
testDF.hasColumn("c") // <-- false

Try is not optimal as the it will evaluate the expression inside Try before it takes the decision.
For large data sets, use the below in Scala:

For those who stumble across this looking for a Python solution, I use:
if 'column_name_to_check' in df.columns:
# do something
When I tried #Jai Prakash's answer of df.columns.contains('column-name-to-check') using Python, I got AttributeError: 'list' object has no attribute 'contains'.

Your other option for this would be to do some array manipulation (in this case an intersect) on the df.columns and your potential_columns.
// Loading some data (so you can just copy & paste right into spark-shell)
case class Document( a: String, b: String, c: String)
val df = sc.parallelize(Seq(Document("a", "b", "c")), 2).toDF
// The columns we want to extract
val potential_columns = Seq("b", "c", "d")
// Get the intersect of the potential columns and the actual columns,
// we turn the array of strings into column objects
// Finally turn the result into a vararg (: _*) _*).show
Alas this will not work for you inner object scenario above. You will need to look at the schema for that.
I'm going to change your potential_columns to fully qualified column names
val potential_columns = Seq("a.b", "a.c", "a.d")
// Our object model
case class Document( a: String, b: String, c: String)
case class Document2( a: Document, b: String, c: String)
// And some data...
val df = sc.parallelize(Seq(Document2(Document("a", "b", "c"), "c2")), 2).toDF
// We go through each of the fields in the schema.
// For StructTypes we return an array of parentName.fieldName
// For everything else we return an array containing just the field name
// We then flatten the complete list of field names
// Then we intersect that with our potential_columns leaving us just a list of column we want
// we turn the array of strings into column objects
// Finally turn the result into a vararg (: _*) => a.dataType match { case s : org.apache.spark.sql.types.StructType => => + "." + x) case _ => Array( }).flatMap(x => x).intersect(potential_columns).map(df(_)) : _*).show
This only goes one level deep, so to make it generic you would have to do more work.

in pyspark you can simply run
'field' in df.columns

If you shred your json using a schema definition when you load it then you don't need to check for the column. if it's not in the json source it will appear as a null column.
val schemaJson = """
"type": "struct",
"fields": [
"name": field1
"type": "string",
"nullable": true,
"metadata": {}
"name": field2
"type": "string",
"nullable": true,
"metadata": {}
val schema = DataType.fromJson(schemaJson).asInstanceOf[StructType]
val djson =
.schema(schema )
.option("badRecordsPath", readExceptionPath)

def hasColumn(df: org.apache.spark.sql.DataFrame, colName: String) =
Use the above mentioned function to check the existence of column including nested column name.

In PySpark, df.columns gives you a list of columns in the dataframe, so
"colName" in df.columns
would return a True or False. Give a try on that. Good luck!

For nested columns you can use


ClassCastException: org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema cannot be cast

I am using an Aggregator to apply some custom merge on a DataFrame after grouping its records by their primary key:
case class Player(
pk: String,
ts: String,
first_name: String,
date_of_birth: String
case class PlayerProcessed(
var ts: String,
var first_name: String,
var date_of_birth: String
// Cutomer Aggregator -This just for the example, actual one is more complex
object BatchDedupe extends Aggregator[Player, PlayerProcessed, PlayerProcessed] {
def zero: PlayerProcessed = PlayerProcessed("0", null, null)
def reduce(bf: PlayerProcessed, in : Player): PlayerProcessed = {
bf.ts = in.ts
bf.first_name = in.first_name
bf.date_of_birth = in.date_of_birth
def merge(bf1: PlayerProcessed, bf2: PlayerProcessed): PlayerProcessed = {
bf1.ts = bf2.ts
bf1.first_name = bf2.first_name
bf1.date_of_birth = bf2.date_of_birth
def finish(reduction: PlayerProcessed): PlayerProcessed = reduction
def bufferEncoder: Encoder[PlayerProcessed] = Encoders.product
def outputEncoder: Encoder[PlayerProcessed] = Encoders.product
val ply1 = Player("12121212121212", "10000001", "Rogger", "1980-01-02")
val ply2 = Player("12121212121212", "10000002", "Rogg", null)
val ply3 = Player("12121212121212", "10000004", null, "1985-01-02")
val ply4 = Player("12121212121212", "10000003", "Roggelio", "1982-01-02")
val seq_users = sc.parallelize(Seq(ply1, ply2, ply3, ply4))[Player]
val grouped = seq_users.groupByKey(
val non_sorted = grouped.agg("deduped"))
This returns:
|key |deduped |
|12121212121212|{10000003, Roggelio, 1982-01-02}|
Now, I would like to order the records based on ts before aggregating them. From here I understand that .sortBy("ts") do not guarantee the order after the .groupByKey( So I was trying to apply the .sortBy between the .groupByKey and the .agg
The output of the .groupByKey( is a KeyValueGroupedDataset[String,Player], being the second element an Iterator. So to apply some sorting logic there I convert it into a Seq:
val sorted = grouped.mapGroups{case(k, iter) => (k, iter.toSeq.sortBy(_.ts))}.agg("deduped"))
However, the output of .mapGroups after adding the sorting logic is a Dataset[(String, Seq[Player])]. So when I try to invoke the .agg function on it I am getting the following exception:
Caused by: ClassCastException: org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema cannot be cast to $line050e0d37885948cd91f7f7dd9e3b4da9311.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$Player
How could I convert back the output of my .mapGroups(...) into a KeyValueGroupedDataset[String,Player]?
I tried to cast back to Iterator as follows:
val sorted = grouped.mapGroups{case(k, iter) => (k, iter.toSeq.sortBy(_.ts).toIterator)}.agg("deduped"))
But this approach produced the following exception:
UnsupportedOperationException: No Encoder found for Iterator[Player]
- field (class: "scala.collection.Iterator", name: "_2")
- root class: "scala.Tuple2"
How else can I add the sort logic between the .groupByKey and .agg methods?
Based on the discussion above, the purpose of the Aggregator is to get the latest field values per Player by ts ignoring null values.
This can be achieved fairly easily aggregating all fields individually using max_by. With that there's no need for a custom Aggregator nor the mutable aggregation buffer.
import org.apache.spark.sql.functions._
val players: Dataset[Player] = ...
// aggregate all columns except the key individually by ts
// NULLs will be ignored (SQL standard)
val aggColumns = players.columns
.filterNot(_ == "pk")
.map(colName => expr(s"max_by($colName, if(isNotNull($colName), ts, null))").as(colName))
val aggregatedPlayers = players
.agg(aggColumns.head, aggColumns.tail: _*)
On the most recent versions of Spark you can also use the build in max_by expression:
import org.apache.spark.sql.functions._
val players: Dataset[Player] = ...
// aggregate all columns except the key individually by ts
// NULLs will be ignored (SQL standard)
val aggColumns = players.columns
.filterNot(_ == "pk")
.map(colName => max_by(col(colName), when(col(colName).isNotNull, col("ts"))).as(colName))
val aggregatedPlayers = players
.agg(aggColumns.head, aggColumns.tail: _*)

How to use the functions.explode to flatten element in dataFrame

I've made this piece of code :
case class RawPanda(id: Long, zip: String, pt: String, happy: Boolean, attributes: Array[Double])
case class PandaPlace(name: String, pandas: Array[RawPanda])
object TestSparkDataFrame extends App{
System.setProperty("hadoop.home.dir", "E:\\Programmation\\Libraries\\hadoop")
val conf = new SparkConf().setAppName("TestSparkDataFrame").set("spark.driver.memory","4g").setMaster("local[*]")
val session = SparkSession.builder().config(conf).getOrCreate()
import session.implicits._
def createAndPrintSchemaRawPanda(session:SparkSession):DataFrame = {
val newPanda = RawPanda(1,"M1B 5K7", "giant", true, Array(0.1, 0.1))
val pandaPlace = PandaPlace("torronto", Array(newPanda))
val df =session.createDataFrame(Seq(pandaPlace))
val df2 = createAndPrintSchemaRawPanda(session)
| name| pandas|
|torronto|[[1,M1B 5K7,giant...|
val pandaInfo = df2.explode(df2("pandas")) {
case Row(pandas: Seq[Row]) =>{
case (Row(
id: Long,
zip: String,
pt: String,
happy: Boolean,
attrs: Seq[Double])) => RawPanda(id, zip, pt , happy, attrs.toArray)
| name| pandas| id| zip| pt|happy|attributes|
|torronto|[[1,M1B 5K7,giant...| 1|M1B 5K7|giant| true|[0.1, 0.1]|
The problem that the explode function as I used it is deprecated, so I would like to recaculate the pandaInfo2 dataframe but using the adviced method in the warning.
use flatMap() or select() with functions.explode() instead
But then when I do :
val pandaInfo ="pandas"))
I obtain the same result as I had in df2.
I don't know how to proceed to use flatMap or functions.explode.
How could I use flatMap or functions.explode to obtain the result that I want ?(the one in pandaInfo)
I've seen this post and this other one but none of them helped me.
Calling select with explode function returns a DataFrame where the Array pandas is "broken up" into individual records; Then, if you want to "flatten" the structure of the resulting single "RawPanda" per record, you can select the individual columns using a dot-separated "route":
val pandaInfo2 =$"name", explode($"pandas") as "pandas")
.select($"name", $"pandas",
$"" as "id",
$"" as "zip",
$"" as "pt",
$"pandas.happy" as "happy",
$"pandas.attributes" as "attributes"
A less verbose version of the exact same operation would be:
import org.apache.spark.sql.Encoders // going to use this to "encode" case class into schema
val pandaColumns = Encoders.product[RawPanda]
val pandaInfo3 =$"name", explode($"pandas") as "pandas")
.select(Seq($"name", $"pandas") ++ => $"pandas.$f" as f): _*)

Filtering in Scala

So suppose I have the following data (only the first few rows, this data covers an entire year) -
(2014-08-31T00:05:00.000+01:00, John)
(2014-08-31T00:11:00.000+01:00, Sarah)
(2014-08-31T00:12:00.000+01:00, George)
(2014-08-31T00:05:00.000+01:00, John)
(2014-09-01T00:05:00.000+01:00, Sarah)
(2014-09-01T00:05:00.000+01:00, George)
(2014-09-01T00:05:00.000+01:00, Jason)
I would like to filter the data so that I only see what the names are for a specific date (say, 2014-09-05). I've tried doing this using the filter function in Scala but I keep receiving the following error -
error: value xxxx is not a member of (org.joda.time.DateTime, String)
Is there another way of doing this?
The filter method takes a function, called a predicate, that takes as parameter an element of your (I'm assuming) RDD, and returns a Boolean.
The returned RDD will keep only the rows for which the predicate evaluates to true.
In your case, it seems that what you want is something like
case (date, _) => date.withTimeAtStartOfDay() == new DateTime("2017-03-31")
I presume from the tag your question is in the context of Spark and not pure Scala. Given that, you could filter a dataframe on a date and get the associated name(s) like this:
import org.apache.spark.sql.functions._
import sparkSession.implicits._
("2014-08-31T00:05:00.000+01:00", "John"),
("2014-08-31T00:11:00.000+01:00", "Sarah")
.toDF("date", "name")
Note that the Date above is java.sql.Date.
Here's a function that takes a date, a list of datetime-name pairs, and returns a list of names for the date:
def getNames(d: String, l: List[(String, String)]): List[String] = {
val date = """^([^T]*).*""".r
val dateMap = {
case (x, y) => ( x match { case date(z) => z }, y )
groupBy(_._1) mapValues( )
dateMap.getOrElse(d, List[String]())
val list = List(
("2014-08-31T00:05:00.000+01:00", "John"),
("2014-08-31T00:11:00.000+01:00", "Sarah"),
("2014-08-31T00:12:00.000+01:00", "George"),
("2014-08-31T00:05:00.000+01:00", "John"),
("2014-09-01T00:05:00.000+01:00", "Sarah"),
("2014-09-01T00:05:00.000+01:00", "George"),
("2014-09-01T00:05:00.000+01:00", "Jason")
getNames("2014-09-01", list)
res1: List[String] = List(Sarah, George, Jason)
val dateTimeStringZero = "2014-08-12T00:05:00.000+01:00"
val dateTimeOne:DateTime = org.joda.time.format.ISODateTimeFormat.dateTime.withZoneUTC.parseDateTime(dateTimeStringZero)
import java.text.SimpleDateFormat
val df = new DateTime(new SimpleDateFormat("yyyy-MM-dd").parse("2014-08-12"))

How to iterate scala wrappedArray? (Spark)

I perform the following operations:
val tempDict = sqlContext.sql("select words.pName_token,collect_set(words.pID) as docids
from words
group by words.pName_token").toDF()
val wordDocs = tempDict.filter(newDict("pName_token")===word)
val listDocs = => t(1)).collect()
listDocs: Array
[Any] = Array(WrappedArray(123, 234, 205876618, 456))
My question is how do I iterate over this wrapped array or convert this into a list?
The options I get for the listDocs are apply, asInstanceOf, clone, isInstanceOf, length, toString, and update.
How do I proceed?
Here is one way to solve this.
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions._
import scala.collection.mutable.WrappedArray
val data = Seq((Seq(1,2,3),Seq(4,5,6),Seq(7,8,9)))
val df = sqlContext.createDataFrame(data)
val first = df.first
// use a pattern match to deferral the type
val mapped = first.getAs[WrappedArray[Int]](0)
// now we can use it like normal collection
// get rows where has array
val rows = {
case Row(a: Seq[Any], b: Seq[Any], c: Seq[Any]) =>
(a, b, c)

Pass array as an UDF parameter in Spark SQL

I'm trying to transform a dataframe via a function that takes an array as a parameter. My code looks something like this:
def getCategory(categories:Array[String], input:String): String = {
val myArray = Array("a", "b", "c")
val myCategories =udf(getCategory _ )
val df = sqlContext.parquetFile("myfile.parquet)
val df1 = df.withColumn("newCategory", myCategories(lit(myArray), col("myInput"))
However, lit doesn't like arrays and this script errors. I tried definining a new partially applied function and then the udf after that :
val newFunc = getCategory(myArray, _:String)
val myCategories = udf(newFunc)
val df1 = df.withColumn("newCategory", myCategories(col("myInput")))
This doesn't work either as I get a nullPointer exception and it appears myArray is not being recognized. Any ideas on how I pass an array as a parameter to a function with a dataframe?
On a separate note, any explanation as to why doing something simple like using a function on a dataframe is so complicated (define function, redefine it as UDF, etc, etc)?
Most likely not the prettiest solution but you can try something like this:
def getCategory(categories: Array[String]) = {
udf((input:String) => categories(input.toInt))
df.withColumn("newCategory", getCategory(myArray)(col("myInput")))
You could also try an array of literals:
val getCategory = udf(
(input:String, categories: Array[String]) => categories(input.toInt))
"newCategory", getCategory($"myInput", array( _*)))
On a side note using Map instead of Array is probably a better idea:
def mapCategory(categories: Map[String, String], default: String) = {
udf((input:String) => categories.getOrElse(input, default))
val myMap = Map[String, String]("1" -> "a", "2" -> "b", "3" -> "c")
df.withColumn("newCategory", mapCategory(myMap, "foo")(col("myInput")))
Since Spark 1.5.0 you can also use an array function:
import org.apache.spark.sql.functions.array
val colArray = array(myArray map(lit _): _*)
myCategories(lit(colArray), col("myInput"))
See also Spark UDF with varargs