How to make tree mapping tail-recursive? - scala

Suppose I have a tree data structure like this:
trait Node { val name: String }
case class BranchNode(name: String, children: List[Node]) extends Node
case class LeafNode(name: String) extends Node
Suppose also I've got a function to map over leaves:
def mapLeaves(root: Node, f: LeafNode => LeafNode): Node = root match {
case ln: LeafNode => f(ln)
case bn: BranchNode => BranchNode(bn.name, bn.children.map(ch => mapLeaves(ch, f)))
}
Now I am trying to make this function tail-recursive but having a hard time to figure out how to do it. I've read this answer but still don't know to make that binary tree solution work for a multiway tree.
How would you rewrite mapLeaves to make it tail-recursive?

"Call stack" and "recursion" are merely popular design patterns that later got incorporated into most programming languages (and thus became mostly "invisible"). There is nothing that prevents you from reimplementing both with heap data structures. So, here is "the obvious" 1960's TAOCP retro-style solution:
trait Node { val name: String }
case class BranchNode(name: String, children: List[Node]) extends Node
case class LeafNode(name: String) extends Node
def mapLeaves(root: Node, f: LeafNode => LeafNode): Node = {
case class Frame(name: String, mapped: List[Node], todos: List[Node])
#annotation.tailrec
def step(stack: List[Frame]): Node = stack match {
// "return / pop a stack-frame"
case Frame(name, done, Nil) :: tail => {
val ret = BranchNode(name, done.reverse)
tail match {
case Nil => ret
case Frame(tn, td, tt) :: more => {
step(Frame(tn, ret :: td, tt) :: more)
}
}
}
case Frame(name, done, x :: xs) :: tail => x match {
// "recursion base"
case l # LeafNode(_) => step(Frame(name, f(l) :: done, xs) :: tail)
// "recursive call"
case BranchNode(n, cs) => step(Frame(n, Nil, cs) :: Frame(name, done, xs) :: tail)
}
case Nil => throw new Error("shouldn't happen")
}
root match {
case l # LeafNode(_) => f(l)
case b # BranchNode(n, cs) => step(List(Frame(n, Nil, cs)))
}
}
The tail-recursive step function takes a reified stack with "stack frames". A "stack frame" stores the name of the branch node that is currently being processed, a list of child nodes that have already been processed, and the list of the remaining nodes that still must be processed later. This roughly corresponds to an actual stack frame of your recursive mapLeaves function.
With this data structure,
returning from recursive calls corresponds to deconstructing a Frame object, and either returning the final result, or at least making the stack one frame shorter.
recursive calls correspond to a step that prepends a Frame to the stack
base case (invoking f on leaves) does not create or remove any frames
Once one understands how the usually invisible stack frames are represented explicitly, the translation is straightforward and mostly mechanical.
Example:
val example = BranchNode("x", List(
BranchNode("y", List(
LeafNode("a"),
LeafNode("b")
)),
BranchNode("z", List(
LeafNode("c"),
BranchNode("v", List(
LeafNode("d"),
LeafNode("e")
))
))
))
println(mapLeaves(example, { case LeafNode(n) => LeafNode(n.toUpperCase) }))
Output (indented):
BranchNode(x,List(
BranchNode(y,List(
LeafNode(A),
LeafNode(B)
)),
BranchNode(z, List(
LeafNode(C),
BranchNode(v,List(
LeafNode(D),
LeafNode(E)
))
))
))

It might be easier to implement it using a technique called trampoline.
If you use it, you'd be able to use two functions calling itself doing mutual recursion (with tailrec, you are limited to one function). Similarly to tailrec this recursion will be transformed to plain loop.
Trampolines are implemented in scala standard library in scala.util.control.TailCalls.
import scala.util.control.TailCalls.{TailRec, done, tailcall}
def mapLeaves(root: Node, f: LeafNode => LeafNode): Node = {
//two inner functions doing mutual recursion
//iterates recursively over children of node
def iterate(nodes: List[Node]): TailRec[List[Node]] = {
nodes match {
case x :: xs => tailcall(deepMap(x)) //it calls with mutual recursion deepMap which maps over children of node
.flatMap(node => iterate(xs).map(node :: _)) //you can flat map over TailRec
case Nil => done(Nil)
}
}
//recursively visits all branches
def deepMap(node: Node): TailRec[Node] = {
node match {
case ln: LeafNode => done(f(ln))
case bn: BranchNode => tailcall(iterate(bn.children))
.map(BranchNode(bn.name, _)) //calls mutually iterate
}
}
deepMap(root).result //unwrap result to plain node
}
Instead of TailCalls you could also use Eval from Cats or Trampoline from scalaz.
With that implementation function worked without problems:
def build(counter: Int): Node = {
if (counter > 0) {
BranchNode("branch", List(build(counter-1)))
} else {
LeafNode("leaf")
}
}
val root = build(4000)
mapLeaves(root, x => x.copy(name = x.name.reverse)) // no problems
When I ran that example with your implementation it caused java.lang.StackOverflowError as expected.

Related

Scala - combine list of items into object

Say I have the following:
trait PropType
case class PropTypeA(String value) extends PropType
case class PropTypeB(String value) extends PropType
case class Item(
propTypeA: PropTypeA,
propTypeB: PropTypeB
)
and that I'm given a List[PropType]. How would I go with combining this into a List[Item]?
That is (and assuming we only have PropTypeA(name: String) and PropTypeB(name: String) to make this shorter / easier to follow hopefully) given this:
List[PropType](
PropTypeA("item1-propTypeA"),
PropTypeB("item1-propTypeB"),
PropTypeA("item2-propTypeA"),
PropTypeB("item2-propTypeB")
]
I'd like to get the equivalent of:
List[Item](
Item(PropTypeA("item1-propTypeA"), PropTypeB("item1-propTypeB")),
Item(PropTypeA("item2-propTypeA"), PropTypeB("item2-propTypeB"))
)
Kind of table building from linearized rows across columns, if that makes sense.
Note that in general there might be incomplete "rows", e.g. this:
List[PropType](
PropTypeA("item1-propTypeA"),
PropTypeB("item1-propTypeB"),
PropTypeB("itemPartialXXX-propTypeB"),
PropTypeA("itemPartialYYY-propTypeA"),
PropTypeA("item2-propTypeA"),
PropTypeB("item2-propTypeB")
]
should generate the same output as the above, with the logic being that PropTypeA always marks the start of a new row and thus everything "unused" is discarded.
How should I approach this?
Something like this will work with examples you mentioned.
list.grouped(2).collect { case Seq(a: PropTypeA, b: PropTypeB) => Item(a,b) }.toList
However it is unclear from your question what other cases you want to handle and how. For example, how exactly do you define the "partial" occurrence. Are there always two elements in reverse order? Can there be just one, or three? Can there be two As in a row? Or three? Or two Bs?
For example, A, A, A, A, B or B, B, A, A, B or just A?
Depending on how you answer those question, you'll need to somehow "pre-filter" the list before hand.
Here is an implementation based on the last phrase in your question: "PropTypeA always marks the start of a new row and thus everything "unused" is discarded." It only looks for instances where an A is immediately followed by B and discards everything else:
list.foldLeft(List.empty[PropType]) {
case ((a: PropTypeA) :: tail, b: PropTypeB) => b :: a :: tail
case ((b: PropTypeB) :: tail, a: PropTypeA) => a :: b :: tail
case (Nil, a: PropTypeA) => a :: Nil
case (_ :: tail, a: PropTypeA) => a :: tail
case (list, _) => list
}.reverse.grouped(2).collect {
case Seq(a: PropTypeA, b: PropTypeB) => Item(a,b)
}.toList
If you have more than just two types, then there are even more questions: what happens if stuff after A comes in wrong order for example? Like what do you do with A,B,C,A,C,B?
But basically, it would be the same idea as above: if next element is of the type you expect in the the sequence, add it to the result, otherwise discard sequence and keep going.
we can use the tail recursion function to generate the list of a new type.
def transformType(proptypes: List[PropType]): List[Item] =
{
// tail recursion function defined
#tailrec
def transform(proptypes: List[PropType], items: List[Item]): List[Item]=
{
proptypes match {
case (first:PropTypeA) :: (second:PropTypeB) :: tail=> transform(tail, items :+ Item(first, second))
case (first:PropTypeA) :: (second:PropTypeA) :: tail => transform(second :: tail, items :+ Item(first, PropTypeB("")))
case (first:PropTypeB) :: tail => transform(tail, items :+ Item(PropTypeA(""), first))
case (first:PropTypeA) :: tail => transform(tail, items :+ Item(first, PropTypeB("")))
case _ => items
}
}
transform(proptypes, List.empty[Item])
}
you can find the working link here

How to convert tail recursive method to more Scala-like function?

In my code, I very often need to process a list by performing operations on an internal model. For each processed element, the model is returned and then a 'new' model is used for the next element of the list.
Usually, I implement this by using a tail recursive method:
def createCar(myModel: Model, record: Record[Any]): Either[CarError, Model] = {
record match {
case c: Car =>
// Do car stuff...
val newModel: Model = myModel.createCar(record)
Right(newModel)
case _ => Left(CarError())
}
}
#tailrec
def processCars(myModel: Model, records: List[Record[Any]]): Either[CarError, Model] =
records match {
case x :: xs =>
createCar(myModel, x) match {
case Right(m) => processCars(m, xs)
case e#Left(_) => e
}
case Nil => Right(myModel)
}
Since I keep repeating this kind of pattern, I am searching for ways to make it more concise and more functional (i.e., the Scala way).
I have looked into foldLeft, but cannot get it to work with Either:
recordsList.foldLeft(myModel) { (m, r) =>
// Do car stuff...
Right(m)
}
Is foldLeft a proper replacement? How can I get it to work?
Following up on my earlier comment, here's how to unfold() to get your result. [Note: Scala 2.13.x]
def processCars(myModel: Model
,records: List[Record[_]]
): Either[CarError, Model] =
LazyList.unfold((myModel,records)) { case (mdl,recs) =>
recs.headOption.map{
createCar(mdl, _).fold(Left(_) -> (mdl,Nil)
,m => Right(m) -> (m,recs.tail))
}
}.last
The advantage here is:
early termination - Iterating through the records stops after the 1st Left is returned or after all the records have been processed, whichever comes first.
memory efficient - Since we're building a LazyList, and nothing is holding on to the head of the resulting list, every element except the last should be immediately released for garbage collection.
You can do that using fold like that:
def processCars(myModel: Model, records: List[Record[Any]]): Either[CarError, Model] = {
records.foldLeft[Either[CarError, Model]](Right(myModel))((m, r) => {
m.fold(Left.apply, { model =>
createCar(model, r).fold(Left.apply, Right.apply)
})
})
}

Scala list foreach, update list while in foreach loop

I just started working with scala and am trying to get used to the language. I was wondering if the following is possible:
I have a list of Instruction objects that I am looping over with the foreach method. Am I able to add elements to my Instruction list while I am looping over it? Here is a code example to explain what I want:
instructions.zipWithIndex.foreach { case (value, index) =>
value match {
case WhileStmt() => {
---> Here I want to add elements to the instructions list.
}
case IfStmt() => {
...
}
_ => {
...
}
Idiomatic way would be something like this for rather complex iteration and replacement logic:
#tailrec
def idiomaticWay(list: List[Instruction],
acc: List[Instruction] = List.empty): List[Instruction] =
list match {
case WhileStmt() :: tail =>
// add element to head of acc
idiomaticWay(tail, CherryOnTop :: acc)
case IfStmt() :: tail =>
// add nothing here
idiomaticWay(tail, list.head :: acc)
case Nil => acc
}
val updatedList = idiomaticWay(List(WhileStmt(), IfStmt()))
println(updatedList) // List(IfStmt(), CherryOnTop)
This solution works with immutable list, returns immutable list which has different values in it according to your logic.
If you want to ultimately hack around (add, remove, etc) you could use Java ListIterator class that would allow you to do all operations mentioned above:
def hackWay(list: util.List[Instruction]): Unit = {
val iterator = list.listIterator()
while(iterator.hasNext) {
iterator.next() match {
case WhileStmt() =>
iterator.set(CherryOnTop)
case IfStmt() => // do nothing here
}
}
}
import collection.JavaConverters._
val instructions = new util.ArrayList[Instruction](List(WhileStmt(), IfStmt()).asJava)
hackWay(instructions)
println(instructions.asScala) // Buffer(CherryOnTop, IfStmt())
However in the second case you do not need scala :( So my advise would be to stick to immutable data structures in scala.

Split a list into a target element, and the rest of the list?

Let's say I have something like the following:
case class Thing(num: Int)
val xs = List(Thing(1), Thing(2), Thing(3))
What I'd like to do is separate the list into one particular value, and the rest of the list. The target value can be at any position in the list, or may not be present at all. The single value needs to be handled separately, after the other values are handled, so I can't simply use pattern matching.
What I have so far is this:
val (targetList, rest) = xs.partition(_.num == 2)
val targetEl = targetList match {
case x :: Nil => x
case _ => null
}
Is it possible to combine the two steps? Like
val (targetEl, rest) = xs.<some_method>
A note on handling order:
The reason that the target element must be handled last is that this is for use in a HTML template (Play framework). The other elements are looped through, and a HTML element is rendered for each. After that group of elements, another HTML element is created for the target element.
You can do it with pattern-matching in map, you just need multiple cases:
xs map {
case t # Thing(1) => // do something with thing 1
case t => // do something with the other things
}
To handle the OP's extra requirements:
xs map {
case t # Thing(num) if(num != 1) => // do something with things that are not "1"
case t => // do something with thing 1
}
Following produces two lists as tuples for some condition.
case class Thing(num: Int)
val xs = List(Thing(1), Thing(2), Thing(3))
val partioned = xs.foldLeft((List.empty[Thing], List.empty[Thing]))((x, y) => y match {
case t # Thing(1) => (x._1, t :: x._2)
case t => (t :: x._1, x._2)
})
//(List(Thing(3), Thing(2)),List(Thing(1)))
Try this:
val (targetEl, rest) = (xs.head, xs.tail)
It works for non-empty list. Nil case must be handled separately.
After some experimentation, I've come up with the following, which is almost what I'm looking for:
var (maybeTargetEl, rest) = xs
.foldLeft((Option.empty[Thing], List[Thing]())) { case ((opt, ls), x) =>
if (x.num == 1)
(Some(x), ls)
else
(opt, x :: ls)
}
The target value is still wrapped in a container, but at least it guarantees a single value.
After that I can do
rest map <some_method>
maybeTargetEl map <some_other_method>
If the order of the original list is important:
var (maybeTargetEl, rest) = xs.
foldLeft((Option.empty[Thing], ListBuffer[Thing]())){ case ((opt, lb), x) =>
if (x.num == 1)
(Some(x), ls)
else
(opt, lb += x)
} match {
case (opt, lb) => (opt, lb.toList)
}
#evanjdooner Your solution with fold works if target element is present only once. If you want to extract only one occurrence of target element:
def find(xs: List[T], target: T, prefix: List[T]) = xs match {
case target :: tail => (target, prefix ::: tail)
case other :: tail => find(tail, target, other :: prefix)
case Nil => throw new Exception("Not found")
}
val (el, rest) = find(xs, target, Nil)
Sorry, I can't add it as a comment.

map over structure with only partial match

I have a tree-like structure of abstract classes and case classes representing an Abstract Syntax Tree of a small language.
For the top abstract class i've implemented a method map:
abstract class AST {
...
def map(f: (AST => AST)): AST = {
val b1 = this match {
case s: STRUCTURAL => s.smap(f) // structural node for example IF(expr,truebranch,falsebranch)
case _ => this // leaf, // leaf, like ASSIGN(x,2)
}
f(b1)
}
...
The smap is defined like:
override def smap(f: AST => AST) = {
this.copy(trueb = trueb.map(f), falseb = falseb.map(f))
}
Now im writing different "transformations" to insert, remove and change nodes in the AST.
For example, remove adjacent NOP nodes from blocks:
def handle_list(l:List[AST]) = l match {
case (NOP::NOP::tl) => handle_list(tl)
case h::tl => h::handle_list(tl)
case Nil => Nil
}
ast.map {
case BLOCK(listofstatements) => handle_list(listofstatements)
}
If I write like this, I end up with MatchError and I can "fix it" by changing the above map to:
ast.map {
case BLOCK(listofstatements) => handle_list(listofstatements)
case a => a
}
Should I just live with all those case a => a or could I improve my map method(or other parts) in some way?
Make the argument to map a PartialFunction:
def map(f: PartialFunction[AST, AST]): AST = {
val idAST: PartialFunction[AST, AST] = {case a => a}
val g = f.orElse(idAST)
val b1 = this match {
case s: STRUCTURAL => s.smap(g)
case _ => this
}
g(b1)
}
If tree transformations are more than a minor aspect of your project, I highly recommend you use Kiama's Rewriter module to implement them. It implements Stratego-like strategy-driven transformations. It has a very rich set of strategies and strategy combinators that permit a complete separation of traversal logic (which for the vast majority of cases can be taken "off the shelf" from the supplied strategies and combinators) from (local) transformations (which are specific to your AST and you supply, of course).