How to optimize dataset aggregation on top of java objects dataset - scala

could you please support me on stupid question:
I have some Java class:
public class ProbePoint implements Serializable, Cloneable {
private long arrivalTimeMillis = 0;
private long captureTimeMillis = 0;
//...
}
public class Trip implements Serializable, Cloneable {
private ArrayList<ProbePoint> points = new ArrayList<>();
//...
}
I have Dataset[Trip]. I need to collect some min/max values. What would be better implementation of next:
public class DataRanges implements Serializable {
private long minCaptureTs;
private long maxArrivalTs;
}
val timesDs: Dataset[DataRanges] = trips.mapPartitions(t => {
var minCaptTime = Long.MaxValue
var maxArrTime = Long.MinValue
t.foreach(f => {
if (f.points.head < minCaptTime) minCaptTime = f.points.head
if (f.points.last.getArrivalTimeMillis > maxArrTime) maxArrTime = f.points.last.getArrivalTimeMillis
})
Iterator[DataRanges](
new DataRanges(minStartTime, maxEndTime, minArrTime, maxArrTime))
})(Encoders.bean(classOf[DataRanges]))
val times = timesDs.agg(min("minCaptureTs"), max("maxArrivalTs")).head()
}
}

Looking at the Java classes, the schema of Dataset[Trip] should be
root
|-- points: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- arrivalTimeMillis: long (nullable = false)
| | |-- captureTimeMillis: long (nullable = false)
It would be possible to explode the array and then take the min and max of the resulting columns, thus simplyfing the code a bit:
val df = tripsDF
.withColumn("exploded", explode($"points"))
.withColumn("arrivalTimeMillis", $"exploded.arrivalTimeMillis")
.withColumn("captureTimeMillis", $"exploded.captureTimeMillis")
val Row(minArrivaltime: Long, maxCaptureTimeMillis: Long) =
df.agg(min("arrivalTimeMillis"), max("captureTimeMillis")).head
println(minArrivaltime)
println(maxCaptureTimeMillis)
The code in the question assumes that the arrays in the Trip class are sorted: the minimal capture time is always taken from the first element of the array and the maximum arrival time is always taken from the last one. This code takes the minimum and maximum over all ProbePoints, so the logic is slightly different.

Related

How to group by a dataframe of a specific class

I have one dataframe with this schema:
|-- Agreement_A1: string (nullable = true)
|-- Line_A1: string (nullable = true)
|-- Line_A2: string (nullable = true)
I create a new dataframe with this code:
val df2 = df.map(row => new MapResultRequestLine().apply(row))(Encoders.bean(classOf[AgreementLine]))
Function apply() is this:
public AgreementLine apply(Row row) {
AgreementLine agrLine = new AgreementLine();
agrLine.Agreement_A1 = row.getAs("Agreement_A1");
Line res = new Line();
res.Line_A1 = row.getAs("Line_A1");
res.Line_A2 = row.getAs("Line_A2");
agrLine.line = res
return agrLine;
}
Class AgreementLine looks like this:
public class AgreementLine{
public String agreementCrocCode;
public Line line;
}
Class Line is this:
public class Line{
public String Line_A1;
public String Line_A2;
}
How to group df2 so the result dataframe had Agreement_A1 column and the list of Line?
I have tried it this way:
val groupedDF = df2.groupBy($"Agreement_A1").agg(collect_set((array($"line"))).as("lines"))
But it shows an error "cannot resolve 'Agreement_A1' given input columns: [];"
The issue is here:
val df2 = df.map(row => new MapResultRequestLine().apply(row))(Encoders.bean(classOf[AgreementLine]))
scala doesn't show the data type, so you think it's a DataFrame(as DataSet[Row]).
But actually, it's a DataSet(as DataSet[AgreementLine]). And thanks for your encoder, it has lost all the schema, that's the reason why your df2.printSchema return the empty result.
Thus, when you call df2.groupBy($"Agreement_A1"), it will throw the Exception because there is no column named "Agreement_A1".
Obviously, the solution is to update schema of your DataSet(df2 in your case).
And sadly, I have no idea how to do with this(I'm a Rookie as well).
My only solution is to convert the dataset back to RDD[Row](mention it is RDD[AgreementLine] if you want to use df2.rdd), and build a new DataFrame with custom schema.
Hope you'll get a better solution.

function to each row of Spark Dataframe

I have a spark Dataframe (df) with 2 column's (Report_id and Cluster_number).
I want to apply a function (getClusterInfo) to df which will return the name for each cluster i.e. if cluster number is '3' then for a specific report_id, the 3 below mentioned rows will be written:
{"cluster_id":"1","influencers":[{"screenName":"A"},{"screenName":"B"},{"screenName":"C"},...]}
{"cluster_id":"2","influencers":[{"screenName":"D"},{"screenName":"E"},{"screenName":"F"},...]}
{"cluster_id":"3","influencers":[{"screenName":"G"},{"screenName":"H"},{"screenName":"E"},...]}
I am using foreach on df to apply getClusterInfo function, but can't figure out how to convert o/p to a Dataframe (Report_id, Array[cluster_info]).
Here is the code snippet:
df.foreach(row => {
val report_id = row(0)
val cluster_no = row(1).toString
val cluster_numbers = new Range(0, cluster_no.toInt - 1, 1)
for (cluster <- cluster_numbers.by(1)) {
val cluster_id = report_id + "_" + cluster
//get cluster influencers
val result = getClusterInfo(cluster_id)
println(result.get)
val res : String = result.get.toString()
// TODO ?
}
.. //TODO ?
})
Geenrally speaking, you shouldn't use foreach when you want to map something into something else; foreach is good for applying functions that only have side-effects and return nothing.
In this case, if I got the details right (probably not), you can use a User-Defined Function (UDF) and explode the result:
import org.apache.spark.sql.functions._
import spark.implicits._
// I'm assuming we have these case classes (or similar)
case class Influencer(screenName: String)
case class ClusterInfo(cluster_id: String, influencers: Array[Influencer])
// I'm assuming this method is supplied - with your own implementation
def getClusterInfo(clusterId: String): ClusterInfo =
ClusterInfo(clusterId, Array(Influencer(clusterId)))
// some sample data - assuming both columns are integers:
val df = Seq((222, 3), (333, 4)).toDF("Report_id", "Cluster_number")
// actual solution:
// UDF that returns an array of ClusterInfo;
// Array size is 'clusterNo', creates cluster id for each element and maps it to info
val clusterInfoUdf = udf { (clusterNo: Int, reportId: Int) =>
(1 to clusterNo).map(v => s"${reportId}_$v").map(getClusterInfo)
}
// apply UDF to each record and explode - to create one record per array item
val result = df.select(explode(clusterInfoUdf($"Cluster_number", $"Report_id")))
result.printSchema()
// root
// |-- col: struct (nullable = true)
// | |-- cluster_id: string (nullable = true)
// | |-- influencers: array (nullable = true)
// | | |-- element: struct (containsNull = true)
// | | | |-- screenName: string (nullable = true)
result.show(truncate = false)
// +-----------------------------+
// |col |
// +-----------------------------+
// |[222_1,WrappedArray([222_1])]|
// |[222_2,WrappedArray([222_2])]|
// |[222_3,WrappedArray([222_3])]|
// |[333_1,WrappedArray([333_1])]|
// |[333_2,WrappedArray([333_2])]|
// |[333_3,WrappedArray([333_3])]|
// |[333_4,WrappedArray([333_4])]|
// +-----------------------------+

Spark - recursive function as udf generates an Exception

I am working with DataFrames which elements have got a schema similar to:
root
|-- NPAData: struct (nullable = true)
| |-- NPADetails: struct (nullable = true)
| | |-- location: string (nullable = true)
| | |-- manager: string (nullable = true)
| |-- service: array (nullable = true)
| | |-- element: struct (containsNull = true)
| | | |-- serviceName: string (nullable = true)
| | | |-- serviceCode: string (nullable = true)
|-- NPAHeader: struct (nullable = true)
| | |-- npaNumber: string (nullable = true)
| | |-- date: string (nullable = true)
In my DataFrame I want to group all elements which has the same NPAHeader.code, so to do that I am using the following line:
val groupedNpa = orderedNpa.groupBy($"NPAHeader.code" ).agg(collect_list(struct($"NPAData",$"NPAHeader")).as("npa"))
After this I have a dataframe with the following schema:
StructType(StructField(npaNumber,StringType,true), StructField(npa,ArrayType(StructType(StructField(NPAData...)))))
An example of each Row would be something similar to:
[1234,WrappedArray([npaNew,npaOlder,...npaOldest])]
Now what I want is to generate another DataFrame with picks up just one of the element in the WrappedArray, so I want an output similar to:
[1234,npaNew]
Note: The chosen element from the WrappedArray is the one that matches a complext logic after iterating over the whole WrappedArray. But to simplify the question, I will pick up always the last element of the WrappedArray (after iterating all over it).
To do so, I want to define a recurside udf
import org.apache.spark.sql.functions.udf
def returnRow(elementList : Row)(index:Int): Row = {
val dif = elementList.size - index
val row :Row = dif match{
case 0 => elementList.getAs[Row](index)
case _ => returnRow(elementList)(index + 1)
}
row
}
val returnRow_udf = udf(returnRow _)
groupedNpa.map{row => (row.getAs[String]("npaNumber"),returnRow_udf(groupedNpa("npa")(0)))}
But I am getting the following error in the map:
Exception in thread "main" java.lang.UnsupportedOperationException:
Schema for type Int => Unit is not supported
What am I doing wrong?
As an aside, I am not sure if I am passing correctly the npa column with groupedNpa("npa"). I am accesing the WrappedArray as a Row, because I don't know how to iterate over Array[Row] (the get(index) method is not present in Array[Row])
TL;DR Just use one of the methods described in How to select the first row of each group?
If you want to use complex logic, and return Row you can skip SQL API and use groupByKey:
val f: (String, Iterator[org.apache.spark.sql.Row]) => Row
val encoder: Encoder
df.groupByKey(_.getAs[String]("NPAHeader.code")).mapGroups(f)(encoder)
or better:
val g: (Row, Row) => Row
df.groupByKey(_.getAs[String]("NPAHeader.code")).reduceGroups(g)
where encoder is a valid RowEncoder (Encoder error while trying to map dataframe row to updated row).
Your code is faulty in multiple ways:
groupBy doesn't guarantee the order of values. So:
orderBy(...).groupBy(....).agg(collect_list(...))
can have non-deterministic output. If you really decide to go this route you should skip orderBy and sort collected array explicitly.
You cannot pass curried function to udf. You'd have to uncurry it first, but it would require different order of arguments (see example below).
If you could, this might be the correct way to call it (Note that you omit the second argument):
returnRow_udf(groupedNpa("npa")(0))
To make it worse, you call it inside map, where udfs are not applicable at all.
udf cannot return Row. It has to return external Scala type.
External representation for array<struct> is Seq[Row]. You cannot just substitute it with Row.
SQL arrays can be accessed by index with apply:
df.select($"array"(size($"array") - 1))
but it is not a correct method due to non-determinism. You could apply sort_array, but as pointed out at the beginning, there are more efficient solutions.
Surprisingly recursion is not so relevant. You could design function like this:
def size(i: Int=0)(xs: Seq[Any]): Int = xs match {
case Seq() => i
case null => i
case Seq(h, t # _*) => size(i + 1)(t)
}
val size_ = udf(size() _)
and it would work just fine:
Seq((1, Seq("a", "b", "c"))).toDF("id", "array")
.select(size_($"array"))
although recursion is an overkill, if you can just iterate over Seq.

How to filter outlier rows from Spark Dataframe based on distance to previous valid row

I'm processing geospatial data using Spark 2.0 Dataframes with the following schema:
root
|-- date: timestamp (nullable = true)
|-- lat: double (nullable = true)
|-- lon: double (nullable = true)
|-- accuracy: double (nullable = true)
|-- track_id: long (nullable = true)
I have seen that there are jumps of the location signal to a complete different place. The strange thing is, that the signal then remains for a certain time, say aound 25 seconds or 5 samples at the remote location and then jumps back to where I stand.
I'd like to remove outliers by calculating the speed between the current and the last valid row by calculating the speed between the points.
If the speed is above a given threshold, the current row should be dropped and the last valid row remains the same.
If the speed is below the threshold the current row is added to the result data frame and becomes the new last valid row.
I tried to accomplish this by using the mapPartition function:
case class OutlierFilterTrackEntry(val dateSec: Long, val track_id: Long, val lat: Double, val lon: Double, val accuracy: Double)
def filterOutlier(iter : Iterator[OutlierFilterTrackEntry]) : Iterator[OutlierFilterTrackEntry] = {
val maxSpeed = 1000 / 3.6 // 1000 km/h in m/s
val buf = scala.collection.mutable.ListBuffer.empty[OutlierFilterTrackEntry]
var lastValid : OutlierFilterTrackEntry = null
while (iter.hasNext){
val cur = iter.next;
if(lastValid == null){
lastValid = cur
}else{
//val distance = calculateDistance(cur.lat, cur.lon, lastValid.lat, lastValid.lon)
val distance = GeometryEngine.geodesicDistanceOnWGS84(new Point(cur.lon, cur.lat), new Point(lastValid.lon, lastValid.lat))
val diffTime = (cur.dateSec - lastValid.dateSec)
val speed = distance / diffTime
println(speed)
if(speed < maxSpeed) {
lastValid = cur
buf.append(cur)
}
}
}
buf.toList.iterator
}
val numPartitions = tracks6.select($"track_id").distinct().count.toInt
println(numPartitions)
val partitionedTracks = tracks6.select(col("date").cast("BigInt").alias("dateSec"), $"track_id", $"lat", $"lon", $"accuracy")
.orderBy($"date")
.repartition(numPartitions, $"track_id")
val tracksWithoutOutliers = partitionedTracks
.as[OutlierFilterTrackEntry]
.mapPartitions(filterOutlier)
println(tracksWithoutOutliers.count)
tracksWithoutOutliers.toDF.show(400)
But I get the following error:
org.apache.spark.SparkException: Job aborted due to stage failure: Task not serializable: java.io.NotSerializableException: org.apache.spark.sql.expressions.WindowSpec
Serialization stack:
- object not serializable (class: org.apache.spark.sql.expressions.WindowSpec, value: org.apache.spark.sql.expressions.WindowSpec#711a5138)
- field (class: linebb5d1438a8a14914aa58d2465d24fa2a59.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw, name: windowSpecAllPrev, type: class org.apache.spark.sql.expressions.WindowSpec)
Any other approaches are welcome.

Dropping a nested column from Spark DataFrame

I have a DataFrame with the schema
root
|-- label: string (nullable = true)
|-- features: struct (nullable = true)
| |-- feat1: string (nullable = true)
| |-- feat2: string (nullable = true)
| |-- feat3: string (nullable = true)
While, I am able to filter the data frame using
val data = rawData
.filter( !(rawData("features.feat1") <=> "100") )
I am unable to drop the columns using
val data = rawData
.drop("features.feat1")
Is it something that I am doing wrong here? I also tried (unsuccessfully) doing drop(rawData("features.feat1")), though it does not make much sense to do so.
Thanks in advance,
Nikhil
It is just a programming exercise but you can try something like this:
import org.apache.spark.sql.{DataFrame, Column}
import org.apache.spark.sql.types.{StructType, StructField}
import org.apache.spark.sql.{functions => f}
import scala.util.Try
case class DFWithDropFrom(df: DataFrame) {
def getSourceField(source: String): Try[StructField] = {
Try(df.schema.fields.filter(_.name == source).head)
}
def getType(sourceField: StructField): Try[StructType] = {
Try(sourceField.dataType.asInstanceOf[StructType])
}
def genOutputCol(names: Array[String], source: String): Column = {
f.struct(names.map(x => f.col(source).getItem(x).alias(x)): _*)
}
def dropFrom(source: String, toDrop: Array[String]): DataFrame = {
getSourceField(source)
.flatMap(getType)
.map(_.fieldNames.diff(toDrop))
.map(genOutputCol(_, source))
.map(df.withColumn(source, _))
.getOrElse(df)
}
}
Example usage:
scala> case class features(feat1: String, feat2: String, feat3: String)
defined class features
scala> case class record(label: String, features: features)
defined class record
scala> val df = sc.parallelize(Seq(record("a_label", features("f1", "f2", "f3")))).toDF
df: org.apache.spark.sql.DataFrame = [label: string, features: struct<feat1:string,feat2:string,feat3:string>]
scala> DFWithDropFrom(df).dropFrom("features", Array("feat1")).show
+-------+--------+
| label|features|
+-------+--------+
|a_label| [f2,f3]|
+-------+--------+
scala> DFWithDropFrom(df).dropFrom("foobar", Array("feat1")).show
+-------+----------+
| label| features|
+-------+----------+
|a_label|[f1,f2,f3]|
+-------+----------+
scala> DFWithDropFrom(df).dropFrom("features", Array("foobar")).show
+-------+----------+
| label| features|
+-------+----------+
|a_label|[f1,f2,f3]|
+-------+----------+
Add an implicit conversion and you're good to go.
This version allows you to remove nested columns at any level:
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructType, DataType}
/**
* Various Spark utilities and extensions of DataFrame
*/
object DataFrameUtils {
private def dropSubColumn(col: Column, colType: DataType, fullColName: String, dropColName: String): Option[Column] = {
if (fullColName.equals(dropColName)) {
None
} else {
colType match {
case colType: StructType =>
if (dropColName.startsWith(s"${fullColName}.")) {
Some(struct(
colType.fields
.flatMap(f =>
dropSubColumn(col.getField(f.name), f.dataType, s"${fullColName}.${f.name}", dropColName) match {
case Some(x) => Some(x.alias(f.name))
case None => None
})
: _*))
} else {
Some(col)
}
case other => Some(col)
}
}
}
protected def dropColumn(df: DataFrame, colName: String): DataFrame = {
df.schema.fields
.flatMap(f => {
if (colName.startsWith(s"${f.name}.")) {
dropSubColumn(col(f.name), f.dataType, f.name, colName) match {
case Some(x) => Some((f.name, x))
case None => None
}
} else {
None
}
})
.foldLeft(df.drop(colName)) {
case (df, (colName, column)) => df.withColumn(colName, column)
}
}
/**
* Extended version of DataFrame that allows to operate on nested fields
*/
implicit class ExtendedDataFrame(df: DataFrame) extends Serializable {
/**
* Drops nested field from DataFrame
*
* #param colName Dot-separated nested field name
*/
def dropNestedColumn(colName: String): DataFrame = {
DataFrameUtils.dropColumn(df, colName)
}
}
}
Usage:
import DataFrameUtils._
df.dropNestedColumn("a.b.c.d")
Expanding on spektom answer. With support for array types:
object DataFrameUtils {
private def dropSubColumn(col: Column, colType: DataType, fullColName: String, dropColName: String): Option[Column] = {
if (fullColName.equals(dropColName)) {
None
} else if (dropColName.startsWith(s"$fullColName.")) {
colType match {
case colType: StructType =>
Some(struct(
colType.fields
.flatMap(f =>
dropSubColumn(col.getField(f.name), f.dataType, s"$fullColName.${f.name}", dropColName) match {
case Some(x) => Some(x.alias(f.name))
case None => None
})
: _*))
case colType: ArrayType =>
colType.elementType match {
case innerType: StructType =>
Some(struct(innerType.fields
.flatMap(f =>
dropSubColumn(col.getField(f.name), f.dataType, s"$fullColName.${f.name}", dropColName) match {
case Some(x) => Some(x.alias(f.name))
case None => None
})
: _*))
}
case other => Some(col)
}
} else {
Some(col)
}
}
protected def dropColumn(df: DataFrame, colName: String): DataFrame = {
df.schema.fields
.flatMap(f => {
if (colName.startsWith(s"${f.name}.")) {
dropSubColumn(col(f.name), f.dataType, f.name, colName) match {
case Some(x) => Some((f.name, x))
case None => None
}
} else {
None
}
})
.foldLeft(df.drop(colName)) {
case (df, (colName, column)) => df.withColumn(colName, column)
}
}
/**
* Extended version of DataFrame that allows to operate on nested fields
*/
implicit class ExtendedDataFrame(df: DataFrame) extends Serializable {
/**
* Drops nested field from DataFrame
*
* #param colName Dot-separated nested field name
*/
def dropNestedColumn(colName: String): DataFrame = {
DataFrameUtils.dropColumn(df, colName)
}
}
}
I will expand upon mmendez.semantic's answer here, and accounting for the issues described in the sub-thread.
def dropSubColumn(col: Column, colType: DataType, fullColName: String, dropColName: String): Option[Column] = {
if (fullColName.equals(dropColName)) {
None
} else if (dropColName.startsWith(s"$fullColName.")) {
colType match {
case colType: StructType =>
Some(struct(
colType.fields
.flatMap(f =>
dropSubColumn(col.getField(f.name), f.dataType, s"$fullColName.${f.name}", dropColName) match {
case Some(x) => Some(x.alias(f.name))
case None => None
})
: _*))
case colType: ArrayType =>
colType.elementType match {
case innerType: StructType =>
// we are potentially dropping a column from within a struct, that is itself inside an array
// Spark has some very strange behavior in this case, which they insist is not a bug
// see https://issues.apache.org/jira/browse/SPARK-31779 and associated comments
// and also the thread here: https://stackoverflow.com/a/39943812/375670
// this is a workaround for that behavior
// first, get all struct fields
val innerFields = innerType.fields
// next, create a new type for all the struct fields EXCEPT the column that is to be dropped
// we will need this later
val preserveNamesStruct = ArrayType(StructType(
innerFields.filterNot(f => s"$fullColName.${f.name}".equals(dropColName))
))
// next, apply dropSubColumn recursively to build up the new values after dropping the column
val filteredInnerFields = innerFields.flatMap(f =>
dropSubColumn(col.getField(f.name), f.dataType, s"$fullColName.${f.name}", dropColName) match {
case Some(x) => Some(x.alias(f.name))
case None => None
}
)
// finally, use arrays_zip to unwrap the arrays that were introduced by building up the new. filtered
// struct in this way (see comments in SPARK-31779), and then cast to the StructType we created earlier
// to get the original names back
Some(arrays_zip(filteredInnerFields:_*).cast(preserveNamesStruct))
}
case _ => Some(col)
}
} else {
Some(col)
}
}
def dropColumn(df: DataFrame, colName: String): DataFrame = {
df.schema.fields.flatMap(f => {
if (colName.startsWith(s"${f.name}.")) {
dropSubColumn(col(f.name), f.dataType, f.name, colName) match {
case Some(x) => Some((f.name, x))
case None => None
}
} else {
None
}
}).foldLeft(df.drop(colName)) {
case (df, (colName, column)) => df.withColumn(colName, column)
}
}
Usage in spark-shell:
// if defining the functions above in your spark-shell session, you first need imports
import org.apache.spark.sql._
import org.apache.spark.sql.types._
// now you can paste the function definitions
// create a deeply nested and complex JSON structure
val jsonData = """{
"foo": "bar",
"top": {
"child1": 5,
"child2": [
{
"child2First": "one",
"child2Second": 2,
"child2Third": -19.51
}
],
"child3": ["foo", "bar", "baz"],
"child4": [
{
"child2First": "two",
"child2Second": 3,
"child2Third": 16.78
}
]
}
}"""
// read it into a DataFrame
val df = spark.read.option("multiline", "true").json(Seq(jsonData).toDS())
// remove a sub-column
val modifiedDf = dropColumn(df, "top.child2.child2First")
modifiedDf.printSchema
root
|-- foo: string (nullable = true)
|-- top: struct (nullable = false)
| |-- child1: long (nullable = true)
| |-- child2: array (nullable = true)
| | |-- element: struct (containsNull = true)
| | | |-- child2Second: long (nullable = true)
| | | |-- child2Third: double (nullable = true)
| |-- child3: array (nullable = true)
| | |-- element: string (containsNull = true)
| |-- child4: array (nullable = true)
| | |-- element: struct (containsNull = true)
| | | |-- child2First: string (nullable = true)
| | | |-- child2Second: long (nullable = true)
| | | |-- child2Third: double (nullable = true)
modifiedDf.show(truncate=false)
+---+------------------------------------------------------+
|foo|top |
+---+------------------------------------------------------+
|bar|[5, [[2, -19.51]], [foo, bar, baz], [[two, 3, 16.78]]]|
+---+------------------------------------------------------+
For Spark 3.1+, you can use method dropFields on struct type columns:
An expression that drops fields in StructType by name. This is a no-op
if schema doesn't contain field name(s).
val df = sql("SELECT named_struct('feat1', 1, 'feat2', 2, 'feat3', 3) features")
val df1 = df.withColumn("features", $"features".dropFields("feat1"))
Another (PySpark) way would be to drop the features.feat1 column by creating features again:
from pyspark.sql.functions import col, arrays_zip
display(df
.withColumn("features", arrays_zip("features.feat2", "features.feat3"))
.withColumn("features", col("features").cast(schema))
)
Where schema is the new schema (excluding features.feat1).
from pyspark.sql.types import StructType, StructField, StringType
schema = StructType(
[
StructField('feat2', StringType(), True),
StructField('feat3', StringType(), True),
]
)
Following spektom's code snippet for scala, I've created a similar code in Java.
Since java 8 doesn't have foldLeft, I used forEachOrdered. This code is suitable for spark 2.x (I'm using 2.1)
Also I noted that dropping a column and adding it using withColumn with the same name doesn't work, so I'm just replacing the column, and it seem to work.
Code is not fully tested, hope it works :-)
public class DataFrameUtils {
public static Dataset<Row> dropNestedColumn(Dataset<Row> dataFrame, String columnName) {
final DataFrameFolder dataFrameFolder = new DataFrameFolder(dataFrame);
Arrays.stream(dataFrame.schema().fields())
.flatMap( f -> {
if (columnName.startsWith(f.name() + ".")) {
final Optional<Column> column = dropSubColumn(col(f.name()), f.dataType(), f.name(), columnName);
if (column.isPresent()) {
return Stream.of(new Tuple2<>(f.name(), column));
} else {
return Stream.empty();
}
} else {
return Stream.empty();
}
}).forEachOrdered(colTuple -> dataFrameFolder.accept(colTuple));
return dataFrameFolder.getDF();
}
private static Optional<Column> dropSubColumn(Column col, DataType colType, String fullColumnName, String dropColumnName) {
Optional<Column> column = Optional.empty();
if (!fullColumnName.equals(dropColumnName)) {
if (colType instanceof StructType) {
if (dropColumnName.startsWith(fullColumnName + ".")) {
column = Optional.of(struct(getColumns(col, (StructType)colType, fullColumnName, dropColumnName)));
}
} else {
column = Optional.of(col);
}
}
return column;
}
private static Column[] getColumns(Column col, StructType colType, String fullColumnName, String dropColumnName) {
return Arrays.stream(colType.fields())
.flatMap(f -> {
final Optional<Column> column = dropSubColumn(col.getField(f.name()), f.dataType(),
fullColumnName + "." + f.name(), dropColumnName);
if (column.isPresent()) {
return Stream.of(column.get().alias(f.name()));
} else {
return Stream.empty();
}
}
).toArray(Column[]::new);
}
private static class DataFrameFolder implements Consumer<Tuple2<String, Optional<Column>>> {
private Dataset<Row> df;
public DataFrameFolder(Dataset<Row> df) {
this.df = df;
}
public Dataset<Row> getDF() {
return df;
}
#Override
public void accept(Tuple2<String, Optional<Column>> colTuple) {
if (!colTuple._2().isPresent()) {
df = df.drop(colTuple._1());
} else {
df = df.withColumn(colTuple._1(), colTuple._2().get());
}
}
}
Usage example:
private class Pojo {
private String str;
private Integer number;
private List<String> strList;
private Pojo2 pojo2;
public String getStr() {
return str;
}
public Integer getNumber() {
return number;
}
public List<String> getStrList() {
return strList;
}
public Pojo2 getPojo2() {
return pojo2;
}
}
private class Pojo2 {
private String str;
private Integer number;
private List<String> strList;
public String getStr() {
return str;
}
public Integer getNumber() {
return number;
}
public List<String> getStrList() {
return strList;
}
}
SQLContext context = new SQLContext(new SparkContext("local[1]", "test"));
Dataset<Row> df = context.createDataFrame(Collections.emptyList(), Pojo.class);
Dataset<Row> dfRes = DataFrameUtils.dropNestedColumn(df, "pojo2.str");
Original struct:
root
|-- number: integer (nullable = true)
|-- pojo2: struct (nullable = true)
| |-- number: integer (nullable = true)
| |-- str: string (nullable = true)
| |-- strList: array (nullable = true)
| | |-- element: string (containsNull = true)
|-- str: string (nullable = true)
|-- strList: array (nullable = true)
| |-- element: string (containsNull = true)
After drop:
root
|-- number: integer (nullable = true)
|-- pojo2: struct (nullable = false)
| |-- number: integer (nullable = true)
| |-- strList: array (nullable = true)
| | |-- element: string (containsNull = true)
|-- str: string (nullable = true)
|-- strList: array (nullable = true)
| |-- element: string (containsNull = true)
PySpark implementation
import pyspark.sql.functions as sf
def _drop_nested_field(
schema: StructType,
field_to_drop: str,
parents: List[str] = None,
) -> Column:
parents = list() if parents is None else parents
src_col = lambda field_names: sf.col('.'.join(f'`{c}`' for c in field_names))
if '.' in field_to_drop:
root, subfield = field_to_drop.split('.', maxsplit=1)
field_to_drop_from = next(f for f in schema.fields if f.name == root)
return sf.struct(
*[src_col(parents + [f.name]) for f in schema.fields if f.name != root],
_drop_nested_field(
schema=field_to_drop_from.dataType,
field_to_drop=subfield,
parents=parents + [root]
).alias(root)
)
else:
# select all columns except the one to drop
return sf.struct(
*[src_col(parents + [f.name])for f in schema.fields if f.name != field_to_drop],
)
def drop_nested_field(
df: DataFrame,
field_to_drop: str,
) -> DataFrame:
if '.' in field_to_drop:
root, subfield = field_to_drop.split('.', maxsplit=1)
field_to_drop_from = next(f for f in df.schema.fields if f.name == root)
return df.withColumn(root, _drop_nested_field(
schema=field_to_drop_from.dataType,
field_to_drop=subfield,
parents=[root]
))
else:
return df.drop(field_to_drop)
df = drop_nested_field(df, 'a.b.c.d')
Adding the java version Solution for this.
Utility Class(Pass your dataset and the nested column which has to be dropped to dropNestedColumn function).
(There are few bugs in Lior Chaga's answer, I have corrected them while I tried to use his answer).
public class NestedColumnActions {
/*
dataset : dataset in which we want to drop columns
columnName : nested column that needs to be deleted
*/
public static Dataset<?> dropNestedColumn(Dataset<?> dataset, String columnName) {
//Special case of top level column deletion
if(!columnName.contains("."))
return dataset.drop(columnName);
final DataSetModifier dataFrameFolder = new DataSetModifier(dataset);
Arrays.stream(dataset.schema().fields())
.flatMap(f -> {
//If the column name to be deleted starts with current top level column
if (columnName.startsWith(f.name() + DOT)) {
//Get new column structure under f , expected after deleting the required column
final Optional<Column> column = dropSubColumn(functions.col(f.name()), f.dataType(), f.name(), columnName);
if (column.isPresent()) {
return Stream.of(new Tuple2<>(f.name(), column));
} else {
return Stream.empty();
}
} else {
return Stream.empty();
}
})
//Call accept function with Tuples of (top level column name, new column structure under it)
.forEach(colTuple -> dataFrameFolder.accept(colTuple));
return dataFrameFolder.getDataset();
}
private static Optional<Column> dropSubColumn(Column col, DataType colType, String fullColumnName, String dropColumnName) {
Optional<Column> column = Optional.empty();
if (!fullColumnName.equals(dropColumnName)) {
if (colType instanceof StructType) {
if (dropColumnName.startsWith(fullColumnName + DOT)) {
column = Optional.of(functions.struct(getColumns(col, (StructType) colType, fullColumnName, dropColumnName)));
}
else {
column = Optional.of(col);
}
} else {
column = Optional.of(col);
}
}
return column;
}
private static Column[] getColumns(Column col, StructType colType, String fullColumnName, String dropColumnName) {
return Arrays.stream(colType.fields())
.flatMap(f -> {
final Optional<Column> column = dropSubColumn(col.getField(f.name()), f.dataType(),
fullColumnName + "." + f.name(), dropColumnName);
if (column.isPresent()) {
return Stream.of(column.get().alias(f.name()));
} else {
return Stream.empty();
}
}
).toArray(Column[]::new);
}
private static class DataSetModifier implements Consumer<Tuple2<String, Optional<Column>>> {
private Dataset<?> df;
public DataSetModifier(Dataset<?> df) {
this.df = df;
}
public Dataset<?> getDataset() {
return df;
}
/*
colTuple[0]:top level column name
colTuple[1]:new column structure under it
*/
#Override
public void accept(Tuple2<String, Optional<Column>> colTuple) {
if (!colTuple._2().isPresent()) {
df = df.drop(colTuple._1());
} else {
df = df.withColumn(colTuple._1(), colTuple._2().get());
}
}
}
}
The Make Structs Easy* library makes it easy to perform operations like adding, dropping, and renaming fields inside nested data structures. The library is available in both Scala and Python.
Assuming you have the following data:
import org.apache.spark.sql.functions._
case class Features(feat1: String, feat2: String, feat3: String)
case class Record(features: Features, arrayOfFeatures: Seq[Features])
val df = Seq(
Record(Features("hello", "world", "!"), Seq(Features("red", "orange", "yellow"), Features("green", "blue", "indigo")))
).toDF
df.printSchema
// root
// |-- features: struct (nullable = true)
// | |-- feat1: string (nullable = true)
// | |-- feat2: string (nullable = true)
// | |-- feat3: string (nullable = true)
// |-- arrayOfFeatures: array (nullable = true)
// | |-- element: struct (containsNull = true)
// | | |-- feat1: string (nullable = true)
// | | |-- feat2: string (nullable = true)
// | | |-- feat3: string (nullable = true)
df.show(false)
// +-----------------+----------------------------------------------+
// |features |arrayOfFeatures |
// +-----------------+----------------------------------------------+
// |[hello, world, !]|[[red, orange, yellow], [green, blue, indigo]]|
// +-----------------+----------------------------------------------+
Then dropping feat2 from features is as simple as:
import com.github.fqaiser94.mse.methods._
// drop feat2 from features
df.withColumn("features", $"features".dropFields("feat2")).show(false)
// +----------+----------------------------------------------+
// |features |arrayOfFeatures |
// +----------+----------------------------------------------+
// |[hello, !]|[[red, orange, yellow], [green, blue, indigo]]|
// +----------+----------------------------------------------+
I noticed there were a lot of follow-up comments on other solutions asking if there's a way to drop a Column nested inside a struct nested inside of an array. This can be done by combining the functions provided by the Make Structs Easy library with the functions provided by spark-hofs library, as follows:
import za.co.absa.spark.hofs._
// drop feat2 in each element of arrayOfFeatures
df.withColumn("arrayOfFeatures", transform($"arrayOfFeatures", features => features.dropFields("feat2"))).show(false)
// +-----------------+--------------------------------+
// |features |arrayOfFeatures |
// +-----------------+--------------------------------+
// |[hello, world, !]|[[red, yellow], [green, indigo]]|
// +-----------------+--------------------------------+
*Full disclosure: I am the author of the Make Structs Easy library that is referenced in this answer.
With Spark 3.1+, short and effective:
object DatasetOps {
implicit class DatasetOps[T](val dataset: Dataset[T]) {
def dropFields(fieldNames: String*): DataFrame =
fieldNames.foldLeft(dataset.toDF()) { (dataset, fieldName) =>
val subFieldRegex = "(\\w+)\\.(.+)".r
fieldName match {
case subFieldRegex(columnName, subFieldPath) =>
dataset.withColumn(columnName, col(columnName).dropFields(subFieldPath))
case _ => dataset.drop(fieldName)
}
}
}
}
This also preserves the required or not boolean in the schema.
Usage:
dataset.dropFields("some_column", "some_struct.some_sub_field.some_field")