Mocking functions used inside spark udfs - scala

I had a user defined function as follows:
def myFunc(df: DataFrame, column_name: String, func:(String)=>String = myFunc) = {
val my_udf = udf{ (columnValue: String) => func(columnValue) }
df.withColumn("newColumn", my_udf(df(column_name)))
}
Now, while writing test cases, I wanted to mock myFunc. I'm using Mockito for testing.
So I created a mock function with
var fmock = mock[(String)=>String]
when(fmock(any[String])).thenReturn("default_string_write_value")
However, when I try to pass fmock, I get the following exception: java.io.NotSerializableException.
So I tried to make fmock serializable by declaring it as follows:
var fmock = mock[(String)=>String](Mockito.withSettings().serializable())
Now when I call a collect on the dataframe returned by myFunc, I get the following exception:
java.lang.ClassCastException: cannot assign instance of scala.collection.immutable.List$SerializationProxy to field org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$dependencies_ of type scala.collection.Seq in instance of org.apache.spark.rdd.MapPartitionsRDD
This error is obtained when I am trying to perform the collect operation.
Also, a weird thing is happening. The test case passes in IntelliJ (that's the IDE I'm using), but fails when I run it with Maven.
The test case is as follows (assume I already have a dataframe-TESTDF in place with columns-COL1,COL2,COL3):
val fmock = mock[(String)=>String](Mockito.withSettings().serializable())
when(fmock(any[String])).thenReturn("default_string_write_value")
val ansDf = myFunc(TESTDF, COL1, fmock)
var rowValueArray = ansDf.select("newColumn").collect.map(x => x(0).asInstanceOf[String])
assert(rowValueArray(0)=="default_string_write_value")
Note that I'm using Spark versions 2.3.3 and scala 2.11.8. The application is running fine, the issue only comes in with test cases (that too on maven).

Related

How do we write Unit test for UDF in scala

I have a following User defined function in scala
val returnKey: UserDefinedFunction = udf((key: String) => {
val abc: String = key
abc
})
Now, I want to unit test whether it is returning correct or not. How do I write the Unit test for it. This is what I tried.
class CommonTest extends FunSuite with Matchers {
test("Invalid String Test") {
val key = "Test Key"
val returnedKey = returnKey(col(key));
returnedKey should equal (key);
}
But since its a UDF the returnKey is a UDF function. I am not sure how to call it or how to test this particular scenario.
A UserDefinedFunction is effectively a wrapper around your Scala function that can be used to transform Column expressions. In other words, the UDF given in the question wraps a function of String => String to create a function of Column => Column.
I usually pick 1 of 2 different approaches to testing UDFs.
Test the UDF in a spark plan. In other words, create a test DataFrame and apply the UDF to it. Then collect the DataFrame and check its contents.
// In your test
val testDF = Seq("Test Key", "", null).toDS().toDF("s")
val result = testDF.select(returnKey(col("s"))).as[String].collect.toSet
result should be(Set("Test Key", "", null))
Notice that this lets us test all our edge cases in a single spark plan. In this case, I have included tests for the empty string and null.
Extract the Scala function being wrapped by the UDF and test it as you would any other Scala function.
def returnKeyImpl(key: String) = {
val abc: String = key
abc
}
val returnKey = udf(returnKeyImpl _)
Now we can test returnKeyImpl by passing in strings and checking the string output.
Which is better?
There is a trade-off between these two approaches, and my recommendation is different depending on the situation.
If you are doing a larger test on bigger datasets, I would recommend using testing the UDF in a Spark job.
Testing the UDF in a Spark job can raise issues that you wouldn't catch by only testing the underlying Scala function. For example, if your underlying Scala function relies on a non-serializable object, then Spark will be unable to broadcast the UDF to the workers and you will get an exception.
On the other hand, starting spark jobs in every unit test for every UDF can be quite slow. If you are only doing a small unit test, it will likely be faster to just test the underlying Scala function.

Scala UDF with multiple parameters used in Pyspark

I have a UDF written in Scala that I'd like to be able to call through a Pyspark session. The UDF takes two parameters, string column value and a second string parameter. I've been able to successfully call the UDF if it takes only a single parameter (column value). I'm struggling to call the UDF if there's multiple parameters required. Here's what I've been able to do so far in Scala and then through Pyspark:
Scala UDF:
class SparkUDFTest() extends Serializable {
def stringLength(columnValue: String, columnName: String): Int =
LOG.info("Column name is: " + columnName)
return columnValue.length
}
When using this in Scala, I've been able to register and use this UDF:
Scala main class:
val udfInstance = new SparkUDFTest()
val stringLength = spark.sqlContext.udf.register("stringlength", udfInstance.stringLength _)
val newDF = df.withColumn("name", stringLength(col("email"), lit("email")))
The above works successfully. Here's the attempt through Pyspark:
def testStringLength(colValue, colName):
package = "com.test.example.udf.SparkUDFTest"
udfInstance = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader().loadClass(testpackage).newInstance().stringLength().apply
return Column(udfInstance(_to_seq(sc, [colValue], _to_java_column), colName))
Call the UDF in Pyspark:
df.withColumn("email", testStringLength("email", lit("email")))
Doing the above and making some adjustments in Pyspark gives me following errors:
py4j.Py4JException: Method getStringLength([]) does not exist
or
java.lang.ClassCastException: com.test.example.udf.SparkUDFTest$$anonfun$stringLength$1 cannot be cast to scala.Function1
or
TypeError: 'Column' object is not callable
I was able to modify the UDF to take just a single parameter (the column value) and was able to successfully call it and get back a new Dataframe.
Scala UDF Class
class SparkUDFTest() extends Serializable {
def testStringLength(): UserDefinedFunction = udf(stringLength _)
def stringLength(columnValue: String): Int =
LOG.info("Column name is: " + columnName)
return columnValue.length
}
Updating Python code:
def testStringLength(colValue, colName):
package = "com.test.example.udf.SparkUDFTest"
udfInstance = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader().loadClass(testpackage).newInstance().testStringLength().apply
return Column(udfInstance(_to_seq(sc, [colValue], _to_java_column)))
The above works successfully. I'm still struggling to call the UDF if the UDF takes an extra parameter. How can the second parameter be passed to the UDF through in Pyspark?
I was able to resolve this by using currying. First registered the UDF as
def testStringLength(columnName): UserDefinedFunction = udf((colValue: String) => stringLength(colValue, colName)
Called the UDF
udfInstance = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader().loadClass(testpackage).newInstance().testStringLength("email").apply
df.withColumn("email", Column(udfInstance(_to_seq(sc, [col("email")], _to_java_column))))
This can be cleaned up a bit more but it's how I got it to work.
Edit: The reason I went with currying is because even when I was using 'lit' on the second argument that I wanted to pass in as a String to the UDF, I kept exerperiencing the "TypeError: 'Column' object is not callable" error. In Scala I did not experience this issue. I am not sure as to why this was happening in Pyspark. It's possible it could be due to some complication that may occur between the Python interpreter and the Scala code. Still unclear but currying works for me.

flatMap Compile Error found: TraversableOnce[String] required: TraversableOnce[String]

EDIT#2: This might be memory related. Logs are showing out-of-heap.
Yes, definitely memory related. Basically docker logs reports all the
spewage of out-of-heap from the java, but the jupyter web notebook does not pass that to the user. Instead the user gets kernel failures and occasional weird behavior like code not compiling correctly.
Spark 1.6, particularly docker run -d .... jupyter/all-spark-notebook
Would like to count accounts in a file of ~ 1 million transactions.
This is simple enough, it can be done without spark but I've hit an odd error trying with spark scala.
Input data is type RDD[etherTrans] where etherTrans is a custom type enclosing a single transaction: a timestamp, the from and to accounts, and the value transacted in ether.
class etherTrans(ts_in:Long, afrom_in:String, ato_in:String, ether_in: Float)
extends Serializable {
var ts: Long = ts_in
var afrom: String = afrom_in
var ato: String = ato_in
var ether: Float = ether_in
override def toString():String = ts.toString+","+afrom+","+ato+","+ether.toString
}
data:RDD[etherTrans] looks ok:
data.take(10).foreach(println)
etherTrans(1438918233,0xa1e4380a3b1f749673e270229993ee55f35663b4,0x5df9b87991262f6ba471f09758cde1c0fc1de734,3.1337E-14)
etherTrans(1438918613,0xbd08e0cddec097db7901ea819a3d1fd9de8951a2,0x5c12a8e43faf884521c2454f39560e6c265a68c8,19.9)
etherTrans(1438918630,0x63ac545c991243fa18aec41d4f6f598e555015dc,0xc93f2250589a6563f5359051c1ea25746549f0d8,599.9895)
etherTrans(1438918983,0x037dd056e7fdbd641db5b6bea2a8780a83fae180,0x7e7ec15a5944e978257ddae0008c2f2ece0a6090,100.0)
etherTrans(1438919175,0x3f2f381491797cc5c0d48296c14fd0cd00cdfa2d,0x4bd5f0ee173c81d42765154865ee69361b6ad189,803.9895)
etherTrans(1438919394,0xa1e4380a3b1f749673e270229993ee55f35663b4,0xc9d4035f4a9226d50f79b73aafb5d874a1b6537e,3.1337E-14)
etherTrans(1438919451,0xc8ebccc5f5689fa8659d83713341e5ad19349448,0xc8ebccc5f5689fa8659d83713341e5ad19349448,0.0)
etherTrans(1438919461,0xa1e4380a3b1f749673e270229993ee55f35663b4,0x5df9b87991262f6ba471f09758cde1c0fc1de734,3.1337E-14)
etherTrans(1438919491,0xf0cf0af5bd7d8a3a1cad12a30b097265d49f255d,0xb608771949021d2f2f1c9c5afb980ad8bcda3985,100.0)
etherTrans(1438919571,0x1c68a66138783a63c98cc675a9ec77af4598d35e,0xc8ebccc5f5689fa8659d83713341e5ad19349448,50.0)
This next function parses ok and is written this way because earlier attempts were complaining of type mismatch between Array[String] or List[String] and TraversableOnce[?]:
def arrow(e:etherTrans):TraversableOnce[String] = Array(e.afrom,e.ato)
But then using this function with flatMap to get an RDD[String] of all accounts fails.
val accts:RDD[String] = data.flatMap(arrow)
Name: Compile Error
Message: :38: error: type mismatch;
found : etherTrans(in class $iwC)(in class $iwC)(in class $iwC)(in class $iwC) => TraversableOnce[String]
required: etherTrans(in class $iwC)(in class $iwC)(in class $iwC)(in class $iwC) => TraversableOnce[String]
val accts:RDD[String] = data.flatMap(arrow)
^
StackTrace:
Make sure you scroll right to see it complain that TraversableOnce[String]
doesn't match TraversableOnce[String]
This must be a fairly common problem as a more blatant type mismatch comes up in Generate List of Pairs and while there isn't enough context, is suggested in I have a Scala List, how can I get a TraversableOnce?.
What's going on here?
EDIT: The issue reported above doesn't appear, and code works fine in older spark-shell, Spark 1.3.1 running standalone in a docker container. Errors are generated running in the spark 1.6 scala jupyter environment with the jupyter/all-spark-notebook docker container.
Also #zero323 says that this toy example:
val rdd = sc.parallelize(Seq((1L, "foo", "bar", 1))).map{ case (ts, fr, to, et) => new etherTrans(ts, fr, to, et)}
rdd.flatMap(arrow).collect
worked for him in the terminal spark-shell 1.6.0/spark 2.10.5 and also Scala 2.11.7 and Spark 1.5.2 work as well.
I think you should switch to use case classes, and it should work fine. Using "regular" classes, might case weird issues when serializing them, and it looks like all you need are value objects, so case classes look like a better fit for your use case.
An example:
case class EtherTrans(ts: Long, afrom: String, ato: String, ether: Float)
val source = sc.parallelize(Array(
(1L, "from1", "to1", 1.234F),
(2L, "from2", "to2", 3.456F)
))
val data = source.as[EtherTrans]
val data = source.map { l => EtherTrans(l._1, l._2, l._3, l._4) }
def arrow(e: EtherTrans) = Array(e.afrom, e.ato)
data.map(arrow).take(5)
/*
res3: Array[Array[String]] = Array(Array(from1, to1), Array(from2, to2))
*/
data.map(arrow).take(5)
// res3: Array[Array[String]] = Array(Array(from1, to1), Array(from2, to2))
If you need to, you can just create some method / object to generate your case classes.
If you don't really need the "toString" method for your logic, but just for "presentation", keep it out of the case class: you can always add it with a map operation before storing if or showing it.
Also, if you are in Spark 1.6.0 or higher, you could try using the DataSet API instead, that would look more or less like this:
val data = sqlContext.read.text("your_file").as[EtherTrans]
https://databricks.com/blog/2016/01/04/introducing-spark-datasets.html

Functions in a Scala constructor don't get called

I am using Scala with Play framework to create a webapplication. I have a class that is connecting to a Cassandra DB. I am using the constructor to connect to the database, but it doesn't work, in fact, I can't call any function function from the constructor. I'm new to Scala but from what I read on the tutorials on Scala it should work. Here is the code:
class Database
{
var cluster = Cluster.builder().addContactPoint(Play.application.configuration.getString("cassandra.node")).build()
var session = cluster.connect("acm")
}
I removed the rest of the class body for clarity.
These functions don't get called when I make an instance of the class and the variables will be left unnassigned when using them in another function. they work fine from a regular function. I also tested it with the logger, but nothing is written. So what is going on here?
scala> class A {
var x = 1
println(s"x = $x")
}
val a = new A
and I got the expected result
scala> x = 1
from what given in the context, I think it should work. If it doesn't, it should resides on some place else.

Is it possible to use scalap from a scala script?

I am using scalap to read out the field names of some case classes (as discussed in this question). Both the case classes and the code that uses scalap to analyze them have been compiled and put into a jar file on the classpath.
Now I want to run a script that uses this code, so I followed the instructions and came up with something like
::#!
#echo off
call scala -classpath *;./libs/* %0 %*
goto :eof
::!#
//Code relying on pre-compiled code that uses scalap
which does not work:
java.lang.ClassCastException: scala.None$ cannot be cast to scala.Option
at scala.tools.nsc.interpreter.ByteCode$.caseParamNamesForPath(ByteCode.
scala:45)
at scala.tools.nsc.interpreter.ProductCompletion.caseNames(ProductComple
tion.scala:22)
However, the code works just fine when I compile everything. I played around with additional scala options like -savecompiled, but this did not help. Is this a bug, or can't this work in principle? (If so, could someone explain why not? As I said, the case classes that shall be analyzed by scalap are compiled.)
Note: I use Scala 2.9.1-1.
EDIT
Here is what I am essentially trying to do (providing a simple way to create multiple instances of a case class):
//This is pre-compiled:
import scala.tools.nsc.interpreter.ProductCompletion
//...
trait MyFactoryTrait[T <: MyFactoryTrait[T] with Product] {
this: T =>
private[this] val copyMethod = this.getClass.getMethods.find(x => x.getName == "copy").get
lazy val productCompletion = new ProductCompletion(this)
/** The names of all specified fields. */
lazy val fieldNames = productCompletion.caseNames //<- provokes the exception (see above)
def createSeq(...):Seq[T] = {
val x = fieldNames map { ... } // <- this method uses the fieldNames value
//[...] invoke copyMethod to create instances
}
// ...
}
//This is pre-compiled too:
case class MyCaseClass(x: Int = 0, y: Int = 0) extends MyFactoryTrait[MyCaseClass]
//This should be interpreted (but crashes):
val seq = MyCaseClass().createSeq(...)
Note: I moved on to Scala 2.9.2, the error stays the same (so probably not a bug).
This is a bug in the compiler:
If you run the program inside an ide, for example Intellij IDEA the code is executed fine, however no fields names are found.
If you run it from command line using scala, you obtain the error you mentioned.
There is no way type-safe could should ever compiler and throw a runtime ClassCastException.
Please open a bug at https://issues.scala-lang.org/secure/Dashboard.jspa