scala case class hashcode with circular references - scala

As Scala documentation says, "Case classes are compared by structure and not by reference".
I was trying to create a linked list with a loop in it, then add a node from this looped list to a mutable set. When using a regular class, it succeeded, but with a case class I hit a stack overflow.
I'm curious why Scala doesn't first check for reference before checking for structure. My thinking is that that if the reference is the same then the structure is guaranteed to be the same, so this can be a shortcut as well as allow for circular references.
Code:
object CaseClassHashcodeExample {
def main(args: Array[String]): Unit = {
val head = new ListNode(0)
head.appendToTail(1)
head.appendToTail(2)
head.appendToTail(3)
head.next.next.next = head.next
val set = collection.mutable.Set[ListNode]()
set.add(head)
assert(set(head))
}
}
case class ListNode(var data: Int, var next: ListNode = null) {
def appendToTail(d: Int): Unit = {
val end = new ListNode(d)
var cur = this
while (cur.next != null)
cur = cur.next
cur.next = end
}
}

I'm curious why Scala doesn't first check for reference before checking for structure
Well, what exactly do you want it to check? The current generated code looks something like
def hashCode() = 31 * data.hashCode() + next.hashCode()
One thing you could try is
def hashCode() = 31 * data.hashCode() + (if (next eq this) 0 else next.hashCode())
but it wouldn't actually help: in your case it isn't next which is the same as this, it's next.next.
next.hashCode doesn't know it's called from this.hashCode and so can't compare its own next to this.
You could create a helper method taking a set of "seen" objects into account:
def hashCode() = hashCode(Nil)
def hashCode(seen: List[ListNode]) = if (seen.exists(_ eq this)) 0 else 31 * data.hashCode() + next.hashCode(this :: seen)
But this both greatly slows down the common case, and is hard to actually get right.
EDIT: for equals, reference equality is checked first in case classes.
class NeverEq {
override def equals(other: Any) = false
}
case class Test(x: NeverEq)
val x = Test(new NeverEq())
println(x == x)
prints true and would be false if only structure was compared.
But it doesn't actually help with circular references. Let's say you have a ListNode type without data to simplify which implements equality like this:
def equals(other: Any) = (this eq other) || (other match {
case other: ListNode => this.next == other.next
})
and want to check if node1 == node2 where node1.next is node2 and vice versa:
node1 <--> node2
node1 == node2 reduces to (node1 eq node2) || (node1.next == node2.next).
node1 eq node2 is false, so that reduces to node1.next == node2.next, that is node2 == node1.
node2 eq node1 is false, so that reduces to node2.next == node1.next, that is node1 == node2...

Related

Generate all IP addresses given a string in scala

I was trying my hand at writing an IP generator given a string of numbers. The generator would take as an input a string of number such as "17234" and will return all possible list of ips as follows:
1.7.2.34
1.7.23.4
1.72.3.4
17.2.3.4
I attempted to write a snippet to do the generation as follows:
def genip(ip:String):Unit = {
def legal(ip:String):Boolean = (ip.size == 1) || (ip.size == 2) || (ip.size == 3)
def genips(ip:String,portion:Int,accum:String):Unit = portion match {
case 1 if legal(ip) => println(accum+ip)
case _ if portion > 1 => {
genips(ip.drop(1),portion-1,if(accum.size == 0) ip.take(1)+"." else accum+ip.take(1)+".")
genips(ip.drop(2),portion-1,if(accum.size == 0) ip.take(2)+"." else accum+ip.take(2)+".")
genips(ip.drop(3),portion-1,if(accum.size == 0) ip.take(3)+"." else accum+ip.take(3)+".")
}
case _ => return
}
genips(ip,4,"")
}
The idea is to partition the string into four octets and then further partition the octet into strings of size "1","2" and "3" and then recursively descend into the remaining string.
I am not sure if I am on the right track but it would be great if somebody could suggest a more functional way of accomplishing the same.
Thanks
Here is an alternative version of the attached code:
def generateIPs(digits : String) : Seq[String] = generateIPs(digits, 4)
private def generateIPs(digits : String, partsLeft : Int) : Seq[String] = {
if ( digits.size < partsLeft || digits.size > partsLeft * 3) {
Nil
} else if(partsLeft == 1) {
Seq(digits)
} else {
(1 to 3).map(n => generateIPs(digits.drop(n), partsLeft - 1)
.map(digits.take(n) + "." + _)
).flatten
}
}
println("Results:\n" + generateIPs("17234").mkString("\n"))
Major changes:
Methods now return the collection of strings (rather than Unit), so they are proper functions (rather than working of side effects) and can be easily tested;
Avoiding repeating the same code 3 times depending on the size of the bunch of numbers we take;
Not passing accumulated interim result as a method parameter - in this case it doesn't have sense since you'll have at most 4 recursive calls and it's easier to read without it, though as you're loosing the tail recursion in many case it might be reasonable to leave it.
Note: The last map statement is a good candidate to be replaced by for comprehension, which many developers find easier to read and reason about, though I will leave it as an exercise :)
You code is the right idea; I'm not sure making it functional really helps anything, but I'll show both functional and side-effecting ways to do what you want. First, we'd like a good routine to chunk off some of the numbers, making sure an okay number are left for the rest of the chunking, and making sure they're in range for IPs:
def validSize(i: Int, len: Int, more: Int) = i + more <= len && i + 3*more >= len
def chunk(s: String, more: Int) = {
val parts = for (i <- 1 to 3 if validSize(i, s.length, more)) yield s.splitAt(i)
parts.filter(_._1.toInt < 256)
}
Now we need to use chunk recursively four times to generate the possibilities. Here's a solution that is mutable internally and iterative:
def genIPs(digits: String) = {
var parts = List(("", digits))
for (i <- 1 to 4) {
parts = parts.flatMap{ case (pre, post) =>
chunk(post, 4-i).map{ case (x,y) => (pre+x+".", y) }
}
}
parts.map(_._1.dropRight(1))
}
Here's one that recurses using Iterator:
def genIPs(digits: String) = Iterator.iterate(List((3,"",digits))){ _.flatMap{
case(j, pre, post) => chunk(post, j).map{ case(x,y) => (j-1, pre+x+".", y) }
}}.dropWhile(_.head._1 >= 0).next.map(_._2.dropRight(1))
The logic is the same either way. Here it is working:
scala> genIPs("1238516")
res2: List[String] = List(1.23.85.16, 1.238.5.16, 1.238.51.6,
12.3.85.16, 12.38.5.16, 12.38.51.6,
123.8.5.16, 123.8.51.6, 123.85.1.6)

Efficient way to fold list in scala, while avoiding allocations and vars

I have a bunch of items in a list, and I need to analyze the content to find out how many of them are "complete". I started out with partition, but then realized that I didn't need to two lists back, so I switched to a fold:
val counts = groupRows.foldLeft( (0,0) )( (pair, row) =>
if(row.time == 0) (pair._1+1,pair._2)
else (pair._1, pair._2+1)
)
but I have a lot of rows to go through for a lot of parallel users, and it is causing a lot of GC activity (assumption on my part...the GC could be from other things, but I suspect this since I understand it will allocate a new tuple on every item folded).
for the time being, I've rewritten this as
var complete = 0
var incomplete = 0
list.foreach(row => if(row.time != 0) complete += 1 else incomplete += 1)
which fixes the GC, but introduces vars.
I was wondering if there was a way of doing this without using vars while also not abusing the GC?
EDIT:
Hard call on the answers I've received. A var implementation seems to be considerably faster on large lists (like by 40%) than even a tail-recursive optimized version that is more functional but should be equivalent.
The first answer from dhg seems to be on-par with the performance of the tail-recursive one, implying that the size pass is super-efficient...in fact, when optimized it runs very slightly faster than the tail-recursive one on my hardware.
The cleanest two-pass solution is probably to just use the built-in count method:
val complete = groupRows.count(_.time == 0)
val counts = (complete, groupRows.size - complete)
But you can do it in one pass if you use partition on an iterator:
val (complete, incomplete) = groupRows.iterator.partition(_.time == 0)
val counts = (complete.size, incomplete.size)
This works because the new returned iterators are linked behind the scenes and calling next on one will cause it to move the original iterator forward until it finds a matching element, but it remembers the non-matching elements for the other iterator so that they don't need to be recomputed.
Example of the one-pass solution:
scala> val groupRows = List(Row(0), Row(1), Row(1), Row(0), Row(0)).view.map{x => println(x); x}
scala> val (complete, incomplete) = groupRows.iterator.partition(_.time == 0)
Row(0)
Row(1)
complete: Iterator[Row] = non-empty iterator
incomplete: Iterator[Row] = non-empty iterator
scala> val counts = (complete.size, incomplete.size)
Row(1)
Row(0)
Row(0)
counts: (Int, Int) = (3,2)
I see you've already accepted an answer, but you rightly mention that that solution will traverse the list twice. The way to do it efficiently is with recursion.
def counts(xs: List[...], complete: Int = 0, incomplete: Int = 0): (Int,Int) =
xs match {
case Nil => (complete, incomplete)
case row :: tail =>
if (row.time == 0) counts(tail, complete + 1, incomplete)
else counts(tail, complete, incomplete + 1)
}
This is effectively just a customized fold, except we use 2 accumulators which are just Ints (primitives) instead of tuples (reference types). It should also be just as efficient a while-loop with vars - in fact, the bytecode should be identical.
Maybe it's just me, but I prefer using the various specialized folds (.size, .exists, .sum, .product) if they are available. I find it clearer and less error-prone than the heavy-duty power of general folds.
val complete = groupRows.view.filter(_.time==0).size
(complete, groupRows.length - complete)
How about this one? No import tax.
import scala.collection.generic.CanBuildFrom
import scala.collection.Traversable
import scala.collection.mutable.Builder
case class Count(n: Int, total: Int) {
def not = total - n
}
object Count {
implicit def cbf[A]: CanBuildFrom[Traversable[A], Boolean, Count] = new CanBuildFrom[Traversable[A], Boolean, Count] {
def apply(): Builder[Boolean, Count] = new Counter
def apply(from: Traversable[A]): Builder[Boolean, Count] = apply()
}
}
class Counter extends Builder[Boolean, Count] {
var n = 0
var ttl = 0
override def +=(b: Boolean) = { if (b) n += 1; ttl += 1; this }
override def clear() { n = 0 ; ttl = 0 }
override def result = Count(n, ttl)
}
object Counting extends App {
val vs = List(4, 17, 12, 21, 9, 24, 11)
val res: Count = vs map (_ % 2 == 0)
Console println s"${vs} have ${res.n} evens out of ${res.total}; ${res.not} were odd."
val res2: Count = vs collect { case i if i % 2 == 0 => i > 10 }
Console println s"${vs} have ${res2.n} evens over 10 out of ${res2.total}; ${res2.not} were smaller."
}
OK, inspired by the answers above, but really wanting to only pass over the list once and avoid GC, I decided that, in the face of a lack of direct API support, I would add this to my central library code:
class RichList[T](private val theList: List[T]) {
def partitionCount(f: T => Boolean): (Int, Int) = {
var matched = 0
var unmatched = 0
theList.foreach(r => { if (f(r)) matched += 1 else unmatched += 1 })
(matched, unmatched)
}
}
object RichList {
implicit def apply[T](list: List[T]): RichList[T] = new RichList(list)
}
Then in my application code (if I've imported the implicit), I can write var-free expressions:
val (complete, incomplete) = groupRows.partitionCount(_.time != 0)
and get what I want: an optimized GC-friendly routine that prevents me from polluting the rest of the program with vars.
However, I then saw Luigi's benchmark, and updated it to:
Use a longer list so that multiple passes on the list were more obvious in the numbers
Use a boolean function in all cases, so that we are comparing things fairly
http://pastebin.com/2XmrnrrB
The var implementation is definitely considerably faster, even though Luigi's routine should be identical (as one would expect with optimized tail recursion). Surprisingly, dhg's dual-pass original is just as fast (slightly faster if compiler optimization is on) as the tail-recursive one. I do not understand why.
It is slightly tidier to use a mutable accumulator pattern, like so, especially if you can re-use your accumulator:
case class Accum(var complete = 0, var incomplete = 0) {
def inc(compl: Boolean): this.type = {
if (compl) complete += 1 else incomplete += 1
this
}
}
val counts = groupRows.foldLeft( Accum() ){ (a, row) => a.inc( row.time == 0 ) }
If you really want to, you can hide your vars as private; if not, you still are a lot more self-contained than the pattern with vars.
You could just calculate it using the difference like so:
def counts(groupRows: List[Row]) = {
val complete = groupRows.foldLeft(0){ (pair, row) =>
if(row.time == 0) pair + 1 else pair
}
(complete, groupRows.length - complete)
}

How can I make this method more Scalalicious

I have a function that calculates the left and right node values for some collection of treeNodes given a simple node.id, node.parentId association. It's very simple and works well enough...but, well, I am wondering if there is a more idiomatic approach. Specifically is there a way to track the left/right values without using some externally tracked value but still keep the tasty recursion.
/*
* A tree node
*/
case class TreeNode(val id:String, val parentId: String){
var left: Int = 0
var right: Int = 0
}
/*
* a method to compute the left/right node values
*/
def walktree(node: TreeNode) = {
/*
* increment state for the inner function
*/
var c = 0
/*
* A method to set the increment state
*/
def increment = { c+=1; c } // poo
/*
* the tasty inner method
* treeNodes is a List[TreeNode]
*/
def walk(node: TreeNode): Unit = {
node.left = increment
/*
* recurse on all direct descendants
*/
treeNodes filter( _.parentId == node.id) foreach (walk(_))
node.right = increment
}
walk(node)
}
walktree(someRootNode)
Edit -
The list of nodes is taken from a database. Pulling the nodes into a proper tree would take too much time. I am pulling a flat list into memory and all I have is an association via node id's as pertains to parents and children.
Adding left/right node values allows me to get a snapshop of all children (and childrens children) with a single SQL query.
The calculation needs to run very quickly in order to maintain data integrity should parent-child associations change (which they do very frequently).
In addition to using the awesome Scala collections I've also boosted speed by using parallel processing for some pre/post filtering on the tree nodes. I wanted to find a more idiomatic way of tracking the left/right node values. After looking at the answer from #dhg it got even better. Using groupBy instead of a filter turns the algorithm (mostly?) linear instead of quadtratic!
val treeNodeMap = treeNodes.groupBy(_.parentId).withDefaultValue(Nil)
def walktree(node: TreeNode) = {
def walk(node: TreeNode, counter: Int): Int = {
node.left = counter
node.right =
treeNodeMap(node.id)
.foldLeft(counter+1) {
(result, curnode) => walk(curnode, result) + 1
}
node.right
}
walk(node,1)
}
Your code appears to be calculating an in-order traversal numbering.
I think what you want to make your code better is a fold that carries the current value downward and passes the updated value upward. Note that it might also be worth it to do a treeNodes.groupBy(_.parentId) before walktree to prevent you from calling treeNodes.filter(...) every time you call walk.
val treeNodes = List(TreeNode("1","0"),TreeNode("2","1"),TreeNode("3","1"))
val treeNodeMap = treeNodes.groupBy(_.parentId).withDefaultValue(Nil)
def walktree2(node: TreeNode) = {
def walk(node: TreeNode, c: Int): Int = {
node.left = c
val newC =
treeNodeMap(node.id) // get the children without filtering
.foldLeft(c+1)((c, child) => walk(child, c) + 1)
node.right = newC
newC
}
walk(node, 1)
}
And it produces the same result:
scala> walktree2(TreeNode("0","-1"))
scala> treeNodes.map(n => "(%s,%s)".format(n.left,n.right))
res32: List[String] = List((2,7), (3,4), (5,6))
That said, I would completely rewrite your code as follows:
case class TreeNode( // class is now immutable; `walktree` returns a new tree
id: String,
value: Int, // value to be set during `walktree`
left: Option[TreeNode], // recursively-defined structure
right: Option[TreeNode]) // makes traversal much simpler
def walktree(node: TreeNode) = {
def walk(nodeOption: Option[TreeNode], c: Int): (Option[TreeNode], Int) = {
nodeOption match {
case None => (None, c) // if this child doesn't exist, do nothing
case Some(node) => // if this child exists, recursively walk
val (newLeft, cLeft) = walk(node.left, c) // walk the left side
val newC = cLeft + 1 // update the value
val (newRight, cRight) = walk(node.right, newC) // walk the right side
(Some(TreeNode(node.id, newC, newLeft, newRight)), cRight)
}
}
walk(Some(node), 0)._1
}
Then you can use it like this:
walktree(
TreeNode("1", -1,
Some(TreeNode("2", -1,
Some(TreeNode("3", -1, None, None)),
Some(TreeNode("4", -1, None, None)))),
Some(TreeNode("5", -1, None, None))))
To produce:
Some(TreeNode(1,4,
Some(TreeNode(2,2,
Some(TreeNode(3,1,None,None)),
Some(TreeNode(4,3,None,None)))),
Some(TreeNode(5,5,None,None))))
If I get your algorithm correctly:
def walktree(node: TreeNode, c: Int): Int = {
node.left = c
val c2 = treeNodes.filter(_.parentId == node.id).foldLeft(c + 1) {
(cur, n) => walktree(n, cur)
}
node.right = c2 + 1
c2 + 2
}
walktree(new TreeNode("", ""), 0)
Off-by-one errors are likely to occur.
Few random thoughts (better suited for http://codereview.stackexchange.com):
try posting that compiles... We have to guess that is a sequence of TreeNode:
val is implicit for case classes:
case class TreeNode(val id: String, val parentId: String) {
Avoid explicit = and Unit for Unit functions:
def walktree(node: TreeNode) = {
def walk(node: TreeNode): Unit = {
Methods with side-effects should have ():
def increment = {c += 1; c}
This is terribly slow, consider storing list of children in the actual node:
treeNodes filter (_.parentId == node.id) foreach (walk(_))
More concice syntax would be treeNodes foreach walk:
treeNodes foreach (walk(_))

Scala: Group an Iterable into an Iterable of Iterables by a predicate

I have very large Iterators that I want to split into pieces. I have a predicate that looks at an item and returns true if it is the start of a new piece. I need the pieces to be Iterators, because even the pieces will not fit into memory. There are so many pieces that I would be wary of a recursive solution blowing out your stack. The situation is similar to this question, but I need Iterators instead of Lists, and the "sentinels" (items for which the predicate is true) occur (and should be included) at the beginning of a piece. The resulting iterators will only be used in order, though some may not be used at all, and they should only use O(1) memory. I imagine this means they should all share the same underlying iterator. Performance is important.
If I were to take a stab at a function signature, it would be this:
def groupby[T](iter: Iterator[T])(startsGroup: T => Boolean): Iterator[Iterator[T]] = ...
I would have loved to use takeWhile, but it loses the last element. I investigated span, but it buffers results. My current best idea involves BufferedIterator, but maybe there is a better way.
You'll know you've got it right because something like this doesn't crash your JVM:
groupby((1 to Int.MaxValue).iterator)(_ % (Int.MaxValue / 2) == 0).foreach(group => println(group.sum))
groupby((1 to Int.MaxValue).iterator)(_ % 10 == 0).foreach(group => println(group.sum))
You have an inherent problem. Iterable implies that you can get multiple iterators. Iterator implies that you can only pass through once. That means that your Iterable[Iterable[T]] should be able to produce Iterator[Iterable[T]]s. But when this returns an element--an Iterable[T]--and that asks for multiple iterators, the underlying single iterator can't comply without either caching the results of the list (too big) or calling the original iterable and going through absolutely everything again (very inefficient).
So, while you could do this, I think you should conceive of your problem in a different way.
If you could start with a Seq instead, you could grab subsets as ranges.
If you already know how you want to use your iterable, you could write a method
def process[T](source: Iterable[T])(starts: T => Boolean)(handlers: T => Unit *)
which increments through the set of handlers each time starts fires off a "true". If there's any way you can do your processing in one sweep, something like this is the way to go. (Your handlers will have to save state via mutable data structures or variables, however.)
If you can permit iteration on the outer list to break the inner list, you could have an Iterable[Iterator[T]] with the additional constraint that once you iterate to a later sub-iterator, all previous sub-iterators are invalid.
Here's a solution of the last type (from Iterator[T] to Iterator[Iterator[T]]; one can wrap this to make the outer layers Iterable instead).
class GroupedBy[T](source: Iterator[T])(starts: T => Boolean)
extends Iterator[Iterator[T]] {
private val underlying = source
private var saved: T = _
private var cached = false
private var starting = false
private def cacheNext() {
saved = underlying.next
starting = starts(saved)
cached = true
}
private def oops() { throw new java.util.NoSuchElementException("empty iterator") }
// Comment the next line if you do NOT want the first element to always start a group
if (underlying.hasNext) { cacheNext(); starting = true }
def hasNext = {
while (!(cached && starting) && underlying.hasNext) cacheNext()
cached && starting
}
def next = {
if (!(cached && starting) && !hasNext) oops()
starting = false
new Iterator[T] {
var presumablyMore = true
def hasNext = {
if (!cached && !starting && underlying.hasNext && presumablyMore) cacheNext()
presumablyMore = cached && !starting
presumablyMore
}
def next = {
if (presumablyMore && (cached || hasNext)) {
cached = false
saved
}
else oops()
}
}
}
}
Here's my solution using BufferedIterator. It doesn't let you skip iterators correctly, but it's fairly simple and functional. The first element(s) go into a group even if !startsGroup(first).
def groupby[T](iter: Iterator[T])(startsGroup: T => Boolean): Iterator[Iterator[T]] =
new Iterator[Iterator[T]] {
val base = iter.buffered
override def hasNext = base.hasNext
override def next() = Iterator(base.next()) ++ new Iterator[T] {
override def hasNext = base.hasNext && !startsGroup(base.head)
override def next() = if (hasNext) base.next() else Iterator.empty.next()
}
}
Update: Keeping a little state lets you skip iterators and prevent people from messing with previous ones:
def groupby[T](iter: Iterator[T])(startsGroup: T => Boolean): Iterator[Iterator[T]] =
new Iterator[Iterator[T]] {
val base = iter.buffered
var prev: Iterator[T] = Iterator.empty
override def hasNext = base.hasNext
override def next() = {
while (prev.hasNext) prev.next() // Exhaust previous iterator; take* and drop* do NOT always work!! (Jira SI-5002?)
prev = Iterator(base.next()) ++ new Iterator[T] {
var hasMore = true
override def hasNext = { hasMore = hasMore && base.hasNext && !startsGroup(base.head) ; hasMore }
override def next() = if (hasNext) base.next() else Iterator.empty.next()
}
prev
}
}
If you are looking at memory constraints then the following will work. You can only use it if your underlying iterable object supports views. This implementation will iterate over the Iterable and then generate IterableViews which can then be iterated over. This implementation does not care if the very first element tests as a start group since it will be regardless.
def groupby[T](iter: Iterable[T])(startsGroup: T => Boolean): Iterable[Iterable[T]] = new Iterable[Iterable[T]] {
def iterator = new Iterator[Iterable[T]] {
val i = iter.iterator
var index = 0
var nextView: IterableView[T, Iterable[T]] = getNextView()
private def getNextView() = {
val start = index
var hitStartGroup = false
while ( i.hasNext && ! hitStartGroup ) {
val next = i.next()
index += 1
hitStartGroup = ( index > 1 && startsGroup( next ) )
}
if ( hitStartGroup ) {
if ( start == 0 ) iter.view( start, index - 1 )
else iter.view( start - 1, index - 1 )
} else { // hit end
if ( start == index ) null
else if ( start == 0 ) iter.view( start, index )
else iter.view( start - 1, index )
}
}
def hasNext = nextView != null
def next() = {
if ( nextView != null ) {
val next = nextView
nextView = getNextView()
next
} else null
}
}
}
You can maintain low memory foot-print by using Streams. Use result.toIterator, if you an iterator again.
With streams, there's no mutable state, only a single conditional and it's nearly as concise as Jay Hacker's solution.
def batchBy[A,B](iter: Iterator[A])(f: A => B): Stream[(B, Iterator[A])] = {
val base = iter.buffered
val empty = Stream.empty[(B, Iterator[A])]
def getBatch(key: B) = {
Iterator(base.next()) ++ new Iterator[A] {
def hasNext: Boolean = base.hasNext && (f(base.head) == key)
def next(): A = base.next()
}
}
def next(skipList: Option[Iterator[A]] = None): Stream[(B, Iterator[A])] = {
skipList.foreach{_.foreach{_=>}}
if (base.isEmpty) empty
else {
val key = f(base.head)
val batch = getBatch(key)
Stream.cons((key, batch), next(Some(batch)))
}
}
next()
}
I ran the tests:
scala> batchBy((1 to Int.MaxValue).iterator)(_ % (Int.MaxValue / 2) == 0)
.foreach{case(_,group) => println(group.sum)}
-1610612735
1073741823
-536870909
2147483646
2147483647
The second test prints too much to paste to Stack Overflow.
import scala.collection.mutable.ArrayBuffer
object GroupingIterator {
/**
* Create a new GroupingIterator with a grouping predicate.
*
* #param it The original iterator
* #param p Predicate controlling the grouping
* #tparam A Type of elements iterated
* #return A new GroupingIterator
*/
def apply[A](it: Iterator[A])(p: (A, IndexedSeq[A]) => Boolean): GroupingIterator[A] =
new GroupingIterator(it)(p)
}
/**
* Group elements in sequences of contiguous elements that satisfy a predicate. The predicate
* tests each single potential next element of the group with the help of the elements grouped so far.
* If it returns true, the potential next element is added to the group, otherwise
* a new group is started with the potential next element as first element
*
* #param self The original iterator
* #param p Predicate controlling the grouping
* #tparam A Type of elements iterated
*/
class GroupingIterator[+A](self: Iterator[A])(p: (A, IndexedSeq[A]) => Boolean) extends Iterator[IndexedSeq[A]] {
private[this] val source = self.buffered
private[this] val buffer: ArrayBuffer[A] = ArrayBuffer()
def hasNext: Boolean = source.hasNext
def next(): IndexedSeq[A] = {
if (hasNext)
nextGroup()
else
Iterator.empty.next()
}
private[this] def nextGroup(): IndexedSeq[A] = {
assert(source.hasNext)
buffer.clear()
buffer += source.next
while (source.hasNext && p(source.head, buffer)) {
buffer += source.next
}
buffer.toIndexedSeq
}
}

Scala Graph Cycle Detection Algo 'return' needed?

I have implemented a small cycle detection algorithm for a DAG in Scala.
The 'return' bothers me - I'd like to have a version without the return...possible?
def isCyclic() : Boolean = {
lock.readLock().lock()
try {
nodes.foreach(node => node.marker = 1)
nodes.foreach(node => {if (1 == node.marker && visit(node)) return true})
} finally {
lock.readLock().unlock()
}
false
}
private def visit(node: MyNode): Boolean = {
node.marker = 3
val nodeId = node.id
val children = vertexMap.getChildren(nodeId).toList.map(nodeId => id2nodeMap(nodeId))
children.foreach(child => {
if (3 == child.marker || (1 == child.marker && visit(child))) return true
})
node.marker = 2
false
}
Yes, by using '.find' instead of 'foreach' + 'return':
http://www.scala-lang.org/api/current/index.html#scala.collection.immutable.Seq
def isCyclic() : Boolean = {
def visit(node: MyNode): Boolean = {
node.marker = 3
val nodeId = node.id
val children = vertexMap.getChildren(nodeId).toList.map(nodeId => id2nodeMap(nodeId))
val found = children.exists(child => (3 == child.marker || (1 == child.marker && visit(child))))
node.marker = 2
found
}
lock.readLock().lock()
try {
nodes.foreach(node => node.marker = 1)
nodes.exists(node => node.marker && visit(node))
} finally {
lock.readLock().unlock()
}
}
Summary:
I have originated two solutions as generic FP functions which detect cycles within a directed graph. And per your implied preference, the use of an early return to escape the recursive function has been eliminated. The first, isCyclic, simply returns a Boolean as soon as the DFS (Depth First Search) repeats a node visit. The second, filterToJustCycles, returns a copy of the input Map filtered down to just the nodes involved in any/all cycles, and returns an empty Map when no cycles are found.
Details:
For the following, please Consider a directed graph encoded as such:
val directedGraphWithCyclesA: Map[String, Set[String]] =
Map(
"A" -> Set("B", "E", "J")
, "B" -> Set("E", "F")
, "C" -> Set("I", "G")
, "D" -> Set("G", "L")
, "E" -> Set("H")
, "F" -> Set("G")
, "G" -> Set("L")
, "H" -> Set("J", "K")
, "I" -> Set("K", "L")
, "J" -> Set("B")
, "K" -> Set("B")
)
In both functions below, the type parameter N refers to whatever "Node" type you care to provide. It is important the provided "Node" type be both immutable and have stable equals and hashCode implementations (all of which occur automatically with use of immutable case classes).
The first function, isCyclic, is a similar in nature to the version of the solution provided by #the-archetypal-paul. It assumes the directed graph has been transformed into a Map[N, Set[N]] where N is the identity of a node in the graph.
If you need to see how to generically transform your custom directed graph implementation into a Map[N, Set[N]], I have outlined a generic solution towards the end of this answer.
Calling the isCyclic function as such:
val isCyclicResult = isCyclic(directedGraphWithCyclesA)
will return:
`true`
No further information is provided. And the DFS (Depth First Search) is aborted at detection of the first repeated visit to a node.
def isCyclic[N](nsByN: Map[N, Set[N]]) : Boolean = {
def hasCycle(nAndNs: (N, Set[N]), visited: Set[N] = Set[N]()): Boolean =
if (visited.contains(nAndNs._1))
true
else
nAndNs._2.exists(
n =>
nsByN.get(n) match {
case Some(ns) =>
hasCycle((n, ns), visited + nAndNs._1)
case None =>
false
}
)
nsByN.exists(hasCycle(_))
}
The second function, filterToJustCycles, uses the set reduction technique to recursively filter away unreferenced root nodes in the Map. If there are no cycles in the supplied graph of nodes, then .isEmpty will be true on the returned Map. If however, there are any cycles, all of the nodes required to participate in any of the cycles are returned with all of the other non-cycle participating nodes filtered away.
Again, if you need to see how to generically transform your custom directed graph implementation into a Map[N, Set[N]], I have outlined a generic solution towards the end of this answer.
Calling the filterToJustCycles function as such:
val cycles = filterToJustCycles(directedGraphWithCyclesA)
will return:
Map(E -> Set(H), J -> Set(B), B -> Set(E), H -> Set(J, K), K -> Set(B))
It's trivial to then create a traversal across this Map to produce any or all of the various cycle pathways through the remaining nodes.
def filterToJustCycles[N](nsByN: Map[N, Set[N]]): Map[N, Set[N]] = {
def recursive(nsByNRemaining: Map[N, Set[N]], referencedRootNs: Set[N] = Set[N]()): Map[N, Set[N]] = {
val (referencedRootNsNew, nsByNRemainingNew) = {
val referencedRootNsNewTemp =
nsByNRemaining.values.flatten.toSet.intersect(nsByNRemaining.keySet)
(
referencedRootNsNewTemp
, nsByNRemaining.collect {
case (t, ts) if referencedRootNsNewTemp.contains(t) && referencedRootNsNewTemp.intersect(ts.toSet).nonEmpty =>
(t, referencedRootNsNewTemp.intersect(ts.toSet))
}
)
}
if (referencedRootNsNew == referencedRootNs)
nsByNRemainingNew
else
recursive(nsByNRemainingNew, referencedRootNsNew)
}
recursive(nsByN)
}
So, how does one generically transform a custom directed graph implementation into a Map[N, Set[N]]?
In essence, "Go Scala case classes!"
First, let's define an example case of a real node in a pre-existing directed graph:
class CustomNode (
val equipmentIdAndType: String //"A387.Structure" - identity is embedded in a string and must be parsed out
, val childrenNodes: List[CustomNode] //even through Set is implied, for whatever reason this implementation used List
, val otherImplementationNoise: Option[Any] = None
)
Again, this is just an example. Yours could involve subclassing, delegation, etc. The purpose is to have access to a something that will be able to fetch the two essential things to make this work:
the identity of a node; i.e. something to distinguish it and makes it unique from all other nodes in the same directed graph
a collection of the identities of the immediate children of a specific node - if the specific node doesn't have any children, this collection will be empty
Next, we define a helper object, DirectedGraph, which will contain the infrastructure for the conversion:
Node: an adapter trait which will wrap CustomNode
toMap: a function which will take a List[CustomNode] and convert it to a Map[Node, Set[Node]] (which is type equivalent to our target type of Map[N, Set[N]])
Here's the code:
object DirectedGraph {
trait Node[S, I] {
def source: S
def identity: I
def children: Set[I]
}
def toMap[S, I, N <: Node[S, I]](ss: List[S], transformSToN: S => N): Map[N, Set[N]] = {
val (ns, nByI) = {
val iAndNs =
ss.map(
s => {
val n =
transformSToN(s)
(n.identity, n)
}
)
(iAndNs.map(_._2), iAndNs.toMap)
}
ns.map(n => (n, n.children.map(nByI(_)))).toMap
}
}
Now, we must generate the actual adapter, CustomNodeAdapter, which will wrap each CustomNode instance. This adapter uses a case class in a very specific way; i.e. specifying two constructor parameters lists. It ensures the case class conforms to a Set's requirement that a Set member have correct equals and hashCode implementations. For more details on why and how to use a case class this way, please see this StackOverflow question and answer:
object CustomNodeAdapter extends (CustomNode => CustomNodeAdapter) {
def apply(customNode: CustomNode): CustomNodeAdapter =
new CustomNodeAdapter(fetchIdentity(customNode))(customNode) {}
def fetchIdentity(customNode: CustomNode): String =
fetchIdentity(customNode.equipmentIdAndType)
def fetchIdentity(eiat: String): String =
eiat.takeWhile(char => char.isLetter || char.isDigit)
}
abstract case class CustomNodeAdapter(identity: String)(customNode: CustomNode) extends DirectedGraph.Node[CustomNode, String] {
val children =
customNode.childrenNodes.map(CustomNodeAdapter.fetchIdentity).toSet
val source =
customNode
}
We now have the infrastructure in place. Let's define a "real world" directed graph consisting of CustomNode:
val customNodeDirectedGraphWithCyclesA =
List(
new CustomNode("A.x", List("B.a", "E.a", "J.a"))
, new CustomNode("B.x", List("E.b", "F.b"))
, new CustomNode("C.x", List("I.c", "G.c"))
, new CustomNode("D.x", List("G.d", "L.d"))
, new CustomNode("E.x", List("H.e"))
, new CustomNode("F.x", List("G.f"))
, new CustomNode("G.x", List("L.g"))
, new CustomNode("H.x", List("J.h", "K.h"))
, new CustomNode("I.x", List("K.i", "L.i"))
, new CustomNode("J.x", List("B.j"))
, new CustomNode("K.x", List("B.k"))
, new CustomNode("L.x", Nil)
)
Finally, we can now do the conversion which looks like this:
val transformCustomNodeDirectedGraphWithCyclesA =
DirectedGraph.toMap[CustomNode, String, CustomNodeAdapter](customNodes1, customNode => CustomNodeAdapter(customNode))
And we can take transformCustomNodeDirectedGraphWithCyclesA, which is of type Map[CustomNodeAdapter,Set[CustomNodeAdapter]], and submit it to the two original functions.
Calling the isCyclic function as such:
val isCyclicResult = isCyclic(transformCustomNodeDirectedGraphWithCyclesA)
will return:
`true`
Calling the filterToJustCycles function as such:
val cycles = filterToJustCycles(transformCustomNodeDirectedGraphWithCyclesA)
will return:
Map(
CustomNodeAdapter(B) -> Set(CustomNodeAdapter(E))
, CustomNodeAdapter(E) -> Set(CustomNodeAdapter(H))
, CustomNodeAdapter(H) -> Set(CustomNodeAdapter(J), CustomNodeAdapter(K))
, CustomNodeAdapter(J) -> Set(CustomNodeAdapter(B))
, CustomNodeAdapter(K) -> Set(CustomNodeAdapter(B))
)
And if needed, this Map can then be converted back to Map[CustomNode, List[CustomNode]]:
cycles.map {
case (customNodeAdapter, customNodeAdapterChildren) =>
(customNodeAdapter.source, customNodeAdapterChildren.toList.map(_.source))
}
If you have any questions, issues or concerns, please let me know and I will address them ASAP.
I think the problem can be solved without changing the state of the node with the marker field. The following is a rough code of what i think the isCyclic should look like. I am currently storing the node objects which are visited instead you can store the node ids if the node doesnt have equality based on node id.
def isCyclic() : Boolean = nodes.exists(hasCycle(_, HashSet()))
def hasCycle(node:Node, visited:Seq[Node]) = visited.contains(node) || children(node).exists(hasCycle(_, node +: visited))
def children(node:Node) = vertexMap.getChildren(node.id).toList.map(nodeId => id2nodeMap(nodeId))
Answer added just to show that the mutable-visited isn't too unreadable either (untested, though!)
def isCyclic() : Boolean =
{
var visited = HashSet()
def hasCycle(node:Node) = {
if (visited.contains(node)) {
true
} else {
visited :+= node
children(node).exists(hasCycle(_))
}
}
nodes.exists(hasCycle(_))
}
def children(node:Node) = vertexMap.getChildren(node.id).toList.map(nodeId => id2nodeMap(nodeId))
If p = node => node.marker==1 && visit(node) and assuming nodes is a List you can pick any of the following:
nodes.filter(p).length>0
nodes.count(p)>0
nodes.exists(p) (I think the most relevant)
I am not sure of the relative complexity of each method and would appreciate a comment from fellow members of the community