Scala recursive macro? - scala

I was wondering whether Scala supports recursive macro expansion e.g. I am trying to write a lens library with a lensing macro that does this:
case class C(d: Int)
case class B(c: C)
case class A(b: B)
val a = A(B(C(10))
val aa = lens(a)(_.b.c.d)(_ + 12)
assert(aa.b.c.d == 22)
Given lens(a)(_.b.c.d)(f), I want to transforms it to a.copy(b = lens(a.b)(_.c.d)(f))
EDIT:
I made some decent progress here
However, I cannot figure out a generic way to create an accessor out of List[TermName] e.g. for the above example, given that I have List(TermName('b'), TermName('c'), TermName('d'))), I want to generate an anonymous function _.b.c.d i.e. (x: A) => x.b.c.d. How do I do that?
Basically, how can I write these lines in a generic fashion?

Actually I managed to make it work: https://github.com/pathikrit/sauron/blob/master/src/main/scala/com/github/pathikrit/sauron/package.scala
Here is the complete source:
package com.github.pathikrit
import scala.reflect.macros.blackbox
package object sauron {
def lens[A, B](obj: A)(path: A => B)(modifier: B => B): A = macro lensImpl[A, B]
def lensImpl[A, B](c: blackbox.Context)(obj: c.Expr[A])(path: c.Expr[A => B])(modifier: c.Expr[B => B]): c.Tree = {
import c.universe._
def split(accessor: c.Tree): List[c.TermName] = accessor match { // (_.p.q.r) -> List(p, q, r)
case q"$pq.$r" => split(pq) :+ r
case _: Ident => Nil
case _ => c.abort(c.enclosingPosition, s"Unsupported path element: $accessor")
}
def join(pathTerms: List[TermName]): c.Tree = (q"(x => x)" /: pathTerms) { // List(p, q, r) -> (_.p.q.r)
case (q"($arg) => $pq", r) => q"($arg) => $pq.$r"
}
path.tree match {
case q"($_) => $accessor" => split(accessor) match {
case p :: ps => q"$obj.copy($p = lens($obj.$p)(${join(ps)})($modifier))" // lens(a)(_.b.c)(f) = a.copy(b = lens(a.b)(_.c)(f))
case Nil => q"$modifier($obj)" // lens(x)(_)(f) = f(x)
}
case _ => c.abort(c.enclosingPosition, s"Path must have shape: _.a.b.c.(...), got: ${path.tree}")
}
}
}
And, yes, Scala does apply the same macro recursively.

Related

Get an scala.MatchError: f (of class scala.reflect.internal.Trees$Ident) when providing a lambda assigned to a val

I'm kind like discovering the macros for an use case in which I tried to extract the lambda arg names from a function. To do so, I've defined this class (let's say in a module A):
object MacroTest {
def getLambdaArgNames[A, B](f: A => B): String = macro getLambdaArgNamesImpl[A, B]
def getLambdaArgNamesImpl[A, B](c: Context)(f: c.Expr[A => B]): c.Expr[String] = {
import c.universe._
val Function(args, body) = f.tree
val names = args.map(_.name)
val argNames = names.mkString(", ")
val constant = Literal(Constant(argNames))
c.Expr[String](q"$constant")
}
Now in another module, I'm trying to write an unit like to check the names of the argument named passed to a lambda :
class TestSomething extends AnyFreeSpec with Matchers {
"test" in {
val f = (e1: Expr[Int]) => e1 === 3
val argNames = MacroTest.getLambdaArgNames(f)
println(argNames)
assert(argNames === "e1")
}
}
But this code doesn't compile because of :
scala.MatchError: f (of class scala.reflect.internal.Trees$Ident)
But if I pass directly the lambda to the function like MacroTest.getLambdaArgNames((e1: Expr[Int]) => e1 === 3) it's working so I'm pretty lost about the reason that makes the code not compiling.
Any possible solution to fix that ?
It's the same issue as in Scala macro inspect tree for anonymous function, where you commented: you really do need to pass the lambda itself to the macro, as you say
if I pass directly the lambda to the function like MacroTest.getLambdaArgNames((e1: Expr[Int]) => e1 === 3) it's working
When you write MacroTest.getLambdaArgNames(f), the AST argument (f: c.Expr[A => B] in getLambdaArgNamesImpl) just stores the identifier f and the lambda's AST isn't stored anywhere.
Alternately, you can store the AST for the lambda in f, not the the lambda itself:
val f = q"(e1: Expr[Int]) => e1 === 3"
in Scala 2,
val f = '{ (e1: Expr[Int]) => e1 === 3 }
in Scala 3, and then make getLambdaArgNames a normal function.
Try approach with Traverser
def getLambdaArgNamesImpl[A, B](c: blackbox.Context)(f: c.Expr[A => B]): c.Expr[String] = {
import c.universe._
val arguments = f.tree match {
case Function(args, body) => args
case _ =>
var rhs: Option[Tree] = None
val traverser = new Traverser {
override def traverse(tree: Tree): Unit = {
tree match {
case q"$_ val f: $_ = $expr"
if tree.symbol == f.tree.symbol ||
(tree.symbol.isTerm && tree.symbol.asTerm.getter == f.tree.symbol) =>
rhs = Some(expr)
case _ => super.traverse(tree)
}
}
}
c.enclosingRun.units.foreach(unit => traverser.traverse(unit.body))
rhs match {
case Some(Function(args, body)) => args
case _ => c.abort(c.enclosingPosition, "can't find definition of val f")
}
}
val names = arguments.map(_.name)
val argNames = names.mkString(", ")
val constant = Literal(Constant(argNames))
c.Expr[String](q"$constant")
}
Case tree.symbol == f.tree.symbol matches when f is a local variable
class TestSomething extends AnyFreeSpec with Matchers {
"test" in {
val f = (e1: Expr[Int]) => e1 === 3
...
Case tree.symbol.asTerm.getter == f.tree.symbol matches when f is a field in a class
class TestSomething extends AnyFreeSpec with Matchers {
val f = (e1: Expr[Int]) => e1 === 3
"test" in {
...
Def Macro, pass parameter from a value
Creating a method definition tree from a method symbol and a body
Scala macro how to convert a MethodSymbol to DefDef with parameter default values?
How to get the runtime value of parameter passed to a Scala macro?

Functional patterns for better chaining of collect

I often find myself needing to chain collects where I want to do multiple collects in a single traversal. I also would like to return a "remainder" for things that don't match any of the collects.
For example:
sealed trait Animal
case class Cat(name: String) extends Animal
case class Dog(name: String, age: Int) extends Animal
val animals: List[Animal] =
List(Cat("Bob"), Dog("Spot", 3), Cat("Sally"), Dog("Jim", 11))
// Normal way
val cats: List[Cat] = animals.collect { case c: Cat => c }
val dogAges: List[Int] = animals.collect { case Dog(_, age) => age }
val rem: List[Animal] = Nil // No easy way to create this without repeated code
This really isn't great, it requires multiple iterations and there is no reasonable way to calculate the remainder. I could write a very complicated fold to pull this off, but it would be really nasty.
Instead, I usually opt for mutation which is fairly similar to the logic you would have in a fold:
import scala.collection.mutable.ListBuffer
// Ugly, hide the mutation away
val (cats2, dogsAges2, rem2) = {
// Lose some benefits of type inference
val cs = ListBuffer[Cat]()
val da = ListBuffer[Int]()
val rem = ListBuffer[Animal]()
// Bad separation of concerns, I have to merge all of my functions
animals.foreach {
case c: Cat => cs += c
case Dog(_, age) => da += age
case other => rem += other
}
(cs.toList, da.toList, rem.toList)
}
I don't like this one bit, it has worse type inference and separation of concerns since I have to merge all of the various partial functions. It also requires lots of lines of code.
What I want, are some useful patterns, like a collect that returns the remainder (I grant that partitionMap new in 2.13 does this, but uglier). I also could use some form of pipe or map for operating on parts of tuples. Here are some made up utilities:
implicit class ListSyntax[A](xs: List[A]) {
import scala.collection.mutable.ListBuffer
// Collect and return remainder
// A specialized form of new 2.13 partitionMap
def collectR[B](pf: PartialFunction[A, B]): (List[B], List[A]) = {
val rem = new ListBuffer[A]()
val res = new ListBuffer[B]()
val f = pf.lift
for (elt <- xs) {
f(elt) match {
case Some(r) => res += r
case None => rem += elt
}
}
(res.toList, rem.toList)
}
}
implicit class Tuple2Syntax[A, B](x: Tuple2[A, B]){
def chainR[C](f: B => C): Tuple2[A, C] = x.copy(_2 = f(x._2))
}
Now, I can write this in a way that could be done in a single traversal (with a lazy datastructure) and yet follows functional, immutable practice:
// Relatively pretty, can imagine lazy forms using a single iteration
val (cats3, (dogAges3, rem3)) =
animals.collectR { case c: Cat => c }
.chainR(_.collectR { case Dog(_, age) => age })
My question is, are there patterns like this? It smells like the type of thing that would be in a library like Cats, FS2, or ZIO, but I am not sure what it might be called.
Scastie link of code examples: https://scastie.scala-lang.org/Egz78fnGR6KyqlUTNTv9DQ
I wanted to see just how "nasty" a fold() would be.
val (cats
,dogAges
,rem) = animals.foldRight((List.empty[Cat]
,List.empty[Int]
,List.empty[Animal])) {
case (c:Cat, (cs,ds,rs)) => (c::cs, ds, rs)
case (Dog(_,d),(cs,ds,rs)) => (cs, d::ds, rs)
case (r, (cs,ds,rs)) => (cs, ds, r::rs)
}
Eye of the beholder I suppose.
How about defining a couple utility classes to help you with this?
case class ListCollect[A](list: List[A]) {
def partialCollect[B](f: PartialFunction[A, B]): ChainCollect[List[B], A] = {
val (cs, rem) = list.partition(f.isDefinedAt)
new ChainCollect((cs.map(f), rem))
}
}
case class ChainCollect[A, B](tuple: (A, List[B])) {
def partialCollect[C](f: PartialFunction[B, C]): ChainCollect[(A, List[C]), B] = {
val (cs, rem) = tuple._2.partition(f.isDefinedAt)
ChainCollect(((tuple._1, cs.map(f)), rem))
}
}
ListCollect is just meant to start the chain, and ChainCollect takes the previous remainder (the second element of the tuple) and tries to apply a PartialFunction to it, creating a new ChainCollect object. I'm not particularly fond of the nested tuples this produces, but you may be able to make it look a bit better if you use Shapeless's HLists.
val ((cats, dogs), rem) = ListCollect(animals)
.partialCollect { case c: Cat => c }
.partialCollect { case Dog(_, age) => age }
.tuple
Scastie
Dotty's *: type makes this a bit easier:
opaque type ChainResult[Prev <: Tuple, Rem] = (Prev, List[Rem])
extension [P <: Tuple, R, N](chainRes: ChainResult[P, R]) {
def partialCollect(f: PartialFunction[R, N]): ChainResult[List[N] *: P, R] = {
val (cs, rem) = chainRes._2.partition(f.isDefinedAt)
(cs.map(f) *: chainRes._1, rem)
}
}
This does end up in the output being reversed, but it doesn't have that ugly nesting from my previous approach:
val ((owls, dogs, cats), rem) = (EmptyTuple, animals)
.partialCollect { case c: Cat => c }
.partialCollect { case Dog(_, age) => age }
.partialCollect { case Owl(wisdom) => wisdom }
/* more animals */
case class Owl(wisdom: Double) extends Animal
case class Fly(isAnimal: Boolean) extends Animal
val animals: List[Animal] =
List(Cat("Bob"), Dog("Spot", 3), Cat("Sally"), Dog("Jim", 11), Owl(200), Fly(false))
Scastie
And if you still don't like that, you can always define a few more helper methods to reverse the tuple, add the extension on a List without requiring an EmptyTuple to begin with, etc.
//Add this to the ChainResult extension
def end: Reverse[List[R] *: P] = {
def revHelp[A <: Tuple, R <: Tuple](acc: A, rest: R): RevHelp[A, R] =
rest match {
case EmptyTuple => acc.asInstanceOf[RevHelp[A, R]]
case h *: t => revHelp(h *: acc, t).asInstanceOf[RevHelp[A, R]]
}
revHelp(EmptyTuple, chainRes._2 *: chainRes._1)
}
//Helpful types for safety
type Reverse[T <: Tuple] = RevHelp[EmptyTuple, T]
type RevHelp[A <: Tuple, R <: Tuple] <: Tuple = R match {
case EmptyTuple => A
case h *: t => RevHelp[h *: A, t]
}
And now you can do this:
val (cats, dogs, owls, rem) = (EmptyTuple, animals)
.partialCollect { case c: Cat => c }
.partialCollect { case Dog(_, age) => age }
.partialCollect { case Owl(wisdom) => wisdom }
.end
Scastie
Since you mentioned cats, I would also add solution using foldMap:
sealed trait Animal
case class Cat(name: String) extends Animal
case class Dog(name: String) extends Animal
case class Snake(name: String) extends Animal
val animals: List[Animal] = List(Cat("Bob"), Dog("Spot"), Cat("Sally"), Dog("Jim"), Snake("Billy"))
val map = animals.foldMap{ //Map(other -> List(Snake(Billy)), cats -> List(Cat(Bob), Cat(Sally)), dogs -> List(Dog(Spot), Dog(Jim)))
case d: Dog => Map("dogs" -> List(d))
case c: Cat => Map("cats" -> List(c))
case o => Map("other" -> List(o))
}
val tuples = animals.foldMap{ //(List(Dog(Spot), Dog(Jim)),List(Cat(Bob), Cat(Sally)),List(Snake(Billy)))
case d: Dog => (List(d), Nil, Nil)
case c: Cat => (Nil, List(c), Nil)
case o => (Nil, Nil, List(o))
}
Arguably it's more succinct than fold version, but it has to combine partial results using monoids, so it won't be as performant.
This code is dividing a list into three sets, so the natural way to do this is to use partition twice:
val (cats, notCat) = animals.partitionMap{
case c: Cat => Left(c)
case x => Right(x)
}
val (dogAges, rem) = notCat.partitionMap {
case Dog(_, age) => Left(age)
case x => Right(x)
}
A helper method can simplify this
def partitionCollect[T, U](list: List[T])(pf: PartialFunction[T, U]): (List[U], List[T]) =
list.partitionMap {
case t if pf.isDefinedAt(t) => Left(pf(t))
case x => Right(x)
}
val (cats, notCat) = partitionCollect(animals) { case c: Cat => c }
val (dogAges, rem) = partitionCollect(notCat) { case Dog(_, age) => age }
This is clearly extensible to more categories, with the slight irritation of having to invent temporary variable names (which could be overcome by explicit n-way partition methods)

Can I generate Scala code from a template (of sorts)?

Can I generate Scala code from a template (of sorts)?
I know how to do this in Racket/Scheme/Lisp, but not in Scala. Is this something Scala macros can do?
I want to have a code template where X varies. If I had this code template:
def funcX(a: ArgsX): Try[Seq[RowX]] =
w.getThing() match {
case Some(t: Thing) => w.wrap(t){Detail.funcX(t, a)}
case _ => Failure(new MissingThingException)
}
and tokens Apple and Orange, a macro would take my template, replace the Xs, and produce:
def funcApple(a: ArgsApple): Try[Seq[RowApple]] =
w.getThing() match {
case Some(t: Thing) => w.wrap(t){Detail.funcApple(t, a)}
case _ => Failure(new MissingThingException)
}
def funcOrange(a: ArgsOrange): Try[Seq[RowOrange]] =
w.getThing() match {
case Some(t: Thing) => w.wrap(t){Detail.funcOrange(t, a)}
case _ => Failure(new MissingThingException)
}
Try macro annotation with tree transformer
#compileTimeOnly("enable macro paradise")
class generate extends StaticAnnotation {
def macroTransform(annottees: Any*): Any = macro GenerateMacro.impl
}
object GenerateMacro {
def impl(c: whitebox.Context)(annottees: c.Tree*): c.Tree = {
import c.universe._
val trees = List("Apple", "Orange").map { s =>
val transformer = new Transformer {
override def transform(tree: Tree): Tree = tree match {
case q"$mods def $tname[..$tparams](...$paramss): $tpt = $expr" if tname.toString.contains("X") =>
val tname1 = TermName(tname.toString.replace("X", s))
val tparams1 = tparams.map(super.transform(_))
val paramss1 = paramss.map(_.map(super.transform(_)))
val tpt1 = super.transform(tpt)
val expr1 = super.transform(expr)
q"$mods def $tname1[..$tparams1](...$paramss1): $tpt1 = $expr1"
case q"${tname: TermName} " if tname.toString.contains("X") =>
val tname1 = TermName(tname.toString.replace("X", s))
q"$tname1"
case tq"${tpname: TypeName} " if tpname.toString.contains("X") =>
val tpname1 = TypeName(tpname.toString.replace("X", s))
tq"$tpname1"
case q"$expr.$tname " if tname.toString.contains("X") =>
val expr1 = super.transform(expr)
val tname1 = TermName(tname.toString.replace("X", s))
q"$expr1.$tname1"
case tq"$ref.$tpname " if tpname.toString.contains("X") =>
val ref1 = super.transform(ref)
val tpname1 = TypeName(tpname.toString.replace("X", s))
tq"$ref1.$tpname1"
case t => super.transform(t)
}
}
transformer.transform(annottees.head)
}
q"..$trees"
}
}
#generate
def funcX(a: ArgsX): Try[Seq[RowX]] =
w.getThing() match {
case Some(t: Thing) => w.wrap(t){Detail.funcX(t, a)}
case _ => Failure(new MissingThingException)
}
//Warning:scalac: {
// def funcApple(a: ArgsApple): Try[Seq[RowApple]] = w.getThing() match {
// case Some((t # (_: Thing))) => w.wrap(t)(Detail.funcApple(t, a))
// case _ => Failure(new MissingThingException())
// };
// def funcOrange(a: ArgsOrange): Try[Seq[RowOrange]] = w.getThing() match {
// case Some((t # (_: Thing))) => w.wrap(t)(Detail.funcOrange(t, a))
// case _ => Failure(new MissingThingException())
// };
// ()
//}
Also you can try approach with type class
def func[A <: Args](a: A)(implicit ar: ArgsRows[A]): Try[Seq[ar.R]] =
w.getThing() match {
case Some(t: Thing) => w.wrap(t){Detail.func(t, a)}
case _ => Failure(new MissingThingException)
}
trait ArgsRows[A <: Args] {
type R <: Row
}
object ArgsRows {
type Aux[A <: Args, R0 <: Row] = ArgsRows[A] { type R = R0 }
implicit val apple: Aux[ArgsApple, RowApple] = null
implicit val orange: Aux[ArgsOrange, RowOrange] = null
}
sealed trait Args
trait ArgsApple extends Args
trait ArgsOrange extends Args
trait Thing
sealed trait Row
trait RowApple extends Row
trait RowOrange extends Row
object Detail {
def func[A <: Args](t: Thing, a: A)(implicit ar: ArgsRows[A]): ar.R = ???
}
class MissingThingException extends Throwable
trait W {
def wrap[R <: Row](t: Thing)(r: R): Try[Seq[R]] = ???
def getThing(): Option[Thing] = ???
}
val w: W = ???
In my opinion, it looks like you could pass your funcX function as a higher-order function. You could also combine it with currying to make a "function factory":
def funcX[A](f: (Thing, A) => RowX)(a: A): Try[Seq[RowX]] =
w.getThing() match {
case Some(t: Thing) => w.wrap(t){f(t,a)}
case _ => Failure(new MissingThingException)
}
Then you could use it to create instances of funcApple or funcOrange:
val funcApple: ArgsApple => Try[Seq[RowX]] = funcX(Detail.funcApple)
val funcOrange: ArgsOrange => Try[Seq[RowX]] = funcX(Detail.funcOrange)
funcApple(argsApple)
funcOrange(argsOrange)
I assumed the signature of Detail.funcApple and Detail.funcOrange is similar to (Thing, X) => RowX, but of course you could use different.
You may not actually need macros to achieve this, you can use a pattern match a generic type like this:
import scala.util.Try
def funcX[A](input :A) :Try[Seq[String]] = input match {
case x :String => Success(List(s"Input is a string: $input, call Detail.funcApple"))
case x :Int => Success(List(s"Input is an int, call Detail.funcOrange"))
}
scala> funcX("apple")
res3: scala.util.Try[Seq[String]] = Success(List(Input is a string: apple, call Detail.funcApple))
scala> funcX(11)
res4: scala.util.Try[Seq[String]] = Success(List(Input is an int, call Detail.funcOrange))

Can we have an array of by-name-parameter functions?

In Scala we have a by-name-parameters where we can write
def foo[T](f: => T):T = {
f // invokes f
}
// use as:
foo(println("hello"))
I now want to do the same with an array of methods, that is I want to use them as:
def foo[T](f:Array[ => T]):T = { // does not work
f(0) // invokes f(0) // does not work
}
foo(println("hi"), println("hello")) // does not work
Is there any way to do what I want? The best I have come up with is:
def foo[T](f:() => T *):T = {
f(0)() // invokes f(0)
}
// use as:
foo(() => println("hi"), () => println("hello"))
or
def foo[T](f:Array[() => T]):T = {
f(0)() // invokes f(0)
}
// use as:
foo(Array(() => println("hi"), () => println("hello")))
EDIT: The proposed SIP-24 is not very useful as pointed out by Seth Tisue in a comment to this answer.
An example where this will be problematic is the following code of a utility function trycatch:
type unitToT[T] = ()=>T
def trycatch[T](list:unitToT[T] *):T = list.size match {
case i if i > 1 =>
try list.head()
catch { case t:Any => trycatch(list.tail: _*) }
case 1 => list(0)()
case _ => throw new Exception("call list must be non-empty")
}
Here trycatch takes a list of methods of type ()=>T and applies each element successively until it succeeds or the end is reached.
Now suppose I have two methods:
def getYahooRate(currencyA:String, currencyB:String):Double = ???
and
def getGoogleRate(currencyA:String, currencyB:String):Double = ???
that convert one unit of currencyA to currencyB and output Double.
I use trycatch as:
val usdEuroRate = trycatch(() => getYahooRate("USD", "EUR"),
() => getGoogleRate("USD", "EUR"))
I would have preferred:
val usdEuroRate = trycatch(getYahooRate("USD", "EUR"),
getGoogleRate("USD", "EUR")) // does not work
In the example above, I would like getGoogleRate("USD", "EUR") to be invoked only if getYahooRate("USD", "EUR") throws an exception. This is not the intended behavior of SIP-24.
Here is a solution, although with a few restrictions compared to direct call-by-name:
import scala.util.control.NonFatal
object Main extends App {
implicit class Attempt[+A](f: => A) {
def apply(): A = f
}
def tryCatch[T](attempts: Attempt[T]*): T = attempts.toList match {
case a :: b :: rest =>
try a()
catch {
case NonFatal(e) =>
tryCatch(b :: rest: _*)
}
case a :: Nil =>
a()
case Nil => throw new Exception("call list must be non-empty")
}
def a = println("Hi")
def b: Unit = sys.error("one")
def c = println("bye")
tryCatch(a, b, c)
def d: Int = sys.error("two")
def e = { println("here"); 45 }
def f = println("not here")
val result = tryCatch(d, e, f)
println("Result is " + result)
}
The restrictions are:
Using a block as an argument won't work; only the last expression of the block will be wrapped in an Attempt.
If the expression is of type Nothing (e.g., if b and d weren't annotated), the conversion to Attempt is not inserted since Nothing is a subtype of every type, including Attempt. Presumably the same would apply for an expression of type Null.
As of Scala 2.11.7, the answer is no. However, there is SIP-24, so in some future version your f: => T* version may be possible.

def macro inside case statement

I'd like to ask where def macros can be called and when they are expanded? I guess we cant just put an appropriate generated AST anywhere it fits?
For example, I want this:
(2,1) match {
case StandaloneMacros.permutations(1,2) => true ;
case (_,_) => false
}
become this after macro expansion
(2,1) match {
case (1,2) | (2,1) => true ;
case (_,_) => false
}
My macro permutations produces an Alternative of tuples. But when I run the first snippet, I get
macro method permutations is not a case class, nor does it have an unapply/unapplySeq member
I also tried defining a Permutations object with unapply macro method but got another error:
scala.reflect.internal.FatalError: unexpected tree: class scala.reflect.internal.Trees$Alternative
So: Is it possible to achieve at all?
I came up with a solution some time ago, and I thought I'd share it with you.
To achieve above task I've used a Transformer and transformCaseDefs
object Matcher {
def apply[A, B](expr: A)(patterns: PartialFunction[A, B]): B = macro apply_impl[A,B]
def apply_impl[A: c.WeakTypeTag, B: c.WeakTypeTag](c: Context)(expr: c.Expr[A])(patterns: c.Expr[PartialFunction[A, B]]): c.Expr[B] = {
import c.universe._
def allElemsAreLiterals(l: List[Tree]) = l forall {
case Literal(_) | Ident(_) => true
case _ => throw new Exception("this type of pattern is not supported")
}
val transformer = new Transformer {
override def transformCaseDefs(trees: List[CaseDef]) = trees.map {
case cas # CaseDef(pat # Apply(typTree, argList), guard, body) if allElemsAreLiterals(argList) =>
val permutations = argList.permutations.toList.map(t => q"(..$t)").map {
case Apply(_, args) => Apply(typTree, args)
}
val newPattern = Alternative(permutations)
CaseDef(newPattern, guard, body)
case x => x
}
}
c.Expr[B](q"${transformer.transform(patterns.tree)}($expr)")
}
}
It would be shorter, but somehow you need to provide the same TypeTree which was used in original (before transformation) case statement.
This way, it can be used like this
val x = (1,2,3)
Matcher(x) {
case (2,3,1) => true
case _ => false
}
which is then translated to something like
val x = (1,2,3)
x match {
case (1,2,3) | (1,3,2) | (2,1,3) | (2,3,1) | (3,1,2) | (3,2,1) => true
case _ => false
}