I'd like to make some function be optimized for tail-recursion. The function would emit stackoverflow exception without optimization.
Sample code:
import scala.util.Try
import scala.annotation.tailrec
object Main {
val trials = 10
#tailrec
val gcd : (Int, Int) => Int = {
case (a,b) if (a == b) => a
case (a,b) if (a > b) => gcd (a-b,b)
case (a,b) if (b > a) => gcd (a, b-a)
}
def main(args : Array[String]) : Unit = {
testTailRec()
}
def testTailRec() {
val outputs : List[Boolean] = Range(0, trials).toList.map(_ + 6000) map { x =>
Try( gcd(x, 1) ).toOption.isDefined
}
outputTestResult(outputs)
}
def outputTestResult(source : List[Boolean]) = {
val failed = source.count(_ == false)
val initial = source.takeWhile(_ == false).length
println( s"totally $failed failures, $initial of which at the beginning")
}
}
Running it will produce the following output:
[info] Running Main
[info] totally 2 failures, 2 of which at the beginning
So, first two runs are performed without optimization and are dropped half-way due to the stackoveflow exception, and only later invocations produce desired result.
There is a workaround: you need to warm up the function with fake runs before actually utilizing it. But it seems clumsy and highly inconvenient. Are there any other means to ensure my recursive function would be optimized for tail recursion before it runs for first time?
update:
I was told to use two-step definition
#tailrec
def gcd_worker(a: Int, b: Int): Int = {
if (a == b) a
else if (a > b) gcd(a-b,b)
else gcd(a, b-a)
}
val gcd : (Int,Int) => Int = gcd_worker(_,_)
I prefer to keep clean functional-style definition if it is possible.
I do not think #tailrec applies to the function defined as val at all. Change it to a def and it will run without errors.
From what I understand #tailrec[1] needs to be on a method, not a field. I was able to get this to be tail recursive in the REPL by making the following change:
#tailrec
def gcd(a: Int, b: Int): Int = {
if (a == b) a
else if (a > b) gcd(a-b,b)
else gcd(a, b-a)
}
[1] http://www.scala-lang.org/api/current/index.html#scala.annotation.tailrec
Related
How to write an early-return piece of code in scala with no returns/breaks?
For example
for i in 0..10000000
if expensive_operation(i)
return i
return -1
How about
input.find(expensiveOperation).getOrElse(-1)
You can use dropWhile
Here an example:
Seq(2,6,8,3,5).dropWhile(_ % 2 == 0).headOption.getOrElse(default = -1) // -> 8
And here you find more scala-takewhile-example
With your example
(0 to 10000000).dropWhile(!expensive_operation(_)).headOption.getOrElse(default = -1)`
Since you asked for intuition to solve this problem generically. Let me start from the basis.
Scala is (between other things) a functional programming language, as such there is a very important concept for us. And it is that we write programs by composing expressions rather than statements.
Thus, the concept of return value for us means the evaluation of an expression.
(Note this is related to the concept of referential transparency).
val a = expr // a is bounded to the evaluation of expr,
val b = (a, a) // and they are interchangeable, thus b === (expr, expr)
How this relates to your question. In the sense that we really do not have control structures but complex expressions. For example an if
val a = if (expr) exprA else exprB // if itself is an expression, that returns other expressions.
Thus instead of doing something like this:
def foo(a: Int): Int =
if (a != 0) {
val b = a * a
return b
}
return -1
We would do something like:
def foo(a: Int): Int =
if (a != 0)
a * a
else
-1
Because we can bound all the if expression itself as the body of foo.
Now, returning to your specific question. How can we early return a cycle?
The answer is, you can't, at least not without mutations. But, you can use a higher concept, instead of iterating, you can traverse something. And you can do that using recursion.
Thus, let's implement ourselves the find proposed by #Thilo, as a tail-recursive function.
(It is very important that the function is recursive by tail, so the compiler optimizes it as something equivalent to a while loop, that way we will not blow up the stack).
def find(start: Int, end: Int, step: Int = 1)(predicate: Int => Boolean): Option[Int] = {
#annotation.tailrec
def loop(current: Int): Option[Int] =
if (current == end)
None // Base case.
else if (predicate(current))
Some(current) // Early return.
else
loop(current + step) // Recursive step.
loop(current = start)
}
find(0, 10000)(_ == 10)
// res: Option[Int] = Some(10)
Or we may generalize this a little bit more, let's implement find for Lists of any kind of elements.
def find[T](list: List[T])(predicate: T => Boolean): Option[T] = {
#annotation.tailrec
def loop(remaining: List[T]): Option[T] =
remaining match {
case Nil => None
case t :: _ if (predicate(t)) => Some(t)
case _ :: tail => loop(remaining = tail)
}
loop(remaining = list)
}
This is not necessarily the best solution from a practical perspective but I still wanted to add it for educational purposes:
import scala.annotation.tailrec
def expensiveOperation(i: Int): Boolean = ???
#tailrec
def findFirstBy[T](f: (T) => Boolean)(xs: Seq[T]): Option[T] = {
xs match {
case Seq() => None
case Seq(head, _*) if f(head) => Some(head)
case Seq(_, tail#_*) => findFirstBy(f)(tail)
}
}
val result = findFirstBy(expensiveOperation)(Range(0, 10000000)).getOrElse(-1)
Please prefer collections methods (dropWhile, find, ...) in your production code.
There a lot of better answer here but I think a 'while' could work just fine in that situation.
So, this code
for i in 0..10000000
if expensive_operation(i)
return i
return -1
could be rewritten as
var i = 0
var result = false
while(!result && i<(10000000-1)) {
i = i+1
result = expensive_operation(i)
}
After the 'while' the variable 'result' will tell if it succeed or not.
In my Scala app, I have a function that calls a function which returns a result of type Future[T]. I need to pass the mapped result in my recursive function call. I want this to be tail recursive, but the map (or flatMap) is breaking the ability to do that. I get an error "Recursive call not in tail position."
Below is a simple example of this scenario. How can this be modified so that the call will be tail recursive (without subverting the benefits of Futures with an Await.result())?
import scala.annotation.tailrec
import scala.concurrent.{Await, Future}
import scala.concurrent.duration._
implicit val ec = scala.concurrent.ExecutionContext.global
object FactorialCalc {
def factorial(n: Int): Future[Int] = {
#tailrec
def factorialAcc(acc: Int, n: Int): Future[Int] = {
if (n <= 1) {
Future.successful(acc)
} else {
val fNum = getFutureNumber(n)
fNum.flatMap(num => factorialAcc(num * acc, num - 1))
}
}
factorialAcc(1, n)
}
protected def getFutureNumber(n: Int) : Future[Int] = Future.successful(n)
}
Await.result(FactorialCalc.factorial(4), 5.seconds)
I might be mistaken, but your function doesn't need to be tail recursive in this case.
Tail recursion helps us to not consume the stack in case we use recursive functions. In your case, however, we are not actually consuming the stack in the way a typical recursive function would.
This is because the "recursive" call will happen asynchronously, on some thread from the execution context. So it is very likely that this recursive call won't even reside on the same stack as the first call.
The factorialAcc method will create the future object which will eventually trigger the "recursive" call asynchronously. After that, it is immediately popped from the stack.
So this isn't actually stack recursion and the stack doesn't grow proportional to n, it stays roughly at a constant size.
You can easily check this by throwing an exception at some point in the factorialAcc method and inspecting the stack trace.
I rewrote your program to obtain a more readable stack trace:
object Main extends App {
import scala.concurrent.{Await, Future}
import scala.concurrent.duration._
implicit val ec = scala.concurrent.ExecutionContext.global
def factorialAcc(acc: Int, n: Int): Future[Int] = {
if (n == 97)
throw new Exception("n is 97")
if (n <= 1) {
Future.successful(acc)
} else {
val fNum = getFutureNumber(n)
fNum.flatMap(num => factorialAcc(num * acc, num - 1))
}
}
def factorial(n: Int): Future[Int] = {
factorialAcc(1, n)
}
protected def getFutureNumber(n: Int) : Future[Int] = Future.successful(n)
val r = Await.result(factorial(100), 5.seconds)
println(r)
}
And the output is:
Exception in thread "main" java.lang.Exception: n is 97
at test.Main$.factorialAcc(Main.scala:16)
at test.Main$$anonfun$factorialAcc$1.apply(Main.scala:23)
at test.Main$$anonfun$factorialAcc$1.apply(Main.scala:23)
at scala.concurrent.Future$$anonfun$flatMap$1.apply(Future.scala:278)
at scala.concurrent.Future$$anonfun$flatMap$1.apply(Future.scala:274)
at scala.concurrent.impl.CallbackRunnable.run(Promise.scala:29)
at scala.concurrent.impl.ExecutionContextImpl$$anon$3.exec(ExecutionContextImpl.scala:107)
at scala.concurrent.forkjoin.ForkJoinTask.doExec(ForkJoinTask.java:262)
at scala.concurrent.forkjoin.ForkJoinPool$WorkQueue.runTask(ForkJoinPool.java:975)
at scala.concurrent.forkjoin.ForkJoinPool.runWorker(ForkJoinPool.java:1478)
at scala.concurrent.forkjoin.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:104)
So you can see that the stack is actually short. If this was stack recursion you should have seen about 97 calls to the factorialAcc method. Instead, you see only one.
How about using foldLeft instead?
def factorial(n: Int): Future[Int] = future {
(1 to n).foldLeft(1) { _ * _ }
}
Here's a foldLeft solution that calls another function that returns a future.
def factorial(n: Int): Future[Int] =
(1 to n).foldLeft(Future.successful(1)) {
(f, n) => f.flatMap(a => getFutureNumber(n).map(b => a * b))
}
def getFutureNumber(n: Int) : Future[Int] = Future.successful(n)
Make factorialAcc return an Int and only wrap it in a future in the factorial function.
def factorial(n: Int): Future[Int] = {
#tailrec
def factorialAcc(acc: Int, n: Int): Int = {
if (n <= 1) {
acc
} else {
factorialAcc(n*acc,n-1)
}
}
future {
factorialAcc(1, n)
}
}
should probably work.
In my Scala app, I have a function that calls a function which returns a result of type Future[T]. I need to pass the mapped result in my recursive function call. I want this to be tail recursive, but the map (or flatMap) is breaking the ability to do that. I get an error "Recursive call not in tail position."
Below is a simple example of this scenario. How can this be modified so that the call will be tail recursive (without subverting the benefits of Futures with an Await.result())?
import scala.annotation.tailrec
import scala.concurrent.{Await, Future}
import scala.concurrent.duration._
implicit val ec = scala.concurrent.ExecutionContext.global
object FactorialCalc {
def factorial(n: Int): Future[Int] = {
#tailrec
def factorialAcc(acc: Int, n: Int): Future[Int] = {
if (n <= 1) {
Future.successful(acc)
} else {
val fNum = getFutureNumber(n)
fNum.flatMap(num => factorialAcc(num * acc, num - 1))
}
}
factorialAcc(1, n)
}
protected def getFutureNumber(n: Int) : Future[Int] = Future.successful(n)
}
Await.result(FactorialCalc.factorial(4), 5.seconds)
I might be mistaken, but your function doesn't need to be tail recursive in this case.
Tail recursion helps us to not consume the stack in case we use recursive functions. In your case, however, we are not actually consuming the stack in the way a typical recursive function would.
This is because the "recursive" call will happen asynchronously, on some thread from the execution context. So it is very likely that this recursive call won't even reside on the same stack as the first call.
The factorialAcc method will create the future object which will eventually trigger the "recursive" call asynchronously. After that, it is immediately popped from the stack.
So this isn't actually stack recursion and the stack doesn't grow proportional to n, it stays roughly at a constant size.
You can easily check this by throwing an exception at some point in the factorialAcc method and inspecting the stack trace.
I rewrote your program to obtain a more readable stack trace:
object Main extends App {
import scala.concurrent.{Await, Future}
import scala.concurrent.duration._
implicit val ec = scala.concurrent.ExecutionContext.global
def factorialAcc(acc: Int, n: Int): Future[Int] = {
if (n == 97)
throw new Exception("n is 97")
if (n <= 1) {
Future.successful(acc)
} else {
val fNum = getFutureNumber(n)
fNum.flatMap(num => factorialAcc(num * acc, num - 1))
}
}
def factorial(n: Int): Future[Int] = {
factorialAcc(1, n)
}
protected def getFutureNumber(n: Int) : Future[Int] = Future.successful(n)
val r = Await.result(factorial(100), 5.seconds)
println(r)
}
And the output is:
Exception in thread "main" java.lang.Exception: n is 97
at test.Main$.factorialAcc(Main.scala:16)
at test.Main$$anonfun$factorialAcc$1.apply(Main.scala:23)
at test.Main$$anonfun$factorialAcc$1.apply(Main.scala:23)
at scala.concurrent.Future$$anonfun$flatMap$1.apply(Future.scala:278)
at scala.concurrent.Future$$anonfun$flatMap$1.apply(Future.scala:274)
at scala.concurrent.impl.CallbackRunnable.run(Promise.scala:29)
at scala.concurrent.impl.ExecutionContextImpl$$anon$3.exec(ExecutionContextImpl.scala:107)
at scala.concurrent.forkjoin.ForkJoinTask.doExec(ForkJoinTask.java:262)
at scala.concurrent.forkjoin.ForkJoinPool$WorkQueue.runTask(ForkJoinPool.java:975)
at scala.concurrent.forkjoin.ForkJoinPool.runWorker(ForkJoinPool.java:1478)
at scala.concurrent.forkjoin.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:104)
So you can see that the stack is actually short. If this was stack recursion you should have seen about 97 calls to the factorialAcc method. Instead, you see only one.
How about using foldLeft instead?
def factorial(n: Int): Future[Int] = future {
(1 to n).foldLeft(1) { _ * _ }
}
Here's a foldLeft solution that calls another function that returns a future.
def factorial(n: Int): Future[Int] =
(1 to n).foldLeft(Future.successful(1)) {
(f, n) => f.flatMap(a => getFutureNumber(n).map(b => a * b))
}
def getFutureNumber(n: Int) : Future[Int] = Future.successful(n)
Make factorialAcc return an Int and only wrap it in a future in the factorial function.
def factorial(n: Int): Future[Int] = {
#tailrec
def factorialAcc(acc: Int, n: Int): Int = {
if (n <= 1) {
acc
} else {
factorialAcc(n*acc,n-1)
}
}
future {
factorialAcc(1, n)
}
}
should probably work.
In my Scala app, I have a function that calls a function which returns a result of type Future[T]. I need to pass the mapped result in my recursive function call. I want this to be tail recursive, but the map (or flatMap) is breaking the ability to do that. I get an error "Recursive call not in tail position."
Below is a simple example of this scenario. How can this be modified so that the call will be tail recursive (without subverting the benefits of Futures with an Await.result())?
import scala.annotation.tailrec
import scala.concurrent.{Await, Future}
import scala.concurrent.duration._
implicit val ec = scala.concurrent.ExecutionContext.global
object FactorialCalc {
def factorial(n: Int): Future[Int] = {
#tailrec
def factorialAcc(acc: Int, n: Int): Future[Int] = {
if (n <= 1) {
Future.successful(acc)
} else {
val fNum = getFutureNumber(n)
fNum.flatMap(num => factorialAcc(num * acc, num - 1))
}
}
factorialAcc(1, n)
}
protected def getFutureNumber(n: Int) : Future[Int] = Future.successful(n)
}
Await.result(FactorialCalc.factorial(4), 5.seconds)
I might be mistaken, but your function doesn't need to be tail recursive in this case.
Tail recursion helps us to not consume the stack in case we use recursive functions. In your case, however, we are not actually consuming the stack in the way a typical recursive function would.
This is because the "recursive" call will happen asynchronously, on some thread from the execution context. So it is very likely that this recursive call won't even reside on the same stack as the first call.
The factorialAcc method will create the future object which will eventually trigger the "recursive" call asynchronously. After that, it is immediately popped from the stack.
So this isn't actually stack recursion and the stack doesn't grow proportional to n, it stays roughly at a constant size.
You can easily check this by throwing an exception at some point in the factorialAcc method and inspecting the stack trace.
I rewrote your program to obtain a more readable stack trace:
object Main extends App {
import scala.concurrent.{Await, Future}
import scala.concurrent.duration._
implicit val ec = scala.concurrent.ExecutionContext.global
def factorialAcc(acc: Int, n: Int): Future[Int] = {
if (n == 97)
throw new Exception("n is 97")
if (n <= 1) {
Future.successful(acc)
} else {
val fNum = getFutureNumber(n)
fNum.flatMap(num => factorialAcc(num * acc, num - 1))
}
}
def factorial(n: Int): Future[Int] = {
factorialAcc(1, n)
}
protected def getFutureNumber(n: Int) : Future[Int] = Future.successful(n)
val r = Await.result(factorial(100), 5.seconds)
println(r)
}
And the output is:
Exception in thread "main" java.lang.Exception: n is 97
at test.Main$.factorialAcc(Main.scala:16)
at test.Main$$anonfun$factorialAcc$1.apply(Main.scala:23)
at test.Main$$anonfun$factorialAcc$1.apply(Main.scala:23)
at scala.concurrent.Future$$anonfun$flatMap$1.apply(Future.scala:278)
at scala.concurrent.Future$$anonfun$flatMap$1.apply(Future.scala:274)
at scala.concurrent.impl.CallbackRunnable.run(Promise.scala:29)
at scala.concurrent.impl.ExecutionContextImpl$$anon$3.exec(ExecutionContextImpl.scala:107)
at scala.concurrent.forkjoin.ForkJoinTask.doExec(ForkJoinTask.java:262)
at scala.concurrent.forkjoin.ForkJoinPool$WorkQueue.runTask(ForkJoinPool.java:975)
at scala.concurrent.forkjoin.ForkJoinPool.runWorker(ForkJoinPool.java:1478)
at scala.concurrent.forkjoin.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:104)
So you can see that the stack is actually short. If this was stack recursion you should have seen about 97 calls to the factorialAcc method. Instead, you see only one.
How about using foldLeft instead?
def factorial(n: Int): Future[Int] = future {
(1 to n).foldLeft(1) { _ * _ }
}
Here's a foldLeft solution that calls another function that returns a future.
def factorial(n: Int): Future[Int] =
(1 to n).foldLeft(Future.successful(1)) {
(f, n) => f.flatMap(a => getFutureNumber(n).map(b => a * b))
}
def getFutureNumber(n: Int) : Future[Int] = Future.successful(n)
Make factorialAcc return an Int and only wrap it in a future in the factorial function.
def factorial(n: Int): Future[Int] = {
#tailrec
def factorialAcc(acc: Int, n: Int): Int = {
if (n <= 1) {
acc
} else {
factorialAcc(n*acc,n-1)
}
}
future {
factorialAcc(1, n)
}
}
should probably work.
I want to use IO monad.
But this code do not run with large file.
I am getting a StackOverflowError.
I tried the -DXss option, but it throws the same error.
val main = for {
l <- getFileLines(file)(collect[String, List]).map(_.run)
_ <- l.traverse_(putStrLn)
} yield ()
How can I do it?
I wrote Iteratee that is output all element.
def putStrLn[E: Show]: IterV[E, IO[Unit]] = {
import IterV._
def step(i: IO[Unit])(input: Input[E]): IterV[E, IO[Unit]] =
input(el = e => Cont(step(i >|> effects.putStrLn(e.shows))),
empty = Cont(step(i)),
eof = Done(i, EOF[E]))
Cont(step(mzero[IO[Unit]]))
}
val main = for {
i <- getFileLines(file)(putStrLn).map(_.run)
} yield i.unsafePerformIO
This is also the same result.
I think to be caused by IO implementation.
This is because scalac is not optimizing loop inside getReaderLines for tail calls. loop is tail recursive but I think the case anonymous function syntax gets in the way.
Edit: actually it's not even tail recursive (the wrapping in the IO monad) causes at least one more call after the recursive call. When I was doing my testing yesterday, I was using similar code but I had dropped the IO monad and it was then possible to make the Iteratee tail recursive. The text below, assumes no IO monad...
I happened to find that out yesterday while experimenting with iteratees. I think changing the signature of loop to this will help (so for the time being you may have to reimplement getFilesLines and getReaderLines:
#annotations.tailrec
def loop(it: IterV[String, A]): IO[IterV[String, A]] = it match {
// ...
}
We should probably report this to the scalaz folk (and may be open an enhancement ticket for scala).
This shows what happens (code vaguely similar to getReaderLines.loop):
#annotation.tailrec
def f(i: Int): Int = i match {
case 0 => 0
case x => f(x - 1)
}
// f: (i: Int)Int
#annotation.tailrec
def g: Int => Int = {
case 0 => 0
case x => g(x - 1)
}
/* error: could not optimize #tailrec annotated method g:
it contains a recursive call not in tail position
def g: Int => Int = {
^
*/