Traverse and Modify an Heterogeneous Directed Acyclic Graph in Scala - scala

Traverse and Modify an Heterogeneous Directed Acyclic Graph in Scala
Good morning everybody.
I have the following directed acyclic graph data structure, implemented in Scala as follows:
abstract class Node // Generic abstract node
/** Various kind of leaf nodes */
case class LeafNodeA(x: String) extends Node
case class LeafNodeB(x: Int) extends Node
/** Various kind of inner nodes */
case class InnerNode1(x: String, depRoleA: Node) extends Node
case class InnerNode2(x: String, y: Double, depRoleX: Node, depRoleY: Node) extends Node
case class InnerNode3(x: List[Int], depRoleA: Node, y: Int,
depRoleB: Node, depRoleG: Node) extends Node
In this structure a node can be a dependency of multiple nodes, therefore it is not a Tree but a Directed Acyclic Graph. In addition the structure is not even balanced (nodes have different numbers of dependencies).
The problem of traversal
Notice that I have called the various dependency fields of the case classes with different names since they represents different roles in the dependencies (for example, depRoleX has a different role than depRoleY in an InnerNode2 type node). For this reason I don't think it is possible to store the dependencies of each node in a List[Node] like in any trivial tree/dag implementation you can find out there, because the meaning of each dependency field is different.
Of course when I traverse this structure I have to do pattern matching in order to understand the type of node I am dealing with at the current recursion step:
// Random function which returns the list of all the String attributes of the nodes
def getAllStrings(dag: Node): List[String] = {
dag match {
case LeafNodeA(x) => List(x)
case LeafNodeB => List()
case InnerNode1(x, dr) => List(x) ::: getAllStrings(dr)
case InnerNode2(x, _, dX, dY) => List(x) ::: getAllStrings(dX) :: getAllStrings(dY)
case InnerNode3(_, dA, _, dB, dG) => getAllStrings(dA) ::: getAllStrings(dB) ::: getAllStrings(dG)
}
}
Now suppose that instead of these 5 relatively simple node types I have around 20 types of node. The previous function would become extremely long and repetitive (a case statement for each node type). Even worse: every time I want to do a traversal I have to do the same thing.
Thinking about this problem I came up with two solutions.
External method for traversal
The first (obvious) way to deal with this is to modularize the previous method defining a generic DAG traversal function
object DAGManipulator {
def getDependencies(dag: Node): List[Node] = {
dag match {
case LeafNodeA => List()
case LeafNodeB => List()
case InnerNode1(_, dr) => List(dr)
case InnerNode2(_, _, dX, dY) => List(dX, dY)
case InnerNode3(_, dA, _, dB, dG) => List(dA, dB, dG)
}
}
}
In this way, every time I need the dependencies of a node I can rely on this static function.
Abstract class method for getting the dependencies
The second solution I came up with is to give to every node an additional method in the following way:
abstract class Node {
def getDependencies : List[Node]
}
case class LeafNodeA(x: String) extends Node = {
override def getDependencies : List[Node] = List()
}
case class LeafNodeB(x: Int) extends Node = {
override def getDependencies : List[Node] = List()
}
/** Various kind of inner nodes */
case class InnerNode1(x: String, depRoleA: Node) extends Node = {
override def getDependencies : List[Node] = List()
}
case class InnerNode2(x: String, y: Double, depRoleX: Node, depRoleY: Node) extends Node = {
override def getDependencies : List[Node] = List(depRoleX, depRoleY)
}
case class InnerNode3(x: List[Int], depRoleA: Node, y: Int,
depRoleB: Node, depRoleG: Node) extends Node = {
override def getDependencies : List[Node] = List(depRoleA, depRoleB, depRoleG)
}
I don't like any of the previous solutions:
The first one must be updated every time a new node type is added to the hierarchy. In addition to this, it delegates a fundamental feature of the DAG structure (traversal) to an external object, which I find very unpleasant from the software engineering point of view.
The second solution in my opinion is even worse because every node type basically has to redundantly state its dependencies (once in its fields and once in the getDependencies method. I find this very ugly and prone to programming errors.
Do you have a better solution to this problem ?
The problem of updating
The second problem I have to deal with is the updating/modification of the data structure.
Suppose that I have a DAG defined in the following way.
val l1 = LeafNodeB(1)
val dag =
InnerNode3(List(1, 2, 3),
InnerNode1("InnerNode1", LeafNodeA("leafA1")),
1, l1, InnerNode2("InnerNode2", 2, l1, LeafNodeA("leafA2")))
corresponding to this structure.
Suppose that I want to change the LeafNodeA("leafA1") (which is a dependency of the InnerNode1) to, for example, l1, which is a LeafNodeB.
This is the kind of operation that I need to do:
def modify(dag: Node): Node = {
dag match {
case x: InnerNode1 => if(x.x == "InnerNode1") x.copy(depRoleA = l1) else x
case x: LeafNodeB => x
case x: LeafNodeA => x
case x: InnerNode2 => x.copy(depRoleX = modify(x.depRoleX), depRoleY = modify(x.depRoleY))
case x: InnerNode3 => x.copy(depRoleA = modify(x.depRoleA), depRoleB = modify(x.depRoleB), depRoleG = modify(x.depRoleG))
}
}
Again, consider the possibility of having more than 20 node types...Again this update method would become not practical, and this counts for every other possible update method that I can think of.
In addition to this...this time I did not come up with a different strategy for factorizing/modularize this "recursive traversal update" of the nested structure. I have to check for every possible node type in order to understand how to use the copy method of the various case classes.
Do you have a better solution / design for this update strategy ?

To address this issue:
The first one must be updated every time a new node type is added to
the hierarchy. In addition to this, it delegates a fundamental feature
of the DAG structure (traversal) to an external object, which I find
very unpleasant from the software engineering point of view.
I think this is not really a drawback, but this is something that's pretty core to the OO vs FP clash inherent in Scala. I would say your node classes being "dumb" data holders, and having a separate code path for traversing on them is a good thing. And sure, you have to add a line there every time you add a node, but the compiler can warn you about that if you don't.
Anyway, this might be overkill and it's a bit of an undertaking, but you may want to look into Matryoshka, which generalizes recursive data structures like this. It requires a bit of plumbing to translate your data types into the scheme that is expected, and define a functor for that:
abstract class NodeF[+A] // Generic abstract node
/** Various kind of leaf nodes */
case class LeafNodeA(x: String) extends NodeF[Nothing]
case class LeafNodeB(x: Int) extends NodeF[Nothing]
/** Various kind of inner nodes */
case class InnerNode1[A](x: String, depRoleA: A) extends NodeF[A]
case class InnerNode2[A](x: String, y: Double, depRoleX: A, depRoleY: A) extends NodeF[A]
case class InnerNode3[A](x: List[Int], depRoleA: A, y: Int,
depRoleB: A, depRoleG: A) extends NodeF[A]
implicit val nodeFunctor: Functor[NodeF] = new Functor[NodeF] {
def map[A, B](fa: NodeF[A])(f: A => B): NodeF[B] = fa match {
case LeafNodeA(x) => LeafNodeA(x)
case LeafNodeB(x) => LeafNodeB(x)
case InnerNode1(x, depA) => InnerNode1(x, f(depA))
case InnerNode2(x, y, depX, depY) => InnerNode2(x, y, f(depX), f(depY))
case InnerNode3(x, depA, y, depB, depG) => InnerNode3(x, f(depA), y, f(depB), f(depG))
}
}
But then it essentially hides the recursion from you and you can more easily define these kinds of things:
type FixNode = Fix[NodeF]
def someExprGeneric[T](implicit T : Corecursive.Aux[T, NodeF]): T =
InnerNode2("hello", 1.0, InnerNode1("world", LeafNodeA("!").embed).embed, LeafNodeB(1).embed).embed
val someExpr = someExprGeneric[FixNode]
def getStrings: Algebra[NodeF, List[String]] = {
case LeafNodeA(x) => List(x)
case LeafNodeB(_) => List()
case InnerNode1(x, depA) => x :: depA
case InnerNode2(x, _, depX, depY) => x :: depX ::: depY
case InnerNode3(_, depA, _, depB, depG) => depA ::: depB ::: depG
}
someExpr.cata(getStrings) // List("hello", "world", "!")
Perhaps that's not that much cleaner than what you have, but it at least separates the recursive traversal logic from the "single step" evaluation logic. But I think where it shines a bit more is when updating:
def expandToUniverse: Algebra[NodeF, Node] = {
case InnerNode1("world", dep) => InnerNode1("universe", dep).embed
case x => x.embed
}
someExpr.cata(expandToUniverse).cata(getStrings) // List("hello", "universe", "!")
Because you've delegated out that recursion, you only have to implement the case(s) you actually care about.

Related

Complexity of linked-lists concatenation

I'm still pretty much new to functional programming and experimenting with algebraic data types. I implemented LinkedList as follows:
sealed abstract class LinkedList[+T]{
def flatMap[T2](f: T => LinkedList[T2]) = LinkedList.flatMap(this)(f)
def foreach(f: T => Unit): Unit = LinkedList.foreach(this)(f)
}
final case class Next[T](t: T, linkedList: LinkedList[T]) extends LinkedList[T]
case object Stop extends LinkedList[Nothing]
object LinkedList{
private def connect[T](left: LinkedList[T], right: LinkedList[T]): LinkedList[T] = left match {
case Stop => right
case Next(t, l) => Next(t, connect(l, right))
}
private def flatMap[T, T2](list: LinkedList[T])(f: T => LinkedList[T2]): LinkedList[T2] = list match {
case Stop => Stop
case Next(t, l) => connect(f(t), flatMap(l)(f))
}
private def foreach[T](ll: LinkedList[T])(f: T => Unit): Unit = ll match {
case Stop => ()
case Next(t, l) =>
f(t)
foreach(l)(f)
}
}
The thing is that I wanted to measure complexity of flatMap operation.
The complexity of flatMap is O(n) where n is the length of flatMaped list (as far as I could prove by induction).
The question is how to design the list that way so the concatenation has constant-time complexity? What kind of class extends LinkedList we should design to achieve that? Or maybe some another way?
One approach would be to keep track of the end of your list, and make your list nodes mutable so you can just modify the end to point to the start of the next list.
I don't see a way to do it with immutable nodes, since modifying what follows a tail node requires modifying everything before it. To get constant-time concatenation with immutable data structures, you need something other than a simple linked list.

Functional Breadth First Search in Scala with the State Monad

I'm trying to implement a functional Breadth First Search in Scala to compute the distances between a given node and all the other nodes in an unweighted graph. I've used a State Monad for this with the signature as :-
case class State[S,A](run:S => (A,S))
Other functions such as map, flatMap, sequence, modify etc etc are similar to what you'd find inside a standard State Monad.
Here's the code :-
case class Node(label: Int)
case class BfsState(q: Queue[Node], nodesList: List[Node], discovered: Set[Node], distanceFromSrc: Map[Node, Int]) {
val isTerminated = q.isEmpty
}
case class Graph(adjList: Map[Node, List[Node]]) {
def bfs(src: Node): (List[Node], Map[Node, Int]) = {
val initialBfsState = BfsState(Queue(src), List(src), Set(src), Map(src -> 0))
val output = bfsComp(initialBfsState)
(output.nodesList,output.distanceFromSrc)
}
#tailrec
private def bfsComp(currState:BfsState): BfsState = {
if (currState.isTerminated) currState
else bfsComp(searchNode.run(currState)._2)
}
private def searchNode: State[BfsState, Unit] = for {
node <- State[BfsState, Node](s => {
val (n, newQ) = s.q.dequeue
(n, s.copy(q = newQ))
})
s <- get
_ <- sequence(adjList(node).filter(!s.discovered(_)).map(n => {
modify[BfsState](s => {
s.copy(s.q.enqueue(n), n :: s.nodesList, s.discovered + n, s.distanceFromSrc + (n -> (s.distanceFromSrc(node) + 1)))
})
}))
} yield ()
}
Please can you advice on :-
Should the State Transition on dequeue in the searchNode function be a member of BfsState itself?
How do I make this code more performant/concise/readable?
First off, I suggest moving all the private defs related to bfs into bfs itself. This is the convention for methods that are solely used to implement another.
Second, I suggest simply not using State for this matter. State (like most monads) is about composition. It is useful when you have many things that all need access to the same global state. In this case, BfsState is specialized to bfs, will likely never be used anywhere else (it might be a good idea to move the class into bfs too), and the State itself is always run, so the outer world never sees it. (In many cases, this is fine, but here the scope is too small for State to be useful.) It'd be much cleaner to pull the logic of searchNode into bfsComp itself.
Third, I don't understand why you need both nodesList and discovered, when you can just call _.toList on discovered once you've done your computation. I've left it in in my reimplementation, though, in case there's more to this code that you haven't displayed.
def bfsComp(old: BfsState): BfsState = {
if(old.q.isEmpty) old // You don't need isTerminated, I think
else {
val (currNode, newQ) = old.q.dequeue
val newState = old.copy(q = newQ)
adjList(curNode)
.filterNot(s.discovered) // Set[T] <: T => Boolean and filterNot means you don't need to write !s.discovered(_)
.foldLeft(newState) { case (BfsState(q, nodes, discovered, distance), adjNode) =>
BfsState(
q.enqueue(adjNode),
adjNode :: nodes,
discovered + adjNode,
distance + (adjNode -> (distance(currNode) + 1)
)
}
}
}
def bfs(src: Node): (List[Node], Map[Node, Int]) = {
// I suggest moving BfsState and bfsComp into this method
val output = bfsComp(BfsState(Queue(src), List(src), Set(src), Map(src -> 0)))
(output.nodesList, output.distanceFromSrc)
// Could get rid of nodesList and say output.discovered.toList
}
In the event that you think you do have a good reason for using State here, here are my thoughts.
You use def searchNode. The point of a State is that it is pure and immutable, so it should be a val, or else you reconstruct the same State every use.
You write:
node <- State[BfsState, Node](s => {
val (n, newQ) = s.q.dequeue
(n, s.copy(q = newQ))
})
First off, Scala's syntax was designed so that you don't need to have both a () and {} surrounding an anonymous function:
node <- State[BfsState, Node] { s =>
// ...
}
Second, this doesn't look quite right to me. One benefit of using for-syntax is that the anonymous functions are hidden from you and there is minimal indentation. I'd just write it out
oldState <- get
(node, newQ) = oldState.q.dequeue
newState = oldState.copy(q = newQ)
Footnote: would it make sense to make Node an inner class of Graph? Just a suggestion.

Traverse/fold a nested case class in Scala without boilerplate code

I have some case classes for a mix of sum and product types:
sealed trait Leaf
case class GoodLeaf(value: Int) extends Leaf
case object BadLeaf extends Leaf
case class Middle(left: Leaf, right: Leaf)
case class Container(leaf: Leaf)
case class Top(middle : Middle, container: Container, extraLeaves : List[Leaf])
I want to do some fold-like operations with this Top structure. Examples include:
Count the occurrences of BadLeaf
Sum all the values in the GoodLeafs
Here is some code that does the operations:
object Top {
def fold[T](accu: T)(f : (T, Leaf) => T)(top: Top) = {
val allLeaves = top.container.leaf :: top.middle.left :: top.middle.right :: top.extraLeaves
allLeaves.foldLeft(accu)(f)
}
private def countBadLeaf(count: Int, leaf : Leaf) = leaf match {
case BadLeaf => count + 1
case _ => count
}
def countBad(top: Top): Int = fold(0)(countBadLeaf)(top)
private def sumGoodLeaf(count: Int, leaf : Leaf) = leaf match {
case GoodLeaf(v) => count + v
case _ => count
}
def sumGoodValues(top: Top) = fold(0)(sumGoodLeaf)(top)
}
The real life structure I am dealing with is significantly more complicated than the example I made up. Are there any techniques that could help me avoid writing lots of boilerplate code?
I already have the cats library as a dependency, so a solution that uses that lib would be preferred. I am open to including new dependencies in order to solve this problem.
For my particular example, the definition is not recursive, but I'd be interested in seeing a solution that works for recursive definitions also.
You could just create a function returning all leaves for Top, like you did with allLeaves, so you can just work with a List[Leaf] (with all the existing fold and other functions the Scala library, Cats, etc provide).
For example :
def topLeaves(top: Top): List[Leaf] =
top.container.leaf :: top.middle.left :: top.middle.right :: top.extraLeaves
val isBadLeaf: Leaf => Boolean = {
case BadLeaf => true
case _ => false
}
val leafValue: Leaf => Int = {
case GoodLeaf(v) => v
case _ => 0
}
Which you could use as
import cats.implicits._
// or
// import cats.instances.int._
// import cats.instances.list._
// import cats.syntax.foldable._
val leaves = topLeaves(someTop)
val badCount = leaves.count(isBadLeaf)
val badAndGood = leaves.partition(isBadLeaf) // (List[Leaf], List[Leaf])
val sumLeaves = leaves.foldMap(leafValue)
I am not sure if this helps with your actual use case ? In general with a heterogeneous structure (like your Top) you probably want to convert it somehow to something more homogeneous (like a List[Leaf] or Tree[Leaf]) where you can fold over.
If you have a recursive structure you could look at some talks about recursion schemes (with the Matryoshka library in Scala).

How to match against the pattern of a partial function's case definition in a Scala macro?

As part of a macro, I want to manipulate the case definitions of a partial function.
To do so, I use a Transformer to manipulate the case definitions of the partial function and a Traverser to inspect the patterns of the case definitions:
def myMatchImpl[A: c.WeakTypeTag, B: c.WeakTypeTag](c: Context)
(expr: c.Expr[A])(patterns: c.Expr[PartialFunction[A, B]]): c.Expr[B] = {
import c.universe._
val transformer = new Transformer {
override def transformCaseDefs(trees: List[CaseDef]) = trees map {
case caseDef # CaseDef(pattern, guard , body) => {
// println(show(pattern))
val traverser = new Traverser {
override def traverse(tree: Tree) = tree match {
// match against a specific pattern
}
}
traverser.traverse(pattern)
}
}
}
val transformedPartialFunction = transformer.transform(patterns.tree)
c.Expr[B](q"$transformedPartialFunction($expr)")
}
Now let us assume, the interesting data I want to match against is represented by the class Data (which is part of the object Example):
case class Data(x: Int, y: String)
When now invoking the macro on the example below
abstract class Foo
case class Bar(data: Data) extends Foo
case class Baz(string: String, data: Data) extends Foo
def test(foo: Foo) = myMatch(foo){
case Bar(Data(x,y)) => y
case Baz(_, Data(x,y)) => y
}
the patterns of the case definitions of the partial function are transformed by the compiler as following (the Foo, Bar, and Baz classes are members of the object Example, too):
(data: Example.Data)Example.Bar((x: Int, y: String)Example.Data((x # _), (y # _)))
(string: String, data: Example.Data)Example.Baz(_, (x: Int, y: String)Example.Data((x # _), (y # _)))
This is the result of printing the patterns as hinted in the macro above (using show), the raw abstract syntax trees (printed using showRaw) look like this:
Apply(TypeTree().setOriginal(Select(This(newTypeName("Example")), Example.Bar)), List(Apply(TypeTree().setOriginal(Select(This(newTypeName("Example")), Example.Data)), List(Bind(newTermName("x"), Ident(nme.WILDCARD)), Bind(newTermName("y"), Ident(nme.WILDCARD))))))
Apply(TypeTree().setOriginal(Select(This(newTypeName("Example")), Example.Baz)), List(Ident(nme.WILDCARD), Apply(TypeTree().setOriginal(Select(This(newTypeName("Example")), Example.Data)), List(Bind(newTermName("x"), Ident(nme.WILDCARD)), Bind(newTermName("y"), Ident(nme.WILDCARD))))))
How do I write a pattern-quote which matches against these trees?
First of all, there is a special flavor of quasiquotes specifically for CaseDefs called cq:
override def transformCaseDefs(trees: List[CaseDef]) = trees map {
case caseDef # cq"$pattern if $guard => $body" => ...
}
Secondly, you should use pq to deconstruct patterns:
pattern match {
case pq"$name # $nested" => ...
case pq"$extractor($arg1, $arg2: _*)" => ...
...
}
If you are interested in internals of trees that are used for pattern matching they are created by patvarTransformer defined in TreeBuilder.scala
On the other hand if you're are working with UnApply trees (that are being produced after typechecking) I have bad news for you: quasiquotes currently don't support them. Follow SI-7789 to get notified when this is fixed.
After Den Shabalin pointed out, that quasiquotes can't be used in this particular setting, I managed to find a pattern which matches against the patterns of a partial function's case definitions.
The key problem is, that the constructor we want to match against (in our example Data) is stored in the TypeTree of the Apply node. Matching against a tree wrapped up in a TypeTree is a bit tricky, since the only extractor of this class (TypeTree()) isn't very helpful for this particular task. Instead we have to select the wrapped up tree using the original method:
override def transform(tree: Tree) = tree match {
case Apply(constructor # TypeTree(), args) => constructor.original match {
case Select(_, sym) if (sym == newTermName("Data")) => ...
}
}
In our use case the wrapped up tree is a Select node and we can now check if the symbol of this node is the one we are looking for.

map over structure with only partial match

I have a tree-like structure of abstract classes and case classes representing an Abstract Syntax Tree of a small language.
For the top abstract class i've implemented a method map:
abstract class AST {
...
def map(f: (AST => AST)): AST = {
val b1 = this match {
case s: STRUCTURAL => s.smap(f) // structural node for example IF(expr,truebranch,falsebranch)
case _ => this // leaf, // leaf, like ASSIGN(x,2)
}
f(b1)
}
...
The smap is defined like:
override def smap(f: AST => AST) = {
this.copy(trueb = trueb.map(f), falseb = falseb.map(f))
}
Now im writing different "transformations" to insert, remove and change nodes in the AST.
For example, remove adjacent NOP nodes from blocks:
def handle_list(l:List[AST]) = l match {
case (NOP::NOP::tl) => handle_list(tl)
case h::tl => h::handle_list(tl)
case Nil => Nil
}
ast.map {
case BLOCK(listofstatements) => handle_list(listofstatements)
}
If I write like this, I end up with MatchError and I can "fix it" by changing the above map to:
ast.map {
case BLOCK(listofstatements) => handle_list(listofstatements)
case a => a
}
Should I just live with all those case a => a or could I improve my map method(or other parts) in some way?
Make the argument to map a PartialFunction:
def map(f: PartialFunction[AST, AST]): AST = {
val idAST: PartialFunction[AST, AST] = {case a => a}
val g = f.orElse(idAST)
val b1 = this match {
case s: STRUCTURAL => s.smap(g)
case _ => this
}
g(b1)
}
If tree transformations are more than a minor aspect of your project, I highly recommend you use Kiama's Rewriter module to implement them. It implements Stratego-like strategy-driven transformations. It has a very rich set of strategies and strategy combinators that permit a complete separation of traversal logic (which for the vast majority of cases can be taken "off the shelf" from the supplied strategies and combinators) from (local) transformations (which are specific to your AST and you supply, of course).