Implementing Enumeratum support for Swagger - scala

I'm using Swagger to annotate my API, and in our API we rely a lot on enumeratum. If I don't do anything, swagger won't recognize it and just call it object.
For example, I have this code that works:
sealed trait Mode extends EnumEntry
object Mode extends Enum[Mode] {
override def values = findValues
case object Initial extends Mode
case object Delta extends Mode
}
#ApiModel
case class Foobar(
#ApiModelProperty(dataType = "string", allowedValues = "Initial,Delta")
mode: Mode
)
However, I would like to avoid repeating the values as some of my types have many more than this example; I don't want to manually keep that in sync.
The problem is that the #ApiModel wants a constant in reference, so I can't do something like reference = Mode.values.mkString(",").
I did try a macro with macro paradise, typically so I can write:
#EnumeratumApiModel(Mode)
sealed trait Mode extends EnumEntry
object Mode extends Enum[Mode] {
override def values = findValues
case object Initial extends Mode
case object Delta extends Mode
}
...but it doesn't work because the macro pass can't access the Mode object.
What solution do I have to avoid repeating the values in the annotation?

This includes code so is too big for a comment.
I tried, that wouldn't work because the #ApiModel annotation wants a String constant as a value (and not a reference to a constant)
This piece of code compiles just fine for me (notice how you should avoid explicitly specifying the type):
import io.swagger.annotations._
import enumeratum._
#ApiModel(reference = Mode.reference)
sealed trait Mode extends EnumEntry
object Mode extends Enum[Mode] {
final val reference = "enum(Initial,Delta)" // this works!
//final val reference: String = "enum(Initial,Delta)" // surprisingly this doesn't!
override def values = findValues
case object Initial extends Mode
case object Delta extends Mode
}
So it seems to be enough to have another macro that would generate such reference string and I assume you already have one (or you can create one basing on the code of EnumMacros.findValuesImpl).
Update
Here is some code for POC that this can actually work. First you start with following macro annotation:
import scala.language.experimental.macros
import scala.annotation.{StaticAnnotation, compileTimeOnly}
import scala.reflect.macros.whitebox.Context
import scala.collection.immutable._
#compileTimeOnly("enable macro to expand macro annotations")
class SwaggerEnumContainer extends StaticAnnotation {
def macroTransform(annottees: Any*) = macro SwaggerEnumMacros.genListString
}
#compileTimeOnly("enable macro to expand macro annotations")
class SwaggerEnumValue(val readOnly: Boolean = false, val required: Boolean = false) extends StaticAnnotation {
def macroTransform(annottees: Any*) = macro SwaggerEnumMacros.genParamAnnotation
}
class SwaggerEnumMacros(val c: Context) {
import c.universe._
def genListString(annottees: c.Expr[Any]*): c.Expr[Any] = {
val result = annottees.map(_.tree).toList match {
case (xxx#q"object $name extends ..$parents { ..$body }") :: Nil =>
val enclosingObject = xxx.asInstanceOf[ModuleDef]
val q"${tq"$pname[..$ptargs]"}(...$pargss)" = parents.head
val enumTraitIdent = ptargs.head.asInstanceOf[Ident]
val subclassSymbols: List[TermName] = enclosingObject.impl.body.foldLeft(List.empty[TermName])((list, innerTree) => {
innerTree match {
case innerObj: ModuleDefApi =>
val innerParentIdent = innerObj.impl.parents.head.asInstanceOf[Ident]
if (enumTraitIdent.name.equals(innerParentIdent.name))
innerObj.name :: list
else
list
case _ => list
}
})
val reference = subclassSymbols.map(n => n.encodedName.toString).mkString(",")
q"""
object $name extends ..$parents {
final val allowableValues = $reference
..$body
}
"""
}
c.Expr[Any](result)
}
def genParamAnnotation(annottees: c.Expr[Any]*): c.Expr[Any] = {
val annotationParams: AnnotationParams = extractAnnotationParameters(c.prefix.tree)
val baseSwaggerAnnot =
q""" new ApiModelProperty(
dataType = "string",
allowableValues = Mode.allowableValues
) """.asInstanceOf[Apply] // why I have to force cast?
val swaggerAnnot: c.universe.Apply = annotationParams.addArgsTo(baseSwaggerAnnot)
annottees.map(_.tree).toList match {
// field definition
case List(param: ValDef) => c.Expr[Any](decorateValDef(param, swaggerAnnot))
// field in a case class = constructor param
case (param: ValDef) :: (rest#(_ :: _)) => decorateConstructorVal(param, rest, swaggerAnnot)
case _ => c.abort(c.enclosingPosition, "SwaggerEnumValue is expected to be used for value definitions")
}
}
def decorateValDef(valDef: ValDef, swaggerAnnot: Apply): ValDef = {
val q"$mods val $name: $tpt = $rhs" = valDef
val newMods: Modifiers = mods.mapAnnotations(al => swaggerAnnot :: al)
q"$newMods val $name: $tpt = $rhs"
}
def decorateConstructorVal(annottee: c.universe.ValDef, expandees: List[Tree], swaggerAnnot: Apply): c.Expr[Any] = {
val q"$_ val $tgtName: $_ = $_" = annottee
val outputs = expandees.map {
case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" => {
// paramss is a 2d array so map inside map
val newParams: List[List[ValDef]] = paramss.map(_.map({
case valDef: ValDef if valDef.name == tgtName => decorateValDef(valDef, swaggerAnnot)
case otherParam => otherParam
}))
q"$mods class $tpname[..$tparams] $ctorMods(...$newParams) extends { ..$earlydefns } with ..$parents { $self => ..$stats }"
}
case otherTree => otherTree
}
c.Expr[Any](Block(outputs, Literal(Constant(()))))
}
case class AnnotationParams(readOnly: Boolean, required: Boolean) {
def customCopy(name: String, value: Any) = {
name match {
case "readOnly" => copy(readOnly = value.asInstanceOf[Boolean])
case "required" => copy(required = value.asInstanceOf[Boolean])
case _ => c.abort(c.enclosingPosition, s"Unknown parameter '$name'")
}
}
def addArgsTo(annot: Apply): Apply = {
val additionalArgs: List[AssignOrNamedArg] = List(
AssignOrNamedArg(q"readOnly", q"$readOnly"),
AssignOrNamedArg(q"required", q"$required")
)
Apply(annot.fun, annot.args ++ additionalArgs)
}
}
private def extractAnnotationParameters(tree: Tree): AnnotationParams = tree match {
case ap: Apply =>
val argNames = Array("readOnly", "required")
val defaults = AnnotationParams(readOnly = false, required = false)
ap.args.zipWithIndex.foldLeft(defaults)((acc, argAndIndex) => argAndIndex match {
case (lit: Literal, index: Int) => acc.customCopy(argNames(index), c.eval(c.Expr[Any](lit)))
case (namedArg: AssignOrNamedArg, _: Int) =>
val q"$name = $lit" = namedArg
acc.customCopy(name.asInstanceOf[Ident].name.toString, c.eval(c.Expr[Any](lit)))
case _ => c.abort(c.enclosingPosition, "Failed to parse annotation params: " + argAndIndex)
})
}
}
And then you can do this:
sealed trait Mode extends EnumEntry
#SwaggerEnumContainer
object Mode extends Enum[Mode] {
override def values = findValues
case object Initial extends Mode
case object Delta extends Mode
}
#ApiModel
case class Foobar(#ApiModelProperty(dataType = "string", allowableValues = Mode.allowableValues) mode: Mode)
Or you can do this which I think is a bit cleaner
#ApiModel
case class Foobar2(
#SwaggerEnumValue mode: Mode,
#SwaggerEnumValue(true) mode2: Mode,
#SwaggerEnumValue(required = true) mode3: Mode,
i: Int, s: String = "abc") {
#SwaggerEnumValue
val modeField: Mode = Mode.Delta
}
Note that this is still only a POC. Known deficiencies include:
#SwaggerEnumContainer can't handle case when some fake allowableValues is already defined with some fake value (which might be nicer for IDE)
#SwaggerEnumValue only supports two attributes from the range available in the original #ApiModelProperty

Related

Scala whitebox macro how to check if class fields are of type of a case class

I am trying to generate a case class from a given case class that strips of Option from the fields. It needs to this recursively, so if the field itself is a case class then it must remove Option from it's fields as well.
So far I managed to it for where no fields are not a case class. But for recursion I need to get the ClassTag for the field if it's a case class. But I have no idea how I can do this. Seems like all I can access is the syntax tree before type check (I guess makes sense considering the final source code isn't formed yet). But I am wondering if it's possible to achieve this in some way.
Here is my code and the missing part as comment.
import scala.annotation.StaticAnnotation
import scala.collection.mutable
import scala.reflect.macros.blackbox.Context
import scala.language.experimental.macros
import scala.annotation.compileTimeOnly
class RemoveOptionFromFields extends StaticAnnotation {
def macroTransform(annottees: Any*): Any = macro RemoveOptionFromFields.impl
}
object RemoveOptionFromFields {
def impl(c: Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
import c.universe._
def modifiedClass(classDecl: ClassDef, compDeclOpt: Option[ModuleDef]) = {
val result = classDecl match {
case q"case class $className(..$fields) extends ..$parents { ..$body }" =>
val fieldsWithoutOption = fields.map {
case ValDef(mods, name, tpt, rhs) =>
tpt.children match {
case List(first, second) if first.toString() == "Option" =>
// Check if `second` is a case class?
// Get it's fields if so
val innerType = tpt.children(1)
ValDef(mods, name, innerType, rhs)
case _ =>
ValDef(mods, name, tpt, rhs)
}
}
val withOptionRemovedFromFieldsClassDecl = q"case class WithOptionRemovedFromFields(..$fieldsWithoutOption)"
val newCompanionDecl = compDeclOpt.fold(
q"""
object ${className.toTermName} {
$withOptionRemovedFromFieldsClassDecl
}
"""
) {
compDecl =>
val q"object $obj extends ..$bases { ..$body }" = compDecl
q"""
object $obj extends ..$bases {
..$body
$withOptionRemovedFromFieldsClassDecl
}
"""
}
q"""
$classDecl
$newCompanionDecl
"""
}
c.Expr[Any](result)
}
annottees.map(_.tree) match {
case (classDecl: ClassDef) :: Nil => modifiedClass(classDecl, None)
case (classDecl: ClassDef) :: (compDecl: ModuleDef) :: Nil => modifiedClass(classDecl, Some(compDecl))
case _ => c.abort(c.enclosingPosition, "This annotation only supports classes")
}
}
}
Not sure I understand what kind of recursion you need. Suppose we have two case classes: the 1st annotated (referring the 2nd) and the 2nd not annotated
#RemoveOptionFromFields
case class MyClass1(mc: Option[MyClass2])
case class MyClass2(i: Option[Int])
What should be the result?
Currently the annotation transforms into
case class MyClass1(mc: Option[MyClass2])
object MyClass1 {
case class WithOptionRemovedFromFields(mc: Class2)
}
case class MyClass2(i: Option[Int])
if the field itself is a case class then it must remove Option from it's fields as well.
Macro annotation can rewrite only class and its companion, it can't rewrite different classes. In my example with 2 classes the annotation can modify MyClass1 and its companion but can't rewrite MyClass2 or its companion. For that MyClass2 should be annotated itself.
In a scope macro annotations are expanded before type checking of this scope. So upon rewriting trees are untyped. If you need some trees to be typed (so that you can find their symbols) you can use c.typecheck
Scala macros: What is the difference between typed (aka typechecked) and untyped Trees
To check that some class is a case class you can use symbol.isClass && symbol.asClass.isCaseClass
How to check if some T is a case class at compile time in Scala?
Hardly you need ClassTags.
One more complication is when MyClass1 and MyClass2 are in the same scope
#RemoveOptionFromFields
case class MyClass1(mc: Option[MyClass2])
case class MyClass2(i: Option[Int])
Then upon expansion of macro annotation for MyClass1 the scope isn't typechecked yet, so it's impossible to typecheck the tree of field definition mc: Option[MyClass2] (class MyClass2 is not known yet). If the classes are in different scopes it's ok
{
#RemoveOptionFromFields
case class MyClass1(mc: Option[MyClass2])
}
case class MyClass2(i: Option[Int])
This is modified version of your code (I'm just printing the fields of the second class)
import scala.annotation.StaticAnnotation
import scala.reflect.macros.blackbox
import scala.language.experimental.macros
import scala.annotation.compileTimeOnly
#compileTimeOnly("enable macro annotations")
class RemoveOptionFromFields extends StaticAnnotation {
def macroTransform(annottees: Any*): Any = macro RemoveOptionFromFields.impl
}
object RemoveOptionFromFields {
def impl(c: blackbox.Context)(annottees: c.Tree*): c.Tree = {
import c.universe._
def modifiedClass(classDecl: ClassDef, compDeclOpt: Option[ModuleDef]) = {
classDecl match {
case q"$mods class $className[..$tparams] $ctorMods(..$fields) extends { ..$earlydefns } with ..$parents { $self => ..$body }"
if mods.hasFlag(Flag.CASE) =>
val fieldsWithoutOption = fields.map {
case field#q"$mods val $name: $tpt = $rhs" =>
tpt match {
case tq"$first[..${List(second)}]" =>
val firstType = c.typecheck(tq"$first", mode = c.TYPEmode, silent = true) match {
case EmptyTree => println(s"can't typecheck $first while expanding #RemoveOptionFromFields for $className"); NoType
case t => t.tpe
}
if (firstType <:< typeOf[Option[_]].typeConstructor) {
val secondSymbol = c.typecheck(tq"$second", mode = c.TYPEmode, silent = true) match {
case EmptyTree => println(s"can't typecheck $second while expanding #RemoveOptionFromFields for $className"); NoSymbol
case t => t.symbol
}
if (secondSymbol.isClass && secondSymbol.asClass.isCaseClass) {
val secondClassFields = secondSymbol.typeSignature.decls.toList.filter(s => s.isMethod && s.asMethod.isCaseAccessor)
secondClassFields.foreach(s =>
c.typecheck(q"$s", silent = true) match {
case EmptyTree => println(s"can't typecheck $s while expanding #RemoveOptionFromFields for $className")
case t => println(s"field ${t.symbol} of type ${t.tpe}, subtype of Option: ${t.tpe <:< typeOf[Option[_]]}")
}
)
}
q"$mods val $name: $second = $rhs"
} else field
case _ =>
field
}
}
val withOptionRemovedFromFieldsClassDecl = q"case class WithOptionRemovedFromFields(..$fieldsWithoutOption)"
val newCompanionDecl = compDeclOpt.fold(
q"""
object ${className.toTermName} {
$withOptionRemovedFromFieldsClassDecl
}
"""
) {
compDecl =>
val q"$mods object $obj extends { ..$earlydefns } with ..$bases { $self => ..$body }" = compDecl
q"""
$mods object $obj extends { ..$earlydefns } with ..$bases { $self =>
..$body
$withOptionRemovedFromFieldsClassDecl
}
"""
}
q"""
$classDecl
$newCompanionDecl
"""
}
}
annottees match {
case (classDecl: ClassDef) :: Nil => modifiedClass(classDecl, None)
case (classDecl: ClassDef) :: (compDecl: ModuleDef) :: Nil => modifiedClass(classDecl, Some(compDecl))
case _ => c.abort(c.enclosingPosition, "This annotation only supports classes")
}
}
}

type parameter mismatch with WeakTypeTag reflection + quasiquoting (I think!)

Inspired by travisbrown, I'm trying to use a macro to create some "smart constructors".
Given
package mypkg
sealed trait Hello[A]
case class Ohayo[A,B](a: (A,B)) extends Hello[A]
and
val smartConstructors = FreeMacros.liftConstructors[Hello]
The macro should find all the subclasses of Hello, look at their constructors, and extract a few elements to populate this tree for the "smart constructor":
q"""
def $methodName[..$typeParams](...$paramLists): $baseType =
$companionSymbol[..$typeArgs](...$argLists)
"""
I hoped to get:
val smartConstructors = new {
def ohayo[A, B](a: (A, B)): Hello[A] = Ohayo[A, B](a)
}
but instead get:
error: type mismatch;
found : (A(in class Ohayo), B(in class Ohayo))
required: ((some other)A(in class Ohayo), (some other)B(in class Ohayo))
val liftedConstructors = FreeMacros.liftConstructors[Hello]
At a glance, the tree looks ok to me:
scala> q" new { ..$wellTyped }"
res1: u.Tree =
{
final class $anon extends scala.AnyRef {
def <init>() = {
super.<init>();
()
};
def ohayo[A, B](a: (A, B)): net.arya.constructors.Hello[A] = Ohayo[A, B](a)
};
new $anon()
}
but I guess it invisibly isn't. If I naively try to freshen up the typeParams with info.typeParams.map(p => TypeName(p.name.toString)), I get "can't splice A as type parameter" when I do the quasiquoting.
Where am I going wrong? Thanks for taking a look.
-Arya
import scala.language.experimental.macros
import scala.reflect.api.Universe
import scala.reflect.macros.whitebox
class FreeMacros(val c: whitebox.Context) {
import c.universe._
import FreeMacros._
def liftedImpl[F[_]](implicit t: c.WeakTypeTag[F[_]]): Tree = {
val atc = t.tpe
val childSymbols: Set[ClassSymbol] = subCaseClassSymbols(c.universe)(atc.typeSymbol.asClass)
val wellTyped = childSymbols.map(ctorsForSymbol(c.universe)(atc)).unzip
q"new { ..${wellTyped} }"
}
}
object FreeMacros {
def liftConstructors[F[_]]: Any = macro FreeMacros.liftedImpl[F]
def smartName(name: String): String = (
name.toList match {
case h :: t => h.toLower :: t
case Nil => Nil
}
).mkString
def subCaseClassSymbols(u: Universe)(root: u.ClassSymbol): Set[u.ClassSymbol] = {
val subclasses = root.knownDirectSubclasses
val cast = subclasses.map(_.asInstanceOf[u.ClassSymbol])
val partitioned = mapped.partition(_.isCaseClass)
partitioned match {
case (caseClasses, regularClasses) => caseClasses ++ regularClasses.flatMap(r => subCaseClassSymbols(u)(r))
}
}
def ctorsForSymbol(u: Universe)(atc: u.Type)(caseClass: u.ClassSymbol): (u.DefDef, u.DefDef) = {
import u._
import internal._
// these didn't help
// def clearTypeSymbol(s: Symbol): TypeSymbol = internal.newTypeSymbol(NoSymbol, s.name.toTypeName, s.pos, if(s.isImplicit)Flag.IMPLICIT else NoFlags)
// def clearTypeSymbol2(s: Symbol): TypeSymbol = internal.newTypeSymbol(NoSymbol, s.name.toTypeName, NoPosition, if(s.isImplicit)Flag.IMPLICIT else NoFlags)
// def clearTypeDef(d: TypeDef): TypeDef = internal.typeDef(clearTypeSymbol(d.symbol))
val companionSymbol: Symbol = caseClass.companion
val info: Type = caseClass.info
val primaryCtor: Symbol = caseClass.primaryConstructor
val method = primaryCtor.asMethod
val typeParams = info.typeParams.map(internal.typeDef(_))
// val typeParams = info.typeParams.map(s => typeDef(newTypeSymbol(NoSymbol, s.name.toTypeName, NoPosition, NoFlags)))
// val typeParams = info.typeParams.map(s => internal.typeDef(clearTypeSymbol2(s)))
val typeArgs = info.typeParams.map(_.name)
val paramLists = method.paramLists.map(_.map(internal.valDef(_)))
val argLists = method.paramLists.map(_.map(_.asTerm.name))
val baseType = info.baseType(atc.typeSymbol)
val List(returnType) = baseType.typeArgs
val methodName = TermName(smartName(caseClass.name.toString))
val wellTyped =
q"""
def $methodName[..$typeParams](...$paramLists): $baseType =
$companionSymbol[..$typeArgs](...$argLists)
"""
wellTyped
}
}
P.S. I have been experimenting with toolbox.untypecheck / typecheck per this article but haven't found a working combination.
you need using
clas.typeArgs.map(_.toString).map(name => {
TypeDef(Modifiers(Flag.PARAM),TypeName(name), List(),TypeBoundsTree(EmptyTree, EmptyTree))
}
replace
info.typeParams.map(p => TypeName(p.name.toString))
it si my code
object GetSealedSubClass {
def ol3[T]: Any = macro GetSealedSubClassImpl.ol3[T]
}
class GetSealedSubClassImpl(val c: Context) {
import c.universe._
def showInfo(s: String) =
c.info(c.enclosingPosition, s.split("\n").mkString("\n |---macro info---\n |", "\n |", ""), true)
def ol3[T: c.WeakTypeTag]: c.universe.Tree = {
//get all sub class
val subClass = c.weakTypeOf[T]
.typeSymbol.asClass.knownDirectSubclasses
.map(e => e.asClass.toType)
//check type params must ia s sealed class
if (subClass.size < 1)
c.abort(c.enclosingPosition, s"${c.weakTypeOf[T]} is not a sealed class")
// get sub class constructor params
val subConstructorParams = subClass.map { e =>
//get constructor
e.members.filter(_.isConstructor)
//if the class has many Constructor then you need filter the main Constructor
.head.map(s => s.asMethod)
//get function param list
}.map(_.asMethod.paramLists.head)
.map(_.map(e => q"""${e.name.toTermName}:${e.info} """))
val outfunc = subClass zip subConstructorParams map {
case (clas, parm) =>
q"def smartConstructors[..${
clas.typeArgs.map(_.toString).map(name => {
TypeDef(Modifiers(Flag.PARAM), TypeName(name), List(), TypeBoundsTree(EmptyTree, EmptyTree))
})
}](..${parm})=${clas.typeSymbol.name.toTermName} (..${parm})"
}
val outClass =
q"""
object Term{
..${outfunc}
}
"""
showInfo(show(outClass))
q"""{
$outClass
Term
}
"""
}
}
using like this
sealed trait Hello[A]
case class Ohayo[A, B](a: (A, B)) extends Hello[A]
object GetSealed extends App {
val a = GetSealedSubClass.ol3[Hello[_]]
val b=a.asInstanceOf[ {def smartConstructors[A, B](a: (A, B)): Ohayo[A, B]}].smartConstructors(1, 2).a
println(b)
}

Scala annotations are not found

I have a case class with annotated fields, like this:
case class Foo(#alias("foo") bar: Int)
I have a macro that processes the declaration of this class:
val (className, access, fields, bases, body) = classDecl match {
case q"case class $n $m(..$ps) extends ..$bs { ..$ss }" => (n, m, ps, bs, ss)
case _ => abort
}
Later, I search for the aliased fields, as follows:
val aliases = fields.asInstanceOf[List[ValDef]].flatMap {
field => field.symbol.annotations.collect {
//deprecated version:
//case annotation if annotation.tpe <:< cv.weakTypeOf[alias] =>
case annotation if annotation.tree.tpe <:< c.weakTypeOf[alias] =>
//deprecated version:
//annotation.scalaArgs.head match {
annotation.tree.children.tail.head match {
case Literal(Constant(param: String)) => (param, field.name)
}
}
}
However, the list of aliases ends up being empty. I have determined that field.symbol.annotations.size is, in fact, 0, despite the annotation clearly sitting on the field.
Any idea of what's wrong?
EDIT
Answering the first two comments:
(1) I tried mods.annotations, but that didn't work. That actually returns List[Tree] instead of List[Annotation], returned by symbol.annotations. Perhaps I didn't modify the code correctly, but the immediate effect was an exception during macro expansion. I'll try to play with it some more.
(2) The class declaration is grabbed while processing an annotation macro slapped on the case class.
The complete code follows. The usage is illustrated in the test code further below.
package com.xxx.util.macros
import scala.collection.immutable.HashMap
import scala.language.experimental.macros
import scala.annotation.StaticAnnotation
import scala.reflect.macros.whitebox
trait Mapped {
def $(key: String) = _vals.get(key)
protected def +=(key: String, value: Any) =
_vals += ((key, value))
private var _vals = new HashMap[String, Any]
}
class alias(val key: String) extends StaticAnnotation
class aliased extends StaticAnnotation {
def macroTransform(annottees: Any*): Any = macro aliasedMacro.impl
}
object aliasedMacro {
def impl(c: whitebox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
import c.universe._
val (classDecl, compDecl) = annottees.map(_.tree) match {
case (clazz: ClassDef) :: Nil => (clazz, None)
case (clazz: ClassDef) :: (comp: ModuleDef) :: Nil => (clazz, Some(comp))
case _ => abort(c, "#aliased must annotate a class")
}
val (className, access, fields, bases, body) = classDecl match {
case q"case class $n $m(..$ps) extends ..$bs { ..$ss }" => (n, m, ps, bs, ss)
case _ => abort(c, "#aliased is only supported on case class")
}
val mappings = fields.asInstanceOf[List[ValDef]].flatMap {
field => field.symbol.annotations.collect {
case annotation if annotation.tree.tpe <:< c.weakTypeOf[alias] =>
annotation.tree.children.tail.head match {
case Literal(Constant(param: String)) =>
q"""this += ($param, ${field.name})"""
}
}
}
val classCode = q"""
case class $className $access(..$fields) extends ..$bases {
..$body; ..$mappings
}"""
c.Expr(compDecl match {
case Some(compCode) => q"""$compCode; $classCode"""
case None => q"""$classCode"""
})
}
protected def abort(c: whitebox.Context, message: String) =
c.abort(c.enclosingPosition, message)
}
The test code:
package test.xxx.util.macros
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner
import org.junit.runner.RunWith
import com.xxx.util.macros._
#aliased
case class Foo(#alias("foo") foo: Int,
#alias("BAR") bar: String,
baz: String) extends Mapped
#RunWith(classOf[JUnitRunner])
class MappedTest extends FunSuite {
val foo = 13
val bar = "test"
val obj = Foo(foo, bar, "extra")
test("field aliased with its own name") {
assertResult(Some(foo))(obj $ "foo")
}
test("field aliased with another string") {
assertResult(Some(bar))(obj $ "BAR")
assertResult(None)(obj $ "bar")
}
test("unaliased field") {
assertResult(None)(obj $ "baz")
}
}
Thanks for the suggestions! In the end, using field.mods.annotations did help. This is how:
val mappings = fields.asInstanceOf[List[ValDef]].flatMap {
field => field.mods.annotations.collect {
case Apply(Select(New(Ident(TypeName("alias"))), termNames.CONSTRUCTOR),
List(Literal(Constant(param: String)))) =>
q"""this += ($param, ${field.name})"""
}
}

Scala compile time macro, filter argument list based on (parent) type

I have a method/constructor that takes a bunch of parameters.
Using a scala macro I can ofcourse extract the Tree representing the type of those parameters.
But I cannot find out how to convert this tree to something "useful", i.e. that I can get the parent types of, check if it is a primitive, etc.
Lets say I have a concrete type C and if want all parameters that inherit from C or are subtypes of Seq[C].
For a bit of context:
case cd#q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$_ } with ..$_ { $self => ..$stats }" :: tail =>
// extract everything that is subtype of C or Seq[C]
val cs = paramss.head.map {
case q"$mods val $name: $tpt = $default" => ???
}
Everything that goes in a macro should be typechecked, right? So $tpt should have a "type"?
How do I get it and what exactly do I get back?
This is how I determine if a type inherits from Iterable and not from Map, and if a type inherits from Map
val iterableType = typeOf[Iterable[_]].typeSymbol
val mapType = typeOf[Map[_, _]].typeSymbol
def isIterable(tpe: Type): Boolean = tpe.baseClasses.contains(iterableType) && !isMap(tpe)
def isMap(tpe: Type): Boolean = tpe.baseClasses.contains(mapType)
I use this on e.g. the fields in a class's primary constructor:
// tpe is a Type
val fields = tpe.declarations.collectFirst {
case m: MethodSymbol if m.isPrimaryConstructor => m
}.get.paramss.head
val iterableFields = fields.filter(f => isIterable(f.typeSignature))
val mapFields = fields.filter(f => isMap(f.typeSignature))
I haven't used them in anger, but it looks like you're using macro annotations.
No, you get untyped annotees.
http://docs.scala-lang.org/overviews/macros/annotations.html
macro annottees are untyped, so that we can change their signatures
Given a client:
package thingy
class C
#testThing class Thingy(val c: C, i: Int) { def j = 2*i }
object Test extends App {
Console println Thingy(42).stuff
}
then then the macro implementation sees:
val t = (annottees map (c typecheck _.tree)) match {
case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$_ } with ..$_ { $self => ..$stats }" :: _ =>
val p # q"$mods val $name: ${tpt: Type} = $default" = paramss.head.head
Console println showRaw(tpt)
Console println tpt.baseClasses
toEmit
}
where the tree has been typechecked explicitly and the type is annotated in the extraction:
case q"val $name: ${tpt: Type} = ???" =>
I realize my spelling of annotee differs from the official docs. I just can't figure out why it should be different from devotee.
Here is the toy skeleton code, for other beginners:
package thingy
import scala.annotation.StaticAnnotation
import scala.language.experimental.macros
import scala.reflect.macros.whitebox.Context
class testThing extends StaticAnnotation {
def macroTransform(annottees: Any*): Any = macro testThing.impl
}
object testThing {
def impl(c: Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
import c.universe._
val toEmit = q"""{
class Thingy(i: Int) {
def stuff = println(i)
}
object Thingy {
def apply(x: Int) = new Thingy(x)
}
()
}"""
def dummy = Literal(Constant(()))
//val t = (annottees map (_.tree)) match {
val t = (annottees map (c typecheck _.tree)) match {
case Nil =>
c.abort(c.enclosingPosition, "No test target")
dummy
case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$_ } with ..$_ { $self => ..$stats }" :: _ =>
//val p # q"$mods val $name: $tpt = $default" = paramss.head.head
val p # q"$mods val $name: ${tpt: Type} = $default" = paramss.head.head
Console println showRaw(tpt)
Console println tpt.baseClasses
toEmit
/*
case (classDeclaration: ClassDef) :: Nil =>
println("No companion provided")
toEmit
case (classDeclaration: ClassDef) :: (companionDeclaration: ModuleDef) :: Nil =>
println("Companion provided")
toEmit
*/
case _ => c.abort(c.enclosingPosition, "Invalid test target")
dummy
}
c.Expr[Any](t)
}
}

How can I splice in a type and a default value in Scala quasiquotes?

I'm trying to make a type-provider that gives updated case classes.
How might I splice in a type and a default value (or omit the default value)?
def impl(c: Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
import c.universe._
import Flag._
val result = {
annottees.map(_.tree).toList match {
case q"$mods class $name[..$tparams](..$first)(...$rest) extends ..$parents { $self => ..$body }" :: Nil =>
val valType = //TODO
val valDefault = //TODO
val helloVal = q"""val x: $valType = $valDefault"""
q"$mods class $name[..$tparams](..$first, $helloVal)(...$rest) extends ..$parents { $self => ..$body }"
}
}
c.Expr[Any](result)
}
I've tried:
I've tried simply val valType = q"String", but then I get an error as if a default value was not found: not enough arguments for method apply
I've also tried splicing in a val defined as typeOf[String], and I've also tried splicing lists of ValDefs into my q"$mods class... (like I've seen done into a q"def... in some similar questions on this site), but in each case there is a typer error.
Any tips? Thanks very much for looking.
You can use the tq interpolator in the definition of valType to create the type tree.
The rest is a little trickier. It seems to work just fine if you define the extra parameter directly:
q"""
$mods class $name[..$tparams](
..$first,
val x: $valType = $valDefault
)(...$rest) extends ..$parents { $self => ..$body }
"""
But when you define $helloVal and then plug that in you end up without the default parameter flag. You could write a helper like this:
def makeDefault(valDef: ValDef) = valDef match {
case ValDef(mods, name, tpt, rhs) => ValDef(
Modifiers(
mods.flags | DEFAULTPARAM, mods.privateWithin, mods.annotations
),
name, tpt, rhs
)
}
Now you can write the following:
val valType = tq"String"
val valDefault = q""""foo""""
val helloVal = makeDefault(q"val x: $valType = $valDefault")
And everything should work as expected.