Scala: what is the most appropriate data structure for sorted subsets? - scala

Given a large collection (let's call it 'a') of elements of type T (say, a Vector or List) and an evaluation function 'f' (say, (T) => Double) I would like to derive from 'a' a result collection 'b' that contains the N elements of 'a' that result in the highest value under f. The collection 'a' may contain duplicates. It is not sorted.
Maybe leaving the question of parallelizability (map/reduce etc.) aside for a moment, what would be the appropriate Scala data structure for compiling the result collection 'b'? Thanks for any pointers / ideas.
Notes:
(1) I guess my use case can be most concisely expressed as
val a = Vector( 9,2,6,1,7,5,2,6,9 ) // just an example
val f : (Int)=>Double = (n)=>n // evaluation function
val b = a.sortBy( f ).take( N ) // sort, then clip
except that I do not want to sort the entire set.
(2) one option might be an iteration over 'a' that fills a TreeSet with 'manual' size bounding (reject anything worse than the worst item in the set, don't let the set grow beyond N). However, I would like to retain duplicates present in the original set in the result set, and so this may not work.
(3) if a sorted multi-set is the right data structure, is there a Scala implementation of this? Or a binary-sorted Vector or Array, if the result set is reasonably small?

You can use a priority queue:
def firstK[A](xs: Seq[A], k: Int)(implicit ord: Ordering[A]) = {
val q = new scala.collection.mutable.PriorityQueue[A]()(ord.reverse)
val (before, after) = xs.splitAt(k)
q ++= before
after.foreach(x => q += ord.max(x, q.dequeue))
q.dequeueAll
}
We fill the queue with the first k elements and then compare each additional element to the head of the queue, swapping as necessary. This works as expected and retains duplicates:
scala> firstK(Vector(9, 2, 6, 1, 7, 5, 2, 6, 9), 4)
res14: scala.collection.mutable.Buffer[Int] = ArrayBuffer(6, 7, 9, 9)
And it doesn't sort the complete list. I've got an Ordering in this implementation, but adapting it to use an evaluation function would be pretty trivial.

Related

Functional Programming way to calculate something like a rolling sum

Let's say I have a list of numerics:
val list = List(4,12,3,6,9)
For every element in the list, I need to find the rolling sum, i,e. the final output should be:
List(4, 16, 19, 25, 34)
Is there any transformation that allows us to take as input two elements of the list (the current and the previous) and compute based on both?
Something like map(initial)((curr,prev) => curr+prev)
I want to achieve this without maintaining any shared global state.
EDIT: I would like to be able to do the same kinds of computation on RDDs.
You may use scanLeft
list.scanLeft(0)(_ + _).tail
The cumSum method below should work for any RDD[N], where N has an implicit Numeric[N] available, e.g. Int, Long, BigInt, Double, etc.
import scala.reflect.ClassTag
import org.apache.spark.rdd.RDD
def cumSum[N : Numeric : ClassTag](rdd: RDD[N]): RDD[N] = {
val num = implicitly[Numeric[N]]
val nPartitions = rdd.partitions.length
val partitionCumSums = rdd.mapPartitionsWithIndex((index, iter) =>
if (index == nPartitions - 1) Iterator.empty
else Iterator.single(iter.foldLeft(num.zero)(num.plus))
).collect
.scanLeft(num.zero)(num.plus)
rdd.mapPartitionsWithIndex((index, iter) =>
if (iter.isEmpty) iter
else {
val start = num.plus(partitionCumSums(index), iter.next)
iter.scanLeft(start)(num.plus)
}
)
}
It should be fairly straightforward to generalize this method to any associative binary operator with a "zero" (i.e. any monoid.) It is the associativity that is key for the parallelization. Without this associativity you're generally going to be stuck with running through the entries of the RDD in a serial fashion.
I don't know what functitonalities are supported by spark RDD, so I am not sure if this satisfies your conditions, because I don't know if zipWithIndex is supported (if the answer is not helpful, please let me know by a comment and I will delete my answer):
list.zipWithIndex.map{x => list.take(x._2+1).sum}
This code works for me, it sums up the elements. It gets the index of the list element, and then adds the corresponding n first elements in the list (notice the +1, since the zipWithIndex starts with 0).
When printing it, I get the following:
List(4, 16, 19, 25, 34)

Breakdown of reduce function

I currently have:
x.collect()
# Result: Array[Int] = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
val a = x.reduce((x,y) => x+1)
# a: Int = 6
val b = x.reduce((x,y) => y + 1)
# b: Int = 12
I have tried to follow what has been said here (http://www.python-course.eu/lambda.php) but still don't quite understand what the individual operations are that lead to these answers.
Could anyone please explain the steps being taken here?
Thank you.
The reason is that the function (x, y) => x + 1 is not associative. reduce requires an associative function. This is necessary to avoid indeterminacy when combining results across different partitions.
You can think of the reduce() method as grabbing two elements from the collection, applying them to a function that results in a new element, and putting that new element back in the collection, replacing the two it grabbed. This is done repeatedly until there is only one element left. In other words, that previous result is going to get re-grabbed until there are no more previous results.
So you can see where (x,y) => x+1 results in a new value (x+1) which would be different from (x,y) => y+1. Cascade that difference through all the iterations and ...

Scala: Side effects with collection transformations

I am trying to understand views in scala via this link http://docs.scala-lang.org/overviews/collections/views.html.
I didn't understand what collection transformations have/not side effects means ! ?
Thank you
By having side effects there is meant a situation when you execute a code in the collection transformations that closures over some external state, or have any side effect that effects anything else than the result of the transformation. Example:
val l = List(1, 2, 3, 4).view.map(x => {println(x); x + 1})
When you execute this code it will print nothing, because view delays the execution of map. Furthermore each time you try to iterate over this list, map will be executed, resulting in printing value more times than it was desirable.
var counter = 0
val ll = for (i <- List(1, 2, 3, 4).view)
yield { counter += 1; i + 1}
println(counter) // 0
println(ll.toList) // this executes .force internally
println(counter) // 4
Behaves in the same way, but it is even more unexpected. counter increases only after the fact of iteration, and given that ll is lazy and delayed, the iteration may happen far more deeper in the code, resulting in counter being equal to 0 before that
Scala has immutable collections (everything in scala.collection.immutable). These collection types have no operations to modfiy them, but only to get modified copies.
So for example this
Set(1) + 2
will give you a new Set containing 1 and 2, not modify the first set. The same holds for transformations as map, flatMap, filter etc.
Views do not change anything about that. The only difference between a view and the collection it is based on is, that (most) operations on views are lazy, i.e. intermediate results are not computed.
val l1 = List(1,2)
val l2 = List(1,2).map(x => x + 1) // a new List(2,3) is computed here
l2.foreach(println) // the elements of l2 are just printed
With views:
val v2 = l1.view.map(x => x + 1) // nothing is computed here
v2.foreach(println) // the values are computed step by step

Constructing BitSets in Scala from a predicate?

Suppose I want to construct a BitSet containing all integers from 0 until n satisfying some predicate f: Int => Boolean.
I could write something like
BitSet((0 until n):_*).filter(f)
which of course works. But it feels rather inefficient! I'm planning on doing this inside a pretty tight loop, and would like suggestions for more efficient ways.
This is the best I could come up with at the moment
BitSet((0 until n).view.filter(f):_*)
The view part makes the filter method lazy. This makes sure that when the BitSet is created from the given sequence, it will filter on the fly. Your original suggestion creates a new BitSet after the first one is created.
If performance is truly your major concern, the best option is probably to use a mutable.BitSet and a while loop, and then call toImmutable on the result.
val bitSet = {
val tmp = new scala.collection.mutable.BitSet(n)
var i = 0;
while (i < n) {
if (f(i)) {
tmp += i
}
i = i + 1
}
tmp.toImmutable
}
I think the most efficient "functional" way is to use foldLeft:
(1 to 5).foldLeft(BitSet())((s,i) => if (f(i)) s + i else s)
It doesn't create an intermediate collection but construct the collection from scratch while filtering.
The first thing I thought is to use breakOut, but it doesn't work for filter:
scala> val set: BitSet = (0 until 10).filter(f)(collection.breakOut)
<console>:11: error: polymorphic expression cannot be instantiated to expected type;
found : [From, T, To]scala.collection.generic.CanBuildFrom[From,T,To]
required: Int
val set: BitSet = (0 until 10).filter(f)(collection.breakOut)
^
scala> val set: BitSet = (0 until 10).map(_+1)(collection.breakOut)
set: scala.collection.immutable.BitSet = BitSet(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
breakOut doesn't create an intermediate collection too, but because filter doesn't have a second parameter list it can't work.

Seq with maximal elements

I have a Seq and function Int => Int. What I need to achieve is to take from original Seq only thoose elements that would be equal to the maximum of the resulting sequence (the one, I'll have after applying given function):
def mapper:Int=>Int= x=>x*x
val s= Seq( -2,-2,2,2 )
val themax= s.map(mapper).max
s.filter( mapper(_)==themax)
But this seems wasteful, since it has to map twice (once for the filter, other for the maximum).
Is there a better way to do this? (without using a cycle, hopefully)
EDIT
The code has since been edited; in the original this was the filter line: s.filter( mapper(_)==s.map(mapper).max). As om-nom-nom has pointed out, this evaluates `s.map(mapper).max each (filter) iteration, leading to quadratic complexity.
Here is a solution that does the mapping only once and using the `foldLeft' function:
The principle is to go through the seq and for each mapped element if it is greater than all mapped before then begin a new sequence with it, otherwise if it is equal return the list of all maximums and the new mapped max. Finally if it is less then return the previously computed Seq of maximums.
def getMaxElems1(s:Seq[Int])(mapper:Int=>Int):Seq[Int] = s.foldLeft(Seq[(Int,Int)]())((res, elem) => {
val e2 = mapper(elem)
if(res.isEmpty || e2>res.head._2)
Seq((elem,e2))
else if (e2==res.head._2)
res++Seq((elem,e2))
else res
}).map(_._1) // keep only original elements
// test with your list
scala> getMaxElems1(s)(mapper)
res14: Seq[Int] = List(-2, -2, 2, 2)
//test with a list containing also non maximal elements
scala> getMaxElems1(Seq(-1, 2,0, -2, 1,-2))(mapper)
res15: Seq[Int] = List(2, -2, -2)
Remark: About complexity
The algorithm I present above has a complexity of O(N) for a list with N elements. However:
the operation of mapping all elements is of complexity O(N)
the operation of computing the max is of complexity O(N)
the operation of zipping is of complexity O(N)
the operation of filtering the list according to the max is also of complexity O(N)
the operation of mapping all elements is of complexity O(M), with M the number of final elements
So, finally the algorithm you presented in your question has the same complexity (quality) than my answer's one, moreover the solution you present is more clear than mine. So, even if the 'foldLeft' is more powerful, for this operation I would recommend your idea, but with zipping original list and computing the map only once (especially if your map is more complicated than a simple square). Here is the solution computed with the help of *scala_newbie* in question/chat/comments.
def getMaxElems2(s:Seq[Int])(mapper:Int=>Int):Seq[Int] = {
val mappedS = s.map(mapper) //map done only once
val m = mappedS.max // find the max
s.zip(mappedS).filter(_._2==themax).unzip._1
}
// test with your list
scala> getMaxElems2(s)(mapper)
res16: Seq[Int] = List(-2, -2, 2, 2)
//test with a list containing also non maximal elements
scala> getMaxElems2(Seq(-1, 2,0, -2, 1,-2))(mapper)
res17: Seq[Int] = List(2, -2, -2)