map over structure with only partial match - scala

I have a tree-like structure of abstract classes and case classes representing an Abstract Syntax Tree of a small language.
For the top abstract class i've implemented a method map:
abstract class AST {
...
def map(f: (AST => AST)): AST = {
val b1 = this match {
case s: STRUCTURAL => s.smap(f) // structural node for example IF(expr,truebranch,falsebranch)
case _ => this // leaf, // leaf, like ASSIGN(x,2)
}
f(b1)
}
...
The smap is defined like:
override def smap(f: AST => AST) = {
this.copy(trueb = trueb.map(f), falseb = falseb.map(f))
}
Now im writing different "transformations" to insert, remove and change nodes in the AST.
For example, remove adjacent NOP nodes from blocks:
def handle_list(l:List[AST]) = l match {
case (NOP::NOP::tl) => handle_list(tl)
case h::tl => h::handle_list(tl)
case Nil => Nil
}
ast.map {
case BLOCK(listofstatements) => handle_list(listofstatements)
}
If I write like this, I end up with MatchError and I can "fix it" by changing the above map to:
ast.map {
case BLOCK(listofstatements) => handle_list(listofstatements)
case a => a
}
Should I just live with all those case a => a or could I improve my map method(or other parts) in some way?

Make the argument to map a PartialFunction:
def map(f: PartialFunction[AST, AST]): AST = {
val idAST: PartialFunction[AST, AST] = {case a => a}
val g = f.orElse(idAST)
val b1 = this match {
case s: STRUCTURAL => s.smap(g)
case _ => this
}
g(b1)
}

If tree transformations are more than a minor aspect of your project, I highly recommend you use Kiama's Rewriter module to implement them. It implements Stratego-like strategy-driven transformations. It has a very rich set of strategies and strategy combinators that permit a complete separation of traversal logic (which for the vast majority of cases can be taken "off the shelf" from the supplied strategies and combinators) from (local) transformations (which are specific to your AST and you supply, of course).

Related

How to make tree mapping tail-recursive?

Suppose I have a tree data structure like this:
trait Node { val name: String }
case class BranchNode(name: String, children: List[Node]) extends Node
case class LeafNode(name: String) extends Node
Suppose also I've got a function to map over leaves:
def mapLeaves(root: Node, f: LeafNode => LeafNode): Node = root match {
case ln: LeafNode => f(ln)
case bn: BranchNode => BranchNode(bn.name, bn.children.map(ch => mapLeaves(ch, f)))
}
Now I am trying to make this function tail-recursive but having a hard time to figure out how to do it. I've read this answer but still don't know to make that binary tree solution work for a multiway tree.
How would you rewrite mapLeaves to make it tail-recursive?
"Call stack" and "recursion" are merely popular design patterns that later got incorporated into most programming languages (and thus became mostly "invisible"). There is nothing that prevents you from reimplementing both with heap data structures. So, here is "the obvious" 1960's TAOCP retro-style solution:
trait Node { val name: String }
case class BranchNode(name: String, children: List[Node]) extends Node
case class LeafNode(name: String) extends Node
def mapLeaves(root: Node, f: LeafNode => LeafNode): Node = {
case class Frame(name: String, mapped: List[Node], todos: List[Node])
#annotation.tailrec
def step(stack: List[Frame]): Node = stack match {
// "return / pop a stack-frame"
case Frame(name, done, Nil) :: tail => {
val ret = BranchNode(name, done.reverse)
tail match {
case Nil => ret
case Frame(tn, td, tt) :: more => {
step(Frame(tn, ret :: td, tt) :: more)
}
}
}
case Frame(name, done, x :: xs) :: tail => x match {
// "recursion base"
case l # LeafNode(_) => step(Frame(name, f(l) :: done, xs) :: tail)
// "recursive call"
case BranchNode(n, cs) => step(Frame(n, Nil, cs) :: Frame(name, done, xs) :: tail)
}
case Nil => throw new Error("shouldn't happen")
}
root match {
case l # LeafNode(_) => f(l)
case b # BranchNode(n, cs) => step(List(Frame(n, Nil, cs)))
}
}
The tail-recursive step function takes a reified stack with "stack frames". A "stack frame" stores the name of the branch node that is currently being processed, a list of child nodes that have already been processed, and the list of the remaining nodes that still must be processed later. This roughly corresponds to an actual stack frame of your recursive mapLeaves function.
With this data structure,
returning from recursive calls corresponds to deconstructing a Frame object, and either returning the final result, or at least making the stack one frame shorter.
recursive calls correspond to a step that prepends a Frame to the stack
base case (invoking f on leaves) does not create or remove any frames
Once one understands how the usually invisible stack frames are represented explicitly, the translation is straightforward and mostly mechanical.
Example:
val example = BranchNode("x", List(
BranchNode("y", List(
LeafNode("a"),
LeafNode("b")
)),
BranchNode("z", List(
LeafNode("c"),
BranchNode("v", List(
LeafNode("d"),
LeafNode("e")
))
))
))
println(mapLeaves(example, { case LeafNode(n) => LeafNode(n.toUpperCase) }))
Output (indented):
BranchNode(x,List(
BranchNode(y,List(
LeafNode(A),
LeafNode(B)
)),
BranchNode(z, List(
LeafNode(C),
BranchNode(v,List(
LeafNode(D),
LeafNode(E)
))
))
))
It might be easier to implement it using a technique called trampoline.
If you use it, you'd be able to use two functions calling itself doing mutual recursion (with tailrec, you are limited to one function). Similarly to tailrec this recursion will be transformed to plain loop.
Trampolines are implemented in scala standard library in scala.util.control.TailCalls.
import scala.util.control.TailCalls.{TailRec, done, tailcall}
def mapLeaves(root: Node, f: LeafNode => LeafNode): Node = {
//two inner functions doing mutual recursion
//iterates recursively over children of node
def iterate(nodes: List[Node]): TailRec[List[Node]] = {
nodes match {
case x :: xs => tailcall(deepMap(x)) //it calls with mutual recursion deepMap which maps over children of node
.flatMap(node => iterate(xs).map(node :: _)) //you can flat map over TailRec
case Nil => done(Nil)
}
}
//recursively visits all branches
def deepMap(node: Node): TailRec[Node] = {
node match {
case ln: LeafNode => done(f(ln))
case bn: BranchNode => tailcall(iterate(bn.children))
.map(BranchNode(bn.name, _)) //calls mutually iterate
}
}
deepMap(root).result //unwrap result to plain node
}
Instead of TailCalls you could also use Eval from Cats or Trampoline from scalaz.
With that implementation function worked without problems:
def build(counter: Int): Node = {
if (counter > 0) {
BranchNode("branch", List(build(counter-1)))
} else {
LeafNode("leaf")
}
}
val root = build(4000)
mapLeaves(root, x => x.copy(name = x.name.reverse)) // no problems
When I ran that example with your implementation it caused java.lang.StackOverflowError as expected.

Scala: Find and update one element in a list

I am trying to find an elegant way to do:
val l = List(1,2,3)
val (item, idx) = l.zipWithIndex.find(predicate)
val updatedItem = updating(item)
l.update(idx, updatedItem)
Can I do all in one operation ? Find the item, if it exist replace with updated value and keep it in place.
I could do:
l.map{ i =>
if (predicate(i)) {
updating(i)
} else {
i
}
}
but that's pretty ugly.
The other complexity is the fact that I want to update only the first element which match predicate .
Edit: Attempt:
implicit class UpdateList[A](l: List[A]) {
def filterMap(p: A => Boolean)(update: A => A): List[A] = {
l.map(a => if (p(a)) update(a) else a)
}
def updateFirst(p: A => Boolean)(update: A => A): List[A] = {
val found = l.zipWithIndex.find { case (item, _) => p(item) }
found match {
case Some((item, idx)) => l.updated(idx, update(item))
case None => l
}
}
}
I don't know any way to make this in one pass of the collection without using a mutable variable. With two passes you can do it using foldLeft as in:
def updateFirst[A](list:List[A])(predicate:A => Boolean, newValue:A):List[A] = {
list.foldLeft((List.empty[A], predicate))((acc, it) => {acc match {
case (nl,pr) => if (pr(it)) (newValue::nl, _ => false) else (it::nl, pr)
}})._1.reverse
}
The idea is that foldLeft allows passing additional data through iteration. In this particular implementation I change the predicate to the fixed one that always returns false. Unfortunately you can't build a List from the head in an efficient way so this requires another pass for reverse.
I believe it is obvious how to do it using a combination of map and var
Note: performance of the List.map is the same as of a single pass over the list only because internally the standard library is mutable. Particularly the cons class :: is declared as
final case class ::[B](override val head: B, private[scala] var tl: List[B]) extends List[B] {
so tl is actually a var and this is exploited by the map implementation to build a list from the head in an efficient way. The field is private[scala] so you can't use the same trick from outside of the standard library. Unfortunately I don't see any other API call that allows to use this feature to reduce the complexity of your problem to a single pass.
You can avoid .zipWithIndex() by using .indexWhere().
To improve complexity, use Vector so that l(idx) becomes effectively constant time.
val l = Vector(1,2,3)
val idx = l.indexWhere(predicate)
val updatedItem = updating(l(idx))
l.updated(idx, updatedItem)
Reason for using scala.collection.immutable.Vector rather than List:
Scala's List is a linked list, which means data are access in O(n) time. Scala's Vector is indexed, meaning data can be read from any point in effectively constant time.
You may also consider mutable collections if you're modifying just one element in a very large collection.
https://docs.scala-lang.org/overviews/collections/performance-characteristics.html

Traverse and Modify an Heterogeneous Directed Acyclic Graph in Scala

Traverse and Modify an Heterogeneous Directed Acyclic Graph in Scala
Good morning everybody.
I have the following directed acyclic graph data structure, implemented in Scala as follows:
abstract class Node // Generic abstract node
/** Various kind of leaf nodes */
case class LeafNodeA(x: String) extends Node
case class LeafNodeB(x: Int) extends Node
/** Various kind of inner nodes */
case class InnerNode1(x: String, depRoleA: Node) extends Node
case class InnerNode2(x: String, y: Double, depRoleX: Node, depRoleY: Node) extends Node
case class InnerNode3(x: List[Int], depRoleA: Node, y: Int,
depRoleB: Node, depRoleG: Node) extends Node
In this structure a node can be a dependency of multiple nodes, therefore it is not a Tree but a Directed Acyclic Graph. In addition the structure is not even balanced (nodes have different numbers of dependencies).
The problem of traversal
Notice that I have called the various dependency fields of the case classes with different names since they represents different roles in the dependencies (for example, depRoleX has a different role than depRoleY in an InnerNode2 type node). For this reason I don't think it is possible to store the dependencies of each node in a List[Node] like in any trivial tree/dag implementation you can find out there, because the meaning of each dependency field is different.
Of course when I traverse this structure I have to do pattern matching in order to understand the type of node I am dealing with at the current recursion step:
// Random function which returns the list of all the String attributes of the nodes
def getAllStrings(dag: Node): List[String] = {
dag match {
case LeafNodeA(x) => List(x)
case LeafNodeB => List()
case InnerNode1(x, dr) => List(x) ::: getAllStrings(dr)
case InnerNode2(x, _, dX, dY) => List(x) ::: getAllStrings(dX) :: getAllStrings(dY)
case InnerNode3(_, dA, _, dB, dG) => getAllStrings(dA) ::: getAllStrings(dB) ::: getAllStrings(dG)
}
}
Now suppose that instead of these 5 relatively simple node types I have around 20 types of node. The previous function would become extremely long and repetitive (a case statement for each node type). Even worse: every time I want to do a traversal I have to do the same thing.
Thinking about this problem I came up with two solutions.
External method for traversal
The first (obvious) way to deal with this is to modularize the previous method defining a generic DAG traversal function
object DAGManipulator {
def getDependencies(dag: Node): List[Node] = {
dag match {
case LeafNodeA => List()
case LeafNodeB => List()
case InnerNode1(_, dr) => List(dr)
case InnerNode2(_, _, dX, dY) => List(dX, dY)
case InnerNode3(_, dA, _, dB, dG) => List(dA, dB, dG)
}
}
}
In this way, every time I need the dependencies of a node I can rely on this static function.
Abstract class method for getting the dependencies
The second solution I came up with is to give to every node an additional method in the following way:
abstract class Node {
def getDependencies : List[Node]
}
case class LeafNodeA(x: String) extends Node = {
override def getDependencies : List[Node] = List()
}
case class LeafNodeB(x: Int) extends Node = {
override def getDependencies : List[Node] = List()
}
/** Various kind of inner nodes */
case class InnerNode1(x: String, depRoleA: Node) extends Node = {
override def getDependencies : List[Node] = List()
}
case class InnerNode2(x: String, y: Double, depRoleX: Node, depRoleY: Node) extends Node = {
override def getDependencies : List[Node] = List(depRoleX, depRoleY)
}
case class InnerNode3(x: List[Int], depRoleA: Node, y: Int,
depRoleB: Node, depRoleG: Node) extends Node = {
override def getDependencies : List[Node] = List(depRoleA, depRoleB, depRoleG)
}
I don't like any of the previous solutions:
The first one must be updated every time a new node type is added to the hierarchy. In addition to this, it delegates a fundamental feature of the DAG structure (traversal) to an external object, which I find very unpleasant from the software engineering point of view.
The second solution in my opinion is even worse because every node type basically has to redundantly state its dependencies (once in its fields and once in the getDependencies method. I find this very ugly and prone to programming errors.
Do you have a better solution to this problem ?
The problem of updating
The second problem I have to deal with is the updating/modification of the data structure.
Suppose that I have a DAG defined in the following way.
val l1 = LeafNodeB(1)
val dag =
InnerNode3(List(1, 2, 3),
InnerNode1("InnerNode1", LeafNodeA("leafA1")),
1, l1, InnerNode2("InnerNode2", 2, l1, LeafNodeA("leafA2")))
corresponding to this structure.
Suppose that I want to change the LeafNodeA("leafA1") (which is a dependency of the InnerNode1) to, for example, l1, which is a LeafNodeB.
This is the kind of operation that I need to do:
def modify(dag: Node): Node = {
dag match {
case x: InnerNode1 => if(x.x == "InnerNode1") x.copy(depRoleA = l1) else x
case x: LeafNodeB => x
case x: LeafNodeA => x
case x: InnerNode2 => x.copy(depRoleX = modify(x.depRoleX), depRoleY = modify(x.depRoleY))
case x: InnerNode3 => x.copy(depRoleA = modify(x.depRoleA), depRoleB = modify(x.depRoleB), depRoleG = modify(x.depRoleG))
}
}
Again, consider the possibility of having more than 20 node types...Again this update method would become not practical, and this counts for every other possible update method that I can think of.
In addition to this...this time I did not come up with a different strategy for factorizing/modularize this "recursive traversal update" of the nested structure. I have to check for every possible node type in order to understand how to use the copy method of the various case classes.
Do you have a better solution / design for this update strategy ?
To address this issue:
The first one must be updated every time a new node type is added to
the hierarchy. In addition to this, it delegates a fundamental feature
of the DAG structure (traversal) to an external object, which I find
very unpleasant from the software engineering point of view.
I think this is not really a drawback, but this is something that's pretty core to the OO vs FP clash inherent in Scala. I would say your node classes being "dumb" data holders, and having a separate code path for traversing on them is a good thing. And sure, you have to add a line there every time you add a node, but the compiler can warn you about that if you don't.
Anyway, this might be overkill and it's a bit of an undertaking, but you may want to look into Matryoshka, which generalizes recursive data structures like this. It requires a bit of plumbing to translate your data types into the scheme that is expected, and define a functor for that:
abstract class NodeF[+A] // Generic abstract node
/** Various kind of leaf nodes */
case class LeafNodeA(x: String) extends NodeF[Nothing]
case class LeafNodeB(x: Int) extends NodeF[Nothing]
/** Various kind of inner nodes */
case class InnerNode1[A](x: String, depRoleA: A) extends NodeF[A]
case class InnerNode2[A](x: String, y: Double, depRoleX: A, depRoleY: A) extends NodeF[A]
case class InnerNode3[A](x: List[Int], depRoleA: A, y: Int,
depRoleB: A, depRoleG: A) extends NodeF[A]
implicit val nodeFunctor: Functor[NodeF] = new Functor[NodeF] {
def map[A, B](fa: NodeF[A])(f: A => B): NodeF[B] = fa match {
case LeafNodeA(x) => LeafNodeA(x)
case LeafNodeB(x) => LeafNodeB(x)
case InnerNode1(x, depA) => InnerNode1(x, f(depA))
case InnerNode2(x, y, depX, depY) => InnerNode2(x, y, f(depX), f(depY))
case InnerNode3(x, depA, y, depB, depG) => InnerNode3(x, f(depA), y, f(depB), f(depG))
}
}
But then it essentially hides the recursion from you and you can more easily define these kinds of things:
type FixNode = Fix[NodeF]
def someExprGeneric[T](implicit T : Corecursive.Aux[T, NodeF]): T =
InnerNode2("hello", 1.0, InnerNode1("world", LeafNodeA("!").embed).embed, LeafNodeB(1).embed).embed
val someExpr = someExprGeneric[FixNode]
def getStrings: Algebra[NodeF, List[String]] = {
case LeafNodeA(x) => List(x)
case LeafNodeB(_) => List()
case InnerNode1(x, depA) => x :: depA
case InnerNode2(x, _, depX, depY) => x :: depX ::: depY
case InnerNode3(_, depA, _, depB, depG) => depA ::: depB ::: depG
}
someExpr.cata(getStrings) // List("hello", "world", "!")
Perhaps that's not that much cleaner than what you have, but it at least separates the recursive traversal logic from the "single step" evaluation logic. But I think where it shines a bit more is when updating:
def expandToUniverse: Algebra[NodeF, Node] = {
case InnerNode1("world", dep) => InnerNode1("universe", dep).embed
case x => x.embed
}
someExpr.cata(expandToUniverse).cata(getStrings) // List("hello", "universe", "!")
Because you've delegated out that recursion, you only have to implement the case(s) you actually care about.

Traverse/fold a nested case class in Scala without boilerplate code

I have some case classes for a mix of sum and product types:
sealed trait Leaf
case class GoodLeaf(value: Int) extends Leaf
case object BadLeaf extends Leaf
case class Middle(left: Leaf, right: Leaf)
case class Container(leaf: Leaf)
case class Top(middle : Middle, container: Container, extraLeaves : List[Leaf])
I want to do some fold-like operations with this Top structure. Examples include:
Count the occurrences of BadLeaf
Sum all the values in the GoodLeafs
Here is some code that does the operations:
object Top {
def fold[T](accu: T)(f : (T, Leaf) => T)(top: Top) = {
val allLeaves = top.container.leaf :: top.middle.left :: top.middle.right :: top.extraLeaves
allLeaves.foldLeft(accu)(f)
}
private def countBadLeaf(count: Int, leaf : Leaf) = leaf match {
case BadLeaf => count + 1
case _ => count
}
def countBad(top: Top): Int = fold(0)(countBadLeaf)(top)
private def sumGoodLeaf(count: Int, leaf : Leaf) = leaf match {
case GoodLeaf(v) => count + v
case _ => count
}
def sumGoodValues(top: Top) = fold(0)(sumGoodLeaf)(top)
}
The real life structure I am dealing with is significantly more complicated than the example I made up. Are there any techniques that could help me avoid writing lots of boilerplate code?
I already have the cats library as a dependency, so a solution that uses that lib would be preferred. I am open to including new dependencies in order to solve this problem.
For my particular example, the definition is not recursive, but I'd be interested in seeing a solution that works for recursive definitions also.
You could just create a function returning all leaves for Top, like you did with allLeaves, so you can just work with a List[Leaf] (with all the existing fold and other functions the Scala library, Cats, etc provide).
For example :
def topLeaves(top: Top): List[Leaf] =
top.container.leaf :: top.middle.left :: top.middle.right :: top.extraLeaves
val isBadLeaf: Leaf => Boolean = {
case BadLeaf => true
case _ => false
}
val leafValue: Leaf => Int = {
case GoodLeaf(v) => v
case _ => 0
}
Which you could use as
import cats.implicits._
// or
// import cats.instances.int._
// import cats.instances.list._
// import cats.syntax.foldable._
val leaves = topLeaves(someTop)
val badCount = leaves.count(isBadLeaf)
val badAndGood = leaves.partition(isBadLeaf) // (List[Leaf], List[Leaf])
val sumLeaves = leaves.foldMap(leafValue)
I am not sure if this helps with your actual use case ? In general with a heterogeneous structure (like your Top) you probably want to convert it somehow to something more homogeneous (like a List[Leaf] or Tree[Leaf]) where you can fold over.
If you have a recursive structure you could look at some talks about recursion schemes (with the Matryoshka library in Scala).

How to match against the pattern of a partial function's case definition in a Scala macro?

As part of a macro, I want to manipulate the case definitions of a partial function.
To do so, I use a Transformer to manipulate the case definitions of the partial function and a Traverser to inspect the patterns of the case definitions:
def myMatchImpl[A: c.WeakTypeTag, B: c.WeakTypeTag](c: Context)
(expr: c.Expr[A])(patterns: c.Expr[PartialFunction[A, B]]): c.Expr[B] = {
import c.universe._
val transformer = new Transformer {
override def transformCaseDefs(trees: List[CaseDef]) = trees map {
case caseDef # CaseDef(pattern, guard , body) => {
// println(show(pattern))
val traverser = new Traverser {
override def traverse(tree: Tree) = tree match {
// match against a specific pattern
}
}
traverser.traverse(pattern)
}
}
}
val transformedPartialFunction = transformer.transform(patterns.tree)
c.Expr[B](q"$transformedPartialFunction($expr)")
}
Now let us assume, the interesting data I want to match against is represented by the class Data (which is part of the object Example):
case class Data(x: Int, y: String)
When now invoking the macro on the example below
abstract class Foo
case class Bar(data: Data) extends Foo
case class Baz(string: String, data: Data) extends Foo
def test(foo: Foo) = myMatch(foo){
case Bar(Data(x,y)) => y
case Baz(_, Data(x,y)) => y
}
the patterns of the case definitions of the partial function are transformed by the compiler as following (the Foo, Bar, and Baz classes are members of the object Example, too):
(data: Example.Data)Example.Bar((x: Int, y: String)Example.Data((x # _), (y # _)))
(string: String, data: Example.Data)Example.Baz(_, (x: Int, y: String)Example.Data((x # _), (y # _)))
This is the result of printing the patterns as hinted in the macro above (using show), the raw abstract syntax trees (printed using showRaw) look like this:
Apply(TypeTree().setOriginal(Select(This(newTypeName("Example")), Example.Bar)), List(Apply(TypeTree().setOriginal(Select(This(newTypeName("Example")), Example.Data)), List(Bind(newTermName("x"), Ident(nme.WILDCARD)), Bind(newTermName("y"), Ident(nme.WILDCARD))))))
Apply(TypeTree().setOriginal(Select(This(newTypeName("Example")), Example.Baz)), List(Ident(nme.WILDCARD), Apply(TypeTree().setOriginal(Select(This(newTypeName("Example")), Example.Data)), List(Bind(newTermName("x"), Ident(nme.WILDCARD)), Bind(newTermName("y"), Ident(nme.WILDCARD))))))
How do I write a pattern-quote which matches against these trees?
First of all, there is a special flavor of quasiquotes specifically for CaseDefs called cq:
override def transformCaseDefs(trees: List[CaseDef]) = trees map {
case caseDef # cq"$pattern if $guard => $body" => ...
}
Secondly, you should use pq to deconstruct patterns:
pattern match {
case pq"$name # $nested" => ...
case pq"$extractor($arg1, $arg2: _*)" => ...
...
}
If you are interested in internals of trees that are used for pattern matching they are created by patvarTransformer defined in TreeBuilder.scala
On the other hand if you're are working with UnApply trees (that are being produced after typechecking) I have bad news for you: quasiquotes currently don't support them. Follow SI-7789 to get notified when this is fixed.
After Den Shabalin pointed out, that quasiquotes can't be used in this particular setting, I managed to find a pattern which matches against the patterns of a partial function's case definitions.
The key problem is, that the constructor we want to match against (in our example Data) is stored in the TypeTree of the Apply node. Matching against a tree wrapped up in a TypeTree is a bit tricky, since the only extractor of this class (TypeTree()) isn't very helpful for this particular task. Instead we have to select the wrapped up tree using the original method:
override def transform(tree: Tree) = tree match {
case Apply(constructor # TypeTree(), args) => constructor.original match {
case Select(_, sym) if (sym == newTermName("Data")) => ...
}
}
In our use case the wrapped up tree is a Select node and we can now check if the symbol of this node is the one we are looking for.