Change value of nested column in DataFrame - scala

I have dataframe with two level nested fields
root
|-- request: struct (nullable = true)
| |-- dummyID: string (nullable = true)
| |-- data: struct (nullable = true)
| | |-- fooID: string (nullable = true)
| | |-- barID: string (nullable = true)
I want to update the value of fooId column here. I was able to update value for the first level for example dummyID column here using this question as reference How to add a nested column to a DataFrame
Input data:
{
"request": {
"dummyID": "test_id",
"data": {
"fooID": "abc",
"barID": "1485351"
}
}
}
output data:
{
"request": {
"dummyID": "test_id",
"data": {
"fooID": "def",
"barID": "1485351"
}
}
}
How can I do it using Scala?

Here is a generic solution to this problem that makes it possible to update any number of nested values, at any level, based on an arbitrary function applied in a recursive traversal:
def mutate(df: DataFrame, fn: Column => Column): DataFrame = {
// Get a projection with fields mutated by `fn` and select it
// out of the original frame with the schema reassigned to the original
// frame (explained later)
df.sqlContext.createDataFrame(df.select(traverse(df.schema, fn):_*).rdd, df.schema)
}
def traverse(schema: StructType, fn: Column => Column, path: String = ""): Array[Column] = {
schema.fields.map(f => {
f.dataType match {
case s: StructType => struct(traverse(s, fn, path + f.name + "."): _*)
case _ => fn(col(path + f.name))
}
})
}
This is effectively equivalent to the usual "just redefine the whole struct as a projection" solutions, but it automates re-nesting fields with the original structure AND preserves nullability/metadata (which are lost when you redefine the structs manually). Annoyingly, preserving those properties isn't possible while creating the projection (afaict) so the code above redefines the schema manually.
An example application:
case class Organ(name: String, count: Int)
case class Disease(id: Int, name: String, organ: Organ)
case class Drug(id: Int, name: String, alt: Array[String])
val df = Seq(
(1, Drug(1, "drug1", Array("x", "y")), Disease(1, "disease1", Organ("heart", 2))),
(2, Drug(2, "drug2", Array("a")), Disease(2, "disease2", Organ("eye", 3)))
).toDF("id", "drug", "disease")
df.show(false)
+---+------------------+-------------------------+
|id |drug |disease |
+---+------------------+-------------------------+
|1 |[1, drug1, [x, y]]|[1, disease1, [heart, 2]]|
|2 |[2, drug2, [a]] |[2, disease2, [eye, 3]] |
+---+------------------+-------------------------+
// Update the integer field ("count") at the lowest level:
val df2 = mutate(df, c => if (c.toString == "disease.organ.count") c - 1 else c)
df2.show(false)
+---+------------------+-------------------------+
|id |drug |disease |
+---+------------------+-------------------------+
|1 |[1, drug1, [x, y]]|[1, disease1, [heart, 1]]|
|2 |[2, drug2, [a]] |[2, disease2, [eye, 2]] |
+---+------------------+-------------------------+
// This will NOT necessarily be equal unless the metadata and nullability
// of all fields is preserved (as the code above does)
assertResult(df.schema.toString)(df2.schema.toString)
A limitation of this is that it cannot add new fields, only update existing ones (though the map can be changed into a flatMap and the function to return Array[Column] for that, if you don't care about preserving nullability/metadata).
Additionally, here is a more generic version for Dataset[T]:
case class Record(id: Int, drug: Drug, disease: Disease)
def mutateDS[T](df: Dataset[T], fn: Column => Column)(implicit enc: Encoder[T]): Dataset[T] = {
df.sqlContext.createDataFrame(df.select(traverse(df.schema, fn):_*).rdd, enc.schema).as[T]
}
// To call as typed dataset:
val fn: Column => Column = c => if (c.toString == "disease.organ.count") c - 1 else c
mutateDS(df.as[Record], fn).show(false)
// To call as untyped dataset:
implicit val encoder: ExpressionEncoder[Row] = RowEncoder(df.schema) // This is necessary regardless of sparkSession.implicits._ imports
mutateDS(df, fn).show(false)

One way, although cumbersome is to fully unpack and recreate the column by explicitly referencing each element of the original struct.
dataFrame.withColumn("person",
struct(
col("person.age").alias("age),
struct(
col("person.name.first").alias("first"),
lit("some new value").alias("last")).alias("name")))

Related

Scala Spark Dataframe new column from object column [duplicate]

I am trying to implement a custom UDT and be able to reference it from Spark SQL (as explained in the Spark SQL whitepaper, section 4.4.2).
The real example is to have a custom UDT backed by an off-heap data structure using Cap'n Proto, or similar.
For this posting, I have made up a contrived example. I know that I could just use Scala case classes and not have to do any work at all, but that isn't my goal.
For example, I have a Person containing several attributes and I want to be able to SELECT person.first_name FROM person. I'm running into the error Can't extract value from person#1 and I'm not sure why.
Here is the full source (also available at https://github.com/andygrove/spark-sql-udt)
package com.theotherandygrove
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.{SparkConf, SparkContext}
object Example {
def main(arg: Array[String]): Unit = {
val conf = new SparkConf()
.setAppName("Example")
.setMaster("local[*]")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
val schema = StructType(List(
StructField("person_id", DataTypes.IntegerType, true),
StructField("person", new MockPersonUDT, true)))
// load initial RDD
val rdd = sc.parallelize(List(
MockPersonImpl(1),
MockPersonImpl(2)
))
// convert to RDD[Row]
val rowRdd = rdd.map(person => Row(person.getAge, person))
// convert to DataFrame (RDD + Schema)
val dataFrame = sqlContext.createDataFrame(rowRdd, schema)
// register as a table
dataFrame.registerTempTable("person")
// selecting the whole object works fine
val results = sqlContext.sql("SELECT person.first_name FROM person WHERE person.age < 100")
val people = results.collect
people.map(row => {
println(row)
})
}
}
trait MockPerson {
def getFirstName: String
def getLastName: String
def getAge: Integer
def getState: String
}
class MockPersonUDT extends UserDefinedType[MockPerson] {
override def sqlType: DataType = StructType(List(
StructField("firstName", StringType, nullable=false),
StructField("lastName", StringType, nullable=false),
StructField("age", IntegerType, nullable=false),
StructField("state", StringType, nullable=false)
))
override def userClass: Class[MockPerson] = classOf[MockPerson]
override def serialize(obj: Any): Any = obj.asInstanceOf[MockPersonImpl].getAge
override def deserialize(datum: Any): MockPerson = MockPersonImpl(datum.asInstanceOf[Integer])
}
#SQLUserDefinedType(udt = classOf[MockPersonUDT])
#SerialVersionUID(123L)
case class MockPersonImpl(n: Integer) extends MockPerson with Serializable {
def getFirstName = "First" + n
def getLastName = "Last" + n
def getAge = n
def getState = "AK"
}
If I simply SELECT person FROM person then the query works. I just can't reference the attributes in SQL, even though they are defined in the schema.
You get this errors because schema defined by sqlType is never exposed and is not intended to be accessed directly. It simply provides a way to express a complex data types using native Spark SQL types.
You can access individual attributes using UDFs but first lets show that the internal structure is indeed not exposed:
dataFrame.printSchema
// root
// |-- person_id: integer (nullable = true)
// |-- person: mockperso (nullable = true)
To create UDF we need functions which take as an argument an object of a type represented by a given UDT:
import org.apache.spark.sql.functions.udf
val getFirstName = (person: MockPerson) => person.getFirstName
val getLastName = (person: MockPerson) => person.getLastName
val getAge = (person: MockPerson) => person.getAge
which can be wrapped using udf function:
val getFirstNameUDF = udf(getFirstName)
val getLastNameUDF = udf(getLastName)
val getAgeUDF = udf(getAge)
dataFrame.select(
getFirstNameUDF($"person").alias("first_name"),
getLastNameUDF($"person").alias("last_name"),
getAgeUDF($"person").alias("age")
).show()
// +----------+---------+---+
// |first_name|last_name|age|
// +----------+---------+---+
// | First1| Last1| 1|
// | First2| Last2| 2|
// +----------+---------+---+
To use these with raw SQL you have register functions through SQLContext:
sqlContext.udf.register("first_name", getFirstName)
sqlContext.udf.register("last_name", getLastName)
sqlContext.udf.register("age", getAge)
sqlContext.sql("""
SELECT first_name(person) AS first_name, last_name(person) AS last_name
FROM person
WHERE age(person) < 100""").show
// +----------+---------+
// |first_name|last_name|
// +----------+---------+
// | First1| Last1|
// | First2| Last2|
// +----------+---------+
Unfortunately it comes with a price tag attached. First of all every operation requires deserialization. It also substantially limits the ways in which query can be optimized. In particular any join operation on one of these fields requires a Cartesian product.
In practice if you want to encode a complex structure, which contains attributes that can be expressed using built-in types, it is better to use StructType:
case class Person(first_name: String, last_name: String, age: Int)
val df = sc.parallelize(
(1 to 2).map(i => (i, Person(s"First$i", s"Last$i", i)))).toDF("id", "person")
df.printSchema
// root
// |-- id: integer (nullable = false)
// |-- person: struct (nullable = true)
// | |-- first_name: string (nullable = true)
// | |-- last_name: string (nullable = true)
// | |-- age: integer (nullable = false)
df
.where($"person.age" < 100)
.select($"person.first_name", $"person.last_name")
.show
// +----------+---------+
// |first_name|last_name|
// +----------+---------+
// | First1| Last1|
// | First2| Last2|
// +----------+---------+
and reserve UDTs for actual types extensions like built-in VectorUDT or things that can benefit from a specific representation like enumerations.

function to each row of Spark Dataframe

I have a spark Dataframe (df) with 2 column's (Report_id and Cluster_number).
I want to apply a function (getClusterInfo) to df which will return the name for each cluster i.e. if cluster number is '3' then for a specific report_id, the 3 below mentioned rows will be written:
{"cluster_id":"1","influencers":[{"screenName":"A"},{"screenName":"B"},{"screenName":"C"},...]}
{"cluster_id":"2","influencers":[{"screenName":"D"},{"screenName":"E"},{"screenName":"F"},...]}
{"cluster_id":"3","influencers":[{"screenName":"G"},{"screenName":"H"},{"screenName":"E"},...]}
I am using foreach on df to apply getClusterInfo function, but can't figure out how to convert o/p to a Dataframe (Report_id, Array[cluster_info]).
Here is the code snippet:
df.foreach(row => {
val report_id = row(0)
val cluster_no = row(1).toString
val cluster_numbers = new Range(0, cluster_no.toInt - 1, 1)
for (cluster <- cluster_numbers.by(1)) {
val cluster_id = report_id + "_" + cluster
//get cluster influencers
val result = getClusterInfo(cluster_id)
println(result.get)
val res : String = result.get.toString()
// TODO ?
}
.. //TODO ?
})
Geenrally speaking, you shouldn't use foreach when you want to map something into something else; foreach is good for applying functions that only have side-effects and return nothing.
In this case, if I got the details right (probably not), you can use a User-Defined Function (UDF) and explode the result:
import org.apache.spark.sql.functions._
import spark.implicits._
// I'm assuming we have these case classes (or similar)
case class Influencer(screenName: String)
case class ClusterInfo(cluster_id: String, influencers: Array[Influencer])
// I'm assuming this method is supplied - with your own implementation
def getClusterInfo(clusterId: String): ClusterInfo =
ClusterInfo(clusterId, Array(Influencer(clusterId)))
// some sample data - assuming both columns are integers:
val df = Seq((222, 3), (333, 4)).toDF("Report_id", "Cluster_number")
// actual solution:
// UDF that returns an array of ClusterInfo;
// Array size is 'clusterNo', creates cluster id for each element and maps it to info
val clusterInfoUdf = udf { (clusterNo: Int, reportId: Int) =>
(1 to clusterNo).map(v => s"${reportId}_$v").map(getClusterInfo)
}
// apply UDF to each record and explode - to create one record per array item
val result = df.select(explode(clusterInfoUdf($"Cluster_number", $"Report_id")))
result.printSchema()
// root
// |-- col: struct (nullable = true)
// | |-- cluster_id: string (nullable = true)
// | |-- influencers: array (nullable = true)
// | | |-- element: struct (containsNull = true)
// | | | |-- screenName: string (nullable = true)
result.show(truncate = false)
// +-----------------------------+
// |col |
// +-----------------------------+
// |[222_1,WrappedArray([222_1])]|
// |[222_2,WrappedArray([222_2])]|
// |[222_3,WrappedArray([222_3])]|
// |[333_1,WrappedArray([333_1])]|
// |[333_2,WrappedArray([333_2])]|
// |[333_3,WrappedArray([333_3])]|
// |[333_4,WrappedArray([333_4])]|
// +-----------------------------+

Spark - recursive function as udf generates an Exception

I am working with DataFrames which elements have got a schema similar to:
root
|-- NPAData: struct (nullable = true)
| |-- NPADetails: struct (nullable = true)
| | |-- location: string (nullable = true)
| | |-- manager: string (nullable = true)
| |-- service: array (nullable = true)
| | |-- element: struct (containsNull = true)
| | | |-- serviceName: string (nullable = true)
| | | |-- serviceCode: string (nullable = true)
|-- NPAHeader: struct (nullable = true)
| | |-- npaNumber: string (nullable = true)
| | |-- date: string (nullable = true)
In my DataFrame I want to group all elements which has the same NPAHeader.code, so to do that I am using the following line:
val groupedNpa = orderedNpa.groupBy($"NPAHeader.code" ).agg(collect_list(struct($"NPAData",$"NPAHeader")).as("npa"))
After this I have a dataframe with the following schema:
StructType(StructField(npaNumber,StringType,true), StructField(npa,ArrayType(StructType(StructField(NPAData...)))))
An example of each Row would be something similar to:
[1234,WrappedArray([npaNew,npaOlder,...npaOldest])]
Now what I want is to generate another DataFrame with picks up just one of the element in the WrappedArray, so I want an output similar to:
[1234,npaNew]
Note: The chosen element from the WrappedArray is the one that matches a complext logic after iterating over the whole WrappedArray. But to simplify the question, I will pick up always the last element of the WrappedArray (after iterating all over it).
To do so, I want to define a recurside udf
import org.apache.spark.sql.functions.udf
def returnRow(elementList : Row)(index:Int): Row = {
val dif = elementList.size - index
val row :Row = dif match{
case 0 => elementList.getAs[Row](index)
case _ => returnRow(elementList)(index + 1)
}
row
}
val returnRow_udf = udf(returnRow _)
groupedNpa.map{row => (row.getAs[String]("npaNumber"),returnRow_udf(groupedNpa("npa")(0)))}
But I am getting the following error in the map:
Exception in thread "main" java.lang.UnsupportedOperationException:
Schema for type Int => Unit is not supported
What am I doing wrong?
As an aside, I am not sure if I am passing correctly the npa column with groupedNpa("npa"). I am accesing the WrappedArray as a Row, because I don't know how to iterate over Array[Row] (the get(index) method is not present in Array[Row])
TL;DR Just use one of the methods described in How to select the first row of each group?
If you want to use complex logic, and return Row you can skip SQL API and use groupByKey:
val f: (String, Iterator[org.apache.spark.sql.Row]) => Row
val encoder: Encoder
df.groupByKey(_.getAs[String]("NPAHeader.code")).mapGroups(f)(encoder)
or better:
val g: (Row, Row) => Row
df.groupByKey(_.getAs[String]("NPAHeader.code")).reduceGroups(g)
where encoder is a valid RowEncoder (Encoder error while trying to map dataframe row to updated row).
Your code is faulty in multiple ways:
groupBy doesn't guarantee the order of values. So:
orderBy(...).groupBy(....).agg(collect_list(...))
can have non-deterministic output. If you really decide to go this route you should skip orderBy and sort collected array explicitly.
You cannot pass curried function to udf. You'd have to uncurry it first, but it would require different order of arguments (see example below).
If you could, this might be the correct way to call it (Note that you omit the second argument):
returnRow_udf(groupedNpa("npa")(0))
To make it worse, you call it inside map, where udfs are not applicable at all.
udf cannot return Row. It has to return external Scala type.
External representation for array<struct> is Seq[Row]. You cannot just substitute it with Row.
SQL arrays can be accessed by index with apply:
df.select($"array"(size($"array") - 1))
but it is not a correct method due to non-determinism. You could apply sort_array, but as pointed out at the beginning, there are more efficient solutions.
Surprisingly recursion is not so relevant. You could design function like this:
def size(i: Int=0)(xs: Seq[Any]): Int = xs match {
case Seq() => i
case null => i
case Seq(h, t # _*) => size(i + 1)(t)
}
val size_ = udf(size() _)
and it would work just fine:
Seq((1, Seq("a", "b", "c"))).toDF("id", "array")
.select(size_($"array"))
although recursion is an overkill, if you can just iterate over Seq.

Why Spark SQL translates String "null" to Object null for Float/Double types?

I have a dataframe containing float and double values.
scala> val df = List((Float.NaN, Double.NaN), (1f, 0d)).toDF("x", "y")
df: org.apache.spark.sql.DataFrame = [x: float, y: double]
scala> df.show
+---+---+
| x| y|
+---+---+
|NaN|NaN|
|1.0|0.0|
+---+---+
scala> df.printSchema
root
|-- x: float (nullable = false)
|-- y: double (nullable = false)
When I replace NaN values with null value, I gave null as String to the Map in fill operation.
scala> val map = df.columns.map((_, "null")).toMap
map: scala.collection.immutable.Map[String,String] = Map(x -> null, y -> null)
scala> df.na.fill(map).printSchema
root
|-- x: float (nullable = true)
|-- y: double (nullable = true)
scala> df.na.fill(map).show
+----+----+
| x| y|
+----+----+
|null|null|
| 1.0| 0.0|
+----+----+
And I got correct value. But I was not able to understand as to How/Why Spark SQL is translating null as a String to a null object ?
If you looked in to the fill function in Dataset, It checks the datatype and tries to convert to datatype of its column's schema. If it can be converted then it converts otherwise it returns null.
It does not convert to "null" to object null but it returns null if exception occurs while converting.
val map = df.columns.map((_, "WHATEVER")).toMap
gives null
and val map = df.columns.map((_, "9999.99")).toMap
gives 9999.99
If you want to update the NAN with same datatype, you can get result as expected.
Hope this helps you to understand!
It's not that "null" as a String translates to a null object. You can try using the transformation with any String and still get null (with the exception of strings that can directly be cast to double/float, see below). For example, using
val map = df.columns.map((_, "abc")).toMap
would give the same result. My guess is that as the columns are of type float and double converting the NaN values to a string will give null. Using a number instead would work as expected, e.g.
val map = df.columns.map((_, 1)).toMap
As some strings can directly be cast to double or float those are also possible to use in this case.
val map = df.columns.map((_, "1")).toMap
I've looked into the source-code, in fill your string is casted to a double/float:
private def fillCol[T](col: StructField, replacement: T): Column = {
col.dataType match {
case DoubleType | FloatType =>
coalesce(nanvl(df.col("`" + col.name + "`"), lit(null)),
lit(replacement).cast(col.dataType)).as(col.name)
case _ =>
coalesce(df.col("`" + col.name + "`"), lit(replacement).cast(col.dataType)).as(col.name)
}
}
The relevant source code for casting is here (similar code for Floats):
Cast.scala (taken from Spark 1.6.3) :
// DoubleConverter
private[this] def castToDouble(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => try s.toString.toDouble catch {
case _: NumberFormatException => null
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1d else 0d)
case DateType =>
buildCast[Int](_, d => null)
case TimestampType =>
buildCast[Long](_, t => timestampToDouble(t))
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
}
So Spark trys to convert the String to a Double (s.toString.toDouble), if thats not possible (i.e. you get a NumberFormatException) you get a null. So instead of "null" you could also use "foo" witht he same outcome. But if you use "1.0" in your map, than NaNs and nulls will be replaced with 1.0 because the String "1.0" is indeed parsable to a Double.

How to Validate contents of Spark Dataframe

I have below Scala Spark code base, which works well, but should not.
The 2nd column has data of mixed types, whereas in Schema I have defined it of IntegerType. My actual program has over 100 columns, and keep deriving multiple child DataFrames after transformations.
How can I validate that contents of RDD or DataFrame fields have correct datatype values, and thus ignore invalid rows or change contents of column to some default value. Any more pointers for data quality checks with DataFrame or RDD are appreciated.
var theSeq = Seq(("X01", "41"),
("X01", 41),
("X01", 41),
("X02", "ab"),
("X02", "%%"))
val newRdd = sc.parallelize(theSeq)
val rowRdd = newRdd.map(r => Row(r._1, r._2))
val theSchema = StructType(Seq(StructField("ID", StringType, true),
StructField("Age", IntegerType, true)))
val theNewDF = sqc.createDataFrame(rowRdd, theSchema)
theNewDF.show()
First of all passing schema is simply a way to avoid type inference. It is not validated or enforced during DataFrame creation. On a side note I wouldn't describe ClassCastException as working well. For a moment I thought you actually found a bug.
I think the important question is how you get data like theSeq / newRdd in the first place. Is it something you parse by yourself, is it received from an external component? Simply looking at the type (Seq[(String, Any)] / RDD[(String, Any)] respectively) you already know it is not a valid input for a DataFrame. Probably the way to handle things at this level is to embrace static typing. Scala provides quite a few neat ways to handle unexpected conditions (Try, Either, Option) where the last one is the simplest one, and as a bonus works well with Spark SQL. Rather simplistic way to handle things could look like this
def validateInt(x: Any) = x match {
case x: Int => Some(x)
case _ => None
}
def validateString(x: Any) = x match {
case x: String => Some(x)
case _ => None
}
val newRddOption: RDD[(Option[String], Option[Int])] = newRdd.map{
case (id, age) => (validateString(id), validateInt(age))}
Since Options can be easily composed you can add additional checks like this:
def validateAge(age: Int) = {
if(age >= 0 && age < 150) Some(age)
else None
}
val newRddValidated: RDD[(Option[String], Option[Int])] = newRddOption.map{
case (id, age) => (id, age.flatMap(validateAge))}
Next instead of Row which is a very crude container I would use cases classes:
case class Record(id: Option[String], age: Option[Int])
val records: RDD[Record] = newRddValidated.map{case (id, age) => Record(id, age)}
At this moment all you have to do is call toDF:
import org.apache.spark.sql.DataFrame
val df: DataFrame = records.toDF
df.printSchema
// root
// |-- id: string (nullable = true)
// |-- age: integer (nullable = true)
This was the hard but arguably a more elegant way. A faster is to let SQL casting system to do a job for you. First lets convert everything to Strings:
val stringRdd: RDD[(String, String)] = sc.parallelize(theSeq).map(
p => (p._1.toString, p._2.toString))
Next create a DataFrame:
import org.apache.spark.sql.types._
import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.col
val df: DataFrame = stringRdd.toDF("id", "age")
val expectedTypes = Seq(StringType, IntegerType)
val exprs: Seq[Column] = df.columns.zip(expectedTypes).map{
case (c, t) => col(c).cast(t).alias(c)}
val dfProcessed: DataFrame = df.select(exprs: _*)
And the result:
dfProcessed.printSchema
// root
// |-- id: string (nullable = true)
// |-- age: integer (nullable = true)
dfProcessed.show
// +---+----+
// | id| age|
// +---+----+
// |X01| 41|
// |X01| 41|
// |X01| 41|
// |X02|null|
// |X02|null|
// +---+----+
In version 1.4 or older
import org.apache.spark.sql.execution.debug._
theNewDF.typeCheck
It was removed via SPARK-9754 though. I haven't checked but I think typeCheck becomes sqlContext.debug beforehand