Traverse/fold a nested case class in Scala without boilerplate code - scala

I have some case classes for a mix of sum and product types:
sealed trait Leaf
case class GoodLeaf(value: Int) extends Leaf
case object BadLeaf extends Leaf
case class Middle(left: Leaf, right: Leaf)
case class Container(leaf: Leaf)
case class Top(middle : Middle, container: Container, extraLeaves : List[Leaf])
I want to do some fold-like operations with this Top structure. Examples include:
Count the occurrences of BadLeaf
Sum all the values in the GoodLeafs
Here is some code that does the operations:
object Top {
def fold[T](accu: T)(f : (T, Leaf) => T)(top: Top) = {
val allLeaves = top.container.leaf :: top.middle.left :: top.middle.right :: top.extraLeaves
allLeaves.foldLeft(accu)(f)
}
private def countBadLeaf(count: Int, leaf : Leaf) = leaf match {
case BadLeaf => count + 1
case _ => count
}
def countBad(top: Top): Int = fold(0)(countBadLeaf)(top)
private def sumGoodLeaf(count: Int, leaf : Leaf) = leaf match {
case GoodLeaf(v) => count + v
case _ => count
}
def sumGoodValues(top: Top) = fold(0)(sumGoodLeaf)(top)
}
The real life structure I am dealing with is significantly more complicated than the example I made up. Are there any techniques that could help me avoid writing lots of boilerplate code?
I already have the cats library as a dependency, so a solution that uses that lib would be preferred. I am open to including new dependencies in order to solve this problem.
For my particular example, the definition is not recursive, but I'd be interested in seeing a solution that works for recursive definitions also.

You could just create a function returning all leaves for Top, like you did with allLeaves, so you can just work with a List[Leaf] (with all the existing fold and other functions the Scala library, Cats, etc provide).
For example :
def topLeaves(top: Top): List[Leaf] =
top.container.leaf :: top.middle.left :: top.middle.right :: top.extraLeaves
val isBadLeaf: Leaf => Boolean = {
case BadLeaf => true
case _ => false
}
val leafValue: Leaf => Int = {
case GoodLeaf(v) => v
case _ => 0
}
Which you could use as
import cats.implicits._
// or
// import cats.instances.int._
// import cats.instances.list._
// import cats.syntax.foldable._
val leaves = topLeaves(someTop)
val badCount = leaves.count(isBadLeaf)
val badAndGood = leaves.partition(isBadLeaf) // (List[Leaf], List[Leaf])
val sumLeaves = leaves.foldMap(leafValue)
I am not sure if this helps with your actual use case ? In general with a heterogeneous structure (like your Top) you probably want to convert it somehow to something more homogeneous (like a List[Leaf] or Tree[Leaf]) where you can fold over.
If you have a recursive structure you could look at some talks about recursion schemes (with the Matryoshka library in Scala).

Related

How to make tree mapping tail-recursive?

Suppose I have a tree data structure like this:
trait Node { val name: String }
case class BranchNode(name: String, children: List[Node]) extends Node
case class LeafNode(name: String) extends Node
Suppose also I've got a function to map over leaves:
def mapLeaves(root: Node, f: LeafNode => LeafNode): Node = root match {
case ln: LeafNode => f(ln)
case bn: BranchNode => BranchNode(bn.name, bn.children.map(ch => mapLeaves(ch, f)))
}
Now I am trying to make this function tail-recursive but having a hard time to figure out how to do it. I've read this answer but still don't know to make that binary tree solution work for a multiway tree.
How would you rewrite mapLeaves to make it tail-recursive?
"Call stack" and "recursion" are merely popular design patterns that later got incorporated into most programming languages (and thus became mostly "invisible"). There is nothing that prevents you from reimplementing both with heap data structures. So, here is "the obvious" 1960's TAOCP retro-style solution:
trait Node { val name: String }
case class BranchNode(name: String, children: List[Node]) extends Node
case class LeafNode(name: String) extends Node
def mapLeaves(root: Node, f: LeafNode => LeafNode): Node = {
case class Frame(name: String, mapped: List[Node], todos: List[Node])
#annotation.tailrec
def step(stack: List[Frame]): Node = stack match {
// "return / pop a stack-frame"
case Frame(name, done, Nil) :: tail => {
val ret = BranchNode(name, done.reverse)
tail match {
case Nil => ret
case Frame(tn, td, tt) :: more => {
step(Frame(tn, ret :: td, tt) :: more)
}
}
}
case Frame(name, done, x :: xs) :: tail => x match {
// "recursion base"
case l # LeafNode(_) => step(Frame(name, f(l) :: done, xs) :: tail)
// "recursive call"
case BranchNode(n, cs) => step(Frame(n, Nil, cs) :: Frame(name, done, xs) :: tail)
}
case Nil => throw new Error("shouldn't happen")
}
root match {
case l # LeafNode(_) => f(l)
case b # BranchNode(n, cs) => step(List(Frame(n, Nil, cs)))
}
}
The tail-recursive step function takes a reified stack with "stack frames". A "stack frame" stores the name of the branch node that is currently being processed, a list of child nodes that have already been processed, and the list of the remaining nodes that still must be processed later. This roughly corresponds to an actual stack frame of your recursive mapLeaves function.
With this data structure,
returning from recursive calls corresponds to deconstructing a Frame object, and either returning the final result, or at least making the stack one frame shorter.
recursive calls correspond to a step that prepends a Frame to the stack
base case (invoking f on leaves) does not create or remove any frames
Once one understands how the usually invisible stack frames are represented explicitly, the translation is straightforward and mostly mechanical.
Example:
val example = BranchNode("x", List(
BranchNode("y", List(
LeafNode("a"),
LeafNode("b")
)),
BranchNode("z", List(
LeafNode("c"),
BranchNode("v", List(
LeafNode("d"),
LeafNode("e")
))
))
))
println(mapLeaves(example, { case LeafNode(n) => LeafNode(n.toUpperCase) }))
Output (indented):
BranchNode(x,List(
BranchNode(y,List(
LeafNode(A),
LeafNode(B)
)),
BranchNode(z, List(
LeafNode(C),
BranchNode(v,List(
LeafNode(D),
LeafNode(E)
))
))
))
It might be easier to implement it using a technique called trampoline.
If you use it, you'd be able to use two functions calling itself doing mutual recursion (with tailrec, you are limited to one function). Similarly to tailrec this recursion will be transformed to plain loop.
Trampolines are implemented in scala standard library in scala.util.control.TailCalls.
import scala.util.control.TailCalls.{TailRec, done, tailcall}
def mapLeaves(root: Node, f: LeafNode => LeafNode): Node = {
//two inner functions doing mutual recursion
//iterates recursively over children of node
def iterate(nodes: List[Node]): TailRec[List[Node]] = {
nodes match {
case x :: xs => tailcall(deepMap(x)) //it calls with mutual recursion deepMap which maps over children of node
.flatMap(node => iterate(xs).map(node :: _)) //you can flat map over TailRec
case Nil => done(Nil)
}
}
//recursively visits all branches
def deepMap(node: Node): TailRec[Node] = {
node match {
case ln: LeafNode => done(f(ln))
case bn: BranchNode => tailcall(iterate(bn.children))
.map(BranchNode(bn.name, _)) //calls mutually iterate
}
}
deepMap(root).result //unwrap result to plain node
}
Instead of TailCalls you could also use Eval from Cats or Trampoline from scalaz.
With that implementation function worked without problems:
def build(counter: Int): Node = {
if (counter > 0) {
BranchNode("branch", List(build(counter-1)))
} else {
LeafNode("leaf")
}
}
val root = build(4000)
mapLeaves(root, x => x.copy(name = x.name.reverse)) // no problems
When I ran that example with your implementation it caused java.lang.StackOverflowError as expected.

Nested Scala case classes to/from CSV

There are many nice libraries for writing/reading Scala case classes to/from CSV files. I'm looking for something that goes beyond that, which can handle nested cases classes. For example, here a Match has two Players:
case class Player(name: String, ranking: Int)
case class Match(place: String, winner: Player, loser: Player)
val matches = List(
Match("London", Player("Jane",7), Player("Fred",23)),
Match("Rome", Player("Marco",19), Player("Giulia",3)),
Match("Paris", Player("Isabelle",2), Player("Julien",5))
)
I'd like to effortlessly (no boilerplate!) write/read matches to/from this CSV:
place,winner.name,winner.ranking,loser.name,loser.ranking
London,Jane,7,Fred,23
Rome,Marco,19,Giulia,3
Paris,Isabelle,2,Julien,5
Note the automated header line using the dot "." to form the column name for a nested field, e.g. winner.ranking. I'd be delighted if someone could demonstrate a simple way to do this (say, using reflection or Shapeless).
[Motivation. During data analysis it's convenient to have a flat CSV to play around with, for sorting, filtering, etc., even when case classes are nested. And it would be nice if you could load nested case classes back from such files.]
Since a case-class is a Product, getting the values of the various fields is relatively easy. Getting the names of the fields/columns does require using Java reflection.
The following function takes a list of case-class instances and returns a list of rows, each is a list of strings. It is using a recursion to get the values and headers of child case-class instances.
def toCsv(p: List[Product]): List[List[String]] = {
def header(c: Class[_], prefix: String = ""): List[String] = {
c.getDeclaredFields.toList.flatMap { field =>
val name = prefix + field.getName
if (classOf[Product].isAssignableFrom(field.getType)) header(field.getType, name + ".")
else List(name)
}
}
def flatten(p: Product): List[String] =
p.productIterator.flatMap {
case p: Product => flatten(p)
case v: Any => List(v.toString)
}.toList
header(classOf[Match]) :: p.map(flatten)
}
However, constructing case-classes from CSV is far more involved, requiring to use reflection for getting the types of the various fields, for creating the values from the CSV strings and for constructing the case-class instances.
For simplicity (not saying the code is simple, just so it won't be further complicated), I assume that the order of columns in the CSV is the same as if the file was produced by the toCsv(...) function above.
The following function starts by creating a list of "instructions how to process a single CSV row" (the instructions are also used to verify that the column headers in the CSV matches the the case-class properties). The instructions are then used to recursively produce one CSV row at a time.
def fromCsv[T <: Product](csv: List[List[String]])(implicit tag: ClassTag[T]): List[T] = {
trait Instruction {
val name: String
val header = true
}
case class BeginCaseClassField(name: String, clazz: Class[_]) extends Instruction {
override val header = false
}
case class EndCaseClassField(name: String) extends Instruction {
override val header = false
}
case class IntField(name: String) extends Instruction
case class StringField(name: String) extends Instruction
case class DoubleField(name: String) extends Instruction
def scan(c: Class[_], prefix: String = ""): List[Instruction] = {
c.getDeclaredFields.toList.flatMap { field =>
val name = prefix + field.getName
val fType = field.getType
if (fType == classOf[Int]) List(IntField(name))
else if (fType == classOf[Double]) List(DoubleField(name))
else if (fType == classOf[String]) List(StringField(name))
else if (classOf[Product].isAssignableFrom(fType)) BeginCaseClassField(name, fType) :: scan(fType, name + ".")
else throw new IllegalArgumentException(s"Unsupported field type: $fType")
} :+ EndCaseClassField(prefix)
}
def produce(instructions: List[Instruction], row: List[String], argAccumulator: List[Any]): (List[Instruction], List[String], List[Any]) = instructions match {
case IntField(_) :: tail => produce(tail, row.drop(1), argAccumulator :+ row.head.toString.toInt)
case StringField(_) :: tail => produce(tail, row.drop(1), argAccumulator :+ row.head.toString)
case DoubleField(_) :: tail => produce(tail, row.drop(1), argAccumulator :+ row.head.toString.toDouble)
case BeginCaseClassField(_, clazz) :: tail =>
val (instructionRemaining, rowRemaining, constructorArgs) = produce(tail, row, List.empty)
val newCaseClass = clazz.getConstructors.head.newInstance(constructorArgs.map(_.asInstanceOf[AnyRef]): _*)
produce(instructionRemaining, rowRemaining, argAccumulator :+ newCaseClass)
case EndCaseClassField(_) :: tail => (tail, row, argAccumulator)
case Nil if row.isEmpty => (Nil, Nil, argAccumulator)
case Nil => throw new IllegalArgumentException("Not all values from CSV row were used")
}
val instructions = BeginCaseClassField(".", tag.runtimeClass) :: scan(tag.runtimeClass)
assert(csv.head == instructions.filter(_.header).map(_.name), "CSV header doesn't match target case-class fields")
csv.drop(1).map(row => produce(instructions, row, List.empty)._3.head.asInstanceOf[T])
}
I've tested this using:
case class Player(name: String, ranking: Int, price: Double)
case class Match(place: String, winner: Player, loser: Player)
val matches = List(
Match("London", Player("Jane", 7, 12.5), Player("Fred", 23, 11.1)),
Match("Rome", Player("Marco", 19, 13.54), Player("Giulia", 3, 41.8)),
Match("Paris", Player("Isabelle", 2, 31.7), Player("Julien", 5, 16.8))
)
val csv = toCsv(matches)
val matchesFromCsv = fromCsv[Match](csv)
assert(matches == matchesFromCsv)
Obviously this should be optimized and hardened if you ever want to use this for production...

Traverse and Modify an Heterogeneous Directed Acyclic Graph in Scala

Traverse and Modify an Heterogeneous Directed Acyclic Graph in Scala
Good morning everybody.
I have the following directed acyclic graph data structure, implemented in Scala as follows:
abstract class Node // Generic abstract node
/** Various kind of leaf nodes */
case class LeafNodeA(x: String) extends Node
case class LeafNodeB(x: Int) extends Node
/** Various kind of inner nodes */
case class InnerNode1(x: String, depRoleA: Node) extends Node
case class InnerNode2(x: String, y: Double, depRoleX: Node, depRoleY: Node) extends Node
case class InnerNode3(x: List[Int], depRoleA: Node, y: Int,
depRoleB: Node, depRoleG: Node) extends Node
In this structure a node can be a dependency of multiple nodes, therefore it is not a Tree but a Directed Acyclic Graph. In addition the structure is not even balanced (nodes have different numbers of dependencies).
The problem of traversal
Notice that I have called the various dependency fields of the case classes with different names since they represents different roles in the dependencies (for example, depRoleX has a different role than depRoleY in an InnerNode2 type node). For this reason I don't think it is possible to store the dependencies of each node in a List[Node] like in any trivial tree/dag implementation you can find out there, because the meaning of each dependency field is different.
Of course when I traverse this structure I have to do pattern matching in order to understand the type of node I am dealing with at the current recursion step:
// Random function which returns the list of all the String attributes of the nodes
def getAllStrings(dag: Node): List[String] = {
dag match {
case LeafNodeA(x) => List(x)
case LeafNodeB => List()
case InnerNode1(x, dr) => List(x) ::: getAllStrings(dr)
case InnerNode2(x, _, dX, dY) => List(x) ::: getAllStrings(dX) :: getAllStrings(dY)
case InnerNode3(_, dA, _, dB, dG) => getAllStrings(dA) ::: getAllStrings(dB) ::: getAllStrings(dG)
}
}
Now suppose that instead of these 5 relatively simple node types I have around 20 types of node. The previous function would become extremely long and repetitive (a case statement for each node type). Even worse: every time I want to do a traversal I have to do the same thing.
Thinking about this problem I came up with two solutions.
External method for traversal
The first (obvious) way to deal with this is to modularize the previous method defining a generic DAG traversal function
object DAGManipulator {
def getDependencies(dag: Node): List[Node] = {
dag match {
case LeafNodeA => List()
case LeafNodeB => List()
case InnerNode1(_, dr) => List(dr)
case InnerNode2(_, _, dX, dY) => List(dX, dY)
case InnerNode3(_, dA, _, dB, dG) => List(dA, dB, dG)
}
}
}
In this way, every time I need the dependencies of a node I can rely on this static function.
Abstract class method for getting the dependencies
The second solution I came up with is to give to every node an additional method in the following way:
abstract class Node {
def getDependencies : List[Node]
}
case class LeafNodeA(x: String) extends Node = {
override def getDependencies : List[Node] = List()
}
case class LeafNodeB(x: Int) extends Node = {
override def getDependencies : List[Node] = List()
}
/** Various kind of inner nodes */
case class InnerNode1(x: String, depRoleA: Node) extends Node = {
override def getDependencies : List[Node] = List()
}
case class InnerNode2(x: String, y: Double, depRoleX: Node, depRoleY: Node) extends Node = {
override def getDependencies : List[Node] = List(depRoleX, depRoleY)
}
case class InnerNode3(x: List[Int], depRoleA: Node, y: Int,
depRoleB: Node, depRoleG: Node) extends Node = {
override def getDependencies : List[Node] = List(depRoleA, depRoleB, depRoleG)
}
I don't like any of the previous solutions:
The first one must be updated every time a new node type is added to the hierarchy. In addition to this, it delegates a fundamental feature of the DAG structure (traversal) to an external object, which I find very unpleasant from the software engineering point of view.
The second solution in my opinion is even worse because every node type basically has to redundantly state its dependencies (once in its fields and once in the getDependencies method. I find this very ugly and prone to programming errors.
Do you have a better solution to this problem ?
The problem of updating
The second problem I have to deal with is the updating/modification of the data structure.
Suppose that I have a DAG defined in the following way.
val l1 = LeafNodeB(1)
val dag =
InnerNode3(List(1, 2, 3),
InnerNode1("InnerNode1", LeafNodeA("leafA1")),
1, l1, InnerNode2("InnerNode2", 2, l1, LeafNodeA("leafA2")))
corresponding to this structure.
Suppose that I want to change the LeafNodeA("leafA1") (which is a dependency of the InnerNode1) to, for example, l1, which is a LeafNodeB.
This is the kind of operation that I need to do:
def modify(dag: Node): Node = {
dag match {
case x: InnerNode1 => if(x.x == "InnerNode1") x.copy(depRoleA = l1) else x
case x: LeafNodeB => x
case x: LeafNodeA => x
case x: InnerNode2 => x.copy(depRoleX = modify(x.depRoleX), depRoleY = modify(x.depRoleY))
case x: InnerNode3 => x.copy(depRoleA = modify(x.depRoleA), depRoleB = modify(x.depRoleB), depRoleG = modify(x.depRoleG))
}
}
Again, consider the possibility of having more than 20 node types...Again this update method would become not practical, and this counts for every other possible update method that I can think of.
In addition to this...this time I did not come up with a different strategy for factorizing/modularize this "recursive traversal update" of the nested structure. I have to check for every possible node type in order to understand how to use the copy method of the various case classes.
Do you have a better solution / design for this update strategy ?
To address this issue:
The first one must be updated every time a new node type is added to
the hierarchy. In addition to this, it delegates a fundamental feature
of the DAG structure (traversal) to an external object, which I find
very unpleasant from the software engineering point of view.
I think this is not really a drawback, but this is something that's pretty core to the OO vs FP clash inherent in Scala. I would say your node classes being "dumb" data holders, and having a separate code path for traversing on them is a good thing. And sure, you have to add a line there every time you add a node, but the compiler can warn you about that if you don't.
Anyway, this might be overkill and it's a bit of an undertaking, but you may want to look into Matryoshka, which generalizes recursive data structures like this. It requires a bit of plumbing to translate your data types into the scheme that is expected, and define a functor for that:
abstract class NodeF[+A] // Generic abstract node
/** Various kind of leaf nodes */
case class LeafNodeA(x: String) extends NodeF[Nothing]
case class LeafNodeB(x: Int) extends NodeF[Nothing]
/** Various kind of inner nodes */
case class InnerNode1[A](x: String, depRoleA: A) extends NodeF[A]
case class InnerNode2[A](x: String, y: Double, depRoleX: A, depRoleY: A) extends NodeF[A]
case class InnerNode3[A](x: List[Int], depRoleA: A, y: Int,
depRoleB: A, depRoleG: A) extends NodeF[A]
implicit val nodeFunctor: Functor[NodeF] = new Functor[NodeF] {
def map[A, B](fa: NodeF[A])(f: A => B): NodeF[B] = fa match {
case LeafNodeA(x) => LeafNodeA(x)
case LeafNodeB(x) => LeafNodeB(x)
case InnerNode1(x, depA) => InnerNode1(x, f(depA))
case InnerNode2(x, y, depX, depY) => InnerNode2(x, y, f(depX), f(depY))
case InnerNode3(x, depA, y, depB, depG) => InnerNode3(x, f(depA), y, f(depB), f(depG))
}
}
But then it essentially hides the recursion from you and you can more easily define these kinds of things:
type FixNode = Fix[NodeF]
def someExprGeneric[T](implicit T : Corecursive.Aux[T, NodeF]): T =
InnerNode2("hello", 1.0, InnerNode1("world", LeafNodeA("!").embed).embed, LeafNodeB(1).embed).embed
val someExpr = someExprGeneric[FixNode]
def getStrings: Algebra[NodeF, List[String]] = {
case LeafNodeA(x) => List(x)
case LeafNodeB(_) => List()
case InnerNode1(x, depA) => x :: depA
case InnerNode2(x, _, depX, depY) => x :: depX ::: depY
case InnerNode3(_, depA, _, depB, depG) => depA ::: depB ::: depG
}
someExpr.cata(getStrings) // List("hello", "world", "!")
Perhaps that's not that much cleaner than what you have, but it at least separates the recursive traversal logic from the "single step" evaluation logic. But I think where it shines a bit more is when updating:
def expandToUniverse: Algebra[NodeF, Node] = {
case InnerNode1("world", dep) => InnerNode1("universe", dep).embed
case x => x.embed
}
someExpr.cata(expandToUniverse).cata(getStrings) // List("hello", "universe", "!")
Because you've delegated out that recursion, you only have to implement the case(s) you actually care about.

How to match against the pattern of a partial function's case definition in a Scala macro?

As part of a macro, I want to manipulate the case definitions of a partial function.
To do so, I use a Transformer to manipulate the case definitions of the partial function and a Traverser to inspect the patterns of the case definitions:
def myMatchImpl[A: c.WeakTypeTag, B: c.WeakTypeTag](c: Context)
(expr: c.Expr[A])(patterns: c.Expr[PartialFunction[A, B]]): c.Expr[B] = {
import c.universe._
val transformer = new Transformer {
override def transformCaseDefs(trees: List[CaseDef]) = trees map {
case caseDef # CaseDef(pattern, guard , body) => {
// println(show(pattern))
val traverser = new Traverser {
override def traverse(tree: Tree) = tree match {
// match against a specific pattern
}
}
traverser.traverse(pattern)
}
}
}
val transformedPartialFunction = transformer.transform(patterns.tree)
c.Expr[B](q"$transformedPartialFunction($expr)")
}
Now let us assume, the interesting data I want to match against is represented by the class Data (which is part of the object Example):
case class Data(x: Int, y: String)
When now invoking the macro on the example below
abstract class Foo
case class Bar(data: Data) extends Foo
case class Baz(string: String, data: Data) extends Foo
def test(foo: Foo) = myMatch(foo){
case Bar(Data(x,y)) => y
case Baz(_, Data(x,y)) => y
}
the patterns of the case definitions of the partial function are transformed by the compiler as following (the Foo, Bar, and Baz classes are members of the object Example, too):
(data: Example.Data)Example.Bar((x: Int, y: String)Example.Data((x # _), (y # _)))
(string: String, data: Example.Data)Example.Baz(_, (x: Int, y: String)Example.Data((x # _), (y # _)))
This is the result of printing the patterns as hinted in the macro above (using show), the raw abstract syntax trees (printed using showRaw) look like this:
Apply(TypeTree().setOriginal(Select(This(newTypeName("Example")), Example.Bar)), List(Apply(TypeTree().setOriginal(Select(This(newTypeName("Example")), Example.Data)), List(Bind(newTermName("x"), Ident(nme.WILDCARD)), Bind(newTermName("y"), Ident(nme.WILDCARD))))))
Apply(TypeTree().setOriginal(Select(This(newTypeName("Example")), Example.Baz)), List(Ident(nme.WILDCARD), Apply(TypeTree().setOriginal(Select(This(newTypeName("Example")), Example.Data)), List(Bind(newTermName("x"), Ident(nme.WILDCARD)), Bind(newTermName("y"), Ident(nme.WILDCARD))))))
How do I write a pattern-quote which matches against these trees?
First of all, there is a special flavor of quasiquotes specifically for CaseDefs called cq:
override def transformCaseDefs(trees: List[CaseDef]) = trees map {
case caseDef # cq"$pattern if $guard => $body" => ...
}
Secondly, you should use pq to deconstruct patterns:
pattern match {
case pq"$name # $nested" => ...
case pq"$extractor($arg1, $arg2: _*)" => ...
...
}
If you are interested in internals of trees that are used for pattern matching they are created by patvarTransformer defined in TreeBuilder.scala
On the other hand if you're are working with UnApply trees (that are being produced after typechecking) I have bad news for you: quasiquotes currently don't support them. Follow SI-7789 to get notified when this is fixed.
After Den Shabalin pointed out, that quasiquotes can't be used in this particular setting, I managed to find a pattern which matches against the patterns of a partial function's case definitions.
The key problem is, that the constructor we want to match against (in our example Data) is stored in the TypeTree of the Apply node. Matching against a tree wrapped up in a TypeTree is a bit tricky, since the only extractor of this class (TypeTree()) isn't very helpful for this particular task. Instead we have to select the wrapped up tree using the original method:
override def transform(tree: Tree) = tree match {
case Apply(constructor # TypeTree(), args) => constructor.original match {
case Select(_, sym) if (sym == newTermName("Data")) => ...
}
}
In our use case the wrapped up tree is a Select node and we can now check if the symbol of this node is the one we are looking for.

map over structure with only partial match

I have a tree-like structure of abstract classes and case classes representing an Abstract Syntax Tree of a small language.
For the top abstract class i've implemented a method map:
abstract class AST {
...
def map(f: (AST => AST)): AST = {
val b1 = this match {
case s: STRUCTURAL => s.smap(f) // structural node for example IF(expr,truebranch,falsebranch)
case _ => this // leaf, // leaf, like ASSIGN(x,2)
}
f(b1)
}
...
The smap is defined like:
override def smap(f: AST => AST) = {
this.copy(trueb = trueb.map(f), falseb = falseb.map(f))
}
Now im writing different "transformations" to insert, remove and change nodes in the AST.
For example, remove adjacent NOP nodes from blocks:
def handle_list(l:List[AST]) = l match {
case (NOP::NOP::tl) => handle_list(tl)
case h::tl => h::handle_list(tl)
case Nil => Nil
}
ast.map {
case BLOCK(listofstatements) => handle_list(listofstatements)
}
If I write like this, I end up with MatchError and I can "fix it" by changing the above map to:
ast.map {
case BLOCK(listofstatements) => handle_list(listofstatements)
case a => a
}
Should I just live with all those case a => a or could I improve my map method(or other parts) in some way?
Make the argument to map a PartialFunction:
def map(f: PartialFunction[AST, AST]): AST = {
val idAST: PartialFunction[AST, AST] = {case a => a}
val g = f.orElse(idAST)
val b1 = this match {
case s: STRUCTURAL => s.smap(g)
case _ => this
}
g(b1)
}
If tree transformations are more than a minor aspect of your project, I highly recommend you use Kiama's Rewriter module to implement them. It implements Stratego-like strategy-driven transformations. It has a very rich set of strategies and strategy combinators that permit a complete separation of traversal logic (which for the vast majority of cases can be taken "off the shelf" from the supplied strategies and combinators) from (local) transformations (which are specific to your AST and you supply, of course).