Scala: Detecting a Straight in a 5-card Poker hand using pattern matching - scala

For those who don't know what a 5-card Poker Straight is: http://en.wikipedia.org/wiki/List_of_poker_hands#Straight
I'm writing a small Poker simulator in Scala to help me learn the language, and I've created a Hand class with 5 ordered Cards in it. Each Card has a Rank and Suit, both defined as Enumerations. The Hand class has methods to evaluate the hand rank, and one of them checks whether the hand contains a Straight (we can ignore Straight Flushes for the moment). I know there are a few nice algorithms for determining a Straight, but I wanted to see whether I could design something with Scala's pattern matching, so I came up with the following:
def isStraight() = {
def matchesStraight(ranks: List[Rank.Value]): Boolean = ranks match {
case head :: Nil => true
case head :: tail if (Rank(head.id + 1) == tail.head) => matchesStraight(tail)
case _ => false
}
matchesStraight(cards.map(_.rank).toList)
}
That works fine and is fairly readable, but I was wondering if there is any way to get rid of that if. I'd imagine something like the following, though I can't get it to work:
private def isStraight() = {
def matchesStraight(ranks: List[Rank.Value]): Boolean = ranks match {
case head :: Nil => true
case head :: next(head.id + 1) :: tail => matchesStraight(next :: tail)
case _ => false
}
matchesStraight(cards.map(_.rank).toList)
}
Any ideas? Also, as a side question, what is the general opinion on the inner matchesStraight definition? Should this rather be private or perhaps done in a different way?

You can't pass information to an extractor, and you can't use information from one value returned in another, except on the if statement -- which is there to cover all these cases.
What you can do is create your own extractors to test these things, but it won't gain you much if there isn't any reuse.
For example:
class SeqExtractor[A, B](f: A => B) {
def unapplySeq(s: Seq[A]): Option[Seq[A]] =
if (s map f sliding 2 forall { case Seq(a, b) => a == b } ) Some(s)
else None
}
val Straight = new SeqExtractor((_: Card).rank)
Then you can use it like this:
listOfCards match {
case Straight(cards) => true
case _ => false
}
But, of course, all that you really want is that if statement in SeqExtractor. So, don't get too much in love with a solution, as you may miss simpler ways of doing stuff.

You could do something like:
val ids = ranks.map(_.id)
ids.max - ids.min == 4 && ids.distinct.length == 5
Handling aces correctly requires a bit of work, though.
Update: Here's a much better solution:
(ids zip ids.tail).forall{case (p,q) => q%13==(p+1)%13}
The % 13 in the comparison handles aces being both rank 1 and rank 14.

How about something like:
def isStraight(cards:List[Card]) = (cards zip cards.tail) forall { case (c1,c2) => c1.rank+1 == c2.rank}
val cards = List(Card(1),Card(2),Card(3),Card(4))
scala> isStraight(cards)
res2: Boolean = true

This is a completely different approache, but it does use pattern matching. It produces warnings in the match clause which seem to indicate that it shouldn't work. But it actually produces the correct results:
Straight !!! 34567
Straight !!! 34567
Sorry no straight this time
I ignored the Suites for now and I also ignored the possibility of an ace under a 2.
abstract class Rank {
def value : Int
}
case class Next[A <: Rank](a : A) extends Rank {
def value = a.value + 1
}
case class Two() extends Rank {
def value = 2
}
class Hand(a : Rank, b : Rank, c : Rank, d : Rank, e : Rank) {
val cards = List(a, b, c, d, e).sortWith(_.value < _.value)
}
object Hand{
def unapply(h : Hand) : Option[(Rank, Rank, Rank, Rank, Rank)] = Some((h.cards(0), h.cards(1), h.cards(2), h.cards(3), h.cards(4)))
}
object Poker {
val two = Two()
val three = Next(two)
val four = Next(three)
val five = Next(four)
val six = Next(five)
val seven = Next(six)
val eight = Next(seven)
val nine = Next(eight)
val ten = Next(nine)
val jack = Next(ten)
val queen = Next(jack)
val king = Next(queen)
val ace = Next(king)
def main(args : Array[String]) {
val simpleStraight = new Hand(three, four, five, six, seven)
val unsortedStraight = new Hand(four, seven, three, six, five)
val notStraight = new Hand (two, two, five, five, ace)
printIfStraight(simpleStraight)
printIfStraight(unsortedStraight)
printIfStraight(notStraight)
}
def printIfStraight[A](h : Hand) {
h match {
case Hand(a: A , b : Next[A], c : Next[Next[A]], d : Next[Next[Next[A]]], e : Next[Next[Next[Next[A]]]]) => println("Straight !!! " + a.value + b.value + c.value + d.value + e.value)
case Hand(a,b,c,d,e) => println("Sorry no straight this time")
}
}
}
If you are interested in more stuff like this google 'church numerals scala type system'

How about something like this?
def isStraight = {
cards.map(_.rank).toList match {
case first :: second :: third :: fourth :: fifth :: Nil if
first.id == second.id - 1 &&
second.id == third.id - 1 &&
third.id == fourth.id - 1 &&
fourth.id == fifth.id - 1 => true
case _ => false
}
}
You're still stuck with the if (which is in fact larger) but there's no recursion or custom extractors (which I believe you're using incorrectly with next and so is why your second attempt doesn't work).

If you're writing a poker program, you are already check for n-of-a-kind. A hand is a straight when it has no n-of-a-kinds (n > 1) and the different between the minimum denomination and the maximum is exactly four.

I was doing something like this a few days ago, for Project Euler problem 54. Like you, I had Rank and Suit as enumerations.
My Card class looks like this:
case class Card(rank: Rank.Value, suit: Suit.Value) extends Ordered[Card] {
def compare(that: Card) = that.rank compare this.rank
}
Note I gave it the Ordered trait so that we can easily compare cards later. Also, when parsing the hands, I sorted them from high to low using sorted, which makes assessing values much easier.
Here is my straight test which returns an Option value depending on whether it's a straight or not. The actual return value (a list of Ints) is used to determine the strength of the hand, the first representing the hand type from 0 (no pair) to 9 (straight flush), and the others being the ranks of any other cards in the hand that count towards its value. For straights, we're only worried about the highest ranking card.
Also, note that you can make a straight with Ace as low, the "wheel", or A2345.
case class Hand(cards: Array[Card]) {
...
def straight: Option[List[Int]] = {
if( cards.sliding(2).forall { case Array(x, y) => (y compare x) == 1 } )
Some(5 :: cards(0).rank.id :: 0 :: 0 :: 0 :: 0 :: Nil)
else if ( cards.map(_.rank.id).toList == List(12, 3, 2, 1, 0) )
Some(5 :: cards(1).rank.id :: 0 :: 0 :: 0 :: 0 :: Nil)
else None
}
}

Here is a complete idiomatic Scala hand classifier for all hands (handles 5-high straights):
case class Card(rank: Int, suit: Int) { override def toString = s"${"23456789TJQKA" rank}${"♣♠♦♥" suit}" }
object HandType extends Enumeration {
val HighCard, OnePair, TwoPair, ThreeOfAKind, Straight, Flush, FullHouse, FourOfAKind, StraightFlush = Value
}
case class Hand(hand: Set[Card]) {
val (handType, sorted) = {
def rankMatches(card: Card) = hand count (_.rank == card.rank)
val groups = hand groupBy rankMatches mapValues {_.toList.sorted}
val isFlush = (hand groupBy {_.suit}).size == 1
val isWheel = "A2345" forall {r => hand exists (_.rank == Card.ranks.indexOf(r))} // A,2,3,4,5 straight
val isStraight = groups.size == 1 && (hand.max.rank - hand.min.rank) == 4 || isWheel
val (isThreeOfAKind, isOnePair) = (groups contains 3, groups contains 2)
val handType = if (isStraight && isFlush) HandType.StraightFlush
else if (groups contains 4) HandType.FourOfAKind
else if (isThreeOfAKind && isOnePair) HandType.FullHouse
else if (isFlush) HandType.Flush
else if (isStraight) HandType.Straight
else if (isThreeOfAKind) HandType.ThreeOfAKind
else if (isOnePair && groups(2).size == 4) HandType.TwoPair
else if (isOnePair) HandType.OnePair
else HandType.HighCard
val kickers = ((1 until 5) flatMap groups.get).flatten.reverse
require(hand.size == 5 && kickers.size == 5)
(handType, if (isWheel) (kickers takeRight 4) :+ kickers.head else kickers)
}
}
object Hand {
import scala.math.Ordering.Implicits._
implicit val rankOrdering = Ordering by {hand: Hand => (hand.handType, hand.sorted)}
}

Related

Scala String Equality Question from Programming Interview

Since I liked programming in Scala, for my Google interview, I asked them to give me a Scala / functional programming style question. The Scala functional style question that I got was as follows:
You have two strings consisting of alphabetic characters as well as a special character representing the backspace symbol. Let's call this backspace character '/'. When you get to the keyboard, you type this sequence of characters, including the backspace/delete character. The solution you are to implement must check if the two sequences of characters produce the same output. For example, "abc", "aa/bc". "abb/c", "abcc/", "/abc", and "//abc" all produce the same output, "abc". Because this is a Scala / functional programming question, you must implement your solution in idiomatic Scala style.
I wrote the following code (it might not be exactly what I wrote, I'm just going off memory). Basically I just go linearly through the string, prepending characters to a list, and then I compare the lists.
def processString(string: String): List[Char] = {
string.foldLeft(List[Char]()){ case(accumulator: List[Char], char: Char) =>
accumulator match {
case head :: tail => if(char != '/') { char :: head :: tail } else { tail }
case emptyList => if(char != '/') { char :: emptyList } else { emptyList }
}
}
}
def solution(string1: String, string2: String): Boolean = {
processString(string1) == processString(string2)
}
So far so good? He then asked for the time complexity and I responded linear time (because you have to process each character once) and linear space (because you have to copy each element into a list). Then he asked me to do it in linear time, but with constant space. I couldn't think of a way to do it that was purely functional. He said to try using a function in the Scala collections library like "zip" or "map" (I explicitly remember him saying the word "zip").
Here's the thing. I think that it's physically impossible to do it in constant space without having any mutable state or side effects. Like I think that he messed up the question. What do you think?
Can you solve it in linear time, but with constant space?
This code takes O(N) time and needs only three integers of extra space:
def solution(a: String, b: String): Boolean = {
def findNext(str: String, pos: Int): Int = {
#annotation.tailrec
def rec(pos: Int, backspaces: Int): Int = {
if (pos == 0) -1
else {
val c = str(pos - 1)
if (c == '/') rec(pos - 1, backspaces + 1)
else if (backspaces > 0) rec(pos - 1, backspaces - 1)
else pos - 1
}
}
rec(pos, 0)
}
#annotation.tailrec
def rec(aPos: Int, bPos: Int): Boolean = {
val ap = findNext(a, aPos)
val bp = findNext(b, bPos)
(ap < 0 && bp < 0) ||
(ap >= 0 && bp >= 0 && (a(ap) == b(bp)) && rec(ap, bp))
}
rec(a.size, b.size)
}
The problem can be solved in linear time with constant extra space: if you scan from right to left, then you can be sure that the /-symbols to the left of the current position cannot influence the already processed symbols (to the right of the current position) in any way, so there is no need to store them.
At every point, you need to know only two things:
Where are you in the string?
How many symbols do you have to throw away because of the backspaces
That makes two integers for storing the positions, and one additional integer for temporary storing the number of accumulated backspaces during the findNext invocation. That's a total of three integers of space overhead.
Intuition
Here is my attempt to formulate why the right-to-left scan gives you a O(1) algorithm:
The future cannot influence the past, therefore there is no need to remember the future.
The "natural time" in this problem flows from left to right. Therefore, if you scan from right to left, you are moving "from the future into the past", and therefore you don't need to remember the characters to the right of your current position.
Tests
Here is a randomized test, which makes me pretty sure that the solution is actually correct:
val rng = new util.Random(0)
def insertBackspaces(s: String): String = {
val n = s.size
val insPos = rng.nextInt(n)
val (pref, suff) = s.splitAt(insPos)
val c = ('a' + rng.nextInt(26)).toChar
pref + c + "/" + suff
}
def prependBackspaces(s: String): String = {
"/" * rng.nextInt(4) + s
}
def addBackspaces(s: String): String = {
var res = s
for (i <- 0 until 8)
res = insertBackspaces(res)
prependBackspaces(res)
}
for (i <- 1 until 1000) {
val s = "hello, world"
val t = "another string"
val s1 = addBackspaces(s)
val s2 = addBackspaces(s)
val t1 = addBackspaces(t)
val t2 = addBackspaces(t)
assert(solution(s1, s2))
assert(solution(t1, t2))
assert(!solution(s1, t1))
assert(!solution(s1, t2))
assert(!solution(s2, t1))
assert(!solution(s2, t2))
if (i % 100 == 0) {
println(s"Examples:\n$s1\n$s2\n$t1\n$t2")
}
}
A few examples that the test generates:
Examples:
/helly/t/oj/m/, wd/oi/g/x/rld
///e/helx/lc/rg//f/o, wosq//rld
/anotl/p/hhm//ere/t/ strih/nc/g
anotx/hb/er sw/p/tw/l/rip/j/ng
Examples:
//o/a/hellom/, i/wh/oe/q/b/rld
///hpj//est//ldb//y/lok/, world
///q/gd/h//anothi/k/eq/rk/ string
///ac/notherli// stri/ig//ina/n/g
Examples:
//hnn//ello, t/wl/oxnh///o/rld
//helfo//u/le/o, wna//ova//rld
//anolq/l//twl//her n/strinhx//g
/anol/tj/hq/er swi//trrq//d/ing
Examples:
//hy/epe//lx/lo, wr/v/t/orlc/d
f/hk/elv/jj//lz/o,wr// world
/anoto/ho/mfh///eg/r strinbm//g
///ap/b/notk/l/her sm/tq/w/rio/ng
Examples:
///hsm/y//eu/llof/n/, worlq/j/d
///gx//helf/i/lo, wt/g/orn/lq/d
///az/e/notm/hkh//er sm/tb/rio/ng
//b/aen//nother v/sthg/m//riv/ng
Seems to work just fine. So, I'd say that the Google-guy did not mess up, looks like a perfectly valid question.
You don't have to create the output to find the answer. You can iterate the two sequences at the same time and stop on the first difference. If you find no difference and both sequences terminate at the same time, they're equal, otherwise they're different.
But now consider sequences such as this one: aaaa/// to compare with a. You need to consume 6 elements from the left sequence and one element from the right sequence before you can assert that they're equal. That means that you would need to keep at least 5 elements in memory until you can verify that they're all deleted. But what if you iterated elements from the end? You would then just need to count the number of backspaces and then just ignoring as many elements as necessary in the left sequence without requiring to keep them in memory since you know they won't be present in the final output. You can achieve O(1) memory using these two tips.
I tried it and it seems to work:
def areEqual(s1: String, s2: String) = {
def charAt(s: String, index: Int) = if (index < 0) '#' else s(index)
#tailrec
def recSol(i1: Int, backspaces1: Int, i2: Int, backspaces2: Int): Boolean = (charAt(s1, i1), charAt(s2, i2)) match {
case ('/', _) => recSol(i1 - 1, backspaces1 + 1, i2, backspaces2)
case (_, '/') => recSol(i1, backspaces1, i2 - 1, backspaces2 + 1)
case ('#' , '#') => true
case (ch1, ch2) =>
if (backspaces1 > 0) recSol(i1 - 1, backspaces1 - 1, i2 , backspaces2 )
else if (backspaces2 > 0) recSol(i1 , backspaces1 , i2 - 1, backspaces2 - 1)
else ch1 == ch2 && recSol(i1 - 1, backspaces1 , i2 - 1, backspaces2 )
}
recSol(s1.length - 1, 0, s2.length - 1, 0)
}
Some tests (all pass, let me know if you have more edge cases in mind):
// examples from the question
val inputs = Array("abc", "aa/bc", "abb/c", "abcc/", "/abc", "//abc")
for (i <- 0 until inputs.length; j <- 0 until inputs.length) {
assert(areEqual(inputs(i), inputs(j)))
}
// more deletions than required
assert(areEqual("a///////b/c/d/e/b/b", "b"))
assert(areEqual("aa/a/a//a//a///b", "b"))
assert(areEqual("a/aa///a/b", "b"))
// not enough deletions
assert(!areEqual("aa/a/a//a//ab", "b"))
// too many deletions
assert(!areEqual("a", "a/"))
PS: just a few notes on the code itself:
Scala type inference is good enough so that you can drop types in the partial function inside your foldLeft
Nil is the idiomatic way to refer to the empty list case
Bonus:
I had something like Tim's soltion in mind before implementing my idea, but I started early with pattern matching on characters only and it didn't fit well because some cases require the number of backspaces. In the end, I think a neater way to write it is a mix of pattern matching and if conditions. Below is my longer original solution, the one I gave above was refactored laater:
def areEqual(s1: String, s2: String) = {
#tailrec
def recSol(c1: Cursor, c2: Cursor): Boolean = (c1.char, c2.char) match {
case ('/', '/') => recSol(c1.next, c2.next)
case ('/' , _) => recSol(c1.next, c2 )
case (_ , '/') => recSol(c1 , c2.next)
case ('#' , '#') => true
case (a , b) if (a == b) => recSol(c1.next, c2.next)
case _ => false
}
recSol(Cursor(s1, s1.length - 1), Cursor(s2, s2.length - 1))
}
private case class Cursor(s: String, index: Int) {
val char = if (index < 0) '#' else s(index)
def next = {
#tailrec
def recSol(index: Int, backspaces: Int): Cursor = {
if (index < 0 ) Cursor(s, index)
else if (s(index) == '/') recSol(index - 1, backspaces + 1)
else if (backspaces > 1) recSol(index - 1, backspaces - 1)
else Cursor(s, index - 1)
}
recSol(index, 0)
}
}
If the goal is minimal memory footprint, it's hard to argue against iterators.
def areSame(a :String, b :String) :Boolean = {
def getNext(ci :Iterator[Char], ignore :Int = 0) : Option[Char] =
if (ci.hasNext) {
val c = ci.next()
if (c == '/') getNext(ci, ignore+1)
else if (ignore > 0) getNext(ci, ignore-1)
else Some(c)
} else None
val ari = a.reverseIterator
val bri = b.reverseIterator
1 to a.length.max(b.length) forall(_ => getNext(ari) == getNext(bri))
}
On the other hand, when arguing FP principals it's hard to defend iterators, since they're all about maintaining state.
Here is a version with a single recursive function and no additional classes or libraries. This is linear time and constant memory.
def compare(a: String, b: String): Boolean = {
#tailrec
def loop(aIndex: Int, aDeletes: Int, bIndex: Int, bDeletes: Int): Boolean = {
val aVal = if (aIndex < 0) None else Some(a(aIndex))
val bVal = if (bIndex < 0) None else Some(b(bIndex))
if (aVal.contains('/')) {
loop(aIndex - 1, aDeletes + 1, bIndex, bDeletes)
} else if (aDeletes > 0) {
loop(aIndex - 1, aDeletes - 1, bIndex, bDeletes)
} else if (bVal.contains('/')) {
loop(aIndex, 0, bIndex - 1, bDeletes + 1)
} else if (bDeletes > 0) {
loop(aIndex, 0, bIndex - 1, bDeletes - 1)
} else {
aVal == bVal && (aVal.isEmpty || loop(aIndex - 1, 0, bIndex - 1, 0))
}
}
loop(a.length - 1, 0, b.length - 1, 0)
}

Generate all IP addresses given a string in scala

I was trying my hand at writing an IP generator given a string of numbers. The generator would take as an input a string of number such as "17234" and will return all possible list of ips as follows:
1.7.2.34
1.7.23.4
1.72.3.4
17.2.3.4
I attempted to write a snippet to do the generation as follows:
def genip(ip:String):Unit = {
def legal(ip:String):Boolean = (ip.size == 1) || (ip.size == 2) || (ip.size == 3)
def genips(ip:String,portion:Int,accum:String):Unit = portion match {
case 1 if legal(ip) => println(accum+ip)
case _ if portion > 1 => {
genips(ip.drop(1),portion-1,if(accum.size == 0) ip.take(1)+"." else accum+ip.take(1)+".")
genips(ip.drop(2),portion-1,if(accum.size == 0) ip.take(2)+"." else accum+ip.take(2)+".")
genips(ip.drop(3),portion-1,if(accum.size == 0) ip.take(3)+"." else accum+ip.take(3)+".")
}
case _ => return
}
genips(ip,4,"")
}
The idea is to partition the string into four octets and then further partition the octet into strings of size "1","2" and "3" and then recursively descend into the remaining string.
I am not sure if I am on the right track but it would be great if somebody could suggest a more functional way of accomplishing the same.
Thanks
Here is an alternative version of the attached code:
def generateIPs(digits : String) : Seq[String] = generateIPs(digits, 4)
private def generateIPs(digits : String, partsLeft : Int) : Seq[String] = {
if ( digits.size < partsLeft || digits.size > partsLeft * 3) {
Nil
} else if(partsLeft == 1) {
Seq(digits)
} else {
(1 to 3).map(n => generateIPs(digits.drop(n), partsLeft - 1)
.map(digits.take(n) + "." + _)
).flatten
}
}
println("Results:\n" + generateIPs("17234").mkString("\n"))
Major changes:
Methods now return the collection of strings (rather than Unit), so they are proper functions (rather than working of side effects) and can be easily tested;
Avoiding repeating the same code 3 times depending on the size of the bunch of numbers we take;
Not passing accumulated interim result as a method parameter - in this case it doesn't have sense since you'll have at most 4 recursive calls and it's easier to read without it, though as you're loosing the tail recursion in many case it might be reasonable to leave it.
Note: The last map statement is a good candidate to be replaced by for comprehension, which many developers find easier to read and reason about, though I will leave it as an exercise :)
You code is the right idea; I'm not sure making it functional really helps anything, but I'll show both functional and side-effecting ways to do what you want. First, we'd like a good routine to chunk off some of the numbers, making sure an okay number are left for the rest of the chunking, and making sure they're in range for IPs:
def validSize(i: Int, len: Int, more: Int) = i + more <= len && i + 3*more >= len
def chunk(s: String, more: Int) = {
val parts = for (i <- 1 to 3 if validSize(i, s.length, more)) yield s.splitAt(i)
parts.filter(_._1.toInt < 256)
}
Now we need to use chunk recursively four times to generate the possibilities. Here's a solution that is mutable internally and iterative:
def genIPs(digits: String) = {
var parts = List(("", digits))
for (i <- 1 to 4) {
parts = parts.flatMap{ case (pre, post) =>
chunk(post, 4-i).map{ case (x,y) => (pre+x+".", y) }
}
}
parts.map(_._1.dropRight(1))
}
Here's one that recurses using Iterator:
def genIPs(digits: String) = Iterator.iterate(List((3,"",digits))){ _.flatMap{
case(j, pre, post) => chunk(post, j).map{ case(x,y) => (j-1, pre+x+".", y) }
}}.dropWhile(_.head._1 >= 0).next.map(_._2.dropRight(1))
The logic is the same either way. Here it is working:
scala> genIPs("1238516")
res2: List[String] = List(1.23.85.16, 1.238.5.16, 1.238.51.6,
12.3.85.16, 12.38.5.16, 12.38.51.6,
123.8.5.16, 123.8.51.6, 123.85.1.6)

Abort early in a fold

What's the best way to terminate a fold early? As a simplified example, imagine I want to sum up the numbers in an Iterable, but if I encounter something I'm not expecting (say an odd number) I might want to terminate. This is a first approximation
def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
nums.foldLeft (Some(0): Option[Int]) {
case (Some(s), n) if n % 2 == 0 => Some(s + n)
case _ => None
}
}
However, this solution is pretty ugly (as in, if I did a .foreach and a return -- it'd be much cleaner and clearer) and worst of all, it traverses the entire iterable even if it encounters a non-even number.
So what would be the best way to write a fold like this, that terminates early? Should I just go and write this recursively, or is there a more accepted way?
My first choice would usually be to use recursion. It is only moderately less compact, is potentially faster (certainly no slower), and in early termination can make the logic more clear. In this case you need nested defs which is a little awkward:
def sumEvenNumbers(nums: Iterable[Int]) = {
def sumEven(it: Iterator[Int], n: Int): Option[Int] = {
if (it.hasNext) {
val x = it.next
if ((x % 2) == 0) sumEven(it, n+x) else None
}
else Some(n)
}
sumEven(nums.iterator, 0)
}
My second choice would be to use return, as it keeps everything else intact and you only need to wrap the fold in a def so you have something to return from--in this case, you already have a method, so:
def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
Some(nums.foldLeft(0){ (n,x) =>
if ((n % 2) != 0) return None
n+x
})
}
which in this particular case is a lot more compact than recursion (though we got especially unlucky with recursion since we had to do an iterable/iterator transformation). The jumpy control flow is something to avoid when all else is equal, but here it's not. No harm in using it in cases where it's valuable.
If I was doing this often and wanted it within the middle of a method somewhere (so I couldn't just use return), I would probably use exception-handling to generate non-local control flow. That is, after all, what it is good at, and error handling is not the only time it's useful. The only trick is to avoid generating a stack trace (which is really slow), and that's easy because the trait NoStackTrace and its child trait ControlThrowable already do that for you. Scala already uses this internally (in fact, that's how it implements the return from inside the fold!). Let's make our own (can't be nested, though one could fix that):
import scala.util.control.ControlThrowable
case class Returned[A](value: A) extends ControlThrowable {}
def shortcut[A](a: => A) = try { a } catch { case Returned(v) => v }
def sumEvenNumbers(nums: Iterable[Int]) = shortcut{
Option(nums.foldLeft(0){ (n,x) =>
if ((x % 2) != 0) throw Returned(None)
n+x
})
}
Here of course using return is better, but note that you could put shortcut anywhere, not just wrapping an entire method.
Next in line for me would be to re-implement fold (either myself or to find a library that does it) so that it could signal early termination. The two natural ways of doing this are to not propagate the value but an Option containing the value, where None signifies termination; or to use a second indicator function that signals completion. The Scalaz lazy fold shown by Kim Stebel already covers the first case, so I'll show the second (with a mutable implementation):
def foldOrFail[A,B](it: Iterable[A])(zero: B)(fail: A => Boolean)(f: (B,A) => B): Option[B] = {
val ii = it.iterator
var b = zero
while (ii.hasNext) {
val x = ii.next
if (fail(x)) return None
b = f(b,x)
}
Some(b)
}
def sumEvenNumbers(nums: Iterable[Int]) = foldOrFail(nums)(0)(_ % 2 != 0)(_ + _)
(Whether you implement the termination by recursion, return, laziness, etc. is up to you.)
I think that covers the main reasonable variants; there are some other options also, but I'm not sure why one would use them in this case. (Iterator itself would work well if it had a findOrPrevious, but it doesn't, and the extra work it takes to do that by hand makes it a silly option to use here.)
The scenario you describe (exit upon some unwanted condition) seems like a good use case for the takeWhile method. It is essentially filter, but should end upon encountering an element that doesn't meet the condition.
For example:
val list = List(2,4,6,8,6,4,2,5,3,2)
list.takeWhile(_ % 2 == 0) //result is List(2,4,6,8,6,4,2)
This will work just fine for Iterators/Iterables too. The solution I suggest for your "sum of even numbers, but break on odd" is:
list.iterator.takeWhile(_ % 2 == 0).foldLeft(...)
And just to prove that it's not wasting your time once it hits an odd number...
scala> val list = List(2,4,5,6,8)
list: List[Int] = List(2, 4, 5, 6, 8)
scala> def condition(i: Int) = {
| println("processing " + i)
| i % 2 == 0
| }
condition: (i: Int)Boolean
scala> list.iterator.takeWhile(condition _).sum
processing 2
processing 4
processing 5
res4: Int = 6
You can do what you want in a functional style using the lazy version of foldRight in scalaz. For a more in depth explanation, see this blog post. While this solution uses a Stream, you can convert an Iterable into a Stream efficiently with iterable.toStream.
import scalaz._
import Scalaz._
val str = Stream(2,1,2,2,2,2,2,2,2)
var i = 0 //only here for testing
val r = str.foldr(Some(0):Option[Int])((n,s) => {
println(i)
i+=1
if (n % 2 == 0) s.map(n+) else None
})
This only prints
0
1
which clearly shows that the anonymous function is only called twice (i.e. until it encounters the odd number). That is due to the definition of foldr, whose signature (in case of Stream) is def foldr[B](b: B)(f: (Int, => B) => B)(implicit r: scalaz.Foldable[Stream]): B. Note that the anonymous function takes a by name parameter as its second argument, so it need no be evaluated.
Btw, you can still write this with the OP's pattern matching solution, but I find if/else and map more elegant.
Well, Scala does allow non local returns. There are differing opinions on whether or not this is a good style.
scala> def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
| nums.foldLeft (Some(0): Option[Int]) {
| case (None, _) => return None
| case (Some(s), n) if n % 2 == 0 => Some(s + n)
| case (Some(_), _) => None
| }
| }
sumEvenNumbers: (nums: Iterable[Int])Option[Int]
scala> sumEvenNumbers(2 to 10)
res8: Option[Int] = None
scala> sumEvenNumbers(2 to 10 by 2)
res9: Option[Int] = Some(30)
EDIT:
In this particular case, as #Arjan suggested, you can also do:
def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
nums.foldLeft (Some(0): Option[Int]) {
case (Some(s), n) if n % 2 == 0 => Some(s + n)
case _ => return None
}
}
You can use foldM from cats lib (as suggested by #Didac) but I suggest to use Either instead of Option if you want to get actual sum out.
bifoldMap is used to extract the result from Either.
import cats.implicits._
def sumEven(nums: Stream[Int]): Either[Int, Int] = {
nums.foldM(0) {
case (acc, n) if n % 2 == 0 => Either.right(acc + n)
case (acc, n) => {
println(s"Stopping on number: $n")
Either.left(acc)
}
}
}
examples:
println("Result: " + sumEven(Stream(2, 2, 3, 11)).bifoldMap(identity, identity))
> Stopping on number: 3
> Result: 4
println("Result: " + sumEven(Stream(2, 7, 2, 3)).bifoldMap(identity, identity))
> Stopping on number: 7
> Result: 2
Cats has a method called foldM which does short-circuiting (for Vector, List, Stream, ...).
It works as follows:
def sumEvenNumbers(nums: Stream[Int]): Option[Long] = {
import cats.implicits._
nums.foldM(0L) {
case (acc, c) if c % 2 == 0 => Some(acc + c)
case _ => None
}
}
If it finds a not even element it returns None without computing the rest, otherwise it returns the sum of the even entries.
If you want to keep count until an even entry is found, you should use an Either[Long, Long]
#Rex Kerr your answer helped me, but I needed to tweak it to use Either
def foldOrFail[A,B,C,D](map: B => Either[D, C])(merge: (A, C) => A)(initial: A)(it: Iterable[B]): Either[D, A] = {
val ii= it.iterator
var b= initial
while (ii.hasNext) {
val x= ii.next
map(x) match {
case Left(error) => return Left(error)
case Right(d) => b= merge(b, d)
}
}
Right(b)
}
You could try using a temporary var and using takeWhile. Here is a version.
var continue = true
// sample stream of 2's and then a stream of 3's.
val evenSum = (Stream.fill(10)(2) ++ Stream.fill(10)(3)).takeWhile(_ => continue)
.foldLeft(Option[Int](0)){
case (result,i) if i%2 != 0 =>
continue = false;
// return whatever is appropriate either the accumulated sum or None.
result
case (optionSum,i) => optionSum.map( _ + i)
}
The evenSum should be Some(20) in this case.
You can throw a well-chosen exception upon encountering your termination criterion, handling it in the calling code.
A more beutiful solution would be using span:
val (l, r) = numbers.span(_ % 2 == 0)
if(r.isEmpty) Some(l.sum)
else None
... but it traverses the list two times if all the numbers are even
Just for an "academic" reasons (:
var headers = Source.fromFile(file).getLines().next().split(",")
var closeHeaderIdx = headers.takeWhile { s => !"Close".equals(s) }.foldLeft(0)((i, S) => i+1)
Takes twice then it should but it is a nice one liner.
If "Close" not found it will return
headers.size
Another (better) is this one:
var headers = Source.fromFile(file).getLines().next().split(",").toList
var closeHeaderIdx = headers.indexOf("Close")

How to find students with the best grades in a list?

Suppose, I have a list of Students. Students have fields like name, birth date, grade, etc. How would you find Students with the best grade in Scala?
For example:
List(Student("Mike", "A"), Student("Pete", "B"), Student("Paul", A))"
I want to get
List(Student("Mike", "A"), Student("Paul", A))
Obviously, I can find the max grade ("A" in the list above) and then filter the list
students.filter(_.grade == max_grade)
This solution is O(N) but runs over the list twice. Can you suggest a better solution?
Running over the list twice is probably the best way to do it, but if you insist on a solution that only runs over once, you can use a fold (here works on empty lists):
(List[Student]() /: list){ (best,next) => best match {
case Nil => next :: Nil
case x :: rest =>
if (betterGrade(x,next)) best
else if (betterGrade(next,x)) next :: Nil
else next :: best
}}
If you are unfamiliar with folds, they are described in an answer here. They're a general way of accumulating something as you pass over a collection (e.g. list). If you're unfamiliar with matching, you can do the same thing with isEmpty and head. If you want the students in the same order as they appeared in the original list, run .reverse at the end.
Using foldLeft you traverse the list of students only once :
scala> students.foldLeft(List.empty[Student]) {
| case (Nil, student) => student::Nil
| case (list, student) if (list.head.grade == student.grade) => student::list
| case (list, student) if (list.head.grade > student.grade) => student::Nil
| case (list, _) => list
| }
res17: List[Student] = List(Student(Paul,A), Student(Mike,A))
There are already 6 answers, but I still feel compelled to add my take:
case class Lifted(grade: String, students: List[Student]) {
def mergeBest(other: Lifted) = grade compare other.grade match {
case 0 => copy(students = other.students ::: students)
case 1 => other
case -1 => this
}
}
This little case class will lift a Student into an object that keeps track of the best grade and a list cell containing the student. It also factors out the logic of the merge: if the new student has
the same grade => merge the new student into the result list, with the shorter one at the front for efficiency - otherwise result won't be O(n)
higher grade => replace current best with the new student
lower grade => keep the current best
The result can then be easily constructed with a reduceLeft:
val result = {
list.view.map(s => Lifted(s.grade, List(s)))
.reduceLeft((l1, l2) => l1.mergeBest(l2))
.students
}
// result: List[Student] = List(Student(Paul,A), Student(Mike,A))
PS. I'm upvoting your question - based on the sheer number of generated response
You can use filter on the students list.
case class Student(grade: Int)
val students = ...
val minGrade = 5
students filter ( _.grade < minGrade)
Works just fine also for type String
After sorting, you can use takeWhile to avoid iterating a second time.
case class Student(grade: Int)
val list = List(Student(1), Student(3), Student(2), Student(1), Student(25), Student(0), Student (25))
val sorted = list.sortWith (_.grade > _.grade)
sorted.takeWhile (_.grade == sorted(0).grade)
It still sorts the whole thing, even grades 1, 3, 0 and -1, which we aren't interested in, before taking the whipped cream, but it is short code.
update:
A second approach, which can be performed in parallel, is, to split the list, and take the max of each side, and then only the higher one, if there is - else both:
def takeMax (l: List[Student]) : List [Student] = l.size match {
case 0 => Nil
case 1 => l
case 2 => if (l(0).grade > l(1).grade) List (l(0)) else
if (l(0).grade < l(1).grade) List (l(1)) else List (l(0), l(1))
case _ => {
val left = takeMax (l.take (l.size / 2))
val right= takeMax (l.drop (l.size / 2))
if (left (0).grade > right(0).grade) left else
if (left (0).grade < right(0).grade) right else left ::: right }
}
Of course we like to factor out the student, and the method to compare two of them.
def takeMax [A] (l: List[A], cmp: ((A, A) => Int)) : List [A] = l.size match {
case 0 | 1 => l
case 2 => cmp (l(0), l(1)) match {
case 0 => l
case x => if (x > 0) List (l(0)) else List (l(1))
}
case _ => {
val left = takeMax (l.take (l.size / 2), cmp)
val right= takeMax (l.drop (l.size / 2), cmp)
cmp (left (0), right (0)) match {
case 0 => left ::: right
case x => if (x > 0) left else right }
}
}
def cmp (s1: Student, s2: Student) = s1.grade - s2.grade
takeMax (list, cmp)

What is Scala way of finding whether all the elements of an Array has same length?

I am new to Scala and but very old to Java and had some understanding working with FP languages like "Haskell".
Here I am wondering how to implement this using Scala. There is a list of elements in an array all of them are strings and I just want to know if there is a way I can do this in Scala in a FP way. Here is my current version which works...
def checkLength(vals: Array[String]): Boolean = {
var len = -1
for(x <- conts){
if(len < 0)
len = x.length()
else{
if (x.length() != len)
return false
else
len = x.length()
}
}
return true;
}
And I am pretty sure there is a better way of doing this in Scala/FP...
list.forall( str => str.size == list(0).size )
Edit: Here's a definition that's as general as possilbe and also allows to check whether a property other than length is the same for all elements:
def allElementsTheSame[T,U](f: T => U)(list: Seq[T]) = {
val first: Option[U] = list.headOption.map( f(_) )
list.forall( f(_) == first.get ) //safe to use get here!
}
type HasSize = { val size: Int }
val checkLength = allElementsTheSame((x: HasSize) => x.size)_
checkLength(Array( "123", "456") )
checkLength(List( List(1,2), List(3,4) ))
Since everyone seems to be so creative, I'll be creative too. :-)
def checkLength(vals: Array[String]): Boolean = vals.map(_.length).removeDuplicates.size <= 1
Mind you, removeDuplicates will likely be named distinct on Scala 2.8.
Tip: Use forall to determine whether all elements in the collection do satisfy a certain predicate (e.g. equality of length).
If you know that your lists are always non-empty, a straight forall works well. If you don't, it's easy to add that in:
list match {
case x :: rest => rest forall (_.size == x.size)
case _ => true
}
Now lists of length zero return true instead of throwing exceptions.
list.groupBy{_.length}.size == 1
You convert the list into a map of groups of equal length strings. If all the strings have the same length, then the map will hold only one such group.
The nice thing with this solution is that you don't need to know anything about the length of the strings, and don't need to comapre them to, say, the first string. It works well on an empty string, in which case it returns false (if that's what you want..)
Here's another approach:
def check(list:List[String]) = list.foldLeft(true)(_ && list.head.length == _.length)
Just my €0.02
def allElementsEval[T, U](f: T => U)(xs: Iterable[T]) =
if (xs.isEmpty) true
else {
val first = f(xs.head)
xs forall { f(_) == first }
}
This works with any Iterable, evaluates f the minimum number of times possible, and while the block can't be curried, the type inferencer can infer the block parameter type.
"allElementsEval" should "return true for an empty Iterable" in {
allElementsEval(List[String]()){ x => x.size } should be (true)
}
it should "eval the function at each item" in {
allElementsEval(List("aa", "bb", "cc")) { x => x.size } should be (true)
allElementsEval(List("aa", "bb", "ccc")) { x => x.size } should be (false)
}
it should "work on Vector and Array as well" in {
allElementsEval(Vector("aa", "bb", "cc")) { x => x.size } should be (true)
allElementsEval(Vector("aa", "bb", "ccc")) { x => x.size } should be (false)
allElementsEval(Array("aa", "bb", "cc")) { x => x.size } should be (true)
allElementsEval(Array("aa", "bb", "ccc")) { x => x.size } should be (false)
}
It's just a shame that head :: tail pattern matching fails so insidiously for Iterables.