Combinatorics in Scala: How to iterate/enumerate all possibilities to merge multiple sequences/lists (riffle shuffle permutations) - scala

Updated question:
In my original question I did not know how to refer to the following problem. To clarify my question, I added the following illustration from Wikipedia:
It turns out that the problem is also named after this analogy: Riffle shuffle permutations. Based on this terminology my question simply becomes: How can I iterate/enumerate all riffle shuffle permutations in the general case of multiple decks?
Original question:
Let's assume we are given multiple sequences and we want to merge these sequences into one single sequence. The resulting sequence should preserve the order of the original sequences. Think of merging multiple stacks of cards (say Seq[Seq[T]]) into one single stack (Seq[T]) by randomly drawing a card from any (random) stack. All input stacks should be fully merged into the resulting stack. How can I iterate or enumerate all possible compositions of such a resulting sequence?
To clarify: If I have three stacks A, B, C (of say 5 elements each) then I do not only want the six possible arrangements of these stacks like "all of A, all of B, all of C" and "all of A, all of C, all of B" etc. I rather want all possible compositions like "1. of A, 1. of B, 2. of A, 1. of C, 3. of A, 2. of B, ...".
Since I'm a bit under the weather today, my first approach is terribly ugly and also produces duplicates:
def enumerateCompositions[T](stacks: Seq[Seq[T]], prefix: Seq[T]): Seq[Seq[T]] = {
if (stacks.length == 0) return {
Seq(prefix)
}
stacks.zipWithIndex.flatMap{ case (stack, stackInd) =>
if (stack.length > 0) {
val stacksWithHeadRemoved = stacks.indices.map{ i =>
if (i != stackInd) stacks(i) else stacks(i).drop(1)
}
enumerateCompositions(stacksWithHeadRemoved, prefix :+ stack.head)
} else {
val remainingStacks = stacks.indices.filterNot(_ == stackInd).map(i => stacks(i))
enumerateCompositions(remainingStacks, prefix)
}
}
}
Any idea how to make this more elegant and get rid of the duplicates?

Let's call this operation "to riffle". Here is a clean idomatic solution:
def allRiffles[T](stack1: List[T], stack2: List[T]): List[List[T]] =
(stack1, stack2) match {
case (x :: xs, y :: ys) => {
allRiffles(xs, stack2).map(x :: _) ++
allRiffles(stack1, ys).map(y :: _)
}
case _ => List(stack1 ++ stack2) // at least one is empty
}
def allRifflesSeq[T](stacks: Seq[List[T]]): List[List[T]] =
stacks.foldLeft(List(List[T]())) { (z, x) =>
z.flatMap(y => allRiffles(y, x))
}
allRiffles will produce all the possible rifflings of two stacks. allRifflesSeq will take a sequence of stacks and produce all the possible rifflings using a fold. For example, if allRifflesSeq is given stacks A, B, and C, it first produces all possible rifflings of A and B and then riffles C into each of those rifflings.
Note that allRiffles consumes stacks space proportional to the length of the shortest stack and allRifflesSeq consumes stacks space bounded by the length of the longest stack. Also, the returned list could be huge (combinatoric explosion) and consume a lot of heap space. An Iterator based solution is safer, but much less pretty:
def allRiffles[T](stacks: List[List[T]]): Iterator[List[T]] = new Iterator[List[T]] {
type Frame = (List[List[T]], List[List[T]], List[T])
var stack = List[Frame]((Nil, stacks, Nil))
var ready = false
var cachedHasNext: Boolean = _
var cachedNext: List[T] = _
def computeNext: Unit = {
while (stack.nonEmpty) {
val (doneStacks, stacks, prefix) :: stackTail = stack
stack = stackTail
stacks match {
case Nil => {
cachedNext = prefix.reverse
cachedHasNext = true
return
}
case Nil :: rest =>
stack ::= (doneStacks, rest, prefix)
case (xs#(x :: xtail)) :: rest =>
if (rest.nonEmpty)
stack ::= (xs :: doneStacks, rest, prefix)
val newStacks = doneStacks.reverse ++ (if (xtail.isEmpty) rest
else xtail :: rest)
stack ::= (Nil, newStacks, x :: prefix)
}
}
cachedHasNext = false
}
def ensureReady = {
if (!ready) {
computeNext
ready = true
}
}
def next = {
ensureReady
if (cachedHasNext) {
val next = cachedNext
ready = false
next
} else Iterator.empty.next
}
def hasNext = {
ensureReady
cachedHasNext
}
}

I have written this in Java. The code is following.
import java.util.*;
import java.io.*;
public class RiffleShufflePermutation {
protected static ArrayList<String> a1 = null;
protected static ArrayList<String> a2 = null;
protected static ArrayList<ArrayList<String>> a = new <ArrayList<ArrayList<String>>();
private static int getStartingPosition(ArrayList<String> inA, String inS) {
for(String s : inA)
if (s.equals(inS))
return inA.indexOf(s)+1;
return 0;
}
private static void shootRiffle(int previous, int current) {
ArrayList<ArrayList<String>> newAA = new ArrayList<ArrayList<String>>();
ArrayList<String> newA;
int start;
for(ArrayList<String> al : a) {
start = (previous < 0)?0:getStartingPosition(al,a2.get(previous));
for(int i=start; i<=al.size(); i++) {
newA = new ArrayList<String>();
newA.addAll(al);
newA.add(i,a2.get(current));
newAA.add(newA);
}
}
a.clear();
a.addAll(newAA);
}
public static void main(String[] args) {
a1 = new ArrayList(Arrays.asList("a1","a2","a3"));
a2 = new ArrayList(Arrays.asList("b1","b2"));
a.add(a1);
for (int i=0; i<a2.size(); i++)
shootRiffle(i-1,i);
int i = 0;
for (ArrayList<String> s : a)
System.out.println(String.format("%2d",++i)+":"+s);
}
}
Here is the Output:
1:[b1, b2, a1, a2, a3]
2:[b1, a1, b2, a2, a3]
3:[b1, a1, a2, b2, a3]
4:[b1, a1, a2, a3, b2]
5:[a1, b1, b2, a2, a3]
6:[a1, b1, a2, b2, a3]
7:[a1, b1, a2, a3, b2]
8:[a1, a2, b1, b2, a3]
9:[a1, a2, b1, a3, b2]
10:[a1, a2, a3, b1, b2]
Hopefully this is useful.

Related

Add a marker character between duplicate pair in list

I am working on an exercise that I need to figure out how to add designated marker char between two duplicate elements in a list.
input - a string
output - a list of string pairs
Two rules;
if the input string has duplicate characters, a char x needs to be
added between them. For ex; trees will become tr, ex, es
if the duplicate char pair is xx, add a q between them. For ex;
boxx becomes bo,xq, x
Both rules run together on the input, For example;
if the input is HelloScalaxxxx the output should be List("He", "lx", "lo", "Sc", "al", "ax", "xq", "xq", "x")
I got the first rule working with following code and struggling to get the second rule satisfied.
input.foldRight[List[Char]](Nil) {
case (h, t) =>
println(h :: t)
if (t.nonEmpty) {
(h, t.head) match {
case ('x', 'x') => t ::: List(h, 'q')
case _ => if (h == t.head) h :: 'x' :: t else h :: t
}
} else h :: t
}
.mkString("").grouped(2).toSeq
I think I am close, for the input HelloScalaxxxx it produces List("He", "lx", "lo", "Sc", "al", "ax", "xq", "xq", "xq"), but with an extra q in the last pair.
I don't want to use a regex-based solution. Looking for an idiomatic Scala version.
I tried searching for existing answers but no luck. Any help would be appreciated. Thank you.
I assume you want to apply the xx rule first...but you can decide.
"Trees & Scalaxxxx"
.replaceAll("(x)(?=\\1)","$1q")
.replaceAll("([^x])(?=\\1)","$1x")
.grouped(2).toList
//res0: List[String] = List(Tr, ex, es, " &", " S", ca, la, xq, xq, xq, x)
And here's the non-regex offering.
"Trees & Scalaxxxx"
.foldLeft(('_',"")){
case (('x',acc),'x') => ('x', s"${acc}qx")
case ((p,acc),c) if c == p &&
p.isLetter => ( c , s"${acc}x$c")
case ((_,acc),c) => ( c , s"$acc$c")
}._2.grouped(2).toList
Tail recursive solution
def processString(input: String): List[String] = {
#scala.annotation.tailrec
def inner(buffer: List[String], str: String): List[String] = {
// recursion ending condition. Nothing left to process
if (str.isEmpty) return buffer
val c0 = str.head
val c1 = if (str.isDefinedAt(1)) {
str(1)
} else {
// recursion ending condition. Only head remains.
return buffer :+ c0.toString
}
val (newBuffer, remainingString) =
(c0, c1) match {
case ('x', 'x') => (buffer :+ "xq", str.substring(1))
case (_, _) if c0 == c1 => (buffer :+ s"${c0}x", str.substring(1))
case _ => (buffer :+ s"$c0$c1", str.substring(2))
}
inner(newBuffer, remainingString)
}
// start here. Pass empty buffer and complete input string
inner(List.empty, input)
}
println(processString("trees"))
println(processString("boxx"))
println(processString("HelloScalaxxxx"))

How can I elegantly return a map while doing work

I'm new to Scala, coming over from Java, and I'm having trouble elegantly returning Map from this function. What's an elegant way to rewrite this function so it doesn't have this awful repetition?
val data = getData
if (someTest(data)) {
val D = doSomething(data)
val E = doWork(D)
if (someTest2(E)) {
val a = A()
val b = B()
Map(a -> b)
} else {
Map.empty
}
} else {
Map.empty
}
If you have a problem with connecting too many conditions with &&, you can put everything into the natural short-circuiting monad (namely Option), perform bunch of filter and map-steps on it, replace the result by Map(A() -> B()) if all the tests are successful, and then unwrap the Option with a getOrElse in the end:
Option(getData)
.filter(someTest)
.map(doSomething andThen doWork)
.filter(someTest2)
.map(_ => Map(A() -> B()))
.getOrElse(Map.empty)
In this way, you can organize your code "more vertically".
Andrey's answer is correct, but the logic can also be written using a for statement:
(for {
data <- Option(getData) if someTest(data)
d = doSomething(data)
e = doWork(d) if someTest2(e)
} yield {
Map(A() -> B())
}).getOrElse(Map.empty)
This retains a bit more of the original form of the code, but it is a matter of taste which version to use. You can also put the if on a separate line if that makes it clearer.
Note that I have retained the values of d and e on the assumption that they are actually meaningful in the real code. If not then there can be a single if expression that does all the tests, as noted in other answers:
(for {
data <- Option(getData)
if someTest(data) && someTest2(doWork(doSomething(data)))
} yield {
Map(A() -> B())
}).getOrElse(Map.empty)
You may rewrite to take advantage of short circuit, if you are mentioning to the else blocks with Map.empty as repetition.
val data = getData
if (someTest(data) && someTest2(doWork(doSomething(data)))) {
val a = A()
val b = B()
Map(a -> b)
} else {
Map.empty
}
Second solution using lazy evaluation:
val data = getData
lazy val D = doSomething(data)
lazy val E = doWork(D)
if (someTest(data) && someTest2(E)) {
val a = A()
val b = B()
Map(a -> b)
} else {
Map.empty
}
D, E and someTest2(E) won't get evaluated if someTest(data) is false.

Topological sort in scala

I'm looking for a nice implementation of topological sorting in scala.
The solution should be stable:
If input is already sorted, the output should be unchanged
The algorithm should be deterministic (hashCode has no effect)
I suspect there are libraries that can do this, but I wouldn't like to add nontrivial dependencies due to this.
Example problem:
case class Node(name: String)(val referenced: Node*)
val a = Node("a")()
val b = Node("b")(a)
val c = Node("c")(a)
val d = Node("d")(b, c)
val e = Node("e")(d)
val f = Node("f")()
assertEquals("Previous order is kept",
Vector(f, a, b, c, d, e),
topoSort(Vector(f, a, b, c, d, e)))
assertEquals(Vector(a, b, c, d, f, e),
topoSort(Vector(d, c, b, f, a, e)))
Here the order is defined such that if the nodes were say declarations in a programming language referencing other declarations, the result order would
be such that no declaration is used before it has been declared.
Here is my own solution. Additionnally it returns possible loops detected in the input.
The format of the nodes is not fixed because the caller provides a visitor that
will take a node and a callback and call the callback for each referenced node.
If the loop reporting is not necessary, it should be easy to remove.
import scala.collection.mutable
// Based on https://en.wikipedia.org/wiki/Topological_sorting?oldformat=true#Depth-first_search
object TopologicalSort {
case class Result[T](result: IndexedSeq[T], loops: IndexedSeq[IndexedSeq[T]])
type Visit[T] = (T) => Unit
// A visitor is a function that takes a node and a callback.
// The visitor calls the callback for each node referenced by the given node.
type Visitor[T] = (T, Visit[T]) => Unit
def topoSort[T <: AnyRef](input: Iterable[T], visitor: Visitor[T]): Result[T] = {
// Buffer, because it is operated in a stack like fashion
val temporarilyMarked = mutable.Buffer[T]()
val permanentlyMarked = mutable.HashSet[T]()
val loopsBuilder = IndexedSeq.newBuilder[IndexedSeq[T]]
val resultBuilder = IndexedSeq.newBuilder[T]
def visit(node: T): Unit = {
if (temporarilyMarked.contains(node)) {
val loopStartIndex = temporarilyMarked.indexOf(node)
val loop = temporarilyMarked.slice(loopStartIndex, temporarilyMarked.size)
.toIndexedSeq
loopsBuilder += loop
} else if (!permanentlyMarked.contains(node)) {
temporarilyMarked += node
visitor(node, visit)
permanentlyMarked += node
temporarilyMarked.remove(temporarilyMarked.size - 1, 1)
resultBuilder += node
}
}
for (i <- input) {
if (!permanentlyMarked.contains(i)) {
visit(i)
}
}
Result(resultBuilder.result(), loopsBuilder.result())
}
}
In the example of the question this would be applied like this:
import TopologicalSort._
def visitor(node: BaseNode, callback: (Node) => Unit): Unit = {
node.referenced.foreach(callback)
}
assertEquals("Previous order is kept",
Vector(f, a, b, c, d, e),
topoSort(Vector(f, a, b, c, d, e), visitor).result)
assertEquals(Vector(a, b, c, d, f, e),
topoSort(Vector(d, c, b, f, a, e), visitor).result)
Some thoughts on complexity:
The worst case complexity of this solution is actually above O(n + m) because the temporarilyMarked array is scanned for each node.
The asymptotic complexity would be improved if the temporarilyMarked would be replaced with for example a HashSet.
A true O(n + m) would be achieved if the marks were be stored directly inside the nodes, but storing them outside makes writing a generic solution easier.
I haven't run any performance tests, but I suspect scanning the temporarilyMarked array is not a problem even in large graphs as long as they are not very deep.
Example code and test on Github
I have very similar code is also published here. That version has a test suite which can be useful for experimenting and exploring the implementation.
Why would you detect loops
Detecting loops can be useful for example in serialization situations where most of the data can be handled as a DAG, but loops can be handled with some kind of special arrangement.
The test suite in the Github code linked to in above section contains various cases with multiple loops.
Here's a purely functional implementation that returns the topological ordering ONLY if the graph is acyclic.
case class Node(label: Int)
case class Graph(adj: Map[Node, Set[Node]]) {
case class DfsState(discovered: Set[Node] = Set(), activeNodes: Set[Node] = Set(), tsOrder: List[Node] = List(),
isCylic: Boolean = false)
def dfs: (List[Node], Boolean) = {
def dfsVisit(currState: DfsState, src: Node): DfsState = {
val newState = currState.copy(discovered = currState.discovered + src, activeNodes = currState.activeNodes + src,
isCylic = currState.isCylic || adj(src).exists(currState.activeNodes))
val finalState = adj(src).filterNot(newState.discovered).foldLeft(newState)(dfsVisit(_, _))
finalState.copy(tsOrder = src :: finalState.tsOrder, activeNodes = finalState.activeNodes - src)
}
val stateAfterSearch = adj.keys.foldLeft(DfsState()) {(state, n) => if (state.discovered(n)) state else dfsVisit(state, n)}
(stateAfterSearch.tsOrder, stateAfterSearch.isCylic)
}
def topologicalSort: Option[List[Node]] = dfs match {
case (topologicalOrder, false) => Some(topologicalOrder)
case _ => None
}
}

Combining multiple Lists of arbitrary length

I am looking for an approach to join multiple Lists in the following manner:
ListA a b c
ListB 1 2 3 4
ListC + # * § %
..
..
..
Resulting List: a 1 + b 2 # c 3 * 4 § %
In Words: The elements in sequential order, starting at first list combined into the resulting list. An arbitrary amount of input lists could be there varying in length.
I used multiple approaches with variants of zip, sliding iterators but none worked and especially took care of varying list lengths. There has to be an elegant way in scala ;)
val lists = List(ListA, ListB, ListC)
lists.flatMap(_.zipWithIndex).sortBy(_._2).map(_._1)
It's pretty self-explanatory. It just zips each value with its position on its respective list, sorts by index, then pulls the values back out.
Here's how I would do it:
class ListTests extends FunSuite {
test("The three lists from his example") {
val l1 = List("a", "b", "c")
val l2 = List(1, 2, 3, 4)
val l3 = List("+", "#", "*", "§", "%")
// All lists together
val l = List(l1, l2, l3)
// Max length of a list (to pad the shorter ones)
val maxLen = l.map(_.size).max
// Wrap the elements in Option and pad with None
val padded = l.map { list => list.map(Some(_)) ++ Stream.continually(None).take(maxLen - list.size) }
// Transpose
val trans = padded.transpose
// Flatten the lists then flatten the options
val result = trans.flatten.flatten
// Viola
assert(List("a", 1, "+", "b", 2, "#", "c", 3, "*", 4, "§", "%") === result)
}
}
Here's an imperative solution if efficiency is paramount:
def combine[T](xss: List[List[T]]): List[T] = {
val b = List.newBuilder[T]
var its = xss.map(_.iterator)
while (!its.isEmpty) {
its = its.filter(_.hasNext)
its.foreach(b += _.next)
}
b.result
}
You can use padTo, transpose, and flatten to good effect here:
lists.map(_.map(Some(_)).padTo(lists.map(_.length).max, None)).transpose.flatten.flatten
Here's a small recursive solution.
def flatList(lists: List[List[Any]]) = {
def loop(output: List[Any], xss: List[List[Any]]): List[Any] = (xss collect { case x :: xs => x }) match {
case Nil => output
case heads => loop(output ::: heads, xss.collect({ case x :: xs => xs }))
}
loop(List[Any](), lists)
}
And here is a simple streams approach which can cope with an arbitrary sequence of sequences, each of potentially infinite length.
def flatSeqs[A](ssa: Seq[Seq[A]]): Stream[A] = {
def seqs(xss: Seq[Seq[A]]): Stream[Seq[A]] = xss collect { case xs if !xs.isEmpty => xs } match {
case Nil => Stream.empty
case heads => heads #:: seqs(xss collect { case xs if !xs.isEmpty => xs.tail })
}
seqs(ssa).flatten
}
Here's something short but not exceedingly efficient:
def heads[A](xss: List[List[A]]) = xss.map(_.splitAt(1)).unzip
def interleave[A](xss: List[List[A]]) = Iterator.
iterate(heads(xss)){ case (_, tails) => heads(tails) }.
map(_._1.flatten).
takeWhile(! _.isEmpty).
flatten.toList
Here's a recursive solution that's O(n). The accepted solution (using sort) is O(nlog(n)). Some testing I've done suggests the second solution using transpose is also O(nlog(n)) due to the implementation of transpose. The use of reverse below looks suspicious (since it's an O(n) operation itself) but convince yourself that it either can't be called too often or on too-large lists.
def intercalate[T](lists: List[List[T]]) : List[T] = {
def intercalateHelper(newLists: List[List[T]], oldLists: List[List[T]], merged: List[T]): List[T] = {
(newLists, oldLists) match {
case (Nil, Nil) => merged
case (Nil, zss) => intercalateHelper(zss.reverse, Nil, merged)
case (Nil::xss, zss) => intercalateHelper(xss, zss, merged)
case ( (y::ys)::xss, zss) => intercalateHelper(xss, ys::zss, y::merged)
}
}
intercalateHelper(lists, List.empty, List.empty).reverse
}

Encoding recursive tree-creation with while loop + stacks

I'm a bit embarassed to admit this, but I seem to be pretty stumped by what should be a simple programming problem. I'm building a decision tree implementation, and have been using recursion to take a list of labeled samples, recursively split the list in half, and turn it into a tree.
Unfortunately, with deep trees I run into stack overflow errors (ha!), so my first thought was to use continuations to turn it into tail recursion. Unfortunately Scala doesn't support that kind of TCO, so the only solution is to use a trampoline. A trampoline seems kinda inefficient and I was hoping there would be some simple stack-based imperative solution to this problem, but I'm having a lot of trouble finding it.
The recursive version looks sort of like (simplified):
private def trainTree(samples: Seq[Sample], usedFeatures: Set[Int]): DTree = {
if (shouldStop(samples)) {
DTLeaf(makeProportions(samples))
} else {
val featureIdx = getSplittingFeature(samples, usedFeatures)
val (statsWithFeature, statsWithoutFeature) = samples.partition(hasFeature(featureIdx, _))
DTBranch(
trainTree(statsWithFeature, usedFeatures + featureIdx),
trainTree(statsWithoutFeature, usedFeatures + featureIdx),
featureIdx)
}
}
So basically I'm recursively subdividing the list into two according to some feature of the data, and passing through a list of used features so I don't repeat - that's all handled in the "getSplittingFeature" function so we can ignore it. The code is really simple! Still, I'm having trouble figuring out a stack-based solution that doesn't just use closures and effectively become a trampoline. I know we'll at least have to keep around little "frames" of arguments in the stack but I would like to avoid closure calls.
I get that I should be writing out explicitly what the callstack and program counter handle for me implicitly in the recursive solution, but I'm having trouble doing that without continuations. At this point it's hardly even about efficiency, I'm just curious. So please, no need to remind me that premature optimization is the root of all evil and the trampoline-based solution will probably work just fine. I know it probably will - this is basically a puzzle for it's own sake.
Can anyone tell me what the canonical while-loop-and-stack-based solution to this sort of thing is?
UPDATE: Based on Thipor Kong's excellent solution, I've coded up a while-loops/stacks/hashtable based implementation of the algorithm that should be a direct translation of the recursive version. This is exactly what I was looking for:
FINAL UPDATE: I've used sequential integer indices, as well as putting everything back into arrays instead of maps for performance, added maxDepth support, and finally have a solution with the same performance as the recursive version (not sure about memory usage but I would guess less):
private def trainTreeNoMaxDepth(startingSamples: Seq[Sample], startingMaxDepth: Int): DTree = {
// Use arraybuffer as dense mutable int-indexed map - no IndexOutOfBoundsException, just expand to fit
type DenseIntMap[T] = ArrayBuffer[T]
def updateIntMap[#specialized T](ab: DenseIntMap[T], idx: Int, item: T, dfault: T = null.asInstanceOf[T]) = {
if (ab.length <= idx) {ab.insertAll(ab.length, Iterable.fill(idx - ab.length + 1)(dfault)) }
ab.update(idx, item)
}
var currentChildId = 0 // get childIdx or create one if it's not there already
def child(childMap: DenseIntMap[Int], heapIdx: Int) =
if (childMap.length > heapIdx && childMap(heapIdx) != -1) childMap(heapIdx)
else {currentChildId += 1; updateIntMap(childMap, heapIdx, currentChildId, -1); currentChildId }
// go down
val leftChildren, rightChildren = new DenseIntMap[Int]() // heapIdx -> childHeapIdx
val todo = Stack((startingSamples, Set.empty[Int], startingMaxDepth, 0)) // samples, usedFeatures, maxDepth, heapIdx
val branches = new Stack[(Int, Int)]() // heapIdx, featureIdx
val nodes = new DenseIntMap[DTree]() // heapIdx -> node
while (!todo.isEmpty) {
val (samples, usedFeatures, maxDepth, heapIdx) = todo.pop()
if (shouldStop(samples) || maxDepth == 0) {
updateIntMap(nodes, heapIdx, DTLeaf(makeProportions(samples)))
} else {
val featureIdx = getSplittingFeature(samples, usedFeatures)
val (statsWithFeature, statsWithoutFeature) = samples.partition(hasFeature(featureIdx, _))
todo.push((statsWithFeature, usedFeatures + featureIdx, maxDepth - 1, child(leftChildren, heapIdx)))
todo.push((statsWithoutFeature, usedFeatures + featureIdx, maxDepth - 1, child(rightChildren, heapIdx)))
branches.push((heapIdx, featureIdx))
}
}
// go up
while (!branches.isEmpty) {
val (heapIdx, featureIdx) = branches.pop()
updateIntMap(nodes, heapIdx, DTBranch(nodes(child(leftChildren, heapIdx)), nodes(child(rightChildren, heapIdx)), featureIdx))
}
nodes(0)
}
Just store the binary tree in an array, as described on Wikipedia: For node i, the left child goes into 2*i+1 and the right child in to 2*i+2. When doing "down", you keep a collection of todos, that still have to be splitted to reach a leaf. Once you've got only leafs, to go upward (from right to left in the array) to build the decision nodes:
Update: A cleaned up version, that also supports the features stored int the branches (type parameter B) and that is more functional/fully pure and that supports sparse trees with a map as suggested by ron.
Update2-3: Make economical use of name space for node ids and abstract over type of ids to allow of large trees. Take node ids from Stream.
sealed trait DTree[A, B]
case class DTLeaf[A, B](a: A, b: B) extends DTree[A, B]
case class DTBranch[A, B](left: DTree[A, B], right: DTree[A, B], b: B) extends DTree[A, B]
def mktree[A, B, Id](a: A, b: B, split: (A, B) => Option[(A, A, B)], ids: Stream[Id]) = {
#tailrec
def goDown(todo: Seq[(A, B, Id)], branches: Seq[(Id, B, Id, Id)], leafs: Map[Id, DTree[A, B]], ids: Stream[Id]): (Seq[(Id, B, Id, Id)], Map[Id, DTree[A, B]]) =
todo match {
case Nil => (branches, leafs)
case (a, b, id) :: rest =>
split(a, b) match {
case None =>
goDown(rest, branches, leafs + (id -> DTLeaf(a, b)), ids)
case Some((left, right, b2)) =>
val leftId #:: rightId #:: idRest = ids
goDown((right, b2, rightId) +: (left, b2, leftId) +: rest, (id, b2, leftId, rightId) +: branches, leafs, idRest)
}
}
#tailrec
def goUp[A, B](branches: Seq[(Id, B, Id, Id)], nodes: Map[Id, DTree[A, B]]): Map[Id, DTree[A, B]] =
branches match {
case Nil => nodes
case (id, b, leftId, rightId) :: rest =>
goUp(rest, nodes + (id -> DTBranch(nodes(leftId), nodes(rightId), b)))
}
val rootId #:: restIds = ids
val (branches, leafs) = goDown(Seq((a, b, rootId)), Seq(), Map(), restIds)
goUp(branches, leafs)(rootId)
}
// try it out
def split(xs: Seq[Int], b: Int) =
if (xs.size > 1) {
val (left, right) = xs.splitAt(xs.size / 2)
Some((left, right, b + 1))
} else {
None
}
val tree = mktree(0 to 1000, 0, split _, Stream.from(0))
println(tree)