Fast Functional Equivalent Of For Loop Scala - scala

I have the following two code snippets in Scala:
/* Iterative */
for (i <- max to sum by min) {
if (sum % i == 0) validBusSize(i, L, 0)
}
/* Functional */
List.range(max, sum + 1, min)
.filter(sum % _ == 0)
.map(validBusSize(_, L, 0))
Both these code snippets are part of otherwise identical objects. However, when I run my code on Hackerrank, the object with the iterative snippet takes a maximum of 1.45 seconds, while the functional snippet causes the code to take > 7 seconds, which is a timeout.
I'd like to know if it's possible to rewrite the for loop functionally while retaining the speed. I took a look at the Stream container, but again I'll have to call filter before map, instead of computing each validBusSize sequentially.
Thanks!
Edit:
/* Full Code */
import scala.io.StdIn.readLine
object BusStation {
def main(args: Array[String]) {
readLine
val L = readLine.split(" ").map(_.toInt).toList
val min = L.min
val max = L.max
val sum = L.foldRight(0)(_ + _)
/* code under consideration */
for (i <- max to sum by min) {
if (sum % i == 0) validBusSize(i, L, 0)
}
}
def validBusSize(size: Int, L: List[Int], curr: Int) {
L match {
case Nil if (curr == size) => print(size + " ")
case head::tail if (curr < size) =>
validBusSize(size, tail, curr + head)
case head::tail if (curr == size) => validBusSize(size, tail, head)
case head::tail if (curr > size) => return
}
}
}

Right now, your best bet for fast functional code is tail-recursive functions:
#annotation.tailrec
def getBusSizes(i: Int, sum: Int, step: Int) {
if (i <= sum) {
if (sum % i == 0) validBusSize(i, L, 0)
getBusSizes(i + step, sum, step)
}
}
Various other things will be sort of fast-ish, but for something like this where there's mostly simple math, the overhead from the generic interface will be sizable. With a tail-recursive function you'll get a while loop underneath. (You don't need the annotation to make it tail-recursive; that just causes the compilation to fail if it can't. The optimization happens whether the annotation is there or not.)

So apparently the following worked:
Replacing the List.range(max, sum + 1, min) with a Range object, (max to sum by min). Going to ask another questions about why this works though.

Consider converting the range into a parallel version with keyword par, for instance like this
(max to sum by min).par
This may improve performance especially for large sized ranges with large values on calling validBusSize.
Thus in the proposed for comprehension,
for ( i <- (max to sum by min).par ) {
if (sum % i == 0) validBusSize(i, L, 0)
}

Related

Scala : How to break out of a nested for comprehension

I'm trying to write some code as below -
def kthSmallest(matrix: Array[Array[Int]], k: Int): Int = {
val pq = new PriorityQueue[Int]() //natural ordering
var count = 0
for (
i <- matrix.indices;
j <- matrix(0).indices
) yield {
pq += matrix(i)(j)
count += 1
} //This would yield Any!
pq.dequeue() //kth smallest.
}
My question is, that I only want to loop till the time count is less than k (something like takeWhile(count != k)), but as I'm also inserting elements into the priority queue in the yield, this won't work in the current state.
My other options are to write a nested loop and return once count reaches k. Is it possible to do with yield? I could not find a idiomatic way of doing it yet. Any pointers would be helpful.
It's not idiomatic for Scala to use vars or break loops. You can go for recursion, lazy evaluation or duct tape a break, giving up on some performance (just like return, it's implemented as an Exception, and won't perform well enough). Here are the options broken down:
Use recursion - recursive algorithms are the analog of loops in functional languages
def kthSmallest(matrix: Array[Array[Int]], k: Int): Int = {
val pq = new PriorityQueue[Int]() //natural ordering
#tailrec
def fillQueue(i: Int, j: Int, count: Int): Unit =
if (count >= k || i >= matrix.length) ()
else {
pq += matrix(i)(j)
fillQueue(
if (j >= matrix(i).length - 1) i + 1 else i,
if (j >= matrix(i).length - 1) 0 else j + 1,
count + 1)
}
fillQueue(0, 0, 0)
pq.dequeue() //kth smallest.
}
Use a lazy structure, as chengpohi suggested - this doesn't sound very much like a pure function though. I'd suggest to use an Iterator instead of a Stream in this case though - as iterators don't memoize the steps they've gone through (might spare some memory for large matrices).
For those desperately willing to use break, Scala supports it in an attachable fashion (note the performance caveat mentioned above):
import scala.util.control.Breaks
breakable {
// loop code
break
}
There is a way using the Stream lazy evaluation to do this. Since for yield is equal to flatMap, you can convert for yield to flatMap with Stream:
matrix.indices.toStream.flatMap(i => {
matrix(0).indices.toStream.map(j => {
pq += matrix(i)(j)
count += 1
})
}).takeWhile(_ => count <= k)
Use toStream to convert the collection to Stream, and Since Stream is lazy evaluation, so we can use takeWhile to predicate count to terminate the less loops without init others.

Convert normal recursion to tail recursion

I was wondering if there is some general method to convert a "normal" recursion with foo(...) + foo(...) as the last call to a tail-recursion.
For example (scala):
def pascal(c: Int, r: Int): Int = {
if (c == 0 || c == r) 1
else pascal(c - 1, r - 1) + pascal(c, r - 1)
}
A general solution for functional languages to convert recursive function to a tail-call equivalent:
A simple way is to wrap the non tail-recursive function in the Trampoline monad.
def pascalM(c: Int, r: Int): Trampoline[Int] = {
if (c == 0 || c == r) Trampoline.done(1)
else for {
a <- Trampoline.suspend(pascal(c - 1, r - 1))
b <- Trampoline.suspend(pascal(c, r - 1))
} yield a + b
}
val pascal = pascalM(10, 5).run
So the pascal function is not a recursive function anymore. However, the Trampoline monad is a nested structure of the computation that need to be done. Finally, run is a tail-recursive function that walks through the tree-like structure, interpreting it, and finally at the base case returns the value.
A paper from Rúnar Bjanarson on the subject of Trampolines: Stackless Scala With Free Monads
In cases where there is a simple modification to the value of a recursive call, that operation can be moved to the front of the recursive function. The classic example of this is Tail recursion modulo cons, where a simple recursive function in this form:
def recur[A](...):List[A] = {
...
x :: recur(...)
}
which is not tail recursive, is transformed into
def recur[A]{...): List[A] = {
def consRecur(..., consA: A): List[A] = {
consA :: ...
...
consrecur(..., ...)
}
...
consrecur(...,...)
}
Alexlv's example is a variant of this.
This is such a well known situation that some compilers (I know of Prolog and Scheme examples but Scalac does not do this) can detect simple cases and perform this optimisation automatically.
Problems combining multiple calls to recursive functions have no such simple solution. TMRC optimisatin is useless, as you are simply moving the first recursive call to another non-tail position. The only way to reach a tail-recursive solution is remove all but one of the recursive calls; how to do this is entirely context dependent but requires finding an entirely different approach to solving the problem.
As it happens, in some ways your example is similar to the classic Fibonnaci sequence problem; in that case the naive but elegant doubly-recursive solution can be replaced by one which loops forward from the 0th number.
def fib (n: Long): Long = n match {
case 0 | 1 => n
case _ => fib( n - 2) + fib( n - 1 )
}
def fib (n: Long): Long = {
def loop(current: Long, next: => Long, iteration: Long): Long = {
if (n == iteration)
current
else
loop(next, current + next, iteration + 1)
}
loop(0, 1, 0)
}
For the Fibonnaci sequence, this is the most efficient approach (a streams based solution is just a different expression of this solution that can cache results for subsequent calls). Now,
you can also solve your problem by looping forward from c0/r0 (well, c0/r2) and calculating each row in sequence - the difference being that you need to cache the entire previous row. So while this has a similarity to fib, it differs dramatically in the specifics and is also significantly less efficient than your original, doubly-recursive solution.
Here's an approach for your pascal triangle example which can calculate pascal(30,60) efficiently:
def pascal(column: Long, row: Long):Long = {
type Point = (Long, Long)
type Points = List[Point]
type Triangle = Map[Point,Long]
def above(p: Point) = (p._1, p._2 - 1)
def aboveLeft(p: Point) = (p._1 - 1, p._2 - 1)
def find(ps: Points, t: Triangle): Long = ps match {
// Found the ultimate goal
case (p :: Nil) if t contains p => t(p)
// Found an intermediate point: pop the stack and carry on
case (p :: rest) if t contains p => find(rest, t)
// Hit a triangle edge, add it to the triangle
case ((c, r) :: _) if (c == 0) || (c == r) => find(ps, t + ((c,r) -> 1))
// Triangle contains (c - 1, r - 1)...
case (p :: _) if t contains aboveLeft(p) => if (t contains above(p))
// And it contains (c, r - 1)! Add to the triangle
find(ps, t + (p -> (t(aboveLeft(p)) + t(above(p)))))
else
// Does not contain(c, r -1). So find that
find(above(p) :: ps, t)
// If we get here, we don't have (c - 1, r - 1). Find that.
case (p :: _) => find(aboveLeft(p) :: ps, t)
}
require(column >= 0 && row >= 0 && column <= row)
(column, row) match {
case (c, r) if (c == 0) || (c == r) => 1
case p => find(List(p), Map())
}
}
It's efficient, but I think it shows how ugly complex recursive solutions can become as you deform them to become tail recursive. At this point, it may be worth moving to a different model entirely. Continuations or monadic gymnastics might be better.
You want a generic way to transform your function. There isn't one. There are helpful approaches, that's all.
I don't know how theoretical this question is, but a recursive implementation won't be efficient even with tail-recursion. Try computing pascal(30, 60), for example. I don't think you'll get a stack overflow, but be prepared to take a long coffee break.
Instead, consider using a Stream or memoization:
val pascal: Stream[Stream[Long]] =
(Stream(1L)
#:: (Stream from 1 map { i =>
// compute row i
(1L
#:: (pascal(i-1) // take the previous row
sliding 2 // and add adjacent values pairwise
collect { case Stream(a,b) => a + b }).toStream
++ Stream(1L))
}))
The accumulator approach
def pascal(c: Int, r: Int): Int = {
def pascalAcc(acc:Int, leftover: List[(Int, Int)]):Int = {
if (leftover.isEmpty) acc
else {
val (c1, r1) = leftover.head
// Edge.
if (c1 == 0 || c1 == r1) pascalAcc(acc + 1, leftover.tail)
// Safe checks.
else if (c1 < 0 || r1 < 0 || c1 > r1) pascalAcc(acc, leftover.tail)
// Add 2 other points to accumulator.
else pascalAcc(acc, (c1 , r1 - 1) :: ((c1 - 1, r1 - 1) :: leftover.tail ))
}
}
pascalAcc(0, List ((c,r) ))
}
It does not overflow the stack but as on big row and column but Aaron mentioned it's not fast.
Yes it's possible. Usually it's done with accumulator pattern through some internally defined function, which has one additional argument with so called accumulator logic, example with counting length of a list.
For example normal recursive version would look like this:
def length[A](xs: List[A]): Int = if (xs.isEmpty) 0 else 1 + length(xs.tail)
that's not a tail recursive version, in order to eliminate last addition operation we have to accumulate values while somehow, for example with accumulator pattern:
def length[A](xs: List[A]) = {
def inner(ys: List[A], acc: Int): Int = {
if (ys.isEmpty) acc else inner(ys.tail, acc + 1)
}
inner(xs, 0)
}
a bit longer to code, but i think the idea i clear. Of cause you can do it without inner function, but in such case you should provide acc initial value manually.
I'm pretty sure it's not possible in the simple way you're looking for the general case, but it would depend on how elaborate you permit the changes to be.
A tail-recursive function must be re-writable as a while-loop, but try implementing for example a Fractal Tree using while-loops. It's possble, but you need to use an array or collection to store the state for each point, which susbstitutes for the data otherwise stored in the call-stack.
It's also possible to use trampolining.
It is indeed possible. The way I'd do this is to
begin with List(1) and keep recursing till you get to the
row you want.
Worth noticing that you can optimize it: if c==0 or c==r the value is one, and to calculate let's say column 3 of the 100th row you still only need to calculate the first three elements of the previous rows.
A working tail recursive solution would be this:
def pascal(c: Int, r: Int): Int = {
#tailrec
def pascalAcc(c: Int, r: Int, acc: List[Int]): List[Int] = {
if (r == 0) acc
else pascalAcc(c, r - 1,
// from let's say 1 3 3 1 builds 0 1 3 3 1 0 , takes only the
// subset that matters (if asking for col c, no cols after c are
// used) and uses sliding to build (0 1) (1 3) (3 3) etc.
(0 +: acc :+ 0).take(c + 2)
.sliding(2, 1).map { x => x.reduce(_ + _) }.toList)
}
if (c == 0 || c == r) 1
else pascalAcc(c, r, List(1))(c)
}
The annotation #tailrec actually makes the compiler check the function
is actually tail recursive.
It could be probably be further optimized since given that the rows are symmetric, if c > r/2, pascal(c,r) == pascal ( r-c,r).. but left to the reader ;)

Testing whether an ordered infinite stream contains a value

I have an infinite Stream of primes primeStream (starting at 2 and increasing). I also have another stream of Ints s which increase in magnitude and I want to test whether each of these is prime.
What is an efficient way to do this? I could define
def isPrime(n: Int) = n == primeStream.dropWhile(_ < n).head
but this seems inefficient since it needs to iterate over the whole stream each time.
Implementation of primeStream (shamelessly copied from elsewhere):
val primeStream: Stream[Int] =
2 #:: primeStream.map{ i =>
Stream.from(i + 1)
.find{ j =>
primeStream.takeWhile{ k => k * k <= j }
.forall{ j % _ > 0 }
}.get
}
If the question is about implementing isPrime, then you should do as suggested by rossum, even with division costing more than equality test, and with primes being more dense for lower values of n, it would be asymptotically much faster. Moreover, it is very fast when testing non primes which have a small divisor (most numbers have)
It may be different if you want to test primality of elements of another increasing Stream. You may consider something akin to a merge sort. You did not state how you want to get your result, here as a stream of Boolean, but it should not be too hard to adapt to something else.
/**
* Returns a stream of boolean, whether element at the corresponding position
* in xs belongs in ys. xs and ys must both be increasing streams.
*/
def belong[A: Ordering](xs: Stream[A], ys: Stream[A]): Stream[Boolean] = {
if (xs.isEmpty) Stream.empty
else if (ys.isEmpty) xs.map(_ => true)
else Ordering[A].compare(xs.head, ys.head) match {
case less if less < 0 => false #:: belong(xs.tail, ys)
case equal if equal == 0 => true #:: belong(xs.tail, ys.tail)
case greater if greater > 0 => belong(xs, ys.tail)
}
}
So you may do belong(yourStream, primeStream)
Yet it is not obvious that this solution will be better than a separate testing of primality for each number in turn, stopping at square root. Especially if yourStream is fast increasing compared to primes, so you have to compute many primes in vain, just to keep up. And even less so if there is no reason to suspect that elements in yourStream tend to be primes or have only large divisors.
You only need to read your prime stream as far as sqrt(s).
As you retrieve each p from the prime stream check if p evenly divides s.
This will give you a trial division method of prime checking.
To solve the general question of determining whether an ordered finite list consisted entirely of element of an ordered but infinite stream:
The simplest way is
candidate.toSet subsetOf infiniteList.takeWhile( _ <= candidate.last).toSet
but if the candidate is large, that requires a lot of space and it is O(n log n) instead O(n) like it could be. The O(n) way is
def acontains(a : Int, b : Iterator[Int]) : Boolean = {
while (b hasNext) {
val c = b.next
if (c == a) {
return true
}
if (c > a) {
return false
}
}
return false
}
def scontains(candidate: List[Int], infiniteList: Stream[Int]) : Boolean = {
val it = candidate.iterator
val il = infiniteList.iterator
while (it.hasNext) {
if (!acontains(it.next, il)) {
return false
}
}
return true
}
(Incidentally, if some helpful soul could propose a more Scalicious way to write the foregoing, I'd appreciate it.)
EDIT:
In the comments, the inestimable Luigi Plinge pointed out that I could just write:
def scontains(candidate: List[Int], infiniteStream: Stream[Int]) = {
val it = infiniteStream.iterator
candidate.forall(i => it.dropWhile(_ < i).next == i)
}

What is the fastest way to write Fibonacci function in Scala?

I've looked over a few implementations of Fibonacci function in Scala starting from a very simple one, to the more complicated ones.
I'm not entirely sure which one is the fastest. I'm leaning towards the impression that the ones that uses memoization is faster, however I wonder why Scala doesn't have a native memoization.
Can anyone enlighten me toward the best and fastest (and cleanest) way to write a fibonacci function?
The fastest versions are the ones that deviate from the usual addition scheme in some way. Very fast is the calculation somehow similar to a fast binary exponentiation based on these formulas:
F(2n-1) = F(n)² + F(n-1)²
F(2n) = (2F(n-1) + F(n))*F(n)
Here is some code using it:
def fib(n:Int):BigInt = {
def fibs(n:Int):(BigInt,BigInt) = if (n == 1) (1,0) else {
val (a,b) = fibs(n/2)
val p = (2*b+a)*a
val q = a*a + b*b
if(n % 2 == 0) (p,q) else (p+q,p)
}
fibs(n)._1
}
Even though this is not very optimized (e.g. the inner loop is not tail recursive), it will beat the usual additive implementations.
for me the simplest defines a recursive inner tail function:
def fib: Stream[Long] = {
def tail(h: Long, n: Long): Stream[Long] = h #:: tail(n, h + n)
tail(0, 1)
}
This doesn't need to build any Tuple objects for the zip and is easy to understand syntactically.
Scala does have memoization in the form of Streams.
val fib: Stream[BigInt] = 0 #:: 1 #:: fib.zip(fib.tail).map(p => p._1 + p._2)
scala> fib take 100 mkString " "
res22: String = 0 1 1 2 3 5 8 13 21 34 55 89 144 233 377 610 987 1597 2584 4181 ...
Stream is a LinearSeq so you might like to convert it to an IndexedSeq if you're doing a lot of fib(42) type calls.
However I would question what your use-case is for a fibbonaci function. It will overflow Long in less than 100 terms so larger terms aren't much use for anything. The smaller terms you can just stick in a table and look them up if speed is paramount. So the details of the computation probably don't matter much since for the smaller terms they're all quick.
If you really want to know the results for very big terms, then it depends on whether you just want one-off values (use Landei's solution) or, if you're making a sufficient number of calls, you may want to pre-compute the whole lot. The problem here is that, for example, the 100,000th element is over 20,000 digits long. So we're talking gigabytes of BigInt values which will crash your JVM if you try to hold them in memory. You could sacrifice accuracy and make things more manageable. You could have a partial-memoization strategy (say, memoize every 100th term) which makes a suitable memory / speed trade-off. There is no clear anwser for what is the fastest: it depends on your usage and resources.
This could work. it takes O(1) space O(n) time to calculate a number, but has no caching.
object Fibonacci {
def fibonacci(i : Int) : Int = {
def h(last : Int, cur: Int, num : Int) : Int = {
if ( num == 0) cur
else h(cur, last + cur, num - 1)
}
if (i < 0) - 1
else if (i == 0 || i == 1) 1
else h(1,2,i - 2)
}
def main(args: Array[String]){
(0 to 10).foreach( (x : Int) => print(fibonacci(x) + " "))
}
}
The answers using Stream (including the accepted answer) are very short and idiomatic, but they aren't the fastest. Streams memoize their values (which isn't necessary in iterative solutions), and even if you don't keep the reference to the stream, a lot of memory may be allocated and then immediately garbage-collected. A good alternative is to use an Iterator: it doesn't cause memory allocations, is functional in style, short and readable.
def fib(n: Int) = Iterator.iterate(BigInt(0), BigInt(1)) { case (a, b) => (b, a+b) }.
map(_._1).drop(n).next
A little simpler tail Recursive solution that can calculate Fibonacci for large values of n. The Int version is faster but is limited, when n > 46 integer overflow occurs
def tailRecursiveBig(n :Int) : BigInt = {
#tailrec
def aux(n : Int, next :BigInt, acc :BigInt) :BigInt ={
if(n == 0) acc
else aux(n-1, acc + next,next)
}
aux(n,1,0)
}
This has already been answered, but hopefully you will find my experience helpful. I had a lot of trouble getting my mind around scala infinite streams. Then, I watched Paul Agron's presentation where he gave very good suggestions: (1) implement your solution with basic Lists first, then if you are going to generify your solution with parameterized types, create a solution with simple types like Int's first.
using that approach I came up with a real simple (and for me, easy to understand solution):
def fib(h: Int, n: Int) : Stream[Int] = { h #:: fib(n, h + n) }
var x = fib(0,1)
println (s"results: ${(x take 10).toList}")
To get to the above solution I first created, as per Paul's advice, the "for-dummy's" version, based on simple lists:
def fib(h: Int, n: Int) : List[Int] = {
if (h > 100) {
Nil
} else {
h :: fib(n, h + n)
}
}
Notice that I short circuited the list version, because if i didn't it would run forever.. But.. who cares? ;^) since it is just an exploratory bit of code.
The code below is both fast and able to compute with high input indices. On my computer it returns the 10^6:th Fibonacci number in less than two seconds. The algorithm is in a functional style but does not use lists or streams. Rather, it is based on the equality \phi^n = F_{n-1} + F_n*\phi, for \phi the golden ratio. (This is a version of "Binet's formula".) The problem with using this equality is that \phi is irrational (involving the square root of five) so it will diverge due to finite-precision arithmetics if interpreted naively using Float-numbers. However, since \phi^2 = 1 + \phi it is easy to implement exact computations with numbers of the form a + b\phi for a and b integers, and this is what the algorithm below does. (The "power" function has a bit of optimization in it but is really just iteration of the "mult"-multiplication on such numbers.)
type Zphi = (BigInt, BigInt)
val phi = (0, 1): Zphi
val mult: (Zphi, Zphi) => Zphi = {
(z, w) => (z._1*w._1 + z._2*w._2, z._1*w._2 + z._2*w._1 + z._2*w._2)
}
val power: (Zphi, Int) => Zphi = {
case (base, ex) if (ex >= 0) => _power((1, 0), base, ex)
case _ => sys.error("no negative power plz")
}
val _power: (Zphi, Zphi, Int) => Zphi = {
case (t, b, e) if (e == 0) => t
case (t, b, e) if ((e & 1) == 1) => _power(mult(t, b), mult(b, b), e >> 1)
case (t, b, e) => _power(t, mult(b, b), e >> 1)
}
val fib: Int => BigInt = {
case n if (n < 0) => 0
case n => power(phi, n)._2
}
EDIT: An implementation which is more efficient and in a sense also more idiomatic is based on Typelevel's Spire library for numeric computations and abstract algebra. One can then paraphrase the above code in a way much closer to the mathematical argument (We do not need the whole ring-structure but I think it's "morally correct" to include it). Try running the following code:
import spire.implicits._
import spire.algebra._
case class S(fst: BigInt, snd: BigInt) {
override def toString = s"$fst + $snd"++"φ"
}
object S {
implicit object SRing extends Ring[S] {
def zero = S(0, 0): S
def one = S(1, 0): S
def plus(z: S, w: S) = S(z.fst + w.fst, z.snd + w.snd): S
def negate(z: S) = S(-z.fst, -z.snd): S
def times(z: S, w: S) = S(z.fst * w.fst + z.snd * w.snd
, z.fst * w.snd + z.snd * w.fst + z.snd * w.snd)
}
}
object Fibo {
val phi = S(0, 1)
val fib: Int => BigInt = n => (phi pow n).snd
def main(arg: Array[String]) {
println( fib(1000000) )
}
}

Scala performance - Sieve

Right now, I am trying to learn Scala . I've started small, writing some simple algorithms . I've encountered some problems when I wanted to implement the Sieve algorithm from finding all all prime numbers lower than a certain threshold .
My implementation is:
import scala.math
object Sieve {
// Returns all prime numbers until maxNum
def getPrimes(maxNum : Int) = {
def sieve(list: List[Int], stop : Int) : List[Int] = {
list match {
case Nil => Nil
case h :: list if h <= stop => h :: sieve(list.filterNot(_ % h == 0), stop)
case _ => list
}
}
val stop : Int = math.sqrt(maxNum).toInt
sieve((2 to maxNum).toList, stop)
}
def main(args: Array[String]) = {
val ap = printf("%d ", (_:Int));
// works
getPrimes(1000).foreach(ap(_))
// works
getPrimes(100000).foreach(ap(_))
// out of memory
getPrimes(1000000).foreach(ap(_))
}
}
Unfortunately it fails when I want to computer all the prime numbers smaller than 1000000 (1 million) . I am receiving OutOfMemory .
Do you have any idea on how to optimize the code, or how can I implement this algorithm in a more elegant fashion .
PS: I've done something very similar in Haskell, and there I didn't encountered any issues .
I would go with an infinite Stream. Using a lazy data structure allows to code pretty much like in Haskell. It reads automatically more "declarative" than the code you wrote.
import Stream._
val primes = 2 #:: sieve(3)
def sieve(n: Int) : Stream[Int] =
if (primes.takeWhile(p => p*p <= n).exists(n % _ == 0)) sieve(n + 2)
else n #:: sieve(n + 2)
def getPrimes(maxNum : Int) = primes.takeWhile(_ < maxNum)
Obviously, this isn't the most performant approach. Read The Genuine Sieve of Eratosthenes for a good explanation (it's Haskell, but not too difficult). For real big ranges you should consider the Sieve of Atkin.
The code in question is not tail recursive, so Scala cannot optimize the recursion away. Also, Haskell is non-strict by default, so you can't hardly compare it to Scala. For instance, whereas Haskell benefits from foldRight, Scala benefits from foldLeft.
There are many Scala implementations of Sieve of Eratosthenes, including some in Stack Overflow. For instance:
(n: Int) => (2 to n) |> (r => r.foldLeft(r.toSet)((ps, x) => if (ps(x)) ps -- (x * x to n by x) else ps))
The following answer is about a 100 times faster than the "one-liner" answer using a Set (and the results don't need sorting to ascending order) and is more of a functional form than the other answer using an array although it uses a mutable BitSet as a sieving array:
object SoE {
def makeSoE_Primes(top: Int): Iterator[Int] = {
val topndx = (top - 3) / 2
val nonprms = new scala.collection.mutable.BitSet(topndx + 1)
def cullp(i: Int) = {
import scala.annotation.tailrec; val p = i + i + 3
#tailrec def cull(c: Int): Unit = if (c <= topndx) { nonprms += c; cull(c + p) }
cull((p * p - 3) >>> 1)
}
(0 to (Math.sqrt(top).toInt - 3) >>> 1).filterNot { nonprms }.foreach { cullp }
Iterator.single(2) ++ (0 to topndx).filterNot { nonprms }.map { i: Int => i + i + 3 }
}
}
It can be tested by the following code:
object Main extends App {
import SoE._
val top_num = 10000000
val strt = System.nanoTime()
val count = makeSoE_Primes(top_num).size
val end = System.nanoTime()
println(s"Successfully completed without errors. [total ${(end - strt) / 1000000} ms]")
println(f"Found $count primes up to $top_num" + ".")
println("Using one large mutable1 BitSet and functional code.")
}
With the results from the the above as follows:
Successfully completed without errors. [total 294 ms]
Found 664579 primes up to 10000000.
Using one large mutable BitSet and functional code.
There is an overhead of about 40 milliseconds for even small sieve ranges, and there are various non-linear responses with increasing range as the size of the BitSet grows beyond the different CPU caches.
It looks like List isn't very effecient space wise. You can get an out of memory exception by doing something like this
1 to 2000000 toList
I "cheated" and used a mutable array. Didn't feel dirty at all.
def primesSmallerThan(n: Int): List[Int] = {
val nonprimes = Array.tabulate(n + 1)(i => i == 0 || i == 1)
val primes = new collection.mutable.ListBuffer[Int]
for (x <- nonprimes.indices if !nonprimes(x)) {
primes += x
for (y <- x * x until nonprimes.length by x if (x * x) > 0) {
nonprimes(y) = true
}
}
primes.toList
}