N-Tree Traversal with Scala Causes Stack Overflow - scala

I am attempting to return a list of widgets from an N-tree data structure. In my unit test, if i have roughly about 2000 widgets each with a single dependency, i'll encounter a stack overflow. What I think is happening is the for loop is causing my tree traversal to not be tail recursive. what's a better way of writing this in scala? Here's my function:
protected def getWidgetTree(key: String) : ListBuffer[Widget] = {
def traverseTree(accumulator: ListBuffer[Widget], current: Widget) : ListBuffer[Widget] = {
accumulator.append(current)
if (!current.hasDependencies) {
accumulator
} else {
for (dependencyKey <- current.dependencies) {
if (accumulator.findIndexOf(_.name == dependencyKey) == -1) {
traverseTree(accumulator, getWidget(dependencyKey))
}
}
accumulator
}
}
traverseTree(ListBuffer[Widget](), getWidget(key))
}

The reason it's not tail-recursive is that you are making multiple recursive calls inside your function. To be tail-recursive, a recursive call can only be the last expression in the function body. After all, the whole point is that it works like a while-loop (and, thus, can be transformed into a loop). A loop can't call itself multiple times within a single iteration.
To do a tree traversal like this, you can use a queue to carry forward the nodes that need to be visited.
Assume we have this tree:
// 1
// / \
// 2 5
// / \
// 3 4
Represented with this simple data structure:
case class Widget(name: String, dependencies: List[String]) {
def hasDependencies = dependencies.nonEmpty
}
And we have this map pointing to each node:
val getWidget = List(
Widget("1", List("2", "5")),
Widget("2", List("3", "4")),
Widget("3", List()),
Widget("4", List()),
Widget("5", List()))
.map { w => w.name -> w }.toMap
Now we can rewrite your method to be tail-recursive:
def getWidgetTree(key: String): List[Widget] = {
#tailrec
def traverseTree(queue: List[String], accumulator: List[Widget]): List[Widget] = {
queue match {
case currentKey :: queueTail => // the queue is not empty
val current = getWidget(currentKey) // get the element at the front
val newQueueItems = // filter out the dependencies already known
current.dependencies.filterNot(dependencyKey =>
accumulator.exists(_.name == dependencyKey) && !queue.contains(dependencyKey))
traverseTree(newQueueItems ::: queueTail, current :: accumulator) //
case Nil => // the queue is empty
accumulator.reverse // we're done
}
}
traverseTree(key :: Nil, List[Widget]())
}
And test it out:
for (k <- 1 to 5)
println(getWidgetTree(k.toString).map(_.name))
prints:
ListBuffer(1, 2, 3, 4, 5)
ListBuffer(2, 3, 4)
ListBuffer(3)
ListBuffer(4)
ListBuffer(5)

For the same example as in #dhg's answer, an equivalent tail recursive function with no mutable state (the ListBuffer) would be:
case class Widget(name: String, dependencies: List[String])
val getWidget = List(
Widget("1", List("2", "5")),
Widget("2", List("3", "4")),
Widget("3", List()),
Widget("4", List()),
Widget("5", List())).map { w => w.name -> w }.toMap
def getWidgetTree(key: String): List[Widget] = {
def addIfNotAlreadyContained(widgetList: List[Widget], widgetNameToAdd: String): List[Widget] = {
if (widgetList.find(_.name == widgetNameToAdd).isDefined) widgetList
else widgetList :+ getWidget(widgetNameToAdd)
}
#tailrec
def traverseTree(currentWidgets: List[Widget], acc: List[Widget]): List[Widget] = currentWidgets match {
case Nil => {
// If there are no more widgets in this branch return what we've traversed so far
acc
}
case Widget(name, Nil) :: rest => {
// If the first widget is a leaf traverse the rest and add the leaf to the list of traversed
traverseTree(rest, addIfNotAlreadyContained(acc, name))
}
case Widget(name, dependencies) :: rest => {
// If the first widget is a parent, traverse it's children and the rest and add it to the list of traversed
traverseTree(dependencies.map(getWidget) ++ rest, addIfNotAlreadyContained(acc, name))
}
}
val root = getWidget(key)
traverseTree(root.dependencies.map(getWidget) :+ root, List[Widget]())
}
For the same test case
for (k <- 1 to 5)
println(getWidgetTree(k.toString).map(_.name).toList.sorted)
Gives you:
List(2, 3, 4, 5, 1)
List(3, 4, 2)
List(3)
List(4)
List(5)
Note that this is postorder not preorder traversal.

Awesome! thanks. I didn't know about the #tailrec annotation. that's a pretty cool little gem there. I had to tweak the solution just a little bit because a widget with a self reference was resulting in in an endless loop. also newQueueItems was an Iterable when the call to traverseTree was expecting a List, so i had to toList that bit.
def getWidgetTree(key: String): List[Widget] = {
#tailrec
def traverseTree(queue: List[String], accumulator: List[Widget]): List[Widget] = {
queue match {
case currentKey :: queueTail => // the queue is not empty
val current = getWidget(currentKey) // get the element at the front
val newQueueItems = // filter out the dependencies already known
current.dependencies.filter(dependencyKey =>
!accumulator.exists(_.name == dependencyKey) && !queue.contains(dependencyKey)).toList
traverseTree(newQueueItems ::: queueTail, current :: accumulator) //
case Nil => // the queue is empty
accumulator.reverse // we're done
}
}
traverseTree(key :: Nil, List[Widget]())
}

Related

Conditional concatenation of iterator elements - A Scala idiomatic solution

I have an Iterator of Strings and would like to concatenate each element preceding one that matches a predicate, e.g. for an Iterator of
Iterator("a", "b", "c break", "d break", "e")
and a predicate of
!line.endsWith("break")
I would like to print out
(Group: 0): a-b-c break
(Group: 1): d break
(Group: 2): e
(without needing to hold in memory more than a single group at a time)
I know I can achieve this with an iterator like below, but there has to be a more "Scala" way of writing this, right?
import scala.collection.mutable.ListBuffer
object IteratingAndAccumulating extends App {
class AccumulatingIterator(lines: Iterator[String])extends Iterator[ListBuffer[String]] {
override def hasNext: Boolean = lines.hasNext
override def next(): ListBuffer[String] = getNextLine(lines, new ListBuffer[String])
def getNextLine(lines: Iterator[String], accumulator: ListBuffer[String]): ListBuffer[String] = {
val line = lines.next
accumulator += line
if (line.endsWith("break") || !lines.hasNext) accumulator
else getNextLine(lines, accumulator)
}
}
new AccumulatingIterator(Iterator("a", "b", "c break", "d break", "e"))
.map(_.mkString("-")).zipWithIndex.foreach{
case (conc, i) =>
println(s"(Group: $i): $conc")
}
}
many thanks,
Fil
Here is a simple solution if you don't mind loading the entire contents into memory at once:
val lines: List[List[String]] = it.foldLeft(List(List.empty[String])) {
case (head::tail, x) if predicate(x) => Nil :: (x::head) :: tail
case (head::tail, x) => (x::head ) :: tail
}.dropWhile(_.isEmpty).map(_.reverse).reverse
If you would rather iterate through the strings and groups one-by-one, it gets a little bit more involved:
// first "instrument" the iterator, by "demarcating" group boundaries with None:
val instrumented: Iterator[Option[String]] = it.flatMap {
case x if predicate(x) => Seq(Some(x), None)
case x => Seq(Some(x))
}
// And now, wrap it around into another iterator, constructing groups:
val lines: Iterator[Iterator[String]] = Iterator.continually {
instrumented.takeWhile(_.nonEmpty).flatten
}.takeWhile(_.nonEmpty)

Working scala code using a var in a pure function. Is this possible without a var?

Is it possible (or even worthwhile) to try to write the below code block without a var? It works with a var. This is not for an interview, it's my first attempt at scala (came from java).
The problem: Fit people as close to the front of a theatre as possible, while keeping each request (eg. Jones, 4 tickets) in a single theatre section. The theatre sections, starting at the front, are sized 6, 6, 3, 5, 5... and so on. I'm trying to accomplish this by putting together all of the potential groups of ticket requests, and then choosing the best fitting group per section.
Here are the classes. A SeatingCombination is one possible combination of SeatingRequest (just the IDs) and the sum of their ticketCount(s):
class SeatingCombination(val idList: List[Int], val seatCount: Int){}
class SeatingRequest(val id: Int, val partyName: String, val ticketCount: Int){}
class TheatreSection(val sectionSize: Int, rowNumber: Int, sectionNumber: Int) {
def id: String = rowNumber.toString + "_"+ sectionNumber.toString;
}
By the time we get to the below function...
1.) all of the possible combinations of SeatingRequest are in a list of SeatingCombination and ordered by descending size.
2.) all of the TheatreSection are listed in order.
def getSeatingMap(groups: List[SeatingCombination], sections: List[TheatreSection]): HashMap[Int, TheatreSection] = {
var seatedMap = new HashMap[Int, TheatreSection]
for (sect <- sections) {
val bestFitOpt = groups.find(g => { g.seatCount <= sect.sectionSize && !isAnyListIdInMap(seatedMap, g.idList) })
bestFitOpt.filter(_.idList.size > 0).foreach(_.idList.foreach(seatedMap.update(_, sect)))
}
seatedMap
}
def isAnyListIdInMap(map: HashMap[Int, TheatreSection], list: List[Int]): Boolean = {
(for (id <- list) yield !map.get(id).isEmpty).reduce(_ || _)
}
I wrote the rest of the program without a var, but in this iterative section it seems impossible. Maybe with my implementation strategy it's impossible. From what else I've read, a var in a pure function is still functional. But it's been bothering me I can't think of how to remove the var, because my textbook told me to try to avoid them, and I don't know what I don't know.
You can use foldLeft to iterate on sections with a running state (and again, inside, on your state to add iteratively all the ids in a section):
sections.foldLeft(Map.empty[Int, TheatreSection]){
case (seatedMap, sect) =>
val bestFitOpt = groups.find(g => g.seatCount <= sect.sectionSize && !isAnyListIdInMap(seatedMap, g.idList))
bestFitOpt.
filter(_.idList.size > 0).toList. //convert option to list
flatMap(_.idList). // flatten list from option and idList
foldLeft(seatedMap)(_ + (_ -> sect))) // add all ids to the map with sect as value
}
By the way, you can simplify the second method using exists and map.contains:
def isAnyListIdInMap(map: HashMap[Int, TheatreSection], list: List[Int]): Boolean = {
list.exists(id => map.contains(id))
}
list.exists(predicate: Int => Boolean) is a Boolean which is true if the predicate is true for any element in list.
map.contains(key) checks if map is defined at key.
If you want to be even more concise, you don't need to give a name to the argument of the predicate:
list.exists(map.contains)
Simply changing var to val should do it :)
I think, you may be asking about getting rid of the mutable map, not of the var (it doesn't need to be var in your code).
Things like this are usually written recursively in scala or using foldLeft, like other answers suggest. Here is a recursive version:
#tailrec
def getSeatingMap(
groups: List[SeatingCombination],
sections: List[TheatreSection],
result: Map[Int, TheatreSection] = Map.empty): Map[Int, TheatreSection] = sections match {
case Nil => result
case head :: tail =>
val seated = groups
.iterator
.filter(_.idList.nonEmpty)
.filterNot(_.idList.find(result.contains).isDefined)
.find(_.seatCount <= head.sectionSize)
.fold(Nil)(_.idList.map(id => id -> sect))
getSeatingMap(groups, tail, result ++ seated)
}
btw, I don't think you need to test every id in list for presence in the map - should suffice to just look at the first one. You could also make it a bit more efficient, probably, if instead of checking the map every time to see if the group is already seated, you'd just drop it from the input list as soon as the section is assigned.
#tailrec
def selectGroup(
sect: TheatreSection,
groups: List[SeatingCombination],
result: List[SeatingCombination] = Nil
): (List[(Int, TheatreSection)], List[SeatingCombination]) = groups match {
case Nil => (Nil, result)
case head :: tail
if(head.idList.nonEmpty && head.seatCount <= sect.sectionSize) => (head.idList.map(_ -> sect), result.reverse ++ tail)
case head :: tail => selectGroup(sect, tail, head :: result)
}
and then in getSeatingMap:
...
case head :: tail =>
val(seated, remaining) => selectGroup(sect, groups)
getSeatingMap(remaining, tail, result ++ seated)
Here is how I was able to achieve without using the mutable.HashMap, the suggestion by the comment to use foldLeft was used to do it:
class SeatingCombination(val idList: List[Int], val seatCount: Int){}
class SeatingRequest(val id: Int, val partyName: String, val ticketCount: Int){}
class TheatreSection(val sectionSize: Int, rowNumber: Int, sectionNumber: Int) {
def id: String = rowNumber.toString + "_"+ sectionNumber.toString;
}
def getSeatingMap(groups: List[SeatingCombination], sections: List[TheatreSection]): Map[Int, TheatreSection] = {
sections.foldLeft(Map.empty[Int, TheatreSection]) { (m, sect) =>
val bestFitOpt = groups.find(g => {
g.seatCount <= sect.sectionSize && !isAnyListIdInMap(m, g.idList)
}).filter(_.idList.nonEmpty)
val newEntries = bestFitOpt.map(_.idList.map(_ -> sect)).getOrElse(List.empty)
m ++ newEntries
}
}
def isAnyListIdInMap(map: Map[Int, TheatreSection], list: List[Int]): Boolean = {
(for (id <- list) yield map.get(id).isDefined).reduce(_ || _)
}

Simple functionnal way for grouping successive elements? [duplicate]

I'm trying to 'group' a string into segments, I guess this example would explain it more succintly
scala> val str: String = "aaaabbcddeeeeeeffg"
... (do something)
res0: List("aaaa","bb","c","dd","eeeee","ff","g")
I can thnk of a few ways to do this in an imperative style (with vars and stepping through the string to find groups) but I was wondering if any better functional solution could
be attained? I've been looking through the Scala API but there doesn't seem to be something that fits my needs.
Any help would be appreciated
You can split the string recursively with span:
def s(x : String) : List[String] = if(x.size == 0) Nil else {
val (l,r) = x.span(_ == x(0))
l :: s(r)
}
Tail recursive:
#annotation.tailrec def s(x : String, y : List[String] = Nil) : List[String] = {
if(x.size == 0) y.reverse
else {
val (l,r) = x.span(_ == x(0))
s(r, l :: y)
}
}
Seems that all other answers are very concentrated on collection operations. But pure string + regex solution is much simpler:
str split """(?<=(\w))(?!\1)""" toList
In this regex I use positive lookbehind and negative lookahead for the captured char
def group(s: String): List[String] = s match {
case "" => Nil
case s => s.takeWhile(_==s.head) :: group(s.dropWhile(_==s.head))
}
Edit: Tail recursive version:
def group(s: String, result: List[String] = Nil): List[String] = s match {
case "" => result reverse
case s => group(s.dropWhile(_==s.head), s.takeWhile(_==s.head) :: result)
}
can be used just like the other because the second parameter has a default value and thus doesnt have to be supplied.
Make it one-liner:
scala> val str = "aaaabbcddddeeeeefff"
str: java.lang.String = aaaabbcddddeeeeefff
scala> str.groupBy(identity).map(_._2)
res: scala.collection.immutable.Iterable[String] = List(eeeee, fff, aaaa, bb, c, dddd)
UPDATE:
As #Paul mentioned about the order here is updated version:
scala> str.groupBy(identity).toList.sortBy(_._1).map(_._2)
res: List[String] = List(aaaa, bb, c, dddd, eeeee, fff)
You could use some helper functions like this:
val str = "aaaabbcddddeeeeefff"
def zame(chars:List[Char]) = chars.partition(_==chars.head)
def q(chars:List[Char]):List[List[Char]] = chars match {
case Nil => Nil
case rest =>
val (thesame,others) = zame(rest)
thesame :: q(others)
}
q(str.toList) map (_.mkString)
This should do the trick, right? No doubt it can be cleaned up into one-liners even further
A functional* solution using fold:
def group(s : String) : Seq[String] = {
s.tail.foldLeft(Seq(s.head.toString)) { case (carry, elem) =>
if ( carry.last(0) == elem ) {
carry.init :+ (carry.last + elem)
}
else {
carry :+ elem.toString
}
}
}
There is a lot of cost hidden in all those sequence operations performed on strings (via implicit conversion). I guess the real complexity heavily depends on the kind of Seq strings are converted to.
(*) Afaik all/most operations in the collection library depend in iterators, an imho inherently unfunctional concept. But the code looks functional, at least.
Starting Scala 2.13, List is now provided with the unfold builder which can be combined with String::span:
List.unfold("aaaabbaaacdeeffg") {
case "" => None
case rest => Some(rest.span(_ == rest.head))
}
// List[String] = List("aaaa", "bb", "aaa", "c", "d", "ee", "ff", "g")
or alternatively, coupled with Scala 2.13's Option#unless builder:
List.unfold("aaaabbaaacdeeffg") {
rest => Option.unless(rest.isEmpty)(rest.span(_ == rest.head))
}
// List[String] = List("aaaa", "bb", "aaa", "c", "d", "ee", "ff", "g")
Details:
Unfold (def unfold[A, S](init: S)(f: (S) => Option[(A, S)]): List[A]) is based on an internal state (init) which is initialized in our case with "aaaabbaaacdeeffg".
For each iteration, we span (def span(p: (Char) => Boolean): (String, String)) this internal state in order to find the prefix containing the same symbol and produce a (String, String) tuple which contains the prefix and the rest of the string. span is very fortunate in this context as it produces exactly what unfold expects: a tuple containing the next element of the list and the new internal state.
The unfolding stops when the internal state is "" in which case we produce None as expected by unfold to exit.
Edit: Have to read more carefully. Below is no functional code.
Sometimes, a little mutable state helps:
def group(s : String) = {
var tmp = ""
val b = Seq.newBuilder[String]
s.foreach { c =>
if ( tmp != "" && tmp.head != c ) {
b += tmp
tmp = ""
}
tmp += c
}
b += tmp
b.result
}
Runtime O(n) (if segments have at most constant length) and tmp.+= probably creates the most overhead. Use a string builder instead for strict runtime in O(n).
group("aaaabbcddeeeeeeffg")
> Seq[String] = List(aaaa, bb, c, dd, eeeeee, ff, g)
If you want to use scala API you can use the built in function for that:
str.groupBy(c => c).values
Or if you mind it being sorted and in a list:
str.groupBy(c => c).values.toList.sorted

Combining multiple Lists of arbitrary length

I am looking for an approach to join multiple Lists in the following manner:
ListA a b c
ListB 1 2 3 4
ListC + # * § %
..
..
..
Resulting List: a 1 + b 2 # c 3 * 4 § %
In Words: The elements in sequential order, starting at first list combined into the resulting list. An arbitrary amount of input lists could be there varying in length.
I used multiple approaches with variants of zip, sliding iterators but none worked and especially took care of varying list lengths. There has to be an elegant way in scala ;)
val lists = List(ListA, ListB, ListC)
lists.flatMap(_.zipWithIndex).sortBy(_._2).map(_._1)
It's pretty self-explanatory. It just zips each value with its position on its respective list, sorts by index, then pulls the values back out.
Here's how I would do it:
class ListTests extends FunSuite {
test("The three lists from his example") {
val l1 = List("a", "b", "c")
val l2 = List(1, 2, 3, 4)
val l3 = List("+", "#", "*", "§", "%")
// All lists together
val l = List(l1, l2, l3)
// Max length of a list (to pad the shorter ones)
val maxLen = l.map(_.size).max
// Wrap the elements in Option and pad with None
val padded = l.map { list => list.map(Some(_)) ++ Stream.continually(None).take(maxLen - list.size) }
// Transpose
val trans = padded.transpose
// Flatten the lists then flatten the options
val result = trans.flatten.flatten
// Viola
assert(List("a", 1, "+", "b", 2, "#", "c", 3, "*", 4, "§", "%") === result)
}
}
Here's an imperative solution if efficiency is paramount:
def combine[T](xss: List[List[T]]): List[T] = {
val b = List.newBuilder[T]
var its = xss.map(_.iterator)
while (!its.isEmpty) {
its = its.filter(_.hasNext)
its.foreach(b += _.next)
}
b.result
}
You can use padTo, transpose, and flatten to good effect here:
lists.map(_.map(Some(_)).padTo(lists.map(_.length).max, None)).transpose.flatten.flatten
Here's a small recursive solution.
def flatList(lists: List[List[Any]]) = {
def loop(output: List[Any], xss: List[List[Any]]): List[Any] = (xss collect { case x :: xs => x }) match {
case Nil => output
case heads => loop(output ::: heads, xss.collect({ case x :: xs => xs }))
}
loop(List[Any](), lists)
}
And here is a simple streams approach which can cope with an arbitrary sequence of sequences, each of potentially infinite length.
def flatSeqs[A](ssa: Seq[Seq[A]]): Stream[A] = {
def seqs(xss: Seq[Seq[A]]): Stream[Seq[A]] = xss collect { case xs if !xs.isEmpty => xs } match {
case Nil => Stream.empty
case heads => heads #:: seqs(xss collect { case xs if !xs.isEmpty => xs.tail })
}
seqs(ssa).flatten
}
Here's something short but not exceedingly efficient:
def heads[A](xss: List[List[A]]) = xss.map(_.splitAt(1)).unzip
def interleave[A](xss: List[List[A]]) = Iterator.
iterate(heads(xss)){ case (_, tails) => heads(tails) }.
map(_._1.flatten).
takeWhile(! _.isEmpty).
flatten.toList
Here's a recursive solution that's O(n). The accepted solution (using sort) is O(nlog(n)). Some testing I've done suggests the second solution using transpose is also O(nlog(n)) due to the implementation of transpose. The use of reverse below looks suspicious (since it's an O(n) operation itself) but convince yourself that it either can't be called too often or on too-large lists.
def intercalate[T](lists: List[List[T]]) : List[T] = {
def intercalateHelper(newLists: List[List[T]], oldLists: List[List[T]], merged: List[T]): List[T] = {
(newLists, oldLists) match {
case (Nil, Nil) => merged
case (Nil, zss) => intercalateHelper(zss.reverse, Nil, merged)
case (Nil::xss, zss) => intercalateHelper(xss, zss, merged)
case ( (y::ys)::xss, zss) => intercalateHelper(xss, ys::zss, y::merged)
}
}
intercalateHelper(lists, List.empty, List.empty).reverse
}

How to use takeWhile with an Iterator in Scala

I have a Iterator of elements and I want to consume them until a condition is met in the next element, like:
val it = List(1,1,1,1,2,2,2).iterator
val res1 = it.takeWhile( _ == 1).toList
val res2 = it.takeWhile(_ == 2).toList
res1 gives an expected List(1,1,1,1) but res2 returns List(2,2) because iterator had to check the element in position 4.
I know that the list will be ordered so there is no point in traversing the whole list like partition does. I like to finish as soon as the condition is not met. Is there any clever way to do this with Iterators? I can not do a toList to the iterator because it comes from a very big file.
The simplest solution I found:
val it = List(1,1,1,1,2,2,2).iterator
val (r1, it2) = it.span( _ == 1)
println(s"group taken is: ${r1.toList}\n rest is: ${it2.toList}")
output:
group taken is: List(1, 1, 1, 1)
rest is: List(2, 2, 2)
Very short but further you have to use new iterator.
With any immutable collection it would be similar:
use takeWhile when you want only some prefix of collection,
use span when you need rest also.
With my other answer (which I've left separate as they are largely unrelated), I think you can implement groupWhen on Iterator as follows:
def groupWhen[A](itr: Iterator[A])(p: (A, A) => Boolean): Iterator[List[A]] = {
#annotation.tailrec
def groupWhen0(acc: Iterator[List[A]], itr: Iterator[A])(p: (A, A) => Boolean): Iterator[List[A]] = {
val (dup1, dup2) = itr.duplicate
val pref = ((dup1.sliding(2) takeWhile { case Seq(a1, a2) => p(a1, a2) }).zipWithIndex collect {
case (seq, 0) => seq
case (Seq(_, a), _) => Seq(a)
}).flatten.toList
val newAcc = if (pref.isEmpty) acc else acc ++ Iterator(pref)
if (dup2.nonEmpty)
groupWhen0(newAcc, dup2 drop (pref.length max 1))(p)
else newAcc
}
groupWhen0(Iterator.empty, itr)(p)
}
When I run it on an example:
println( groupWhen(List(1,1,1,1,3,4,3,2,2,2).iterator)(_ == _).toList )
I get List(List(1, 1, 1, 1), List(2, 2, 2))
I had a similar need, but the solution from #oxbow_lakes does not take into account the situation when the list has only one element, or even if the list contains elements that are not repeated. Also, that solution doesn't lend itself well to an infinite iterator (it wants to "see" all the elements before it gives you a result).
What I needed was the ability to group sequential elements that match a predicate, but also include the single elements (I can always filter them out if I don't need them). I needed those groups to be delivered continuously, without having to wait for the original iterator to be completely consumed before they are produced.
I came up with the following approach which works for my needs, and thought I should share:
implicit class IteratorEx[+A](itr: Iterator[A]) {
def groupWhen(p: (A, A) => Boolean): Iterator[List[A]] = new AbstractIterator[List[A]] {
val (it1, it2) = itr.duplicate
val ritr = new RewindableIterator(it1, 1)
override def hasNext = it2.hasNext
override def next() = {
val count = (ritr.rewind().sliding(2) takeWhile {
case Seq(a1, a2) => p(a1, a2)
case _ => false
}).length
(it2 take (count + 1)).toList
}
}
}
The above is using a few helper classes:
abstract class AbstractIterator[A] extends Iterator[A]
/**
* Wraps a given iterator to add the ability to remember the last 'remember' values
* From any position the iterator can be rewound (can go back) at most 'remember' values,
* such that when calling 'next()' the memoized values will be provided as if they have not
* been iterated over before.
*/
class RewindableIterator[A](it: Iterator[A], remember: Int) extends Iterator[A] {
private var memory = List.empty[A]
private var memoryIndex = 0
override def next() = {
if (memoryIndex < memory.length) {
val next = memory(memoryIndex)
memoryIndex += 1
next
} else {
val next = it.next()
memory = memory :+ next
if (memory.length > remember)
memory = memory drop 1
memoryIndex = memory.length
next
}
}
def canRewind(n: Int) = memoryIndex - n >= 0
def rewind(n: Int) = {
require(memoryIndex - n >= 0, "Attempted to rewind past 'remember' limit")
memoryIndex -= n
this
}
def rewind() = {
memoryIndex = 0
this
}
override def hasNext = it.hasNext
}
Example use:
List(1,2,2,3,3,3,4,5,5).iterator.groupWhen(_ == _).toList
gives: List(List(1), List(2, 2), List(3, 3, 3), List(4), List(5, 5))
If you want to filter out the single elements, just apply a filter or withFilter after groupWhen
Stream.continually(Random.nextInt(100)).iterator
.groupWhen(_ + _ == 100).withFilter(_.length > 1).take(3).toList
gives: List(List(34, 66), List(87, 13), List(97, 3))
You could use method toStream on Iterator.
Stream is a lazy equivalent of List.
As you can see from implementation of toStream it creates a Stream without traversing the whole Iterator.
Stream keeps all element in memory. You should localize usage of link to Stream in some local scope to prevent memory leaking.
With Stream you should use span like this:
val (res1, rest1) = stream.span(_ == 1)
val (res2, rest2) = rest1.span(_ == 2)
I'm guessing a bit here but by the statement "until a condition is met in the next element", it sounds like you might want to look at the groupWhen method on ListOps in scalaz
scala> import scalaz.syntax.std.list._
import scalaz.syntax.std.list._
scala> List(1,1,1,1,2,2,2) groupWhen (_ == _)
res1: List[List[Int]] = List(List(1, 1, 1, 1), List(2, 2, 2))
Basically this "chunks" up the input sequence upon a condition (a (A, A) => Boolean) being met between an element and its successor. In the example above the condition is equality, so, as long as an element is equal to its successor, they will be in the same chunk.