For comprehension and number of function creation - scala

Recently I had an interview for Scala Developer position. I was asked such question
// matrix 100x100 (content unimportant)
val matrix = Seq.tabulate(100, 100) { case (x, y) => x + y }
// A
for {
row <- matrix
elem <- row
} print(elem)
// B
val func = print _
for {
row <- matrix
elem <- row
} func(elem)
and the question was: Which implementation, A or B, is more efficent?
We all know that for comprehensions can be translated to
// A
matrix.foreach(row => row.foreach(elem => print(elem)))
// B
matrix.foreach(row => row.foreach(func))
B can be written as matrix.foreach(row => row.foreach(print _))
Supposedly correct answer is B, because A will create function print 100 times more.
I have checked Language Specification but still fail to understand the answer. Can somebody explain this to me?

In short:
Example A is faster in theory, in practice you shouldn't be able to measure any difference though.
Long answer:
As you already found out
for {xs <- xxs; x <- xs} f(x)
is translated to
xxs.foreach(xs => xs.foreach(x => f(x)))
This is explained in §6.19 SLS:
A for loop
for ( p <- e; p' <- e' ... ) e''
where ... is a (possibly empty) sequence of generators, definitions, or guards, is translated to
e .foreach { case p => for ( p' <- e' ... ) e'' }
Now when one writes a function literal, one gets a new instance every time the function needs to be called (§6.23 SLS). This means that
xs.foreach(x => f(x))
is equivalent to
xs.foreach(new scala.Function1 { def apply(x: T) = f(x)})
When you introduce a local function type
val g = f _; xxs.foreach(xs => xs.foreach(x => g(x)))
you are not introducing an optimization because you still pass a function literal to foreach. In fact the code is slower because the inner foreach is translated to
xs.foreach(new scala.Function1 { def apply(x: T) = g.apply(x) })
where an additional call to the apply method of g happens. Though, you can optimize when you write
val g = f _; xxs.foreach(xs => xs.foreach(g))
because the inner foreach now is translated to
which means that the function g itself is passed to foreach.
This would mean that B is faster in theory, because no anonymous function needs to be created each time the body of the for comprehension is executed. However, the optimization mentioned above (that the function is directly passed to foreach) is not applied on for comprehensions, because as the spec says the translation includes the creation of function literals, therefore there are always unnecessary function objects created (here I must say that the compiler could optimize that as well, but it doesn't because optimization of for comprehensions is difficult and does still not happen in 2.11). All in all it means that A is more efficient but B would be more efficient if it is written without a for comprehension (and no function literal is created for the innermost function).
Nevertheless, all of these rules can only be applied in theory, because in practice there is the backend of scalac and the JVM itself which both can do optimizations - not to mention optimizations that are done by the CPU. Furthermore your example contains a syscall that is executed on every iteration - it is probably the most expensive operation here that outweighs everything else.

I'd agree with sschaef and say that A is the more efficient option.
Looking at the generated class files we get the following anonymous functions and their apply methods:
anonfun$2 -- row => row.foreach(new anonfun$2$$anonfun$1)
anonfun$2$$anonfun$1 -- elem => print(elem)
i.e. matrix.foreach(row => row.foreach(elem => print(elem)))
anonfun$3 -- x => print(x)
anonfun$4 -- row => row.foreach(new anonfun$4$$anonfun$2)
anonfun$4$$anonfun$2 -- elem => func(elem)
i.e. matrix.foreach(row => row.foreach(elem => func(elem)))
where func is just another indirection before calling to print. In addition func needs to be looked up, i.e. through a method call on an instance (this.func()) for each row.
So for Method B, 1 extra object is created (func) and there are # of elem additional function calls.
The most efficient option would be
matrix.foreach(row => row.foreach(func))
as this has the least number of objects created and does exactly as you would expect.

Method A is nearly 30% faster than method B.
Link to code:
Implementation Details
I added method C (two while loops) and D (fold, sum). I also increased the size of the matrix and used an IndexedSeq instead. Also I replaced the print with something less heavy (sum all entries).
Strangely the while construct is not the fastest. But if one uses Array instead of IndexedSeq it becomes the fastest by a large margin (factor 5, no boxing anymore). Using explicitly boxed integers, methods A, B, C are all equally fast. In particular they are faster by 50% compared to the implicitly boxed versions of A, B.
verify results
>$ scala -version
Scala code runner version 2.11.0 -- Copyright 2002-2013, LAMP/EPFL
Code excerpt
val matrix = IndexedSeq.tabulate(1000, 1000) { case (x, y) => x + y }
def variantA(): Int = {
var r = 0
for {
row <- matrix
elem <- row
r += elem
def variantB(): Int = {
var r = 0
val f = (x:Int) => r += x
for {
row <- matrix
elem <- row
} f(elem)
def variantC(): Int = {
var r = 0
var i1 = 0
while(i1 < matrix.size){
var i2 = 0
val row = matrix(i1)
while(i2 < row.size){
r += row(i2)
i2 += 1
i1 += 1
def variantD(): Int = matrix.foldLeft(0)(_ + _.sum)


How to count the number of iterations in a for comprehension in Scala?

I am using a for comprehension on a stream and I would like to know how many iterations took to get o the final results.
In code:
var count = 0
for {
xs <- xs_generator
x <- xs
count = count + 1 //doesn't work!!
if (x prop)
yield x
Is there a way to achieve this?
Edit: If you don't want to return only the first item, but the entire stream of solutions, take a look at the second part.
Edit-2: Shorter version with zipWithIndex appended.
It's not entirely clear what you are attempting to do. To me it seems as if you are trying to find something in a stream of lists, and additionaly save the number of checked elements.
If this is what you want, consider doing something like this:
/** Returns `x` that satisfies predicate `prop`
* as well the the total number of tested `x`s
def findTheX(): (Int, Int) = {
val xs_generator = Stream.from(1).map(a => (1 to a).toList).take(1000)
var count = 0
def prop(x: Int): Boolean = x % 317 == 0
for (xs <- xs_generator; x <- xs) {
count += 1
if (prop(x)) {
return (x, count)
throw new Exception("No solution exists")
// prints:
// (317,50403)
Several important points:
Scala's for-comprehension have nothing to do with Python's "yield". Just in case you thought they did: re-read the documentation on for-comprehensions.
There is no built-in syntax for breaking out of for-comprehensions. It's better to wrap it into a function, and then call return. There is also breakable though, but it works with Exceptions.
The function returns the found item and the total count of checked items, therefore the return type is (Int, Int).
The error in the end after the for-comprehension is to ensure that the return type is Nothing <: (Int, Int) instead of Unit, which is not a subtype of (Int, Int).
Think twice when you want to use Stream for such purposes in this way: after generating the first few elements, the Stream holds them in memory. This might lead to "GC-overhead limit exceeded"-errors if the Stream isn't used properly.
Just to emphasize it again: the yield in Scala for-comprehensions is unrelated to Python's yield. Scala has no built-in support for coroutines and generators. You don't need them as often as you might think, but it requires some readjustment.
I've re-read your question again. In case that you want an entire stream of solutions together with a counter of how many different xs have been checked, you might use something like that instead:
val xs_generator = Stream.from(1).map(a => (1 to a).toList)
var count = 0
def prop(x: Int): Boolean = x % 317 == 0
val xsWithCounter = for {
xs <- xs_generator;
x <- xs
_ = { count = count + 1 }
if (prop(x))
} yield (x, count)
// prints:
// List(
// (317,50403), (317,50721), (317,51040), (317,51360), (317,51681),
// (317,52003), (317,52326), (317,52650), (317,52975), (317,53301)
// )
Note the _ = { ... } part. There is a limited number of things that can occur in a for-comprehension:
generators (the x <- things)
filters/guards (if-s)
value definitions
Here, we sort-of abuse the value-definition syntax to update the counter. We use the block { counter += 1 } as the right hand side of the assignment. It returns Unit. Since we don't need the result of the block, we use _ as the left hand side of the assignment. In this way, this block is executed once for every x.
If mutating the counter is not your main goal, you can of course use the zipWithIndex directly:
val xsWithCounter =
xs_generator.flatten.zipWithIndex.filter{x => prop(x._1)}
It gives almost the same result as the previous version, but the indices are shifted by -1 (it's the indices, not the number of tried x-s).

Getting an error trying to map through a list in Scala

I'm trying to print out all the factors of every number in a list.
Here is my code:
def main(args: Array[String])
val list_of_numbers = List(1,4,6)
def get_factors(list_of_numbers:List[Int]) : Int =
return list_of_numbers.foreach{(1 to _).filter {divisor => _ % divisor == 0}}
I want the end result to contain a single list that will hold all the numbers which are factors of any of the numbers in the list. So the final result should be (1,2,3,4,6). Right now, I get the following error:
error: missing parameter type for expanded function ((x$1) =>$1))
return list_of_numbers.foreach{(1 to _).filter {divisor => _ % divisor == 0}}
How can I fix this?
You can only use _ shorthand once in a function (except for some special cases), and even then not always.
Try spelling it out instead:
list_of_numbers.foreach { n =>
(1 to n).filter { divisor => n % divisor == 0 }
This will compile.
There are other problems with your code though.
foreach returns a Unit, but you are requiring an Int for example.
Perhaps, you wanted a .map rather than .foreach, but that would still be a List, not an Int.
A few things are wrong here.
First, foreach takes a function A => Unit as an argument, meaning that it's really just for causing side effects.
Second your use of _, you can use _ when the function uses each argument once.
Lastly your expected output seems to be getting rid of duplicates (1 is a factor for all 3 inputs, but it only appears once).
list_of_numbers flatMap { i => (1 to i) filter {i % _ == 0 }} distinct
will do what you are looking for.
flatMap takes a function from A => List[B] and produces a simple List[B] as output, list.distinct gets rid of the duplicates.
Actually, there are several problems with your code.
First, foreach is a method which yields Unit (like void in Java). You want to yield something so you should use a for comprehension.
Second, in your divisor-test function, you've specified both the unnamed parameter ("_") and the named parameter (divisor).
The third problem is that you expect the result to be Int (in the code) but List[Int] in your description.
The following code will do what you want (although it will repeat factors, so you might want to pass it through distinct before using the result):
def main(args: Array[String]) {
val list_of_numbers = List(1, 4, 6)
def get_factors(list_of_numbers: List[Int]) = for (n <- list_of_numbers; r = 1 to n; f <- r.filter(n%_ == 0)) yield f
Note that you need two generators ("<-") in the for comprehension in order that you end up with simply a List. If you instead implemented the filter part in the yield expression, you would get a List[List[Int]].

Scala, Erastothenes: Is there a straightforward way to replace a stream with an iteration?

I wrote a function that generates primes indefinitely (wikipedia: incremental sieve of Erastothenes) usings streams. It returns a stream, but it also merges streams of prime multiples internally to mark upcoming composites. The definition is concise, functional, elegant and easy to understand, if I do say so myself:
def primes(): Stream[Int] = {
def merge(a: Stream[Int], b: Stream[Int]): Stream[Int] = {
def next = a.head min b.head
Stream.cons(next, merge(if (a.head == next) a.tail else a,
if (b.head == next) b.tail else b))
def test(n: Int, compositeStream: Stream[Int]): Stream[Int] = {
if (n == compositeStream.head) test(n+1, compositeStream.tail)
else Stream.cons(n, test(n+1, merge(compositeStream, Stream.from(n*n, n))))
test(2, Stream.from(4, 2))
But, I get a "java.lang.OutOfMemoryError: GC overhead limit exceeded" when I try to generate the 1000th prime.
I have an alternative solution that returns an iterator over primes and uses a priority queue of tuples (multiple, prime used to generate multiple) internally to mark upcoming composites. It works well, but it takes about twice as much code, and I basically had to restart from scratch:
import scala.collection.mutable.PriorityQueue
def primes(): Iterator[Int] = {
// Tuple (composite, prime) is used to generate a primes multiples
object CompositeGeneratorOrdering extends Ordering[(Long, Int)] {
def compare(a: (Long, Int), b: (Long, Int)) = b._1 compare a._1
var n = 2;
val composites = PriorityQueue(((n*n).toLong, n))(CompositeGeneratorOrdering)
def advance = {
while (n == composites.head._1) { // n is composite
while (n == composites.head._1) { // duplicate composites
val (multiple, prime) = composites.dequeue
composites.enqueue((multiple + prime, prime))
n += 1
assert(n < composites.head._1)
val prime = n
n += 1
composites.enqueue((prime.toLong * prime.toLong, prime))
Is there a straightforward way to translate the code with streams to code with iterators? Or is there a simple way to make my first attempt more memory efficient?
It's easier to think in terms of streams; I'd rather start that way, then tweak my code if necessary.
I guess it's a bug in current Stream implementation.
primes().drop(999).head works fine:
// Int = 7919
You'll get OutOfMemoryError with stored Stream like this:
val prs = primes()
// Exception in thread "main" java.lang.OutOfMemoryError: GC overhead limit exceeded
The problem here with class Cons implementation: it contains not only calculated tail, but also a function to calculate this tail. Even when the tail is calculated and function is not needed any more!
In this case functions are extremely heavy, so you'll get OutOfMemoryError even with 1000 functions stored.
We have to drop that functions somehow.
Intuitive fix is failed:
val prs = primes().iterator.toStream
// Exception in thread "main" java.lang.OutOfMemoryError: GC overhead limit exceeded
With iterator on Stream you'll get StreamIterator, with StreamIterator#toStream you'll get initial heavy Stream.
So we have to convert it manually:
def toNewStream[T](i: Iterator[T]): Stream[T] =
if (i.hasNext) Stream.cons(, toNewStream(i))
else Stream.empty
val prs = toNewStream(primes().iterator)
// Stream[Int] = Stream(2, ?)
// Int = 7919
In your first code, you should postpone the merging until the square of a prime is seen amongst the candidates. This will drastically reduce the number of streams in use, radically improving your memory usage issues. To get the 1000th prime, 7919, we only need to consider primes not above its square root, 88. That's just 23 primes/streams of their multiples, instead of 999 (22, if we ignore the evens from the outset). For the 10,000th prime, it's the difference between having 9999 streams of multiples and just 66. And for the 100,000th, only 189 are needed.
The trick is to separate the primes being consumed from the primes being produced, via a recursive invocation:
def primes(): Stream[Int] = {
def merge(a: Stream[Int], b: Stream[Int]): Stream[Int] = {
def next = a.head min b.head
Stream.cons(next, merge(if (a.head == next) a.tail else a,
if (b.head == next) b.tail else b))
def test(n: Int, q: Int,
compositeStream: Stream[Int],
primesStream: Stream[Int]): Stream[Int] = {
if (n == q) test(n+2, primesStream.tail.head*primesStream.tail.head,
Stream.from(q, 2*primesStream.head).tail),
else if (n == compositeStream.head) test(n+2, q, compositeStream.tail,
else Stream.cons(n, test(n+2, q, compositeStream, primesStream))
Stream.cons(2, Stream.cons(3, Stream.cons(5,
test(7, 25, Stream.from(9, 6), primes().tail.tail))))
As an added bonus, there's no need to store the squares of primes as Longs. This will also be much faster and have better algorithmic complexity (time and space) as it avoids doing a lot of superfluous work. Ideone testing shows it runs at about ~ n1.5..1.6 empirical orders of growth in producing up to n = 80,000 primes.
There's still an algorithmic problem here: the structure that is created here is still a linear left-leaning structure (((mults_of_2 + mults_of_3) + mults_of_5) + ...), with more frequently-producing streams situated deeper inside it (so the numbers have more levels to percolate through, going up). The right-leaning structure should be better, mults_of_2 + (mults_of_3 + (mults_of_5 + ...)). Making it a tree should bring a real improvement in time complexity (pushing it down typically to about ~ n1.2..1.25). For a related discussion, see this haskellwiki page.
The "real" imperative sieve of Eratosthenes usually runs at around ~ n1.1 (in n primes produced) and an optimal trial division sieve at ~ n1.40..1.45. Your original code runs at about cubic time, or worse. Using imperative mutable array is usually the fastest, working by segments (a.k.a. the segmented sieve of Eratosthenes).
In the context of your second code, this is how it is achieved in Python.
Is there a straightforward way to translate the code with streams to code with iterators? Or is there a simple way to make my first attempt more memory efficient?
#Will Ness has given you an improved answer using Streams and given reasons why your code is taking so much memory and time as in adding streams early and a left-leaning linear structure, but no one has completely answered the second (or perhaps main) part of your question as to can a true incremental Sieve of Eratosthenes be implemented with Iterator's.
First, we should properly credit this right-leaning algorithm of which your first code is a crude (left-leaning) example (since it prematurely adds all prime composite streams to the merge operations), which is due to Richard Bird as in the Epilogue of Melissa E. O'Neill's definitive paper on incremental Sieve's of Eratosthenes.
Second, no, it isn't really possible to substitute Iterator's for Stream's in this algorithm as it depends on moving through a stream without restarting the stream, and although one can access the head of an iterator (the current position), using the next value (skipping over the head) to generate the rest of the iteration as a stream requires building a completely new iterator at a terrible cost in memory and time. However, we can use an Iterator to output the results of the sequence of primes in order to minimize memory use and make it easy to use iterator higher order functions, as you will see in my code below.
Now Will Ness has walked you though the principles of postponing adding prime composite streams to the calculations until they are needed, which works well when one is storing these in a structure such as a Priority Queue or a HashMap and was even missed in the O'Neill paper, but for the Richard Bird algorithm this is not necessary as future stream values will not be accessed until needed so are not stored if the Streams are being properly lazily built (as is lazily and left-leaning). In fact, this algorithm doesn't even need the memorization and overheads of a full Stream as each composite number culling sequence only moves forward without reference to any past primes other than one needs a separate source of the base primes, which can be supplied by a recursive call of the same algorithm.
For ready reference, let's list the Haskell code of the Richard Bird algorithms as follows:
primes = 2:([3..] ‘minus‘ composites)
composites = union [multiples p | p <− primes]
multiples n = map (n*) [n..]
(x:xs) ‘minus‘ (y:ys)
| x < y = x:(xs ‘minus‘ (y:ys))
| x == y = xs ‘minus‘ ys
| x > y = (x:xs) ‘minus‘ ys
union = foldr merge []
merge (x:xs) ys = x:merge’ xs ys
merge’ (x:xs) (y:ys)
| x < y = x:merge’ xs (y:ys)
| x == y = x:merge’ xs ys
| x > y = y:merge’ (x:xs) ys
In the following code I have simplified the 'minus' function (called "minusStrtAt") as we don't need to build a completely new stream but can incorporate the composite subtraction operation with the generation of the original (in my case odds only) sequence. I have also simplified the "union" function (renaming it as "mrgMltpls")
The stream operations are implemented as a non memoizing generic Co Inductive Stream (CIS) as a generic class where the first field of the class is the value of the current position of the stream and the second is a thunk (a zero argument function that returns the next value of the stream through embedded closure arguments to another function).
def primes(): Iterator[Long] = {
// generic class as a Co Inductive Stream element
class CIS[A](val v: A, val cont: () => CIS[A])
def mltpls(p: Long): CIS[Long] = {
var px2 = p * 2
def nxtmltpl(cmpst: Long): CIS[Long] =
new CIS(cmpst, () => nxtmltpl(cmpst + px2))
nxtmltpl(p * p)
def allMltpls(mps: CIS[Long]): CIS[CIS[Long]] =
new CIS(mltpls(mps.v), () => allMltpls(mps.cont()))
def merge(a: CIS[Long], b: CIS[Long]): CIS[Long] =
if (a.v < b.v) new CIS(a.v, () => merge(a.cont(), b))
else if (a.v > b.v) new CIS(b.v, () => merge(a, b.cont()))
else new CIS(b.v, () => merge(a.cont(), b.cont()))
def mrgMltpls(mlps: CIS[CIS[Long]]): CIS[Long] =
new CIS(mlps.v.v, () => merge(mlps.v.cont(), mrgMltpls(mlps.cont())))
def minusStrtAt(n: Long, cmpsts: CIS[Long]): CIS[Long] =
if (n < cmpsts.v) new CIS(n, () => minusStrtAt(n + 2, cmpsts))
else minusStrtAt(n + 2, cmpsts.cont())
// the following are recursive, where cmpsts uses oddPrms and
// oddPrms uses a delayed version of cmpsts in order to avoid a race
// as oddPrms will already have a first value when cmpsts is called to generate the second
def cmpsts(): CIS[Long] = mrgMltpls(allMltpls(oddPrms()))
def oddPrms(): CIS[Long] = new CIS(3, () => minusStrtAt(5L, cmpsts()))
Iterator.iterate(new CIS(2L, () => oddPrms()))
{(cis: CIS[Long]) => cis.cont()}
.map {(cis: CIS[Long]) => cis.v}
The above code generates the 100,000th prime (1299709) on ideone in about 1.3 seconds with about a 0.36 second overhead and has an empirical computational complexity to 600,000 primes of about 1.43. The memory use is negligible above that used by the program code.
The above code could be implemented using the built-in Scala Streams, but there is a performance and memory use overhead (of a constant factor) that this algorithm does not require. Using Streams would mean that one could use them directly without the extra Iterator generation code, but as this is used only for final output of the sequence, it doesn't cost much.
To implement some basic tree folding as Will Ness has suggested, one only needs to add a "pairs" function and hook it into the "mrgMltpls" function:
def primes(): Iterator[Long] = {
// generic class as a Co Inductive Stream element
class CIS[A](val v: A, val cont: () => CIS[A])
def mltpls(p: Long): CIS[Long] = {
var px2 = p * 2
def nxtmltpl(cmpst: Long): CIS[Long] =
new CIS(cmpst, () => nxtmltpl(cmpst + px2))
nxtmltpl(p * p)
def allMltpls(mps: CIS[Long]): CIS[CIS[Long]] =
new CIS(mltpls(mps.v), () => allMltpls(mps.cont()))
def merge(a: CIS[Long], b: CIS[Long]): CIS[Long] =
if (a.v < b.v) new CIS(a.v, () => merge(a.cont(), b))
else if (a.v > b.v) new CIS(b.v, () => merge(a, b.cont()))
else new CIS(b.v, () => merge(a.cont(), b.cont()))
def pairs(mltplss: CIS[CIS[Long]]): CIS[CIS[Long]] = {
val tl = mltplss.cont()
new CIS(merge(mltplss.v, tl.v), () => pairs(tl.cont()))
def mrgMltpls(mlps: CIS[CIS[Long]]): CIS[Long] =
new CIS(mlps.v.v, () => merge(mlps.v.cont(), mrgMltpls(pairs(mlps.cont()))))
def minusStrtAt(n: Long, cmpsts: CIS[Long]): CIS[Long] =
if (n < cmpsts.v) new CIS(n, () => minusStrtAt(n + 2, cmpsts))
else minusStrtAt(n + 2, cmpsts.cont())
// the following are recursive, where cmpsts uses oddPrms and
// oddPrms uses a delayed version of cmpsts in order to avoid a race
// as oddPrms will already have a first value when cmpsts is called to generate the second
def cmpsts(): CIS[Long] = mrgMltpls(allMltpls(oddPrms()))
def oddPrms(): CIS[Long] = new CIS(3, () => minusStrtAt(5L, cmpsts()))
Iterator.iterate(new CIS(2L, () => oddPrms()))
{(cis: CIS[Long]) => cis.cont()}
.map {(cis: CIS[Long]) => cis.v}
The above code generates the 100,000th prime (1299709) on ideone in about 0.75 seconds with about a 0.37 second overhead and has an empirical computational complexity to the 1,000,000th prime (15485863) of about 1.09 (5.13 seconds). The memory use is negligible above that used by the program code.
Note that the above codes are completely functional in that there is no mutable state used whatsoever, but that the Bird algorithm (or even the tree folding version) isn't as fast as using a Priority Queue or HashMap for larger ranges as the number of operations to handle the tree merging has a higher computational complexity than the log n overhead of the Priority Queue or the linear (amortized) performance of a HashMap (although there is a large constant factor overhead to handle the hashing so that advantage isn't really seen until some truly large ranges are used).
The reason that these codes use so little memory is that the CIS streams are formulated with no permanent reference to the start of the streams so that the streams are garbage collected as they are used, leaving only the minimal number of base prime composite sequence place holders, which as Will Ness has explained is very small - only 546 base prime composite number streams for generating the first million primes up to 15485863, each placeholder only taking a few 10's of bytes (eight for the Long number, eight for the 64-bit function reference, with another couple of eight bytes for the pointer to the closure arguments and another few bytes for function and class overheads, for a total per stream placeholder of perhaps 40 bytes, or a total of not much more than 20 Kilobytes for generating the sequence for a million primes).
If you just want an infinite stream of primes, this is the most elegant way in my opinion:
def primes = {
def sieve(from : Stream[Int]): Stream[Int] = from.head #:: sieve(from.tail.filter(_ % from.head != 0))

Convert normal recursion to tail recursion

I was wondering if there is some general method to convert a "normal" recursion with foo(...) + foo(...) as the last call to a tail-recursion.
For example (scala):
def pascal(c: Int, r: Int): Int = {
if (c == 0 || c == r) 1
else pascal(c - 1, r - 1) + pascal(c, r - 1)
A general solution for functional languages to convert recursive function to a tail-call equivalent:
A simple way is to wrap the non tail-recursive function in the Trampoline monad.
def pascalM(c: Int, r: Int): Trampoline[Int] = {
if (c == 0 || c == r) Trampoline.done(1)
else for {
a <- Trampoline.suspend(pascal(c - 1, r - 1))
b <- Trampoline.suspend(pascal(c, r - 1))
} yield a + b
val pascal = pascalM(10, 5).run
So the pascal function is not a recursive function anymore. However, the Trampoline monad is a nested structure of the computation that need to be done. Finally, run is a tail-recursive function that walks through the tree-like structure, interpreting it, and finally at the base case returns the value.
A paper from Rúnar Bjanarson on the subject of Trampolines: Stackless Scala With Free Monads
In cases where there is a simple modification to the value of a recursive call, that operation can be moved to the front of the recursive function. The classic example of this is Tail recursion modulo cons, where a simple recursive function in this form:
def recur[A](...):List[A] = {
x :: recur(...)
which is not tail recursive, is transformed into
def recur[A]{...): List[A] = {
def consRecur(..., consA: A): List[A] = {
consA :: ...
consrecur(..., ...)
Alexlv's example is a variant of this.
This is such a well known situation that some compilers (I know of Prolog and Scheme examples but Scalac does not do this) can detect simple cases and perform this optimisation automatically.
Problems combining multiple calls to recursive functions have no such simple solution. TMRC optimisatin is useless, as you are simply moving the first recursive call to another non-tail position. The only way to reach a tail-recursive solution is remove all but one of the recursive calls; how to do this is entirely context dependent but requires finding an entirely different approach to solving the problem.
As it happens, in some ways your example is similar to the classic Fibonnaci sequence problem; in that case the naive but elegant doubly-recursive solution can be replaced by one which loops forward from the 0th number.
def fib (n: Long): Long = n match {
case 0 | 1 => n
case _ => fib( n - 2) + fib( n - 1 )
def fib (n: Long): Long = {
def loop(current: Long, next: => Long, iteration: Long): Long = {
if (n == iteration)
loop(next, current + next, iteration + 1)
loop(0, 1, 0)
For the Fibonnaci sequence, this is the most efficient approach (a streams based solution is just a different expression of this solution that can cache results for subsequent calls). Now,
you can also solve your problem by looping forward from c0/r0 (well, c0/r2) and calculating each row in sequence - the difference being that you need to cache the entire previous row. So while this has a similarity to fib, it differs dramatically in the specifics and is also significantly less efficient than your original, doubly-recursive solution.
Here's an approach for your pascal triangle example which can calculate pascal(30,60) efficiently:
def pascal(column: Long, row: Long):Long = {
type Point = (Long, Long)
type Points = List[Point]
type Triangle = Map[Point,Long]
def above(p: Point) = (p._1, p._2 - 1)
def aboveLeft(p: Point) = (p._1 - 1, p._2 - 1)
def find(ps: Points, t: Triangle): Long = ps match {
// Found the ultimate goal
case (p :: Nil) if t contains p => t(p)
// Found an intermediate point: pop the stack and carry on
case (p :: rest) if t contains p => find(rest, t)
// Hit a triangle edge, add it to the triangle
case ((c, r) :: _) if (c == 0) || (c == r) => find(ps, t + ((c,r) -> 1))
// Triangle contains (c - 1, r - 1)...
case (p :: _) if t contains aboveLeft(p) => if (t contains above(p))
// And it contains (c, r - 1)! Add to the triangle
find(ps, t + (p -> (t(aboveLeft(p)) + t(above(p)))))
// Does not contain(c, r -1). So find that
find(above(p) :: ps, t)
// If we get here, we don't have (c - 1, r - 1). Find that.
case (p :: _) => find(aboveLeft(p) :: ps, t)
require(column >= 0 && row >= 0 && column <= row)
(column, row) match {
case (c, r) if (c == 0) || (c == r) => 1
case p => find(List(p), Map())
It's efficient, but I think it shows how ugly complex recursive solutions can become as you deform them to become tail recursive. At this point, it may be worth moving to a different model entirely. Continuations or monadic gymnastics might be better.
You want a generic way to transform your function. There isn't one. There are helpful approaches, that's all.
I don't know how theoretical this question is, but a recursive implementation won't be efficient even with tail-recursion. Try computing pascal(30, 60), for example. I don't think you'll get a stack overflow, but be prepared to take a long coffee break.
Instead, consider using a Stream or memoization:
val pascal: Stream[Stream[Long]] =
#:: (Stream from 1 map { i =>
// compute row i
#:: (pascal(i-1) // take the previous row
sliding 2 // and add adjacent values pairwise
collect { case Stream(a,b) => a + b }).toStream
++ Stream(1L))
The accumulator approach
def pascal(c: Int, r: Int): Int = {
def pascalAcc(acc:Int, leftover: List[(Int, Int)]):Int = {
if (leftover.isEmpty) acc
else {
val (c1, r1) = leftover.head
// Edge.
if (c1 == 0 || c1 == r1) pascalAcc(acc + 1, leftover.tail)
// Safe checks.
else if (c1 < 0 || r1 < 0 || c1 > r1) pascalAcc(acc, leftover.tail)
// Add 2 other points to accumulator.
else pascalAcc(acc, (c1 , r1 - 1) :: ((c1 - 1, r1 - 1) :: leftover.tail ))
pascalAcc(0, List ((c,r) ))
It does not overflow the stack but as on big row and column but Aaron mentioned it's not fast.
Yes it's possible. Usually it's done with accumulator pattern through some internally defined function, which has one additional argument with so called accumulator logic, example with counting length of a list.
For example normal recursive version would look like this:
def length[A](xs: List[A]): Int = if (xs.isEmpty) 0 else 1 + length(xs.tail)
that's not a tail recursive version, in order to eliminate last addition operation we have to accumulate values while somehow, for example with accumulator pattern:
def length[A](xs: List[A]) = {
def inner(ys: List[A], acc: Int): Int = {
if (ys.isEmpty) acc else inner(ys.tail, acc + 1)
inner(xs, 0)
a bit longer to code, but i think the idea i clear. Of cause you can do it without inner function, but in such case you should provide acc initial value manually.
I'm pretty sure it's not possible in the simple way you're looking for the general case, but it would depend on how elaborate you permit the changes to be.
A tail-recursive function must be re-writable as a while-loop, but try implementing for example a Fractal Tree using while-loops. It's possble, but you need to use an array or collection to store the state for each point, which susbstitutes for the data otherwise stored in the call-stack.
It's also possible to use trampolining.
It is indeed possible. The way I'd do this is to
begin with List(1) and keep recursing till you get to the
row you want.
Worth noticing that you can optimize it: if c==0 or c==r the value is one, and to calculate let's say column 3 of the 100th row you still only need to calculate the first three elements of the previous rows.
A working tail recursive solution would be this:
def pascal(c: Int, r: Int): Int = {
def pascalAcc(c: Int, r: Int, acc: List[Int]): List[Int] = {
if (r == 0) acc
else pascalAcc(c, r - 1,
// from let's say 1 3 3 1 builds 0 1 3 3 1 0 , takes only the
// subset that matters (if asking for col c, no cols after c are
// used) and uses sliding to build (0 1) (1 3) (3 3) etc.
(0 +: acc :+ 0).take(c + 2)
.sliding(2, 1).map { x => x.reduce(_ + _) }.toList)
if (c == 0 || c == r) 1
else pascalAcc(c, r, List(1))(c)
The annotation #tailrec actually makes the compiler check the function
is actually tail recursive.
It could be probably be further optimized since given that the rows are symmetric, if c > r/2, pascal(c,r) == pascal ( r-c,r).. but left to the reader ;)

generating permutations with scalacheck

I have some generators like this:
val fooRepr = oneOf(a, b, c, d, e)
val foo = for (s <- choose(1, 5); c <- listOfN(s, fooRepr)) yield c.mkString("$")
This leads to duplicates ... I might get two a's, etc. What I really want is to generate random permutation with exactly 0 or 1 or each of a, b, c, d, or e (with at least one of something), in any order.
I was thinking there must be an easy way, but I'm struggling to even find a hard way. :)
Edited: Ok, this seems to work:
val foo = for (s <- choose(1, 5);
c <- permute(s, a, b, c, d, e)) yield c.mkString("$")
def permute[T](n: Int, gs: Gen[T]*): Gen[Seq[T]] = {
val perm = Random.shuffle(gs.toList)
for {
is <- pick(n, 1 until gs.size)
xs <- sequence[List,T](
} yield xs
...borrowing heavily from Gen.pick.
Thanks for your help, -Eric
Rex, thanks for clarifying exactly what I'm trying to do, and that's useful code, but perhaps not so nice with scalacheck, particularly if the generators in question are quite complex. In my particular case the generators a, b, c, etc. are generating huge strings.
Anyhow, there was a bug in my solution above; what worked for me is below. I put a tiny project demonstrating how to do this at github
The guts of it is below. If there's a better way, I'd love to know it...
package powerset
import org.scalacheck._
import org.scalacheck.Gen._
import org.scalacheck.Gen
import scala.util.Random
object PowersetPermutations extends Properties("PowersetPermutations") {
def a: Gen[String] = value("a")
def b: Gen[String] = value("b")
def c: Gen[String] = value("c")
def d: Gen[String] = value("d")
def e: Gen[String] = value("e")
val foo = for (s <- choose(1, 5);
c <- permute(s, a, b, c, d, e)) yield c.mkString
def permute[T](n: Int, gs: Gen[T]*): Gen[Seq[T]] = {
val perm = Random.shuffle(gs.toList)
for {
is <- pick(n, 0 until gs.size)
xs <- sequence[List, T](
} yield xs
implicit def arbString: Arbitrary[String] = Arbitrary(foo)
property("powerset") = Prop.forAll {
a: String => println(a); true
You're not describing a permutation, but the power set (minus the empty set)Edit: you're describing a combination of a power set and a permutation. The power set of an indexed set N is isomorphic to 2^N, so we simply (in Scala alone; maybe you want to alter this for use with ScalaCheck):
def powerSet[X](xs: List[X]) = {
val xis = xs.zipWithIndex
(for (j <- 1 until (1<<xs.length)) yield {
for ((x,i) <- xis if ((j & (1<<i)) != 0)) yield x
to generate all possible subsets given a set. Of course, explicit generation of power sets is unwise if they original set contains more than a handful of elements. If you don't want to generate all of them, just pass in a random number from 1 until (1<<(xs.length-1)) and run the inner loop. (Switch to Long if there are 33-64 elements, and to BitSet if there are more yet.) You can then permute the result to switch the order around if you wish.
Edit: there's another way to do this if you can generate permutations easily and you can add a dummy argument: make your list one longer, with a Stop token. Then permute and .takeWhile(_ != Stop). Ta-da! Permutations of arbitrary length. (Filter out the zero-length answer if need be.)