As part of an experiment I wanted to see how long it would take to compute Ack(0,0) to Ack(4,19) both with and without caching/memoization. But I keep running into a simple stumbling block... My stack keeps overflowing.
Here's my code:
import org.joda.time.{Seconds, DateTime}
import scala.collection.mutable
// Wrapper case class to make it easier to look for a specific m n combination when using a Map
// also makes the code cleaner by letting me use a match rather than chained ifs
case class Ack(m: Int, n: Int)
object Program extends App {
// basic ackermann without cache; reaches stack overflow at around Ack(3,11)
def compute(a: Ack): Int = a match {
case Ack(0, n) => n + 1
case Ack(m, 0) => compute(Ack(m - 1, 1))
case Ack(m, n) => compute(Ack(m - 1, compute(Ack(m, n - 1))))
}
// ackermann WITH cache; also reaches stack overflow at around Ack(3,11)
def computeWithCache(a: Ack, cache: mutable.Map[Ack, Int]): Int = if(cache contains a) {
val res = cache(a)
println(s"result from cache: $res")
return res // explicit use of 'return' for readability's sake
} else {
val res = a match {
case Ack(0, n) => n + 1
case Ack(m, 0) => computeWithCache(Ack(m - 1, 1), cache)
case Ack(m, n) => computeWithCache(Ack(m - 1, computeWithCache(Ack(m, n - 1), cache)), cache)
}
cache += a -> res
return res
}
// convenience method
def getSeconds(start: DateTime, end: DateTime): Int =
Seconds.secondsBetween(start, end).getSeconds
val time = DateTime.now
val cache = mutable.Map[Ack, Int]()
for{
i <- 0 to 4
j <- 0 until 20
} println(s"result of ackermann($i, $j) => ${computeWithCache(Ack(i, j), cache)}. Time: ${getSeconds(time, DateTime.now)}")
}
I run Scala 2.11.1 in IntelliJ Idea 13.1.3 with the Scala and SBT plugins.
Is there anything I can do to not stack overflow at around Ack(3, 11)?
I've tried adding javacOptions += "-Xmx2G" to my build.sbt, but it just seems to make the problem worse.
Related
I'm using pattern matching in scala a lot. Many times I need to do some calculations in guard part and sometimes they are pretty expensive. Is there any way to bind calculated values to separate value?
//i wan't to use result of prettyExpensiveFunc in body safely
people.collect {
case ...
case Some(Right((x, y))) if prettyExpensiveFunc(x, y) > 0 => prettyExpensiveFunc(x)
}
//ideally something like that could be helpful, but it doesn't compile:
people.collect {
case ...
case Some(Right((x, y))) if {val z = prettyExpensiveFunc(x, y); y > 0} => z
}
//this sollution works but it isn't safe for some `Seq` types and is risky when more cases are used.
var cache:Int = 0
people.collect {
case ...
case Some(Right((x, y))) if {cache = prettyExpensiveFunc(x, y); cache > 0} => cache
}
Is there any better solution?
ps: Example is simplified and I don't expect anwers that shows that I don't need pattern matching here.
You can use cats.Eval to make expensive calculations lazy and memoizable, create Evals using .map and extract .value (calculated at most once - if needed) in .collect
values.map { value =>
val expensiveCheck1 = Eval.later { prettyExpensiveFunc(value) }
val expensiveCheck2 = Eval.later { anotherExpensiveFunc(value) }
(value, expensiveCheck1, expensiveCheck2)
}.collect {
case (value, lazyResult1, _) if lazyResult1.value > 0 => ...
case (value, _, lazyResult2) if lazyResult2.value > 0 => ...
case (value, lazyResult1, lazyResult2) if lazyResult1.value > lazyResult2.value => ...
...
}
I don't see a way of doing what you want without creating some implementation of lazy evaluation, and if you have to use one, you might as well use existing one instead of rolling one yourself.
EDIT. Just in case you haven't noticed - you aren't losing the ability to pattern match by using tuple here:
values.map {
// originial value -> lazily evaluated memoized expensive calculation
case a # Some(Right((x, y)) => a -> Some(Eval.later(prettyExpensiveFunc(x, y)))
case a => a -> None
}.collect {
// match type and calculation
...
case (Some(Right((x, y))), Some(lazyResult)) if lazyResult.value > 0 => ...
...
}
Why not run the function first for every element and then work with a tuple?
Seq(1,2,3,4,5).map(e => (e, prettyExpensiveFunc(e))).collect {
case ...
case (x, y) if y => y
}
I tried own matchers and effect is somehow OK, but not perfect. My matcher is untyped, and it is bit ugly to make it fully typed.
class Matcher[T,E](f:PartialFunction[T, E]) {
def unapply(z: T): Option[E] = if (f.isDefinedAt(z)) Some(f(z)) else None
}
def newMatcherAny[E](f:PartialFunction[Any, E]) = new Matcher(f)
def newMatcher[T,E](f:PartialFunction[T, E]) = new Matcher(f)
def prettyExpensiveFunc(x:Int) = {println(s"-- prettyExpensiveFunc($x)"); x%2+x*x}
val x = Seq(
Some(Right(22)),
Some(Right(10)),
Some(Left("Oh now")),
None
)
val PersonAgeRank = newMatcherAny { case Some(Right(x:Int)) => (x, prettyExpensiveFunc(x)) }
x.collect {
case PersonAgeRank(age, rank) if rank > 100 => println("age:"+age + " rank:" + rank)
}
https://scalafiddle.io/sf/hFbcAqH/3
Suppose that I'm trying to implement a very simple domain specific language with only one operation:
printLine(line)
Then I want to write a program that takes an integer n as input, prints something if n is divisible by 10k, and then calls itself with n + 1, until n reaches some maximum value N.
Omitting all syntactic noise caused by for-comprehensions, what I want is:
#annotation.tailrec def p(n: Int): Unit = {
if (n % 10000 == 0) printLine("line")
if (n > N) () else p(n + 1)
}
Essentially, it would be a kind of "fizzbuzz".
Here are a few attempts to implement this using the Free monad from Scalaz 7.3.0-M7:
import scalaz._
object Demo1 {
// define operations of a little domain specific language
sealed trait Lang[X]
case class PrintLine(line: String) extends Lang[Unit]
// define the domain specific language as the free monad of operations
type Prog[X] = Free[Lang, X]
import Free.{liftF, pure}
// lift operations into the free monad
def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
def ret: Prog[Unit] = Free.pure(())
// write a program that is just a loop that prints current index
// after every few iteration steps
val mod = 100000
val N = 1000000
// straightforward syntax: deadly slow, exits with OutOfMemoryError
def p0(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
_ <- (if (i > N) ret else p0(i + 1))
} yield ()
// Same as above, but written out without `for`
def p1(i: Int): Prog[Unit] =
(if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
ignore1 =>
(if (i > N) ret else p1(i + 1)).map{ ignore2 => () }
}
// Same as above, with `map` attached to recursive call
def p2(i: Int): Prog[Unit] =
(if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
ignore1 =>
(if (i > N) ret else p2(i + 1).map{ ignore2 => () })
}
// Same as above, but without the `map`; performs ok.
def p3(i: Int): Prog[Unit] = {
(if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
ignore1 =>
if (i > N) ret else p3(i + 1)
}
}
// Variation of the above; Ok.
def p4(i: Int): Prog[Unit] = (for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
} yield ()).flatMap{ ignored2 =>
if (i > N) ret else p4(i + 1)
}
// try to use the variable returned by the last generator after yield,
// hope that the final `map` is optimized away (it's not optimized away...)
def p5(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
stopHere <- (if (i > N) ret else p5(i + 1))
} yield stopHere
// define an interpreter that translates the programs into Trampoline
import scalaz.Trampoline
type Exec[X] = Free.Trampoline[X]
val interpreter = new (Lang ~> Exec) {
def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
case PrintLine(l) => Trampoline.delay(println(l))
}
}
// try it out
def main(args: Array[String]): Unit = {
println("\n p0")
p0(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
println("\n p1")
p1(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
println("\n p2")
p2(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
println("\n p3")
p3(0).foldMap(interpreter).run // ok
println("\n p4")
p4(0).foldMap(interpreter).run // ok
println("\n p5")
p5(0).foldMap(interpreter).run // OutOfMemory
}
}
Unfortunately, the straightforward translation (p0) seems to run with some kind of O(N^2) overhead, and crashes with an OutOfMemoryError. The problem seems to be that the for-comprehension appends a map{x => ()} after the recursive call to p0, which forces the Free monad to fill the entire memory with reminders to "finish 'p0' and then do nothing".
If I manually "unroll" the for comprehension, and write out the last flatMap explicitly (as in p3 and p4), then the problem goes away, and everything runs smoothly. This, however, is an extremely brittle workaround: the behavior of the program changes dramatically if we simply append a map(id) to it, and this map(id) isn't even visible in the code, because it is generated automatically by the for-comprehension.
In this older post here: https://apocalisp.wordpress.com/2011/10/26/tail-call-elimination-in-scala-monads/
it has been repeatedly advised to wrap recursive calls into a suspend. Here is an attempt with Applicative instance and suspend:
import scalaz._
// Essentially same as in `Demo1`, but this time with
// an `Applicative` and an explicit `Suspend` in the
// `for`-comprehension
object Demo2 {
sealed trait Lang[H]
case class Const[H](h: H) extends Lang[H]
case class PrintLine[H](line: String) extends Lang[H]
implicit object Lang extends Applicative[Lang] {
def point[A](a: => A): Lang[A] = Const(a)
def ap[A, B](a: => Lang[A])(f: => Lang[A => B]): Lang[B] = a match {
case Const(x) => {
f match {
case Const(ab) => Const(ab(x))
case _ => throw new Error
}
}
case PrintLine(l) => PrintLine(l)
}
}
type Prog[X] = Free[Lang, X]
import Free.{liftF, pure}
def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
def ret: Prog[Unit] = Free.pure(())
val mod = 100000
val N = 2000000
// try to suspend the entire second generator
def p7(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
_ <- Free.suspend(if (i > N) ret else p7(i + 1))
} yield ()
// try to suspend the recursive call
def p8(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
_ <- if (i > N) ret else Free.suspend(p8(i + 1))
} yield ()
import scalaz.Trampoline
type Exec[X] = Free.Trampoline[X]
val interpreter = new (Lang ~> Exec) {
def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
case Const(x) => Trampoline.done(x)
case PrintLine(l) =>
(Trampoline.delay(println(l))).asInstanceOf[Exec[A]]
}
}
def main(args: Array[String]): Unit = {
p7(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
p8(0).foldMap(interpreter).run // same...
}
}
Inserting suspend did not really help: it's still slow, and crashes with OutOfMemoryErrors.
Should I use the suspend somehow differently?
Maybe there is some purely syntactic remedy that makes it possible to use for-comprehensions without generating the map in the end?
I'd really appreciate if someone could point out what I'm doing wrong here, and how to repair it.
That superfluous map added by the Scala compiler moves the recursion from tail position to non-tail position. Free monad still makes this stack safe, but space complexity becomes O(N) instead of O(1). (Specifically, it is still not O(N2).)
Whether it is possible to make scalac optimize that map away makes for a separate question (which I don't know the answer to).
I will try to illustrate what is going on when interpreting p1 versus p3. (I will ignore the translation to Trampoline, which is redundant (see below).)
p3 (i.e. without extra map)
Let me use the following shorthand:
def cont(i: Int): Unit => Prg[Unit] =
ignore1 => if (i > N) ret else p3(i + 1)
Now p3(0) is interpreted as follows
p3(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p3(1)
ret flatMap cont(1)
cont(1)
p3(2)
ret flatMap cont(2)
cont(2)
and so on... You see that the amount of memory needed at any point doesn't exceed some constant upper bound.
p1 (i.e. with extra map)
I will use the following shorthands:
def cont(i: Int): Unit => Prg[Unit] =
ignore1 => (if (i > N) ret else p1(i + 1)).map{ ignore2 => () }
def cpu: Unit => Prg[Unit] = // constant pure unit
ignore => Free.pure(())
Now p1(0) is interpreted as follows:
p1(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p1(1) map { ignore2 => () }
// Free.map is implemented via flatMap
p1(1) flatMap cpu
(ret flatMap cont(1)) flatMap cpu
cont(1) flatMap cpu
(p1(2) map { ignore2 => () }) flatMap cpu
(p1(2) flatMap cpu) flatMap cpu
((ret flatMap cont(2)) flatMap cpu) flatMap cpu
(cont(2) flatMap cpu) flatMap cpu
((p1(3) map { ignore2 => () }) flatMap cpu) flatMap cpu
((p1(3) flatMap cpu) flatMap cpu) flatMap cpu
(((ret flatMap cont(3)) flatMap cpu) flatMap cpu) flatMap cpu
and so on... You see that the memory consumption depends linearly on N. We just moved the evaluation from stack to heap.
Take away: To keep Free memory friendly, keep the recursion in "tail position", that is, on the right hand-side of flatMap (or map).
Aside: The translation to Trampoline is not necessary, since Free is already trampolined. You could interpret directly to Id and use foldMapRec for stack-safe interpretation:
val idInterpreter = new (Lang ~> Id) {
def apply[A](cmd: Lang[A]): Id[A] = cmd match {
case PrintLine(l) => println(l)
}
}
p0(0).foldMapRec(idInterpreter)
This will regain you some fraction of memory (but doesn't make the problem go away).
I am new to scala. I need to count Number of categories in the List, and I am trying to build a tail recursive function, without any success.
case class Category(name:String, children: List[Category])
val lists = List(
Category("1",
List(Category("1.1",
List(Category("1.2", Nil))
))
)
,Category("2", Nil),
Category("3",
List(Category("3.1", Nil))
)
)
Nyavro's solution can be made much faster (by several orders of magnitude) if you use Lists instead of Streams and also append elements at the front.
That's because x.children is usually a lot shorter than xs and Scala's List is an immutable singly linked list making prepend operations a lot faster than append operations.
Here is an example
import scala.annotation.tailrec
case class Category(name:String, children: List[Category])
#tailrec
def childCount(cats:Stream[Category], acc:Int):Int =
cats match {
case Stream.Empty => acc
case x #:: xs => childCount(xs ++ x.children, acc+1)
}
#tailrec
def childCount2(cats: List[Category], acc:Int): Int =
cats match {
case Nil => acc
case x :: xs => childCount2(x.children ++ xs, acc + 1)
}
def generate(depth: Int, children: Int): List[Category] = {
if(depth == 0) Nil
else (0 until children).map(i => Category("abc", generate(depth - 1, children))).toList
}
val list = generate(8, 3)
var start = System.nanoTime
var count = childCount(list.toStream, 0)
var end = System.nanoTime
println("count: " + count)
println("time: " + ((end - start)/1e6) + "ms")
start = System.nanoTime
count = childCount2(list, 0)
end = System.nanoTime
println("count: " + count)
println("time: " + ((end - start)/1e6) + "ms")
output:
count: 9840
time: 2226.761485ms
count: 9840
time: 3.90171ms
Consider the following idea.
Lets define function childCount, taking collection of categories (cats) and number of children count so far (acc). To organize tail-recursive processing we take first child from collection and incrementing the acc. So we have processed first item but got some more items to process - children of first element. The idea is to put these unprocessed children to the end of children collection and call childCount again.
You can implement it this way:
#tailrec
def childCount(cats:Stream[Category], acc:Int):Int =
cats match {
case Stream.Empty => acc
case x #:: xs => childCount(xs ++ x.children, acc+1)
}
call it:
val count = childCount(lists.toStream, 0)
I came up with the following to convert a List[Int] => Try[BigDecimal]:
import scala.util.Try
def f(xs: List[Int]): Try[BigDecimal] =
Try { xs.mkString.toInt }.map ( BigDecimal(_) )
Example:
scala> f(List(1,2,3,4))
res4: scala.util.Try[BigDecimal] = Success(1234)
scala> f(List(1,2,3,55555))
res5: scala.util.Try[BigDecimal] = Success(12355555)
Is there a way to write this function without resorting to a String conversion step?
Not very pretty, and I'm not convinced it's much more efficient. Here's the basic outline.
val pwrs:Stream[BigInt] = 10 #:: pwrs.map(_ * 10)
List(1,2,3,55555).foldLeft(0:BigInt)((p,i) => pwrs.find(_ > i).get * p + i)
Here it is a little more fleshed out with error handling.
import scala.util.Try
def f(xs: List[Int]): Try[BigDecimal] = Try {
lazy val pwrs: Stream[BigDecimal] = 10 #:: pwrs.map(_ * 10)
xs.foldLeft(0: BigDecimal) {
case (acc, i) if i >= 0 => pwrs.find(_ > i).get * acc + i
case _ => throw new Error("bad")
}
}
UPDATE
Just for giggles, I thought I'd plug some code into Rex Kerr's handy benchmarking/profiling tool, Thyme.
the code
import scala.util.Try
def fString(xs: List[Int]): Try[BigInt] = Try { BigInt(xs.mkString) }
def fStream(xs: List[Int]): Try[BigInt] = Try {
lazy val pwrs: Stream[BigInt] = 10 #:: pwrs.map(_ * 10)
xs.foldLeft(0: BigInt) {
case (acc, i) if i >= 0 => pwrs.find(_ > i).get * acc + i
case _ => throw new Error("bad")
}
}
def fLog10(xs: List[Int]): Try[BigInt] = Try {
xs.foldLeft(0: BigInt) {
case (acc, i) if i >= 0 =>
math.pow(10, math.ceil(math.log10(i))).toInt * acc + i
case _ => throw new Error("bad")
}
}
fString() is a slight simplification of Kevin's original question. fStream() is my proposed non-string implementation. fLog10 is the same but with Alexey's suggested enhancement.
You'll note that I'm using BigInt instead of BigDecimal. I found that both non-string methods encountered a bug somewhere around the 37th digit of the result. Some kind of rounding error or something, but there was no problem with BigInt so that's what I used.
test setup
// create a List of 40 Ints and check its contents
val lst = List.fill(40)(util.Random.nextInt(20000))
lst.min // 5
lst.max // 19858
lst.mkString.length // 170
val th = ichi.bench.Thyme.warmed(verbose = print)
th.pbenchWarm(th.Warm(fString(lst)), title="fString")
th.pbenchWarm(th.Warm(fStream(lst)), title="fStream")
th.pbenchWarm(th.Warm(fLog10(lst)), title="fLog10")
results
Benchmark for fString (20 calls in 345.6 ms) Time: 4.015 us 95%
CI 3.957 us - 4.073 us (n=19) Garbage: 109.9 ns (n=2 sweeps
measured)
Benchmark for fStream (20 calls in 305.6 ms) Time: 7.118 us 95%
CI 7.024 us - 7.213 us (n=19) Garbage: 293.0 ns (n=3 sweeps
measured)
Benchmark for fLog10 (20 calls in 382.8 ms) Time: 9.205 us 95%
CI 9.187 us - 9.222 us (n=17) Garbage: 73.24 ns (n=2 sweeps
measured)
So I was right about the efficiency of the non-string algorithm. Oddly, using math._ to avoid Stream creation doesn't appear to be better. I didn't expect that.
takeaway
Number-to-string and string-to-number transitions are reasonably efficient.
import scala.util.{Try, Success}
import scala.annotation.tailrec
def findOrder(i: Int): Long = {
#tailrec
def _findOrder(i: Int, order: Long): Long = {
if (i < order) order
else _findOrder(i, order * 10)
}
_findOrder(i, 1)
}
def f(xs: List[Int]): Try[BigDecimal] = Try(
xs.foldLeft(BigDecimal(0))((acc, i) => acc * findOrder(i) + i)
)
To find the correct power of 10 more efficiently (replace pwrs.find(_ > i).get with nextPowerOf10(i) in #jwvh's answer):
def nextPowerOf10(x: Int) = {
val n = math.ceil(math.log10(x))
BigDecimal(math.pow(10, n))
}
Since you start with an Int, there should be no rounding issues.
What's the best way to terminate a fold early? As a simplified example, imagine I want to sum up the numbers in an Iterable, but if I encounter something I'm not expecting (say an odd number) I might want to terminate. This is a first approximation
def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
nums.foldLeft (Some(0): Option[Int]) {
case (Some(s), n) if n % 2 == 0 => Some(s + n)
case _ => None
}
}
However, this solution is pretty ugly (as in, if I did a .foreach and a return -- it'd be much cleaner and clearer) and worst of all, it traverses the entire iterable even if it encounters a non-even number.
So what would be the best way to write a fold like this, that terminates early? Should I just go and write this recursively, or is there a more accepted way?
My first choice would usually be to use recursion. It is only moderately less compact, is potentially faster (certainly no slower), and in early termination can make the logic more clear. In this case you need nested defs which is a little awkward:
def sumEvenNumbers(nums: Iterable[Int]) = {
def sumEven(it: Iterator[Int], n: Int): Option[Int] = {
if (it.hasNext) {
val x = it.next
if ((x % 2) == 0) sumEven(it, n+x) else None
}
else Some(n)
}
sumEven(nums.iterator, 0)
}
My second choice would be to use return, as it keeps everything else intact and you only need to wrap the fold in a def so you have something to return from--in this case, you already have a method, so:
def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
Some(nums.foldLeft(0){ (n,x) =>
if ((n % 2) != 0) return None
n+x
})
}
which in this particular case is a lot more compact than recursion (though we got especially unlucky with recursion since we had to do an iterable/iterator transformation). The jumpy control flow is something to avoid when all else is equal, but here it's not. No harm in using it in cases where it's valuable.
If I was doing this often and wanted it within the middle of a method somewhere (so I couldn't just use return), I would probably use exception-handling to generate non-local control flow. That is, after all, what it is good at, and error handling is not the only time it's useful. The only trick is to avoid generating a stack trace (which is really slow), and that's easy because the trait NoStackTrace and its child trait ControlThrowable already do that for you. Scala already uses this internally (in fact, that's how it implements the return from inside the fold!). Let's make our own (can't be nested, though one could fix that):
import scala.util.control.ControlThrowable
case class Returned[A](value: A) extends ControlThrowable {}
def shortcut[A](a: => A) = try { a } catch { case Returned(v) => v }
def sumEvenNumbers(nums: Iterable[Int]) = shortcut{
Option(nums.foldLeft(0){ (n,x) =>
if ((x % 2) != 0) throw Returned(None)
n+x
})
}
Here of course using return is better, but note that you could put shortcut anywhere, not just wrapping an entire method.
Next in line for me would be to re-implement fold (either myself or to find a library that does it) so that it could signal early termination. The two natural ways of doing this are to not propagate the value but an Option containing the value, where None signifies termination; or to use a second indicator function that signals completion. The Scalaz lazy fold shown by Kim Stebel already covers the first case, so I'll show the second (with a mutable implementation):
def foldOrFail[A,B](it: Iterable[A])(zero: B)(fail: A => Boolean)(f: (B,A) => B): Option[B] = {
val ii = it.iterator
var b = zero
while (ii.hasNext) {
val x = ii.next
if (fail(x)) return None
b = f(b,x)
}
Some(b)
}
def sumEvenNumbers(nums: Iterable[Int]) = foldOrFail(nums)(0)(_ % 2 != 0)(_ + _)
(Whether you implement the termination by recursion, return, laziness, etc. is up to you.)
I think that covers the main reasonable variants; there are some other options also, but I'm not sure why one would use them in this case. (Iterator itself would work well if it had a findOrPrevious, but it doesn't, and the extra work it takes to do that by hand makes it a silly option to use here.)
The scenario you describe (exit upon some unwanted condition) seems like a good use case for the takeWhile method. It is essentially filter, but should end upon encountering an element that doesn't meet the condition.
For example:
val list = List(2,4,6,8,6,4,2,5,3,2)
list.takeWhile(_ % 2 == 0) //result is List(2,4,6,8,6,4,2)
This will work just fine for Iterators/Iterables too. The solution I suggest for your "sum of even numbers, but break on odd" is:
list.iterator.takeWhile(_ % 2 == 0).foldLeft(...)
And just to prove that it's not wasting your time once it hits an odd number...
scala> val list = List(2,4,5,6,8)
list: List[Int] = List(2, 4, 5, 6, 8)
scala> def condition(i: Int) = {
| println("processing " + i)
| i % 2 == 0
| }
condition: (i: Int)Boolean
scala> list.iterator.takeWhile(condition _).sum
processing 2
processing 4
processing 5
res4: Int = 6
You can do what you want in a functional style using the lazy version of foldRight in scalaz. For a more in depth explanation, see this blog post. While this solution uses a Stream, you can convert an Iterable into a Stream efficiently with iterable.toStream.
import scalaz._
import Scalaz._
val str = Stream(2,1,2,2,2,2,2,2,2)
var i = 0 //only here for testing
val r = str.foldr(Some(0):Option[Int])((n,s) => {
println(i)
i+=1
if (n % 2 == 0) s.map(n+) else None
})
This only prints
0
1
which clearly shows that the anonymous function is only called twice (i.e. until it encounters the odd number). That is due to the definition of foldr, whose signature (in case of Stream) is def foldr[B](b: B)(f: (Int, => B) => B)(implicit r: scalaz.Foldable[Stream]): B. Note that the anonymous function takes a by name parameter as its second argument, so it need no be evaluated.
Btw, you can still write this with the OP's pattern matching solution, but I find if/else and map more elegant.
Well, Scala does allow non local returns. There are differing opinions on whether or not this is a good style.
scala> def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
| nums.foldLeft (Some(0): Option[Int]) {
| case (None, _) => return None
| case (Some(s), n) if n % 2 == 0 => Some(s + n)
| case (Some(_), _) => None
| }
| }
sumEvenNumbers: (nums: Iterable[Int])Option[Int]
scala> sumEvenNumbers(2 to 10)
res8: Option[Int] = None
scala> sumEvenNumbers(2 to 10 by 2)
res9: Option[Int] = Some(30)
EDIT:
In this particular case, as #Arjan suggested, you can also do:
def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
nums.foldLeft (Some(0): Option[Int]) {
case (Some(s), n) if n % 2 == 0 => Some(s + n)
case _ => return None
}
}
You can use foldM from cats lib (as suggested by #Didac) but I suggest to use Either instead of Option if you want to get actual sum out.
bifoldMap is used to extract the result from Either.
import cats.implicits._
def sumEven(nums: Stream[Int]): Either[Int, Int] = {
nums.foldM(0) {
case (acc, n) if n % 2 == 0 => Either.right(acc + n)
case (acc, n) => {
println(s"Stopping on number: $n")
Either.left(acc)
}
}
}
examples:
println("Result: " + sumEven(Stream(2, 2, 3, 11)).bifoldMap(identity, identity))
> Stopping on number: 3
> Result: 4
println("Result: " + sumEven(Stream(2, 7, 2, 3)).bifoldMap(identity, identity))
> Stopping on number: 7
> Result: 2
Cats has a method called foldM which does short-circuiting (for Vector, List, Stream, ...).
It works as follows:
def sumEvenNumbers(nums: Stream[Int]): Option[Long] = {
import cats.implicits._
nums.foldM(0L) {
case (acc, c) if c % 2 == 0 => Some(acc + c)
case _ => None
}
}
If it finds a not even element it returns None without computing the rest, otherwise it returns the sum of the even entries.
If you want to keep count until an even entry is found, you should use an Either[Long, Long]
#Rex Kerr your answer helped me, but I needed to tweak it to use Either
def foldOrFail[A,B,C,D](map: B => Either[D, C])(merge: (A, C) => A)(initial: A)(it: Iterable[B]): Either[D, A] = {
val ii= it.iterator
var b= initial
while (ii.hasNext) {
val x= ii.next
map(x) match {
case Left(error) => return Left(error)
case Right(d) => b= merge(b, d)
}
}
Right(b)
}
You could try using a temporary var and using takeWhile. Here is a version.
var continue = true
// sample stream of 2's and then a stream of 3's.
val evenSum = (Stream.fill(10)(2) ++ Stream.fill(10)(3)).takeWhile(_ => continue)
.foldLeft(Option[Int](0)){
case (result,i) if i%2 != 0 =>
continue = false;
// return whatever is appropriate either the accumulated sum or None.
result
case (optionSum,i) => optionSum.map( _ + i)
}
The evenSum should be Some(20) in this case.
You can throw a well-chosen exception upon encountering your termination criterion, handling it in the calling code.
A more beutiful solution would be using span:
val (l, r) = numbers.span(_ % 2 == 0)
if(r.isEmpty) Some(l.sum)
else None
... but it traverses the list two times if all the numbers are even
Just for an "academic" reasons (:
var headers = Source.fromFile(file).getLines().next().split(",")
var closeHeaderIdx = headers.takeWhile { s => !"Close".equals(s) }.foldLeft(0)((i, S) => i+1)
Takes twice then it should but it is a nice one liner.
If "Close" not found it will return
headers.size
Another (better) is this one:
var headers = Source.fromFile(file).getLines().next().split(",").toList
var closeHeaderIdx = headers.indexOf("Close")