How to do string manipulation in Spark dataframe's columns value - scala

I am very new to Spark, i have to perform string manipulation operations and create new column in spark dataframe. I have created UDF functions for string manipulation and due to performance i want to do this without UDF. Following is my code and output. Could please help me to create this in better way?
object Demo2 extends Context {
import org.apache.spark.sql.functions.udf
def main(args: Array[String]): Unit = {
import sparkSession.sqlContext.implicits._
val data = Seq(
("bankInfo.SBI.C_1.Kothrud.Pune.displayInfo"),
("bankInfo.ICICI.C_2.TilakRoad.Pune.displayInfo"),
("bankInfo.Axis.C_3.Santacruz.Mumbai.displayInfo"),
("bankInfo.HDFC.C_4.Deccan.Pune.displayInfo")
)
val df = data.toDF("Key")
println("Input Dataframe")
df.show(false)
//get local_address
val get_local_address = udf((key: String) => {
val first_index = key.indexOf(".")
val tmp_key = key.substring(first_index + 1)
val last_index = tmp_key.lastIndexOf(".")
val local_address = tmp_key.substring(0, last_index)
local_address
})
//get address
val get_address = udf((key: String) => {
val first_index = key.indexOf(".")
val tmp_key = key.substring(first_index + 1)
val last_index1 = tmp_key.lastIndexOf(".")
val tmp_key1 = tmp_key.substring(0, last_index1)
val last_index2 = tmp_key1.lastIndexOf(".");
val first_index1 = tmp_key1.lastIndexOf(".", last_index2 - 1);
val address = tmp_key1.substring(0, first_index1) + tmp_key1.substring(last_index2)
address
})
val df2 = df
.withColumn("Local Address", get_local_address(df("Key")))
.withColumn("Address", get_address(df("Key")))
println("Output Dataframe")
df2.show(false)
}
}
Input Dataframe
+----------------------------------------------+
|Key |
+----------------------------------------------+
|bankInfo.SBI.C_1.Kothrud.Pune.displayInfo |
|bankInfo.ICICI.C_2.TilakRoad.Pune.displayInfo |
|bankInfo.Axis.C_3.Santacruz.Mumbai.displayInfo|
|bankInfo.HDFC.C_4.Deccan.Pune.displayInfo |
+----------------------------------------------+
Output Dataframe
+----------------------------------------------+-------------------------+---------------+
|Key |Local Address |Address |
+----------------------------------------------+-------------------------+---------------+
|bankInfo.SBI.C_1.Kothrud.Pune.displayInfo |SBI.C_1.Kothrud.Pune |SBI.C_1.Pune |
|bankInfo.ICICI.C_2.TilakRoad.Pune.displayInfo |ICICI.C_2.TilakRoad.Pune |ICICI.C_2.Pune |
|bankInfo.Axis.C_3.Santacruz.Mumbai.displayInfo|Axis.C_3.Santacruz.Mumbai|Axis.C_3.Mumbai|
|bankInfo.HDFC.C_4.Deccan.Pune.displayInfo |HDFC.C_4.Deccan.Pune |HDFC.C_4.Pune |
+----------------------------------------------+-------------------------+---------------+

Since you have fixed sized array, you can structurize them and then concat as required-
Load the test data provided
val data =
"""
|Key
|bankInfo.SBI.C_1.Kothrud.Pune.displayInfo
|bankInfo.ICICI.C_2.TilakRoad.Pune.displayInfo
|bankInfo.Axis.C_3.Santacruz.Mumbai.displayInfo
|bankInfo.HDFC.C_4.Deccan.Pune.displayInfo
""".stripMargin
val stringDS1 = data.split(System.lineSeparator())
.map(_.split("\\|").map(_.replaceAll("""^[ \t]+|[ \t]+$""", "")).mkString(","))
.toSeq.toDS()
val df1 = spark.read
.option("sep", ",")
.option("inferSchema", "true")
.option("header", "true")
.option("nullValue", "null")
.csv(stringDS1)
df1.show(false)
df1.printSchema()
/**
* +----------------------------------------------+
* |Key |
* +----------------------------------------------+
* |bankInfo.SBI.C_1.Kothrud.Pune.displayInfo |
* |bankInfo.ICICI.C_2.TilakRoad.Pune.displayInfo |
* |bankInfo.Axis.C_3.Santacruz.Mumbai.displayInfo|
* |bankInfo.HDFC.C_4.Deccan.Pune.displayInfo |
* +----------------------------------------------+
*
* root
* |-- Key: string (nullable = true)
*/
Derive the columns from the fixed format string column
df1.select($"key", split($"key", "\\.").as("x"))
.withColumn("bankInfo",
expr(
"""
|named_struct('name', element_at(x, 2), 'cust_id', element_at(x, 3),
| 'branch', element_at(x, 4), 'dist', element_at(x, 5))
""".stripMargin))
.select($"key",
concat_ws(".", $"bankInfo.name", $"bankInfo.cust_id", $"bankInfo.branch", $"bankInfo.dist")
.as("Local_Address"),
concat_ws(".", $"bankInfo.name", $"bankInfo.cust_id", $"bankInfo.dist")
.as("Address"))
.show(false)
/**
* +----------------------------------------------+-------------------------+---------------+
* |key |Local_Address |Address |
* +----------------------------------------------+-------------------------+---------------+
* |bankInfo.SBI.C_1.Kothrud.Pune.displayInfo |SBI.C_1.Kothrud.Pune |SBI.C_1.Pune |
* |bankInfo.ICICI.C_2.TilakRoad.Pune.displayInfo |ICICI.C_2.TilakRoad.Pune |ICICI.C_2.Pune |
* |bankInfo.Axis.C_3.Santacruz.Mumbai.displayInfo|Axis.C_3.Santacruz.Mumbai|Axis.C_3.Mumbai|
* |bankInfo.HDFC.C_4.Deccan.Pune.displayInfo |HDFC.C_4.Deccan.Pune |HDFC.C_4.Pune |
* +----------------------------------------------+-------------------------+---------------+
*/
df1.select($"key", split($"key", "\\.").as("x"))
.withColumn("bankInfo",
expr("named_struct('name', x[1], 'cust_id', x[2], 'branch', x[3], 'dist', x[4])"))
.select($"key",
concat_ws(".", $"bankInfo.name", $"bankInfo.cust_id", $"bankInfo.branch", $"bankInfo.dist")
.as("Local_Address"),
concat_ws(".", $"bankInfo.name", $"bankInfo.cust_id", $"bankInfo.dist")
.as("Address"))
.show(false)
/**
* +----------------------------------------------+-------------------------+---------------+
* |key |Local_Address |Address |
* +----------------------------------------------+-------------------------+---------------+
* |bankInfo.SBI.C_1.Kothrud.Pune.displayInfo |SBI.C_1.Kothrud.Pune |SBI.C_1.Pune |
* |bankInfo.ICICI.C_2.TilakRoad.Pune.displayInfo |ICICI.C_2.TilakRoad.Pune |ICICI.C_2.Pune |
* |bankInfo.Axis.C_3.Santacruz.Mumbai.displayInfo|Axis.C_3.Santacruz.Mumbai|Axis.C_3.Mumbai|
* |bankInfo.HDFC.C_4.Deccan.Pune.displayInfo |HDFC.C_4.Deccan.Pune |HDFC.C_4.Pune |
* +----------------------------------------------+-------------------------+---------------+
*/

Check below code.
scala> df.show(false)
+----------------------------------------------+
|Key |
+----------------------------------------------+
|bankInfo.SBI.C_1.Kothrud.Pune.displayInfo |
|bankInfo.ICICI.C_2.TilakRoad.Pune.displayInfo |
|bankInfo.Axis.C_3.Santacruz.Mumbai.displayInfo|
|bankInfo.HDFC.C_4.Deccan.Pune.displayInfo |
+----------------------------------------------+
scala> val maxLength = df.select(split($"key","\\.").as("keys")).withColumn("length",size($"keys")).select(max($"length").as("length")).map(_.getAs[Int](0)).collect.head
maxLength: Int = 6
scala> val address_except = Seq(0,3,maxLength-1)
address_except: Seq[Int] = List(0, 3, 5)
scala> val local_address_except = Seq(0,maxLength-1)
local_address_except: Seq[Int] = List(0, 5)
scala> def parse(column: Column,indexes:Seq[Int]) = (0 to maxLength).filter(i => !indexes.contains(i)).map(i => column(i)).reduce(concat_ws(".",_,_))
parse: (column: org.apache.spark.sql.Column, indexes: Seq[Int])org.apache.spark.sql.Column
scala> df.select(split($"key","\\.").as("keys")).withColumn("local_address",parse($"keys",local_address_except)).withColumn("address",parse($"keys",address_except)).show(false)
+-----------------------------------------------------+-------------------------+---------------+
|keys |local_address |address |
+-----------------------------------------------------+-------------------------+---------------+
|[bankInfo, SBI, C_1, Kothrud, Pune, displayInfo] |SBI.C_1.Kothrud.Pune |SBI.C_1.Pune |
|[bankInfo, ICICI, C_2, TilakRoad, Pune, displayInfo] |ICICI.C_2.TilakRoad.Pune |ICICI.C_2.Pune |
|[bankInfo, Axis, C_3, Santacruz, Mumbai, displayInfo]|Axis.C_3.Santacruz.Mumbai|Axis.C_3.Mumbai|
|[bankInfo, HDFC, C_4, Deccan, Pune, displayInfo] |HDFC.C_4.Deccan.Pune |HDFC.C_4.Pune |
+-----------------------------------------------------+-------------------------+---------------+

Related

How to compute an iterative computation with spark and scala

In the example below the code produce a computation that is applied systematically to the same set of the original records.
Instead, the code must use the previously computed value to produce the subsequent quantity.
package playground
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{KeyValueGroupedDataset, SparkSession}
object basic2 extends App {
Logger.getLogger("org").setLevel(Level.OFF)
Logger.getLogger("akka").setLevel(Level.OFF)
val spark = SparkSession
.builder()
.appName("Sample app")
.master("local")
.getOrCreate()
import spark.implicits._
final case class Owner(car: String, pcode: String, qtty: Double)
final case class Invoice(car: String, pcode: String, qtty: Double)
val data = Seq(
Owner("A", "666", 80),
Owner("B", "555", 20),
Owner("A", "444", 50),
Owner("A", "222", 20),
Owner("C", "444", 20),
Owner("C", "666", 80),
Owner("C", "555", 120),
Owner("A", "888", 100)
)
val fleet = Seq(Invoice("A", "666", 15), Invoice("A", "888", 12))
val owners = spark.createDataset(data)
val invoices = spark.createDataset(fleet)
val gb: KeyValueGroupedDataset[Invoice, (Owner, Invoice)] = owners
.joinWith(invoices, invoices("car") === owners("car"), "inner")
.groupByKey(_._2)
gb.flatMapGroups {
case (fleet, group) ⇒
val subOwner: Vector[Owner] = group.toVector.map(_._1)
val calculatedRes = subOwner.filter(_.car == fleet.car)
calculatedRes.map(c => c.copy(qtty = .3 * c.qtty + fleet.qtty))
}
.show()
}
/**
* +---+-----+----+
* |car|pcode|qtty|
* +---+-----+----+
* | A| 666|39.0|
* | A| 444|30.0|
* | A| 222|21.0|
* | A| 888|45.0|
* | A| 666|36.0|
* | A| 444|27.0|
* | A| 222|18.0|
* | A| 888|42.0|
* +---+-----+----+
*
* +---+-----+----+
* |car|pcode|qtty|
* +---+-----+----+
* | A| 666|0.3 * 39.0 + 12|
* | A| 444|0.3 * 30.0 + 12|
* | A| 222|0.3 * 21.0 + 12|
* | A| 888|0.3 * 45.0 + 12|
* +---+-----+----+
*/
The second table above is showing the expected output. The first table is what the code of this question produces.
How to produce the expected output in an iterative way?
Notice that the order of computation doesn't matter, the results will be different but it is still a valid answer.
Check below code.
val getQtty = udf((invoicesQtty:Seq[Double],ownersQtty:Double) => {
invoicesQtty.tail.foldLeft((0.3 * ownersQtty + invoicesQtty.head))(
(totalIQ,nextInvoiceQtty) => 0.3 * totalIQ + nextInvoiceQtty
)
})
val getQttyStr = udf((invoicesQtty:Seq[Double],ownersQtty:Double) => {
val totalIQ = (0.3 * ownersQtty + invoicesQtty.head)
invoicesQtty.tail.foldLeft("")(
(data,nextInvoiceQtty) => {
s"0.3 * ${if(data.isEmpty) totalIQ else s"(${data})"} + ${nextInvoiceQtty}"
}
)
})
owners
.join(invoices, invoices("car") === owners("car"), "inner")
.orderBy(invoices("qtty").desc)
.groupBy(owners("car"),owners("pcode"))
.agg(
collect_list(invoices("qtty")).as("invoices_qtty"),
first(owners("qtty")).as("owners_qtty")
)
.withColumn("qtty",getQtty($"invoices_qtty",$"owners_qtty"))
.withColumn("qtty_str",getQttyStr($"invoices_qtty",$"owners_qtty"))
.show(false)
Result
+---+-----+-------------+-----------+----+-----------------+
|car|pcode|invoices_qtty|owners_qtty|qtty|qtty_str |
+---+-----+-------------+-----------+----+-----------------+
|A |666 |[15.0, 12.0] |80.0 |23.7|0.3 * 39.0 + 12.0|
|A |888 |[15.0, 12.0] |100.0 |25.5|0.3 * 45.0 + 12.0|
|A |444 |[15.0, 12.0] |50.0 |21.0|0.3 * 30.0 + 12.0|
|A |222 |[15.0, 12.0] |20.0 |18.3|0.3 * 21.0 + 12.0|
+---+-----+-------------+-----------+----+-----------------+

Better Alternatives to EXCEPT Spark Scala

I have been told that EXCEPT is a very costly operation and one should always try to avoid using EXCEPT.
My Use Case -
val myFilter = "rollNo='11' AND class='10'"
val rawDataDf = spark.table(<table_name>)
val myFilteredDataframe = rawDataDf.where(myFilter)
val allOthersDataframe = rawDataDf.except(myFilteredDataframe)
But I am confused, in such use case , what are my alternatives ?
Use left anti join as below-
val df = spark.range(2).withColumn("name", lit("foo"))
df.show(false)
df.printSchema()
/**
* +---+----+
* |id |name|
* +---+----+
* |0 |foo |
* |1 |foo |
* +---+----+
*
* root
* |-- id: long (nullable = false)
* |-- name: string (nullable = false)
*/
val df2 = df.filter("id=0")
df.join(df2, df.columns.toSeq, "leftanti")
.show(false)
/**
* +---+----+
* |id |name|
* +---+----+
* |1 |foo |
* +---+----+
*/

Convertions String which has list of objects to a dataframe

i have a string as
str=[{"A":120.0,"B":"0005236"},{"A":10.0,"B":"0005200"},
{"A":12.0,"B":"00042276"},{"A":20.0,"B":"00052000"}]
i am trying to convert it to a data frame...
+-------+--------------+
|packQty|gtin |
+-------+--------------+
|120.0 |0005236 |
|10.0 |0005200 |
|12.0 |00042276 |
|20.0 |00052000 |
+-------+--------------+
i have created a schema as
val schema=new StructType()
.add("packQty",FloatType)
.add("gtin", StringType)
val df =Seq(str).toDF("testQTY")
val df2=df.withColumn("jsonData",from_json($"testQTY",schema)).select("jsonData.*")
this is returning me a data frame with only one record...
+-------+--------------+
|packQty|gtin |
+-------+--------------+
|120.0 |0005236|
+-------+--------------+
how to modify the schema so that i can get all the records
if it was a an array then i could have used explode() function to get the values, but i am getting this error.
cannot resolve 'explode(`gtins`)' due to data type mismatch: input to
function explode should be array or map type, not string;;
this is how the column is populated
+----------------------------------------------------------------------+
| gtins
|
+----------------------------------------------------------------------
|[{"packQty":120.0,"gtin":"000520"},{"packQty":10.0,”gtin":"0005200"}]
+----------------------------------------------------------------------+
Keep your schema inside ArrayType i.e ArrayType(new StructType().add("packQty",FloatType).add("gtin", StringType)) this will give you null values as schema column names are not matched with json data.
Change schema ArrayType(new StructType().add("packQty",FloatType).add("gtin", StringType)) to ArrayType(new StructType().add("A",FloatType).add("B", StringType)) & After parsing data rename the required columns.
Please check below code.
If column names are matched in both schema & JSON data.
scala> val json = Seq("""[{"A":120.0,"B":"0005236"},{"A":10.0,"B":"0005200"},{"A":12.0,"B":"00042276"},{"A":20.0,"B":"00052000"}]""").toDF("testQTY")
json: org.apache.spark.sql.DataFrame = [testQTY: string]
scala> val schema = ArrayType(StructType(StructField("A",DoubleType,true):: StructField("B",StringType,true) :: Nil))
schema: org.apache.spark.sql.types.ArrayType = ArrayType(StructType(StructField(A,DoubleType,true), StructField(B,StringType,true)),true)
scala> json.withColumn("jsonData",from_json($"testQTY",schema)).select(explode($"jsonData").as("jsonData")).select($"jsonData.A".as("packQty"),$"jsonData.B".as("gtin")).show(false)
+-------+--------+
|packQty|gtin |
+-------+--------+
|120.0 |0005236 |
|10.0 |0005200 |
|12.0 |00042276|
|20.0 |00052000|
+-------+--------+
If column names are not matched in both schema & JSON data.
scala> val json = Seq("""[{"A":120.0,"B":"0005236"},{"A":10.0,"B":"0005200"},{"A":12.0,"B":"00042276"},{"A":20.0,"B":"00052000"}]""").toDF("testQTY")
json: org.apache.spark.sql.DataFrame = [testQTY: string]
scala> val schema = ArrayType(StructType(StructField("packQty",DoubleType,true):: StructField("gtin",StringType,true) :: Nil)) // Column names are not matched with json & schema.
schema: org.apache.spark.sql.types.ArrayType = ArrayType(StructType(StructField(packQty,DoubleType,true), StructField(gtin,StringType,true)),true)
scala> json.withColumn("jsonData",from_json($"testQTY",schema)).select(explode($"jsonData").as("jsonData")).select($"jsonData.*").show(false)
+-------+----+
|packQty|gtin|
+-------+----+
|null |null|
|null |null|
|null |null|
|null |null|
+-------+----+
Alternative way of parsing json string into DataFrame using DataSet
scala> val json = Seq("""[{"A":120.0,"B":"0005236"},{"A":10.0,"B":"0005200"},{"A":12.0,"B":"00042276"},{"A":20.0,"B":"00052000"}]""").toDS // Creating DataSet from json string.
json: org.apache.spark.sql.Dataset[String] = [value: string]
scala> val schema = StructType(StructField("A",DoubleType,true):: StructField("B",StringType,true) :: Nil) // Creating schema.
schema: org.apache.spark.sql.types.StructType = StructType(StructField(A,DoubleType,true), StructField(B,StringType,true))
scala> spark.read.schema(schema).json(json).select($"A".as("packQty"),$"B".as("gtin")).show(false)
+-------+--------+
|packQty|gtin |
+-------+--------+
|120.0 |0005236 |
|10.0 |0005200 |
|12.0 |00042276|
|20.0 |00052000|
+-------+--------+
Specifying one more option -
val data = """[{"A":120.0,"B":"0005236"},{"A":10.0,"B":"0005200"},{"A":12.0,"B":"00042276"},{"A":20.0,"B":"00052000"}]"""
val df2 = Seq(data).toDF("gtins")
df2.show(false)
df2.printSchema()
/**
* +--------------------------------------------------------------------------------------------------------+
* |gtins |
* +--------------------------------------------------------------------------------------------------------+
* |[{"A":120.0,"B":"0005236"},{"A":10.0,"B":"0005200"},{"A":12.0,"B":"00042276"},{"A":20.0,"B":"00052000"}]|
* +--------------------------------------------------------------------------------------------------------+
*
* root
* |-- gtins: string (nullable = true)
*/
df2.selectExpr("inline_outer(from_json(gtins, 'array<struct<A:double, B:string>>')) as (packQty, gtin)")
.show(false)
/**
* +-------+--------+
* |packQty|gtin |
* +-------+--------+
* |120.0 |0005236 |
* |10.0 |0005200 |
* |12.0 |00042276|
* |20.0 |00052000|
* +-------+--------+
*/

In spark iterate through each column and find the max length

I am new to spark scala and I have following situation as below
I have a table "TEST_TABLE" on cluster(can be hive table)
I am converting that to dataframe
as:
scala> val testDF = spark.sql("select * from TEST_TABLE limit 10")
Now the DF can be viewed as
scala> testDF.show()
COL1|COL2|COL3
----------------
abc|abcd|abcdef
a|BCBDFG|qddfde
MN|1234B678|sd
I want an output like below
COLUMN_NAME|MAX_LENGTH
COL1|3
COL2|8
COL3|6
Is this feasible to do so in spark scala?
Plain and simple:
import org.apache.spark.sql.functions._
val df = spark.table("TEST_TABLE")
df.select(df.columns.map(c => max(length(col(c)))): _*)
You can try in the following way:
import org.apache.spark.sql.functions.{length, max}
import spark.implicits._
val df = Seq(("abc","abcd","abcdef"),
("a","BCBDFG","qddfde"),
("MN","1234B678","sd"),
(null,"","sd")).toDF("COL1","COL2","COL3")
df.cache()
val output = df.columns.map(c => (c, df.agg(max(length(df(s"$c")))).as[Int].first())).toSeq.toDF("COLUMN_NAME", "MAX_LENGTH")
+-----------+----------+
|COLUMN_NAME|MAX_LENGTH|
+-----------+----------+
| COL1| 3|
| COL2| 8|
| COL3| 6|
+-----------+----------+
I think it's good idea to cache input dataframe df to make the computation faster.
Here is one more way to get the report of column names in vertical
scala> val df = Seq(("abc","abcd","abcdef"),("a","BCBDFG","qddfde"),("MN","1234B678","sd")).toDF("COL1","COL2","COL3")
df: org.apache.spark.sql.DataFrame = [COL1: string, COL2: string ... 1 more field]
scala> df.show(false)
+----+--------+------+
|COL1|COL2 |COL3 |
+----+--------+------+
|abc |abcd |abcdef|
|a |BCBDFG |qddfde|
|MN |1234B678|sd |
+----+--------+------+
scala> val columns = df.columns
columns: Array[String] = Array(COL1, COL2, COL3)
scala> val df2 = columns.foldLeft(df) { (acc,x) => acc.withColumn(x,length(col(x))) }
df2: org.apache.spark.sql.DataFrame = [COL1: int, COL2: int ... 1 more field]
scala> df2.select( columns.map(x => max(col(x))):_* ).show(false)
+---------+---------+---------+
|max(COL1)|max(COL2)|max(COL3)|
+---------+---------+---------+
|3 |8 |6 |
+---------+---------+---------+
scala> df3.flatMap( r => { (0 until r.length).map( i => (columns(i),r.getInt(i)) ) } ).show(false)
+----+---+
|_1 |_2 |
+----+---+
|COL1|3 |
|COL2|8 |
|COL3|6 |
+----+---+
scala>
To get the results into Scala collections, say Map()
scala> val result = df3.flatMap( r => { (0 until r.length).map( i => (columns(i),r.getInt(i)) ) } ).as[(String,Int)].collect.toMap
result: scala.collection.immutable.Map[String,Int] = Map(COL1 -> 3, COL2 -> 8, COL3 -> 6)
scala> result
res47: scala.collection.immutable.Map[String,Int] = Map(COL1 -> 3, COL2 -> 8, COL3 -> 6)
scala>

How to encode string values into numeric values in Spark DataFrame

I have a DataFrame with two columns:
df =
Col1 Col2
aaa bbb
ccc aaa
I want to encode String values into numeric values. I managed to do it in this way:
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer}
val indexer1 = new StringIndexer()
.setInputCol("Col1")
.setOutputCol("Col1Index")
.fit(df)
val indexer2 = new StringIndexer()
.setInputCol("Col2")
.setOutputCol("Col2Index")
.fit(df)
val indexed1 = indexer1.transform(df)
val indexed2 = indexer2.transform(df)
val encoder1 = new OneHotEncoder()
.setInputCol("Col1Index")
.setOutputCol("Col1Vec")
val encoder2 = new OneHotEncoder()
.setInputCol("Col2Index")
.setOutputCol("Col2Vec")
val encoded1 = encoder1.transform(indexed1)
encoded1.show()
val encoded2 = encoder2.transform(indexed2)
encoded2.show()
The problem is that aaa is encoded in different ways in two columns.
How can I encode my DataFrame in order to get the new one correctly encoded, e.g.:
df_encoded =
Col1 Col2
1 2
3 1
Train single Indexer on both columns:
val df = Seq(("aaa", "bbb"), ("ccc", "aaa")).toDF("col1", "col2")
val indexer = new StringIndexer().setInputCol("col").fit(
df.select("col1").toDF("col").union(df.select("col2").toDF("col"))
)
and apply copy on each column
import org.apache.spark.ml.param.ParamMap
val result = Seq("col1", "col2").foldLeft(df){
(df, col) => indexer
.copy(new ParamMap()
.put(indexer.inputCol, col)
.put(indexer.outputCol, s"${col}_idx"))
.transform(df)
}
result.show
// +----+----+--------+--------+
// |col1|col2|col1_idx|col2_idx|
// +----+----+--------+--------+
// | aaa| bbb| 0.0| 1.0|
// | ccc| aaa| 2.0| 0.0|
// +----+----+--------+--------+
you can make yourself transform,the example is my pyspark code.
training a transform model as clf
sindex_pro = StringIndexer(inputCol='StringCol',outputCol='StringCol_c',stringOrderType="frequencyDesc",handleInvalid="keep").fit(province_df)`
define the self transformer load the clf
from pyspark.sql.functions import col
from pyspark.ml import Transformer
from pyspark.sql import DataFrame
class SelfSI(Transformer):
def __init__(self, clf,col_name):
super(SelfSI, self).__init__()
self.clf = clf
self.col_name=col_name
def rename_col(self,df,invers=False):
or_name = 'StringCol'
col_name = self.col_name
if invers:
df = df.withColumnRenamed(or_name,col_name)
or_name = col_name + '_c'
col_name = 'StringCol_c'
df = df.withColumnRenamed(col_name,or_name)
return df
def _transform(self, df: DataFrame) -> DataFrame:
df = self.rename_col(df)
df = self.clf.transform(df)
df = self.rename_col(df,invers=True)
return df
define the model by your need transforming column name
pro_si = SelfSI(sindex_pro,'pro_name')
pro_si.transform(df_or)
#or pipline
model = Pipeline(stages=[pro_si,pro_si2]).fit(df_or)
model.transform(df_or)
#result like
province_name|city_name|province_name_c|city_name_c|
+-------------+---------+---------------+-----------+
| 河北| 保定| 23.0| 18.0|
| 河北| 张家| 23.0| 213.0|
| 河北| 承德| 23.0| 126.0|
| 河北| 沧州| 23.0| 6.0|
| 河北| 廊坊| 23.0| 26.0|
| 北京| 北京| 13.0| 107.0|
| 天津| 天津| 10.0| 85.0|
| 河北| 石家| 23.0| 185.0|