I came up with the following to convert a List[Int] => Try[BigDecimal]:
import scala.util.Try
def f(xs: List[Int]): Try[BigDecimal] =
Try { xs.mkString.toInt }.map ( BigDecimal(_) )
Example:
scala> f(List(1,2,3,4))
res4: scala.util.Try[BigDecimal] = Success(1234)
scala> f(List(1,2,3,55555))
res5: scala.util.Try[BigDecimal] = Success(12355555)
Is there a way to write this function without resorting to a String conversion step?
Not very pretty, and I'm not convinced it's much more efficient. Here's the basic outline.
val pwrs:Stream[BigInt] = 10 #:: pwrs.map(_ * 10)
List(1,2,3,55555).foldLeft(0:BigInt)((p,i) => pwrs.find(_ > i).get * p + i)
Here it is a little more fleshed out with error handling.
import scala.util.Try
def f(xs: List[Int]): Try[BigDecimal] = Try {
lazy val pwrs: Stream[BigDecimal] = 10 #:: pwrs.map(_ * 10)
xs.foldLeft(0: BigDecimal) {
case (acc, i) if i >= 0 => pwrs.find(_ > i).get * acc + i
case _ => throw new Error("bad")
}
}
UPDATE
Just for giggles, I thought I'd plug some code into Rex Kerr's handy benchmarking/profiling tool, Thyme.
the code
import scala.util.Try
def fString(xs: List[Int]): Try[BigInt] = Try { BigInt(xs.mkString) }
def fStream(xs: List[Int]): Try[BigInt] = Try {
lazy val pwrs: Stream[BigInt] = 10 #:: pwrs.map(_ * 10)
xs.foldLeft(0: BigInt) {
case (acc, i) if i >= 0 => pwrs.find(_ > i).get * acc + i
case _ => throw new Error("bad")
}
}
def fLog10(xs: List[Int]): Try[BigInt] = Try {
xs.foldLeft(0: BigInt) {
case (acc, i) if i >= 0 =>
math.pow(10, math.ceil(math.log10(i))).toInt * acc + i
case _ => throw new Error("bad")
}
}
fString() is a slight simplification of Kevin's original question. fStream() is my proposed non-string implementation. fLog10 is the same but with Alexey's suggested enhancement.
You'll note that I'm using BigInt instead of BigDecimal. I found that both non-string methods encountered a bug somewhere around the 37th digit of the result. Some kind of rounding error or something, but there was no problem with BigInt so that's what I used.
test setup
// create a List of 40 Ints and check its contents
val lst = List.fill(40)(util.Random.nextInt(20000))
lst.min // 5
lst.max // 19858
lst.mkString.length // 170
val th = ichi.bench.Thyme.warmed(verbose = print)
th.pbenchWarm(th.Warm(fString(lst)), title="fString")
th.pbenchWarm(th.Warm(fStream(lst)), title="fStream")
th.pbenchWarm(th.Warm(fLog10(lst)), title="fLog10")
results
Benchmark for fString (20 calls in 345.6 ms) Time: 4.015 us 95%
CI 3.957 us - 4.073 us (n=19) Garbage: 109.9 ns (n=2 sweeps
measured)
Benchmark for fStream (20 calls in 305.6 ms) Time: 7.118 us 95%
CI 7.024 us - 7.213 us (n=19) Garbage: 293.0 ns (n=3 sweeps
measured)
Benchmark for fLog10 (20 calls in 382.8 ms) Time: 9.205 us 95%
CI 9.187 us - 9.222 us (n=17) Garbage: 73.24 ns (n=2 sweeps
measured)
So I was right about the efficiency of the non-string algorithm. Oddly, using math._ to avoid Stream creation doesn't appear to be better. I didn't expect that.
takeaway
Number-to-string and string-to-number transitions are reasonably efficient.
import scala.util.{Try, Success}
import scala.annotation.tailrec
def findOrder(i: Int): Long = {
#tailrec
def _findOrder(i: Int, order: Long): Long = {
if (i < order) order
else _findOrder(i, order * 10)
}
_findOrder(i, 1)
}
def f(xs: List[Int]): Try[BigDecimal] = Try(
xs.foldLeft(BigDecimal(0))((acc, i) => acc * findOrder(i) + i)
)
To find the correct power of 10 more efficiently (replace pwrs.find(_ > i).get with nextPowerOf10(i) in #jwvh's answer):
def nextPowerOf10(x: Int) = {
val n = math.ceil(math.log10(x))
BigDecimal(math.pow(10, n))
}
Since you start with an Int, there should be no rounding issues.
Related
I have a lot of calculations contributing to a final result, with no restrictions on the order of the contributions. Seems like Futures should be able to speed this up, and they do, but not in the way I thought they would. Here is code comparing performance of a very silly way to divide integers:
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration.Duration
import scala.concurrent.{Await, Future}
object scale_me_up {
def main(args: Array[String]) {
val M = 500 * 1000
val N = 5
Thread.sleep(3210) // let launcher settle down
for (it <- 0 until 15) {
val method = it % 3
val start = System.currentTimeMillis()
val result = divide(M, N, method)
val elapsed = System.currentTimeMillis() - start
assert(result == M / N)
if (it >= 6) {
val methods = Array("ordinary", "fast parallel", "nice parallel")
val name = methods(method)
println(f"$name%15s: $elapsed ms")
}
}
}
def is_multiple_of(m: Int, n: Int): Boolean = {
val result = !(1 until n).map(_ + (m / n) * n).toSet.contains(m)
assert(result == (m % n == 0)) // yes, a less crazy implementation exists
result
}
def divide(m: Int, n: Int, method: Int): Int = {
method match {
case 0 =>
(1 to m).count(is_multiple_of(_, n))
case 1 =>
(1 to m)
.map { x =>
Future { is_multiple_of(x, n) }
}
.count(Await.result(_, Duration.Inf))
case 2 =>
Await.result(divide_futuristically(m, n), Duration.Inf)
}
}
def divide_futuristically(m: Int, n: Int): Future[Int] = {
val futures = (1 to m).map { x =>
Future { is_multiple_of(x, n) }
}
Future.foldLeft(futures)(0) { (count, flag) =>
{ if (flag) { count + 1 } else { count } }
}
/* much worse performing alternative:
Future.sequence(futures).map(_.count(identity))
*/
}
}
When I run this, the parallel case 1 is somewhat faster than the ordinary case 0 code (hurray), but case 2 takes twice as long. Of course, it depends on the system and whether enough work needs to be done per future (which here grows with the denominator N) to offset the concurrency overhead. [PS] As expected, decreasing N gives case 0 the lead and increasing N enough makes both case 1 and case 2 about twice as fast as case 0 on my two core CPU.
I'm lead to believe that divide_futuristically is a better way to express this kind of calculation: returning a future with the combined result. Blocking is just something we need here to measure performance. But in reality, the more blocking, the faster everyone is finished. What am I doing wrong? Several alternatives to summarize the futures (like sequence) all take the same penalty.
[PPS] this was with Scala 2.12 running on Java 11 on a 2 core CPU. With Java 12 on a 6 core CPU, the difference is much less significant (though the alternative with sequence still drags its feet). With Scala 2.13, the difference is even less and as you increase the amount of work per iteration, divide_futuristically starts outperforming the competition. The future is here at last...
Seems you did everything right. I've tried by myself different approaches even .par but got the same or worse result.
I've dived into Future.foldLeft and tried to analyze what caused the latency:
/** A non-blocking, asynchronous left fold over the specified futures,
* with the start value of the given zero.
* The fold is performed asynchronously in left-to-right order as the futures become completed.
* The result will be the first failure of any of the futures, or any failure in the actual fold,
* or the result of the fold.
*
* Example:
* {{{
* val futureSum = Future.foldLeft(futures)(0)(_ + _)
* }}}
*
* #tparam T the type of the value of the input Futures
* #tparam R the type of the value of the returned `Future`
* #param futures the `scala.collection.immutable.Iterable` of Futures to be folded
* #param zero the start value of the fold
* #param op the fold operation to be applied to the zero and futures
* #return the `Future` holding the result of the fold
*/
def foldLeft[T, R](futures: scala.collection.immutable.Iterable[Future[T]])(zero: R)(op: (R, T) => R)(implicit executor: ExecutionContext): Future[R] =
foldNext(futures.iterator, zero, op)
private[this] def foldNext[T, R](i: Iterator[Future[T]], prevValue: R, op: (R, T) => R)(implicit executor: ExecutionContext): Future[R] =
if (!i.hasNext) successful(prevValue)
else i.next().flatMap { value => foldNext(i, op(prevValue, value), op) }
and this part:
else i.next().flatMap { value => foldNext(i, op(prevValue, value), op) }
.flatMap produces a new Future which is submitted to executor. In other words, every
{ (count, flag) =>
{ if (flag) { count + 1 } else { count } }
}
is executed as new Future.
I suppose this part causes the experimentally proved latency.
I am Java developer and learning Scala at the moment. It is generally admitted, that Java is more verbose then Scala. I just need to call 2 or more methods concurrently and then combine the result. Official Scala documentation at docs.scala-lang.org/overviews/core/futures.html suggests to use for-comprehention for that. So I used that out-of-the-box solution straightforwardly. Then I thought how I would do it with CompletableFuture and was surprised that it produced more concise and faster code, then Scala's Future
Let's consider a basic concurrent case: summing up values in array. For simplicity, let's split array in 2 parts(hence it will be 2 worker threads). Java's sumConcurrently takes only 4 LOC, while Scala's version requires 12 LOC. Also Java's version is 15% faster on my computer.
Complete code, not benchmark optimised.
Java impl.:
public class CombiningCompletableFuture {
static int sumConcurrently(List<Integer> numbers) throws ExecutionException, InterruptedException {
int mid = numbers.size() / 2;
return CompletableFuture.supplyAsync( () -> sumSequentially(numbers.subList(0, mid)))
.thenCombine(CompletableFuture.supplyAsync( () -> sumSequentially(numbers.subList(mid, numbers.size())))
, (left, right) -> left + right).get();
}
static int sumSequentially(List<Integer> numbers) {
try {
Thread.sleep(TimeUnit.SECONDS.toMillis(1));
} catch (InterruptedException ignored) { }
return numbers.stream().mapToInt(Integer::intValue).sum();
}
public static void main(String[] args) throws ExecutionException, InterruptedException {
List<Integer> from1toTen = IntStream.rangeClosed(1, 10).boxed().collect(toList());
long start = System.currentTimeMillis();
long sum = sumConcurrently(from1toTen);
long duration = System.currentTimeMillis() - start;
System.out.printf("sum is %d in %d ms.", sum, duration);
}
}
Scala's impl.:
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.{Await, Future}
import scala.concurrent.duration._
object CombiningFutures extends App {
def sumConcurrently(numbers: Seq[Int]) = {
val splitted = numbers.splitAt(5)
val leftFuture = Future {
sumSequentally(splitted._1)
}
val rightFuture = Future {
sumSequentally(splitted._2)
}
val totalFuture = for {
left <- leftFuture
right <- rightFuture
} yield left + right
Await.result(totalFuture, Duration.Inf)
}
def sumSequentally(numbers: Seq[Int]) = {
Thread.sleep(1000)
numbers.sum
}
val from1toTen = 1 to 10
val start = System.currentTimeMillis
val sum = sumConcurrently(from1toTen)
val duration = System.currentTimeMillis - start
println(s"sum is $sum in $duration ms.")
}
Any explanations and suggestions how to improve Scala code without impacting readability too much?
A verbose scala version of your sumConcurrently,
def sumConcurrently(numbers: List[Int]): Future[Int] = {
val (v1, v2) = numbers.splitAt(numbers.length / 2)
for {
sum1 <- Future(v1.sum)
sum2 <- Future(v2.sum)
} yield sum1 + sum2
}
A more concise version
def sumConcurrently2(numbers: List[Int]): Future[Int] = numbers.splitAt(numbers.length / 2) match {
case (l1, l2) => Future.sequence(List(Future(l1.sum), Future(l2.sum))).map(_.sum)
}
And all this is because we have to partition the list. Lets say we had to write a function which takes a few lists and returns the sum of their sum's using multiple concurrent computations,
def sumConcurrently3(lists: List[Int]*): Future[Int] =
Future.sequence(lists.map(l => Future(l.sum))).map(_.sum)
If the above looks cryptic... then let me de-crypt it,
def sumConcurrently3(lists: List[Int]*): Future[Int] = {
val listOfFuturesOfSums = lists.map { l => Future(l.sum) }
val futureOfListOfSums = Future.sequence(listOfFuturesOfSums)
futureOfListOfSums.map { l => l.sum }
}
Now, whenever you use the result of a Future (lets say the future completes at time t1) in a computation, it means that this computation is bound to happen after time t1. You can do it with blocking like this in Scala,
val sumFuture = sumConcurrently(List(1, 2, 3, 4))
val sum = Await.result(sumFuture, Duration.Inf)
val anotherSum = sum + 100
println("another sum = " + anotherSum)
But what is the point of all that, you are blocking the current thread while for the computation on those threads to finish. Why not move the whole computation into the future itself.
val sumFuture = sumConcurrently(List(1, 2, 3, 4))
val anotherSumFuture = sumFuture.map(s => s + 100)
anotherSumFuture.foreach(s => println("another sum = " + s))
Now, you are not blocking anywhere and the threads can be used anywhere required.
Future implementation and api in Scala is designed to enable you to write your program avoiding blocking as far as possible.
For the task at hand, the following is probably not the tersest option either:
def sumConcurrently(numbers: Vector[Int]): Future[Int] = {
val (v1, v2) = numbers.splitAt(numbers.length / 2)
Future(v1.sum).zipWith(Future(v2.sum))(_ + _)
}
As I mentioned in my comment there are several issues with your example.
Suppose that I'm trying to implement a very simple domain specific language with only one operation:
printLine(line)
Then I want to write a program that takes an integer n as input, prints something if n is divisible by 10k, and then calls itself with n + 1, until n reaches some maximum value N.
Omitting all syntactic noise caused by for-comprehensions, what I want is:
#annotation.tailrec def p(n: Int): Unit = {
if (n % 10000 == 0) printLine("line")
if (n > N) () else p(n + 1)
}
Essentially, it would be a kind of "fizzbuzz".
Here are a few attempts to implement this using the Free monad from Scalaz 7.3.0-M7:
import scalaz._
object Demo1 {
// define operations of a little domain specific language
sealed trait Lang[X]
case class PrintLine(line: String) extends Lang[Unit]
// define the domain specific language as the free monad of operations
type Prog[X] = Free[Lang, X]
import Free.{liftF, pure}
// lift operations into the free monad
def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
def ret: Prog[Unit] = Free.pure(())
// write a program that is just a loop that prints current index
// after every few iteration steps
val mod = 100000
val N = 1000000
// straightforward syntax: deadly slow, exits with OutOfMemoryError
def p0(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
_ <- (if (i > N) ret else p0(i + 1))
} yield ()
// Same as above, but written out without `for`
def p1(i: Int): Prog[Unit] =
(if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
ignore1 =>
(if (i > N) ret else p1(i + 1)).map{ ignore2 => () }
}
// Same as above, with `map` attached to recursive call
def p2(i: Int): Prog[Unit] =
(if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
ignore1 =>
(if (i > N) ret else p2(i + 1).map{ ignore2 => () })
}
// Same as above, but without the `map`; performs ok.
def p3(i: Int): Prog[Unit] = {
(if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
ignore1 =>
if (i > N) ret else p3(i + 1)
}
}
// Variation of the above; Ok.
def p4(i: Int): Prog[Unit] = (for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
} yield ()).flatMap{ ignored2 =>
if (i > N) ret else p4(i + 1)
}
// try to use the variable returned by the last generator after yield,
// hope that the final `map` is optimized away (it's not optimized away...)
def p5(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
stopHere <- (if (i > N) ret else p5(i + 1))
} yield stopHere
// define an interpreter that translates the programs into Trampoline
import scalaz.Trampoline
type Exec[X] = Free.Trampoline[X]
val interpreter = new (Lang ~> Exec) {
def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
case PrintLine(l) => Trampoline.delay(println(l))
}
}
// try it out
def main(args: Array[String]): Unit = {
println("\n p0")
p0(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
println("\n p1")
p1(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
println("\n p2")
p2(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
println("\n p3")
p3(0).foldMap(interpreter).run // ok
println("\n p4")
p4(0).foldMap(interpreter).run // ok
println("\n p5")
p5(0).foldMap(interpreter).run // OutOfMemory
}
}
Unfortunately, the straightforward translation (p0) seems to run with some kind of O(N^2) overhead, and crashes with an OutOfMemoryError. The problem seems to be that the for-comprehension appends a map{x => ()} after the recursive call to p0, which forces the Free monad to fill the entire memory with reminders to "finish 'p0' and then do nothing".
If I manually "unroll" the for comprehension, and write out the last flatMap explicitly (as in p3 and p4), then the problem goes away, and everything runs smoothly. This, however, is an extremely brittle workaround: the behavior of the program changes dramatically if we simply append a map(id) to it, and this map(id) isn't even visible in the code, because it is generated automatically by the for-comprehension.
In this older post here: https://apocalisp.wordpress.com/2011/10/26/tail-call-elimination-in-scala-monads/
it has been repeatedly advised to wrap recursive calls into a suspend. Here is an attempt with Applicative instance and suspend:
import scalaz._
// Essentially same as in `Demo1`, but this time with
// an `Applicative` and an explicit `Suspend` in the
// `for`-comprehension
object Demo2 {
sealed trait Lang[H]
case class Const[H](h: H) extends Lang[H]
case class PrintLine[H](line: String) extends Lang[H]
implicit object Lang extends Applicative[Lang] {
def point[A](a: => A): Lang[A] = Const(a)
def ap[A, B](a: => Lang[A])(f: => Lang[A => B]): Lang[B] = a match {
case Const(x) => {
f match {
case Const(ab) => Const(ab(x))
case _ => throw new Error
}
}
case PrintLine(l) => PrintLine(l)
}
}
type Prog[X] = Free[Lang, X]
import Free.{liftF, pure}
def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
def ret: Prog[Unit] = Free.pure(())
val mod = 100000
val N = 2000000
// try to suspend the entire second generator
def p7(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
_ <- Free.suspend(if (i > N) ret else p7(i + 1))
} yield ()
// try to suspend the recursive call
def p8(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
_ <- if (i > N) ret else Free.suspend(p8(i + 1))
} yield ()
import scalaz.Trampoline
type Exec[X] = Free.Trampoline[X]
val interpreter = new (Lang ~> Exec) {
def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
case Const(x) => Trampoline.done(x)
case PrintLine(l) =>
(Trampoline.delay(println(l))).asInstanceOf[Exec[A]]
}
}
def main(args: Array[String]): Unit = {
p7(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
p8(0).foldMap(interpreter).run // same...
}
}
Inserting suspend did not really help: it's still slow, and crashes with OutOfMemoryErrors.
Should I use the suspend somehow differently?
Maybe there is some purely syntactic remedy that makes it possible to use for-comprehensions without generating the map in the end?
I'd really appreciate if someone could point out what I'm doing wrong here, and how to repair it.
That superfluous map added by the Scala compiler moves the recursion from tail position to non-tail position. Free monad still makes this stack safe, but space complexity becomes O(N) instead of O(1). (Specifically, it is still not O(N2).)
Whether it is possible to make scalac optimize that map away makes for a separate question (which I don't know the answer to).
I will try to illustrate what is going on when interpreting p1 versus p3. (I will ignore the translation to Trampoline, which is redundant (see below).)
p3 (i.e. without extra map)
Let me use the following shorthand:
def cont(i: Int): Unit => Prg[Unit] =
ignore1 => if (i > N) ret else p3(i + 1)
Now p3(0) is interpreted as follows
p3(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p3(1)
ret flatMap cont(1)
cont(1)
p3(2)
ret flatMap cont(2)
cont(2)
and so on... You see that the amount of memory needed at any point doesn't exceed some constant upper bound.
p1 (i.e. with extra map)
I will use the following shorthands:
def cont(i: Int): Unit => Prg[Unit] =
ignore1 => (if (i > N) ret else p1(i + 1)).map{ ignore2 => () }
def cpu: Unit => Prg[Unit] = // constant pure unit
ignore => Free.pure(())
Now p1(0) is interpreted as follows:
p1(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p1(1) map { ignore2 => () }
// Free.map is implemented via flatMap
p1(1) flatMap cpu
(ret flatMap cont(1)) flatMap cpu
cont(1) flatMap cpu
(p1(2) map { ignore2 => () }) flatMap cpu
(p1(2) flatMap cpu) flatMap cpu
((ret flatMap cont(2)) flatMap cpu) flatMap cpu
(cont(2) flatMap cpu) flatMap cpu
((p1(3) map { ignore2 => () }) flatMap cpu) flatMap cpu
((p1(3) flatMap cpu) flatMap cpu) flatMap cpu
(((ret flatMap cont(3)) flatMap cpu) flatMap cpu) flatMap cpu
and so on... You see that the memory consumption depends linearly on N. We just moved the evaluation from stack to heap.
Take away: To keep Free memory friendly, keep the recursion in "tail position", that is, on the right hand-side of flatMap (or map).
Aside: The translation to Trampoline is not necessary, since Free is already trampolined. You could interpret directly to Id and use foldMapRec for stack-safe interpretation:
val idInterpreter = new (Lang ~> Id) {
def apply[A](cmd: Lang[A]): Id[A] = cmd match {
case PrintLine(l) => println(l)
}
}
p0(0).foldMapRec(idInterpreter)
This will regain you some fraction of memory (but doesn't make the problem go away).
As part of an experiment I wanted to see how long it would take to compute Ack(0,0) to Ack(4,19) both with and without caching/memoization. But I keep running into a simple stumbling block... My stack keeps overflowing.
Here's my code:
import org.joda.time.{Seconds, DateTime}
import scala.collection.mutable
// Wrapper case class to make it easier to look for a specific m n combination when using a Map
// also makes the code cleaner by letting me use a match rather than chained ifs
case class Ack(m: Int, n: Int)
object Program extends App {
// basic ackermann without cache; reaches stack overflow at around Ack(3,11)
def compute(a: Ack): Int = a match {
case Ack(0, n) => n + 1
case Ack(m, 0) => compute(Ack(m - 1, 1))
case Ack(m, n) => compute(Ack(m - 1, compute(Ack(m, n - 1))))
}
// ackermann WITH cache; also reaches stack overflow at around Ack(3,11)
def computeWithCache(a: Ack, cache: mutable.Map[Ack, Int]): Int = if(cache contains a) {
val res = cache(a)
println(s"result from cache: $res")
return res // explicit use of 'return' for readability's sake
} else {
val res = a match {
case Ack(0, n) => n + 1
case Ack(m, 0) => computeWithCache(Ack(m - 1, 1), cache)
case Ack(m, n) => computeWithCache(Ack(m - 1, computeWithCache(Ack(m, n - 1), cache)), cache)
}
cache += a -> res
return res
}
// convenience method
def getSeconds(start: DateTime, end: DateTime): Int =
Seconds.secondsBetween(start, end).getSeconds
val time = DateTime.now
val cache = mutable.Map[Ack, Int]()
for{
i <- 0 to 4
j <- 0 until 20
} println(s"result of ackermann($i, $j) => ${computeWithCache(Ack(i, j), cache)}. Time: ${getSeconds(time, DateTime.now)}")
}
I run Scala 2.11.1 in IntelliJ Idea 13.1.3 with the Scala and SBT plugins.
Is there anything I can do to not stack overflow at around Ack(3, 11)?
I've tried adding javacOptions += "-Xmx2G" to my build.sbt, but it just seems to make the problem worse.
What's the best way to terminate a fold early? As a simplified example, imagine I want to sum up the numbers in an Iterable, but if I encounter something I'm not expecting (say an odd number) I might want to terminate. This is a first approximation
def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
nums.foldLeft (Some(0): Option[Int]) {
case (Some(s), n) if n % 2 == 0 => Some(s + n)
case _ => None
}
}
However, this solution is pretty ugly (as in, if I did a .foreach and a return -- it'd be much cleaner and clearer) and worst of all, it traverses the entire iterable even if it encounters a non-even number.
So what would be the best way to write a fold like this, that terminates early? Should I just go and write this recursively, or is there a more accepted way?
My first choice would usually be to use recursion. It is only moderately less compact, is potentially faster (certainly no slower), and in early termination can make the logic more clear. In this case you need nested defs which is a little awkward:
def sumEvenNumbers(nums: Iterable[Int]) = {
def sumEven(it: Iterator[Int], n: Int): Option[Int] = {
if (it.hasNext) {
val x = it.next
if ((x % 2) == 0) sumEven(it, n+x) else None
}
else Some(n)
}
sumEven(nums.iterator, 0)
}
My second choice would be to use return, as it keeps everything else intact and you only need to wrap the fold in a def so you have something to return from--in this case, you already have a method, so:
def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
Some(nums.foldLeft(0){ (n,x) =>
if ((n % 2) != 0) return None
n+x
})
}
which in this particular case is a lot more compact than recursion (though we got especially unlucky with recursion since we had to do an iterable/iterator transformation). The jumpy control flow is something to avoid when all else is equal, but here it's not. No harm in using it in cases where it's valuable.
If I was doing this often and wanted it within the middle of a method somewhere (so I couldn't just use return), I would probably use exception-handling to generate non-local control flow. That is, after all, what it is good at, and error handling is not the only time it's useful. The only trick is to avoid generating a stack trace (which is really slow), and that's easy because the trait NoStackTrace and its child trait ControlThrowable already do that for you. Scala already uses this internally (in fact, that's how it implements the return from inside the fold!). Let's make our own (can't be nested, though one could fix that):
import scala.util.control.ControlThrowable
case class Returned[A](value: A) extends ControlThrowable {}
def shortcut[A](a: => A) = try { a } catch { case Returned(v) => v }
def sumEvenNumbers(nums: Iterable[Int]) = shortcut{
Option(nums.foldLeft(0){ (n,x) =>
if ((x % 2) != 0) throw Returned(None)
n+x
})
}
Here of course using return is better, but note that you could put shortcut anywhere, not just wrapping an entire method.
Next in line for me would be to re-implement fold (either myself or to find a library that does it) so that it could signal early termination. The two natural ways of doing this are to not propagate the value but an Option containing the value, where None signifies termination; or to use a second indicator function that signals completion. The Scalaz lazy fold shown by Kim Stebel already covers the first case, so I'll show the second (with a mutable implementation):
def foldOrFail[A,B](it: Iterable[A])(zero: B)(fail: A => Boolean)(f: (B,A) => B): Option[B] = {
val ii = it.iterator
var b = zero
while (ii.hasNext) {
val x = ii.next
if (fail(x)) return None
b = f(b,x)
}
Some(b)
}
def sumEvenNumbers(nums: Iterable[Int]) = foldOrFail(nums)(0)(_ % 2 != 0)(_ + _)
(Whether you implement the termination by recursion, return, laziness, etc. is up to you.)
I think that covers the main reasonable variants; there are some other options also, but I'm not sure why one would use them in this case. (Iterator itself would work well if it had a findOrPrevious, but it doesn't, and the extra work it takes to do that by hand makes it a silly option to use here.)
The scenario you describe (exit upon some unwanted condition) seems like a good use case for the takeWhile method. It is essentially filter, but should end upon encountering an element that doesn't meet the condition.
For example:
val list = List(2,4,6,8,6,4,2,5,3,2)
list.takeWhile(_ % 2 == 0) //result is List(2,4,6,8,6,4,2)
This will work just fine for Iterators/Iterables too. The solution I suggest for your "sum of even numbers, but break on odd" is:
list.iterator.takeWhile(_ % 2 == 0).foldLeft(...)
And just to prove that it's not wasting your time once it hits an odd number...
scala> val list = List(2,4,5,6,8)
list: List[Int] = List(2, 4, 5, 6, 8)
scala> def condition(i: Int) = {
| println("processing " + i)
| i % 2 == 0
| }
condition: (i: Int)Boolean
scala> list.iterator.takeWhile(condition _).sum
processing 2
processing 4
processing 5
res4: Int = 6
You can do what you want in a functional style using the lazy version of foldRight in scalaz. For a more in depth explanation, see this blog post. While this solution uses a Stream, you can convert an Iterable into a Stream efficiently with iterable.toStream.
import scalaz._
import Scalaz._
val str = Stream(2,1,2,2,2,2,2,2,2)
var i = 0 //only here for testing
val r = str.foldr(Some(0):Option[Int])((n,s) => {
println(i)
i+=1
if (n % 2 == 0) s.map(n+) else None
})
This only prints
0
1
which clearly shows that the anonymous function is only called twice (i.e. until it encounters the odd number). That is due to the definition of foldr, whose signature (in case of Stream) is def foldr[B](b: B)(f: (Int, => B) => B)(implicit r: scalaz.Foldable[Stream]): B. Note that the anonymous function takes a by name parameter as its second argument, so it need no be evaluated.
Btw, you can still write this with the OP's pattern matching solution, but I find if/else and map more elegant.
Well, Scala does allow non local returns. There are differing opinions on whether or not this is a good style.
scala> def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
| nums.foldLeft (Some(0): Option[Int]) {
| case (None, _) => return None
| case (Some(s), n) if n % 2 == 0 => Some(s + n)
| case (Some(_), _) => None
| }
| }
sumEvenNumbers: (nums: Iterable[Int])Option[Int]
scala> sumEvenNumbers(2 to 10)
res8: Option[Int] = None
scala> sumEvenNumbers(2 to 10 by 2)
res9: Option[Int] = Some(30)
EDIT:
In this particular case, as #Arjan suggested, you can also do:
def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
nums.foldLeft (Some(0): Option[Int]) {
case (Some(s), n) if n % 2 == 0 => Some(s + n)
case _ => return None
}
}
You can use foldM from cats lib (as suggested by #Didac) but I suggest to use Either instead of Option if you want to get actual sum out.
bifoldMap is used to extract the result from Either.
import cats.implicits._
def sumEven(nums: Stream[Int]): Either[Int, Int] = {
nums.foldM(0) {
case (acc, n) if n % 2 == 0 => Either.right(acc + n)
case (acc, n) => {
println(s"Stopping on number: $n")
Either.left(acc)
}
}
}
examples:
println("Result: " + sumEven(Stream(2, 2, 3, 11)).bifoldMap(identity, identity))
> Stopping on number: 3
> Result: 4
println("Result: " + sumEven(Stream(2, 7, 2, 3)).bifoldMap(identity, identity))
> Stopping on number: 7
> Result: 2
Cats has a method called foldM which does short-circuiting (for Vector, List, Stream, ...).
It works as follows:
def sumEvenNumbers(nums: Stream[Int]): Option[Long] = {
import cats.implicits._
nums.foldM(0L) {
case (acc, c) if c % 2 == 0 => Some(acc + c)
case _ => None
}
}
If it finds a not even element it returns None without computing the rest, otherwise it returns the sum of the even entries.
If you want to keep count until an even entry is found, you should use an Either[Long, Long]
#Rex Kerr your answer helped me, but I needed to tweak it to use Either
def foldOrFail[A,B,C,D](map: B => Either[D, C])(merge: (A, C) => A)(initial: A)(it: Iterable[B]): Either[D, A] = {
val ii= it.iterator
var b= initial
while (ii.hasNext) {
val x= ii.next
map(x) match {
case Left(error) => return Left(error)
case Right(d) => b= merge(b, d)
}
}
Right(b)
}
You could try using a temporary var and using takeWhile. Here is a version.
var continue = true
// sample stream of 2's and then a stream of 3's.
val evenSum = (Stream.fill(10)(2) ++ Stream.fill(10)(3)).takeWhile(_ => continue)
.foldLeft(Option[Int](0)){
case (result,i) if i%2 != 0 =>
continue = false;
// return whatever is appropriate either the accumulated sum or None.
result
case (optionSum,i) => optionSum.map( _ + i)
}
The evenSum should be Some(20) in this case.
You can throw a well-chosen exception upon encountering your termination criterion, handling it in the calling code.
A more beutiful solution would be using span:
val (l, r) = numbers.span(_ % 2 == 0)
if(r.isEmpty) Some(l.sum)
else None
... but it traverses the list two times if all the numbers are even
Just for an "academic" reasons (:
var headers = Source.fromFile(file).getLines().next().split(",")
var closeHeaderIdx = headers.takeWhile { s => !"Close".equals(s) }.foldLeft(0)((i, S) => i+1)
Takes twice then it should but it is a nice one liner.
If "Close" not found it will return
headers.size
Another (better) is this one:
var headers = Source.fromFile(file).getLines().next().split(",").toList
var closeHeaderIdx = headers.indexOf("Close")