rearrange order of spark columns - scala

I have a spark dataframe with many columns. Using Spark and Scala, I would like to select the columns in a specified order, but I don't want to hardcode the desired order. In pseudo-code, I'd like do something like:
val colNames = df.columns
val newOrder = colNames(colNames.length) ++ colNames(0:colNames.length-1)
How can I do this? Thanks!

You can do something like this:
val df = Seq((1,2,3)).toDF("A","B","C"), df.columns.dropRight(1): _*).show
| C| A| B|
| 3| 1| 2|


Exploding column with index

I know that I can "explode" a column of type array like this:
import org.apache.spark.sql._
import org.apache.spark.sql.functions.explode
val explodedDf =
payloadLegsDf.withColumn("legs", explode(payloadLegsDf.col("legs")))
Now I have multiple rows; one for each item in the array.
Is there a way I can "explode with index"? So that there will be a new column that contains the index of the item in the original array?
(I can think of hacks to do this. First make the array field into an array of tuples of the original value and the index. Then do the explode. Then unpack the tuples. But is there a more elegant way?)
If you are using Spark 2.1+, the posexplode function can be used for that:
Creates a new row for each element with position in the given array or map column.
val df = Seq(
(1L, Array[String]("a", "b")),
(2L, Array[String]("c", "d"))
).toDF("id", "items")
val res =$"id", posexplode($"items"))
This will create two new columns, pos for position/index and col for the extracted value:
| id|pos|col|
| 1| 0| a|
| 1| 1| b|
| 2| 0| c|
| 2| 1| d|

scala: method that return varargs

From a scala method, I want to return a variable number a Spark columns, like this:
def getColumns() : (Column*) = {...}
This idea is then to use it with spark sql:, "anotherColumns"..)
The thing is I have about 30 requests that all have the same select clause, that I want to put in common.
Any idea what to replace with the ...? I tried something like:
($"col1", "$col2")
but it doesn't compile.
Try this:
val df = Seq((1,2,3,4),(5,6,7,8)).toDF("a","b","c","d")
Typecast String to spark columns using map function and append addition column in an array as required.
val lstCols = List("a","b") ++ List(col("c"),col("d")): _*).show()
| a| b| c| d|
| 1| 2| 3| 4|
| 5| 6| 7| 8|

How can I join a list of Spark dataframes together in Scala?

I have a Seq of Spark dataframes (i.e. Seq[org.apache.spark.sql.DataFrame]), it could contain 1 or many elements.
There is a list of columns that is common to each of those dataframes, each dataframe also has some additional columns. What I would like to do is join together all those dataframes using those common columns in the join conditions (remember, the number of dataframes is unknown)
How can I join together all these dataframes? I guess I could foreach over them but that doesn't seem very elegant. Can anyone come up with a more functional way of doing it? edit: A recursive function would be better than a foreach, I'm working on that now, will post it up here when done.
Here is some code that creates a list of n dataframes (n=3 in this case), each of which contains columns id & Product:
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
val conf = new SparkConf().setMaster("local[*]")
val spark = SparkSession.builder().appName("Feature Generator tests").config(conf).config("spark.sql.warehouse.dir", "/tmp/hive").enableHiveSupport().getOrCreate()
val df = spark.range(0, 1000).toDF().withColumn("Product", concat(lit("product"), col("id")))
val dataFrames = Seq(1,2,3).map(s => df.withColumn("_" + s.toString, lit(s)))
To clarify, dataFrames.head.columns returns Array[String] = Array(id, Product, _1).
How might I join those n dataframes together on columns id & Product so that the returned dataframe has columns Array[String] = Array(id, Product, _1, _2, _3)?
dataFrames is a List; You can use the List.reduce method to join all data frames inside:
dataFrames.reduce(_.join(_, Seq("id", "Product"))).show
//| id| Product| _1| _2| _3|
//| 0| product0| 1| 2| 3|
//| 1| product1| 1| 2| 3|
//| 2| product2| 1| 2| 3|
//| 3| product3| 1| 2| 3|
//| 4| product4| 1| 2| 3|
//| ... more rows

Remove all records which are duplicate in spark dataframe

I have a spark dataframe with multiple columns in it. I want to find out and remove rows which have duplicated values in a column (the other columns can be different).
I tried using dropDuplicates(col_name) but it will only drop duplicate entries but still keep one record in the dataframe. What I need is to remove all entries which were initially containing duplicate entries.
I am using Spark 1.6 and Scala 2.10.
I would use window-functions for this. Lets say you want to remove duplicate id rows :
import org.apache.spark.sql.expressions.Window
.withColumn("cnt", count("*").over(Window.partitionBy($"id")))
This can be done by grouping by the column (or columns) to look for duplicates in and then aggregate and filter the results.
Example dataframe df:
| id|num|
| 1| 1|
| 2| 2|
| 3| 3|
| 4| 4|
| 4| 5|
Grouping by the id column to remove its duplicates (the last two rows):
val df2 = df.groupBy("id")
.agg(first($"num").as("num"), count($"id").as("count"))
.filter($"count" === 1)
.select("id", "num")
This will give you:
| id|num|
| 1| 1|
| 2| 2|
| 3| 3|
Alternativly, it can be done by using a join. It will be slower, but if there is a lot of columns there is no need to use first($"num").as("num") for each one to keep them.
val df2 = df.groupBy("id").agg(count($"id").as("count")).filter($"count" === 1).select("id")
val df3 = df.join(df2, Seq("id"), "inner")
I added a killDuplicates() method to the open source spark-daria library that uses #Raphael Roth's solution. Here's how to use the code:
import com.github.mrpowers.spark.daria.sql.DataFrameExt._
// you can also supply multiple Column arguments
df.killDuplicates(col("id"), col("another_column"))
Here's the code implementation:
object DataFrameExt {
implicit class DataFrameMethods(df: DataFrame) {
def killDuplicates(cols: Column*): DataFrame = {
count("*").over(Window.partitionBy(cols: _*))
.where(col("my_super_secret_count") === 1)
You might want to leverage the spark-daria library to keep this logic out of your codebase.

Split 1 column into 3 columns in spark scala

I have a dataframe in Spark using scala that has a column that I need split.
| a.b.c|
| d.e.f|
I need this column split out to look like this:
| a| b| c|
| d| e| f|
I'm using Spark 2.0.0
import sparkObject.spark.implicits._
import org.apache.spark.sql.functions.split
df.withColumn("_tmp", split($"columnToSplit", "\\.")).select(
The important point to note here is that the sparkObject is the SparkSession object you might have already initialized. So, the (1) import statement has to be compulsorily put inline within the code, not before the class definition.
To do this programmatically, you can create a sequence of expressions with (0 until 3).map(i => col("temp").getItem(i).as(s"col$i")) (assume you need 3 columns as result) and then apply it to select with : _* syntax:
df.withColumn("temp", split(col("columnToSplit"), "\\.")).select(
(0 until 3).map(i => col("temp").getItem(i).as(s"col$i")): _*
| a| b| c|
| d| e| f|
To keep all columns:
df.withColumn("temp", split(col("columnToSplit"), "\\.")).select(
col("*") +: (0 until 3).map(i => col("temp").getItem(i).as(s"col$i")): _*
|columnToSplit| temp|col0|col1|col2|
| a.b.c|[a, b, c]| a| b| c|
| d.e.f|[d, e, f]| d| e| f|
If you are using pyspark, use a list comprehension to replace the map in scala:
df = spark.createDataFrame([['a.b.c'], ['d.e.f']], ['columnToSplit'])
from pyspark.sql.functions import col, split
(df.withColumn('temp', split('columnToSplit', '\\.'))
.select(*(col('temp').getItem(i).alias(f'col{i}') for i in range(3))
| a| b| c|
| d| e| f|
A solution which avoids the select part. This is helpful when you just want to append the new columns:
case class Message(others: String, text: String)
val r1 = Message("foo1", "a.b.c")
val r2 = Message("foo2", "d.e.f")
val records = Seq(r1, r2)
val df = spark.createDataFrame(records)
df.withColumn("col1", split(col("text"), "\\.").getItem(0))
.withColumn("col2", split(col("text"), "\\.").getItem(1))
.withColumn("col3", split(col("text"), "\\.").getItem(2))
|others|text |col1|col2|col3|
|foo1 |a.b.c|a |b |c |
|foo2 |d.e.f|d |e |f |
Update: I highly recommend to use Psidom's implementation to avoid splitting three times.
This appends columns to the original DataFrame and doesn't use select, and only splits once using a temporary column:
import spark.implicits._
df.withColumn("_tmp", split($"columnToSplit", "\\."))
.withColumn("col1", $"_tmp".getItem(0))
.withColumn("col2", $"_tmp".getItem(1))
.withColumn("col3", $"_tmp".getItem(2))
This expands on Psidom's answer and shows how to do the split dynamically, without hardcoding the number of columns. This answer runs a query to calculate the number of columns.
val df = Seq(
.withColumn("letters", split(col("my_str"), "\\."))
val numCols = df
.withColumn("letters_size", size($"letters"))
(0 until numCols).map(i => $"letters".getItem(i).as(s"col$i")): _*
We can write using for with yield in Scala :-
If your number of columns exceeds just add it to desired column and play with it. :)
val aDF = Seq("Deepak.Singh.Delhi").toDF("name")
val desiredColumn = Seq("name","Lname","City")
val colsize = desiredColumn.size
val columList = for (i <- 0 until colsize) yield split(col("name"),".").getItem(i).alias(desiredColumn(i)) _ *).show(false)
|name |Lname |city |
|Deepak|Singh |Delhi|
If you don't need name column then, drop the column and just use withColumn.
Without using the select statement.
Lets assume we have a dataframe having a set of columns and we want to split a column having column name as name
import spark.implicits._
val columns = Seq("name","age","address")
val data = Seq(("Amit.Mehta", 25, "1 Main st, Newark, NJ, 92537"),
("Rituraj.Mehta", 28,"3456 Walnut st, Newark, NJ, 94732"))
var dfFromData = spark.createDataFrame(data).toDF(columns:_*)
val newDF =>{
val nameSplit = f.getAs[String](0).split("\\.").map(_.trim)
val finalDF = newDF.toDF("First Name","Last Name", "Age","Address")