Mocking SparkSession for unit testing - scala

I have a method in my spark application that loads the data from a MySQL database. the method looks something like this.
trait DataManager {
val session: SparkSession
def loadFromDatabase(input: Input): DataFrame = {
session.read.jdbc(input.jdbcUrl, s"(${input.selectQuery}) T0",
input.columnName, 0L, input.maxId, input.parallelism, input.connectionProperties)
}
}
The method does nothing else other than executing jdbc method and loads data from the database. How can I test this method? The standard approach is to create a mock of the object session which is an instance of SparkSession. But since SparkSession has a private constructor I was not able to mock it using ScalaMock.
The main ask here is that my function is a pure side-effecting function (the side-effect being pull data from relational database) and how can i unit test this function given that I have issues mocking SparkSession.
So is there any way I can mock SparkSession or any other better way than mocking to test this method?

In your case I would recommend not to mock the SparkSession. This would more or less mock the entire function (which you could do anyways). If you want to test this function my suggestion would be to run an embeded database (like H2) and use a real SparkSession. To do this you need to provide the SparkSession to your DataManager.
Untested sketch:
Your code:
class DataManager (session: SparkSession) {
def loadFromDatabase(input: Input): DataFrame = {
session.read.jdbc(input.jdbcUrl, s"(${input.selectQuery}) T0",
input.columnName, 0L, input.maxId, input.parallelism, input.connectionProperties)
}
}
Your test-case:
class DataManagerTest extends FunSuite with BeforeAndAfter {
override def beforeAll() {
Connection conn = DriverManager.getConnection("jdbc:h2:~/test", "sa", "");
// your insert statements goes here
conn.close()
}
test ("should load data from database") {
val dm = DataManager(SparkSession.builder().getOrCreate())
val input = Input(jdbcUrl = "jdbc:h2:~/test", selectQuery="SELECT whateveryounedd FROM whereeveryouputit ")
val expectedData = dm.loadFromDatabase(input)
assert(//expectedData)
}
}

You can use mockito scala to mock SparkSession as shown in this article.

Related

How to mock a function to return a dummy value in scala? [duplicate]

This question already has answers here:
Mocking scala object
(4 answers)
Closed 2 years ago.
object ReadUtils {
def readData(sqlContext: SQLContext, fileType: FileType.Value): List[DataFrame] = {
//some logic
}
I am writing test for the execute function
import com.utils.ReadUtils.readData
class Logs extends Interface with NativeImplicits{
override def execute(sqlContext: SQLContext){
val inputDFs: List[DataFrame] = readData(sqlContext, FileType.PARQUET)
//some logic
}
how to mock the readData function to return a dummy value when writing test for execute function? Currently its calling the actual function.
test("Log Test") {
val df1 = //some dummy df
val sparkSession = SparkSession
.builder()
.master("local[*]")
.appName("test")
.getOrCreate()
sparkSession.sparkContext.setLogLevel("ERROR")
val log = new Logs()
val mockedReadUtils = mock[ReadUtils.type]
when(mockedReadUtils.readData(sparkSession.sqlContext,FileType.PARQUET)).thenReturn(df1)
log.execute(sparkSession.sqlContext)
The simple answer is - you can't do it. Objects are basically singletons in scala and you can't mock singletons - that's one of the reasons why they say that you should avoid singletons as much as possible.
You could mock sqlContext instead, and all its functions which are called in readData function.
As another approach, you could try to add Dependency Injection with some sort of Cake Pattern - https://medium.com/rahasak/scala-cake-pattern-e0cd894dae4e
trait DataReader {
def readData(sqlContext: SQLContext, fileType: FileType.Value): List[DataFrame]
}
trait RealDataReader {
def readData(sqlContext: SQLContext, fileType: FileType.Value): List[DataFrame] = {
// some code
}
}
trait MockedDataReader {
def readData(sqlContext: SQLContext, fileType: FileType.Value): List[DataFrame] = {
// some moking code
}
}
class Logs extends Interface with NativeImplicits with DataReader {
override def execute(sqlContext: SQLContext){
val inputDFs: List[DataFrame] = readData(sqlContext, FileType.PARQUET)
//some logic
}
}
class RealLogs extends Logs with RealDataReader // that would be the real class
class MockedLogs extends Logs with MockedDataReader // that would be the class for tests

What the best way to execute "not transformation" actions in elements of a Dataset

Newly coming in spark, I'm looking for a way to execute actions in all elements of a Dataset with Spark structured streaming:
I know this is a specific purpose case, what I want is iterate through all elements of Dataset, do an action on it, then continue to work with Dataset.
Example:
I got val df = Dataset[Person], I would like to be able to do something like:
def execute(df: Dataset[Person]): Dataset[Person] = {
df.foreach((p: Person) => {
someHttpClient.doRequest(httpPostRequest(p.asString)) // this is pseudo code / not compiling
})
df
}
Unfortunately, foreach is not available with structured streaming since I got error "Queries with streaming sources must be executed with writeStream.start"
I tried to use map(), but then error "Task not serializable" occured, I think because http request, or http client, is not serializable.
I know Spark is mostly use for filter and transform, but is there a way to handle well this specific use case ?
Thanks :)
val conf = new SparkConf().setMaster(“local[*]").setAppName(“Example")
val jssc = new JavaStreamingContext(conf, Durations.seconds(1)) // second option tell about The time interval at which streaming data will be divided into batches
Before concluding on whether a solution exists or not
Let’s as few questions
How does Spark Streaming work?
Spark Streaming receives live input data streams from input source and divides the data into batches, which are then processed by the Spark engine and final batch results are pushed down to downstream applications
How Does the batch execution start?
Spark does lazy evaluations on all the transformation applied on Dstream.it will apply transformation on actions (i.e only when you start streaming context)
jssc.start(); // Start the computation
jssc.awaitTermination(); // Wait for the computation to terminate.
Note : Each Batch of Dstream contains multiple partitions ( it is just like running sequence of spark-batch job until input source stop producing data)
So you can have custom logic like below.
dStream.foreachRDD(new VoidFunction[JavaRDD[Object]] {
override def call(t: JavaRDD[Object]): Unit = {
t.foreach(new VoidFunction[Object] {
override def call(t: Object): Unit = {
//pseudo code someHttpClient.doRequest(httpPostRequest(t.asString))
}
})
}
})
But again make sure your someHttpClient is serializable or
you can create that object As mentioned below.
dStream.foreachRDD(new VoidFunction[JavaRDD[Object]] {
override def call(t: JavaRDD[Object]): Unit = {
// create someHttpClient object
t.foreach(new VoidFunction[Object] {
override def call(t: Object): Unit = {
//pseudo code someHttpClient.doRequest(httpPostRequest(t.asString))
}
})
}
})
Related to Spark Structured Streaming
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.sql._;
import org.apache.spark.sql.streaming.StreamingQuery;
import org.apache.spark.sql.streaming.StreamingQuery
import java.util.Arrays;
import java.util.Iterator;
val spark = SparkSession
.builder()
.appName("example")
.getOrCreate();
val lines = spark.readStream.format("socket").option("host", "localhost").option("port", 9999).load(); // this is example source load copied from spark-streaming doc
lines.foreach(new ForeachFunction[Row] {
override def call(t: Row): Unit = {
//someHttpClient.doRequest(httpPostRequest(p.asString))
OR
// create someHttpClient object here and use it to tackle serialization errors
}
})
// Start running the query foreach and do mention downstream sink below/
val query = lines.writeStream.start
query.awaitTermination()

How to Call Methods in Scala

I have written some code in scala and created one method inside that now i want to call this method in another program but I'm not not getting any result after calling this method.
First Program
object helper_class {
def driver {
def main(args: Array[String]) {
val sparksession = SparkSession.builder().appName("app").enableHiveSupport().getOrCreate();
val filepath: String = args(0)
val d1 = spark.sql(s"load data inpath '${args(0)}' into table databasename.tablename")
//some more reusable code
}
}
}
second Program
import Packagename.helper_class.driver
object child_program {
def main(args: Array[String]) {
driver //I want to call this method from helper_class
}
}
if I'm removing def main(args: Array[String]) from 1st code its giving error near args(0) as args(0) not found
args(0) I am planning to pass as spark-submit
can someone please help me how should i Implement this.
The main method is your entry point to run a program. It can't be wrapped within another method. Thus, you can't wrap it within another method.
I think what you are trying to do, is to load some Spark setup code and then use that "driver" to do something else... This is my best guess on how that could work
object ChildProgram {
def main(args: Array[String]): Unit = {
val driver = Helper.driver(args)
// do something with the driver
}
}
object Helper {
def driver(args: Array[String]) = {
val sparksession = SparkSession.builder().appName("app").enableHiveSupport().getOrCreate()
val filepath: String = args(0)
val d1 = spark.sql(s"load data inpath '${args(0)}' into table databasename.tablename")
//some more reusable code
// I am assuming you are going to return the driver here
}
}
That said, I would highly recommend reading a bit on Scala before attempting to go down that route, because you are likely to face even more obstacles. If I am to recommend a resource, you can try the awesome book "Scala for the Impatient" which should get you up and running very quickly.

How to unit-test a class is serializable for spark?

I just found a bug on a class serialization in spark.
=> Now, I want to make a unit-test, but I don't see how?
Notes:
the failure appends in a (de)serialized object which has been broadcasted.
I want to test exactly what spark will do, to assert it will work once deployed
the class to serialize is a standard class (not case class) which extends Serializer
Looking into spark broadcast code, I found a way. But it uses private spark code, so it might becomes invalid if spark changes internally. But still it works.
Add a test class in a package starting by org.apache.spark, such as:
package org.apache.spark.my_company_tests
// [imports]
/**
* test data that need to be broadcast in spark (using kryo)
*/
class BroadcastSerializationTests extends FlatSpec with Matchers {
it should "serialize a transient val, which should be lazy" in {
val data = new MyClass(42) // data to test
val conf = new SparkConf()
// Serialization
// code found in TorrentBroadcast.(un)blockifyObject that is used by TorrentBroadcastFactory
val blockSize = 4 * 1024 * 1024 // 4Mb
val out = new ChunkedByteBufferOutputStream(blockSize, ByteBuffer.allocate)
val ser = new KryoSerializer(conf).newInstance() // Here I test using KryoSerializer, you can use JavaSerializer too
val serOut = ser.serializeStream(out)
Utils.tryWithSafeFinally { serOut.writeObject(data) } { serOut.close() }
// Deserialization
val blocks = out.toChunkedByteBuffer.getChunks()
val in = new SequenceInputStream(blocks.iterator.map(new ByteBufferInputStream(_)).asJavaEnumeration)
val serIn = ser.deserializeStream(in)
val data2 = Utils.tryWithSafeFinally { serIn.readObject[MyClass]() } { serIn.close() }
// run test on data2
data2.yo shouldBe data.yo
}
}
class MyClass(i: Int) extends Serializable {
#transient val yo = 1 to i // add lazy to make the test pass: not lazy transient val are not recomputed after deserialization
}

Scala: how can I perform actions when test are over?

I'm using scalatest_2.11 version 2.2.1. I'm trying to write test that run on a SparkContext. How can I initiate a SparkContext when the tests begin, take this Sc thorugh all the tests and then stop it when they are all done?
(I know that it suppose to stop by itself, but I'sd still like to do it myself)
You can use BeforeAndAfterAll from ScalaTest. Define a base trait that starts and stops the SparkContext and use it with your other tests.
trait SparkTest extends BeforeAndAfterAll {
self: Suite =>
#transient var sc: SparkContext = _
override def beforeAll {
val conf = new SparkConf().
setMaster("local[*]").
setAppName("test")
sc = new SparkContext(conf)
super.beforeAll()
}
override def afterAll: Unit = {
try {
sc.stop()
} finally {
super.afterAll
}
}
}
// Mix-in the trait with your tests like below.
class MyTest extends FunSuite with SparkTest {
test("my test") {
// you can access "sc" here.
}
}