How can I add or replace fields to a struct on any nested level?
This input:
val rdd = sc.parallelize(Seq(
"""{"a": {"xX": 1,"XX": 2},"b": {"z": 0}}""",
"""{"a": {"xX": 3},"b": {"z": 0}}""",
"""{"a": {"XX": 3},"b": {"z": 0}}""",
"""{"a": {"xx": 4},"b": {"z": 0}}"""))
var df = sqlContext.read.json(rdd)
Yields the following schema:
root
|-- a: struct (nullable = true)
| |-- XX: long (nullable = true)
| |-- xX: long (nullable = true)
| |-- xx: long (nullable = true)
|-- b: struct (nullable = true)
| |-- z: long (nullable = true)
Then I can do this:
import org.apache.spark.sql.functions._
val overlappingNames = Seq(col("a.xx"), col("a.xX"), col("a.XX"))
df = df
.withColumn("a_xx",
coalesce(overlappingNames:_*))
.dropNestedColumn("a.xX")
.dropNestedColumn("a.XX")
.dropNestedColumn("a.xx")
(dropNestedColumn is borrowed from this answer:
https://stackoverflow.com/a/39943812/1068385. I'm basically looking for the inverse operation of that.)
And the schema becomes:
root
|-- a: struct (nullable = false)
|-- b: struct (nullable = true)
| |-- z: long (nullable = true)
|-- a_xx: long (nullable = true)
Obviously it doesn't replace (or add) a.xx but instead it adds the new field a_xx on root level.
I'd like to be able to do this instead:
val overlappingNames = Seq(col("a.xx"), col("a.xX"), col("a.XX"))
df = df
.withNestedColumn("a.xx",
coalesce(overlappingNames:_*))
.dropNestedColumn("a.xX")
.dropNestedColumn("a.XX")
So that it would result in this schema:
root
|-- a: struct (nullable = false)
| |-- xx: long (nullable = true)
|-- b: struct (nullable = true)
| |-- z: long (nullable = true)
How can I achieve that?
The practical goal here is to be case-insensitive with column names in the input JSON. The final step would be simple: collect all overlapping column names and apply the coalesce on each.
It might not be as elegant or as efficient as it could be but here is what I came up with:
object DataFrameUtils {
private def nullableCol(parentCol: Column, c: Column): Column = {
when(parentCol.isNotNull, c)
}
private def nullableCol(c: Column): Column = {
nullableCol(c, c)
}
private def createNestedStructs(splitted: Seq[String], newCol: Column): Column = {
splitted
.foldRight(newCol) {
case (colName, nestedStruct) => nullableCol(struct(nestedStruct as colName))
}
}
private def recursiveAddNestedColumn(splitted: Seq[String], col: Column, colType: DataType, nullable: Boolean, newCol: Column): Column = {
colType match {
case colType: StructType if splitted.nonEmpty => {
var modifiedFields: Seq[(String, Column)] = colType.fields
.map(f => {
var curCol = col.getField(f.name)
if (f.name == splitted.head) {
curCol = recursiveAddNestedColumn(splitted.tail, curCol, f.dataType, f.nullable, newCol)
}
(f.name, curCol as f.name)
})
if (!modifiedFields.exists(_._1 == splitted.head)) {
modifiedFields :+= (splitted.head, nullableCol(col, createNestedStructs(splitted.tail, newCol)) as splitted.head)
}
var modifiedStruct: Column = struct(modifiedFields.map(_._2): _*)
if (nullable) {
modifiedStruct = nullableCol(col, modifiedStruct)
}
modifiedStruct
}
case _ => createNestedStructs(splitted, newCol)
}
}
private def addNestedColumn(df: DataFrame, newColName: String, newCol: Column): DataFrame = {
if (newColName.contains('.')) {
var splitted = newColName.split('.')
val modifiedOrAdded: (String, Column) = df.schema.fields
.find(_.name == splitted.head)
.map(f => (f.name, recursiveAddNestedColumn(splitted.tail, col(f.name), f.dataType, f.nullable, newCol)))
.getOrElse {
(splitted.head, createNestedStructs(splitted.tail, newCol) as splitted.head)
}
df.withColumn(modifiedOrAdded._1, modifiedOrAdded._2)
} else {
// Top level addition, use spark method as-is
df.withColumn(newColName, newCol)
}
}
implicit class ExtendedDataFrame(df: DataFrame) extends Serializable {
/**
* Add nested field to DataFrame
*
* #param newColName Dot-separated nested field name
* #param newCol New column value
*/
def withNestedColumn(newColName: String, newCol: Column): DataFrame = {
DataFrameUtils.addNestedColumn(df, newColName, newCol)
}
}
}
Feel free to improve on it.
val data = spark.sparkContext.parallelize(List("""{ "a1": 1, "a3": { "b1": 3, "b2": { "c1": 5, "c2": 6 } } }"""))
val df: DataFrame = spark.read.json(data)
val df2 = df.withNestedColumn("a3.b2.c3.d1", $"a3.b2")
should produce:
assertResult("struct<a1:bigint,a3:struct<b1:bigint,b2:struct<c1:bigint,c2:bigint,c3:struct<d1:struct<c1:bigint,c2:bigint>>>>>")(df2.shema.simpleString)
Related
I need the updatedDF as per new columns. but it is not updated with new columns , it still gives me old columns and its names
val schema = "sku_cd#sku_code,ean_nbr#ean,vnr_cd#nan_key,dsupp_pcmdty_desc#pack_descr"
val schemaArr = schema.split(",")
var df = spark.sql("""select sku_code, ean , nan_key, pack_descr from db.products""")
val updatedDF = populateAttributes(df,schemaArr)
def populateAttributes(df:DataFrame,schemaArr:Array[String]) : DataFrame = {
for(i <- schemaArr)
{
val targetCol = i.split("#")(0)
val sourceCol = i.split("#")(1)
df.withColumn(targetCol, col(sourceCol))
}
df
}
I get below output which is incorrect
scala> updatedDF.printSchema
root
|-- sku_code: string (nullable = true)
|-- ean: string (nullable = true)
|-- nan_key: string (nullable = true)
|-- pack_descr: string (nullable = true)
Expected output
|-- sku_cd: string (nullable = true)
|-- ean_nbr: string (nullable = true)
|-- vnr_cd: string (nullable = true)
|-- dsupp_pcmdty_desc: string (nullable = true)
You are not updating the dataframe in your for loop. The line:
df.withColumn(targetCol, col(sourceCol))
will create a new dataframe and df will remain the same.
You can use var in order to reassign the original dataframe in each iteration. Also use withColumnRenamed to rename a column:
df = df.withColumnRenamed(sourceCol, targetCol)
Or better, use foldLeft :
def populateAttributes(df:DataFrame,schemaArr:Array[String]) : DataFrame = {
schemaArr.foldLeft(df)((acc, m) => {
val mapping = m.split("#")
acc.withColumnRenamed(mapping(1), mapping(0))
})
}
Another way using a select expression :
val selectExpr = schemaArr.map(m => {
val mapping = m.split("#")
col(mapping(1)).as(mapping(0))
})
val updatedDF = df.select(selectExpr:_*)
Just another way to do what blackbishop did
val schema = "sku_cd#sku_code,ean_nbr#ean,vnr_cd#nan_key,dsupp_pcmdty_desc#pack_descr"
val schemaArr = schema.split(",").toSeq
val outputDF=schemaArr.foldLeft(inputDF)((df,x)=>df.withColumnRenamed(x,x.split('#')(0)))
I have 5 queries like below:
select * from table1
select * from table2
select * from table3
select * from table4
select * from table5
Now, what I want is I have to execute these queries in the sequential fashion and then keep on storing the output in the single JSON file in the appended mode. I wrote the below code but it stores the output for each query in different part files instead of one.
Below is my code:
def store(jobEntity: JobDetails, jobRunId: Int): Unit = {
UDFUtil.registerUdfFunctions()
var outputTableName: String = null
val jobQueryMap = jobEntity.jobQueryList.map(jobQuery => (jobQuery.sequenceId, jobQuery))
val sortedQueries = scala.collection.immutable.TreeMap(jobQueryMap.toSeq: _*).toMap
LOGGER.debug("sortedQueries ===>" + sortedQueries)
try {
outputTableName = jobEntity.destinationEntity
var resultDF: DataFrame = null
sortedQueries.values.foreach(jobQuery => {
LOGGER.debug(s"jobQuery.query ===> ${jobQuery.query}")
resultDF = SparkSession.builder.getOrCreate.sqlContext.sql(jobQuery.query)
if (jobQuery.partitionColumn != null && !jobQuery.partitionColumn.trim.isEmpty) {
resultDF = resultDF.repartition(jobQuery.partitionColumn.split(",").map(col): _*)
}
if (jobQuery.isKeepInMemory) {
resultDF = resultDF.persist(StorageLevel.MEMORY_AND_DISK_SER)
}
if (jobQuery.isCheckpointEnabled) {
val checkpointDir = ApplicationConfig.getAppConfig(JobConstants.CHECKPOINT_DIR)
val fs = FileSystem.get(new Storage(JsonUtil.toMap[String](jobEntity.sourceConnection)).asHadoopConfig())
val path = new Path(checkpointDir)
if (!fs.exists(path)) {
fs.mkdirs(path)
}
resultDF.explain(true)
SparkSession.builder.getOrCreate.sparkContext.setCheckpointDir(checkpointDir)
resultDF = resultDF.checkpoint
}
resultDF = {
if (jobQuery.isBroadCast) {
import org.apache.spark.sql.functions.broadcast
broadcast(resultDF)
} else
resultDF
}
tempViewsList.+=(jobQuery.queryAliasName)
resultDF.createOrReplaceTempView(jobQuery.queryAliasName)
// resultDF.explain(true)
val map: Map[String, String] = JsonUtil.toMap[String](jobEntity.sinkConnection)
LOGGER.debug("sink details :: " + map)
if (resultDF != null && !resultDF.take(1).isEmpty) {
resultDF.show(false)
val sinkDetails = new Storage(JsonUtil.toMap[String](jobEntity.sinkConnection))
val path = sinkDetails.basePath + File.separator + jobEntity.destinationEntity
println("path::: " + path)
resultDF.repartition(1).write.mode(SaveMode.Append).json(path)
}
}
)
Just ignore the other things(Checkpointing, Logging, Auditing) that I am doing in this method along with reading and writing.
Use the below example as a reference for your problem.
I have three tables with Json data (with different schema) as below:
table1 --> Personal Data Table
table2 --> Company Data Table
table3 --> Salary Data Table
I am reading these three tables one by one in the sequential mode as per your requirement and doing few transformations over data (exploding Json array Column) with the help of List TableColList which contains Array column Name corresponding to table with a semicolon (":") separator.
OutDFList is the list of all transformed DataFrames.
At the end, I am reducing all DataFrames from OutDFList into a single dataframe and writing it into one JSON file.
Note: I have used join to reduced all DataFrames, You can also use
union(if have same columns) or else as per requirement.
Check below code:
scala> spark.sql("select * from table1").printSchema
root
|-- Personal: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- DOB: string (nullable = true)
| | |-- EmpID: string (nullable = true)
| | |-- Name: string (nullable = true)
scala> spark.sql("select * from table2").printSchema
root
|-- Company: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- EmpID: string (nullable = true)
| | |-- JoinDate: string (nullable = true)
| | |-- Project: string (nullable = true)
scala> spark.sql("select * from table3").printSchema
root
|-- Salary: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- EmpID: string (nullable = true)
| | |-- Monthly: string (nullable = true)
| | |-- Yearly: string (nullable = true)
scala> val TableColList = List("table1:Personal", "table2:Company", "table3:Salary")
TableColList: List[String] = List(table1:Personal, table2:Company, table3:Salary)
scala> val OutDFList = TableColList.map{ X =>
| val table = X.split(":")(0)
| val arrayColumn = X.split(":")(1)
| val df = spark.sql(s"""SELECT * FROM """ + table).select(explode(col(arrayColumn)) as "data").select("data.*")
| df}
OutDFList: List[org.apache.spark.sql.DataFrame] = List([DOB: string, EmpID: string ... 1 more field], [EmpID: string, JoinDate: string ... 1 more field], [EmpID: string, Monthly: string ... 1 more field])
scala> val FinalOutDF = OutDFList.reduce((df1, df2) => df1.join(df2, "EmpID"))
FinalOutDF: org.apache.spark.sql.DataFrame = [EmpID: string, DOB: string ... 5 more fields]
scala> FinalOutDF.printSchema
root
|-- EmpID: string (nullable = true)
|-- DOB: string (nullable = true)
|-- Name: string (nullable = true)
|-- JoinDate: string (nullable = true)
|-- Project: string (nullable = true)
|-- Monthly: string (nullable = true)
|-- Yearly: string (nullable = true)
scala> FinalOutDF.write.json("/FinalJsonOut")
First thing first, you need to union all the schemas:
import org.apache.spark.sql.functions._
val df1 = sc.parallelize(List(
(42, 11),
(43, 21)
)).toDF("foo", "bar")
val df2 = sc.parallelize(List(
(44, true, 1.0),
(45, false, 3.0)
)).toDF("foo", "foo0", "foo1")
val cols1 = df1.columns.toSet
val cols2 = df2.columns.toSet
val total = cols1 ++ cols2 // union
def expr(myCols: Set[String], allCols: Set[String]) = {
allCols.toList.map(x => x match {
case x if myCols.contains(x) => col(x)
case _ => lit(null).as(x)
})
}
val total = df1.select(expr(cols1, total):_*).unionAll(df2.select(expr(cols2, total):_*))
total.show()
And obvs save to the single JSON file:
df.coalesce(1).write.mode('append').json("/some/path")
UPD
If you are not using DFs, just come along with plain SQL queries (writing to single file remains the same - coalesce(1) or repartition(1)):
spark.sql(
"""
|SELECT id, name
|FROM (
| SELECT first.id, first.name, FROM first
| UNION
| SELECT second.id, second.name FROM second
| ORDER BY second.name
| ) t
""".stripMargin).show()
I have a set of dataframes, dfs, with different schema, for example:
root
|-- A_id: string (nullable = true)
|-- b_cd: string (nullable = true)
|-- c_id: integer (nullable = true)
|-- d_info: struct (nullable = true)
| |-- eid: string (nullable = true)
| |-- oid: string (nullable = true)
|-- l: array (nullable = true)
| |-- m: struct (containsNull = true)
| | |-- n: string (nullable = true)
| | |-- o: string (nullable = true)
..........
I want to check if, for example, "oid" is given in one of the column (here under d_info column). How can I search inside a schema for a set of dataframes and distinguish them. Pyspark or Scala suggestion are both helpful. Thank you
A dictionary/map of [node , root to node path] could be created for DataFame StructType (including nested StructType) using a recursive function.
val df = spark.read.json("nested_data.json")
val path = searchSchema(df.schema, "n", "root")
def searchSchema(schema: StructType, key: String, path: String): String = {
val paths = scala.collection.mutable.Map[String, String]()
addPaths(schema, path, paths)
paths(key)
}
def addPaths(schema: StructType, path: String, paths: scala.collection.mutable.Map[String, String]): Unit = {
for (field <- schema.fields) {
val _path = s"$path.${field.name}"
paths += (field.name -> _path)
field.dataType match {
case structType: StructType => addPaths(structType, _path, paths)
case arrayType: ArrayType => addPaths(arrayType.elementType.asInstanceOf[StructType], _path, paths)
case _ => //donothing
}
}
}
Input and output
Input = {"A_id":"A_id","b_cd":"b_cd","c_id":1,"d_info":{"eid":"eid","oid":"oid"},"l":[{"m":{"n":"n1","o":"01"}},{"m":{"n":"n2","o":"02"}}]}
Output = Map(n -> root.l.m.n, b_cd -> root.b_cd, d_info -> root.d_info, m -> root.l.m, oid -> root.d_info.oid, c_id -> root.c_id, l -> root.l, o -> root.l.m.o, eid -> root.d_info.eid, A_id -> root.A_id)
I am trying to replace certain characters in all the columns of my DataFrame which has lot of nested Struct Types.
I tried to process the schema fields recursively and for some reason it is only renaming the fields at the top level even through it is reaching the leaf nodes.
I am trying replace the : char in the column name with _
Here is the scala code I have writte:
class UpdateSchema {
val logger = LoggerFactory.getLogger(classOf[UpdateSchema])
Logger.getLogger("org").setLevel(Level.OFF)
Logger.getLogger("akka").setLevel(Level.OFF)
val sparkSession = SparkLauncher.spark
import sparkSession.implicits._
def updateSchema(filePath:String):Boolean ={
logger.info(".updateSchema() : filePath ={}",filePath);
logger.info(".updateSchema() : sparkSession ={}",sparkSession);
if(sparkSession!=null){
var xmlDF = sparkSession
.read
.format("com.databricks.spark.xml")
.option("rowTag","ns:fltdMessage")
.option("inferschema","true")
.option("attributePrefix","attr_")
.load(filePath)
.toDF()
xmlDF.printSchema()
val updatedDF = renameDataFrameColumns(xmlDF.toDF())
updatedDF.printSchema()
}
else
logger.info(".updateSchema(): Spark Session is NULL !!!");
false;
}
def replaceSpecialChars(str:String):String ={
val newColumn:String = str.replaceAll(":", "_")
//logger.info(".replaceSpecialChars() : Old Column Name =["+str+"] New Column Name =["+newColumn+"]")
return newColumn
}
def renameColumn(df:DataFrame,colName:String,prefix:String):DataFrame ={
val newColuName:String = replaceSpecialChars(colName)
logger.info(".renameColumn(): prefix=["+prefix+"] colName=["+colName+"] New Column Name=["+newColuName+"]")
if(prefix.equals("")){
if(df.col(colName)!=null){
return df.withColumnRenamed(colName, replaceSpecialChars(colName))
}
else{
logger.error(".logSchema() : Column ["+prefix+"."+colName+"] Not found in DataFrame !! ")
logger.info("Prefix ="+prefix+" Existing Columns =["+df.columns.mkString("),(")+"]")
throw new Exception("Unable to find Column ["+prefix+"."+colName+"]")
}
}
else{
if(df.col(prefix+"."+colName)!=null){
return df.withColumnRenamed(prefix+"."+colName, prefix+"."+replaceSpecialChars(colName))
}
else{
logger.error(".logSchema() : Column ["+prefix+"."+colName+"] Not found in DataFrame !! ")
logger.info("Prefix ="+prefix+" Existing Columns =["+df.columns.mkString("),(")+"]")
throw new Exception("Unable to find Column ["+prefix+"."+colName+"]")
}
}
}
def getStructType(schema:StructType,fieldName:String):StructType = {
schema.fields.foreach(field => {
field.dataType match{
case st:StructType => {
logger.info(".getStructType(): Current Field Name =["+field.name.toString()+"] Checking for =["+fieldName+"]")
if(field.name.toString().equals(fieldName)){
return field.dataType.asInstanceOf[StructType]
}
else{
getStructType(st,fieldName)
}
}
case _ =>{
logger.info(".getStructType(): Non Struct Type. Ignoring Filed=["+field.name.toString()+"]");
}
}
})
throw new Exception("Unable to find Struct Type for filed Name["+fieldName+"]")
}
def processSchema(df:DataFrame,schema:StructType,prefix:String):DataFrame ={
var updatedDF:DataFrame =df
schema.fields.foreach(field =>{
field.dataType match {
case st:StructType => {
logger.info(".processSchema() : Struct Type =["+st+"]");
logger.info(".processSchema() : Field Data Type =["+field.dataType+"]");
logger.info(".processSchema() : Renaming the Struct Field =["+field.name.toString()+"] st=["+st.fieldNames.mkString(",")+"]")
updatedDF = renameColumn(updatedDF,field.name.toString(),prefix)
logger.info(".processSchema() : Column List after Rename =["+updatedDF.columns.mkString(",")+"]")
// updatedDF.schema.fields.foldLeft(z)(op)
val renamedCol:String = replaceSpecialChars(field.name.toString())
var fieldType:DataType = null;
//if(prefix.equals(""))
fieldType = schema.fields.find(f =>{ (f.name.toString().equals(field.name.toString()))}).get.dataType
if(prefix.trim().equals("")
//&& fieldType.isInstanceOf[StructType]
){
updatedDF = processSchema(updatedDF,
getStructType(updatedDF.schema,renamedCol),
replaceSpecialChars(field.name.toString()))
}
else{
updatedDF = processSchema(updatedDF,
getStructType(updatedDF.schema,renamedCol),
prefix+"."+replaceSpecialChars(field.name.toString()))
}
}
case _ => {
updatedDF = renameColumn(updatedDF,field.name.toString(),prefix)
}
}
})
//updatedDF.printSchema()
return updatedDF
}
def renameDataFrameColumns(df:DataFrame):DataFrame ={
val schema = df.schema;
return processSchema(df,schema,"")
}
}
Here's a recursive method that revise a DataFrame schema by renaming via replaceAll any columns whose name consists of a substring to be replaced:
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
def renameAllColumns(schema: StructType, from: String, to: String): StructType = {
def recurRename(schema: StructType, from: String, to:String): Seq[StructField] =
schema.fields.map{
case StructField(name, dtype: StructType, nullable, meta) =>
StructField(name.replaceAll(from, to), StructType(recurRename(dtype, from, to)), nullable, meta)
case StructField(name, dtype: ArrayType, nullable, meta) => dtype.elementType match {
case struct: StructType => StructField(name.replaceAll(from, to), ArrayType(StructType(recurRename(struct, from, to)), true), nullable, meta)
case other => StructField(name.replaceAll(from, to), other, nullable, meta)
}
case StructField(name, dtype, nullable, meta) =>
StructField(name.replaceAll(from, to), dtype, nullable, meta)
}
StructType(recurRename(schema, from, to))
}
Testing the method on a sample DataFrame with a nested structure:
case class M(i: Int, `p:q`: String)
case class N(j: Int, m: M)
val df = Seq(
(1, "a", Array(N(7, M(11, "x")), N(72, M(112, "x2")))),
(2, "b", Array(N(8, M(21, "y")))),
(3, "c", Array(N(9, M(31, "z"))))
).toDF("c1", "c2:0", "c3")
df.printSchema
// root
// |-- c1: integer (nullable = false)
// |-- c2:0: string (nullable = true)
// |-- c3: array (nullable = true)
// | |-- element: struct (containsNull = true)
// | | |-- j: integer (nullable = false)
// | | |-- m: struct (nullable = true)
// | | | |-- i: integer (nullable = false)
// | | | |-- p:q: string (nullable = true)
val newSchema = renameAllColumns(df.schema, ":", "_")
spark.createDataFrame(df.rdd, newSchema).printSchema
// root
// |-- c1: integer (nullable = false)
// |-- c2_0: string (nullable = true)
// |-- c3: array (nullable = true)
// | |-- element: struct (containsNull = true)
// | | |-- j: integer (nullable = false)
// | | |-- m: struct (nullable = true)
// | | | |-- i: integer (nullable = false)
// | | | |-- p_q: string (nullable = true)
Note that since method replaceAll supports Regex pattern, one can apply the method with more versatile replacement condition. For example, here's how to trim off column name starting from the ':' character:
val newSchema = renameAllColumns(df.schema, """:.*""", "")
spark.createDataFrame(df.rdd, newSchema).printSchema
// root
// |-- c1: integer (nullable = false)
// |-- c2: string (nullable = true)
// |-- c3: array (nullable = true)
// | |-- element: struct (containsNull = true)
// | | |-- j: integer (nullable = false)
// | | |-- m: struct (nullable = true)
// | | | |-- i: integer (nullable = false)
// | | | |-- p: string (nullable = true)
Unfortunately, you can't easily rename a single nested field using withFieldRenamed liked you are attempting to do. The only way I know of to rename nested fields is to do a cast on the field provided a type with the same structure and data types but new field names. This has to be done on the top level field, so you need to do all the fields in one go. Here is an example:
Create some input data
case class InnerRecord(column1: String, column2: Int)
case class Record(field: InnerRecord)
val df = Seq(
Record(InnerRecord("a", 1)),
Record(InnerRecord("b", 2))
).toDF
df.printSchema
Input data looks like this:
root
|-- field: struct (nullable = true)
| |-- column1: string (nullable = true)
| |-- column2: integer (nullable = false)
This is an example using withColumnRenamed. You'll notice in the output it doesn't actually do anything!
val updated = df.withColumnRenamed("field.column1", "field.newname")
updated.printSchema
root
|-- field: struct (nullable = true)
| |-- column1: string (nullable = true)
| |-- column2: integer (nullable = false)
Here is how you can do it instead with casting. The function will recursively recreate the nested field type while updating the name. In my case I just replaced "column" with "col_". I also only ran it on one field, but you could easily loop across all your fields in schema.
import org.apache.spark.sql.types._
def rename(dataType: DataType): DataType = dataType match {
case StructType(fields) =>
StructType(fields.map {
case StructField(name, dtype, nullable, meta) =>
val newName = name.replace("column", "col_")
StructField(newName, rename(dtype), nullable, meta)
})
case _ => dataType
}
val fDataType = df.schema.filter(_.name == "field").head.dataType
val updated = df.withColumn("field", $"field".cast(rename(fDataType)))
updated.printSchema
Which prints:
root
|-- field: struct (nullable = true)
| |-- col_1: string (nullable = true)
| |-- col_2: integer (nullable = false)
I had some issues with the answer from #Leo C, so I've used a slight variation instead. It also takes in any mapping function f to rename.
def renameAllColumns(schema: StructType, f: String => String): StructType = {
def recurRename(schema: StructType, f: String => String): Seq[StructField] =
schema.fields.map{
case StructField(name, dtype: StructType, nullable, meta) =>
StructField(f(name), StructType(recurRename(dtype, f)), nullable, meta)
case StructField(name, dtype: ArrayType, nullable, meta) => dtype.elementType match {
case struct: StructType => StructField(f(name), ArrayType(StructType(recurRename(struct, f)), true), nullable, meta)
case other => StructField(f(name), ArrayType(other), nullable, meta)
}
case StructField(name, dtype, nullable, meta) =>
StructField(f(name), dtype, nullable, meta)
}
StructType(recurRename(schema, f))
}
This has me confused. I'm using "spark-testing-base_2.11" % "2.0.0_0.5.0" for the test. Can anyone explain why the map function changes the schema if using a Dataset, but works if I use the RDD? Any insights greatly appreciated.
import com.holdenkarau.spark.testing.SharedSparkContext
import org.apache.spark.sql.{ Encoders, SparkSession }
import org.scalatest.{ FunSpec, Matchers }
class TransformSpec extends FunSpec with Matchers with SharedSparkContext {
describe("data transformation") {
it("the rdd maintains the schema") {
val spark = SparkSession.builder.getOrCreate()
import spark.implicits._
val personEncoder = Encoders.product[TestPerson]
val personDS = Seq(TestPerson("JoeBob", 29)).toDS
personDS.schema shouldEqual personEncoder.schema
val mappedSet = personDS.rdd.map { p: TestPerson => p.copy(age = p.age + 1) }.toDS
personEncoder.schema shouldEqual mappedSet.schema
}
it("datasets choke on explicit schema") {
val spark = SparkSession.builder.getOrCreate()
import spark.implicits._
val personEncoder = Encoders.product[TestPerson]
val personDS = Seq(TestPerson("JoeBob", 29)).toDS
personDS.schema shouldEqual personEncoder.schema
val mappedSet = personDS.map[TestPerson] { p: TestPerson => p.copy(age = p.age + 1) }
personEncoder.schema shouldEqual mappedSet.schema
}
}
}
case class TestPerson(name: String, age: Int)
A couple of things are conspiring against you here. Spark appears to have special casing for what types it considers nullable.
case class TestTypes(
scalaString: String,
javaString: java.lang.String,
myString: MyString,
scalaInt: Int,
javaInt: java.lang.Integer,
myInt: MyInt
)
Encoders.product[TestTypes].schema.printTreeString results in:
root
|-- scalaString: string (nullable = true)
|-- javaString: string (nullable = true)
|-- myString: struct (nullable = true)
| |-- value: string (nullable = true)
|-- scalaInt: integer (nullable = false)
|-- javaInt: integer (nullable = true)
|-- myInt: struct (nullable = true)
| |-- value: integer (nullable = false)
but if you map the types you will end up with everything nullable
val testTypes: Seq[TestTypes] = Nil
val testDS = testTypes.toDS
testDS.map(foo => foo).mapped.schema.printTreeString results in everything being nullable:
root
|-- scalaString: string (nullable = true)
|-- javaString: string (nullable = true)
|-- myString: struct (nullable = true)
| |-- value: string (nullable = true)
|-- scalaInt: integer (nullable = true)
|-- javaInt: integer (nullable = true)
|-- myInt: struct (nullable = true)
| |-- value: integer (nullable = true)
Even if you force the schema to be correct, Spark explicitely ignores nullability comparisons when applying a schema, which is why when you convert back to the typed representation you lose the few nullability guarantees you had.
You could enrich your types to be able to force a nonNull schema:
implicit class StructImprovements(s: StructType) {
def nonNull: StructType = StructType(s.map(_.copy(nullable = false)))
}
implicit class DsImprovements[T: Encoder](ds: Dataset[T]) {
def nonNull: Dataset[T] = {
val nnSchema = ds.schema.nonNull
applySchema(ds.toDF, nnSchema).as[T]
}
}
val mappedSet = personDS.map { p =>
p.copy(age = p.age + 1)
}.nonNull
But you will find it evaporates when applying any interesting operation then when comparing schemas again if the shape is the same except nullability Spark will pass it through as the same.
This appears to be by design https://github.com/apache/spark/pull/11785
A map is a transformation operation on the data. It takes the input, and a function and applies that function to all elements of the input data. The output is the set of return values of this function. So the schmea of the output data is dependant on the return type of the function.
The map operation is a fairly standard and heavily used operation in functional programming. Look at https://en.m.wikipedia.org/wiki/Map_(higher-order_function) if you want to read more.