I want to auto-generate REST API models in Scala using scalameta annotation macros. Specifically, given:
#Resource case class User(
#get id : Int,
#get #post #patch name : String,
#get #post email : String,
registeredOn : Long
)
I want to generate:
object User {
case class Get(id: Int, name: String, email: String)
case class Post(name: String, email: String)
case class Patch(name: Option[String])
}
trait UserRepo {
def getAll: Seq[User.Get]
def get(id: Int): User.Get
def create(request: User.Post): User.Get
def replace(id: Int, request: User.Put): User.Get
def update(id: Int, request: User.Patch): User.Get
def delete(id: Int): User.Get
}
I have something working here: https://github.com/pathikrit/metarest
Specifically I am doing this:
import scala.collection.immutable.Seq
import scala.collection.mutable
import scala.annotation.{StaticAnnotation, compileTimeOnly}
import scala.meta._
class get extends StaticAnnotation
class put extends StaticAnnotation
class post extends StaticAnnotation
class patch extends StaticAnnotation
#compileTimeOnly("#metarest.Resource not expanded")
class Resource extends StaticAnnotation {
inline def apply(defn: Any): Any = meta {
val (cls: Defn.Class, companion: Defn.Object) = defn match {
case Term.Block(Seq(cls: Defn.Class, companion: Defn.Object)) => (cls, companion)
case cls: Defn.Class => (cls, q"object ${Term.Name(cls.name.value)} {}")
case _ => abort("#metarest.Resource must annotate a class")
}
val paramsWithAnnotation = for {
Term.Param(mods, name, decltype, default) <- cls.ctor.paramss.flatten
seenMods = mutable.Set.empty[String]
modifier <- mods if seenMods.add(modifier.toString)
(tpe, defArg) <- modifier match {
case mod"#get" | mod"#put" | mod"#post" => Some(decltype -> default)
case mod"#patch" =>
val optDeclType = decltype.collect({case tpe: Type => targ"Option[$tpe]"})
val defaultArg = default match {
case Some(term) => q"Some($term)"
case None => q"None"
}
Some(optDeclType -> Some(defaultArg))
case _ => None
}
} yield modifier -> Term.Param(Nil, name, tpe, defArg)
val models = paramsWithAnnotation
.groupBy(_._1.toString)
.map({case (verb, pairs) =>
val className = Type.Name(verb.stripPrefix("#").capitalize)
val classParams = pairs.map(_._2)
q"case class $className[..${cls.tparams}] (..$classParams)"
})
val newCompanion = companion.copy(
templ = companion.templ.copy(stats = Some(
companion.templ.stats.getOrElse(Nil) ++ models
))
)
Term.Block(Seq(cls, newCompanion))
}
}
I am unhappy with the following snip of code:
modifier match {
case mod"#get" | mod"#put" | mod"#post" => ...
case mod"#patch" => ...
case _ => None
}
The above code does "stringly" pattern matching on the annotations I have. Is there anyway to re-use the exact annotations I have to pattern match for these:
class get extends StaticAnnotation
class put extends StaticAnnotation
class post extends StaticAnnotation
class patch extends StaticAnnotation
It's possible to replace the mod#get stringly typed annotation with a get() extractor using a bit of runtime reflection (at compile time).
In addition, let's say we also want to allow users to fully qualify the annotation with #metarest.get or #_root_.metarest.get
All the following code examples assume import scala.meta._. The tree structure of #get, #metarest.get and #_root_.metarest.get are
# mod"#get".structure
res4: String = """ Mod.Annot(Ctor.Ref.Name("get"))
"""
# mod"#metarest.get".structure
res5: String = """
Mod.Annot(Ctor.Ref.Select(Term.Name("metarest"), Ctor.Ref.Name("get")))
"""
# mod"#_root_.metarest.get".structure
res6: String = """
Mod.Annot(Ctor.Ref.Select(Term.Select(Term.Name("_root_"), Term.Name("metarest")), Ctor.Ref.Name("get")))
"""
The selectors are either Ctor.Ref.Select or Term.Select and the names are either Term.Name or Ctor.Ref.Name.
Let's first create a custom selector extractor
object Select {
def unapply(tree: Tree): Option[(Term, Name)] = tree match {
case Term.Select(a, b) => Some(a -> b)
case Ctor.Ref.Select(a, b) => Some(a -> b)
case _ => None
}
}
Then create a few helper utilities
object ParamAnnotation {
/* isSuffix(c, a.b.c) // true
* isSuffix(b.c, a.b.c) // true
* isSuffix(a.b.c, a.b.c) // true
* isSuffix(_root_.a.b.c, a.b.c) // true
* isSuffix(d.c, a.b.c) // false
*/
def isSuffix(maybeSuffix: Term, fullName: Term): Boolean =
(maybeSuffix, fullName) match {
case (a: Name, b: Name) => a.value == b.value
case (Select(q"_root_", a), b: Name) => a.value == b.value
case (a: Name, Select(_, b)) => a.value == b.value
case (Select(aRest, a), Select(bRest, b)) =>
a.value == b.value && isSuffix(aRest, bRest)
case _ => false
}
// Returns true if `mod` matches the tree structure of `#T`
def modMatchesType[T: ClassTag](mod: Mod): Boolean = mod match {
case Mod.Annot(term: Term.Ref) =>
isSuffix(term, termRefForType[T])
case _ => false
}
// Parses `T.getClass.getName` into a Term.Ref
// Uses runtime reflection, but this happens only at compile time.
def termRefForType[T](implicit ev: ClassTag[T]): Term.Ref =
ev.runtimeClass.getName.parse[Term].get.asInstanceOf[Term.Ref]
}
With this setup, we can add a companion object to the get definition with an
unapply boolean extractor
class get extends StaticAnnotation
object get {
def unapply(mod: Mod): Boolean = ParamAnnotation.modMatchesType[get](mod)
}
Doing the same for post and put, we can now write
// before
case mod"#get" | mod"#put" | mod"#post" => Some(decltype -> default)
// after
case get() | put() | post() => Some(decltype -> default)
Note that this approach will still not work if the user renames for example get on import
import metarest.{get => GET}
I would recommend aborting if an annotation does not match what you expected
// before
case _ => None
// after
case unexpected => abort("Unexpected modifier $unexpected. Expected one of: put, get post")
PS. The object get { def unapply(mod: Mod): Boolean = ... } part is boilerplate that could be generated by some #ParamAnnotation macro annotation, for example #ParamAnnotion class get extends StaticAnnotation
Related
I have the following 3 case classes:
case class Profile(name: String,
age: Int,
bankInfoData: BankInfoData,
userUpdatedFields: Option[UserUpdatedFields])
case class BankInfoData(accountNumber: Int,
bankAddress: String,
bankNumber: Int,
contactPerson: String,
phoneNumber: Int,
accountType: AccountType)
case class UserUpdatedFields(contactPerson: String,
phoneNumber: Int,
accountType: AccountType)
this is just enums, but i added anyway:
sealed trait AccountType extends EnumEntry
object AccountType extends Enum[AccountType] {
val values: IndexedSeq[AccountType] = findValues
case object Personal extends AccountType
case object Business extends AccountType
}
my task is - i need to write a funcc Profile and compare UserUpdatedFields(all of the fields) with SOME of the fields in BankInfoData...this func is to find which fields where updated.
so I wrote this func:
def findDiff(profile: Profile): Seq[String] = {
var listOfFieldsThatChanged: List[String] = List.empty
if (profile.bankInfoData.contactPerson != profile.userUpdatedFields.get.contactPerson){
listOfFieldsThatChanged = listOfFieldsThatChanged :+ "contactPerson"
}
if (profile.bankInfoData.phoneNumber != profile.userUpdatedFields.get.phoneNumber) {
listOfFieldsThatChanged = listOfFieldsThatChanged :+ "phoneNumber"
}
if (profile.bankInfoData.accountType != profile.userUpdatedFields.get.accountType) {
listOfFieldsThatChanged = listOfFieldsThatChanged :+ "accountType"
}
listOfFieldsThatChanged
}
val profile =
Profile(
"nir",
34,
BankInfoData(1, "somewhere", 2, "john", 123, AccountType.Personal),
Some(UserUpdatedFields("lee", 321, AccountType.Personal))
)
findDiff(profile)
it works, but wanted something cleaner..any suggestions?
Each case class extends Product interface so we could use it to convert case classes into sets of (field, value) elements. Then we can use set operations to find the difference. For example,
def findDiff(profile: Profile): Seq[String] = {
val userUpdatedFields = profile.userUpdatedFields.get
val bankInfoData = profile.bankInfoData
val updatedFieldsMap = userUpdatedFields.productElementNames.zip(userUpdatedFields.productIterator).toMap
val bankInfoDataMap = bankInfoData.productElementNames.zip(bankInfoData.productIterator).toMap
val bankInfoDataSubsetMap = bankInfoDataMap.view.filterKeys(userUpdatedFieldsMap.keys.toList.contains)
(bankInfoDataSubsetMap.toSet diff updatedFieldsMap.toSet).toList.map { case (field, value) => field }
}
Now findDiff(profile) should output List(phoneNumber, contactPerson). Note we are using productElementNames from Scala 2.13 to get the filed names which we then zip with corresponding values
userUpdatedFields.productElementNames.zip(userUpdatedFields.productIterator)
Also we rely on filterKeys and diff.
A simple improvement would be to introduce a trait
trait Fields {
val contactPerson: String
val phoneNumber: Int
val accountType: AccountType
def findDiff(that: Fields): Seq[String] = Seq(
Some(contactPerson).filter(_ != that.contactPerson).map(_ => "contactPerson"),
Some(phoneNumber).filter(_ != that.phoneNumber).map(_ => "phoneNumber"),
Some(accountType).filter(_ != that.accountType).map(_ => "accountType")
).flatten
}
case class BankInfoData(accountNumber: Int,
bankAddress: String,
bankNumber: Int,
contactPerson: String,
phoneNumber: Int,
accountType: String) extends Fields
case class UserUpdatedFields(contactPerson: String,
phoneNumber: Int,
accountType: AccountType) extends Fields
so it was possible to call
BankInfoData(...). findDiff(UserUpdatedFields(...))
If you want to further-improve and avoid naming all the fields multiple times, for example shapeless could be used to do it compile time. Not exactly the same but something like this to get started. Or use reflection to do it runtime like this answer.
That would be a very easy task to achieve if it would be an easy way to convert case class to map. Unfortunately, case classes don't offer that functionality out-of-box yet in Scala 2.12 (as Mario have mentioned it will be easy to achieve in Scala 2.13).
There's a library called shapeless, that offers some generic programming utilities. For example, we could write an extension function toMap using Record and ToMap from shapeless:
object Mappable {
implicit class RichCaseClass[X](val x: X) extends AnyVal {
import shapeless._
import ops.record._
def toMap[L <: HList](
implicit gen: LabelledGeneric.Aux[X, L],
toMap: ToMap[L]
): Map[String, Any] =
toMap(gen.to(x)).map{
case (k: Symbol, v) => k.name -> v
}
}
}
Then we could use it for findDiff:
def findDiff(profile: Profile): Seq[String] = {
import Mappable._
profile match {
case Profile(_, _, bankInfo, Some(userUpdatedFields)) =>
val bankInfoMap = bankInfo.toMap
userUpdatedFields.toMap.toList.flatMap{
case (k, v) if bankInfoMap.get(k).exists(_ != v) => Some(k)
case _ => None
}
case _ => Seq()
}
}
I have a macro annotation that I use to inject implicit type class to a companion method.
#MyMacro case class MyClass[T](a: String, b: Int, t: T)
Most of the time it work as expected, but it breaks when I use type constraint notation:
#MyMacro case class MyClass[T: TypeClass](a: String, b: Int, t: T)
// private[this] not allowed for case class parameters
This error was described on SO and reported as a bug.
Thing is: macros (v1) are no longer maintained, so I cannot expect that this will be fixed.
So what I wanted to know is: can I fix this myself within a macro? Is this change done to AST in a way that I could somehow undo it? I would like to try repairing it within a macro instead of forcing all users to rewrite their code to ...(implicit tc: TypeClass[T]).
class AnnotationType() extends scala.annotation.StaticAnnotation {
def macroTransform(annottees: Any*): Any = macro AnnotationTypeImpl.impl
}
class AnnotationTypeImpl(val c: blackbox.Context) {
import c.universe._
def impl(annottees: Tree*): Tree = {
val tree = annottees.head.asInstanceOf[ClassDef]
val newTree = tree match {
case ClassDef(mods, name, tparams, impl#Template(parents, self, body)) =>
val newBody = body.map {
case ValDef(mods, name, tpt, rhs) =>
// look here
// the flag of `private[this]` is Flag.PRIVATE | Flag.LOCAL
// the flag of `private` is Flag.PRIVATE
// drop Flag.LOCAL in Modifiers.flags , it will change `private[this]` to `private`
val newMods =
if(mods.hasFlag(Flag.IMPLICIT))
mods.asInstanceOf[scala.reflect.internal.Trees#Modifiers].&~(Flag.LOCAL.asInstanceOf[Long]).&~(Flag.CASEACCESSOR.asInstanceOf[Long]).asInstanceOf[Modifiers]
else
mods
ValDef(newMods, name, tpt, rhs)
case e => e
}
ClassDef(mods, name, tparams, Template(parents, self, newBody))
}
println(show(tree))
println(show(newTree))
q"..${newTree +: annottees.tail}"
}
}
// test
#AnnotationType()
case class AnnotationTypeTest[T: Option](a: T){
def option = implicitly[Option[T]]
}
object AnnotationTypeTest {
def main(args: Array[String]): Unit = {
implicit val x = Option(1)
println(AnnotationTypeTest(100))
println(AnnotationTypeTest(100).option)
println(AnnotationTypeTest(100).copy(a =2222))
println(AnnotationTypeTest(100).copy(a =2222)(Some(999)).option)
}
}
I'm using Swagger to annotate my API, and in our API we rely a lot on enumeratum. If I don't do anything, swagger won't recognize it and just call it object.
For example, I have this code that works:
sealed trait Mode extends EnumEntry
object Mode extends Enum[Mode] {
override def values = findValues
case object Initial extends Mode
case object Delta extends Mode
}
#ApiModel
case class Foobar(
#ApiModelProperty(dataType = "string", allowedValues = "Initial,Delta")
mode: Mode
)
However, I would like to avoid repeating the values as some of my types have many more than this example; I don't want to manually keep that in sync.
The problem is that the #ApiModel wants a constant in reference, so I can't do something like reference = Mode.values.mkString(",").
I did try a macro with macro paradise, typically so I can write:
#EnumeratumApiModel(Mode)
sealed trait Mode extends EnumEntry
object Mode extends Enum[Mode] {
override def values = findValues
case object Initial extends Mode
case object Delta extends Mode
}
...but it doesn't work because the macro pass can't access the Mode object.
What solution do I have to avoid repeating the values in the annotation?
This includes code so is too big for a comment.
I tried, that wouldn't work because the #ApiModel annotation wants a String constant as a value (and not a reference to a constant)
This piece of code compiles just fine for me (notice how you should avoid explicitly specifying the type):
import io.swagger.annotations._
import enumeratum._
#ApiModel(reference = Mode.reference)
sealed trait Mode extends EnumEntry
object Mode extends Enum[Mode] {
final val reference = "enum(Initial,Delta)" // this works!
//final val reference: String = "enum(Initial,Delta)" // surprisingly this doesn't!
override def values = findValues
case object Initial extends Mode
case object Delta extends Mode
}
So it seems to be enough to have another macro that would generate such reference string and I assume you already have one (or you can create one basing on the code of EnumMacros.findValuesImpl).
Update
Here is some code for POC that this can actually work. First you start with following macro annotation:
import scala.language.experimental.macros
import scala.annotation.{StaticAnnotation, compileTimeOnly}
import scala.reflect.macros.whitebox.Context
import scala.collection.immutable._
#compileTimeOnly("enable macro to expand macro annotations")
class SwaggerEnumContainer extends StaticAnnotation {
def macroTransform(annottees: Any*) = macro SwaggerEnumMacros.genListString
}
#compileTimeOnly("enable macro to expand macro annotations")
class SwaggerEnumValue(val readOnly: Boolean = false, val required: Boolean = false) extends StaticAnnotation {
def macroTransform(annottees: Any*) = macro SwaggerEnumMacros.genParamAnnotation
}
class SwaggerEnumMacros(val c: Context) {
import c.universe._
def genListString(annottees: c.Expr[Any]*): c.Expr[Any] = {
val result = annottees.map(_.tree).toList match {
case (xxx#q"object $name extends ..$parents { ..$body }") :: Nil =>
val enclosingObject = xxx.asInstanceOf[ModuleDef]
val q"${tq"$pname[..$ptargs]"}(...$pargss)" = parents.head
val enumTraitIdent = ptargs.head.asInstanceOf[Ident]
val subclassSymbols: List[TermName] = enclosingObject.impl.body.foldLeft(List.empty[TermName])((list, innerTree) => {
innerTree match {
case innerObj: ModuleDefApi =>
val innerParentIdent = innerObj.impl.parents.head.asInstanceOf[Ident]
if (enumTraitIdent.name.equals(innerParentIdent.name))
innerObj.name :: list
else
list
case _ => list
}
})
val reference = subclassSymbols.map(n => n.encodedName.toString).mkString(",")
q"""
object $name extends ..$parents {
final val allowableValues = $reference
..$body
}
"""
}
c.Expr[Any](result)
}
def genParamAnnotation(annottees: c.Expr[Any]*): c.Expr[Any] = {
val annotationParams: AnnotationParams = extractAnnotationParameters(c.prefix.tree)
val baseSwaggerAnnot =
q""" new ApiModelProperty(
dataType = "string",
allowableValues = Mode.allowableValues
) """.asInstanceOf[Apply] // why I have to force cast?
val swaggerAnnot: c.universe.Apply = annotationParams.addArgsTo(baseSwaggerAnnot)
annottees.map(_.tree).toList match {
// field definition
case List(param: ValDef) => c.Expr[Any](decorateValDef(param, swaggerAnnot))
// field in a case class = constructor param
case (param: ValDef) :: (rest#(_ :: _)) => decorateConstructorVal(param, rest, swaggerAnnot)
case _ => c.abort(c.enclosingPosition, "SwaggerEnumValue is expected to be used for value definitions")
}
}
def decorateValDef(valDef: ValDef, swaggerAnnot: Apply): ValDef = {
val q"$mods val $name: $tpt = $rhs" = valDef
val newMods: Modifiers = mods.mapAnnotations(al => swaggerAnnot :: al)
q"$newMods val $name: $tpt = $rhs"
}
def decorateConstructorVal(annottee: c.universe.ValDef, expandees: List[Tree], swaggerAnnot: Apply): c.Expr[Any] = {
val q"$_ val $tgtName: $_ = $_" = annottee
val outputs = expandees.map {
case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" => {
// paramss is a 2d array so map inside map
val newParams: List[List[ValDef]] = paramss.map(_.map({
case valDef: ValDef if valDef.name == tgtName => decorateValDef(valDef, swaggerAnnot)
case otherParam => otherParam
}))
q"$mods class $tpname[..$tparams] $ctorMods(...$newParams) extends { ..$earlydefns } with ..$parents { $self => ..$stats }"
}
case otherTree => otherTree
}
c.Expr[Any](Block(outputs, Literal(Constant(()))))
}
case class AnnotationParams(readOnly: Boolean, required: Boolean) {
def customCopy(name: String, value: Any) = {
name match {
case "readOnly" => copy(readOnly = value.asInstanceOf[Boolean])
case "required" => copy(required = value.asInstanceOf[Boolean])
case _ => c.abort(c.enclosingPosition, s"Unknown parameter '$name'")
}
}
def addArgsTo(annot: Apply): Apply = {
val additionalArgs: List[AssignOrNamedArg] = List(
AssignOrNamedArg(q"readOnly", q"$readOnly"),
AssignOrNamedArg(q"required", q"$required")
)
Apply(annot.fun, annot.args ++ additionalArgs)
}
}
private def extractAnnotationParameters(tree: Tree): AnnotationParams = tree match {
case ap: Apply =>
val argNames = Array("readOnly", "required")
val defaults = AnnotationParams(readOnly = false, required = false)
ap.args.zipWithIndex.foldLeft(defaults)((acc, argAndIndex) => argAndIndex match {
case (lit: Literal, index: Int) => acc.customCopy(argNames(index), c.eval(c.Expr[Any](lit)))
case (namedArg: AssignOrNamedArg, _: Int) =>
val q"$name = $lit" = namedArg
acc.customCopy(name.asInstanceOf[Ident].name.toString, c.eval(c.Expr[Any](lit)))
case _ => c.abort(c.enclosingPosition, "Failed to parse annotation params: " + argAndIndex)
})
}
}
And then you can do this:
sealed trait Mode extends EnumEntry
#SwaggerEnumContainer
object Mode extends Enum[Mode] {
override def values = findValues
case object Initial extends Mode
case object Delta extends Mode
}
#ApiModel
case class Foobar(#ApiModelProperty(dataType = "string", allowableValues = Mode.allowableValues) mode: Mode)
Or you can do this which I think is a bit cleaner
#ApiModel
case class Foobar2(
#SwaggerEnumValue mode: Mode,
#SwaggerEnumValue(true) mode2: Mode,
#SwaggerEnumValue(required = true) mode3: Mode,
i: Int, s: String = "abc") {
#SwaggerEnumValue
val modeField: Mode = Mode.Delta
}
Note that this is still only a POC. Known deficiencies include:
#SwaggerEnumContainer can't handle case when some fake allowableValues is already defined with some fake value (which might be nicer for IDE)
#SwaggerEnumValue only supports two attributes from the range available in the original #ApiModelProperty
I have a case class with annotated fields, like this:
case class Foo(#alias("foo") bar: Int)
I have a macro that processes the declaration of this class:
val (className, access, fields, bases, body) = classDecl match {
case q"case class $n $m(..$ps) extends ..$bs { ..$ss }" => (n, m, ps, bs, ss)
case _ => abort
}
Later, I search for the aliased fields, as follows:
val aliases = fields.asInstanceOf[List[ValDef]].flatMap {
field => field.symbol.annotations.collect {
//deprecated version:
//case annotation if annotation.tpe <:< cv.weakTypeOf[alias] =>
case annotation if annotation.tree.tpe <:< c.weakTypeOf[alias] =>
//deprecated version:
//annotation.scalaArgs.head match {
annotation.tree.children.tail.head match {
case Literal(Constant(param: String)) => (param, field.name)
}
}
}
However, the list of aliases ends up being empty. I have determined that field.symbol.annotations.size is, in fact, 0, despite the annotation clearly sitting on the field.
Any idea of what's wrong?
EDIT
Answering the first two comments:
(1) I tried mods.annotations, but that didn't work. That actually returns List[Tree] instead of List[Annotation], returned by symbol.annotations. Perhaps I didn't modify the code correctly, but the immediate effect was an exception during macro expansion. I'll try to play with it some more.
(2) The class declaration is grabbed while processing an annotation macro slapped on the case class.
The complete code follows. The usage is illustrated in the test code further below.
package com.xxx.util.macros
import scala.collection.immutable.HashMap
import scala.language.experimental.macros
import scala.annotation.StaticAnnotation
import scala.reflect.macros.whitebox
trait Mapped {
def $(key: String) = _vals.get(key)
protected def +=(key: String, value: Any) =
_vals += ((key, value))
private var _vals = new HashMap[String, Any]
}
class alias(val key: String) extends StaticAnnotation
class aliased extends StaticAnnotation {
def macroTransform(annottees: Any*): Any = macro aliasedMacro.impl
}
object aliasedMacro {
def impl(c: whitebox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
import c.universe._
val (classDecl, compDecl) = annottees.map(_.tree) match {
case (clazz: ClassDef) :: Nil => (clazz, None)
case (clazz: ClassDef) :: (comp: ModuleDef) :: Nil => (clazz, Some(comp))
case _ => abort(c, "#aliased must annotate a class")
}
val (className, access, fields, bases, body) = classDecl match {
case q"case class $n $m(..$ps) extends ..$bs { ..$ss }" => (n, m, ps, bs, ss)
case _ => abort(c, "#aliased is only supported on case class")
}
val mappings = fields.asInstanceOf[List[ValDef]].flatMap {
field => field.symbol.annotations.collect {
case annotation if annotation.tree.tpe <:< c.weakTypeOf[alias] =>
annotation.tree.children.tail.head match {
case Literal(Constant(param: String)) =>
q"""this += ($param, ${field.name})"""
}
}
}
val classCode = q"""
case class $className $access(..$fields) extends ..$bases {
..$body; ..$mappings
}"""
c.Expr(compDecl match {
case Some(compCode) => q"""$compCode; $classCode"""
case None => q"""$classCode"""
})
}
protected def abort(c: whitebox.Context, message: String) =
c.abort(c.enclosingPosition, message)
}
The test code:
package test.xxx.util.macros
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner
import org.junit.runner.RunWith
import com.xxx.util.macros._
#aliased
case class Foo(#alias("foo") foo: Int,
#alias("BAR") bar: String,
baz: String) extends Mapped
#RunWith(classOf[JUnitRunner])
class MappedTest extends FunSuite {
val foo = 13
val bar = "test"
val obj = Foo(foo, bar, "extra")
test("field aliased with its own name") {
assertResult(Some(foo))(obj $ "foo")
}
test("field aliased with another string") {
assertResult(Some(bar))(obj $ "BAR")
assertResult(None)(obj $ "bar")
}
test("unaliased field") {
assertResult(None)(obj $ "baz")
}
}
Thanks for the suggestions! In the end, using field.mods.annotations did help. This is how:
val mappings = fields.asInstanceOf[List[ValDef]].flatMap {
field => field.mods.annotations.collect {
case Apply(Select(New(Ident(TypeName("alias"))), termNames.CONSTRUCTOR),
List(Literal(Constant(param: String)))) =>
q"""this += ($param, ${field.name})"""
}
}
Is it possible to perform a pattern match whose result conforms to a type parameter of the outer method? E.g. given:
trait Key[A] {
def id: Int
def unapply(k: Key[_]): Boolean = k.id == id // used for Fail2
def apply(thunk: => A): A = thunk // used for Fail3
}
trait Ev[A] {
def pull[A1 <: A](key: Key[A1]): Option[A1]
}
trait Test extends Ev[AnyRef] {
val key1 = new Key[String] { def id = 1 }
val key2 = new Key[Symbol] { def id = 2 }
}
Is there an implementation of Test (its pull method) which uses a pattern match on the key argument and returns Option[A1] for each key checked, without the use of asInstanceOf?
Some pathetic tries:
class Fails1 extends Test {
def pull[A1 <: AnyRef](key: Key[A1]): Option[A1] = key match {
case `key1` => Some("hallo")
case `key2` => Some('welt)
}
}
class Fails2 extends Test {
def pull[A1 <: AnyRef](key: Key[A1]): Option[A1] = key match {
case key1() => Some("hallo")
case key2() => Some('welt)
}
}
class Fails3 extends Test {
def pull[A1 <: AnyRef](key: Key[A1]): Option[A1] = key match {
case k # key1() => Some(k("hallo"))
case k # key2() => Some(k('welt))
}
}
None works, obviously... The only solution is to cast:
class Ugly extends Test {
def pull[A1 <: AnyRef](key: Key[A1]): Option[A1] = key match {
case `key1` => Some("hallo".asInstanceOf[A1])
case `key2` => Some('welt .asInstanceOf[A1])
}
}
val u = new Ugly
u.pull(u.key1)
u.pull(u.key2)
The problem is indeed that pattern matching ignores all erased types. However, there is a little implicit trickery that one could employ. The following will preserve the type resolution provided by the match for the return type.
abstract class UnErased[A]
implicit case object UnErasedString extends UnErased[String]
implicit case object UnErasedSymbol extends UnErased[Symbol]
class UnErasedTest extends Test {
def pull[ A1 <: AnyRef ]( key: Key[ A1 ])(implicit unErased: UnErased[A1]): Option[ A1 ] = unErased match {
case UnErasedString if key1.id == key.id => Some( "hallo" )
case UnErasedSymbol if key2.id == key.id => Some( 'welt )
case _ => None
}
}
val u = new UnErasedTest
println( u.pull( u.key1 ) )
println( u.pull( u.key2 ) )
This is however nearly equivalent to just defining separate sub classes of Key. I find the following method preferable however it may not work if existing code is using Key[String] that you can't change to the necessary KeyString (or too much work to change).
trait KeyString extends Key[String]
trait KeySymbol extends Key[Symbol]
trait Test extends Ev[ AnyRef ] {
val key1 = new KeyString { def id = 1 }
val key2 = new KeySymbol { def id = 2 }
}
class SubTest extends Test {
def pull[ A1 <: AnyRef ]( key: Key[ A1 ]): Option[ A1 ] = key match {
case k: KeyString if key1.id == k.id => Some( "hallo" )
case k: KeySymbol if key2.id == k.id => Some( 'welt )
case _ => None
}
}
val s = new SubTest
println( s.pull( s.key1 ) )
println( s.pull( s.key2 ) )
I provide here an extended example (that shows more of my context) based on the closed types approach of Neil Essy's answer:
trait KeyLike { def id: Int }
trait DispatchCompanion {
private var cnt = 0
sealed trait Value
sealed trait Key[V <: Value] extends KeyLike {
val id = cnt // automatic incremental ids
cnt += 1
}
}
trait Event[V] {
def apply(): Option[V] // simple imperative invocation for testing
}
class EventImpl[D <: DispatchCompanion, V <: D#Value](
disp: Dispatch[D], key: D#Key[V]) extends Event[V] {
def apply(): Option[V] = disp.pull(key)
}
trait Dispatch[D <: DispatchCompanion] {
// factory method for events
protected def event[V <: D#Value](key: D#Key[V]): Event[V] =
new EventImpl[D, V](this, key)
def pull[V <: D#Value](key: D#Key[V]): Option[V]
}
Then the following scenario compiles with not too much clutter:
object Test extends DispatchCompanion {
case class Renamed(before: String, now: String) extends Value
case class Moved (before: Int , now: Int ) extends Value
private case object renamedKey extends Key[Renamed]
private case object movedKey extends Key[Moved ]
}
class Test extends Dispatch[Test.type] {
import Test._
val renamed = event(renamedKey)
val moved = event(movedKey )
// some dummy propagation for testing
protected def pullRenamed: (String, String) = ("doesn't", "matter")
protected def pullMoved : (Int , Int ) = (3, 4)
def pull[V <: Value](key: Key[V]): Option[V] = key match {
case _: renamedKey.type => val p = pullRenamed; Some(Renamed(p._1, p._2))
case _: movedKey.type => val p = pullMoved; Some(Moved( p._1, p._2))
}
}
...and yields the desired results:
val t = new Test
t.renamed()
t.moved()
Now the only thing I don't get and I find ugly is that my cases must be of the form
case _: keyCaseObject.type =>
and cannot be
case keyCaseObject =>
which I would very much prefer. Any ideas where this limitation comes from?