Accumulate result until some condition is met in functional way - scala

I have some expensive computation in a loop, and I need to find max value produced by the calculations, though if, say, it will equal to LIMIT I'd like to stop the calculation and return my accumulator.
It may easily be done by recursion:
val list: List[Int] = ???
val UpperBound = ???
def findMax(ls: List[Int], max: Int): Int = ls match {
case h :: rest =>
val v = expensiveComputation(h)
if (v == UpperBound) v
else findMax(rest, math.max(max, v))
case _ => max
}
findMax(list, 0)
My question: whether this behaviour template has a name and reflected in scala collection library?
Update: Do something up to N times or until condition is met in Scala - There is an interesting idea (using laziness and find or exists at the end) but it is not directly applicable to my particular case or requires mutable var to track accumulator.

I think your recursive function is quite nice, so honestly I wouldn't change that, but here's a way to use the collections library:
list.foldLeft(0) {
case (max, next) =>
if(max == UpperBound)
max
else
math.max(expensiveComputation(next), max)
}
It will iterate over the whole list, but after it has hit the upper bound it won't perform the expensive computation.
Update
Based on your comment I tried adapting foldLeft a bit, based on LinearSeqOptimized's foldLeft implementation.
def foldLeftWithExit[A, B](list: Seq[A])(z: B)(exit: B => Boolean)(f: (B, A) => B): B = {
var acc = z
var remaining = list
while (!remaining.isEmpty && !exit(acc)) {
acc = f(acc, list.head)
remaining = remaining.tail
}
acc
}
Calling it:
foldLeftWithExit(list)(0)(UpperBound==){
case (max, next) => math.max(expensiveComputation(next), max)
}
You could potentially use implicits to omit the first parameter of list.
Hope this helps.

Related

Find max value in a list recursively in scala

I'm new to Scala, there is a better way to express this with the most basic knowledge possible?
def findMax(xs: List[Int]): Int = {
xs match {
case x :: tail => (if (tail.length==0) x else (if(x>findMax(tail)) x else (findMax(tail))))
}
}
Thee are two problems here. First, you call tail.length which is an operation of order O(N), so in the worst case this will cost you N*N steps where N is the length of the sequence. The second is that your function is not tail-recursive - you nest the findMax calls "from outside to inside".
The usual strategy to write the correct recursive function is
to think about each possible pattern case: here you have either the empty list Nil or the non-empty list head :: tail. This solves your first problem.
to carry along the temporary result (here the current guess of the maximum value) as another argument of the function. This solves your second problem.
This gives:
import scala.annotation.tailrec
#tailrec
def findMax(xs: List[Int], max: Int): Int = xs match {
case head :: tail => findMax(tail, if (head > max) head else max)
case Nil => max
}
val z = util.Random.shuffle(1 to 100 toList)
assert(findMax(z, Int.MinValue) == 100)
If you don't want to expose this additional argument, you can write an auxiliary inner function.
def findMax(xs: List[Int]): Int = {
#tailrec
def loop(ys: List[Int], max: Int): Int = ys match {
case head :: tail => loop(tail, if (head > max) head else max)
case Nil => max
}
loop(xs, Int.MinValue)
}
val z = util.Random.shuffle(1 to 100 toList)
assert(findMax(z) == 100)
For simplicity we return Int.MinValue if the list is empty. A better solution might be to throw an exception for this case.
The #tailrec annotation here is optional, it simply assures that we indeed defined a tail recursive function. This has the advantage that we cannot produce a stack overflow if the list is extremely long.
Any time you're reducing a collection to a single value, consider using one of the fold functions instead of explicit recursion.
List(3,7,1).fold(Int.MinValue)(Math.max)
// 7
Even I too am new to Scala (am into Haskell though!).
My attempt at this would be as below.
Note that I assume a non-empty list, since the max of an empty list does not make sense.
I first define an helper method which simply returns the max of 2 numbers.
def maxOf2 (x:Int, y:Int): Int = {
if (x >= y) x
else y
}
Armed with this simple function, we can build a recursive function to find the 'max' as below:
def findMax(xs: List[Int]): Int = {
if (xs.tail.isEmpty)
xs.head
else
maxOf2(xs.head, findMax(xs.tail))
}
I feel this is a pretty 'clear'(though not 'efficient') way to do it.
I wanted to make the concept of recursion obvious.
Hope this helps!
Elaborating on #fritz's answer. If you pass in an empty list, it will throw you a java.lang.UnsupportedOperationException: tail of empty list
So, keeping the algorithm intact, I made this adjustment:
def max(xs: List[Int]): Int = {
def maxOfTwo(x: Int, y: Int): Int = {
if (x >= y) x else y
}
if (xs.isEmpty) throw new UnsupportedOperationException("What man?")
else if (xs.size == 1) xs.head
else maxOfTwo(xs.head, max(xs.tail))
}
#fritz Thanks for the answer
Using pattern matching an recursion,
def top(xs: List[Int]): Int = xs match {
case Nil => sys.error("no max in empty list")
case x :: Nil => x
case x :: xs => math.max(x, top(xs))
}
Pattern matching is used to decompose the list into head and rest. A single element list is denoted with x :: Nil. We recurse on the rest of the list and compare for maximum on the head item of the list at each recursive stage. To make the cases exhaustive (to make a well-defined function) we consider also empty lists (Nil).
def maxl(xl: List[Int]): Int = {
if ( (xl.head > xl.tail.head) && (xl.tail.length >= 1) )
return xl.head
else
if(xl.tail.length == 1)
xl.tail.head
else
maxl(xl.tail)
}

Idiomatic "do until" collection updating

Scenario:
val col: IndexedSeq[Array[Char]] = for (i <- 1 to n) yield {
val x = for (j <- 1 to m) yield 'x'
x.toArray
}
This is a fairly simple char matrix. toArray used to allow updating.
var west = last.x - 1
while (west >= 0 && arr(last.y)(west) == '.') {
arr(last.y)(west) = ch;
west -= 1;
}
This is updating all . to ch until a non-dot char is found.
Generically, update until stop condition is met, unknown number of steps.
What is the idiomatic equivalent of it?
Conclusion
It's doable, but the trade-off isn't worth it, a lot of performance is lost to expressive syntax when the collection allows updating.
Your wish for a "cleaner, more idiomatic" solution is of course a little fuzzy, because it leaves a lot of room for subjectivity. In general, I'd consider a tail-recursive updating routine more idiomatic, but it might not be "cleaner" if you're more familiar with a non-functional programming style. I came up with this:
#tailrec
def update(arr:List[Char], replace:Char, replacement:Char, result:List[Char] = Nil):List[Char] = arr match {
case `replace` :: tail =>
update(tail, replace, replacement, replacement :: result)
case _ => result.reverse ::: arr
}
This takes one of the inner sequences (assuming a List for easier pattern matching, since Arrays are trivially convertible to lists), and replaces the replace char with the replacement recursively.
You can then use map to update the outer sequence, like so:
col.map { x => update(x, '.', ch) }
Another more reusable alternative is writing your own mapUntil, or using one which is implemented in a supplemental library (Scalaz probably has something like it). The one I came up with looks like this:
def mapUntil[T](input:List[T])(f:(T => Option[T])) = {
#tailrec
def inner(xs:List[T], result:List[T]):List[T] = xs match {
case Nil => Nil
case head :: tail => f(head) match {
case None => (head :: result).reverse ::: tail
case Some(x) => inner(tail, x :: result)
}
}
inner(input, Nil)
}
It does the same as a regular map invocation, except that it stops as soon as the passed function returns None, e.g.
mapUntil(List(1,2,3,4)) {
case x if x >= 3 => None
case x => Some(x-1)
}
Will result in
List[Int] = List(0, 1, 3, 4)
If you want to look at Scalaz, this answer might be a good place to start.
x3ro's answer is the right answer, esp. if you care about performance or are going to be using this operation in multiple places. I would like to add simple solution using only what you find in the collections API:
col.map { a =>
val (l, r) = a.span(_ == '.')
l.map {
case '.' => ch
case x => x
} ++ r
}

Why doesn't my recursive function return the max value of a List

I have the following recursive function in Scala that should return the maximum size integer in the List. Is anyone able to tell me why the largest value is not returned?
def max(xs: List[Int]): Int = {
var largest = xs.head
println("largest: " + largest)
if (!xs.tail.isEmpty) {
var next = xs.tail.head
println("next: " + next)
largest = if (largest > next) largest else next
var remaining = List[Int]()
remaining = largest :: xs.tail.tail
println("remaining: " + remaining)
max(remaining)
}
return largest
}
Print out statements show me that I've successfully managed to bring back the largest value in the List as the head (which was what I wanted) but the function still returns back the original head in the list. I'm guessing this is because the reference for xs is still referring to the original xs list, problem is I can't override that because it's a val.
Any ideas what I'm doing wrong?
You should use the return value of the inner call to max and compare that to the local largest value.
Something like the following (removed println just for readability):
def max(xs: List[Int]): Int = {
var largest = xs.head
if (!xs.tail.isEmpty) {
var remaining = List[Int]()
remaining = largest :: xs.tail
var next = max(remaining)
largest = if (largest > next) largest else next
}
return largest
}
Bye.
I have an answer to your question but first...
This is the most minimal recursive implementation of max I've ever been able to think up:
def max(xs: List[Int]): Option[Int] = xs match {
case Nil => None
case List(x: Int) => Some(x)
case x :: y :: rest => max( (if (x > y) x else y) :: rest )
}
(OK, my original version was ever so slightly more minimal but I wrote that in Scheme which doesn't have Option or type safety etc.) It doesn't need an accumulator or a local helper function because it compares the first two items on the list and discards the smaller, a process which - performed recursively - inevitably leaves you with a list of just one element which must be bigger than all the rest.
OK, why your original solution doesn't work... It's quite simple: you do nothing with the return value from the recursive call to max. All you had to do was change the line
max(remaining)
to
largest = max(remaining)
and your function would work. It wouldn't be the prettiest solution, but it would work. As it is, your code looks as if it assumes that changing the value of largest inside the recursive call will change it in the outside context from which it was called. But each new call to max creates a completely new version of largest which only exists inside that new iteration of the function. Your code then throws away the return value from max(remaining) and returns the original value of largest, which hasn't changed.
Another way to solve this would have been to use a local (inner) function after declaring var largest. That would have looked like this:
def max(xs: List[Int]): Int = {
var largest = xs.head
def loop(ys: List[Int]) {
if (!ys.isEmpty) {
var next = ys.head
largest = if (largest > next) largest else next
loop(ys.tail)
}
}
loop(xs.tail)
return largest
}
Generally, though, it is better to have recursive functions be entirely self-contained (that is, not to look at or change external variables but only at their input) and to return a meaningful value.
When writing a recursive solution of this kind, it often helps to think in reverse. Think first about what things are going to look like when you get to the end of the list. What is the exit condition? What will things look like and where will I find the value to return?
If you do this, then the case which you use to exit the recursive function (by returning a simple value rather than making another recursive call) is usually very simple. The other case matches just need to deal with a) invalid input and b) what to do if you are not yet at the end. a) is usually simple and b) can usually be broken down into just a few different situations, each with a simple thing to do before making another recursive call.
If you look at my solution, you'll see that the first case deals with invalid input, the second is my exit condition and the third is "what to do if we're not at the end".
In many other recursive solutions, Nil is the natural end of the recursion.
This is the point at which I (as always) recommend reading The Little Schemer. It teaches you recursion (and basic Scheme) at the same time (both of which are very good things to learn).
It has been pointed out that Scala has some powerful functions which can help you avoid recursion (or hide the messy details of it), but to use them well you really do need to understand how recursion works.
The following is a typical way to solve this sort of problem. It uses an inner tail-recursive function that includes an extra "accumulator" value, which in this case will hold the largest value found so far:
def max(xs: List[Int]): Int = {
def go(xs: List[Int], acc: Int): Int = xs match {
case Nil => acc // We've emptied the list, so just return the final result
case x :: rest => if (acc > x) go(rest, acc) else go(rest, x) // Keep going, with remaining list and updated largest-value-so-far
}
go(xs, Int.MinValue)
}
Nevermind I've resolved the issue...
I finally came up with:
def max(xs: List[Int]): Int = {
var largest = 0
var remaining = List[Int]()
if (!xs.isEmpty) {
largest = xs.head
if (!xs.tail.isEmpty) {
var next = xs.tail.head
largest = if (largest > next) largest else next
remaining = largest :: xs.tail.tail
}
}
if (!remaining.tail.isEmpty) max(remaining) else xs.head
}
Kinda glad we have loops - this is an excessively complicated solution and hard to get your head around in my opinion. I resolved the problem by making sure the recursive call was the last statement in the function either that or xs.head is returned as the result if there isn't a second member in the array.
The most concise but also clear version I have ever seen is this:
def max(xs: List[Int]): Int = {
def maxIter(a: Int, xs: List[Int]): Int = {
if (xs.isEmpty) a
else a max maxIter(xs.head, xs.tail)
}
maxIter(xs.head, xs.tail)
}
This has been adapted from the solutions to a homework on the Scala official Corusera course: https://github.com/purlin/Coursera-Scala/blob/master/src/example/Lists.scala
but here I use the rich operator max to return the largest of its two operands. This saves having to redefine this function within the def max block.
What about this?
def max(xs: List[Int]): Int = {
maxRecursive(xs, 0)
}
def maxRecursive(xs: List[Int], max: Int): Int = {
if(xs.head > max && ! xs.isEmpty)
maxRecursive(xs.tail, xs.head)
else
max
}
What about this one ?
def max(xs: List[Int]): Int = {
var largest = xs.head
if( !xs.tail.isEmpty ) {
if(xs.head < max(xs.tail)) largest = max(xs.tail)
}
largest
}
My answer is using recursion is,
def max(xs: List[Int]): Int =
xs match {
case Nil => throw new NoSuchElementException("empty list is not allowed")
case head :: Nil => head
case head :: tail =>
if (head >= tail.head)
if (tail.length > 1)
max(head :: tail.tail)
else
head
else
max(tail)
}
}

Abort early in a fold

What's the best way to terminate a fold early? As a simplified example, imagine I want to sum up the numbers in an Iterable, but if I encounter something I'm not expecting (say an odd number) I might want to terminate. This is a first approximation
def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
nums.foldLeft (Some(0): Option[Int]) {
case (Some(s), n) if n % 2 == 0 => Some(s + n)
case _ => None
}
}
However, this solution is pretty ugly (as in, if I did a .foreach and a return -- it'd be much cleaner and clearer) and worst of all, it traverses the entire iterable even if it encounters a non-even number.
So what would be the best way to write a fold like this, that terminates early? Should I just go and write this recursively, or is there a more accepted way?
My first choice would usually be to use recursion. It is only moderately less compact, is potentially faster (certainly no slower), and in early termination can make the logic more clear. In this case you need nested defs which is a little awkward:
def sumEvenNumbers(nums: Iterable[Int]) = {
def sumEven(it: Iterator[Int], n: Int): Option[Int] = {
if (it.hasNext) {
val x = it.next
if ((x % 2) == 0) sumEven(it, n+x) else None
}
else Some(n)
}
sumEven(nums.iterator, 0)
}
My second choice would be to use return, as it keeps everything else intact and you only need to wrap the fold in a def so you have something to return from--in this case, you already have a method, so:
def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
Some(nums.foldLeft(0){ (n,x) =>
if ((n % 2) != 0) return None
n+x
})
}
which in this particular case is a lot more compact than recursion (though we got especially unlucky with recursion since we had to do an iterable/iterator transformation). The jumpy control flow is something to avoid when all else is equal, but here it's not. No harm in using it in cases where it's valuable.
If I was doing this often and wanted it within the middle of a method somewhere (so I couldn't just use return), I would probably use exception-handling to generate non-local control flow. That is, after all, what it is good at, and error handling is not the only time it's useful. The only trick is to avoid generating a stack trace (which is really slow), and that's easy because the trait NoStackTrace and its child trait ControlThrowable already do that for you. Scala already uses this internally (in fact, that's how it implements the return from inside the fold!). Let's make our own (can't be nested, though one could fix that):
import scala.util.control.ControlThrowable
case class Returned[A](value: A) extends ControlThrowable {}
def shortcut[A](a: => A) = try { a } catch { case Returned(v) => v }
def sumEvenNumbers(nums: Iterable[Int]) = shortcut{
Option(nums.foldLeft(0){ (n,x) =>
if ((x % 2) != 0) throw Returned(None)
n+x
})
}
Here of course using return is better, but note that you could put shortcut anywhere, not just wrapping an entire method.
Next in line for me would be to re-implement fold (either myself or to find a library that does it) so that it could signal early termination. The two natural ways of doing this are to not propagate the value but an Option containing the value, where None signifies termination; or to use a second indicator function that signals completion. The Scalaz lazy fold shown by Kim Stebel already covers the first case, so I'll show the second (with a mutable implementation):
def foldOrFail[A,B](it: Iterable[A])(zero: B)(fail: A => Boolean)(f: (B,A) => B): Option[B] = {
val ii = it.iterator
var b = zero
while (ii.hasNext) {
val x = ii.next
if (fail(x)) return None
b = f(b,x)
}
Some(b)
}
def sumEvenNumbers(nums: Iterable[Int]) = foldOrFail(nums)(0)(_ % 2 != 0)(_ + _)
(Whether you implement the termination by recursion, return, laziness, etc. is up to you.)
I think that covers the main reasonable variants; there are some other options also, but I'm not sure why one would use them in this case. (Iterator itself would work well if it had a findOrPrevious, but it doesn't, and the extra work it takes to do that by hand makes it a silly option to use here.)
The scenario you describe (exit upon some unwanted condition) seems like a good use case for the takeWhile method. It is essentially filter, but should end upon encountering an element that doesn't meet the condition.
For example:
val list = List(2,4,6,8,6,4,2,5,3,2)
list.takeWhile(_ % 2 == 0) //result is List(2,4,6,8,6,4,2)
This will work just fine for Iterators/Iterables too. The solution I suggest for your "sum of even numbers, but break on odd" is:
list.iterator.takeWhile(_ % 2 == 0).foldLeft(...)
And just to prove that it's not wasting your time once it hits an odd number...
scala> val list = List(2,4,5,6,8)
list: List[Int] = List(2, 4, 5, 6, 8)
scala> def condition(i: Int) = {
| println("processing " + i)
| i % 2 == 0
| }
condition: (i: Int)Boolean
scala> list.iterator.takeWhile(condition _).sum
processing 2
processing 4
processing 5
res4: Int = 6
You can do what you want in a functional style using the lazy version of foldRight in scalaz. For a more in depth explanation, see this blog post. While this solution uses a Stream, you can convert an Iterable into a Stream efficiently with iterable.toStream.
import scalaz._
import Scalaz._
val str = Stream(2,1,2,2,2,2,2,2,2)
var i = 0 //only here for testing
val r = str.foldr(Some(0):Option[Int])((n,s) => {
println(i)
i+=1
if (n % 2 == 0) s.map(n+) else None
})
This only prints
0
1
which clearly shows that the anonymous function is only called twice (i.e. until it encounters the odd number). That is due to the definition of foldr, whose signature (in case of Stream) is def foldr[B](b: B)(f: (Int, => B) => B)(implicit r: scalaz.Foldable[Stream]): B. Note that the anonymous function takes a by name parameter as its second argument, so it need no be evaluated.
Btw, you can still write this with the OP's pattern matching solution, but I find if/else and map more elegant.
Well, Scala does allow non local returns. There are differing opinions on whether or not this is a good style.
scala> def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
| nums.foldLeft (Some(0): Option[Int]) {
| case (None, _) => return None
| case (Some(s), n) if n % 2 == 0 => Some(s + n)
| case (Some(_), _) => None
| }
| }
sumEvenNumbers: (nums: Iterable[Int])Option[Int]
scala> sumEvenNumbers(2 to 10)
res8: Option[Int] = None
scala> sumEvenNumbers(2 to 10 by 2)
res9: Option[Int] = Some(30)
EDIT:
In this particular case, as #Arjan suggested, you can also do:
def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
nums.foldLeft (Some(0): Option[Int]) {
case (Some(s), n) if n % 2 == 0 => Some(s + n)
case _ => return None
}
}
You can use foldM from cats lib (as suggested by #Didac) but I suggest to use Either instead of Option if you want to get actual sum out.
bifoldMap is used to extract the result from Either.
import cats.implicits._
def sumEven(nums: Stream[Int]): Either[Int, Int] = {
nums.foldM(0) {
case (acc, n) if n % 2 == 0 => Either.right(acc + n)
case (acc, n) => {
println(s"Stopping on number: $n")
Either.left(acc)
}
}
}
examples:
println("Result: " + sumEven(Stream(2, 2, 3, 11)).bifoldMap(identity, identity))
> Stopping on number: 3
> Result: 4
println("Result: " + sumEven(Stream(2, 7, 2, 3)).bifoldMap(identity, identity))
> Stopping on number: 7
> Result: 2
Cats has a method called foldM which does short-circuiting (for Vector, List, Stream, ...).
It works as follows:
def sumEvenNumbers(nums: Stream[Int]): Option[Long] = {
import cats.implicits._
nums.foldM(0L) {
case (acc, c) if c % 2 == 0 => Some(acc + c)
case _ => None
}
}
If it finds a not even element it returns None without computing the rest, otherwise it returns the sum of the even entries.
If you want to keep count until an even entry is found, you should use an Either[Long, Long]
#Rex Kerr your answer helped me, but I needed to tweak it to use Either
def foldOrFail[A,B,C,D](map: B => Either[D, C])(merge: (A, C) => A)(initial: A)(it: Iterable[B]): Either[D, A] = {
val ii= it.iterator
var b= initial
while (ii.hasNext) {
val x= ii.next
map(x) match {
case Left(error) => return Left(error)
case Right(d) => b= merge(b, d)
}
}
Right(b)
}
You could try using a temporary var and using takeWhile. Here is a version.
var continue = true
// sample stream of 2's and then a stream of 3's.
val evenSum = (Stream.fill(10)(2) ++ Stream.fill(10)(3)).takeWhile(_ => continue)
.foldLeft(Option[Int](0)){
case (result,i) if i%2 != 0 =>
continue = false;
// return whatever is appropriate either the accumulated sum or None.
result
case (optionSum,i) => optionSum.map( _ + i)
}
The evenSum should be Some(20) in this case.
You can throw a well-chosen exception upon encountering your termination criterion, handling it in the calling code.
A more beutiful solution would be using span:
val (l, r) = numbers.span(_ % 2 == 0)
if(r.isEmpty) Some(l.sum)
else None
... but it traverses the list two times if all the numbers are even
Just for an "academic" reasons (:
var headers = Source.fromFile(file).getLines().next().split(",")
var closeHeaderIdx = headers.takeWhile { s => !"Close".equals(s) }.foldLeft(0)((i, S) => i+1)
Takes twice then it should but it is a nice one liner.
If "Close" not found it will return
headers.size
Another (better) is this one:
var headers = Source.fromFile(file).getLines().next().split(",").toList
var closeHeaderIdx = headers.indexOf("Close")

Encoding recursive tree-creation with while loop + stacks

I'm a bit embarassed to admit this, but I seem to be pretty stumped by what should be a simple programming problem. I'm building a decision tree implementation, and have been using recursion to take a list of labeled samples, recursively split the list in half, and turn it into a tree.
Unfortunately, with deep trees I run into stack overflow errors (ha!), so my first thought was to use continuations to turn it into tail recursion. Unfortunately Scala doesn't support that kind of TCO, so the only solution is to use a trampoline. A trampoline seems kinda inefficient and I was hoping there would be some simple stack-based imperative solution to this problem, but I'm having a lot of trouble finding it.
The recursive version looks sort of like (simplified):
private def trainTree(samples: Seq[Sample], usedFeatures: Set[Int]): DTree = {
if (shouldStop(samples)) {
DTLeaf(makeProportions(samples))
} else {
val featureIdx = getSplittingFeature(samples, usedFeatures)
val (statsWithFeature, statsWithoutFeature) = samples.partition(hasFeature(featureIdx, _))
DTBranch(
trainTree(statsWithFeature, usedFeatures + featureIdx),
trainTree(statsWithoutFeature, usedFeatures + featureIdx),
featureIdx)
}
}
So basically I'm recursively subdividing the list into two according to some feature of the data, and passing through a list of used features so I don't repeat - that's all handled in the "getSplittingFeature" function so we can ignore it. The code is really simple! Still, I'm having trouble figuring out a stack-based solution that doesn't just use closures and effectively become a trampoline. I know we'll at least have to keep around little "frames" of arguments in the stack but I would like to avoid closure calls.
I get that I should be writing out explicitly what the callstack and program counter handle for me implicitly in the recursive solution, but I'm having trouble doing that without continuations. At this point it's hardly even about efficiency, I'm just curious. So please, no need to remind me that premature optimization is the root of all evil and the trampoline-based solution will probably work just fine. I know it probably will - this is basically a puzzle for it's own sake.
Can anyone tell me what the canonical while-loop-and-stack-based solution to this sort of thing is?
UPDATE: Based on Thipor Kong's excellent solution, I've coded up a while-loops/stacks/hashtable based implementation of the algorithm that should be a direct translation of the recursive version. This is exactly what I was looking for:
FINAL UPDATE: I've used sequential integer indices, as well as putting everything back into arrays instead of maps for performance, added maxDepth support, and finally have a solution with the same performance as the recursive version (not sure about memory usage but I would guess less):
private def trainTreeNoMaxDepth(startingSamples: Seq[Sample], startingMaxDepth: Int): DTree = {
// Use arraybuffer as dense mutable int-indexed map - no IndexOutOfBoundsException, just expand to fit
type DenseIntMap[T] = ArrayBuffer[T]
def updateIntMap[#specialized T](ab: DenseIntMap[T], idx: Int, item: T, dfault: T = null.asInstanceOf[T]) = {
if (ab.length <= idx) {ab.insertAll(ab.length, Iterable.fill(idx - ab.length + 1)(dfault)) }
ab.update(idx, item)
}
var currentChildId = 0 // get childIdx or create one if it's not there already
def child(childMap: DenseIntMap[Int], heapIdx: Int) =
if (childMap.length > heapIdx && childMap(heapIdx) != -1) childMap(heapIdx)
else {currentChildId += 1; updateIntMap(childMap, heapIdx, currentChildId, -1); currentChildId }
// go down
val leftChildren, rightChildren = new DenseIntMap[Int]() // heapIdx -> childHeapIdx
val todo = Stack((startingSamples, Set.empty[Int], startingMaxDepth, 0)) // samples, usedFeatures, maxDepth, heapIdx
val branches = new Stack[(Int, Int)]() // heapIdx, featureIdx
val nodes = new DenseIntMap[DTree]() // heapIdx -> node
while (!todo.isEmpty) {
val (samples, usedFeatures, maxDepth, heapIdx) = todo.pop()
if (shouldStop(samples) || maxDepth == 0) {
updateIntMap(nodes, heapIdx, DTLeaf(makeProportions(samples)))
} else {
val featureIdx = getSplittingFeature(samples, usedFeatures)
val (statsWithFeature, statsWithoutFeature) = samples.partition(hasFeature(featureIdx, _))
todo.push((statsWithFeature, usedFeatures + featureIdx, maxDepth - 1, child(leftChildren, heapIdx)))
todo.push((statsWithoutFeature, usedFeatures + featureIdx, maxDepth - 1, child(rightChildren, heapIdx)))
branches.push((heapIdx, featureIdx))
}
}
// go up
while (!branches.isEmpty) {
val (heapIdx, featureIdx) = branches.pop()
updateIntMap(nodes, heapIdx, DTBranch(nodes(child(leftChildren, heapIdx)), nodes(child(rightChildren, heapIdx)), featureIdx))
}
nodes(0)
}
Just store the binary tree in an array, as described on Wikipedia: For node i, the left child goes into 2*i+1 and the right child in to 2*i+2. When doing "down", you keep a collection of todos, that still have to be splitted to reach a leaf. Once you've got only leafs, to go upward (from right to left in the array) to build the decision nodes:
Update: A cleaned up version, that also supports the features stored int the branches (type parameter B) and that is more functional/fully pure and that supports sparse trees with a map as suggested by ron.
Update2-3: Make economical use of name space for node ids and abstract over type of ids to allow of large trees. Take node ids from Stream.
sealed trait DTree[A, B]
case class DTLeaf[A, B](a: A, b: B) extends DTree[A, B]
case class DTBranch[A, B](left: DTree[A, B], right: DTree[A, B], b: B) extends DTree[A, B]
def mktree[A, B, Id](a: A, b: B, split: (A, B) => Option[(A, A, B)], ids: Stream[Id]) = {
#tailrec
def goDown(todo: Seq[(A, B, Id)], branches: Seq[(Id, B, Id, Id)], leafs: Map[Id, DTree[A, B]], ids: Stream[Id]): (Seq[(Id, B, Id, Id)], Map[Id, DTree[A, B]]) =
todo match {
case Nil => (branches, leafs)
case (a, b, id) :: rest =>
split(a, b) match {
case None =>
goDown(rest, branches, leafs + (id -> DTLeaf(a, b)), ids)
case Some((left, right, b2)) =>
val leftId #:: rightId #:: idRest = ids
goDown((right, b2, rightId) +: (left, b2, leftId) +: rest, (id, b2, leftId, rightId) +: branches, leafs, idRest)
}
}
#tailrec
def goUp[A, B](branches: Seq[(Id, B, Id, Id)], nodes: Map[Id, DTree[A, B]]): Map[Id, DTree[A, B]] =
branches match {
case Nil => nodes
case (id, b, leftId, rightId) :: rest =>
goUp(rest, nodes + (id -> DTBranch(nodes(leftId), nodes(rightId), b)))
}
val rootId #:: restIds = ids
val (branches, leafs) = goDown(Seq((a, b, rootId)), Seq(), Map(), restIds)
goUp(branches, leafs)(rootId)
}
// try it out
def split(xs: Seq[Int], b: Int) =
if (xs.size > 1) {
val (left, right) = xs.splitAt(xs.size / 2)
Some((left, right, b + 1))
} else {
None
}
val tree = mktree(0 to 1000, 0, split _, Stream.from(0))
println(tree)