SparkSQL:Avg based on a column after GroupBy - scala

I have an rdd of student grades and I need to first group them by the first column which is university and then show the average of student count in each course like this. What is the easiest way to do this query?
+----------+-------------------+
|university| avg of students |
+----------+--------------------+
| MIT| 3 |
| Cambridge| 2.66
Here is the dataset.
case class grade(university: String, courseId: Int, studentId: Int, grade: Double)
val grades = List(grade(
grade("Cambridge", 1, 1001, 4),
grade("Cambridge", 1, 1004, 4),
grade("Cambridge", 2, 1006, 3.5),
grade("Cambridge", 2, 1004, 3.5),
grade("Cambridge", 2, 1002, 3.5),
grade("Cambridge", 3, 1006, 3.5),
grade("Cambridge", 3, 1007, 5),
grade("Cambridge", 3, 1008, 4.5),
grade("MIT", 1, 1001, 4),
grade("MIT", 1, 1002, 4),
grade("MIT", 1, 1003, 4),
grade("MIT", 1, 1004, 4),
grade("MIT", 1, 1005, 3.5),
grade("MIT", 2, 1009, 2))

1) First groupBy university
2) then get course count per university
3) then groupBy courseId
4) then get student count per course
grades.groupBy(_.university).map { case (k, v) =>
val courseCount = v.map(_.courseId).distinct.length
val studentCountPerCourse = v.groupBy(_.courseId).map { case (k, v) => v.length }.sum
k -> (studentCountPerCourse.toDouble / courseCount.toDouble)
}
Scala REPL
scala> val grades = List(
grade("Cambridge", 1, 1001, 4),
grade("Cambridge", 1, 1004, 4),
grade("Cambridge", 2, 1006, 3.5),
grade("Cambridge", 2, 1004, 3.5),
grade("Cambridge", 2, 1002, 3.5),
grade("Cambridge", 3, 1006, 3.5),
grade("Cambridge", 3, 1007, 5),
grade("Cambridge", 3, 1008, 4.5),
grade("MIT", 1, 1001, 4),
grade("MIT", 1, 1002, 4),
grade("MIT", 1, 1003, 4),
grade("MIT", 1, 1004, 4),
grade("MIT", 1, 1005, 3.5),
grade("MIT", 2, 1009, 2))
// grades: List[grade] = List(...)
scala> grades.groupBy(_.university).map { case (k, v) =>
val courseCount = v.map(_.courseId).distinct.length
val studentCountPerCourse = v.groupBy(_.courseId).map { case (k, v) => v.length }.sum
k -> (studentCountPerCourse.toDouble / courseCount.toDouble)
}
// res2: Map[String, Double] = Map("MIT" -> 3.0, "Cambridge" -> 2.6666666666666665)

gradesRdd.map({ case Grade(university: String, courseId: Int, studentId: Int, gpa: Int) =>
((university),(courseId))}).mapValues(x => (x, 1))
.reduceByKey((x, y) => (x._1 + y._1, x._2 + y._2))
.mapValues(y => 1.0 * y._1 / y._2).collect
res73: Array[(String, Double)] = Array((Cambridge,2.125), (MIT,1.1666666666666667))

Related

removing duplicate cycles of directed graph from a list in scala

I have collection of lists shown below.
List(4, 0, 1, 2, 4)
List(4, 0, 1, 3, 4)
List(4, 0, 2, 3, 4)
List(4, 3, 2, 3, 4)
List(4, 3, 4, 3, 4)
List(0, 1, 2, 4, 0)
List(0, 1, 3, 4, 0)
List(0, 2, 3, 4, 0)
List(1, 2, 4, 0, 1)
List(1, 3, 4, 0, 1)
List(3, 4, 0, 1, 3)
List(3, 4, 0, 2, 3)
List(3, 2, 3, 2, 3)
List(3, 4, 3, 2, 3)
List(3, 2, 3, 4, 3)
List(3, 4, 3, 4, 3)
List(2, 3, 4, 0, 2)
List(2, 4, 0, 1, 2)
List(2, 3, 2, 3, 2)
List(2, 3, 4, 3, 2)
These lists are the individual cycles in a directed graph with cycle length of 4. I want to filter out the number of unique path from the given lists which does not have any smaller path in between. For example - List(4,0,1,2,4) and List(0,1,2,4,0) forms the same cycle. Another example - List(2,3,2,3,2) iterates over 2 and 3 only and does not form the cycle length 4.
From this collection we can say that List(0, 1, 2, 4, 0) List(0, 1, 3, 4, 0) List(0, 2, 3, 4, 0) are the unique paths and total number would be 3.
List(0, 1, 2, 4, 0) and List(4,0,1,2,4) is the same cycle so we take one of them.
I tried to use filter but unable to find any logic to do this.
Following should work:
val input = List(List(4, 0, 1, 2, 4),List(4, 0, 1, 3, 4) ,List(4, 0, 2, 3, 4) ,List(4, 3, 2, 3, 4) ,List(4, 3, 4, 3, 4) ,
List(0, 1, 2, 4, 0) ,List(0, 1, 3, 4, 0) ,List(0, 2, 3, 4, 0) ,List(1, 2, 4, 0, 1) ,List(1, 3, 4, 0, 1) ,List(3, 4, 0, 1, 3) ,
List(3, 4, 0, 2, 3) ,List(3, 2, 3, 2, 3) ,List(3, 4, 3, 2, 3) ,List(3, 2, 3, 4, 3) ,List(3, 4, 3, 4, 3) ,
List(2, 3, 4, 0, 2) ,List(2, 4, 0, 1, 2) ,List(2, 3, 2, 3, 2), List(2, 3, 4, 3, 2))
var uniquePaths: mutable.Set[List[Int]] = collection.mutable.Set[List[Int]]()
var indexes: ListBuffer[Int] = mutable.ListBuffer[Int]()
input.zipWithIndex.foreach{x =>
val (list, index) = (x._1, x._2)
if(list.head==list.last) {
val list1 = rotateArray(list.tail)
if (list1.toSet.size == 4) {
if(!uniquePaths.contains(list1))
indexes.append(index)
uniquePaths.add(list1)
}
}
}
indexes foreach{x => println(input(x))}
def rotateArray(xs: List[Int]): List[Int] =
xs.splitAt(xs.indexOf(xs.min)) match {case (x, y) => List(y, x).flatten}
...freehand red cycles to the rescue.
Here are two different cycles on the same four vertices, which shows that sorting is insufficient:
The sketch assumes that all the points are vertices of a fully connected graph (edges omitted), and is supposed to show that the cycles [0, 1, 2, 3, 0] and [0, 2, 1, 3, 0] are not the same, despite the fact that if you sort the sets, you obtain [0, 1, 2, 3] in both cases.
Here is what might work instead:
Throw away all the paths which go through the same vertex more than once by filtering out all the paths that do not consist of four distinct elements.
Rotate the path representation into canonical form (e.g. starting at the vertex with minimum position).
Compute the set of canonical representations, retaining only the unique paths.
Here is what the implementation might look like:
def canonicalize(cycle: List[Int]) = {
val t = cycle.tail
val (b, a) = t.splitAt(t.zipWithIndex.minBy(_._1)._2)
val ab = (a ++ b)
ab :+ (ab.head)
}
val cycles = List(
List(4, 0, 1, 2, 4),
List(4, 0, 1, 3, 4),
List(4, 0, 2, 3, 4),
List(4, 3, 2, 3, 4),
List(4, 3, 4, 3, 4),
List(0, 1, 2, 4, 0),
List(0, 1, 3, 4, 0),
List(0, 2, 3, 4, 0),
List(1, 2, 4, 0, 1),
List(1, 3, 4, 0, 1),
List(3, 4, 0, 1, 3),
List(3, 4, 0, 2, 3),
List(3, 2, 3, 2, 3),
List(3, 4, 3, 2, 3),
List(3, 2, 3, 4, 3),
List(3, 4, 3, 4, 3),
List(2, 3, 4, 0, 2),
List(2, 4, 0, 1, 2),
List(2, 3, 2, 3, 2),
List(2, 3, 4, 3, 2)
)
val unique = cycles.filter(_.toSet.size == 4).map(canonicalize).toSet
unique foreach println
Output:
List(0, 1, 2, 4, 0)
List(0, 1, 3, 4, 0)
List(0, 2, 3, 4, 0)
Line-by-line example of what canonicalize does:
tail removes the duplicate vertex: [2, 1, 0, 4, 2] -> [1, 0, 4, 2]
splitAt finds the minimum vertex and cuts the list: [1, 0, 4, 2] -> ([1], [0, 4, 2])
a ++ b rebuilds the rotated list: [0, 4, 2, 1]
:+ appends the minimum vertex to the end: [0, 4, 2, 1, 0]
Drop the last element from the list (it's redundant)
Scroll the lists to start from the smallest ID
Sort loops by the length shortest first
You can use lexical matching now (if loop[i] contains any loop[0..i-1] -> drop it)

Spark RDDs: How to join value in a Map to a row in an RDD

I have a csv file that I am loading into Spark as an RDD with:
val home_rdd = sc.textFile("hdfs://path/to/home_data.csv")
val home_parsed = home_rdd.map(row => row.split(",").map(_.trim))
val home_header = home_parsed.first
val home_data = home_parsed.filter(_(0) != home_header(0))
home_data then is:
scala> home_data
res17: org.apache.spark.rdd.RDD[Array[String]] = MapPartitionsRDD[3] at filter at <console>:30
scala> home_data.take(3)
res20: Array[Array[String]] = Array(Array("7129300520", "20141013T000000", 221900, "3", "1", 1180, 5650, "1", 0, 0, 3, 7, 1180, 0, 1955, 0, "98178", 47.5112, -122.257, 1340, 5650), Array("6414100192", "20141209T000000", 538000, "3", "2.25", 2570, 7242, "2", 0, 0, 3, 7, 2170, 400, 1951, 1991, "98125", 47.721, -122.319, 1690, 7639), Array("5631500400", "20150225T000000", 180000, "2", "1", 770, 10000, "1", 0, 0, 3, 6, 770, 0, 1933, 0, "98028", 47.7379, -122.233, 2720, 8062))
I also have a csv of zipcodes to neighborhoods loaded as RDD then used to create a map that is a Map[String,String] with:
val zip_rdd = sc.textFile("hdfs://path/to/zipcodes.csv")
val zip_parsed = zip_rdd.map(row => row.split(",").map(_.trim))
val zip_header = zip_parsed.first
val zip_data = zip_parsed.filter(_(0) != zip_header(0))
val zip_map = zip_data.map(row => (row(0), row(1))).collectAsMap
val zip_ind = home_header.indexOf("zipcode") //to get the zipcode column in home_data
Where:
scala> zip_map.take(3)
res21: scala.collection.Map[String,String] = Map(98151 -> Seattle, 98052 -> Redmond, 98104 -> Seattle)
What I am trying to do next is iterate through home_data and use the zipcode value in each row (at zip_ind = 16) to fetch the neighborhood value from zip_map and append that value to the end of the row.
val zip_processed = home_data.map(row => row :+ zip_map.get(row(zip_ind)))
But each time it fetches from zip_map, something is failing and so it only appends None to the end of each row in home_data
scala> zip_processed.take(3)
res19: Array[Array[java.io.Serializable]] = Array(Array("7129300520", "20141013T000000", 221900, "3", "1", 1180, 5650, "1", 0, 0, 3, 7, 1180, 0, 1955, 0, "98178", 47.5112, -122.257, 1340, 5650, None), Array("6414100192", "20141209T000000", 538000, "3", "2.25", 2570, 7242, "2", 0, 0, 3, 7, 2170, 400, 1951, 1991, "98125", 47.721, -122.319, 1690, 7639, None), Array("5631500400", "20150225T000000", 180000, "2", "1", 770, 10000, "1", 0, 0, 3, 6, 770, 0, 1933, 0, "98028", 47.7379, -122.233, 2720, 8062, None))
I am trying to debug this, but am not sure why it's failing at zip_map.get(row(zip_ind)).
I am fairly green with Scala so maybe I am making some bad assumptions, but trying to figure out how to better understand what is happening in the map function.
Map.get() returns None when there is no match. You can use getOrElse to append the Map value with a fall-back:
val home_data = sc.parallelize(Array(
Array("7129300520", "20141013T000000", 221900, "3", "1", 1180, 5650, "1", 0, 0, 3, 7, 1180, 0, 1955, 0, "98178", 47.5112, -122.257, 1340, 5650),
Array("6414100192", "20141209T000000", 538000, "3", "2.25", 2570, 7242, "2", 0, 0, 3, 7, 2170, 400, 1951, 1991, "98125", 47.721, -122.319, 1690, 7639),
Array("5631500400", "20150225T000000", 180000, "2", "1", 770, 10000, "1", 0, 0, 3, 6, 770, 0, 1933, 0, "98028", 47.7379, -122.233, 2720, 8062)
))
val zip_ind = 16
val zip_map: Map[String, String] = Map("98178" -> "A", "98028" -> "B")
val zip_processed = home_data.map(row => row :+ zip_map.getOrElse(row(zip_ind).toString, "N/A"))
zip_processed.collect
// res1: Array[Array[Any]] = Array(
// Array(7129300520, 20141013T000000, 221900, 3, 1, 1180, 5650, 1, 0, 0, 3, 7, 1180, 0, 1955, 0, 98178, 47.5112, -122.257, 1340, 5650, A),
// Array(6414100192, 20141209T000000, 538000, 3, 2.25, 2570, 7242, 2, 0, 0, 3, 7, 2170, 400, 1951, 1991, 98125, 47.721, -122.319, 1690, 7639, N/A),
// Array(5631500400, 20150225T000000, 180000, 2, 1, 770, 10000, 1, 0, 0, 3, 6, 770, 0, 1933, 0, 98028, 47.7379, -122.233, 2720, 8062, B)
// )

Finding values greater than * in a map list

My current system is a mapped String,List[Int], the String being a key value, "Sk1", "Sk2" etc, and the int is a list of numbers from 0-9.
Here is my current method to find all of the lists, how do I edit this to find only all of the "Sk*"s greater than the selected "SK*". The value of the list is the last element of the tail, which I already have a function to find. It is the handleFive option menu. To clarify, I need to find the last element (already have that function) then display only stocks greater than the selected stock.
Handler for the menu options
def handleFive(): Boolean = {
mnuShowSingleDataStock(currentStockLevel)
true
}
def handleSeven(): Boolean = {
mnuShowPointsForStock(allStockLevel)
true
}
Functions that invoke and interact with the user
// Returns a single result, not a list
def mnuShowSingleDataStock(f: (String) => (String,Int)) = {
print("Stock > ")
val data = f(readLine)
println(s"${data._1}: ${data._2}")
}
//Returns a list value
def mnuShowPointsForStock(f: (String) => (String,List[Int])) = {
print("Stock > ")
val data = f(readLine)
println(s"${data._1}: ${data._2}")
}
Not sure how to edit this, currently it shows ALL of the values in the list, I only want to return values greater than the selected value
//Show last element in the list, most current
def currentStockLevel (stock: String): (String, Int) = {
(stock, mapdata.get (stock).map(findLast(_)).getOrElse(0))
}
//Unsure how to change this to only return values greater than the selected one, not everything
def currentStockLevel (stock: String): (String, List[Int]) = {
(stock, mapdata.get (stock).map(findLast(_)).getOrElse(0))
}
My current mapped list - THIS IS MAPDATA
val mapdata = Map(
"SK1" -> List(9, 7, 2, 0, 7, 3, 7, 9, 1, 2, 8, 1, 9, 6, 5, 3, 2, 2, 7, 2, 8, 5, 4, 5, 1, 6, 5, 2, 4, 1),
"SK2" -> List(0, 7, 6, 3, 3, 3, 1, 6, 9, 2, 9, 7, 8, 7, 3, 6, 3, 5, 5, 2, 9, 7, 3, 4, 6, 3, 4, 3, 4, 1),
"SK3" -> List(8, 7, 1, 8, 0, 5, 8, 3, 5, 9, 7, 5, 4, 7, 9, 8, 1, 4, 6, 5, 6, 6, 3, 6, 8, 8, 7, 4, 0, 6),
"SK4" -> List(2, 9, 5, 7, 0, 8, 6, 6, 7, 9, 0, 1, 3, 1, 6, 0, 0, 1, 3, 8, 5, 4, 0, 9, 7, 1, 4, 5, 2, 8),
"SK5" -> List(2, 6, 8, 0, 3, 5, 5, 2, 5, 9, 4, 5, 3, 5, 7, 8, 8, 2, 5, 9, 3, 8, 6, 7, 8, 7, 4, 1, 2, 3),
"SK6" -> List(2, 7, 5, 9, 1, 9, 8, 4, 1, 7, 3, 7, 0, 8, 4, 5, 9, 2, 4, 4, 8, 7, 9, 2, 2, 7, 9, 1, 6, 9),
"SK7" -> List(6, 9, 5, 0, 0, 0, 0, 5, 8, 3, 8, 7, 1, 9, 6, 1, 5, 3, 4, 7, 9, 5, 5, 9, 1, 4, 4, 0, 2, 0),
"SK8" -> List(2, 8, 8, 3, 1, 1, 0, 8, 5, 9, 0, 3, 1, 6, 8, 7, 9, 6, 7, 7, 0, 9, 5, 2, 5, 0, 2, 1, 8, 6),
"SK9" -> List(7, 1, 8, 8, 4, 4, 2, 2, 7, 4, 0, 6, 9, 5, 5, 4, 9, 1, 8, 6, 3, 4, 8, 2, 7, 9, 7, 2, 6, 6)
)
The Map[String, List[Int]] type has a filterKeys(f: String => Boolean) method, in order to keep only the keys satisfying a given predicate.
A possible solution would be
// get int value from stock if of the form "SK<int>"
def stockInt(stock: String): Option[Int] =
Try(stock.drop(2).toInt).filter(_ => stock.startsWith("SK")).toOption
// we keep the keys in the return, so that you do not get unordered results
// (order is not assured by Map)
def currentStockLevel(stock: String): (String, Map[String, Int]) = {
val maybeN = stockInt(stock)
def isGreater(other: String) = (for {
o <- stockInt(other)
n <- maybeN
} yield o > n).getOrElse(true) // if any key is not in the form of SK*, assume it is greater than the original stock
(
stock,
mapdata.filterKeys(isGreater(_)).mapValues(findLast(_))
)
}
Another possibility, if you are sure to have only "SK" keys, is to use SortedMap, which uses a SortedSet for its keys, so that you are sure to have key-value pairs ordered as you want them to be.
In that case, a solution would be
//put all values in mapdata in a SortedMap
val sortedMap = SortedMap[String, List[Int]]() ++ mapdata
def currentStockLevel(stock: String): (String, List[Int]) = {
(
stock,
sortedMap.dropWhile(_ <= stock).toList.map(_._2).map(findLast(_))
)
}
EDIT (after comments on what is expected as a return):
If I understand well what you are trying to do, you want to filter on the values rather than the keys. This is not a problem, Map also has a filter(p: ((K, V)) => Boolean): Map[K, V] method to do just that:
def currentHigherStockLevel(stock: String): Map[String, Int] = {
val current = datamap.get(stock).map(findLast).getOrElse(0) // if stock is not in the keySet, we keep all keys, by keeping those greater than 0.
datamap.mapValues(findLast).filter {
case (sk, val) => val > current
}
}
This returns a Map[String; Int] where the values are the last ones that are greater than the one given as parameter (we keep their keys because they will probably be useful).
If the key strings are things like "SK9" and "SK10" then you have to cut the digits out, convert to Int, and compare/filter them, but if your keys are kept in a completely consistent format: "SK001", "SK002" ... "SK009", "SK010" ... "SK099", "SK100", etc., then you use simple string comparisons to filter for just what you want.
mapdata.filterKeys(_ >= stock).values // an Iterable[List[Int]]

Creating a list of ints from a txt file

I have an external list in a txt file, I need to grab the first string and use it as a key, thats fine it works, and then I need a list of the numbers afterwards. However I only get the first, what have I done wrong?
So current output would be SK1, 9 - SK2, 0 etc when I need this to be the full list not just the first number.
I am using Scala on Intelije
/**
* Created by Andre on 10/11/2016.
*/
import scala.io.Source
import scala.io.StdIn.readInt
import scala.io.StdIn.readLine
import scala.collection.immutable.ListMap
object StockMarket extends App{
// APPLICATION LOGIC
// reads the data from text file
val mapdata = readFile("data.txt")
// print data to check it's been read in correctly
println(mapdata)
// *******************************************************************************************************************
// UTILITY FUNCTIONS
// reads data file - comma separated file
def readFile(filename: String): Map[String, Int] = {
// create buffer to build up map as we read each line
var mapBuffer: Map[String, Int] = Map()
try {
for (line <- Source.fromFile(filename).getLines()) { // for each line
val splitline = line.split(",").map(_.trim).toList // split line at , and convert to List
// add element to map buffer
// splitline is line from file as List, e.g. List(Bayern Munich, 24)
// use head as key
// tail is a list, but need just the first (only in this case) element, so use head of tail and convert to int
mapBuffer = mapBuffer ++ Map(splitline.head -> splitline.tail.head.toInt)
}
} catch {
case ex: Exception => println("Sorry, an exception happened.")
}
mapBuffer
}
}
My external List
SK1, 9, 7, 2, 0, 7, 3, 7, 9, 1, 2, 8, 1, 9, 6, 5, 3, 2, 2, 7, 2, 8, 5, 4, 5, 1, 6, 5, 2, 4, 1
SK2, 0, 7, 6, 3, 3, 3, 1, 6, 9, 2, 9, 7, 8, 7, 3, 6, 3, 5, 5, 2, 9, 7, 3, 4, 6, 3, 4, 3, 4, 1
SK4, 2, 9, 5, 7, 0, 8, 6, 6, 7, 9, 0, 1, 3, 1, 6, 0, 0, 1, 3, 8, 5, 4, 0, 9, 7, 1, 4, 5, 2, 8
SK5, 2, 6, 8, 0, 3, 5, 5, 2, 5, 9, 4, 5, 3, 5, 7, 8, 8, 2, 5, 9, 3, 8, 6, 7, 8, 7, 4, 1, 2, 3
SK6, 2, 7, 5, 9, 1, 9, 8, 4, 1, 7, 3, 7, 0, 8, 4, 5, 9, 2, 4, 4, 8, 7, 9, 2, 2, 7, 9, 1, 6, 9
SK7, 6, 9, 5, 0, 0, 0, 0, 5, 8, 3, 8, 7, 1, 9, 6, 1, 5, 3, 4, 7, 9, 5, 5, 9, 1, 4, 4, 0, 2, 0
SK8, 2, 8, 8, 3, 1, 1, 0, 8, 5, 9, 0, 3, 1, 6, 8, 7, 9, 6, 7, 7, 0, 9, 5, 2, 5, 0, 2, 1, 8, 6
SK9, 7, 1, 8, 8, 4, 4, 2, 2, 7, 4, 0, 6, 9, 5, 5, 4, 9, 1, 8, 6, 3, 4, 8, 2, 7, 9, 7, 2, 6, 6
Here is your code with minimal changes:
// I split it on two functions just to facilitate testing:
def readFile(filename: String): Map[String, List[Int]] = {
processInput(Source.fromFile(filename).getLines)
}
def processInput(lines: Iterator[String]): Map[String, List[Int]] = {
var mapBuffer: Map[String, List[Int]] = Map()
try {
for (line <- lines) {
val splitline = line.split(",").map(_.trim).toList
// here instead of taking .tail.head, we map over the tail (all numbers):
mapBuffer = mapBuffer + (splitline.head -> splitline.tail.map(_.toInt))
}
} catch {
case ex: Exception => println("Sorry, an exception happened.")
}
mapBuffer
}
And here is a solution, which I believe, is more a idiomatic Scala code:
import scala.util.Try
def processInput(lines: Iterator[String]): Map[String, List[Int]] = {
Try {
lines.foldLeft( Map[String, List[Int]]() ) { (acc, line) =>
val splitline = line.split(",").map(_.trim).toList
acc.updated(splitline.head, splitline.tail.map(_.toInt))
}
}.getOrElse {
println("Sorry, an exception happened.")
Map()
}
}
The differences mainly are
not using var
not using mutable Map (by the way, you don't need a var to mutate
it)
using foldLeft to iterate and accumulate the Map instead of for
using
scala.util.Try
instead of try-catch.

How do I accumulate results without using a mutable ArrayBuffer?

The code at the end of this question replaces the zeros with possible numbers ranging from 1 to 9 once and non-repeating. For a given sequence of numbers, List(0, 0, 1, 5, 0, 0, 8, 0, 0), it will returns the following result. There are 720 permutations in total.
List(2, 3, 1, 5, 4, 6, 8, 7, 9)
List(2, 3, 1, 5, 4, 6, 8, 9, 7)
List(2, 3, 1, 5, 4, 7, 8, 6, 9)
List(2, 3, 1, 5, 4, 7, 8, 9, 6)
List(2, 3, 1, 5, 4, 9, 8, 6, 7)
List(2, 3, 1, 5, 4, 9, 8, 7, 6)
List(2, 3, 1, 5, 6, 4, 8, 7, 9)
...
My question is how do I convert my code to NOT using ArrayBuffer(coll) as my temporary storage and the final result is returned from the function(search0) instead?
Thanks
/lim/
import collection.mutable.ArrayBuffer
object ScratchPad extends App {
def search(l : List[Int]) : ArrayBuffer[List[Int]] = {
def search0(la : List[Int], pos : Int, occur : List[Int], coll : ArrayBuffer[List[Int]]) : Unit = {
if (pos == l.length) {println(la); coll += la }
val bal = (1 to 9) diff occur
if (!bal.isEmpty) {
la(pos) match {
case 0 => bal map { x => search0(la.updated(pos, x), pos + 1, x :: occur, coll)}
case n => if (occur contains n) Nil else search0(la, pos + 1, n :: occur, coll)
}
}
}
val coll = ArrayBuffer[List[Int]]()
search0(l, 0, Nil, coll)
coll
}
println(search(List(0, 0, 1, 5, 0, 0, 8, 0, 0)).size)
}
Here is a shorter solution using immutable collection:
scala> def search(xs: Seq[Int])(implicit ys: Seq[Int] = (1 to 9).diff(xs)): Seq[Seq[Int]] = ys match {
| case Seq() => Seq(xs)
| case _ => ys.flatten(y => search(xs.updated(xs.indexOf(0), y))(ys.diff(Seq(y))))
| }
search: (xs: Seq[Int])(implicit ys: Seq[Int])Seq[Seq[Int]]
scala> search(List(0, 0, 1, 5, 0, 0, 8, 0, 0)).size
res0: Int = 720
scala> search(List(0, 0, 1, 5, 0, 0, 8, 0, 0)) take 10 foreach println
List(2, 3, 1, 5, 4, 6, 8, 7, 9)
List(2, 3, 1, 5, 4, 6, 8, 9, 7)
List(2, 3, 1, 5, 4, 7, 8, 6, 9)
List(2, 3, 1, 5, 4, 7, 8, 9, 6)
List(2, 3, 1, 5, 4, 9, 8, 6, 7)
List(2, 3, 1, 5, 4, 9, 8, 7, 6)
List(2, 3, 1, 5, 6, 4, 8, 7, 9)
List(2, 3, 1, 5, 6, 4, 8, 9, 7)
List(2, 3, 1, 5, 6, 7, 8, 4, 9)
List(2, 3, 1, 5, 6, 7, 8, 9, 4)
An even more shorter solution:
scala> :paste
// Entering paste mode (ctrl-D to finish)
def search(xs: Seq[Int], ys: Seq[Int] = 1 to 9): Seq[Seq[Int]] = ys.diff(xs) match {
case Seq() => Seq(xs)
case zs => zs.flatten(z => search(xs.updated(xs.indexOf(0),z),zs))
}
// Exiting paste mode, now interpreting.
search: (xs: Seq[Int], ys: Seq[Int])Seq[Seq[Int]]
scala> search(List(0, 0, 1, 5, 0, 0, 8, 0, 0)).size
res1: Int = 720
scala> search(List(0, 0, 1, 5, 0, 0, 8, 0, 0)) take 10 foreach println
List(2, 3, 1, 5, 4, 6, 8, 7, 9)
List(2, 3, 1, 5, 4, 6, 8, 9, 7)
List(2, 3, 1, 5, 4, 7, 8, 6, 9)
List(2, 3, 1, 5, 4, 7, 8, 9, 6)
List(2, 3, 1, 5, 4, 9, 8, 6, 7)
List(2, 3, 1, 5, 4, 9, 8, 7, 6)
List(2, 3, 1, 5, 6, 4, 8, 7, 9)
List(2, 3, 1, 5, 6, 4, 8, 9, 7)
List(2, 3, 1, 5, 6, 7, 8, 4, 9)
List(2, 3, 1, 5, 6, 7, 8, 9, 4)
Naive refactoring of your code using only immutable data structures:
object ScratchPad extends App {
def search(l: List[Int]): List[List[Int]] = {
def search0(la: List[Int], pos: Int, occur: List[Int]): List[List[Int]] =
if (pos == l.length)
List(la)
else {
val bal = (1 to 9) diff occur
if (pos < l.length && !bal.isEmpty)
la(pos) match {
case 0 => bal.toList flatMap {x => search0(la.updated(pos, x), pos + 1, x :: occur)}
case n => if (occur contains n) List.empty[List[Int]] else search0(la, pos + 1, n :: occur)
}
else
List.empty[List[Int]]
}
search0(l, 0, Nil)
}
val result = search(List(0, 0, 1, 5, 0, 0, 8, 0, 0))
result foreach println
println(result.size)
}