Scala test match case classes on some field - scala

I have
case class Foo(field1: String, field2: String, field3: String)
expected: Seq[Foo] = Seq(...)
result: Seq[Foo] = Seq(...)
I want to write a Scala matcher that compare the elements of Seq without taking in consideration field1.
So two elements will match if field2 and field3 are equals. Field1 value is not considerated.
Example:
val expected = Seq(
Foo(
"f1", "f2", "f3"
)
)
val result = Seq(
Foo(
"fx", "f2", "f3"
)
)
This two sequence has to match.
result should matchWithoutId(expected)

The documentation is in Using custom Matchers, this example would be (for example)
def matchWithoutId(expected: Foo): Matcher[Foo] = Matcher { actual =>
MatchResult(
actual.field2 == expected.field2 && actual.field3 == expected.field3,
if (actual.field2 != expected.field2)
s"field2 of $actual was not equal to that of $expected"
else
s"field3 of $actual was not equal to that of $expected",
s"field2 and field3 of $actual were equal to those of $expected")
}
Adjust error messages to taste.
Or another approach (probably a better one in this case):
def matchWithoutId(expected: Foo): Matcher[Foo] = have(
'field2 (expected.field2),
'field3 (expected.field3)
)
Normally I'd say using Symbols to name properties and use reflection should be avoided, but here a change to Foo will make the tests fail to compile anyway because of access to expected's fields.
EDIT: I missed that you want to compare Seq[Foo]s and not Foos, so Mario Galic's answer is probably the one you want. Still, this can hopefully be useful as well.

Working with "sequences" states
if you want to change how containership is determined for an element
type E, place an implicit Equality[E] in scope or use the explicitly
DSL.
so the following should work
(expected should contain theSameElementsAs (result)) (decided by fooEqualityWithoutId)
Note about contains usage
Note that when you use the explicitly DSL with contain you need to
wrap the entire contain expression in parentheses
Here is the full example
import org.scalactic._
import org.scalatest._
import org.scalatest.matchers.should.Matchers
import org.scalactic.Explicitly._
class CustomizeEqualitySeqSpec extends FlatSpec with Matchers with Explicitly {
case class Foo(field1: String, field2: String, field3: String)
val expected = Seq(Foo("f1", "f2", "f3"))
val result = Seq(Foo("fx", "f2", "f3"))
val fooEqualityWithoutId = new Equality[Foo] {
override def areEqual(a: Foo, b: Any): Boolean = b match {
case Foo(_, field2, field3) => a.field2 == field2 && a.field3 == field3;
case _ => false
}
}
"Sequences" should "use custom Equality[E]" in {
(expected should contain theSameElementsAs (result)) (decided by fooEqualityWithoutId)
}
}

Related

Nested Scala case classes to/from CSV

There are many nice libraries for writing/reading Scala case classes to/from CSV files. I'm looking for something that goes beyond that, which can handle nested cases classes. For example, here a Match has two Players:
case class Player(name: String, ranking: Int)
case class Match(place: String, winner: Player, loser: Player)
val matches = List(
Match("London", Player("Jane",7), Player("Fred",23)),
Match("Rome", Player("Marco",19), Player("Giulia",3)),
Match("Paris", Player("Isabelle",2), Player("Julien",5))
)
I'd like to effortlessly (no boilerplate!) write/read matches to/from this CSV:
place,winner.name,winner.ranking,loser.name,loser.ranking
London,Jane,7,Fred,23
Rome,Marco,19,Giulia,3
Paris,Isabelle,2,Julien,5
Note the automated header line using the dot "." to form the column name for a nested field, e.g. winner.ranking. I'd be delighted if someone could demonstrate a simple way to do this (say, using reflection or Shapeless).
[Motivation. During data analysis it's convenient to have a flat CSV to play around with, for sorting, filtering, etc., even when case classes are nested. And it would be nice if you could load nested case classes back from such files.]
Since a case-class is a Product, getting the values of the various fields is relatively easy. Getting the names of the fields/columns does require using Java reflection.
The following function takes a list of case-class instances and returns a list of rows, each is a list of strings. It is using a recursion to get the values and headers of child case-class instances.
def toCsv(p: List[Product]): List[List[String]] = {
def header(c: Class[_], prefix: String = ""): List[String] = {
c.getDeclaredFields.toList.flatMap { field =>
val name = prefix + field.getName
if (classOf[Product].isAssignableFrom(field.getType)) header(field.getType, name + ".")
else List(name)
}
}
def flatten(p: Product): List[String] =
p.productIterator.flatMap {
case p: Product => flatten(p)
case v: Any => List(v.toString)
}.toList
header(classOf[Match]) :: p.map(flatten)
}
However, constructing case-classes from CSV is far more involved, requiring to use reflection for getting the types of the various fields, for creating the values from the CSV strings and for constructing the case-class instances.
For simplicity (not saying the code is simple, just so it won't be further complicated), I assume that the order of columns in the CSV is the same as if the file was produced by the toCsv(...) function above.
The following function starts by creating a list of "instructions how to process a single CSV row" (the instructions are also used to verify that the column headers in the CSV matches the the case-class properties). The instructions are then used to recursively produce one CSV row at a time.
def fromCsv[T <: Product](csv: List[List[String]])(implicit tag: ClassTag[T]): List[T] = {
trait Instruction {
val name: String
val header = true
}
case class BeginCaseClassField(name: String, clazz: Class[_]) extends Instruction {
override val header = false
}
case class EndCaseClassField(name: String) extends Instruction {
override val header = false
}
case class IntField(name: String) extends Instruction
case class StringField(name: String) extends Instruction
case class DoubleField(name: String) extends Instruction
def scan(c: Class[_], prefix: String = ""): List[Instruction] = {
c.getDeclaredFields.toList.flatMap { field =>
val name = prefix + field.getName
val fType = field.getType
if (fType == classOf[Int]) List(IntField(name))
else if (fType == classOf[Double]) List(DoubleField(name))
else if (fType == classOf[String]) List(StringField(name))
else if (classOf[Product].isAssignableFrom(fType)) BeginCaseClassField(name, fType) :: scan(fType, name + ".")
else throw new IllegalArgumentException(s"Unsupported field type: $fType")
} :+ EndCaseClassField(prefix)
}
def produce(instructions: List[Instruction], row: List[String], argAccumulator: List[Any]): (List[Instruction], List[String], List[Any]) = instructions match {
case IntField(_) :: tail => produce(tail, row.drop(1), argAccumulator :+ row.head.toString.toInt)
case StringField(_) :: tail => produce(tail, row.drop(1), argAccumulator :+ row.head.toString)
case DoubleField(_) :: tail => produce(tail, row.drop(1), argAccumulator :+ row.head.toString.toDouble)
case BeginCaseClassField(_, clazz) :: tail =>
val (instructionRemaining, rowRemaining, constructorArgs) = produce(tail, row, List.empty)
val newCaseClass = clazz.getConstructors.head.newInstance(constructorArgs.map(_.asInstanceOf[AnyRef]): _*)
produce(instructionRemaining, rowRemaining, argAccumulator :+ newCaseClass)
case EndCaseClassField(_) :: tail => (tail, row, argAccumulator)
case Nil if row.isEmpty => (Nil, Nil, argAccumulator)
case Nil => throw new IllegalArgumentException("Not all values from CSV row were used")
}
val instructions = BeginCaseClassField(".", tag.runtimeClass) :: scan(tag.runtimeClass)
assert(csv.head == instructions.filter(_.header).map(_.name), "CSV header doesn't match target case-class fields")
csv.drop(1).map(row => produce(instructions, row, List.empty)._3.head.asInstanceOf[T])
}
I've tested this using:
case class Player(name: String, ranking: Int, price: Double)
case class Match(place: String, winner: Player, loser: Player)
val matches = List(
Match("London", Player("Jane", 7, 12.5), Player("Fred", 23, 11.1)),
Match("Rome", Player("Marco", 19, 13.54), Player("Giulia", 3, 41.8)),
Match("Paris", Player("Isabelle", 2, 31.7), Player("Julien", 5, 16.8))
)
val csv = toCsv(matches)
val matchesFromCsv = fromCsv[Match](csv)
assert(matches == matchesFromCsv)
Obviously this should be optimized and hardened if you ever want to use this for production...

Conditionally map a nullable field to None value in Slick

In Slick 2.1, I need to perform a query/map operation where I convert a nullable field to None if it contains a certain non-null value. Not sure whether it matters or not, but in my case the column type in question is a mapped column type. Here is a code snippet which tries to illustrate what I'm trying to do. It won't compile, as the compiler doesn't like the None.
case class Record(field1: Int, field2: Int, field3: MyEnum)
sealed trait MyEnum
val MyValue: MyEnum = new MyEnum { }
// table is a TableQuery[Record]
table.map { r => (
r.field1,
r.field2,
Case If (r.field3 === MyValue) Then MyValue Else None // compile error on 'None'
)
}
The error is something like this:
type mismatch; found : None.type required: scala.slick.lifted.Column[MyEnum]
Actually, the reason I want to do this is that I want to perform a groupBy in which I count the number of records whose field3 contains a given value. I couldn't get the more complicated groupBy expression working, so I backed off to this simpler example which I still can't get working. If there's a more direct way to show me the groupBy expression, that would be fine too. Thanks!
Update
I tried the code suggested by #cvogt but this produces a compile error. Here is a SSCCE in case anyone can spot what I'm doing wrong here. Compile fails with "value ? is not a member of Int":
import scala.slick.jdbc.JdbcBackend.Database
import scala.slick.driver.H2Driver
object ExpMain extends App {
val dbName = "mydb"
val db = Database.forURL(s"jdbc:h2:mem:${dbName};DB_CLOSE_DELAY=-1", driver = "org.h2.Driver")
val driver = H2Driver
import driver.simple._
class Exp(tag: Tag) extends Table[(Int, Option[Int])](tag, "EXP") {
def id = column[Int]("ID", O.PrimaryKey)
def value = column[Option[Int]]("VALUE")
def * = (id, value)
}
val exp = TableQuery[Exp]
db withSession { implicit session =>
exp.ddl.create
exp += (1, (Some(1)))
exp += (2, None)
exp += (3, (Some(4)))
exp.map { record =>
Case If (record.value === 1) Then 1.? Else None // this will NOT compile
//Case If (record.value === 1) Then Some(1) Else None // this will NOT compile
//Case If (record.value === 1) Then 1 Else 0 // this will compile
}.foreach {
println
}
}
}
I need to perform a query/map operation where I convert a nullable field to None if it contains a certain non-null value
Given the example data you have in the update, and pretending that 1 is the "certain" value you care about, I believe this is the output you expect:
None, None, Some(4)
(for rows with IDs 1, 2 and 3)
If I've understood the problem correctly, is this what you need...?
val q: Query[Column[Option[Int]], Option[Int], Seq] = exp.map { record =>
Case If (record.value === 1) Then (None: Option[Int]) Else (record.value)
}
...which equates to:
select (case when ("VALUE" = 1) then null else "VALUE" end) from "EXP"
You need to wrap MyValue in an Option, so that both outcomes of the conditional are options. In Slick 2.1 you use the .? operator for that. In Slick 3.0 it will likely be Rep.Some(...).
Try
Case If (r.field3 === MyValue) Then MyValue.? Else None
or
Case If (r.field3 === MyValue) Then MyValue.? Else (None: Option[MyEnum])

Scala Erasure Type Match and Use in Different Method

I have been searching around to achieve this, even with Manifest and Reflect API, it's still hard to achieve.
With Manifest and Reflection, I can match List[Any] to a class(List[A]), I am also able to get match by type T, just as in
http://daily-scala.blogspot.co.uk/2010/01/overcoming-type-erasure-in-matching-1.html
How save a TypeTag and then use it later to reattach the type to an Any (Scala 2.10)
but how can I make sure the type of the input and use it in a method?
Say,
object test {
val list : List[List[Any]] = List(
List(2.5, 3.6 ,7.9),
List("EUR","HKD", "USD")
)
def calculateString(in:List[String]) = {
println("It's a String List")
println(in)
}
def calculateDouble(in:List[String]) = {
println("It's a Double List")
println(in)
}
def main( args: Array[String]){
list.foreach(l=> matchAndCalculate(l))
}
// Copy from Andrzej Jozwik it doesn't work, but it's good to demonstrate the idea
def matchAndCalculate(list:List[Any]) = list match {
case i if i.headOption.exists(_.isInstanceOf[Long]) => calculateLong(i)
case i if i.headOption.exists(_.isInstanceOf[String]) => calculateString(i)
}
}
Many Thanks
Harvey
PS: As Sarah pointed out that it might be the only way that keeping type manifest while I create the list in the first before I put them into more complex structure.
Here's the challenge: is that possible to cast List[Any] back to / match to something say List[String] and as input to method like def dummyMethod(stringList: List[String]) without pissing off compiler?
Unless you can change your data structure, Andrej's solution is the only reasonable way to do this.
You can't really use type manifests, because you have two levels of indirection. You'd need a different type manifest for every member of the outer list. E.g., you could have a List[(List[Any], TypeTag[Any])], but there's no way to get compile-time information about every individual row out of a List unless you build that information at the time that you're constructing the lists.
If you really wanted to carry along static type information, it would be easy to do this with implicits and just make each entry in your outer list a special class.
One simple variant might look like this:
class CalculableList[A](val elements: List[A])(implicit val calc: Calculator[A]) {
def calculate = calc(elements)
}
trait Calculator[-A] extends (List[A] => Unit)
implicit object StringCalc extends Calculator[String] {
def apply(in: List[String]) {
println("It's a String List")
println(in)
}
}
implicit object DoubleCalc extends Calculator[Double] {
def apply(in: List[Double]) {
println("It's a Double List")
println(in)
}
}
val list: List[CalculableList[_]] = List(
new CalculableList(List(1.0, 2.0, 3.0)),
new CalculableList(List("a", "b", "c"))
)
list foreach { _.calculate }
Another option for this kind of generic programming is to use Miles Sabin's Shapeless. This uses special data structures to let you construct arbitrary-sized tuples that can be treated like type-safe lists. It generates a data structure similar to a linked list with a generic type wrapper that keeps track of the type of each row, so you wouldn't want to use it unless your lists are fairly short. It's also a bit difficult to learn, maintain and understand—but it opens up some deep wizardry when you understand and use it appropriately.
I don't know enough about your use case to know whether Shapeless is advisable in this case.
In Shapeless for Scala 2.11, a solution would look something like this:
import shapeless._
val lists = List(1.0, 2.0, 3.0) ::
List("a", "b", "c") ::
HNil
object calc extends Poly1 {
implicit def doubleList = at[List[Double]] { in =>
println("It's a double list")
println(in)
}
implicit def stringList = at[List[String]] { in =>
println("It's a string list")
println(in)
}
}
lists map calc
def calculateString(in:List[String]) = {
println("It's a String List")
println(in)
}
def calculateDouble(in:List[Double]){
println("It's a Double List")
println(in)
}
def castTo[T](t:T,list:List[Any]) = list.asInstanceOf[List[T]]
def matchAndCalculate(list:List[Any]) = list.headOption match {
case Some(x:Double) => calculateDouble(castTo(x,list))
case Some(x:String) => calculateString(castTo(x,list))
}
And check:
scala> matchAndCalculate(List(3.4))
It's a Double List
List(3.4)
scala> matchAndCalculate(List("3.4"))
It's a String List
List(3.4)
scala> val list : List[List[Any]] = List(
| List(2.5, 3.6 ,7.9),
| List("EUR","HKD", "USD")
| )
list: List[List[Any]] = List(List(2.5, 3.6, 7.9), List(EUR, HKD, USD))
scala> list.foreach(l=> matchAndCalculate(l))
It's a Double List
List(2.5, 3.6, 7.9)
It's a String List
List(EUR, HKD, USD)

Allocation of Function Literals in Scala

I have a class that represents sales orders:
class SalesOrder(val f01:String, val f02:Int, ..., f50:Date)
The fXX fields are of various types. I am faced with the problem of creating an audit trail of my orders. Given two instances of the class, I have to determine which fields have changed. I have come up with the following:
class SalesOrder(val f01:String, val f02:Int, ..., val f50:Date){
def auditDifferences(that:SalesOrder): List[String] = {
def diff[A](fieldName:String, getField: SalesOrder => A) =
if(getField(this) != getField(that)) Some(fieldName) else None
val diffList = diff("f01", _.f01) :: diff("f02", _.f02) :: ...
:: diff("f50", _.f50) :: Nil
diffList.flatten
}
}
I was wondering what the compiler does with all the _.fXX functions: are they instanced just once (statically), and can be shared by all instances of my class, or will they be instanced every time I create an instance of my class?
My worry is that, since I will use a lot of SalesOrder instances, it may create a lot of garbage. Should I use a different approach?
One clean way of solving this problem would be to use the standard library's Ordering type class. For example:
class SalesOrder(val f01: String, val f02: Int, val f03: Char) {
def diff(that: SalesOrder) = SalesOrder.fieldOrderings.collect {
case (name, ord) if !ord.equiv(this, that) => name
}
}
object SalesOrder {
val fieldOrderings: List[(String, Ordering[SalesOrder])] = List(
"f01" -> Ordering.by(_.f01),
"f02" -> Ordering.by(_.f02),
"f03" -> Ordering.by(_.f03)
)
}
And then:
scala> val orderA = new SalesOrder("a", 1, 'a')
orderA: SalesOrder = SalesOrder#5827384f
scala> val orderB = new SalesOrder("b", 1, 'b')
orderB: SalesOrder = SalesOrder#3bf2e1c7
scala> orderA diff orderB
res0: List[String] = List(f01, f03)
You almost certainly don't need to worry about the perfomance of your original formulation, but this version is (arguably) nicer for unrelated reasons.
Yes, that creates 50 short lived functions. I don't think you should be worried unless you have manifest evidence that that causes a performance problem in your case.
But I would define a method that transforms SalesOrder into a Map[String, Any], then you would just have
trait SalesOrder {
def fields: Map[String, Any]
}
def diff(a: SalesOrder, b: SalesOrder): Iterable[String] = {
val af = a.fields
val bf = b.fields
af.collect { case (key, value) if bf(key) != value => key }
}
If the field names are indeed just incremental numbers, you could simplify
trait SalesOrder {
def fields: Iterable[Any]
}
def diff(a: SalesOrder, b: SalesOrder): Iterable[String] =
(a.fields zip b.fields).zipWithIndex.collect {
case ((av, bv), idx) if av != bv => f"f${idx + 1}%02d"
}

Scala Macros: Checking for a certain annotation

Thanks to the answers to my previous question, I was able to create a function macro such that it returns a Map that maps each field name to its value of a class, e.g.
...
trait Model
case class User (name: String, age: Int, posts: List[String]) extends Model {
val numPosts: Int = posts.length
...
def foo = "bar"
...
}
So this command
val myUser = User("Foo", 25, List("Lorem", "Ipsum"))
myUser.asMap
returns
Map("name" -> "Foo", "age" -> 25, "posts" -> List("Lorem", "Ipsum"), "numPosts" -> 2)
This is where Tuples for the Map are generated (see Travis Brown's answer):
...
val pairs = weakTypeOf[T].declarations.collect {
case m: MethodSymbol if m.isAccessor =>
val name = c.literal(m.name.decoded)
val value = c.Expr(Select(model, m.name))
reify(name.splice -> value.splice).tree
}
...
Now I want to ignore fields that have #transient annotation. How would I check if a method has a #transient annotation?
I'm thinking of modifying the snippet above as
val pairs = weakTypeOf[T].declarations.collect {
case m: MethodSymbol if m.isAccessor && !m.annotations.exists(???) =>
val name = c.literal(m.name.decoded)
val value = c.Expr(Select(model, m.name))
reify(name.splice -> value.splice).tree
}
but I can't find what I need to write in exists part. How would I get #transient as an Annotation so I could pass it there?
Thanks in advance!
The annotation will be on the val itself, not on the accessor. The easiest way to access the val is through the accessed method on MethodSymbol:
def isTransient(m: MethodSymbol) = m.accessed.annotations.exists(
_.tpe =:= typeOf[scala.transient]
)
Now you can just write the following in your collect:
case m: MethodSymbol if m.isAccessor && !isTransient(m) =>
Note that the version of isTransient I've given here has to be defined in your macro, since it needs the imports from c.universe, but you could factor it out by adding a Universe argument if you're doing this kind of thing in several macros.