I am a bit new to scala and I have a program here that does everything I would like it to, except one thing I cannot seem to figure out. I Have expressions and I simplify the expression based on a few simple math rules.Then I prompt the user for bindings to variables and substitutes the variables with integers the user types in and then the expression is evaluated. That all works like I want it to. However I still want to simplify/evaluate the expression when I do not substitute a specific variable for an integer. I just simply evaluate the expression without asking the user for a value for x and I remove the statement for replacing x in my Environment for example. The problem is I get a scala.Match error when the eval method calls the env when it hits the x in my expression. To maybe make it clearer what I am trying to do is with the expression in my example.. Expression: (x + (x * (y - (z / 2)))) by only substituting the variables y and z with 2, the result I would like to see is (x + x). I would appreciate any advice/help possible! My program at this point looks like this:
var n1 = 0
var n2 = 0
var n3 = 0
type Environment = String => Int
lazy val exp: Tree = Sum(Var("x"), Times(Var("x"), Minus(Var("y"), Divide(Var("z"), Const(2)))))
lazy val env: Environment = {
case "x" => n1 //take this line out to not bind x
case "y" => n2
case "z" => n3
}
abstract class Tree
case class Sum(l: Tree, r: Tree) extends Tree {
override def toString = "(" + l + " + " + r + ")"
}
case class Minus(l: Tree, r: Tree) extends Tree {
override def toString = "(" + l + " - " + r + ")"
}
case class Times(l: Tree, r: Tree) extends Tree {
override def toString = "(" + l + " * " + r + ")"
}
case class Divide(l: Tree, r: Tree) extends Tree {
override def toString = "(" + l + " / " + r + ")"
}
case class Var(n: String) extends Tree {
override def toString = n
}
case class Const(v: Int) extends Tree {
override def toString = v.toString
}
def simplify(t: Tree): Tree = t match {
case Times(Const(1), r) => simplify(r)
case Times(l, Const(1)) => simplify(l)
case Times(Const(0), r) => Const(0)
case Times(l, Const(0)) => Const(0)
case Sum(Const(0), r) => simplify(r)
case Sum(l, Const(0)) => simplify(l)
case Minus(l, Const(0)) => simplify(l)
case Minus(l, r) if l == r => Const(0)
case Divide(Const(0), r) => Const(0)
case Divide(l, Const(1)) => simplify(l)
case Divide(l, r) if l == r => Const(1)
case Times(l, r) => Times(simplify(l), simplify(r))
case Sum(l, r) => Sum(simplify(l), simplify(r))
case Minus(l, r) => Minus(simplify(l), simplify(r))
case Divide(l, r) => Divide(simplify(l), simplify(r))
case _ => t
}
def eval(t: Tree): Int = t match {
case Sum(l, r) => eval(l) + eval(r)
case Minus(l, r) => eval(l) - eval(r)
case Times(l, r) => eval(l) * eval(r)
case Divide(l, r) => eval(l) / eval(r)
case Var(n) => env(n)
case Const(v) => v
}
println("Expression: " + exp)
println("Simplified: " + simplify(exp))
println("Enter the binding for x.") //take this line out to not bind x
n1 = readInt() //take this line out to not bind x
println("Enter the binding for y.")
n2 = readInt()
println("Enter the binding for z.")
n3 = readInt()
println("Evaluation: " + eval(exp))
Here is a relatively clean solution. I've only included the pieces I modified.
type Environment = String => Option[Int]
lazy val env: Environment = {
case "y" => Some(n2)
case "z" => Some(n3)
case _ => None
}
def simplify(t: Tree): Tree = {
val reducedTree = t match {
case Times(l, r) => Times(simplify(l), simplify(r))
case Sum(l, r) => Sum(simplify(l), simplify(r))
case Minus(l, r) => Minus(simplify(l), simplify(r))
case Divide(l, r) => Divide(simplify(l), simplify(r))
case Var(n) => env(n).map(Const).getOrElse(t)
case _ => t
}
reducedTree match {
case Times(Const(1), r) => r
case Times(l, Const(1)) => l
case Times(Const(0), r) => Const(0)
case Times(l, Const(0)) => Const(0)
case Times(Const(l), Const(r)) => Const(l * r)
case Sum(Const(0), r) => r
case Sum(l, Const(0)) => l
case Sum(Const(l), Const(r)) => Const(l + r)
case Minus(l, Const(0)) => l
case Minus(l, r) if l == r => Const(0)
case Minus(Const(l), Const(r)) => Const(l - r)
case Divide(Const(0), r) => Const(0)
case Divide(l, Const(1)) => l
case Divide(l, r) if l == r => Const(1)
case Divide(Const(l), Const(r)) => Const(l / r)
case _ => reducedTree
}
}
I changed your code a little bit to make it work as you want:
abstract class Tree
case class Sum(l: Tree, r: Tree) extends Tree {
override def toString = "(" + l + " + " + r + ")"
}
case class Minus(l: Tree, r: Tree) extends Tree {
override def toString = "(" + l + " - " + r + ")"
}
case class Times(l: Tree, r: Tree) extends Tree {
override def toString = "(" + l + " * " + r + ")"
}
case class Divide(l: Tree, r: Tree) extends Tree {
override def toString = "(" + l + " / " + r + ")"
}
case class Var(n: String) extends Tree {
override def toString = n
}
case class Const(v: Int) extends Tree {
override def toString = v.toString
}
object ExprEval {
lazy val exp: Tree = Sum(Var("x"), Times(Var("x"), Minus(Var("y"), Divide(Var("z"), Const(2)))))
var env = Map[String, Int]()
def simplify(t: Tree, recursive : Boolean = true): Tree = {
t match {
case Sum(Const(0), r) => simplify(r)
case Sum(l, Const(0)) => simplify(l)
case Sum(Const(l), Const(r)) => Const(l + r)
case Sum(Var(l), Var(r)) if l == r => Times(Const(2), Var(l))
case Sum(l, r) if recursive => simplify(Sum(simplify(l), simplify(r)), recursive = false)
case Minus(l, Const(0)) => simplify(l)
case Minus(l, r) if l == r => Const(0)
case Minus(Const(l), Const(r)) => Const(l - r)
case Minus(l, r) if recursive => simplify(Minus(simplify(l), simplify(r)), recursive = false)
case Times(Const(1), r) => simplify(r)
case Times(l, Const(1)) => simplify(l)
case Times(Const(0), r) => Const(0)
case Times(l, Const(0)) => Const(0)
case Times(Const(l), Const(r)) => Const(l * r)
case Times(l, r) if recursive => simplify(Times(simplify(l), simplify(r)), recursive = false)
case Divide(Const(0), r) => Const(0)
case Divide(l, Const(1)) => simplify(l)
case Divide(l, r) if l == r => Const(1)
case Divide(Const(l), Const(r)) => Const(l / r)
case Divide(l, r) if recursive => simplify(Divide(simplify(l), simplify(r)), recursive = false)
case Var(n) => env.get(n).map(Const).getOrElse(Var(n))
case _ => t
}
}
def eval(t: Tree): Int = t match {
case Sum(l, r) => eval(l) + eval(r)
case Minus(l, r) => eval(l) - eval(r)
case Times(l, r) => eval(l) * eval(r)
case Divide(l, r) => eval(l) / eval(r)
case Var(n) => env(n)
case Const(v) => v
}
def main(args: Array[String]) {
env = Map()
println(s"env: $env, exp : $exp, simplified : ${simplify(exp)}")
env = Map("y" -> 2, "z" -> 2)
println(s"env: $env, exp : $exp, simplified : ${simplify(exp)}")
env = Map("z" -> 4)
println(s"env: $env, exp : $exp, simplified : ${simplify(exp)}")
env = Map("x" -> 3, "y" -> 2, "z" -> 2)
println(s"env: $env, exp : $exp, simplified : ${simplify(exp)}")
}
}
If no variables are bound, it returns a simplified original expression tree. If all variables are bound, it produces the same result as eval().
Output:
env: Map(), exp : (x + (x * (y - (z / 2)))), simplified : (x + (x * (y - (z / 2))))
env: Map(y -> 2, z -> 2), exp : (x + (x * (y - (z / 2)))), simplified : (2 * x)
env: Map(z -> 4), exp : (x + (x * (y - (z / 2)))), simplified : (x + (x * (y - 2)))
env: Map(x -> 3, y -> 2, z -> 2), exp : (x + (x * (y - (z / 2)))), simplified : 6
Related
I wrote a method to calculate the maximum depth of a binary tree.
I would like to write a tail recursive method.
I thought of using lists, but I didn't find solutions
This is my method that is not tail recursive:
def depth: Int = {
def iter(f: FormulaWff): Int = f match {
case Var(_) => 0
case Not(e1) => 1 + iter(e1)
case And(e1, e2) => 1 + Math.max(iter(e1), iter(e2))
case Or(e1, e2) => 1 + Math.max(iter(e1), iter(e2))
case Implies(e1, e2) => 1 + Math.max(iter(e1), iter(e2))
}
iter(this)
}
Try
import scala.util.control.TailCalls.{TailRec, done, tailcall}
trait FormulaWff {
def depth: Int = {
def iter(f: FormulaWff): TailRec[Int] = {
def hlp(e1: FormulaWff, e2: FormulaWff): TailRec[Int] = for {
x <- tailcall(iter(e1))
y <- tailcall(iter(e2))
} yield 1 + Math.max(x, y)
f match {
case Var(_) => done(0)
case Not(e1) => for {
x <- tailcall(iter(e1))
} yield 1 + x
case And(e1, e2) => hlp(e1, e2)
case Or(e1, e2) => hlp(e1, e2)
case Implies(e1, e2) => hlp(e1, e2)
}
}
iter(this).result
}
}
case class Var(s: String) extends FormulaWff
case class Not(e: FormulaWff) extends FormulaWff
case class And(e1: FormulaWff, e2: FormulaWff) extends FormulaWff
case class Or(e1: FormulaWff, e2: FormulaWff) extends FormulaWff
case class Implies(e1: FormulaWff, e2: FormulaWff) extends FormulaWff
Direct solution
sealed trait FormulaWff
final case class Var(s: String) extends FormulaWff
final case class Not(e: FormulaWff) extends FormulaWff
final case class And(e1: FormulaWff, e2: FormulaWff) extends FormulaWff
final case class Or(e1: FormulaWff, e2: FormulaWff) extends FormulaWff
final case class Implies(e1: FormulaWff, e2: FormulaWff) extends FormulaWff
sealed trait Result
case object NotExpanded extends Result
case object Expanded extends Result
final case class Calculated(value: Int) extends Result
final case class Frame(arg: FormulaWff, res: Result)
def depth(f: FormulaWff): Int = step1(List(Frame(f, NotExpanded))) match {
case Frame(arg, Calculated(res)) :: Nil => res
}
#tailrec
def step1(stack: List[Frame]): List[Frame] = {
val x = step(stack, Nil)
x match {
case Frame(arg, Calculated(res)) :: Nil => x
case _ => step1(x)
}
}
#tailrec
def step(stack: List[Frame], acc: List[Frame]): List[Frame] = {
stack match {
case Frame(_, Calculated(res1)) :: Frame(_, Calculated(res2)) :: Frame(And(e1, e2), Expanded) :: frames =>
step(frames, Frame(And(e1, e2), Calculated(1 + math.max(res1, res2))) :: acc)
case Frame(_, Calculated(res1)) :: Frame(_, Calculated(res2)) :: Frame(Or(e1, e2), Expanded) :: frames =>
step(frames, Frame(Or(e1, e2), Calculated(1 + math.max(res1, res2))) :: acc)
case Frame(_, Calculated(res1)) :: Frame(_, Calculated(res2)) :: Frame(Implies(e1, e2), Expanded) :: frames =>
step(frames, Frame(Implies(e1, e2), Calculated(1 + math.max(res1, res2))) :: acc)
case Frame(_, Calculated(res1)) :: Frame(Not(e1), Expanded) :: frames =>
step(frames, Frame(Not(e1), Calculated(1 + res1)) :: acc)
case Frame(Var(s), _) :: frames =>
step(frames, Frame(Var(s), Calculated(0)) :: acc)
case Frame(Not(e1), NotExpanded) :: frames =>
step(frames, Frame(Not(e1), Expanded) :: Frame(e1, NotExpanded) :: acc)
case Frame(And(e1, e2), NotExpanded) :: frames =>
step(frames, Frame(And(e1, e2), Expanded) :: Frame(e1, NotExpanded) :: Frame(e2, NotExpanded) :: acc)
case Frame(Or(e1, e2), NotExpanded) :: frames =>
step(frames, Frame(Or(e1, e2), Expanded) :: Frame(e1, NotExpanded) :: Frame(e2, NotExpanded) :: acc)
case Frame(Implies(e1, e2), NotExpanded) :: frames =>
step(frames, Frame(Implies(e1, e2), Expanded) :: Frame(e1, NotExpanded) :: Frame(e2, NotExpanded) :: acc)
case Frame(arg, Expanded) :: frames => step(frames, Frame(arg, Expanded) :: acc)
case Frame(arg, Calculated(res)) :: frames => step(frames, Frame(arg, Calculated(res)) :: acc)
case Nil => acc.reverse
}
}
How to make tree mapping tail-recursive?
I'm trying to create a Gen for BST with ScalaCheck, but when I call .sample method it gives me java.lang.NullPointerException. Where do I wrong?
sealed trait Tree
case class Node(left: Tree, right: Tree, v: Int) extends Tree
case object Leaf extends Tree
import org.scalacheck._
import Gen._
import Arbitrary.arbitrary
case class GenerateBST() {
def genValue(left: Tree, right: Tree): Gen[Int] = (left, right) match {
case (Node(_, _, min), Node(_, _, max)) => arbitrary[Int].suchThat(x => x > min && x < max)
case (Node(_, _, min), Leaf) => arbitrary[Int].suchThat(x => x > min)
case (Leaf, Node(_, _, max)) => arbitrary[Int].suchThat(x => x < max)
case (Leaf, Leaf) => arbitrary[Int]
}
val genNode = for {
left <- genTree
right <- genTree
v <- genValue(left, right)
} yield Node(left, right, v)
def genTree: Gen[Tree] = oneOf(const(Leaf), genNode)
}
GenerateBST().genTree.sample
Because of the way you're defining the generator recursively for a recursive data type, you need to use Gen.lzy at the top:
def genTree: Gen[Tree] = Gen.lzy(oneOf(const(Leaf), genNode))
As an unrelated side note, using suchThat in your generator definitions should generally only be a last resort. It means that sample will often fail (about a third of the time with the fixed version of your code), and more importantly that if you someday want to create arbitrary functions resulting in Tree, you're going to see a lot of horrible org.scalacheck.Gen$RetrievalError: couldn't generate value exceptions.
In this case you can avoid suchThat pretty easily by using Gen.chooseNum and swapping the left and right sides if they're in the wrong ordered:
sealed trait Tree
case class Node(left: Tree, right: Tree, v: Int) extends Tree
case object Leaf extends Tree
import org.scalacheck.{ Arbitrary, Gen }
object GenerateBST {
def swapIfNeeded(l: Tree, r: Tree): (Tree, Tree) = (l, r) match {
// If the two trees don't have space between them, we bump one and recheck:
case (Node(_, _, x), n # Node(_, _, y)) if math.abs(x - y) <= 1 =>
swapIfNeeded(l, n.copy(v = y + 1))
// If the values are in the wrong order, swap:
case (Node(_, _, x), Node(_, _, y)) if x > y => (r, l)
// Otherwise do nothing:
case (_, _) => (l, r)
}
def genValue(left: Tree, right: Tree): Gen[Int] = (left, right) match {
case (Node(_, _, min), Node(_, _, max)) => Gen.chooseNum(min + 1, max - 1)
case (Node(_, _, min), Leaf) => Gen.chooseNum(min + 1, Int.MaxValue)
case (Leaf, Node(_, _, max)) => Gen.chooseNum(Int.MinValue, max - 1)
case (Leaf, Leaf) => Arbitrary.arbitrary[Int]
}
val genNode = for {
l0 <- genTree
r0 <- genTree
(left, right) = swapIfNeeded(l0, r0)
v <- genValue(left, right)
} yield Node(left, right, v)
def genTree: Gen[Tree] = Gen.lzy(Gen.oneOf(Gen.const(Leaf), genNode))
}
Now you can use Arbitrary[Whatever => Tree] without worrying about generator failures:
scala> implicit val arbTree: Arbitrary[Tree] = Arbitrary(GenerateBST.genTree)
arbTree: org.scalacheck.Arbitrary[Tree] = org.scalacheck.ArbitraryLowPriority$$anon$1#606abb0e
scala> val f = Arbitrary.arbitrary[Int => Tree].sample.get
f: Int => Tree = org.scalacheck.GenArities$$Lambda$7109/289518656#13eefeaf
scala> f(1)
res0: Tree = Leaf
scala> f(2)
res1: Tree = Node(Leaf,Leaf,-20313200)
scala> f(3)
res2: Tree = Leaf
scala> f(4)
res3: Tree = Node(Node(Leaf,Leaf,-850041807),Leaf,-1)
How to emulate following behavior in Scala? i.e. keep folding while some certain conditions on the accumulator are met.
def foldLeftWhile[B](z: B, p: B => Boolean)(op: (B, A) => B): B
For example
scala> val seq = Seq(1, 2, 3, 4)
seq: Seq[Int] = List(1, 2, 3, 4)
scala> seq.foldLeftWhile(0, _ < 3) { (acc, e) => acc + e }
res0: Int = 1
scala> seq.foldLeftWhile(0, _ < 7) { (acc, e) => acc + e }
res1: Int = 6
UPDATES:
Based on #Dima answer, I realized that my intention was a little bit side-effectful. So I made it synchronized with takeWhile, i.e. there would be no advancement if the predicate does not match. And add some more examples to make it clearer. (Note: that will not work with Iterators)
First, note that your example seems wrong. If I understand correctly what you describe, the result should be 1 (the last value on which the predicate _ < 3 was satisfied), not 6
The simplest way to do this is using a return statement, which is very frowned upon in scala, but I thought, I'd mention it for the sake of completeness.
def foldLeftWhile[A, B](seq: Seq[A], z: B, p: B => Boolean)(op: (B, A) => B): B = foldLeft(z) { case (b, a) =>
val result = op(b, a)
if(!p(result)) return b
result
}
Since we want to avoid using return, scanLeft might be a possibility:
seq.toStream.scanLeft(z)(op).takeWhile(p).last
This is a little wasteful, because it accumulates all (matching) results.
You could use iterator instead of toStream to avoid that, but Iterator does not have .last for some reason, so, you'd have to scan through it an extra time explicitly:
seq.iterator.scanLeft(z)(op).takeWhile(p).foldLeft(z) { case (_, b) => b }
It is pretty straightforward to define what you want in scala. You can define an implicit class which will add your function to any TraversableOnce (that includes Seq).
implicit class FoldLeftWhile[A](trav: TraversableOnce[A]) {
def foldLeftWhile[B](init: B)(where: B => Boolean)(op: (B, A) => B): B = {
trav.foldLeft(init)((acc, next) => if (where(acc)) op(acc, next) else acc)
}
}
Seq(1,2,3,4).foldLeftWhile(0)(_ < 3)((acc, e) => acc + e)
Update, since the question was modified:
implicit class FoldLeftWhile[A](trav: TraversableOnce[A]) {
def foldLeftWhile[B](init: B)(where: B => Boolean)(op: (B, A) => B): B = {
trav.foldLeft((init, false))((a,b) => if (a._2) a else {
val r = op(a._1, b)
if (where(r)) (op(a._1, b), false) else (a._1, true)
})._1
}
}
Note that I split your (z: B, p: B => Boolean) into two higher-order functions. That's just a personal scala style preference.
What about this:
def foldLeftWhile[A, B](z: B, xs: Seq[A], p: B => Boolean)(op: (B, A) => B): B = {
def go(acc: B, l: Seq[A]): B = l match {
case h +: t =>
val nacc = op(acc, h)
if(p(nacc)) go(op(nacc, h), t) else nacc
case _ => acc
}
go(z, xs)
}
val a = Seq(1,2,3,4,5,6)
val r = foldLeftWhile(0, a, (x: Int) => x <= 3)(_ + _)
println(s"$r")
Iterate recursively on the collection while the predicate is true, and then return the accumulator.
You cand try it on scalafiddle
After a while I received a lot of good looking answers. So, I combined them to this single post
a very concise solution by #Dima
implicit class FoldLeftWhile[A](seq: Seq[A]) {
def foldLeftWhile[B](z: B)(p: B => Boolean)(op: (B, A) => B): B = {
seq.toStream.scanLeft(z)(op).takeWhile(p).lastOption.getOrElse(z)
}
}
by #ElBaulP (I modified a little bit to match comment by #Dima)
implicit class FoldLeftWhile[A](seq: Seq[A]) {
def foldLeftWhile[B](z: B)(p: B => Boolean)(op: (B, A) => B): B = {
#tailrec
def foldLeftInternal(acc: B, seq: Seq[A]): B = seq match {
case x :: _ =>
val newAcc = op(acc, x)
if (p(newAcc))
foldLeftInternal(newAcc, seq.tail)
else
acc
case _ => acc
}
foldLeftInternal(z, seq)
}
}
Answer by me (involving side effects)
implicit class FoldLeftWhile[A](seq: Seq[A]) {
def foldLeftWhile[B](z: B)(p: B => Boolean)(op: (B, A) => B): B = {
var accumulator = z
seq
.map { e =>
accumulator = op(accumulator, e)
accumulator -> e
}
.takeWhile { case (acc, _) =>
p(acc)
}
.lastOption
.map { case (acc, _) =>
acc
}
.getOrElse(z)
}
}
Fist exemple: predicate for each element
First you can use inner tail recursive function
implicit class TravExt[A](seq: TraversableOnce[A]) {
def foldLeftWhile[B](z: B, f: A => Boolean)(op: (A, B) => B): B = {
#tailrec
def rec(trav: TraversableOnce[A], z: B): B = trav match {
case head :: tail if f(head) => rec(tail, op(head, z))
case _ => z
}
rec(seq, z)
}
}
Or short version
implicit class TravExt[A](seq: TraversableOnce[A]) {
#tailrec
final def foldLeftWhile[B](z: B, f: A => Boolean)(op: (A, B) => B): B = seq match {
case head :: tail if f(head) => tail.foldLeftWhile(op(head, z), f)(op)
case _ => z
}
}
Then use it
val a = List(1, 2, 3, 4, 5, 6).foldLeftWhile(0, _ < 3)(_ + _)
//a == 3
Second example: for accumulator value:
implicit class TravExt[A](seq: TraversableOnce[A]) {
def foldLeftWhile[B](z: B, f: A => Boolean)(op: (A, B) => B): B = {
#tailrec
def rec(trav: TraversableOnce[A], z: B): B = trav match {
case _ if !f(z) => z
case head :: tail => rec(tail, op(head, z))
case _ => z
}
rec(seq, z)
}
}
Or short version
implicit class TravExt[A](seq: TraversableOnce[A]) {
#tailrec
final def foldLeftWhile[B](z: B, f: A => Boolean)(op: (A, B) => B): B = seq match {
case _ if !f(z) => z
case head :: tail => tail.foldLeftWhile(op(head, z), f)(op)
case _ => z
}
}
Simply use a branch condition on the accumulator:
seq.foldLeft(0, _ < 3) { (acc, e) => if (acc < 3) acc + e else acc}
However you will run every entry of the sequence.
I am learning Funcional Programming in Scala and quite often I need to trace a function evaluation in order to understand better how it works.
For example, having the following function:
def foldRight[A,B](l: List[A], z: B)(f: (A, B) => B): B =
l match {
case Nil => z
case Cons(x, xs) => f(x, foldRight(xs, z)(f))
}
For the following call:
foldRight(Cons(1, Cons(2, Cons(3, Nil))), 0)(_ + _)
I would like to get printed its evaluation trace, like this:
foldRight(Cons(1, Cons(2, Cons(3, Nil))), 0)(_ + _)
1 + foldRight(Cons(2, Cons(3, Nil)), 0)(_ + _)
1 + (2 + foldRight(Cons(3, Nil), 0)(_ + _))
1 + (2 + (3 + (foldRight(Nil, 0)(_ + _))))
1 + (2 + (3 + (0)))
6
Currently I am doing either manually or injecting ugly print's. How can I achieve that in a convenient elegant way?
I assumed that Cons and :: are the same operations.
If you don't mind getting only the current element and the accumulator you can do the following:
def printable(x:Int, y:Int): Int = {
println("Curr: "+x.toString+" Acc:"+ y.toString)
x+y
}
foldRight(List(1, 2, 3, 4), 0)(printable(_,_))
//> Curr: 4 Acc:0
//| Curr: 3 Acc:4
//| Curr: 2 Acc:7
//| Curr: 1 Acc:9
//| res0: Int = 10
If you want the whole "stacks trace" this will give you the output you asked, although it is far from elegant:
def foldRight[A, B](l: List[A], z: B)(f: (A, B) => B): B = {
var acc = if (l.isEmpty) "" else l.head.toString
def newAcc(acc: String, x: A) = acc + " + (" + x
def rightSide(xs: List[A], z: B, size: Int) = xs.toString + "," + z + ")" * (l.size - size + 1)
def printDebug(left: String, right: String) = println(left + " + foldRight(" + right)
def go(la: List[A], z: B)(f: (A, B) => B): B = la match {
case Nil => z
case x :: xs => {
acc = newAcc(acc, x)
printDebug(acc, rightSide(xs, z, la.size))
f(x, go(xs, z)(f))
}
}
if (l.isEmpty) z
else f(l.head, go(l.tail, z)(f))
}
Note: to get rid of the variable 'acc' you can make a second accumulator in the 'go' function
This one also returns the output you asked for, but doesn't obscure foldRight.
class Trace[A](z: A) {
var list = List[A]()
def store(x: A) = {
list = list :+ x
}
def getTrace(level: Int): String = {
val left = list.take(level).map(x => s"$x + (").mkString
val right = list.drop(level).map(x => s"$x,").mkString
if (right.isEmpty)
s"${left.dropRight(4)}" + ")" * (list.size - 1)
else
s"${left}foldRight(List(${right.init}), $z)" + ")" * (list.size - level - 1)
}
def getFullTrace: String =
{ for (i <- 0 to list.size) yield getTrace(i) }.mkString("\n")
def foldRight(l: List[A], z: A)(f: (A, A) => A): A = l match {
case Nil => z
case x :: xs => store(x); f(x, foldRight(xs, z)(f))
}
}
val start = 0
val t = new Trace[Int](start)
t.foldRight(List(1, 2, 3, 4), start)(_ + _)
t.getFullTrace
Let's see an example (it's a naive example but sufficient to illustrate the problem).
def produce(l: List[Int]) : Any =
l match {
case List(x) => x
case List(x, y) => (x, y)
}
val client1 : Int = produce(List(1)).asInstanceOf[Int]
Drawback : client need to cast !
def produce2[A](l: List[Int])(f: List[Int] => A) = {
f(l)
}
val toOne = (l: List[Int]) => l.head
val toTwo = (l: List[Int]) => (l.head, l.tail.head)
val client2 : Int = produce2(List(1))(toOne)
Drawback : type safety, i.e. we can call toTwo with a singleton List.
Is there a better solution ?
If you only have two possible return values you could use Either:
def produce(l : List[Any]) : Either[Any, (Any, Any)] = l match {
case List(x) => Left(x)
case List(x, y) => Right((x, y))
}
If you don't want to create an Either, you could pass a function to transform each case:
def produce[A](l : List[Int])(sf: Int => A)(pf: (Int, Int) => A): A = l match {
case List(x) => sf(x)
case List(x, y) => pf(x, y)
}
Will this work?
def produce(l: List[Int]) = {
l match {
case List(x) => (x, None)
case List(x,y) => (x,y)
case Nil => (None, None)
}
}
or even better, to avoid match errors on lists longer than 2 elements:
def produce(l: List[Int]) =
l match {
case x :: Nil => (x, None)
case x :: xs => (x,xs.head)
case Nil => (None, None)
}