How does the Scala compiler perform implicit conversion? - scala

I have a custom class, A, and I have defined some operations within the class as follows:
def +(that: A) = ...
def -(that: A) = ...
def *(that: A) = ...
def +(that: Double) = ...
def -(that: Double) = ...
def *(that: Double) = ...
In order to have something like 2.0 + x make sense when x is of type A, I have defined the following implicit class:
object A {
implicit class Ops (lhs: Double) {
def +(rhs: A) = ...
def -(rhs: A) = ...
def *(rhs: A) = ...
}
}
This all works fine normally. Now I introduce a compiler plugin with a TypingTransformer that performs some optimizations. Specifically, let's say I have a ValDef:
val x = y + a * z
where x, y, and z are of type A, and a is a Double. Normally, this compiles fine. I put it through the optimizer, which uses quasiquotes to change y + a * z into something else. BUT in this particular example, the expression is unchanged (there are no optimizations to perform). Suddenly, the compiler no longer does an implicit conversion for a * z.
To summarize, I have a compiler plugin that takes an expression that would normally have implicit conversions applied to it. It creates a new expression via quasiquotes, which syntactically appears the same as the old expression. But for this new expression, the compiler fails to perform implicit conversion.
So my question — how does the compiler determine that an implicit conversion must take place? Is there a specific flag or something that needs to be set in the AST that quasiquotes are failing to set?
UPDATE
The plugin phase looks something like this:
override def transform(tree: Tree) = tree match {
case ClassDef(classmods, classname, classtparams, impl) if classname.toString == "Module" => {
var implStatements: List[Tree] = List()
for (node <- impl.body) node match {
case DefDef(mods, name, tparams, vparamss, tpt, body) if name.toString == "loop" => {
var statements: List[Tree] = List()
for (statement <- body.children.dropRight(1)) statement match {
case Assign(opd, rhs) => {
val optimizedRHS = optimizeStatement(rhs)
statements = statements ++ List(Assign(opd, optimizedRHS))
}
case ValDef(mods, opd, tpt, rhs) => {
val optimizedRHS = optimizeStatement(rhs)
statements = statements ++
List(ValDef(mods, opd, tpt, optimizedRHS))
}
case Apply(Select(src1, op), List(src2)) if op.toString == "push" => {
val optimizedSrc2 = optimizeStatement(src2)
statements = statements ++
List(Apply(Select(src1, op), List(optimizedSrc2)))
}
case _ => statements = statements ++ List(statement)
}
val newBody = Block(statements, body.children.last)
implStatements = implStatements ++
List(DefDef(mods, name, tparams, vparamss, tpt, newBody))
}
case _ => implStatements = implStatements ++ List(node)
}
val newImpl = Template(impl.parents, impl.self, implStatements)
ClassDef(classmods, classname, classtparams, newImpl)
}
case _ => super.transform(tree)
}
def optimizeStatement(tree: Tree): Tree = {
// some logic that transforms
// 1.0 * x + 2.0 * (x + y)
// into
// 3.0 * x + 2.0 * y
// (i.e. distribute multiplication & collect like terms)
//
// returned trees are always newly created
// returned trees are create w/ quasiquotes
// something like
// 1.0 * x + 2.0 * y
// will return
// 1.0 * x + 2.0 * y
// (i.e. syntactically unchanged)
}
UPDATE 2
Please refer to this GitHub repo for a minimum working example: https://github.com/darsnack/compiler-plugin-demo
The issue is that a * z turns into a.<$times: error>(z) after I optimize the statement.

The issue is related to the pos field associated with trees. Even though everything is happening before the namer, and the tree with and without the compiler plugin is syntactically the same, the compiler will not be able to infer implicit conversion due to this pesky line in the compiler source:
val retry = typeErrors.forall(_.errPos != null) && (errorInResult(fun) || errorInResult(tree) || args.exists(errorInResult))
(credit to hrhino for finding this).
The solution is to always use treeCopy when creating a new tree so that all the internal flags/fields are copied:
case Assign(opd, rhs) => {
val optimizedRHS = optimizeStatement(rhs)
statements = statements ++ List(treeCopy.Assign(statement, opd, optimizedRHS))
}
And when generating a tree using quasiquotes, remember to set the position:
var optimizedNode = atPos(statement.pos.focus)(q"$optimizedSrc1.$newOp")
I updated my MWP Github repo with the fixed solution: https://github.com/darsnack/compiler-plugin-demo

Related

Performance around functional programming in scala

I'm working with below stuff as a way to learn functional programming and scala, I came from a python background.
case class Point(x: Int, y:Int)
object Operation extends Enumeration {
type Operation = Value
val TurnOn, TurnOff, Toggle = Value
}
object Status extends Enumeration {
type Status = Value
val On, Off = Value
}
val inputs: List[String]
def parseInputs(s: String): (Point, Point, Operation)
Idea is that we have a light matrix(Point), each Point can be either On or Off as describe in Status.
My inputs is a series of command, asking to either TurnOn, TurnOff or Toggle all the lights from one Point to another Point (The rectangular area defined using two points are bottom-left corner and upper-right corner).
My original solution is like this:
type LightStatus = mutable.Map[Point, Status]
val lightStatus = mutable.Map[Point, Status]()
def updateStatus(p1: Point, p2: Point, op: Operation): Unit = {
(p1, p2) match {
case (Point(x1, y1), Point(x2, y2)) =>
for (x <- x1 to x2)
for (y <- y1 to y2) {
val p = Point(x, y)
val currentStatus = lightStatus.getOrElse(p, Off)
(op, currentStatus) match {
case (TurnOn, _) => lightStatus.update(p, On)
case (TurnOff, _) => lightStatus.update(p, Off)
case (Toggle, On) => lightStatus.update(p, Off)
case (Toggle, Off) => lightStatus.update(p, On)
}
}
}
}
for ((p1, p2, op) <- inputs.map(parseInputs)) {
updateStatus(p1, p2, op)
}
Now I have lightStatus as a map to describe the end status of the entire matrix. This works, but seems less functional to me as I was using a mutable Map instead of an immutable object, so I tried to re-factor this into a more functional way, I ended up with this:
inputs.flatMap(s => parseInputs(s) match {
case (Point(x1, y1), Point(x2, y2), op) =>
for (x <- x1 to x2;
y <- y1 to y2)
yield (Point(x, y), op)
}).foldLeft(Map[Point, Status]())((m, item) => {
item match {
case (p, op) =>
val currentStatus = m.getOrElse(p, Off)
(op, currentStatus) match {
case (TurnOn, _) => m.updated(p, On)
case (TurnOff, _) => m.updated(p, Off)
case (Toggle, On) => m.updated(p, Off)
case (Toggle, Off) => m.updated(p, On)
}
}
})
I have couple questions regarding this process:
My second version doesn't seem as clean and straightforward as the first version to me, I'm not sure if this is because I'm not that familiar with functional programming or I just wrote bad functional code.
Is there someway to simplify the syntax on second piece? Especially the (m, item) => ??? function in the foldLeft part? Something like (m, (point, operation)) => ??? gives me syntax error
The second piece of code takes significantly longer to run, which surprise me a bit as these two code essentially is doing the same thing, As I don't have too much Java background, Any idea what might be causing the performance issue?
Many thanks!
From a Functional Programming perspective, your code suffers from the fact that...
The lightStatus Map "maintains state" and thus requires mutation.
A large "area" of status changes == a large number of data updates.
If you can accept each light status as a Boolean value, here's a design that requires no mutation and has fast status updates even over very large areas.
case class Point(x: Int, y:Int)
class LightGrid private (status: Point => Boolean) {
def apply(p: Point): Boolean = status(p)
private def isWithin(p:Point, ll:Point, ur:Point) =
ll.x <= p.x && ll.y <= p.y && p.x <= ur.x && p.y <= ur.y
//each light op returns a new LightGrid
def turnOn(lowerLeft: Point, upperRight: Point): LightGrid =
new LightGrid(point =>
isWithin(point, lowerLeft, upperRight) || status(point))
def turnOff(lowerLeft: Point, upperRight: Point): LightGrid =
new LightGrid(point =>
!isWithin(point, lowerLeft, upperRight) && status(point))
def toggle(lowerLeft: Point, upperRight: Point): LightGrid =
new LightGrid(point =>
isWithin(point, lowerLeft, upperRight) ^ status(point))
}
object LightGrid { //the public constructor
def apply(): LightGrid = new LightGrid(_ => false)
}
usage:
val ON = true
val OFF = false
val lg = LightGrid().turnOn(Point(2,2), Point(11,11)) //easy numbers
.turnOff(Point(8,8), Point(10,10))
.toggle(Point(1,1), Point(9,9))
lg(Point(1,1)) //ON
lg(Point(7,7)) //OFF
lg(Point(8,8)) //ON
lg(Point(9,9)) //ON
lg(Point(10,10)) //OFF
lg(Point(11,11)) //ON
lg(Point(12,12)) //OFF

Simplify/DRY up a case statement in Scala for Twirl Templates

So I'm using play Twirl templates (not within play; independent project) and I have some templates that generate some database DDLs. The following works:
if(config.params.showDDL.isSupplied) {
print( BigSenseServer.config.options("dbms") match {
case "mysql" => txt.mysql(
BigSenseServer.config.options("dbDatabase"),
InetAddress.getLocalHost().getCanonicalHostName,
BigSenseServer.config.options("dboUser"),
BigSenseServer.config.options("dboPass"),
BigSenseServer.config.options("dbUser"),
BigSenseServer.config.options("dbPass")
)
case "pgsql" => txt.pgsql(
BigSenseServer.config.options("dbDatabase"),
InetAddress.getLocalHost().getCanonicalHostName,
BigSenseServer.config.options("dboUser"),
BigSenseServer.config.options("dboPass"),
BigSenseServer.config.options("dbUser"),
BigSenseServer.config.options("dbPass")
)
case "mssql" => txt.mssql$.MODULE$(
BigSenseServer.config.options("dbDatabase"),
InetAddress.getLocalHost().getCanonicalHostName,
BigSenseServer.config.options("dboUser"),
BigSenseServer.config.options("dboPass"),
BigSenseServer.config.options("dbUser"),
BigSenseServer.config.options("dbPass")
)
})
System.exit(0)
}
But I have a lot of repeated statements. If I try to assign the case to a variable and use the $.MODULE$ trick, I get an error saying my variable doesn't take parameters:
val b = BigSenseServer.config.options("dbms") match {
case "mysql" => txt.mysql$.MODULE$
case "pgsql" => txt.pgsql$.MODULE$
case "mssql" => txt.mssql$.MODULE$
}
b("string1","string2","string3","string4","string5","string6")
and the error:
BigSense/src/main/scala/io/bigsense/server/BigSenseServer.scala:32: play.twirl.api.BaseScalaTemplate[T,F] with play.twirl.api.Template6[A,B,C,D,E,F,Result] does not take parameters
What's the best way to simplify this Scala code?
EDIT: Final Solution using a combination of the answers below
The answers below suggest creating factory classes, but I really want to avoid that since I already have the Twirl generated template object. The partially applied functions gave me a better understanding of how to achieve this. Turns out all I needed to do was to pick the apply methods and to eta-expand these; if necessary in combination with partial function application. The following works great:
if(config.params.showDDL.isSupplied) {
print((config.options("dbms") match {
case "pgsql" =>
txt.pgsql.apply _
case "mssql" =>
txt.mssql.apply _
case "mysql" =>
txt.mysql.apply(InetAddress.getLocalHost().getCanonicalHostName,
_:String, _:String, _:String,_:String, _:String)
})(
config.options("dbDatabase"),
config.options("dboUser"),
config.options("dboPass"),
config.options("dbUser"),
config.options("dbPass")
))
System.exit(0)
}
You can try to use eta-expansion and partially applied functions.
Given a factory with some methods:
class Factory {
def mysql(i: Int, s: String) = s"x: $i/$s"
def pgsql(i: Int, s: String) = s"y: $i/$s"
def mssql(i: Int, j: Int, s: String) = s"z: $i/$j/$s"
}
You can abstract over the methods like this:
val factory = new Factory()
// Arguments required by all factory methods
val i = 5
val s = "Hello"
Seq("mysql", "pgsql", "mssql").foreach {
name =>
val f = name match {
case "mysql" =>
// Eta-expand: Convert method into function
factory.mysql _
case "pgsql" =>
factory.pgsql _
case "mssql" =>
// Argument for only one factory method
val j = 10
// Eta-expand, then apply function partially
factory.mssql(_ :Int, j, _: String)
}
// Fill in common arguments into the new function
val result = f(i, s)
println(name + " -> " + result)
}
As you can see in the "mssql" case, the arguments may even differ; yet the common arguments only need to be passed once. The foreach loop is just to test each case, the code in the body shows how to partially apply a function.
You can try to do this by using tupled() to create tupled version of the function.
object X {
def a(x : Int, y : Int, z : Int) = "A" + x + y + z
def b(x : Int, y : Int, z : Int) = "B" + x + y + z
def c(x : Int, y : Int, z : Int) = "C" + x + y + z
}
val selectedFunc = X.a _
selectedFunc.tupled((1, 2, 3)) //returns A123
More specifically, you would store your parameters in a tuple:
val params = (BigSenseServer.config.options("dbDatabase"),
InetAddress.getLocalHost().getCanonicalHostName) //etc.
and then in your match statement:
case "mysql" => (txt.mysql _).tupled(params)

Beginner for loop in Scala: How do I declare a generic element?

I'm new to Scala and am having trouble with a simple generic for-loop declaration, where one instance of my class, FinSet[T] is "unionized" with my another instance of FinSet[T], other. Here is my current implementation of U (short for Union):
def U(other:FinSet[T]) = {
var otherList = other.toList
for(otherElem <- 0 until otherList.length){
this.+(otherElem)
}
this
}
When attempting to compile, it receive this error.
error: type mismatch:
found: : otherElem.type (with underlying type Int)
required : T
this.+(otherElem)
This is in class ListSet[T], which is an extension of the abstract class FinSet[T]. Both are shown here:
abstract class FinSet[T] protected () {
/* returns a list consisting of the set's elements */
def toList:List[T]
/* given a value x, it retuns a new set consisting of x
and all the elemens of this (set)
*/
def +(x:T):FinSet[T]
/* given a set other, it returns the union of this and other,
i.e., a new set consisting of all the elements of this and
all the elements of other
*/
def U(other:FinSet[T]):FinSet[T]
/* given a set other, it returns the intersection of this and other,
i.e., a new set consisting of all the elements that occur both
in this and in other
*/
def ^(other:FinSet[T]):FinSet[T]
/* given a set other, it returns the difference of this and other,
i.e., a new set consisting of all the elements of this that
do not occur in other
*/
def \(other:FinSet[T]):FinSet[T]
/* given a value x, it retuns true if and only if x is an element of this
*/
def contains(x: T):Boolean
/* given a set other, it returns true if and only if this is included
in other, i.e., iff every element of this is an element of other
*/
def <=(other:FinSet[T]):Boolean =
false // replace this line with your implementation
override def toString = "{" ++ (toList mkString ", ") ++ "}"
// overrides the default definition of == (an alias of equals)
override def equals(other:Any):Boolean = other match {
// if other is an instance of FinSet[T] then ...
case o:FinSet[T] =>
// it is equal to this iff it includes and is included in this
(this <= o) && (o <= this)
case _ => false
}
}
And here, ListSet:
class ListSet[T] private (l: List[T]) extends FinSet[T] {
def this() = this(Nil)
// invariant: elems is a list with no repetitions
// storing all of the set's elements
private val elems = l
private def add(x:T, l:List[T]):List[T] = l match {
case Nil => x :: Nil
case y :: t => if (x == y) l else y :: add(x, t)
}
val toList =
elems
def +(x: T) =
this.toList.+(x)
def U(other:FinSet[T]) = {
var otherList = other.toList
for(otherElem <- 0 until otherList.length){
this.+(otherElem)
}
this
}
def ^(other:FinSet[T]) =
this
def \(other:FinSet[T]) =
this
def contains(x:T) =
false
}
Am I missing something obvious here?
In your for loop you are assigning Ints to otherElem (x until y produces a Range[Int], which effectively gives you an iteration over the Ints from x up to y), not members of otherList. What you want is something like:
def U(other:FinSet[T]) = {
for(otherElem <- other.toList){
this.+(otherElem)
}
this
}
EDIT:
Curious, given your definitions of FinSet and ListSet (which I didn't see until after giving my initial answer), you ought to have some other issues with the above code (+ returns a List, not a FinSet, and you don't capture the result of using + anywhere, so your final return value of this ought to just return the original value of the set - unless you are not using the standard Scala immutable List class? If not, which class are you using here?). If you are using the standard Scala immutable List class, then here is an alternative to consider:
def U(other:FinSet[T]) = new ListSet((this.toList ++ other.toList).distinct)
In general, it looks a bit like you are going to some trouble to produce mutable versions of the data structures you are interested in. I strongly encourage you to look into immutable data structures and how to work with them - they are much nicer and safer to work with once you understand the principles.

declare variable in custom control structure in scala

I am wondering if there is a way to create a temp variable in the parameter list of a custom control structure.
Essentially, I would like create a control structure that looks something like the
for loop where I can create a variable, i, and have access to i in the loop body only:
for(i<- 1 to 100) {
//loop body can access i here
}
//i is not visible outside
I would like to do something similar in my code. For example,
customControl ( myVar <- "Task1") {
computation(myVar)
}
customControl ( myVar <- "Task2") {
computation(myVar)
}
def customControl (taskId:String) ( body: => Any) = {
Futures.future {
val result = body
result match {
case Some(x) =>
logger.info("Executed successfully")
x
case _ =>
logger.error(taskId + " failed")
None
}
}
}
Right now, I get around the problem by declaring a variable outside of the custom control structure, which doesn't look very elegant.
val myVar = "Task1"
customControl {
computation(myVar)
}
val myVar2 = "Task2"
customControl {
computation(myVar2 )
}
You could do something like this:
import scala.actors.Futures
def custom(t: String)(f: String => Any) = {
Futures.future {
val result = f(t)
result match {
case Some(x) =>
println("Executed successfully")
x
case _ =>
println(t + " failed")
None
}
}
}
And then you can get syntax like this, which isn't exactly what you asked for, but spares you declaring the variable on a separate line:
scala> custom("ss") { myvar => println("in custom " + myvar); myvar + "x" }
res7: scala.actors.Future[Any] = <function0>
in custom ss
ss failed
scala> custom("ss") { myvar => println("in custom " + myvar); Some(myvar + "x") }
in custom ss
Executed successfully
res8: scala.actors.Future[Any] = <function0>
scala>
Note that the built-in for (x <- expr) body is just syntactic sugar for
expr foreach (x => body)
Thus it might be possible to achieve what you want (using the existing for syntax) by defining a custom foreach method.
Also note that there is already a foreach method that applies to strings. You could do something like this:
case class T(t: String) {
def foreach(f: String => Unit): Unit = f(t)
}
Note: You can also change the result type of f above from Unit to Any and it will still work.
Which would enable you to do something like
for (x <- T("test"))
print(x)
This is just a trivial (and useless) example, since now for (x <- T(y)) f(x) just abbreviates (or rather "enlongishes") f(y). But of course by changing the argument of f in the above definition of foreach from String to something else and doing a corresponding translation from the string t to this type, you could achieve more useful effects.

Polish notation evaluate function

I am new to Scala and I am having hard-time with defining, or more likely translating my code from Ruby to evaluate calculations described as Polish Notations,
f.e. (+ 3 2) or (- 4 (+ 3 2))
I successfully parse the string to form of ArrayBuffer(+, 3, 2) or ArrayBuffer(-, 4, ArrayBuffer(+, 3 2)).
The problem actually starts when I try to define a recursive eval function ,which simply takes ArrayBuffer as argument and "return" an Int(result of evaluated application).
IN THE BASE CASE:
I want to simply check if 2nd element is an instanceOf[Int] and 3rd element is instanceOf[Int] then evaluate them together (depending on sign operator - 1st element) and return Int.
However If any of the elements is another ArrayBuffer, I simply want to reassign that element to returned value of recursively called eval function. like:
Storage(2) = eval(Storage(2)). (** thats why i am using mutable ArrayBuffer **)
The error ,which I get is:
scala.collection.mutable.ArrayBuffer cannot be cast to java.lang.Integer
I am of course not looking for any copy-and-paste answers but for some advices and observations.
Constructive Criticism fully welcomed.
****** This is the testing code I am using only for the addition ******
def eval(Input: ArrayBuffer[Any]):Int = {
if(ArrayBuffer(2).isInstaceOf[ArrayBuffer[Any]]) {
ArrayBuffer(2) = eval(ArrayBuffer(2))
}
if(ArrayBuffer(3).isInstaceOf[ArrayBuffer[Any]]) {
ArrayBuffer(3) = eval(ArrayBuffer(3))
}
if(ArrayBuffer(2).isInstaceOf[Int] && ArrayBuffer(3).isInstanceOf[Int]) {
ArrayBuffer(2).asInstanceOf[Int] + ArrayBuffer(3).asInstanceOf[Int]
}
}
A few problems with your code:
ArrayBuffer(2) means "construct an ArrayBuffer with one element: 2". Nowhere in your code are you referencing your parameter Input. You would need to replace instances of ArrayBuffer(2) with Input(2) for this to work.
ArrayBuffer (and all collections in Scala) are 0-indexed, so if you want to access the second thing in the collection, you would do input(1).
If you leave the the final if there, then the compiler will complain since your function won't always return an Int; if the input contained something unexpected, then that last if would evaluate to false, and you have no else to fall to.
Here's a direct rewrite of your code: fixing the issues:
def eval(input: ArrayBuffer[Any]):Int = {
if(input(1).isInstanceOf[ArrayBuffer[Any]])
input(1) = eval(input(1).asInstanceOf[ArrayBuffer[Any]])
if(input(2).isInstanceOf[ArrayBuffer[Any]])
input(2) = eval(input(2).asInstanceOf[ArrayBuffer[Any]])
input(1).asInstanceOf[Int] + input(2).asInstanceOf[Int]
}
(note also that variable names, like input, should be lowercased.)
That said, the procedure of replacing entries in your input with their evaluations is probably not the best route because it destroys the input in the process of evaluating. You should instead write a function that takes the ArrayBuffer and simply recurses through it without modifying the original.
You'll want you eval function to check for specific cases. Here's a simple implementation as a demonstration:
def eval(e: Seq[Any]): Int =
e match {
case Seq("+", a: Int, b: Int) => a + b
case Seq("+", a: Int, b: Seq[Any]) => a + eval(b)
case Seq("+", a: Seq[Any], b: Int) => eval(a) + b
case Seq("+", a: Seq[Any], b: Seq[Any]) => eval(a) + eval(b)
}
So you can see that for the simple case of (+ arg1 arg2), there are 4 cases. In each case, if the argument is an Int, we use it directly in the addition. If the argument itself is a sequence (like ArrayBuffer), then we recursively evaluate before adding. Notice also that Scala's case syntax lets to do pattern matches with types, so you can skip the isInstanceOf and asInstanceOf stuff.
Now there definitely style improvements you'd want to make down the line (like using Either instead of Any and not hard coding the "+"), but this should get you on the right track.
And here's how you would use it:
eval(Seq("+", 3, 2))
res0: Int = 5
scala> eval(Seq("+", 4, Seq("+", 3, 2)))
res1: Int = 9
Now, if you want to really take advantage of Scala features, you could use an Eval extractor:
object Eval {
def unapply(e: Any): Option[Int] = {
e match {
case i: Int => Some(i)
case Seq("+", Eval(a), Eval(b)) => Some(a + b)
}
}
}
And you'd use it like this:
scala> val Eval(result) = 2
result: Int = 2
scala> val Eval(result) = ArrayBuffer("+", 2, 3)
result: Int = 5
scala> val Eval(result) = ArrayBuffer("+", 2, ArrayBuffer("+", 2, 3))
result: Int = 7
Or you could wrap it in an eval function:
def eval(e: Any): Int = {
val Eval(result) = e
result
}
Here is my take on right to left stack-based evaluation:
def eval(expr: String): Either[Throwable, Int] = {
import java.lang.NumberFormatException
import scala.util.control.Exception._
def int(s: String) = catching(classOf[NumberFormatException]).opt(s.toInt)
val symbols = expr.replaceAll("""[^\d\+\-\*/ ]""", "").split(" ").toSeq
allCatch.either {
val results = symbols.foldRight(List.empty[Int]) {
(symbol, operands) => int(symbol) match {
case Some(op) => op :: operands
case None => val x :: y :: ops = operands
val result = symbol match {
case "+" => x + y
case "-" => x - y
case "*" => x * y
case "/" => x / y
}
result :: ops
}
}
results.head
}
}