Replace if-else with functional code - scala

I know that in functional style all if-else blocks replaced by pattern matching. But how I can handle Maps with pattern matching in Scala? For example how I can rewrite this code in more functional style?
val myMap= getMap()
if(myMap.contains(x) && myMap.contains(y)) Some(myMap(x) + myMap(y))
else if(myMap.contains(x + y)){
val z = myMap(x + y)
if (z % 2 == 0) Some(z)
else None
}
else None

Using if-else is totally acceptable for functional programming, because if-else in Scala is merely an expression. The reasons to decide between if-else and pattern-matching should be focused on improving readability, mainly.
Here's my try at rewriting your code. I'm actually not using pattern matching here, but a for-comprehension to sum the values.
def sumOfValues = for{
mx <- myMap.get(x)
my <- myMap.get(y)
} yield mx + my
def valueOfSumIfEven = myMap.get(x+y).filter(_ % 2 == 0)
sumOfValues orElse valueOfSumIfEven

To answer your question directly - you can filter the cases with an additional if:
getMap match {
case myMap if (myMap.contains(x) && myMap.contains(y)) =>
Some(myMap(x) + myMap(y))
case myMap if (myMap.contains(x+y)) =>
myMap(x+y) match {
case z if (z % 2 == 0) => Some(z)
case _ => None
}
case _ => None
}
([edit: there actually is] Although there is are "else-if's" in Scala, ) this is actually a way of doing if-else-if-else (looking at the produced class-files this is what it actually does whereas if-else is equivalent to the ternary ?: in java, returning Unit implicitly when the final else is missing).

How about:
myMap.get(x)
.zip(myMap.get(y))
.headOption.map(x => x._1 + x._2)
.orElse(myMap.get(x + y)
.filter(__ % 2 == 0))

Well, you could write your test as follows:
myMap.get(x).flatMap(xVal => myMap.get(y).map(_ + xVal))
.orElse(myMap.get(x+y).filter(_ % 2 == 0))
But what you have already may just be clearer to anyone trying to understand the intent. Note that the first line (from flatMap to the end) can also be written as a for-comprehension, as shown in #ziggystar's answer).
Maybe the modulo test part could be rewritten as a match, if that feels cleaner:
if(myMap.contains(x) && myMap.contains(y))
Some(myMap(x) + myMap(y))
else myMap.get(x + y) match {
case Some(z) if (z % 2 == 0) => Some(z)
case _ => None
}

Related

Use foldLeft to replace occurrences of character in String

I'm trying to learn functional Scala and working on a simple problem - replace occurrences of \' or \\ contained in a String:
Here is my code so far:
val data : String = "\' this is a test \\ "
data.toCharArray.foldLeft(""){ (x, y) => x match {
case Nil => y :: Nil
case head :: tail =>
if head == '\'' ''
else if head == '\\' ''
else head :: tail
}
There are multiple errors:
I've not understood something fundamental with fold?
Simple examples of foldLeft such as:
val sum = prices.foldLeft(0.0)(_ + _)
are understandable but I'm unsure how to use foldLeft in a context where there is conditions. In the problem I posted the condition being matching on a character.
There are several issues here, starting with some syntactic problems, like missing parentheses around the conditionals. The first real substantive issue is that the initial value (the "" in foldLeft("")) must be the same type as the accumulator, and as the return type. You seem to want a List[Char] as the return type, so you'll need to use something like List.empty[Char] as the initial value.
Next I'd strongly recommend using names like acc and c instead of x and y to indicate more clearly which is the accumulator and which is the current value.
Another issue is that '' also isn't valid Scala syntax—there is no empty character literal. I'll use '_' as the replacement just for the sake of example.
A working implementation might look like this:
val data: String = "\' this is a test \\ "
data.toCharArray.foldLeft(List.empty[Char]) { (acc, c) =>
c match {
case '\'' => acc :+ '_'
case '\\' => acc :+ '_'
case other => acc :+ other
}
}
Which yields:
val data: String = "' this is a test \ "
val res1: List[Char] = List(_, , t, h, i, s, , i, s, , a, , t, e, s, t, , _, )
Which I think is what you're aiming for?
As a footnote, I'm assuming this is just an exercise, but it's worth noting that using a left fold for an operation like this is extremely inefficient, since you're building up a list by appending.
There are several errors in this code:
you haven't closed lambda's bracket
you use List pattern matching on... well string because
x here is result so far (so "" initially) and y are elements of data (chars)
This code should look like this:
val data : String = "\' this is a test \\ "
data.toCharArray.foldLeft("") { (result, ch) =>
if (ch == '\'' || ch == '\\') result
else result + ch
}

Scala for expressions in tail recursive form

I have a bit of a problem trying to come up with a valid way to convert a for - expression N queens solution to a tail recursive form and still preserve the idiomatic nature achieved by using the for syntax. Any ideas are more than welcome.
def place(boardSize: Int, n: Int): Solutions = n match {
case 0 => List(Nil)
case _ =>
for {
queens <- place(boardSize, n - 1)
y <- 1 to boardSize
queen = (n, y)
if (isSafe(queen, queens))
} yield queen :: queens
}
def isSafe(queen: Queen, others: List[Queen]) = {...}
What you're writing basically corresponds to what's called Depth-First Search (DFS).
Although a recursive implementation of DFS is easily written, it is not tail-recursive. Here's a proposal for a tail-recursive one. Note that I did not test this code, but it should at least give you an idea of how to proceed.
def solve(): List[List[Int]] = {
#tailrec def solver(fringe: List[List[Int]], solutions: List[List[Int]]): List[List[Int]] = fringe match {
case Nil => solutions
case potentialSol :: fringeTail =>
if(potentialSol.length == n) // We found a solution
solver(fringe.tail, potentialSol.reverse :: solutions)
else { // Keep looking
val unused = (1 to n).toList filterNot potentialSol.contains
val children = for(u <- unused ; partial = u :: fringe.head if isValid(partial)) yield partial
solver(children ++ fringe.tail, solutions)
}
}
solver((1 to n).toList.map(List(_)), Nil).map(_.reverse)
}
If you're concerned about performances, note that this solution is very poor because it uses slow operations on immutable data structure, and because on the JVM you're better off using iteration where performance matters. This will start failing quite rapidly as n increases. Algorithmically, there are far better ways to solve NQueens than using DFS.

Scala: transform a collection, yielding 0..many elements on each iteration

Given a collection in Scala, I'd like to traverse this collection and for each object I'd like to emit (yield) from 0 to multiple elements that should be joined together into a new collection.
For example, I expect something like this:
val input = Range(0, 15)
val output = input.somefancymapfunction((x) => {
if (x % 3 == 0)
yield(s"${x}/3")
if (x % 5 == 0)
yield(s"${x}/5")
})
to build an output collection that will contain
(0/3, 0/5, 3/3, 5/5, 6/3, 9/3, 10/5, 12/3)
Basically, I want a superset of what filter (1 → 0..1) and map (1 → 1) allows to do: mapping (1 → 0..n).
Solutions I've tried
Imperative solutions
Obviously, it's possible to do so in non-functional maneer, like:
var output = mutable.ListBuffer()
input.foreach((x) => {
if (x % 3 == 0)
output += s"${x}/3"
if (x % 5 == 0)
output += s"${x}/5"
})
Flatmap solutions
I know of flatMap, but it again, either:
1) becomes really ugly if we're talking about arbitrary number of output elements:
val output = input.flatMap((x) => {
val v1 = if (x % 3 == 0) {
Some(s"${x}/3")
} else {
None
}
val v2 = if (x % 5 == 0) {
Some(s"${x}/5")
} else {
None
}
List(v1, v2).flatten
})
2) requires usage of mutable collections inside it:
val output = input.flatMap((x) => {
val r = ListBuffer[String]()
if (x % 3 == 0)
r += s"${x}/3"
if (x % 5 == 0)
r += s"${x}/5"
r
})
which is actually even worse that using mutable collection from the very beginning, or
3) requires major logic overhaul:
val output = input.flatMap((x) => {
if (x % 3 == 0) {
if (x % 5 == 0) {
List(s"${x}/3", s"${x}/5")
} else {
List(s"${x}/3")
}
} else if (x % 5 == 0) {
List(s"${x}/5")
} else {
List()
}
})
which is, IMHO, also looks ugly and requires duplicating the generating code.
Roll-your-own-map-function
Last, but not least, I can roll my own function of that kind:
def myMultiOutputMap[T, R](coll: TraversableOnce[T], func: (T, ListBuffer[R]) => Unit): List[R] = {
val out = ListBuffer[R]()
coll.foreach((x) => func.apply(x, out))
out.toList
}
which can be used almost like I want:
val output = myMultiOutputMap[Int, String](input, (x, out) => {
if (x % 3 == 0)
out += s"${x}/3"
if (x % 5 == 0)
out += s"${x}/5"
})
Am I really overlooking something and there's no such functionality in standard Scala collection libraries?
Similar questions
This question bears some similarity to Can I yield or map one element into many in Scala? — but that question discusses 1 element → 3 elements mapping, and I want 1 element → arbitrary number of elements mapping.
Final note
Please note that this is not the question about division / divisors, such conditions are included purely for illustrative purposes.
Rather than having a separate case for each divisor, put them in a container and iterate over them in a for comprehension:
val output = for {
n <- input
d <- Seq(3, 5)
if n % d == 0
} yield s"$n/$d"
Or equivalently in a collect nested in a flatMap:
val output = input.flatMap { n =>
Seq(3, 5).collect {
case d if n % d == 0 => s"$n/$d"
}
}
In the more general case where the different cases may have different logic, you can put each case in a separate partial function and iterate over the partial functions:
val output = for {
n <- input
f <- Seq[PartialFunction[Int, String]](
{case x if x % 3 == 0 => s"$x/3"},
{case x if x % 5 == 0 => s"$x/5"})
if f.isDefinedAt(n)
} yield f(n)
You can also use some functional library (e.g. scalaz) to express this:
import scalaz._, Scalaz._
def divisibleBy(byWhat: Int)(what: Int): List[String] =
(what % byWhat == 0).option(s"$what/$byWhat").toList
(0 to 15) flatMap (divisibleBy(3) _ |+| divisibleBy(5))
This uses the semigroup append operation |+|. For Lists this operation means a simple list concatenation. So for functions Int => List[String], this append operation will produce a function that runs both functions and appends their results.
If you have complex computation, during which you should sometimes add some elements to operation global accumulator, you can use popular approach named Writer Monad
Preparation in scala is somewhat bulky but results are extremely composable thanks to Monad interface
import scalaz.Writer
import scalaz.syntax.writer._
import scalaz.syntax.monad._
import scalaz.std.vector._
import scalaz.syntax.traverse._
type Messages[T] = Writer[Vector[String], T]
def yieldW(a: String): Messages[Unit] = Vector(a).tell
val output = Vector.range(0, 15).traverse { n =>
yieldW(s"$n / 3").whenM(n % 3 == 0) >>
yieldW(s"$n / 5").whenM(n % 5 == 0)
}.run._1
Here is my proposition for a custom function, might be better with pimp my library pattern
def fancyMap[A, B](list: TraversableOnce[A])(fs: (A => Boolean, A => B)*) = {
def valuesForElement(elem: A) = fs collect { case (predicate, mapper) if predicate(elem) => mapper(elem) }
list flatMap valuesForElement
}
fancyMap[Int, String](0 to 15)((_ % 3 == 0, _ + "/3"), (_ % 5 == 0, _ + "/5"))
You can try collect:
val input = Range(0,15)
val output = input.flatMap { x =>
List(3,5) collect { case n if (x%n == 0) => s"${x}/${n}" }
}
System.out.println(output)
I would us a fold:
val input = Range(0, 15)
val output = input.foldLeft(List[String]()) {
case (acc, value) =>
val acc1 = if (value % 3 == 0) s"$value/3" :: acc else acc
val acc2 = if (value % 5 == 0) s"$value/5" :: acc1 else acc1
acc2
}.reverse
output contains
List(0/3, 0/5, 3/3, 5/5, 6/3, 9/3, 10/5, 12/3)
A fold takes an accumumlator (acc), a collection, and a function. The function is called with the initial value of the accumumator, in this case an empty List[String], and each value of the collection. The function should return an updated collection.
On each iteration, we take the growing accumulator and, if the inside if statements are true, prepend the calculation to the new accumulator. The function finally returns the updated accumulator.
When the fold is done, it returns the final accumulator, but unfortunately, it is in reverse order. We simply reverse the accumulator with .reverse.
There is a nice paper on folds: A tutorial on the universality and expressiveness of fold, by Graham Hutton.

Scala: Implementing a mathematical recursion with convergence

Many numerical problems are of the form:
initialize: x_0 = ...
iterate: x_i+1 = function(x_i) until convergence, e.g.,
|| x_i+1 - x_i || < epsilon
I'm wondering whether there is a nice way to write such an algorithm using idiomatic Scala. The nature of the problem calls for an Iterator or Stream. However, my current take on this looks really ugly:
val xFinal = Iterator.iterate(xInit) { x_i =>
// update x_i+1
}.toList // necessary to pattern match within takeWhile
.sliding(2) // necessary since takeWhile needs pair-wise comparison
.takeWhile{ case x_i :: x_iPlus1 :: Nil => /* convergence condition */ }
.toList // since the outer container is still an Iterator
.last // to get the last element of the iteration
.last // to get x_iPlus1
This is not only ugly, the pattern matching in takeWhile also causes a warning. Obviously I do not have to pattern-match here, but I would love to keep a strong resemblance to the mathematical original.
Any ideas to make this look more beautiful?
The following minimalist (silly) example may illustrate none the less a useful framework to adapt,
def function (i:Int): Int = i+1
def iter (x0: Int): Int = {
val x1 = function(x0)
if (x1 - x0 == 1) x1 else iter(x1)
}
Here is my solution for the example of finding the square root using Newton's method, which reduces in this case to the Babylonian method:
import math.abs
val tol=0.00001
val desiredSqRoot=256
val xFinal = Iterator.iterate(1.0) { x => 0.5*(x+desiredSqRoot/x) }
def converged(l: Seq[Double]): Boolean = l match{
case x_old :: x_new :: Nil => if( abs(x_old-x_new)/x_old < tol ) true else false
case _ => true
}
xFinal.sliding(2).dropWhile( x=> !converged(x) ).next.last
which results as:
scala> xFinal.sliding(2).dropWhile( x=> !converged(x) ).next.last
res23: Double = 16.00000000000039
In this example we know the value it should converge to, yet I've written the convergence criterion without this knowledge, because in general we don't know this.

Convert normal recursion to tail recursion

I was wondering if there is some general method to convert a "normal" recursion with foo(...) + foo(...) as the last call to a tail-recursion.
For example (scala):
def pascal(c: Int, r: Int): Int = {
if (c == 0 || c == r) 1
else pascal(c - 1, r - 1) + pascal(c, r - 1)
}
A general solution for functional languages to convert recursive function to a tail-call equivalent:
A simple way is to wrap the non tail-recursive function in the Trampoline monad.
def pascalM(c: Int, r: Int): Trampoline[Int] = {
if (c == 0 || c == r) Trampoline.done(1)
else for {
a <- Trampoline.suspend(pascal(c - 1, r - 1))
b <- Trampoline.suspend(pascal(c, r - 1))
} yield a + b
}
val pascal = pascalM(10, 5).run
So the pascal function is not a recursive function anymore. However, the Trampoline monad is a nested structure of the computation that need to be done. Finally, run is a tail-recursive function that walks through the tree-like structure, interpreting it, and finally at the base case returns the value.
A paper from Rúnar Bjanarson on the subject of Trampolines: Stackless Scala With Free Monads
In cases where there is a simple modification to the value of a recursive call, that operation can be moved to the front of the recursive function. The classic example of this is Tail recursion modulo cons, where a simple recursive function in this form:
def recur[A](...):List[A] = {
...
x :: recur(...)
}
which is not tail recursive, is transformed into
def recur[A]{...): List[A] = {
def consRecur(..., consA: A): List[A] = {
consA :: ...
...
consrecur(..., ...)
}
...
consrecur(...,...)
}
Alexlv's example is a variant of this.
This is such a well known situation that some compilers (I know of Prolog and Scheme examples but Scalac does not do this) can detect simple cases and perform this optimisation automatically.
Problems combining multiple calls to recursive functions have no such simple solution. TMRC optimisatin is useless, as you are simply moving the first recursive call to another non-tail position. The only way to reach a tail-recursive solution is remove all but one of the recursive calls; how to do this is entirely context dependent but requires finding an entirely different approach to solving the problem.
As it happens, in some ways your example is similar to the classic Fibonnaci sequence problem; in that case the naive but elegant doubly-recursive solution can be replaced by one which loops forward from the 0th number.
def fib (n: Long): Long = n match {
case 0 | 1 => n
case _ => fib( n - 2) + fib( n - 1 )
}
def fib (n: Long): Long = {
def loop(current: Long, next: => Long, iteration: Long): Long = {
if (n == iteration)
current
else
loop(next, current + next, iteration + 1)
}
loop(0, 1, 0)
}
For the Fibonnaci sequence, this is the most efficient approach (a streams based solution is just a different expression of this solution that can cache results for subsequent calls). Now,
you can also solve your problem by looping forward from c0/r0 (well, c0/r2) and calculating each row in sequence - the difference being that you need to cache the entire previous row. So while this has a similarity to fib, it differs dramatically in the specifics and is also significantly less efficient than your original, doubly-recursive solution.
Here's an approach for your pascal triangle example which can calculate pascal(30,60) efficiently:
def pascal(column: Long, row: Long):Long = {
type Point = (Long, Long)
type Points = List[Point]
type Triangle = Map[Point,Long]
def above(p: Point) = (p._1, p._2 - 1)
def aboveLeft(p: Point) = (p._1 - 1, p._2 - 1)
def find(ps: Points, t: Triangle): Long = ps match {
// Found the ultimate goal
case (p :: Nil) if t contains p => t(p)
// Found an intermediate point: pop the stack and carry on
case (p :: rest) if t contains p => find(rest, t)
// Hit a triangle edge, add it to the triangle
case ((c, r) :: _) if (c == 0) || (c == r) => find(ps, t + ((c,r) -> 1))
// Triangle contains (c - 1, r - 1)...
case (p :: _) if t contains aboveLeft(p) => if (t contains above(p))
// And it contains (c, r - 1)! Add to the triangle
find(ps, t + (p -> (t(aboveLeft(p)) + t(above(p)))))
else
// Does not contain(c, r -1). So find that
find(above(p) :: ps, t)
// If we get here, we don't have (c - 1, r - 1). Find that.
case (p :: _) => find(aboveLeft(p) :: ps, t)
}
require(column >= 0 && row >= 0 && column <= row)
(column, row) match {
case (c, r) if (c == 0) || (c == r) => 1
case p => find(List(p), Map())
}
}
It's efficient, but I think it shows how ugly complex recursive solutions can become as you deform them to become tail recursive. At this point, it may be worth moving to a different model entirely. Continuations or monadic gymnastics might be better.
You want a generic way to transform your function. There isn't one. There are helpful approaches, that's all.
I don't know how theoretical this question is, but a recursive implementation won't be efficient even with tail-recursion. Try computing pascal(30, 60), for example. I don't think you'll get a stack overflow, but be prepared to take a long coffee break.
Instead, consider using a Stream or memoization:
val pascal: Stream[Stream[Long]] =
(Stream(1L)
#:: (Stream from 1 map { i =>
// compute row i
(1L
#:: (pascal(i-1) // take the previous row
sliding 2 // and add adjacent values pairwise
collect { case Stream(a,b) => a + b }).toStream
++ Stream(1L))
}))
The accumulator approach
def pascal(c: Int, r: Int): Int = {
def pascalAcc(acc:Int, leftover: List[(Int, Int)]):Int = {
if (leftover.isEmpty) acc
else {
val (c1, r1) = leftover.head
// Edge.
if (c1 == 0 || c1 == r1) pascalAcc(acc + 1, leftover.tail)
// Safe checks.
else if (c1 < 0 || r1 < 0 || c1 > r1) pascalAcc(acc, leftover.tail)
// Add 2 other points to accumulator.
else pascalAcc(acc, (c1 , r1 - 1) :: ((c1 - 1, r1 - 1) :: leftover.tail ))
}
}
pascalAcc(0, List ((c,r) ))
}
It does not overflow the stack but as on big row and column but Aaron mentioned it's not fast.
Yes it's possible. Usually it's done with accumulator pattern through some internally defined function, which has one additional argument with so called accumulator logic, example with counting length of a list.
For example normal recursive version would look like this:
def length[A](xs: List[A]): Int = if (xs.isEmpty) 0 else 1 + length(xs.tail)
that's not a tail recursive version, in order to eliminate last addition operation we have to accumulate values while somehow, for example with accumulator pattern:
def length[A](xs: List[A]) = {
def inner(ys: List[A], acc: Int): Int = {
if (ys.isEmpty) acc else inner(ys.tail, acc + 1)
}
inner(xs, 0)
}
a bit longer to code, but i think the idea i clear. Of cause you can do it without inner function, but in such case you should provide acc initial value manually.
I'm pretty sure it's not possible in the simple way you're looking for the general case, but it would depend on how elaborate you permit the changes to be.
A tail-recursive function must be re-writable as a while-loop, but try implementing for example a Fractal Tree using while-loops. It's possble, but you need to use an array or collection to store the state for each point, which susbstitutes for the data otherwise stored in the call-stack.
It's also possible to use trampolining.
It is indeed possible. The way I'd do this is to
begin with List(1) and keep recursing till you get to the
row you want.
Worth noticing that you can optimize it: if c==0 or c==r the value is one, and to calculate let's say column 3 of the 100th row you still only need to calculate the first three elements of the previous rows.
A working tail recursive solution would be this:
def pascal(c: Int, r: Int): Int = {
#tailrec
def pascalAcc(c: Int, r: Int, acc: List[Int]): List[Int] = {
if (r == 0) acc
else pascalAcc(c, r - 1,
// from let's say 1 3 3 1 builds 0 1 3 3 1 0 , takes only the
// subset that matters (if asking for col c, no cols after c are
// used) and uses sliding to build (0 1) (1 3) (3 3) etc.
(0 +: acc :+ 0).take(c + 2)
.sliding(2, 1).map { x => x.reduce(_ + _) }.toList)
}
if (c == 0 || c == r) 1
else pascalAcc(c, r, List(1))(c)
}
The annotation #tailrec actually makes the compiler check the function
is actually tail recursive.
It could be probably be further optimized since given that the rows are symmetric, if c > r/2, pascal(c,r) == pascal ( r-c,r).. but left to the reader ;)