I'm seeing what seems to be a very obvious bug with scalacheck, such that if it's really there I can't see how people use it for recursive data structures.
This program fails with a StackOverflowError before scalacheck takes over, while constructing the Arbitrary value. Note that the Tree type and the generator for Trees is taken verbatim from this scalacheck tutorial.
package treegen
import org.scalacheck._
import Prop._
class TreeProperties extends Properties("Tree") {
trait Tree
case class Node(left: Tree, right: Tree) extends Tree
case class Leaf(x: Int) extends Tree
val ints = Gen.choose(-100, 100)
def leafs: Gen[Leaf] = for {
x <- ints
} yield Leaf(x)
def nodes: Gen[Node] = for {
left <- trees
right <- trees
} yield Node(left, right)
def trees: Gen[Tree] = Gen.oneOf(leafs, nodes)
implicit lazy val arbTree: Arbitrary[Tree] = Arbitrary(trees)
property("vacuous") = forAll { t: Tree => true }
}
object Main extends App {
(new TreeProperties).check
}
What's stranger is that changes that shouldn't affect anything seem to alter the program so that it works. For example, if you change the definition of trees to this, it passes without any problem:
def trees: Gen[Tree] = for {
x <- Gen.oneOf(0, 1)
t <- if (x == 0) {leafs} else {nodes}
} yield t
Even stranger, if you alter the binary tree structure so that the value is stored on Nodes and not on Leafs, and alter the leafs and nodes definition to be:
def leafs: Gen[Leaf] = Gen.value(Leaf())
def nodes: Gen[Node] = for {
x <- ints // Note: be sure to ask for x first, or it'll StackOverflow later, inside scalacheck code!
left <- trees
right <- trees
} yield Node(left, right, x)
It also then works fine.
What's going on here? Why is constructing the Arbitrary value initially causing a stack overflow? Why does it seem that scalacheck generators are so sensitive to minor changes that shouldn't affect the control flow of the generators?
Why isn't my expression above with the oneOf(0, 1) exactly equivalent to the original oneOf(leafs, nodes) ?
The problem is that when Scala evaluates trees, it ends up in an endless recursion since trees is defined in terms of itself (via nodes). However, when you put some other expression than trees as the first part of your for-expression in nodes, Scala will delay the evaluation of the rest of the for-expression (wrapped up in chains of map and flatMap calls), and the infinite recursion will not happen.
Just as pedrofurla says, if oneOf was non-strict this would probably not happen (since Scala wouldn't evaluate the arguments immediately). However you can use Gen.lzy to be explicit about the lazyness. lzy takes any generator and delays the evaluation of that generator until it is really used. So the following change solves your problem:
def trees: Gen[Tree] = Gen.lzy(Gen.oneOf(leafs, nodes))
Even though following Rickard Nilsson's answer above got rid of the constant StackOverflowError on program startup, I'd still hit a StackOverflowError about one time out of three once I actually asked scalacheck to check the properties. (I changed Main above to run .check 40 times, and would see it succeed twice, then fail with a stack overflow, then succeed twice, etc.)
Eventually I had to put in a hard block to the depth of the recursion and this is what I guess I'll be doing when using scalacheck on recursive data structures in the future:
def leafs: Gen[Leaf] = for {
x <- ints
} yield Leaf(x)
def genNode(level: Int): Gen[Node] = for {
left <- genTree(level)
right <- genTree(level)
} yield Node(left, right)
def genTree(level: Int): Gen[Tree] = if (level >= 100) {leafs}
else {leafs | genNode(level + 1)}
lazy val trees: Gen[Tree] = genTree(0)
With this change, scalacheck never runs into a StackOverflowError.
A slight generalization of approach in Daniel Martin's own answer is using sized. Something like (untested):
def genTree() = Gen.sized { size => genTree0(size) }
def genTree0(maxDepth: Int) =
if (maxDepth == 0) leafs else Gen.oneOf(leafs, genNode(maxDepth))
def genNode(maxDepth: Int) = for {
depthL <- Gen.choose(0, maxDepth - 1)
depthR <- Gen.choose(0, maxDepth - 1)
left <- genTree0(depthL)
right <- genTree0(depthR)
} yield Node(left, right)
def leafs = for {
x <- ints
} yield Leaf(x)
Related
I'm seeing what seems to be a very obvious bug with scalacheck, such that if it's really there I can't see how people use it for recursive data structures.
This program fails with a StackOverflowError before scalacheck takes over, while constructing the Arbitrary value. Note that the Tree type and the generator for Trees is taken verbatim from this scalacheck tutorial.
package treegen
import org.scalacheck._
import Prop._
class TreeProperties extends Properties("Tree") {
trait Tree
case class Node(left: Tree, right: Tree) extends Tree
case class Leaf(x: Int) extends Tree
val ints = Gen.choose(-100, 100)
def leafs: Gen[Leaf] = for {
x <- ints
} yield Leaf(x)
def nodes: Gen[Node] = for {
left <- trees
right <- trees
} yield Node(left, right)
def trees: Gen[Tree] = Gen.oneOf(leafs, nodes)
implicit lazy val arbTree: Arbitrary[Tree] = Arbitrary(trees)
property("vacuous") = forAll { t: Tree => true }
}
object Main extends App {
(new TreeProperties).check
}
What's stranger is that changes that shouldn't affect anything seem to alter the program so that it works. For example, if you change the definition of trees to this, it passes without any problem:
def trees: Gen[Tree] = for {
x <- Gen.oneOf(0, 1)
t <- if (x == 0) {leafs} else {nodes}
} yield t
Even stranger, if you alter the binary tree structure so that the value is stored on Nodes and not on Leafs, and alter the leafs and nodes definition to be:
def leafs: Gen[Leaf] = Gen.value(Leaf())
def nodes: Gen[Node] = for {
x <- ints // Note: be sure to ask for x first, or it'll StackOverflow later, inside scalacheck code!
left <- trees
right <- trees
} yield Node(left, right, x)
It also then works fine.
What's going on here? Why is constructing the Arbitrary value initially causing a stack overflow? Why does it seem that scalacheck generators are so sensitive to minor changes that shouldn't affect the control flow of the generators?
Why isn't my expression above with the oneOf(0, 1) exactly equivalent to the original oneOf(leafs, nodes) ?
The problem is that when Scala evaluates trees, it ends up in an endless recursion since trees is defined in terms of itself (via nodes). However, when you put some other expression than trees as the first part of your for-expression in nodes, Scala will delay the evaluation of the rest of the for-expression (wrapped up in chains of map and flatMap calls), and the infinite recursion will not happen.
Just as pedrofurla says, if oneOf was non-strict this would probably not happen (since Scala wouldn't evaluate the arguments immediately). However you can use Gen.lzy to be explicit about the lazyness. lzy takes any generator and delays the evaluation of that generator until it is really used. So the following change solves your problem:
def trees: Gen[Tree] = Gen.lzy(Gen.oneOf(leafs, nodes))
Even though following Rickard Nilsson's answer above got rid of the constant StackOverflowError on program startup, I'd still hit a StackOverflowError about one time out of three once I actually asked scalacheck to check the properties. (I changed Main above to run .check 40 times, and would see it succeed twice, then fail with a stack overflow, then succeed twice, etc.)
Eventually I had to put in a hard block to the depth of the recursion and this is what I guess I'll be doing when using scalacheck on recursive data structures in the future:
def leafs: Gen[Leaf] = for {
x <- ints
} yield Leaf(x)
def genNode(level: Int): Gen[Node] = for {
left <- genTree(level)
right <- genTree(level)
} yield Node(left, right)
def genTree(level: Int): Gen[Tree] = if (level >= 100) {leafs}
else {leafs | genNode(level + 1)}
lazy val trees: Gen[Tree] = genTree(0)
With this change, scalacheck never runs into a StackOverflowError.
A slight generalization of approach in Daniel Martin's own answer is using sized. Something like (untested):
def genTree() = Gen.sized { size => genTree0(size) }
def genTree0(maxDepth: Int) =
if (maxDepth == 0) leafs else Gen.oneOf(leafs, genNode(maxDepth))
def genNode(maxDepth: Int) = for {
depthL <- Gen.choose(0, maxDepth - 1)
depthR <- Gen.choose(0, maxDepth - 1)
left <- genTree0(depthL)
right <- genTree0(depthR)
} yield Node(left, right)
def leafs = for {
x <- ints
} yield Leaf(x)
I am starting to use the state monad to clean up my code. I have got it working for my problem where I process a transaction called CDR and modify the state accordingly.
It is working perfectly fine for individual transactions, using this function to perform the state update.
def addTraffic(cdr: CDR): Network => Network = ...
Here is an example:
scala> val processed: (CDR) => State[Network, Long] = cdr =>
| for {
| m <- init
| _ <- modify(Network.addTraffic(cdr))
| p <- get
| } yield p.count
processed: CDR => scalaz.State[Network,Long] = $$Lambda$4372/1833836780#1258d5c0
scala> val r = processed(("122","celda 1", 3))
r: scalaz.State[Network,Long] = scalaz.IndexedStateT$$anon$13#4cc4bdde
scala> r.run(Network.empty)
res56: scalaz.Id.Id[(Network, Long)] = (Network(Map(122 -> (1,0.0)),Map(celda 1 -> (1,0.0)),Map(1 -> Map(1 -> 3)),1,true),1)
What i want to do now is to chain a number of transactions on an iterator. I have found something that works quite well but the state transitions take no inputs (state changes through RNG)
import scalaz._
import scalaz.std.list.listInstance
type RNG = scala.util.Random
val f = (rng:RNG) => (rng, rng.nextInt)
val intGenerator: State[RNG, Int] = State(f)
val rng42 = new scala.util.Random
val applicative = Applicative[({type l[Int] = State[RNG,Int]})#l]
// To generate the first 5 Random integers
val chain: State[RNG, List[Int]] = applicative.sequence(List.fill(5)(intGenerator))
val chainResult: (RNG, List[Int]) = chain.run(rng42)
chainResult._2.foreach(println)
I have unsuccessfully tried to adapt this, but I can not get they types signatures to match because my state function requires the cdr (transaction) input
Thanks
TL;DR
you can use traverse from the Traverse type-class on a collection (e.g. List) of CDRs, using a function with this signature: CDR => State[Network, Long]. The result will be a State[Network, List[Long]]. Alternatively, if you don't care about the List[Long] there, you can use traverse_ instead, which will return State[Network, Unit]. Finally, should you want to "aggregate" the results T as they come along, and T forms a Monoid, you can use foldMap from Foldable, which will return State[Network, T], where T is the combined (e.g. folded) result of all Ts in your chain.
A code example
Now some more details, with code examples. I will answer this using Cats State rather than Scalaz, as I never used the latter, but the concept is the same and, if you still have problems, I will dig out the correct syntax.
Assume that we have the following data types and imports to work with:
import cats.implicits._
import cats.data.State
case class Position(x : Int = 0, y : Int = 0)
sealed trait Move extends Product
case object Up extends Move
case object Down extends Move
case object Left extends Move
case object Right extends Move
As it is clear, the Position represents a point in a 2D plane and a Move can move such point up, down, left or right.
Now, lets create a method that will allow us to see where we are at a given time:
def whereAmI : State[Position, String] = State.inspect{ s => s.toString }
and a method to change our position, given a Move:
def move(m : Move) : State[Position, String] = State{ s =>
m match {
case Up => (s.copy(y = s.y + 1), "Up!")
case Down => (s.copy(y = s.y - 1), "Down!")
case Left => (s.copy(x = s.x - 1), "Left!")
case Right => (s.copy(x = s.x + 1), "Right!")
}
}
Notice that this will return a String, with the name of the move followed by an exclamation mark. This is just to simulate the type change from Move to something else, and show how the results will be aggregated. More on this in a bit.
Now let's try to play with our methods:
val positions : State[Position, List[String]] = for{
pos1 <- whereAmI
_ <- move(Up)
_ <- move(Right)
_ <- move(Up)
pos2 <- whereAmI
_ <- move(Left)
_ <- move(Left)
pos3 <- whereAmI
} yield List(pos1,pos2,pos3)
And we can feed it an initial Position and see the result:
positions.runA(Position()).value // List(Position(0,0), Position(1,2), Position(-1,2))
(you can ignore the .value there, it's a quirk due to the fact that State[S,A] is really just an alias for StateT[Eval,S,A])
As you can see, this behaves as you would expect, and you can create different "blueprints" (e.g. sequences of state modifications), which will be applied once an initial state is provided.
Now, to actually answer to you question, say we have a List[Move] and we want to apply them sequentially to an initial state, and get the result: we use traverse from the Traverse type-class.
val moves = List(Down, Down, Left, Up)
val result : State[Position, List[String]] = moves.traverse(move)
result.run(Position()).value // (Position(-1,-1),List(Down!, Down!, Left!, Up!))
Alternatively, should you not need the A at all (the List in you case), you can use traverse_, instead of traverse and the result type will be:
val result_ : State[Position, List[String]] = moves.traverse_(move)
result_.run(Position()).value // (Position(-1,-1),Unit)
Finally, if your A type in State[S,A] forms a Monoid, then you could also use foldMap from Foldable to combine (e.g. fold) all As as they are calculated. A trivial example (probably useless, because this will just concatenate all Strings) would be this:
val result : State[Position,String] = moves.foldMap(move)
result.run(Position()).value // (Position(-1,-1),Down!Down!Left!Up!)
Whether this final approach is useful or not to you, really depends on what A you have and if it makes sense to combine it.
And this should be all you need in your scenario.
I'm trying to implement a functional Breadth First Search in Scala to compute the distances between a given node and all the other nodes in an unweighted graph. I've used a State Monad for this with the signature as :-
case class State[S,A](run:S => (A,S))
Other functions such as map, flatMap, sequence, modify etc etc are similar to what you'd find inside a standard State Monad.
Here's the code :-
case class Node(label: Int)
case class BfsState(q: Queue[Node], nodesList: List[Node], discovered: Set[Node], distanceFromSrc: Map[Node, Int]) {
val isTerminated = q.isEmpty
}
case class Graph(adjList: Map[Node, List[Node]]) {
def bfs(src: Node): (List[Node], Map[Node, Int]) = {
val initialBfsState = BfsState(Queue(src), List(src), Set(src), Map(src -> 0))
val output = bfsComp(initialBfsState)
(output.nodesList,output.distanceFromSrc)
}
#tailrec
private def bfsComp(currState:BfsState): BfsState = {
if (currState.isTerminated) currState
else bfsComp(searchNode.run(currState)._2)
}
private def searchNode: State[BfsState, Unit] = for {
node <- State[BfsState, Node](s => {
val (n, newQ) = s.q.dequeue
(n, s.copy(q = newQ))
})
s <- get
_ <- sequence(adjList(node).filter(!s.discovered(_)).map(n => {
modify[BfsState](s => {
s.copy(s.q.enqueue(n), n :: s.nodesList, s.discovered + n, s.distanceFromSrc + (n -> (s.distanceFromSrc(node) + 1)))
})
}))
} yield ()
}
Please can you advice on :-
Should the State Transition on dequeue in the searchNode function be a member of BfsState itself?
How do I make this code more performant/concise/readable?
First off, I suggest moving all the private defs related to bfs into bfs itself. This is the convention for methods that are solely used to implement another.
Second, I suggest simply not using State for this matter. State (like most monads) is about composition. It is useful when you have many things that all need access to the same global state. In this case, BfsState is specialized to bfs, will likely never be used anywhere else (it might be a good idea to move the class into bfs too), and the State itself is always run, so the outer world never sees it. (In many cases, this is fine, but here the scope is too small for State to be useful.) It'd be much cleaner to pull the logic of searchNode into bfsComp itself.
Third, I don't understand why you need both nodesList and discovered, when you can just call _.toList on discovered once you've done your computation. I've left it in in my reimplementation, though, in case there's more to this code that you haven't displayed.
def bfsComp(old: BfsState): BfsState = {
if(old.q.isEmpty) old // You don't need isTerminated, I think
else {
val (currNode, newQ) = old.q.dequeue
val newState = old.copy(q = newQ)
adjList(curNode)
.filterNot(s.discovered) // Set[T] <: T => Boolean and filterNot means you don't need to write !s.discovered(_)
.foldLeft(newState) { case (BfsState(q, nodes, discovered, distance), adjNode) =>
BfsState(
q.enqueue(adjNode),
adjNode :: nodes,
discovered + adjNode,
distance + (adjNode -> (distance(currNode) + 1)
)
}
}
}
def bfs(src: Node): (List[Node], Map[Node, Int]) = {
// I suggest moving BfsState and bfsComp into this method
val output = bfsComp(BfsState(Queue(src), List(src), Set(src), Map(src -> 0)))
(output.nodesList, output.distanceFromSrc)
// Could get rid of nodesList and say output.discovered.toList
}
In the event that you think you do have a good reason for using State here, here are my thoughts.
You use def searchNode. The point of a State is that it is pure and immutable, so it should be a val, or else you reconstruct the same State every use.
You write:
node <- State[BfsState, Node](s => {
val (n, newQ) = s.q.dequeue
(n, s.copy(q = newQ))
})
First off, Scala's syntax was designed so that you don't need to have both a () and {} surrounding an anonymous function:
node <- State[BfsState, Node] { s =>
// ...
}
Second, this doesn't look quite right to me. One benefit of using for-syntax is that the anonymous functions are hidden from you and there is minimal indentation. I'd just write it out
oldState <- get
(node, newQ) = oldState.q.dequeue
newState = oldState.copy(q = newQ)
Footnote: would it make sense to make Node an inner class of Graph? Just a suggestion.
I have a question about writing recursive algorithms in a functional style. I will use Scala for my example here, but the question applies to any functional language.
I am doing a depth-first enumeration of an n-ary tree where each node has a label and a variable number of children. Here is a simple implementation that prints the labels of the leaf nodes.
case class Node[T](label:T, ns:Node[T]*)
def dfs[T](r:Node[T]):Seq[T] = {
if (r.ns.isEmpty) Seq(r.label) else for (n<-r.ns;c<-dfs(n)) yield c
}
val r = Node('a, Node('b, Node('d), Node('e, Node('f))), Node('c))
dfs(r) // returns Seq[Symbol] = ArrayBuffer('d, 'f, 'c)
Now say that sometimes I want to be able to give up on parsing oversize trees by throwing an exception. Is this possible in a functional language? Specifically is this possible without using mutable state? That seems to depend on what you mean by "oversize". Here is a purely functional version of the algorithm that throws an exception when it tries to handle a tree with a depth of 3 or greater.
def dfs[T](r:Node[T], d:Int = 0):Seq[T] = {
require(d < 3)
if (r.ns.isEmpty) Seq(r.label) else for (n<-r.ns;c<-dfs(n, d+1)) yield c
}
But what if a tree is oversized because it is too broad rather than too deep? Specifically what if I want to throw an exception the n-th time the dfs() function is called recursively regardless of how deep the recursion goes? The only way I can see how to do this is to have a mutable counter that is incremented with each call. I can't see how to do it without a mutable variable.
I'm new to functional programming and have been working under the assumption that anything you can do with mutable state can be done without, but I don't see the answer here. The only thing I can think to do is write a version of dfs() that returns a view over all the nodes in the tree in depth-first order.
dfs[T](r:Node[T]):TraversableView[T, Traversable[_]] = ...
Then I could impose my limit by saying dfs(r).take(n), but I don't see how to write this function. In Python I'd just create a generator by yielding nodes as I visited them, but I don't see how to achieve the same effect in Scala. (Scala's equivalent to a Python-style yield statement appears to be a visitor function passed in as a parameter, but I can't figure out how to write one of these that will generate a sequence view.)
EDIT Getting close to the answer.
Here is an function that returns a Stream of nodes in depth-first order.
def dfs[T](r: Node[T]): Stream[Node[T]] = {
(r #:: Stream.empty /: r.ns)(_ ++ dfs(_))
}
That is almost it. The only problem is that Stream memoizes all results, which is a waste of memory. I want a traversable view. The following is the idea, but does not compile.
def dfs[T](r: Node[T]): TraversableView[Node[T], Traversable[Node[T]]] = {
(Traversable(r).view /: r.ns)(_ ++ dfs(_))
}
It gives a "found TraversableView[Node[T], Traversable[Node[T]]], required TraversableView[Node[T], Traversable[_]] error for the ++ operator. If I change the return type to TraversableView[Node[T], Traversable[_]], I get the same problem with the "found" and "required" clauses switched. So there's some magic type variance incantation I haven't lit upon yet, but this is close.
It can be done: you just have to write some code to actually iterate through the children in the way you want (as opposed to relying on for).
More explicitly, you'll have to write code to iterate through a list of children and check if the "depth" crossed your threshold. Here's some Haskell code (I'm really sorry, I'm not fluent in Scala, but this can probably be easily transliterated):
http://ideone.com/O5gvhM
In this code, I've basically replaced the for loop for an explicit recursive version. This allows me to stop the recursion if the number of visited nodes is already too deep (i.e., limit is not positive). When I recurse to examine the next child, I subtract the number of nodes the dfs of the previous child visited and set this as the limit for the next child.
Functional languages are fun, but they're a huge leap from imperative programming. It really makes you pay attention to the concept of state, because all of it is excruciatingly explicit in the arguments when you go functional.
EDIT: Explaining this a bit more.
I ended up converting from "print just the leaf nodes" (which was the original algorithm from the OP) to "print all nodes". This enabled me to have access to the number of nodes the subcall visited through the length of the resulting list. If you want to stick to the leaf nodes, you'll have to carry around how many nodes you have already visited:
http://ideone.com/cIQrna
EDIT again To clear up this answer, I'm putting all the Haskell code on ideone, and I've transliterated my Haskell code to Scala, so this can stay here as the definite answer to the question:
case class Node[T](label:T, children:Seq[Node[T]])
case class TraversalResult[T](num_visited:Int, labels:Seq[T])
def dfs[T](node:Node[T], limit:Int):TraversalResult[T] =
limit match {
case 0 => TraversalResult(0, Nil)
case limit =>
node.children match {
case Nil => TraversalResult(1, List(node.label))
case children => {
val result = traverse(node.children, limit - 1)
TraversalResult(result.num_visited + 1, result.labels)
}
}
}
def traverse[T](children:Seq[Node[T]], limit:Int):TraversalResult[T] =
limit match {
case 0 => TraversalResult(0, Nil)
case limit =>
children match {
case Nil => TraversalResult(0, Nil)
case first :: rest => {
val trav_first = dfs(first, limit)
val trav_rest =
traverse(rest, limit - trav_first.num_visited)
TraversalResult(
trav_first.num_visited + trav_rest.num_visited,
trav_first.labels ++ trav_rest.labels
)
}
}
}
val n = Node(0, List(
Node(1, List(Node(2, Nil), Node(3, Nil))),
Node(4, List(Node(5, List(Node(6, Nil))))),
Node(7, Nil)
))
for (i <- 1 to 8)
println(dfs(n, i))
Output:
TraversalResult(1,List())
TraversalResult(2,List())
TraversalResult(3,List(2))
TraversalResult(4,List(2, 3))
TraversalResult(5,List(2, 3))
TraversalResult(6,List(2, 3))
TraversalResult(7,List(2, 3, 6))
TraversalResult(8,List(2, 3, 6, 7))
P.S. this is my first attempt at Scala, so the above probably contains some horrid non-idiomatic code. I'm sorry.
You can convert breadth into depth by passing along an index or taking the tail:
def suml(xs: List[Int], total: Int = 0) = xs match {
case Nil => total
case x :: rest => suml(rest, total+x)
}
def suma(xs: Array[Int], from: Int = 0, total: Int = 0) = {
if (from >= xs.length) total
else suma(xs, from+1, total + xs(from))
}
In the latter case, you already have something to limit your breadth if you want; in the former, just add a width or somesuch.
The following implements a lazy depth-first search over nodes in a tree.
import collection.TraversableView
case class Node[T](label: T, ns: Node[T]*)
def dfs[T](r: Node[T]): TraversableView[Node[T], Traversable[Node[T]]] =
(Traversable[Node[T]](r).view /: r.ns) {
(a, b) => (a ++ dfs(b)).asInstanceOf[TraversableView[Node[T], Traversable[Node[T]]]]
}
This prints the labels of all the nodes in depth-first order.
val r = Node('a, Node('b, Node('d), Node('e, Node('f))), Node('c))
dfs(r).map(_.label).force
// returns Traversable[Symbol] = List('a, 'b, 'd, 'e, 'f, 'c)
This does the same thing, quitting after 3 nodes have been visited.
dfs(r).take(3).map(_.label).force
// returns Traversable[Symbol] = List('a, 'b, 'd)
If you want only leaf nodes you can use filter, and so forth.
Note that the fold clause of the dfs function requires an explicit asInstanceOf cast. See "Type variance error in Scala when doing a foldLeft over Traversable views" for a discussion of the Scala typing issues that necessitate this.
Much like this question:
Functional code for looping with early exit
Say the code is
def findFirst[T](objects: List[T]):T = {
for (obj <- objects) {
if (expensiveFunc(obj) != null) return /*???*/ Some(obj)
}
None
}
How to yield a single element from a for loop like this in scala?
I do not want to use find, as proposed in the original question, i am curious about if and how it could be implemented using the for loop.
* UPDATE *
First, thanks for all the comments, but i guess i was not clear in the question. I am shooting for something like this:
val seven = for {
x <- 1 to 10
if x == 7
} return x
And that does not compile. The two errors are:
- return outside method definition
- method main has return statement; needs result type
I know find() would be better in this case, i am just learning and exploring the language. And in a more complex case with several iterators, i think finding with for can actually be usefull.
Thanks commenters, i'll start a bounty to make up for the bad posing of the question :)
If you want to use a for loop, which uses a nicer syntax than chained invocations of .find, .filter, etc., there is a neat trick. Instead of iterating over strict collections like list, iterate over lazy ones like iterators or streams. If you're starting with a strict collection, make it lazy with, e.g. .toIterator.
Let's see an example.
First let's define a "noisy" int, that will show us when it is invoked
def noisyInt(i : Int) = () => { println("Getting %d!".format(i)); i }
Now let's fill a list with some of these:
val l = List(1, 2, 3, 4).map(noisyInt)
We want to look for the first element which is even.
val r1 = for(e <- l; val v = e() ; if v % 2 == 0) yield v
The above line results in:
Getting 1!
Getting 2!
Getting 3!
Getting 4!
r1: List[Int] = List(2, 4)
...meaning that all elements were accessed. That makes sense, given that the resulting list contains all even numbers. Let's iterate over an iterator this time:
val r2 = (for(e <- l.toIterator; val v = e() ; if v % 2 == 0) yield v)
This results in:
Getting 1!
Getting 2!
r2: Iterator[Int] = non-empty iterator
Notice that the loop was executed only up to the point were it could figure out whether the result was an empty or non-empty iterator.
To get the first result, you can now simply call r2.next.
If you want a result of an Option type, use:
if(r2.hasNext) Some(r2.next) else None
Edit Your second example in this encoding is just:
val seven = (for {
x <- (1 to 10).toIterator
if x == 7
} yield x).next
...of course, you should be sure that there is always at least a solution if you're going to use .next. Alternatively, use headOption, defined for all Traversables, to get an Option[Int].
You can turn your list into a stream, so that any filters that the for-loop contains are only evaluated on-demand. However, yielding from the stream will always return a stream, and what you want is I suppose an option, so, as a final step you can check whether the resulting stream has at least one element, and return its head as a option. The headOption function does exactly that.
def findFirst[T](objects: List[T], expensiveFunc: T => Boolean): Option[T] =
(for (obj <- objects.toStream if expensiveFunc(obj)) yield obj).headOption
Why not do exactly what you sketched above, that is, return from the loop early? If you are interested in what Scala actually does under the hood, run your code with -print. Scala desugares the loop into a foreach and then uses an exception to leave the foreach prematurely.
So what you are trying to do is to break out a loop after your condition is satisfied. Answer here might be what you are looking for. How do I break out of a loop in Scala?.
Overall, for comprehension in Scala is translated into map, flatmap and filter operations. So it will not be possible to break out of these functions unless you throw an exception.
If you are wondering, this is how find is implemented in LineerSeqOptimized.scala; which List inherits
override /*IterableLike*/
def find(p: A => Boolean): Option[A] = {
var these = this
while (!these.isEmpty) {
if (p(these.head)) return Some(these.head)
these = these.tail
}
None
}
This is a horrible hack. But it would get you the result you wished for.
Idiomatically you'd use a Stream or View and just compute the parts you need.
def findFirst[T](objects: List[T]): T = {
def expensiveFunc(o : T) = // unclear what should be returned here
case class MissusedException(val data: T) extends Exception
try {
(for (obj <- objects) {
if (expensiveFunc(obj) != null) throw new MissusedException(obj)
})
objects.head // T must be returned from loop, dummy
} catch {
case MissusedException(obj) => obj
}
}
Why not something like
object Main {
def main(args: Array[String]): Unit = {
val seven = (for (
x <- 1 to 10
if x == 7
) yield x).headOption
}
}
Variable seven will be an Option holding Some(value) if value satisfies condition
I hope to help you.
I think ... no 'return' impl.
object TakeWhileLoop extends App {
println("first non-null: " + func(Seq(null, null, "x", "y", "z")))
def func[T](seq: Seq[T]): T = if (seq.isEmpty) null.asInstanceOf[T] else
seq(seq.takeWhile(_ == null).size)
}
object OptionLoop extends App {
println("first non-null: " + func(Seq(null, null, "x", "y", "z")))
def func[T](seq: Seq[T], index: Int = 0): T = if (seq.isEmpty) null.asInstanceOf[T] else
Option(seq(index)) getOrElse func(seq, index + 1)
}
object WhileLoop extends App {
println("first non-null: " + func(Seq(null, null, "x", "y", "z")))
def func[T](seq: Seq[T]): T = if (seq.isEmpty) null.asInstanceOf[T] else {
var i = 0
def obj = seq(i)
while (obj == null)
i += 1
obj
}
}
objects iterator filter { obj => (expensiveFunc(obj) != null } next
The trick is to get some lazy evaluated view on the colelction, either an iterator or a Stream, or objects.view. The filter will only execute as far as needed.