How to use chisel dsptools with floats - scala

I need to convert a Float32 into a Chisel FixedPoint, perform some computation and convert back FixedPoint to Float32.
For example, I need the following:
val a = 3.1F
val b = 2.2F
val res = a * b // REPL returns res: Float 6.82
Now, I do this:
import chisel3.experimental.FixedPoint
val fp_tpe = FixedPoint(6.W, 2.BP)
val a_fix = a.Something (fp_tpe) // convert a to FixPoint
val b_fix = b.Something (fp_tpe) // convert b to FixPoint
val res_fix = a_fix * b_fix
val res0 = res_fix.Something (fp_tpe) // convert back to Float
As a result, I'd expect the delta to be in a range of , e.g
val eps = 1e-4
assert ( abs(res - res0) < eps, "The error is too big")
Who can provide a working example for Chisel3 FixedPoint class for the pseudocode above?

Take a look at the following code:
import chisel3._
import chisel3.core.FixedPoint
import dsptools._
class FPMultiplier extends Module {
val io = IO(new Bundle {
val a = Input(FixedPoint(6.W, binaryPoint = 2.BP))
val b = Input(FixedPoint(6.W, binaryPoint = 2.BP))
val c = Output(FixedPoint(12.W, binaryPoint = 4.BP))
})
io.c := io.a * io.b
}
class FPMultiplierTester(c: FPMultiplier) extends DspTester(c) {
//
// This will PASS, there is sufficient precision to model the inputs
//
poke(c.io.a, 3.25)
poke(c.io.b, 2.5)
step(1)
expect(c.io.c, 8.125)
//
// This will FAIL, there is not sufficient precision to model the inputs
// But this is only caught on output, this is likely the right approach
// because you can't really pass in wrong precision data in hardware.
//
poke(c.io.a, 3.1)
poke(c.io.b, 2.2)
step(1)
expect(c.io.c, 6.82)
}
object FPMultiplierMain {
def main(args: Array[String]): Unit = {
iotesters.Driver.execute(Array("-fiv"), () => new FPMultiplier) { c =>
new FPMultiplierTester(c)
}
}
}
I'd also suggest looking at ParameterizedAdder in dsptools, that gives you a feel of how to write hardware modules that you pass different types. Generally you start with DspReals, confirm the model then start experimenting/calculating with FixedPoint sizes that return results with the desired precision.

For others benefit, I provide an improved solution from #Chick, rewritten in a more abstract Scala with variable DSP tolerances.
package my_pkg
import chisel3._
import chisel3.core.{FixedPoint => FP}
import dsptools.{DspTester, DspTesterOptions, DspTesterOptionsManager}
class FPGenericIO (inType:FP, outType:FP) extends Bundle {
val a = Input(inType)
val b = Input(inType)
val c = Output(outType)
}
class FPMul (inType:FP, outType:FP) extends Module {
val io = IO(new FPGenericIO(inType, outType))
io.c := io.a * io.b
}
class FPMulTester(c: FPMul) extends DspTester(c) {
val uut = c.io
// This will PASS, there is sufficient precision to model the inputs
poke(uut.a, 3.25)
poke(uut.b, 2.5)
step(1)
expect(uut.c, 3.25*2.5)
// This will FAIL, if you won't increase tolerance, which is eps = 0.0 by default
poke(uut.a, 3.1)
poke(uut.b, 2.2)
step(1)
expect(uut.c, 3.1*2.2)
}
object FPUMain extends App {
val fpInType = FP(8.W, 4.BP)
val fpOutType = FP(12.W, 6.BP)
// Update default DspTester options and increase tolerance
val opts = new DspTesterOptionsManager {
dspTesterOptions = DspTesterOptions(
fixTolLSBs = 2,
genVerilogTb = false,
isVerbose = true
)
}
dsptools.Driver.execute (() => new FPMul(fpInType, fpOutType), opts) {
c => new FPMulTester(c)
}
}

Here's my ultimate DSP multiplier implementation, which should support both FixedPoint and DspComplex numbers. #ChickMarkley, how do I update this class to implement a Complex multiplication?
package my_pkg
import chisel3._
import dsptools.numbers.{Ring,DspComplex}
import dsptools.numbers.implicits._
import dsptools.{DspContext}
import chisel3.core.{FixedPoint => FP}
import dsptools.{DspTester, DspTesterOptions, DspTesterOptionsManager}
class FPGenericIO[A <: Data:Ring, B <: Data:Ring] (inType:A, outType:B) extends Bundle {
val a = Input(inType.cloneType)
val b = Input(inType.cloneType)
val c = Output(outType.cloneType)
override def cloneType = (new FPGenericIO(inType, outType)).asInstanceOf[this.type]
}
class FPMul[A <: Data:Ring, B <: Data:Ring] (inType:A, outType:B) extends Module {
val io = IO(new FPGenericIO(inType, outType))
DspContext.withNumMulPipes(3) {
io.c := io.a * io.b
}
}
class FPMulTester[A <: Data:Ring, B <: Data:Ring](c: FPMul[A,B]) extends DspTester(c) {
val uut = c.io
//
// This will PASS, there is sufficient precision to model the inputs
//
poke(uut.a, 3.25)
poke(uut.b, 2.5)
step(1)
expect(uut.c, 3.25*2.5)
//
// This will FAIL, there is not sufficient precision to model the inputs
// But this is only caught on output, this is likely the right approach
// because you can't really pass in wrong precision data in hardware.
//
poke(uut.a, 3.1)
poke(uut.b, 2.2)
step(1)
expect(uut.c, 3.1*2.2)
}
object FPUMain extends App {
val fpInType = FP(8.W, 4.BP)
val fpOutType = FP(12.W, 6.BP)
//val comp = DspComplex[Double] // How to declare a complex DSP type ?
val opts = new DspTesterOptionsManager {
dspTesterOptions = DspTesterOptions(
fixTolLSBs = 0,
genVerilogTb = false,
isVerbose = true
)
}
dsptools.Driver.execute (() => new FPMul(fpInType, fpOutType), opts) {
//dsptools.Driver.execute (() => new FPMul(comp, comp), opts) { // <-- this won't compile
c => new FPMulTester(c)
}
}

Related

Apache Flink (Scala) Rate-Balanced Source Function

In advance, I know this question is a little long; but I did my best to simplify it and make it approachable. Please provide feedback and I will try to act on it.
I am new to Flink, and I am trying to use it for a pipeline that will update each item of a large (TBs) dataset continuously. I wish to update some higher priority items more often but want to update all items as fast as possible.
Would it be possible to have a different source for each priority Tier (e.g. high, medium, low) to start the update process and read from the higher-priority sources more often? The other approach I've thought of is a custom SourceFunction that would have a reader for each file and emit according to the rates I set. The first approach didn't seem feasable, so I am trying the second but am stuck.
This is what I've tried so far:
import scala.collection.mutable
import org.apache.flink.streaming.api.functions.source.SourceFunction
import org.apache.hadoop.fs.{FileSystem, Path}
import java.util.concurrent.locks.ReadWriteLock
import java.util.concurrent.locks.ReentrantReadWriteLock
import java.util.concurrent.atomic.AtomicBoolean
/** Implements logic to ensure higher-priority values are pushed more often.
*
* The relative priority is the relative run frequency. For example, if tier C
* is half the priority of tier B, then C is pushed half as often.
*
* Type Parameters: \- T: Tier Type (e.g. enum) \- V: Emitted
* Value Type (e.g. ItemToUpdate)
*
* How to read objects from CSV rows:
* - https://nightlies.apache.org/flink/flink-docs-release-1.15/docs/connectors/datastream/formats/csv/#:~:text=Flink%20supports%20reading%20CSV%20files%20using%20CsvReaderFormat.%20The,configuration%20for%20the%20CSV%20schema%20and%20parsing%20options.
*
* How to read objects from Parquet:
* - https://nightlies.apache.org/flink/flink-docs-release-1.16/docs/connectors/datastream/formats/parquet/#:~:text=Flink%20supports%20reading%20Parquet%20files%2C%20producing%20Flink%20RowData,to%20your%20project%3A%20%3Cdependency%3E%20%3CgroupId%3Eorg.apache.flink%3C%2FgroupId%3E%20%3CartifactId%3Eflink-parquet%3C%2FartifactId%3E%20%3Cversion%3E1.16.0%3C%2Fversion%3E%20%3C%2Fdependency%3E
*
*
*
*/
abstract class TierPriorityEmitter[T, V](
val emitRates: mutable.Map[T, Float],
val updateCounts: mutable.Map[T, Long],
val iterators: mutable.Map[T, Iterator[V]]
) extends SourceFunction[V] {
#transient
private val running = new AtomicBoolean(false)
def this(
emitRates: Map[T, Float]
) {
this(
mutable.Map() ++= emitRates,
mutable.Map() ++= (for ((k, v) <- emitRates) yield (k, 0L)).toMap,
mutable.Map.empty[T, Iterator[V]]
)
for ((k, v) <- emitRates) {
val it = getTierIterator(k)
if (it.hasNext) {
iterators.put(k, it)
} else {
updateCounts.remove(k)
this.emitRates.remove(k)
}
}
// rescale rates in case any dropped or priorities given instead
val s = this.emitRates.foldLeft(0.0)(_ + _._2).asInstanceOf[Float] // sum
for ((k, v) <- this.emitRates) {
this.emitRates.put(k, v / s)
}
}
/** Return iterator for reading a specific priority tier from source file.
* Called again every pass over each priority tier.
*
* For example, may return an iterator to a specific file or an iterator that
* filters to the specified tier. Must avoid storing the entire file in memory
* as it is prohibitively large. TODO FIXME how to read?
*/
def getTierIterator(tier: T): Iterator[V] = ???
/** Determine which class of road should be emitted next.
*/
def nextToEmit(): T = {
/*
```python
total_updates = sum(updateCounts.values())
if total_updates == 0:
return max(emitRates.keys(), key=lambda k: emitRates[k])
actualEmitRates = {k: v/total_updates for k, v in updateCounts}
return min(emitRates.keys(), key=lambda k: actualEmitRates[k] - emitRates[k])
```
*/
val totalUpdates = updateCounts.foldLeft(0.0)(_ + _._2) // sum
if (totalUpdates == 0) {
return emitRates.maxBy(_._2)._1
}
val actualEmitRates =
(for ((k, v) <- updateCounts) yield (k, v / totalUpdates)).toMap
return emitRates.minBy(t => actualEmitRates(t._1) - t._2)._1
}
/** Emit to trigger updates.
*
* #param ctx
*/
override def run(ctx: SourceFunction.SourceContext[V]): Unit = {
running.set(true)
while (running.get()) {
val tier = nextToEmit()
val it = iterators(tier)
// zero length iterators are removed in the constructor
ctx.collect(it.next())
updateCounts.put(tier, updateCounts(tier) + 1)
if (!it.hasNext) {
iterators.put(tier, getTierIterator(tier))
}
}
}
override def cancel(): Unit = {
running.set(false)
}
}
I have this unit test.
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import org.scalatest.BeforeAndAfter
import org.apache.flink.api.common.time.Time
import org.apache.flink.streaming.util.SourceOperatorTestHarness
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.common.typeinfo.TypeHint
import org.apache.flink.streaming.api.watermark.Watermark
import org.apache.flink.streaming.api.operators.SourceOperatorFactory
import org.apache.flink.api.common.eventtime.WatermarkStrategy
import org.apache.flink.runtime.operators.testutils.MockEnvironmentBuilder
import org.apache.flink.api.connector.source.Source
import org.apache.flink.api.connector.source.SourceReader
import scala.collection.mutable.Queue
import org.apache.flink.test.util.MiniClusterWithClientResource
import org.apache.flink.test.util.MiniClusterResourceConfiguration
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment
/** Class for testing abstract TierPriorityEmitter
*/
class StringPriorityEmitter(priorities: Map[Char, Float]) extends TierPriorityEmitter[Char, String](priorities) with Serializable {
override def getTierIterator(tier: Char): Iterator[String] = {
return (for (i <- (1 to 9)) yield tier + i.toString()).iterator
}
}
class TestTierPriorityEmitter
extends AnyFlatSpec
with Matchers with BeforeAndAfter {
val flinkCluster = new MiniClusterWithClientResource(
new MiniClusterResourceConfiguration.Builder()
.setNumberSlotsPerTaskManager(1)
.setNumberTaskManagers(1)
.build
)
before {
flinkCluster.before()
}
after {
flinkCluster.after()
}
"TierPriorityEmitter" should "emit according to priority" in {
// be sure emit rates are floats and sum to 1
val emitter: TierPriorityEmitter[Char, String] = new StringPriorityEmitter(Map('A' -> 3f, 'B' -> 2f, 'C' -> 1f))
// should set emit rates based on our priorities
(emitter.emitRates('A') > emitter.emitRates('B') && emitter.emitRates('B') > emitter.emitRates('C') && emitter.emitRates('C') > 0) shouldBe(true)
(emitter.emitRates('A') + emitter.emitRates('B') + emitter.emitRates('C')) shouldEqual(1.0)
(emitter.emitRates('A') / emitter.emitRates('C')) shouldEqual(3.0)
// should output according to the assigned rates
// TODO use a SourceOperatorTestHarness instead. I couldn't get one to work.
val res = new Queue[String]()
for (_ <- 1 to 15) {
val tier = emitter.nextToEmit()
val it = emitter.iterators(tier)
// zero length iterators are removed in the constructor
res.enqueue(it.next())
emitter.updateCounts.put(tier, emitter.updateCounts(tier) + 1)
if (!it.hasNext) {
emitter.iterators.put(tier, emitter.getTierIterator(tier))
}
}
res shouldEqual(Queue("A1", "B1", "C1", "A2", "B2", "A3", "B3", "A4", "C2", "A5", "B4", "A6", "B5", "A7", "C3"))
}
"TierPriorityEmitter" should "be a source" in {
val env = StreamExecutionEnvironment.getExecutionEnvironment
env.setParallelism(1) // FIXME not enough resources for 2
val emitter: TierPriorityEmitter[Char, String] = new StringPriorityEmitter(Map('A' -> 3, 'B' -> 2, 'C' -> 1))
// FIXME not serializable
val source = env.addSource(
emitter
)
source.returns(TypeInformation.of(new TypeHint[String]() {}))
val sink = new TestSink[String]("Strings")
source.addSink(sink)
// Execute program, beginning computation.
env.execute("Test TierPriorityEmitter")
sink.getResults should contain allOf ("A1", "B1", "C1", "A2", "B2", "A3", "B3", "A4", "C2", "A5", "B4", "A6", "B5", "A7", "C3")
}
}
There are a couple obvious problems that I can't figure out how to resolve:
How to read the input files to implement the TierPriorityEmitter as a working (serializable) SourceFunction?
I know that Iterators probably aren't the right class to use, but I am not sure what to try next. I am not very experienced with Scala.
Optimization advice?
Although the mutex leads to synchronization rate balancing should still result in the best downstream performance as later steps will be slower. Any advice on how to improve this would be appreciated.

fixed point support on Chisel hdl

I am new to Chisel HDL, and I found the Chisel HDL does provide fixed point respresentation. (I found this link:
Fixed Point Arithmetic in Chisel HDL)
when I try it in the chisel hdl it actually doesn't work:
import Chisel._
class Toy extends Module {
val io = new Bundle {
val in0 = SFix(4, 12).asInput
val in1 = SFix(4, 12).asInput
val out = SFix(4, 16).asOutput
val oraw = Bits(OUTPUT, width=128)
}
val int_result = -io.in0 * (io.in0 + io.in1)
io.out := int_result
io.oraw := int_result.raw
}
class ToyTest(c: Toy) extends Tester(c) {
for (i <- 0 until 20) {
val i0 = 0.5
val i1 = 0.25
poke(c.io.in0, i0)
poke(c.io.in1, i1)
val res = -i0 * (i0+i1)
step(1)
expect(c.io.out, res)
}
}
object Toy {
def main(args: Array[String]): Unit = {
val tutArgs = args.slice(1, args.length)
chiselMainTest(tutArgs, () => Module(new Toy())) {
c => new ToyTest(c)
}
}
}
which produce the error that:
In my build.sbt file, I choose the latest release chisel by:
libraryDependencies += "edu.berkeley.cs" %% "chisel" % "latest.release"
According to Chisel code SFix seems to be deprecated, Fixed should be used instead.
I modified your code to use it, but there is a problem with poke and expect. It seems that Fixed is not supported yet by poke and expect.
import Chisel._
class Toy extends Module {
val io = new Bundle {
val in0 = Fixed(INPUT, 4, 12)
val in1 = Fixed(INPUT, 4, 12)
val out = Fixed(OUTPUT, 8, 24)
val oraw = Bits(OUTPUT, width=128)
}
val int_result = -io.in0 * (io.in0 + io.in1)
io.out := int_result
io.oraw := int_result.asUInt()
}
class ToyTest(c: Toy) extends Tester(c) {
for (i <- 0 until 20) {
val i0 = Fixed(0.5, 4, 12)
val i1 = Fixed(0.25, 4, 12)
c.io.in0 := i0
c.io.in1 := i1
//poke(c.io.in0, i0)
//poke(c.io.in1, i1)
val res = -i0 * (i0+i1)
step(1)
//expect(c.io.out, res)
}
}
object Toy {
def main(args: Array[String]): Unit = {
val tutArgs = args.slice(1, args.length)
chiselMainTest(tutArgs, () => Module(new Toy())) {
c => new ToyTest(c)
}
}
}

Akka Streams: State in a flow

I want to read multiple big files using Akka Streams to process each line. Imagine that each key consists of an (identifier -> value). If a new identifier is found, I want to save it and its value in the database; otherwise, if the identifier has already been found while processing the stream of lines, I want to save only the value. For that, I think that I need some kind of recursive stateful flow in order to keep the identifiers that have already been found in a Map. I think I'd receive in this flow a pair of (newLine, contextWithIdentifiers).
I've just started to look into Akka Streams. I guess I can manage myself to do the stateless processing stuff but I have no clue about how to keep the contextWithIdentifiers. I'd appreciate any pointers to the right direction.
Maybe something like statefulMapConcat can help you:
import akka.actor.ActorSystem
import akka.stream.ActorMaterializer
import akka.stream.scaladsl.{Sink, Source}
import scala.util.Random._
import scala.math.abs
import scala.concurrent.ExecutionContext.Implicits.global
implicit val system = ActorSystem()
implicit val materializer = ActorMaterializer()
//encapsulating your input
case class IdentValue(id: Int, value: String)
//some random generated input
val identValues = List.fill(20)(IdentValue(abs(nextInt()) % 5, "valueHere"))
val stateFlow = Flow[IdentValue].statefulMapConcat{ () =>
//state with already processed ids
var ids = Set.empty[Int]
identValue => if (ids.contains(identValue.id)) {
//save value to DB
println(identValue.value)
List(identValue)
} else {
//save both to database
println(identValue)
ids = ids + identValue.id
List(identValue)
}
}
Source(identValues)
.via(stateFlow)
.runWith(Sink.seq)
.onSuccess { case identValue => println(identValue) }
A few years later, here is an implementation I wrote if you only need a 1-to-1 mapping (not 1-to-N):
import akka.stream.stage.{GraphStage, GraphStageLogic}
import akka.stream.{Attributes, FlowShape, Inlet, Outlet}
object StatefulMap {
def apply[T, O](converter: => T => O) = new StatefulMap[T, O](converter)
}
class StatefulMap[T, O](converter: => T => O) extends GraphStage[FlowShape[T, O]] {
val in = Inlet[T]("StatefulMap.in")
val out = Outlet[O]("StatefulMap.out")
val shape = FlowShape.of(in, out)
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) {
val f = converter
setHandler(in, () => push(out, f(grab(in))))
setHandler(out, () => pull(in))
}
}
Test (and demo):
behavior of "StatefulMap"
class Counter extends (Any => Int) {
var count = 0
override def apply(x: Any): Int = {
count += 1
count
}
}
it should "not share state among substreams" in {
val result = await {
Source(0 until 10)
.groupBy(2, _ % 2)
.via(StatefulMap(new Counter()))
.fold(Seq.empty[Int])(_ :+ _)
.mergeSubstreams
.runWith(Sink.seq)
}
result.foreach(_ should be(1 to 5))
}

Chisel: Access to Module Parameters from Tester

How does one access the parameters used to construct a Module from inside the Tester that is testing it?
In the test below I am passing the parameters explicitly both to the Module and to the Tester. I would prefer not to have to pass them to the Tester but instead extract them from the module that was also passed in.
Also I am new to scala/chisel so any tips on bad techniques I'm using would be appreciated :).
import Chisel._
import math.pow
class TestA(dataWidth: Int, arrayLength: Int) extends Module {
val dataType = Bits(INPUT, width = dataWidth)
val arrayType = Vec(gen = dataType, n = arrayLength)
val io = new Bundle {
val i_valid = Bool(INPUT)
val i_data = dataType
val i_array = arrayType
val o_valid = Bool(OUTPUT)
val o_data = dataType.flip
val o_array = arrayType.flip
}
io.o_valid := io.i_valid
io.o_data := io.i_data
io.o_array := io.i_array
}
class TestATests(c: TestA, dataWidth: Int, arrayLength: Int) extends Tester(c) {
val maxData = pow(2, dataWidth).toInt
for (t <- 0 until 16) {
val i_valid = rnd.nextInt(2)
val i_data = rnd.nextInt(maxData)
val i_array = List.fill(arrayLength)(rnd.nextInt(maxData))
poke(c.io.i_valid, i_valid)
poke(c.io.i_data, i_data)
(c.io.i_array, i_array).zipped foreach {
(element,value) => poke(element, value)
}
expect(c.io.o_valid, i_valid)
expect(c.io.o_data, i_data)
(c.io.o_array, i_array).zipped foreach {
(element,value) => poke(element, value)
}
step(1)
}
}
object TestAObject {
def main(args: Array[String]): Unit = {
val tutArgs = args.slice(0, args.length)
val dataWidth = 5
val arrayLength = 6
chiselMainTest(tutArgs, () => Module(
new TestA(dataWidth=dataWidth, arrayLength=arrayLength))){
c => new TestATests(c, dataWidth=dataWidth, arrayLength=arrayLength)
}
}
}
If you make the arguments dataWidth and arrayLength members of TestA you can just reference them. In Scala this can be accomplished by inserting val into the argument list:
class TestA(val dataWidth: Int, val arrayLength: Int) extends Module ...
Then you can reference them from the test as members with c.dataWidth or c.arrayLength

Creating serializable objects from Scala source code at runtime

To embed Scala as a "scripting language", I need to be able to compile text fragments to simple objects, such as Function0[Unit] that can be serialised to and deserialised from disk and which can be loaded into the current runtime and executed.
How would I go about this?
Say for example, my text fragment is (purely hypothetical):
Document.current.elements.headOption.foreach(_.open())
This might be wrapped into the following complete text:
package myapp.userscripts
import myapp.DSL._
object UserFunction1234 extends Function0[Unit] {
def apply(): Unit = {
Document.current.elements.headOption.foreach(_.open())
}
}
What comes next? Should I use IMain to compile this code? I don't want to use the normal interpreter mode, because the compilation should be "context-free" and not accumulate requests.
What I need to get hold off from the compilation is I guess the binary class file? In that case, serialisation is straight forward (byte array). How would I then load that class into the runtime and invoke the apply method?
What happens if the code compiles to multiple auxiliary classes? The example above contains a closure _.open(). How do I make sure I "package" all those auxiliary things into one object to serialize and class-load?
Note: Given that Scala 2.11 is imminent and the compiler API probably changed, I am happy to receive hints as how to approach this problem on Scala 2.11
Here is one idea: use a regular Scala compiler instance. Unfortunately it seems to require the use of hard disk files both for input and output. So we use temporary files for that. The output will be zipped up in a JAR which will be stored as a byte array (that would go into the hypothetical serialization process). We need a special class loader to retrieve the class again from the extracted JAR.
The following assumes Scala 2.10.3 with the scala-compiler library on the class path:
import scala.tools.nsc
import java.io._
import scala.annotation.tailrec
Wrapping user provided code in a function class with a synthetic name that will be incremented for each new fragment:
val packageName = "myapp"
var userCount = 0
def mkFunName(): String = {
val c = userCount
userCount += 1
s"Fun$c"
}
def wrapSource(source: String): (String, String) = {
val fun = mkFunName()
val code = s"""package $packageName
|
|class $fun extends Function0[Unit] {
| def apply(): Unit = {
| $source
| }
|}
|""".stripMargin
(fun, code)
}
A function to compile a source fragment and return the byte array of the resulting jar:
/** Compiles a source code consisting of a body which is wrapped in a `Function0`
* apply method, and returns the function's class name (without package) and the
* raw jar file produced in the compilation.
*/
def compile(source: String): (String, Array[Byte]) = {
val set = new nsc.Settings
val d = File.createTempFile("temp", ".out")
d.delete(); d.mkdir()
set.d.value = d.getPath
set.usejavacp.value = true
val compiler = new nsc.Global(set)
val f = File.createTempFile("temp", ".scala")
val out = new BufferedOutputStream(new FileOutputStream(f))
val (fun, code) = wrapSource(source)
out.write(code.getBytes("UTF-8"))
out.flush(); out.close()
val run = new compiler.Run()
run.compile(List(f.getPath))
f.delete()
val bytes = packJar(d)
deleteDir(d)
(fun, bytes)
}
def deleteDir(base: File): Unit = {
base.listFiles().foreach { f =>
if (f.isFile) f.delete()
else deleteDir(f)
}
base.delete()
}
Note: Doesn't handle compiler errors yet!
The packJar method uses the compiler output directory and produces an in-memory jar file from it:
// cf. http://stackoverflow.com/questions/1281229
def packJar(base: File): Array[Byte] = {
import java.util.jar._
val mf = new Manifest
mf.getMainAttributes.put(Attributes.Name.MANIFEST_VERSION, "1.0")
val bs = new java.io.ByteArrayOutputStream
val out = new JarOutputStream(bs, mf)
def add(prefix: String, f: File): Unit = {
val name0 = prefix + f.getName
val name = if (f.isDirectory) name0 + "/" else name0
val entry = new JarEntry(name)
entry.setTime(f.lastModified())
out.putNextEntry(entry)
if (f.isFile) {
val in = new BufferedInputStream(new FileInputStream(f))
try {
val buf = new Array[Byte](1024)
#tailrec def loop(): Unit = {
val count = in.read(buf)
if (count >= 0) {
out.write(buf, 0, count)
loop()
}
}
loop()
} finally {
in.close()
}
}
out.closeEntry()
if (f.isDirectory) f.listFiles.foreach(add(name, _))
}
base.listFiles().foreach(add("", _))
out.close()
bs.toByteArray
}
A utility function that takes the byte array found in deserialization and creates a map from class names to class byte code:
def unpackJar(bytes: Array[Byte]): Map[String, Array[Byte]] = {
import java.util.jar._
import scala.annotation.tailrec
val in = new JarInputStream(new ByteArrayInputStream(bytes))
val b = Map.newBuilder[String, Array[Byte]]
#tailrec def loop(): Unit = {
val entry = in.getNextJarEntry
if (entry != null) {
if (!entry.isDirectory) {
val name = entry.getName
// cf. http://stackoverflow.com/questions/8909743
val bs = new ByteArrayOutputStream
var i = 0
while (i >= 0) {
i = in.read()
if (i >= 0) bs.write(i)
}
val bytes = bs.toByteArray
b += mkClassName(name) -> bytes
}
loop()
}
}
loop()
in.close()
b.result()
}
def mkClassName(path: String): String = {
require(path.endsWith(".class"))
path.substring(0, path.length - 6).replace("/", ".")
}
A suitable class loader:
class MemoryClassLoader(map: Map[String, Array[Byte]]) extends ClassLoader {
override protected def findClass(name: String): Class[_] =
map.get(name).map { bytes =>
println(s"defineClass($name, ...)")
defineClass(name, bytes, 0, bytes.length)
} .getOrElse(super.findClass(name)) // throws exception
}
And a test case which contains additional classes (closures):
val exampleSource =
"""val xs = List("hello", "world")
|println(xs.map(_.capitalize).mkString(" "))
|""".stripMargin
def test(fun: String, cl: ClassLoader): Unit = {
val clName = s"$packageName.$fun"
println(s"Resolving class '$clName'...")
val clazz = Class.forName(clName, true, cl)
println("Instantiating...")
val x = clazz.newInstance().asInstanceOf[() => Unit]
println("Invoking 'apply':")
x()
}
locally {
println("Compiling...")
val (fun, bytes) = compile(exampleSource)
val map = unpackJar(bytes)
println("Classes found:")
map.keys.foreach(k => println(s" '$k'"))
val cl = new MemoryClassLoader(map)
test(fun, cl) // should call `defineClass`
test(fun, cl) // should find cached class
}