How to unit test BroadcastProcessFunction in flink when processElement depends on broadcasted data - scala

I implemented a flink stream with a BroadcastProcessFunction. From the processBroadcastElement I get my model and I apply it on my event in processElement.
I don't find a way to unit test my stream as I don't find a solution to ensure the model is dispatched prior to the first event.
I would say there are two ways for achieving this:
1. Find a solution to have the model pushed in the stream first
2. Have the broadcast state filled with the model prio to the execution of the stream so that it is restored
I may have missed something, but I have not found an simple way to do this.
Here is a simple unit test with my issue:
import org.apache.flink.api.common.state.MapStateDescriptor
import org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction
import org.apache.flink.streaming.api.functions.sink.SinkFunction
import org.apache.flink.streaming.api.scala._
import org.apache.flink.util.Collector
import org.scalatest.Matchers._
import org.scalatest.{BeforeAndAfter, FunSuite}
import scala.collection.mutable
class BroadCastProcessor extends BroadcastProcessFunction[Int, (Int, String), String] {
import BroadCastProcessor._
override def processElement(value: Int,
ctx: BroadcastProcessFunction[Int, (Int, String), String]#ReadOnlyContext,
out: Collector[String]): Unit = {
val broadcastState = ctx.getBroadcastState(broadcastStateDescriptor)
if (broadcastState.contains(value)) {
out.collect(broadcastState.get(value))
}
}
override def processBroadcastElement(value: (Int, String),
ctx: BroadcastProcessFunction[Int, (Int, String), String]#Context,
out: Collector[String]): Unit = {
ctx.getBroadcastState(broadcastStateDescriptor).put(value._1, value._2)
}
}
object BroadCastProcessor {
val broadcastStateDescriptor: MapStateDescriptor[Int, String] = new MapStateDescriptor[Int, String]("int_to_string", classOf[Int], classOf[String])
}
class CollectSink extends SinkFunction[String] {
import CollectSink._
override def invoke(value: String): Unit = {
values += value
}
}
object CollectSink { // must be static
val values: mutable.MutableList[String] = mutable.MutableList[String]()
}
class BroadCastProcessTest extends FunSuite with BeforeAndAfter {
before {
CollectSink.values.clear()
}
test("add_elem_to_broadcast_and_process_should_apply_broadcast_rule") {
val env = StreamExecutionEnvironment.getExecutionEnvironment
env.setParallelism(1)
val dataToProcessStream = env.fromElements(1)
val ruleToBroadcastStream = env.fromElements(1 -> "1", 2 -> "2", 3 -> "3")
val broadcastStream = ruleToBroadcastStream.broadcast(BroadCastProcessor.broadcastStateDescriptor)
dataToProcessStream
.connect(broadcastStream)
.process(new BroadCastProcessor)
.addSink(new CollectSink())
// execute
env.execute()
CollectSink.values should contain("1")
}
}
Update thanks to David Anderson
I went for the buffer solution. I defined a process function for the synchronization:
class SynchronizeModelAndEvent(modelNumberToWaitFor: Int) extends CoProcessFunction[Int, (Int, String), Int] {
val eventBuffer: mutable.MutableList[Int] = mutable.MutableList[Int]()
var modelEventsNumber = 0
override def processElement1(value: Int, ctx: CoProcessFunction[Int, (Int, String), Int]#Context, out: Collector[Int]): Unit = {
if (modelEventsNumber < modelNumberToWaitFor) {
eventBuffer += value
return
}
out.collect(value)
}
override def processElement2(value: (Int, String), ctx: CoProcessFunction[Int, (Int, String), Int]#Context, out: Collector[Int]): Unit = {
modelEventsNumber += 1
if (modelEventsNumber >= modelNumberToWaitFor) {
eventBuffer.foreach(event => out.collect(event))
}
}
}
And so I need to add it to my stream:
dataToProcessStream
.connect(ruleToBroadcastStream)
.process(new SynchronizeModelAndEvent(3))
.connect(broadcastStream)
.process(new BroadCastProcessor)
.addSink(new CollectSink())
Thanks

There isn't an easy way to do this. You could have processElement buffer all of its input until the model has been received by processBroadcastElement. Or run the job once with no event traffic and take a savepoint once the model has been broadcast. Then restore that savepoint into the same job, but with its event input connected.
By the way, the capability you are looking for is often referred to as "side inputs" in the Flink community.

Thanks to David Anderson and Matthieu I wrote this generic CoFlatMap function that makes the requested delay on the event stream:
import org.apache.flink.streaming.api.functions.co.CoProcessFunction
import org.apache.flink.util.Collector
import scala.collection.mutable
class SynchronizeEventsWithRules[A,B](rulesToWait: Int) extends CoProcessFunction[A, B, A] {
val eventBuffer: mutable.MutableList[A] = mutable.MutableList[A]()
var processedRules = 0
override def processElement1(value: A, ctx: CoProcessFunction[A, B, A]#Context, out: Collector[A]): Unit = {
if (processedRules < rulesToWait) {
println("1 item buffered")
println(rulesToWait+"--"+processedRules)
eventBuffer += value
return
}
eventBuffer.clear()
println("send input to output without buffering:")
out.collect(value)
}
override def processElement2(value: B, ctx: CoProcessFunction[A, B, A]#Context, out: Collector[A]): Unit = {
processedRules += 1
println("1 rule processed processedRules:"+processedRules)
if (processedRules >= rulesToWait && eventBuffer.length>0) {
println("send buffered data to output")
eventBuffer.foreach(event => out.collect(event))
eventBuffer.clear()
}
}
}
but unfortunately, it does not help at all in my case, because the subject under the test was a KeyedBroadcastProcessFunction that makes the delay on event data irrelevant because of that I tried to apply a flatmap that makes the rule stream n times larger, n was the number of CPUs. so I will sure that the resulting event stream will be always sync with the rule stream and will be arrived after that but it does not help either.
after all, I came to this simple solution that of course it is not deterministic but because of the nature of parallelism and concurrency the problem itself is not deterministic either.
If we set the delayMilis big enough (>100) the result will be deterministic:
val delayMilis = 100
val synchronizedInput = inputEventStream.map(x=>{
Thread.sleep(delayMilis)
x
}).keyBy(_.someKey)
you can also change the mapping function to this to apply the delay only on the first element:
package util
import org.apache.flink.api.common.functions.MapFunction
class DelayEvents[T](delayMilis: Int) extends MapFunction[T,T] {
var delayed = false
override def map(value: T): T = {
if (!delayed) {
delayed=true
Thread.sleep(delayMilis)
}
value
}
}
val delayMilis = 100
val synchronizedInput = inputEventStream.map(new DelayEvents(100)).keyBy(_.someKey)

Related

Scala stream and ExecutionContext issue

I'm new in Scala and i'm facing a few problems in my assignment :
I want to build a stream class that can do 3 main tasks : filter,map,and forEach.
My streams data is an array of elements. Each of the 3 main tasks should run in 2 different threads on my streams array.
In addition, I need to divde the logic of the action and its actual run to two different parts. First declare all tasks in stream and only when I run stream.run() I want the actual actions to happen.
My code :
class LearningStream[A]() {
val es: ExecutorService = Executors.newFixedThreadPool(2)
val ec = ExecutionContext.fromExecutorService(es)
var streamValues: ArrayBuffer[A] = ArrayBuffer[A]()
var r: Runnable = () => "";
def setValues(streamv: ArrayBuffer[A]) = {
streamValues = streamv;
}
def filter(p: A => Boolean): LearningStream[A] = {
var ls_filtered: LearningStream[A] = new LearningStream[A]()
r = () => {
println("running real filter..")
val (l,r) = streamValues.splitAt(streamValues.length/2)
val a:ArrayBuffer[A]=es.submit(()=>l.filter(p)).get()
val b:ArrayBuffer[A]=es.submit(()=>r.filter(p)).get()
ms_filtered.setValues(a++b)
}
return ls_filtered
}
def map[B](f: A => B): LearningStream[B] = {
var ls_map: LearningStream[B] = new LearningStream[B]()
r = () => {
println("running real map..")
val (l,r) = streamValues.splitAt(streamValues.length/2)
val a:ArrayBuffer[B]=es.submit(()=>l.map(f)).get()
val b:ArrayBuffer[B]=es.submit(()=>r.map(f)).get()
ls_map.setValues(a++b)
}
return ls_map
}
def forEach(c: A => Unit): Unit = {
r=()=>{
println("running real forEach")
streamValues.foreach(c)}
}
def insert(a: A): Unit = {
streamValues += a
}
def start(): Unit = {
ec.submit(r)
}
def shutdown(): Unit = {
ec.shutdown()
}
}
my main :
def main(args: Array[String]): Unit = {
var factorial=0
val s = new LearningStream[String]
s.filter(str=>str.startsWith("-")).map(s=>s.toInt*(-1)).forEach(i=>factorial=factorial*i)
for(i <- -5 to 5){
s.insert(i.toString)
}
println(s.streamValues)
s.start()
println(factorial)
}
The main prints only the filter`s output and the factorial isnt changed (still 1).
What am I missing here ?
My solution: #Levi Ramsey left a few good hints in the comments if you want to get hints and not the real solution.
First problem: Only one command (filter) run and the other didn't. solution: insert to the runnable of each command a call for the next stream via:
ec.submit(ms_map.r)
In order to be able to close all sessions, we need to add another LearningStream data member to the class. However we can't add just a regular LearningStream object because it depends on parameter [A]. Therefore, I implemented a trait that has the close function and my data member was of that trait type.

Scala RestartSink Future

I'm trying to re-create similar functionality of Scala's [RestartSink][1] feature.
I've come up with this code. However, since we only return a SinkShape instead of a Sink, I'm having trouble specifying that it should return a Future[Done] instead of NotUsed as it's materialized type. However, I'm confused about how to do that. I'm only able to have it return [MessageActionPair, NotUsed] instead of the desired [MessageActionPair, Future[Done]]. I'm still learning my way around this framework, so I'm sure I'm missing something small. I tried calling Source.toMat(RestartWithBackoffSink...), however that doesn't give the desired result either.
private final class RestartWithBackoffSink(
sourcePool: Seq[SqsEndpoint],
minBackoff: FiniteDuration,
maxBackoff: FiniteDuration,
randomFactor: Double) extends GraphStage[SinkShape[MessageActionPair]] { self ⇒
val in = Inlet[MessageActionPair]("RoundRobinRestartWithBackoffSink.in")
override def shape = SinkShape(in)
override def createLogic(inheritedAttributes: Attributes) = new RestartWithBackoffLogic(
"Sink", shape, minBackoff, maxBackoff, randomFactor, onlyOnFailures = false) {
override protected def logSource = self.getClass
override protected def startGraph() = {
val sourceOut = createSubOutlet(in)
Source.fromGraph(sourceOut.source).runWith(createSink(getEndpoint))(subFusingMaterializer)
}
override protected def backoff() = {
setHandler(in, new InHandler {
override def onPush() = ()
})
}
private def createSink(endpoint: SqsEndpoint): Sink[MessageActionPair, Future[Done]] = {
SqsAckSink(endpoint.queue.url)(endpoint.client)
}
def getEndpoint: SqsEndpoint = {
if(isTimedOut) {
index = (index + 1) % sourcePool.length
restartCount = 0
}
sourcePool(index)
}
backoff()
}
}
Syntax error here, since types don't match:
def withBackoff[T](minBackoff: FiniteDuration, maxBackoff: FiniteDuration, randomFactor: Double, sourcePool: Seq[SqsEndpoint]): Sink[MessageActionPair, Future[Done]] = {
Sink.fromGraph(new RestartWithBackoffSink(sourcePool, minBackoff, maxBackoff, randomFactor))
}
By extending extends GraphStage[SinkShape[MessageActionPair]] you are defining a stage with no materialized value. Or better you define a stage that materializes to NotUsed.
You have to decide if your stage can materialize into anything meaningful. More on materialized values for stages here.
If so: you have to extend GraphStageWithMaterializedValue[SinkShape[MessageActionPair], Future[Done]] and properly override the createLogicAndMaterializedValue function. More guidance can be found in the docs.
If not: you can change your types as per below
def withBackoff[T](minBackoff: FiniteDuration, maxBackoff: FiniteDuration, randomFactor: Double, sourcePool: Seq[SqsEndpoint]): Sink[MessageActionPair, NotUsed] = {
Sink.fromGraph(new RestartWithBackoffSink(sourcePool, minBackoff, maxBackoff, randomFactor))
}

Akka streams change return type of 3rd party flow/stage

I have a graph that reads from sqs, writes to another system and then deletes from sqs. In order to delete from sqs i need a receipt handle on the SqsMessage object
In the case of Http flows the signature of the flow allows me to say which type gets emitted downstream from the flow,
Flow[(HttpRequest, T), (Try[HttpResponse], T), HostConnectionPool]
In this case i can set T to SqsMessage and i still have all the data i need.
However some connectors e.g google cloud pub sub emits a completely useless (to me) pub sub id.
Downstream of the pub sub flow I need to be able to access the sqs message id which i had prior to the pub sub flow.
What is the best way to work around this without rewriting the pub sub connector
I conceptually want something a bit like this:
Flow[SqsMessage] //i have my data at this point
within(
.map(toPubSubMessage)
.via(pubSub))
... from here i have the same type i had before within however it still behaves like a linear graph with back pressure etc
You can use PassThrough integration pattern.
As example of usage look on akka-streams-kafka -> Class akka.kafka.scaladsl.Producer -> Mehtod def flow[K, V, PassThrough]
So just implement your own Stage with PassThrough element, example inakka.kafka.internal.ProducerStage[K, V, PassThrough]
package my.package
import java.util.concurrent.atomic.AtomicInteger
import scala.concurrent.Future
import scala.util.{Failure, Success, Try}
import akka.stream._
import akka.stream.ActorAttributes.SupervisionStrategy
import akka.stream.stage._
final case class Message[V, PassThrough](record: V, passThrough: PassThrough)
final case class Result[R, PassThrough](result: R, message: PassThrough)
class PathThroughStage[R, V, PassThrough]
extends GraphStage[FlowShape[Message[V, PassThrough], Future[Result[R, PassThrough]]]] {
private val in = Inlet[Message[V, PassThrough]]("messages")
private val out = Outlet[Result[R, PassThrough]]("result")
override val shape = FlowShape(in, out)
override protected def createLogic(inheritedAttributes: Attributes) = {
val logic = new GraphStageLogic(shape) with StageLogging {
lazy val decider = inheritedAttributes.get[SupervisionStrategy]
.map(_.decider)
.getOrElse(Supervision.stoppingDecider)
val awaitingConfirmation = new AtomicInteger(0)
#volatile var inIsClosed = false
var completionState: Option[Try[Unit]] = None
override protected def logSource: Class[_] = classOf[PathThroughStage[R, V, PassThrough]]
def checkForCompletion() = {
if (isClosed(in) && awaitingConfirmation.get == 0) {
completionState match {
case Some(Success(_)) => completeStage()
case Some(Failure(ex)) => failStage(ex)
case None => failStage(new IllegalStateException("Stage completed, but there is no info about status"))
}
}
}
val checkForCompletionCB = getAsyncCallback[Unit] { _ =>
checkForCompletion()
}
val failStageCb = getAsyncCallback[Throwable] { ex =>
failStage(ex)
}
setHandler(out, new OutHandler {
override def onPull() = {
tryPull(in)
}
})
setHandler(in, new InHandler {
override def onPush() = {
val msg = grab(in)
val f = Future[Result[R, PassThrough]] {
try {
Result(// TODO YOUR logic
msg.record,
msg.passThrough)
} catch {
case exception: Exception =>
decider(exception) match {
case Supervision.Stop =>
failStageCb.invoke(exception)
case _ =>
Result(exception, msg.passThrough)
}
}
if (awaitingConfirmation.decrementAndGet() == 0 && inIsClosed) checkForCompletionCB.invoke(())
}
awaitingConfirmation.incrementAndGet()
push(out, f)
}
override def onUpstreamFinish() = {
inIsClosed = true
completionState = Some(Success(()))
checkForCompletion()
}
override def onUpstreamFailure(ex: Throwable) = {
inIsClosed = true
completionState = Some(Failure(ex))
checkForCompletion()
}
})
override def postStop() = {
log.debug("Stage completed")
super.postStop()
}
}
logic
}
}

Asynchronous Iterable over remote data

There is some data that I have pulled from a remote API, for which I use a Future-style interface. The data is structured as a linked-list. A relevant example data container is shown below.
case class Data(information: Int) {
def hasNext: Boolean = ??? // Implemented
def next: Future[Data] = ??? // Implemented
}
Now I'm interested in adding some functionality to the data class, such as map, foreach, reduce, etc. To do so I want to implement some form of IterableLike such that it inherets these methods.
Given below is the trait Data may extend, such that it gets this property.
trait AsyncIterable[+T]
extends IterableLike[Future[T], AsyncIterable[T]]
{
def hasNext : Boolean
def next : Future[T]
// How to implement?
override def iterator: Iterator[Future[T]] = ???
override protected[this] def newBuilder: mutable.Builder[Future[T], AsyncIterable[T]] = ???
override def seq: TraversableOnce[Future[T]] = ???
}
It should be a non-blocking implementation, which when acted on, starts requesting the next data from the remote data source.
It is then possible to do cool stuff such as
case class Data(information: Int) extends AsyncIterable[Data]
val data = Data(1) // And more, of course
// Asynchronously print all the information.
data.foreach(data => println(data.information))
It is also acceptable for the interface to be different. But the result should in some way represent asynchronous iteration over the collection. Preferably in a way that is familiar to developers, as it will be part of an (open source) library.
In production I would use one of following:
Akka Streams
Reactive Extensions
For private tests I would implement something similar to following.
(Explanations are below)
I have modified a little bit your Data:
abstract class AsyncIterator[T] extends Iterator[Future[T]] {
def hasNext: Boolean
def next(): Future[T]
}
For it we can implement this Iterable:
class AsyncIterable[T](sourceIterator: AsyncIterator[T])
extends IterableLike[Future[T], AsyncIterable[T]]
{
private def stream(): Stream[Future[T]] =
if(sourceIterator.hasNext) {sourceIterator.next #:: stream()} else {Stream.empty}
val asStream = stream()
override def iterator = asStream.iterator
override def seq = asStream.seq
override protected[this] def newBuilder = throw new UnsupportedOperationException()
}
And if see it in action using following code:
object Example extends App {
val source = "Hello World!";
val iterator1 = new DelayedIterator[Char](100L, source.toCharArray)
new AsyncIterable(iterator1).foreach(_.foreach(print)) //prints 1 char per 100 ms
pause(2000L)
val iterator2 = new DelayedIterator[String](100L, source.toCharArray.map(_.toString))
new AsyncIterable(iterator2).reduceLeft((fl: Future[String], fr) =>
for(l <- fl; r <- fr) yield {println(s"$l+$r"); l + r}) //prints 1 line per 100 ms
pause(2000L)
def pause(duration: Long) = {println("->"); Thread.sleep(duration); println("\n<-")}
}
class DelayedIterator[T](delay: Long, data: Seq[T]) extends AsyncIterator[T] {
private val dataIterator = data.iterator
private var nextTime = System.currentTimeMillis() + delay
override def hasNext = dataIterator.hasNext
override def next = {
val thisTime = math.max(System.currentTimeMillis(), nextTime)
val thisValue = dataIterator.next()
nextTime = thisTime + delay
Future {
val now = System.currentTimeMillis()
if(thisTime > now) Thread.sleep(thisTime - now) //Your implementation will be better
thisValue
}
}
}
Explanation
AsyncIterable uses Stream because it's calculated lazily and it's simple.
Pros:
simplicity
multiple calls to iterator and seq methods return same iterable with all items.
Cons:
could lead to memory overflow because stream keeps all prevously obtained values.
first value is eagerly gotten during creation of AsyncIterable
DelayedIterator is very simplistic implementation of AsyncIterator, don't blame me for quick and dirty code here.
It's still strange for me to see synchronous hasNext and asynchronous next()
Using Twitter Spool I've implemented a working example.
To implement spool I modified the example in the documentation.
import com.twitter.concurrent.Spool
import com.twitter.util.{Await, Return, Promise}
import scala.concurrent.{ExecutionContext, Future}
trait AsyncIterable[+T <: AsyncIterable[T]] { self : T =>
def hasNext : Boolean
def next : Future[T]
def spool(implicit ec: ExecutionContext) : Spool[T] = {
def fill(currentPage: Future[T], rest: Promise[Spool[T]]) {
currentPage foreach { cPage =>
if(hasNext) {
val nextSpool = new Promise[Spool[T]]
rest() = Return(cPage *:: nextSpool)
fill(next, nextSpool)
} else {
val emptySpool = new Promise[Spool[T]]
emptySpool() = Return(Spool.empty[T])
rest() = Return(cPage *:: emptySpool)
}
}
}
val rest = new Promise[Spool[T]]
if(hasNext) {
fill(next, rest)
} else {
rest() = Return(Spool.empty[T])
}
self *:: rest
}
}
Data is the same as before, and now we can use it.
// Cool stuff
implicit val ec = scala.concurrent.ExecutionContext.global
val data = Data(1) // And others
// Print all the information asynchronously
val fut = data.spool.foreach(data => println(data.information))
Await.ready(fut)
It will trow an exception on the second element, because the implementation of next was not provided.

Sequencing Scala Futures with bounded parallelism (without messing around with ExecutorContexts)

Background: I have a function:
def doWork(symbol: String): Future[Unit]
which initiates some side-effects to fetch data and store it, and completes a Future when its done. However, the back-end infrastructure has usage limits, such that no more than 5 of these requests can be made in parallel. I have a list of N symbols that I need to get through:
var symbols = Array("MSFT",...)
but I want to sequence them such that no more than 5 are executing simultaneously. Given:
val allowableParallelism = 5
my current solution is (assuming I'm working with async/await):
val symbolChunks = symbols.toList.grouped(allowableParallelism).toList
def toThunk(x: List[String]) = () => Future.sequence(x.map(doWork))
val symbolThunks = symbolChunks.map(toThunk)
val done = Promise[Unit]()
def procThunks(x: List[() => Future[List[Unit]]]): Unit = x match {
case Nil => done.success()
case x::xs => x().onComplete(_ => procThunks(xs))
}
procThunks(symbolThunks)
await { done.future }
but, for obvious reasons, I'm not terribly happy with it. I feel like this should be possible with folds, but every time I try, I end up eagerly creating the Futures. I also tried out a version with RxScala Observables, using concatMap, but that also seemed like overkill.
Is there a better way to accomplish this?
I have example how to do it with scalaz-stream. It's quite a lot of code because it's required to convert scala Future to scalaz Task (abstraction for deferred computation). However it's required to add it to project once. Another option is to use Task for defining 'doWork'. I personally prefer task for building async programs.
import scala.concurrent.{Future => SFuture}
import scala.util.Random
import scala.concurrent.ExecutionContext.Implicits.global
import scalaz.stream._
import scalaz.concurrent._
val P = scalaz.stream.Process
val rnd = new Random()
def doWork(symbol: String): SFuture[Unit] = SFuture {
Thread.sleep(rnd.nextInt(1000))
println(s"Symbol: $symbol. Thread: ${Thread.currentThread().getName}")
}
val symbols = Seq("AAPL", "MSFT", "GOOGL", "CVX").
flatMap(s => Seq.fill(5)(s).zipWithIndex.map(t => s"${t._1}${t._2}"))
implicit class Transformer[+T](fut: => SFuture[T]) {
def toTask(implicit ec: scala.concurrent.ExecutionContext): Task[T] = {
import scala.util.{Failure, Success}
import scalaz.syntax.either._
Task.async {
register =>
fut.onComplete {
case Success(v) => register(v.right)
case Failure(ex) => register(ex.left)
}
}
}
}
implicit class ConcurrentProcess[O](val process: Process[Task, O]) {
def concurrently[O2](concurrencyLevel: Int)(f: Channel[Task, O, O2]): Process[Task, O2] = {
val actions =
process.
zipWith(f)((data, f) => f(data))
val nestedActions =
actions.map(P.eval)
merge.mergeN(concurrencyLevel)(nestedActions)
}
}
val workChannel = io.channel((s: String) => doWork(s).toTask)
val process = Process.emitAll(symbols).concurrently(5)(workChannel)
process.run.run
When you'll have all this transformation in scope, basically all you need is:
val workChannel = io.channel((s: String) => doWork(s).toTask)
val process = Process.emitAll(symbols).concurrently(5)(workChannel)
Quite short and self-decribing
Although you've already got an excellent answer, I thought I might still offer an opinion or two about these matters.
I remember seeing somewhere (on someone's blog) "use actors for state and use futures for concurrency".
So my first thought would be to utilize actors somehow. To be precise, I would have a master actor with a router launching multiple worker actors, with number of workers restrained according to allowableParallelism. So, assuming I have
def doWorkInternal (symbol: String): Unit
which does the work from yours doWork taken 'outside of future', I would have something along these lines (very rudimentary, not taking many details into consideration, and practically copying code from akka documentation):
import akka.actor._
case class WorkItem (symbol: String)
case class WorkItemCompleted (symbol: String)
case class WorkLoad (symbols: Array[String])
case class WorkLoadCompleted ()
class Worker extends Actor {
def receive = {
case WorkItem (symbol) =>
doWorkInternal (symbol)
sender () ! WorkItemCompleted (symbol)
}
}
class Master extends Actor {
var pending = Set[String] ()
var originator: Option[ActorRef] = None
var router = {
val routees = Vector.fill (allowableParallelism) {
val r = context.actorOf(Props[Worker])
context watch r
ActorRefRoutee(r)
}
Router (RoundRobinRoutingLogic(), routees)
}
def receive = {
case WorkLoad (symbols) =>
originator = Some (sender ())
context become processing
for (symbol <- symbols) {
router.route (WorkItem (symbol), self)
pending += symbol
}
}
def processing: Receive = {
case Terminated (a) =>
router = router.removeRoutee(a)
val r = context.actorOf(Props[Worker])
context watch r
router = router.addRoutee(r)
case WorkItemCompleted (symbol) =>
pending -= symbol
if (pending.size == 0) {
context become receive
originator.get ! WorkLoadCompleted
}
}
}
You could query the master actor with ask and receive a WorkLoadCompleted in a future.
But thinking more about 'state' (of number of simultaneous requests in processing) to be hidden somewhere, together with implementing necessary code for not exceeding it, here's something of the 'future gateway intermediary' sort, if you don't mind imperative style and mutable (used internally only though) structures:
object Guardian
{
private val incoming = new collection.mutable.HashMap[String, Promise[Unit]]()
private val outgoing = new collection.mutable.HashMap[String, Future[Unit]]()
private val pending = new collection.mutable.Queue[String]
def doWorkGuarded (symbol: String): Future[Unit] = {
synchronized {
val p = Promise[Unit] ()
incoming(symbol) = p
if (incoming.size <= allowableParallelism)
launchWork (symbol)
else
pending.enqueue (symbol)
p.future
}
}
private def completionHandler (t: Try[Unit]): Unit = {
synchronized {
for (symbol <- outgoing.keySet) {
val f = outgoing (symbol)
if (f.isCompleted) {
incoming (symbol).completeWith (f)
incoming.remove (symbol)
outgoing.remove (symbol)
}
}
for (i <- outgoing.size to allowableParallelism) {
if (pending.nonEmpty) {
val symbol = pending.dequeue()
launchWork (symbol)
}
}
}
}
private def launchWork (symbol: String): Unit = {
val f = doWork(symbol)
outgoing(symbol) = f
f.onComplete(completionHandler)
}
}
doWork now is exactly like yours, returning Future[Unit], with the idea that instead of using something like
val futures = symbols.map (doWork (_)).toSeq
val future = Future.sequence(futures)
which would launch futures not regarding allowableParallelism at all, I would instead use
val futures = symbols.map (Guardian.doWorkGuarded (_)).toSeq
val future = Future.sequence(futures)
Think about some hypothetical database access driver with non-blocking interface, i.e. returning futures on requests, which is limited in concurrency by being built over some connection pool for example - you wouldn't want it to return futures not taking parallelism level into account, and require you to juggle with them to keep parallelism under control.
This example is more illustrative than practical since I wouldn't normally expect that 'outgoing' interface would be utilizing futures like this (which is quote ok for 'incoming' interface).
First, obviously some purely functional wrapper around Scala's Future is needed, cause it's side-effective and runs as soon as it can. Let's call it Deferred:
import scala.concurrent.Future
import scala.util.control.Exception.nonFatalCatch
class Deferred[+T](f: () => Future[T]) {
def run(): Future[T] = f()
}
object Deferred {
def apply[T](future: => Future[T]): Deferred[T] =
new Deferred(() => nonFatalCatch.either(future).fold(Future.failed, identity))
}
And here is the routine:
import java.util.concurrent.CopyOnWriteArrayList
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.immutable.Seq
import scala.concurrent.{ExecutionContext, Future, Promise}
import scala.util.control.Exception.nonFatalCatch
import scala.util.{Failure, Success}
trait ConcurrencyUtils {
def runWithBoundedParallelism[T](parallelism: Int = Runtime.getRuntime.availableProcessors())
(operations: Seq[Deferred[T]])
(implicit ec: ExecutionContext): Deferred[Seq[T]] =
if (parallelism > 0) Deferred {
val indexedOps = operations.toIndexedSeq // index for faster access
val promise = Promise[Seq[T]]()
val acc = new CopyOnWriteArrayList[(Int, T)] // concurrent acc
val nextIndex = new AtomicInteger(parallelism) // keep track of the next index atomically
def run(operation: Deferred[T], index: Int): Unit = {
operation.run().onComplete {
case Success(value) =>
acc.add((index, value)) // accumulate result value
if (acc.size == indexedOps.size) { // we've done
import scala.collection.JavaConversions._
// in concurrent setting next line may be called multiple times, that's why trySuccess instead of success
promise.trySuccess(acc.view.sortBy(_._1).map(_._2).toList)
} else {
val next = nextIndex.getAndIncrement() // get and inc atomically
if (next < indexedOps.size) { // run next operation if exists
run(indexedOps(next), next)
}
}
case Failure(t) =>
promise.tryFailure(t) // same here (may be called multiple times, let's prevent stdout pollution)
}
}
if (operations.nonEmpty) {
indexedOps.view.take(parallelism).zipWithIndex.foreach((run _).tupled) // run as much as allowed
promise.future
} else {
Future.successful(Seq.empty)
}
} else {
throw new IllegalArgumentException("Parallelism must be positive")
}
}
In a nutshell, we run as much operations initially as allowed and then on each operation completion we run next operation available, if any. So the only difficulty here is to maintain next operation index and results accumulator in concurrent setting. I'm not an absolute concurrency expert, so make me know if there are some potential problems in the code above. Notice that returned value is also a deferred computation that should be run.
Usage and test:
import org.scalatest.{Matchers, FlatSpec}
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.time.{Seconds, Span}
import scala.collection.immutable.Seq
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future
import scala.concurrent.duration._
class ConcurrencyUtilsSpec extends FlatSpec with Matchers with ScalaFutures with ConcurrencyUtils {
"runWithBoundedParallelism" should "return results in correct order" in {
val comp1 = mkDeferredComputation(1)
val comp2 = mkDeferredComputation(2)
val comp3 = mkDeferredComputation(3)
val comp4 = mkDeferredComputation(4)
val comp5 = mkDeferredComputation(5)
val compountComp = runWithBoundedParallelism(2)(Seq(comp1, comp2, comp3, comp4, comp5))
whenReady(compountComp.run()) { result =>
result should be (Seq(1, 2, 3, 4, 5))
}
}
// increase default ScalaTest patience
implicit val defaultPatience = PatienceConfig(timeout = Span(10, Seconds))
private def mkDeferredComputation[T](result: T, sleepDuration: FiniteDuration = 100.millis): Deferred[T] =
Deferred {
Future {
Thread.sleep(sleepDuration.toMillis)
result
}
}
}
Use Monix Task. An example from Monix document for parallelism=10
val items = 0 until 1000
// The list of all tasks needed for execution
val tasks = items.map(i => Task(i * 2))
// Building batches of 10 tasks to execute in parallel:
val batches = tasks.sliding(10,10).map(b => Task.gather(b))
// Sequencing batches, then flattening the final result
val aggregate = Task.sequence(batches).map(_.flatten.toList)
// Evaluation:
aggregate.foreach(println)
//=> List(0, 2, 4, 6, 8, 10, 12, 14, 16,...
Akka streams, allow you to do the following:
import akka.NotUsed
import akka.stream.Materializer
import akka.stream.scaladsl.Source
import scala.concurrent.Future
def sequence[A: Manifest, B](items: Seq[A], func: A => Future[B], parallelism: Int)(
implicit mat: Materializer
): Future[Seq[B]] = {
val futures: Source[B, NotUsed] =
Source[A](items.toList).mapAsync(parallelism)(x => func(x))
futures.runFold(Seq.empty[B])(_ :+ _)
}
sequence(symbols, doWork, allowableParallelism)