Scala function call without "." (dot) vs using "." (dot) - scala

Can someone help me understand what is going on here. I have this definition for generating primes:
def primes: Stream[Long] = {
2 #:: 3 #:: 5 #:: 7 #::Stream.iterate(11L)(_ + 2).filter {
n => primes takeWhile (p => p*p <= n) forall (n % _ != 0)
}
}
def primes: Stream[Long] = {
2 #:: 3 #:: 5 #:: 7 #::Stream.iterate(11L)(_ + 2) filter {
n => primes takeWhile (p => p*p <= n) forall (n % _ != 0)
}
}
As you can see, both definitions are exactly similar except for the fact that the second one does not have a . before filter, whereas the first one does.
The problem is that running the first one, runs as expected and gives us primes, but the second one produces a java.lang.StackOverflowError.
Can someone shed some light on this? What is being passed to filter in either case?
Scala version: 2.11.6
Java version: 1.8.0_121
This is the full program I used to test each one:
object Main {
def primes: Stream[Long] = {
2 #:: 3 #:: 5 #:: 7 #::Stream.iterate(11L)(_ + 2) filter {
n => primes takeWhile (_ <= sqrt(n)) forall (n % _ != 0)
}
}
def primes2: Stream[Long] = {
2 #:: 3 #:: 5 #:: 7 #::Stream.iterate(11L)(_ + 2).filter {
n => primes2 takeWhile (p => p*p <= n) forall (n % _ != 0)
}
}
def main(args: Array[String]): Unit = {
println(primes.take(args.head.toInt).force)
}
}

The notation without . has the same precedence as that of any custom infix. So the first one applies filter only to Stream.iterate(11L)(_ + 2) - the second one applies it to 2 #:: 3 #:: 5 #:: 7 #::Stream.iterate(11L)(_ + 2).
The reason why the first one works is that the elements 2, 3, 5 and 7 are already in primes when the filter runs, so when the filter tries to use primes, those elements are already in it.
In the second code that's not the case because the filter is applied to those elements as well, meaning they wouldn't appear in primes until the filter returned true for them. But the filter needs to get elements from prime before it can return anything, so it loses itself in infinite recursion while trying to get at an element.

You need parenthesis:
def primes: Stream[Long] = {
2 #:: 3 #:: 5 #:: 7 #::(Stream.iterate(11L)(_ + 2) filter {
n => primes takeWhile (p => p*p <= n) forall (n % _ != 0)
})
}
As a rule of thumb, I usually use dots everywhere. It's easier to read and makes these sort of bugs harder to appear.

Related

Convert normal recursion to tail recursion

I was wondering if there is some general method to convert a "normal" recursion with foo(...) + foo(...) as the last call to a tail-recursion.
For example (scala):
def pascal(c: Int, r: Int): Int = {
if (c == 0 || c == r) 1
else pascal(c - 1, r - 1) + pascal(c, r - 1)
}
A general solution for functional languages to convert recursive function to a tail-call equivalent:
A simple way is to wrap the non tail-recursive function in the Trampoline monad.
def pascalM(c: Int, r: Int): Trampoline[Int] = {
if (c == 0 || c == r) Trampoline.done(1)
else for {
a <- Trampoline.suspend(pascal(c - 1, r - 1))
b <- Trampoline.suspend(pascal(c, r - 1))
} yield a + b
}
val pascal = pascalM(10, 5).run
So the pascal function is not a recursive function anymore. However, the Trampoline monad is a nested structure of the computation that need to be done. Finally, run is a tail-recursive function that walks through the tree-like structure, interpreting it, and finally at the base case returns the value.
A paper from Rúnar Bjanarson on the subject of Trampolines: Stackless Scala With Free Monads
In cases where there is a simple modification to the value of a recursive call, that operation can be moved to the front of the recursive function. The classic example of this is Tail recursion modulo cons, where a simple recursive function in this form:
def recur[A](...):List[A] = {
...
x :: recur(...)
}
which is not tail recursive, is transformed into
def recur[A]{...): List[A] = {
def consRecur(..., consA: A): List[A] = {
consA :: ...
...
consrecur(..., ...)
}
...
consrecur(...,...)
}
Alexlv's example is a variant of this.
This is such a well known situation that some compilers (I know of Prolog and Scheme examples but Scalac does not do this) can detect simple cases and perform this optimisation automatically.
Problems combining multiple calls to recursive functions have no such simple solution. TMRC optimisatin is useless, as you are simply moving the first recursive call to another non-tail position. The only way to reach a tail-recursive solution is remove all but one of the recursive calls; how to do this is entirely context dependent but requires finding an entirely different approach to solving the problem.
As it happens, in some ways your example is similar to the classic Fibonnaci sequence problem; in that case the naive but elegant doubly-recursive solution can be replaced by one which loops forward from the 0th number.
def fib (n: Long): Long = n match {
case 0 | 1 => n
case _ => fib( n - 2) + fib( n - 1 )
}
def fib (n: Long): Long = {
def loop(current: Long, next: => Long, iteration: Long): Long = {
if (n == iteration)
current
else
loop(next, current + next, iteration + 1)
}
loop(0, 1, 0)
}
For the Fibonnaci sequence, this is the most efficient approach (a streams based solution is just a different expression of this solution that can cache results for subsequent calls). Now,
you can also solve your problem by looping forward from c0/r0 (well, c0/r2) and calculating each row in sequence - the difference being that you need to cache the entire previous row. So while this has a similarity to fib, it differs dramatically in the specifics and is also significantly less efficient than your original, doubly-recursive solution.
Here's an approach for your pascal triangle example which can calculate pascal(30,60) efficiently:
def pascal(column: Long, row: Long):Long = {
type Point = (Long, Long)
type Points = List[Point]
type Triangle = Map[Point,Long]
def above(p: Point) = (p._1, p._2 - 1)
def aboveLeft(p: Point) = (p._1 - 1, p._2 - 1)
def find(ps: Points, t: Triangle): Long = ps match {
// Found the ultimate goal
case (p :: Nil) if t contains p => t(p)
// Found an intermediate point: pop the stack and carry on
case (p :: rest) if t contains p => find(rest, t)
// Hit a triangle edge, add it to the triangle
case ((c, r) :: _) if (c == 0) || (c == r) => find(ps, t + ((c,r) -> 1))
// Triangle contains (c - 1, r - 1)...
case (p :: _) if t contains aboveLeft(p) => if (t contains above(p))
// And it contains (c, r - 1)! Add to the triangle
find(ps, t + (p -> (t(aboveLeft(p)) + t(above(p)))))
else
// Does not contain(c, r -1). So find that
find(above(p) :: ps, t)
// If we get here, we don't have (c - 1, r - 1). Find that.
case (p :: _) => find(aboveLeft(p) :: ps, t)
}
require(column >= 0 && row >= 0 && column <= row)
(column, row) match {
case (c, r) if (c == 0) || (c == r) => 1
case p => find(List(p), Map())
}
}
It's efficient, but I think it shows how ugly complex recursive solutions can become as you deform them to become tail recursive. At this point, it may be worth moving to a different model entirely. Continuations or monadic gymnastics might be better.
You want a generic way to transform your function. There isn't one. There are helpful approaches, that's all.
I don't know how theoretical this question is, but a recursive implementation won't be efficient even with tail-recursion. Try computing pascal(30, 60), for example. I don't think you'll get a stack overflow, but be prepared to take a long coffee break.
Instead, consider using a Stream or memoization:
val pascal: Stream[Stream[Long]] =
(Stream(1L)
#:: (Stream from 1 map { i =>
// compute row i
(1L
#:: (pascal(i-1) // take the previous row
sliding 2 // and add adjacent values pairwise
collect { case Stream(a,b) => a + b }).toStream
++ Stream(1L))
}))
The accumulator approach
def pascal(c: Int, r: Int): Int = {
def pascalAcc(acc:Int, leftover: List[(Int, Int)]):Int = {
if (leftover.isEmpty) acc
else {
val (c1, r1) = leftover.head
// Edge.
if (c1 == 0 || c1 == r1) pascalAcc(acc + 1, leftover.tail)
// Safe checks.
else if (c1 < 0 || r1 < 0 || c1 > r1) pascalAcc(acc, leftover.tail)
// Add 2 other points to accumulator.
else pascalAcc(acc, (c1 , r1 - 1) :: ((c1 - 1, r1 - 1) :: leftover.tail ))
}
}
pascalAcc(0, List ((c,r) ))
}
It does not overflow the stack but as on big row and column but Aaron mentioned it's not fast.
Yes it's possible. Usually it's done with accumulator pattern through some internally defined function, which has one additional argument with so called accumulator logic, example with counting length of a list.
For example normal recursive version would look like this:
def length[A](xs: List[A]): Int = if (xs.isEmpty) 0 else 1 + length(xs.tail)
that's not a tail recursive version, in order to eliminate last addition operation we have to accumulate values while somehow, for example with accumulator pattern:
def length[A](xs: List[A]) = {
def inner(ys: List[A], acc: Int): Int = {
if (ys.isEmpty) acc else inner(ys.tail, acc + 1)
}
inner(xs, 0)
}
a bit longer to code, but i think the idea i clear. Of cause you can do it without inner function, but in such case you should provide acc initial value manually.
I'm pretty sure it's not possible in the simple way you're looking for the general case, but it would depend on how elaborate you permit the changes to be.
A tail-recursive function must be re-writable as a while-loop, but try implementing for example a Fractal Tree using while-loops. It's possble, but you need to use an array or collection to store the state for each point, which susbstitutes for the data otherwise stored in the call-stack.
It's also possible to use trampolining.
It is indeed possible. The way I'd do this is to
begin with List(1) and keep recursing till you get to the
row you want.
Worth noticing that you can optimize it: if c==0 or c==r the value is one, and to calculate let's say column 3 of the 100th row you still only need to calculate the first three elements of the previous rows.
A working tail recursive solution would be this:
def pascal(c: Int, r: Int): Int = {
#tailrec
def pascalAcc(c: Int, r: Int, acc: List[Int]): List[Int] = {
if (r == 0) acc
else pascalAcc(c, r - 1,
// from let's say 1 3 3 1 builds 0 1 3 3 1 0 , takes only the
// subset that matters (if asking for col c, no cols after c are
// used) and uses sliding to build (0 1) (1 3) (3 3) etc.
(0 +: acc :+ 0).take(c + 2)
.sliding(2, 1).map { x => x.reduce(_ + _) }.toList)
}
if (c == 0 || c == r) 1
else pascalAcc(c, r, List(1))(c)
}
The annotation #tailrec actually makes the compiler check the function
is actually tail recursive.
It could be probably be further optimized since given that the rows are symmetric, if c > r/2, pascal(c,r) == pascal ( r-c,r).. but left to the reader ;)

Scala - can 'for-yield' clause yields nothing for some condition?

In Scala language, I want to write a function that yields odd numbers within a given range. The function prints some log when iterating even numbers. The first version of the function is:
def getOdds(N: Int): Traversable[Int] = {
val list = new mutable.MutableList[Int]
for (n <- 0 until N) {
if (n % 2 == 1) {
list += n
} else {
println("skip even number " + n)
}
}
return list
}
If I omit printing logs, the implementation become very simple:
def getOddsWithoutPrint(N: Int) =
for (n <- 0 until N if (n % 2 == 1)) yield n
However, I don't want to miss the logging part. How do I rewrite the first version more compactly? It would be great if it can be rewritten similar to this:
def IWantToDoSomethingSimilar(N: Int) =
for (n <- 0 until N) if (n % 2 == 1) yield n else println("skip even number " + n)
def IWantToDoSomethingSimilar(N: Int) =
for {
n <- 0 until N
if n % 2 != 0 || { println("skip even number " + n); false }
} yield n
Using filter instead of a for expression would be slightly simpler though.
I you want to keep the sequentiality of your traitement (processing odds and evens in order, not separately), you can use something like that (edited) :
def IWantToDoSomethingSimilar(N: Int) =
(for (n <- (0 until N)) yield {
if (n % 2 == 1) {
Option(n)
} else {
println("skip even number " + n)
None
}
// Flatten transforms the Seq[Option[Int]] into Seq[Int]
}).flatten
EDIT, following the same concept, a shorter solution :
def IWantToDoSomethingSimilar(N: Int) =
(0 until N) map {
case n if n % 2 == 0 => println("skip even number "+ n)
case n => n
} collect {case i:Int => i}
If you will to dig into a functional approach, something like the following is a good point to start.
First some common definitions:
// use scalaz 7
import scalaz._, Scalaz._
// transforms a function returning either E or B into a
// function returning an optional B and optionally writing a log of type E
def logged[A, E, B, F[_]](f: A => E \/ B)(
implicit FM: Monoid[F[E]], FP: Pointed[F]): (A => Writer[F[E], Option[B]]) =
(a: A) => f(a).fold(
e => Writer(FP.point(e), None),
b => Writer(FM.zero, Some(b)))
// helper for fixing the log storage format to List
def listLogged[A, E, B](f: A => E \/ B) = logged[A, E, B, List](f)
// shorthand for a String logger with List storage
type W[+A] = Writer[List[String], A]
Now all you have to do is write your filtering function:
def keepOdd(n: Int): String \/ Int =
if (n % 2 == 1) \/.right(n) else \/.left(n + " was even")
You can try it instantly:
scala> List(5, 6) map(keepOdd)
res0: List[scalaz.\/[String,Int]] = List(\/-(5), -\/(6 was even))
Then you can use the traverse function to apply your function to a list of inputs, and collect both the logs written and the results:
scala> val x = List(5, 6).traverse[W, Option[Int]](listLogged(keepOdd))
x: W[List[Option[Int]]] = scalaz.WriterTFunctions$$anon$26#503d0400
// unwrap the results
scala> x.run
res11: (List[String], List[Option[Int]]) = (List(6 was even),List(Some(5), None))
// we may even drop the None-s from the output
scala> val (logs, results) = x.map(_.flatten).run
logs: List[String] = List(6 was even)
results: List[Int] = List(5)
I don't think this can be done easily with a for comprehension. But you could use partition.
def getOffs(N:Int) = {
val (evens, odds) = 0 until N partition { x => x % 2 == 0 }
evens foreach { x => println("skipping " + x) }
odds
}
EDIT: To avoid printing the log messages after the partitioning is done, you can change the first line of the method like this:
val (evens, odds) = (0 until N).view.partition { x => x % 2 == 0 }

How to fix my Fibonacci stream in Scala

I defined a function to return Fibonacci stream as follows:
def fib:Stream[Int] = {
Stream.cons(1,
Stream.cons(2,
(fib zip fib.tail) map {case (x, y) => println("%s + %s".format(x, y)); x + y}))
}
The functions work ok but it looks inefficient (see the output below)
scala> fib take 5 foreach println
1
2
1 + 2
3
1 + 2
2 + 3
5
1 + 2
1 + 2
2 + 3
3 + 5
8
So, it looks like the function calculates the n-th fibonacci number from the very beginning. Is it correct? How would you fix it?
That is because you have used a def. Try using a val:
lazy val fib: Stream[Int]
= 1 #:: 2 #:: (fib zip fib.tail map { case (x, y) => x + y })
Basically a def is a method; in your example you are calling the method each time and each time the method call constructs a new stream. The distinction between def and val has been covered on SO before, so I won't go into detail here. If you are from a Java background, it should be pretty clear.
This is another nice thing about scala; in Java, methods may be recursive but types and values may not be. In scala both values and types can be recursive.
You can do it the other way:
lazy val fibs = {
def f(a: Int, b: Int): Stream[Int] = a #:: f(b, a + b)
f(0, 1)
}

Pattern matching and infinite streams

So, I'm working to teach myself Scala, and one of the things I've been playing with is the Stream class. I tried to use a naïve translation of the classic Haskell version of Dijkstra's solution to the Hamming number problem:
object LazyHammingBad {
private def merge(a: Stream[BigInt], b: Stream[BigInt]): Stream[BigInt] =
(a, b) match {
case (x #:: xs, y #:: ys) =>
if (x < y) x #:: merge(xs, b)
else if (y < x) y #:: merge(a, ys)
else x #:: merge(xs, ys)
}
val numbers: Stream[BigInt] =
1 #:: merge(numbers map { _ * 2 },
merge(numbers map { _ * 3 }, numbers map { _ * 5 }))
}
Taking this for a spin in the interpreter led quickly to disappointment:
scala> LazyHammingBad.numbers.take(10).toList
java.lang.StackOverflowError
I decided to look to see if other people had solved the problem in Scala using the Haskell approach, and adapted this solution from Rosetta Code:
object LazyHammingGood {
private def merge(a: Stream[BigInt], b: Stream[BigInt]): Stream[BigInt] =
if (a.head < b.head) a.head #:: merge(a.tail, b)
else if (b.head < a.head) b.head #:: merge(a, b.tail)
else a.head #:: merge(a.tail, b.tail)
val numbers: Stream[BigInt] =
1 #:: merge(numbers map {_ * 2},
merge(numbers map {_ * 3}, numbers map {_ * 5}))
}
This one worked nicely, but I still wonder how I went wrong in LazyHammingBad. Does using #:: to destructure x #:: xs force the evaluation of xs for some reason? Is there any way to use pattern matching safely with infinite streams, or do you just have to use head and tail if you don't want things to blow up?
a match {case x#::xs =>... is about the same as val (x, xs) = (a.head, a.tail). So the difference between the bad version and the good one, is that in that in the bad version, you're calling a.tail and b.tail right at the start, instead of just use them to build the tail of the resulting stream. Furthermore when you use them at the right of #:: (not pattern matching, but building the result, as in #:: merge(a.b.tail) you are not actually calling merge, that will be done only later, when accessing the tail of the returned Stream. So in the good version, a call to merge does not call tail at all. In the bad version, it calls it right at start.
Now if you consider numbers, or even a simplified version, say 1 #:: merge(numbers, anotherStream), when you call you call tail on that (as take(10) will), merge has to be evaluated. You call tail on numbers, which call merge with numbers as parameters, which calls tails on numbers, which calls merge, which calls tail...
By contrast, in super lazy Haskell, when you pattern match, it does barely any work. When you do case l of x:xs, it will evaluate l just enough to know whether it is an empty list or a cons.
If it is indeed a cons, x and xs will be available as two thunks, functions that will eventually give access, later, to content. The closest equivalent in Scala would be to just test empty.
Note also that in Scala Stream, while the tail is lazy, the head is not. When you have a (non empty) Stream, the head has to be known. Which means that when you get the tail of the stream, itself a stream, its head, that is the second element of the original stream, has to be computed. This is sometimes problematic, but in your example, you fail before even getting there.
Note that you can do what you want by defining a better pattern matcher for Stream:
Here's a bit I just pulled together in a Scala Worksheet:
object HammingTest {
// A convenience object for stream pattern matching
object #:: {
class TailWrapper[+A](s: Stream[A]) {
def unwrap = s.tail
}
object TailWrapper {
implicit def unwrap[A](wrapped: TailWrapper[A]) = wrapped.unwrap
}
def unapply[A](s: Stream[A]): Option[(A, TailWrapper[A])] = {
if (s.isEmpty) None
else {
Some(s.head, new TailWrapper(s))
}
}
}
def merge(a: Stream[BigInt], b: Stream[BigInt]): Stream[BigInt] =
(a, b) match {
case (x #:: xs, y #:: ys) =>
if (x < y) x #:: merge(xs, b)
else if (y < x) y #:: merge(a, ys)
else x #:: merge(xs, ys)
} //> merge: (a: Stream[BigInt], b: Stream[BigInt])Stream[BigInt]
lazy val numbers: Stream[BigInt] =
1 #:: merge(numbers map { _ * 2 }, merge(numbers map { _ * 3 }, numbers map { _ * 5 }))
//> numbers : Stream[BigInt] = <lazy>
numbers.take(10).toList //> res0: List[BigInt] = List(1, 2, 3, 4, 5, 6, 8, 9, 10, 12)
}
Now you just need to make sure that Scala finds your object #:: instead of the one in Stream.class whenever it's doing pattern matching. To facilitate that, it might be best to use a different name like #>: or ##:: and then just remember to always use that name when pattern matching.
If you ever need to match the empty stream, use case Stream.Empty. Using case Stream() will attempt to evaluate your entire stream there in the pattern match, which will lead to sadness.

Converting a sequence of map operations to a for-comprehension

I read in Programming in Scala section 23.5 that map, flatMap and filter operations can always be converted into for-comprehensions and vice-versa.
We're given the following equivalence:
def map[A, B](xs: List[A], f: A => B): List[B] =
for (x <- xs) yield f(x)
I have a value calculated from a series of map operations:
val r = (1 to 100).map{ i => (1 to 100).map{i % _ == 0} }
.map{ _.foldLeft(false)(_^_) }
.map{ case true => "open"; case _ => "closed" }
I'm wondering what this would look like as a for-comprehension. How do I translate it?
(If it's helpful, in words this is:
take integers from 1 to 100
for each, create a list of 100 boolean values
fold each list with an XOR operator, back into a boolean
yield a list of 100 Strings "open" or "closed" depending on the boolean
I imagine there is a standard way to translate map operations and the details of the actual functions in them is not important. I could be wrong though.)
Is this the kind of translation you're looking for?
for (i <- 1 to 100;
val x = (1 to 100).map(i % _ == 0);
val y = x.foldLeft(false)(_^_);
val z = y match { case true => "open"; case _ => "closed" })
yield z
If desired, the map in the definition of x could also be translated to an "inner" for-comprehension.
In retrospect, a series of chained map calls is sort of trivial, in that you could equivalently call map once with composed functions:
s.map(f).map(g).map(h) == s.map(f andThen g andThen h)
I find for-comprehensions to be a bigger win when flatMap and filter are involved. Consider
for (i <- 1 to 3;
j <- 1 to 3 if (i + j) % 2 == 0;
k <- 1 to 3) yield i ^ j ^ k
versus
(1 to 3).flatMap { i =>
(1 to 3).filter(j => (i + j) % 2 == 0).flatMap { j =>
(1 to 3).map { k => i ^ j ^ k }
}
}