Is Scala idiomatic coding style just a cool trap for writing inefficient code? - scala

I sense that the Scala community has a little big obsession with writing "concise", "cool", "scala idiomatic", "one-liner" -if possible- code. This is immediately followed by a comparison to Java/imperative/ugly code.
While this (sometimes) leads to easy to understand code, it also leads to inefficient code for 99% of developers. And this is where Java/C++ is not easy to beat.
Consider this simple problem: Given a list of integers, remove the greatest element. Ordering does not need to be preserved.
Here is my version of the solution (It may not be the greatest, but it's what the average non-rockstar developer would do).
def removeMaxCool(xs: List[Int]) = {
val maxIndex = xs.indexOf(xs.max);
xs.take(maxIndex) ::: xs.drop(maxIndex+1)
}
It's Scala idiomatic, concise, and uses a few nice list functions. It's also very inefficient. It traverses the list at least 3 or 4 times.
Here is my totally uncool, Java-like solution. It's also what a reasonable Java developer (or Scala novice) would write.
def removeMaxFast(xs: List[Int]) = {
var res = ArrayBuffer[Int]()
var max = xs.head
var first = true;
for (x <- xs) {
if (first) {
first = false;
} else {
if (x > max) {
res.append(max)
max = x
} else {
res.append(x)
}
}
}
res.toList
}
Totally non-Scala idiomatic, non-functional, non-concise, but it's very efficient. It traverses the list only once!
So, if 99% of Java developers write more efficient code than 99% of Scala developers, this is a huge
obstacle to cross for greater Scala adoption. Is there a way out of this trap?
I am looking for practical advice to avoid such "inefficiency traps" while keeping implementation clear ans concise.
Clarification: This question comes from a real-life scenario: I had to write a complex algorithm. First I wrote it in Scala, then I "had to" rewrite it in Java. The Java implementation was twice as long, and not that clear, but at the same time it was twice as fast. Rewriting the Scala code to be efficient would probably take some time and a somewhat deeper understanding of scala internal efficiencies (for vs. map vs. fold, etc)

Let's discuss a fallacy in the question:
So, if 99% of Java developers write more efficient code than 99% of
Scala developers, this is a huge obstacle to cross for greater Scala
adoption. Is there a way out of this trap?
This is presumed, with absolutely no evidence backing it up. If false, the question is moot.
Is there evidence to the contrary? Well, let's consider the question itself -- it doesn't prove anything, but shows things are not that clear.
Totally non-Scala idiomatic, non-functional, non-concise, but it's
very efficient. It traverses the list only once!
Of the four claims in the first sentence, the first three are true, and the fourth, as shown by user unknown, is false! And why it is false? Because, contrary to what the second sentence states, it traverses the list more than once.
The code calls the following methods on it:
res.append(max)
res.append(x)
and
res.toList
Let's consider first append.
append takes a vararg parameter. That means max and x are first encapsulated into a sequence of some type (a WrappedArray, in fact), and then passed as parameter. A better method would have been +=.
Ok, append calls ++=, which delegates to +=. But, first, it calls ensureSize, which is the second mistake (+= calls that too -- ++= just optimizes that for multiple elements). Because an Array is a fixed size collection, which means that, at each resize, the whole Array must be copied!
So let's consider this. When you resize, Java first clears the memory by storing 0 in each element, then Scala copies each element of the previous array over to the new array. Since size doubles each time, this happens log(n) times, with the number of elements being copied increasing each time it happens.
Take for example n = 16. It does this four times, copying 1, 2, 4 and 8 elements respectively. Since Java has to clear each of these arrays, and each element must be read and written, each element copied represents 4 traversals of an element. Adding all we have (n - 1) * 4, or, roughly, 4 traversals of the complete list. If you count read and write as a single pass, as people often erroneously do, then it's still three traversals.
One can improve on this by initializing the ArrayBuffer with an initial size equal to the list that will be read, minus one, since we'll be discarding one element. To get this size, we need to traverse the list once, though.
Now let's consider toList. To put it simply, it traverses the whole list to create a new list.
So, we have 1 traversal for the algorithm, 3 or 4 traversals for resize, and 1 additional traversal for toList. That's 4 or 5 traversals.
The original algorithm is a bit difficult to analyse, because take, drop and ::: traverse a variable number of elements. Adding all together, however, it does the equivalent of 3 traversals. If splitAt was used, it would be reduced to 2 traversals. With 2 more traversals to get the maximum, we get 5 traversals -- the same number as the non-functional, non-concise algorithm!
So, let's consider improvements.
On the imperative algorithm, if one uses ListBuffer and +=, then all methods are constant-time, which reduces it to a single traversal.
On the functional algorithm, it could be rewritten as:
val max = xs.max
val (before, _ :: after) = xs span (max !=)
before ::: after
That reduces it to a worst case of three traversals. Of course, there are other alternatives presented, based on recursion or fold, that solve it in one traversal.
And, most interesting of all, all of these algorithms are O(n), and the only one which almost incurred (accidentally) in worst complexity was the imperative one (because of array copying). On the other hand, the cache characteristics of the imperative one might well make it faster, because the data is contiguous in memory. That, however, is unrelated to either big-Oh or functional vs imperative, and it is just a matter of the data structures that were chosen.
So, if we actually go to the trouble of benchmarking, analyzing the results, considering performance of methods, and looking into ways of optimizing it, then we can find faster ways to do this in an imperative manner than in a functional manner.
But all this effort is very different from saying the average Java programmer code will be faster than the average Scala programmer code -- if the question is an example, that is simply false. And even discounting the question, we have seen no evidence that the fundamental premise of the question is true.
EDIT
First, let me restate my point, because it seems I wasn't clear. My point is that the code the average Java programmer writes may seem to be more efficient, but actually isn't. Or, put another way, traditional Java style doesn't gain you performance -- only hard work does, be it Java or Scala.
Next, I have a benchmark and results too, including almost all solutions suggested. Two interesting points about it:
Depending on list size, the creation of objects can have a bigger impact than multiple traversals of the list. The original functional code by Adrian takes advantage of the fact that lists are persistent data structures by not copying the elements right of the maximum element at all. If a Vector was used instead, both left and right sides would be mostly unchanged, which might lead to even better performance.
Even though user unknown and paradigmatic have similar recursive solutions, paradigmatic's is way faster. The reason for that is that he avoids pattern matching. Pattern matching can be really slow.
The benchmark code is here, and the results are here.

def removeOneMax (xs: List [Int]) : List [Int] = xs match {
case x :: Nil => Nil
case a :: b :: xs => if (a < b) a :: removeOneMax (b :: xs) else b :: removeOneMax (a :: xs)
case Nil => Nil
}
Here is a recursive method, which only iterates once. If you need performance, you have to think about it, if not, not.
You can make it tail-recursive in the standard way: giving an extra parameter carry, which is per default the empty List, and collects the result while iterating. That is, of course, a bit longer, but if you need performance, you have to pay for it:
import annotation.tailrec
#tailrec
def removeOneMax (xs: List [Int], carry: List [Int] = List.empty) : List [Int] = xs match {
case a :: b :: xs => if (a < b) removeOneMax (b :: xs, a :: carry) else removeOneMax (a :: xs, b :: carry)
case x :: Nil => carry
case Nil => Nil
}
I don't know what the chances are, that later compilers will improve slower map-calls to be as fast as while-loops. However: You rarely need high speed solutions, but if you need them often, you will learn them fast.
Do you know how big your collection has to be, to use a whole second for your solution on your machine?
As oneliner, similar to Daniel C. Sobrals solution:
((Nil : List[Int], xs(0)) /: xs.tail) ((p, x)=> if (p._2 > x) (x :: p._1, p._2) else ((p._2 :: p._1), x))._1
but that is hard to read, and I didn't measure the effective performance. The normal pattern is (x /: xs) ((a, b) => /* something */). Here, x and a are pairs of List-so-far and max-so-far, which solves the problem to bring everything into one line of code, but isn't very readable. However, you can earn reputation on CodeGolf this way, and maybe someone likes to make a performance measurement.
And now to our big surprise, some measurements:
An updated timing-method, to get the garbage collection out of the way, and have the hotspot-compiler warm up, a main, and many methods from this thread, together in an Object named
object PerfRemMax {
def timed (name: String, xs: List [Int]) (f: List [Int] => List [Int]) = {
val a = System.currentTimeMillis
val res = f (xs)
val z = System.currentTimeMillis
val delta = z-a
println (name + ": " + (delta / 1000.0))
res
}
def main (args: Array [String]) : Unit = {
val n = args(0).toInt
val funs : List [(String, List[Int] => List[Int])] = List (
"indexOf/take-drop" -> adrian1 _,
"arraybuf" -> adrian2 _, /* out of memory */
"paradigmatic1" -> pm1 _, /**/
"paradigmatic2" -> pm2 _,
// "match" -> uu1 _, /*oom*/
"tailrec match" -> uu2 _,
"foldLeft" -> uu3 _,
"buf-=buf.max" -> soc1 _,
"for/yield" -> soc2 _,
"splitAt" -> daniel1,
"ListBuffer" -> daniel2
)
val r = util.Random
val xs = (for (x <- 1 to n) yield r.nextInt (n)).toList
// With 1 Mio. as param, it starts with 100 000, 200k, 300k, ... 1Mio. cases.
// a) warmup
// b) look, where the process gets linear to size
funs.foreach (f => {
(1 to 10) foreach (i => {
timed (f._1, xs.take (n/10 * i)) (f._2)
compat.Platform.collectGarbage
});
println ()
})
}
I renamed all the methods, and had to modify uu2 a bit, to fit to the common method declaration (List [Int] => List [Int]).
From the long result, i only provide the output for 1M invocations:
scala -Dserver PerfRemMax 2000000
indexOf/take-drop: 0.882
arraybuf: 1.681
paradigmatic1: 0.55
paradigmatic2: 1.13
tailrec match: 0.812
foldLeft: 1.054
buf-=buf.max: 1.185
for/yield: 0.725
splitAt: 1.127
ListBuffer: 0.61
The numbers aren't completly stable, depending on the sample size, and a bit varying from run to run. For example, for 100k to 1M runs, in steps of 100k, the timing for splitAt was as follows:
splitAt: 0.109
splitAt: 0.118
splitAt: 0.129
splitAt: 0.139
splitAt: 0.157
splitAt: 0.166
splitAt: 0.749
splitAt: 0.752
splitAt: 1.444
splitAt: 1.127
The initial solution is already pretty fast. splitAt is a modification from Daniel, often faster, but not always.
The measurement was done on a single core 2Ghz Centrino, running xUbuntu Linux, Scala-2.8 with Sun-Java-1.6 (desktop).
The two lessons for me are:
always measure your performance improvements; it is very hard to estimate it, if you don't do it on a daily basis
it is not only fun, to write functional code - sometimes the result is even faster
Here is a link to my benchmarkcode, if somebody is interested.

First of all, the behavior of the methods you presented is not the same. The first one keeps the element ordering, while the second one doesn't.
Second, among all the possible solution which could be qualified as "idiomatic", some are more efficient than others. Staying very close to your example, you can for instance use tail-recursion to eliminate variables and manual state management:
def removeMax1( xs: List[Int] ) = {
def rec( max: Int, rest: List[Int], result: List[Int]): List[Int] = {
if( rest.isEmpty ) result
else if( rest.head > max ) rec( rest.head, rest.tail, max :: result)
else rec( max, rest.tail, rest.head :: result )
}
rec( xs.head, xs.tail, List() )
}
or fold the list:
def removeMax2( xs: List[Int] ) = {
val result = xs.tail.foldLeft( xs.head -> List[Int]() ) {
(acc,x) =>
val (max,res) = acc
if( x > max ) x -> ( max :: res )
else max -> ( x :: res )
}
result._2
}
If you want to keep the original insertion order, you can (at the expense of having two passes, rather than one) without any effort write something like:
def removeMax3( xs: List[Int] ) = {
val max = xs.max
xs.filterNot( _ == max )
}
which is more clear than your first example.

The biggest inefficiency when you're writing a program is worrying about the wrong things. This is usually the wrong thing to worry about. Why?
Developer time is generally much more expensive than CPU time — in fact, there is usually a dearth of the former and a surplus of the latter.
Most code does not need to be very efficient because it will never be running on million-item datasets multiple times every second.
Most code does need to bug free, and less code is less room for bugs to hide.

The example you gave is not very functional, actually. Here's what you are doing:
// Given a list of Int
def removeMaxCool(xs: List[Int]): List[Int] = {
// Find the index of the biggest Int
val maxIndex = xs.indexOf(xs.max);
// Then take the ints before and after it, and then concatenate then
xs.take(maxIndex) ::: xs.drop(maxIndex+1)
}
Mind you, it is not bad, but you know when functional code is at its best when it describes what you want, instead of how you want it. As a minor criticism, if you used splitAt instead of take and drop you could improve it slightly.
Another way of doing it is this:
def removeMaxCool(xs: List[Int]): List[Int] = {
// the result is the folding of the tail over the head
// and an empty list
xs.tail.foldLeft(xs.head -> List[Int]()) {
// Where the accumulated list is increased by the
// lesser of the current element and the accumulated
// element, and the accumulated element is the maximum between them
case ((max, ys), x) =>
if (x > max) (x, max :: ys)
else (max, x :: ys)
// and of which we return only the accumulated list
}._2
}
Now, let's discuss the main issue. Is this code slower than the Java one? Most certainly! Is the Java code slower than a C equivalent? You can bet it is, JIT or no JIT. And if you write it directly in assembler, you can make it even faster!
But the cost of that speed is that you get more bugs, you spend more time trying to understand the code to debug it, and you have less visibility of what the overall program is doing as opposed to what a little piece of code is doing -- which might result in performance problems of its own.
So my answer is simple: if you think the speed penalty of programming in Scala is not worth the gains it brings, you should program in assembler. If you think I'm being radical, then I counter that you just chose the familiar as being the "ideal" trade off.
Do I think performance doesn't matter? Not at all! I think one of the main advantages of Scala is leveraging gains often found in dynamically typed languages with the performance of a statically typed language! Performance matters, algorithm complexity matters a lot, and constant costs matters too.
But, whenever there is a choice between performance and readability and maintainability, the latter is preferable. Sure, if performance must be improved, then there isn't a choice: you have to sacrifice something to it. And if there's no lost in readability/maintainability -- such as Scala vs dynamically typed languages -- sure, go for performance.
Lastly, to gain performance out of functional programming you have to know functional algorithms and data structures. Sure, 99% of Java programmers with 5-10 years experience will beat the performance of 99% of Scala programmers with 6 months experience. The same was true for imperative programming vs object oriented programming a couple of decades ago, and history shows it didn't matter.
EDIT
As a side note, your "fast" algorithm suffer from a serious problem: you use ArrayBuffer. That collection does not have constant time append, and has linear time toList. If you use ListBuffer instead, you get constant time append and toList.

For reference, here's how splitAt is defined in TraversableLike in the Scala standard library,
def splitAt(n: Int): (Repr, Repr) = {
val l, r = newBuilder
l.sizeHintBounded(n, this)
if (n >= 0) r.sizeHint(this, -n)
var i = 0
for (x <- this) {
(if (i < n) l else r) += x
i += 1
}
(l.result, r.result)
}
It's not unlike your example code of what a Java programmer might come up with.
I like Scala because, where performance matters, mutability is a reasonable way to go. The collections library is a great example; especially how it hides this mutability behind a functional interface.
Where performance isn't as important, such as some application code, the higher order functions in Scala's library allow great expressivity and programmer efficiency.
Out of curiosity, I picked an arbitrary large file in the Scala compiler (scala.tools.nsc.typechecker.Typers.scala) and counted something like 37 for loops, 11 while loops, 6 concatenations (++), and 1 fold (it happens to be a foldRight).

What about this?
def removeMax(xs: List[Int]) = {
val buf = xs.toBuffer
buf -= (buf.max)
}
A bit more ugly, but faster:
def removeMax(xs: List[Int]) = {
var max = xs.head
for ( x <- xs.tail )
yield {
if (x > max) { val result = max; max = x; result}
else x
}
}

Try this:
(myList.foldLeft((List[Int](), None: Option[Int]))) {
case ((_, None), x) => (List(), Some(x))
case ((Nil, Some(m), x) => (List(Math.min(x, m)), Some(Math.max(x, m))
case ((l, Some(m), x) => (Math.min(x, m) :: l, Some(Math.max(x, m))
})._1
Idiomatic, functional, traverses only once. Maybe somewhat cryptic if you are not used to functional-programming idioms.
Let's try to explain what is happening here. I will try to make it as simple as possible, lacking some rigor.
A fold is an operation on a List[A] (that is, a list that contains elements of type A) that will take an initial state s0: S (that is, an instance of a type S) and a function f: (S, A) => S (that is, a function that takes the current state and an element from the list, and gives the next state, ie, it updates the state according to the next element).
The operation will then iterate over the elements of the list, using each one to update the state according to the given function. In Java, it would be something like:
interface Function<T, R> { R apply(T t); }
class Pair<A, B> { ... }
<State> State fold(List<A> list, State s0, Function<Pair<A, State>, State> f) {
State s = s0;
for (A a: list) {
s = f.apply(new Pair<A, State>(a, s));
}
return s;
}
For example, if you want to add all the elements of a List[Int], the state would be the partial sum, that would have to be initialized to 0, and the new state produced by a function would simply add the current state to the current element being processed:
myList.fold(0)((partialSum, element) => partialSum + element)
Try to write a fold to multiply the elements of a list, then another one to find extreme values (max, min).
Now, the fold presented above is a bit more complex, since the state is composed of the new list being created along with the maximum element found so far. The function that updates the state is more or less straightforward once you grasp these concepts. It simply puts into the new list the minimum between the current maximum and the current element, while the other value goes to the current maximum of the updated state.
What is a bit more complex than to understand this (if you have no FP background) is to come up with this solution. However, this is only to show you that it exists, can be done. It's just a completely different mindset.
EDIT: As you see, the first and second case in the solution I proposed are used to setup the fold. It is equivalent to what you see in other answers when they do xs.tail.fold((xs.head, ...)) {...}. Note that the solutions proposed until now using xs.tail/xs.head don't cover the case in which xs is List(), and will throw an exception. The solution above will return List() instead. Since you didn't specify the behavior of the function on empty lists, both are valid.

Another option would be:
package code.array
object SliceArrays {
def main(args: Array[String]): Unit = {
println(removeMaxCool(Vector(1,2,3,100,12,23,44)))
}
def removeMaxCool(xs: Vector[Int]) = xs.filter(_ < xs.max)
}
Using Vector instead of List, the reason is that Vector is more versatile and has a better general performance and time complexity if compared to List.
Consider the following collections operations:
head, tail, apply, update, prepend, append
Vector takes an amortized constant time for all operations, as per Scala docs:
"The operation takes effectively constant time, but this might depend on some assumptions such as maximum length of a vector or distribution of hash keys"
While List takes constant time only for head, tail and prepend operations.
Using
scalac -print
generates:
package code.array {
object SliceArrays extends Object {
def main(args: Array[String]): Unit = scala.Predef.println(SliceArrays.this.removeMaxCool(scala.`package`.Vector().apply(scala.Predef.wrapIntArray(Array[Int]{1, 2, 3, 100, 12, 23, 44})).$asInstanceOf[scala.collection.immutable.Vector]()));
def removeMaxCool(xs: scala.collection.immutable.Vector): scala.collection.immutable.Vector = xs.filter({
((x$1: Int) => SliceArrays.this.$anonfun$removeMaxCool$1(xs, x$1))
}).$asInstanceOf[scala.collection.immutable.Vector]();
final <artifact> private[this] def $anonfun$removeMaxCool$1(xs$1: scala.collection.immutable.Vector, x$1: Int): Boolean = x$1.<(scala.Int.unbox(xs$1.max(scala.math.Ordering$Int)));
def <init>(): code.array.SliceArrays.type = {
SliceArrays.super.<init>();
()
}
}
}

Another contender. This uses a ListBuffer, like Daniel's second offering, but shares the post-max tail of the original list, avoiding copying it.
def shareTail(xs: List[Int]): List[Int] = {
var res = ListBuffer[Int]()
var maxTail = xs
var first = true;
var x = xs
while ( x != Nil ) {
if (x.head > maxTail.head) {
while (!(maxTail.head == x.head)) {
res += maxTail.head
maxTail = maxTail.tail
}
}
x = x.tail
}
res.prependToList(maxTail.tail)
}

Related

How to write an efficient groupBy-size filter in Scala, can be approximate

Given a List[Int] in Scala, I wish to get the Set[Int] of all Ints which appear at least thresh times. I can do this using groupBy or foldLeft, then filter. For example:
val thresh = 3
val myList = List(1,2,3,2,1,4,3,2,1)
myList.foldLeft(Map[Int,Int]()){case(m, i) => m + (i -> (m.getOrElse(i, 0) + 1))}.filter(_._2 >= thresh).keys
will give Set(1,2).
Now suppose the List[Int] is very large. How large it's hard to say but in any case this seems wasteful as I don't care about each of the Ints frequencies, and I only care if they're at least thresh. Once it passed thresh there's no need to check anymore, just add the Int to the Set[Int].
The question is: can I do this more efficiently for a very large List[Int],
a) if I need a true, accurate result (no room for mistakes)
b) if the result can be approximate, e.g. by using some Hashing trick or Bloom Filters, where Set[Int] might include some false-positives, or whether {the frequency of an Int > thresh} isn't really a Boolean but a Double in [0-1].
First of all, you can't do better than O(N), as you need to check each element of your initial array at least once. You current approach is O(N), presuming that operations with IntMap are effectively constant.
Now what you can try in order to increase efficiency:
update map only when current counter value is less or equal to threshold. This will eliminate huge number of most expensive operations — map updates
try faster map instead of IntMap. If you know that values of the initial List are in fixed range, you can use Array instead of IntMap (index as the key). Another possible option will be mutable HashMap with sufficient initail capacity. As my benchmark shows it actually makes significant difference
As #ixx proposed, after incrementing value in the map, check whether it's equal to 3 and in this case add it immediately to result list. This will save you one linear traversing (appears to be not that significant for large input)
I don't see how any approximate solution can be faster (only if you ignore some elements at random). Otherwise it will still be O(N).
Update
I created microbenchmark to measure the actual performance of different implementations. For sufficiently large input and output Ixx's suggestion regarding immediately adding elements to result list doesn't produce significant improvement. However similar approach could be used to eliminate unnecessary Map updates (which appears to be the most expensive operation).
Results of benchmarks (avg run times on 1000000 elems with pre-warming):
Authors solution:
447 ms
Ixx solution:
412 ms
Ixx solution2 (eliminated excessive map writes):
150 ms
My solution:
57 ms
My solution involves using mutable HashMap instead of immutable IntMap and includes all other possible optimizations.
Ixx's updated solution:
val tuple = (Map[Int, Int](), List[Int]())
val res = myList.foldLeft(tuple) {
case ((m, s), i) =>
val count = m.getOrElse(i, 0) + 1
(if (count <= 3) m + (i -> count) else m, if (count == thresh) i :: s else s)
}
My solution:
val map = new mutable.HashMap[Int, Int]()
val res = new ListBuffer[Int]
myList.foreach {
i =>
val c = map.getOrElse(i, 0) + 1
if (c == thresh) {
res += i
}
if (c <= thresh) {
map(i) = c
}
}
The full microbenchmark source is available here.
You could use the foldleft to collect the matching items, like this:
val tuple = (Map[Int,Int](), List[Int]())
myList.foldLeft(tuple) {
case((m, s), i) => {
val count = (m.getOrElse(i, 0) + 1)
(m + (i -> count), if (count == thresh) i :: s else s)
}
}
I could measure a performance improvement of about 40% with a small list, so it's definitely an improvement...
Edited to use List and prepend, which takes constant time (see comments).
If by "more efficiently" you mean the space efficiency (in extreme case when the list is infinite), there's a probabilistic data structure called Count Min Sketch to estimate the frequency of items inside it. Then you can discard those with frequency below your threshold.
There's a Scala implementation from Algebird library.
You can change your foldLeft example a bit using a mutable.Set that is build incrementally and at the same time used as filter for iterating over your Seq by using withFilter. However, because I'm using withFilteri cannot use foldLeft and have to make do with foreach and a mutable map:
import scala.collection.mutable
def getItems[A](in: Seq[A], threshold: Int): Set[A] = {
val counts: mutable.Map[A, Int] = mutable.Map.empty
val result: mutable.Set[A] = mutable.Set.empty
in.withFilter(!result(_)).foreach { x =>
counts.update(x, counts.getOrElse(x, 0) + 1)
if (counts(x) >= threshold) {
result += x
}
}
result.toSet
}
So, this would discard items that have already been added to the result set while running through the Seq the first time, because withFilterfilters the Seqin the appended function (map, flatMap, foreach) rather than returning a filtered Seq.
EDIT:
I changed my solution to not use Seq.count, which was stupid, as Aivean correctly pointed out.
Using Aiveans microbench I can see that it is still slightly slower than his approach, but still better than the authors first approach.
Authors solution
377
Ixx solution:
399
Ixx solution2 (eliminated excessive map writes):
110
Sascha Kolbergs solution:
72
Aivean solution:
54

Combination of elements

Problem:
Given a Seq seq and an Int n.
I basically want all combinations of the elements up to size n. The arrangement matters, meaning e.g. [1,2] is different that [2,1].
def combinations[T](seq: Seq[T], size: Int) = ...
Example:
combinations(List(1,2,3), 0)
//Seq(Seq())
combinations(List(1,2,3), 1)
//Seq(Seq(), Seq(1), Seq(2), Seq(3))
combinations(List(1,2,3), 2)
//Seq(Seq(), Seq(1), Seq(2), Seq(3), Seq(1,2), Seq(2,1), Seq(1,3), Seq(3,1),
//Seq(2,3), Seq(3,2))
...
What I have so far:
def combinations[T](seq: Seq[T], size: Int) = {
#tailrec
def inner(seq: Seq[T], soFar: Seq[Seq[T]]): Seq[Seq[T]] = seq match {
case head +: tail => inner(tail, soFar ++ {
val insertList = Seq(head)
for {
comb <- soFar
if comb.size < size
index <- 0 to comb.size
} yield comb.patch(index, insertList, 0)
})
case _ => soFar
}
inner(seq, IndexedSeq(IndexedSeq.empty))
}
What would be your approach to this problem? This method will be called a lot and therefore it should be made most efficient.
There are methods in the library like subsets or combinations (yea I chose the same name), which return iterators. I also thought about that, but I have no idea yet how to design this lazily.
Not sure if this is efficient enough for your purpose but it's the simplest approach.
def combinations[T](seq: Seq[T], size: Int) : Seq[Seq[T]] = {
(1 to size).flatMap(i => seq.combinations(i).flatMap(_.permutations))
}
edit:
to make it lazy you can use a view
def combinations[T](seq: Seq[T], size: Int) : Iterable[Seq[T]] = {
(1 to size).view.flatMap(i => seq.combinations(i).flatMap(_.permutations))
}
From permutations theory we know that the number of permutations of K elements taken from a set of N elements is
N! / (N - K)!
(see http://en.wikipedia.org/wiki/Permutation)
Therefore if you wanna build them all, you will have
algorithm complexity = number of permutations * cost of building each permutation
The potential optimization of the algorithm lies in minimizing the cost of building each permutation, by using a data structure that has some appending / prepending operation that runs in O(1).
You are using an IndexedSeq, which is a collection optimized for O(1) random access. When collections are optimized for random access they are backed by arrays. When using such collections (this is also valid for java ArrayList) you give up the guarantee of a low cost insertion operation: sometimes the array won't be big enough and the collection will have to create a new one and copy all the elements.
When using instead linked lists (such as scala List, which is the default implementation for Seq) you take the opposite choice: you give up constant time access for constant time insertion. In particular, scala List is a recursive data structure with constant time insertion at the front.
So if you are looking for high performance and you need the collection to be available eagerly, use a Seq.empty instead of IndexedSeq.empty and at each iteration prepend the new element at the head of the Seq. If you need something lazy, use Stream which will minimize memory occupation. Additional strategies worth exploring is to create an empty IndexedSeq for your first iteration, but not through Indexed.empty. Use instead the builder and try to provide an array which has the right size (N! / (N-K)!)

Scala - folding on values that result from object interaction

In Scala I have a list of objects that represent points and contain x and y values. The list describes a path that goes through all these points sequentially. My question is how to use folding on that list in order to find the total length of the path? Or maybe there is even a better functional or Scala way to do this?
What I have came up with is this:
def distance = (0 /: wps)(Waypoint.distance(_, _))
but ofcourse this is totally wrong because distance returns Float, but accepts two Waypoint objects.
UPDATE:
Thanks for the proposed solutions! They are definitely interesting, but I think that this is too much functional for real-time calculations that may become heavy. So far I have came out with these lines:
val distances = for(i <- 0 until wps.size) yield wps(i).distanceTo(wps(i + 1))
val distance = (0f /: distances)(_ + _)
I feel this to be a fair imperative/functional mix that is both fast and also leaves the distances values between each waypoint for further possible references which is also a benifit in my case.
UPDATE 2: Actually, to determine, what is faster, I will have to do benchmarks of all the proposed solutions on all types of sequences.
This should work.
(wps, wps drop 1).zipped.map(Waypoint.distance).sum
Don't know if fold can be used here, but try this:
wps.sliding(2).map(segment => Waypoint.distance(segment(0), segment(1))).sum
wps.sliding(2) returns a list of all subsequent pairs. Or if you prefer pattern matching:
wps.sliding(2).collect{case start :: end :: Nil => Waypoint.distance(start, end)}.sum
BTW consider defining:
def distanceTo(to: Waypoint)
on Waypoint class directly, not on companion object as it looks more object-oriented and will allow you to write nice DSL-like code:
point1.distanceTo(point2)
or even:
point1 distanceTo point2
wps.sliding(2).collect{
case start :: end :: Nil => start distanceTo end
}.sum
Your comment "too much functional for real-time calculations that may become heavy" makes this interesting. Benchmarking and profiling are critical, since you don't want to write a bunch of hard-to-maintain code for the sake of performance, only to find out that it's not a performance critical part of your application in the first place! Or, even worse, find out that your performance optimizations makes things worse for your specific workload.
The best performing implementation will depend on your specifics (How long are the paths? How many cores are on the system?) But I think blending imperative and functional approaches may give you the worst-of-both worlds. You could lose out on both readability and performance if you're not careful!
I would very slightly modify missingfaktor's answer to allow you to have performance gains from parallel collections. The fact that simply adding .par could give you a tremendous performance boost demonstrates the power of sticking with functional programming!
def distancePar(wps: collection.GenSeq[Waypoint]): Double = {
val parwps = wps.par
parwps.zip(parwps drop 1).map(Function.tupled(distance)).sum
}
My guess is that this would work best if you have several of cores to throw at the problem, and wps tends to be somewhat long. If you have few cores or short paths, then parallelism will probably hurt more than it helps.
The other extreme would be a fully imperative solution. Writing imperative implementations of individual, performance critical, functions is usually acceptable, so long as you avoid shared mutable state. But once you get used to FP, you'll find this sort of function more difficult to write and maintain. And it's also not easy to parallelize.
def distanceImp(wps: collection.GenSeq[Waypoint]): Double = {
if (wps.size <= 1) {
0.0
} else {
var r = 0.0
var here = wps.head
var remaining = wps.tail
while (!remaining.isEmpty) {
r += distance(here, remaining.head)
here = remaining.head
remaining = remaining.tail
}
r
}
}
Finally, if you're looking for a middle ground between FP and imperative, you might try recursion. I haven't profiled it, but my guess is that this will be roughly equivalent to the imperative solution in terms of performance.
def distanceRec(wps: collection.GenSeq[Waypoint]): Double = {
#annotation.tailrec
def helper(acc: Double, here: Waypoint, remaining: collection.GenSeq[Waypoint]): Double =
if (remaining.isEmpty)
acc
else
helper(acc + distance(here, remaining.head), remaining.head, remaining.tail)
if (wps.size <= 1)
0.0
else
helper(0.0, wps.head, wps.tail)
}
If you are doing indexing of any kind you want to be using Vector, not List:
scala> def timed(op: => Unit) = { val start = System.nanoTime; op; (System.nanoTime - start) / 1e9 }
timed: (op: => Unit)Double
scala> val l = List.fill(100000)(1)
scala> val v = Vector.fill(100000)(1)
scala> timed { var t = 0; for (i <- 0 until l.length - 1) yield t += l(i) + l(i + 1) }
res2: Double = 16.252194583
scala> timed { var t = 0; for (i <- 0 until v.length - 1) yield t += v(i) + v(i + 1) }
res3: Double = 0.047047654
ListBuffer offers fast appends, it doesn't offer fast random access.

Infinite streams in Scala

Say I have a function, for example the old favourite
def factorial(n:Int) = (BigInt(1) /: (1 to n)) (_*_)
Now I want to find the biggest value of n for which factorial(n) fits in a Long. I could do
(1 to 100) takeWhile (factorial(_) <= Long.MaxValue) last
This works, but the 100 is an arbitrary large number; what I really want on the left hand side is an infinite stream that keeps generating higher numbers until the takeWhile condition is met.
I've come up with
val s = Stream.continually(1).zipWithIndex.map(p => p._1 + p._2)
but is there a better way?
(I'm also aware I could get a solution recursively but that's not what I'm looking for.)
Stream.from(1)
creates a stream starting from 1 and incrementing by 1. It's all in the API docs.
A Solution Using Iterators
You can also use an Iterator instead of a Stream. The Stream keeps references of all computed values. So if you plan to visit each value only once, an iterator is a more efficient approach. The downside of the iterator is its mutability, though.
There are some nice convenience methods for creating Iterators defined on its companion object.
Edit
Unfortunately there's no short (library supported) way I know of to achieve something like
Stream.from(1) takeWhile (factorial(_) <= Long.MaxValue) last
The approach I take to advance an Iterator for a certain number of elements is drop(n: Int) or dropWhile:
Iterator.from(1).dropWhile( factorial(_) <= Long.MaxValue).next - 1
The - 1 works for this special purpose but is not a general solution. But it should be no problem to implement a last method on an Iterator using pimp my library. The problem is taking the last element of an infinite Iterator could be problematic. So it should be implemented as method like lastWith integrating the takeWhile.
An ugly workaround can be done using sliding, which is implemented for Iterator:
scala> Iterator.from(1).sliding(2).dropWhile(_.tail.head < 10).next.head
res12: Int = 9
as #ziggystar pointed out, Streams keeps the list of previously computed values in memory, so using Iterator is a great improvment.
to further improve the answer, I would argue that "infinite streams", are usually computed (or can be computed) based on pre-computed values. if this is the case (and in your factorial stream it definately is), I would suggest using Iterator.iterate instead.
would look roughly like this:
scala> val it = Iterator.iterate((1,BigInt(1))){case (i,f) => (i+1,f*(i+1))}
it: Iterator[(Int, scala.math.BigInt)] = non-empty iterator
then, you could do something like:
scala> it.find(_._2 >= Long.MaxValue).map(_._1).get - 1
res0: Int = 22
or use #ziggystar sliding solution...
another easy example that comes to mind, would be fibonacci numbers:
scala> val it = Iterator.iterate((1,1)){case (a,b) => (b,a+b)}.map(_._1)
it: Iterator[Int] = non-empty iterator
in these cases, your'e not computing your new element from scratch every time, but rather do an O(1) work for every new element, which would improve your running time even more.
The original "factorial" function is not optimal, since factorials are computed from scratch every time. The simplest/immutable implementation using memoization is like this:
val f : Stream[BigInt] = 1 #:: (Stream.from(1) zip f).map { case (x,y) => x * y }
And now, the answer can be computed like this:
println( "count: " + (f takeWhile (_<Long.MaxValue)).length )
The following variant does not test the current, but the next integer, in order to find and return the last valid number:
Iterator.from(1).find(i => factorial(i+1) > Long.MaxValue).get
Using .get here is acceptable, since find on an infinite sequence will never return None.

Why is Clojure much faster than Scala on a recursive add function?

A friend gave me this code snippet in Clojure
(defn sum [coll acc] (if (empty? coll) acc (recur (rest coll) (+ (first coll) acc))))
(time (sum (range 1 9999999) 0))
and asked me how does it fare against a similar Scala implementation.
The Scala code I've written looks like this:
def from(n: Int): Stream[Int] = Stream.cons(n, from(n+1))
val ints = from(1).take(9999998)
def add(a: Stream[Int], b: Long): Long = {
if (a.isEmpty) b else add(a.tail, b + a.head)
}
val t1 = System.currentTimeMillis()
println(add(ints, 0))
val t2 = System.currentTimeMillis()
println((t2 - t1).asInstanceOf[Float] + " msecs")
Bottom line is: the code in Clojure runs in about 1.8 seconds on my machine and uses less than 5MB of heap, the code in Scala runs in about 12 seconds and 512MB of heap aren't enough (it finishes the computation if I set the heap to 1GB).
So I'm wondering why is Clojure so much faster and slimmer in this particular case? Do you have a Scala implementation that has a similar behavior in terms of speed and memory usage?
Please refrain from religious remarks, my interest lies in finding out primarily what makes clojure so fast in this case and if there's a faster implementation of the algo in scala. Thanks.
First, Scala only optimises tail calls if you invoke it with -optimise. Edit: It seems Scala will always optimise tail-call recursions if it can, even without -optimise.
Second, Stream and Range are two very different things. A Range has a beginning and an end, and its projection has just a counter and the end. A Stream is a list which will be computed on-demand. Since you are adding the whole ints, you'll compute, and, therefore, allocate, the whole Stream.
A closer code would be:
import scala.annotation.tailrec
def add(r: Range) = {
#tailrec
def f(i: Iterator[Int], acc: Long): Long =
if (i.hasNext) f(i, acc + i.next) else acc
f(r iterator, 0)
}
def time(f: => Unit) {
val t1 = System.currentTimeMillis()
f
val t2 = System.currentTimeMillis()
println((t2 - t1).asInstanceOf[Float]+" msecs")
}
Normal run:
scala> time(println(add(1 to 9999999)))
49999995000000
563.0 msecs
On Scala 2.7 you need "elements" instead of "iterator", and there's no "tailrec" annotation -- that annotation is used just to complain if a definition can't be optimized with tail recursion -- so you'll need to strip "#tailrec" as well as the "import scala.annotation.tailrec" from the code.
Also, some considerations on alternate implementations. The simplest:
scala> time(println(1 to 9999999 reduceLeft (_+_)))
-2014260032
640.0 msecs
On average, with multiple runs here, it is slower. It's also incorrect, because it works just with Int. A correct one:
scala> time(println((1 to 9999999 foldLeft 0L)(_+_)))
49999995000000
797.0 msecs
That's slower still, running here. I honestly wouldn't have expected it to run slower, but each interation calls to the function being passed. Once you consider that, it's a pretty good time compared to the recursive version.
Clojure's range does not memoize, Scala's Stream does. Totally different data structures with totally different results. Scala does have a non memoizing Range structure, but it's currently kind of awkard to work with in this simple recursive way. Here's my take on the whole thing.
Using Clojure 1.0 on an older box, which is slow, I get 3.6 seconds
user=> (defn sum [coll acc] (if (empty? coll) acc (recur (rest coll) (+ (first coll) acc))))
#'user/sum
user=> (time (sum (range 1 9999999) 0))
"Elapsed time: 3651.751139 msecs"
49999985000001
A literal translation to Scala requires me to write some code
def time[T](x : => T) = {
val start = System.nanoTime : Double
val result = x
val duration = (System.nanoTime : Double) - start
println("Elapsed time " + duration / 1000000.0 + " msecs")
result
}
It's good to make sure that that's right
scala> time (Thread sleep 1000)
Elapsed time 1000.277967 msecs
Now we need an unmemoized Range with similar semantics to Clojure's
case class MyRange(start : Int, end : Int) {
def isEmpty = start >= end
def first = if (!isEmpty) start else error("empty range")
def rest = new MyRange(start + 1, end)
}
From that "add" follows directly
def add(a: MyRange, b: Long): Long = {
if (a.isEmpty) b else add(a.rest, b + a.first)
}
And it times much faster than Clojure's on the same box
scala> time(add(MyRange(1, 9999999), 0))
Elapsed time 252.526784 msecs
res1: Long = 49999985000001
Using Scala's standard library Range, you can do a fold. It's not as fast as simple primitive recursion, but its less code and still faster than the Clojure recursive version (at least on my box).
scala> time((1 until 9999999 foldLeft 0L)(_ + _))
Elapsed time 1995.566127 msecs
res2: Long = 49999985000001
Contrast with a fold over a memoized Stream
time((Stream from 1 take 9999998 foldLeft 0L)(_ + _))
Elapsed time 3879.991318 msecs
res3: Long = 49999985000001
I would suspect it's due to how Clojure handles tail-cail optimizations. Since the JVM doesn't natively perform this optimization (and both Clojure and Scala run on it), Clojure optimizes tail recursion through the recur keyword. From the Clojure site:
In functional languages looping and
iteration are replaced/implemented via
recursive function calls. Many such
languages guarantee that function
calls made in tail position do not
consume stack space, and thus
recursive loops utilize constant
space. Since Clojure uses the Java
calling conventions, it cannot, and
does not, make the same tail call
optimization guarantees. Instead, it
provides the recur special operator,
which does constant-space recursive
looping by rebinding and jumping to
the nearest enclosing loop or function
frame. While not as general as
tail-call-optimization, it allows most
of the same elegant constructs, and
offers the advantage of checking that
calls to recur can only happen in a
tail position.
EDIT: Scala optimizes tail calls also, as long as they're in a certain form. However, as the previous link shows, Scala can only do this for very simple cases:
In fact, this is a feature of the Scala compiler called tail call optimization. It
optimizes away the recursive call. This feature works only in simple cases as above,
though. If the recursion is indirect, for example, Scala cannot optimize tail calls,
because of the limited JVM instruction set.
Without actually compiling and decompiling your code to see what JVM instructions are produced, I suspect it's just not one of those simple cases (as Michael put it, due to having to fetch a.tail on each recursive step) and thus Scala just can't optimize it.
Profiled this example of yours and it seems that the class Stream (well... some anonymous function related to it - forgot its name as visualvm crashed on me) occupies most of the heap. It's related to the fact that Streams in Scala do leak memory - see Scala Trac #692. Fixes are due in Scala 2.8.. EDIT: Daniel's comment rightly pointed out that it is not related to this bug. It's because "val ints points to the Stream head, the garbage collector can't collect anything" [Daniel]. I found the comments in this bug report nice to read though, in relation to this question.
In your add function, you are holding a reference to a.head, therefore the garbage collector cannot collect the head, leading to a stream that holds 9999998 elements in the end, which cannot be GC-ed.
[A little interlude]
You may also keep copies of the tails you keep passing, I am not sure how Streams deal with that. If you would use a list, tails would not be copied. For example:
val xs = List(1,2,3)
val ys = 1 :: xs
val zs = 2 :: xs
Here, both ys and zs 'share' the same tail, at least heap-wise (ys.tail eq zs.tail, aka reference equality yields true).
[This little interlude was to make the point that passing a lot of tails is not a really bad thing in principle :), they are not copied, at least for lists]
An alternative implementation (which runs quite fast, and I think it is more clear than the pure functional one) is to use an imperative approach:
def addTo(n: Int, init: Int): Long = {
var sum = init.toLong
for(i <- 1 to n) sum += i
sum
}
scala> addTo(9999998, 0)
In Scala it is quite OK to use an imperative approach, for performance and clarity (at least to me, this version of add is more clear to its intent). For even more conciseness, you could even write
(1 to 9999998).reduceLeft(_ + _)
(runs a bit slower, but still reasonable and doesn't blow the memory up)
I believe that Clojure might be faster as it is fully functional, therefore more optimisations are possible than with Scala (which blends functional, OO and imperative). I am not very familiar with Clojure though.
Hope this helps :)