Scala - splitting dataframe based on number of rows - scala

I have a spark dataframe that has approximately a million records. I'm trying to split this dataframe into multiple small dataframes where each of these dataframes has a maximum rowCount of 20,000 (Each of these dataframes should have a row count of 20,000 except the last dataframe which may or may not have 20,000). Can you help me with this? Thank you.

Ok, maybe not the the most efficient way, but here it is. You can create a new column that counts every row (in case you don't have a unique Id column). Here we are basically iterating over the whole dataframe and selecting batches of size 20k, adding them to a list of DataFrames.
import org.apache.spark.sql.functions._
import spark.implicits._
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.monotonically_increasing_id
var index = 0
val subsetSize = 20000
var listOfDF: List[DataFrame] = List()
// withColumn optional if you already have a unique id per row
val df = spark.table("your_table").withColumn("rowNum", monotonically_increasing_id())
def returnSubDF(fromIndex: Int, toIndex: Int) = {
df.filter($"rowNum" >= fromIndex && $"rowNum" < toIndex)
while (index <= 1000000){
listOfDF = listOfDF :+ returnSubDF(index, index+subsetSize)
index += subsetSize


Spark: Performant way to find top n values

I have a large dataset and I would like to find rows with n highest values.
id, count
id1, 10
id2, 15
id3, 5
The only method I can think of is using row_number without partition like
val window = Window.orderBy(desc("count"))
df.withColumn("row_number", row_number over window).filter(col("row_number") <= n)
but this is in no way performant when the data contains millions or billions of rows because it pushes the data into one partition and I get OOM.
Has anyone managed to come up with a performant solution?
I see two methods to improve your algorithm performance. First is to use sort and limit to retrieve the top n rows. The second is to develop your custom Aggregator.
Sort and Limit method
You sort your dataframe and then you take the first n rows:
val n: Int = ???
import org.apache.spark.functions.sql.desc
Spark optimizes this kind of transformations sequence by first performing sort on each partition, taking first n rows on each partition, retrieving it on a final partition and reperforming sort and taking final first n rows. You can check this by executing explain() on transformations. You get the following execution plan:
== Physical Plan ==
TakeOrderedAndProject(limit=3, orderBy=[count#8 DESC NULLS LAST], output=[id#7,count#8])
+- LocalTableScan [id#7, count#8]
And by looking how TakeOrderedAndProject step is executed in limit.scala in Spark's source code (case class TakeOrderedAndProjectExec, method doExecute).
Custom Aggregator method
For custom aggregator, you create an Aggregator that will populate and update an ordered array of top n rows.
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.Encoder
import scala.collection.mutable.ArrayBuffer
case class Record(id: String, count: Int)
case class TopRecords(limit: Int) extends Aggregator[Record, ArrayBuffer[Record], Seq[Record]] {
def zero: ArrayBuffer[Record] = ArrayBuffer.empty[Record]
def reduce(topRecords: ArrayBuffer[Record], currentRecord: Record): ArrayBuffer[Record] = {
val insertIndex = topRecords.lastIndexWhere(p => p.count > currentRecord.count)
if (topRecords.length < limit) {
topRecords.insert(insertIndex + 1, currentRecord)
} else if (insertIndex < limit - 1) {
topRecords.insert(insertIndex + 1, currentRecord)
topRecords.remove(topRecords.length - 1)
def merge(topRecords1: ArrayBuffer[Record], topRecords2: ArrayBuffer[Record]): ArrayBuffer[Record] = {
val merged = ArrayBuffer.empty[Record]
while (merged.length < limit && (topRecords1.nonEmpty || topRecords2.nonEmpty)) {
if (topRecords1.isEmpty) {
} else if (topRecords2.isEmpty) {
} else if (topRecords2.head.count < topRecords1.head.count) {
} else {
def finish(reduction: ArrayBuffer[Record]): Seq[Record] = reduction
def bufferEncoder: Encoder[ArrayBuffer[Record]] = ExpressionEncoder[ArrayBuffer[Record]]
def outputEncoder: Encoder[Seq[Record]] = ExpressionEncoder[Seq[Record]]
And then you apply this aggregator on your dataframe, and flatten the aggregation result:
val n: Int = ???
import sparkSession.implicits._[Record].select(TopRecords(n).toColumn).flatMap(record => record)
Method comparison
To compare those two methods, let's say we want to take top n rows of a dataframe that is distributed on p partitions, each partition having around k records. So dataframe has size p·k. Which gives the following complexity (subject to errors):
total number of operations
memory consumption(on executor)
memory consumption(on final executor)
Current code
Sort and Limit
O(p·k·log(k) + p·n·log(p·n))
Custom Aggregator
O(k) + O(n)
So regarding number of operations, Custom Aggregator is the most performant. However, this method is by far the most complex and implies lots of serialization/deserialization so it may be less performant than Sort and Limit on certain case.
You have two methods to efficiently take top n rows, Sort and Limit and Custom Aggregator. To select which one to use, you should benchmark those two methods with your real dataframe. If after benchmarking Sort and Limit is a bit slower than Custom aggregator, I would select Sort and Limit as its code is a lot easier to maintain.
Convert to rdd
mapPartitions sorting the data, take N
Convert to df
Then sort and rank and take top N. Unlikely you will have OOM
Actual example, slightly updated, roll your own approach for posterity.
import org.apache.spark.sql.functions._
import spark.sqlContext.implicits._
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{DoubleType, IntegerType, StringType, StructField, StructType}
// 1. Create data
val data = Seq(("James ","","Smith","36636","M",33000),
("Michael ","Rose","","40288","M",14000),
("Robert ","","Williams","42114","M",40),
("Robert ","","Williams","42114","M",540),
("Robert ","","Zeedong","42114","M",40000000),
("Maria ","Anne","Jones","39192","F",300),
("Maria ","Anne","Vangelis","39192","F",1300),
val columns = Seq("firstname","middlename","lastname","dob","gender","val")
val df = data.toDF(columns:_*)
//2. Any number of partitions, and sort that partition. Combiner function like Hadoop.
val df2 = df.repartition(1000,col("lastname")).sortWithinPartitions(desc("val"))
//3. Take top N per partition. Thus num partitions x 2 in this case. The take(n) is the top n per partition. No OOM.
val rdd2 = df2.rdd.mapPartitions(_.take(2))
//4. Ghastly Row to DF work-arounds.
val schema = new StructType()
.add(StructField("f", StringType, true))
.add(StructField("m", StringType, true))
.add(StructField("l", StringType, true))
.add(StructField("d", StringType, true))
.add(StructField("g", StringType, true))
.add(StructField("v", IntegerType, true))
val df3 = spark.createDataFrame(rdd2, schema)
//5. Sort and take top(n) = 2 and Bob's your uncle. The Reduce after Combine.
Returns for top 2 desc:
| f| m| l| d| g| v|
|Robert | |Zeedong|42114| M|40000000|
| James | | Smith|36636| M| 33000|

Iterate rows and columns in Spark dataframe

I have the following Spark dataframe that is created dynamically:
val sf1 = StructField("name", StringType, nullable = true)
val sf2 = StructField("sector", StringType, nullable = true)
val sf3 = StructField("age", IntegerType, nullable = true)
val fields = List(sf1,sf2,sf3)
val schema = StructType(fields)
val row1 = Row("Andy","aaa",20)
val row2 = Row("Berta","bbb",30)
val row3 = Row("Joe","ccc",40)
val data = Seq(row1,row2,row3)
val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
val sqlDF = spark.sql("SELECT * FROM people")
Now, I need to iterate each row and column in sqlDF to print each column, this is my attempt:
sqlDF.foreach { row =>
row.foreach { col => println(col) }
row is type Row, but is not iterable that's why this code throws a compilation error in row.foreach. How to iterate each column in Row?
Consider you have a Dataframe like below
| name|sector|age|
| Andy| aaa| 20|
|Berta| bbb| 30|
| Joe| ccc| 40|
To loop your Dataframe and extract the elements from the Dataframe, you can either chose one of the below approaches.
Approach 1 - Loop using foreach
Looping a dataframe directly using foreach loop is not possible. To do this, first you have to define schema of dataframe using case class and then you have to specify this schema to the dataframe.
import spark.implicits._
import org.apache.spark.sql._
case class cls_Employee(name:String, sector:String, age:Int)
val df = Seq(cls_Employee("Andy","aaa", 20), cls_Employee("Berta","bbb", 30), cls_Employee("Joe","ccc", 40)).toDF()[cls_Employee].take(df.count.toInt).foreach(t => println(s"name=${},sector=${t.sector},age=${t.age}"))
Please see the result below :
Approach 2 - Loop using rdd
Use rdd.collect on top of your Dataframe. The row variable will contain each row of Dataframe of rdd row type. To get each element from a row, use row.mkString(",") which will contain value of each row in comma separated values. Using split function (inbuilt function) you can access each column value of rdd row with index.
for (row <- df.rdd.collect)
var name = row.mkString(",").split(",")(0)
var sector = row.mkString(",").split(",")(1)
var age = row.mkString(",").split(",")(2)
Note that there are two drawback of this approach.
1. If there is a , in the column value, data will be wrongly split to adjacent column.
2. rdd.collect is an action that returns all the data to the driver's memory where driver's memory might not be that much huge to hold the data, ending up with getting the application failed.
I would recommend to use Approach 1.
Approach 3 - Using where and select
You can directly use where and select which will internally loop and finds the data. Since it should not throws Index out of bound exception, an if condition is used
if(df.where($"name" === "Andy").select(col("name")).collect().length >= 1)
name = df.where($"name" === "Andy").select(col("name")).collect()(0).get(0).toString
Approach 4 - Using temp tables
You can register dataframe as temptable which will be stored in spark's memory. Then you can use a select query as like other database to query the data and then collect and save in a variable
name = sqlContext.sql("select name from student where name='Andy'").collect()(0).toString().replace("[","").replace("]","")
You can convert Row to Seq with toSeq. Once turned to Seq you can iterate over it as usual with foreach, map or whatever you need
sqlDF.foreach { row =>
row.toSeq.foreach{col => println(col) }
You should use mkString on your Row:
sqlDF.foreach { row =>
But note that this will be printed inside the executors JVM's, so norally you won't see the output (unless you work with master = local)
sqlDF.foreach is not working for me but Approach 1 from #Sarath Avanavu answer works but it was also playing with the order of the records sometime.
I found one more way which is working
df.collect().foreach { row =>
You should iterate over the partitions which allows the data to be processed by Spark in parallel and you can do foreach on each row inside the partition.
You can further group the data in partition into batches if need be
sqlDF.foreachPartition { partitionedRows: Iterator[Model1] =>
if (partitionedRows.take(1).nonEmpty) {
partitionedRows.grouped(numberOfRowsPerBatch).foreach { batch =>
batch.foreach { row =>
This worked fine for me
sqlDF.collect().foreach(row => row.toSeq.foreach(col => println(col)))
simple collect result and then apply foreach
My solution using FOR because it was I need:
Solution 1:
case class campos_tablas(name:String, sector:String, age:Int)
for (row <-[campos_tablas].take(df.count.toInt))
Solution 2:
for (row <- df.take(df.count.toInt))
Let's assume resultDF is the Dataframe.
val resultDF = // DataFrame //
var itr = 0
val resultRow = resultDF.count
val resultSet = resultDF.collectAsList
var load_id = 0
var load_dt = ""
var load_hr = 0
while ( itr < resultRow ){
col1 = resultSet.get(itr).getInt(0)
col2 = resultSet.get(itr).getString(1) // if column is having String value
col3 = resultSet.get(itr).getLong(2) // if column is having Long value
// Write other logic for your code //
itr = itr + 1

Check every column in a spark dataframe has a certain value

Can we check to see if every column in a spark dataframe contains a certain string(example "Y") using Spark-SQL or scala?
I have tried the following but don't think it is working properly."*")).filter("'*' =='Y'")
You can do something like this to keep the rows where all columns contain 'Y':
//Get all columns
val columns: Array[String] = df.columns
//For each column, keep the rows with 'Y'
val seqDfs: Seq[DataFrame] = => df.filter(s"$name == 'Y'"))
//Union all the dataframes together into one final dataframe
val output: DataFrame = seqDfs.reduceRight(_ union _)
You can use data frame method columns to get all column's names
val columnNames: Array[String] = df.columns
and then add all filters in a loop
var filteredDf ="*"))
for(name <- columnNames) {
filteredDf = filteredDf.filter(s"$name =='Y'")
or you can create a SQL query using same approach
If you want to filter every row, in which any of the columns is equal to 1 (or anything else), you can dynamically create a query like this:
cols = [col(c) == lit(1) for c in patients.columns]
query = cols[0]
for c in cols[1:]:
query |= c
It's a bit verbose, but it is very clear what is happening. A more elegant version would be:
res = df.filter(reduce(lambda x, y: x | y, (col(c) == lit(1) for c in cols)))

Spark Dataframe select based on column index

How do I select all the columns of a dataframe that has certain indexes in Scala?
For example if a dataframe has 100 columns and i want to extract only columns (10,12,13,14,15), how to do the same?
Below selects all columns from dataframe df which has the column name mentioned in the Array colNames:
df =,colNames.tail: _*)
If there is similar, colNos array which has
colNos = Array(10,20,25,45)
How do I transform the above to fetch only those columns at the specific indexes.
You can map over columns:
import org.apache.spark.sql.functions.col map df.columns map col: _*)
or: map (df.columns andThen col): _*)
or: map (col _ compose df.columns): _*)
All the methods shown above are equivalent and don't impose performance penalty. Following mapping:
colNos map df.columns
is just a local Array access (constant time access for each index) and choosing between String or Column based variant of select doesn't affect the execution plan:
val df = Seq((1, 2, 3 ,4, 5, 6)).toDF
val colNos = Seq(0, 3, 5) map df.columns map col: _*).explain
== Physical Plan ==
LocalTableScan [_1#46, _4#49, _6#51]"_1", "_4", "_6").explain
== Physical Plan ==
LocalTableScan [_1#46, _4#49, _6#51]
#user6910411's answer above works like a charm and the number of tasks/logical plan is similar to my approach below. BUT my approach is a bit faster.
I would suggest you to go with the column names rather than column numbers. Column names are much safer and much ligher than using numbers. You can use the following solution :
val colNames = Seq("col1", "col2" ...... "col99", "col100")
val selectColNames = Seq("col1", "col3", .... selected column names ... )
val selectCols = => df.col(name))
df =*)
If you are hesitant to write all the 100 column names then there is a shortcut method too
val colNames = df.schema.fieldNames
Example: Grab first 14 columns of Spark Dataframe by Index using Scala.
import org.apache.spark.sql.functions.col
// Gives array of names by index (first 14 cols for example)
val sliceCols = df.columns.slice(0, 14)
// Maps names & selects columns in dataframe
val subset_df =>col(name)):_*)
You cannot simply do this (as I tried and failed):
// Gives array of names by index (first 14 cols for example)
val sliceCols = df.columns.slice(0, 14)
// Maps names & selects columns in dataframe
val subset_df =
The reason is that you have to convert your datatype of Array[String] to Array[org.apache.spark.sql.Column] in order for the slicing to work.
OR Wrap it in a function using Currying (high five to my colleague for this):
// Subsets Dataframe to using beg_val & end_val index.
def subset_frame(beg_val:Int=0, end_val:Int)(df: DataFrame): DataFrame = {
val sliceCols = df.columns.slice(beg_val, end_val)
return => col(name)):_*)
// Get first 25 columns as subsetted dataframe
val subset_df:DataFrame = df_.transform(subset_frame(0, 25))

How to add corresponding Integer values in 2 different DataFrames

I have two DataFrames in my code with exact same dimensions, let's say 1,000,000 X 50. I need to add corresponding values in both dataframes. How to achieve that.
One option would be to add another column with ids, union both DataFrames and then use reduceByKey. But is there any other more elegent way?
Your approach is good. Another option can be two take the RDD and zip those together and then iterate over those to sum the columns and create a new dataframe using any of the original dataframe schemas.
Assuming the data types for all the columns are integer, this code snippets should work. Please note that, this has been done in spark 2.1.0.
import spark.implicits._
val a: DataFrame = spark.sparkContext.parallelize(Seq(
(1, 2),
(3, 6)
)).toDF("column_1", "column_2")
val b: DataFrame = spark.sparkContext.parallelize(Seq(
(3, 4),
(1, 5)
)).toDF("column_1", "column_2")
// Merge rows
val rows ={
case (rowLeft, rowRight) => {
val totalColumns = rowLeft.schema.fields.size
val summedRow = for(i <- (0 until totalColumns)) yield rowLeft.getInt(i) + rowRight.getInt(i)
// Create new data frame
val ab: DataFrame = spark.createDataFrame(rows, a.schema) // use any of the schemas
So, I tried to experiment with the performance of my solution vs yours. I tested with 100000 rows and each row has 50 columns. In case of your approach it has 51 columns, the extra one is for the ID column. In a single machine(no cluster), my solution seems to work a bit faster.
The union and group by approach takes about 5598 milliseconds.
Where as my solution takes about 5378 milliseconds.
My assumption is the first solution takes a bit more time because of the union operation of the two dataframes.
Here are the methods which I created for testing the approaches.
def option_1()(implicit spark: SparkSession): Unit = {
import spark.implicits._
val a: DataFrame = getDummyData(withId = true)
val b: DataFrame = getDummyData(withId = true)
val allData = a.union(b)
val result = allData.groupBy($"id").agg(allData.columns.collect({ case col if col != "id" => (col, "sum") }).toMap)
def option_2()(implicit spark: SparkSession): Unit = {
val a: DataFrame = getDummyData()
val b: DataFrame = getDummyData()
// Merge rows
val rows = {
case (rowLeft, rowRight) => {
val totalColumns = rowLeft.schema.fields.size
val summedRow = for (i <- (0 until totalColumns)) yield rowLeft.getInt(i) + rowRight.getInt(i)
// Create new data frame
val result: DataFrame = spark.createDataFrame(rows, a.schema) // use any of the schemas