Reading Spark method sortByKey :
sortByKey([ascending], [numTasks]) When called on a dataset of (K, V) pairs where K implements Ordered, returns a dataset of (K, V) pairs sorted by keys in ascending or descending order, as specified in the boolean ascending argument.
Is it possible to return just "N" amount of results. So instead of returning all results, just return the top 10. I could convert the sorted collection to an Array and use take method but since this is an O(N) operation is there a more efficient method ?

If you only need the top 10, use It avoids sorting, so it is faster. makes one parallel pass through the data, collecting the top N in each partition in a heap, then merges the heaps. It is an O(rdd.count) operation. Sorting would be O(rdd.count log rdd.count), and incur a lot of data transfer — it does a shuffle, so all of the data would be transmitted over the network.

Most likely you have already perused the source code:
class OrderedRDDFunctions {
// <snip>
def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = {
val part = new RangePartitioner(numPartitions, self, ascending)
val shuffled = new ShuffledRDD[K, V, P](self, part)
shuffled.mapPartitions(iter => {
val buf = iter.toArray
if (ascending) {
buf.sortWith((x, y) => x._1 < y._1).iterator
} else {
buf.sortWith((x, y) => x._1 > y._1).iterator
}, preservesPartitioning = true)
And, as you say, the entire data must go through the shuffle stage - as seen in the snippet.
However, your concern about subsequently invoking take(K) may not be so accurate. This operation does NOT cycle through all N items:
* Take the first num elements of the RDD. It works by first scanning one partition, and use the
* results from that partition to estimate the number of additional partitions needed to satisfy
* the limit.
def take(num: Int): Array[T] = {
So then, it would seem:
O(myRdd.take(K)) << O(myRdd.sortByKey()) ~= O(myRdd.sortByKey.take(k))
(at least for small K) << O(myRdd.sortByKey().collect()

Another option, at least from PySpark 1.2.0, is the use of takeOrdered.
In ascending order:
In descending order:
rdd.takeOrdered(10, lambda x: -x)
Top k values for k,v pairs:
rdd.takeOrdered(10, lambda (k, v): -v)


Optimize Sorting Iterable Values after grouping in Spark

I have RDD[(String,(Int, Int)], I need to get top 10 values(tuples) for each key after sorting. I tried:
val sortedRDD = rdd.groupByKey.mapValues( x => x.toList.sortWith((x,y) => <<sorting logic>>).take(10))
This throws OutOfMemoryException as Iterable[(Int, Int)] is large for few keys for some keys. How should i handle this?, Is there a way to do this without using .groupByKey().
You should use aggregateByKey instead of groupByKey to perform the sorting and "trimming" (that keeps only top 10) while grouping instead of grouping into potentially-huge groups and only then mapping the result.
Here's how this could look:
// your sorting logic:
val sortingFunction: ((Int, Int), (Int, Int)) => Boolean = ???
val N = 10
val sortedRDD = rdd.aggregateByKey(List[(Int, Int)]())(
// first function: seqOp, how to add another item of the group to the result
case (topSoFar, candidate) if topSoFar.size < N => candidate :: topSoFar
case (topTen, candidate) => (candidate :: topTen).sortWith(sortingFunction).take(N)
// second function: combOp, how to add combine two partial results created by seqOp
{ case (list1, list2) => (list1 ++ list2).sortWith(sortingFunction).take(N) }
Notice that per group, we always create values that are 10 items or less.
NOTE: performance can possibly be improved by performing less "sort" operations (we sort the same list again and again whenever we add another item / list). To solve that, you can consider using a "sorted set" with a limited capacity (see Limited SortedSet) as the value, so that each addition efficiently adds or discards the new value without sorting.

Spark: flatMap/reduceByKey seems to be quite slow with Long keys on some distributions

I'm using Spark to process some corpora and I need to count the occurrence of each 2-gram. I started with counting tuples (wordID1, wordID2) and it worked fine except for the large memory usage and gc overhead due to the substantial number of small tuple objects. Then I tried to pack a pair of Ints into a Long, and the gc overhead did reduce greatly, but the run time also increased several times.
I ran some small experiments with random data on different distributions. It seems that the performance issue only occurs on exponential distributed data.
// lines of word IDs
val data = (1 to 5000){ _ =>
(1 to 1000) map { _ => (-1000 * Math.log(Random.nextDouble)).toInt }
// count Tuples, fast
sc parallelize(data) flatMap { line =>
val first = line.iterator
val second = line.iterator.drop(1)
for (pair <- first zip(second))
yield (pair, 1L)
} reduceByKey { _ + _ } count()
// count Long, slow
sc parallelize(data) flatMap { line =>
val first = line.iterator
val second = line.iterator.drop(1)
for ((a, b) <- first zip(second))
yield ((a.toLong << 32) | b, 1L)
} reduceByKey { _ + _ } count()
The job is split into two stages, flatMap() and count(). When counting Tuple2s, flatMap() takes about 6s and count() takes about 2s, while when counting Longs, flatMap() takes 18s and count() takes 10s.
It doesn't make sense to me as Longs should impose less overhead than Tuple2. Does spark has some specializations for Long keys, which happen to perform even worse for some specific distributions?
Thanks to #SarveshKumarSingh's hint, I finally solved the problem. It is not the Spark's specialization for Long that trigger the issue, but Java's, and Spark doesn't address it properly.
Java's hashCode() for Long is quite simple and strongly dependent on the two halves of the values, and Spark's default HashPartitioner simply partition keys according their hashCode() values modulo the partition number. These make Spark's default partitioning quite sensitive to the distribution of Long keys, especially when the number of partitions is relatively small. And in my case, the situation deteriorates as the Long keys are constructed via concatenating pairs of Ints.
The solutions would be quite straightforward as we just need to somehow "shuffle" the keys, which makes the keys with similar frequencies distributed evenly.
The simplest way is to map each key into another unique value using some perfect hash function, and convert it back when the original key is required. This approach involves only small code changes, but might not perform very well. I achieved performance similar to the count-by-tuple approach using the following mappings.
val newKey = oldKey * 6364136223846793005L + 1442695040888963407L
val oldKey = (newKey - 1442695040888963407L) * -4568919932995229531L
A more effective way is to substitute the default HashPartitioner. I used the following partitioner between flatMap and reduceByKey and achieved two times performance boost on real world data.
val prevRDD = // ... flatMap ...
val nParts = prevRDD.partitioner match {
case Some(p) => p.numPartitions
case None => prevRDD.partitions.size
prevRDD partitionBy (new Partitioner {
override def getPartition(key: Any): Int = {
val rawMod = LongHash(key.asInstanceOf[Long]) % numPartitions
rawMod + (if (rawMod < 0) numPartitions else 0)
override def numPartitions: Int = nParts
}) reduceByKey { _ + _ }
def LongHash(v: Long) = { // the 64bit mix function from Murmurhash3
var k = v
k ^= k >> 33
k *= 0xff51afd7ed558ccdL
k ^= k >> 33
k *= 0xc4ceb9fe1a85ec53L
k ^= k >> 33

Selecting every 3rd element from a huge RDD [duplicate]

I'm looking for a way to split an RDD into two or more RDDs. The closest I've seen is Scala Spark: Split collection into several RDD? which is still a single RDD.
If you're familiar with SAS, something like this:
data work.split1, work.split2;
set work.preSplit;
if (condition1)
output work.split1
else if (condition2)
output work.split2
which resulted in two distinct data sets. It would have to be immediately persisted to get the results I intend...
It is not possible to yield multiple RDDs from a single transformation*. If you want to split a RDD you have to apply a filter for each split condition. For example:
def even(x): return x % 2 == 0
def odd(x): return not even(x)
rdd = sc.parallelize(range(20))
rdd_odd, rdd_even = (rdd.filter(f) for f in (odd, even))
If you have only a binary condition and computation is expensive you may prefer something like this:
kv_rdd = x: (x, odd(x)))
rdd_odd = kv_rdd.filter(lambda kv: kv[1]).keys()
rdd_even = kv_rdd.filter(lambda kv: not kv[1]).keys()
It means only a single predicate computation but requires additional pass over all data.
It is important to note that as long as an input RDD is properly cached and there no additional assumptions regarding data distribution there is no significant difference when it comes to time complexity between repeated filter and for-loop with nested if-else.
With N elements and M conditions number of operations you have to perform is clearly proportional to N times M. In case of for-loop it should be closer to (N + MN) / 2 and repeated filter is exactly NM but at the end of the day it is nothing else than O(NM). You can see my discussion** with Jason Lenderman to read about some pros-and-cons.
At the very high level you should consider two things:
Spark transformations are lazy, until you execute an action your RDD is not materialized
Why does it matter? Going back to my example:
rdd_odd, rdd_even = (rdd.filter(f) for f in (odd, even))
If later I decide that I need only rdd_odd then there is no reason to materialize rdd_even.
If you take a look at your SAS example to compute work.split2 you need to materialize both input data and work.split1.
RDDs provide a declarative API. When you use filter or map it is completely up to Spark engine how this operation is performed. As long as the functions passed to transformations are side effects free it creates multiple possibilities to optimize a whole pipeline.
At the end of the day this case is not special enough to justify its own transformation.
This map with filter pattern is actually used in a core Spark. See my answer to How does Sparks RDD.randomSplit actually split the RDD and a relevant part of the randomSplit method.
If the only goal is to achieve a split on input it is possible to use partitionBy clause for DataFrameWriter which text output format:
def makePairs(row: T): (String, String) = ???
.map(makePairs).toDF("key", "value")
* There are only 3 basic types of transformations in Spark:
RDD[T] => RDD[T]
RDD[T] => RDD[U]
(RDD[T], RDD[U]) => RDD[W]
where T, U, W can be either atomic types or products / tuples (K, V). Any other operation has to be expressed using some combination of the above. You can check the original RDD paper for more details.
*** See also Scala Spark: Split collection into several RDD?
As other posters mentioned above, there is no single, native RDD transform that splits RDDs, but here are some "multiplex" operations that can efficiently emulate a wide variety of "splitting" on RDDs, without reading multiple times:
Some methods specific to random splitting:
Methods are available from open source silex project:
A blog post explaining how they work:
def muxPartitions[U :ClassTag](n: Int, f: (Int, Iterator[T]) => Seq[U],
persist: StorageLevel): Seq[RDD[U]] = {
val mux = self.mapPartitionsWithIndex { case (id, itr) =>
Iterator.single(f(id, itr))
Vector.tabulate(n) { j => mux.mapPartitions { itr => Iterator.single( } }
def flatMuxPartitions[U :ClassTag](n: Int, f: (Int, Iterator[T]) => Seq[TraversableOnce[U]],
persist: StorageLevel): Seq[RDD[U]] = {
val mux = self.mapPartitionsWithIndex { case (id, itr) =>
Iterator.single(f(id, itr))
Vector.tabulate(n) { j => mux.mapPartitions { itr => } }
As mentioned elsewhere, these methods do involve a trade-off of memory for speed, because they operate by computing entire partition results "eagerly" instead of "lazily." Therefore, it is possible for these methods to run into memory problems on large partitions, where more traditional lazy transforms will not.
One way is to use a custom partitioner to partition the data depending upon your filter condition. This can be achieved by extending Partitioner and implementing something similar to the RangePartitioner.
A map partitions can then be used to construct multiple RDDs from the partitioned RDD without reading all the data.
val filtered = partitioned.mapPartitions { iter => {
new Iterator[Int](){
override def hasNext: Boolean = {
if(rangeOfPartitionsToKeep.contains(TaskContext.get().partitionId)) {
} else {
override def next():Int =
Just be aware that the number of partitions in the filtered RDDs will be the same as the number in the partitioned RDD so a coalesce should be used to reduce this down and remove the empty partitions.
If you split an RDD using the randomSplit API call, you get back an array of RDDs.
If you want 5 RDDs returned, pass in 5 weight values.
val sourceRDD = val sourceRDD = sc.parallelize(1 to 100, 4)
val seedValue = 5
val splitRDD = sourceRDD.randomSplit(Array(1.0,1.0,1.0,1.0,1.0), seedValue)
res7: Array[Int] = Array(1, 6, 11, 12, 20, 29, 40, 62, 64, 75, 77, 83, 94, 96, 100)

How to write an efficient groupBy-size filter in Scala, can be approximate

Given a List[Int] in Scala, I wish to get the Set[Int] of all Ints which appear at least thresh times. I can do this using groupBy or foldLeft, then filter. For example:
val thresh = 3
val myList = List(1,2,3,2,1,4,3,2,1)
myList.foldLeft(Map[Int,Int]()){case(m, i) => m + (i -> (m.getOrElse(i, 0) + 1))}.filter(_._2 >= thresh).keys
will give Set(1,2).
Now suppose the List[Int] is very large. How large it's hard to say but in any case this seems wasteful as I don't care about each of the Ints frequencies, and I only care if they're at least thresh. Once it passed thresh there's no need to check anymore, just add the Int to the Set[Int].
The question is: can I do this more efficiently for a very large List[Int],
a) if I need a true, accurate result (no room for mistakes)
b) if the result can be approximate, e.g. by using some Hashing trick or Bloom Filters, where Set[Int] might include some false-positives, or whether {the frequency of an Int > thresh} isn't really a Boolean but a Double in [0-1].
First of all, you can't do better than O(N), as you need to check each element of your initial array at least once. You current approach is O(N), presuming that operations with IntMap are effectively constant.
Now what you can try in order to increase efficiency:
update map only when current counter value is less or equal to threshold. This will eliminate huge number of most expensive operations — map updates
try faster map instead of IntMap. If you know that values of the initial List are in fixed range, you can use Array instead of IntMap (index as the key). Another possible option will be mutable HashMap with sufficient initail capacity. As my benchmark shows it actually makes significant difference
As #ixx proposed, after incrementing value in the map, check whether it's equal to 3 and in this case add it immediately to result list. This will save you one linear traversing (appears to be not that significant for large input)
I don't see how any approximate solution can be faster (only if you ignore some elements at random). Otherwise it will still be O(N).
I created microbenchmark to measure the actual performance of different implementations. For sufficiently large input and output Ixx's suggestion regarding immediately adding elements to result list doesn't produce significant improvement. However similar approach could be used to eliminate unnecessary Map updates (which appears to be the most expensive operation).
Results of benchmarks (avg run times on 1000000 elems with pre-warming):
Authors solution:
447 ms
Ixx solution:
412 ms
Ixx solution2 (eliminated excessive map writes):
150 ms
My solution:
57 ms
My solution involves using mutable HashMap instead of immutable IntMap and includes all other possible optimizations.
Ixx's updated solution:
val tuple = (Map[Int, Int](), List[Int]())
val res = myList.foldLeft(tuple) {
case ((m, s), i) =>
val count = m.getOrElse(i, 0) + 1
(if (count <= 3) m + (i -> count) else m, if (count == thresh) i :: s else s)
My solution:
val map = new mutable.HashMap[Int, Int]()
val res = new ListBuffer[Int]
myList.foreach {
i =>
val c = map.getOrElse(i, 0) + 1
if (c == thresh) {
res += i
if (c <= thresh) {
map(i) = c
The full microbenchmark source is available here.
You could use the foldleft to collect the matching items, like this:
val tuple = (Map[Int,Int](), List[Int]())
myList.foldLeft(tuple) {
case((m, s), i) => {
val count = (m.getOrElse(i, 0) + 1)
(m + (i -> count), if (count == thresh) i :: s else s)
I could measure a performance improvement of about 40% with a small list, so it's definitely an improvement...
Edited to use List and prepend, which takes constant time (see comments).
If by "more efficiently" you mean the space efficiency (in extreme case when the list is infinite), there's a probabilistic data structure called Count Min Sketch to estimate the frequency of items inside it. Then you can discard those with frequency below your threshold.
There's a Scala implementation from Algebird library.
You can change your foldLeft example a bit using a mutable.Set that is build incrementally and at the same time used as filter for iterating over your Seq by using withFilter. However, because I'm using withFilteri cannot use foldLeft and have to make do with foreach and a mutable map:
import scala.collection.mutable
def getItems[A](in: Seq[A], threshold: Int): Set[A] = {
val counts: mutable.Map[A, Int] = mutable.Map.empty
val result: mutable.Set[A] = mutable.Set.empty
in.withFilter(!result(_)).foreach { x =>
counts.update(x, counts.getOrElse(x, 0) + 1)
if (counts(x) >= threshold) {
result += x
So, this would discard items that have already been added to the result set while running through the Seq the first time, because withFilterfilters the Seqin the appended function (map, flatMap, foreach) rather than returning a filtered Seq.
I changed my solution to not use Seq.count, which was stupid, as Aivean correctly pointed out.
Using Aiveans microbench I can see that it is still slightly slower than his approach, but still better than the authors first approach.
Authors solution
Ixx solution:
Ixx solution2 (eliminated excessive map writes):
Sascha Kolbergs solution:
Aivean solution:

Sort a list by an ordered index

Let us assume that I have the following two sequences:
val index = Seq(2,5,1,4,7,6,3)
val unsorted = Seq(7,6,5,4,3,2,1)
The first is the index by which the second should be sorted. My current solution is to traverse over the index and construct a new sequence with the found elements from the unsorted sequence.
val sorted = index.foldLeft(Seq[Int]()) { (s, num) =>
s ++ Seq(unsorted.find(_ == num).get)
But this solution seems very inefficient and error-prone to me. On every iteration it searches the complete unsorted sequence. And if the index and the unsorted list aren't in sync, then either an error will be thrown or an element will be omitted. In both cases, the not in sync elements should be appended to the ordered sequence.
Is there a more efficient and solid solution for this problem? Or is there a sort algorithm which fits into this paradigm?
Note: This is a constructed example. In reality I would like to sort a list of mongodb documents by an ordered list of document Id's.
Update 1
I've selected the answer from Marius Danila because it seems the more fastest and scala-ish solution for my problem. It doesn't come with a not in sync item solution, but this could be easily implemented.
So here is the updated solution:
def sort[T: ClassTag, Key](index: Seq[Key], unsorted: Seq[T], key: T => Key): Seq[T] = {
val positionMapping = HashMap(index.zipWithIndex: _*)
val inSync = new Array[T](unsorted.size)
val notInSync = new ArrayBuffer[T]()
for (item <- unsorted) {
if (positionMapping.contains(key(item))) {
inSync(positionMapping(key(item))) = item
} else {
inSync.filterNot(_ == null) ++ notInSync
Update 2
The approach suggested by seems the correct answer. It also doesn't consider the not in sync issue, but this can also be easily implemented.
val index: Seq[String]
val entities: Seq[Foo]
val idToEntityMap = => -> e).toMap
val sorted =
val result = sorted ++ entities.filterNot(sorted.toSet)
Why do you want to sort collection, when you already have sorted index collection? You can just use map
Concerning> In reality I would like to sort a list of mongodb documents by an ordered list of document Id's.
val ids: Seq[String]
val entities: Seq[Foo]
val idToEntityMap = => -> e).toMap _)
This may not exactly map to your use case, but Googlers may find this useful:
scala> val ids = List(3, 1, 0, 2)
ids: List[Int] = List(3, 1, 0, 2)
scala> val unsorted = List("third", "second", "fourth", "first")
unsorted: List[String] = List(third, second, fourth, first)
scala> val sorted = ids map unsorted
sorted: List[String] = List(first, second, third, fourth)
I do not know the language that you are using. But irrespective of the language this is how i would have solved the problem.
From the first list (here 'index') create a hash table taking key as the document id and the value as the position of the document in the sorted order.
Now when traversing through the list of document i would lookup the hash table using the document id and then get the position it should be in the sorted order. Then i would use this obtained order to sort in a pre allocated memory.
Note: if the number of documents is small then instead of using hashtable u could use a pre allocated table and index it directly using the document id.
Flat Mapping the index over the unsorted list seems to be a safer version (if the index isn't found it's just dropped since find returns a None):
index.flatMap(i => unsorted.find(_ == i))
It still has to traverse the unsorted list every time (worst case this is O(n^2)). With you're example I'm not sure that there's a more efficient solution.
In this case you can use zip-sort-unzip:
(unsorted zip index).sortWith(_._2 < _._2).unzip._1
Btw, if you can, better solution would be to sort list on db side using $orderBy.
Let's start from the beginning.
Besides the fact you're rescanning the unsorted list each time, the Seq object will create, by default a List collection. So in the foldLeft you're appending an element at the end of the list each time and this is a O(N^2) operation.
An improvement would be
val sorted_rev = index.foldLeft(Seq[Int]()) { (s, num) =>
unsorted.find(_ == num).get +: s
val sorted = sorted_rev.reverse
But that is still an O(N^2) algorithm. We can do better.
The following sort function should work:
def sort[T: ClassTag, Key](index: Seq[Key], unsorted: Seq[T], key: T => Key): Seq[T] = {
val positionMapping = HashMap(index.zipWithIndex: _*) //1
val arr = new Array[T](unsorted.size) //2
for (item <- unsorted) { //3
val position = positionMapping(key(item))
arr(position) = item
arr //6
The function sorts a list of items unsorted by a sequence of indexes index where the key function will be used to extract the id from the objects you're trying to sort.
Line 1 creates a reverse index - mapping each object id to its final position.
Line 2 allocates the array which will hold the sorted sequence. We're using an array since we need constant-time random-position set performance.
The loop that starts at line 3 will traverse the sequence of unsorted items and place each item in it's meant position using the positionMapping reverse index
Line 6 will return the array converted implicitly to a Seq using the WrappedArray wrapper.
Since our reverse-index is an immutable HashMap, lookup should take constant-time for regular cases. Building the actual reverse-index takes O(N_Index) time where N_Index is the size of the index sequence. Traversing the unsorted sequence takes O(N_Unsorted) time where N_Unsorted is the size of the unsorted sequence.
So the complexity is O(max(N_Index, N_Unsorted)), which I guess is the best you can do in the circumstances.
For your particular example, you would call the function like so:
val sorted = sort(index, unsorted, identity[Int])
For the real case, it would probably be like this:
val sorted = sort(idList, unsorted, obj =>
The best I can do is to create a Map from the unsorted data, and use map lookups (basically the hashtable suggested by a previous poster). The code looks like:
val unsortedAsMap = => x -> x).toMap
Or, if there's a possibility of hash misses:
val unsortedAsMap = => x -> x).toMap
It's O(n) in time*, but you're swapping time for space, as it uses O(n) space.
For a slightly more sophisticated version, that handles missing values, try:
import scala.collection.JavaConversions._
import scala.collection.mutable.ListBuffer
val unsortedAsMap = new java.util.LinkedHashMap[Int, Int]
for (i <- unsorted) unsortedAsMap.add(i, i)
val newBuffer = ListBuffer.empty[Int]
for (i <- index) {
val r = unsortedAsMap.remove(i)
if (r != null) newBuffer += i
// Not sure what to do for "else"
for ((k, v) <- unsortedAsMap) newBuffer += v
If it's a MongoDB database in the first place, you might be better retrieving documents directly from the database by index, so something like:
*technically it's O(n log n), as Scala's standard immutable map is O(log n), but you could always use a mutable map, which is O(1)