I was wondering if there was a way to execute very simple tasks on another thread in scala that does not have a lot of overhead?
Basically I would like to make a global 'executor' that can handle executing an arbitrary number of tasks. I can then use the executor to build up additional constructs.
Additionally it would be nice if blocking or non-blocking considerations did not have to be considered by the clients.
I know that the scala actors library is built on top of the Doug Lea FJ stuff, and also that they support to a limited degree what I am trying to accomplish. However from my understanding I will have to pre-allocate an 'Actor Pool' to accomplish.
I would like to avoid making a global thread pool for this, as from what I understand it is not all that good at fine grained parallelism.
Here is a simple example:
import concurrent.SyncVar
object SimpleExecutor {
import actors.Actor._
def exec[A](task: => A) : SyncVar[A] = {
//what goes here?
//This is what I currently have
val x = new concurrent.SyncVar[A]
//The overhead of making the actor appears to be a killer
actor {
x.set(task)
}
x
}
//Not really sure what to stick here
def execBlocker[A](task: => A) : SyncVar[A] = exec(task)
}
and now an example of using exec:
object Examples {
//Benchmarks a task
def benchmark(blk : => Unit) = {
val start = System.nanoTime
blk
System.nanoTime - start
}
//Benchmarks and compares 2 tasks
def cmp(a: => Any, b: => Any) = {
val at = benchmark(a)
val bt = benchmark(b)
println(at + " " + bt + " " +at.toDouble / bt)
}
//Simple example for simple non blocking comparison
import SimpleExecutor._
def paraAdd(hi: Int) = (0 until hi) map (i=>exec(i+5)) foreach (_.get)
def singAdd(hi: Int) = (0 until hi) foreach (i=>i+5)
//Simple example for the blocking performance
import Thread.sleep
def paraSle(hi : Int) = (0 until hi) map (i=>exec(sleep(i))) foreach (_.get)
def singSle(hi : Int) = (0 until hi) foreach (i=>sleep(i))
}
Finally to run the examples (might want to do it a few times so HotSpot can warm up):
import Examples._
cmp(paraAdd(10000), singAdd(10000))
cmp(paraSle(100), singSle(100))
That's what Futures was made for. Just import scala.actors.Futures._, use future to create new futures, methods like awaitAll to wait on the results for a while, apply or respond to block until the result is received, isSet to see if it's ready or not, etc.
You don't need to create a thread pool either. Or, at least, not normally you don't. Why do you think you do?
EDIT
You can't gain performance parallelizing something as simple as an integer addition, because that's even faster than a function call. Concurrency will only bring performance by avoiding time lost to blocking i/o and by using multiple CPU cores to execute tasks in parallel. In the latter case, the task must be computationally expensive enough to offset the cost of dividing the workload and merging the results.
One other reason to go for concurrency is to improve the responsiveness of the application. That's not making it faster, that's making it respond faster to the user, and one way of doing that is getting even relatively fast operations offloaded to another thread so that the threads handling what the user sees or does can be faster. But I digress.
There's a serious problem with your code:
def paraAdd(hi: Int) = (0 until hi) map (i=>exec(i+5)) foreach (_.get)
def singAdd(hi: Int) = (0 until hi) foreach (i=>i+5)
Or, translating into futures,
def paraAdd(hi: Int) = (0 until hi) map (i=>future(i+5)) foreach (_.apply)
def singAdd(hi: Int) = (0 until hi) foreach (i=>i+5)
You might think paraAdd is doing the tasks in paralallel, but it isn't, because Range has a non-strict implementation of map (that's up to Scala 2.7; starting with Scala 2.8.0, Range is strict). You can look it up on other Scala questions. What happens is this:
A range is created from 0 until hi
A range projection is created from each element i of the range into a function that returns future(i+5) when called.
For each element of the range projection (i => future(i+5)), the element is evaluated (foreach is strict) and then the function apply is called on it.
So, because future is not called in step 2, but only in step 3, you'll wait for each future to complete before doing the next one. You can fix it with:
def paraAdd(hi: Int) = (0 until hi).force map (i=>future(i+5)) foreach (_.apply)
Which will give you better performance, but never as good as a simple immediate addition. On the other hand, suppose you do this:
def repeat(n: Int, f: => Any) = (0 until n) foreach (_ => f)
def paraRepeat(n: Int, f: => Any) =
(0 until n).force map (_ => future(f)) foreach (_.apply)
And then compare:
cmp(repeat(100, singAdd(100000)), paraRepeat(100, singAdd(100000)))
You may start seeing gains (it will depend on the number of cores and processor speed).
Related
I am trying to write a Stream in Scala and I do not get why it keeps some intermediary objects in memory (and eventually get out of memory).
Here is a simplified version of my code :
val slols = {
def genLols (curr:Vector[Int]) :Stream[Int] = 0 #:: genLols(curr map (_ + 1))
genLols(List.fill(1000 * 1000)(0).toVector)
}
println(slols(1000))
This seems to keep in memory the intermediary curr and I do not see why.
Here is the same code written with an iterator (much better memory consumption):
val ilols = new Iterator [Int] {
var curr = List.fill(1000 * 1000)(0).toVector
def hasNext = true
def next () :Int = { curr = curr map (_ + 1) ; 0 }
}
val silols = ilols.toStream
println(silols(1000))
EDIT : I am interested in keeping the 0s in memory, my goal is to not keep the currs because they are juste useful to take a step of computation (it may not be obvious in my simplified example). And the 0s alone cannot make an out of memory error (they are not so heavy to store).
When you assign stream to a variable you prevent it's head from being garbage collected.
In first case that means that all the references for cur Vectors for each iteration are memoized.
In second case the situation is better, as iterator stores only single instance of cur at a time. But still, all Int elements of the stream are memoized.
Try replacing val slols with def slols if you don't need memoization.
And by the way, your two examples are not equivalent in terms of logic (probably you're aware of it).
Check this article for details.
How to compute the factorial using Scala actors ?
And would it prove more time efficient compared to for instance
def factorial(n: Int): BigInt = (BigInt(1) to BigInt(n)).par.product
Many Thanks.
Problem
You have to split up your input in partial products. This partial products can then be calculated in parallel. The partial products are then multiplied to get the final product.
This can be reduced to a broader class of problems: The so called Parallel prefix calculation. You can read up about it on Wikipedia.
Short version: When you calculate a*b*c*d with an associative operation _ * _, you can structure the calculation a*(b*(c*d)) or (a*b)*(c*d). With the second approach, you can then calculate a*b and c*d in parallel and then calculate the final result from these partial results. Of course you can do this recursively, when you have a bigger number of input values.
Solution
Disclaimer
This sounds a little bit like a homework assignment. So I will provide a solution that has two properties:
It contains a small bug
It shows how to solve parallel prefix in general, without solving the problem directly
So you can see how the solution should be structured, but no one can use it to cheat on her homework.
Solution in detail
First I need a few imports
import akka.event.Logging
import java.util.concurrent.TimeUnit
import scala.concurrent.duration.FiniteDuration
import akka.actor._
Then I create some helper classes for the communication between the actors
case class Calculate[T](values : Seq[T], segment : Int, parallelLimit : Int, fn : (T,T) => T)
trait CalculateResponse
case class CalculationResult[T](result : T, index : Int) extends CalculateResponse
case object Busy extends CalculateResponse
Instead of telling the receiver you are busy, the actor could also use the stash or implement its own queue for partial results. But in this case I think the sender shoudl decide how much parallel calculations are allowed.
Now I create the actor:
class ParallelPrefixActor[T] extends Actor {
val log = Logging(context.system, this)
val subCalculation = Props(classOf[ParallelPrefixActor[BigInt]])
val fanOut = 2
def receive = waitForCalculation
def waitForCalculation : Actor.Receive = {
case c : Calculate[T] =>
log.debug(s"Start calculation for ${c.values.length} values, segment nr. ${c.index}, from ${c.values.head} to ${c.values.last}")
if (c.values.length < c.parallelLimit) {
log.debug("Calculating result direct")
val result = c.values.reduceLeft(c.fn)
sender ! CalculationResult(result, c.index)
}else{
val groupSize: Int = Math.max(1, (c.values.length / fanOut) + Math.min(c.values.length % fanOut, 1))
log.debug(s"Splitting calculation for ${c.values.length} values up to ${fanOut} children, ${groupSize} elements each, limit ${c.parallelLimit}")
def segments=c.values.grouped(groupSize)
log.debug("Starting children")
segments.zipWithIndex.foreach{case (values, index) =>
context.actorOf(subCalculation) ! c.copy(values = values, index = index)
}
val partialResults: Vector[T] = segments.map(_.head).to[Vector]
log.debug(s"Waiting for ${partialResults.length} results (${partialResults.indices})")
context.become(waitForResults(segments.length, partialResults, c, sender), discardOld = true)
}
}
def waitForResults(outstandingResults : Int, partialResults : Vector[T], originalRequest : Calculate[T], originalSender : ActorRef) : Actor.Receive = {
case c : Calculate[_] => sender ! Busy
case r : CalculationResult[T] =>
log.debug(s"Putting result ${r.result} on position ${r.index} in ${partialResults.length}")
val updatedResults = partialResults.updated(r.index, r.result)
log.debug("Killing sub-worker")
sender ! PoisonPill
if (outstandingResults==1) {
log.debug("Calculating result from partial results")
val result = updatedResults.reduceLeft(originalRequest.fn)
originalSender ! CalculationResult(result, originalRequest.index)
context.become(waitForCalculation, discardOld = true)
}else{
log.debug(s"Still waiting for ${outstandingResults-1} results")
// For fanOut > 2 one could here already combine consecutive partial results
context.become(waitForResults(outstandingResults-1, updatedResults, originalRequest, originalSender), discardOld = true)
}
}
}
Optimizations
Using parallel prefix calculation is not optimal. The actors calculating the the product of the bigger numbers will do much more work than the actors calculating the product of the smaller numbers (e.g. when calculating 1 * ... * 100 , it is faster to calculate 1 * ... * 10 than 90 * ... * 100). So it might be a good idea to shuffle the numbers, so big numbers will be mixed with small numbers. This works in this case, because we use an commutative operation. Parallel prefix calculation in general only needs an associative operation to work.
Performance
In theory
Performance of the actor solution is worse than the "naive" solution (using parallel collections) for small amounts of data. The actor solution will shine, when you make complex calculations or distribute your calculation on specialized hardware (e.g. graphics card or FPGA) or on multiple machines. With the actor you can control, who does which calculation and you can even restart "hanging calculations". This can give a big speed up.
On a single machine, the actor solution might help when you have a non-uniform memory architecture. You could then organize the actors in a way that pins memory to a certain processor.
Some measurement
I did some real performance measurement using a Scala worksheet in IntelliJ IDEA.
First I set up the actor system:
// Setup the actor system
val system = ActorSystem("root")
// Start one calculation actor
val calculationStart = Props(classOf[ParallelPrefixActor[BigInt]])
val calcolon = system.actorOf(calculationStart, "Calcolon-BigInt")
val inbox = Inbox.create(system)
Then I defined a helper method to measure time:
// Helper function to measure time
def time[A] (id : String)(f: => A) = {
val start = System.nanoTime()
val result = f
val stop = System.nanoTime()
println(s"""Time for "${id}": ${(stop-start)*1e-6d}ms""")
result
}
And then I did some performance measurement:
// Test code
val limit = 10000
def testRange = (1 to limit).map(BigInt(_))
time("par product")(testRange.par.product)
val timeOut = FiniteDuration(240, TimeUnit.SECONDS)
inbox.send(calcolon, Calculate[BigInt]((1 to limit).map(BigInt(_)), 0, 10, _ * _))
time("actor product")(inbox.receive(timeOut))
time("par sum")(testRange.par.sum)
inbox.send(calcolon, Calculate[BigInt](testRange, 0, 5, _ + _))
time("actor sum")(inbox.receive(timeOut))
I got the following results
> Time for "par product": 134.38289ms
res0: scala.math.BigInt = 284625968091705451890641321211986889014805140170279923
079417999427441134000376444377299078675778477581588406214231752883004233994015
351873905242116138271617481982419982759241828925978789812425312059465996259867
065601615720360323979263287367170557419759620994797203461536981198970926112775
004841988454104755446424421365733030767036288258035489674611170973695786036701
910715127305872810411586405612811653853259684258259955846881464304255898366493
170592517172042765974074461334000541940524623034368691540594040662278282483715
120383221786446271838229238996389928272218797024593876938030946273322925705554
596900278752822425443480211275590191694254290289169072190970836905398737474524
833728995218023632827412170402680867692104515558405671725553720158521328290342
799898184493136...
Time for "actor product": 1310.217247ms
res2: Any = CalculationResult(28462596809170545189064132121198688901480514017027
992307941799942744113400037644437729907867577847758158840621423175288300423399
401535187390524211613827161748198241998275924182892597878981242531205946599625
986706560161572036032397926328736717055741975962099479720346153698119897092611
277500484198845410475544642442136573303076703628825803548967461117097369578603
670191071512730587281041158640561281165385325968425825995584688146430425589836
649317059251717204276597407446133400054194052462303436869154059404066227828248
371512038322178644627183822923899638992827221879702459387693803094627332292570
555459690027875282242544348021127559019169425429028916907219097083690539873747
452483372899521802363282741217040268086769210451555840567172555372015852132829
034279989818449...
> Time for "par sum": 6.488620999999999ms
res3: scala.math.BigInt = 50005000
> Time for "actor sum": 657.752832ms
res5: Any = CalculationResult(50005000,0)
You can easily see that the actor version is much slower than using parallel collections.
I have a list of possible input Values
val inputValues = List(1,2,3,4,5)
I have a really long to compute function that gives me a result
def reallyLongFunction( input: Int ) : Option[String] = { ..... }
Using scala parallel collections, I can easily do
inputValues.par.map( reallyLongFunction( _ ) )
To get what all the results are, in parallel. The problem is, I don't really want all the results, I only want the FIRST result. As soon as one of my input is a success, I want my output, and want to move on with my life. This did a lot of extra work.
So how do I get the best of both worlds? I want to
Get the first result that returns something from my long function
Stop all my other threads from useless work.
Edit -
I solved it like a dumb java programmer by having
#volatile var done = false;
Which is set and checked inside my reallyLongFunction. This works, but does not feel very scala. Would like a better way to do this....
(Updated: no, it doesn't work, doesn't do the map)
Would it work to do something like:
inputValues.par.find({ v => reallyLongFunction(v); true })
The implementation uses this:
protected[this] class Find[U >: T](pred: T => Boolean, protected[this] val pit: IterableSplitter[T]) extends Accessor[Option[U], Find[U]] {
#volatile var result: Option[U] = None
def leaf(prev: Option[Option[U]]) = { if (!pit.isAborted) result = pit.find(pred); if (result != None) pit.abort }
protected[this] def newSubtask(p: IterableSplitter[T]) = new Find(pred, p)
override def merge(that: Find[U]) = if (this.result == None) result = that.result
}
which looks pretty similar in spirit to your #volatile except you don't have to look at it ;-)
I took interpreted your question in the same way as huynhjl, but if you just want to search and discardNones, you could do something like this to avoid the need to repeat the computation when a suitable outcome is found:
class Computation[A,B](value: A, function: A => B) {
lazy val result = function(value)
}
def f(x: Int) = { // your function here
Thread.sleep(100 - x)
if (x > 5) Some(x * 10)
else None
}
val list = List.range(1, 20) map (i => new Computation(i, f))
val found = list.par find (_.result.isDefined)
//found is Option[Computation[Int,Option[Int]]]
val result = found map (_.result.get)
//result is Option[Int]
However find for parallel collections seems to do a lot of unnecessary work (see this question), so this might not work well, with current versions of Scala at least.
Volatile flags are used in the parallel collections (take a look at the source for find, exists, and forall), so I think your idea is a good one. It's actually better if you can include the flag in the function itself. It kills referential transparency on your function (i.e. for certain inputs your function now sometimes returns None rather than Some), but since you're discarding the stopped computations, this shouldn't matter.
If you're willing to use a non-core library, I think Futures would be a good match for this task. For instance:
Akka's Futures include Futures.firstCompletedOf
Twitter's Futures include Future.select
...both of which appear to enable the functionality you're looking for.
Given a very large instance of collection.parallel.mutable.ParHashMap (or any other parallel collection), how can one abort a filtering parallel scan once a given, say 50, number of matches has been found ?
Attempting to accumulate intermediate matches in a thread-safe "external" data structure or keeping an external AtomicInteger with result count seems to be 2 to 3 times slower on 4 cores than using a regular collection.mutable.HashMap and pegging a single core at 100%.
I am aware that find or exists on Par* collections do abort "on the inside". Is there a way to generalize this to find more than one result ?
Here's the code which still seems to be 2 to 3 times slower on the ParHashMap with ~ 79,000 entries and also has a problem of stuffing more than maxResults results into the results CHM (Which is probably due to thread being preempted after incrementAndGet but before break which allows other threads to add more elements in). Update: it seems the slow down is due to worker threads contending on the counter.incrementAndGet() which of course defeats the purpose of the whole parallel scan :-(
def find(filter: Node => Boolean, maxResults: Int): Iterable[Node] =
{
val counter = new AtomicInteger(0)
val results = new ConcurrentHashMap[Key, Node](maxResults)
import util.control.Breaks._
breakable
{
for ((key, node) <- parHashMap if filter(node))
{
results.put(key, node)
val total = counter.incrementAndGet()
if (total > maxResults) break
}
}
results.values.toArray(new Array[Node](results.size))
}
I would first do parallel scan in which variable maxResults would be threadlocal. This would find up to (maxResults * numberOfThreads) results.
Then I would do single threaded scan to reduce it to maxResults.
I had performed an interesting investigation about your case.
Investigation reasoning
I suspected the problem is with the mutability of the input Map and I will try to explain you why: HashMap implementation organizes the data in different buckets, as one can see on Wikipedia.
The first thread-safe collections in Java, the synchronized collections were based on synchronizing all the methods around the underlying implementation and resulted in poor performance. Further research and thinking brought to the more performant Concurrent Collection, such as the ConcurrentHashMap which approach was smarter : why don't we protect each bucket with a specific lock?
According to my feeling the performance problem occurs because:
when you run in parallel your filter, some threads will conflict on accessing the same bucket at once and will hit the same lock, because your map is mutable.
You hold a counter to see how many results you have while you can actually check the size of your
result. If you have a thread-safe way to build a collection, you don't need a thread-safe counter too.
Investigation result
I have developed a test case and I find out I was wrong. The problem is with the concurrent nature of the output map. In fact, that is where the collision occurs, when you are putting elements in the map, rather then when you are iterating on it. Additionally, since you want only the result on values, you don't need the keys and the hashing and all the map features. It might be interesting to test if you remove the AtomicCounter and you use only the result map to check if you collected enough elements how your version performs.
Please be careful with the following code in Scala 2.9.2. I am explaining in another post why I need two different functions for the parallel and the non parallel version: Calling map on a parallel collection via a reference to an ancestor type
object MapPerformance {
val size = 100000
val items = Seq.tabulate(size)( x => (x,x*2))
val concurrentParallelMap = ImmutableParHashMap(items:_*)
val concurrentMutableParallelMap = MutableParHashMap(items:_*)
val unparallelMap = Map(items:_*)
class ThreadSafeIndexedSeqBuilder[T](maxSize:Int) {
val underlyingBuilder = new VectorBuilder[T]()
var counter = 0
def sizeHint(hint:Int) { underlyingBuilder.sizeHint(hint) }
def +=(item:T):Boolean ={
synchronized{
if(counter>=maxSize)
false
else{
underlyingBuilder+=item
counter+=1
true
}
}
}
def result():Vector[T] = underlyingBuilder.result()
}
def find(map:ParMap[Int,Int],filter: Int => Boolean, maxResults: Int): Iterable[Int] =
{
// we already know the maximum size
val resultsBuilder = new ThreadSafeIndexedSeqBuilder[Int](maxResults)
resultsBuilder.sizeHint(maxResults)
import util.control.Breaks._
breakable
{
for ((key, node) <- map if filter(node))
{
val newItemAdded = resultsBuilder+=node
if (!newItemAdded)
break()
}
}
resultsBuilder.result().seq
}
def findUnParallel(map:Map[Int,Int],filter: Int => Boolean, maxResults: Int): Iterable[Int] =
{
// we already know the maximum size
val resultsBuilder = Array.newBuilder[Int]
resultsBuilder.sizeHint(maxResults)
var counter = 0
for {
(key, node) <- map if filter(node)
if counter < maxResults
}{
resultsBuilder+=node
counter+=1
}
resultsBuilder.result()
}
def measureTime[K](f: => K):(Long,K) = {
val startMutable = System.currentTimeMillis()
val result = f
val endMutable = System.currentTimeMillis()
(endMutable-startMutable,result)
}
def main(args:Array[String]) = {
val maxResultSetting=10
(1 to 10).foreach{
tryNumber =>
println("Try number " +tryNumber)
val (mutableTime, mutableResult) = measureTime(find(concurrentMutableParallelMap,_%2==0,maxResultSetting))
val (immutableTime, immutableResult) = measureTime(find(concurrentMutableParallelMap,_%2==0,maxResultSetting))
val (unparallelTime, unparallelResult) = measureTime(findUnParallel(unparallelMap,_%2==0,maxResultSetting))
assert(mutableResult.size==maxResultSetting)
assert(immutableResult.size==maxResultSetting)
assert(unparallelResult.size==maxResultSetting)
println(" The mutable version has taken " + mutableTime + " milliseconds")
println(" The immutable version has taken " + immutableTime + " milliseconds")
println(" The unparallel version has taken " + unparallelTime + " milliseconds")
}
}
}
With this code, I have systematically the parallel (both mutable and immutable version of the input map) about 3,5 time faster then the unparallel on my machine.
You could try to get an iterator and then create a lazy list (a Stream) where you filter (with your predicate) and take the number of elements you want. Because it is a non strict, this 'taking' of elements is not evaluated.
Afterwards you can force the execution by adding ".par" to the whole thing and achieve parallelization.
Example code:
A parallelized map with random values (simulating your parallel hash map):
scala> myMap
res14: scala.collection.parallel.immutable.ParMap[Int,Int] = ParMap(66978401 -> -1331298976, 256964068 -> 126442706, 1698061835 -> 1622679396, -1556333580 -> -1737927220, 791194343 -> -591951714, -1907806173 -> 365922424, 1970481797 -> 162004380, -475841243 -> -445098544, -33856724 -> -1418863050, 1851826878 -> 64176692, 1797820893 -> 405915272, -1838192182 -> 1152824098, 1028423518 -> -2124589278, -670924872 -> 1056679706, 1530917115 -> 1265988738, -808655189 -> -1742792788, 873935965 -> 733748120, -1026980400 -> -163182914, 576661388 -> 900607992, -1950678599 -> -731236098)
Get an iterator and create a Stream from the iterator and filter it.
In this case my predicate is only accepting pairs (of the value member of the map).
I want to get 10 even elements, so I take 10 elements which will only get evaluated when I force it to:
scala> val mapIterator = myMap.toIterator
mapIterator: Iterator[(Int, Int)] = HashTrieIterator(20)
scala> val r = Stream.continually(mapIterator.next()).filter(_._2 % 2 == 0).take(10)
r: scala.collection.immutable.Stream[(Int, Int)] = Stream((66978401,-1331298976), ?)
Finally, I force the evaluation which only gets 10 elements as planned
scala> r.force
res16: scala.collection.immutable.Stream[(Int, Int)] = Stream((66978401,-1331298976), (256964068,126442706), (1698061835,1622679396), (-1556333580,-1737927220), (791194343,-591951714), (-1907806173,365922424), (1970481797,162004380), (-475841243,-445098544), (-33856724,-1418863050), (1851826878,64176692))
This way you only get the number of elements you want (without needing to process the remaining elements) and you parallelize the process without locks, atomics or breaks.
Please compare this to your solutions to see if it is any good.
let's say we have a list of states and we want to sequence them:
import cats.data.State
import cats.instances.list._
import cats.syntax.traverse._
trait MachineState
case object ContinueRunning extends MachineState
case object StopRunning extends MachineState
case class Machine(candy: Int)
val addCandy: Int => State[Machine, MachineState] = amount =>
State[Machine, MachineState] { machine =>
val newCandyAmount = machine.candy + amount
if(newCandyAmount > 10)
(machine, StopRunning)
else
(machine.copy(newCandyAmount), ContinueRunning)
}
List(addCandy(1),
addCandy(2),
addCandy(5),
addCandy(10),
addCandy(20),
addCandy(50)).sequence.run(Machine(0)).value
Result would be
(Machine(10),List(ContinueRunning, ContinueRunning, ContinueRunning, StopRunning, StopRunning, StopRunning))
It's obvious that 3 last steps are redundant. Is there a way to make this sequence stop early? Here when StopRunning gets returned I would like to stop. For example a list of Either's would fail fast and stop sequence early if needed (because it acts like a monad).
For the record - I do know that it is possible to simply write a tail recursion that checks each state that is being runned and if some condition is satisfied - stop the recursion. I just want to know if there is a more elegant way of doing this? The recursion solution seems like a lot of boilerplate to me, am I wrong or not?
Thank you!:))
There are 2 things here needed to be done.
The first is understanding what is actually happening:
State takes some state value, threads in between many composed calls and in the process produces some output value as well
in your case Machine is the state threaded between calls, while MachineState is the output of a single operation
sequence (usually) takes a collection (here List) of some parametric stuff here State[Machine, _] and turns nesting on the left side (here: List[State[Machine, _]] -> State[Machine, List[_]]) (_ is the gap that you'll be filling with your type)
the result is that you'll thread state (Machine(0)) through all the functions, while you combine the output of each of them (MachineState) into list of outputs
// ammonite
// to better see how many times things are being run
# {
val addCandy: Int => State[Machine, MachineState] = amount =>
State[Machine, MachineState] { machine =>
val newCandyAmount = machine.candy + amount
println("new attempt with " + machine + " and " + amount)
if(newCandyAmount > 10)
(machine, StopRunning)
else
(machine.copy(newCandyAmount), ContinueRunning)
}
}
addCandy: Int => State[Machine, MachineState] = ammonite.$sess.cmd24$$$Lambda$2669/1733815710#25c887ca
# List(addCandy(1),
addCandy(2),
addCandy(5),
addCandy(10),
addCandy(20),
addCandy(50)).sequence.run(Machine(0)).value
new attempt with Machine(0) and 1
new attempt with Machine(1) and 2
new attempt with Machine(3) and 5
new attempt with Machine(8) and 10
new attempt with Machine(8) and 20
new attempt with Machine(8) and 50
res25: (Machine, List[MachineState]) = (Machine(8), List(ContinueRunning, ContinueRunning, ContinueRunning, StopRunning, StopRunning, StopRunning))
In other words, what you want is circuit breaking then .sequence might not be what you want.
As a matter of the fact, you probably want something else - combine a list of A => (A, B) functions into one function which stops next computation if the result of a computation is StopRunning (in your code nothing tells the code what is the condition of circuit break and how it should be performed). I would suggest doing it explicitly with some other function, e.g.:
# {
List(addCandy(1),
addCandy(2),
addCandy(5),
addCandy(10),
addCandy(20),
addCandy(50))
.reduce { (a, b) =>
a.flatMap {
// flatMap and map uses MachineState
// - the second parameter is the result after all!
// we are pattern matching on it to decide if we want to
// proceed with computation or stop it
case ContinueRunning => b // runs next computation
case StopRunning => State.pure(StopRunning) // returns current result without modifying it
}
}
.run(Machine(0))
.value
}
new attempt with Machine(0) and 1
new attempt with Machine(1) and 2
new attempt with Machine(3) and 5
new attempt with Machine(8) and 10
res23: (Machine, MachineState) = (Machine(8), StopRunning)
This will eliminate the need for running code within addCandy - but you cannot really get rid of code that combines states together, so this reduce logic will be applied on runtime n-1 times (where n is the size of your list) and that cannot be helped.
BTW If you take a closer look at Either you will find that it also computes n results and only then combines them so that it looks like it's circuit breaking, but in fact isn't. Sequence is combining a result of "parallel" computations but won't interrupt them if any of them failed.