Conditionally apply a function in Scala - How to write this function? - scala

I need to conditionally apply a function f1 to the elements in a collection depending on the result of a function f2 that takes each element as an argument and returns a boolean. If f2(e) is true, f1(e) will be applied otherwise 'e' will be returned "as is".
My intent is to write a general-purpose function able to work on any kind of collection.
c: C[E] // My collection
f1 = ( E => E ) // transformation function
f2 = ( E => Boolean ) // conditional function
I cannot come to a solution. Here's my idea, but I'm afraid I'm in high-waters
/* Notice this code doesn't compile ~ partially pseudo-code */
conditionallyApply[E,C[_](c: C[E], f2: E => Boolean, f1: E => E): C[E] = {
#scala.annotation.tailrec
def loop(a: C[E], c: C[E]): C[E] = {
c match {
case Nil => a // Here head / tail just express the idea, but I want to use a generic collection
case head :: tail => go(a ++ (if f2(head) f1(head) else head ), tail)
}
}
loop(??, c) // how to get an empty collection of the same type as the one from the input?
}
Could any of you enlighten me?

This looks like a simple map of a Functor. Using scalaz:
def condMap[F[_],A](fa: F[A])(f: A => A, p: A => Boolean)(implicit F:Functor[F]) =
F.map(fa)(x => if (p(x)) f(x) else x)

Not sure why you would need scalaz for something so pedestrian.
// example collection and functions
val xs = 1 :: 2 :: 3 :: 4 :: Nil
def f1(v: Int) = v + 1
def f2(v: Int) = v % 2 == 0
// just conditionally transform inside a map
val transformed = xs.map(x => if (f2(x)) f1(x) else x)

Without using scalaz, you can use the CanBuildFrom pattern. This is exactly what is used in the standard collections library. Of course, in your specific case, this is probably over-engineered as a simple call to map is enough.
import scala.collection.generic._
def cmap[A, C[A] <: Traversable[A]](col: C[A])(f: A ⇒ A, p: A ⇒ Boolean)(implicit bf: CanBuildFrom[C[A], A, C[A]]): C[A] = {
val b = bf(col)
b.sizeHint(col)
for (x <- col) if(p(x)) b += f(x) else b += x
b.result
}
And now the usage:
scala> def f(i: Int) = 0
f: (i: Int)Int
scala> def p(i: Int) = i % 2 == 0
p: (i: Int)Boolean
scala> cmap(Seq(1, 2, 3, 4))(f, p)
res0: Seq[Int] = List(1, 0, 3, 0)
scala> cmap(List(1, 2, 3, 4))(f, p)
res1: List[Int] = List(1, 0, 3, 0)
scala> cmap(Set(1, 2, 3, 4))(f, p)
res2: scala.collection.immutable.Set[Int] = Set(1, 0, 3)
Observe how the return type is always the same as the one provided.
The function could be nicely encapsulated in an implicit class, using the "pimp my library" pattern.

For something like this you can use an implicit class. They were added just for this reason, to enhance libraries you can't change.
It would work like this:
object ImplicitStuff {
implicit class SeqEnhancer[A](s:Seq[A]) {
def transformIf( cond : A => Boolean)( f : A => A ):Seq[A] =
s.map{ x => if(cond(x)) f(x) else x }
}
def main(a:Array[String]) = {
val s = Seq(1,2,3,4,5,6,7)
println(s.transformIf(_ % 2 ==0){ _ * 2})
// result is (1, 4, 3, 8, 5, 12, 7)
}
}
Basically if you call a method that does not exists in the object you're calling it in (in this case, Seq), it will check if there's an implicit class that implements it, but it looks like a built in method.

Related

In Scala 2.10 how to add each element in two generic lists together

I am trying to rewrite some java math classes into Scala, but am having an odd problem.
class Polynomials[#specialized T](val coefficients:List[T]) {
def +(operand:Polynomials[T]):Polynomials[T] = {
return new Polynomials[T](coefficients =
(operand.coefficients, this.coefficients).zipped.map(_ + _))
}
}
My problem may be similar to this question: How do I make a class generic for all Numeric Types?, but when I remove the #specialized I get the same error.
type mismatch; found : T required: String
The second underscore in the map function is highlighted for the error, but I don't think that is the problem.
What I want to do is have:
Polynomial(1, 2, 3) + Polynomial(2, 3, 4) return Polynomial(3, 5, 7)
And Polynomial(1, 2, 3, 5) + Polynomial(2, 3, 4) return Polynomial(3, 5, 7, 5)
For the second one I may have to pad the shorter list with zero elements in order to get this to work, but that is my goal on this function.
So, how can I get this function to compile, so I can test it?
List is not specialized, so there's not much point making the class specialized. Only Array is specialized.
class Poly[T](val coef: List[T]) {
def +(op: Poly[T])(implicit adder: (T,T) => T) =
new Poly(Poly.combine(coef, op.coef, adder))
}
object Poly {
def combine[A](a: List[A], b: List[A], f: (A,A) => A, part: List[A] = Nil): List[A] = {
a match {
case Nil => if (b.isEmpty) part.reverse else combine(b,a,f,part)
case x :: xs => b match {
case Nil => part.reverse ::: a
case y :: ys => combine(xs, ys, f, f(x,y) :: part)
}
}
}
}
Now we can
implicit val stringAdd = (s: String, t: String) => (s+t)
scala> val p = new Poly(List("red","blue"))
p: Poly[String] = Poly#555214b9
scala> val q = new Poly(List("fish","cat","dog"))
q: Poly[String] = Poly#20f5498f
scala> val r = p+q; r.coef
r: Poly[String] = Poly#180f471e
res0: List[String] = List(redfish, bluecat, dog)
You could also ask the class provide the adder rather than the + method, or you could subclass Function2 so that you don't pollute things with implicit addition functions.

Remove duplicates in List specifying equality function

I have a List[A], how is a idiomatic way of removing duplicates given an equality function (a:A, b:A) => Boolean? I cannot generally override equalsfor A
The way I can think now is creating a wrapping class AExt with overridden equals, then
list.map(new AExt(_)).distinct
But I wonder if there's a cleaner way.
There is a simple (simpler) way to do this:
list.groupBy(_.key).mapValues(_.head)
If you want you can use the resulting map instantly by replacing _.head by a function block like:
sameElements => { val observedItem = sameElements.head
new A (var1 = observedItem.firstAttr,
var2 = "SomethingElse") }
to return a new A for each distinct element.
There is only one minor problem. The above code (list.groupBy(_.key).mapValues(_.head)) didnt explains very well the intention to remove duplicates. For that reason it would be great to have a function like distinctIn[A](attr: A => B) or distinctBy[A](eq: (A, A) -> Boolean).
Using the Foo and customEquals from misingFaktor's answer:
case class Foo(a: Int, b: Int)
val (a, b, c, d) = (Foo(3, 4), Foo(3, 1), Foo(2, 5), Foo(2, 5))
def customEquals(x: Foo, y: Foo) = x.a == y.a
(Seq(a, b, c, d).foldLeft(Seq[Foo]()) {
(unique, curr) => {
if (!unique.exists(customEquals(curr, _)))
curr +: unique
else
unique
}
}).reverse
If result ordering is important but the duplicate to be removed is not, then foldRight is preferable
Seq(a, b, c, d).foldRight(Seq[Foo]()) {
(curr, unique) => {
if (!unique.exists(customEquals(curr, _)))
curr +: unique
else
unique
}
}
I must say I think I'd go via an intermediate collection which was a Set if you expected that your Lists might be quite long as testing for presence (via exists or find) on a Seq is O(n) of course:
Rather than write a custom equals; decide what property the elements are equal by. So instead of:
def myCustomEqual(a1: A, a2: A) = a1.foo == a2.foo && a1.bar == a2.bar
Make a Key. Like so:
type Key = (Foo, Bar)
def key(a: A) = (a.foo, a.bar)
Then you can add the keys to a Set to see whether you have come across them before.
var keys = Set.empty[Key]
((List.empty[A] /: as) { (l, a) =>
val k = key(a)
if (keys(k)) l else { keys += k; a +: l }
}).reverse
Of course, this solution has worse space complexity and potentially worse performance (as you are creating extra objects - the keys) in the case of very short lists. If you do not like the var in the fold, you might like to look at how you could achieve this using State and Traverse from scalaz 7
scala> case class Foo(a: Int, b: Int)
defined class Foo
scala> val (a, b, c, d) = (Foo(3, 4), Foo(3, 1), Foo(2, 5), Foo(2, 5))
a: Foo = Foo(3,4)
b: Foo = Foo(3,1)
c: Foo = Foo(2,5)
d: Foo = Foo(2,5)
scala> def customEquals(x: Foo, y: Foo) = x.a == y.a
customEquals: (x: Foo, y: Foo)Boolean
scala> Seq(a, b, c, d) filter {
| var seq = Seq.empty[Foo]
| x => {
| if(seq.exists(customEquals(x, _))) {
| false
| } else {
| seq :+= x
| true
| }
| }
res13: Seq[Foo] = List(Foo(3,4), Foo(2,5))
Starting Scala 2.13, we can use the new distinctBy method which returns elements of a sequence ignoring the duplicates as determined by == after applying a transforming function f:
def distinctBy[B](f: (A) => B): List[A]
For instance:
// case class A(a: Int, b: String, c: Double)
// val list = List(A(1, "hello", 3.14), A(2, "world", 3.14), A(1, "hello", 12.3))
list.distinctBy(x => (x.a, x.b)) // List(A(1, "hello", 3.14), A(2, "world", 3.14))
list.distinctBy(_.c) // List(A(1, "hello", 3.14), A(1, "hello", 12.3))
case class Foo (a: Int, b: Int)
val x = List(Foo(3,4), Foo(3,1), Foo(2,5), Foo(2,5))
def customEquals(x : Foo, y: Foo) = (x.a == y.a && x.b == y.b)
x.foldLeft(Nil : List[Foo]) {(list, item) =>
val exists = list.find(x => customEquals(item, x))
if (exists.isEmpty) item :: list
else list
}.reverse
res0: List[Foo] = List(Foo(3,4), Foo(3,1), Foo(2,5))

Sequence with Streams in Scala

Suppose there is a sequence a[i] = f(a[i-1], a[i-2], ... a[i-k]). How would you code it using streams in Scala?
It will be possible to generalize it for any k, using an array for a and another k parameter, and having, f.i., the function with a rest... parameter.
def next(a1:Any, ..., ak:Any, f: (Any, ..., Any) => Any):Stream[Any] {
val n = f(a1, ..., ak)
Stream.cons(n, next(a2, ..., n, f))
}
val myStream = next(init1, ..., initk)
in order to have the 1000th do next.drop(1000)
An Update to show how this could be done with varargs. Beware that there is no arity check for the passed function:
object Test extends App {
def next(a:Seq[Long], f: (Long*) => Long): Stream[Long] = {
val v = f(a: _*)
Stream.cons(v, next(a.tail ++ Array(v), f))
}
def init(firsts:Seq[Long], rest:Seq[Long], f: (Long*) => Long):Stream[Long] = {
rest match {
case Nil => next(firsts, f)
case x :: xs => Stream.cons(x,init(firsts, xs, f))
}
}
def sum(a:Long*):Long = {
a.sum
}
val myStream = init(Seq[Long](1,1,1), Seq[Long](1,1,1), sum)
myStream.take(12).foreach(println)
}
Is this OK?
(a[i] = f(a[i-k], a[i-k+1], ... a[i-1]) instead of a[i] = f(a[i-1], a[i-2], ... a[i-k]), since I prefer to this way)
/**
Generating a Stream[T] by the given first k items and a function map k items to the next one.
*/
def getStream[T](f : T => Any,a : T*): Stream[T] = {
def invoke[T](fun: T => Any, es: T*): T = {
if(es.size == 1) fun.asInstanceOf[T=>T].apply(es.head)
else invoke(fun(es.head).asInstanceOf[T => Any],es.tail :_*)
}
Stream.iterate(a){ es => es.tail :+ invoke(f,es: _*)}.map{ _.head }
}
For example, the following code to generate Fibonacci sequence.
scala> val fn = (x: Int, y: Int) => x+y
fn: (Int, Int) => Int = <function2>
scala> val fib = getStream(fn.curried,1,1)
fib: Stream[Int] = Stream(1, ?)
scala> fib.take(10).toList
res0: List[Int] = List(1, 1, 2, 3, 5, 8, 13, 21, 34, 55)
The following code can generate a sequence {an} where a1 = 1, a2 = 2, a3 = 3, a(n+3) = a(n) + 2a(n+1) + 3a(n+2).
scala> val gn = (x: Int, y: Int, z: Int) => x + 2*y + 3*z
gn: (Int, Int, Int) => Int = <function3>
scala> val seq = getStream(gn.curried,1,2,3)
seq: Stream[Int] = Stream(1, ?)
scala> seq.take(10).toList
res1: List[Int] = List(1, 2, 3, 14, 50, 181, 657, 2383, 8644, 31355)
The short answer, that you are probably looking for, is a pattern to define your Stream once you have fixed a chosen k for the arity of f (i.e. you have a fixed type for f). The following pattern gives you a Stream which n-th element is the term a[n] of your sequence:
def recStreamK [A](f : A ⇒ A ⇒ ... A) (x1:A) ... (xk:A):Stream[A] =
x1 #:: recStreamK (f) (x2)(x3) ... (xk) (f(x1)(x2) ... (xk))
(credit : it is very close to the answer of andy petrella, except that the initial elements are set up correctly, and consequently the rank in the Stream matches that in the sequence)
If you want to generalize over k, this is possible in a type-safe manner (with arity checking) in Scala, using prioritized overlapping implicits. The code (˜80 lines) is available as a gist here. I'm afraid I got a little carried away, and explained it as an detailed & overlong blog post there.
Unfortunately, we cannot generalize over number and be type safe at the same time. So we’ll have to do it all manually:
def seq2[T, U](initials: Tuple2[T, T]) = new {
def apply(fun: Function2[T, T, T]): Stream[T] = {
initials._1 #::
initials._2 #::
(apply(fun) zip apply(fun).tail).map {
case (a, b) => fun(a, b)
}
}
}
And we get def fibonacci = seq2((1, 1))(_ + _).
def seq3[T, U](initials: Tuple3[T, T, T]) = new {
def apply(fun: Function3[T, T, T, T]): Stream[T] = {
initials._1 #::
initials._2 #::
initials._3 #::
(apply(fun) zip apply(fun).tail zip apply(fun).tail.tail).map {
case ((a, b), c) => fun(a, b, c)
}
}
}
def tribonacci = seq3((1, 1, 1))(_ + _ + _)
… and up to 22.
I hope the pattern is getting clear somehow. (We could of course improve and exchange the initials tuple with separate arguments. This saves us a pair of parentheses later when we use it.) If some day in the future, the Scala macro language arrives, this hopefully will be easier to define.

abstracting over a collection

Recently, I wrote an iterator for a cartesian product of Anys, and started with a List of List, but recognized, that I can easily switch to the more abstract trait Seq.
I know, you like to see the code. :)
class Cartesian (val ll: Seq[Seq[_]]) extends Iterator [Seq[_]] {
def combicount: Int = (1 /: ll) (_ * _.length)
val last = combicount
var iter = 0
override def hasNext (): Boolean = iter < last
override def next (): Seq[_] = {
val res = combination (ll, iter)
iter += 1
res
}
def combination (xx: Seq [Seq[_]], i: Int): List[_] = xx match {
case Nil => Nil
case x :: xs => x (i % x.length) :: combination (xs, i / x.length)
}
}
And a client of that class:
object Main extends Application {
val illi = new Cartesian (List ("abc".toList, "xy".toList, "AB".toList))
// val ivvi = new Cartesian (Vector (Vector (1, 2, 3), Vector (10, 20)))
val issi = new Cartesian (Seq (Seq (1, 2, 3), Seq (10, 20)))
// val iaai = new Cartesian (Array (Array (1, 2, 3), Array (10, 20)))
(0 to 5).foreach (dummy => println (illi.next ()))
// (0 to 5).foreach (dummy => println (issi.next ()))
}
/*
List(a, x, A)
List(b, x, A)
List(c, x, A)
List(a, y, A)
List(b, y, A)
List(c, y, A)
*/
The code works well for Seq and Lists (which are Seqs), but of course not for Arrays or Vector, which aren't of type Seq, and don't have a cons-method '::'.
But the logic could be used for such collections too.
I could try to write an implicit conversion to and from Seq for Vector, Array, and such, or try to write an own, similar implementation, or write an Wrapper, which transforms the collection to a Seq of Seq, and calls 'hasNext' and 'next' for the inner collection, and converts the result to an Array, Vector or whatever. (I tried to implement such workarounds, but I have to recognize: it's not that easy. For a real world problem I would probably rewrite the Iterator independently.)
However, the whole thing get's a bit out of control if I have to deal with Arrays of Lists or Lists of Arrays and other mixed cases.
What would be the most elegant way to write the algorithm in the broadest, possible way?
There are two solutions. The first is to not require the containers to be a subclass of some generic super class, but to be convertible to one (by using implicit function arguments). If the container is already a subclass of the required type, there's a predefined identity conversion which only returns it.
import collection.mutable.Builder
import collection.TraversableLike
import collection.generic.CanBuildFrom
import collection.mutable.SeqLike
class Cartesian[T, ST[T], TT[S]](val ll: TT[ST[T]])(implicit cbf: CanBuildFrom[Nothing, T, ST[T]], seqLike: ST[T] => SeqLike[T, ST[T]], traversableLike: TT[ST[T]] => TraversableLike[ST[T], TT[ST[T]]] ) extends Iterator[ST[T]] {
def combicount (): Int = (1 /: ll) (_ * _.length)
val last = combicount - 1
var iter = 0
override def hasNext (): Boolean = iter < last
override def next (): ST[T] = {
val res = combination (ll, iter, cbf())
iter += 1
res
}
def combination (xx: TT[ST[T]], i: Int, builder: Builder[T, ST[T]]): ST[T] =
if (xx.isEmpty) builder.result
else combination (xx.tail, i / xx.head.length, builder += xx.head (i % xx.head.length) )
}
This sort of works:
scala> new Cartesian[String, Vector, Vector](Vector(Vector("a"), Vector("xy"), Vector("AB")))
res0: Cartesian[String,Vector,Vector] = empty iterator
scala> new Cartesian[String, Array, Array](Array(Array("a"), Array("xy"), Array("AB")))
res1: Cartesian[String,Array,Array] = empty iterator
I needed to explicitly pass the types because of bug https://issues.scala-lang.org/browse/SI-3343
One thing to note is that this is better than using existential types, because calling next on the iterator returns the right type, and not Seq[Any].
There are several drawbacks here:
If the container is not a subclass of the required type, it is converted to one, which costs in performance
The algorithm is not completely generic. We need types to be converted to SeqLike or TraversableLike only to use a subset of functionality these types offer. So making a conversion function can be tricky.
What if some capabilities can be interpreted differently in different contexts? For example, a rectangle has two 'length' properties (width and height)
Now for the alternative solution. We note that we don't actually care about the types of collections, just their capabilities:
TT should have foldLeft, get(i: Int) (to get head/tail)
ST should have length, get(i: Int) and a Builder
So we can encode these:
trait HasGet[T, CC[_]] {
def get(cc: CC[T], i: Int): T
}
object HasGet {
implicit def seqLikeHasGet[T, CC[X] <: SeqLike[X, _]] = new HasGet[T, CC] {
def get(cc: CC[T], i: Int): T = cc(i)
}
implicit def arrayHasGet[T] = new HasGet[T, Array] {
def get(cc: Array[T], i: Int): T = cc(i)
}
}
trait HasLength[CC] {
def length(cc: CC): Int
}
object HasLength {
implicit def seqLikeHasLength[CC <: SeqLike[_, _]] = new HasLength[CC] {
def length(cc: CC) = cc.length
}
implicit def arrayHasLength[T] = new HasLength[Array[T]] {
def length(cc: Array[T]) = cc.length
}
}
trait HasFold[T, CC[_]] {
def foldLeft[A](cc: CC[T], zero: A)(op: (A, T) => A): A
}
object HasFold {
implicit def seqLikeHasFold[T, CC[X] <: SeqLike[X, _]] = new HasFold[T, CC] {
def foldLeft[A](cc: CC[T], zero: A)(op: (A, T) => A): A = cc.foldLeft(zero)(op)
}
implicit def arrayHasFold[T] = new HasFold[T, Array] {
def foldLeft[A](cc: Array[T], zero: A)(op: (A, T) => A): A = {
var i = 0
var result = zero
while (i < cc.length) {
result = op(result, cc(i))
i += 1
}
result
}
}
}
(strictly speaking, HasFold is not required since its implementation is in terms of length and get, but i added it here so the algorithm will translate more cleanly)
now the algorithm is:
class Cartesian[T, ST[_], TT[Y]](val ll: TT[ST[T]])(implicit cbf: CanBuildFrom[Nothing, T, ST[T]], stHasLength: HasLength[ST[T]], stHasGet: HasGet[T, ST], ttHasFold: HasFold[ST[T], TT], ttHasGet: HasGet[ST[T], TT], ttHasLength: HasLength[TT[ST[T]]]) extends Iterator[ST[T]] {
def combicount (): Int = ttHasFold.foldLeft(ll, 1)((a,l) => a * stHasLength.length(l))
val last = combicount - 1
var iter = 0
override def hasNext (): Boolean = iter < last
override def next (): ST[T] = {
val res = combination (ll, 0, iter, cbf())
iter += 1
res
}
def combination (xx: TT[ST[T]], j: Int, i: Int, builder: Builder[T, ST[T]]): ST[T] =
if (ttHasLength.length(xx) == j) builder.result
else {
val head = ttHasGet.get(xx, j)
val headLength = stHasLength.length(head)
combination (xx, j + 1, i / headLength, builder += stHasGet.get(head, (i % headLength) ))
}
}
And use:
scala> new Cartesian[String, Vector, List](List(Vector("a"), Vector("xy"), Vector("AB")))
res6: Cartesian[String,Vector,List] = empty iterator
scala> new Cartesian[String, Array, Array](Array(Array("a"), Array("xy"), Array("AB")))
res7: Cartesian[String,Array,Array] = empty iterator
Scalaz probably has all of this predefined for you, unfortunately, I don't know it well.
(again I need to pass the types because inference doesn't infer the right kind)
The benefit is that the algorithm is now completely generic and that there is no need for implicit conversions from Array to WrappedArray in order for it to work
Excercise: define for tuples ;-)

How can I extend Scala collections with an argmax method?

I would like to add to all collections where it makes sense, an argMax method.
How to do it? Use implicits?
On Scala 2.8, this works:
val list = List(1, 2, 3)
def f(x: Int) = -x
val argMax = list max (Ordering by f)
As pointed by mkneissl, this does not return the set of maximum points. Here's an alternate implementation that does, and tries to reduce the number of calls to f. If calls to f don't matter that much, see mkneissl's answer. Also, note that his answer is curried, which provides superior type inference.
def argMax[A, B: Ordering](input: Iterable[A], f: A => B) = {
val fList = input map f
val maxFList = fList.max
input.view zip fList filter (_._2 == maxFList) map (_._1) toSet
}
scala> argMax(-2 to 2, (x: Int) => x * x)
res15: scala.collection.immutable.Set[Int] = Set(-2, 2)
The argmax function (as I understand it from Wikipedia)
def argMax[A,B](c: Traversable[A])(f: A=>B)(implicit o: Ordering[B]): Traversable[A] = {
val max = (c map f).max(o)
c filter { f(_) == max }
}
If you really want, you can pimp it onto the collections
implicit def enhanceWithArgMax[A](c: Traversable[A]) = new {
def argMax[B](f: A=>B)(implicit o: Ordering[B]): Traversable[A] = ArgMax.argMax(c)(f)(o)
}
and use it like this
val l = -2 to 2
assert (argMax(l)(x => x*x) == List(-2,2))
assert (l.argMax(x => x*x) == List(-2,2))
(Scala 2.8)
Yes, the usual way would be to use the 'pimp my library' pattern to decorate your collection. For example (N.B. just as illustration, not meant to be a correct or working example):
trait PimpedList[A] {
val l: List[A]
//example argMax, not meant to be correct
def argMax[T <% Ordered[T]](f:T => T) = {error("your definition here")}
}
implicit def toPimpedList[A](xs: List[A]) = new PimpedList[A] {
val l = xs
}
scala> def f(i:Int):Int = 10
f: (i: Int) Int
scala> val l = List(1,2,3)
l: List[Int] = List(1, 2, 3)
scala> l.argMax(f)
java.lang.RuntimeException: your definition here
at scala.Predef$.error(Predef.scala:60)
at PimpedList$class.argMax(:12)
//etc etc...
Nice and easy ? :
val l = List(1,0,10,2)
l.zipWithIndex.maxBy(x => x._1)._2
You can add functions to an existing API in Scala by using the Pimp my Library pattern. You do this by defining an implicit conversion function. For example, I have a class Vector3 to represent 3D vectors:
class Vector3 (val x: Float, val y: Float, val z: Float)
Suppose I want to be able to scale a vector by writing something like: 2.5f * v. I can't directly add a * method to class Float ofcourse, but I can supply an implicit conversion function like this:
implicit def scaleVector3WithFloat(f: Float) = new {
def *(v: Vector3) = new Vector3(f * v.x, f * v.y, f * v.z)
}
Note that this returns an object of a structural type (the new { ... } construct) that contains the * method.
I haven't tested it, but I guess you could do something like this:
implicit def argMaxImplicit[A](t: Traversable[A]) = new {
def argMax() = ...
}
Here's a way of doing so with the implicit builder pattern. It has the advantage over the previous solutions that it works with any Traversable, and returns a similar Traversable. Sadly, it's pretty imperative. If anyone wants to, it could probably be turned into a fairly ugly fold instead.
object RichTraversable {
implicit def traversable2RichTraversable[A](t: Traversable[A]) = new RichTraversable[A](t)
}
class RichTraversable[A](t: Traversable[A]) {
def argMax[That, C](g: A => C)(implicit bf : scala.collection.generic.CanBuildFrom[Traversable[A], A, That], ord:Ordering[C]): That = {
var minimum:C = null.asInstanceOf[C]
val repr = t.repr
val builder = bf(repr)
for(a<-t){
val test: C = g(a)
if(test == minimum || minimum == null){
builder += a
minimum = test
}else if (ord.gt(test, minimum)){
builder.clear
builder += a
minimum = test
}
}
builder.result
}
}
Set(-2, -1, 0, 1, 2).argmax(x=>x*x) == Set(-2, 2)
List(-2, -1, 0, 1, 2).argmax(x=>x*x) == List(-2, 2)
Here's a variant loosely based on #Daniel's accepted answer that also works for Sets.
def argMax[A, B: Ordering](input: GenIterable[A], f: A => B) : GenSet[A] = argMaxZip(input, f) map (_._1) toSet
def argMaxZip[A, B: Ordering](input: GenIterable[A], f: A => B): GenIterable[(A, B)] = {
if (input.isEmpty) Nil
else {
val fPairs = input map (x => (x, f(x)))
val maxF = fPairs.map(_._2).max
fPairs filter (_._2 == maxF)
}
}
One could also do a variant that produces (B, Iterable[A]), of course.
Based on other answers, you can pretty easily combine the strengths of each (minimal calls to f(), etc.). Here we have an implicit conversion for all Iterables (so they can just call .argmax() transparently), and a stand-alone method if for some reason that is preferred. ScalaTest tests to boot.
class Argmax[A](col: Iterable[A]) {
def argmax[B](f: A => B)(implicit ord: Ordering[B]): Iterable[A] = {
val mapped = col map f
val max = mapped max ord
(mapped zip col) filter (_._1 == max) map (_._2)
}
}
object MathOps {
implicit def addArgmax[A](col: Iterable[A]) = new Argmax(col)
def argmax[A, B](col: Iterable[A])(f: A => B)(implicit ord: Ordering[B]) = {
new Argmax(col) argmax f
}
}
class MathUtilsTests extends FunSuite {
import MathOps._
test("Can argmax with unique") {
assert((-10 to 0).argmax(_ * -1).toSet === Set(-10))
// or alternate calling syntax
assert(argmax(-10 to 0)(_ * -1).toSet === Set(-10))
}
test("Can argmax with multiple") {
assert((-10 to 10).argmax(math.pow(_, 2)).toSet === Set(-10, 10))
}
}