Why is this scala prime generation so slow/memory intensive? - scala

I run out of memory while finding the 10,001th prime number.
object Euler0007 {
def from(n: Int): Stream[Int] = n #:: from(n + 1)
def sieve(s: Stream[Int]): Stream[Int] = s.head #:: sieve(s.filter(_ % s.head != 0))
def primes = sieve(from(2))
def main(args: Array[String]): Unit = {
println(primes(10001))
}
}
Is this because after each "iteration" (is this the correct term in this context?) of primes, I increase the stack of functions to be called to get the next element by one?
One solution that I've found on the web which doesn't resort to an iterative solution (which I'd like to avoid to get into functional programming/idiomatic scala) is this (Problem 7):
lazy val ps: Stream[Int] = 2 #:: Stream.from(3).filter(i => ps.takeWhile(j => j * j <= i).forall(i % _ > 0))
From what I can see, this does not lead to this recursion-like way. Is this a good way to do it, or do you know of a better way?

One reason why this is slow is that it isn't the sieve of Eratosthenes. Read http://www.cs.hmc.edu/~oneill/papers/Sieve-JFP.pdf for a detailled explanation (the examples are in Haskell, but can be translated directly into Scala).
My old solution for Euler problem #7 wasn't the "true" sieve either, but it seems to work good enough for little numbers:
object Sieve {
val primes = 2 #:: sieve(3)
def sieve(n: Int) : Stream[Int] =
if (primes.takeWhile(p => p*p <= n).exists(n % _ == 0)) sieve(n + 2)
else n #:: sieve(n + 2)
def main(args: Array[String]) {
println(primes(10000)) //note that indexes are zero-based
}
}
I think the problem with your first version is that you have only defs and no val which collects the results and can be consulted by the generating function, so you always recalculate from scratch.

Yes, it is because you "increase the stack of functions to be called to get the next element, by one after each "iteration" " - i.e. add a new filter on top of stack of filters each time after getting each prime. That's way too many filters.
This means that each produced prime gets tested by all its preceding primes - but only those below its square root are really needed. For instance, to get the 10001-th prime, 104743, there will be 10000 filters created, at run-time. But there are just 66 primes below 323, the square root of 104743, so only 66 filters were really needed. All the 9934 others will be there needlessly, taking up memory, hard at work producing absolutely no added value.
This is the key deficiency of that "functional sieve", which seems to have originated in the 1970s code by David Turner, and later have found its way into the SICP book and other places. It is not that it's a trial division sieve (rather than the sieve of Eratosthenes). That's far too remote a concern for it. Trial division, when optimally implemented, is perfectly capable of producing the 10000th prime very fast.
The key deficiency of that code is that it does not postpone the creation of filters to the right moment, and ends up creating far too many of them.
Talking complexities now, the "old sieve" code is O(n2), in n primes produced. The optimal trial division is O(n1.5/log0.5(n)), and the sieve of Eratosthenes is O(n*log(n)*log(log(n))). As empirical orders of growth the first is seen typically as ~ n^2, the second as ~ n^1.45 and the third ~ n^1.2.
You can find Python generators-based code for optimal trial division implemented in this answer (2nd half of it). It was originally discussed here dealing with the Haskell equivalent of your sieve function.
Just as an illustration, a "readable pseudocode" :) for the old sieve is
primes = sieve [2..] where
sieve (x:xs) = x : sieve [ y | y <- xs, rem y x > 0 ]
-- list of 'y's, drawn from 'xs',
-- such that (y % x > 0)
and for optimal trial division (TD) sieve, synchronized on primes' squares,
primes = sieve [2..] primes where
sieve (x:xs) ps = x : (h ++ sieve [ y | y <- t, rem y p > 0 ] qs)
where
(p:qs) = ps -- 'p' is head elt in 'ps', and 'qs' the rest
(h,t) = span (< p*p) xs -- 'h' are elts below p^2 in 'xs'
-- and 't' are the rest
and for a sieve of Eratosthenes, devised by Richard Bird, as seen in that JFP article mentioned in another answer here,
primes = 2 : minus [3..]
(foldr (\p r-> p*p : union [p*p+p, p*p+2*p..] r) [] primes)
-- function of 'p' and 'r', that returns
-- a list with p^2 as its head elt, ...
Short and fast. (minus a b is a list a with all the elts of b progressively removed from it; union a b is a list a with all the elts of b progressively added to it without duplicates; both dealing with ordered, non-decreasing lists). foldr is the right fold of a list. Because it is linear this runs at ~ n^1.33, to make it run at ~ n^1.2 the tree-like folding function foldi can be used).
The answer to your second question is also a yes. Your second code, re-written in same "pseudocode",
ps = 2 : [i | i <- [3..], all ((> 0).rem i) (takeWhile ((<= i).(^2)) ps)]
is very similar to the optimal TD sieve above - both arrange for each candidate to be tested by all primes below its square root. While the sieve arranges that with a run-time sequence of postponed filters, the latter definition re-fetches the needed primes anew for each candidate. One might be faster than another depending on a compiler, but both are essentially the same.
And the third is also a yes: the sieve of Eratosthenes is better,
ps = 2 : 3 : minus [5,7..] (unionAll [[p*p, p*p+2*p..] | p <- drop 1 ps])
unionAll = foldi union' [] -- one possible implementation
union' (x:xs) ys = x : union xs ys
-- unconditionally produce first elt of the 1st arg
-- to avoid run-away access to infinite lists
It looks like it can be implemented in Scala too, judging by the similarity of other code snippets. (Though I don't know Scala). unionAll here implements tree-like folding structure (click for a picture and full code) but could also be implemented with a sliding array, working segment by segment along the streams of primes' multiples.
TL;DR: yes, yes, and yes.

FWIW, here's a real Sieve of Eratosthenes:
def sieve(n: Int) = (2 to math.sqrt(n).toInt).foldLeft((2 to n).toSet) { (ps, x) =>
if (ps(x)) ps -- (x * x to n by x)
else ps
}
Here's an infinite stream of primes using a variation on the Sieve of Eratosthenes that preserves its fundamental properties:
case class Cross(next: Int, incr: Int)
def adjustCrosses(crosses: List[Cross], current: Int) = {
crosses map {
case cross # Cross(`current`, incr) => cross copy (next = current + incr)
case unchangedCross => unchangedCross
}
}
def notPrime(crosses: List[Cross], current: Int) = crosses exists (_.next == current)
def sieve(s: Stream[Int], crosses: List[Cross]): Stream[Int] = {
val current #:: rest = s
if (notPrime(crosses, current)) sieve(rest, adjustCrosses(crosses, current))
else current #:: sieve(rest, Cross(current * current, current) :: crosses)
}
def primes = sieve(Stream from 2, Nil)
This is somewhat difficult to use, however, since each element of the Stream is composed using the crosses list, which has as many numbers as there have been primes up to a number, and it seems that, for some reason, these lists are being kept in memory for each number in the Stream.
For example, prompted by a comment, primes take 6000 contains 56993 would throw a GC exception whereas primes drop 5000 take 1000 contains 56993 would return a result rather fast on my tests.

Related

How to write an efficient groupBy-size filter in Scala, can be approximate

Given a List[Int] in Scala, I wish to get the Set[Int] of all Ints which appear at least thresh times. I can do this using groupBy or foldLeft, then filter. For example:
val thresh = 3
val myList = List(1,2,3,2,1,4,3,2,1)
myList.foldLeft(Map[Int,Int]()){case(m, i) => m + (i -> (m.getOrElse(i, 0) + 1))}.filter(_._2 >= thresh).keys
will give Set(1,2).
Now suppose the List[Int] is very large. How large it's hard to say but in any case this seems wasteful as I don't care about each of the Ints frequencies, and I only care if they're at least thresh. Once it passed thresh there's no need to check anymore, just add the Int to the Set[Int].
The question is: can I do this more efficiently for a very large List[Int],
a) if I need a true, accurate result (no room for mistakes)
b) if the result can be approximate, e.g. by using some Hashing trick or Bloom Filters, where Set[Int] might include some false-positives, or whether {the frequency of an Int > thresh} isn't really a Boolean but a Double in [0-1].
First of all, you can't do better than O(N), as you need to check each element of your initial array at least once. You current approach is O(N), presuming that operations with IntMap are effectively constant.
Now what you can try in order to increase efficiency:
update map only when current counter value is less or equal to threshold. This will eliminate huge number of most expensive operations — map updates
try faster map instead of IntMap. If you know that values of the initial List are in fixed range, you can use Array instead of IntMap (index as the key). Another possible option will be mutable HashMap with sufficient initail capacity. As my benchmark shows it actually makes significant difference
As #ixx proposed, after incrementing value in the map, check whether it's equal to 3 and in this case add it immediately to result list. This will save you one linear traversing (appears to be not that significant for large input)
I don't see how any approximate solution can be faster (only if you ignore some elements at random). Otherwise it will still be O(N).
Update
I created microbenchmark to measure the actual performance of different implementations. For sufficiently large input and output Ixx's suggestion regarding immediately adding elements to result list doesn't produce significant improvement. However similar approach could be used to eliminate unnecessary Map updates (which appears to be the most expensive operation).
Results of benchmarks (avg run times on 1000000 elems with pre-warming):
Authors solution:
447 ms
Ixx solution:
412 ms
Ixx solution2 (eliminated excessive map writes):
150 ms
My solution:
57 ms
My solution involves using mutable HashMap instead of immutable IntMap and includes all other possible optimizations.
Ixx's updated solution:
val tuple = (Map[Int, Int](), List[Int]())
val res = myList.foldLeft(tuple) {
case ((m, s), i) =>
val count = m.getOrElse(i, 0) + 1
(if (count <= 3) m + (i -> count) else m, if (count == thresh) i :: s else s)
}
My solution:
val map = new mutable.HashMap[Int, Int]()
val res = new ListBuffer[Int]
myList.foreach {
i =>
val c = map.getOrElse(i, 0) + 1
if (c == thresh) {
res += i
}
if (c <= thresh) {
map(i) = c
}
}
The full microbenchmark source is available here.
You could use the foldleft to collect the matching items, like this:
val tuple = (Map[Int,Int](), List[Int]())
myList.foldLeft(tuple) {
case((m, s), i) => {
val count = (m.getOrElse(i, 0) + 1)
(m + (i -> count), if (count == thresh) i :: s else s)
}
}
I could measure a performance improvement of about 40% with a small list, so it's definitely an improvement...
Edited to use List and prepend, which takes constant time (see comments).
If by "more efficiently" you mean the space efficiency (in extreme case when the list is infinite), there's a probabilistic data structure called Count Min Sketch to estimate the frequency of items inside it. Then you can discard those with frequency below your threshold.
There's a Scala implementation from Algebird library.
You can change your foldLeft example a bit using a mutable.Set that is build incrementally and at the same time used as filter for iterating over your Seq by using withFilter. However, because I'm using withFilteri cannot use foldLeft and have to make do with foreach and a mutable map:
import scala.collection.mutable
def getItems[A](in: Seq[A], threshold: Int): Set[A] = {
val counts: mutable.Map[A, Int] = mutable.Map.empty
val result: mutable.Set[A] = mutable.Set.empty
in.withFilter(!result(_)).foreach { x =>
counts.update(x, counts.getOrElse(x, 0) + 1)
if (counts(x) >= threshold) {
result += x
}
}
result.toSet
}
So, this would discard items that have already been added to the result set while running through the Seq the first time, because withFilterfilters the Seqin the appended function (map, flatMap, foreach) rather than returning a filtered Seq.
EDIT:
I changed my solution to not use Seq.count, which was stupid, as Aivean correctly pointed out.
Using Aiveans microbench I can see that it is still slightly slower than his approach, but still better than the authors first approach.
Authors solution
377
Ixx solution:
399
Ixx solution2 (eliminated excessive map writes):
110
Sascha Kolbergs solution:
72
Aivean solution:
54

Scala, Erastothenes: Is there a straightforward way to replace a stream with an iteration?

I wrote a function that generates primes indefinitely (wikipedia: incremental sieve of Erastothenes) usings streams. It returns a stream, but it also merges streams of prime multiples internally to mark upcoming composites. The definition is concise, functional, elegant and easy to understand, if I do say so myself:
def primes(): Stream[Int] = {
def merge(a: Stream[Int], b: Stream[Int]): Stream[Int] = {
def next = a.head min b.head
Stream.cons(next, merge(if (a.head == next) a.tail else a,
if (b.head == next) b.tail else b))
}
def test(n: Int, compositeStream: Stream[Int]): Stream[Int] = {
if (n == compositeStream.head) test(n+1, compositeStream.tail)
else Stream.cons(n, test(n+1, merge(compositeStream, Stream.from(n*n, n))))
}
test(2, Stream.from(4, 2))
}
But, I get a "java.lang.OutOfMemoryError: GC overhead limit exceeded" when I try to generate the 1000th prime.
I have an alternative solution that returns an iterator over primes and uses a priority queue of tuples (multiple, prime used to generate multiple) internally to mark upcoming composites. It works well, but it takes about twice as much code, and I basically had to restart from scratch:
import scala.collection.mutable.PriorityQueue
def primes(): Iterator[Int] = {
// Tuple (composite, prime) is used to generate a primes multiples
object CompositeGeneratorOrdering extends Ordering[(Long, Int)] {
def compare(a: (Long, Int), b: (Long, Int)) = b._1 compare a._1
}
var n = 2;
val composites = PriorityQueue(((n*n).toLong, n))(CompositeGeneratorOrdering)
def advance = {
while (n == composites.head._1) { // n is composite
while (n == composites.head._1) { // duplicate composites
val (multiple, prime) = composites.dequeue
composites.enqueue((multiple + prime, prime))
}
n += 1
}
assert(n < composites.head._1)
val prime = n
n += 1
composites.enqueue((prime.toLong * prime.toLong, prime))
prime
}
Iterator.continually(advance)
}
Is there a straightforward way to translate the code with streams to code with iterators? Or is there a simple way to make my first attempt more memory efficient?
It's easier to think in terms of streams; I'd rather start that way, then tweak my code if necessary.
I guess it's a bug in current Stream implementation.
primes().drop(999).head works fine:
primes().drop(999).head
// Int = 7919
You'll get OutOfMemoryError with stored Stream like this:
val prs = primes()
prs.drop(999).head
// Exception in thread "main" java.lang.OutOfMemoryError: GC overhead limit exceeded
The problem here with class Cons implementation: it contains not only calculated tail, but also a function to calculate this tail. Even when the tail is calculated and function is not needed any more!
In this case functions are extremely heavy, so you'll get OutOfMemoryError even with 1000 functions stored.
We have to drop that functions somehow.
Intuitive fix is failed:
val prs = primes().iterator.toStream
prs.drop(999).head
// Exception in thread "main" java.lang.OutOfMemoryError: GC overhead limit exceeded
With iterator on Stream you'll get StreamIterator, with StreamIterator#toStream you'll get initial heavy Stream.
Workaround
So we have to convert it manually:
def toNewStream[T](i: Iterator[T]): Stream[T] =
if (i.hasNext) Stream.cons(i.next, toNewStream(i))
else Stream.empty
val prs = toNewStream(primes().iterator)
// Stream[Int] = Stream(2, ?)
prs.drop(999).head
// Int = 7919
In your first code, you should postpone the merging until the square of a prime is seen amongst the candidates. This will drastically reduce the number of streams in use, radically improving your memory usage issues. To get the 1000th prime, 7919, we only need to consider primes not above its square root, 88. That's just 23 primes/streams of their multiples, instead of 999 (22, if we ignore the evens from the outset). For the 10,000th prime, it's the difference between having 9999 streams of multiples and just 66. And for the 100,000th, only 189 are needed.
The trick is to separate the primes being consumed from the primes being produced, via a recursive invocation:
def primes(): Stream[Int] = {
def merge(a: Stream[Int], b: Stream[Int]): Stream[Int] = {
def next = a.head min b.head
Stream.cons(next, merge(if (a.head == next) a.tail else a,
if (b.head == next) b.tail else b))
}
def test(n: Int, q: Int,
compositeStream: Stream[Int],
primesStream: Stream[Int]): Stream[Int] = {
if (n == q) test(n+2, primesStream.tail.head*primesStream.tail.head,
merge(compositeStream,
Stream.from(q, 2*primesStream.head).tail),
primesStream.tail)
else if (n == compositeStream.head) test(n+2, q, compositeStream.tail,
primesStream)
else Stream.cons(n, test(n+2, q, compositeStream, primesStream))
}
Stream.cons(2, Stream.cons(3, Stream.cons(5,
test(7, 25, Stream.from(9, 6), primes().tail.tail))))
}
As an added bonus, there's no need to store the squares of primes as Longs. This will also be much faster and have better algorithmic complexity (time and space) as it avoids doing a lot of superfluous work. Ideone testing shows it runs at about ~ n1.5..1.6 empirical orders of growth in producing up to n = 80,000 primes.
There's still an algorithmic problem here: the structure that is created here is still a linear left-leaning structure (((mults_of_2 + mults_of_3) + mults_of_5) + ...), with more frequently-producing streams situated deeper inside it (so the numbers have more levels to percolate through, going up). The right-leaning structure should be better, mults_of_2 + (mults_of_3 + (mults_of_5 + ...)). Making it a tree should bring a real improvement in time complexity (pushing it down typically to about ~ n1.2..1.25). For a related discussion, see this haskellwiki page.
The "real" imperative sieve of Eratosthenes usually runs at around ~ n1.1 (in n primes produced) and an optimal trial division sieve at ~ n1.40..1.45. Your original code runs at about cubic time, or worse. Using imperative mutable array is usually the fastest, working by segments (a.k.a. the segmented sieve of Eratosthenes).
In the context of your second code, this is how it is achieved in Python.
Is there a straightforward way to translate the code with streams to code with iterators? Or is there a simple way to make my first attempt more memory efficient?
#Will Ness has given you an improved answer using Streams and given reasons why your code is taking so much memory and time as in adding streams early and a left-leaning linear structure, but no one has completely answered the second (or perhaps main) part of your question as to can a true incremental Sieve of Eratosthenes be implemented with Iterator's.
First, we should properly credit this right-leaning algorithm of which your first code is a crude (left-leaning) example (since it prematurely adds all prime composite streams to the merge operations), which is due to Richard Bird as in the Epilogue of Melissa E. O'Neill's definitive paper on incremental Sieve's of Eratosthenes.
Second, no, it isn't really possible to substitute Iterator's for Stream's in this algorithm as it depends on moving through a stream without restarting the stream, and although one can access the head of an iterator (the current position), using the next value (skipping over the head) to generate the rest of the iteration as a stream requires building a completely new iterator at a terrible cost in memory and time. However, we can use an Iterator to output the results of the sequence of primes in order to minimize memory use and make it easy to use iterator higher order functions, as you will see in my code below.
Now Will Ness has walked you though the principles of postponing adding prime composite streams to the calculations until they are needed, which works well when one is storing these in a structure such as a Priority Queue or a HashMap and was even missed in the O'Neill paper, but for the Richard Bird algorithm this is not necessary as future stream values will not be accessed until needed so are not stored if the Streams are being properly lazily built (as is lazily and left-leaning). In fact, this algorithm doesn't even need the memorization and overheads of a full Stream as each composite number culling sequence only moves forward without reference to any past primes other than one needs a separate source of the base primes, which can be supplied by a recursive call of the same algorithm.
For ready reference, let's list the Haskell code of the Richard Bird algorithms as follows:
primes = 2:([3..] ‘minus‘ composites)
where
composites = union [multiples p | p <− primes]
multiples n = map (n*) [n..]
(x:xs) ‘minus‘ (y:ys)
| x < y = x:(xs ‘minus‘ (y:ys))
| x == y = xs ‘minus‘ ys
| x > y = (x:xs) ‘minus‘ ys
union = foldr merge []
where
merge (x:xs) ys = x:merge’ xs ys
merge’ (x:xs) (y:ys)
| x < y = x:merge’ xs (y:ys)
| x == y = x:merge’ xs ys
| x > y = y:merge’ (x:xs) ys
In the following code I have simplified the 'minus' function (called "minusStrtAt") as we don't need to build a completely new stream but can incorporate the composite subtraction operation with the generation of the original (in my case odds only) sequence. I have also simplified the "union" function (renaming it as "mrgMltpls")
The stream operations are implemented as a non memoizing generic Co Inductive Stream (CIS) as a generic class where the first field of the class is the value of the current position of the stream and the second is a thunk (a zero argument function that returns the next value of the stream through embedded closure arguments to another function).
def primes(): Iterator[Long] = {
// generic class as a Co Inductive Stream element
class CIS[A](val v: A, val cont: () => CIS[A])
def mltpls(p: Long): CIS[Long] = {
var px2 = p * 2
def nxtmltpl(cmpst: Long): CIS[Long] =
new CIS(cmpst, () => nxtmltpl(cmpst + px2))
nxtmltpl(p * p)
}
def allMltpls(mps: CIS[Long]): CIS[CIS[Long]] =
new CIS(mltpls(mps.v), () => allMltpls(mps.cont()))
def merge(a: CIS[Long], b: CIS[Long]): CIS[Long] =
if (a.v < b.v) new CIS(a.v, () => merge(a.cont(), b))
else if (a.v > b.v) new CIS(b.v, () => merge(a, b.cont()))
else new CIS(b.v, () => merge(a.cont(), b.cont()))
def mrgMltpls(mlps: CIS[CIS[Long]]): CIS[Long] =
new CIS(mlps.v.v, () => merge(mlps.v.cont(), mrgMltpls(mlps.cont())))
def minusStrtAt(n: Long, cmpsts: CIS[Long]): CIS[Long] =
if (n < cmpsts.v) new CIS(n, () => minusStrtAt(n + 2, cmpsts))
else minusStrtAt(n + 2, cmpsts.cont())
// the following are recursive, where cmpsts uses oddPrms and
// oddPrms uses a delayed version of cmpsts in order to avoid a race
// as oddPrms will already have a first value when cmpsts is called to generate the second
def cmpsts(): CIS[Long] = mrgMltpls(allMltpls(oddPrms()))
def oddPrms(): CIS[Long] = new CIS(3, () => minusStrtAt(5L, cmpsts()))
Iterator.iterate(new CIS(2L, () => oddPrms()))
{(cis: CIS[Long]) => cis.cont()}
.map {(cis: CIS[Long]) => cis.v}
}
The above code generates the 100,000th prime (1299709) on ideone in about 1.3 seconds with about a 0.36 second overhead and has an empirical computational complexity to 600,000 primes of about 1.43. The memory use is negligible above that used by the program code.
The above code could be implemented using the built-in Scala Streams, but there is a performance and memory use overhead (of a constant factor) that this algorithm does not require. Using Streams would mean that one could use them directly without the extra Iterator generation code, but as this is used only for final output of the sequence, it doesn't cost much.
To implement some basic tree folding as Will Ness has suggested, one only needs to add a "pairs" function and hook it into the "mrgMltpls" function:
def primes(): Iterator[Long] = {
// generic class as a Co Inductive Stream element
class CIS[A](val v: A, val cont: () => CIS[A])
def mltpls(p: Long): CIS[Long] = {
var px2 = p * 2
def nxtmltpl(cmpst: Long): CIS[Long] =
new CIS(cmpst, () => nxtmltpl(cmpst + px2))
nxtmltpl(p * p)
}
def allMltpls(mps: CIS[Long]): CIS[CIS[Long]] =
new CIS(mltpls(mps.v), () => allMltpls(mps.cont()))
def merge(a: CIS[Long], b: CIS[Long]): CIS[Long] =
if (a.v < b.v) new CIS(a.v, () => merge(a.cont(), b))
else if (a.v > b.v) new CIS(b.v, () => merge(a, b.cont()))
else new CIS(b.v, () => merge(a.cont(), b.cont()))
def pairs(mltplss: CIS[CIS[Long]]): CIS[CIS[Long]] = {
val tl = mltplss.cont()
new CIS(merge(mltplss.v, tl.v), () => pairs(tl.cont()))
}
def mrgMltpls(mlps: CIS[CIS[Long]]): CIS[Long] =
new CIS(mlps.v.v, () => merge(mlps.v.cont(), mrgMltpls(pairs(mlps.cont()))))
def minusStrtAt(n: Long, cmpsts: CIS[Long]): CIS[Long] =
if (n < cmpsts.v) new CIS(n, () => minusStrtAt(n + 2, cmpsts))
else minusStrtAt(n + 2, cmpsts.cont())
// the following are recursive, where cmpsts uses oddPrms and
// oddPrms uses a delayed version of cmpsts in order to avoid a race
// as oddPrms will already have a first value when cmpsts is called to generate the second
def cmpsts(): CIS[Long] = mrgMltpls(allMltpls(oddPrms()))
def oddPrms(): CIS[Long] = new CIS(3, () => minusStrtAt(5L, cmpsts()))
Iterator.iterate(new CIS(2L, () => oddPrms()))
{(cis: CIS[Long]) => cis.cont()}
.map {(cis: CIS[Long]) => cis.v}
}
The above code generates the 100,000th prime (1299709) on ideone in about 0.75 seconds with about a 0.37 second overhead and has an empirical computational complexity to the 1,000,000th prime (15485863) of about 1.09 (5.13 seconds). The memory use is negligible above that used by the program code.
Note that the above codes are completely functional in that there is no mutable state used whatsoever, but that the Bird algorithm (or even the tree folding version) isn't as fast as using a Priority Queue or HashMap for larger ranges as the number of operations to handle the tree merging has a higher computational complexity than the log n overhead of the Priority Queue or the linear (amortized) performance of a HashMap (although there is a large constant factor overhead to handle the hashing so that advantage isn't really seen until some truly large ranges are used).
The reason that these codes use so little memory is that the CIS streams are formulated with no permanent reference to the start of the streams so that the streams are garbage collected as they are used, leaving only the minimal number of base prime composite sequence place holders, which as Will Ness has explained is very small - only 546 base prime composite number streams for generating the first million primes up to 15485863, each placeholder only taking a few 10's of bytes (eight for the Long number, eight for the 64-bit function reference, with another couple of eight bytes for the pointer to the closure arguments and another few bytes for function and class overheads, for a total per stream placeholder of perhaps 40 bytes, or a total of not much more than 20 Kilobytes for generating the sequence for a million primes).
If you just want an infinite stream of primes, this is the most elegant way in my opinion:
def primes = {
def sieve(from : Stream[Int]): Stream[Int] = from.head #:: sieve(from.tail.filter(_ % from.head != 0))
sieve(Stream.from(2))
}

Fibonnaci Sequence fast implementation

I have written this function in Scala to calculate the fibonacci number given a particular index n:
def fibonacci(n: Long): Long = {
if(n <= 1) n
else
fibonacci(n - 1) + fibonacci(n - 2)
}
However it is not efficient when calculating with large indexes. Therefore I need to implement a function using a tuple and this function should return two consecutive values as the result.
Can somebody give me any hints about this? I have never used Scala before. Thanks!
This question should maybe go to Mathematics.
There is an explicit formula for the Fibonacci sequence. If you need to calculate the Fibonacci number for n without the previous ones, this is much faster. You find it here (Binet's formula): http://en.wikipedia.org/wiki/Fibonacci_number
Here's a simple tail-recursive solution:
def fibonacci(n: Long): Long = {
def fib(i: Long, x: Long, y: Long): Long = {
if (i > 0) fib(i-1, x+y, x)
else x
}
fib(n, 0, 1)
}
The solution you posted takes exponential time since it creates two recursive invocation trees (fibonacci(n - 1) and fibonacci(n - 2)) at each step. By simply tracking the last two numbers, you can recursively compute the answer without any repeated computation.
Can you explain the middle part, why (i-1, x+y, x) etc. Sorry if I am asking too much but I hate to copy and paste code without knowing how it works.
It's pretty simple—but my poor choice of variable names might have made it confusing.
i is simply a counter saying how many steps we have left. If we're calculating the Mth (I'm using M since I already used n in my code) Fibonacci number, then i tells us how many more terms we have left to calculate before we reach the Mth term.
x is the mth term in the Fibonacci sequence, or Fm (where m = M - i).
y is the m-1th term in the Fibonacci sequence, or Fm-1 .
So, on the first call fib(n, 0, 1), we have i=M, x=0, y=1. If you look up the bidirectional Fibonacci sequence, you'll see that F0 = 0 and F-1 = 1, which is why x=0 and y=1 here.
On the next recursive call, fib(i-1, x+y, x), we pass x+y as our next x value. This come straight from the definiton:
Fn = Fn-1 + Fn-2
We pass x as the next y term, since our current Fn-1 is the same as Fn-2 for the next term.
On each step we decrement i since we're one step closer to the final answer.
I am assuming that you don't have saved values from previous computations. If so, it will be faster for you to use the direct formula using the golden ratio instead of the recursive definition. The formula can be found in the Wikipedia page for Fibonnaci number:
floor(pow(phi, n)/root_of_5 + 0.5)
where phi = (1 + sqrt(5)/2).
I have no knowledge of programming in Scala. I am hoping someone on SO will upgrade my pseudo-code to actual Scala code.
Update
Here's another solution again using Streams as below (getting Memoization for free) but a bit more intuitive (aka: without using zip/tail invocation on fibs Stream):
val fibs = Stream.iterate( (0,1) ) { case (a,b)=>(b,a+b) }.map(_._1)
that yields the same output as below for:
fibs take 5 foreach println
Scala supports Memoizations through Streams that is an implementation of lazy lists. This is a perfect fit for Fibonacci implementation which is actually provided as an example in the Scala Api for Streams. Quoting here:
import scala.math.BigInt
object Main extends App {
val fibs: Stream[BigInt] = BigInt(0) #:: BigInt(1) #:: fibs.zip(fibs.tail).map { n => n._1 + n._2 }
fibs take 5 foreach println
}
// prints
//
// 0
// 1
// 1
// 2
// 3

Generate a DAG from a poset using stricly functional programming

Here is my problem: I have a sequence S of (nonempty but possibly not distinct) sets s_i, and for each s_i need to know how many sets s_j in S (i ≠ j) are subsets of s_i.
I also need incremental performance: once I have all my counts, I may replace one set s_i by some subset of s_i and update the counts incrementally.
Performing all this using purely functional code would be a huge plus (I code in Scala).
As set inclusion is a partial ordering, I thought the best way to solve my problem would be to build a DAG that would represent the Hasse diagram of the sets, with edges representing inclusion, and join an integer value to each node representing the size of the sub-dag below the node plus 1. However, I have been stuck for several days trying to develop the algorithm that builds the Hasse diagram from the partial ordering (let's not talk about incrementality!), even though I thought it would be some standard undergraduate material.
Here is my data structure :
case class HNode[A] (
val v: A,
val child: List[HNode[A]]) {
val rank = 1 + child.map(_.rank).sum
}
My DAG is defined by a list of roots and some partial ordering:
class Hasse[A](val po: PartialOrdering[A], val roots: List[HNode[A]]) {
def +(v: A): Hasse[A] = new Hasse[A](po, add(v, roots))
private def collect(v: A, roots: List[HNode[A]], collected: List[HNode[A]]): List[HNode[A]] =
if (roots == Nil) collected
else {
val (subsets, remaining) = roots.partition(r => po.lteq(r.v, v))
collect(v, remaining.map(_.child).flatten, subsets.filter(r => !collected.exists(c => po.lteq(r.v, c.v))) ::: collected)
}
}
I am pretty stuck here. The last I came up to add a new value v to the DAG is:
find all "root subsets" rs_i of v in the DAG, i.e., subsets of v such that no superset of rs_i is a subset of v. This can be done quite easily by performing a search (BFS or DFS) on the graph (collect function, possibly non-optimal or even flawed).
build the new node n_v, the children of which are the previously found rs_i.
Now, let's find out where n_v should be attached: for a given list of roots, find out supersets of v. If none are found, add n_v to the roots and remove subsets of n_v from the roots. Else, perform step 3 recursively on the supersets's children.
I have not yet implemented fully this algorithm, but it seems uncessarily circonvoluted and nonoptimal for my apparently simple problem. Is there some simpler algorithm available (Google was clueless on this)?
After some work, I finally ended up solving my problem, following my initial intuition. The collect method and rank evaluation were flawed, I rewrote them with tail-recursion as a bonus. Here is the code I obtained:
final case class HNode[A](
val v: A,
val child: List[HNode[A]]) {
val rank: Int = 1 + count(child, Set.empty)
#tailrec
private def count(stack: List[HNode[A]], c: Set[HNode[A]]): Int =
if (stack == Nil) c.size
else {
val head :: rem = stack
if (c(head)) count(rem, c)
else count(head.child ::: rem, c + head)
}
}
// ...
private def add(v: A, roots: List[HNode[A]]): List[HNode[A]] = {
val newNode = HNode(v, collect(v, roots, Nil))
attach(newNode, roots)
}
private def attach(n: HNode[A], roots: List[HNode[A]]): List[HNode[A]] =
if (roots.contains(n)) roots
else {
val (supersets, remaining) = roots.partition { r =>
// Strict superset to avoid creating cycles in case of equal elements
po.tryCompare(n.v, r.v) == Some(-1)
}
if (supersets.isEmpty) n :: remaining.filter(r => !po.lteq(r.v, n.v))
else {
supersets.map(s => HNode(s.v, attach(n, s.child))) ::: remaining
}
}
#tailrec
private def collect(v: A, stack: List[HNode[A]], collected: List[HNode[A]]): List[HNode[A]] =
if (stack == Nil) collected
else {
val head :: tail = stack
if (collected.exists(c => po.lteq(head.v, c.v))) collect(v, tail, collected)
else if (po.lteq(head.v, v)) collect(v, tail, head :: (collected.filter(c => !po.lteq(c.v, head.v))))
else collect(v, head.child ::: tail, collected)
}
I now must check some optimization:
- cut off branches with totally distinct sets when collecting subsets (as Rex Kerr suggested)
- see if sorting the sets by size improves the process (as mitchus suggested)
The following problem is to work the (worst case) complexity of the add() operation out.
With n the number of sets, and d the size of the largest set, the complexity will probably be O(n²d), but I hope it can be refined. Here is my reasoning: if all sets are distinct, the DAG will be reduced to a sequence of roots/leaves. Thus, every time I try to add a node to the data structure, I still have to check for inclusion with each node already present (both in collect and attach procedures). This leads to 1 + 2 + … + n = n(n+1)/2 ∈ O(n²) inclusion checks.
Each set inclusion test is O(d), hence the result.
Suppose your DAG G contains a node v for each set, with attributes v.s (the set) and v.count (the number of instances of the set), including a node G.root with G.root.s = union of all sets (where G.root.count=0 if this set never occurs in your collection).
Then to count the number of distinct subsets of s you could do the following (in a bastardized mixture of Scala, Python and pseudo-code):
sum(apply(lambda x: x.count, get_subsets(s, G.root)))
where
get_subsets(s, v) :
if(v.s is not a subset of s, {},
union({v} :: apply(v.children, lambda x: get_subsets(s, x))))
In my opinion though, for performance reasons you would be better off abandoning this kind of purely functional solution... it works well on lists and trees, but beyond that the going gets tough.

Is Scala idiomatic coding style just a cool trap for writing inefficient code?

I sense that the Scala community has a little big obsession with writing "concise", "cool", "scala idiomatic", "one-liner" -if possible- code. This is immediately followed by a comparison to Java/imperative/ugly code.
While this (sometimes) leads to easy to understand code, it also leads to inefficient code for 99% of developers. And this is where Java/C++ is not easy to beat.
Consider this simple problem: Given a list of integers, remove the greatest element. Ordering does not need to be preserved.
Here is my version of the solution (It may not be the greatest, but it's what the average non-rockstar developer would do).
def removeMaxCool(xs: List[Int]) = {
val maxIndex = xs.indexOf(xs.max);
xs.take(maxIndex) ::: xs.drop(maxIndex+1)
}
It's Scala idiomatic, concise, and uses a few nice list functions. It's also very inefficient. It traverses the list at least 3 or 4 times.
Here is my totally uncool, Java-like solution. It's also what a reasonable Java developer (or Scala novice) would write.
def removeMaxFast(xs: List[Int]) = {
var res = ArrayBuffer[Int]()
var max = xs.head
var first = true;
for (x <- xs) {
if (first) {
first = false;
} else {
if (x > max) {
res.append(max)
max = x
} else {
res.append(x)
}
}
}
res.toList
}
Totally non-Scala idiomatic, non-functional, non-concise, but it's very efficient. It traverses the list only once!
So, if 99% of Java developers write more efficient code than 99% of Scala developers, this is a huge
obstacle to cross for greater Scala adoption. Is there a way out of this trap?
I am looking for practical advice to avoid such "inefficiency traps" while keeping implementation clear ans concise.
Clarification: This question comes from a real-life scenario: I had to write a complex algorithm. First I wrote it in Scala, then I "had to" rewrite it in Java. The Java implementation was twice as long, and not that clear, but at the same time it was twice as fast. Rewriting the Scala code to be efficient would probably take some time and a somewhat deeper understanding of scala internal efficiencies (for vs. map vs. fold, etc)
Let's discuss a fallacy in the question:
So, if 99% of Java developers write more efficient code than 99% of
Scala developers, this is a huge obstacle to cross for greater Scala
adoption. Is there a way out of this trap?
This is presumed, with absolutely no evidence backing it up. If false, the question is moot.
Is there evidence to the contrary? Well, let's consider the question itself -- it doesn't prove anything, but shows things are not that clear.
Totally non-Scala idiomatic, non-functional, non-concise, but it's
very efficient. It traverses the list only once!
Of the four claims in the first sentence, the first three are true, and the fourth, as shown by user unknown, is false! And why it is false? Because, contrary to what the second sentence states, it traverses the list more than once.
The code calls the following methods on it:
res.append(max)
res.append(x)
and
res.toList
Let's consider first append.
append takes a vararg parameter. That means max and x are first encapsulated into a sequence of some type (a WrappedArray, in fact), and then passed as parameter. A better method would have been +=.
Ok, append calls ++=, which delegates to +=. But, first, it calls ensureSize, which is the second mistake (+= calls that too -- ++= just optimizes that for multiple elements). Because an Array is a fixed size collection, which means that, at each resize, the whole Array must be copied!
So let's consider this. When you resize, Java first clears the memory by storing 0 in each element, then Scala copies each element of the previous array over to the new array. Since size doubles each time, this happens log(n) times, with the number of elements being copied increasing each time it happens.
Take for example n = 16. It does this four times, copying 1, 2, 4 and 8 elements respectively. Since Java has to clear each of these arrays, and each element must be read and written, each element copied represents 4 traversals of an element. Adding all we have (n - 1) * 4, or, roughly, 4 traversals of the complete list. If you count read and write as a single pass, as people often erroneously do, then it's still three traversals.
One can improve on this by initializing the ArrayBuffer with an initial size equal to the list that will be read, minus one, since we'll be discarding one element. To get this size, we need to traverse the list once, though.
Now let's consider toList. To put it simply, it traverses the whole list to create a new list.
So, we have 1 traversal for the algorithm, 3 or 4 traversals for resize, and 1 additional traversal for toList. That's 4 or 5 traversals.
The original algorithm is a bit difficult to analyse, because take, drop and ::: traverse a variable number of elements. Adding all together, however, it does the equivalent of 3 traversals. If splitAt was used, it would be reduced to 2 traversals. With 2 more traversals to get the maximum, we get 5 traversals -- the same number as the non-functional, non-concise algorithm!
So, let's consider improvements.
On the imperative algorithm, if one uses ListBuffer and +=, then all methods are constant-time, which reduces it to a single traversal.
On the functional algorithm, it could be rewritten as:
val max = xs.max
val (before, _ :: after) = xs span (max !=)
before ::: after
That reduces it to a worst case of three traversals. Of course, there are other alternatives presented, based on recursion or fold, that solve it in one traversal.
And, most interesting of all, all of these algorithms are O(n), and the only one which almost incurred (accidentally) in worst complexity was the imperative one (because of array copying). On the other hand, the cache characteristics of the imperative one might well make it faster, because the data is contiguous in memory. That, however, is unrelated to either big-Oh or functional vs imperative, and it is just a matter of the data structures that were chosen.
So, if we actually go to the trouble of benchmarking, analyzing the results, considering performance of methods, and looking into ways of optimizing it, then we can find faster ways to do this in an imperative manner than in a functional manner.
But all this effort is very different from saying the average Java programmer code will be faster than the average Scala programmer code -- if the question is an example, that is simply false. And even discounting the question, we have seen no evidence that the fundamental premise of the question is true.
EDIT
First, let me restate my point, because it seems I wasn't clear. My point is that the code the average Java programmer writes may seem to be more efficient, but actually isn't. Or, put another way, traditional Java style doesn't gain you performance -- only hard work does, be it Java or Scala.
Next, I have a benchmark and results too, including almost all solutions suggested. Two interesting points about it:
Depending on list size, the creation of objects can have a bigger impact than multiple traversals of the list. The original functional code by Adrian takes advantage of the fact that lists are persistent data structures by not copying the elements right of the maximum element at all. If a Vector was used instead, both left and right sides would be mostly unchanged, which might lead to even better performance.
Even though user unknown and paradigmatic have similar recursive solutions, paradigmatic's is way faster. The reason for that is that he avoids pattern matching. Pattern matching can be really slow.
The benchmark code is here, and the results are here.
def removeOneMax (xs: List [Int]) : List [Int] = xs match {
case x :: Nil => Nil
case a :: b :: xs => if (a < b) a :: removeOneMax (b :: xs) else b :: removeOneMax (a :: xs)
case Nil => Nil
}
Here is a recursive method, which only iterates once. If you need performance, you have to think about it, if not, not.
You can make it tail-recursive in the standard way: giving an extra parameter carry, which is per default the empty List, and collects the result while iterating. That is, of course, a bit longer, but if you need performance, you have to pay for it:
import annotation.tailrec
#tailrec
def removeOneMax (xs: List [Int], carry: List [Int] = List.empty) : List [Int] = xs match {
case a :: b :: xs => if (a < b) removeOneMax (b :: xs, a :: carry) else removeOneMax (a :: xs, b :: carry)
case x :: Nil => carry
case Nil => Nil
}
I don't know what the chances are, that later compilers will improve slower map-calls to be as fast as while-loops. However: You rarely need high speed solutions, but if you need them often, you will learn them fast.
Do you know how big your collection has to be, to use a whole second for your solution on your machine?
As oneliner, similar to Daniel C. Sobrals solution:
((Nil : List[Int], xs(0)) /: xs.tail) ((p, x)=> if (p._2 > x) (x :: p._1, p._2) else ((p._2 :: p._1), x))._1
but that is hard to read, and I didn't measure the effective performance. The normal pattern is (x /: xs) ((a, b) => /* something */). Here, x and a are pairs of List-so-far and max-so-far, which solves the problem to bring everything into one line of code, but isn't very readable. However, you can earn reputation on CodeGolf this way, and maybe someone likes to make a performance measurement.
And now to our big surprise, some measurements:
An updated timing-method, to get the garbage collection out of the way, and have the hotspot-compiler warm up, a main, and many methods from this thread, together in an Object named
object PerfRemMax {
def timed (name: String, xs: List [Int]) (f: List [Int] => List [Int]) = {
val a = System.currentTimeMillis
val res = f (xs)
val z = System.currentTimeMillis
val delta = z-a
println (name + ": " + (delta / 1000.0))
res
}
def main (args: Array [String]) : Unit = {
val n = args(0).toInt
val funs : List [(String, List[Int] => List[Int])] = List (
"indexOf/take-drop" -> adrian1 _,
"arraybuf" -> adrian2 _, /* out of memory */
"paradigmatic1" -> pm1 _, /**/
"paradigmatic2" -> pm2 _,
// "match" -> uu1 _, /*oom*/
"tailrec match" -> uu2 _,
"foldLeft" -> uu3 _,
"buf-=buf.max" -> soc1 _,
"for/yield" -> soc2 _,
"splitAt" -> daniel1,
"ListBuffer" -> daniel2
)
val r = util.Random
val xs = (for (x <- 1 to n) yield r.nextInt (n)).toList
// With 1 Mio. as param, it starts with 100 000, 200k, 300k, ... 1Mio. cases.
// a) warmup
// b) look, where the process gets linear to size
funs.foreach (f => {
(1 to 10) foreach (i => {
timed (f._1, xs.take (n/10 * i)) (f._2)
compat.Platform.collectGarbage
});
println ()
})
}
I renamed all the methods, and had to modify uu2 a bit, to fit to the common method declaration (List [Int] => List [Int]).
From the long result, i only provide the output for 1M invocations:
scala -Dserver PerfRemMax 2000000
indexOf/take-drop: 0.882
arraybuf: 1.681
paradigmatic1: 0.55
paradigmatic2: 1.13
tailrec match: 0.812
foldLeft: 1.054
buf-=buf.max: 1.185
for/yield: 0.725
splitAt: 1.127
ListBuffer: 0.61
The numbers aren't completly stable, depending on the sample size, and a bit varying from run to run. For example, for 100k to 1M runs, in steps of 100k, the timing for splitAt was as follows:
splitAt: 0.109
splitAt: 0.118
splitAt: 0.129
splitAt: 0.139
splitAt: 0.157
splitAt: 0.166
splitAt: 0.749
splitAt: 0.752
splitAt: 1.444
splitAt: 1.127
The initial solution is already pretty fast. splitAt is a modification from Daniel, often faster, but not always.
The measurement was done on a single core 2Ghz Centrino, running xUbuntu Linux, Scala-2.8 with Sun-Java-1.6 (desktop).
The two lessons for me are:
always measure your performance improvements; it is very hard to estimate it, if you don't do it on a daily basis
it is not only fun, to write functional code - sometimes the result is even faster
Here is a link to my benchmarkcode, if somebody is interested.
First of all, the behavior of the methods you presented is not the same. The first one keeps the element ordering, while the second one doesn't.
Second, among all the possible solution which could be qualified as "idiomatic", some are more efficient than others. Staying very close to your example, you can for instance use tail-recursion to eliminate variables and manual state management:
def removeMax1( xs: List[Int] ) = {
def rec( max: Int, rest: List[Int], result: List[Int]): List[Int] = {
if( rest.isEmpty ) result
else if( rest.head > max ) rec( rest.head, rest.tail, max :: result)
else rec( max, rest.tail, rest.head :: result )
}
rec( xs.head, xs.tail, List() )
}
or fold the list:
def removeMax2( xs: List[Int] ) = {
val result = xs.tail.foldLeft( xs.head -> List[Int]() ) {
(acc,x) =>
val (max,res) = acc
if( x > max ) x -> ( max :: res )
else max -> ( x :: res )
}
result._2
}
If you want to keep the original insertion order, you can (at the expense of having two passes, rather than one) without any effort write something like:
def removeMax3( xs: List[Int] ) = {
val max = xs.max
xs.filterNot( _ == max )
}
which is more clear than your first example.
The biggest inefficiency when you're writing a program is worrying about the wrong things. This is usually the wrong thing to worry about. Why?
Developer time is generally much more expensive than CPU time — in fact, there is usually a dearth of the former and a surplus of the latter.
Most code does not need to be very efficient because it will never be running on million-item datasets multiple times every second.
Most code does need to bug free, and less code is less room for bugs to hide.
The example you gave is not very functional, actually. Here's what you are doing:
// Given a list of Int
def removeMaxCool(xs: List[Int]): List[Int] = {
// Find the index of the biggest Int
val maxIndex = xs.indexOf(xs.max);
// Then take the ints before and after it, and then concatenate then
xs.take(maxIndex) ::: xs.drop(maxIndex+1)
}
Mind you, it is not bad, but you know when functional code is at its best when it describes what you want, instead of how you want it. As a minor criticism, if you used splitAt instead of take and drop you could improve it slightly.
Another way of doing it is this:
def removeMaxCool(xs: List[Int]): List[Int] = {
// the result is the folding of the tail over the head
// and an empty list
xs.tail.foldLeft(xs.head -> List[Int]()) {
// Where the accumulated list is increased by the
// lesser of the current element and the accumulated
// element, and the accumulated element is the maximum between them
case ((max, ys), x) =>
if (x > max) (x, max :: ys)
else (max, x :: ys)
// and of which we return only the accumulated list
}._2
}
Now, let's discuss the main issue. Is this code slower than the Java one? Most certainly! Is the Java code slower than a C equivalent? You can bet it is, JIT or no JIT. And if you write it directly in assembler, you can make it even faster!
But the cost of that speed is that you get more bugs, you spend more time trying to understand the code to debug it, and you have less visibility of what the overall program is doing as opposed to what a little piece of code is doing -- which might result in performance problems of its own.
So my answer is simple: if you think the speed penalty of programming in Scala is not worth the gains it brings, you should program in assembler. If you think I'm being radical, then I counter that you just chose the familiar as being the "ideal" trade off.
Do I think performance doesn't matter? Not at all! I think one of the main advantages of Scala is leveraging gains often found in dynamically typed languages with the performance of a statically typed language! Performance matters, algorithm complexity matters a lot, and constant costs matters too.
But, whenever there is a choice between performance and readability and maintainability, the latter is preferable. Sure, if performance must be improved, then there isn't a choice: you have to sacrifice something to it. And if there's no lost in readability/maintainability -- such as Scala vs dynamically typed languages -- sure, go for performance.
Lastly, to gain performance out of functional programming you have to know functional algorithms and data structures. Sure, 99% of Java programmers with 5-10 years experience will beat the performance of 99% of Scala programmers with 6 months experience. The same was true for imperative programming vs object oriented programming a couple of decades ago, and history shows it didn't matter.
EDIT
As a side note, your "fast" algorithm suffer from a serious problem: you use ArrayBuffer. That collection does not have constant time append, and has linear time toList. If you use ListBuffer instead, you get constant time append and toList.
For reference, here's how splitAt is defined in TraversableLike in the Scala standard library,
def splitAt(n: Int): (Repr, Repr) = {
val l, r = newBuilder
l.sizeHintBounded(n, this)
if (n >= 0) r.sizeHint(this, -n)
var i = 0
for (x <- this) {
(if (i < n) l else r) += x
i += 1
}
(l.result, r.result)
}
It's not unlike your example code of what a Java programmer might come up with.
I like Scala because, where performance matters, mutability is a reasonable way to go. The collections library is a great example; especially how it hides this mutability behind a functional interface.
Where performance isn't as important, such as some application code, the higher order functions in Scala's library allow great expressivity and programmer efficiency.
Out of curiosity, I picked an arbitrary large file in the Scala compiler (scala.tools.nsc.typechecker.Typers.scala) and counted something like 37 for loops, 11 while loops, 6 concatenations (++), and 1 fold (it happens to be a foldRight).
What about this?
def removeMax(xs: List[Int]) = {
val buf = xs.toBuffer
buf -= (buf.max)
}
A bit more ugly, but faster:
def removeMax(xs: List[Int]) = {
var max = xs.head
for ( x <- xs.tail )
yield {
if (x > max) { val result = max; max = x; result}
else x
}
}
Try this:
(myList.foldLeft((List[Int](), None: Option[Int]))) {
case ((_, None), x) => (List(), Some(x))
case ((Nil, Some(m), x) => (List(Math.min(x, m)), Some(Math.max(x, m))
case ((l, Some(m), x) => (Math.min(x, m) :: l, Some(Math.max(x, m))
})._1
Idiomatic, functional, traverses only once. Maybe somewhat cryptic if you are not used to functional-programming idioms.
Let's try to explain what is happening here. I will try to make it as simple as possible, lacking some rigor.
A fold is an operation on a List[A] (that is, a list that contains elements of type A) that will take an initial state s0: S (that is, an instance of a type S) and a function f: (S, A) => S (that is, a function that takes the current state and an element from the list, and gives the next state, ie, it updates the state according to the next element).
The operation will then iterate over the elements of the list, using each one to update the state according to the given function. In Java, it would be something like:
interface Function<T, R> { R apply(T t); }
class Pair<A, B> { ... }
<State> State fold(List<A> list, State s0, Function<Pair<A, State>, State> f) {
State s = s0;
for (A a: list) {
s = f.apply(new Pair<A, State>(a, s));
}
return s;
}
For example, if you want to add all the elements of a List[Int], the state would be the partial sum, that would have to be initialized to 0, and the new state produced by a function would simply add the current state to the current element being processed:
myList.fold(0)((partialSum, element) => partialSum + element)
Try to write a fold to multiply the elements of a list, then another one to find extreme values (max, min).
Now, the fold presented above is a bit more complex, since the state is composed of the new list being created along with the maximum element found so far. The function that updates the state is more or less straightforward once you grasp these concepts. It simply puts into the new list the minimum between the current maximum and the current element, while the other value goes to the current maximum of the updated state.
What is a bit more complex than to understand this (if you have no FP background) is to come up with this solution. However, this is only to show you that it exists, can be done. It's just a completely different mindset.
EDIT: As you see, the first and second case in the solution I proposed are used to setup the fold. It is equivalent to what you see in other answers when they do xs.tail.fold((xs.head, ...)) {...}. Note that the solutions proposed until now using xs.tail/xs.head don't cover the case in which xs is List(), and will throw an exception. The solution above will return List() instead. Since you didn't specify the behavior of the function on empty lists, both are valid.
Another option would be:
package code.array
object SliceArrays {
def main(args: Array[String]): Unit = {
println(removeMaxCool(Vector(1,2,3,100,12,23,44)))
}
def removeMaxCool(xs: Vector[Int]) = xs.filter(_ < xs.max)
}
Using Vector instead of List, the reason is that Vector is more versatile and has a better general performance and time complexity if compared to List.
Consider the following collections operations:
head, tail, apply, update, prepend, append
Vector takes an amortized constant time for all operations, as per Scala docs:
"The operation takes effectively constant time, but this might depend on some assumptions such as maximum length of a vector or distribution of hash keys"
While List takes constant time only for head, tail and prepend operations.
Using
scalac -print
generates:
package code.array {
object SliceArrays extends Object {
def main(args: Array[String]): Unit = scala.Predef.println(SliceArrays.this.removeMaxCool(scala.`package`.Vector().apply(scala.Predef.wrapIntArray(Array[Int]{1, 2, 3, 100, 12, 23, 44})).$asInstanceOf[scala.collection.immutable.Vector]()));
def removeMaxCool(xs: scala.collection.immutable.Vector): scala.collection.immutable.Vector = xs.filter({
((x$1: Int) => SliceArrays.this.$anonfun$removeMaxCool$1(xs, x$1))
}).$asInstanceOf[scala.collection.immutable.Vector]();
final <artifact> private[this] def $anonfun$removeMaxCool$1(xs$1: scala.collection.immutable.Vector, x$1: Int): Boolean = x$1.<(scala.Int.unbox(xs$1.max(scala.math.Ordering$Int)));
def <init>(): code.array.SliceArrays.type = {
SliceArrays.super.<init>();
()
}
}
}
Another contender. This uses a ListBuffer, like Daniel's second offering, but shares the post-max tail of the original list, avoiding copying it.
def shareTail(xs: List[Int]): List[Int] = {
var res = ListBuffer[Int]()
var maxTail = xs
var first = true;
var x = xs
while ( x != Nil ) {
if (x.head > maxTail.head) {
while (!(maxTail.head == x.head)) {
res += maxTail.head
maxTail = maxTail.tail
}
}
x = x.tail
}
res.prependToList(maxTail.tail)
}