Scala pattern matching with disjunctions not working - scala

I am learning Scala and don't understand why the following is not working.
I want to refactor a (tested) mergeAndCount function which is part of a counting inversions algorithm to utilize pattern matching. Here is the unrefactored method:
def mergeAndCount(b: Vector[Int], c: Vector[Int]): (Int, Vector[Int]) = {
if (b.isEmpty && c.isEmpty)
(0, Vector())
else if (!b.isEmpty && (c.isEmpty || b.head < c.head)) {
val (count, r) = mergeAndCount(b drop 1, c)
(count, b.head +: r)
} else {
val (count, r) = mergeAndCount(b, c drop 1)
(count + b.length, c.head +: r)
}
}
Here is my refactored method mergeAndCount2. Which is working fine.
def mergeAndCount2(b: Vector[Int], c: Vector[Int]): (Int, Vector[Int]) = (b, c) match {
case (Vector(), Vector()) =>
(0, Vector())
case (bh +: br, Vector()) =>
val (count, r) = mergeAndCount2(br, c)
(count, bh +: r)
case (bh +: br, ch +: cr) if bh < ch =>
val (count, r) = mergeAndCount2(br, c)
(count, bh +: r)
case (_, ch +: cr) =>
val (count, r) = mergeAndCount2(b, cr)
(count + b.length, ch +: r)
}
However as you can see the second and third case are duplicate code. I therefore wanted to combine them using the disjunction like this:
case (bh +: br, Vector()) | (bh +: br, ch +: cr) if bh < ch =>
val (count, r) = mergeAndCount2(br, c)
(count, bh +: r)
This gives me an error though (on the case line): illegal variable in pattern alternative.
What am I doing wrong?
Any help (also on style) is greatly appreciated.
Update: thanks to your suggestions here is my result:
#tailrec
def mergeAndCount3(b: Vector[Int], c: Vector[Int], acc : (Int, Vector[Int])): (Int, Vector[Int]) = (b, c) match {
case (Vector(), Vector()) =>
acc
case (bh +: br, _) if c.isEmpty || bh < c.head =>
mergeAndCount3(br, c, (acc._1, acc._2 :+ bh))
case (_, ch +: cr) =>
mergeAndCount3(b, cr, (acc._1 + b.length, acc._2 :+ ch))
}

When pattern matching with pipe (|) you are not allowed to bind any variable other than wildcard (_).
This is easy to understand: in the body of your case, what would be the actual type of bh or br for example if your two alternatives match different types?
Edit - from the scala reference:
8.1.11 Pattern Alternatives Syntax: Pattern ::= Pattern1 { ‘|’ Pattern1 } A pattern alternative p 1 | . . . | p n consists of a
number of alternative patterns p i . All alternative patterns are type
checked with the expected type of the pattern. They may no bind
variables other than wildcards. The alternative pattern matches a
value v if at least one its alternatives matches v.
Edit after first comment - you can use the wildcard to match something like this for example:
try {
...
} catch {
case (_: NullPointerException | _: IllegalArgumentException) => ...
}

If you think about that, looking at your case clause, how should the compiler know if in the case body it should be allowed to use ch and cr or not?
This sort of questions make it very hard to make the compiler support disjunction and variable binding in the same case clause, thus this is not allowed at all.
Your mergeAndCount2 function looks quite fine with respect to pattern matching. I think that its most evident problem is not being tail-recursive and thus not running in constant stack space. If you can solve this problem you will probably end with something that is less repetitive as well.

You can rewrite the case expression and move the disjunction to the if part
case (bh +: br, cr) if cr.isEmpty || bh < cr.head =>
val (count, r) = mergeAndCount2(br, c)
(count, bh +: r)
Update:
You can yet simplify a little bit:
#tailrec
def mergeAndCount3(b: Vector[Int], c: Vector[Int],
count: Int = 0, r: Vector[Int] = Vector()): (Int, Vector[Int]) =
(b, c) match {
case (bh +: br, _) if c.isEmpty || bh < c.head =>
mergeAndCount3(br, c, count, bh +: r)
case (_, ch +: cr) =>
mergeAndCount3(b, cr, count + b.length, ch +: r)
case _ => (count, r)
}

Related

Scala fold right and fold left

I am trying to learn functional programming and Scala, so I'm reading the "Functional Programming in Scala" by Chiusano and Bjarnason. I' m having trouble understanding what fold left and fold right methods do in case of a list. I've looked around here but I haven't find something beginner friendly. So the code provided by the book is:
def foldRight[A,B](as: List[A], z: B)(f: (A, B) => B): B = as match {
case Nil => z
case Cons(h, t) => f(h, foldRight(t, z)(f))
}
def foldLeft[A,B](l: List[A], z: B)(f: (B, A) => B): B = l match {
case Nil => z
case Cons(h,t) => foldLeft(t, f(z,h))(f)
}
Where Cons and Nil are:
case class Cons[+A](head: A, tail: List[A]) extends List[A]
case object Nil extends List[Nothing]
So what do actually fold left and right do? Why are needed as "utility" methods? There are many other methods that use them and I have trouble to understand them as well, since I don't get those two.
According to my experience, one of the best ways to workout the intuition is to see how it works on the very simple examples:
List(1, 3, 8).foldLeft(100)(_ - _) == ((100 - 1) - 3) - 8 == 88
List(1, 3, 8).foldRight(100)(_ - _) == 1 - (3 - (8 - 100)) == -94
As you can see, foldLeft/Right just passes the element of the list and the result of the previous application to the the operation in second parentheses.
It should be also mentioned that if you apply these methods to the same list, they will return equal results only if the applied operation is associative.
Say you have a list of numbers, and you want to add them all up. How would you do that?
You add the first and the second, then take the result of that, add that to the third, take the result of that, add it to the fourth.. and so on.
That's what fold let's you do.
List(1,2,3,4,5).foldLeft(0)(_ + _)
The "+" is the function you want to apply, with the first operand being the result of its application to the elements so far, and the second operand being the next element.
As you don't have a "result so far" for the first application, you provide a start value - in this case 0, as it is the identity element for addition.
Say you want to multiply all of your list elements, with fold, that'd be
List(1,2,3,4,5).foldLeft(1)(_ * _)
Fold has it's own Wikipedia page you might want to check.
Of course there are also ScalaDoc entries for foldLeft and foldRight.
Another way of visualisation of leftFold and rightFold in Scala is through string concatenation, its clearly show how leftFold and rightFold worked, let's see the below example:
val listString = List("a", "b", "c") // : List[String] = List(a,b,c)
val leftFoldValue = listString.foldLeft("z")((el, acc) => el + acc) // : String = zabc
val rightFoldValue = listString.foldRight("z")((el, acc) => el + acc) // : abcz
OR in shorthand ways
val leftFoldValue = listString.foldLeft("z")(_ + _) // : String = zabc
val rightFoldValue = listString.foldRight("z")(_ + _) // : String = abcz
Explanation:
leftFold is worked as ( ( ('z' + 'a') + 'b') + 'c') = ( ('za' + 'b') + 'c') = ('zab' + 'c') = 'zabc'
and rightFold as ('a' + ('b' + ('c' + 'z'))) = ('a' + ('b' + 'cz')) = ('a' + 'bcz') = 'abcz'

Scala: How to map a subset of a seq to a shorter seq

I am trying to map a subset of a sequence using another (shorter) sequence while preserving the elements that are not in the subset. A toy example below tries to give a flower to females only:
def giveFemalesFlowers(people: Seq[Person], flowers: Seq[Flower]): Seq[Person] = {
require(people.count(_.isFemale) == flowers.length)
magic(people, flowers)(_.isFemale)((p, f) => p.withFlower(f))
}
def magic(people: Seq[Person], flowers: Seq[Flower])(predicate: Person => Boolean)
(mapping: (Person, Flower) => Person): Seq[Person] = ???
Is there an elegant way to implement the magic?
Use an iterator over flowers, consume one each time the predicate holds; the code would look like this,
val it = flowers.iterator
people.map ( p => if (predicate(p)) p.withFlowers(it.next) else p )
What about zip (aka zipWith) ?
scala> val people = List("m","m","m","f","f","m","f")
people: List[String] = List(m, m, m, f, f, m, f)
scala> val flowers = List("f1","f2","f3")
flowers: List[String] = List(f1, f2, f3)
scala> def comb(xs:List[String],ys:List[String]):List[String] = (xs,ys) match {
| case (x :: xs, y :: ys) if x=="f" => (x+y) :: comb(xs,ys)
| case (x :: xs,ys) => x :: comb(xs,ys)
| case (Nil,Nil) => Nil
| }
scala> comb(people, flowers)
res1: List[String] = List(m, m, m, ff1, ff2, m, ff3)
If the order is not important, you can get this elegant code:
scala> val (men,women) = people.partition(_=="m")
men: List[String] = List(m, m, m, m)
women: List[String] = List(f, f, f)
scala> men ++ (women,flowers).zipped.map(_+_)
res2: List[String] = List(m, m, m, m, ff1, ff2, ff3)
I am going to presume you want to retain all the starting people (not simply filter out the females and lose the males), and in the original order, too.
Hmm, bit ugly, but what I came up with was:
def giveFemalesFlowers(people: Seq[Person], flowers: Seq[Flower]): Seq[Person] = {
require(people.count(_.isFemale) == flowers.length)
people.foldLeft((List[Person]() -> flowers)){ (acc, p) => p match {
case pp: Person if pp.isFemale => ( (pp.withFlower(acc._2.head) :: acc._1) -> acc._2.tail)
case pp: Person => ( (pp :: acc._1) -> acc._2)
} }._1.reverse
}
Basically, a fold-left, initialising the 'accumulator' with a pair made up of an empty list of people and the full list of flowers, then cycling through the people passed in.
If the current person is female, pass it the head of the current list of flowers (field 2 of the 'accumulator'), then set the updated accumulator to be the updated person prepended to the (growing) list of processed people, and the tail of the (shrinking) list of flowers.
If male, just prepend to the list of processed people, leaving the flowers unchanged.
By the end of the fold, field 2 of the 'accumulator' (the flowers) should be an empty list, while field one holds all the people (with any females having each received their own flower), in reverse order, so finish with ._1.reverse
Edit: attempt to clarify the code (and substitute a test more akin to #elm's to replace the match, too) - hope that makes it clearer what is going on, #Felix! (and no, no offence taken):
def giveFemalesFlowers(people: Seq[Person], flowers: Seq[Flower]): Seq[Person] = {
require(people.count(_.isFemale) == flowers.length)
val start: (List[Person], Seq[Flower]) = (List[Person](), flowers)
val result: (List[Person], Seq[Flower]) = people.foldLeft(start){ (acc, p) =>
val (pList, fList) = acc
if (p.isFemale) {
(p.withFlower(fList.head) :: pList, fList.tail)
} else {
(p :: pList, fList)
}
}
result._1.reverse
}
I'm obviously missing something but isn't it just
people map {
case p if p.isFemale => p.withFlower(f)
case p => p
}

Scala Pattern Matching Enigma

Here's my attempt of the 3rd problem (P03) of the 99 Problems in Scala (http://aperiodic.net/phil/scala/s-99/):
import scala.annotation._
// Find the nth element of a list.
// nth(2, List(1, 1, 2, 3, 5, 8)) = 2
object P03 {
#tailrec def nth[A](n: Int, ls: List[A]): A = (n, ls) match {
case (0, h :: t :: Nil) => h
case (n, _ :: t) => nth(n - 1, t)
case _ => println(n); throw new IllegalArgumentException
}
The enigma is that this code prints -4 and throws an IllegalArgumentException
The solution of course is to change the first pattern to:
case (0, h :: _) => h
This now prints the correct answer 2
Question is Why? What is the subtle difference between:
case (0, h :: t :: Nil) => h
&
case (0, h :: _) => h
Thanks!
The difference is that h :: t :: Nil matches only a list with two elements (h and t, Nil is the marker for the end of a list (I'm not 100% sure it's the exact nomenclature)) while h :: _ matches every non empty list, ie a list that has at least one element, if you check the :: class you'll see:
final case class ::[B](private var hd: B, private[scala] var tl: List[B]) extends List[B]
Which has a head and a tail where the first is the first element of the list and the second is the rest, matching on h :: t :: Nil means getting the first element of the list, than the first of the tail and then there should be a Nil, matching on h :: _ means getting the head and then you don't care of what's left as long as there's a head.

How to "find" sequence element and predicate result at once?

Suppose I have a function f(n:Int):Option[String]. I would like to find such 1 <= k <= 10 that f(k) is not None. I can code it as follows: (1 to 10).find(k => f(k).isDefined)
Now I would like to know both k and f(k). val k = (1 to 10).find(f(_).isDefined)
val s = f(k)
Unfortunately, this code invokes f(k) twice. How would you find k and f(k) at once ?
My first try would be:
(1 to 10).view map {k => (k, f(k))} find {_._2.isDefined}
The use of view avoids creating intermediate map. Or even better with pattern matching and partial function:
(1 to 10).view map {k => (k, f(k))} collectFirst {case (k, Some(v)) => (k, v)}
This returns Option[(Int, java.lang.String)] (None if no element satisfying f is found).
You might also experiment with .zipWithIndex.
A bit shorter - just map and find:
// for testing
def f (n: Int): Option [String] =
if (n > 0) Some ((List.fill (n) ("" + n)).mkString) else None
(-5 to 5).map (i => (i, f(i))).find (e => e._2 != None)
// result in REPL
res67: Option[(Int, Option[String])] = Some((1,Some(1)))
A slightly more verbose version of Tomasz Nurkiewicz's solution:
xs = (1 to 10).view
xs zip { xs map { f(_) } } collectFirst { case (k, Some(v)) => (k, v) }

Stable identifier required during pattern matching? (Scala)

Trying to produce a list of tuples showing prime factor multiplicity... the idea is to match each integer in a sorted list against the first value in a tuple, using the second value to count. Could probably do it more easily with takeWhile, but meh. Unfortunately my solution won't compile:
def primeFactorMultiplicity (primeFactors: List[Int]) = {
primeFactors.foldRight (List[(Int, Int)]()) ((a, b) => (a, b) match {
case (_, Nil) => (a, 1) :: b
case (b.head._1, _) => (a, b.head._2 + 1) :: b.tail
case _ => (a, 1) :: b
})
}
It says "error: stable identifier required, but b.head._1 found." But changing the second case line to the following works fine:
case (i, _) if (i == b.head._1) => (a, b.head._2 + 1) :: b.tail
Why is this, and why can't the compiler cope if there is such a simple fix?
A variable in a pattern captures the value in that position; it does not do a comparison. If the syntax worked at all, the result would be to put the value of a into b.head._1, overwriting the current value. The purpose of this is to let you use a pattern to pull something out of a complex structure.
b.head._1 is not a valid name for the result of the (x, y) tuple extractor
Try this instead:
case (x, _) if x == b.head._1 => (a, b.head._2 + 1) :: b.tail