This has me confused. I'm using "spark-testing-base_2.11" % "2.0.0_0.5.0" for the test. Can anyone explain why the map function changes the schema if using a Dataset, but works if I use the RDD? Any insights greatly appreciated.
import com.holdenkarau.spark.testing.SharedSparkContext
import org.apache.spark.sql.{ Encoders, SparkSession }
import org.scalatest.{ FunSpec, Matchers }
class TransformSpec extends FunSpec with Matchers with SharedSparkContext {
describe("data transformation") {
it("the rdd maintains the schema") {
val spark = SparkSession.builder.getOrCreate()
import spark.implicits._
val personEncoder = Encoders.product[TestPerson]
val personDS = Seq(TestPerson("JoeBob", 29)).toDS
personDS.schema shouldEqual personEncoder.schema
val mappedSet = personDS.rdd.map { p: TestPerson => p.copy(age = p.age + 1) }.toDS
personEncoder.schema shouldEqual mappedSet.schema
}
it("datasets choke on explicit schema") {
val spark = SparkSession.builder.getOrCreate()
import spark.implicits._
val personEncoder = Encoders.product[TestPerson]
val personDS = Seq(TestPerson("JoeBob", 29)).toDS
personDS.schema shouldEqual personEncoder.schema
val mappedSet = personDS.map[TestPerson] { p: TestPerson => p.copy(age = p.age + 1) }
personEncoder.schema shouldEqual mappedSet.schema
}
}
}
case class TestPerson(name: String, age: Int)
A couple of things are conspiring against you here. Spark appears to have special casing for what types it considers nullable.
case class TestTypes(
scalaString: String,
javaString: java.lang.String,
myString: MyString,
scalaInt: Int,
javaInt: java.lang.Integer,
myInt: MyInt
)
Encoders.product[TestTypes].schema.printTreeString results in:
root
|-- scalaString: string (nullable = true)
|-- javaString: string (nullable = true)
|-- myString: struct (nullable = true)
| |-- value: string (nullable = true)
|-- scalaInt: integer (nullable = false)
|-- javaInt: integer (nullable = true)
|-- myInt: struct (nullable = true)
| |-- value: integer (nullable = false)
but if you map the types you will end up with everything nullable
val testTypes: Seq[TestTypes] = Nil
val testDS = testTypes.toDS
testDS.map(foo => foo).mapped.schema.printTreeString results in everything being nullable:
root
|-- scalaString: string (nullable = true)
|-- javaString: string (nullable = true)
|-- myString: struct (nullable = true)
| |-- value: string (nullable = true)
|-- scalaInt: integer (nullable = true)
|-- javaInt: integer (nullable = true)
|-- myInt: struct (nullable = true)
| |-- value: integer (nullable = true)
Even if you force the schema to be correct, Spark explicitely ignores nullability comparisons when applying a schema, which is why when you convert back to the typed representation you lose the few nullability guarantees you had.
You could enrich your types to be able to force a nonNull schema:
implicit class StructImprovements(s: StructType) {
def nonNull: StructType = StructType(s.map(_.copy(nullable = false)))
}
implicit class DsImprovements[T: Encoder](ds: Dataset[T]) {
def nonNull: Dataset[T] = {
val nnSchema = ds.schema.nonNull
applySchema(ds.toDF, nnSchema).as[T]
}
}
val mappedSet = personDS.map { p =>
p.copy(age = p.age + 1)
}.nonNull
But you will find it evaporates when applying any interesting operation then when comparing schemas again if the shape is the same except nullability Spark will pass it through as the same.
This appears to be by design https://github.com/apache/spark/pull/11785
A map is a transformation operation on the data. It takes the input, and a function and applies that function to all elements of the input data. The output is the set of return values of this function. So the schmea of the output data is dependant on the return type of the function.
The map operation is a fairly standard and heavily used operation in functional programming. Look at https://en.m.wikipedia.org/wiki/Map_(higher-order_function) if you want to read more.
Related
I'm new in scala and spark, I have some issue,
I tried to load a dataset and filtering with timestamp but I got error.
I have 1 case class that contain 2 case class, after saving generated dataset, I tried to reload dataset.
this is the schema of created dataset
// root
// |-- date: date (nullable = false)
// |-- info: struct (nullable = false)
// | |-- name: string (nullable = true)
// | |-- inlife: bool (nullable = true)
// |-- extra: struct (nullable = false)
// | |-- prob: string (nullable = true)
// | |-- ext: string (nullable = true)
// |-- dt: string (nullable = true)
But I got error that :
No such struct field name in prob, ext;
here is the pice of my code
val spark = SparkSession.builder()
.master("local[2]")
.getOrCreate()
case class Info(
name: Option[String] = None,
inLife: Option[Boolean] = None)
case class Extra(
prob: Option[Float] = None,
ext: Option[String] = None)
case class Person(
date : java.sql.Timestamp = java.sql.Timestamp.from(Instant.EPOCH),
inf : Info = Info(),
pExt : Extra = Extra(),
dt : Option[java.sql.Timestamp] = Option(java.sql.Timestamp.from(Instant.EPOCH))
)
implicit val encI = Encoders.product[Info]
implicit val encE = Encoders.product[Extra]
implicit val encP = Encoders.product[Person]
val p = Seq(Person())
val ds = spark.createDataset(p)
ds.show
ds.printSchema
val uri = new URI("/Users/me/testSaving/test")
ds.repartition(10)
.write.mode(SaveMode.Append)
.partitionBy("date", "dt")
.parquet(uri.toString)
When I try to load it without filtering, the dataset loads without problem,
val loadedDS = spark.read.format("parquet")
.schema(spark.emptyDataset[Person].schema)
.option("basePath", uri.toString)
.load(uri.toString)
But when I load with filtering there is many errors:
val loadedDS = spark.read.format("parquet")
.schema(spark.emptyDataset[Person].schema)
.option("basePath", uri.toString)
.load(uri.toString)
.filter({
val ts = unix_timestamp(to_timestamp(col("dt"), "yyyy-MM-dd HH:mm:ss"))
lit(Instant.MIN.getEpochSecond) <= ts and ts < lit(Instant.MAX.getEpochSecond)}).as[Person]
Knowing that this part of the code was working very well before adding the dt field in the Person class
I am trying to move data from GP to HDFS using Scala & Spark.
val execQuery = "select * from schema.tablename"
val yearDF = spark.read.format("jdbc").option("url", connectionUrl).option("dbtable", s"(${execQuery}) as year2016").option("user", devUserName).option("password", devPassword).option("partitionColumn","header_id").option("lowerBound", 19919927).option("upperBound", 28684058).option("numPartitions",30).load()
val yearDFSchema = yearDF.schema
The schema for yearDF is:
root
|-- source_system_name: string (nullable = true)
|-- table_refresh_delay_min: decimal(38,30) (nullable = true)
|-- release_number: decimal(38,30) (nullable = true)
|-- change_number: decimal(38,30) (nullable = true)
|-- interface_queue_enabled_flag: string (nullable = true)
|-- rework_enabled_flag: string (nullable = true)
|-- fdm_application_id: decimal(15,0) (nullable = true)
|-- history_enabled_flag: string (nullable = true)
The schema of same table on hive which is given by our project:
val hiveColumns = source_system_name:String|description:String|creation_date:Timestamp|status:String|status_date:Timestamp|table_refresh_delay_min:Timestamp|release_number:Double|change_number:Double|interface_queue_enabled_flag:String|rework_enabled_flag:String|fdm_application_id:Bigint|history_enabled_flag:String
So I took hiveColumns and created a new StructType as given below:
def convertDatatype(datatype: String): DataType = {
val convert = datatype match {
case "string" => StringType
case "bigint" => LongType
case "int" => IntegerType
case "double" => DoubleType
case "date" => TimestampType
case "boolean" => BooleanType
case "timestamp" => TimestampType
}
convert
}
val schemaList = hiveColumns.split("\\|")
val newSchema = new StructType(schemaList.map(col => col.split(":")).map(e => StructField(e(0), convertDatatype(e(1)), true)))
newSchema.printTreeString()
root
|-- source_system_name: string (nullable = true)
|-- table_refresh_delay_min: double (nullable = true)
|-- release_number: double (nullable = true)
|-- change_number: double (nullable = true)
|-- interface_queue_enabled_flag: string (nullable = true)
|-- rework_enabled_flag: string (nullable = true)
|-- fdm_application_id: long (nullable = true)
|-- history_enabled_flag: string (nullable = true)
When I try to apply my new schema: schemaStructType on yearDF as below, I get the exception:
Caused by: java.lang.RuntimeException: java.math.BigDecimal is not a valid external type for schema of double
The exception occurs due to conversion of decimal to double.
What I don't understand is how can I convert the datatype of columns: table_refresh_delay_min, release_number, change_number, fdm_application_id in the StructType: newSchema from DoubleType to their corresponding datatypes present in yearDF's Schema. i.e.
If the column in yearDFSchema has a decimal datatype with precision more than zero, in this case decimal(38,30), I need to convert the same column's datatype in newSchema to DecimalType(38,30)
Could anyone let me know how can I achieve it ?
Errors like this occur when you try to apply schema on RDD[Row], using Developer's API functions:
def createDataFrame(rows: List[Row], schema: StructType): DataFrame
def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame
def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame
In such cases stored data types have to match external (i.e. Value type in Scala) data types as listed in the official SQL and no type casting or coercion is applied.
Therefore it is your responsibility as an user to ensure that the date and schema are compatible.
The description of the problem you've provided indicates rather different scenario, which asks for CAST. Let's create dataset with exact the same schema as in your example:
val yearDF = spark.createDataFrame(
sc.parallelize(Seq[Row]()),
StructType(Seq(
StructField("source_system_name", StringType),
StructField("table_refresh_delay_min", DecimalType(38, 30)),
StructField("release_number", DecimalType(38, 30)),
StructField("change_number", DecimalType(38, 30)),
StructField("interface_queue_enabled_flag", StringType),
StructField("rework_enabled_flag", StringType),
StructField("fdm_application_id", DecimalType(15, 0)),
StructField("history_enabled_flag", StringType)
)))
yearDF.printSchema
root
|-- source_system_name: string (nullable = true)
|-- table_refresh_delay_min: decimal(38,30) (nullable = true)
|-- release_number: decimal(38,30) (nullable = true)
|-- change_number: decimal(38,30) (nullable = true)
|-- interface_queue_enabled_flag: string (nullable = true)
|-- rework_enabled_flag: string (nullable = true)
|-- fdm_application_id: decimal(15,0) (nullable = true)
|-- history_enabled_flag: string (nullable = true)
and desired types like
val dtypes = Seq(
"source_system_name" -> "string",
"table_refresh_delay_min" -> "double",
"release_number" -> "double",
"change_number" -> "double",
"interface_queue_enabled_flag" -> "string",
"rework_enabled_flag" -> "string",
"fdm_application_id" -> "long",
"history_enabled_flag" -> "string"
)
then you can just map:
val mapping = dtypes.toMap
yearDF.select(yearDF.columns.map { c => col(c).cast(mapping(c)) }: _*).printSchema
root
|-- source_system_name: string (nullable = true)
|-- table_refresh_delay_min: double (nullable = true)
|-- release_number: double (nullable = true)
|-- change_number: double (nullable = true)
|-- interface_queue_enabled_flag: string (nullable = true)
|-- rework_enabled_flag: string (nullable = true)
|-- fdm_application_id: long (nullable = true)
|-- history_enabled_flag: string (nullable = true)
This of course assumes that actual and desired types are compatible, and CAST is allowed.
If you still experience problems due you to peculiarities of specific JDBC driver, you should consider placing cast directly in the query, either manually (In Apache Spark 2.0.0, is it possible to fetch a query from an external database (rather than grab the whole table)?)
val externalDtypes = Seq(
"source_system_name" -> "text",
"table_refresh_delay_min" -> "double precision",
"release_number" -> "float8",
"change_number" -> "float8",
"interface_queue_enabled_flag" -> "string",
"rework_enabled_flag" -> "string",
"fdm_application_id" -> "bigint",
"history_enabled_flag" -> "string"
)
val externalDtypes = dtypes.map {
case (c, t) => s"CAST(`$c` AS $t)"
} .mkString(", ")
val dbTable = s"""(select $fields from schema.tablename) as tmp"""
or through custom schema:
spark.read
.format("jdbc")
.option(
"customSchema",
dtypes.map { case (c, t) => s"`$c` $t" } .mkString(", "))
...
.load()
I am trying to replace certain characters in all the columns of my DataFrame which has lot of nested Struct Types.
I tried to process the schema fields recursively and for some reason it is only renaming the fields at the top level even through it is reaching the leaf nodes.
I am trying replace the : char in the column name with _
Here is the scala code I have writte:
class UpdateSchema {
val logger = LoggerFactory.getLogger(classOf[UpdateSchema])
Logger.getLogger("org").setLevel(Level.OFF)
Logger.getLogger("akka").setLevel(Level.OFF)
val sparkSession = SparkLauncher.spark
import sparkSession.implicits._
def updateSchema(filePath:String):Boolean ={
logger.info(".updateSchema() : filePath ={}",filePath);
logger.info(".updateSchema() : sparkSession ={}",sparkSession);
if(sparkSession!=null){
var xmlDF = sparkSession
.read
.format("com.databricks.spark.xml")
.option("rowTag","ns:fltdMessage")
.option("inferschema","true")
.option("attributePrefix","attr_")
.load(filePath)
.toDF()
xmlDF.printSchema()
val updatedDF = renameDataFrameColumns(xmlDF.toDF())
updatedDF.printSchema()
}
else
logger.info(".updateSchema(): Spark Session is NULL !!!");
false;
}
def replaceSpecialChars(str:String):String ={
val newColumn:String = str.replaceAll(":", "_")
//logger.info(".replaceSpecialChars() : Old Column Name =["+str+"] New Column Name =["+newColumn+"]")
return newColumn
}
def renameColumn(df:DataFrame,colName:String,prefix:String):DataFrame ={
val newColuName:String = replaceSpecialChars(colName)
logger.info(".renameColumn(): prefix=["+prefix+"] colName=["+colName+"] New Column Name=["+newColuName+"]")
if(prefix.equals("")){
if(df.col(colName)!=null){
return df.withColumnRenamed(colName, replaceSpecialChars(colName))
}
else{
logger.error(".logSchema() : Column ["+prefix+"."+colName+"] Not found in DataFrame !! ")
logger.info("Prefix ="+prefix+" Existing Columns =["+df.columns.mkString("),(")+"]")
throw new Exception("Unable to find Column ["+prefix+"."+colName+"]")
}
}
else{
if(df.col(prefix+"."+colName)!=null){
return df.withColumnRenamed(prefix+"."+colName, prefix+"."+replaceSpecialChars(colName))
}
else{
logger.error(".logSchema() : Column ["+prefix+"."+colName+"] Not found in DataFrame !! ")
logger.info("Prefix ="+prefix+" Existing Columns =["+df.columns.mkString("),(")+"]")
throw new Exception("Unable to find Column ["+prefix+"."+colName+"]")
}
}
}
def getStructType(schema:StructType,fieldName:String):StructType = {
schema.fields.foreach(field => {
field.dataType match{
case st:StructType => {
logger.info(".getStructType(): Current Field Name =["+field.name.toString()+"] Checking for =["+fieldName+"]")
if(field.name.toString().equals(fieldName)){
return field.dataType.asInstanceOf[StructType]
}
else{
getStructType(st,fieldName)
}
}
case _ =>{
logger.info(".getStructType(): Non Struct Type. Ignoring Filed=["+field.name.toString()+"]");
}
}
})
throw new Exception("Unable to find Struct Type for filed Name["+fieldName+"]")
}
def processSchema(df:DataFrame,schema:StructType,prefix:String):DataFrame ={
var updatedDF:DataFrame =df
schema.fields.foreach(field =>{
field.dataType match {
case st:StructType => {
logger.info(".processSchema() : Struct Type =["+st+"]");
logger.info(".processSchema() : Field Data Type =["+field.dataType+"]");
logger.info(".processSchema() : Renaming the Struct Field =["+field.name.toString()+"] st=["+st.fieldNames.mkString(",")+"]")
updatedDF = renameColumn(updatedDF,field.name.toString(),prefix)
logger.info(".processSchema() : Column List after Rename =["+updatedDF.columns.mkString(",")+"]")
// updatedDF.schema.fields.foldLeft(z)(op)
val renamedCol:String = replaceSpecialChars(field.name.toString())
var fieldType:DataType = null;
//if(prefix.equals(""))
fieldType = schema.fields.find(f =>{ (f.name.toString().equals(field.name.toString()))}).get.dataType
if(prefix.trim().equals("")
//&& fieldType.isInstanceOf[StructType]
){
updatedDF = processSchema(updatedDF,
getStructType(updatedDF.schema,renamedCol),
replaceSpecialChars(field.name.toString()))
}
else{
updatedDF = processSchema(updatedDF,
getStructType(updatedDF.schema,renamedCol),
prefix+"."+replaceSpecialChars(field.name.toString()))
}
}
case _ => {
updatedDF = renameColumn(updatedDF,field.name.toString(),prefix)
}
}
})
//updatedDF.printSchema()
return updatedDF
}
def renameDataFrameColumns(df:DataFrame):DataFrame ={
val schema = df.schema;
return processSchema(df,schema,"")
}
}
Here's a recursive method that revise a DataFrame schema by renaming via replaceAll any columns whose name consists of a substring to be replaced:
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
def renameAllColumns(schema: StructType, from: String, to: String): StructType = {
def recurRename(schema: StructType, from: String, to:String): Seq[StructField] =
schema.fields.map{
case StructField(name, dtype: StructType, nullable, meta) =>
StructField(name.replaceAll(from, to), StructType(recurRename(dtype, from, to)), nullable, meta)
case StructField(name, dtype: ArrayType, nullable, meta) => dtype.elementType match {
case struct: StructType => StructField(name.replaceAll(from, to), ArrayType(StructType(recurRename(struct, from, to)), true), nullable, meta)
case other => StructField(name.replaceAll(from, to), other, nullable, meta)
}
case StructField(name, dtype, nullable, meta) =>
StructField(name.replaceAll(from, to), dtype, nullable, meta)
}
StructType(recurRename(schema, from, to))
}
Testing the method on a sample DataFrame with a nested structure:
case class M(i: Int, `p:q`: String)
case class N(j: Int, m: M)
val df = Seq(
(1, "a", Array(N(7, M(11, "x")), N(72, M(112, "x2")))),
(2, "b", Array(N(8, M(21, "y")))),
(3, "c", Array(N(9, M(31, "z"))))
).toDF("c1", "c2:0", "c3")
df.printSchema
// root
// |-- c1: integer (nullable = false)
// |-- c2:0: string (nullable = true)
// |-- c3: array (nullable = true)
// | |-- element: struct (containsNull = true)
// | | |-- j: integer (nullable = false)
// | | |-- m: struct (nullable = true)
// | | | |-- i: integer (nullable = false)
// | | | |-- p:q: string (nullable = true)
val newSchema = renameAllColumns(df.schema, ":", "_")
spark.createDataFrame(df.rdd, newSchema).printSchema
// root
// |-- c1: integer (nullable = false)
// |-- c2_0: string (nullable = true)
// |-- c3: array (nullable = true)
// | |-- element: struct (containsNull = true)
// | | |-- j: integer (nullable = false)
// | | |-- m: struct (nullable = true)
// | | | |-- i: integer (nullable = false)
// | | | |-- p_q: string (nullable = true)
Note that since method replaceAll supports Regex pattern, one can apply the method with more versatile replacement condition. For example, here's how to trim off column name starting from the ':' character:
val newSchema = renameAllColumns(df.schema, """:.*""", "")
spark.createDataFrame(df.rdd, newSchema).printSchema
// root
// |-- c1: integer (nullable = false)
// |-- c2: string (nullable = true)
// |-- c3: array (nullable = true)
// | |-- element: struct (containsNull = true)
// | | |-- j: integer (nullable = false)
// | | |-- m: struct (nullable = true)
// | | | |-- i: integer (nullable = false)
// | | | |-- p: string (nullable = true)
Unfortunately, you can't easily rename a single nested field using withFieldRenamed liked you are attempting to do. The only way I know of to rename nested fields is to do a cast on the field provided a type with the same structure and data types but new field names. This has to be done on the top level field, so you need to do all the fields in one go. Here is an example:
Create some input data
case class InnerRecord(column1: String, column2: Int)
case class Record(field: InnerRecord)
val df = Seq(
Record(InnerRecord("a", 1)),
Record(InnerRecord("b", 2))
).toDF
df.printSchema
Input data looks like this:
root
|-- field: struct (nullable = true)
| |-- column1: string (nullable = true)
| |-- column2: integer (nullable = false)
This is an example using withColumnRenamed. You'll notice in the output it doesn't actually do anything!
val updated = df.withColumnRenamed("field.column1", "field.newname")
updated.printSchema
root
|-- field: struct (nullable = true)
| |-- column1: string (nullable = true)
| |-- column2: integer (nullable = false)
Here is how you can do it instead with casting. The function will recursively recreate the nested field type while updating the name. In my case I just replaced "column" with "col_". I also only ran it on one field, but you could easily loop across all your fields in schema.
import org.apache.spark.sql.types._
def rename(dataType: DataType): DataType = dataType match {
case StructType(fields) =>
StructType(fields.map {
case StructField(name, dtype, nullable, meta) =>
val newName = name.replace("column", "col_")
StructField(newName, rename(dtype), nullable, meta)
})
case _ => dataType
}
val fDataType = df.schema.filter(_.name == "field").head.dataType
val updated = df.withColumn("field", $"field".cast(rename(fDataType)))
updated.printSchema
Which prints:
root
|-- field: struct (nullable = true)
| |-- col_1: string (nullable = true)
| |-- col_2: integer (nullable = false)
I had some issues with the answer from #Leo C, so I've used a slight variation instead. It also takes in any mapping function f to rename.
def renameAllColumns(schema: StructType, f: String => String): StructType = {
def recurRename(schema: StructType, f: String => String): Seq[StructField] =
schema.fields.map{
case StructField(name, dtype: StructType, nullable, meta) =>
StructField(f(name), StructType(recurRename(dtype, f)), nullable, meta)
case StructField(name, dtype: ArrayType, nullable, meta) => dtype.elementType match {
case struct: StructType => StructField(f(name), ArrayType(StructType(recurRename(struct, f)), true), nullable, meta)
case other => StructField(f(name), ArrayType(other), nullable, meta)
}
case StructField(name, dtype, nullable, meta) =>
StructField(f(name), dtype, nullable, meta)
}
StructType(recurRename(schema, f))
}
When I retrieve a dataset in Spark 2, using a select statement the underlying columns inherit the data types of the queried columns.
val ds1 = spark.sql("select 1 as a, 2 as b, 'abd' as c")
ds1.printSchema()
root
|-- a: integer (nullable = false)
|-- b: integer (nullable = false)
|-- c: string (nullable = false)
Now if I convert this into a case class, it will correctly convert the values, but the underlying schema is still wrong.
case class abc(a: Double, b: Double, c: String)
val ds2 = ds1.as[abc]
ds2.printSchema()
root
|-- a: integer (nullable = false)
|-- b: integer (nullable = false)
|-- c: string (nullable = false)
ds2.collect
res18: Array[abc] = Array(abc(1.0,2.0,abd))
I "SHOULD" be able to specify the encoder to use when I create the second dataset, but scala seems to ignore this parameter (Is this a BUG?):
val abc_enc = org.apache.spark.sql.Encoders.product[abc]
val ds2 = ds1.as[abc](abc_enc)
ds2.printSchema
root
|-- a: integer (nullable = false)
|-- b: integer (nullable = false)
|-- c: string (nullable = false)
So the only way I can see to do this simply, without very complex mapping is to use createDataset, but this requires a collect on the underlying object, so it's not ideal.
val ds2 = spark.createDataset(ds1.as[abc].collect)
This is an open issue in Spark API (check this ticket SPARK-17694)
So what you need to do is doing an extra explicit cast. Something like this should work:
ds1.as[abc].map(x => x : abc)
You can simply use cast method on columns as
import sqlContext.implicits._
val ds2 = ds1.select($"a".cast(DoubleType), $"a".cast(DoubleType), $"c")
ds2.printSchema()
you should have
root
|-- a: double (nullable = false)
|-- a: double (nullable = false)
|-- c: string (nullable = false)
You could also cast the column while selecting with sql query as below
import spark.implicits._
val ds = Seq((1,2,"abc"),(1,2,"abc")).toDF("a", "b","c").createOrReplaceTempView("temp")
val ds1 = spark.sql("select cast(a as Double) , cast (b as Double), c from temp")
ds1.printSchema()
This have the schema as
root
|-- a: double (nullable = false)
|-- b: double (nullable = false)
|-- c: string (nullable = true)
Now you can convert to Dataset with case class
case class abc(a: Double, b: Double, c: String)
val ds2 = ds1.as[abc]
ds2.printSchema()
Which now has the required schema
root
|-- a: double (nullable = false)
|-- b: double (nullable = false)
|-- c: string (nullable = true)
Hope this helps!
OK, I think I've resolved this in a better way.
Instead of using a collect when we create a new dataset, we can just reference the rdd of the dataset.
So instead of
val ds2 = spark.createDataset(ds1.as[abc].collect)
We use:
val ds2 = spark.createDataset(ds1.as[abc].rdd)
ds2.printSchema
root
|-- a: double (nullable = false)
|-- b: double (nullable = false)
|-- c: string (nullable = true)
This keeps the lazy evaluation intact, but allows the new dataset to use the Encoder for the abc case class, and the subsequent schema will reflect this when we use it to create a new table.
I'm working through a Databricks example. The schema for the dataframe looks like:
> parquetDF.printSchema
root
|-- department: struct (nullable = true)
| |-- id: string (nullable = true)
| |-- name: string (nullable = true)
|-- employees: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- firstName: string (nullable = true)
| | |-- lastName: string (nullable = true)
| | |-- email: string (nullable = true)
| | |-- salary: integer (nullable = true)
In the example, they show how to explode the employees column into 4 additional columns:
val explodeDF = parquetDF.explode($"employees") {
case Row(employee: Seq[Row]) => employee.map{ employee =>
val firstName = employee(0).asInstanceOf[String]
val lastName = employee(1).asInstanceOf[String]
val email = employee(2).asInstanceOf[String]
val salary = employee(3).asInstanceOf[Int]
Employee(firstName, lastName, email, salary)
}
}.cache()
display(explodeDF)
How would I do something similar with the department column (i.e. add two additional columns to the dataframe called "id" and "name")? The methods aren't exactly the same, and I can only figure out how to create a brand new data frame using:
val explodeDF = parquetDF.select("department.id","department.name")
display(explodeDF)
If I try:
val explodeDF = parquetDF.explode($"department") {
case Row(dept: Seq[String]) => dept.map{dept =>
val id = dept(0)
val name = dept(1)
}
}.cache()
display(explodeDF)
I get the warning and error:
<console>:38: warning: non-variable type argument String in type pattern Seq[String] is unchecked since it is eliminated by erasure
case Row(dept: Seq[String]) => dept.map{dept =>
^
<console>:37: error: inferred type arguments [Unit] do not conform to method explode's type parameter bounds [A <: Product]
val explodeDF = parquetDF.explode($"department") {
^
In my opinion the most elegant solution is to star expand a Struct using a select operator as shown below:
var explodedDf2 = explodedDf.select("department.*","*")
https://docs.databricks.com/spark/latest/spark-sql/complex-types.html
You could use something like that:
var explodeDF = explodeDF.withColumn("id", explodeDF("department.id"))
explodeDeptDF = explodeDeptDF.withColumn("name", explodeDeptDF("department.name"))
which you helped me into and these questions:
Flattening Rows in Spark
Spark 1.4.1 DataFrame explode list of JSON objects
This seems to work (though maybe not the most elegant solution).
var explodeDF2 = explodeDF.withColumn("id", explodeDF("department.id"))
explodeDF2 = explodeDF2.withColumn("name", explodeDF2("department.name"))