I want to unit test code that read DataFrame from RDBMS using sparkSession.read.jdbc(...). But I did't find a way how to mock DataFrameReader to return dummy DataFrame for test.
Code example:
object ConfigurationLoader {
def readTable(tableName: String)(implicit spark: SparkSession): DataFrame = {
spark.read
.format("jdbc")
.option("url", s"$postgresUrl/$postgresDatabase")
.option("dbtable", tableName)
.option("user", postgresUsername)
.option("password", postgresPassword)
.option("driver", postgresDriver)
.load()
}
def loadUsingFilter(dummyFilter: String*)(implicit spark: SparkSession): DataFrame = {
readTable(postgresFilesTableName)
.where(col("column").isin(fileTypes: _*))
}
}
And second problem - to mock scala object, looks like I need to use other approach to create such service.
In my opinion, unit tests are not meant to test database connections. This should be done in integration tests to check that all the parts work together. Unit tests are just meant to test your functional logic, and not spark's ability to read from a database.
This is why I would design your code slightly differently and do just that, without caring about the DB.
/** This, I don't test. I trust spark.read */
def readTable(tableName: String)(implicit spark: SparkSession): DataFrame = {
spark.read
.option(...)
...
.load()
// Nothing more
}
/** This I test, this is my logic. */
def transform(df : DataFrame, dummyFilter: String*): DataFrame = {
df
.where(col("column").isin(fileTypes: _*))
}
Then I use the code this way in production.
val source = readTable("...")
val result = transform(source, filter)
And now transform, that contains my logic, is easy to test. In case you wonder how to create dummy dataframes, one way I like is this:
val df = Seq((1, Some("a"), true), (2, Some("b"), false),
(3, None, true)).toDF("x", "y", "z")
// and the test
val result = transform(df, filter)
result should be ...
If you want to test sparkSession.read.jdbc(...), you can play with in-memory H2 database. I do it sometimes when I'm writing learning tests. You can find an example here: https://github.com/bartosz25/spark-scala-playground/blob/d3cad26ff236ae78884bdeb300f2e59a616dc479/src/test/scala/com/waitingforcode/sql/LoadingDataTest.scala Please note however that you may encounter some subtle differences with "real" RDBMS.
On the other side, you can better separate the concerns of the code and create the DataFrame differently, for instance with toDF(...) method. You can find an example here: https://github.com/bartosz25/spark-scala-playground/blob/77ea416d2493324ddd6f3f2be42122855596d238/src/test/scala/com/waitingforcode/sql/CorrelatedSubqueryTest.scala
Finally and IMO, if you have to mock DataFrameReader, it means that maybe there is something to do with the code separation. For instance, you can put all your filters inside a Filters object and test each filter separately. Same for mapping or aggregation functions. 2 years ago I wrote a blog post about testing Apache Spark - https://www.waitingforcode.com/apache-spark/testing-spark-applications/read It describes RDD API but the idea of separating concerns is the same.
Updated:
object Filters {
def isInFileTypes(inputDataFrame: DataFrame, fileTypes: Seq[String]): DataFrame = {
inputDataFrame.where(col("column").isin(fileTypes: _*))
}
}
object ConfigurationLoader {
def readTable(tableName: String)(implicit spark: SparkSession): DataFrame = {
val input = spark.read
.format("jdbc")
.option("url", s"$postgresUrl/$postgresDatabase")
.option("dbtable", tableName)
.option("user", postgresUsername)
.option("password", postgresPassword)
.option("driver", postgresDriver)
.load()
Filters.isInFileTypes(input, Seq("txt", "doc")
}
And with that you can test your filtering logic whatever you want :) If you have more filters and want to test them, you can also combine them in a single method, pass any DataFrameyou want and voilà :)
You shouldn't test the .load() unless you have very good reasons to do so. It's Apache Spark internal logic, already tested.
Update, answer for:
So, now I am able to test filters, but how to make sure that readTable really use proper filter(sorry for thoroughness, it is just question of full coverage). Probably you have some simple approach how to mock scala object(it is actually mu second problem). – dytyniak 14 mins ago
object MyApp {
def main(args: Array[String]): Unit = {
val inputDataFrame = readTable(postgreSQLConnection)
val outputDataFrame = ProcessingLogic.generateOutputDataFrame(inputDataFrame)
}
}
object ProcessingLogic {
def generateOutputDataFrame(inputDataFrame: DataFrame): DataFrame = {
// Here you apply all needed filters, transformations & co
}
}
As you can see, no need to mock an object here. It seems redundant but it's not because you can test every filter in isolation thanks to Filters object and all your processing logic combined thanks to ProcessingLogic object (name only for example). And you can create your DataFrame in any valid way. The drawback is you will need define a schema explicitly or use case classes since in your PostgreSQL source, Apache Spark will resolve the schema automatically (I explained this here: https://www.waitingforcode.com/apache-spark-sql/schema-projection/read).
Write UT for all DataFrameWriter, DataFrameReader, DataStreamReader, DataStreamWriter
The sample test case using the above steps
Mock
Behavior
Assertion
Maven based dependencies
<groupId>org.scalatestplus</groupId>
<artifactId>mockito-3-4_2.11</artifactId>
<version>3.2.3.0</version>
<scope>test</scope>
<groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId>
<version>2.13.0</version>
<scope>test</scope>
Let’s use an example of a spark class where the source is Hive and the sink is JDBC
class DummySource extends SparkPipeline {
/**
* Method to read the source and create a Dataframe
*
* #param sparkSession : SparkSession
* #return : DataFrame
*/
override def read(spark: SparkSession): DataFrame = {
spark.read.table("Table_Name").filter("_2 > 1")
}
/**
* Method to transform the dataframe
*
* #param df : DataFrame
* #return : DataFrame
*/
override def transform(df: DataFrame): DataFrame = ???
/**
* Method to write/save the Dataframe to a target
*
* #param df : DataFrame
*
*/
override def write(df: DataFrame): Unit =
df.write.jdbc("url", "targetTableName", new Properties())
}
Mocking Read
test("Spark read table") {
val dummySource = new DummySource()
val sparkSession = SparkSession
.builder()
.master("local[*]")
.appName("mocking spark test")
.getOrCreate()
val testData = Seq(("one", 1), ("two", 2))
val df = sparkSession.createDataFrame(testData)
df.show()
val mockDataFrameReader = mock[DataFrameReader]
val mockSpark = mock[SparkSession]
when(mockSpark.read).thenReturn(mockDataFrameReader)
when(mockDataFrameReader.table("Table_Name")).thenReturn(df)
dummySource.read(mockSpark).count() should be(1)
}
Mocking Write
test("Spark write") {
val dummySource = new DummySource()
val mockDf = mock[DataFrame]
val mockDataFrameWriter = mock[DataFrameWriter[Row]]
when(mockDf.write).thenReturn(mockDataFrameWriter)
when(mockDataFrameWriter.mode(SaveMode.Append)).thenReturn(mockDataFrameWriter)
doNothing().when(mockDataFrameWriter).jdbc("url", "targetTableName", new Properties())
dummySource.write(df = mockDf)
}
Streaming code in ref
Ref : https://medium.com/walmartglobaltech/spark-mocking-read-readstream-write-and-writestream-b6fe70761242
Related
I need to execute some functions based on the values that I receive from topics. I'm currently using ForeachWriter to convert all the topics to a List.
Now, I want to pass this List as a parameter to the methods.
This is what I have so far
def doA(mylist: List[String]) = { //something for A }
def doB(mylist: List[String]) = { //something for B }
Ans this is how I call my streaming queries
//{"s":"a","v":"2"}
//{"s":"b","v":"3"}
val readTopics = spark.readStream.format("kafka").option("kafka.bootstrap.servers", "localhost:9092").option("subscribe", "myTopic").load()
val schema = new StructType()
.add("s",StringType)
.add("v",StringType)
val parseStringDF = readTopics.selectExpr("CAST(value AS STRING)")
val parseDF = parseStringDF.select(from_json(col("value"), schema).as("data"))
.select("data.*")
parseDF.writeStream
.format("console")
.outputMode("append")
.start()
//fails here
val listOfTopics = parseDF.select("s").map(row => (row.getString(0))).collect.toList
//unable to call the below methods
for (t <- listOfTopics ){
if(t == "a")
doA(listOfTopics)
else if (t == "b")
doB(listOfTopics)
else
println("do nothing")
}
spark.streams.awaitAnyTermination()
Questions:
How can I call a stand-alone (non-streaming) method in a streaming job?
I cannot use ForeachWriter here as I want to pass a SparkSession to methods and since SparkSession is not serializable, I cannot use ForeachWriter. What are the alternatives to call the methods doA and doB in parallel?
If you want to be able to collect data to a local Spark driver/executor, you need to use parseDF.write.foreachBatch, i.e. using a ForEachWriter
It's unclear what you need the SparkSession for within your two methods, but since they are working on non-Spark datatypes, you probably shouldn't be using a SparkSession instance, anyway
Alternatively, you should .select() and filter your topic column, then apply the functions to two "topic-a" and "topic-b" dataframes, thus parallelizing the workload. Otherwise, you would be better off just using regular KafkaConsumer from kafka-clients or kafka-streams rather than Spark
I want to test a method we have that is formatted something like this:
def extractTable( spark: SparkSession, /* unrelated other parameters */ ): DataFrame = {
// Code before that I want to test
val df = spark.read
.format("jdbc")
.option("url", "URL")
.option("driver", "<Driver>")
.option("fetchsize", "1000")
.option("dbtable", "select * from whatever")
.load()
// Code after that I want to test
}
And I am trying to make stubs of the spark object, and the DataFrameReader objects that the read and option methods return:
val sparkStub = stub[ SparkSession ]
val dataFrameReaderStub = stub[ DataFrameReader ]
( dataFrameReaderStub.format _).when(*).returning( dataFrameReaderStub ) // Works
( dataFrameReaderStub.option _).when(*, *).returning( dataFrameReaderStub ) // Error
( dataFrameReaderStub.load _).when(*).returning( ??? ) // Return a dataframe // Error
( sparkStub.read _).when().returning( dataFrameReaderStub )
But I am getting an error on dataFrameReaderStub.option and dataFrameReaderStub.load that says "Cannot resolve symbol option" and "Cannot resolve symbol load". But these methods definitely exist on the object that spark.read returns.
How can I resolve this error, or is there a better way to mock/test the code I have?
I would suggest you look at this library for testing Spark code: https://github.com/holdenk/spark-testing-base
Mix in this with your test suite: https://github.com/holdenk/spark-testing-base/wiki/SharedSparkContext
...or alternatively, spin up your own SparkSession with a local[2] master. and load the test data from csv/parquet/json.
Mocking Spark classes will be quite painful and probably not a success. I am speaking from experience here, both working for a long time with Spark, and maintaining ScalaMock as a library.
You are better off using Spark in your tests, but not against the real datasources.
Instead, load the test data from csv/parquet/json, or programatically generate it (if it contains timestamps and such).
I've a csv which I need to read as Dataset. The csv is having 140 columns and it doesn't have a header.
I created a schema with StructType(Seq(StructFiled(...), Seq(StructFiled(...), ...)) and the code to read that is as follows:-
object dataParser {
def getData(inputPath: String, delimeter: String)(implicit spark: SparkSession): Dataset[MyCaseClass] = {
val parsedData: Dataset[MyCaseClass] = spark.read
.option("header", "false")
.option("delimeter", "delimeter")
.option("inferSchema", "true")
.schema(mySchema)
.load(inputPath)
.as[MyCaseClass]
parsedData
}
}
And the case class I created is like:-
case class MycaseClass(
mycaseClass1: MyCaseClass1,
mycaseClass2: MyCaseClass2,
mycaseClass3: MyCaseClass3,
mycaseClass4: MyCaseClass4,
mycaseClass5: MyCaseClass5,
mycaseClass6: MyCaseClass6,
mycaseClass7: MyCaseClass7,
)
MyCaseClass1(
first 20 columns of csv: it's datatypes
)
MyCaseClass2(
next 20 columns of csv: it's datatypes
)
and so on.
But when I'm trying to compile it, it gives me an error as below:-
Unable to find encoder for type stored in a Dataset. Primitive types (Int, String, etc) and Product types (case classes) are supported by importing spark.implicits._ Support for serializing other types will be added in future releases.
[error] .as[myCaseClass]
I'm calling this from my Scala App as :-
object MyTestApp{
def main(args: Array[String]): Unit ={
implicit val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
import spark.implicits._
run(args)
}
def run(args: Array[String])(implicit spark: SparkSession): Unit = {
val inputPath = args.get("inputData")
val delimeter = Constants.delimeter
val myData = Dataparser.getData(inputPath, delimeter)
}
}
```
I'm not very sure about the approach also as I'm new to Dataset.
I saw multiple answers around this issue but they were mainly for very small no of columns which can be contained within the scope of a single case class and that too with header which makes this little simpler.
Any help would be really appreciated.
Thanks to all the viewers. Actually I found the issue. Posting the answer here so that other's who come across any such issue, will be able to get rid of this issue.
I needed to import the spark.implicits._ here
object dataParser {
def getData(inputPath: String, delimeter: String)(implicit spark: SparkSession): Dataset[MyCaseClass] = {
**import spark.implicits._**
val parsedData: Dataset[MyCaseClass] = spark.read
.option("header", "false")
.option("delimeter", "delimeter")
.option("inferSchema", "true")
.schema(mySchema)
.load(inputPath)
.as[MyCaseClass]
parsedData
}
}
I am learning more about Scala and Spark but have came stuck upon how to structure a function when I am using two tables as an input. My goal is to condense my code and utilise more functions. I am stuck on how I structure the functions when using two tables which I intend to join. My code without a function looks like:
val spark = SparkSession
.builder()
.master("local[*]")
.appName("XX1")
.getOrCreate()
val df1 = spark.sqlContext.read
.format("com.databricks.spark.csv")
.option("header", "true")
.option("delimiter", ",")
.option("inferSchema", "true")
.load("C:/Users/YYY/Documents/YYY.csv")
// df1: org.apache.spark.sql.DataFrame = [customerID: int, StoreID: int, FirstName: string, Surname: string, dateofbirth: int]
val df2 = spark.sqlContext.read
.format("com.databricks.spark.csv")
.option("header", "true")
.option("delimiter", ",")
.option("inferSchema", "true")
.load("C:/Users/XXX/Documents/XXX.csv")
df1.printSchema()
df1.createOrReplaceTempView("customerinfo")
df2.createOrReplaceTempView("customerorders")
def innerjoinA(df1: DataFrame, df2:Dataframe): Array[String]={
val innerjoindf= df1.join(df2,"customerId")
}
innerjoin().show()
}
My question is: how do I properly define the function for innerjoinA (&why?) and how exactly am I able to call it later in the program? And to a greater point, what else could I format as a function in this example?
you could do something like this.
Create A function to create Spark Session, and ReadCSV. This function if you need put into a different file if it's being called by other programs as well.
Just for join, no Need to crate a function. However, you could create to understand the business flow and give it a proper name.
import org.apache.spark.sql.{DataFrame, SparkSession}
def getSparkSession(unit: Unit) : SparkSession = {
val spark = SparkSession
.builder()
.master("local[*]")
.appName("XX1")
.getOrCreate()
spark
}
def readCSV(filePath: String): DataFrame = {
val df = getSparkSession().sqlContext.read
.format("com.databricks.spark.csv")
.option("header", "true")
.option("delimiter", ",")
.option("inferSchema", "true")
.load(filePath)
df
}
def getCustomerDetails(customer: DataFrame, details: DataFrame) : DataFrame = {
customer.join(details,"customerId")
}
val xxxDF = readCSV("C:/Users/XXX/Documents/XXX.csv")
val yyyDF = readCSV("C:/Users/XXX/Documents/YYY.csv")
getCustomerDetails(xxxDF, yyyDF).show()
The basic premise on grouping complex tranformations and joins in methods is sound. Only you know if a special innerjoin method makes sense in you usecase.
I usually define them as extension methods so I can chain them one after another.
trait/object DataFrameExtensions{
implicit class JoinDataFrameExtensions(df:DataFrame){
def innerJoin(df2:DataFrame):DataFrame = df.join(df2, Seq("ColumnName"))
}
}
And then later on in the code import/mixin the methods I want and call them on the DataFrame.
originalDataFrame.innerJoin(toBeJoinedDataFrame).show()
I prefer extension methods but you can also just declare a method DataFrame => DataFrame and use it in the .transform method already defined on the Dataset API.
def innerJoin(df2:DataFrame)(df1:DataFrame):DataFrame = df1.join(df2, Seq("ColumnName"))
val join = innerJoin(tobeJoinedDataFrame) _
originalDataFrame.transform(join).show()
I would like to know about the unit testing side of Spark Structured Streaming. My scenario is, I am getting data from Kafka and I am consuming it using Spark Structured Streaming and applying some transformations on top of the data.
I am not sure about how can I test this using Scala and Spark. Can someone tell me how to do unit testing in Structured Streaming using Scala. I am new to streaming.
tl;dr Use MemoryStream to add events and memory sink for the output.
The following code should help to get started:
import org.apache.spark.sql.execution.streaming.MemoryStream
implicit val sqlCtx = spark.sqlContext
import spark.implicits._
val events = MemoryStream[Event]
val sessions = events.toDS
assert(sessions.isStreaming, "sessions must be a streaming Dataset")
// use sessions event stream to apply required transformations
val transformedSessions = ...
val streamingQuery = transformedSessions
.writeStream
.format("memory")
.queryName(queryName)
.option("checkpointLocation", checkpointLocation)
.outputMode(queryOutputMode)
.start
// Add events to MemoryStream as if they came from Kafka
val batch = Seq(
eventGen.generate(userId = 1, offset = 1.second),
eventGen.generate(userId = 2, offset = 2.seconds))
val currentOffset = events.addData(batch)
streamingQuery.processAllAvailable()
events.commit(currentOffset.asInstanceOf[LongOffset])
// check the output
// The output is in queryName table
// The following code simply shows the result
spark
.table(queryName)
.show(truncate = false)
So, I tried to implement the answer from #Jacek and I couldn't find how to create the eventGen object and also test a small streaming application for write data on the console. I am also using MemoryStream and here I show a small example working.
The class that I testing is:
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.streaming.StreamingQuery
import org.apache.spark.sql.{DataFrame, SparkSession, functions}
object StreamingDataFrames {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder()
.appName(StreamingDataFrames.getClass.getSimpleName)
.master("local[2]")
.getOrCreate()
val lines = readData(spark, "socket")
val streamingQuery = writeData(lines)
streamingQuery.awaitTermination()
}
def readData(spark: SparkSession, source: String = "socket"): DataFrame = {
val lines: DataFrame = spark.readStream
.format(source)
.option("host", "localhost")
.option("port", 12345)
.load()
lines
}
def writeData(df: DataFrame, sink: String = "console", queryName: String = "calleventaggs", outputMode: String = "append"): StreamingQuery = {
println(s"Is this a streaming data frame: ${df.isStreaming}")
val shortLines: DataFrame = df.filter(functions.length(col("value")) >= 3)
val query = shortLines.writeStream
.format(sink)
.queryName(queryName)
.outputMode(outputMode)
.start()
query
}
}
I test only the writeData method. This is way I split the query into 2 methods.
Then here is the Spec to test the class. I use a SharedSparkSession class to facilitate the open and close of spark context. Like it is shown here.
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.execution.streaming.{LongOffset, MemoryStream}
import org.github.explore.spark.SharedSparkSession
import org.scalatest.funsuite.AnyFunSuite
class StreamingDataFramesSpec extends AnyFunSuite with SharedSparkSession {
test("spark structured streaming can read from memory socket") {
// We can import sql implicits
implicit val sqlCtx = sparkSession.sqlContext
import sqlImplicits._
val events = MemoryStream[String]
val queryName: String = "calleventaggs"
// Add events to MemoryStream as if they came from Kafka
val batch = Seq(
"this is a value to read",
"and this is another value"
)
val currentOffset = events.addData(batch)
val streamingQuery = StreamingDataFrames.writeData(events.toDF(), "memory", queryName)
streamingQuery.processAllAvailable()
events.commit(currentOffset.asInstanceOf[LongOffset])
val result: DataFrame = sparkSession.table(queryName)
result.show
streamingQuery.awaitTermination(1000L)
assertResult(batch.size)(result.count)
val values = result.take(2)
assertResult(batch(0))(values(0).getString(0))
assertResult(batch(1))(values(1).getString(0))
}
}