Using Streams ... still got Out of Memory - scala

I am trying to generate a file which contains a very large test data in the form of a json array.
Since my test data is really big I cannot use "mkString". This is also why I am using Streams and tail recursion.
But still my program gets an Out of memory exception
package com.abhi
import java.io.FileWriter
import scala.annotation.tailrec
object GenerateTestFile {
def jsonField(name: String, value : String) : String = {
s""""$name":"$value""""
}
def writeRecords(records : Stream[String], fw : FileWriter) : Unit = {
#tailrec
def inner(records: Stream[String]) : Unit = {
records match {
case head #:: Stream.Empty => fw.write(head)
case head #:: tail => fw.write(head + ","); inner(tail)
case Stream.Empty =>
}
}
inner(records)
}
def main(args: Array[String]) : Unit = {
val fileWriter = new FileWriter("./sample.json", true)
fileWriter.write("[")
val lines = (1 to 10000000)
.toStream
.map(x => (
jsonField("id", x.toString),
jsonField("FieldA", "FieldA" + x),
jsonField("FieldB", "FieldB" +x),
jsonField("FieldC", "FieldC" + x),
jsonField("FieldD", "FieldD" + x),
jsonField("FieldE", "FieldE" + x)
))
.map (t => t.productIterator.mkString("{",",", "}"))
writeRecords(lines, fileWriter)
fileWriter.write("]")
fileWriter.close()
}
}
Exception
[error] (run-main-0) java.lang.OutOfMemoryError: GC overhead limit exceeded
java.lang.OutOfMemoryError: GC overhead limit exceeded
at java.util.Arrays.copyOf(Arrays.java:2367)
at java.lang.AbstractStringBuilder.expandCapacity(AbstractStringBuilder.java:130)
at java.lang.AbstractStringBuilder.ensureCapacityInternal(AbstractStringBuilder.java:114)
at java.lang.AbstractStringBuilder.append(AbstractStringBuilder.java:415)
at java.lang.StringBuilder.append(StringBuilder.java:132)
at java.lang.StringBuilder.append(StringBuilder.java:128)
at scala.StringContext.standardInterpolator(StringContext.scala:125)
at scala.StringContext.s(StringContext.scala:95)
at com.abhi.GenerateTestFile$.jsonField(GenerateTestFile.scala:9)
at com.abhi.GenerateTestFile$$anonfun$1.apply(GenerateTestFile.scala:32)

What you are looking for is not a stream but an iterator that materialises that list of values one by one; also after you process it i.e. save to file throws away the memory.
When you create your json you actually hold in memory the entire sequence of numbers and in addition to that you generate for each element a new large text blog which you later put in the file. Memory wise the initial sequence is insignificant compared to the size of the text.
What I've did is I've used a for comprehension to create an iterator that releases elements one by one. With foldLeft I make sure that i map the element to a json string, write it do disk and release the memory (the reference to any created object is lost therefore GC can kick in an reclaim the memory. unfortunately with this approach you cannot make use of the parallelism features.
def main(args: Array[String]): Unit = {
val fileWriter = new FileWriter("./sample.json", true)
fileWriter.write("[")
fileWriter.write(createObject(1).productIterator.mkString("{", ",", "}"))
val lines = (for (v <- 2 to 10000000) yield v)
.foldLeft(0)((_, x) => {
if (x % 50000 == 0)
println(s"We've reached $x element")
fileWriter.write(createObject(x).productIterator.mkString(",{", ",", "}"))
x
})
fileWriter.write("]")
fileWriter.close()
}
def createObject(x: Int) =
(jsonField("id", x.toString),
jsonField("FieldA", "FieldA" + x),
jsonField("FieldB", "FieldB" + x),
jsonField("FieldC", "FieldC" + x),
jsonField("FieldD", "FieldD" + x),
jsonField("FieldE", "FieldE" + x))

from the source of Stream:
* - One must be cautious of memoization; you can very quickly eat up large
* amounts of memory if you're not careful. The reason for this is that the
* memoization of the `Stream` creates a structure much like
* [[scala.collection.immutable.List]]. So long as something is holding on to
* the head, the head holds on to the tail, and so it continues recursively.
* If, on the other hand, there is nothing holding on to the head (e.g. we used
* `def` to define the `Stream`) then once it is no longer being used directly,
* it disappears.
So, if you inline lines (or write it as a def), and rewrite your writeRecords function to not hold onto a reference to the initial head, (or write the param as 'call by name' value using an arrow: records: => Stream[String], which does basically the same thing as def vs val) elements should be garbage collected as the rest of the stream is processed:
#tailrec
def writeRecords(records : Stream[String], fw : FileWriter) : Unit = {
records match {
case head #:: Stream.Empty => fw.write(head)
case head #:: tail => fw.write(head + ","); writeRecords(tail, fw)
case Stream.Empty =>
}
}
writeRecords((1 to 10000000)
.toStream
.map(x => (
jsonField("id", x.toString),
jsonField("FieldA", "FieldA" + x),
jsonField("FieldB", "FieldB" +x),
jsonField("FieldC", "FieldC" + x),
jsonField("FieldD", "FieldD" + x),
jsonField("FieldE", "FieldE" + x)
))
.map (t => t.productIterator.mkString("{",",", "}")),
fileWriter)

Related

Is it faster to create a new Map or clear it and use again?

I need to use many Maps in my project so I wonder which way is more efficient:
val map = mutable.Map[Int, Int] = mutable.Map.empty
for (_ <- 0 until big_number)
{
// do something with map
map.clear()
}
or
for (_ <- 0 until big_number)
{
val map = mutable.Map[Int, Int] = mutable.Map.empty
// do something with map
}
to use in terms of time and memory?
Well, my formal answer would always be depends. As you need to benchmark your own scenario, and see what fits better for your scenario. I'll provide an example how you can try benchmarking your own code. Let's start with writing a measuring method:
def measure(name: String, f: () => Unit): Unit = {
val l = System.currentTimeMillis()
println(name + ": " + (System.currentTimeMillis() - l))
f()
println(name + ": " + (System.currentTimeMillis() - l))
}
Let's assume that in each iteration we need to insert into the map one key-value pair, and then to print it:
Await.result(Future.sequence(Seq(Future {
measure("inner", () => {
for (i <- 0 until 10) {
val map2 = mutable.Map.empty[Int, Int]
map2(i) = i
println(map2)
}
})
},
Future {
measure("outer", () => {
val map1 = mutable.Map.empty[Int, Int]
for (i <- 0 until 10) {
map1(i) = i
println(map1)
map1.clear()
}
})
})), 10.seconds)
The output in this case, is almost always equal between the inner and the outer. Please note that in this case I run the two options in parallel, as if I wouldn't the first one always takes significantly more time, no matter which one of then is first.
Therefore, we can conclude, that in this case they are almost the same.
But, if for example I add an immutable option:
Future {
measure("immutable", () => {
for (i <- 0 until 10) {
val map1 = Map[Int, Int](i -> i)
println(map1)
}
})
}
it always ends up first. This makes sense because immutable collections are much more performant than the mutables.
For better performance tests you probably need to use some third parties, such as scalameter, or others that exists.

Free ~> Trampoline : recursive program crashes with OutOfMemoryError

Suppose that I'm trying to implement a very simple domain specific language with only one operation:
printLine(line)
Then I want to write a program that takes an integer n as input, prints something if n is divisible by 10k, and then calls itself with n + 1, until n reaches some maximum value N.
Omitting all syntactic noise caused by for-comprehensions, what I want is:
#annotation.tailrec def p(n: Int): Unit = {
if (n % 10000 == 0) printLine("line")
if (n > N) () else p(n + 1)
}
Essentially, it would be a kind of "fizzbuzz".
Here are a few attempts to implement this using the Free monad from Scalaz 7.3.0-M7:
import scalaz._
object Demo1 {
// define operations of a little domain specific language
sealed trait Lang[X]
case class PrintLine(line: String) extends Lang[Unit]
// define the domain specific language as the free monad of operations
type Prog[X] = Free[Lang, X]
import Free.{liftF, pure}
// lift operations into the free monad
def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
def ret: Prog[Unit] = Free.pure(())
// write a program that is just a loop that prints current index
// after every few iteration steps
val mod = 100000
val N = 1000000
// straightforward syntax: deadly slow, exits with OutOfMemoryError
def p0(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
_ <- (if (i > N) ret else p0(i + 1))
} yield ()
// Same as above, but written out without `for`
def p1(i: Int): Prog[Unit] =
(if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
ignore1 =>
(if (i > N) ret else p1(i + 1)).map{ ignore2 => () }
}
// Same as above, with `map` attached to recursive call
def p2(i: Int): Prog[Unit] =
(if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
ignore1 =>
(if (i > N) ret else p2(i + 1).map{ ignore2 => () })
}
// Same as above, but without the `map`; performs ok.
def p3(i: Int): Prog[Unit] = {
(if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
ignore1 =>
if (i > N) ret else p3(i + 1)
}
}
// Variation of the above; Ok.
def p4(i: Int): Prog[Unit] = (for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
} yield ()).flatMap{ ignored2 =>
if (i > N) ret else p4(i + 1)
}
// try to use the variable returned by the last generator after yield,
// hope that the final `map` is optimized away (it's not optimized away...)
def p5(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
stopHere <- (if (i > N) ret else p5(i + 1))
} yield stopHere
// define an interpreter that translates the programs into Trampoline
import scalaz.Trampoline
type Exec[X] = Free.Trampoline[X]
val interpreter = new (Lang ~> Exec) {
def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
case PrintLine(l) => Trampoline.delay(println(l))
}
}
// try it out
def main(args: Array[String]): Unit = {
println("\n p0")
p0(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
println("\n p1")
p1(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
println("\n p2")
p2(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
println("\n p3")
p3(0).foldMap(interpreter).run // ok
println("\n p4")
p4(0).foldMap(interpreter).run // ok
println("\n p5")
p5(0).foldMap(interpreter).run // OutOfMemory
}
}
Unfortunately, the straightforward translation (p0) seems to run with some kind of O(N^2) overhead, and crashes with an OutOfMemoryError. The problem seems to be that the for-comprehension appends a map{x => ()} after the recursive call to p0, which forces the Free monad to fill the entire memory with reminders to "finish 'p0' and then do nothing".
If I manually "unroll" the for comprehension, and write out the last flatMap explicitly (as in p3 and p4), then the problem goes away, and everything runs smoothly. This, however, is an extremely brittle workaround: the behavior of the program changes dramatically if we simply append a map(id) to it, and this map(id) isn't even visible in the code, because it is generated automatically by the for-comprehension.
In this older post here: https://apocalisp.wordpress.com/2011/10/26/tail-call-elimination-in-scala-monads/
it has been repeatedly advised to wrap recursive calls into a suspend. Here is an attempt with Applicative instance and suspend:
import scalaz._
// Essentially same as in `Demo1`, but this time with
// an `Applicative` and an explicit `Suspend` in the
// `for`-comprehension
object Demo2 {
sealed trait Lang[H]
case class Const[H](h: H) extends Lang[H]
case class PrintLine[H](line: String) extends Lang[H]
implicit object Lang extends Applicative[Lang] {
def point[A](a: => A): Lang[A] = Const(a)
def ap[A, B](a: => Lang[A])(f: => Lang[A => B]): Lang[B] = a match {
case Const(x) => {
f match {
case Const(ab) => Const(ab(x))
case _ => throw new Error
}
}
case PrintLine(l) => PrintLine(l)
}
}
type Prog[X] = Free[Lang, X]
import Free.{liftF, pure}
def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
def ret: Prog[Unit] = Free.pure(())
val mod = 100000
val N = 2000000
// try to suspend the entire second generator
def p7(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
_ <- Free.suspend(if (i > N) ret else p7(i + 1))
} yield ()
// try to suspend the recursive call
def p8(i: Int): Prog[Unit] = for {
_ <- (if (i % mod == 0) printLine("i = " + i) else ret)
_ <- if (i > N) ret else Free.suspend(p8(i + 1))
} yield ()
import scalaz.Trampoline
type Exec[X] = Free.Trampoline[X]
val interpreter = new (Lang ~> Exec) {
def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
case Const(x) => Trampoline.done(x)
case PrintLine(l) =>
(Trampoline.delay(println(l))).asInstanceOf[Exec[A]]
}
}
def main(args: Array[String]): Unit = {
p7(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
p8(0).foldMap(interpreter).run // same...
}
}
Inserting suspend did not really help: it's still slow, and crashes with OutOfMemoryErrors.
Should I use the suspend somehow differently?
Maybe there is some purely syntactic remedy that makes it possible to use for-comprehensions without generating the map in the end?
I'd really appreciate if someone could point out what I'm doing wrong here, and how to repair it.
That superfluous map added by the Scala compiler moves the recursion from tail position to non-tail position. Free monad still makes this stack safe, but space complexity becomes O(N) instead of O(1). (Specifically, it is still not O(N2).)
Whether it is possible to make scalac optimize that map away makes for a separate question (which I don't know the answer to).
I will try to illustrate what is going on when interpreting p1 versus p3. (I will ignore the translation to Trampoline, which is redundant (see below).)
p3 (i.e. without extra map)
Let me use the following shorthand:
def cont(i: Int): Unit => Prg[Unit] =
ignore1 => if (i > N) ret else p3(i + 1)
Now p3(0) is interpreted as follows
p3(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p3(1)
ret flatMap cont(1)
cont(1)
p3(2)
ret flatMap cont(2)
cont(2)
and so on... You see that the amount of memory needed at any point doesn't exceed some constant upper bound.
p1 (i.e. with extra map)
I will use the following shorthands:
def cont(i: Int): Unit => Prg[Unit] =
ignore1 => (if (i > N) ret else p1(i + 1)).map{ ignore2 => () }
def cpu: Unit => Prg[Unit] = // constant pure unit
ignore => Free.pure(())
Now p1(0) is interpreted as follows:
p1(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p1(1) map { ignore2 => () }
// Free.map is implemented via flatMap
p1(1) flatMap cpu
(ret flatMap cont(1)) flatMap cpu
cont(1) flatMap cpu
(p1(2) map { ignore2 => () }) flatMap cpu
(p1(2) flatMap cpu) flatMap cpu
((ret flatMap cont(2)) flatMap cpu) flatMap cpu
(cont(2) flatMap cpu) flatMap cpu
((p1(3) map { ignore2 => () }) flatMap cpu) flatMap cpu
((p1(3) flatMap cpu) flatMap cpu) flatMap cpu
(((ret flatMap cont(3)) flatMap cpu) flatMap cpu) flatMap cpu
and so on... You see that the memory consumption depends linearly on N. We just moved the evaluation from stack to heap.
Take away: To keep Free memory friendly, keep the recursion in "tail position", that is, on the right hand-side of flatMap (or map).
Aside: The translation to Trampoline is not necessary, since Free is already trampolined. You could interpret directly to Id and use foldMapRec for stack-safe interpretation:
val idInterpreter = new (Lang ~> Id) {
def apply[A](cmd: Lang[A]): Id[A] = cmd match {
case PrintLine(l) => println(l)
}
}
p0(0).foldMapRec(idInterpreter)
This will regain you some fraction of memory (but doesn't make the problem go away).

How to count number of total items where a class references itself

I am new to scala. I need to count Number of categories in the List, and I am trying to build a tail recursive function, without any success.
case class Category(name:String, children: List[Category])
val lists = List(
Category("1",
List(Category("1.1",
List(Category("1.2", Nil))
))
)
,Category("2", Nil),
Category("3",
List(Category("3.1", Nil))
)
)
Nyavro's solution can be made much faster (by several orders of magnitude) if you use Lists instead of Streams and also append elements at the front.
That's because x.children is usually a lot shorter than xs and Scala's List is an immutable singly linked list making prepend operations a lot faster than append operations.
Here is an example
import scala.annotation.tailrec
case class Category(name:String, children: List[Category])
#tailrec
def childCount(cats:Stream[Category], acc:Int):Int =
cats match {
case Stream.Empty => acc
case x #:: xs => childCount(xs ++ x.children, acc+1)
}
#tailrec
def childCount2(cats: List[Category], acc:Int): Int =
cats match {
case Nil => acc
case x :: xs => childCount2(x.children ++ xs, acc + 1)
}
def generate(depth: Int, children: Int): List[Category] = {
if(depth == 0) Nil
else (0 until children).map(i => Category("abc", generate(depth - 1, children))).toList
}
val list = generate(8, 3)
var start = System.nanoTime
var count = childCount(list.toStream, 0)
var end = System.nanoTime
println("count: " + count)
println("time: " + ((end - start)/1e6) + "ms")
start = System.nanoTime
count = childCount2(list, 0)
end = System.nanoTime
println("count: " + count)
println("time: " + ((end - start)/1e6) + "ms")
output:
count: 9840
time: 2226.761485ms
count: 9840
time: 3.90171ms
Consider the following idea.
Lets define function childCount, taking collection of categories (cats) and number of children count so far (acc). To organize tail-recursive processing we take first child from collection and incrementing the acc. So we have processed first item but got some more items to process - children of first element. The idea is to put these unprocessed children to the end of children collection and call childCount again.
You can implement it this way:
#tailrec
def childCount(cats:Stream[Category], acc:Int):Int =
cats match {
case Stream.Empty => acc
case x #:: xs => childCount(xs ++ x.children, acc+1)
}
call it:
val count = childCount(lists.toStream, 0)

Abort early in a fold

What's the best way to terminate a fold early? As a simplified example, imagine I want to sum up the numbers in an Iterable, but if I encounter something I'm not expecting (say an odd number) I might want to terminate. This is a first approximation
def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
nums.foldLeft (Some(0): Option[Int]) {
case (Some(s), n) if n % 2 == 0 => Some(s + n)
case _ => None
}
}
However, this solution is pretty ugly (as in, if I did a .foreach and a return -- it'd be much cleaner and clearer) and worst of all, it traverses the entire iterable even if it encounters a non-even number.
So what would be the best way to write a fold like this, that terminates early? Should I just go and write this recursively, or is there a more accepted way?
My first choice would usually be to use recursion. It is only moderately less compact, is potentially faster (certainly no slower), and in early termination can make the logic more clear. In this case you need nested defs which is a little awkward:
def sumEvenNumbers(nums: Iterable[Int]) = {
def sumEven(it: Iterator[Int], n: Int): Option[Int] = {
if (it.hasNext) {
val x = it.next
if ((x % 2) == 0) sumEven(it, n+x) else None
}
else Some(n)
}
sumEven(nums.iterator, 0)
}
My second choice would be to use return, as it keeps everything else intact and you only need to wrap the fold in a def so you have something to return from--in this case, you already have a method, so:
def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
Some(nums.foldLeft(0){ (n,x) =>
if ((n % 2) != 0) return None
n+x
})
}
which in this particular case is a lot more compact than recursion (though we got especially unlucky with recursion since we had to do an iterable/iterator transformation). The jumpy control flow is something to avoid when all else is equal, but here it's not. No harm in using it in cases where it's valuable.
If I was doing this often and wanted it within the middle of a method somewhere (so I couldn't just use return), I would probably use exception-handling to generate non-local control flow. That is, after all, what it is good at, and error handling is not the only time it's useful. The only trick is to avoid generating a stack trace (which is really slow), and that's easy because the trait NoStackTrace and its child trait ControlThrowable already do that for you. Scala already uses this internally (in fact, that's how it implements the return from inside the fold!). Let's make our own (can't be nested, though one could fix that):
import scala.util.control.ControlThrowable
case class Returned[A](value: A) extends ControlThrowable {}
def shortcut[A](a: => A) = try { a } catch { case Returned(v) => v }
def sumEvenNumbers(nums: Iterable[Int]) = shortcut{
Option(nums.foldLeft(0){ (n,x) =>
if ((x % 2) != 0) throw Returned(None)
n+x
})
}
Here of course using return is better, but note that you could put shortcut anywhere, not just wrapping an entire method.
Next in line for me would be to re-implement fold (either myself or to find a library that does it) so that it could signal early termination. The two natural ways of doing this are to not propagate the value but an Option containing the value, where None signifies termination; or to use a second indicator function that signals completion. The Scalaz lazy fold shown by Kim Stebel already covers the first case, so I'll show the second (with a mutable implementation):
def foldOrFail[A,B](it: Iterable[A])(zero: B)(fail: A => Boolean)(f: (B,A) => B): Option[B] = {
val ii = it.iterator
var b = zero
while (ii.hasNext) {
val x = ii.next
if (fail(x)) return None
b = f(b,x)
}
Some(b)
}
def sumEvenNumbers(nums: Iterable[Int]) = foldOrFail(nums)(0)(_ % 2 != 0)(_ + _)
(Whether you implement the termination by recursion, return, laziness, etc. is up to you.)
I think that covers the main reasonable variants; there are some other options also, but I'm not sure why one would use them in this case. (Iterator itself would work well if it had a findOrPrevious, but it doesn't, and the extra work it takes to do that by hand makes it a silly option to use here.)
The scenario you describe (exit upon some unwanted condition) seems like a good use case for the takeWhile method. It is essentially filter, but should end upon encountering an element that doesn't meet the condition.
For example:
val list = List(2,4,6,8,6,4,2,5,3,2)
list.takeWhile(_ % 2 == 0) //result is List(2,4,6,8,6,4,2)
This will work just fine for Iterators/Iterables too. The solution I suggest for your "sum of even numbers, but break on odd" is:
list.iterator.takeWhile(_ % 2 == 0).foldLeft(...)
And just to prove that it's not wasting your time once it hits an odd number...
scala> val list = List(2,4,5,6,8)
list: List[Int] = List(2, 4, 5, 6, 8)
scala> def condition(i: Int) = {
| println("processing " + i)
| i % 2 == 0
| }
condition: (i: Int)Boolean
scala> list.iterator.takeWhile(condition _).sum
processing 2
processing 4
processing 5
res4: Int = 6
You can do what you want in a functional style using the lazy version of foldRight in scalaz. For a more in depth explanation, see this blog post. While this solution uses a Stream, you can convert an Iterable into a Stream efficiently with iterable.toStream.
import scalaz._
import Scalaz._
val str = Stream(2,1,2,2,2,2,2,2,2)
var i = 0 //only here for testing
val r = str.foldr(Some(0):Option[Int])((n,s) => {
println(i)
i+=1
if (n % 2 == 0) s.map(n+) else None
})
This only prints
0
1
which clearly shows that the anonymous function is only called twice (i.e. until it encounters the odd number). That is due to the definition of foldr, whose signature (in case of Stream) is def foldr[B](b: B)(f: (Int, => B) => B)(implicit r: scalaz.Foldable[Stream]): B. Note that the anonymous function takes a by name parameter as its second argument, so it need no be evaluated.
Btw, you can still write this with the OP's pattern matching solution, but I find if/else and map more elegant.
Well, Scala does allow non local returns. There are differing opinions on whether or not this is a good style.
scala> def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
| nums.foldLeft (Some(0): Option[Int]) {
| case (None, _) => return None
| case (Some(s), n) if n % 2 == 0 => Some(s + n)
| case (Some(_), _) => None
| }
| }
sumEvenNumbers: (nums: Iterable[Int])Option[Int]
scala> sumEvenNumbers(2 to 10)
res8: Option[Int] = None
scala> sumEvenNumbers(2 to 10 by 2)
res9: Option[Int] = Some(30)
EDIT:
In this particular case, as #Arjan suggested, you can also do:
def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
nums.foldLeft (Some(0): Option[Int]) {
case (Some(s), n) if n % 2 == 0 => Some(s + n)
case _ => return None
}
}
You can use foldM from cats lib (as suggested by #Didac) but I suggest to use Either instead of Option if you want to get actual sum out.
bifoldMap is used to extract the result from Either.
import cats.implicits._
def sumEven(nums: Stream[Int]): Either[Int, Int] = {
nums.foldM(0) {
case (acc, n) if n % 2 == 0 => Either.right(acc + n)
case (acc, n) => {
println(s"Stopping on number: $n")
Either.left(acc)
}
}
}
examples:
println("Result: " + sumEven(Stream(2, 2, 3, 11)).bifoldMap(identity, identity))
> Stopping on number: 3
> Result: 4
println("Result: " + sumEven(Stream(2, 7, 2, 3)).bifoldMap(identity, identity))
> Stopping on number: 7
> Result: 2
Cats has a method called foldM which does short-circuiting (for Vector, List, Stream, ...).
It works as follows:
def sumEvenNumbers(nums: Stream[Int]): Option[Long] = {
import cats.implicits._
nums.foldM(0L) {
case (acc, c) if c % 2 == 0 => Some(acc + c)
case _ => None
}
}
If it finds a not even element it returns None without computing the rest, otherwise it returns the sum of the even entries.
If you want to keep count until an even entry is found, you should use an Either[Long, Long]
#Rex Kerr your answer helped me, but I needed to tweak it to use Either
def foldOrFail[A,B,C,D](map: B => Either[D, C])(merge: (A, C) => A)(initial: A)(it: Iterable[B]): Either[D, A] = {
val ii= it.iterator
var b= initial
while (ii.hasNext) {
val x= ii.next
map(x) match {
case Left(error) => return Left(error)
case Right(d) => b= merge(b, d)
}
}
Right(b)
}
You could try using a temporary var and using takeWhile. Here is a version.
var continue = true
// sample stream of 2's and then a stream of 3's.
val evenSum = (Stream.fill(10)(2) ++ Stream.fill(10)(3)).takeWhile(_ => continue)
.foldLeft(Option[Int](0)){
case (result,i) if i%2 != 0 =>
continue = false;
// return whatever is appropriate either the accumulated sum or None.
result
case (optionSum,i) => optionSum.map( _ + i)
}
The evenSum should be Some(20) in this case.
You can throw a well-chosen exception upon encountering your termination criterion, handling it in the calling code.
A more beutiful solution would be using span:
val (l, r) = numbers.span(_ % 2 == 0)
if(r.isEmpty) Some(l.sum)
else None
... but it traverses the list two times if all the numbers are even
Just for an "academic" reasons (:
var headers = Source.fromFile(file).getLines().next().split(",")
var closeHeaderIdx = headers.takeWhile { s => !"Close".equals(s) }.foldLeft(0)((i, S) => i+1)
Takes twice then it should but it is a nice one liner.
If "Close" not found it will return
headers.size
Another (better) is this one:
var headers = Source.fromFile(file).getLines().next().split(",").toList
var closeHeaderIdx = headers.indexOf("Close")

Scala: Read some data of an Enumerator[T] and return the remaining Enumerator[T]

I am using the asynchronous I/O library of the playframework which uses Iteratees and Enumerators. I now have an Iterator[T] as data sink (for simplification say it's an Iterator[Byte] which stores its content into a file). This Iterator[Byte] is passed to the function which handles the writing.
But before writing I want to add some statistical information at the file begin (for simplification say it's one Byte), so I transfer the iterator the following way before passing it to the write function:
def write(value: Byte, output: Iteratee[Byte]): Iteratee[Byte] =
Iteratee.flatten(output.feed(Input.El(value)))
When I now read the stored file from the disk, I get an Enumerator[Byte] for it.
At first I want to read and remove the additional data and then I want to pass the rest of the Enumerator[Byte] to a function which handles the reading.
So I also need to transform the enumerator:
def read(input: Enumerator[Byte]): (Byte, Enumerator[Byte]) = {
val firstEnumeratorEntry = ...
val remainingEnumerator = ...
(firstEnumeratorEntry, remainingEnumerator)
}
But I have no idea, how to do this. How can I read some bytes from an Enumerator and get the remaining Enumerator?
Replacing Iteratee[Byte] with OutputStream and Enumerator[Byte] with InputStream, this would be very easy:
def write(value: Byte, output: OutputStream) = {
output.write(value)
output
}
def read(input: InputStream) = (input.read,input)
But I need the asynchronous I/O of the play framework.
I wonder if you can tackle your goal from another angle.
That function that would use the remaining enumerator, let's call it remaining, presumably it applies to an iteratee to do the processing of the remainder: remaining |>> iteratee yielding another iteratee. Let's call that resulting iteratee iteratee2... Can you check whether you can get a reference to iteratee2? If that's the case, then you can get and process the first byte using a first iteratee head, then combine head and iteratee2 through flatMap:
val head = Enumeratee.take[Byte](1) &>> Iteratee.foreach[Byte](println)
val processing = for { h <- head; i <- iteratee2 } yield (h, i)
Iteratee.flatten(processing).run
If you cannot get a hold of iteratee2 - which would be the case if your enumerator combines with an enumeratee that you did not implement - then this approach won't work.
Here is one way to achieve this by folding within the Iteratee and an appropriate (kind-of) State accumulator (a tuple here)
I go read the routes file, the first byte will be read as a Char and the other will be appended to a String as UTF-8 bytestrings.
def index = Action {
/*let's do everything asyncly*/
Async {
/*for comprehension for read-friendly*/
for (
i <- read; /*read the file */
(r:(Option[Char], String)) <- i.run /*"create" the related Promise and run it*/
) yield Ok("first : " + r._1.get + "\n" + "rest" + r._2) /* map the Promised result in a correct Request's Result*/
}
}
def read = {
//get the routes file in an Enumerator
val file: Enumerator[Array[Byte]] = Enumerator.fromFile(Play.getFile("/conf/routes"))
//apply the enumerator with an Iteratee that folds the data as wished
file(Iteratee.fold((None, ""):(Option[Char], String)) { (acc, b) =>
acc._1 match {
/*on the first chunk*/ case None => (Some(b(0).toChar), acc._2 + new String(b.tail, Charset.forName("utf-8")))
/*on other chunks*/ case x => (x, acc._2 + new String(b, Charset.forName("utf-8")))
}
})
}
EDIT
I found yet another way using Enumeratee but it needs to create 2 Enumerator s (one short lived). However is it a bit more elegant. We use a "kind-of" Enumeratee but the Traversal one which works at a finer level than Enumeratee (chunck level).
We use take 1 that will take only 1 byte and then close the stream. On the other one, we use drop that simply drops the first byte (because we're using a Enumerator[Array[Byte]])
Furthermore, now read2 has a signature much more closer than what you wished, because it returns 2 enumerators (not so far from Promise, Enumerator)
def index = Action {
Async {
val (first, rest) = read2
val enee = Enumeratee.map[Array[Byte]] {bs => new String(bs, Charset.forName("utf-8"))}
def useEnee(enumor:Enumerator[Array[Byte]]) = Iteratee.flatten(enumor &> enee |>> Iteratee.consume[String]()).run.asInstanceOf[Promise[String]]
for {
f <- useEnee(first);
r <- useEnee(rest)
} yield Ok("first : " + f + "\n" + "rest" + r)
}
}
def read2 = {
def create = Enumerator.fromFile(Play.getFile("/conf/routes"))
val file: Enumerator[Array[Byte]] = create
val file2: Enumerator[Array[Byte]] = create
(file &> Traversable.take[Array[Byte]](1), file2 &> Traversable.drop[Array[Byte]](1))
}
Actually we like Iteratees because they compose. So instead of creating multiple Enumerators from your original one, you rather compose the two Iteratees sequentially (read-first and read-rest), and feed it with your single Enumerator.
For this you need a sequential composition method, now I call it andThen. Here is a rough implementation. Note that returning the unconsumed input is a bit harsh, maybe could customize behavior with a typeclass based on the Input type. Also it doesn't handle passing the leftover stuff from the first iterator to the second one (Exercise :).
object Iteratees {
def andThen[E, A, B](a: Iteratee[E, A], b: Iteratee[E, B]): Iteratee[E, (A,B)] = new Iteratee[E, (A,B)] {
def fold[C](
done: ((A, B), Input[E]) => Promise[C],
cont: ((Input[E]) => Iteratee[E, (A, B)]) => Promise[C],
error: (String, Input[E]) => Promise[C]): Promise[C] = {
a.fold(
(ra, aleft) => b.fold(
(rb, bleft) => done((ra, rb), aleft /* could be magicop(aleft, bleft)*/),
(bcont) => cont(e => bcont(e) map (rb => (ra, rb))),
(s, err) => error(s, err)
),
(acont) => cont(e => andThen[E, A, B](acont(e), b)),
(s, err) => error(s, err)
)
}
}
}
Now you can just use the following:
object Application extends Controller {
def index = Action { Async {
val strings: Enumerator[String] = Enumerator("1","2","3","4")
val takeOne = Cont[String, String](e => e match {
case Input.El(e) => Done(e, Input.Empty)
case x => Error("not enough", x)
})
val takeRest = Iteratee.consume[String]()
val firstAndRest = Iteratees.andThen(takeOne, takeRest)
val futureRes = strings(firstAndRest) flatMap (_.run)
futureRes.map(x => Ok(x.toString)) // prints (1,234)
} }
}