Exception handing in a UDF : Spark 1.6 - scala

def parse_values(value: String) = {
val values = value.split(",").map(_.trim)
values.foldLeft(Array[(Int, Double)]()) {
case (acc, present) =>
val Array(k, v) = {
present.split(",")(0).split(":") match {
case Array(_) => Array("0", "0.0")
case arr => arr
acc :+ (k.trim.toInt, v.trim.toDouble)
What this function does is that it parses a column of string into an array of keys and values. "50:63.25,100:58.38" to [[50,63.2], [100,58.38]]. This is my UDF which creates a wrapped array of struct elements of int and Double.
| |-- element: struct (containsNull = true)
| | |-- _1: integer (nullable = false)
| | |-- _2: double (nullable = false)
There are cases when the input string is not correctly formatted and I get an error: java.lang.NumberFormatException for the input string: because "k.trim.toInt" is not able to cast dirty data like ".01-4.1293" which is one of the exception string in a huge dataset. Can anyone help me with this issue?
I would like to return an empty array or an array with [0,0.0] when exception occurs. Any suggestions?

You can use the Try class
Instead of
(k.trim.toInt, v.trim.toDouble)
Encapsulate it in a Try with a getOrElse such as:
(Try(k.trim.toInt).getOrElse(0), Try(v.trim.toDouble).getOrElse(0.0))
It will return the proper value if succeeded and the default value of your desire if failed
Quick test here:
val invalid: String = .01-4.1293
val valid: String = 56
res19: Int = 0
res20: Int = 56
As a whole with your function:
import scala.util.Try
def parse_values(value: String) = {
val values = value.split(",").map(_.trim)
values.foldLeft(Array[(Int, Double)]()) {
case (acc, present) =>
val Array(k, v) = {
present.split(",")(0).split(":") match {
case Array(_) => Array("0", "0.0")
case arr => arr
acc :+ (Try(k.trim.toInt).getOrElse(0), Try(v.trim.toDouble).getOrElse(0.0))
Also you can find more info about the functional error handling and the Try class here


Apache Spark Null Value when casting incompatible DecimalType vs ClassCastException

Casting DecimalType(10,5) e.g. 99999.99999 to DecimalType(5,4) in Apache Spark silently returns null
Is it possible to change this behavior and allow Spark to throw an exception(for example some CastException) in this case and fail the job instead of silently return null ?
As per the Git hub documentation, https://github.com/apache/spark/blob/3ab96d7acf870e53c9016b0b63d0b328eec23bed/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala#L499
/** * Change the precision / scale in a given decimal to those set
in decimalType (if any), * returning null if it overflows or
modifying value in-place and returning it if successful. * *
NOTE: this modifies value in-place, so don't call it on external
data. */
There is also another thread, suggesting there may not be a direct method to fail the code if not able to cast. Spark: cast decimal without changing nullable property of column.
So, probably you can try checking for the nullvalue in the casted column and create a logic to fail if any?
As I mentioned in my comment above you can try to achieve what you want using a UserDefinedFunction. I'm currently facing the same problem but managed to solve mine using a UDF. The problem I was facing is that I wanted try to cast a column to DoubleType but I don't know the type upfront and I wan't my application to fail when the parsing fails, so not a silent 'null' like you are talking about.
In the code below you can see I've written an udf which takes in a struct as parameter.I'll try to parse the only value in this struct to a double. If this fails I will throw an exception which causes my job to fail.
import spark.implicits._
val cast_to_double = udf((number: Row) => {
try {
number.get(0) match {
case s: String => s.toDouble
case d: Double => d
case l: Long => l.toDouble
case i: Int => i.toDouble
case _ => throw new NumberFormatException
} catch {
case _: NumberFormatException => throw new IllegalArgumentException("Can't parse this so called number of yours.")
try {
val intDF = List(1).toDF("something")
val secondIntDF = intDF.withColumn("something_else", cast_to_double(struct(col("something"))))
val stringIntDF = List("1").toDF("something")
val secondStringIntDF = stringIntDF.withColumn("something_else", cast_to_double(struct(col("something"))))
val stringDF = List("string").toDF("something")
val secondStringDF = stringDF.withColumn("something_else", cast_to_double(struct(col("something"))))
} catch {
case se: SparkException => println(se.getCause.getMessage)
|-- something: integer (nullable = false)
|-- something_else: double (nullable = false)
| 1| 1.0|
|-- something: string (nullable = true)
|-- something_else: double (nullable = false)
| 1| 1.0|
|-- something: string (nullable = true)
|-- something_else: double (nullable = false)
Can't parse this so called number of yours.

How to create a Row from a given case class?

Imagine that you have the following case classes:
case class B(key: String, value: Int)
case class A(name: String, data: B)
Given an instance of A, how do I create a Spark Row? e.g.
val a = A("a", B("b", 0))
val row = ???
NOTE: Given row I need to be able to get data with:
val name: String = row.getAs[String]("name")
val b: Row = row.getAs[Row]("data")
The following seems to match what you're looking for.
scala> spark.version
res0: String = 2.3.0
scala> val a = A("a", B("b", 0))
a: A = A(a,B(b,0))
import org.apache.spark.sql.Encoders
val schema = Encoders.product[A].schema
scala> schema.printTreeString
|-- name: string (nullable = true)
|-- data: struct (nullable = true)
| |-- key: string (nullable = true)
| |-- value: integer (nullable = false)
val values = a.productIterator.toSeq.toArray
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
val row: Row = new GenericRowWithSchema(values, schema)
scala> val name: String = row.getAs[String]("name")
name: String = a
// the following won't work since B =!= Row
scala> val b: Row = row.getAs[Row]("data")
java.lang.ClassCastException: B cannot be cast to org.apache.spark.sql.Row
... 55 elided
Very short but probably not the fastest as it first creates a dataframe and then collects it again :
import session.implicits._
val row = Seq(a).toDF().first()
#Jacek Laskowski answer is great!
To complete:
Here some syntactic sugar:
val row = Row(a.productIterator.toSeq: _*)
And a recursive method if you happen to have nested case classes
def productToRow(product: Product): Row = {
val sequence = product.productIterator.toSeq.map {
case product : Product => productToRow(product)
case e => e
Row(sequence : _*)
I don't think there exist a public API that can do it directly. Internally Spark uses Encoder.toRow method to convert objects org.apache.spark.sql.catalyst.expressions.UnsafeRow, but this method is private. You could try to:
Obtain Encoder for the class:
val enc: Encoder[A] = ExpressionEncoder()
Use reflection to access toRow method and set it to accessible.
Call it to convert object to UnsafeRow.
Obtain RowEncoder for the expected schema (enc.schema).
Convert UnsafeRow to Row.
I haven't tried this, so I cannot guarantee it will work or not.

Spark UDAF: How to get value from input by column field name in UDAF (User-Defined Aggregation Function)?

I am trying to use Spark UDAF to summarize two existing columns into a new column. Most of the tutorials on Spark UDAF out there use indices to get the values in each column of the input Row. Like this:
, which is used in my update method (override def update(buffer: MutableAggregationBuffer, input: Row): Unit). It works in my case as well. However I want to use the field name of the that column to get that value. Like this:
, where ColumnNames.BehaviorType is a String object defined in an object:
* Column names in the original dataset
object ColumnNames {
val JobSeekerID = "JobSeekerID"
val JobID = "JobID"
val Date = "Date"
val BehaviorType = "BehaviorType"
This time it does not work. I got the following exception:
java.lang.IllegalArgumentException: Field "BehaviorType" does not
exist. at
... at org.apache.spark.sql.Row$class.getAs(Row.scala:333) at
Some relevant code segments:
This is part of my UDAF:
class UserBehaviorRecordsUDAF extends UserDefinedAggregateFunction {
override def inputSchema: StructType = StructType(
StructField("JobID", IntegerType) ::
StructField("BehaviorType", StringType) :: Nil)
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// println
// println(bufferSchema.treeString)
input.getAs[String](ColumnNames.BehaviorType) match { //ColumnNames.BehaviorType //1 //TODO WHY??
case BehaviourTypes.viewed_job =>
buffer(0) =
buffer.getAs[Seq[Int]](0) :+ //Array[Int] //TODO WHY??
input.getAs[Int](0) //ColumnNames.JobID
case BehaviourTypes.bookmarked_job =>
buffer(1) =
buffer.getAs[Seq[Int]](1) :+ //Array[Int]
case BehaviourTypes.applied_job =>
buffer(2) =
buffer.getAs[Seq[Int]](2) :+ //Array[Int]
input.getAs[Int](0) //ColumnNames.JobID
The following is the part of codes that call the UDAF:
val ubrUDAF = new UserBehaviorRecordsUDAF
val userProfileDF = userBehaviorDS
userBehaviorDS.col(ColumnNames.JobID), //userBehaviorDS.col(ColumnNames.JobID)
userBehaviorDS.col(ColumnNames.BehaviorType) //userBehaviorDS.col(ColumnNames.BehaviorType)
).as("profile str"))
It seems the field names in the schema of the input Row are not passed into the UDAF:
|-- input0: integer (nullable = true)
|-- input1: string (nullable = true)
|-- JobID: integer (nullable = true)
|-- BehaviorType: string (nullable = true)
What is the problem in my codes?
I also want to use the field names from my inputSchema in my update method to create maintainable code.
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
class MyUDAF extends UserDefinedAggregateFunction {
def update(buffer: MutableAggregationBuffer, input: Row) = {
val inputWSchema = new GenericRowWithSchema(input.toSeq.toArray, inputSchema)
Ultimately switched to Aggregator which ran in half the time.

Spark scala - Nested StructType conversion to Map

I am using Spark 1.6 in scala.
I created an index in ElasticSearch with an object. The object "params" was created as a Map[String, Map[String, String]]. Example:
val params : Map[String, Map[String, String]] = ("p1" -> ("p1_detail" -> "table1"), "p2" -> (("p2_detail" -> "table2"), ("p2_filter" -> "filter2")), "p3" -> ("p3_detail" -> "table3"))
That gives me records that look like the following:
"_index": "x",
"_type": "1",
"_id": "xxxxxxxxxxxx",
"_score": 1,
"_timestamp": 1506537199650,
"_source": {
"a": "toto",
"b": "tata",
"c": "description",
"params": {
"p1": {
"p1_detail": "table1"
"p2": {
"p2_detail": "table2",
"p2_filter": "filter2"
"p3": {
"p3_detail": "table3"
Then I am trying to read the Elasticsearch index in order to update the values.
Spark reads the index with the following schema:
|-- a: string (nullable = true)
|-- b: string (nullable = true)
|-- c: string (nullable = true)
|-- params: struct (nullable = true)
| |-- p1: struct (nullable = true)
| | |-- p1_detail: string (nullable = true)
| |-- p2: struct (nullable = true)
| | |-- p2_detail: string (nullable = true)
| | |-- p2_filter: string (nullable = true)
| |-- p3: struct (nullable = true)
| | |-- p3_detail: string (nullable = true)
My problem is that the object is read as a struct. In order to manage and easily update the fields I want to have a Map as I am not very familiar with StructType.
I tried to get the object in a UDF as a Map but I have the following error:
User class threw exception: org.apache.spark.sql.AnalysisException: cannot resolve 'UDF(params)' due to data type mismatch: argument 1 requires map<string,map<string,string>> type, however, 'params' is of struct<p1:struct<p1_detail:string>,p2:struct<p2_detail:string,p2_filter:string>,p3:struct<p3_detail:string>> type.;
UDF code snippet:
val getSubField : Map[String, Map[String, String]] => String = (params : Map[String, Map[String, String]]) => { val return_string = (params ("p1") getOrElse("p1_detail", null.asInstanceOf[String]) return_string }
My question: How can we convert this Struct to a Map? I already read saw the toMap method available in the documentation but can not find how to use it (not very familiar with implicit parameters) as I am a scala beginner.
Thanks in advance,
I finally solved it as follows:
def convertRowToMap[T](row: Row): Map[String, T] = {
.filter(field => !row.isNullAt(row.fieldIndex(field)))
.map(field => field -> row.getAs[T](field))
/* udf that converts Row to Map */
val rowToMap: Row => Map[String, Map[String, String]] = (row: Row) => {
val mapTemp = convertRowToMap[Row](row)
val mapToReturn = mapTemp.map { case (k, v) => k -> convertRowToMap[String](v) }
val udfrowToMap = udf(rowToMap)
You can't specify type of param as StructType object, instead specify type as Row.
//Schema of parameter
def schema:StructType = (new StructType).add("p1", (new StructType).add("p1_detail", StringType))
.add("p2", (new StructType).add("p2_detail", StringType).add("p2_filter",StringType))
.add("p3", (new StructType).add("p3_detail", StringType))
//Not allowed
val extractVal: schema => collection.Map[Nothing, Nothing] = _.getMap(0)
// UDF example to process struct column
val extractVal: (Row) => collection.Map[Nothing, Nothing] = _.getMap(0)
// You would implement something similar
val getSubField : Map[String, Map[String, String]] => String =
(params : Row) =>
val p1 = params.getAs[Row]("p1")
return null;
I hope this helps !

Dropping a nested column from Spark DataFrame

I have a DataFrame with the schema
|-- label: string (nullable = true)
|-- features: struct (nullable = true)
| |-- feat1: string (nullable = true)
| |-- feat2: string (nullable = true)
| |-- feat3: string (nullable = true)
While, I am able to filter the data frame using
val data = rawData
.filter( !(rawData("features.feat1") <=> "100") )
I am unable to drop the columns using
val data = rawData
Is it something that I am doing wrong here? I also tried (unsuccessfully) doing drop(rawData("features.feat1")), though it does not make much sense to do so.
Thanks in advance,
It is just a programming exercise but you can try something like this:
import org.apache.spark.sql.{DataFrame, Column}
import org.apache.spark.sql.types.{StructType, StructField}
import org.apache.spark.sql.{functions => f}
import scala.util.Try
case class DFWithDropFrom(df: DataFrame) {
def getSourceField(source: String): Try[StructField] = {
Try(df.schema.fields.filter(_.name == source).head)
def getType(sourceField: StructField): Try[StructType] = {
def genOutputCol(names: Array[String], source: String): Column = {
f.struct(names.map(x => f.col(source).getItem(x).alias(x)): _*)
def dropFrom(source: String, toDrop: Array[String]): DataFrame = {
.map(genOutputCol(_, source))
.map(df.withColumn(source, _))
Example usage:
scala> case class features(feat1: String, feat2: String, feat3: String)
defined class features
scala> case class record(label: String, features: features)
defined class record
scala> val df = sc.parallelize(Seq(record("a_label", features("f1", "f2", "f3")))).toDF
df: org.apache.spark.sql.DataFrame = [label: string, features: struct<feat1:string,feat2:string,feat3:string>]
scala> DFWithDropFrom(df).dropFrom("features", Array("feat1")).show
| label|features|
|a_label| [f2,f3]|
scala> DFWithDropFrom(df).dropFrom("foobar", Array("feat1")).show
| label| features|
scala> DFWithDropFrom(df).dropFrom("features", Array("foobar")).show
| label| features|
Add an implicit conversion and you're good to go.
This version allows you to remove nested columns at any level:
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructType, DataType}
* Various Spark utilities and extensions of DataFrame
object DataFrameUtils {
private def dropSubColumn(col: Column, colType: DataType, fullColName: String, dropColName: String): Option[Column] = {
if (fullColName.equals(dropColName)) {
} else {
colType match {
case colType: StructType =>
if (dropColName.startsWith(s"${fullColName}.")) {
.flatMap(f =>
dropSubColumn(col.getField(f.name), f.dataType, s"${fullColName}.${f.name}", dropColName) match {
case Some(x) => Some(x.alias(f.name))
case None => None
: _*))
} else {
case other => Some(col)
protected def dropColumn(df: DataFrame, colName: String): DataFrame = {
.flatMap(f => {
if (colName.startsWith(s"${f.name}.")) {
dropSubColumn(col(f.name), f.dataType, f.name, colName) match {
case Some(x) => Some((f.name, x))
case None => None
} else {
.foldLeft(df.drop(colName)) {
case (df, (colName, column)) => df.withColumn(colName, column)
* Extended version of DataFrame that allows to operate on nested fields
implicit class ExtendedDataFrame(df: DataFrame) extends Serializable {
* Drops nested field from DataFrame
* #param colName Dot-separated nested field name
def dropNestedColumn(colName: String): DataFrame = {
DataFrameUtils.dropColumn(df, colName)
import DataFrameUtils._
Expanding on spektom answer. With support for array types:
object DataFrameUtils {
private def dropSubColumn(col: Column, colType: DataType, fullColName: String, dropColName: String): Option[Column] = {
if (fullColName.equals(dropColName)) {
} else if (dropColName.startsWith(s"$fullColName.")) {
colType match {
case colType: StructType =>
.flatMap(f =>
dropSubColumn(col.getField(f.name), f.dataType, s"$fullColName.${f.name}", dropColName) match {
case Some(x) => Some(x.alias(f.name))
case None => None
: _*))
case colType: ArrayType =>
colType.elementType match {
case innerType: StructType =>
.flatMap(f =>
dropSubColumn(col.getField(f.name), f.dataType, s"$fullColName.${f.name}", dropColName) match {
case Some(x) => Some(x.alias(f.name))
case None => None
: _*))
case other => Some(col)
} else {
protected def dropColumn(df: DataFrame, colName: String): DataFrame = {
.flatMap(f => {
if (colName.startsWith(s"${f.name}.")) {
dropSubColumn(col(f.name), f.dataType, f.name, colName) match {
case Some(x) => Some((f.name, x))
case None => None
} else {
.foldLeft(df.drop(colName)) {
case (df, (colName, column)) => df.withColumn(colName, column)
* Extended version of DataFrame that allows to operate on nested fields
implicit class ExtendedDataFrame(df: DataFrame) extends Serializable {
* Drops nested field from DataFrame
* #param colName Dot-separated nested field name
def dropNestedColumn(colName: String): DataFrame = {
DataFrameUtils.dropColumn(df, colName)
I will expand upon mmendez.semantic's answer here, and accounting for the issues described in the sub-thread.
def dropSubColumn(col: Column, colType: DataType, fullColName: String, dropColName: String): Option[Column] = {
if (fullColName.equals(dropColName)) {
} else if (dropColName.startsWith(s"$fullColName.")) {
colType match {
case colType: StructType =>
.flatMap(f =>
dropSubColumn(col.getField(f.name), f.dataType, s"$fullColName.${f.name}", dropColName) match {
case Some(x) => Some(x.alias(f.name))
case None => None
: _*))
case colType: ArrayType =>
colType.elementType match {
case innerType: StructType =>
// we are potentially dropping a column from within a struct, that is itself inside an array
// Spark has some very strange behavior in this case, which they insist is not a bug
// see https://issues.apache.org/jira/browse/SPARK-31779 and associated comments
// and also the thread here: https://stackoverflow.com/a/39943812/375670
// this is a workaround for that behavior
// first, get all struct fields
val innerFields = innerType.fields
// next, create a new type for all the struct fields EXCEPT the column that is to be dropped
// we will need this later
val preserveNamesStruct = ArrayType(StructType(
innerFields.filterNot(f => s"$fullColName.${f.name}".equals(dropColName))
// next, apply dropSubColumn recursively to build up the new values after dropping the column
val filteredInnerFields = innerFields.flatMap(f =>
dropSubColumn(col.getField(f.name), f.dataType, s"$fullColName.${f.name}", dropColName) match {
case Some(x) => Some(x.alias(f.name))
case None => None
// finally, use arrays_zip to unwrap the arrays that were introduced by building up the new. filtered
// struct in this way (see comments in SPARK-31779), and then cast to the StructType we created earlier
// to get the original names back
case _ => Some(col)
} else {
def dropColumn(df: DataFrame, colName: String): DataFrame = {
df.schema.fields.flatMap(f => {
if (colName.startsWith(s"${f.name}.")) {
dropSubColumn(col(f.name), f.dataType, f.name, colName) match {
case Some(x) => Some((f.name, x))
case None => None
} else {
}).foldLeft(df.drop(colName)) {
case (df, (colName, column)) => df.withColumn(colName, column)
Usage in spark-shell:
// if defining the functions above in your spark-shell session, you first need imports
import org.apache.spark.sql._
import org.apache.spark.sql.types._
// now you can paste the function definitions
// create a deeply nested and complex JSON structure
val jsonData = """{
"foo": "bar",
"top": {
"child1": 5,
"child2": [
"child2First": "one",
"child2Second": 2,
"child2Third": -19.51
"child3": ["foo", "bar", "baz"],
"child4": [
"child2First": "two",
"child2Second": 3,
"child2Third": 16.78
// read it into a DataFrame
val df = spark.read.option("multiline", "true").json(Seq(jsonData).toDS())
// remove a sub-column
val modifiedDf = dropColumn(df, "top.child2.child2First")
|-- foo: string (nullable = true)
|-- top: struct (nullable = false)
| |-- child1: long (nullable = true)
| |-- child2: array (nullable = true)
| | |-- element: struct (containsNull = true)
| | | |-- child2Second: long (nullable = true)
| | | |-- child2Third: double (nullable = true)
| |-- child3: array (nullable = true)
| | |-- element: string (containsNull = true)
| |-- child4: array (nullable = true)
| | |-- element: struct (containsNull = true)
| | | |-- child2First: string (nullable = true)
| | | |-- child2Second: long (nullable = true)
| | | |-- child2Third: double (nullable = true)
|foo|top |
|bar|[5, [[2, -19.51]], [foo, bar, baz], [[two, 3, 16.78]]]|
For Spark 3.1+, you can use method dropFields on struct type columns:
An expression that drops fields in StructType by name. This is a no-op
if schema doesn't contain field name(s).
val df = sql("SELECT named_struct('feat1', 1, 'feat2', 2, 'feat3', 3) features")
val df1 = df.withColumn("features", $"features".dropFields("feat1"))
Another (PySpark) way would be to drop the features.feat1 column by creating features again:
from pyspark.sql.functions import col, arrays_zip
.withColumn("features", arrays_zip("features.feat2", "features.feat3"))
.withColumn("features", col("features").cast(schema))
Where schema is the new schema (excluding features.feat1).
from pyspark.sql.types import StructType, StructField, StringType
schema = StructType(
StructField('feat2', StringType(), True),
StructField('feat3', StringType(), True),
Following spektom's code snippet for scala, I've created a similar code in Java.
Since java 8 doesn't have foldLeft, I used forEachOrdered. This code is suitable for spark 2.x (I'm using 2.1)
Also I noted that dropping a column and adding it using withColumn with the same name doesn't work, so I'm just replacing the column, and it seem to work.
Code is not fully tested, hope it works :-)
public class DataFrameUtils {
public static Dataset<Row> dropNestedColumn(Dataset<Row> dataFrame, String columnName) {
final DataFrameFolder dataFrameFolder = new DataFrameFolder(dataFrame);
.flatMap( f -> {
if (columnName.startsWith(f.name() + ".")) {
final Optional<Column> column = dropSubColumn(col(f.name()), f.dataType(), f.name(), columnName);
if (column.isPresent()) {
return Stream.of(new Tuple2<>(f.name(), column));
} else {
return Stream.empty();
} else {
return Stream.empty();
}).forEachOrdered(colTuple -> dataFrameFolder.accept(colTuple));
return dataFrameFolder.getDF();
private static Optional<Column> dropSubColumn(Column col, DataType colType, String fullColumnName, String dropColumnName) {
Optional<Column> column = Optional.empty();
if (!fullColumnName.equals(dropColumnName)) {
if (colType instanceof StructType) {
if (dropColumnName.startsWith(fullColumnName + ".")) {
column = Optional.of(struct(getColumns(col, (StructType)colType, fullColumnName, dropColumnName)));
} else {
column = Optional.of(col);
return column;
private static Column[] getColumns(Column col, StructType colType, String fullColumnName, String dropColumnName) {
return Arrays.stream(colType.fields())
.flatMap(f -> {
final Optional<Column> column = dropSubColumn(col.getField(f.name()), f.dataType(),
fullColumnName + "." + f.name(), dropColumnName);
if (column.isPresent()) {
return Stream.of(column.get().alias(f.name()));
} else {
return Stream.empty();
private static class DataFrameFolder implements Consumer<Tuple2<String, Optional<Column>>> {
private Dataset<Row> df;
public DataFrameFolder(Dataset<Row> df) {
this.df = df;
public Dataset<Row> getDF() {
return df;
public void accept(Tuple2<String, Optional<Column>> colTuple) {
if (!colTuple._2().isPresent()) {
df = df.drop(colTuple._1());
} else {
df = df.withColumn(colTuple._1(), colTuple._2().get());
Usage example:
private class Pojo {
private String str;
private Integer number;
private List<String> strList;
private Pojo2 pojo2;
public String getStr() {
return str;
public Integer getNumber() {
return number;
public List<String> getStrList() {
return strList;
public Pojo2 getPojo2() {
return pojo2;
private class Pojo2 {
private String str;
private Integer number;
private List<String> strList;
public String getStr() {
return str;
public Integer getNumber() {
return number;
public List<String> getStrList() {
return strList;
SQLContext context = new SQLContext(new SparkContext("local[1]", "test"));
Dataset<Row> df = context.createDataFrame(Collections.emptyList(), Pojo.class);
Dataset<Row> dfRes = DataFrameUtils.dropNestedColumn(df, "pojo2.str");
Original struct:
|-- number: integer (nullable = true)
|-- pojo2: struct (nullable = true)
| |-- number: integer (nullable = true)
| |-- str: string (nullable = true)
| |-- strList: array (nullable = true)
| | |-- element: string (containsNull = true)
|-- str: string (nullable = true)
|-- strList: array (nullable = true)
| |-- element: string (containsNull = true)
After drop:
|-- number: integer (nullable = true)
|-- pojo2: struct (nullable = false)
| |-- number: integer (nullable = true)
| |-- strList: array (nullable = true)
| | |-- element: string (containsNull = true)
|-- str: string (nullable = true)
|-- strList: array (nullable = true)
| |-- element: string (containsNull = true)
PySpark implementation
import pyspark.sql.functions as sf
def _drop_nested_field(
schema: StructType,
field_to_drop: str,
parents: List[str] = None,
) -> Column:
parents = list() if parents is None else parents
src_col = lambda field_names: sf.col('.'.join(f'`{c}`' for c in field_names))
if '.' in field_to_drop:
root, subfield = field_to_drop.split('.', maxsplit=1)
field_to_drop_from = next(f for f in schema.fields if f.name == root)
return sf.struct(
*[src_col(parents + [f.name]) for f in schema.fields if f.name != root],
parents=parents + [root]
# select all columns except the one to drop
return sf.struct(
*[src_col(parents + [f.name])for f in schema.fields if f.name != field_to_drop],
def drop_nested_field(
df: DataFrame,
field_to_drop: str,
) -> DataFrame:
if '.' in field_to_drop:
root, subfield = field_to_drop.split('.', maxsplit=1)
field_to_drop_from = next(f for f in df.schema.fields if f.name == root)
return df.withColumn(root, _drop_nested_field(
return df.drop(field_to_drop)
df = drop_nested_field(df, 'a.b.c.d')
Adding the java version Solution for this.
Utility Class(Pass your dataset and the nested column which has to be dropped to dropNestedColumn function).
(There are few bugs in Lior Chaga's answer, I have corrected them while I tried to use his answer).
public class NestedColumnActions {
dataset : dataset in which we want to drop columns
columnName : nested column that needs to be deleted
public static Dataset<?> dropNestedColumn(Dataset<?> dataset, String columnName) {
//Special case of top level column deletion
return dataset.drop(columnName);
final DataSetModifier dataFrameFolder = new DataSetModifier(dataset);
.flatMap(f -> {
//If the column name to be deleted starts with current top level column
if (columnName.startsWith(f.name() + DOT)) {
//Get new column structure under f , expected after deleting the required column
final Optional<Column> column = dropSubColumn(functions.col(f.name()), f.dataType(), f.name(), columnName);
if (column.isPresent()) {
return Stream.of(new Tuple2<>(f.name(), column));
} else {
return Stream.empty();
} else {
return Stream.empty();
//Call accept function with Tuples of (top level column name, new column structure under it)
.forEach(colTuple -> dataFrameFolder.accept(colTuple));
return dataFrameFolder.getDataset();
private static Optional<Column> dropSubColumn(Column col, DataType colType, String fullColumnName, String dropColumnName) {
Optional<Column> column = Optional.empty();
if (!fullColumnName.equals(dropColumnName)) {
if (colType instanceof StructType) {
if (dropColumnName.startsWith(fullColumnName + DOT)) {
column = Optional.of(functions.struct(getColumns(col, (StructType) colType, fullColumnName, dropColumnName)));
else {
column = Optional.of(col);
} else {
column = Optional.of(col);
return column;
private static Column[] getColumns(Column col, StructType colType, String fullColumnName, String dropColumnName) {
return Arrays.stream(colType.fields())
.flatMap(f -> {
final Optional<Column> column = dropSubColumn(col.getField(f.name()), f.dataType(),
fullColumnName + "." + f.name(), dropColumnName);
if (column.isPresent()) {
return Stream.of(column.get().alias(f.name()));
} else {
return Stream.empty();
private static class DataSetModifier implements Consumer<Tuple2<String, Optional<Column>>> {
private Dataset<?> df;
public DataSetModifier(Dataset<?> df) {
this.df = df;
public Dataset<?> getDataset() {
return df;
colTuple[0]:top level column name
colTuple[1]:new column structure under it
public void accept(Tuple2<String, Optional<Column>> colTuple) {
if (!colTuple._2().isPresent()) {
df = df.drop(colTuple._1());
} else {
df = df.withColumn(colTuple._1(), colTuple._2().get());
The Make Structs Easy* library makes it easy to perform operations like adding, dropping, and renaming fields inside nested data structures. The library is available in both Scala and Python.
Assuming you have the following data:
import org.apache.spark.sql.functions._
case class Features(feat1: String, feat2: String, feat3: String)
case class Record(features: Features, arrayOfFeatures: Seq[Features])
val df = Seq(
Record(Features("hello", "world", "!"), Seq(Features("red", "orange", "yellow"), Features("green", "blue", "indigo")))
// root
// |-- features: struct (nullable = true)
// | |-- feat1: string (nullable = true)
// | |-- feat2: string (nullable = true)
// | |-- feat3: string (nullable = true)
// |-- arrayOfFeatures: array (nullable = true)
// | |-- element: struct (containsNull = true)
// | | |-- feat1: string (nullable = true)
// | | |-- feat2: string (nullable = true)
// | | |-- feat3: string (nullable = true)
// +-----------------+----------------------------------------------+
// |features |arrayOfFeatures |
// +-----------------+----------------------------------------------+
// |[hello, world, !]|[[red, orange, yellow], [green, blue, indigo]]|
// +-----------------+----------------------------------------------+
Then dropping feat2 from features is as simple as:
import com.github.fqaiser94.mse.methods._
// drop feat2 from features
df.withColumn("features", $"features".dropFields("feat2")).show(false)
// +----------+----------------------------------------------+
// |features |arrayOfFeatures |
// +----------+----------------------------------------------+
// |[hello, !]|[[red, orange, yellow], [green, blue, indigo]]|
// +----------+----------------------------------------------+
I noticed there were a lot of follow-up comments on other solutions asking if there's a way to drop a Column nested inside a struct nested inside of an array. This can be done by combining the functions provided by the Make Structs Easy library with the functions provided by spark-hofs library, as follows:
import za.co.absa.spark.hofs._
// drop feat2 in each element of arrayOfFeatures
df.withColumn("arrayOfFeatures", transform($"arrayOfFeatures", features => features.dropFields("feat2"))).show(false)
// +-----------------+--------------------------------+
// |features |arrayOfFeatures |
// +-----------------+--------------------------------+
// |[hello, world, !]|[[red, yellow], [green, indigo]]|
// +-----------------+--------------------------------+
*Full disclosure: I am the author of the Make Structs Easy library that is referenced in this answer.
With Spark 3.1+, short and effective:
object DatasetOps {
implicit class DatasetOps[T](val dataset: Dataset[T]) {
def dropFields(fieldNames: String*): DataFrame =
fieldNames.foldLeft(dataset.toDF()) { (dataset, fieldName) =>
val subFieldRegex = "(\\w+)\\.(.+)".r
fieldName match {
case subFieldRegex(columnName, subFieldPath) =>
dataset.withColumn(columnName, col(columnName).dropFields(subFieldPath))
case _ => dataset.drop(fieldName)
This also preserves the required or not boolean in the schema.
dataset.dropFields("some_column", "some_struct.some_sub_field.some_field")