Flink: How to convert the deprecated fold to aggregrate? - scala

I am following the quick start example of Flink: Monitoring the Wikipedia Edit Stream.
The example is in Java, and I am implementing it in Scala, as following:
/**
* Wikipedia Edit Monitoring
*/
object WikipediaEditMonitoring {
def main(args: Array[String]) {
// set up the execution environment
val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment
val edits: DataStream[WikipediaEditEvent] = env.addSource(new WikipediaEditsSource)
val result = edits.keyBy( _.getUser )
.timeWindow(Time.seconds(5))
.fold(("", 0L)) {
(acc: (String, Long), event: WikipediaEditEvent) => {
(event.getUser, acc._2 + event.getByteDiff)
}
}
result.print
// execute program
env.execute("Wikipedia Edit Monitoring")
}
}
However, the fold function in Flink is already deprecated, and the aggregate function is recommended.
But I did not find the example or tutorial about how to convert the deprecated fold to aggregrate.
Any idea how to do this? Probably not only by applying aggregrate.
UPDATE
I have another implementation as following:
/**
* Wikipedia Edit Monitoring
*/
object WikipediaEditMonitoring {
def main(args: Array[String]) {
// set up the execution environment
val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment
val edits: DataStream[WikipediaEditEvent] = env.addSource(new WikipediaEditsSource)
val result = edits
.map( e => UserWithEdits(e.getUser, e.getByteDiff) )
.keyBy( "user" )
.timeWindow(Time.seconds(5))
.sum("edits")
result.print
// execute program
env.execute("Wikipedia Edit Monitoring")
}
/** Data type for words with count */
case class UserWithEdits(user: String, edits: Long)
}
I also would like to know how to have the implementation using self-defined AggregateFunction.
UPDATE
I followed this documentation: AggregateFunction, but have the following question:
In the source code of Interface AggregateFunction for release 1.3, you will see add indeed returns void:
void add(IN value, ACC accumulator);
But for version 1.4 AggregateFunction, is is returning:
ACC add(IN value, ACC accumulator);
How should I handle this?
The Flink version I am using is 1.3.2 and the documentation for this version is not having AggregateFunction, but there is no release 1.4 in artifactory yet.

You will find some documentation for AggregateFunction in the Flink 1.4 docs, including an example.
The version included in 1.3.2 is limited to being used with mutable accumulator types, where the add operation modifies the accumulator. This has been fixed for Flink 1.4, but hasn't been released.

import org.apache.flink.api.common.functions.AggregateFunction
import org.apache.flink.streaming.api.scala._
import org.apache.flink.api.common.serialization.SimpleStringSchema
import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment}
import org.apache.flink.streaming.api.windowing.time.Time
import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer08
import org.apache.flink.streaming.connectors.wikiedits.{WikipediaEditEvent, WikipediaEditsSource}
class SumAggregate extends AggregateFunction[WikipediaEditEvent, (String, Int), (String, Int)] {
override def createAccumulator() = ("", 0)
override def add(value: WikipediaEditEvent, accumulator: (String, Int)) = (value.getUser, value.getByteDiff + accumulator._2)
override def getResult(accumulator: (String, Int)) = accumulator
override def merge(a: (String, Int), b: (String, Int)) = (a._1, a._2 + b._2)
}
object WikipediaAnalysis extends App {
val see: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment
val edits: DataStream[WikipediaEditEvent] = see.addSource(new WikipediaEditsSource())
val result: DataStream[(String, Int)] = edits
.keyBy(_.getUser)
.timeWindow(Time.seconds(5))
.aggregate(new SumAggregate)
// .fold(("", 0))((acc, event) => (event.getUser, acc._2 + event.getByteDiff))
result.print()
result.map(_.toString()).addSink(new FlinkKafkaProducer08[String]("localhost:9092", "wiki-result", new SimpleStringSchema()))
see.execute("Wikipedia User Edit Volume")
}

Related

JDBC sink for Flink fails with not serializable error

I am following https://ci.apache.org/projects/flink/flink-docs-master/dev/connectors/jdbc.html to use a mysql database as sink for Flink. The code compiles successfully but executing the job in a Flink cluster fails with
The program finished with the following exception:
The implementation of the AbstractJdbcOutputFormat is not serializable. The object probably contains or references non serializable fields.
org.apache.flink.api.java.ClosureCleaner.clean(ClosureCleaner.java:151)
org.apache.flink.api.java.ClosureCleaner.clean(ClosureCleaner.java:126)
org.apache.flink.api.java.ClosureCleaner.clean(ClosureCleaner.java:71)
org.apache.flink.streaming.api.environment.StreamExecutionEnvironment.clean(StreamExecutionEnvironment.java:1899)
org.apache.flink.streaming.api.datastream.DataStream.clean(DataStream.java:189)
org.apache.flink.streaming.api.datastream.DataStream.addSink(DataStream.java:1296)
org.apache.flink.streaming.api.scala.DataStream.addSink(DataStream.scala:1131)
Aggregator.Aggregator$.main(Aggregator.scala:81)
Here is the relevant part of the code:
object Aggregator {
#throws[Exception]
def main(args: Array[String]): Unit = {
[...]
val counts = stream.map { x => (
x.get("value").get("id").asInt(),
x.get("value").get("kpi").asDouble()
)}
.keyBy(0)
.timeWindow(Time.seconds(60))
.sum(1)
counts.print()
val statementBuilder: JdbcStatementBuilder[(Int, Double)] = (ps: PreparedStatement, t: (Int, Double)) => {
ps.setInt(1, t._1);
ps.setDouble(2, t._2);
};
val connection = new JdbcConnectionOptions.JdbcConnectionOptionsBuilder()
.withDriverName("mysql.Driver")
.withPassword("XXX")
.withUrl("jdbc:mysql://<DB_HOST>:3306/<DB_NAME>")
.withUsername("<USERNAME>")
.build();
val jdbcSink = JdbcSink.sink(
"INSERT INTO table (id, kpi) VALUES (?, ?)",
statementBuilder,
connection);
counts.addSink(jdbcSink)
env.execute("Aggregator")
}
}
I am not sure which part of the code is the problem here and how to debug. Unfortunately I also cannot find a reference implementation for a JDBC sink in Scala. Any help is appreciated!
What worked for me is explicitly creating JdbcStatementBuilder. Something like:
val statementBuilder: JdbcStatementBuilder[(Int, Double)] =
new JdbcStatementBuilder[(Int, Double)] {
override def accept(ps: PreparedStatement, t: (Int, Double)): Unit = {
ps.setInt(1, t._1)
ps.setDouble(2, t._2)
}
}

Context in processFunction for KeyedProcessFunction is null

I am trying to use KeyedProcessFunction, but the ctx: Context variable in processFunction inside my KeyedProcessFunction is returning null. Note that I'm using the default TimeCharacteristic which is ProcessingTime (so I'm not even setting it).
I found this on stackoverflow but that one is relating to EventTime and not ProcessingTime.
Following the exact example of https://ci.apache.org/projects/flink/flink-docs-stable/dev/stream/operators/process_function.html#example, I have created the following using Scala 2.11.12 and Flink 1.10, and I'm still getting the same error.
import org.apache.flink.streaming.api.scala._
import org.apache.flink.api.common.state.{ValueState, ValueStateDescriptor}
import org.apache.flink.api.java.tuple.Tuple
import org.apache.flink.streaming.api.functions.KeyedProcessFunction
import org.apache.flink.util.Collector
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
object example {
def main(args: Array[String]): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
env.setParallelism(1)
// the source data stream
val stream = env.socketTextStream("localhost", 9999).map(x => {
var splitCsv = x.stripLineEnd.split(",")
(splitCsv(0), splitCsv(1))
}
)
// apply the process function onto a keyed stream
val result: DataStream[Tuple2[String, Long]] = stream
.keyBy(0)
.process(new CountWithTimeoutFunction())
result.print()
env.execute("Flink Streaming Demo STDOUT")
}
/**
* The data type stored in the state
*/
case class CountWithTimestamp(key: String, count: Long, lastModified: Long)
/**
* The implementation of the ProcessFunction that maintains the count and timeouts
*/
class CountWithTimeoutFunction extends KeyedProcessFunction[Tuple, (String, String), (String, Long)] {
/** The state that is maintained by this process function */
lazy val state: ValueState[CountWithTimestamp] = getRuntimeContext
.getState(new ValueStateDescriptor[CountWithTimestamp]("myState", classOf[CountWithTimestamp]))
override def processElement(
value: (String, String),
ctx: KeyedProcessFunction[Tuple, (String, String), (String, Long)]#Context,
out: Collector[(String, Long)]): Unit = {
// initialize or retrieve/update the state
val current: CountWithTimestamp = state.value match {
case null =>
CountWithTimestamp(value._1, 1, ctx.timestamp)
case CountWithTimestamp(key, count, lastModified) =>
CountWithTimestamp(key, count + 1, ctx.timestamp)
}
// write the state back
state.update(current)
// schedule the next timer 60 seconds from the current event time
ctx.timerService.registerEventTimeTimer(current.lastModified + 60000)
}
override def onTimer(
timestamp: Long,
ctx: KeyedProcessFunction[Tuple, (String, String), (String, Long)]#OnTimerContext,
out: Collector[(String, Long)]): Unit = {
state.value match {
case CountWithTimestamp(key, count, lastModified) if (timestamp == lastModified + 60000) =>
out.collect((key, count))
case _ =>
}
}
}
}
Here is the error:
Caused by: java.lang.NullPointerException at
scala.Predef$.Long2long(Predef.scala:363) at
com.leidos.example$CountWithTimeoutFunction.processElement(example.scala:57)
at
com.leidos.example$CountWithTimeoutFunction.processElement(example.scala:42)
at
org.apache.flink.streaming.api.operators.KeyedProcessOperator.processElement(KeyedProcessOperator.java:85)
at
org.apache.flink.streaming.runtime.tasks.OneInputStreamTask$StreamTaskNetworkOutput.emitRecord(OneInputStreamTask.java:173)
at
org.apache.flink.streaming.runtime.io.StreamTaskNetworkInput.processElement(StreamTaskNetworkInput.java:151)
at
org.apache.flink.streaming.runtime.io.StreamTaskNetworkInput.emitNext(StreamTaskNetworkInput.java:128)
at
org.apache.flink.streaming.runtime.io.StreamOneInputProcessor.processInput(StreamOneInputProcessor.java:69)
at
org.apache.flink.streaming.runtime.tasks.StreamTask.processInput(StreamTask.java:311)
at
org.apache.flink.streaming.runtime.tasks.mailbox.MailboxProcessor.runMailboxLoop(MailboxProcessor.java:187)
at
org.apache.flink.streaming.runtime.tasks.StreamTask.runMailboxLoop(StreamTask.java:487)
at
org.apache.flink.streaming.runtime.tasks.StreamTask.invoke(StreamTask.java:470)
at org.apache.flink.runtime.taskmanager.Task.doRun(Task.java:707) at
org.apache.flink.runtime.taskmanager.Task.run(Task.java:532) at
java.lang.Thread.run(Thread.java:748)
Any idea of what am I doing wrong?
Thank you in advance!
The problem is that you are accessing in line 57 the timestamp field of the Context. This field is null if you are using ProcessingTime or if you don't specify a timestamp extractor when using EventTime.

Scala: how to modify the default metric for cross validation

I find te code below on this site:
https://spark.apache.org/docs/2.3.1/ml-tuning.html
// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric
// is areaUnderROC.
val cv = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(new BinaryClassificationEvaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(2) // Use 3+ in practice
.setParallelism(2) // Evaluate up to 2 parameter settings in parallel
As they said the default metric for BinaryClassificationEvaluator is "AUC".
How can I do to change this default metric to F1-score?
I tried:
// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric
// is areaUnderROC.
val cv = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(new BinaryClassificationEvaluator.setMetricName("f1"))
.setEstimatorParamMaps(paramGrid)
.setNumFolds(2) // Use 3+ in practice
.setParallelism(2) // Evaluate up to 2 parameter settings in parallel
But I got some errors...
I search on many sites but I did not find the solution...
setMetricName only accepts "areaUnderPR" or "areaUnderROC". You will need to write your own Evaluator; something like this:
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.{Dataset, functions => F}
class FScoreEvaluator(override val uid: String) extends Evaluator with HasPredictionCol with HasLabelCol{
def this() = this(Identifiable.randomUID("FScoreEvaluator"))
def evaluate(dataset: Dataset[_]): Double = {
val truePositive = F.sum(((F.col(getLabelCol) === 1) && (F.col(getPredictionCol) === 1)).cast(IntegerType))
val predictedPositive = F.sum((F.col(getPredictionCol) === 1).cast(IntegerType))
val actualPositive = F.sum((F.col(getLabelCol) === 1).cast(IntegerType))
val precision = truePositive / predictedPositive
val recall = truePositive / actualPositive
val fScore = F.lit(2) * (precision * recall) / (precision + recall)
dataset.select(fScore).collect()(0)(0).asInstanceOf[Double]
}
override def copy(extra: ParamMap): Evaluator = defaultCopy(extra)
}
Based on the answer of #gmds. Make sure Spark version >=2.3.
You can also follow the implementation of RegressionEvaluator in Spark to implement other custom evaluators.
I also added isLargerBetter so that the instantiated evaluator can be used in model selection (e.g. CV)
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.{Dataset, functions => F}
class WRmseEvaluator(override val uid: String) extends Evaluator with HasPredictionCol with HasLabelCol with HasWeightCol {
def this() = this(Identifiable.randomUID("wrmseEval"))
def setPredictionCol(value: String): this.type = set(predictionCol, value)
def setLabelCol(value: String): this.type = set(labelCol, value)
def setWeightCol(value: String): this.type = set(weightCol, value)
def evaluate(dataset: Dataset[_]): Double = {
dataset
.withColumn("residual", F.col(getLabelCol) - F.col(getPredictionCol))
.select(
F.sqrt(F.sum(F.col(getWeightCol) * F.pow(F.col("residual"), 2)) / F.sum(getWeightCol))
)
.collect()(0)(0).asInstanceOf[Double]
}
override def copy(extra: ParamMap): Evaluator = defaultCopy(extra)
override def isLargerBetter: Boolean = false
}
The following is how to use it.
val wrmseEvaluator = new WRmseEvaluator()
.setLabelCol(labelColName)
.setPredictionCol(predColName)
.setWeightCol(weightColName)

How to ensure constant Avro schema generation and avoid the 'Too many schema objects created for x' exception?

I am experiencing a reproducible error while producing Avro messages with reactive kafka and avro4s. Once the identityMapCapacity of the client (CachedSchemaRegistryClient) is reached, serialization fails with
java.lang.IllegalStateException: Too many schema objects created for <myTopic>-value
This is unexpected, since all messages should have the same schema - they are serializations of the same case class.
val avroProducerSettings: ProducerSettings[String, GenericRecord] =
ProducerSettings(system, Serdes.String().serializer(),
avroSerde.serializer())
.withBootstrapServers(settings.bootstrapServer)
val avroProdFlow: Flow[ProducerMessage.Message[String, GenericRecord, String],
ProducerMessage.Result[String, GenericRecord, String],
NotUsed] = Producer.flow(avroProducerSettings)
val avroQueue: SourceQueueWithComplete[Message[String, GenericRecord, String]] =
Source.queue(bufferSize, overflowStrategy)
.via(avroProdFlow)
.map(logResult)
.to(Sink.ignore)
.run()
...
queue.offer(msg)
The serializer is a KafkaAvroSerializer, instantiated with a new CachedSchemaRegistryClient(settings.schemaRegistry, 1000)
Generating the GenericRecord:
def toAvro[A](a: A)(implicit recordFormat: RecordFormat[A]): GenericRecord =
recordFormat.to(a)
val makeEdgeMessage: (Edge, String) => Message[String, GenericRecord, String] = { (edge, topic) =>
val edgeAvro: GenericRecord = toAvro(edge)
val record = new ProducerRecord[String, GenericRecord](topic, edge.id, edgeAvro)
ProducerMessage.Message(record, edge.id)
}
The schema is created deep in the code (io.confluent.kafka.serializers.AbstractKafkaAvroSerDe#getSchema, invoked by io.confluent.kafka.serializers.AbstractKafkaAvroSerializer#serializeImpl) where I have no influence on it, so I have no idea how to fix the leak. Looks to me like the two confluent projects do not work well together.
The issues I have found here, here and here do not seem to address my use case.
The two workarounds for me are currently:
not use schema registry - not a long-term solution obviously
create custom SchemaRegistryClient not relying on object identity - doable but I would like to avoid creating more issues than by reimplementing
Is there a way to generate or cache a consistent schema depending on message/record type and use it with my setup?
edit 2017.11.20
The issue in my case was that each instance of GenericRecord carrying my message has been serialized by a different instance of RecordFormat, containing a different instance of the Schema. The implicit resolution here generated a new instance each time.
def toAvro[A](a: A)(implicit recordFormat: RecordFormat[A]): GenericRecord = recordFormat.to(a)
The solution was to pin the RecordFormat instance to a val and reuse it explicitly. Many thanks to https://github.com/heliocentrist for explaining the details.
original response:
After waiting for a while (also no answer for the github issue) I had to implement my own SchemaRegistryClient. Over 90% is copied from the original CachedSchemaRegistryClient, just translated into scala. Using a scala mutable.Map fixed the memory leak. I have not performed any comprehensive tests, so use at your own risk.
import java.util
import io.confluent.kafka.schemaregistry.client.rest.entities.{ Config, SchemaString }
import io.confluent.kafka.schemaregistry.client.rest.entities.requests.ConfigUpdateRequest
import io.confluent.kafka.schemaregistry.client.rest.{ RestService, entities }
import io.confluent.kafka.schemaregistry.client.{ SchemaMetadata, SchemaRegistryClient }
import org.apache.avro.Schema
import scala.collection.mutable
class CachingSchemaRegistryClient(val restService: RestService, val identityMapCapacity: Int)
extends SchemaRegistryClient {
val schemaCache: mutable.Map[String, mutable.Map[Schema, Integer]] = mutable.Map()
val idCache: mutable.Map[String, mutable.Map[Integer, Schema]] =
mutable.Map(null.asInstanceOf[String] -> mutable.Map())
val versionCache: mutable.Map[String, mutable.Map[Schema, Integer]] = mutable.Map()
def this(baseUrl: String, identityMapCapacity: Int) {
this(new RestService(baseUrl), identityMapCapacity)
}
def this(baseUrls: util.List[String], identityMapCapacity: Int) {
this(new RestService(baseUrls), identityMapCapacity)
}
def registerAndGetId(subject: String, schema: Schema): Int =
restService.registerSchema(schema.toString, subject)
def getSchemaByIdFromRegistry(id: Int): Schema = {
val restSchema: SchemaString = restService.getId(id)
(new Schema.Parser).parse(restSchema.getSchemaString)
}
def getVersionFromRegistry(subject: String, schema: Schema): Int = {
val response: entities.Schema = restService.lookUpSubjectVersion(schema.toString, subject)
response.getVersion.intValue
}
override def getVersion(subject: String, schema: Schema): Int = synchronized {
val schemaVersionMap: mutable.Map[Schema, Integer] =
versionCache.getOrElseUpdate(subject, mutable.Map())
val version: Integer = schemaVersionMap.getOrElse(
schema, {
if (schemaVersionMap.size >= identityMapCapacity) {
throw new IllegalStateException(s"Too many schema objects created for $subject!")
}
val version = new Integer(getVersionFromRegistry(subject, schema))
schemaVersionMap.put(schema, version)
version
}
)
version.intValue()
}
override def getAllSubjects: util.List[String] = restService.getAllSubjects()
override def getByID(id: Int): Schema = synchronized { getBySubjectAndID(null, id) }
override def getBySubjectAndID(subject: String, id: Int): Schema = synchronized {
val idSchemaMap: mutable.Map[Integer, Schema] = idCache.getOrElseUpdate(subject, mutable.Map())
idSchemaMap.getOrElseUpdate(id, getSchemaByIdFromRegistry(id))
}
override def getSchemaMetadata(subject: String, version: Int): SchemaMetadata = {
val response = restService.getVersion(subject, version)
val id = response.getId.intValue
val schema = response.getSchema
new SchemaMetadata(id, version, schema)
}
override def getLatestSchemaMetadata(subject: String): SchemaMetadata = synchronized {
val response = restService.getLatestVersion(subject)
val id = response.getId.intValue
val version = response.getVersion.intValue
val schema = response.getSchema
new SchemaMetadata(id, version, schema)
}
override def updateCompatibility(subject: String, compatibility: String): String = {
val response: ConfigUpdateRequest = restService.updateCompatibility(compatibility, subject)
response.getCompatibilityLevel
}
override def getCompatibility(subject: String): String = {
val response: Config = restService.getConfig(subject)
response.getCompatibilityLevel
}
override def testCompatibility(subject: String, schema: Schema): Boolean =
restService.testCompatibility(schema.toString(), subject, "latest")
override def register(subject: String, schema: Schema): Int = synchronized {
val schemaIdMap: mutable.Map[Schema, Integer] =
schemaCache.getOrElseUpdate(subject, mutable.Map())
val id = schemaIdMap.getOrElse(
schema, {
if (schemaIdMap.size >= identityMapCapacity)
throw new IllegalStateException(s"Too many schema objects created for $subject!")
val id: Integer = new Integer(registerAndGetId(subject, schema))
schemaIdMap.put(schema, id)
idCache(null).put(id, schema)
id
}
)
id.intValue()
}
}

How to create a custom Transformer from a UDF?

I was trying to create and save a Pipeline with custom stages. I need to add a column to my DataFrame by using a UDF. Therefore, I was wondering if it was possible to convert a UDF or a similar action into a Transformer?
My custom UDF looks like this and I'd like to learn how to do it using an UDF as a custom Transformer.
def getFeatures(n: String) = {
val NUMBER_FEATURES = 4
val name = n.split(" +")(0).toLowerCase
((1 to NUMBER_FEATURES)
.filter(size => size <= name.length)
.map(size => name.substring(name.length - size)))
}
val tokenizeUDF = sqlContext.udf.register("tokenize", (name: String) => getFeatures(name))
It is not a fully featured solution but your can start with something like this:
import org.apache.spark.ml.{UnaryTransformer}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
class NGramTokenizer(override val uid: String)
extends UnaryTransformer[String, Seq[String], NGramTokenizer] {
def this() = this(Identifiable.randomUID("ngramtokenizer"))
override protected def createTransformFunc: String => Seq[String] = {
getFeatures _
}
override protected def validateInputType(inputType: DataType): Unit = {
require(inputType == StringType)
}
override protected def outputDataType: DataType = {
new ArrayType(StringType, true)
}
}
Quick check:
val df = Seq((1L, "abcdef"), (2L, "foobar")).toDF("k", "v")
val transformer = new NGramTokenizer().setInputCol("v").setOutputCol("vs")
transformer.transform(df).show
// +---+------+------------------+
// | k| v| vs|
// +---+------+------------------+
// | 1|abcdef|[f, ef, def, cdef]|
// | 2|foobar|[r, ar, bar, obar]|
// +---+------+------------------+
You can even try to generalize it to something like this:
import org.apache.spark.sql.catalyst.ScalaReflection.schemaFor
import scala.reflect.runtime.universe._
class UnaryUDFTransformer[T : TypeTag, U : TypeTag](
override val uid: String,
f: T => U
) extends UnaryTransformer[T, U, UnaryUDFTransformer[T, U]] {
override protected def createTransformFunc: T => U = f
override protected def validateInputType(inputType: DataType): Unit =
require(inputType == schemaFor[T].dataType)
override protected def outputDataType: DataType = schemaFor[U].dataType
}
val transformer = new UnaryUDFTransformer("featurize", getFeatures)
.setInputCol("v")
.setOutputCol("vs")
If you want to use UDF not the wrapped function you'll have to extend Transformer directly and override transform method. Unfortunately majority of the useful classes is private so it can be rather tricky.
Alternatively you can register UDF:
spark.udf.register("getFeatures", getFeatures _)
and use SQLTransformer
import org.apache.spark.ml.feature.SQLTransformer
val transformer = new SQLTransformer()
.setStatement("SELECT *, getFeatures(v) AS vs FROM __THIS__")
transformer.transform(df).show
// +---+------+------------------+
// | k| v| vs|
// +---+------+------------------+
// | 1|abcdef|[f, ef, def, cdef]|
// | 2|foobar|[r, ar, bar, obar]|
// +---+------+------------------+
I initially tried to extend the Transformer and UnaryTransformer abstracts but encountered trouble with my application being unable to reach DefaultParamsWriteable.As an example that may be relevant to your problem, I created a simple term normalizer as a UDF following along from this example. My goal is to match terms against patterns and sets to replace them with generic terms. For example:
"\b[A-Z0-9._%+-]+#[A-Z0-9.-]+\.[A-Z]{2,}\b".r -> "emailaddr"
This is the class
import scala.util.matching.Regex
class TermNormalizer(normMap: Map[Any, String]) {
val normalizationMap = normMap
def normalizeTerms(terms: Seq[String]): Seq[String] = {
var termsUpdated = terms
for ((term, idx) <- termsUpdated.view.zipWithIndex) {
for (normalizer <- normalizationMap.keys: Iterable[Any]) {
normalizer match {
case (regex: Regex) =>
if (!regex.findFirstIn(term).isEmpty) termsUpdated =
termsUpdated.updated(idx, normalizationMap(regex))
case (set: Set[String]) =>
if (set.contains(term)) termsUpdated =
termsUpdated.updated(idx, normalizationMap(set))
}
}
}
termsUpdated
}
}
I use it like this:
val testMap: Map[Any, String] = Map("hadoop".r -> "elephant",
"spark".r -> "sparky", "cool".r -> "neat",
Set("123", "456") -> "set1",
Set("789", "10") -> "set2")
val testTermNormalizer = new TermNormalizer(testMap)
val termNormalizerUdf = udf(testTermNormalizer.normalizeTerms(_: Seq[String]))
val trainingTest = sqlContext.createDataFrame(Seq(
(0L, "spark is cool 123", 1.0),
(1L, "adsjkfadfk akjdsfhad 456", 0.0),
(2L, "spark rocks my socks 789 10", 1.0),
(3L, "hadoop is cool 10", 0.0)
)).toDF("id", "text", "label")
val testTokenizer = new Tokenizer()
.setInputCol("text")
.setOutputCol("words")
val tokenizedTrainingTest = testTokenizer.transform(trainingTest)
println(tokenizedTrainingTest
.select($"id", $"text", $"words", termNormalizerUdf($"words"), $"label").show(false))
Now that I read the question a little closer, it sounds like you're asking how to avoid doing it this way lol. Anyways, I'll still post it in case someone in the future is looking for an easy way to apply a transformer-ish like functionality
If you wish to make the transformer writable as well, then you can re-implement the traits such as HasInputCol in the sharedParams library in a public package of your choice and then use them with DefaultParamsWritable trait to make the transformer persistable.
This way you can also avoid having to place part of your code inside the spark core ml packages but you kind of maintain a parallel set of params in your own package. This isnt really a problem given they hardly ever change.
But do track the bug in their JIRA board here that asks for some of the common sharedParams to be made public instead of private to the ml so that people can directly use those from outside classes.