Free ~> Trampoline : recursive program crashes with OutOfMemoryError - scala

Suppose that I'm trying to implement a very simple domain specific language with only one operation:
printLine(line)
Then I want to write a program that takes an integer n as input, prints something if n is divisible by 10k, and then calls itself with n + 1, until n reaches some maximum value N.
Omitting all syntactic noise caused by for-comprehensions, what I want is:
#annotation.tailrec def p(n: Int): Unit = {
if (n % 10000 == 0) printLine("line")
if (n > N) () else p(n + 1)
}
Essentially, it would be a kind of "fizzbuzz".
Here are a few attempts to implement this using the Free monad from Scalaz 7.3.0-M7:
import scalaz._
object Demo1 {
// define operations of a little domain specific language
sealed trait Lang[X]
case class PrintLine(line: String) extends Lang[Unit]
// define the domain specific language as the free monad of operations
type Prog[X] = Free[Lang, X]
import Free.{liftF, pure}
// lift operations into the free monad
def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
def ret: Prog[Unit] = Free.pure(())
// write a program that is just a loop that prints current index
// after every few iteration steps
val mod = 100000
val N = 1000000
// straightforward syntax: deadly slow, exits with OutOfMemoryError
def p0(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
_ <- (if (i > N) ret else p0(i + 1))
} yield ()
// Same as above, but written out without `for`
def p1(i: Int): Prog[Unit] =
(if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
ignore1 =>
(if (i > N) ret else p1(i + 1)).map{ ignore2 => () }
}
// Same as above, with `map` attached to recursive call
def p2(i: Int): Prog[Unit] =
(if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
ignore1 =>
(if (i > N) ret else p2(i + 1).map{ ignore2 => () })
}
// Same as above, but without the `map`; performs ok.
def p3(i: Int): Prog[Unit] = {
(if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
ignore1 =>
if (i > N) ret else p3(i + 1)
}
}
// Variation of the above; Ok.
def p4(i: Int): Prog[Unit] = (for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
} yield ()).flatMap{ ignored2 =>
if (i > N) ret else p4(i + 1)
}
// try to use the variable returned by the last generator after yield,
// hope that the final `map` is optimized away (it's not optimized away...)
def p5(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
stopHere <- (if (i > N) ret else p5(i + 1))
} yield stopHere
// define an interpreter that translates the programs into Trampoline
import scalaz.Trampoline
type Exec[X] = Free.Trampoline[X]
val interpreter = new (Lang ~> Exec) {
def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
case PrintLine(l) => Trampoline.delay(println(l))
}
}
// try it out
def main(args: Array[String]): Unit = {
println("\n p0")
p0(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
println("\n p1")
p1(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
println("\n p2")
p2(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
println("\n p3")
p3(0).foldMap(interpreter).run // ok
println("\n p4")
p4(0).foldMap(interpreter).run // ok
println("\n p5")
p5(0).foldMap(interpreter).run // OutOfMemory
}
}
Unfortunately, the straightforward translation (p0) seems to run with some kind of O(N^2) overhead, and crashes with an OutOfMemoryError. The problem seems to be that the for-comprehension appends a map{x => ()} after the recursive call to p0, which forces the Free monad to fill the entire memory with reminders to "finish 'p0' and then do nothing".
If I manually "unroll" the for comprehension, and write out the last flatMap explicitly (as in p3 and p4), then the problem goes away, and everything runs smoothly. This, however, is an extremely brittle workaround: the behavior of the program changes dramatically if we simply append a map(id) to it, and this map(id) isn't even visible in the code, because it is generated automatically by the for-comprehension.
In this older post here: https://apocalisp.wordpress.com/2011/10/26/tail-call-elimination-in-scala-monads/
it has been repeatedly advised to wrap recursive calls into a suspend. Here is an attempt with Applicative instance and suspend:
import scalaz._
// Essentially same as in `Demo1`, but this time with
// an `Applicative` and an explicit `Suspend` in the
// `for`-comprehension
object Demo2 {
sealed trait Lang[H]
case class Const[H](h: H) extends Lang[H]
case class PrintLine[H](line: String) extends Lang[H]
implicit object Lang extends Applicative[Lang] {
def point[A](a: => A): Lang[A] = Const(a)
def ap[A, B](a: => Lang[A])(f: => Lang[A => B]): Lang[B] = a match {
case Const(x) => {
f match {
case Const(ab) => Const(ab(x))
case _ => throw new Error
}
}
case PrintLine(l) => PrintLine(l)
}
}
type Prog[X] = Free[Lang, X]
import Free.{liftF, pure}
def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
def ret: Prog[Unit] = Free.pure(())
val mod = 100000
val N = 2000000
// try to suspend the entire second generator
def p7(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
_ <- Free.suspend(if (i > N) ret else p7(i + 1))
} yield ()
// try to suspend the recursive call
def p8(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
_ <- if (i > N) ret else Free.suspend(p8(i + 1))
} yield ()
import scalaz.Trampoline
type Exec[X] = Free.Trampoline[X]
val interpreter = new (Lang ~> Exec) {
def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
case Const(x) => Trampoline.done(x)
case PrintLine(l) =>
(Trampoline.delay(println(l))).asInstanceOf[Exec[A]]
}
}
def main(args: Array[String]): Unit = {
p7(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
p8(0).foldMap(interpreter).run // same...
}
}
Inserting suspend did not really help: it's still slow, and crashes with OutOfMemoryErrors.
Should I use the suspend somehow differently?
Maybe there is some purely syntactic remedy that makes it possible to use for-comprehensions without generating the map in the end?
I'd really appreciate if someone could point out what I'm doing wrong here, and how to repair it.

That superfluous map added by the Scala compiler moves the recursion from tail position to non-tail position. Free monad still makes this stack safe, but space complexity becomes O(N) instead of O(1). (Specifically, it is still not O(N2).)
Whether it is possible to make scalac optimize that map away makes for a separate question (which I don't know the answer to).
I will try to illustrate what is going on when interpreting p1 versus p3. (I will ignore the translation to Trampoline, which is redundant (see below).)
p3 (i.e. without extra map)
Let me use the following shorthand:
def cont(i: Int): Unit => Prg[Unit] =
ignore1 => if (i > N) ret else p3(i + 1)
Now p3(0) is interpreted as follows
p3(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p3(1)
ret flatMap cont(1)
cont(1)
p3(2)
ret flatMap cont(2)
cont(2)
and so on... You see that the amount of memory needed at any point doesn't exceed some constant upper bound.
p1 (i.e. with extra map)
I will use the following shorthands:
def cont(i: Int): Unit => Prg[Unit] =
ignore1 => (if (i > N) ret else p1(i + 1)).map{ ignore2 => () }
def cpu: Unit => Prg[Unit] = // constant pure unit
ignore => Free.pure(())
Now p1(0) is interpreted as follows:
p1(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p1(1) map { ignore2 => () }
// Free.map is implemented via flatMap
p1(1) flatMap cpu
(ret flatMap cont(1)) flatMap cpu
cont(1) flatMap cpu
(p1(2) map { ignore2 => () }) flatMap cpu
(p1(2) flatMap cpu) flatMap cpu
((ret flatMap cont(2)) flatMap cpu) flatMap cpu
(cont(2) flatMap cpu) flatMap cpu
((p1(3) map { ignore2 => () }) flatMap cpu) flatMap cpu
((p1(3) flatMap cpu) flatMap cpu) flatMap cpu
(((ret flatMap cont(3)) flatMap cpu) flatMap cpu) flatMap cpu
and so on... You see that the memory consumption depends linearly on N. We just moved the evaluation from stack to heap.
Take away: To keep Free memory friendly, keep the recursion in "tail position", that is, on the right hand-side of flatMap (or map).
Aside: The translation to Trampoline is not necessary, since Free is already trampolined. You could interpret directly to Id and use foldMapRec for stack-safe interpretation:
val idInterpreter = new (Lang ~> Id) {
def apply[A](cmd: Lang[A]): Id[A] = cmd match {
case PrintLine(l) => println(l)
}
}
p0(0).foldMapRec(idInterpreter)
This will regain you some fraction of memory (but doesn't make the problem go away).

Related

Functional way of interrupting lazy iteration depedning on timeout and comparisson between previous and next, while, LazyList vs Stream

Background
I have the following scenario. I want to execute the method of a class from an external library, repeatedly, and I want to do so until a certain timeout condition and result condition (compared to the previous result) is met. Furthermore I want to collect the return values, even on the "failed" run (the run with the "failing" result condition that should interrupt further execution).
Thus far I have accomplished this with initializing an empty var result: Result, a var stop: Boolean and using a while loop that runs while the conditions are true and modifying the outer state. I would like to get rid of this and use a functional approach.
Some context. Each run is expected to run from 0 to 60 minutes and the total time of iteration is capped at 60 minutes. Theoretically, there's no bound to how many times it executes in this period but in practice, it's generally 2-60 times.
The problem is, the runs take a long time so I need to stop the execution. My idea is to use some kind of lazy Iterator or Stream coupled with scanLeft and Option.
Code
Boiler plate
This code isn't particularly relevant but used in my approach samples and provide identical but somewhat random pseudo runtime results.
import scala.collection.mutable.ListBuffer
import scala.util.Random
val r = Random
r.setSeed(1)
val sleepingTimes: Seq[Int] = (1 to 601)
.map(x => Math.pow(2, x).toInt * r.nextInt(100))
.toList
.filter(_ > 0)
.sorted
val randomRes = r.shuffle((0 to 600).map(x => r.nextInt(10)).toList)
case class Result(val a: Int, val slept: Int)
class Lib() {
def run(i: Int) = {
println(s"running ${i}")
Thread.sleep(sleepingTimes(i))
Result(randomRes(i), sleepingTimes(i))
}
}
case class Baz(i: Int, result: Result)
val lib = new Lib()
val timeout = 10 * 1000
While approach
val iteratorStart = System.currentTimeMillis()
val iterator = for {
i <- (0 to 600).iterator
if System.currentTimeMillis() < iteratorStart + timeout
f = Baz(i, lib.run(i))
} yield f
val iteratorBuffer = ListBuffer[Baz]()
if (iterator.hasNext) iteratorBuffer.append(iterator.next())
var run = true
while (run && iterator.hasNext) {
val next = iterator.next()
run = iteratorBuffer.last.result.a < next.result.a
iteratorBuffer.append(next)
}
Stream approach (Scala.2.12)
Full example
val streamStart = System.currentTimeMillis()
val stream = for {
i <- (0 to 600).toStream
if System.currentTimeMillis() < streamStart + timeout
} yield Baz(i, lib.run(i))
var last: Option[Baz] = None
val head = stream.headOption
val tail = if (stream.nonEmpty) stream.tail else stream
val streamVersion = (tail
.scanLeft((head, true))((x, y) => {
if (x._1.exists(_.result.a > y.result.a)) (Some(y), false)
else (Some(y), true)
})
.takeWhile {
case (baz, continue) =>
if (!baz.eq(head)) last = baz
continue
}
.map(_._1)
.toList :+ last).flatten
LazyList approach (Scala 2.13)
Full example
val lazyListStart = System.currentTimeMillis()
val lazyList = for {
i <- (0 to 600).to(LazyList)
if System.currentTimeMillis() < lazyListStart + timeout
} yield Baz(i, lib.run(i))
var last: Option[Baz] = None
val head = lazyList.headOption
val tail = if (lazyList.nonEmpty) lazyList.tail else lazyList
val lazyListVersion = (tail
.scanLeft((head, true))((x, y) => {
if (x._1.exists(_.result.a > y.result.a)) (Some(y), false)
else (Some(y), true)
})
.takeWhile {
case (baz, continue) =>
if (!baz.eq(head)) last = baz
continue
}
.map(_._1)
.toList :+ last).flatten
Result
Both approaches appear to yield the correct end result:
List(Baz(0,Result(4,170)), Baz(1,Result(5,208)))
and they interrupt execution as desired.
Edit: The desired outcome is to not execute the next iteration but still return the result of the iteration that caused the interruption. Thus the desired result is
List(Baz(0,Result(4,170)), Baz(1,Result(5,208)), Baz(2,Result(2,256))
and lib.run(i) should only run 3 times.
This is achieved by the while approach, as well as the LazyList approach but not the Stream approach which executes lib.run 4 times (Bad!).
Question
Is there another stateless approach, which is hopefully more elegant?
Edit
I realized my examples were faulty and not returning the "failing" result, which it should, and that they kept executing beyond the stop condition. I rewrote the code and examples but I believe the spirit of the question is the same.
I would use something higher level, like fs2.
(or any other high-level streaming library, like: monix observables, akka streams or zio zstreams)
def runUntilOrTimeout[F[_]: Concurrent: Timer, A](work: F[A], timeout: FiniteDuration)
(stop: (A, A) => Boolean): Stream[F, A] = {
val interrupt =
Stream.sleep_(timeout)
val run =
Stream
.repeatEval(work)
.zipWithPrevious
.takeThrough {
case (Some(p), c) if stop(p, c) => false
case _ => true
} map {
case (_, c) => c
}
run mergeHaltBoth interrupt
}
You can see it working here.

List[Int] => Int without String Conversion?

I came up with the following to convert a List[Int] => Try[BigDecimal]:
import scala.util.Try
def f(xs: List[Int]): Try[BigDecimal] =
Try { xs.mkString.toInt }.map ( BigDecimal(_) )
Example:
scala> f(List(1,2,3,4))
res4: scala.util.Try[BigDecimal] = Success(1234)
scala> f(List(1,2,3,55555))
res5: scala.util.Try[BigDecimal] = Success(12355555)
Is there a way to write this function without resorting to a String conversion step?
Not very pretty, and I'm not convinced it's much more efficient. Here's the basic outline.
val pwrs:Stream[BigInt] = 10 #:: pwrs.map(_ * 10)
List(1,2,3,55555).foldLeft(0:BigInt)((p,i) => pwrs.find(_ > i).get * p + i)
Here it is a little more fleshed out with error handling.
import scala.util.Try
def f(xs: List[Int]): Try[BigDecimal] = Try {
lazy val pwrs: Stream[BigDecimal] = 10 #:: pwrs.map(_ * 10)
xs.foldLeft(0: BigDecimal) {
case (acc, i) if i >= 0 => pwrs.find(_ > i).get * acc + i
case _ => throw new Error("bad")
}
}
UPDATE
Just for giggles, I thought I'd plug some code into Rex Kerr's handy benchmarking/profiling tool, Thyme.
the code
import scala.util.Try
def fString(xs: List[Int]): Try[BigInt] = Try { BigInt(xs.mkString) }
def fStream(xs: List[Int]): Try[BigInt] = Try {
lazy val pwrs: Stream[BigInt] = 10 #:: pwrs.map(_ * 10)
xs.foldLeft(0: BigInt) {
case (acc, i) if i >= 0 => pwrs.find(_ > i).get * acc + i
case _ => throw new Error("bad")
}
}
def fLog10(xs: List[Int]): Try[BigInt] = Try {
xs.foldLeft(0: BigInt) {
case (acc, i) if i >= 0 =>
math.pow(10, math.ceil(math.log10(i))).toInt * acc + i
case _ => throw new Error("bad")
}
}
fString() is a slight simplification of Kevin's original question. fStream() is my proposed non-string implementation. fLog10 is the same but with Alexey's suggested enhancement.
You'll note that I'm using BigInt instead of BigDecimal. I found that both non-string methods encountered a bug somewhere around the 37th digit of the result. Some kind of rounding error or something, but there was no problem with BigInt so that's what I used.
test setup
// create a List of 40 Ints and check its contents
val lst = List.fill(40)(util.Random.nextInt(20000))
lst.min // 5
lst.max // 19858
lst.mkString.length // 170
val th = ichi.bench.Thyme.warmed(verbose = print)
th.pbenchWarm(th.Warm(fString(lst)), title="fString")
th.pbenchWarm(th.Warm(fStream(lst)), title="fStream")
th.pbenchWarm(th.Warm(fLog10(lst)), title="fLog10")
results
Benchmark for fString (20 calls in 345.6 ms) Time: 4.015 us 95%
CI 3.957 us - 4.073 us (n=19) Garbage: 109.9 ns (n=2 sweeps
measured)
Benchmark for fStream (20 calls in 305.6 ms) Time: 7.118 us 95%
CI 7.024 us - 7.213 us (n=19) Garbage: 293.0 ns (n=3 sweeps
measured)
Benchmark for fLog10 (20 calls in 382.8 ms) Time: 9.205 us 95%
CI 9.187 us - 9.222 us (n=17) Garbage: 73.24 ns (n=2 sweeps
measured)
So I was right about the efficiency of the non-string algorithm. Oddly, using math._ to avoid Stream creation doesn't appear to be better. I didn't expect that.
takeaway
Number-to-string and string-to-number transitions are reasonably efficient.
import scala.util.{Try, Success}
import scala.annotation.tailrec
def findOrder(i: Int): Long = {
#tailrec
def _findOrder(i: Int, order: Long): Long = {
if (i < order) order
else _findOrder(i, order * 10)
}
_findOrder(i, 1)
}
def f(xs: List[Int]): Try[BigDecimal] = Try(
xs.foldLeft(BigDecimal(0))((acc, i) => acc * findOrder(i) + i)
)
To find the correct power of 10 more efficiently (replace pwrs.find(_ > i).get with nextPowerOf10(i) in #jwvh's answer):
def nextPowerOf10(x: Int) = {
val n = math.ceil(math.log10(x))
BigDecimal(math.pow(10, n))
}
Since you start with an Int, there should be no rounding issues.

Using Streams ... still got Out of Memory

I am trying to generate a file which contains a very large test data in the form of a json array.
Since my test data is really big I cannot use "mkString". This is also why I am using Streams and tail recursion.
But still my program gets an Out of memory exception
package com.abhi
import java.io.FileWriter
import scala.annotation.tailrec
object GenerateTestFile {
def jsonField(name: String, value : String) : String = {
s""""$name":"$value""""
}
def writeRecords(records : Stream[String], fw : FileWriter) : Unit = {
#tailrec
def inner(records: Stream[String]) : Unit = {
records match {
case head #:: Stream.Empty => fw.write(head)
case head #:: tail => fw.write(head + ","); inner(tail)
case Stream.Empty =>
}
}
inner(records)
}
def main(args: Array[String]) : Unit = {
val fileWriter = new FileWriter("./sample.json", true)
fileWriter.write("[")
val lines = (1 to 10000000)
.toStream
.map(x => (
jsonField("id", x.toString),
jsonField("FieldA", "FieldA" + x),
jsonField("FieldB", "FieldB" +x),
jsonField("FieldC", "FieldC" + x),
jsonField("FieldD", "FieldD" + x),
jsonField("FieldE", "FieldE" + x)
))
.map (t => t.productIterator.mkString("{",",", "}"))
writeRecords(lines, fileWriter)
fileWriter.write("]")
fileWriter.close()
}
}
Exception
[error] (run-main-0) java.lang.OutOfMemoryError: GC overhead limit exceeded
java.lang.OutOfMemoryError: GC overhead limit exceeded
at java.util.Arrays.copyOf(Arrays.java:2367)
at java.lang.AbstractStringBuilder.expandCapacity(AbstractStringBuilder.java:130)
at java.lang.AbstractStringBuilder.ensureCapacityInternal(AbstractStringBuilder.java:114)
at java.lang.AbstractStringBuilder.append(AbstractStringBuilder.java:415)
at java.lang.StringBuilder.append(StringBuilder.java:132)
at java.lang.StringBuilder.append(StringBuilder.java:128)
at scala.StringContext.standardInterpolator(StringContext.scala:125)
at scala.StringContext.s(StringContext.scala:95)
at com.abhi.GenerateTestFile$.jsonField(GenerateTestFile.scala:9)
at com.abhi.GenerateTestFile$$anonfun$1.apply(GenerateTestFile.scala:32)
What you are looking for is not a stream but an iterator that materialises that list of values one by one; also after you process it i.e. save to file throws away the memory.
When you create your json you actually hold in memory the entire sequence of numbers and in addition to that you generate for each element a new large text blog which you later put in the file. Memory wise the initial sequence is insignificant compared to the size of the text.
What I've did is I've used a for comprehension to create an iterator that releases elements one by one. With foldLeft I make sure that i map the element to a json string, write it do disk and release the memory (the reference to any created object is lost therefore GC can kick in an reclaim the memory. unfortunately with this approach you cannot make use of the parallelism features.
def main(args: Array[String]): Unit = {
val fileWriter = new FileWriter("./sample.json", true)
fileWriter.write("[")
fileWriter.write(createObject(1).productIterator.mkString("{", ",", "}"))
val lines = (for (v <- 2 to 10000000) yield v)
.foldLeft(0)((_, x) => {
if (x % 50000 == 0)
println(s"We've reached $x element")
fileWriter.write(createObject(x).productIterator.mkString(",{", ",", "}"))
x
})
fileWriter.write("]")
fileWriter.close()
}
def createObject(x: Int) =
(jsonField("id", x.toString),
jsonField("FieldA", "FieldA" + x),
jsonField("FieldB", "FieldB" + x),
jsonField("FieldC", "FieldC" + x),
jsonField("FieldD", "FieldD" + x),
jsonField("FieldE", "FieldE" + x))
from the source of Stream:
* - One must be cautious of memoization; you can very quickly eat up large
* amounts of memory if you're not careful. The reason for this is that the
* memoization of the `Stream` creates a structure much like
* [[scala.collection.immutable.List]]. So long as something is holding on to
* the head, the head holds on to the tail, and so it continues recursively.
* If, on the other hand, there is nothing holding on to the head (e.g. we used
* `def` to define the `Stream`) then once it is no longer being used directly,
* it disappears.
So, if you inline lines (or write it as a def), and rewrite your writeRecords function to not hold onto a reference to the initial head, (or write the param as 'call by name' value using an arrow: records: => Stream[String], which does basically the same thing as def vs val) elements should be garbage collected as the rest of the stream is processed:
#tailrec
def writeRecords(records : Stream[String], fw : FileWriter) : Unit = {
records match {
case head #:: Stream.Empty => fw.write(head)
case head #:: tail => fw.write(head + ","); writeRecords(tail, fw)
case Stream.Empty =>
}
}
writeRecords((1 to 10000000)
.toStream
.map(x => (
jsonField("id", x.toString),
jsonField("FieldA", "FieldA" + x),
jsonField("FieldB", "FieldB" +x),
jsonField("FieldC", "FieldC" + x),
jsonField("FieldD", "FieldD" + x),
jsonField("FieldE", "FieldE" + x)
))
.map (t => t.productIterator.mkString("{",",", "}")),
fileWriter)

Stack overflow when trying to compute Ackermann

As part of an experiment I wanted to see how long it would take to compute Ack(0,0) to Ack(4,19) both with and without caching/memoization. But I keep running into a simple stumbling block... My stack keeps overflowing.
Here's my code:
import org.joda.time.{Seconds, DateTime}
import scala.collection.mutable
// Wrapper case class to make it easier to look for a specific m n combination when using a Map
// also makes the code cleaner by letting me use a match rather than chained ifs
case class Ack(m: Int, n: Int)
object Program extends App {
// basic ackermann without cache; reaches stack overflow at around Ack(3,11)
def compute(a: Ack): Int = a match {
case Ack(0, n) => n + 1
case Ack(m, 0) => compute(Ack(m - 1, 1))
case Ack(m, n) => compute(Ack(m - 1, compute(Ack(m, n - 1))))
}
// ackermann WITH cache; also reaches stack overflow at around Ack(3,11)
def computeWithCache(a: Ack, cache: mutable.Map[Ack, Int]): Int = if(cache contains a) {
val res = cache(a)
println(s"result from cache: $res")
return res // explicit use of 'return' for readability's sake
} else {
val res = a match {
case Ack(0, n) => n + 1
case Ack(m, 0) => computeWithCache(Ack(m - 1, 1), cache)
case Ack(m, n) => computeWithCache(Ack(m - 1, computeWithCache(Ack(m, n - 1), cache)), cache)
}
cache += a -> res
return res
}
// convenience method
def getSeconds(start: DateTime, end: DateTime): Int =
Seconds.secondsBetween(start, end).getSeconds
val time = DateTime.now
val cache = mutable.Map[Ack, Int]()
for{
i <- 0 to 4
j <- 0 until 20
} println(s"result of ackermann($i, $j) => ${computeWithCache(Ack(i, j), cache)}. Time: ${getSeconds(time, DateTime.now)}")
}
I run Scala 2.11.1 in IntelliJ Idea 13.1.3 with the Scala and SBT plugins.
Is there anything I can do to not stack overflow at around Ack(3, 11)?
I've tried adding javacOptions += "-Xmx2G" to my build.sbt, but it just seems to make the problem worse.

Abort early in a fold

What's the best way to terminate a fold early? As a simplified example, imagine I want to sum up the numbers in an Iterable, but if I encounter something I'm not expecting (say an odd number) I might want to terminate. This is a first approximation
def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
nums.foldLeft (Some(0): Option[Int]) {
case (Some(s), n) if n % 2 == 0 => Some(s + n)
case _ => None
}
}
However, this solution is pretty ugly (as in, if I did a .foreach and a return -- it'd be much cleaner and clearer) and worst of all, it traverses the entire iterable even if it encounters a non-even number.
So what would be the best way to write a fold like this, that terminates early? Should I just go and write this recursively, or is there a more accepted way?
My first choice would usually be to use recursion. It is only moderately less compact, is potentially faster (certainly no slower), and in early termination can make the logic more clear. In this case you need nested defs which is a little awkward:
def sumEvenNumbers(nums: Iterable[Int]) = {
def sumEven(it: Iterator[Int], n: Int): Option[Int] = {
if (it.hasNext) {
val x = it.next
if ((x % 2) == 0) sumEven(it, n+x) else None
}
else Some(n)
}
sumEven(nums.iterator, 0)
}
My second choice would be to use return, as it keeps everything else intact and you only need to wrap the fold in a def so you have something to return from--in this case, you already have a method, so:
def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
Some(nums.foldLeft(0){ (n,x) =>
if ((n % 2) != 0) return None
n+x
})
}
which in this particular case is a lot more compact than recursion (though we got especially unlucky with recursion since we had to do an iterable/iterator transformation). The jumpy control flow is something to avoid when all else is equal, but here it's not. No harm in using it in cases where it's valuable.
If I was doing this often and wanted it within the middle of a method somewhere (so I couldn't just use return), I would probably use exception-handling to generate non-local control flow. That is, after all, what it is good at, and error handling is not the only time it's useful. The only trick is to avoid generating a stack trace (which is really slow), and that's easy because the trait NoStackTrace and its child trait ControlThrowable already do that for you. Scala already uses this internally (in fact, that's how it implements the return from inside the fold!). Let's make our own (can't be nested, though one could fix that):
import scala.util.control.ControlThrowable
case class Returned[A](value: A) extends ControlThrowable {}
def shortcut[A](a: => A) = try { a } catch { case Returned(v) => v }
def sumEvenNumbers(nums: Iterable[Int]) = shortcut{
Option(nums.foldLeft(0){ (n,x) =>
if ((x % 2) != 0) throw Returned(None)
n+x
})
}
Here of course using return is better, but note that you could put shortcut anywhere, not just wrapping an entire method.
Next in line for me would be to re-implement fold (either myself or to find a library that does it) so that it could signal early termination. The two natural ways of doing this are to not propagate the value but an Option containing the value, where None signifies termination; or to use a second indicator function that signals completion. The Scalaz lazy fold shown by Kim Stebel already covers the first case, so I'll show the second (with a mutable implementation):
def foldOrFail[A,B](it: Iterable[A])(zero: B)(fail: A => Boolean)(f: (B,A) => B): Option[B] = {
val ii = it.iterator
var b = zero
while (ii.hasNext) {
val x = ii.next
if (fail(x)) return None
b = f(b,x)
}
Some(b)
}
def sumEvenNumbers(nums: Iterable[Int]) = foldOrFail(nums)(0)(_ % 2 != 0)(_ + _)
(Whether you implement the termination by recursion, return, laziness, etc. is up to you.)
I think that covers the main reasonable variants; there are some other options also, but I'm not sure why one would use them in this case. (Iterator itself would work well if it had a findOrPrevious, but it doesn't, and the extra work it takes to do that by hand makes it a silly option to use here.)
The scenario you describe (exit upon some unwanted condition) seems like a good use case for the takeWhile method. It is essentially filter, but should end upon encountering an element that doesn't meet the condition.
For example:
val list = List(2,4,6,8,6,4,2,5,3,2)
list.takeWhile(_ % 2 == 0) //result is List(2,4,6,8,6,4,2)
This will work just fine for Iterators/Iterables too. The solution I suggest for your "sum of even numbers, but break on odd" is:
list.iterator.takeWhile(_ % 2 == 0).foldLeft(...)
And just to prove that it's not wasting your time once it hits an odd number...
scala> val list = List(2,4,5,6,8)
list: List[Int] = List(2, 4, 5, 6, 8)
scala> def condition(i: Int) = {
| println("processing " + i)
| i % 2 == 0
| }
condition: (i: Int)Boolean
scala> list.iterator.takeWhile(condition _).sum
processing 2
processing 4
processing 5
res4: Int = 6
You can do what you want in a functional style using the lazy version of foldRight in scalaz. For a more in depth explanation, see this blog post. While this solution uses a Stream, you can convert an Iterable into a Stream efficiently with iterable.toStream.
import scalaz._
import Scalaz._
val str = Stream(2,1,2,2,2,2,2,2,2)
var i = 0 //only here for testing
val r = str.foldr(Some(0):Option[Int])((n,s) => {
println(i)
i+=1
if (n % 2 == 0) s.map(n+) else None
})
This only prints
0
1
which clearly shows that the anonymous function is only called twice (i.e. until it encounters the odd number). That is due to the definition of foldr, whose signature (in case of Stream) is def foldr[B](b: B)(f: (Int, => B) => B)(implicit r: scalaz.Foldable[Stream]): B. Note that the anonymous function takes a by name parameter as its second argument, so it need no be evaluated.
Btw, you can still write this with the OP's pattern matching solution, but I find if/else and map more elegant.
Well, Scala does allow non local returns. There are differing opinions on whether or not this is a good style.
scala> def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
| nums.foldLeft (Some(0): Option[Int]) {
| case (None, _) => return None
| case (Some(s), n) if n % 2 == 0 => Some(s + n)
| case (Some(_), _) => None
| }
| }
sumEvenNumbers: (nums: Iterable[Int])Option[Int]
scala> sumEvenNumbers(2 to 10)
res8: Option[Int] = None
scala> sumEvenNumbers(2 to 10 by 2)
res9: Option[Int] = Some(30)
EDIT:
In this particular case, as #Arjan suggested, you can also do:
def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
nums.foldLeft (Some(0): Option[Int]) {
case (Some(s), n) if n % 2 == 0 => Some(s + n)
case _ => return None
}
}
You can use foldM from cats lib (as suggested by #Didac) but I suggest to use Either instead of Option if you want to get actual sum out.
bifoldMap is used to extract the result from Either.
import cats.implicits._
def sumEven(nums: Stream[Int]): Either[Int, Int] = {
nums.foldM(0) {
case (acc, n) if n % 2 == 0 => Either.right(acc + n)
case (acc, n) => {
println(s"Stopping on number: $n")
Either.left(acc)
}
}
}
examples:
println("Result: " + sumEven(Stream(2, 2, 3, 11)).bifoldMap(identity, identity))
> Stopping on number: 3
> Result: 4
println("Result: " + sumEven(Stream(2, 7, 2, 3)).bifoldMap(identity, identity))
> Stopping on number: 7
> Result: 2
Cats has a method called foldM which does short-circuiting (for Vector, List, Stream, ...).
It works as follows:
def sumEvenNumbers(nums: Stream[Int]): Option[Long] = {
import cats.implicits._
nums.foldM(0L) {
case (acc, c) if c % 2 == 0 => Some(acc + c)
case _ => None
}
}
If it finds a not even element it returns None without computing the rest, otherwise it returns the sum of the even entries.
If you want to keep count until an even entry is found, you should use an Either[Long, Long]
#Rex Kerr your answer helped me, but I needed to tweak it to use Either
def foldOrFail[A,B,C,D](map: B => Either[D, C])(merge: (A, C) => A)(initial: A)(it: Iterable[B]): Either[D, A] = {
val ii= it.iterator
var b= initial
while (ii.hasNext) {
val x= ii.next
map(x) match {
case Left(error) => return Left(error)
case Right(d) => b= merge(b, d)
}
}
Right(b)
}
You could try using a temporary var and using takeWhile. Here is a version.
var continue = true
// sample stream of 2's and then a stream of 3's.
val evenSum = (Stream.fill(10)(2) ++ Stream.fill(10)(3)).takeWhile(_ => continue)
.foldLeft(Option[Int](0)){
case (result,i) if i%2 != 0 =>
continue = false;
// return whatever is appropriate either the accumulated sum or None.
result
case (optionSum,i) => optionSum.map( _ + i)
}
The evenSum should be Some(20) in this case.
You can throw a well-chosen exception upon encountering your termination criterion, handling it in the calling code.
A more beutiful solution would be using span:
val (l, r) = numbers.span(_ % 2 == 0)
if(r.isEmpty) Some(l.sum)
else None
... but it traverses the list two times if all the numbers are even
Just for an "academic" reasons (:
var headers = Source.fromFile(file).getLines().next().split(",")
var closeHeaderIdx = headers.takeWhile { s => !"Close".equals(s) }.foldLeft(0)((i, S) => i+1)
Takes twice then it should but it is a nice one liner.
If "Close" not found it will return
headers.size
Another (better) is this one:
var headers = Source.fromFile(file).getLines().next().split(",").toList
var closeHeaderIdx = headers.indexOf("Close")