Asserting #tailrec on Option.getOrElse - scala

In the following, the line maybeNext.map{rec}.getOrElse(n) uses the Option monad to implement the recurse or escape pattern.
scala> #tailrec
| def rec(n: Int): Int = {
| val maybeNext = if (n >= 99) None else Some(n+1)
| maybeNext.map{rec}.getOrElse(n)
| }
Looks good, however:
<console>:7: error: could not optimize #tailrec annotated method:
it contains a recursive call not in tail position
def rec(n: Int): Int = {
^
I feel that the compiler should be able to sort out tail recursion in this case. It is equivalent to the following (somewhat repulsive, but compilable) sample:
scala> #tailrec
| def rec(n: Int): Int = {
| val maybeNext = if (n >= 99) None else Some(n+1)
| if (maybeNext.isEmpty) n
| else rec(maybeNext.get)
| }
rec: (n: Int)Int
Can anyone provide illumination here? Why can't the compiler figure it out? Is it a bug, or an oversight? Is the problem too difficult?
Edit: Remove the #tailrec from the first example and the method compiles; the loop terminates. The last call is always getOrElse which is equivalent to if option.isEmpty defaultValue else recurse. I think this could and should be inferred by the compiler.

It is not a bug, it is not an oversight, and it is not a tail recursion.
Yes, you can write the code in a tail recursive manner, but that doesn't mean every equivalent algorithm can be made tail recursive. Let's take this code:
maybeNext.map{rec].getOrElse(n)
First, the last call is to getOrElse(n). This call is not optional -- it is always made, and it is necessary to adjust the result. But let's ignore that.
The next to last call is to map{rec}. Not to rec. In fact, rec is not called at all in your code! Some other function calls it (and, in fact, it is not the last call on map either), but not your function.
For something to be tail recursive, you need to be able to replace the call with a "goto", so to speak. Like this:
def rec(n: Int): Int = {
BEGINNING:
val maybeNext = if (n >= 99) None else Some(n+1)
if (maybeNext.isEmpty) n
else {
n = maybeNext.get
goto BEGINNING
}
}
How would that happen in the other code?
def rec(n: Int): Int = {
BEGINNING:
val maybeNext = if (n >= 99) None else Some(n+1)
maybeNext.map{x => n = x; goto BEGINNING}.getOrElse(n)
}
The goto here is not inside rec. It is inside an anonymous Function1's apply, which, by its turn, is inside an Option's map, so a branch here would leave two stack frames on each call. Assuming inter-method branching was possible in first place.

Related

Tail recursion and call by name / value

Learning Scala and functional programming in general. In the following tail-recursive factorial implementation:
def factorialTailRec(n: Int) : Int = {
#tailrec
def factorialRec(n: Int, f: => Int): Int = {
if (n == 0) f else factorialRec(n - 1, n * f)
}
factorialRec(n, 1)
}
I wonder whether there is any benefit to having the second parameter called by value vs called by name (as I have done). In the first case, every stack frame is burdened with a product. In the second case, if my understanding is correct, the entire chain of products will be carried over to the case if ( n== 0) at the nth stack frame, so we will still have to perform the same number of multiplications. Unfortunately, this is not a product of form a^n, which can be calculated in log_2n steps through repeated squaring, but a product of terms that differ by 1 every time. So I can't see any possible way of optimizing the final product: it will still require the multiplication of O(n) terms.
Is this correct? Is call by value equivalent to call by name here, in terms of complexity?
Let me just expand a little bit what you've already been told in comments.
That's how by-name parameters are desugared by the compiler:
#tailrec
def factorialTailRec(n: Int, f: => Int): Int = {
if (n == 0) {
val fEvaluated = f
fEvaluated
} else {
val fEvaluated = f // <-- here we are going deeper into stack.
factorialTailRec(n - 1, n * fEvaluated)
}
}
Through experimentation I found out that with the call by name formalism, the method becomes... non-tail recursive! I made this example code to compare factorial tail-recursively, and factorial non-tail-recursively:
package example
import scala.annotation.tailrec
object Factorial extends App {
val ITERS = 100000
def factorialTailRec(n: Int) : Int = {
#tailrec
def factorialTailRec(n: Int, f: => Int): Int = {
if (n == 0) f else factorialTailRec(n - 1, n * f)
}
factorialTailRec(n, 1)
}
for(i <-1 to ITERS) println("factorialTailRec(" + i + ") = " + factorialTailRec(i))
def factorial(n:Int) : Int = {
if(n == 0) 1 else n * factorial(n-1)
}
for(i <-1 to ITERS) println("factorial(" + i + ") = " + factorial(i))
}
Observe that the inner tailRec function calls the second argument by name. for which the #tailRec annotation still does NOT throw a compile-time error!
I've been playing around with different values for the ITERS variable, and for a value of 100,000, I receive a... StackOverflowError!
(The result of zero is there because of overflow of Int.)
So I went ahead and changed the signature of factorialTailRec/2, to:
def factorialTailRec(n: Int, f: Int): Int
i.e call by value for the argument f. This time, the portion of main that runs factorialTailRec finishes absolutely fine, whereas, of course, factorial/1 crashes at the exact same integer.
Very, very interesting. It seems as if call by name in this situation maintains the stack frames because of the need of computation of the products themselves all the way back to the call chain.

Tail recursion - Scala (any language else)

i have a question about Tail recursion. As i know Tail recursion is when the last recursive call from function will deliver the result of the function. But when i have a function like this
def func1(n: Int): Int = {
if (n > 100) {
n - 10
}
else {
func1(func1(n + 11))
}
}
would it be tail recursion ? For example
func1(100) = func1(func1(111)) = func1(101) = 91
so the last recursive call would be func1(101) and it should deliver the results so that would be tail recursion right? I'm a little confused. Thank you!
It's not tail-recursive. You could rewrite the code to look like this:
def func1(n: Int): Int = {
if (n > 100) {
n - 10
}
else {
val f = func1(n + 11)
func1(f)
}
}
You can see that there is a call to func1 on line 6 that is not in the tail position.
Any easy way to check would be to just try it. the #tailrec annotation (import scala.annotation.tailrec) will give you a compile-time error if your method is not tail recursive.
This is not tail recursive though because you have a recursive call in a non-tail position.
You have two recursive calls in your function, one is in a tail position, it's the last call in the method, but the other is the input to that call, which isn't tail recursive because something comes after it, the next call. It's not enough to have one recursive call in the tail position, every recursive call must be a tail call
No, your example is not a case of tail recursion.
func1(func1(n + 11)) is a case of non-linear recursion (particularly, nested recursion).
Tail recursion is a particular case of linear recursion that can be immediately converted into iteration (loop), which is why it is interesting, as it allows easy optimization.
In your case, the call to the inner function is not the last operation in the function (there is still pending the call to the outer function), thus, it is not tail recursion.
Actually tail recursive method is a method which 'returns' a result itself or next invocation. If you can rewrite your algorithm the following way - it can be tail recursive.
trait Recursive[R] {
def oneIteration: Either[R, Recursive[R]]
}
object Recursive {
def interpret[R](fn: Recursive[R]): R = {
var res: Either[R, Recursive[R]] = Right(fn)
while(res.isRight) {
res = res.right.get.oneIteration
}
res.left.get
}
}
object Factorial extends App {
def factorial(acc: BigDecimal, n: Int): Recursive[BigDecimal] = new Recursive[BigDecimal] {
override def oneIteration(): Either[BigDecimal, Recursive[BigDecimal]] = {
if(n == 1 ){
Left(acc)
}
else {
Right(factorial(acc * n, n - 1))
}
}
}
val res = Recursive.interpret(factorial(1 , 5))
println(res)
}

How do you print an integer value within an object in scala?

object perMissing {
def solution(A: Array[Int]): Int = {
def findMissing(i: Int, L: List[Int]): Int = {
if (L.isEmpty || L.head != i+1) {
i+1
println(i+1)}
else findMissing(i+1, L.tail)
}
if (A.length == 0) 1
else findMissing(0, A.toList.sorted)
}
solution(Array(2,3,1,5))
}
I'm new to the world of Scala. I come from Python and C world.
How do we print an integer value, eg. for debugging? For instance, if I want to see the value of i in every iteration.
I compile my code using scalac and run it using scala.
According to the signature of your findMissing function, it should return an Int. However, if you look at the implementation of that function, only one of the code paths (namely the else part) returns an Int - the if part on the other hand does not return anything (besides Unit), since the call to println is the last line of that particular code block. To fix this issue, just return the increased value by putting it at the end of the block:
def findMissing(i: Int, l: List[Int]): Int = {
val inc = i + 1
if (l.isEmpty || l.head != inc) {
println(inc)
inc
}
else findMissing(inc, l.tail)
}
Since findMissing is tail recursive, you could additionally annotate it with #tailrec to ensure it will be compiled with tail call optimization.

Find max value in a list recursively in scala

I'm new to Scala, there is a better way to express this with the most basic knowledge possible?
def findMax(xs: List[Int]): Int = {
xs match {
case x :: tail => (if (tail.length==0) x else (if(x>findMax(tail)) x else (findMax(tail))))
}
}
Thee are two problems here. First, you call tail.length which is an operation of order O(N), so in the worst case this will cost you N*N steps where N is the length of the sequence. The second is that your function is not tail-recursive - you nest the findMax calls "from outside to inside".
The usual strategy to write the correct recursive function is
to think about each possible pattern case: here you have either the empty list Nil or the non-empty list head :: tail. This solves your first problem.
to carry along the temporary result (here the current guess of the maximum value) as another argument of the function. This solves your second problem.
This gives:
import scala.annotation.tailrec
#tailrec
def findMax(xs: List[Int], max: Int): Int = xs match {
case head :: tail => findMax(tail, if (head > max) head else max)
case Nil => max
}
val z = util.Random.shuffle(1 to 100 toList)
assert(findMax(z, Int.MinValue) == 100)
If you don't want to expose this additional argument, you can write an auxiliary inner function.
def findMax(xs: List[Int]): Int = {
#tailrec
def loop(ys: List[Int], max: Int): Int = ys match {
case head :: tail => loop(tail, if (head > max) head else max)
case Nil => max
}
loop(xs, Int.MinValue)
}
val z = util.Random.shuffle(1 to 100 toList)
assert(findMax(z) == 100)
For simplicity we return Int.MinValue if the list is empty. A better solution might be to throw an exception for this case.
The #tailrec annotation here is optional, it simply assures that we indeed defined a tail recursive function. This has the advantage that we cannot produce a stack overflow if the list is extremely long.
Any time you're reducing a collection to a single value, consider using one of the fold functions instead of explicit recursion.
List(3,7,1).fold(Int.MinValue)(Math.max)
// 7
Even I too am new to Scala (am into Haskell though!).
My attempt at this would be as below.
Note that I assume a non-empty list, since the max of an empty list does not make sense.
I first define an helper method which simply returns the max of 2 numbers.
def maxOf2 (x:Int, y:Int): Int = {
if (x >= y) x
else y
}
Armed with this simple function, we can build a recursive function to find the 'max' as below:
def findMax(xs: List[Int]): Int = {
if (xs.tail.isEmpty)
xs.head
else
maxOf2(xs.head, findMax(xs.tail))
}
I feel this is a pretty 'clear'(though not 'efficient') way to do it.
I wanted to make the concept of recursion obvious.
Hope this helps!
Elaborating on #fritz's answer. If you pass in an empty list, it will throw you a java.lang.UnsupportedOperationException: tail of empty list
So, keeping the algorithm intact, I made this adjustment:
def max(xs: List[Int]): Int = {
def maxOfTwo(x: Int, y: Int): Int = {
if (x >= y) x else y
}
if (xs.isEmpty) throw new UnsupportedOperationException("What man?")
else if (xs.size == 1) xs.head
else maxOfTwo(xs.head, max(xs.tail))
}
#fritz Thanks for the answer
Using pattern matching an recursion,
def top(xs: List[Int]): Int = xs match {
case Nil => sys.error("no max in empty list")
case x :: Nil => x
case x :: xs => math.max(x, top(xs))
}
Pattern matching is used to decompose the list into head and rest. A single element list is denoted with x :: Nil. We recurse on the rest of the list and compare for maximum on the head item of the list at each recursive stage. To make the cases exhaustive (to make a well-defined function) we consider also empty lists (Nil).
def maxl(xl: List[Int]): Int = {
if ( (xl.head > xl.tail.head) && (xl.tail.length >= 1) )
return xl.head
else
if(xl.tail.length == 1)
xl.tail.head
else
maxl(xl.tail)
}

How to iterate and add elements to a list in scala without changing state

I have this
val x = List(1,2,3,4,5,6,7,8,9)
I want to take the items of the list from 3 to the end of the list and create a new list with them without changing state,
and get this:
List(4,5,6,7,8,9)
As stated in #Jubobs' comment and #Yuval Itzchakov's answer, the answer is to use the scala.collection.immutable.List.drop(n: Int): List[A] method of lists, whose implementation you can find in src/library/scala/collection/immutable/List.scala.
However, because of the importance of List in the Scala ecosystem, this method is aggressively optimized for performance and not indicative of good Scala style. In particular, even though it is "externally pure", it does use mutation, side-effects and loops on the inside.
An alternative implementation that doesn't use any mutation, loops or side-effects might look like this:
override def drop(n: Int): List[A] =
if (n == 0) this else tail drop n-1
This is the trivial recursive implementation. Note: this will throw an exception if you try to drop more items than the list has, while the original one will handle that gracefully by returning the empty list. Re-introducing that behavior is trivial, though:
override def drop(n: Int): List[A] =
if (n == 0 || isEmpty) this else tail drop n-1
This method is not just recursive, but actually tail-recursive, so it will be as efficient as a while loop. We can have the compiler yell at us if that isn't true by adding the #scala.annotation.tailrec annotation to the method:
#scala.annotation.tailrec override def drop(n: Int): List[A] =
if (n == 0 || isEmpty) this else tail drop n-1
As mentioned in the comments, List.drop does exactly that:
val x = List(1,2,3,4,5,6,7,8,9)
val reduced = x.drop(3)
The implementation looks like this:
override def drop(n: Int): List[A] = {
var these = this
var count = n
while (!these.isEmpty && count > 0) {
these = these.tail
count -= 1
}
these
}