Is tail recursion optimization guaranteed in Scala when it is possible? - scala

Suppose I have following code
def foo(x:Int):Unit = {
if (x == 1) println ("done")
else foo(scala.util.Random.nextInt(10))
}
Is it guaranteed that the compiler does tail recursion optimization?

Yes. To know for sure add the #tailrec annotation to your method. This will cause the compiler to throw an error if it does not compile using tail recursion.
#tailrec
def foo(x:Int):Unit = {
if (x == 1) println ("done")
else foo(scala.util.Random.nextInt(10))
}

No, the Unit return type is irrelevant.
scala> #tailrec def f(i: Int) { if (i >= 0) { println(i); f(i - 1) } }
f: (i: Int)Unit
But:
scala> #tailrec def f(i: Int) { if (i >= 0) { f(i - 1); println(".") } }
<console>:11: error: could not optimize #tailrec annotated method f:
it contains a recursive call not in tail position
You need to have the recursive call as the last call, return type does not matter.
Your code in the question is fine but the title of the question would be misleading.

Related

Why return in getOrElse makes tail recursion not possible?

I am confused by following code: the code is artificial, but still I think it is tail recursive. The compiler does not agree and produces an error message:
#annotation.tailrec
def listSize(l : Seq[Any], s: Int = 0): Int = {
if (l.isEmpty) {
None.getOrElse( return s )
}
listSize(l.tail, s + 1)
}
How is the code above making tail recusion not possible? Why is the compiler telling me:
could not optimize #tailrec annotated method listSize: it contains a recursive call not in tail position
A similar code (with return inside of map) compiles fine:
#annotation.tailrec
def listSize(l : Seq[Any], s: Int = 0): Int = {
if (l.isEmpty) {
Some(()).map( return s )
}
listSize(l.tail, s + 1)
}
Even the code obtained by inlining None.isEmpty compiles fine:
#annotation.tailrec
def listSize(l : Seq[Any], s: Int = 0): Int = {
if (l.isEmpty) {
if (None.isEmpty) {
return s
} else None.get
}
listSize(l.tail, s + 1)
}
On the other hand, code with slightly modified map is rejected by the compiler and produces the error:
#annotation.tailrec
def listSize(l : Seq[Any], s: Int = 0): Int = {
if (l.isEmpty) {
Some(()).map( x => return s )
}
listSize(l.tail, s + 1)
}
It happens because the return in your first snippet is a non-local one (it's nested inside a lambda). Scala uses exceptions to compile non-local return expressions, so the code gets transformed by the compiler from this:
#annotation.tailrec
def listSize(l : Seq[Any], s: Int = 0): Int = {
if (l.isEmpty) {
None.getOrElse( return s )
}
listSize(l.tail, s + 1)
}
To something similar to this (compile with scalac -Xprint:tailcalls):
def listSize2(l : Seq[Any], s: Int = 0): Int = {
val key = new Object
try {
if (l.isEmpty) {
None.getOrElse {
throw new scala.runtime.NonLocalReturnControl(key, 0)
}
}
listSize2(l.tail, s + 1)
} catch {
case e: scala.runtime.NonLocalReturnControl[Int #unchecked] =>
if (e.key == key)
e.value
else
throw e
}
}
The last point is that recursive calls are not tail calls when wrapped in try/catch blocks. Basically, this contrieved example:
def self(a: Int): Int = {
try {
self(a)
} catch {
case e: Exception => 0
}
}
Is akin to this, which is obviously not tail-recursive:
def self(a: Int): Int = {
if (self(a)) {
// ...
} else {
// ...
}
}
There are certain particular cases where you can optimize this (down to two stack frames, if not one), but there doesn't seem to exist an universal rule to apply to this kind of situation.
Also, the return expression in this snippet, is not a non-local return, which is why the function can be optimized:
#annotation.tailrec
def listSize(l : Seq[Any], s: Int = 0): Int = {
if (l.isEmpty) {
// `return` happens _before_ map `is` called.
Some(()).map( return s )
}
listSize(l.tail, s + 1)
}
The above works because, in Scala, return e is an expression, not a statement. Its type is Nothing, though, which is a subtype of everything, including Unit => X, which is the type required by map's param. The evaluation though is pretty simple, e is returned from the enclosing function before map is even executed (arguments are evaluated before the method call, obviously). It may be confusing because you'd expect map(return e) to be parsed/interpreted as map(_ => return e), but it's not.
This is almost surely a bug with the compiler, or a partially implemented feature.
It most likely has to do with the implementation of return in an expression in Scala. Non-local return statements are implemented using exceptions, so that when the return is called, a NonLocalReturnException is thrown, and the whole expression is wrapped in a try-catch. I bet x => return x is converted to a nested expression, which, when wrapped in a try-catch, is confusing the compiler when determining if it can use #tailrec. I would go so far as to say that using #tailrec in conjunction with non-local return should be avoided.
Read more about the implementation of return in Scala in this blog post or in this question.
return always breaks recursion calls. You should change you code into something like this:
#tailrec
def listSize(l : Seq[Any], s: Int = 0): Int = {
l match {
case Nil => s
case head :: tail => listSize(tail, s + 1)
}
}
I can't try it out right now, but will this fix the problem?
#annotation.tailrec
def listSize(l : Seq[Any], s: Int = 0): Int = {
if (l.isEmpty) {
None.getOrElse( return s )
} else {
listSize(l.tail, s + 1)
}
}
Using if-else, instead of just if will insure the if statement always return something.

Max elements with tail function in Scala

I'm messing around with the assignments on Coursera's Functional Programming course and I've stumbled upon something weird. This problem requires you to find the max of a List of integers using only the methods isEmpty, head, and tail. My solution is a recursive function that catches an UnsupportedOperationException if there are no more elements. The solution doesn't seem to work however, and I think it is because the exception is never caught.
/**
* This method returns the largest element in a list of integers. If the
* list `xs` is empty it throws a `java.util.NoSuchElementException`.
*
* You can use the same methods of the class `List` as mentioned above.
*
* ''Hint:'' Again, think of a recursive solution instead of using looping
* constructs. You might need to define an auxiliary method.
*
* #param xs A list of natural numbers
* #return The largest element in `xs`
* #throws java.util.NoSuchElementException if `xs` is an empty list
*/
def max(xs: List[Int]): Int =
{
def maxOfTwo(value1: Int, value2: Int) = {
if(value1 > value2) value1
else value2
}
println(xs.size)
try { maxOfTwo(xs.head, max(xs.tail)) }
catch { case noTail: UnsupportedOperationException => xs.head }
}
When I use the following code, which is just replacing the UnsupportedOperationException with Exception everything works perfectly. Am I missing something here?
def max(xs: List[Int]): Int =
{
def maxOfTwo(value1: Int, value2: Int) = {
if(value1 > value2) value1
else value2
}
println(xs.size)
try { maxOfTwo(xs.head, max(xs.tail)) }
catch { case noTail: Exception => xs.head }
}
I think this would be better:
def max(xs: List[Int]): Option[Int] = {
#tailrec
def go(l: List[Int], x: Int): Int = {
l match {
case Nil => x
case h :: t => if (x > h) go(t, x) else go(t, h)
}
}
if (xs.isEmpty) None else Some(go(xs.tail, xs.head))
}
Result type is Option, because list can be empty.
UPDATE:
It fails when UnsupportedOperationException is used, because when you try to access xs.head of an empty list you should also catch NoSuchElementException. It works with Exception, because it's a base class of these two exceptions.
You can't catch java.util.NoSuchElementException with UnsupportedOperationException pattern.
And BTW your code is throwing exception twice. The second exception is thrown by catch block, by invoking xs.head.
How about just thinking functional?
def sum(xs: List[Int]): Int = {
if (xs.isEmpty) 0 else xs.head + sum (xs.tail)
}

How to ensure tail recursion consistently

I'd like to make some function be optimized for tail-recursion. The function would emit stackoverflow exception without optimization.
Sample code:
import scala.util.Try
import scala.annotation.tailrec
object Main {
val trials = 10
#tailrec
val gcd : (Int, Int) => Int = {
case (a,b) if (a == b) => a
case (a,b) if (a > b) => gcd (a-b,b)
case (a,b) if (b > a) => gcd (a, b-a)
}
def main(args : Array[String]) : Unit = {
testTailRec()
}
def testTailRec() {
val outputs : List[Boolean] = Range(0, trials).toList.map(_ + 6000) map { x =>
Try( gcd(x, 1) ).toOption.isDefined
}
outputTestResult(outputs)
}
def outputTestResult(source : List[Boolean]) = {
val failed = source.count(_ == false)
val initial = source.takeWhile(_ == false).length
println( s"totally $failed failures, $initial of which at the beginning")
}
}
Running it will produce the following output:
[info] Running Main
[info] totally 2 failures, 2 of which at the beginning
So, first two runs are performed without optimization and are dropped half-way due to the stackoveflow exception, and only later invocations produce desired result.
There is a workaround: you need to warm up the function with fake runs before actually utilizing it. But it seems clumsy and highly inconvenient. Are there any other means to ensure my recursive function would be optimized for tail recursion before it runs for first time?
update:
I was told to use two-step definition
#tailrec
def gcd_worker(a: Int, b: Int): Int = {
if (a == b) a
else if (a > b) gcd(a-b,b)
else gcd(a, b-a)
}
val gcd : (Int,Int) => Int = gcd_worker(_,_)
I prefer to keep clean functional-style definition if it is possible.
I do not think #tailrec applies to the function defined as val at all. Change it to a def and it will run without errors.
From what I understand #tailrec[1] needs to be on a method, not a field. I was able to get this to be tail recursive in the REPL by making the following change:
#tailrec
def gcd(a: Int, b: Int): Int = {
if (a == b) a
else if (a > b) gcd(a-b,b)
else gcd(a, b-a)
}
[1] http://www.scala-lang.org/api/current/index.html#scala.annotation.tailrec

Tail recursion - will this make optimal use of frame and how do I check if compiling as tail recursive?

In the below code - quite trivial max and sum of lists - I have a recursive function called at the end of a method. Will the scala compiler treat this as tail recursive and optimize the stack frame usage? How do I know/how can I verify this?
package example
import common._
object Lists {
def sum(xs: List[Int]): Int = {
def recSum(current: Int, remaining: List[Int]): Int = {
if (remaining.isEmpty) current else recSum(current + remaining.head, remaining.drop(1))
}
recSum(0, xs)
}
def max(xs: List[Int]): Int = {
def recMax(current: Int, remaining: List[Int], firstIteration: Boolean): Int = {
if(remaining.isEmpty){
current
}else{
val newMax = if (firstIteration || remaining.head>current) remaining.head else current
recMax(newMax, remaining.drop(1), false)
}
}
if (xs.isEmpty) throw new NoSuchElementException else recMax(0, xs, true)
}
}
Add #tailrec before function definition to make the compiler cause an error on non-tailrecursive methods :)
Also, you have to assume the function will be as efficient as an imperative loop (aka. for/while loop) when you have it optimized in this way by the compiler.

Summing values in a List

I'm trying to write a scala function which will recursively sum the values in a list. Here is what I have so far :
def sum(xs: List[Int]): Int = {
val num = List(xs.head)
if(!xs.isEmpty) {
sum(xs.tail)
}
0
}
I dont know how to sum the individual Int values as part of the function. I am considering defining a new function within the function sum and have using a local variable which sums values as List is beuing iterated upon. But this seems like an imperative approach. Is there an alternative method ?
Also you can avoid using recursion directly and use some basic abstractions instead:
val l = List(1, 3, 5, 11, -1, -3, -5)
l.foldLeft(0)(_ + _) // same as l.foldLeft(0)((a,b) => a + b)
foldLeft is as reduce() in python. Also there is foldRight which is also known as accumulate (e.g. in SICP).
With recursion I often find it worthwhile to think about how you'd describe the process in English, as that often translates to code without too much complication. So...
"How do I calculate the sum of a list of integers recursively?"
"Well, what's the sum of a list, 3 :: restOfList?
"What's restOfList?
"It could be anything, you don't know. But remember, we're being recursive - and don't you have a function to calculate the sum of a list?"
"Oh right! Well then the sum would be 3 + sum(restOfList).
"That's right. But now your only problem is that every sum is defined in terms of another call to sum(), so you'll never be able to get an actual value out. You'll need some sort of base case that everything will actually reach, and that you can provide a value for."
"Hmm, you're right." Thinks...
"Well, since your lists are getting shorter and shorter, what's the shortest possible list?"
"The empty list?"
"Right! And what's the sum of an empty list of ints?"
"Zero - I get it now. So putting it together, the sum of an empty list is zero, and the sum of any other list is its first element added to the sum of the rest of it.
And indeed, the code could read almost exactly like that last sentence:
def sumList(xs: List[Int]) = {
if (xs.isEmpty) 0
else xs.head + sumList(xs.tail)
}
(The pattern matching versions, such as that proposed by Kim Stebel, are essentially identical to this, they just express the conditions in a more "functional" way.)
Here's the the "standard" recursive approach:
def sum(xs: List[Int]): Int = {
xs match {
case x :: tail => x + sum(tail) // if there is an element, add it to the sum of the tail
case Nil => 0 // if there are no elements, then the sum is 0
}
}
And, here's a tail-recursive function. It will be more efficient than a non-tail-recursive function because the compiler turns it into a while loop that doesn't require pushing a new frame on the stack for every recursive call:
def sum(xs: List[Int]): Int = {
#tailrec
def inner(xs: List[Int], accum: Int): Int = {
xs match {
case x :: tail => inner(tail, accum + x)
case Nil => accum
}
}
inner(xs, 0)
}
You cannot make it more easy :
val list = List(3, 4, 12);
println(list.sum); // result will be 19
Hope it helps :)
Your code is good but you don't need the temporary value num. In Scala [If] is an expression and returns a value, this will be returned as the value of the sum function. So your code will be refactored to:
def sum(xs: List[Int]): Int = {
if(xs.isEmpty) 0
else xs.head + sum(xs.tail)
}
If list is empty return 0 else you add the to the head number the rest of the list
The canonical implementation with pattern matching:
def sum(xs:List[Int]) = xs match {
case Nil => 0
case x::xs => x + sum(xs)
}
This isn't tail recursive, but it's easy to understand.
Building heavily on #Kim's answer:
def sum(xs: List[Int]): Int = {
if (xs.isEmpty) throw new IllegalArgumentException("Empty list provided for sum operation")
def inner(xs: List[Int]): Int = {
xs match {
case Nil => 0
case x :: tail => xs.head + inner(xs.tail)
}
}
return inner(xs)
}
The inner function is recursive and when an empty list is provided raise appropriate exception.
If you are required to write a recursive function using isEmpty, head and tail, and also throw exception in case empty list argument:
def sum(xs: List[Int]): Int =
if (xs.isEmpty) throw new IllegalArgumentException("sum of empty list")
else if (xs.tail.isEmpty) xs.head
else xs.head + sum(xs.tail)
def sum(xs: List[Int]): Int = {
def loop(accum: Int, xs: List[Int]): Int = {
if (xs.isEmpty) accum
else loop(accum + xs.head, xs.tail)
}
loop(0,xs)
}
def sum(xs: List[Int]): Int = xs.sum
scala> sum(List(1,3,7,5))
res1: Int = 16
scala> sum(List())
res2: Int = 0
To add another possible answer to this, here is a solution I came up with that is a slight variation of #jgaw's answer and uses the #tailrec annotation:
def sum(xs: List[Int]): Int = {
if (xs.isEmpty) throw new Exception // May want to tailor this to either some sort of case class or do something else
#tailrec
def go(l: List[Int], acc: Int): Int = {
if (l.tail == Nil) l.head + acc // If the current 'list' (current element in xs) does not have a tail (no more elements after), then we reached the end of the list.
else go(l.tail, l.head + acc) // Iterate to the next, add on the current accumulation
}
go(xs, 0)
}
Quick note regarding the checks for an empty list being passed in; when programming functionally, it is preferred to not throw any exceptions and instead return something else (another value, function, case class, etc.) to handle errors elegantly and to keep flowing through the path of execution rather than stopping it via an Exception. I threw one in the example above since we're just looking at recursively summing items in a list.
Tried the following method without using substitution approach
def sum(xs: List[Int]) = {
val listSize = xs.size
def loop(a:Int,b:Int):Int={
if(a==0||xs.isEmpty)
b
else
loop(a-1,xs(a-1)+b)
}
loop(listSize,0)
}