I am trying to find a tail recursive fold function for a binary tree. Given the following definitions:
// From the book "Functional Programming in Scala", page 45
sealed trait Tree[+A]
case class Leaf[A](value: A) extends Tree[A]
case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A]
Implementing a non tail recursive function is quite straightforward:
def fold[A, B](t: Tree[A])(map: A => B)(red: (B, B) => B): B =
t match {
case Leaf(v) => map(v)
case Branch(l, r) =>
red(fold(l)(map)(red), fold(r)(map)(red))
}
But now I am struggling to find a tail recursive fold function so that the annotation #annotation.tailrec can be used.
During my research I have found several examples where tail recursive functions on a tree can e.g. compute the sum of all leafs using an own stack which is then basically a List[Tree[Int]]. But as far as I understand in this case it only works for the additions because it is not important whether you first evaluate the left or the right hand side of the operator. But for a generalised fold it is quite relevant. To show my intension here are some example trees:
val leafs = Branch(Leaf(1), Leaf(2))
val left = Branch(Branch(Leaf(1), Leaf(2)), Leaf(3))
val right = Branch(Leaf(1), Branch(Leaf(2), Leaf(3)))
val bal = Branch(Branch(Leaf(1), Leaf(2)), Branch(Leaf(3), Leaf(4)))
val cmb = Branch(right, Branch(bal, Branch(leafs, left)))
val trees = List(leafs, left, right, bal, cmb)
Based on those trees I want to create a deep copy with the given fold method like:
val oldNewPairs =
trees.map(t => (t, fold(t)(Leaf(_): Tree[Int])(Branch(_, _))))
And then proof that the condition of equality holds for all created copies:
val conditionHolds = oldNewPairs.forall(p => {
if (p._1 == p._2) true
else {
println(s"Original:\n${p._1}\nNew:\n${p._2}")
false
}
})
println("Condition holds: " + conditionHolds)
Could someone give me some pointers, please?
You can find the code used in this question at ScalaFiddle: https://scalafiddle.io/sf/eSKJyp2/15
You could reach a tail recursive solution if you stop using the function call stack and start using a stack managed by your code and an accumulator:
def fold[A, B](t: Tree[A])(map: A => B)(red: (B, B) => B): B = {
case object BranchStub extends Tree[Nothing]
#tailrec
def foldImp(toVisit: List[Tree[A]], acc: Vector[B]): Vector[B] =
if(toVisit.isEmpty) acc
else {
toVisit.head match {
case Leaf(v) =>
val leafRes = map(v)
foldImp(
toVisit.tail,
acc :+ leafRes
)
case Branch(l, r) =>
foldImp(l :: r :: BranchStub :: toVisit.tail, acc)
case BranchStub =>
foldImp(toVisit.tail, acc.dropRight(2) ++ Vector(acc.takeRight(2).reduce(red)))
}
}
foldImp(t::Nil, Vector.empty).head
}
The idea is to accumulate values from left to right, keep track of the parenthood relation by the introduction of a stub node and reduce the result using your red function using the last two elements of the accumulator whenever a stub node is found in the exploration.
This solution could be optimized but it is already a tail recursive function implementation.
EDIT:
It can be slightly simplified by changing the accumulator data structure to a list seen as a stack:
def fold[A, B](t: Tree[A])(map: A => B)(red: (B, B) => B): B = {
case object BranchStub extends Tree[Nothing]
#tailrec
def foldImp(toVisit: List[Tree[A]], acc: List[B]): List[B] =
if(toVisit.isEmpty) acc
else {
toVisit.head match {
case Leaf(v) =>
foldImp(
toVisit.tail,
map(v)::acc
)
case Branch(l, r) =>
foldImp(r :: l :: BranchStub :: toVisit.tail, acc)
case BranchStub =>
foldImp(toVisit.tail, acc.take(2).reduce(red) :: acc.drop(2))
}
}
foldImp(t::Nil, Nil).head
}
Related
How do I rewrite the following loop (pattern) into Scala, either using built-in higher order functions or tail recursion?
This the example of an iteration pattern where you do a computation (comparison, for example) of two list elements, but only if the second one comes after first one in the original input. Note that the +1 step is used here, but in general, it could be +n.
public List<U> mapNext(List<T> list) {
List<U> results = new ArrayList();
for (i = 0; i < list.size - 1; i++) {
for (j = i + 1; j < list.size; j++) {
results.add(doSomething(list[i], list[j]))
}
}
return results;
}
So far, I've come up with this in Scala:
def mapNext[T, U](list: List[T])(f: (T, T) => U): List[U] = {
#scala.annotation.tailrec
def loop(ix: List[T], jx: List[T], res: List[U]): List[U] = (ix, jx) match {
case (_ :: _ :: is, Nil) => loop(ix, ix.tail, res)
case (i :: _ :: is, j :: Nil) => loop(ix.tail, Nil, f(i, j) :: res)
case (i :: _ :: is, j :: js) => loop(ix, js, f(i, j) :: res)
case _ => res
}
loop(list, Nil, Nil).reverse
}
Edit:
To all contributors, I only wish I could accept every answer as solution :)
Here's my stab. I think it's pretty readable. The intuition is: for each head of the list, apply the function to the head and every other member of the tail. Then recurse on the tail of the list.
def mapNext[U, T](list: List[U], fun: (U, U) => T): List[T] = list match {
case Nil => Nil
case (first :: Nil) => Nil
case (first :: rest) => rest.map(fun(first, _: U)) ++ mapNext(rest, fun)
}
Here's a sample run
scala> mapNext(List(1, 2, 3, 4), (x: Int, y: Int) => x + y)
res6: List[Int] = List(3, 4, 5, 5, 6, 7)
This one isn't explicitly tail recursive but an accumulator could be easily added to make it.
Recursion is certainly an option, but the standard library offers some alternatives that will achieve the same iteration pattern.
Here's a very simple setup for demonstration purposes.
val lst = List("a","b","c","d")
def doSomething(a:String, b:String) = a+b
And here's one way to get at what we're after.
val resA = lst.tails.toList.init.flatMap(tl=>tl.tail.map(doSomething(tl.head,_)))
// resA: List[String] = List(ab, ac, ad, bc, bd, cd)
This works but the fact that there's a map() within a flatMap() suggests that a for comprehension might be used to pretty it up.
val resB = for {
tl <- lst.tails
if tl.nonEmpty
h = tl.head
x <- tl.tail
} yield doSomething(h, x) // resB: Iterator[String] = non-empty iterator
resB.toList // List(ab, ac, ad, bc, bd, cd)
In both cases the toList cast is used to get us back to the original collection type, which might not actually be necessary depending on what further processing of the collection is required.
Comeback Attempt:
After deleting my first attempt to give an answer I put some more thought into it and came up with another, at least shorter solution.
def mapNext[T, U](list: List[T])(f: (T, T) => U): List[U] = {
#tailrec
def loop(in: List[T], out: List[U]): List[U] = in match {
case Nil => out
case head :: tail => loop(tail, out ::: tail.map { f(head, _) } )
}
loop(list, Nil)
}
I would also like to recommend the enrich my library pattern for adding the mapNext function to the List api (or with some adjustments to any other collection).
object collection {
object Implicits {
implicit class RichList[A](private val underlying: List[A]) extends AnyVal {
def mapNext[U](f: (A, A) => U): List[U] = {
#tailrec
def loop(in: List[A], out: List[U]): List[U] = in match {
case Nil => out
case head :: tail => loop(tail, out ::: tail.map { f(head, _) } )
}
loop(underlying, Nil)
}
}
}
}
Then you can use the function like:
list.mapNext(doSomething)
Again, there is a downside, as concatenating lists is relatively expensive.
However, variable assignemends inside for comprehensions can be quite inefficient, too (as this improvement task for dotty Scala Wart: Convoluted de-sugaring of for-comprehensions suggests).
UPDATE
Now that I'm into this, I simply cannot let go :(
Concerning 'Note that the +1 step is used here, but in general, it could be +n.'
I extended my proposal with some parameters to cover more situations:
object collection {
object Implicits {
implicit class RichList[A](private val underlying: List[A]) extends AnyVal {
def mapNext[U](f: (A, A) => U): List[U] = {
#tailrec
def loop(in: List[A], out: List[U]): List[U] = in match {
case Nil => out
case head :: tail => loop(tail, out ::: tail.map { f(head, _) } )
}
loop(underlying, Nil)
}
def mapEvery[U](step: Int)(f: A => U) = {
#tailrec
def loop(in: List[A], out: List[U]): List[U] = {
in match {
case Nil => out.reverse
case head :: tail => loop(tail.drop(step), f(head) :: out)
}
}
loop(underlying, Nil)
}
def mapDrop[U](drop1: Int, drop2: Int, step: Int)(f: (A, A) => U): List[U] = {
#tailrec
def loop(in: List[A], out: List[U]): List[U] = in match {
case Nil => out
case head :: tail =>
loop(tail.drop(drop1), out ::: tail.drop(drop2).mapEvery(step) { f(head, _) } )
}
loop(underlying, Nil)
}
}
}
}
list // [a, b, c, d, ...]
.indices // [0, 1, 2, 3, ...]
.flatMap { i =>
elem = list(i) // Don't redo access every iteration of the below map.
list.drop(i + 1) // Take only the inputs that come after the one we're working on
.map(doSomething(elem, _))
}
// Or with a monad-comprehension
for {
index <- list.indices
thisElem = list(index)
thatElem <- list.drop(index + 1)
} yield doSomething(thisElem, thatElem)
You start, not with the list, but with its indices. Then, you use flatMap, because each index goes to a list of elements. Use drop to take only the elements after the element we're working on, and map that list to actually run the computation. Note that this has terrible time complexity, because most operations here, indices/length, flatMap, map, are O(n) in the list size, and drop and apply are O(n) in the argument.
You can get better performance if you a) stop using a linked list (List is good for LIFO, sequential access, but Vector is better in the general case), and b) make this a tiny bit uglier
val len = vector.length
(0 until len)
.flatMap { thisIdx =>
val thisElem = vector(thisIdx)
((thisIdx + 1) until len)
.map { thatIdx =>
doSomething(thisElem, vector(thatIdx))
}
}
// Or
val len = vector.length
for {
thisIdx <- 0 until len
thisElem = vector(thisIdx)
thatIdx <- (thisIdx + 1) until len
thatElem = vector(thatIdx)
} yield doSomething(thisElem, thatElem)
If you really need to, you can generalize either version of this code to all IndexedSeqs, by using some implicit CanBuildFrom parameters, but I won't cover that.
In cats, when a Monad is created using Monad trait, an implementation for method tailRecM should be provided.
I have a scenario below that I found impossible to provide a tail recursive implementation of tailRecM
sealed trait Tree[+A]
final case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A]
final case class Leaf[A](value: A) extends Tree[A]
implicit val treeMonad = new Monad[Tree] {
override def pure[A](value: A): Tree[A] = Leaf(value)
override def flatMap[A, B](initial: Tree[A])(func: A => Tree[B]): Tree[B] =
initial match {
case Branch(l, r) => Branch(flatMap(l)(func), flatMap(r)(func))
case Leaf(value) => func(value)
}
//#tailrec
override def tailRecM[A, B](a: A)(func: (A) => Tree[Either[A, B]]): Tree[B] = {
func(a) match {
case Branch(l, r) =>
Branch(
flatMap(l) {
case Right(l) => pure(l)
case Left(l) => tailRecM(l)(func)
},
flatMap(r){
case Right(r) => pure(r)
case Left(r) => tailRecM(r)(func)
}
)
case Leaf(Left(value)) => tailRecM(value)(func)
case Leaf(Right(value)) => Leaf(value)
}
}
}
1) According to the above example, how this tailRecM method can be used for optimizing flatMap method call? Does the implementation of the flatMap method is overridden/modified by tailRecM at the compile time ?
2) If the tailRecM is not tail recursive as above, will it still be efficient than using the original flatMap method ?
Please share your thoughts.
Sometimes there is a way to replace a call stack with explicit list.
Here toVisit keeps track of branches that are waiting to be processed.
And toCollect keeps branches that are waiting to be merged until corresponding branch is finished processed.
override def tailRecM[A, B](a: A)(f: (A) => Tree[Either[A, B]]): Tree[B] = {
#tailrec
def go(toVisit: List[Tree[Either[A, B]]],
toCollect: List[Tree[B]]): List[Tree[B]] = toVisit match {
case (tree :: tail) =>
tree match {
case Branch(l, r) =>
l match {
case Branch(_, _) => go(l :: r :: tail, toCollect)
case Leaf(Left(a)) => go(f(a) :: r :: tail, toCollect)
case Leaf(Right(b)) => go(r :: tail, pure(b) +: toCollect)
}
case Leaf(Left(a)) => go(f(a) :: tail, toCollect)
case Leaf(Right(b)) =>
go(tail,
if (toCollect.isEmpty) pure(b) +: toCollect
else Branch(toCollect.head, pure(b)) :: toCollect.tail)
}
case Nil => toCollect
}
go(f(a) :: Nil, Nil).head
}
From cats ticket why to use tailRecM
tailRecM won't blow the stack (like almost every JVM program it may OOM), for any of the Monads in cats.
and then
Without tailRecM (or recursive flatMap) being safe, libraries like
iteratee.io can't safely be written since they require monadic recursion.
and another ticket states that clients of cats.Monad should be aware that some monads don't have stacksafe tailRecM
tailRecM can still be used by those that are trying to get stack safety, so long as they understand that certain monads will not be able to give it to them
Relation between tailRecM and flatMap
To answer you first question, the following code is part of FlatMapLaws.scala, from cats-laws. It tests consistency between flatMap and tailRecM methods.
/**
* It is possible to implement flatMap from tailRecM and map
* and it should agree with the flatMap implementation.
*/
def flatMapFromTailRecMConsistency[A, B](fa: F[A], fn: A => F[B]): IsEq[F[B]] = {
val tailRecMFlatMap = F.tailRecM[Option[A], B](Option.empty[A]) {
case None => F.map(fa) { a => Left(Some(a)) }
case Some(a) => F.map(fn(a)) { b => Right(b) }
}
F.flatMap(fa)(fn) <-> tailRecMFlatMap
}
This shows how to implement a flatMap from tailRecM and implicitly suggests that the compiler will not do such thing automatically. It's up to the user of the Monad to decide when it makes sense to use tailRecM over flatMap.
This blog has nice scala examples to explain when tailRecM comes in useful. It follows the PureScript article by Phil Freeman, which originally introduced the method.
It explains the downsides in using flatMap for monadic composition:
This characteristic of Scala limits the usefulness of monadic composition where flatMap can call monadic function f, which then can call flatMap etc..
In contrast with a tailRecM-based implementation:
This guarantees greater safety on the user of FlatMap typeclass, but it would mean that each the implementers of the instances would need to provide a safe tailRecM.
Many of the provided methods in cats leverage monadic composition. So, even if you don't use it directly, implementing tailRecM allows for more efficient composition with other monads.
Implmentation for tree
In a different answer, #nazarii-bardiuk provides an implementation of tailRecM which is tail recursive, but does not pass the flatMap/tailRecM consistency test mentioned above. The tree structure is not properly rebuilt after recursion. A fixed version below:
def tailRecM[A, B](arg: A)(func: A => Tree[Either[A, B]]): Tree[B] = {
#tailrec
def loop(toVisit: List[Tree[Either[A, B]]],
toCollect: List[Option[Tree[B]]]): List[Tree[B]] =
toVisit match {
case Branch(l, r) :: next =>
loop(l :: r :: next, None :: toCollect)
case Leaf(Left(value)) :: next =>
loop(func(value) :: next, toCollect)
case Leaf(Right(value)) :: next =>
loop(next, Some(pure(value)) :: toCollect)
case Nil =>
toCollect.foldLeft(Nil: List[Tree[B]]) { (acc, maybeTree) =>
maybeTree.map(_ :: acc).getOrElse {
val left :: right :: tail = acc
branch(left, right) :: tail
}
}
}
loop(List(func(arg)), Nil).head
}
(gist with test)
You're probably aware, but your example (as well as the answer by #nazarii-bardiuk) is used in the book Scala with Cats by Noel Welsh and Dave Gurnell (highly recommended).
In scalaz 7.2.6, I want to implement sequence on Disjunction, such that if there is one or more lefts, it returns a list of those, instead of taking only the first one (as in Disjunction.sequenceU):
import scalaz._, Scalaz._
List(1.right, 2.right, 3.right).sequence
res1: \/-(List(1, 2, 3))
List(1.right, "error2".left, "error3".left).sequence
res2: -\/(List(error2, error3))
I've implemented it as follows and it works, but it looks ugly. Is there a getRight method (such as in scala Either class, Right[String, Int](3).right.get)? And how to improve this code?
implicit class RichSequence[L, R](val l: List[\/[L, R]]) {
def getLeft(v: \/[L, R]):L = v match { case -\/(x) => x }
def getRight(v: \/[L, R]):R = v match { case \/-(x) => x }
def sequence: \/[List[L], List[R]] =
if (l.forall(_.isRight)) {
l.map(e => getRight(e)).right
} else {
l.filter(_.isLeft).map(e => getLeft(e)).left
}
}
Playing around I've implemented a recursive function for that, but the best option would be to use separate:
implicit class RichSequence[L, R](val l: List[\/[L, R]]) {
def sequence: \/[List[L], List[R]] = {
def seqLoop(left: List[L], right: List[R], list: List[\/[L, R]]): \/[List[L], List[R]] =
list match {
case (h :: t) =>
h match {
case -\/(e) => seqLoop(left :+ e, right, t)
case \/-(s) => seqLoop(left, right :+ s, t)
}
case Nil =>
if(left.isEmpty) \/-(right)
else -\/(left)
}
seqLoop(List(), List(), l)
}
def sequenceSeparate: \/[List[L], List[R]] = {
val (left, right) = l.separate[\/[L, R], L, R]
if(left.isEmpty) \/-(right)
else -\/(left)
}
}
The first one just collects results and at the end decide what to do with those, the second its basically the same with the exception that the recursive function is much simpler, I didn't think about performance here, I've used :+, if you care use prepend or some other collection.
You may also want to take a look at Validation and ValidationNEL which unlike Disjunction accumulate failures.
An example from lift cookbook, the pattern matching is bit curious here.
serve("issues" / "by-state" prefix {
case "open" :: Nil XmlGet _ => <p>None open</p>
case "closed" :: Nil XmlGet _ => <p>None closed</p>
case "closed" :: Nil XmlDelete _ => <p>All deleted</p>
})
I don't understand what the XmlGet _ part is doing.
Could anyone explain a bit?
One of Scala's nice niche features is that many binary operations (e. g. f(x, y)) can be invoked from the infix position x f y. This applies to normal method calls:
case class InfixMethodCalls(x: Int) {
def wild(y: Int): Int = x + y
}
val infix = InfixMethodCalls(3)
infix wild 4
type constructors:
// A simple union type based on http://www.scalactic.org/
trait Or[A, B]
case class Good[A, B](value: A) extends Or[A, B]
case class Bad[A, B](value: B) extends Or[A, B]
def myMethod(x: Int Or String): Int
// This is the same as
def myMethod(x: Or[Int, String]): Int
and unapply / unapplySeq:
object InfixMagic {
def unapply(x: Any) = Option((List(x), x))
}
123 match {
case v :: Nil InfixMagic x => println(s"got v: $v and x: $x")
}
// is the same as
123 match {
case InfixMagic(v :: Nil, x) => println(s"got v: $v and x: $x")
}
So in the case of XmlGet this syntax here:
case "open" :: Nil XmlGet _ =>
is the same as:
case XmlGet("open" :: Nil, _) =>
And the _ is ignoring the Req parameter, which is the second part of the returned value from TestGet.unapply.
If you go through RuleHelper class of liftweb framework one will be able to make few assumptions.
XmlGet and XmlDelete extends TestGet trait with unapply method and Request argument. So this part basically means: check if it's XmlGet\XmlDelete method with any request.
How list is separated from second part? Good question. Suppose implicit listStringToSuper and listServeMagic used for this purpose.
https://github.com/lift/framework/blob/master/web/webkit/src/main/scala/net/liftweb/http/rest/RestHelper.scala
The following code is adapted from a paper (R. O. Bjarnason, Stackless Scala With Free Monads).
The title of the paper points to the purpose of the proposed data structures in general - that is to afford recursive processing in constant stack space, and to let the user express recursion in a clear way.
Specificly, my goal is to have a monadic structure that affords structural rewriting of an immutable Tree of Pairs (binary tree) or of Lists (n-ary-tree) based on simple pattern matching in constant stack space when ascending.
sealed trait Free[S[+_], +A]{
private case class FlatMap[S[+_], A, +B](
a: Free[S, A],
f: A => Free[S, B]
) extends Free[S, B]
def map[B](f: A => B): Free[S, B] = this.flatMap((a:A) => Done[S, B](f(a)))
def flatMap[B](f: A => Free[S, B]): Free[S, B] = this match {
case FlatMap(a, g) => FlatMap(a, (x: Any) => g(x).flatMap(f))
case x => FlatMap(x, f)
}
#tailrec
final def resume(implicit S: Functor[S]): Either[S[Free[S, A]], A] = {
this match {
case Done(a) => Right(a)
case More(k) => Left(k)
case FlatMap(a, f) => a match {
case Done(a) => f(a).resume
case More(k) => Left(S.map(k)((x)=>x.flatMap(f)))
case FlatMap(b, g) => b.flatMap((x: Any) => g(x).flatMap(f)).resume
}
}
}
}
case class Done[S[+_], +A](a: A) extends Free[S, A]
case class More[S[+_], +A](k: S[Free[S, A]]) extends Free[S,A]
trait Functor[F[+_]] {
def map[A, B](m: F[A])(f: A => B): F[B]
}
type RoseTree[+A] = Free[List, A]
implicit object listFunctor extends Functor[List] {
def map[A, B](a: List[A])(f: A => B) = a.map(f)
}
var tree : Free[List, Int]= More(List(More(List(More(List(Done(1), Done(2))), More(List(Done(3), Done(4))))), More(List(More(List(Done(5), Done(6))), More(List(Done(7), Done(8)))))))
How is the rewriting achieved using Free?
Where is a hook for the pattern matcher? - The pattern matcher has to be exposed to each entire subtree when ascending!
Can this be done within a for block?
[The question was edited.]
Update: the answer below addresses an earlier version of the question, but is mostly still relevant.
First of all, your code isn't going to work as it is. You can either make everything invariant, or go with the variance annotations in the original paper. For the sake of simplicity I'll take the invariant route (see here for a complete example), but I've also just confirmed that the version in the paper will work on 2.10.2.
To answer your first question first: your binary tree type is isomorphic to BinTree[Int]. Before we can show this, though, we need a functor for our pair type:
implicit object pairFunctor extends Functor[Pair] {
def map[A, B](a: Pair[A])(f: A => B) = (f(a._1), f(a._2))
}
Now we can use resume, which we'll need for the conversion from BinTree back to T:
def from(tree: T): BinTree[Int] = tree match {
case L(i) => Done(i)
case F((l, r)) => More[Pair, Int]((from(l), from(r)))
}
def to(tree: BinTree[Int]): T = tree.resume match {
case Left((l, r)) => F((to(l), to(r)))
case Right(i) => L(i)
}
Now we can define your example tree:
var z = 0
def f(i: Int): T = if (i > 0) F((f(i - 1), f(i - 1))) else { z = z + 1; L(z) }
val tree = f(3)
Let's demonstrate our isomorphism and the monad for BinTree by replacing every leaf value with the tree containing its predecessor and successor:
val newTree = to(
from(tree).flatMap(i => More[Pair, Int]((Done(i - 1), Done(i + 1))))
)
After some reformatting, the result will look like this:
F((
F((
F((
F((L(0), L(2))),
F((L(1), L(3)))
)),
F((
F((L(2), L(4))),
F((L(3), L(5)))
)),
...
And so on, as expected.
For your second question: if you want to do the same kind of thing for a rose tree, you'd just replace the pair with a list (or a stream). You'll need to provide a functor instance for lists, as we did above for pairs, and then you've got a tree with Done(x) representing leaves and More(xs) for branches.
Your type needs map for the for-comprehension syntax to work. Fortunately you can write map in terms of flatMap and Doneājust add the following to the definition of Free:
def map[B](f: A => B): Free[S, B] = this.flatMap(f andThen Done.apply)
Now the following is exactly the same as the newTree above:
val newTree = to(
for {
i <- from(tree)
m <- More[Pair, Int]((Done(i - 1), Done(i + 1)))
} yield m
)
The same thing will work with the Free[List, _] rose tree.