Scala Extractor with Argument - scala

Is there a syntax in Scala that allows extractors to take a customization argument? This example is a bit contrived. Suppose I have a binary search tree on integers, and I want to match on the current node if its value is divisible by some custom value.
Using F# active patterns, I can do the following:
type Tree =
| Node of int * Tree * Tree
| Empty
let (|NodeDivisibleBy|_|) x t =
match t with
| Empty -> None
| Node(y, l, r) -> if y % x = 0 then Some((l, r)) else None
let doit = function
| NodeDivisibleBy(2)(l, r) -> printfn "Matched two: %A %A" l r
| NodeDivisibleBy(3)(l, r) -> printfn "Matched three: %A %A" l r
| _ -> printfn "Nada"
[<EntryPoint>]
let main args =
let t10 = Node(10, Node(1, Empty, Empty), Empty)
let t15 = Node(15, Node(1, Empty, Empty), Empty)
doit t10
doit t15
0
In Scala, I can do something similar, but not quite what I want:
sealed trait Tree
case object Empty extends Tree
case class Node(v: Int, l: Tree, r: Tree) extends Tree
object NodeDivisibleBy {
def apply(x: Int) = new {
def unapply(t: Tree) = t match {
case Empty => None
case Node(y, l, r) => if (y % x == 0) Some((l, r)) else None
}
}
}
def doit(t: Tree) {
// I would prefer to not need these two lines.
val NodeDivisibleBy2 = NodeDivisibleBy(2)
val NodeDivisibleBy3 = NodeDivisibleBy(3)
t match {
case NodeDivisibleBy2(l, r) => println("Matched two: " + l + " " + r)
case NodeDivisibleBy3(l, r) => println("Matched three: " + l + " " + r)
case _ => println("Nada")
}
}
val t10 = Node(10, Node(1, Empty, Empty), Empty)
val t15 = Node(15, Node(1, Empty, Empty), Empty)
doit(t10)
doit(t15)
It would be great if I could do:
case NodeDivisibleBy(2)(l, r) => println("Matched two: " + l + " " + r)
case NodeDivisibleBy(3)(l, r) => println("Matched three: " + l + " " + r)
but this is a compile time error: '=>' expected but '(' found.
Thoughts?

From the spec:
SimplePattern ::= StableId ‘(’ [Patterns] ‘)’
An extractor pattern x(p1, ..., pn) where n ≥ 0 is of the same
syntactic form as a constructor pattern. However, instead of a case
class, the stable identifier x denotes an object which has a member
method named unapply or unapplySeq that matches the pattern.
and:
A stable identifier is a path which ends in an identifier.
i.e., not an expression like NodeDivisibleBy(2).
So no, this isn't possible in any straightforward way in Scala, and personally I think that's just fine: having to write the following (which by the way I'd probably define in the NodeDivisibleBy object and import):
val NodeDivisibleBy2 = NodeDivisibleBy(2)
val NodeDivisibleBy3 = NodeDivisibleBy(3)
is a small price to pay for the increased readability of not having to decipher arbitrary expressions in the case clause.

As Travis Brown has noted, it's not really possible in scala.
What I do in that scenario is just separating the decomposition from the test with a guard and an alias.
val DivisibleBy = (n: Node, x: Int) => (n.v % x == 0)
def doit(t: Tree) = t match {
case n # Node(y, l, r) if DivisibleBy(n,2) => println("Matched two: " + l + " " + r)
case n # Node(y, l, r) if DivisibleBy(n,3) => println("Matched three: " + l + " " + r)
case _ => println("Nada")
}
Defining a separate DivisibleBy is obviously complete overkill in this simple case, but may help readability in more complex scenarios in a similar way as F# active patterns do.
You could also define divisibleBy as a method of Node and have:
case class Node(v: Int, l: Tree, r: Tree) extends Tree {
def divisibleBy(o:Int) = (v % o)==0
}
def doit(t: Tree) = t match {
case n # Node(y, l, r) if n divisibleBy 2 => println("Matched two: " + l + " " + r)
case n # Node(y, l, r) if n divisibleBy 3 => println("Matched three: " + l + " " + r)
case _ => println("Nada")
}
which I think is more readable (if more verbose) than the F# version

Bind together the case class, predicate and arguments, then match on the result as usual.
case class Foo(i: Int)
class Testable(val f: Foo, val ds: List[Int])
object Testable {
def apply(f: Foo, ds: List[Int]) = new Testable(f, ds)
def unapply(t: Testable): Option[(Foo, List[Int])] = {
val xs = t.ds filter (t.f.i % _ == 0)
if (xs.nonEmpty) Some((t.f, xs)) else None
}
}
object Test extends App {
val f = Foo(100)
Testable(f, List(3,5,20)) match {
case Testable(f, 3 :: Nil) => println(s"$f matched three")
case Testable(Foo(i), 5 :: Nil) if i < 50
=> println(s"$f matched five")
case Testable(f, ds) => println(s"$f matched ${ds mkString ","}")
case _ => println("Nothing")
}
}

Late answer, but there is a scalac plugin providing syntax ~(extractorWithParam(p), bindings), with all safety from compilation: https://github.com/cchantep/acolyte/tree/master/scalac-plugin#match-component

As I know answer is no.
I also use the previous way for this case.

Related

Scala - standard recursive pattern match on list

Given:
def readLines(x: String): List[String] = Source.fromFile(x).getLines.toList
def toKVMap(fName : String): Map[String,String] =
readLines(fName).map(x => x.split(',')).map { case Array(x, y) => (x, y) }.toMap
I want to be able to take a string and a list of files of replacements and replace bracketed items. So if I have:
replLines("Hello",["cat"]) and cat contains ello,i!, I want to get back Hi!
I tried:
def replLines(inpQ : String, y : List[String]): String = y match {
case Nil => inpQ
case x::xs => replLines(toKVMap(x).fold(inpQ) {
case ((str: String), ((k: String), (v: String))) =>
str.replace("[" + k + "]", v).toString
}, xs)
}
I think the syntax is close, but not quite there. What have I done wrong?
What you're looking for is most likely this (note the foldLeft[String] instead of fold:
def replLines(inpQ: String, y: List[String]): String = y match {
case Nil => inpQ
case x :: xs => replLines(toKVMap(x).foldLeft[String](inpQ) {
case ((str: String), ((k: String), (v: String))) =>
str.replace("[" + k + "]", v)
}, xs)
}
fold generalizes the fold initial argument too much, and considers it a Serializable, not a String. foldLeft (and foldRight, if you prefer to start your replacements from the end) allows you to explicitly specify the type you fold on
EDIT: In fact, you don't even need a recursive pattern matching at all, as you can map your replacements directly to the list:
def replLines2(inpQ: String, y: List[String]): String =
y.flatMap(toKVMap).foldLeft[String](inpQ) {
case (str, (k, v)) => str.replace(s"[$k]", v)
}

avoiding a for loop to reach a tail recursive state in scala

I have written the following method to print an arbitrary arity tree structure. (deep first search)
def treeString[A](tree: Tree[A]): String = {
def DFS(t: Tree[A], acc: String = "", visited: List[String] = Nil, toClose: Int = 0): String = t match {
case Leaf(id, terminal) if !visited.contains(id) => {
acc + " " + terminal.name + "[" + id.toString + "] " + ")"*toClose
}
case Node(id, op, args #_*) if !visited.contains(id) => {
var state = acc + " " + op.name + "[" + id.toString + "] " + "("
var args_visited: List[String] = List(id)
for (i <- args.tail) {
state = DFS(i, state , visited ::: args_visited) + " , "
args_visited = args_visited ++ List(i.id)
}
DFS(args.head, state , visited ++ args_visited, toClose + 1)
}
case _ => acc
}
DFS(tree)
}
The scala compiler claims this function is not tail recursive. However at the tail position i have the proper DFS(..) call. All of the DFS(..) calls made in the loop will be done in an iterative fashion, thus stack safe.
I understand that if the tree is infinitely deep a stack overflow will occur. Do we have a technique to deal with these cases?
I have to agree with #VictorMoroz. The problem is that your statement:
state = DFS(i, state , visited ::: args_visited) + " , "
is not in the tail position. It does create a new stack, thus, your method is not tail recursive anymore.
Without deep dive into your data structures, here the way to tail-recursive-traverse a tree in DFS manner.
sealed trait Tree[+A]
case class Leaf[A](value: A) extends Tree[A]
case class Node[A](left: Tree[A], right: Tree[A]) extends Tree[A]
case class Element(id:Int, name:String)
def DFS[T](t:Tree[T]): List[Element] = {
#tailrec
def _dfs(rest:List[Tree[T]], visited:List[Element]):List[Element] = {
rest match {
case Leaf(v:Element) => v :: visited
case Node(l,r) :: rs => _dfs(l::r::rs,visited)
}
}
_dfs(List(t),Nil)
}
The key is to use an explicit stack - here in form of a List.
I understand that if the tree is infinitely deep a stack overflow will
occur.
This is correct. Basically, in the most programming languages, each method call reserve some of the stack memory. In simple words tail-recursion allow the compiler to ignore the previous state - do not need the previous stack.
Bare in mind that DFS can not ensure to find the global optima and should not be used if this should be fulfilled.
Postscriptum:
The #tailrec annotation is a friendly hint on the compiler to check in compile time whether your method can be optimized or not (Scala do tail-call optimization by default).
I managed to reach a tail recursive state.
Despite the elegancy of the code i truly question the cost/benefit of a functional approach. Most of the times it is not obvious how to manually pass the stack through the algorithm.
Anyway, the true problem here resides in the JVM's recursiveness problems, the scala compiler stills tries to provides us an escape with the tail recursive feature, but it is often a nightmare to put things working
def treeString[A](tree: Tree[A]): String = {
def dropWhile[A](l: List[A], f: A => Boolean): List[A] =
l match {
case h :: t if f(h) => dropWhile(t, f)
case _ => l
}
#tailrec
def DFS(toVisit: List[Tree[A]], visited: List[String] = Nil, acc: String = "", n_args: List[Int] = Nil): String = toVisit match {
case Leaf(id, terminal) :: tail if !visited.contains(id) => {
val n_args_filtered = dropWhile[Int](n_args, x => x == 1)
val acc_to_pass = acc + " " + terminal.name + "[" + id.toString + "] " + ")" * (n_args.length - n_args_filtered.length) + "," * {if (n_args_filtered.length > 0) 1 else 0}
val n_args_to_pass = {if (n_args_filtered.length > 0 )List(n_args_filtered.head - 1) ::: n_args_filtered.tail else Nil}
DFS(toVisit.tail, visited ::: List(id), acc_to_pass, n_args_to_pass)
}
case Node(id, op, args #_*) :: tail if !visited.contains(id) => {
DFS(args.toList ::: toVisit.tail, visited ::: List(id), acc + " " + op.name + " (", List(args.length ) ::: n_args)
}
case Nil => acc
}
DFS(List(tree))
}

Does Scala continuation plugin support nested shift?

I am going through the following Shift/Reset tutorial: http://www.is.ocha.ac.jp/~asai/cw2011tutorial/main-e.pdf.
I got pretty good results so far in translating the OchaCaml examples to Scala (all the way up to section 2.11). But now I seem to have hit a wall. The code from the paper from Asai/Kiselyov defines the following recursive function (this is OchaCaml - I think):
(* a_normal : term_t => term_t *)
let rec a_normal term = match term with
Var (x) -> Var (x)
| Lam (x, t) -> Lam (x, reset (fun () -> a_normal t))
| App (t1, t2) ->
shift (fun k ->
let t = gensym () in (* generate fresh variable *)
App (Lam (t, (* let expression *)
k (Var (t))), (* continue with new variable *)
App (a_normal t1, a_normal t2))) ;;
The function is supposed to A-normalize a lambda expression. This is my Scala translation:
// section 2.11
object ShiftReset extends App {
sealed trait Term
case class Var(x: String) extends Term
case class Lam(x: String, t: Term) extends Term
case class App(t1: Term, t2: Term) extends Term
val gensym = {
var i = 0
() => { i += 1; "t" + i }
}
def a_normal(t: Term): Term#cps[Term] = t match {
case Var(x) => Var(x)
case Lam(x, t) => Lam(x, reset(a_normal(t)))
case App(t1, t2) => shift{ (k:Term=>Term) =>
val t = gensym()
val u = Lam(t, k(Var(t)))
val v = App(a_normal(t1), a_normal(t2))
App(u, v): Term
}
}
}
I am getting the following compilation error:
found : ShiftReset.Term #scala.util.continuations.cpsSynth
#scala.util.continuations.cpsParam[ShiftReset.Term,ShiftReset.Term]
required: ShiftReset.Term
case App(t1, t2) => shift{ (k:Term=>Term) =>
^
one error found
I think the plugin is telling me that it cannot deal with nested shift... Is there a way to make the code compile (either a basic error I've overlooked or some workaround)? I have tried to convert the pattern match to a if/else if/else and to introduce more local variables but I got the same error.
Alternately, would I have more luck using Haskell and the Cont monad (+ the shift/reset from here) or is there going to be the same type of limitation with nested shift? I am adding the Haskell tag as well, since I don't mind switching to Haskell to go through the rest of the tutorial.
Edit: Thanks to James who figured out which line the continuation plugin could not deal with and how to tweak it, it now works. Using the version in his answer and the following formatting code:
def format(t: Term): String = t match {
case Var(x) => s"$x"
case Lam(x, t) => s"\\$x.${format(t)}"
case App(Lam(x, t1), t2) => s"let $x = ${format(t2)} in ${format(t1)}"
case App(Var(x), Var(y)) => s"$x$y"
case App(Var(x), t2) => s"$x (${format(t2)})"
case App(t1, t2) => s"(${format(t1)}) (${format(t2)})"
}
I get the output that the paper mentions (though I don't grasp yet how the continuation actually manages it):
sCombinator:
\x.\y.\z.(xz) (yz)
reset{a_normal(sCombinator)}:
\x.\y.\z.let t1 = xz in let t2 = yz in let t3 = t1t2 in t3
The problem is the line:
val v = App(a_normal(t1), a_normal(t2))
I'm not certain, but I think the type inferencer is getting confused by the fact that a_normal returns a Term#cps[Term], but we're inside a shift so the continuation isn't annotated the same way.
It will compile if you pull the line up out of the shift block:
def a_normal(t: Term): Term#cps[Term] = t match {
case Var(x) => Var(x)
case Lam(x, t) => Lam(x, reset(a_normal(t)))
case App(t1, t2) =>
val v = App(a_normal(t1), a_normal(t2))
shift{ (k:Term=>Term) =>
val t = gensym()
val u = Lam(t, k(Var(t)))
App(u, v): Term
}
}
Regarding nested shifts in general, you can definitely do this if each nested continuation has a compatible type:
object NestedShifts extends App {
import scala.util.continuations._
def foo(x: Int): Int#cps[Unit] = shift { k: (Int => Unit) =>
k(x)
}
reset {
val x = foo(1)
println("x: " + x)
val y = foo(2)
println("y: " + y)
val z = foo(foo(3))
println("z: " + z)
}
}
This program prints to stdout:
x: 1
y: 2
z: 3

How to achieve the expression simplification with distributivity rule in Scala?

I want to write a scala program simplifying this mathematical expression using distributivity rule:
a*b+a*c = a(b+c)
I quickly wrote the following code in order to solve this example:
object Test {
sealed abstract class Expr
case class Var(name: String) extends Expr
case class BinOp(operator: String, left: Expr, right: Expr) extends Expr
def main(args: Array[String]) {
val expr = BinOp("+", BinOp("*", Var("a"), Var("b")), BinOp("*", Var("a"), Var("c")))
println(simplify(expr)) //outputs "a(b + c)"
}
def simplify(expr: Expr) : String = expr match {
case BinOp("+", BinOp("*", Var(x), Var(a)), BinOp("*", Var(y), Var(b))) if (x == y) => "" + x + "*(" + a + " + " + b + ")"
case _ => "" //no matter for the test since I test the first case statically
}
}
Is there a better way to achieve this?
What is the best way to manage the order of operand without duplicating cases for each combination (would be ugly...)? Indeed, what about these expressions:
a*b+a*c = a(b+c)
a*b+c*a = a(b+c)
b*a+a*c = a(b+c)
b*a+c*a = a(b+c)
If Expr keeps commutative law, it must be
def simplify(expr: Expr) : String = expr match {
case expr # BinOp("+", BinOp("*", Var(x), Var(a)), BinOp("*", Var(y), Var(b))) => {
def another(that: String) = {
Seq((x, a), (a, x)) find (_._1 == that) map (_._2)
}
val byY = another(y).map(z => BinOp("+", Var(y), BinOp("*", Var(z), Var(b)))) // combine by y
val byB = another(b).map(z => BinOp("+", Var(b), BinOp("*", Var(z), Var(y)))) // combine by b
(byY orElse byB getOrElse expr).toString
}
case _ => "" //no matter for the test since I test the first case statically
}
byY and byB have the same structure. This is not the best, you maybe reuse some piece of code. :P

could not find implicit value for evidence parameter of type Ordered[T]

I am actually blocked on this for about 4 hours now. I want to get a List of Pairs[String, Int] ordered by their int value. The function partiotion works fine, so should the bestN, but when loading this into my interpreter I get:
<console>:15: error: could not find implicit value for evidence parameter of type Ordered[T]
on my predicate. Does someone see what the problem is? I am really desperate at the moment...
This is the code:
def partition[T : Ordered](pred: (T)=>Boolean, list:List[T]): Pair[List[T],List[T]] = {
list.foldLeft(Pair(List[T](),List[T]()))((pair,x) => if(pred(x))(pair._1, x::pair._2) else (x::pair._1, pair._2))
}
def bestN[T <% Ordered[T]](list:List[T], n:Int): List[T] = {
list match {
case pivot::other => {
println("pivot: " + pivot)
val (smaller,bigger) = partition(pivot <, list)
val s = smaller.size
println(smaller)
if (s == n) smaller
else if (s+1 == n) pivot::smaller
else if (s < n) bestN(bigger, n-s-1)
else bestN(smaller, n)
}
case Nil => Nil
}
}
class OrderedPair[T, V <% Ordered[V]] (t:T, v:V) extends Pair[T,V](t,v) with Ordered[OrderedPair[T,V]] {
def this(p:Pair[T,V]) = this(p._1, p._2)
override def compare(that:OrderedPair[T,V]) : Int = this._2.compare(that._2)
}
Edit: The first function divides a List into two by applying the predicate to every member, the bestN function should return a List of the lowest n members of the list list. And the class is there to make Pairs comparable, in this case what I want do do is:
val z = List(Pair("alfred",1),Pair("peter",4),Pair("Xaver",1),Pair("Ulf",2),Pair("Alfons",6),Pair("Gulliver",3))
with this given List I want to get for example with:
bestN(z, 3)
the result:
(("alfred",1), ("Xaver",1), ("Ulf",2))
It looks like you don't need an Ordered T on your partition function, since it just invokes the predicate function.
The following doesn't work (presumably) but merely compiles. Other matters for code review would be the extra braces and stuff like that.
package evident
object Test extends App {
def partition[T](pred: (T)=>Boolean, list:List[T]): Pair[List[T],List[T]] = {
list.foldLeft(Pair(List[T](),List[T]()))((pair,x) => if(pred(x))(pair._1, x::pair._2) else (x::pair._1, pair._2))
}
def bestN[U,V<%Ordered[V]](list:List[(U,V)], n:Int): List[(U,V)] = {
list match {
case pivot::other => {
println(s"pivot: $pivot and rest ${other mkString ","}")
def cmp(a: (U,V), b: (U,V)) = (a: OrderedPair[U,V]) < (b: OrderedPair[U,V])
val (smaller,bigger) = partition(((x:(U,V)) => cmp(x, pivot)), list)
//val (smaller,bigger) = list partition ((x:(U,V)) => cmp(x, pivot))
println(s"smaller: ${smaller mkString ","} and bigger ${bigger mkString ","}")
val s = smaller.size
if (s == n) smaller
else if (s+1 == n) pivot::smaller
else if (s < n) bestN(bigger, n-s-1)
else bestN(smaller, n)
}
case Nil => Nil
}
}
implicit class OrderedPair[T, V <% Ordered[V]](tv: (T,V)) extends Pair(tv._1, tv._2) with Ordered[OrderedPair[T,V]] {
override def compare(that:OrderedPair[T,V]) : Int = this._2.compare(that._2)
}
val z = List(Pair("alfred",1),Pair("peter",4),Pair("Xaver",1),Pair("Ulf",2),Pair("Alfons",6),Pair("Gulliver",3))
println(bestN(z, 3))
}
I found the partition function hard to read; you need a function to partition all the parens. Here are a couple of formulations, which also use the convention that results accepted by the filter go left, rejects go right.
def partition[T](p: T => Boolean, list: List[T]) =
((List.empty[T], List.empty[T]) /: list) { (s, t) =>
if (p(t)) (t :: s._1, s._2) else (s._1, t :: s._2)
}
def partition2[T](p: T => Boolean, list: List[T]) =
((List.empty[T], List.empty[T]) /: list) {
case ((is, not), t) if p(t) => (t :: is, not)
case ((is, not), t) => (is, t :: not)
}
// like List.partition
def partition3[T](p: T => Boolean, list: List[T]) = {
import collection.mutable.ListBuffer
val is, not = new ListBuffer[T]
for (t <- list) (if (p(t)) is else not) += t
(is.toList, not.toList)
}
This might be closer to what the original code intended:
def bestN[U, V <% Ordered[V]](list: List[(U,V)], n: Int): List[(U,V)] = {
require(n >= 0)
require(n <= list.length)
if (n == 0) Nil
else if (n == list.length) list
else list match {
case pivot :: other =>
println(s"pivot: $pivot and rest ${other mkString ","}")
def cmp(x: (U,V)) = x._2 < pivot._2
val (smaller, bigger) = partition(cmp, other) // other partition cmp
println(s"smaller: ${smaller mkString ","} and bigger ${bigger mkString ","}")
val s = smaller.size
if (s == n) smaller
else if (s == 0) pivot :: bestN(bigger, n - 1)
else if (s < n) smaller ::: bestN(pivot :: bigger, n - s)
else bestN(smaller, n)
case Nil => Nil
}
}
Arrow notation is more usual:
val z = List(
"alfred" -> 1,
"peter" -> 4,
"Xaver" -> 1,
"Ulf" -> 2,
"Alfons" -> 6,
"Gulliver" -> 3
)
I suspect I am missing something, but I'll post a bit of code anyway.
For bestN, you know you can just do this?
val listOfPairs = List(Pair("alfred",1),Pair("peter",4),Pair("Xaver",1),Pair("Ulf",2),Pair("Alfons",6),Pair("Gulliver",3))
val bottomThree = listOfPairs.sortBy(_._2).take(3)
Which gives you:
List((alfred,1), (Xaver,1), (Ulf,2))
And for the partition function, you can just do this (say you wanted all pairs lower then 4):
val partitioned = listOfPairs.partition(_._2 < 4)
Which gives (all lower then 4 on the left, all greater on the right):
(List((alfred,1), (Xaver,1), (Ulf,2), (Gulliver,3)),List((peter,4), (Alfons,6)))
Just to share with you: this works! Thanks alot to all people who helped me, you're all great!
object Test extends App {
def partition[T](pred: (T)=>Boolean, list:List[T]): Pair[List[T],List[T]] = {
list.foldLeft(Pair(List[T](),List[T]()))((pair,x) => if(pred(x))(pair._1, x::pair._2) else (x::pair._1, pair._2))
}
def bestN[U,V<%Ordered[V]](list:List[(U,V)], n:Int): List[(U,V)] = {
list match {
case pivot::other => {
def cmp(a: (U,V), b: (U,V)) = (a: OrderedPair[U,V]) <= (b: OrderedPair[U,V])
val (smaller,bigger) = partition(((x:(U,V)) => cmp(pivot, x)), list)
val s = smaller.size
//println(n + " :" + s)
//println("size:" + smaller.size + "Pivot: " + pivot + " Smaller part: " + smaller + " bigger: " + bigger)
if (s == n) smaller
else if (s+1 == n) pivot::smaller
else if (s < n) bestN(bigger, n-s)
else bestN(smaller, n)
}
case Nil => Nil
}
}
class OrderedPair[T, V <% Ordered[V]](tv: (T,V)) extends Pair(tv._1, tv._2) with Ordered[OrderedPair[T,V]] {
override def compare(that:OrderedPair[T,V]) : Int = this._2.compare(that._2)
}
implicit final def OrderedPair[T, V <% Ordered[V]](p : Pair[T, V]) : OrderedPair[T,V] = new OrderedPair(p)
val z = List(Pair("alfred",1),Pair("peter",1),Pair("Xaver",1),Pair("Ulf",2),Pair("Alfons",6),Pair("Gulliver",3))
println(bestN(z, 3))
println(bestN(z, 4))
println(bestN(z, 1))
}