How to create a custom Transformer from a UDF? - scala

I was trying to create and save a Pipeline with custom stages. I need to add a column to my DataFrame by using a UDF. Therefore, I was wondering if it was possible to convert a UDF or a similar action into a Transformer?
My custom UDF looks like this and I'd like to learn how to do it using an UDF as a custom Transformer.
def getFeatures(n: String) = {
val name = n.split(" +")(0).toLowerCase
.filter(size => size <= name.length)
.map(size => name.substring(name.length - size)))
val tokenizeUDF = sqlContext.udf.register("tokenize", (name: String) => getFeatures(name))

It is not a fully featured solution but your can start with something like this:
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
class NGramTokenizer(override val uid: String)
extends UnaryTransformer[String, Seq[String], NGramTokenizer] {
def this() = this(Identifiable.randomUID("ngramtokenizer"))
override protected def createTransformFunc: String => Seq[String] = {
getFeatures _
override protected def validateInputType(inputType: DataType): Unit = {
require(inputType == StringType)
override protected def outputDataType: DataType = {
new ArrayType(StringType, true)
Quick check:
val df = Seq((1L, "abcdef"), (2L, "foobar")).toDF("k", "v")
val transformer = new NGramTokenizer().setInputCol("v").setOutputCol("vs")
// +---+------+------------------+
// | k| v| vs|
// +---+------+------------------+
// | 1|abcdef|[f, ef, def, cdef]|
// | 2|foobar|[r, ar, bar, obar]|
// +---+------+------------------+
You can even try to generalize it to something like this:
import org.apache.spark.sql.catalyst.ScalaReflection.schemaFor
import scala.reflect.runtime.universe._
class UnaryUDFTransformer[T : TypeTag, U : TypeTag](
override val uid: String,
f: T => U
) extends UnaryTransformer[T, U, UnaryUDFTransformer[T, U]] {
override protected def createTransformFunc: T => U = f
override protected def validateInputType(inputType: DataType): Unit =
require(inputType == schemaFor[T].dataType)
override protected def outputDataType: DataType = schemaFor[U].dataType
val transformer = new UnaryUDFTransformer("featurize", getFeatures)
If you want to use UDF not the wrapped function you'll have to extend Transformer directly and override transform method. Unfortunately majority of the useful classes is private so it can be rather tricky.
Alternatively you can register UDF:
spark.udf.register("getFeatures", getFeatures _)
and use SQLTransformer
val transformer = new SQLTransformer()
.setStatement("SELECT *, getFeatures(v) AS vs FROM __THIS__")
// +---+------+------------------+
// | k| v| vs|
// +---+------+------------------+
// | 1|abcdef|[f, ef, def, cdef]|
// | 2|foobar|[r, ar, bar, obar]|
// +---+------+------------------+

I initially tried to extend the Transformer and UnaryTransformer abstracts but encountered trouble with my application being unable to reach DefaultParamsWriteable.As an example that may be relevant to your problem, I created a simple term normalizer as a UDF following along from this example. My goal is to match terms against patterns and sets to replace them with generic terms. For example:
"\b[A-Z0-9._%+-]+#[A-Z0-9.-]+\.[A-Z]{2,}\b".r -> "emailaddr"
This is the class
import scala.util.matching.Regex
class TermNormalizer(normMap: Map[Any, String]) {
val normalizationMap = normMap
def normalizeTerms(terms: Seq[String]): Seq[String] = {
var termsUpdated = terms
for ((term, idx) <- termsUpdated.view.zipWithIndex) {
for (normalizer <- normalizationMap.keys: Iterable[Any]) {
normalizer match {
case (regex: Regex) =>
if (!regex.findFirstIn(term).isEmpty) termsUpdated =
termsUpdated.updated(idx, normalizationMap(regex))
case (set: Set[String]) =>
if (set.contains(term)) termsUpdated =
termsUpdated.updated(idx, normalizationMap(set))
I use it like this:
val testMap: Map[Any, String] = Map("hadoop".r -> "elephant",
"spark".r -> "sparky", "cool".r -> "neat",
Set("123", "456") -> "set1",
Set("789", "10") -> "set2")
val testTermNormalizer = new TermNormalizer(testMap)
val termNormalizerUdf = udf(testTermNormalizer.normalizeTerms(_: Seq[String]))
val trainingTest = sqlContext.createDataFrame(Seq(
(0L, "spark is cool 123", 1.0),
(1L, "adsjkfadfk akjdsfhad 456", 0.0),
(2L, "spark rocks my socks 789 10", 1.0),
(3L, "hadoop is cool 10", 0.0)
)).toDF("id", "text", "label")
val testTokenizer = new Tokenizer()
val tokenizedTrainingTest = testTokenizer.transform(trainingTest)
.select($"id", $"text", $"words", termNormalizerUdf($"words"), $"label").show(false))
Now that I read the question a little closer, it sounds like you're asking how to avoid doing it this way lol. Anyways, I'll still post it in case someone in the future is looking for an easy way to apply a transformer-ish like functionality

If you wish to make the transformer writable as well, then you can re-implement the traits such as HasInputCol in the sharedParams library in a public package of your choice and then use them with DefaultParamsWritable trait to make the transformer persistable.
This way you can also avoid having to place part of your code inside the spark core ml packages but you kind of maintain a parallel set of params in your own package. This isnt really a problem given they hardly ever change.
But do track the bug in their JIRA board here that asks for some of the common sharedParams to be made public instead of private to the ml so that people can directly use those from outside classes.


Efficient way to collect HashSet during map operation on some Dataset

I have big dataset to transform one structure to another. During that phase I want also collect some info about computed field (quadkeys for given lat/longs). I dont want attach this info to every result row, since it would give a lot of duplication information and memory overhead. All I need is to know which particular quadkeys are touched by given coordinates. If there are any way to do it within one job to not iterate dataset twice?
def load(paths: Seq[String]): (Dataset[ResultStruct], Dataset[String]) = {
val df ="com.databricks.spark.csv").option("header", "true")
.option("delimiter", "\t")
val qkSet = mutable.HashSet.empty[String]
val result = => {
val id =
val points = toPoints(c.geom)
points.foreach(p => qkSet.add(Quadkey.get(, p.lon, 6).getId))
createResultStruct(id, points)
return result, //some dataset created from qkSet's from all executors
You could use accumulators
class SetAccumulator[T] extends AccumulatorV2[T, Set[T]] {
import scala.collection.JavaConverters._
private val items = new ConcurrentHashMap[T, Boolean]
override def isZero: Boolean = items.isEmpty
override def copy(): AccumulatorV2[T, Set[T]] = {
val other = new SetAccumulator[T]
override def reset(): Unit = items.clear()
override def add(v: T): Unit = items.put(v, true)
override def merge(
other: AccumulatorV2[T, Set[T]]): Unit = other match {
case setAccumulator: SetAccumulator[T] => items.putAll(setAccumulator.items)
override def value: Set[T] = items.keys().asScala.toSet
val df = Seq("foo", "bar", "foo", "foo").toDF("test")
val acc = new SetAccumulator[String]
spark.sparkContext.register(acc) {
case Row(str: String) =>
Set(bar, foo)
Note that map itself is lazy so something like count etc. is needed to actually force the calculation. Depending on the real use-case, another option would be to cache the data frame and just using plain SQL functions"test").distinct()

Evalutate complex type with quasiquote scala, unlifting

I need to compile function and then evaluate it with different parameters of type List[Map[String, AnyRef]].
I have the following code that does not compile with such the type but compiles with simple type like List[Int].
I found that there are just certain implementations of Liftable in scala.reflect.api.StandardLiftables.StandardLiftableInstances
import scala.reflect.runtime.universe
import scala.reflect.runtime.universe._
val tb = universe.runtimeMirror(getClass.getClassLoader).mkToolBox()
val functionWrapper =
object FunctionWrapper {
def makeBody(messages: List[Map[String, AnyRef]]) = Map.empty
val functionSymbol =
val list: List[Map[String, AnyRef]] = List(Map("1" -> "2"))
Getting compilation error for this, how can I make it work?
Error:(22, 38) Can't unquote List[Map[String,AnyRef]], consider using
... or providing an implicit instance of
The problem comes not from complicated type but from the attempt to use AnyRef. When you unquote some literal, it means you want the infrastructure to be able to create a valid syntax tree to create an object that would exactly match the object you pass. Unfortunately this is obviously not possible for all objects. For example, assume that you've passed a reference to Thread.currentThread() as a part of the Map. How it could possible work? Compiler is just not able to recreate such a complicated object (not to mention making it the current thread). So you have two obvious alternatives:
Make you argument also a Tree i.e. something like this
def testTree() = {
val tb = universe.runtimeMirror(getClass.getClassLoader).mkToolBox()
val functionWrapper =
| object FunctionWrapper {
| def makeBody(messages: List[Map[String, AnyRef]]) = Map.empty
| }
val functionSymbol =
//val list: List[Map[String, AnyRef]] = List(Map("1" -> "2"))
val list = q"""List(Map("1" -> "2"))"""
val res = tb.eval(q"$functionSymbol.makeBody($list)")
println(s"testTree = $res")
The obvious drawback of this approach is that you loose type safety at compile time and might need to provide a lot of context for the tree to work
Another approach is to not try to pass anything containing AnyRef to the compiler-infrastructure. It means you create some function-like Wrapper:
package so {
trait Wrapper {
def call(args: List[Map[String, AnyRef]]): Map[String, AnyRef]
and then make your generated code return a Wrapper instead of directly executing the logic and call the Wrapper from the usual Scala code rather than inside compiled code. Something like this:
def testWrapper() = {
val tb = universe.runtimeMirror(getClass.getClassLoader).mkToolBox()
val functionWrapper =
|object FunctionWrapper {
| import scala.collection._
| import so.Wrapper /* <- here probably different package :) */
| def createWrapper(): Wrapper = new Wrapper {
| override def call(args: List[Map[String, AnyRef]]): Map[String, AnyRef] = Map.empty
| }
| """.stripMargin
val functionSymbol = tb.define(tb.parse(functionWrapper).asInstanceOf[tb.u.ImplDef])
val list: List[Map[String, AnyRef]] = List(Map("1" -> "2"))
val tree: tb.u.Tree = q"$functionSymbol.createWrapper()"
val wrapper = tb.eval(tree).asInstanceOf[Wrapper]
val res =
println(s"testWrapper = $res")
P.S. I'm not sure what are you doing but beware of performance issues. Scala is a hard language to compile and thus it might easily take more time to compile your custom code than to run it. If performance becomes an issue you might need to use some other methods such as full-blown macro-code-generation or at least caching of the compiled code.

How to use countDistinct in Scala with Spark?

I've tried to use countDistinct function which should be available in Spark 1.5 according to DataBrick's blog. However, I got the following exception:
Exception in thread "main" org.apache.spark.sql.AnalysisException: undefined function countDistinct;
I've found that on Spark developers' mail list they suggest using count and distinct functions to get the same result which should be produced by countDistinct:
count(distinct <columnName>)
// Instead
Because I build aggregation expressions dynamically from the list of the names of aggregation functions I'd prefer to don't have any special cases which require different treating.
So, is it possible to unify it by:
registering new UDAF which will be an alias for count(distinct columnName)
registering manually already implemented in Spark CountDistinct function which is probably one from following import:
import org.apache.spark.sql.catalyst.expressions.{CountDistinctFunction, CountDistinct}
or do it in any other way?
Example (with removed some local references and unnecessary code):
import org.apache.spark.SparkContext
import org.apache.spark.sql.{Column, SQLContext, DataFrame}
import org.apache.spark.sql.functions._
import scala.collection.mutable.ListBuffer
class Flattener(sc: SparkContext) {
val sqlContext = new SQLContext(sc)
def flatTable(data: DataFrame, groupField: String): DataFrame = {
val flatteningExpressions =
flatMap(x => getFlatteningExpressions(x._1, x._2)).toList
data.groupBy(groupField).agg (
expr(s"count($groupField) as groupSize"),
private def getFlatteningExpressions(fieldName: String, fieldType: DType): List[Column] = {
val aggFuncs = getAggregationFunctons(fieldType) => expr(s"$f($fieldName) as ${fieldName}_$f"))
private def getAggregationFunctons(fieldType: DType): List[String] = {
val aggFuncs = new ListBuffer[String]()
if(fieldType == DType.NUMERIC) {
aggFuncs += ("avg", "min", "max")
if(fieldType == DType.CATEGORY) {
aggFuncs += "countDistinct"
countDistinct can be used in two different forms:
df.groupBy("A").agg(expr("count(distinct B)")
However, neither of these methods work when you want to use them on the same column with your custom UDAF (implemented as UserDefinedAggregateFunction in Spark 1.5):
// Assume that we have already implemented and registered StdDev UDAF
df.groupBy("A").agg(countDistinct("B"), expr("StdDev(B)"))
// Will cause
Exception in thread "main" org.apache.spark.sql.AnalysisException: StdDev is implemented based on the new Aggregate Function interface and it cannot be used with functions implemented based on the old Aggregate Function interface.;
Due to these limitation it looks that the most reasonable is implementing countDistinct as a UDAF what should allow to treat all functions in the same way as well as use countDistinct along with other UDAFs.
The example implementation can look like this:
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
class CountDistinct extends UserDefinedAggregateFunction{
override def inputSchema: StructType = StructType(StructField("value", StringType) :: Nil)
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = (buffer.getSeq[String](0).toSet + input.getString(0)).toSeq
override def bufferSchema: StructType = StructType(
StructField("items", ArrayType(StringType, true)) :: Nil
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = (buffer1.getSeq[String](0).toSet ++ buffer2.getSeq[String](0).toSet).toSeq
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = Seq[String]()
override def deterministic: Boolean = true
override def evaluate(buffer: Row): Any = {
override def dataType: DataType = IntegerType
Not sure if I really understood your problem, but this is an example for the countDistinct aggregated function:
val values = Array((1, 2), (1, 3), (2, 2), (1, 2))
val myDf = sc.parallelize(values).toDF("id", "foo")
import org.apache.spark.sql.functions.countDistinct
myDf.groupBy('id).agg(countDistinct('foo) as 'distinctFoo) show
| 1| 2|
| 2| 1|

Reduce two Scala methods, that only differ in one Object Type

I have the following two methods, using objects from Apache Spark.
def SVMModelScoring(sc: SparkContext, scoringDataset: String, modelFileName: String): RDD[(Double, Double)] = {
val model = SVMModel.load(sc, modelFileName)
val scoreAndLabels =
MLUtils.loadLibSVMFile(sc, scoringDataset).randomSplit(Array(0.1), seed = 11L)(0).map { point =>
val score = model.predict(point.features)
(score, point.label)
return scoreAndLabels
def DecisionTreeScoring(sc: SparkContext, scoringDataset: String, modelFileName: String): RDD[(Double, Double)] = {
val model = DecisionTreeModel.load(sc, modelFileName)
val scoreAndLabels =
MLUtils.loadLibSVMFile(sc, scoringDataset).randomSplit(Array(0.1), seed = 11L)(0).map { point =>
val score = model.predict(point.features)
(score, point.label)
return scoreAndLabels
My previous attempts to merge these functions have resulted in errors surround model.predict.
Is there a way I can use model as a parameter that is weakly typed in Scala?
Disclaimer - I've never used Apache Spark.
It looks to me like the only difference between the two methods is the way the model is instantiated. It's unfortunate that the two model instances don't actually share a common trait that provides predict(...) but we can still make this work by pulling out the part that changes - the scorer:
def scoreWith(sc: SparkContext, scoringDataset: String)(scorer: (Vector)=>Double): RDD[(Double, Double)] = {
MLUtils.loadLibSVMFile(sc, scoringDataset).randomSplit(Array(0.1), seed = 11L)(0).map { point =>
val score = scorer(point.features)
(score, point.label)
Now we can get the previous functionality with:
def svmScorer(sc: SparkContext, scoringDataset:String, modelFileName:String) =
scoreWith(sc: SparkContext, scoringDataset:String)(SVMModel.load(sc, modelFileName).predict)
def dtScorer(sc: SparkContext, scoringDataset:String, modelFileName:String) =
scoreWith(sc: SparkContext, scoringDataset:String)(DecisionTreeModel.load(sc, modelFileName).predict)

How do I provide basic configuration for a Scala application?

I am working on a small GUI application written in Scala. There are a few settings that the user will set in the GUI and I want them to persist between program executions. Basically I want a scala.collections.mutable.Map that automatically persists to a file when modified.
This seems like it must be a common problem, but I have been unable to find a lightweight solution. How is this problem typically solved?
I do a lot of this, and I use .properties files (it's idiomatic in Java-land). I keep my config pretty straight-forward by design, though. If you have nested config constructs you might want a different format like YAML (if humans are the main authors) or JSON or XML (if machines are the authors).
Here's some example code for loading props, manipulating as Scala Map, then saving as .properties again:
import java.util._
import scala.collection.JavaConverters._
val f = new File("")
// foo=bar
// baz=123
val props = new Properties
// Note: in real code make sure all these streams are
// closed carefully in try/finally
val fis = new InputStreamReader(new FileInputStream(f), "UTF-8")
println(props) // {baz=123, foo=bar}
val map = props.asScala // Get to Scala Map via JavaConverters
map("foo") = "42"
map("quux") = "newvalue"
println(map) // Map(baz -> 123, quux -> newvalue, foo -> 42)
println(props) // {baz=123, quux=newvalue, foo=42}
val fos = new OutputStreamWriter(new FileOutputStream(f), "UTF-8"), "")
Here's an example of using XML and a case class for reading a config. A real class can be nicer than a map. (You could also do what sbt and at least one project do, take the config as Scala source and compile it in; saving it is less automatic. Or as a repl script. I haven't googled, but someone must have done that.)
Here's the simpler code.
This version uses a case class:
case class PluginDescription(name: String, classname: String) {
def toXML: Node = {
object PluginDescription {
def fromXML(xml: Node): PluginDescription = {
// extract one field
def getField(field: String): Option[String] = {
val text = (xml \\ field).text.trim
if (text == "") None else Some(text)
def extracted = {
val name = "name"
val claas = "classname"
val vs = Map(name -> getField(name), claas -> getField(claas))
if (vs.values exists (_.isEmpty)) fail()
else PluginDescription(name = vs(name).get, classname = vs(claas).get)
def fail() = throw new RuntimeException("Bad plugin descriptor.")
// check the top-level tag
xml match {
case <plugin>{_*}</plugin> => extracted
case _ => fail()
This code reflectively calls the apply of a case class. The use case is that fields missing from config can be supplied by default args. No type conversions here. E.g., case class Config(foo: String = "bar").
// isn't it easier to write a quick loop to reflect the field names?
import scala.reflect.runtime.{currentMirror => cm, universe => ru}
import ru._
def fromXML(xml: Node): Option[PluginDescription] = {
def extract[A]()(implicit tt: TypeTag[A]): Option[A] = {
// extract one field
def getField(field: String): Option[String] = {
val text = (xml \\ field).text.trim
if (text == "") None else Some(text)
val apply = ru.newTermName("apply")
val module = ru.typeOf[A].typeSymbol.companionSymbol.asModule
val ts = module.moduleClass.typeSignature
val m = (ts member apply).asMethod
val im = cm reflect (cm reflectModule module).instance
val mm = im reflectMethod m
def getDefault(i: Int): Option[Any] = {
val n = ru.newTermName("apply$default$" + (i+1))
val m = ts member n
if (m == NoSymbol) None
else Some((im reflectMethod m.asMethod)())
def extractArgs(pss: List[List[Symbol]]): List[Option[Any]] =
pss.flatten.zipWithIndex map (p => getField( orElse getDefault(p._2))
val args = extractArgs(m.paramss)
if (args exists (!_.isDefined)) None
else Some(mm(args.flatten: _*).asInstanceOf[A])
// check the top-level tag
xml match {
case <plugin>{_*}</plugin> => extract[PluginDescription]()
case _ => None
XML has loadFile and save, it's too bad there seems to be no one-liner for Properties.
$ scala
Welcome to Scala version 2.10.0-RC5 (Java HotSpot(TM) 64-Bit Server VM, Java 1.7.0_06).
Type in expressions to have them evaluated.
Type :help for more information.
scala> import
scala> import java.util._
import java.util._
scala> import{StringReader, File=>JFile}
import{StringReader, File=>JFile}
scala> import scala.collection.JavaConverters._
import scala.collection.JavaConverters._
scala> val p = new Properties
p: java.util.Properties = {}
scala> p load new StringReader(
| (new File(new JFile(""))).slurp)
scala> p.asScala
res2: scala.collection.mutable.Map[String,String] = Map(foo -> bar)
As it all boils down to serializing a map / object to a file, your choices are:
classic serialization to Bytecode
serialization to XML
serialization to JSON (easy using Jackson, or Lift-JSON)
use of a properties file (ugly, no utf-8 support)
serialization to a proprietary format (ugly, why reinvent the wheel)
I suggest to convert Map to Properties and vice versa. "*.properties" files are standard for storing configuration in Java world, why not use it for Scala?
The common way are *. properties, *.xml, since scala supports xml natively, so it would be easier using xml config then in java.