Chisel: How to change module parameters from command line? - scala

I have many modules with multiple parameters. Take as a toy example a modified version of the GCD in the template:
class GCD (len: Int = 16, validHigh: Boolean = true) extends Module {
val io = IO(new Bundle {
val value1 = Input(UInt(len.W))
val value2 = Input(UInt(len.W))
val loadingValues = Input(Bool())
val outputGCD = Output(UInt(len.W))
val outputValid = Output(Bool())
})
val x = Reg(UInt())
val y = Reg(UInt())
when(x > y) { x := x - y }
.otherwise { y := y - x }
when(io.loadingValues) {
x := io.value1
y := io.value2
}
io.outputGCD := x
if (validHigh) {
io.outputValid := (y === 0.U)
} else {
io.outputValid := (y =/= 0.U)
}
}
To test or synthesize many different designs, I want to change the values from the command line when I call the tester or the generator apps. Preferably, like this:
[generation or test command] --len 12 --validHigh false
but this or something similar would also be okay
[generation or test command] --param "len=12" --param "validHigh=false"
After some trial and error, I came up with a solution that looks like this:
gcd.scala
package gcd
import firrtl._
import chisel3._
case class GCDConfig(
len: Int = 16,
validHigh: Boolean = true
)
class GCD (val conf: GCDConfig = GCDConfig()) extends Module {
val io = IO(new Bundle {
val value1 = Input(UInt(conf.len.W))
val value2 = Input(UInt(conf.len.W))
val loadingValues = Input(Bool())
val outputGCD = Output(UInt(conf.len.W))
val outputValid = Output(Bool())
})
val x = Reg(UInt())
val y = Reg(UInt())
when(x > y) { x := x - y }
.otherwise { y := y - x }
when(io.loadingValues) {
x := io.value1
y := io.value2
}
io.outputGCD := x
if (conf.validHigh) {
io.outputValid := y === 0.U
} else {
io.outputValid := y =/= 0.U
}
}
trait HasParams {
self: ExecutionOptionsManager =>
var params: Map[String, String] = Map()
parser.note("Design Parameters")
parser.opt[Map[String, String]]('p', "params")
.valueName("k1=v1,k2=v2")
.foreach { v => params = v }
.text("Parameters of Design")
}
object GCD {
def apply(params: Map[String, String]): GCD = {
new GCD(params2conf(params))
}
def params2conf(params: Map[String, String]): GCDConfig = {
var conf = new GCDConfig
for ((k, v) <- params) {
(k, v) match {
case ("len", _) => conf = conf.copy(len = v.toInt)
case ("validHigh", _) => conf = conf.copy(validHigh = v.toBoolean)
case _ =>
}
}
conf
}
}
object GCDGen extends App {
val optionsManager = new ExecutionOptionsManager("gcdgen")
with HasChiselExecutionOptions with HasFirrtlOptions with HasParams
optionsManager.parse(args) match {
case true =>
chisel3.Driver.execute(optionsManager, () => GCD(optionsManager.params))
case _ =>
ChiselExecutionFailure("could not parse results")
}
}
and for tests
GCDSpec.scala
package gcd
import chisel3._
import firrtl._
import chisel3.tester._
import org.scalatest.FreeSpec
import chisel3.experimental.BundleLiterals._
import chiseltest.internal._
import chiseltest.experimental.TestOptionBuilder._
object GCDTest extends App {
val optionsManager = new ExecutionOptionsManager("gcdtest") with HasParams
optionsManager.parse(args) match {
case true =>
//println(optionsManager.commonOptions.programArgs)
(new GCDSpec(optionsManager.params)).execute()
case _ =>
ChiselExecutionFailure("could not parse results")
}
}
class GCDSpec(params: Map[String, String] = Map()) extends FreeSpec with ChiselScalatestTester {
"Gcd should calculate proper greatest common denominator" in {
test(GCD(params)) { dut =>
dut.io.value1.poke(95.U)
dut.io.value2.poke(10.U)
dut.io.loadingValues.poke(true.B)
dut.clock.step(1)
dut.io.loadingValues.poke(false.B)
while (dut.io.outputValid.peek().litToBoolean != dut.conf.validHigh) {
dut.clock.step(1)
}
dut.io.outputGCD.expect(5.U)
}
}
}
This way, I can generate different designs and test them with
sbt 'runMain gcd.GCDGen --params "len=12,validHigh=false"'
sbt 'test:runMain gcd.GCDTest --params "len=12,validHigh=false"'
But there are a couple of problems or annoyances with this solution:
It uses deprecated features (ExecutionOptionsManager and HasFirrtlOptions). I'm not sure if this solution is portable to the new FirrtlStage Infrastructure.
There is a lot of boilerplate involved. It becomes tedious to write new case classes and params2conf functions for every module and rewrite both when a parameter is added or removed.
Using conf.x instead of x all the time. But I guess, this is unavoidable because there is nothing like python's kwargs in Scala.
Is there a better way or one that is at least not deprecated?

Good Question.
I think you are you have pretty much everything right. I don't usually find that I need the command line to alter my tests, my development cycle usually is just poking values in the test params directly running. I use intelliJ which seems to make that easy (but may only work for my habits and the scale of projects I work on).
But I would like to offer you a suggestions that will get you away from ExecutionOptions style as that is going away fast.
In my example code below I offer basically two files here in line, in the first there a few library like tools that use the modern annotation idioms and, I believe, minimize boiler plate. They rely on stringy matching but that is fixable.
In the second, is your GCD, GCDSpec, slightly modified to pull out the params a bit differently. At the bottom of the second is some very minimal boiler plate that allows you to get the command line access you want.
Good luck, I hope this is mostly self explanatory.
First file:
import chisel3.stage.ChiselCli
import firrtl.AnnotationSeq
import firrtl.annotations.{Annotation, NoTargetAnnotation}
import firrtl.options.{HasShellOptions, Shell, ShellOption, Stage, Unserializable}
import firrtl.stage.FirrtlCli
trait TesterAnnotation {
this: Annotation =>
}
case class TestParams(params: Map[String, String] = Map.empty) {
val defaults: collection.mutable.HashMap[String, String] = new collection.mutable.HashMap()
def getInt(key: String): Int = params.getOrElse(key, defaults(key)).toInt
def getBoolean(key: String): Boolean = params.getOrElse(key, defaults(key)).toBoolean
def getString(key: String): String = params.getOrElse(key, defaults(key))
}
case class TesterParameterAnnotation(paramString: TestParams)
extends TesterAnnotation
with NoTargetAnnotation
with Unserializable
object TesterParameterAnnotation extends HasShellOptions {
val options = Seq(
new ShellOption[Map[String, String]](
longOption = "param-string",
toAnnotationSeq = (a: Map[String, String]) => Seq(TesterParameterAnnotation(TestParams(a))),
helpText = """a comma separated, space free list of additional paramters, e.g. --param-string "k1=7,k2=dog" """
)
)
}
trait TesterCli {
this: Shell =>
Seq(TesterParameterAnnotation).foreach(_.addOptions(parser))
}
class GenericTesterStage(thunk: (TestParams, AnnotationSeq) => Unit) extends Stage {
val shell: Shell = new Shell("chiseltest") with TesterCli with ChiselCli with FirrtlCli
def run(annotations: AnnotationSeq): AnnotationSeq = {
val params = annotations.collectFirst { case TesterParameterAnnotation(p) => p }.getOrElse(TestParams())
thunk(params, annotations)
annotations
}
}
Second File:
import chisel3._
import chisel3.tester._
import chiseltest.experimental.TestOptionBuilder._
import chiseltest.{ChiselScalatestTester, GenericTesterStage, TestParams}
import firrtl._
import firrtl.options.StageMain
import org.scalatest.freespec.AnyFreeSpec
case class GCD(testParams: TestParams) extends Module {
val bitWidth = testParams.getInt("len")
val validHigh = testParams.getBoolean("validHigh")
val io = IO(new Bundle {
val value1 = Input(UInt(bitWidth.W))
val value2 = Input(UInt(bitWidth.W))
val loadingValues = Input(Bool())
val outputGCD = Output(UInt(bitWidth.W))
val outputValid = Output(Bool())
})
val x = Reg(UInt())
val y = Reg(UInt())
when(x > y) { x := x - y }.otherwise { y := y - x }
when(io.loadingValues) {
x := io.value1
y := io.value2
}
io.outputGCD := x
if (validHigh) {
io.outputValid := y === 0.U
} else {
io.outputValid := y =/= 0.U
}
}
class GCDSpec(params: TestParams, annotations: AnnotationSeq = Seq()) extends AnyFreeSpec with ChiselScalatestTester {
"Gcd should calculate proper greatest common denominator" in {
test(GCD(params)).withAnnotations(annotations) { dut =>
dut.io.value1.poke(95.U)
dut.io.value2.poke(10.U)
dut.io.loadingValues.poke(true.B)
dut.clock.step(1)
dut.io.loadingValues.poke(false.B)
while (dut.io.outputValid.peek().litToBoolean != dut.validHigh) {
dut.clock.step(1)
}
dut.io.outputGCD.expect(5.U)
}
}
}
class GcdTesterStage
extends GenericTesterStage((params, annotations) => {
params.defaults ++= Seq("len" -> "16", "validHigh" -> "false")
(new GCDSpec(params, annotations)).execute()
})
object GcdTesterStage extends StageMain(new GcdTesterStage)

Based on http://blog.echo.sh/2013/11/04/exploring-scala-macros-map-to-case-class-conversion.html, I was able to find another way of removing the params2conf boilerplate using scala macros. I also extended Chick's answer with verilog generation since that was also part of the original question. A full repository of my solution can be found on github.
Basically there are three four files:
The macro that converts a map to a case class:
package mappable
import scala.language.experimental.macros
import scala.reflect.macros.whitebox.Context
trait Mappable[T] {
def toMap(t: T): Map[String, String]
def fromMap(map: Map[String, String]): T
}
object Mappable {
implicit def materializeMappable[T]: Mappable[T] = macro materializeMappableImpl[T]
def materializeMappableImpl[T: c.WeakTypeTag](c: Context): c.Expr[Mappable[T]] = {
import c.universe._
val tpe = weakTypeOf[T]
val companion = tpe.typeSymbol.companion
val fields = tpe.decls.collectFirst {
case m: MethodSymbol if m.isPrimaryConstructor => m
}.get.paramLists.head
val (toMapParams, fromMapParams) = fields.map { field =>
val name = field.name.toTermName
val decoded = name.decodedName.toString
val returnType = tpe.decl(name).typeSignature
val fromMapLine = returnType match {
case NullaryMethodType(res) if res =:= typeOf[Int] => q"map($decoded).toInt"
case NullaryMethodType(res) if res =:= typeOf[String] => q"map($decoded)"
case NullaryMethodType(res) if res =:= typeOf[Boolean] => q"map($decoded).toBoolean"
case _ => q""
}
(q"$decoded -> t.$name.toString", fromMapLine)
}.unzip
c.Expr[Mappable[T]] { q"""
new Mappable[$tpe] {
def toMap(t: $tpe): Map[String, String] = Map(..$toMapParams)
def fromMap(map: Map[String, String]): $tpe = $companion(..$fromMapParams)
}
""" }
}
}
Library like tools:
package cliparams
import chisel3.stage.{ChiselStage, ChiselGeneratorAnnotation, ChiselCli}
import firrtl.AnnotationSeq
import firrtl.annotations.{Annotation, NoTargetAnnotation}
import firrtl.options.{HasShellOptions, Shell, ShellOption, Stage, Unserializable, StageMain}
import firrtl.stage.FirrtlCli
import mappable._
trait SomeAnnotaion {
this: Annotation =>
}
case class ParameterAnnotation(map: Map[String, String])
extends SomeAnnotaion
with NoTargetAnnotation
with Unserializable
object ParameterAnnotation extends HasShellOptions {
val options = Seq(
new ShellOption[Map[String, String]](
longOption = "params",
toAnnotationSeq = (a: Map[String, String]) => Seq(ParameterAnnotation(a)),
helpText = """a comma separated, space free list of additional paramters, e.g. --param-string "k1=7,k2=dog" """
)
)
}
trait ParameterCli {
this: Shell =>
Seq(ParameterAnnotation).foreach(_.addOptions(parser))
}
class GenericParameterCliStage[P: Mappable](thunk: (P, AnnotationSeq) => Unit, default: P) extends Stage {
def mapify(p: P) = implicitly[Mappable[P]].toMap(p)
def materialize(map: Map[String, String]) = implicitly[Mappable[P]].fromMap(map)
val shell: Shell = new Shell("chiseltest") with ParameterCli with ChiselCli with FirrtlCli
def run(annotations: AnnotationSeq): AnnotationSeq = {
val params = annotations
.collectFirst {case ParameterAnnotation(map) => materialize(mapify(default) ++ map.toSeq)}
.getOrElse(default)
thunk(params, annotations)
annotations
}
}
The GCD source file
// See README.md for license details.
package gcd
import firrtl._
import chisel3._
import chisel3.stage.{ChiselStage, ChiselGeneratorAnnotation}
import firrtl.options.{StageMain}
// Both have to be imported
import mappable._
import cliparams._
case class GCDConfig(
len: Int = 16,
validHigh: Boolean = true
)
/**
* Compute GCD using subtraction method.
* Subtracts the smaller from the larger until register y is zero.
* value in register x is then the GCD
*/
class GCD (val conf: GCDConfig = GCDConfig()) extends Module {
val io = IO(new Bundle {
val value1 = Input(UInt(conf.len.W))
val value2 = Input(UInt(conf.len.W))
val loadingValues = Input(Bool())
val outputGCD = Output(UInt(conf.len.W))
val outputValid = Output(Bool())
})
val x = Reg(UInt())
val y = Reg(UInt())
when(x > y) { x := x - y }
.otherwise { y := y - x }
when(io.loadingValues) {
x := io.value1
y := io.value2
}
io.outputGCD := x
if (conf.validHigh) {
io.outputValid := y === 0.U
} else {
io.outputValid := y =/= 0.U
}
}
class GCDGenStage extends GenericParameterCliStage[GCDConfig]((params, annotations) => {
(new chisel3.stage.ChiselStage).execute(
Array("-X", "verilog"),
Seq(ChiselGeneratorAnnotation(() => new GCD(params))))}, GCDConfig())
object GCDGen extends StageMain(new GCDGenStage)
and the tests
// See README.md for license details.
package gcd
import chisel3._
import firrtl._
import chisel3.tester._
import org.scalatest.FreeSpec
import chisel3.experimental.BundleLiterals._
import chiseltest.internal._
import chiseltest.experimental.TestOptionBuilder._
import firrtl.options.{StageMain}
import mappable._
import cliparams._
class GCDSpec(params: GCDConfig, annotations: AnnotationSeq = Seq()) extends FreeSpec with ChiselScalatestTester {
"Gcd should calculate proper greatest common denominator" in {
test(new GCD(params)) { dut =>
dut.io.value1.poke(95.U)
dut.io.value2.poke(10.U)
dut.io.loadingValues.poke(true.B)
dut.clock.step(1)
dut.io.loadingValues.poke(false.B)
while (dut.io.outputValid.peek().litToBoolean != dut.conf.validHigh) {
dut.clock.step(1)
}
dut.io.outputGCD.expect(5.U)
}
}
}
class GCDTestStage extends GenericParameterCliStage[GCDConfig]((params, annotations) => {
(new GCDSpec(params, annotations)).execute()}, GCDConfig())
object GCDTest extends StageMain(new GCDTestStage)
Both, generation and tests can be parameterized via CLI as in the OQ:
sbt 'runMain gcd.GCDGen --params "len=12,validHigh=false"'
sbt 'test:runMain gcd.GCDTest --params "len=12,validHigh=false"'

Related

Can you implement dsinfo in Scala 3? (Can Scala 3 macros get info about their context?)

The dsinfo library lets you access the names of values from the context of where a function is written using Scala 2 macros. The example they give is that if you have something like
val name = myFunction(x, y)
myFunction will actually be passed the name of its val in addition to the other arguments, i.e., myFunction("name", x, y).
This is very useful for DSLs where you'd like named values for error reporting or other kinds of encoding. The only other option seems to explicitly pass the name as a String, which can lead to unintentional mismatches.
Is this possible with Scala 3 macros, and if so, how do you "climb up" the tree at the macro's use location to find its id?
In Scala 3 there is no c.macroApplication. Only Position.ofMacroExpansion instead of a tree. But we can analyze Symbol.spliceOwner.maybeOwner. I presume that scalacOptions += "-Yretain-trees" is switched on.
import scala.annotation.experimental
import scala.quoted.*
object Macro {
inline def makeCallWithName[T](inline methodName: String): T =
${makeCallWithNameImpl[T]('methodName)}
#experimental
def makeCallWithNameImpl[T](methodName: Expr[String])(using Quotes, Type[T]): Expr[T] = {
import quotes.reflect.*
println(Position.ofMacroExpansion.sourceCode)//Some(twoargs(1, "one"))
val methodNameStr = methodName.valueOrAbort
val strs = methodNameStr.split('.')
val moduleName = strs.init.mkString(".")
val moduleSymbol = Symbol.requiredModule(moduleName)
val shortMethodName = strs.last
val ident = Ident(TermRef(moduleSymbol.termRef, shortMethodName))
val (ownerName, ownerRhs) = Symbol.spliceOwner.maybeOwner.tree match {
case ValDef(name, tpt, Some(rhs)) => (name, rhs)
case DefDef(name, paramss, tpt, Some(rhs)) => (name, rhs)
case t => report.errorAndAbort(s"can't find RHS of ${t.show}")
}
val treeAccumulator = new TreeAccumulator[Option[Tree]] {
override def foldTree(acc: Option[Tree], tree: Tree)(owner: Symbol): Option[Tree] = tree match {
case Apply(fun, args) if fun.symbol.fullName == "App$.twoargs" =>
Some(Apply(ident, Literal(StringConstant(ownerName)) :: args))
case _ => foldOverTree(acc, tree)(owner)
}
}
treeAccumulator.foldTree(None, ownerRhs)(ownerRhs.symbol)
.getOrElse(report.errorAndAbort(s"can't find twoargs in RHS: ${ownerRhs.show}"))
.asExprOf[T]
}
}
Usage:
package mypackage
case class TwoArgs(name : String, i : Int, s : String)
import mypackage.TwoArgs
object App {
inline def twoargs(i: Int, s: String) =
Macro.makeCallWithName[TwoArgs]("mypackage.TwoArgs.apply")
def x() = twoargs(1, "one") // TwoArgs("x", 1, "one")
def aMethod() = {
val y = twoargs(2, "two") // TwoArgs("y", 2, "two")
}
val z = Some(twoargs(3, "three")) // Some(TwoArgs("z", 3, "three"))
}
dsinfo also handles the name twoargs at call site (as template $macro) but I didn't implement this. I guess the name (if necessary) can be obtained from Position.ofMacroExpansion.sourceCode.
Update. Here is implementation handling name of inline method (e.g. twoargs) using Scalameta + Semanticdb besides Scala 3 macros.
import mypackage.TwoArgs
object App {
inline def twoargs(i: Int, s: String) =
Macro.makeCallWithName[TwoArgs]("mypackage.TwoArgs.apply")
inline def twoargs1(i: Int, s: String) =
Macro.makeCallWithName[TwoArgs]("mypackage.TwoArgs.apply")
def x() = twoargs(1, "one") // TwoArgs("x", 1, "one")
def aMethod() = {
val y = twoargs(2, "two") // TwoArgs("y", 2, "two")
}
val z = Some(twoargs1(3, "three")) // Some(TwoArgs("z", 3, "three"))
}
package mypackage
case class TwoArgs(name : String, i : Int, s : String)
import scala.annotation.experimental
import scala.quoted.*
object Macro {
inline def makeCallWithName[T](inline methodName: String): T =
${makeCallWithNameImpl[T]('methodName)}
#experimental
def makeCallWithNameImpl[T](methodName: Expr[String])(using Quotes, Type[T]): Expr[T] = {
import quotes.reflect.*
val position = Position.ofMacroExpansion
val scalaFile = position.sourceFile.getJPath.getOrElse(
report.errorAndAbort(s"maybe virtual file, can't find path to position $position")
)
val inlineMethodSymbol =
new SemanticdbInspector(scalaFile)
.getInlineMethodSymbol(position.start, position.end)
.getOrElse(report.errorAndAbort(s"can't find Scalameta symbol at position (${position.startLine},${position.startColumn})..(${position.endLine},${position.endColumn})=$position"))
val methodNameStr = methodName.valueOrAbort
val strs = methodNameStr.split('.')
val moduleName = strs.init.mkString(".")
val moduleSymbol = Symbol.requiredModule(moduleName)
val shortMethodName = strs.last
val ident = Ident(TermRef(moduleSymbol.termRef, shortMethodName))
val owner = Symbol.spliceOwner.maybeOwner
val macroApplication: Option[Tree] = {
val (ownerName, ownerRhs) = owner.tree match {
case ValDef(name, tpt, Some(rhs)) => (name, rhs)
case DefDef(name, paramss, tpt, Some(rhs)) => (name, rhs)
case t => report.errorAndAbort(s"can't find RHS of ${t.show}")
}
val treeAccumulator = new TreeAccumulator[Option[Tree]] {
override def foldTree(acc: Option[Tree], tree: Tree)(owner: Symbol): Option[Tree] = tree match {
case Apply(fun, args) if tree.pos == position /* fun.symbol.fullName == inlineMethodSymbol */ =>
Some(Apply(ident, Literal(StringConstant(ownerName)) :: args))
case _ => foldOverTree(acc, tree)(owner)
}
}
treeAccumulator.foldTree(None, ownerRhs)(ownerRhs.symbol)
}
val res = macroApplication
.getOrElse(report.errorAndAbort(s"can't find application of $inlineMethodSymbol in RHS of $owner"))
report.info(res.show)
res.asExprOf[T]
}
}
import java.nio.file.{Path, Paths}
import scala.io
import scala.io.BufferedSource
import scala.meta.*
import scala.meta.interactive.InteractiveSemanticdb
import scala.meta.internal.semanticdb.{ClassSignature, Locator, Range, SymbolInformation, SymbolOccurrence, TextDocument, TypeRef}
class SemanticdbInspector(val scalaFile: Path) {
val scalaFileStr = scalaFile.toString
var textDocuments: Seq[TextDocument] = Seq()
Locator(
Paths.get(scalaFileStr + ".semanticdb")
)((path, textDocs) => {
textDocuments ++= textDocs.documents
})
val bufferedSource: BufferedSource = io.Source.fromFile(scalaFileStr)
val source = try bufferedSource.mkString finally bufferedSource.close()
extension (tree: Tree) {
def occurence: Option[SymbolOccurrence] = {
val treeRange = Range(tree.pos.startLine, tree.pos.startColumn, tree.pos.endLine, tree.pos.endColumn)
textDocuments.flatMap(_.occurrences)
.find(_.range.exists(occurrenceRange => treeRange == occurrenceRange))
}
def info: Option[SymbolInformation] = occurence.flatMap(_.symbol.info)
}
extension (symbol: String) {
def info: Option[SymbolInformation] = textDocuments.flatMap(_.symbols).find(_.symbol == symbol)
}
def getInlineMethodSymbol(startOffset: Int, endOffset: Int): Option[String] = {
def translateScalametaToMacro3(symbol: String): String =
symbol
.stripPrefix("_empty_/")
.stripSuffix("().")
.replace(".", "$.")
.replace("/", ".")
dialects.Scala3(source).parse[Source].get.collect {
case t#Term.Apply(fun, args) if t.pos.start == startOffset && t.pos.end == endOffset =>
fun.info.map(_.symbol)
}.headOption.flatten.map(translateScalametaToMacro3)
}
}
lazy val scala3V = "3.1.3"
lazy val scala2V = "2.13.8"
lazy val scalametaV = "4.5.13"
lazy val root = project
.in(file("."))
.settings(
name := "scala3demo",
version := "0.1.0-SNAPSHOT",
scalaVersion := scala3V,
libraryDependencies ++= Seq(
"org.scalameta" %% "scalameta" % scalametaV cross CrossVersion.for3Use2_13,
"org.scalameta" % s"semanticdb-scalac_$scala2V" % scalametaV,
),
scalacOptions ++= Seq(
"-Yretain-trees",
),
semanticdbEnabled := true,
)
By the way, Semantidb can't be replaced by Tasty here because when a macro in App is being expanded, the file App.scala.semantidb already exists (it's generated early, at frontend phase of compilation) but App.tasty hasn't yet (it appears when App has been compiled i.e. after expansion of the macro, at pickler phase).
.scala.semanticdb file will appear even if .scala file doesn't compile (e.g. if there is an error in macro expansion) but .tasty file won't.
scala.meta parent of parent of Defn.Object
Is it possible to using macro to modify the generated code of structural-typing instance invocation?
Scala conditional compilation
Macro annotation to override toString of Scala function
How to merge multiple imports in scala?
How to get the type of a variable with scalameta if the decltpe is empty?
See also https://github.com/lampepfl/dotty-macro-examples/tree/main/accessEnclosingParameters
Simplified version:
import scala.quoted.*
inline def makeCallWithName[T](inline methodName: String): T =
${makeCallWithNameImpl[T]('methodName)}
def makeCallWithNameImpl[T](methodName: Expr[String])(using Quotes, Type[T]): Expr[T] = {
import quotes.reflect.*
val position = Position.ofMacroExpansion
val methodNameStr = methodName.valueOrAbort
val strs = methodNameStr.split('.')
val moduleName = strs.init.mkString(".")
val moduleSymbol = Symbol.requiredModule(moduleName)
val shortMethodName = strs.last
val ident = Ident(TermRef(moduleSymbol.termRef, shortMethodName))
val owner0 = Symbol.spliceOwner.maybeOwner
val ownerName = owner0.tree match {
case ValDef(name, _, _) => name
case DefDef(name, _, _, _) => name
case t => report.errorAndAbort(s"unexpected tree shape: ${t.show}")
}
val owner = if owner0.isLocalDummy then owner0.maybeOwner else owner0
val macroApplication: Option[Tree] = {
val treeAccumulator = new TreeAccumulator[Option[Tree]] {
override def foldTree(acc: Option[Tree], tree: Tree)(owner: Symbol): Option[Tree] = tree match {
case _ if tree.pos == position => Some(tree)
case _ => foldOverTree(acc, tree)(owner)
}
}
treeAccumulator.foldTree(None, owner.tree)(owner)
}
val res = macroApplication.getOrElse(
report.errorAndAbort("can't find macro application")
) match {
case Apply(_, args) => Apply(ident, Literal(StringConstant(ownerName)) :: args)
case t => report.errorAndAbort(s"unexpected shape of macro application: ${t.show}")
}
report.info(res.show)
res.asExprOf[T]
}

Get fully qualified method name in scala macros

I use Scala macros and match Apply and I would like to get fully qualified name of the function which is called.
Examples:
println("") -> scala.Predef.println
scala.Predef.println("") -> scala.Predef.println
class Abc {
def met(): Unit = ???
}
case class X {
def met(): Unit = ???
def abc(): Abc = ???
}
val a = new Abc()
val x = new Abc()
a.met() -> Abc.met
new Abc().met() -> Abc.met
X() -> X.apply
X().met() -> X.met
x.met() -> X.met
x.abc.met() -> Abc.met
On the left side is what I have in code and on the right side after arrow is what I want to get. Is it possible? And how?
Here is the macro:
import scala.language.experimental.macros
import scala.reflect.macros.blackbox
object ExampleMacro {
final val useFullyQualifiedName = false
def methodName(param: Any): String = macro debugParameters_Impl
def debugParameters_Impl(c: blackbox.Context)(param: c.Expr[Any]): c.Expr[String] = {
import c.universe._
param.tree match {
case Apply(Select(t, TermName(methodName)), _) =>
val baseClass = t.tpe.resultType.baseClasses.head // there may be a better way than this line
val className = if (useFullyQualifiedName) baseClass.fullName else baseClass.name
c.Expr[String](Literal(Constant(className + "." + methodName)))
case _ => sys.error("Not a method call: " + show(param.tree))
}
}
}
Usage of the macro:
object Main {
def main(args: Array[String]): Unit = {
class Abc {
def met(): Unit = ???
}
case class X() {
def met(): Unit = ???
def abc(): Abc = ???
}
val a = new Abc()
val x = X()
import sk.ygor.stackoverflow.q53326545.macros.ExampleMacro.methodName
println(methodName(Main.main(Array("foo", "bar"))))
println(methodName(a.met()))
println(methodName(new Abc().met()))
println(methodName(X()))
println(methodName(X().met()))
println(methodName(x.met()))
println(methodName(x.abc().met()))
println(methodName("a".getClass))
}
}
Source code for this example contains following:
it is a multi module SBT project, because macros have to be in a separate compilation unit than classes, which use the macro
macro modules depends explicitly on libraryDependencies += "org.scala-lang" % "scala-reflect" % scalaVersion.value,

Scalacheck Shrink

I am fairly new to ScalaCheck (and Scala entirely) so this may be a fairly simple solution
I am using ScalaCheck to generate tests for an AST and verifying that the writer/parser work. I have these files
AST.scala
package com.test
object Operator extends Enumeration {
val Add, Subtract, Multiply, Divide = Value
}
sealed trait AST
case class Operation(left: AST, op: Operator.Value, right: AST) extends AST
case class Literal(value: Int) extends AST
GenOperation.scala
import com.test.{AST, Literal}
import org.scalacheck._
import Shrink._
import Prop._
import Arbitrary.arbitrary
object GenLiteral extends Properties("AST::Literal") {
property("Verify parse/write") = forAll(genLiteral){ (node) =>
// val string_version = node.writeToString() // AST -> String
// val result = Parse(string_version) // String -> AST
true
}
def genLiteral: Gen[Literal] = for {
value <- arbitrary[Int]
} yield Literal(value)
implicit def shrinkLiteral: Shrink[AST] = Shrink {
case Literal(value) =>
for {
reduced <- shrink(value)
} yield Literal(reduced)
}
}
GenOperation.scala
import com.test.{AST, Operation}
import org.scalacheck._
import Gen._
import Shrink._
import Prop._
import GenLiteral._
object GenOperation extends Properties("AST::Operation") {
property("Verify parse/write") = forAll(genOperation){ (node) =>
// val string_version = node.writeToString() // AST -> String
// val result = Parse(string_version) // String -> AST
true
}
def genOperation: Gen[Operation] = for {
left <- oneOf(genOperation, genLiteral)
right <- oneOf(genOperation, genLiteral)
op <- oneOf(Operator.values.toSeq)
} yield Operation(left,op,right)
implicit def shrinkOperation: Shrink[AST] = Shrink {
case Operation(l,o,r) =>
(
for {
ls <- shrink(l)
rs <- shrink(r)
} yield Operation(ls, o, rs)
) append (
for {
ls <- shrink(l)
} yield Operation(ls, o, r)
) append (
for {
rs <- shrink(r)
} yield Operation(l, o, rs)
) append shrink(l) append shrink(r)
}
}
In the example code I wrote (what is pasted above) I get the error
ambiguous implicit values:
both method shrinkLiteral in object GenLiteral of type => org.scalacheck.Shrink[com.test.AST]
and method shrinkOperation in object GenOperation of type => org.scalacheck.Shrink[com.test.AST]
match expected type org.scalacheck.Shrink[com.test.AST]
ls <- shrink(l)
How do I write the shrink methods for this?
You have two implicit instances of Shrink[AST] and so the compiler complains about ambiguous implicit values.
You could re-write your code as:
implicit def shrinkLiteral: Shrink[Literal] = Shrink {
case Literal(value) => shrink(value).map(Literal)
}
implicit def shrinkOperation: Shrink[Operation] = Shrink {
case Operation(l,o,r) =>
shrink(l).map(Operation(_, o, r)) append
shrink(r).map(Operation(l, o, _)) append ???
}
implicit def shrinkAST: Shrink[AST] = Shrink {
case o: Operation => shrink(o)
case l: Literal => shrink(l)
}

how to print variable name and value using a scala macro?

I am sure there is a more elegant way of writing the following macro which prints the name and value of a variable:
def mprintx(c: Context)(linecode: c.Expr[Any]): c.Expr[Unit] = {
import c.universe._
val namez = (c.enclosingImpl match {
case ClassDef(mods, name, tparams, impl) =>
c.universe.reify(c.literal(name.toString).splice)
case ModuleDef(mods, name, impl) =>
c.universe.reify(c.literal(name.toString).splice)
case _ => c.abort(c.enclosingPosition, "NoEnclosingClass")
}).toString match {
case r_name(n) => n
case _ => "Unknown?"
}
val msg = linecode.tree.productIterator.toList.last.toString.replaceAll("scala.*\\]", "").replaceAll(namez+"\\.this\\.", "").replaceAll("List", "")
reify(myPrintDln(c.Expr[String](Literal(Constant(msg))).splice+" ---> "+linecode.splice))
}
def myPrintIt(linecode: Any) = macro mprintx
called by the following program:
object Zabi2 extends App {
val l = "zab"
val kol = 345
var zub = List("2", 89)
val zubi = List(zub,l,kol)
printIt(l)
printIt(l, kol, (l, zub),zubi)
}
which prints:
l ---> zab
(l, kol, (l, zub), zubi) ---> (zab,345,(zab,List(2, 89)),List(List(2, 89), zab, 345))
Thanks in advance for your help.
Here is a macro to print expressions and their values:
package mymacro
import scala.annotation.compileTimeOnly
import scala.language.experimental.macros
import scala.reflect.macros.whitebox
#compileTimeOnly("DebugPrint is available only during compile-time")
class DebugPrint(val c: whitebox.Context) {
import c.universe._
def impl(args: c.Expr[Any]*): c.Tree = {
val sep = ", "
val colon = "="
val trees = args.map(expr => expr.tree).toList
val ctxInfo = s"${c.internal.enclosingOwner.fullName}:${c.enclosingPosition.line}: "
val treeLits = trees.zipWithIndex.map {
case (tree, i) => Literal(Constant((if (i != 0) sep else ctxInfo) + tree + colon))
}
q"""
System.err.println(StringContext(..$treeLits, "").s(..$trees))
"""
}
}
#compileTimeOnly("DebugPrint is available only during compile-time")
object DebugPrint {
def apply(args: Any*): Any = macro DebugPrint.impl
}
Example:
package myapp
import mymacro.DebugPrint
case class Person(name: String, age: Int)
object Main extends App {
val x = 5
val y = "example"
val person = Person("Derp", 20)
DebugPrint(x, y, person, person.name, person.age)
def func() = {
val x = 5
val y = "example"
val person = Person("Derp", 20)
DebugPrint(x, y, person, person.name, person.age)
}
func()
}
Output:
myapp.Main.<local Main>:12: Main.this.x=5, Main.this.y=example, Main.this.person=Person(Derp,20), Main.this.person.name=Derp, Main.this.person.age=20
myapp.Main.func:18: x=5, y=example, person=Person(Derp,20), person.name=Derp, person.age=20
Works well with scala 2.12.12.

Instantiating a case class with default args via reflection

I need to be able to instantiate various case classes through reflection, both by figuring out the argument types of the constructor, as well as invoking the constructor with all default arguments.
I've come as far as this:
import reflect.runtime.{universe => ru}
val m = ru.runtimeMirror(getClass.getClassLoader)
case class Bar(i: Int = 33)
val tpe = ru.typeOf[Bar]
val classBar = tpe.typeSymbol.asClass
val cm = m.reflectClass(classBar)
val ctor = tpe.declaration(ru.nme.CONSTRUCTOR).asMethod
val ctorm = cm.reflectConstructor(ctor)
// figuring out arg types
val arg1 = ctor.paramss.head.head
arg1.typeSignature =:= ru.typeOf[Int] // true
// etc.
// instantiating with given args
val p = ctorm(33)
Now the missing part:
val p2 = ctorm() // IllegalArgumentException: wrong number of arguments
So how can I create p2 with the default arguments of Bar, i.e. what would be Bar() without reflection.
So in the linked question, the :power REPL uses internal API, which means that defaultGetterName is not available, so we need to construct that from hand. An adoption from #som-snytt 's answer:
def newDefault[A](implicit t: reflect.ClassTag[A]): A = {
import reflect.runtime.{universe => ru, currentMirror => cm}
val clazz = cm.classSymbol(t.runtimeClass)
val mod = clazz.companionSymbol.asModule
val im = cm.reflect(cm.reflectModule(mod).instance)
val ts = im.symbol.typeSignature
val mApply = ts.member(ru.newTermName("apply")).asMethod
val syms = mApply.paramss.flatten
val args = syms.zipWithIndex.map { case (p, i) =>
val mDef = ts.member(ru.newTermName(s"apply$$default$$${i+1}")).asMethod
im.reflectMethod(mDef)()
}
im.reflectMethod(mApply)(args: _*).asInstanceOf[A]
}
case class Foo(bar: Int = 33)
val f = newDefault[Foo] // ok
Is this really the shortest path?
Not minimized... and not endorsing...
scala> import scala.reflect.runtime.universe
import scala.reflect.runtime.universe
scala> import scala.reflect.internal.{ Definitions, SymbolTable, StdNames }
import scala.reflect.internal.{Definitions, SymbolTable, StdNames}
scala> val ds = universe.asInstanceOf[Definitions with SymbolTable with StdNames]
ds: scala.reflect.internal.Definitions with scala.reflect.internal.SymbolTable with scala.reflect.internal.StdNames = scala.reflect.runtime.JavaUniverse#52a16a10
scala> val n = ds.newTermName("foo")
n: ds.TermName = foo
scala> ds.nme.defaultGetterName(n,1)
res1: ds.TermName = foo$default$1
Here's a working version that you can copy into your codebase:
import scala.reflect.api
import scala.reflect.api.{TypeCreator, Universe}
import scala.reflect.runtime.universe._
object Maker {
val mirror = runtimeMirror(getClass.getClassLoader)
var makerRunNumber = 1
def apply[T: TypeTag]: T = {
val method = typeOf[T].companion.decl(TermName("apply")).asMethod
val params = method.paramLists.head
val args = params.map { param =>
makerRunNumber += 1
param.info match {
case t if t <:< typeOf[Enumeration#Value] => chooseEnumValue(convert(t).asInstanceOf[TypeTag[_ <: Enumeration]])
case t if t =:= typeOf[Int] => makerRunNumber
case t if t =:= typeOf[Long] => makerRunNumber
case t if t =:= typeOf[Date] => new Date(Time.now.inMillis)
case t if t <:< typeOf[Option[_]] => None
case t if t =:= typeOf[String] && param.name.decodedName.toString.toLowerCase.contains("email") => s"random-$arbitrary#give.asia"
case t if t =:= typeOf[String] => s"arbitrary-$makerRunNumber"
case t if t =:= typeOf[Boolean] => false
case t if t <:< typeOf[Seq[_]] => List.empty
case t if t <:< typeOf[Map[_, _]] => Map.empty
// Add more special cases here.
case t if isCaseClass(t) => apply(convert(t))
case t => throw new Exception(s"Maker doesn't support generating $t")
}
}
val obj = mirror.reflectModule(typeOf[T].typeSymbol.companion.asModule).instance
mirror.reflect(obj).reflectMethod(method)(args:_*).asInstanceOf[T]
}
def chooseEnumValue[E <: Enumeration: TypeTag]: E#Value = {
val parentType = typeOf[E].asInstanceOf[TypeRef].pre
val valuesMethod = parentType.baseType(typeOf[Enumeration].typeSymbol).decl(TermName("values")).asMethod
val obj = mirror.reflectModule(parentType.termSymbol.asModule).instance
mirror.reflect(obj).reflectMethod(valuesMethod)().asInstanceOf[E#ValueSet].head
}
def convert(tpe: Type): TypeTag[_] = {
TypeTag.apply(
runtimeMirror(getClass.getClassLoader),
new TypeCreator {
override def apply[U <: Universe with Singleton](m: api.Mirror[U]) = {
tpe.asInstanceOf[U # Type]
}
}
)
}
def isCaseClass(t: Type) = {
t.companion.decls.exists(_.name.decodedName.toString == "apply") &&
t.decls.exists(_.name.decodedName.toString == "copy")
}
}
And, when you want to use it, you can call:
val user = Maker[User]
val user2 = Maker[User].copy(email = "someemail#email.com")
The code above generates arbitrary and unique values. The data aren't exactly randomised. It's best for using in tests.
It works with Enum and nested case class. You can also easily extend it to support some other special types.
Read our full blog post here: https://give.engineering/2018/08/24/instantiate-case-class-with-arbitrary-value.html
This is the most complete example how to create case class via reflection with default constructor parameters(Github source):
import scala.reflect.runtime.universe
import scala.reflect.internal.{Definitions, SymbolTable, StdNames}
object Main {
def newInstanceWithDefaultParameters(className: String): Any = {
val runtimeMirror: universe.Mirror = universe.runtimeMirror(getClass.getClassLoader)
val ds = universe.asInstanceOf[Definitions with SymbolTable with StdNames]
val classSymbol = runtimeMirror.staticClass(className)
val classMirror = runtimeMirror.reflectClass(classSymbol)
val moduleSymbol = runtimeMirror.staticModule(className)
val moduleMirror = runtimeMirror.reflectModule(moduleSymbol)
val moduleInstanceMirror = runtimeMirror.reflect(moduleMirror.instance)
val defaultValueMethodSymbols = moduleMirror.symbol.info.members
.filter(_.name.toString.startsWith(ds.nme.defaultGetterName(ds.newTermName("apply"), 1).toString.dropRight(1)))
.toSeq
.reverse
.map(_.asMethod)
val defaultValueMethods = defaultValueMethodSymbols.map(moduleInstanceMirror.reflectMethod).toList
val primaryConstructorMirror = classMirror.reflectConstructor(classSymbol.primaryConstructor.asMethod)
primaryConstructorMirror.apply(defaultValueMethods.map(_.apply()): _*)
}
def main(args: Array[String]): Unit = {
val instance = newInstanceWithDefaultParameters(classOf[Bar].getName)
println(instance)
}
}
case class Bar(i: Int = 33)