cannot select a specific column for ReduceByKey operation Spark - scala

I create a DataFrame which is showed as below, I want to apply map-reduce algorithm for column 'title', but when I use reduceByKey function, I encounter some problems.
+-------+--------------------+------------+-----------+
|project| title|requests_num|return_size|
+-------+--------------------+------------+-----------+
| aa|%CE%92%CE%84_%CE%...| 1| 4854|
| aa|%CE%98%CE%B5%CF%8...| 1| 4917|
| aa|%CE%9C%CF%89%CE%A...| 1| 4832|
| aa|%CE%A0%CE%B9%CE%B...| 1| 4828|
| aa|%CE%A3%CE%A4%CE%8...| 1| 4819|
| aa|%D0%A1%D0%BE%D0%B...| 1| 4750|
| aa| 271_a.C| 1| 4675|
| aa|Battaglia_di_Qade...| 1| 4765|
| aa| Category:User_th| 1| 4770|
| aa| Chiron_Elias_Krase| 1| 4694|
| aa|County_Laois/en/Q...| 1| 4752|
| aa| Dassault_rafaele| 2| 9372|
| aa|Dyskusja_wikiproj...| 1| 4824|
| aa| E.Desv| 1| 4662|
| aa|Enclos-apier/fr/E...| 1| 4772|
| aa|File:Wiktionary-l...| 1| 10752|
| aa|Henri_de_Sourdis/...| 1| 4748|
| aa|Incentive_Softwar...| 1| 4777|
| aa|Indonesian_Wikipedia| 1| 4679|
| aa| Main_Page| 5| 266946|
+-------+--------------------+------------+-----------+
I try this, but it doesn't work:
dataframe.select("title").map(word => (word,1)).reduceByKey(_+_);
it seems that I should transfer dataframe to list first and then use map function to generate key-value pairs(word,1), finally sum up key value.
I a method for transfering dataframe to list from stackoverflow,
for example
val text =dataframe.select("title").map(r=>r(0).asInstanceOf[String]).collect()
but an error occurs
scala> val text = dataframe.select("title").map(r=>r(0).asInstanceOf[String]).collect()
2018-04-08 21:49:35 WARN NettyRpcEnv:66 - Ignored message: HeartbeatResponse(false)
2018-04-08 21:49:35 WARN Executor:87 - Issue communicating with driver in heartbeater
org.apache.spark.rpc.RpcTimeoutException: Futures timed out after [10 seconds]. This timeout is controlled by spark.executor.heartbeatInterval
at org.apache.spark.rpc.RpcTimeout.org$apache$spark$rpc$RpcTimeout$$createRpcTimeoutException(RpcTimeout.scala:47)
at org.apache.spark.rpc.RpcTimeout$$anonfun$addMessageIfTimeout$1.applyOrElse(RpcTimeout.scala:62)
at org.apache.spark.rpc.RpcTimeout$$anonfun$addMessageIfTimeout$1.applyOrElse(RpcTimeout.scala:58)
at scala.runtime.AbstractPartialFunction.apply(AbstractPartialFunction.scala:36)
at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:76)
at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:92)
at org.apache.spark.executor.Executor.org$apache$spark$executor$Executor$$reportHeartBeat(Executor.scala:785)
at org.apache.spark.executor.Executor$$anon$2$$anonfun$run$1.apply$mcV$sp(Executor.scala:814)
at org.apache.spark.executor.Executor$$anon$2$$anonfun$run$1.apply(Executor.scala:814)
at org.apache.spark.executor.Executor$$anon$2$$anonfun$run$1.apply(Executor.scala:814)
at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1988)
at org.apache.spark.executor.Executor$$anon$2.run(Executor.scala:814)
at java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:511)
at java.util.concurrent.FutureTask.runAndReset(FutureTask.java:308)
at java.util.concurrent.ScheduledThreadPoolExecutor$ScheduledFutureTask.access$301(ScheduledThreadPoolExecutor.java:180)
at java.util.concurrent.ScheduledThreadPoolExecutor$ScheduledFutureTask.run(ScheduledThreadPoolExecutor.java:294)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
Caused by: java.util.concurrent.TimeoutException: Futures timed out after [10 seconds]
at scala.concurrent.impl.Promise$DefaultPromise.ready(Promise.scala:219)
at scala.concurrent.impl.Promise$DefaultPromise.result(Promise.scala:223)
at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:201)
at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
... 14 more
java.lang.OutOfMemoryError: Java heap space
at org.apache.spark.sql.execution.SparkPlan$$anon$1.next(SparkPlan.scala:280)
at org.apache.spark.sql.execution.SparkPlan$$anon$1.next(SparkPlan.scala:276)
at scala.collection.Iterator$class.foreach(Iterator.scala:893)
at org.apache.spark.sql.execution.SparkPlan$$anon$1.foreach(SparkPlan.scala:276)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeCollect$1.apply(SparkPlan.scala:298)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeCollect$1.apply(SparkPlan.scala:297)
at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186)
at org.apache.spark.sql.execution.SparkPlan.executeCollect(SparkPlan.scala:297)
at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3272)
at org.apache.spark.sql.Dataset$$anonfun$collect$1.apply(Dataset.scala:2722)
at org.apache.spark.sql.Dataset$$anonfun$collect$1.apply(Dataset.scala:2722)
at org.apache.spark.sql.Dataset$$anonfun$52.apply(Dataset.scala:3253)
at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:77)
at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3252)
at org.apache.spark.sql.Dataset.collect(Dataset.scala:2722)
... 16 elided
scala> val text = dataframe.select("title").map(r=>r(0).asInstanceOf[String]).collect()
java.lang.OutOfMemoryError: GC overhead limit exceeded
at org.apache.spark.sql.execution.SparkPlan$$anon$1.next(SparkPlan.scala:280)
at org.apache.spark.sql.execution.SparkPlan$$anon$1.next(SparkPlan.scala:276)
at scala.collection.Iterator$class.foreach(Iterator.scala:893)
at org.apache.spark.sql.execution.SparkPlan$$anon$1.foreach(SparkPlan.scala:276)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeCollect$1.apply(SparkPlan.scala:298)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeCollect$1.apply(SparkPlan.scala:297)
at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186)
at org.apache.spark.sql.execution.SparkPlan.executeCollect(SparkPlan.scala:297)
at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3272)
at org.apache.spark.sql.Dataset$$anonfun$collect$1.apply(Dataset.scala:2722)
at org.apache.spark.sql.Dataset$$anonfun$collect$1.apply(Dataset.scala:2722)
at org.apache.spark.sql.Dataset$$anonfun$52.apply(Dataset.scala:3253)
at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:77)
at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3252)
at org.apache.spark.sql.Dataset.collect(Dataset.scala:2722)
... 16 elided

Collect-ing your DataFrame to a Scala collection would impose constraint on your dataset size. Rather, you could convert the DataFrame to a RDD then apply map and reduceByKey as below:
val df = Seq(
("aa", "271_a.C", 1, 4675),
("aa", "271_a.C", 1, 4400),
("aa", "271_a.C", 1, 4600),
("aa", "Chiron_Elias_Krase", 1, 4694),
("aa", "Chiron_Elias_Krase", 1, 4500)
).toDF("project", "title", "request_num", "return_size")
import org.apache.spark.sql.functions._
import org.apache.spark.sql.Row
val rdd = df.rdd.
map{ case Row(_, title: String, _, _) => (title, 1) }.
reduceByKey(_ + _)
rdd.collect
// res1: Array[(String, Int)] = Array((Chiron_Elias_Krase,2), (271_a.C,3))
You could also transform your DataFrame directly using groupBy:
df.groupBy($"title").agg(count($"title").as("count")).
show
// +------------------+-----+
// | title|count|
// +------------------+-----+
// | 271_a.C| 3|
// |Chiron_Elias_Krase| 2|
// +------------------+-----+

Related

How to extract a previous and next line sentence in spark?

I'm analyzing a log file for customer impact analysis by using Apache spark. I have the log file that contains the time stamp in one line, customer's details in another line and the error caused by in another line, I want the output in one file which will combine all the extracted record to one line. Here is my log file below:
2018-10-15 05:24:00.102 ERROR 7 --- [DefaultMessageListenerContainer-2] c.l.p.a.c.event.listener.MQListener : The ABC/CDE object received for the xyz event was not valid. e_id=11111111, s_id=111, e_name=ABC
com.xyz.abc.pqr.exception.PNotVException: The r received from C was invalid/lacks mandatory fields. S_id: 123, P_Id: 123456789, R_Number: 12345678
at com.xyz.abc.pqr.mprofile.CPServiceImpl.lambda$bPByC$1(CPServiceImpl.java:240)
at java.util.ArrayList.forEach(ArrayList.java:1249)
rContainer.doInvokeListener(AbstractMessageListenerContainer.java:721)
at org.springframework.jms.listener.AbstractMessageListenerContainer.invokeListener(AbstractMessageListenerContainer.java:681)
at org.springframework.jms.listener.AbstractMessageListenerContainer.doExecuteListener(AbstractMessageListenerContainer.java:651)
at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.IllegalArgumentException: Invalid D because cm: null and pk: null were missing.
at com.xyz.abc.pqr.mp.DD.resolveDetailsFromCDE(DD.java:151)
at com.xyz.abc.pqr.mp.DD.<init>(DD.java:35)
at java.util.ArrayList.forEach(ArrayList.java:1249)
2018-10-15 05:24:25.136 ERROR 7 --- [DefaultMessageListenerContainer-1] c.l.p.a.c.event.listener.MQListener : The ABC/CDE object received for the C1 event was not valid. entity_id=2222222, s_id=3333, event_name=CDE
com.xyz.abc.pqr.PNotVException: The r received from C was invalid/lacks mandatory fields. S_id: 123, P_Id: 123456789, R_Number: 12345678
at com.xyz.abc.pqr.mp.CSImpl.lambda$buildABCByCo$1(CSImpl.java:240)
at java.util.ArrayList.forEach(ArrayList.java:1249)
at com.xyz.abc.pqr.event.handler.DHandler.handle(CDEEventHandler.java:45)
at sun.reflect.GMA.invoke(Unknown Source)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at org.springframework.aop.support.AopUtils.invokeJoinpointUsingReflection(AopUtils.java:333)
at org.springframework.messaging.handler.invocation.InvocableHandlerMethod.doInvoke(InvocableHandlerMethod.java:197)
at org.springframework.messaging.handler.invocation.InvocableHandlerMethod.invoke(InvocableHandlerMethod.java:115)
at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.NullPointerException: null
You can use the DataFrame API to do this in a few ways. Here is one
import org.apache.spark.sql.functions._
val rd = sc.textFile("/FileStore/tables/log.txt").zipWithIndex.map{case (r, i) => Row(r, i)}
val schema = StructType(StructField("logs", StringType, false) :: StructField("id", LongType, false) :: Nil)
val df = spark.sqlContext.createDataFrame(rd, schema)
df.show
+--------------------+---+
| logs| id|
+--------------------+---+
|2018-10-15 05:24:...| 0|
| | 1|
|com.xyz.abc.pqr.e...| 2|
| at com.xyz.ab...| 3|
| at java.util....| 4|
| rContainer.do...| 5|
| at org.spring...| 6|
| at org.spring...| 7|
| at java.lang....| 8|
|Caused by: java.l...| 9|
| at com.xyz.ab...| 10|
| at com.xyz.ab...| 11|
| at com.xyz.ab...| 12|
| at java.util....| 13|
| | 14|
|2018-10-15 05:24:...| 15|
| | 16|
|com.xyz.abc.pqr.P...| 17|
val df1 = df.filter($"logs".contains("c.l.p.a.c.event.listener.MQListener")).withColumn("logs",regexp_replace($"logs","ERROR.*","")).sort("id")
df1.show
+--------------------+---+
| logs| id|
+--------------------+---+
|2018-10-15 05:24:...| 0|
|2018-10-15 05:24:...| 15|
+--------------------+---+
val df2 = df.filter($"logs".contains("PrescriptionNotValidException:")).withColumn("logs",regexp_replace($"logs","(.*?)mandatory fields.","")).sort("id")
df2.show
+--------------------+---+
| logs| id|
+--------------------+---+
| StoreId: 123, Co...| 2|
| StoreId: 234, Co...| 17|
+--------------------+---+
val df3 = df.filter($"logs".contains("Caused by: java.lang.")).sort("id")
df3.show
df1.select("logs").collect.toSeq.zip(df2.select("logs").collect.toSeq).zip(df3.select("logs").collect.toSeq)
+--------------------+---+
| logs| id|
+--------------------+---+
|Caused by: java.l...| 9|
|Caused by: java.l...| 28|
+--------------------+---+
df3: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [logs: string, id: bigint]
res71: Seq[((org.apache.spark.sql.Row, org.apache.spark.sql.Row), org.apache.spark.sql.Row)] = ArrayBuffer((([2018-10-15 05:24:00.102 ],[ StoreId: 123, Co Patient Id: 123456789, Rx Number: 12345678]),[Caused by: java.lang.IllegalArgumentException: Invalid Dispense Object because compound: null and pack: null were missing.]), (([2018-10-15 05:24:25.136 ],[ StoreId: 234, Co Patient Id: 999999, Rx Number: 45555]),[Caused by: java.lang.NullPointerException: null]))

minBy equivalent in Spark Dataframe [duplicate]

This question already has answers here:
How to select the first row of each group?
(9 answers)
Closed 4 years ago.
I'm looking for equivalent function of minBy aggregate in Spark Dataframe or may need to manually aggregate. Any thoughts? Thanks.
https://prestodb.io/docs/current/functions/aggregate.html#min_by
There is no such direct function to get the 'min_by' values from the Dataframe.
It is a two stage operation in Spark. First groupby the column then apply min function to get min value for each numeric column for each group.
scala> val inputDF = Seq(("a", 1),("b", 2), ("b", 3), ("a", 4), ("a", 5)).toDF("id", "count")
inputDF: org.apache.spark.sql.DataFrame = [id: string, count: int]
scala> inputDF.show()
+---+-----+
| id|count|
+---+-----+
| a| 1|
| b| 2|
| b| 3|
| a| 4|
| a| 5|
+---+-----+
scala> inputDF.groupBy($"id").min("count").show()
+---+----------+
| id|min(count)|
+---+----------+
| b| 2|
| a| 1|
+---+----------+

Spark: Parquet DataFrame operations fail when forcing schema on read

(Spark 2.0.2)
The problem here rises when you have parquet files with different schema and force the schema during read. Even though you can print the schema and run show() ok, you cannot apply any filtering logic on the missing columns.
Here are the two example schemata:
// assuming you are running this code in a spark REPL
import spark.implicits._
case class Foo(i: Int)
case class Bar(i: Int, j: Int)
So Bar includes all the fields of Foo and adds one more (j). In real-life this arises when you start with schema Foo and later decided that you needed more fields and end up with schema Bar.
Let's simulate the two different parquet files.
// assuming you are on a Mac or Linux OS
spark.createDataFrame(Foo(1)::Nil).write.parquet("/tmp/foo")
spark.createDataFrame(Bar(1,2)::Nil).write.parquet("/tmp/bar")
What we want here is to always read data using the more generic schema Bar. That is, rows written on schema Foo should have j to be null.
case 1: We read a mix of both schema
spark.read.option("mergeSchema", "true").parquet("/tmp/foo", "/tmp/bar").show()
+---+----+
| i| j|
+---+----+
| 1| 2|
| 1|null|
+---+----+
spark.read.option("mergeSchema", "true").parquet("/tmp/foo", "/tmp/bar").filter($"j".isNotNull).show()
+---+---+
| i| j|
+---+---+
| 1| 2|
+---+---+
case 2: We only have Bar data
spark.read.parquet("/tmp/bar").show()
+---+---+
| i| j|
+---+---+
| 1| 2|
+---+---+
case 3: We only have Foo data
scala> spark.read.parquet("/tmp/foo").show()
+---+
| i|
+---+
| 1|
+---+
The problematic case is 3, where our resulting schema is of type Foo and not of Bar. Since we migrate to schema Bar, we want to always get schema Bar from our data (old and new).
The suggested solution would be to define the schema programmatically to always be Bar. Let's see how to do this:
val barSchema = org.apache.spark.sql.Encoders.product[Bar].schema
//barSchema: org.apache.spark.sql.types.StructType = StructType(StructField(i,IntegerType,false), StructField(j,IntegerType,false))
Running show() works great:
scala> spark.read.schema(barSchema).parquet("/tmp/foo").show()
+---+----+
| i| j|
+---+----+
| 1|null|
+---+----+
However, if you try to filter on the missing column j, things fail.
scala> spark.read.schema(barSchema).parquet("/tmp/foo").filter($"j".isNotNull).show()
17/09/07 18:13:50 ERROR Executor: Exception in task 0.0 in stage 230.0 (TID 481)
java.lang.IllegalArgumentException: Column [j] was not found in schema!
at org.apache.parquet.Preconditions.checkArgument(Preconditions.java:55)
at org.apache.parquet.filter2.predicate.SchemaCompatibilityValidator.getColumnDescriptor(SchemaCompatibilityValidator.java:181)
at org.apache.parquet.filter2.predicate.SchemaCompatibilityValidator.validateColumn(SchemaCompatibilityValidator.java:169)
at org.apache.parquet.filter2.predicate.SchemaCompatibilityValidator.validateColumnFilterPredicate(SchemaCompatibilityValidator.java:151)
at org.apache.parquet.filter2.predicate.SchemaCompatibilityValidator.visit(SchemaCompatibilityValidator.java:91)
at org.apache.parquet.filter2.predicate.SchemaCompatibilityValidator.visit(SchemaCompatibilityValidator.java:58)
at org.apache.parquet.filter2.predicate.Operators$NotEq.accept(Operators.java:194)
at org.apache.parquet.filter2.predicate.SchemaCompatibilityValidator.validate(SchemaCompatibilityValidator.java:63)
at org.apache.parquet.filter2.compat.RowGroupFilter.visit(RowGroupFilter.java:59)
at org.apache.parquet.filter2.compat.RowGroupFilter.visit(RowGroupFilter.java:40)
at org.apache.parquet.filter2.compat.FilterCompat$FilterPredicateCompat.accept(FilterCompat.java:126)
at org.apache.parquet.filter2.compat.RowGroupFilter.filterRowGroups(RowGroupFilter.java:46)
at org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase.initialize(SpecificParquetRecordReaderBase.java:110)
at org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader.initialize(VectorizedParquetRecordReader.java:109)
at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$$anonfun$buildReader$1.apply(ParquetFileFormat.scala:381)
at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$$anonfun$buildReader$1.apply(ParquetFileFormat.scala:355)
at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.nextIterator(FileScanRDD.scala:168)
at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.hasNext(FileScanRDD.scala:109)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.scan_nextBatch$(Unknown Source)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:377)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:231)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:225)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
at org.apache.spark.scheduler.Task.run(Task.scala:99)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:322)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
at java.lang.Thread.run(Thread.java:745)
Issue is due to parquet filter push down which is not correctly handled in parquet-mr versions < 1.9.0
You can check https://issues.apache.org/jira/browse/PARQUET-389 for more details.
You can either upgrade the parquet-mr version or add a new column and base the filter on the new column.
For eg.
dfNew = df.withColumn("new_j", when($"j".isNotNull, $"j").otherwise(lit(null)))
dfNew.filter($"new_j".isNotNull)
On Spark 1.6 worked fine, schema retrieving was changed, HiveContext was used:
val barSchema = ScalaReflection.schemaFor[Bar].dataType.asInstanceOf[StructType]
println(s"barSchema: $barSchema")
hiveContext.read.schema(barSchema).parquet("tmp/foo").filter($"j".isNotNull).show()
Result is:
barSchema: StructType(StructField(i,IntegerType,false), StructField(j,IntegerType,false))
+---+----+
| i| j|
+---+----+
| 1|null|
+---+----+
What worked for me is to use the createDataFrame API with RDD[Row] and the new schema (which at least the new columns being nullable).
// Make the columns nullable (probably you don't need to make them all nullable)
val barSchemaNullable = org.apache.spark.sql.types.StructType(
barSchema.map(_.copy(nullable = true)).toArray)
// We create the df (but this is not what you want to use, since it still has the same issue)
val df = spark.read.schema(barSchemaNullable).parquet("/tmp/foo")
// Here is the final API that give a working DataFrame
val fixedDf = spark.createDataFrame(df.rdd, barSchemaNullable)
fixedDf.filter($"j".isNotNull).show()
+---+---+
| i| j|
+---+---+
+---+---+

How to use DataFrame Window expressions and withColumn and not to change partition?

For some reason I have to convert RDD to DataFrame, then do something with DataFrame.
My interface is RDD,so I have to convert DataFrame to RDD, and when I use df.withcolumn, the partition change to 1, so I have to repartition and sortBy RDD.
Is there any cleaner solution ?
This is my code :
val rdd = sc.parallelize(List(1,3,2,4,5,6,7,8),4)
val partition = rdd.getNumPartitions
println(partition + "rdd")
val df=rdd.toDF()
val rdd2=df.rdd
val result = rdd.toDF("col1")
.withColumn("csum", sum($"col1").over(Window.orderBy($"col1")))
.withColumn("rownum", row_number().over(Window.orderBy($"col1")))
.withColumn("avg", $"csum"/$"rownum").rdd
println(result.getNumPartitions + "rdd2")
Let's make this as simple as possible, we will generate the same data into 4 partitions
scala> val df = spark.range(1,9,1,4).toDF
df: org.apache.spark.sql.DataFrame = [id: bigint]
scala> df.show
+---+
| id|
+---+
| 1|
| 2|
| 3|
| 4|
| 5|
| 6|
| 7|
| 8|
+---+
scala> df.rdd.getNumPartitions
res13: Int = 4
We don't need 3 window functions to prove this, so let's do it with one :
scala> import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.expressions.Window
scala> val df2 = df.withColumn("csum", sum($"id").over(Window.orderBy($"id")))
df2: org.apache.spark.sql.DataFrame = [id: bigint, csum: bigint]
So what's happening here is that we didn't just add a column but we computed a window of cumulative sum over the data and since you haven't provided an partition column, the window function will move all the data to a single partition and you even get a warning from spark :
scala> df2.rdd.getNumPartitions
17/06/06 10:05:53 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
res14: Int = 1
scala> df2.show
17/06/06 10:05:56 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
+---+----+
| id|csum|
+---+----+
| 1| 1|
| 2| 3|
| 3| 6|
| 4| 10|
| 5| 15|
| 6| 21|
| 7| 28|
| 8| 36|
+---+----+
So let's add now a column to partition on. We will create a new DataFrame just for the sake of demonstration :
scala> val df3 = df.withColumn("x", when($"id"<5,lit("a")).otherwise("b"))
df3: org.apache.spark.sql.DataFrame = [id: bigint, x: string]
It has indeed the same number of partitions that we defined explicitly on df :
scala> df3.rdd.getNumPartitions
res18: Int = 4
Let's perform our window operation using the column x to partition :
scala> val df4 = df3.withColumn("csum", sum($"id").over(Window.orderBy($"id").partitionBy($"x")))
df4: org.apache.spark.sql.DataFrame = [id: bigint, x: string ... 1 more field]
scala> df4.show
+---+---+----+
| id| x|csum|
+---+---+----+
| 5| b| 5|
| 6| b| 11|
| 7| b| 18|
| 8| b| 26|
| 1| a| 1|
| 2| a| 3|
| 3| a| 6|
| 4| a| 10|
+---+---+----+
The window function will repartition our data using the default number of partitions set in spark configuration.
scala> df4.rdd.getNumPartitions
res20: Int = 200
I was just reading about controlling the number of partitions when using groupBy aggregation, from https://jaceklaskowski.gitbooks.io/mastering-spark-sql/spark-sql-performance-tuning-groupBy-aggregation.html, it seems the same trick works with Window, in my code I'm defining a window like
windowSpec = Window \
.partitionBy('colA', 'colB') \
.orderBy('timeCol') \
.rowsBetween(1, 1)
and then doing
next_event = F.lead('timeCol', 1).over(windowSpec)
and creating a dataframe via
df2 = df.withColumn('next_event', next_event)
and indeed, it has 200 partitions. But, if I do
df2 = df.repartition(10, 'colA', 'colB').withColumn('next_event', next_event)
it has 10!

Select column by name with multiple aggregate columns after pivot with Spark Scala

I am trying to aggregate multitple columns after a pivot in Scala Spark 2.0.1:
scala> val df = List((1, 2, 3, None), (1, 3, 4, Some(1))).toDF("a", "b", "c", "d")
df: org.apache.spark.sql.DataFrame = [a: int, b: int ... 2 more fields]
scala> df.show
+---+---+---+----+
| a| b| c| d|
+---+---+---+----+
| 1| 2| 3|null|
| 1| 3| 4| 1|
+---+---+---+----+
scala> val pivoted = df.groupBy("a").pivot("b").agg(max("c"), max("d"))
pivoted: org.apache.spark.sql.DataFrame = [a: int, 2_max(`c`): int ... 3 more fields]
scala> pivoted.show
+---+----------+----------+----------+----------+
| a|2_max(`c`)|2_max(`d`)|3_max(`c`)|3_max(`d`)|
+---+----------+----------+----------+----------+
| 1| 3| null| 4| 1|
+---+----------+----------+----------+----------+
I am unable to select or rename those columns so far:
scala> pivoted.select("3_max(`d`)")
org.apache.spark.sql.AnalysisException: syntax error in attribute name: 3_max(`d`);
scala> pivoted.select("`3_max(`d`)`")
org.apache.spark.sql.AnalysisException: syntax error in attribute name: `3_max(`d`)`;
scala> pivoted.select("`3_max(d)`")
org.apache.spark.sql.AnalysisException: cannot resolve '`3_max(d)`' given input columns: [2_max(`c`), 3_max(`d`), a, 2_max(`d`), 3_max(`c`)];
There must be a simple trick here, any ideas? Thanks.
Seems like a bug, the back ticks caused the problem. One fix here would be to remove the back ticks from the column names:
val pivotedNewName = pivoted.columns.foldLeft(pivoted)((df, col) =>
df.withColumnRenamed(col, col.replace("`", "")))
Now you can select by column names as normal:
pivotedNewName.select("2_max(c)").show
+--------+
|2_max(c)|
+--------+
| 3|
+--------+