I am fairly new to ScalaCheck (and Scala entirely) so this may be a fairly simple solution
I am using ScalaCheck to generate tests for an AST and verifying that the writer/parser work. I have these files
AST.scala
package com.test
object Operator extends Enumeration {
val Add, Subtract, Multiply, Divide = Value
}
sealed trait AST
case class Operation(left: AST, op: Operator.Value, right: AST) extends AST
case class Literal(value: Int) extends AST
GenOperation.scala
import com.test.{AST, Literal}
import org.scalacheck._
import Shrink._
import Prop._
import Arbitrary.arbitrary
object GenLiteral extends Properties("AST::Literal") {
property("Verify parse/write") = forAll(genLiteral){ (node) =>
// val string_version = node.writeToString() // AST -> String
// val result = Parse(string_version) // String -> AST
true
}
def genLiteral: Gen[Literal] = for {
value <- arbitrary[Int]
} yield Literal(value)
implicit def shrinkLiteral: Shrink[AST] = Shrink {
case Literal(value) =>
for {
reduced <- shrink(value)
} yield Literal(reduced)
}
}
GenOperation.scala
import com.test.{AST, Operation}
import org.scalacheck._
import Gen._
import Shrink._
import Prop._
import GenLiteral._
object GenOperation extends Properties("AST::Operation") {
property("Verify parse/write") = forAll(genOperation){ (node) =>
// val string_version = node.writeToString() // AST -> String
// val result = Parse(string_version) // String -> AST
true
}
def genOperation: Gen[Operation] = for {
left <- oneOf(genOperation, genLiteral)
right <- oneOf(genOperation, genLiteral)
op <- oneOf(Operator.values.toSeq)
} yield Operation(left,op,right)
implicit def shrinkOperation: Shrink[AST] = Shrink {
case Operation(l,o,r) =>
(
for {
ls <- shrink(l)
rs <- shrink(r)
} yield Operation(ls, o, rs)
) append (
for {
ls <- shrink(l)
} yield Operation(ls, o, r)
) append (
for {
rs <- shrink(r)
} yield Operation(l, o, rs)
) append shrink(l) append shrink(r)
}
}
In the example code I wrote (what is pasted above) I get the error
ambiguous implicit values:
both method shrinkLiteral in object GenLiteral of type => org.scalacheck.Shrink[com.test.AST]
and method shrinkOperation in object GenOperation of type => org.scalacheck.Shrink[com.test.AST]
match expected type org.scalacheck.Shrink[com.test.AST]
ls <- shrink(l)
How do I write the shrink methods for this?
You have two implicit instances of Shrink[AST] and so the compiler complains about ambiguous implicit values.
You could re-write your code as:
implicit def shrinkLiteral: Shrink[Literal] = Shrink {
case Literal(value) => shrink(value).map(Literal)
}
implicit def shrinkOperation: Shrink[Operation] = Shrink {
case Operation(l,o,r) =>
shrink(l).map(Operation(_, o, r)) append
shrink(r).map(Operation(l, o, _)) append ???
}
implicit def shrinkAST: Shrink[AST] = Shrink {
case o: Operation => shrink(o)
case l: Literal => shrink(l)
}
Related
I have many modules with multiple parameters. Take as a toy example a modified version of the GCD in the template:
class GCD (len: Int = 16, validHigh: Boolean = true) extends Module {
val io = IO(new Bundle {
val value1 = Input(UInt(len.W))
val value2 = Input(UInt(len.W))
val loadingValues = Input(Bool())
val outputGCD = Output(UInt(len.W))
val outputValid = Output(Bool())
})
val x = Reg(UInt())
val y = Reg(UInt())
when(x > y) { x := x - y }
.otherwise { y := y - x }
when(io.loadingValues) {
x := io.value1
y := io.value2
}
io.outputGCD := x
if (validHigh) {
io.outputValid := (y === 0.U)
} else {
io.outputValid := (y =/= 0.U)
}
}
To test or synthesize many different designs, I want to change the values from the command line when I call the tester or the generator apps. Preferably, like this:
[generation or test command] --len 12 --validHigh false
but this or something similar would also be okay
[generation or test command] --param "len=12" --param "validHigh=false"
After some trial and error, I came up with a solution that looks like this:
gcd.scala
package gcd
import firrtl._
import chisel3._
case class GCDConfig(
len: Int = 16,
validHigh: Boolean = true
)
class GCD (val conf: GCDConfig = GCDConfig()) extends Module {
val io = IO(new Bundle {
val value1 = Input(UInt(conf.len.W))
val value2 = Input(UInt(conf.len.W))
val loadingValues = Input(Bool())
val outputGCD = Output(UInt(conf.len.W))
val outputValid = Output(Bool())
})
val x = Reg(UInt())
val y = Reg(UInt())
when(x > y) { x := x - y }
.otherwise { y := y - x }
when(io.loadingValues) {
x := io.value1
y := io.value2
}
io.outputGCD := x
if (conf.validHigh) {
io.outputValid := y === 0.U
} else {
io.outputValid := y =/= 0.U
}
}
trait HasParams {
self: ExecutionOptionsManager =>
var params: Map[String, String] = Map()
parser.note("Design Parameters")
parser.opt[Map[String, String]]('p', "params")
.valueName("k1=v1,k2=v2")
.foreach { v => params = v }
.text("Parameters of Design")
}
object GCD {
def apply(params: Map[String, String]): GCD = {
new GCD(params2conf(params))
}
def params2conf(params: Map[String, String]): GCDConfig = {
var conf = new GCDConfig
for ((k, v) <- params) {
(k, v) match {
case ("len", _) => conf = conf.copy(len = v.toInt)
case ("validHigh", _) => conf = conf.copy(validHigh = v.toBoolean)
case _ =>
}
}
conf
}
}
object GCDGen extends App {
val optionsManager = new ExecutionOptionsManager("gcdgen")
with HasChiselExecutionOptions with HasFirrtlOptions with HasParams
optionsManager.parse(args) match {
case true =>
chisel3.Driver.execute(optionsManager, () => GCD(optionsManager.params))
case _ =>
ChiselExecutionFailure("could not parse results")
}
}
and for tests
GCDSpec.scala
package gcd
import chisel3._
import firrtl._
import chisel3.tester._
import org.scalatest.FreeSpec
import chisel3.experimental.BundleLiterals._
import chiseltest.internal._
import chiseltest.experimental.TestOptionBuilder._
object GCDTest extends App {
val optionsManager = new ExecutionOptionsManager("gcdtest") with HasParams
optionsManager.parse(args) match {
case true =>
//println(optionsManager.commonOptions.programArgs)
(new GCDSpec(optionsManager.params)).execute()
case _ =>
ChiselExecutionFailure("could not parse results")
}
}
class GCDSpec(params: Map[String, String] = Map()) extends FreeSpec with ChiselScalatestTester {
"Gcd should calculate proper greatest common denominator" in {
test(GCD(params)) { dut =>
dut.io.value1.poke(95.U)
dut.io.value2.poke(10.U)
dut.io.loadingValues.poke(true.B)
dut.clock.step(1)
dut.io.loadingValues.poke(false.B)
while (dut.io.outputValid.peek().litToBoolean != dut.conf.validHigh) {
dut.clock.step(1)
}
dut.io.outputGCD.expect(5.U)
}
}
}
This way, I can generate different designs and test them with
sbt 'runMain gcd.GCDGen --params "len=12,validHigh=false"'
sbt 'test:runMain gcd.GCDTest --params "len=12,validHigh=false"'
But there are a couple of problems or annoyances with this solution:
It uses deprecated features (ExecutionOptionsManager and HasFirrtlOptions). I'm not sure if this solution is portable to the new FirrtlStage Infrastructure.
There is a lot of boilerplate involved. It becomes tedious to write new case classes and params2conf functions for every module and rewrite both when a parameter is added or removed.
Using conf.x instead of x all the time. But I guess, this is unavoidable because there is nothing like python's kwargs in Scala.
Is there a better way or one that is at least not deprecated?
Good Question.
I think you are you have pretty much everything right. I don't usually find that I need the command line to alter my tests, my development cycle usually is just poking values in the test params directly running. I use intelliJ which seems to make that easy (but may only work for my habits and the scale of projects I work on).
But I would like to offer you a suggestions that will get you away from ExecutionOptions style as that is going away fast.
In my example code below I offer basically two files here in line, in the first there a few library like tools that use the modern annotation idioms and, I believe, minimize boiler plate. They rely on stringy matching but that is fixable.
In the second, is your GCD, GCDSpec, slightly modified to pull out the params a bit differently. At the bottom of the second is some very minimal boiler plate that allows you to get the command line access you want.
Good luck, I hope this is mostly self explanatory.
First file:
import chisel3.stage.ChiselCli
import firrtl.AnnotationSeq
import firrtl.annotations.{Annotation, NoTargetAnnotation}
import firrtl.options.{HasShellOptions, Shell, ShellOption, Stage, Unserializable}
import firrtl.stage.FirrtlCli
trait TesterAnnotation {
this: Annotation =>
}
case class TestParams(params: Map[String, String] = Map.empty) {
val defaults: collection.mutable.HashMap[String, String] = new collection.mutable.HashMap()
def getInt(key: String): Int = params.getOrElse(key, defaults(key)).toInt
def getBoolean(key: String): Boolean = params.getOrElse(key, defaults(key)).toBoolean
def getString(key: String): String = params.getOrElse(key, defaults(key))
}
case class TesterParameterAnnotation(paramString: TestParams)
extends TesterAnnotation
with NoTargetAnnotation
with Unserializable
object TesterParameterAnnotation extends HasShellOptions {
val options = Seq(
new ShellOption[Map[String, String]](
longOption = "param-string",
toAnnotationSeq = (a: Map[String, String]) => Seq(TesterParameterAnnotation(TestParams(a))),
helpText = """a comma separated, space free list of additional paramters, e.g. --param-string "k1=7,k2=dog" """
)
)
}
trait TesterCli {
this: Shell =>
Seq(TesterParameterAnnotation).foreach(_.addOptions(parser))
}
class GenericTesterStage(thunk: (TestParams, AnnotationSeq) => Unit) extends Stage {
val shell: Shell = new Shell("chiseltest") with TesterCli with ChiselCli with FirrtlCli
def run(annotations: AnnotationSeq): AnnotationSeq = {
val params = annotations.collectFirst { case TesterParameterAnnotation(p) => p }.getOrElse(TestParams())
thunk(params, annotations)
annotations
}
}
Second File:
import chisel3._
import chisel3.tester._
import chiseltest.experimental.TestOptionBuilder._
import chiseltest.{ChiselScalatestTester, GenericTesterStage, TestParams}
import firrtl._
import firrtl.options.StageMain
import org.scalatest.freespec.AnyFreeSpec
case class GCD(testParams: TestParams) extends Module {
val bitWidth = testParams.getInt("len")
val validHigh = testParams.getBoolean("validHigh")
val io = IO(new Bundle {
val value1 = Input(UInt(bitWidth.W))
val value2 = Input(UInt(bitWidth.W))
val loadingValues = Input(Bool())
val outputGCD = Output(UInt(bitWidth.W))
val outputValid = Output(Bool())
})
val x = Reg(UInt())
val y = Reg(UInt())
when(x > y) { x := x - y }.otherwise { y := y - x }
when(io.loadingValues) {
x := io.value1
y := io.value2
}
io.outputGCD := x
if (validHigh) {
io.outputValid := y === 0.U
} else {
io.outputValid := y =/= 0.U
}
}
class GCDSpec(params: TestParams, annotations: AnnotationSeq = Seq()) extends AnyFreeSpec with ChiselScalatestTester {
"Gcd should calculate proper greatest common denominator" in {
test(GCD(params)).withAnnotations(annotations) { dut =>
dut.io.value1.poke(95.U)
dut.io.value2.poke(10.U)
dut.io.loadingValues.poke(true.B)
dut.clock.step(1)
dut.io.loadingValues.poke(false.B)
while (dut.io.outputValid.peek().litToBoolean != dut.validHigh) {
dut.clock.step(1)
}
dut.io.outputGCD.expect(5.U)
}
}
}
class GcdTesterStage
extends GenericTesterStage((params, annotations) => {
params.defaults ++= Seq("len" -> "16", "validHigh" -> "false")
(new GCDSpec(params, annotations)).execute()
})
object GcdTesterStage extends StageMain(new GcdTesterStage)
Based on http://blog.echo.sh/2013/11/04/exploring-scala-macros-map-to-case-class-conversion.html, I was able to find another way of removing the params2conf boilerplate using scala macros. I also extended Chick's answer with verilog generation since that was also part of the original question. A full repository of my solution can be found on github.
Basically there are three four files:
The macro that converts a map to a case class:
package mappable
import scala.language.experimental.macros
import scala.reflect.macros.whitebox.Context
trait Mappable[T] {
def toMap(t: T): Map[String, String]
def fromMap(map: Map[String, String]): T
}
object Mappable {
implicit def materializeMappable[T]: Mappable[T] = macro materializeMappableImpl[T]
def materializeMappableImpl[T: c.WeakTypeTag](c: Context): c.Expr[Mappable[T]] = {
import c.universe._
val tpe = weakTypeOf[T]
val companion = tpe.typeSymbol.companion
val fields = tpe.decls.collectFirst {
case m: MethodSymbol if m.isPrimaryConstructor => m
}.get.paramLists.head
val (toMapParams, fromMapParams) = fields.map { field =>
val name = field.name.toTermName
val decoded = name.decodedName.toString
val returnType = tpe.decl(name).typeSignature
val fromMapLine = returnType match {
case NullaryMethodType(res) if res =:= typeOf[Int] => q"map($decoded).toInt"
case NullaryMethodType(res) if res =:= typeOf[String] => q"map($decoded)"
case NullaryMethodType(res) if res =:= typeOf[Boolean] => q"map($decoded).toBoolean"
case _ => q""
}
(q"$decoded -> t.$name.toString", fromMapLine)
}.unzip
c.Expr[Mappable[T]] { q"""
new Mappable[$tpe] {
def toMap(t: $tpe): Map[String, String] = Map(..$toMapParams)
def fromMap(map: Map[String, String]): $tpe = $companion(..$fromMapParams)
}
""" }
}
}
Library like tools:
package cliparams
import chisel3.stage.{ChiselStage, ChiselGeneratorAnnotation, ChiselCli}
import firrtl.AnnotationSeq
import firrtl.annotations.{Annotation, NoTargetAnnotation}
import firrtl.options.{HasShellOptions, Shell, ShellOption, Stage, Unserializable, StageMain}
import firrtl.stage.FirrtlCli
import mappable._
trait SomeAnnotaion {
this: Annotation =>
}
case class ParameterAnnotation(map: Map[String, String])
extends SomeAnnotaion
with NoTargetAnnotation
with Unserializable
object ParameterAnnotation extends HasShellOptions {
val options = Seq(
new ShellOption[Map[String, String]](
longOption = "params",
toAnnotationSeq = (a: Map[String, String]) => Seq(ParameterAnnotation(a)),
helpText = """a comma separated, space free list of additional paramters, e.g. --param-string "k1=7,k2=dog" """
)
)
}
trait ParameterCli {
this: Shell =>
Seq(ParameterAnnotation).foreach(_.addOptions(parser))
}
class GenericParameterCliStage[P: Mappable](thunk: (P, AnnotationSeq) => Unit, default: P) extends Stage {
def mapify(p: P) = implicitly[Mappable[P]].toMap(p)
def materialize(map: Map[String, String]) = implicitly[Mappable[P]].fromMap(map)
val shell: Shell = new Shell("chiseltest") with ParameterCli with ChiselCli with FirrtlCli
def run(annotations: AnnotationSeq): AnnotationSeq = {
val params = annotations
.collectFirst {case ParameterAnnotation(map) => materialize(mapify(default) ++ map.toSeq)}
.getOrElse(default)
thunk(params, annotations)
annotations
}
}
The GCD source file
// See README.md for license details.
package gcd
import firrtl._
import chisel3._
import chisel3.stage.{ChiselStage, ChiselGeneratorAnnotation}
import firrtl.options.{StageMain}
// Both have to be imported
import mappable._
import cliparams._
case class GCDConfig(
len: Int = 16,
validHigh: Boolean = true
)
/**
* Compute GCD using subtraction method.
* Subtracts the smaller from the larger until register y is zero.
* value in register x is then the GCD
*/
class GCD (val conf: GCDConfig = GCDConfig()) extends Module {
val io = IO(new Bundle {
val value1 = Input(UInt(conf.len.W))
val value2 = Input(UInt(conf.len.W))
val loadingValues = Input(Bool())
val outputGCD = Output(UInt(conf.len.W))
val outputValid = Output(Bool())
})
val x = Reg(UInt())
val y = Reg(UInt())
when(x > y) { x := x - y }
.otherwise { y := y - x }
when(io.loadingValues) {
x := io.value1
y := io.value2
}
io.outputGCD := x
if (conf.validHigh) {
io.outputValid := y === 0.U
} else {
io.outputValid := y =/= 0.U
}
}
class GCDGenStage extends GenericParameterCliStage[GCDConfig]((params, annotations) => {
(new chisel3.stage.ChiselStage).execute(
Array("-X", "verilog"),
Seq(ChiselGeneratorAnnotation(() => new GCD(params))))}, GCDConfig())
object GCDGen extends StageMain(new GCDGenStage)
and the tests
// See README.md for license details.
package gcd
import chisel3._
import firrtl._
import chisel3.tester._
import org.scalatest.FreeSpec
import chisel3.experimental.BundleLiterals._
import chiseltest.internal._
import chiseltest.experimental.TestOptionBuilder._
import firrtl.options.{StageMain}
import mappable._
import cliparams._
class GCDSpec(params: GCDConfig, annotations: AnnotationSeq = Seq()) extends FreeSpec with ChiselScalatestTester {
"Gcd should calculate proper greatest common denominator" in {
test(new GCD(params)) { dut =>
dut.io.value1.poke(95.U)
dut.io.value2.poke(10.U)
dut.io.loadingValues.poke(true.B)
dut.clock.step(1)
dut.io.loadingValues.poke(false.B)
while (dut.io.outputValid.peek().litToBoolean != dut.conf.validHigh) {
dut.clock.step(1)
}
dut.io.outputGCD.expect(5.U)
}
}
}
class GCDTestStage extends GenericParameterCliStage[GCDConfig]((params, annotations) => {
(new GCDSpec(params, annotations)).execute()}, GCDConfig())
object GCDTest extends StageMain(new GCDTestStage)
Both, generation and tests can be parameterized via CLI as in the OQ:
sbt 'runMain gcd.GCDGen --params "len=12,validHigh=false"'
sbt 'test:runMain gcd.GCDTest --params "len=12,validHigh=false"'
Inspired by travisbrown, I'm trying to use a macro to create some "smart constructors".
Given
package mypkg
sealed trait Hello[A]
case class Ohayo[A,B](a: (A,B)) extends Hello[A]
and
val smartConstructors = FreeMacros.liftConstructors[Hello]
The macro should find all the subclasses of Hello, look at their constructors, and extract a few elements to populate this tree for the "smart constructor":
q"""
def $methodName[..$typeParams](...$paramLists): $baseType =
$companionSymbol[..$typeArgs](...$argLists)
"""
I hoped to get:
val smartConstructors = new {
def ohayo[A, B](a: (A, B)): Hello[A] = Ohayo[A, B](a)
}
but instead get:
error: type mismatch;
found : (A(in class Ohayo), B(in class Ohayo))
required: ((some other)A(in class Ohayo), (some other)B(in class Ohayo))
val liftedConstructors = FreeMacros.liftConstructors[Hello]
At a glance, the tree looks ok to me:
scala> q" new { ..$wellTyped }"
res1: u.Tree =
{
final class $anon extends scala.AnyRef {
def <init>() = {
super.<init>();
()
};
def ohayo[A, B](a: (A, B)): net.arya.constructors.Hello[A] = Ohayo[A, B](a)
};
new $anon()
}
but I guess it invisibly isn't. If I naively try to freshen up the typeParams with info.typeParams.map(p => TypeName(p.name.toString)), I get "can't splice A as type parameter" when I do the quasiquoting.
Where am I going wrong? Thanks for taking a look.
-Arya
import scala.language.experimental.macros
import scala.reflect.api.Universe
import scala.reflect.macros.whitebox
class FreeMacros(val c: whitebox.Context) {
import c.universe._
import FreeMacros._
def liftedImpl[F[_]](implicit t: c.WeakTypeTag[F[_]]): Tree = {
val atc = t.tpe
val childSymbols: Set[ClassSymbol] = subCaseClassSymbols(c.universe)(atc.typeSymbol.asClass)
val wellTyped = childSymbols.map(ctorsForSymbol(c.universe)(atc)).unzip
q"new { ..${wellTyped} }"
}
}
object FreeMacros {
def liftConstructors[F[_]]: Any = macro FreeMacros.liftedImpl[F]
def smartName(name: String): String = (
name.toList match {
case h :: t => h.toLower :: t
case Nil => Nil
}
).mkString
def subCaseClassSymbols(u: Universe)(root: u.ClassSymbol): Set[u.ClassSymbol] = {
val subclasses = root.knownDirectSubclasses
val cast = subclasses.map(_.asInstanceOf[u.ClassSymbol])
val partitioned = mapped.partition(_.isCaseClass)
partitioned match {
case (caseClasses, regularClasses) => caseClasses ++ regularClasses.flatMap(r => subCaseClassSymbols(u)(r))
}
}
def ctorsForSymbol(u: Universe)(atc: u.Type)(caseClass: u.ClassSymbol): (u.DefDef, u.DefDef) = {
import u._
import internal._
// these didn't help
// def clearTypeSymbol(s: Symbol): TypeSymbol = internal.newTypeSymbol(NoSymbol, s.name.toTypeName, s.pos, if(s.isImplicit)Flag.IMPLICIT else NoFlags)
// def clearTypeSymbol2(s: Symbol): TypeSymbol = internal.newTypeSymbol(NoSymbol, s.name.toTypeName, NoPosition, if(s.isImplicit)Flag.IMPLICIT else NoFlags)
// def clearTypeDef(d: TypeDef): TypeDef = internal.typeDef(clearTypeSymbol(d.symbol))
val companionSymbol: Symbol = caseClass.companion
val info: Type = caseClass.info
val primaryCtor: Symbol = caseClass.primaryConstructor
val method = primaryCtor.asMethod
val typeParams = info.typeParams.map(internal.typeDef(_))
// val typeParams = info.typeParams.map(s => typeDef(newTypeSymbol(NoSymbol, s.name.toTypeName, NoPosition, NoFlags)))
// val typeParams = info.typeParams.map(s => internal.typeDef(clearTypeSymbol2(s)))
val typeArgs = info.typeParams.map(_.name)
val paramLists = method.paramLists.map(_.map(internal.valDef(_)))
val argLists = method.paramLists.map(_.map(_.asTerm.name))
val baseType = info.baseType(atc.typeSymbol)
val List(returnType) = baseType.typeArgs
val methodName = TermName(smartName(caseClass.name.toString))
val wellTyped =
q"""
def $methodName[..$typeParams](...$paramLists): $baseType =
$companionSymbol[..$typeArgs](...$argLists)
"""
wellTyped
}
}
P.S. I have been experimenting with toolbox.untypecheck / typecheck per this article but haven't found a working combination.
you need using
clas.typeArgs.map(_.toString).map(name => {
TypeDef(Modifiers(Flag.PARAM),TypeName(name), List(),TypeBoundsTree(EmptyTree, EmptyTree))
}
replace
info.typeParams.map(p => TypeName(p.name.toString))
it si my code
object GetSealedSubClass {
def ol3[T]: Any = macro GetSealedSubClassImpl.ol3[T]
}
class GetSealedSubClassImpl(val c: Context) {
import c.universe._
def showInfo(s: String) =
c.info(c.enclosingPosition, s.split("\n").mkString("\n |---macro info---\n |", "\n |", ""), true)
def ol3[T: c.WeakTypeTag]: c.universe.Tree = {
//get all sub class
val subClass = c.weakTypeOf[T]
.typeSymbol.asClass.knownDirectSubclasses
.map(e => e.asClass.toType)
//check type params must ia s sealed class
if (subClass.size < 1)
c.abort(c.enclosingPosition, s"${c.weakTypeOf[T]} is not a sealed class")
// get sub class constructor params
val subConstructorParams = subClass.map { e =>
//get constructor
e.members.filter(_.isConstructor)
//if the class has many Constructor then you need filter the main Constructor
.head.map(s => s.asMethod)
//get function param list
}.map(_.asMethod.paramLists.head)
.map(_.map(e => q"""${e.name.toTermName}:${e.info} """))
val outfunc = subClass zip subConstructorParams map {
case (clas, parm) =>
q"def smartConstructors[..${
clas.typeArgs.map(_.toString).map(name => {
TypeDef(Modifiers(Flag.PARAM), TypeName(name), List(), TypeBoundsTree(EmptyTree, EmptyTree))
})
}](..${parm})=${clas.typeSymbol.name.toTermName} (..${parm})"
}
val outClass =
q"""
object Term{
..${outfunc}
}
"""
showInfo(show(outClass))
q"""{
$outClass
Term
}
"""
}
}
using like this
sealed trait Hello[A]
case class Ohayo[A, B](a: (A, B)) extends Hello[A]
object GetSealed extends App {
val a = GetSealedSubClass.ol3[Hello[_]]
val b=a.asInstanceOf[ {def smartConstructors[A, B](a: (A, B)): Ohayo[A, B]}].smartConstructors(1, 2).a
println(b)
}
I have a case class with annotated fields, like this:
case class Foo(#alias("foo") bar: Int)
I have a macro that processes the declaration of this class:
val (className, access, fields, bases, body) = classDecl match {
case q"case class $n $m(..$ps) extends ..$bs { ..$ss }" => (n, m, ps, bs, ss)
case _ => abort
}
Later, I search for the aliased fields, as follows:
val aliases = fields.asInstanceOf[List[ValDef]].flatMap {
field => field.symbol.annotations.collect {
//deprecated version:
//case annotation if annotation.tpe <:< cv.weakTypeOf[alias] =>
case annotation if annotation.tree.tpe <:< c.weakTypeOf[alias] =>
//deprecated version:
//annotation.scalaArgs.head match {
annotation.tree.children.tail.head match {
case Literal(Constant(param: String)) => (param, field.name)
}
}
}
However, the list of aliases ends up being empty. I have determined that field.symbol.annotations.size is, in fact, 0, despite the annotation clearly sitting on the field.
Any idea of what's wrong?
EDIT
Answering the first two comments:
(1) I tried mods.annotations, but that didn't work. That actually returns List[Tree] instead of List[Annotation], returned by symbol.annotations. Perhaps I didn't modify the code correctly, but the immediate effect was an exception during macro expansion. I'll try to play with it some more.
(2) The class declaration is grabbed while processing an annotation macro slapped on the case class.
The complete code follows. The usage is illustrated in the test code further below.
package com.xxx.util.macros
import scala.collection.immutable.HashMap
import scala.language.experimental.macros
import scala.annotation.StaticAnnotation
import scala.reflect.macros.whitebox
trait Mapped {
def $(key: String) = _vals.get(key)
protected def +=(key: String, value: Any) =
_vals += ((key, value))
private var _vals = new HashMap[String, Any]
}
class alias(val key: String) extends StaticAnnotation
class aliased extends StaticAnnotation {
def macroTransform(annottees: Any*): Any = macro aliasedMacro.impl
}
object aliasedMacro {
def impl(c: whitebox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
import c.universe._
val (classDecl, compDecl) = annottees.map(_.tree) match {
case (clazz: ClassDef) :: Nil => (clazz, None)
case (clazz: ClassDef) :: (comp: ModuleDef) :: Nil => (clazz, Some(comp))
case _ => abort(c, "#aliased must annotate a class")
}
val (className, access, fields, bases, body) = classDecl match {
case q"case class $n $m(..$ps) extends ..$bs { ..$ss }" => (n, m, ps, bs, ss)
case _ => abort(c, "#aliased is only supported on case class")
}
val mappings = fields.asInstanceOf[List[ValDef]].flatMap {
field => field.symbol.annotations.collect {
case annotation if annotation.tree.tpe <:< c.weakTypeOf[alias] =>
annotation.tree.children.tail.head match {
case Literal(Constant(param: String)) =>
q"""this += ($param, ${field.name})"""
}
}
}
val classCode = q"""
case class $className $access(..$fields) extends ..$bases {
..$body; ..$mappings
}"""
c.Expr(compDecl match {
case Some(compCode) => q"""$compCode; $classCode"""
case None => q"""$classCode"""
})
}
protected def abort(c: whitebox.Context, message: String) =
c.abort(c.enclosingPosition, message)
}
The test code:
package test.xxx.util.macros
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner
import org.junit.runner.RunWith
import com.xxx.util.macros._
#aliased
case class Foo(#alias("foo") foo: Int,
#alias("BAR") bar: String,
baz: String) extends Mapped
#RunWith(classOf[JUnitRunner])
class MappedTest extends FunSuite {
val foo = 13
val bar = "test"
val obj = Foo(foo, bar, "extra")
test("field aliased with its own name") {
assertResult(Some(foo))(obj $ "foo")
}
test("field aliased with another string") {
assertResult(Some(bar))(obj $ "BAR")
assertResult(None)(obj $ "bar")
}
test("unaliased field") {
assertResult(None)(obj $ "baz")
}
}
Thanks for the suggestions! In the end, using field.mods.annotations did help. This is how:
val mappings = fields.asInstanceOf[List[ValDef]].flatMap {
field => field.mods.annotations.collect {
case Apply(Select(New(Ident(TypeName("alias"))), termNames.CONSTRUCTOR),
List(Literal(Constant(param: String)))) =>
q"""this += ($param, ${field.name})"""
}
}
I want to define a method, which will return a Future. And in this method, it will call another service which returns also a Future.
We have defined a BusinessResult to represent Success and Fail:
object validation {
trait BusinessResult[+V] {
def flatMap[T](f: V => BusinessResult[T]):BusinessResult[T]
def map[T](f: V => T): BusinessResult[T]
}
sealed case class Success[V](t:V) extends BusinessResult[V] {
def flatMap[T](f: V => BusinessResult[T]):BusinessResult[T] = {
f(t)
}
def map[T](f: V => T): BusinessResult[T] = {
Success(f(t))
}
}
sealed case class Fail(e:String) extends BusinessResult[Nothing] {
def flatMap[T](f: Nothing => BusinessResult[T]):BusinessResult[T] = this
def map[T](f: Nothing => T): BusinessResult[T] = this
}
}
And define the method:
import scala.concurrent._
import scala.concurrent.ExecutionContext.Implicits.global
import validation._
def name: BusinessResult[String] = Success("my name")
def externalService(name:String):Future[String] = future(name)
def myservice:Future[Int] = {
for {
n <- name
res <- externalService(n)
} yield res match {
case "ok" => 1
case _ => 0
}
}
But which is not compilable. The code in myservice can't return a Future[Int] type.
I also tried to wrap the name with Future:
def myservice:Future[Int] = {
for {
nn <- Future.successful(name)
n <- nn
res <- externalService(n)
} yield res match {
case "ok" => 1
case _ => 0
}
}
Which is also not compilable.
I know there must be a lot of issues in this code. How can I adjust them to make it compilable?
If you change the n with some hardcoded string it works, the problem is that in the for comprehension the variable n has type BusinessResult[String], as you probably already know for comprehension desugarize to map, flatMap and filter so the first part n <- name desugarize to a map on name:
val test: BusinessResult[String] = name.map(x => x)
Intellij thinks n is a String but the scala compiler disagree:
type mismatch;
[error] found : validation.BusinessResult[Nothing]
[error] required: scala.concurrent.Future[Int]
[error] n <- name
[error] ^
Easy solution could be to add a getter method to get back the string and do something like Option does:
object validation {
trait BusinessResult[+V] {
def flatMap[T](f: V => BusinessResult[T]):BusinessResult[T]
def map[T](f: V => T): BusinessResult[T]
def getVal: V
}
sealed case class Success[V](t:V) extends BusinessResult[V] {
def flatMap[T](f: V => BusinessResult[T]):BusinessResult[T] = {
f(t)
}
def map[T](f: V => T): BusinessResult[T] = {
Success(f(t))
}
def getVal: V = t
}
sealed case class Fail(e:String) extends BusinessResult[Nothing] {
def flatMap[T](f: Nothing => BusinessResult[T]):BusinessResult[T] = this
def map[T](f: Nothing => T): BusinessResult[T] = this
def getVal = throw new Exception("some message")
}
}
def myservice: Future[Int] = {
val value = name.getVal
for {
res <- externalService(value)
} yield res match {
case "ok" => 1
case _ => 0
}
}
Note that you can't extract the name in the for comprehension since map on String return Chars
I'm building a web-application using Play and Slick, and find myself in a situation where the user-facing forms are similar, but not exactly the same as the database model.
Hence I have two very similar case classes, and need to map from one to another (e.g. while filling the form for rendering an "update" view).
In the case I'm interested in, the database model case class is a super-set of the form case-class, i.e. the only difference between both is that the database model has two more fields (two identifiers, basically).
What I'm now wondering about is whether there'd be a way to build a small library (e.g. macro-driven) to automatically populate the form case class from the database case class based on the member names. I've seen that it may be possible to access this kind of information via reflection using Paranamer, but I'd rather not venture into this.
Here is a solution using Dynamic because I wanted to try it out. A macro would decide statically whether to emit an apply of a source value method, the default value method, or just to supply a literal. The syntax could look something like newFrom[C](k). (Update: see below for the macro.)
import scala.language.dynamics
trait Invocable extends Dynamic {
import scala.reflect.runtime.currentMirror
import scala.reflect.runtime.universe._
def applyDynamic(method: String)(source: Any) = {
require(method endsWith "From")
def caseMethod(s: Symbol) = s.asTerm.isCaseAccessor && s.asTerm.isMethod
val sm = currentMirror reflect source
val ms = sm.symbol.asClass.typeSignature.members filter caseMethod map (_.asMethod)
val values = ms map (m => (m.name, (sm reflectMethod m)()))
val im = currentMirror reflect this
invokeWith(im, method dropRight 4, values.toMap)
}
def invokeWith(im: InstanceMirror, name: String, values: Map[Name, Any]): Any = {
val at = TermName(name)
val ts = im.symbol.typeSignature
val method = (ts member at).asMethod
// supplied value or defarg or default val for type of p
def valueFor(p: Symbol, i: Int): Any = {
if (values contains p.name) values(p.name)
else ts member TermName(s"$name$$default$$${i+1}") match {
case NoSymbol =>
if (p.typeSignature.typeSymbol.asClass.isPrimitive) {
if (p.typeSignature <:< typeOf[Int]) 0
else if (p.typeSignature <:< typeOf[Double]) 0.0
else ???
} else null
case defarg => (im reflectMethod defarg.asMethod)()
}
}
val args = (for (ps <- method.paramss; p <- ps) yield p).zipWithIndex map (p => valueFor(p._1,p._2))
(im reflectMethod method)(args: _*)
}
}
case class C(a: String, b: Int, c: Double = 2.0, d: Double)
case class K(b: Int, e: String, a: String)
object C extends Invocable
object Test extends App {
val res = C applyFrom K(8, "oh", "kay")
Console println res // C(kay,8,2.0,0.0)
}
Update: Here is the macro version, more for fun than for profit:
import scala.language.experimental.macros
import scala.reflect.macros._
import scala.collection.mutable.ListBuffer
def newFrom[A, B](source: A): B = macro newFrom_[A, B]
def newFrom_[A: c.WeakTypeTag, B: c.WeakTypeTag](c: Context)(source: c.Expr[A]): c.Expr[B] = {
import c.{ literal, literalNull }
import c.universe._
import treeBuild._
import nme.{ CONSTRUCTOR => Ctor }
def caseMethod(s: Symbol) = s.asTerm.isCaseAccessor && s.asTerm.isMethod
def defaulter(name: Name, i: Int): String = s"${name.encoded}$$default$$${i+1}"
val noargs = List[c.Tree]()
// side effects: first evaluate the arg
val side = ListBuffer[c.Tree]()
val src = TermName(c freshName "src$")
side += ValDef(Modifiers(), src, TypeTree(source.tree.tpe), source.tree)
// take the arg as instance of a case class and use the case members
val a = implicitly[c.WeakTypeTag[A]].tpe
val srcs = (a.members filter caseMethod map (m => (m.name, m.asMethod))).toMap
// construct the target, using src fields, defaults (from the companion), or zero
val b = implicitly[c.WeakTypeTag[B]].tpe
val bm = b.typeSymbol.asClass.companionSymbol.asModule
val bc = bm.moduleClass.asClass.typeSignature
val ps = (b declaration Ctor).asMethod.paramss.flatten.zipWithIndex
val args: List[c.Tree] = ps map { case (p, i) =>
if (srcs contains p.name)
Select(Ident(src), p.name)
else bc member TermName(defaulter(Ctor, i)) match {
case NoSymbol =>
if (p.typeSignature.typeSymbol.asClass.isPrimitive) {
if (p.typeSignature <:< typeOf[Int]) literal(0).tree
else if (p.typeSignature <:< typeOf[Double]) literal(0.0).tree
else ???
} else literalNull.tree
case defarg => Select(mkAttributedRef(bm), defarg.name)
}
}
c.Expr(Block(side.toList, Apply(Select(New(mkAttributedIdent(b.typeSymbol)), Ctor), args)))
}
With usage:
case class C(a: String, b: Int, c: Double = 2.0, d: Double)
case class K(b: Int, e: String, a: String) { def i() = b }
val res = newFrom[K, C](K(8, "oh", "kay"))