Let me explain this with an example. Starting with the following dataframe
val df = Seq((1, "CS", 0, Array(0.1, 0.2, 0.4, 0.5)),
(4, "Ed", 0, Array(0.4, 0.8, 0.3, 0.6)),
(7, "CS", 0, Array(0.2, 0.5, 0.4, 0.7)),
(101, "CS", 1, Array(0.5, 0.7, 0.3, 0.8)),
(5, "CS", 1, Array(0.4, 0.2, 0.6, 0.9))).toDF("id", "dept", "test", "array")
df.show()
+---+----+----+--------------------+
| id|dept|test| array|
+---+----+----+--------------------+
| 1| CS| 0|[0.1, 0.2, 0.4, 0.5]|
| 4| Ed| 0|[0.4, 0.8, 0.3, 0.6]|
| 7| CS| 0|[0.2, 0.5, 0.4, 0.7]|
|101| CS| 1|[0.5, 0.7, 0.3, 0.8]|
| 5| CS| 1|[0.4, 0.2, 0.6, 0.9]|
+---+----+----+--------------------+
Considering the following two common operations as example (but do not have to be limited to them):
import org.apache.spark.sql.functions._ // for `when`
val dfFilter1 = df.where($"dept" === "CS")
val dfFilter3 = df.withColumn("category", when($"dept" === "CS" && $"id" === 101, 10).otherwise(0))
Now, I have a string variable colName = "dept". And it is required that $"dept" in the previous operation has to be replaced by colName in some form to achieve the same functionality. I managed to achieve the first one as following:
val dfFilter2 = df.where(s"${colName} = 'CS'")
But similar operation fails in the second case:
val dfFilter4 = df.withColumn("category", when(s"${colName} = 'CS'" && $"id" === 101, 10).otherwise(0))
Specifically it gives the following error:
Name: Unknown Error
Message: <console>:35: error: value && is not a member of String
val dfFilter4 = df.withColumn("category", when(s"${colName} = 'CS'" && $"id" === 101, 10).otherwise(0))
My understanding so far is that after I use s"${variable}" to deal with a variable, everything becomes pure string, and it is difficult to have logic operation involved.
So, my question are:
1. What is the best way to use such string variable as colName for operations similar as the two I listed above (I also do not like the solution I have for .where())?
2. Are there any general guidelines to use such string variable in more general operations other than the two examples here (I always felt that it is very case-specific when I deal with string related operations).
You can use expr function as
val dfFilter4 = df.withColumn("category", when(expr(s"${colName} = 'CS' and id = 101"), 10).otherwise(0))
Reason of the error
where function when defined with string query as following is working
val dfFilter2 = df.where(s"${colName} = 'CS'")
because there are supporting apis for both string and column
/**
* Filters rows using the given condition. This is an alias forfilter.
* {{{
* // The following are equivalent:
* peopleDs.filter($"age" > 15)
* peopleDs.where($"age" > 15)
* }}}
*
* #group typedrel
* #since 1.6.0
*/
def where(condition: Column): Dataset[T] = filter(condition)
and
/**
* Filters rows using the given SQL expression.
* {{{
* peopleDs.where("age > 15")
* }}}
*
* #group typedrel
* #since 1.6.0
*/
def where(conditionExpr: String): Dataset[T] = {
filter(Column(sparkSession.sessionState.sqlParser.parseExpression(conditionExpr)))
}
But there is only one api for when function supporting only column type
/**
* Evaluates a list of conditions and returns one of multiple possible result expressions.
* If otherwise is not defined at the end, null is returned for unmatched conditions.
*
* {{{
* // Example: encoding gender string column into integer.
*
* // Scala:
* people.select(when(people("gender") === "male", 0)
* .when(people("gender") === "female", 1)
* .otherwise(2))
*
* // Java:
* people.select(when(col("gender").equalTo("male"), 0)
* .when(col("gender").equalTo("female"), 1)
* .otherwise(2))
* }}}
*
* #group normal_funcs
* #since 1.4.0
*/
def when(condition: Column, value: Any): Column = withExpr {
CaseWhen(Seq((condition.expr, lit(value).expr)))
}
So you cannot use string sql query for when function
So, correct way of doing is as following
val dfFilter4 = df.withColumn("category", when(col(s"${colName}") === "CS" && $"id" === 101, 10).otherwise(0))
or in short as
val dfFilter4 = df.withColumn("category", when(col(colName) === "CS" && col("id") === 101, 10).otherwise(0))
What is the best way to use such string variable as colName for operations similar as the two I listed above
You can use col function from org.apache.spark.sql.functions
import org.apache.spark.sql.functions._
val colName = "dept"
For dfFilter2
val dfFilter2 = df.where(col(colName) === "CS")
For dfFilter4
val dfFilter4 = df.withColumn("category", when(col(colName) === "CS" && $"id" === 101, 10).otherwise(0))
Related
I have a sparse vector in spark and I want to randomly shuffle (reorder) its contents. This vector is actually a tf-idf vector and what I want is to reorder it so that in my new dataset the features have different order. is there any way to do this using scala?
this is my code for generating tf-idf vectors:
val tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words")
val wordsData = tokenizer.transform(data).cache()
val cvModel: CountVectorizerModel = new CountVectorizer()
.setInputCol("words")
.setOutputCol("rawFeatures")
.fit(wordsData)
val featurizedData = cvModel.transform(wordsData).cache()
val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features")
val idfModel = idf.fit(featurizedData)
val rescaledData = idfModel.transform(featurizedData).cache()
Perhaps this is useful-
Load the test data
val data = Array(
Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))),
Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)
)
val df = spark.createDataFrame(data.map(Tuple1.apply)).toDF("features")
df.show(false)
df.printSchema()
/**
* +---------------------+
* |features |
* +---------------------+
* |(5,[1,3],[1.0,7.0]) |
* |[2.0,0.0,3.0,4.0,5.0]|
* |[4.0,0.0,0.0,6.0,7.0]|
* +---------------------+
*
* root
* |-- features: vector (nullable = true)
*/
shuffle the vector
val shuffleVector = udf((vector: Vector) =>
Vectors.dense(scala.util.Random.shuffle(mutable.WrappedArray.make[Double](vector.toArray)).toArray)
)
val p = df.withColumn("shuffled_vector", shuffleVector($"features"))
p.show(false)
p.printSchema()
/**
* +---------------------+---------------------+
* |features |shuffled_vector |
* +---------------------+---------------------+
* |(5,[1,3],[1.0,7.0]) |[1.0,0.0,0.0,0.0,7.0]|
* |[2.0,0.0,3.0,4.0,5.0]|[0.0,3.0,2.0,5.0,4.0]|
* |[4.0,0.0,0.0,6.0,7.0]|[4.0,7.0,6.0,0.0,0.0]|
* +---------------------+---------------------+
*
* root
* |-- features: vector (nullable = true)
* |-- shuffled_vector: vector (nullable = true)
*/
You can also use the above udf to create Transformer and put it in pipeline
please make sure to use import org.apache.spark.ml.linalg._
Update-1 convert shuffled vector to sparse
val shuffleVectorToSparse = udf((vector: Vector) =>
Vectors.dense(scala.util.Random.shuffle(mutable.WrappedArray.make[Double](vector.toArray)).toArray).toSparse
)
val p1 = df.withColumn("shuffled_vector", shuffleVectorToSparse($"features"))
p1.show(false)
p1.printSchema()
/**
* +---------------------+-------------------------------+
* |features |shuffled_vector |
* +---------------------+-------------------------------+
* |(5,[1,3],[1.0,7.0]) |(5,[0,3],[1.0,7.0]) |
* |[2.0,0.0,3.0,4.0,5.0]|(5,[1,2,3,4],[5.0,3.0,2.0,4.0])|
* |[4.0,0.0,0.0,6.0,7.0]|(5,[1,3,4],[7.0,4.0,6.0]) |
* +---------------------+-------------------------------+
*
* root
* |-- features: vector (nullable = true)
* |-- shuffled_vector: vector (nullable = true)
*/
How can I replicate this code to get the dataframe size in pyspark?
scala> val df = spark.range(10)
scala> print(spark.sessionState.executePlan(df.queryExecution.logical).optimizedPlan.stats)
Statistics(sizeInBytes=80.0 B, hints=none)
What I would like to do is get the sizeInBytes value into a variable.
In Spark 2.4 you can do
df = spark.range(10)
df.createOrReplaceTempView('myView')
spark.sql('explain cost select * from myView').show(truncate=False)
|== Optimized Logical Plan ==
Range (0, 10, step=1, splits=Some(8)), Statistics(sizeInBytes=80.0 B, hints=none)
In Spark 3.0.0-preview2 you can use explain with the cost mode:
df = spark.range(10)
df.explain(mode='cost')
== Optimized Logical Plan ==
Range (0, 10, step=1, splits=Some(8)), Statistics(sizeInBytes=80.0 B)
See of this helps-
Reading the json file source and computing stats like size in bytes , number of rows etc. This stat will also help spark to take it=ntelligent decision while optimizing execution plan This code should be same in pysparktoo
/**
* file content
* spark-test-data.json
* --------------------
* {"id":1,"name":"abc1"}
* {"id":2,"name":"abc2"}
* {"id":3,"name":"abc3"}
*/
val fileName = "spark-test-data.json"
val path = getClass.getResource("/" + fileName).getPath
spark.catalog.createTable("df", path, "json")
.show(false)
/**
* +---+----+
* |id |name|
* +---+----+
* |1 |abc1|
* |2 |abc2|
* |3 |abc3|
* +---+----+
*/
// Collect only statistics that do not require scanning the whole table (that is, size in bytes).
spark.sql("ANALYZE TABLE df COMPUTE STATISTICS NOSCAN")
spark.sql("DESCRIBE EXTENDED df ").filter(col("col_name") === "Statistics").show(false)
/**
* +----------+---------+-------+
* |col_name |data_type|comment|
* +----------+---------+-------+
* |Statistics|68 bytes | |
* +----------+---------+-------+
*/
spark.sql("ANALYZE TABLE df COMPUTE STATISTICS")
spark.sql("DESCRIBE EXTENDED df ").filter(col("col_name") === "Statistics").show(false)
/**
* +----------+----------------+-------+
* |col_name |data_type |comment|
* +----------+----------------+-------+
* |Statistics|68 bytes, 3 rows| |
* +----------+----------------+-------+
*/
more info - databricks sql doc
Typically, you can access the scala methods through py4j. I just tried this in the pyspark shell:
>>> spark._jsparkSession.sessionState().executePlan(df._jdf.queryExecution().logical()).optimizedPlan().stats().sizeInBytes()
716
I have a dataframe with three columns: id, index and value.
+---+-----+-------------------+
| id|index| value|
+---+-----+-------------------+
| A| 1023|0.09938822262205915|
| A| 1046| 0.3110047630613805|
| A| 1069| 0.8486710971453512|
+---+-----+-------------------+
root
|-- id: string (nullable = true)
|-- index: integer (nullable = false)
|-- value: double (nullable = false)
Then, I have another dataframe which shows desirable periods for each id:
+---+-----------+---------+
| id|start_index|end_index|
+---+-----------+---------+
| A| 1069| 1276|
| B| 2066| 2291|
| B| 1616| 1841|
| C| 3716| 3932|
+---+-----------+---------+
root
|-- id: string (nullable = true)
|-- start_index: integer (nullable = false)
|-- end_index: integer (nullable = false)
I have three templates as below
val template1 = Array(0.0, 0.1, 0.15, 0.2, 0.3, 0.33, 0.42, 0.51, 0.61, 0.7)
val template2 = Array(0.96, 0.89, 0.82, 0.76, 0.71, 0.65, 0.57, 0.51, 0.41, 0.35)
val template3 = Array(0.0, 0.07, 0.21, 0.41, 0.53, 0.42, 0.34, 0.25, 0.19, 0.06)
The goal is, for each row in dfIntervals, apply a function (let's assume it's correlation) in which the function receives value column from dfRaw and three template arrays and adds three columns to dfIntervals, each column related to each template.
Assumptions:
1 - Sizes of templates arrays are are exactly 10.
2 - There are no duplicates in index column of dfRaw
3 - start_index and end_index columns in dfIntervals exist in index column of dfRaw and when there are exactly 10 rows between them. For instance, dfRaw.filter($"id" === "A").filter($"index" >= 1069 && $"index" <= 1276).count (first row in dfIntervals) results in exactly 10.
Here's the code that generates these dataframes:
import org.apache.spark.sql.functions._
val mySeed = 1000
/* Defining templates for correlation analysis*/
val template1 = Array(0.0, 0.1, 0.15, 0.2, 0.3, 0.33, 0.42, 0.51, 0.61, 0.7)
val template2 = Array(0.96, 0.89, 0.82, 0.76, 0.71, 0.65, 0.57, 0.51, 0.41, 0.35)
val template3 = Array(0.0, 0.07, 0.21, 0.41, 0.53, 0.42, 0.34, 0.25, 0.19, 0.06)
/* Defining raw data*/
var dfRaw = Seq(
("A", (1023 to 1603 by 23).toArray),
("B", (341 to 2300 by 25).toArray),
("C", (2756 to 3954 by 24).toArray)
).toDF("id", "index")
dfRaw = dfRaw.select($"id", explode($"index") as "index").withColumn("value", rand(seed=mySeed))
/* Defining intervals*/
var dfIntervals = Seq(
("A", 1069, 1276),
("B", 2066, 2291),
("B", 1616, 1841),
("C", 3716, 3932)
).toDF("id", "start_index", "end_index")
There result is three columns added to dfIntervals dataframe with names corr_w_template1, corr_w_template2 and corr_w_template3
PS: I could not find a correlation function in Scala. Let's assume such a function exists (as below) and we are about to make a udf out of it is needed.
def correlation(arr1: Array[Double], arr2: Array[Double]): Double
Ok.
Let's define a UDF function.
For testing purpose, let'say it will always return 1.
val correlation = functions.udf( (values: mutable.WrappedArray[Double], template: mutable.WrappedArray[Double]) => {
1f
})
val orderUdf = udf((values: mutable.WrappedArray[Row]) => {
values.sortBy(r => r.getAs[Int](0)).map(r => r.getAs[Double](1))
})
Then let's join your 2 data frames with the defined rules & collect value into 1 column called values. Also, apply our orderUdf
val df = dfIntervals.join(dfRaw,dfIntervals("id") === dfRaw("id") && dfIntervals("start_index") <= dfRaw("index") && dfRaw("index") <= dfIntervals("end_index") )
.groupBy(dfIntervals("id"), dfIntervals("start_index"), dfIntervals("end_index"))
.agg(orderUdf(collect_list(struct(dfRaw("index"), dfRaw("value")))).as("values"))
Finally, apply our udf & show it out.
df.withColumn("corr_w_template1",correlation(df("values"), lit(template1)))
.withColumn("corr_w_template2",correlation(df("values"), lit(template2)))
.withColumn("corr_w_template3",correlation(df("values"), lit(template3)))
.show(10)
This is full of example code:
import org.apache.spark.sql.functions._
import scala.collection.JavaConverters._
val conf = new SparkConf().setAppName("learning").setMaster("local[2]")
val session = SparkSession.builder().config(conf).getOrCreate()
val mySeed = 1000
/* Defining templates for correlation analysis*/
val template1 = Array(0.0, 0.1, 0.15, 0.2, 0.3, 0.33, 0.42, 0.51, 0.61, 0.7)
val template2 = Array(0.96, 0.89, 0.82, 0.76, 0.71, 0.65, 0.57, 0.51, 0.41, 0.35)
val template3 = Array(0.0, 0.07, 0.21, 0.41, 0.53, 0.42, 0.34, 0.25, 0.19, 0.06)
val schema1 = DataTypes.createStructType(Array(
DataTypes.createStructField("id",DataTypes.StringType,false),
DataTypes.createStructField("index",DataTypes.createArrayType(DataTypes.IntegerType),false)
))
val schema2 = DataTypes.createStructType(Array(
DataTypes.createStructField("id",DataTypes.StringType,false),
DataTypes.createStructField("start_index",DataTypes.IntegerType,false),
DataTypes.createStructField("end_index",DataTypes.IntegerType,false)
))
/* Defining raw data*/
var dfRaw = session.createDataFrame(Seq(
("A", (1023 to 1603 by 23).toArray),
("B", (341 to 2300 by 25).toArray),
("C", (2756 to 3954 by 24).toArray)
).map(r => Row(r._1 , r._2)).asJava, schema1)
dfRaw = dfRaw.select(dfRaw("id"), explode(dfRaw("index")) as "index")
.withColumn("value", rand(seed=mySeed))
/* Defining intervals*/
var dfIntervals = session.createDataFrame(Seq(
("A", 1069, 1276),
("B", 2066, 2291),
("B", 1616, 1841),
("C", 3716, 3932)
).map(r => Row(r._1 , r._2,r._3)).asJava, schema2)
//Define udf
val correlation = functions.udf( (values: mutable.WrappedArray[Double], template: mutable.WrappedArray[Double]) => {
1f
})
val orderUdf = udf((values: mutable.WrappedArray[Row]) => {
values.sortBy(r => r.getAs[Int](0)).map(r => r.getAs[Double](1))
})
val df = dfIntervals.join(dfRaw,dfIntervals("id") === dfRaw("id") && dfIntervals("start_index") <= dfRaw("index") && dfRaw("index") <= dfIntervals("end_index") )
.groupBy(dfIntervals("id"), dfIntervals("start_index"), dfIntervals("end_index"))
.agg(orderUdf(collect_list(struct(dfRaw("index"), dfRaw("value")))).as("values"))
df.withColumn("corr_w_template1",correlation(df("values"), lit(template1)))
.withColumn("corr_w_template2",correlation(df("values"), lit(template2)))
.withColumn("corr_w_template3",correlation(df("values"), lit(template3)))
.show(10,false)
I have 2 DataFrames
case class UserTransactions(id: Long, transactionDate: java.sql.Date, currencyUsed: String, value: Long)
ID, TransactionDate, CurrencyUsed, value
1, 2016-01-05, USD, 100
1, 2016-01-09, GBP, 150
1, 2016-02-01, USD, 50
1, 2016-02-10, JPN, 10
2, 2016-01-10, EURO, 50
2, 2016-01-10, GBP, 100
case class ReportingTime(userId: Long, reportDate: java.sql.Date)
userId, reportDate
1, 2016-01-05
1, 2016-01-31
1, 2016-02-15
2, 2016-01-10
2, 2016-02-01
Now I want to get summary by combining all previously used currencies by userId, reportDate and sum. The results should look like:
userId, reportDate, trasactionSummary
1, 2016-01-05, None
1, 2016-01-31, (USD -> 100)(GBP-> 150) // combined above 2 transactions less than 2016-01-31
1, 2016-02-15, (USD -> 150)(GBP-> 150)(JPN->10) // combined transactions less than 2016-02-15
2, 2016-01-10, None
2, 2016-02-01, (EURO-> 50) (GBP-> 100)
What is the best way to do this to do this? We have over 300 million transactions where each user can have up to 10,000 transactions.
The below snippet would achieve your requirement. Initial joining and aggregation is done via the Dataframe API of pyspark. Then the grouping of data (using reduceByKey) and final dataset preparation is done via RDD api since it is more suitable for such operations.
from datetime import datetime
from pyspark.sql.functions import udf
from pyspark.sql.types import DateType
from pyspark.sql import functions as F
df1 = spark.createDataFrame([(1,'2016-01-05','USD',100),
(1,'2016-01-09','GBP',150),
(1,'2016-02-01','USD',50),
(1,'2016-02-10','JPN',10),
(2,'2016-01-10','EURO',50),
(2,'2016-01-10','GBP',100)],['id', 'tdate', 'currency', 'value'])
df2 = spark.createDataFrame([(1,'2016-01-05'),
(1,'2016-01-31'),
(1,'2016-02-15'),
(2,'2016-01-10'),
(2,'2016-02-01')],['user_id', 'report_date'])
func = udf (lambda x: datetime.strptime(x, '%Y-%m-%d'), DateType()) ### function to convert string data type to date data type
df2 = df2.withColumn('tdate', func(df2.report_date))
df1 = df1.withColumn('tdate', func(df1.tdate))
result = df2.join(df1, (df1.id == df2.user_id) & (df1.tdate < df2.report_date), 'left_outer').select('user_id', 'report_date', 'currency', 'value').groupBy('user_id', 'report_date', 'currency').agg(F.sum('value').alias('value'))
data = result.rdd.map(lambda x: (x.user_id,x.report_date,x.currency,x.value)).keyBy(lambda x: (x[0],x[1])).mapValues(lambda x: filter(lambda x: bool(x),[(x[2],x[3]) if x[2] else None])).reduceByKey(lambda x,y: x + y).map(lambda x: (x[0][0],x[0][1], x[1]))
The final result generated is as shown below.
>>> spark.createDataFrame([ (x[0],x[1],str(x[2])) for x in data.collect()], ['id', 'date', 'values']).orderBy('id', 'date').show(20, False)
+---+----------+--------------------------------------------+
|id |date |values |
+---+----------+--------------------------------------------+
|1 |2016-01-05|[] |
|1 |2016-01-31|[(u'USD', 100), (u'GBP', 150)] |
|1 |2016-02-15|[(u'USD', 150), (u'GBP', 150), (u'JPN', 10)]|
|2 |2016-01-10|[] |
|2 |2016-02-01|[(u'EURO', 50), (u'GBP', 100)] |
+---+----------+--------------------------------------------+
In case some one needs in Scala
case class Transaction(id: String, date: java.sql.Date, currency:Option[String], value: Option[Long])
case class Report(id:String, date:java.sql.Date)
def toDate(date: String): java.sql.Date = {
val sf = new SimpleDateFormat("yyyy-MM-dd")
new java.sql.Date(sf.parse(date).getTime)
}
val allTransactions = Seq(
Transaction("1", toDate("2016-01-05"),Some("USD"),Some(100L)),
Transaction("1", toDate("2016-01-09"),Some("GBP"),Some(150L)),
Transaction("1",toDate("2016-02-01"),Some("USD"),Some(50L)),
Transaction("1",toDate("2016-02-10"),Some("JPN"),Some(10L)),
Transaction("2",toDate("2016-01-10"),Some("EURO"),Some(50L)),
Transaction("2",toDate("2016-01-10"),Some("GBP"),Some(100L))
)
val allReports = Seq(
Report("1",toDate("2016-01-05")),
Report("1",toDate("2016-01-31")),
Report("1",toDate("2016-02-15")),
Report("2",toDate("2016-01-10")),
Report("2",toDate("2016-02-01"))
)
val transections:Dataset[Transaction] = spark.createDataFrame(allTransactions).as[Transaction]
val reports: Dataset[Report] = spark.createDataFrame(allReports).as[Report]
val result = reports.alias("rp").join(transections.alias("tx"), (col("tx.id") === col("rp.id")) && (col("tx.date") < col("rp.date")), "left_outer")
.select("rp.id", "rp.date", "currency", "value")
.groupBy("rp.id", "rp.date", "currency").agg(sum("value"))
.toDF("id", "date", "currency", "value")
.as[Transaction]
val data = result.rdd.keyBy(x => (x.id , x.date))
.mapValues(x => if (x.currency.isDefined) collection.Map[String, Long](x.currency.get -> x.value.get) else collection.Map[String, Long]())
.reduceByKey((x,y) => x ++ y).map(x => (x._1._1, x._1._2, x._2))
.toDF("id", "date", "map")
.orderBy("id", "date")
Console output
+---+----------+--------------------------------------+
|id |date |map |
+---+----------+--------------------------------------+
|1 |2016-01-05|Map() |
|1 |2016-01-31|Map(GBP -> 150, USD -> 100) |
|1 |2016-02-15|Map(USD -> 150, GBP -> 150, JPN -> 10)|
|2 |2016-01-10|Map() |
|2 |2016-02-01|Map(GBP -> 100, EURO -> 50) |
+---+----------+--------------------------------------+
Given Table 1 with one column "x" of type String.
I want to create Table 2 with a column "y" that is an integer representation of the date strings given in "x".
Essential is to keep null values in column "y".
Table 1 (Dataframe df1):
+----------+
| x|
+----------+
|2015-09-12|
|2015-09-13|
| null|
| null|
+----------+
root
|-- x: string (nullable = true)
Table 2 (Dataframe df2):
+----------+--------+
| x| y|
+----------+--------+
| null| null|
| null| null|
|2015-09-12|20150912|
|2015-09-13|20150913|
+----------+--------+
root
|-- x: string (nullable = true)
|-- y: integer (nullable = true)
While the user-defined function (udf) to convert values from column "x" into those of column "y" is:
val extractDateAsInt = udf[Int, String] (
(d:String) => d.substring(0, 10)
.filterNot( "-".toSet)
.toInt )
and works, dealing with null values is not possible.
Even though, I can do something like
val extractDateAsIntWithNull = udf[Int, String] (
(d:String) =>
if (d != null) d.substring(0, 10).filterNot( "-".toSet).toInt
else 1 )
I have found no way, to "produce" null values via udfs (of course, as Ints can not be null).
My current solution for creation of df2 (Table 2) is as follows:
// holds data of table 1
val df1 = ...
// filter entries from df1, that are not null
val dfNotNulls = df1.filter(df1("x")
.isNotNull)
.withColumn("y", extractDateAsInt(df1("x")))
.withColumnRenamed("x", "right_x")
// create df2 via a left join on df1 and dfNotNull having
val df2 = df1.join( dfNotNulls, df1("x") === dfNotNulls("right_x"), "leftouter" ).drop("right_x")
Questions:
The current solution seems cumbersome (and probably not efficient wrt. performance). Is there a better way?
#Spark-developers: Is there a type NullableInt planned / avaiable, such that the following udf is possible (see Code excerpt ) ?
Code excerpt
val extractDateAsNullableInt = udf[NullableInt, String] (
(d:String) =>
if (d != null) d.substring(0, 10).filterNot( "-".toSet).toInt
else null )
This is where Optioncomes in handy:
val extractDateAsOptionInt = udf((d: String) => d match {
case null => None
case s => Some(s.substring(0, 10).filterNot("-".toSet).toInt)
})
or to make it slightly more secure in general case:
import scala.util.Try
val extractDateAsOptionInt = udf((d: String) => Try(
d.substring(0, 10).filterNot("-".toSet).toInt
).toOption)
All credit goes to Dmitriy Selivanov who've pointed out this solution as a (missing?) edit here.
Alternative is to handle null outside the UDF:
import org.apache.spark.sql.functions.{lit, when}
import org.apache.spark.sql.types.IntegerType
val extractDateAsInt = udf(
(d: String) => d.substring(0, 10).filterNot("-".toSet).toInt
)
df.withColumn("y",
when($"x".isNull, lit(null))
.otherwise(extractDateAsInt($"x"))
.cast(IntegerType)
)
Scala actually has a nice factory function, Option(), that can make this even more concise:
val extractDateAsOptionInt = udf((d: String) =>
Option(d).map(_.substring(0, 10).filterNot("-".toSet).toInt))
Internally the Option object's apply method is just doing the null check for you:
def apply[A](x: A): Option[A] = if (x == null) None else Some(x)
Supplementary code
With the nice answer of #zero323, I created the following code, to have user defined functions available that handle null values as described. Hope, it is helpful for others!
/**
* Set of methods to construct [[org.apache.spark.sql.UserDefinedFunction]]s that
* handle `null` values.
*/
object NullableFunctions {
import org.apache.spark.sql.functions._
import scala.reflect.runtime.universe.{TypeTag}
import org.apache.spark.sql.UserDefinedFunction
/**
* Given a function A1 => RT, create a [[org.apache.spark.sql.UserDefinedFunction]] such that
* * if fnc input is null, None is returned. This will create a null value in the output Spark column.
* * if A1 is non null, Some( f(input) will be returned, thus creating f(input) as value in the output column.
* #param f function from A1 => RT
* #tparam RT return type
* #tparam A1 input parameter type
* #return a [[org.apache.spark.sql.UserDefinedFunction]] with the behaviour describe above
*/
def nullableUdf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = {
udf[Option[RT],A1]( (i: A1) => i match {
case null => None
case s => Some(f(i))
})
}
/**
* Given a function A1, A2 => RT, create a [[org.apache.spark.sql.UserDefinedFunction]] such that
* * if on of the function input parameters is null, None is returned.
* This will create a null value in the output Spark column.
* * if both input parameters are non null, Some( f(input) will be returned, thus creating f(input1, input2)
* as value in the output column.
* #param f function from A1 => RT
* #tparam RT return type
* #tparam A1 input parameter type
* #tparam A2 input parameter type
* #return a [[org.apache.spark.sql.UserDefinedFunction]] with the behaviour describe above
*/
def nullableUdf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = {
udf[Option[RT], A1, A2]( (i1: A1, i2: A2) => (i1, i2) match {
case (null, _) => None
case (_, null) => None
case (s1, s2) => Some((f(s1,s2)))
} )
}
}