Creating sum tree of binary tree scala - scala

For a homework assignment I wrote some scala code in which I have the following classes and object (used for modeling a binary tree):
object Tree {
def fold[B](t: Tree, e: B, n: (Int, B, B) => B): B = t match {
case Node(value, l, r) => n(value,fold(l,e,n),fold(r,e,n))
case _ => e
}
def sumTree(t: Tree): Tree =
fold(t, Nil(), (a, b: Tree, c: Tree) => {
val left = b match {
case Node(value, _, _) => value
case _ => 0
}
val right = c match {
case Node(value, _, _) => value
case _ => 0
}
Node(a+left+right,b,c)
})
}
abstract case class Tree
case class Node(value: Int, left: Tree, right: Tree) extends Tree
case class Nil extends Tree
My question is about the sumTree function which creates a new tree where the nodes have values equal to the sum of the values of its children plus it's own value.
I find it rather ugly looking and I wonder if there is a better way to do this. If I use recursion which works top-down this would be easier, but I could not come up with such a function.
I have to implement the fold function, with a signature as in the code, to calculate sumTree
I got the feeling this can be implemented in a better way, maybe you have suggestions?

First of all, I believe and if I may say so, you've done a very good job. I can suggest a couple of slight changes to your code:
abstract class Tree
case class Node(value: Int, left: Tree, right: Tree) extends Tree
case object Nil extends Tree
Tree doesn't need to be a case-class, besides using a case-class as non-leaf node is deprecated because of possible erroneous behaviour of automatically generated methods.
Nil is a singleton and best defined as a case-object instead of case-class.
Additionally consider qualifying super class Tree with sealed. sealed tells compiler that the class can only be inherited from within the same source file. This lets compiler emit warnings whenever a following match expression is not exhaustive - in other words doesn't include all possible cases.
sealed abstract class Tree
The next couple of improvement could be made to the sumTree:
def sumTree(t: Tree) = {
// create a helper function to extract Tree value
val nodeValue: Tree=>Int = {
case Node(v,_,_) => v
case _ => 0
}
// parametrise fold with Tree to aid type inference further down the line
fold[Tree](t,Nil,(acc,l,r)=>Node(acc + nodeValue(l) + nodeValue(r) ,l,r))
}
nodeValue helper function can also be defined as (the alternative notation I used above is possible because a sequence of cases in curly braces is treated as a function literal):
def nodeValue (t:Tree) = t match {
case Node(v,_,_) => v
case _ => 0
}
Next little improvement is parametrising fold method with Tree (fold[Tree]). Because Scala type inferer works through the expression sequentially left-to-right telling it early that we're going to deal with Tree's lets us omit type information when defining function literal which is passed to fold further on.
So here is the full code including suggestions:
sealed abstract class Tree
case class Node(value: Int, left: Tree, right: Tree) extends Tree
case object Nil extends Tree
object Tree {
def fold[B](t: Tree, e: B, n: (Int, B, B) => B): B = t match {
case Node(value, l, r) => n(value,fold(l,e,n),fold(r,e,n))
case _ => e
}
def sumTree(t: Tree) = {
val nodeValue: Tree=>Int = {
case Node(v,_,_) => v
case _ => 0
}
fold[Tree](t,Nil,(acc,l,r)=>Node(acc + nodeValue(l) + nodeValue(r) ,l,r))
}
}
The recursion you came up with is the only possible direction that lets you traverse the tree and produce a modified copy of the immutable data structure. Any leaf nodes have to be created first before being added to the root, because individual nodes of the tree are immutable and all objects necessary to construct a node have to be known before the construction: leaf nodes need to be created before you can create root node.

As Vlad writes, your solution has about the only general shape you can have with such a fold.
Still there is a way to get rid of the node value matching, not only factor it out. And personally I would prefer it that way.
You use match because not every result you get from a recursive fold carries a sum with it. Yes, not every Tree can carry it, Nil has no place for a value, but your fold is not limited to Trees, is it?
So let's have:
case class TreePlus[A](value: A, tree: Tree)
Now we can fold it like this:
def sumTree(t: Tree) = fold[TreePlus[Int]](t, TreePlus(0, Nil), (v, l, r) => {
val sum = v+l.value+r.value
TreePlus(sum, Node(sum, l.tree, r.tree))
}.tree
Of course the TreePlus is not really needed as we have the canonical product Tuple2 in the standard library.

Your solution is probably more efficient (certainly uses less stack), but here's a recursive solution, fwiw
def sum( tree:Tree):Tree ={
tree match{
case Nil =>Nil
case Tree(a, b, c) =>val left = sum(b)
val right = sum(c)
Tree(a+total(left)+total(right), left, right)
}
}
def total(tree:Tree):Int = {
tree match{
case Nil => 0
case Tree(a, _, _) =>a
}

You've probably turned in your homework already, but I think it's still worth pointing out that the way your code (and the code in other people's answers) looks like is a direct result of how you modeled the binary trees. If, instead of using an algebraic data type (Tree, Node, Nil), you had gone with a recursive type definition, you wouldn't have had to use pattern matching to decompose your binary trees. Here's my definition of a binary tree:
case class Tree[A](value: A, left: Option[Tree[A]], right: Option[Tree[A]])
As you can see there's no need for Node or Nil here (the latter is just glorified null anyway - you don't want anything like this in your code, do you?).
With such definition, fold is essentially a one-liner:
def fold[A,B](t: Tree[A], z: B)(op: (A, B, B) => B): B =
op(t.value, t.left map (fold(_, z)(op)) getOrElse z, t.right map (fold(_, z)(op)) getOrElse z)
And sumTree is also short and sweet:
def sumTree(tree: Tree[Int]) = fold(tree, None: Option[Tree[Int]]) { (value, left, right) =>
Some(Tree(value + valueOf(left, 0) + valueOf(right, 0), left , right))
}.get
where valueOf helper is defined as:
def valueOf[A](ot: Option[Tree[A]], df: A): A = ot map (_.value) getOrElse df
No pattern matching needed anywhere - all because of a nice recursive definition of binary trees.

Related

Declaring an empty pattern matching function

How do I declare an empty/case-less pattern matching function to satisfy a type definition? I'm happy with the function just throwing a runtime exception if it's ever called.
I'm working through the Scala tutorial for Java programmers, in which I have a working function that performs variable substitution on a mathematic expression presented as a tree given a String => Int map. I want to call the same code path even when no variables should exist (in this case after taking a derivative of the expression), but I can't find a concise way to satisfy the type requirements. Here's the full code I have that works but feels wrong:
abstract class Tree
case class Sum(l: Tree, r: Tree) extends Tree
case class Var(n: String) extends Tree
case class Const(v: Int) extends Tree
object CalculatorPatternsPrime {
def eval(tree: Tree, env: String => Int): Int = tree match {
case Sum(l, r) => eval(l, env) + eval(r, env)
case Var(n) => env(n)
case Const(v) => v
}
def eval(tree: Tree): Int = eval(tree, { case "ignore" => -1 })
def derive(tree: Tree, v: String): Tree = tree match {
case Sum(l, r) => Sum(derive(l, v), derive(r, v))
case Var(n) if (n == v) => Const(1)
case _ => Const(0)
}
def main(args: Array[String]): Unit = {
val env: String => Int = { case "x" => 5 case "y" => 7 }
val tree = Sum(
Sum(Const(7), Var("y")),
Sum(Var("x"), Var("x"))
)
println(eval(tree, env))
println(derive(tree, "x"))
println(eval(derive(tree, "x")))
}
}
As you can see I have a dummy { case "ignore" => -1 } to make the type system happy, and the code works fine but I feel like there must be a better way to do this. Here are two alternatives I've considered:
Just writing out a full method body to eval(tree: Tree) instead of trying to call eval(tree: Tree, env: String => Int) but this duplicates the code handling the Sum and Const cases.
Making env an optional/union type and letting it throw an NPE.
What is the idiomatic approach here?
There are a few ways to go about this:
First, as jwvh noted, you can eliminate the single-argument eval by using a default argument for the second parameter in the two-argument eval.
The question then arises, what should that default argument be?
env is a String => Int, which is shorthand for Function1[-A, +R]: contravariant in the argument type and covariant in the result type. For our purposes, this means that any function which accepts a supertype of String (including String) and results in a subtype of Int (including Int) will work.
Since you've said you're OK with throwing, this is a reasonable default function:
{ a: Any => throw new AssertionError(s"shouldn't have looked up $a in the environment") }
Assuming that some other component of the system is ensuring that if there's a Var expression in the tree passed to eval, there's always an appropriate entry in the environment, this might be the most honest thing to do: something important in your system isn't maintaining an invariant, so trying to reason about it in your system might just make things worse.
That function works because it has the type Any => Nothing, which is a subtype of String => Int: you can pass it a String (String is a subtype of Any) and it will never have a result that's not an Int or a subtype of an Int (it doesn't have a result).
As an alternative, you could also use PartialFunction.empty as a default, which will throw a MatchError if it's ever called. This fits using a partial function literal (which is what a bare { case... } block is).
So I would have either (both have type Env = String => Int)
val emptyEnv: Env = { a: Any => throw new AssertionError(s"shouldn't have looked up $a in the environment") }
or
val emptyEnv: Env = PartialFunction.empty
And then define eval as:
def eval(tree: Tree, env: Env = emptyEnv): Int
As a side note, I would strongly recommend making Tree sealed:
sealed abstract class Tree
Which limits where classes extending Tree can be defined and gives you a stronger guarantee that invariants about Trees are enforced.

regular shaped tree fold left scala implementation

I'm trying to implement a tail-recursive foldLeft function for a regular shaped tree. The exercise comes from the book, "The Science of Functional Programming," in exercise 3.3.5.3.
Until now, I was able to do the exercises but I don't know what I'm missing in this one.
There's is a definition for the regular shaped tree:
sealed trait RTree[A]
final case class Leaf[A](x: A) extends RTree[A]
final case class Branch[A](xs: RTree[(A,A)]) extends RTree[A]
The method signature and expected result:
#tailrec
def foldLeft[A,R](t: RTree[A])(init: R)(f: (R,A)=>R): R= ???
foldLeft(Branch(Branch(Leaf(((1,2),(3,4))))))(0)(_+_)
//10
The biggest problem so far is that I don't know how to match and access the element inside the case class of Branch. I can only match the Leaf and Branch (and not the leaf's inside a branch) and because of that the recursion has no end.
Not sure if this helps, but for now I have only NON tail-recursive implementation.
def foldLeft[A, R](t: RTree[A])(init: R)(f: (R, A) => R): R = {
t match {
case Leaf(value) => f(init, value)
case Branch(tree) =>
foldLeft(tree)(init) {
case (result, (left, right)) => f(f(result, left), right)
}
}
}
UPDATE: As was said in comments section to this answer this is actually tail rec implementation, excuse me, for confusing you.

A way to avoid asInstanceOf in Scala

I have this hierarchy of traits and classes in Scala:
trait A
trait B[T] extends A {
def v: T
}
case class C(v:Int) extends B[Int]
case class D(v:String) extends B[String]
val l:List[A] = C(1) :: D("a") :: Nil
l.foreach(t => println(t.asInstanceOf[B[_]].v))
I cannot change the type hierarchy or the type of the list.
Is there a better way to avoid the asInstanceOf[B[_]] statement?
You might try pattern matching.
l.collect{case x :B[_] => println(x.v)}
You might try something like this:
for (x <- l.view; y <- Some(x).collect { case b: B[_] => b }) println(y.v)
It doesn't require any isInstanceOf or asInstanceOf, and never crashes, even if your list contains As that aren't B[_]s. It also doesn't create any lengthy lists as intermediate results, only small short-lived Options.
Not as concise, but also much less surprising solution:
for (x <- l) {
x match {
case b: B[_] => println(b.v)
case _ => /* do nothing */
}
}
If you could change the type of l to List[B[_]], this would be the preferable solution.
I think the most ideomatic way to do it would be to supply B with an extractor object and pattern match for B values:
object B {
def unapply[T](arg: B[T]): Some[T] = Some(arg.v)
}
l.collect{case B(x) => println(x)}
If B is declared in a source file you can't alter you might need a different name for the extractor object.

How can I make my immutable binary search-tree generic in Scala?

I am newcomer to Scala. I'm trying to develop my own immutable binary search tree.
Firstly, I developed a binary search tree that takes Int on its nodes. After that , I decided to develop generic binary search tree.
When I compiled these codes , I took these error message from terminal.
trait GenericBST[+T] {
def add[TT >: T](x: T): GenericBST[TT] = this match {
case Empty => Branch(x, Empty, Empty)
case Branch(d, l, r) =>
if(d > x) Branch(d, l.add(x), r)
else if(d < x) Branch(d, l, r.add(x))
else this
}
}
case class Branch[+T](x: T, left: GenericBST[T], right: GenericBST[T]) extends GenericBST[T]
case object Empty extends GenericBST[Nothing]
error: value < is not member of type paramater T.
The error is sensible, how can I fix it?
Don't forget I am newcomer for Scala, so please explain this in detail for me.
T represents any type, but in order to use > and < you need a type for which ordering makes sense.
In scala words, it means to you have to put a bound of the type T, restricting it to all T for which an Ordering[T] exists. You can use a context bound, or equivalently require an implicit ord of type Ordering[TT].
trait GenericBST[+A] {
def add[B >: A](x: B)(implicit ord: Ordering[B]): GenericBST[B] = {
import ord.mkOrderingOps
this match {
case Empty => Branch(x, Empty, Empty)
case Branch(e, l, r) =>
if (e > x) Branch(e, l.add(x), r)
else if (e < x) Branch(e, l, r.add(x))
else this
}
}
}
case class Branch[+A](x: A, left: GenericBST[A], right: GenericBST[A]) extends GenericBST[A]
case object Empty extends GenericBST[Nothing]
Importing ord.mkOrderingOps allows for the syntax
e > x
instead of
ord.gt(e, x)
You could also use a context bound directly, but it would require some extra work to get the implicit ord in scope (and it's arguably less readable):
def add[B >: A : Ordering](x: B): GenericBST[B] = {
val ord = implicitly[Ordering[B]]
import ord.mkOrderingOps
...
}
Absolutely not relevant, but you might be wondering why I used A and B in my example, as opposed to T and TT. According to the official style guide:
For simple type parameters, a single upper-case letter (from the English alphabet) should be used, starting with A (this is different than the Java convention of starting with T)

Generically rewriting Scala case classes

Is it possible to generically replace arguments in a case class? More specifically, say I wanted a substitute function that received a "find" case class and a "replace" case class (like the left and right sides of a grammar rule) as well as a target case class, and the function would return a new case class with arguments of the find case class replaced with the replace case class? The function could also simply take a case class (Product?) and a function to be applied to all arguments/products of the case class.
Obviously, given a specific case class, I could use unapply and apply -- but what's the best/easiest/etc way to generically (given any case class) write this sort of function?
I'm wondering if there is a good solution using Scala 2.10 reflection features or Iso.hlist from shapeless.
For example, what I really want to be able to do is, given classes like the following...
class Op[T]
case class From(x:Op[Int]) extends Op[Int]
case class To(x:Op[Int]) extends Op[Int]
case class Target(a:Op[Int], b:Op[Int]) extends ...
// and lots of other similar case classes
... have a function that can take an arbitrary case class and return a copy of it with any elements of type From replaced with instances of type To.
If you'll pardon the plug, I think you'll find that the rewriting component of our Kiama language processing library is perfect for this kind of purpose. It provides a very powerful form of strategic programming.
Here is a complete solution that rewrites To's to From's in a tree made from case class instances.
import org.kiama.rewriting.Rewriter
class Op[T]
case class Leaf (i : Int) extends Op[Int]
case class From (x : Op[Int]) extends Op[Int]
case class To (x : Op[Int]) extends Op[Int]
case class Target1 (a : Op[Int], b : Op[Int]) extends Op[Int]
case class Target2 (c : Op[Int]) extends Op[Int]
object Main extends Rewriter {
def main (args : Array[String]) {
val replaceFromsWithTos =
everywhere {
rule {
case From (x) => To (x)
}
}
val t1 = Target1 (From (Leaf (1)), To (Leaf (2)))
val t2 = Target2 (Target1 (From (Leaf (3)), Target2 (From (Leaf (4)))))
println (rewrite (replaceFromsWithTos) (t1))
println (rewrite (replaceFromsWithTos) (t2))
}
}
The output is
Target1(To(Leaf(1)),To(Leaf(2)))
Target2(Target1(To(Leaf(3)),Target2(To(Leaf(4)))))
The idea of the replaceFromsWithTos value is that the rule construct lifts a partial function to be able to operate on any kind of value. In this case the partial function is only defined at From nodes, replacing them with To nodes. The everywhere combinator says "apply my argument to all nodes in the tree, leaving unchanged places where the argument does not apply.
Much more can be done than this kind of simple rewrite. See the main Kiama rewriting documentation for the gory detail, including links to some more examples.
I experimented a bit with shapeless and was able to come up with the following, relatively generic way of converting one case class into another:
import shapeless._ /* shapeless 1.2.3-SNAPSHOT */
case class From(s: String, i: Int)
case class To(s: String, i: Int)
implicit def fromIso = Iso.hlist(From.apply _, From.unapply _)
implicit def toIso = Iso.hlist(To.apply _, To.unapply _)
implicit def convert[A, B, L <: HList]
(a: A)
(implicit srcIso: Iso[A, L],
dstIso: Iso[B, L])
: B =
dstIso.from(srcIso.to(a))
val f1 = From("Hi", 7)
val t1 = convert(f1)(fromIso, toIso)
println("f1 = " + f1) // From("Hi", 7)
println("t1 = " + t1) // To("Hi", 7)
However, I was not able to get the implicits right. Ideally,
val t1: To = f1
would be sufficient, or maybe
val t1 = convert(f1)
Another nice improvement would be to get rid of the need of having to explicitly declare iso-implicits (fromIso, toIso) for each case class.
I don't think you'll really find a better way than just using unapply/apply through pattern matching:
someValue match {
case FindCaseClass(a, b, c) => ReplaceCaseClass(a, b, c)
// . . .
}
You have to write out the rules to associate FindCaseClass with ReplaceCaseClass somehow, and although you might be able to do it a little more succinctly by somehow just using the names, this has the added benefit of also checking the number and types of the case class fields at compile time to make sure everything matches just right.
There is probably some way to do this automatically using the fact that all case classes extend Product, but the fact that productElement(n) returns Any might make it a bit of a pain—I think that's where reflection would have to come in. Here's a little something to get you started:
case class From(i: Int, s: String, xs: Seq[Nothing])
case class To(i: Int, s: String, xs: Seq[Nothing])
val iter = From(5,"x",Nil).productIterator
val f = To.curried
iter.foldLeft(f: Any) { _.asInstanceOf[Any => Any](_) }
// res0: Any = To(5,x,List())
But really, I think you're better off with the pattern-matching version.
Edit: Here is a version with the relavent code refactored into a method:
case class From(i: Int, s: String, xs: Seq[Nothing])
case class To(i: Int, s: String, xs: Seq[Nothing])
type Curryable = { def curried: _ => _ }
def recase(from: Product, to: Curryable) = {
val iter = from.productIterator
val f = to.curried
iter.foldLeft(f: Any) { _.asInstanceOf[Any => Any](_) }
}
recase(From(5,"x",Nil), To)
// res0: Any = To(5,x,List())