I wanted to memoize this:
def fib(n: Int) = if(n <= 1) 1 else fib(n-1) + fib(n-2)
println(fib(100)) // times out
So I wrote this and this surprisingly compiles and works (I am surprised because fib references itself in its declaration):
case class Memo[A,B](f: A => B) extends (A => B) {
private val cache = mutable.Map.empty[A, B]
def apply(x: A) = cache getOrElseUpdate (x, f(x))
}
val fib: Memo[Int, BigInt] = Memo {
case 0 => 0
case 1 => 1
case n => fib(n-1) + fib(n-2)
}
println(fib(100)) // prints 100th fibonacci number instantly
But when I try to declare fib inside of a def, I get a compiler error:
def foo(n: Int) = {
val fib: Memo[Int, BigInt] = Memo {
case 0 => 0
case 1 => 1
case n => fib(n-1) + fib(n-2)
}
fib(n)
}
Above fails to compile error: forward reference extends over definition of value fib
case n => fib(n-1) + fib(n-2)
Why does declaring the val fib inside a def fails but outside in the class/object scope works?
To clarify, why I might want to declare the recursive memoized function in the def scope - here is my solution to the subset sum problem:
/**
* Subset sum algorithm - can we achieve sum t using elements from s?
*
* #param s set of integers
* #param t target
* #return true iff there exists a subset of s that sums to t
*/
def subsetSum(s: Seq[Int], t: Int): Boolean = {
val max = s.scanLeft(0)((sum, i) => (sum + i) max sum) //max(i) = largest sum achievable from first i elements
val min = s.scanLeft(0)((sum, i) => (sum + i) min sum) //min(i) = smallest sum achievable from first i elements
val dp: Memo[(Int, Int), Boolean] = Memo { // dp(i,x) = can we achieve x using the first i elements?
case (_, 0) => true // 0 can always be achieved using empty set
case (0, _) => false // if empty set, non-zero cannot be achieved
case (i, x) if min(i) <= x && x <= max(i) => dp(i-1, x - s(i-1)) || dp(i-1, x) // try with/without s(i-1)
case _ => false // outside range otherwise
}
dp(s.length, t)
}
I found a better way to memoize using Scala:
def memoize[I, O](f: I => O): I => O = new mutable.HashMap[I, O]() {
override def apply(key: I) = getOrElseUpdate(key, f(key))
}
Now you can write fibonacci as follows:
lazy val fib: Int => BigInt = memoize {
case 0 => 0
case 1 => 1
case n => fib(n-1) + fib(n-2)
}
Here's one with multiple arguments (the choose function):
lazy val c: ((Int, Int)) => BigInt = memoize {
case (_, 0) => 1
case (n, r) if r > n/2 => c(n, n - r)
case (n, r) => c(n - 1, r - 1) + c(n - 1, r)
}
And here's the subset sum problem:
// is there a subset of s which has sum = t
def isSubsetSumAchievable(s: Vector[Int], t: Int) = {
// f is (i, j) => Boolean i.e. can the first i elements of s add up to j
lazy val f: ((Int, Int)) => Boolean = memoize {
case (_, 0) => true // 0 can always be achieved using empty list
case (0, _) => false // we can never achieve non-zero if we have empty list
case (i, j) =>
val k = i - 1 // try the kth element
f(k, j - s(k)) || f(k, j)
}
f(s.length, t)
}
EDIT: As discussed below, here is a thread-safe version
def memoize[I, O](f: I => O): I => O = new mutable.HashMap[I, O]() {self =>
override def apply(key: I) = self.synchronized(getOrElseUpdate(key, f(key)))
}
Class/trait level val compiles to a combination of a method and a private variable. Hence a recursive definition is allowed.
Local vals on the other hand are just regular variables, and thus recursive definition is not allowed.
By the way, even if the def you defined worked, it wouldn't do what you expect. On every invocation of foo a new function object fib will be created and it will have its own backing map. What you should be doing instead is this (if you really want a def to be your public interface):
private val fib: Memo[Int, BigInt] = Memo {
case 0 => 0
case 1 => 1
case n => fib(n-1) + fib(n-2)
}
def foo(n: Int) = {
fib(n)
}
Scalaz has a solution for that, why not reuse it?
import scalaz.Memo
lazy val fib: Int => BigInt = Memo.mutableHashMapMemo {
case 0 => 0
case 1 => 1
case n => fib(n-2) + fib(n-1)
}
You can read more about memoization in Scalaz.
Mutable HashMap isn't thread safe. Also defining case statements separately for base conditions seems unnecessary special handling, rather Map can be loaded with initial values and passed to Memoizer. Following would be the signature of Memoizer where it accepts a memo(immutable Map) and formula and returns a recursive function.
Memoizer would look like
def memoize[I,O](memo: Map[I, O], formula: (I => O, I) => O): I => O
Now given a following Fibonacci formula,
def fib(f: Int => Int, n: Int) = f(n-1) + f(n-2)
fibonacci with Memoizer can be defined as
val fibonacci = memoize( Map(0 -> 0, 1 -> 1), fib)
where context agnostic general purpose Memoizer is defined as
def memoize[I, O](map: Map[I, O], formula: (I => O, I) => O): I => O = {
var memo = map
def recur(n: I): O = {
if( memo contains n) {
memo(n)
} else {
val result = formula(recur, n)
memo += (n -> result)
result
}
}
recur
}
Similarly, for factorial, a formula is
def fac(f: Int => Int, n: Int): Int = n * f(n-1)
and factorial with Memoizer is
val factorial = memoize( Map(0 -> 1, 1 -> 1), fac)
Inspiration: Memoization, Chapter 4 of Javascript good parts by Douglas Crockford
Related
I want to create a function that would return a function which is n-times composed function of f over parameter x, i.e. f(f(f ... f(x) ...)).
Here is my code:
def repeated(f: Int => Int, n: Int) = {
var tek: Int => Int = f
for (i <- 0 to n) {
tek = x => f(tek(x))
}
tek
}
I know this is not right way to do this in Scala, I just want to find out what's happening behind the scenes.
Calling it like repeated(x => x + 1, 5)(1) would result in stack overflow.
What I have notice in debugger is that line inside for loop is executed after repeated is finished. It seems like lazy initiation, maybe body of for loop is a lambda passed by name?
In pure FP:
def repeated[A](f: A => A, n: Int): A => A =
(0 until n).foldLeft(identity[A] _)((ff, _) => ff.andThen(f))
(also works if n=0 - becomes identity)
Or, if you don't like iterating over a Range (which I think wouldn't be much less performant than alternatives), manual tail recursion:
def repeated[A](f: A => A, n: Int): A => A = {
#tailrec def aux(acc: A => A, n: Int): A => A = {
if(n > 0) aux(acc.andThen(f), n - 1)
else acc
}
aux(identity, n)
}
EDIT: there's also the Stream version, as #Karl Bielefeldt mentioned. Should be about about as performant, but of course the best way to choose is to benchmark with your usecase:
def repeated[A](f: A => A, n: Int): A => A =
Stream.iterate(identity[A] _)(_.andThen(f)).apply(n)
EDIT 2: if you have Cats:
def repeated[A](f: A => A, n: Int): A => A =
MonoidK[Endo].algebra[A].combineN(f, n)
Your x => f(tek(x)) is closing over the variable tek. Once the inner for-loop runs at least once, your tek becomes self-referential, because tek = x => f(tek(x)) calls itself, which causes unbounded recursion and StackOverflowError.
If you wanted to do it with a for-loop, you could introduce a local immutable helper variable to break the recursion:
def repeated(f: Int => Int, n: Int) = {
var tek: Int => Int = identity
for (i <- 1 to n) {
val g = tek
tek = x => f(g(x))
}
tek
}
Note that you had at least two applications of f too much in your code:
You didn't start with identity for n = 0
You iterated from 0 to n, that is, (n + 1) times.
A much simpler solution would have been:
def repeated[A](f: A => A, n: Int): A => A = { (a0: A) =>
var res: A = a0
for (i <- 1 to n) {
res = f(res)
}
res
}
Generally the boilerplate of while loop looks like(r is the result I want, p is the predictor:
var r, p;
while(p()) {
(r, p) = compute(r)
}
I can convert it into a recursion to get rid of var:
def f(r) = {
val (nr, p) = compute(r)
if(p()) nr
else f(nr)
}
Is there a built-in way to implement such logic? I knew Iterator.continually, but it seems still to take a var to store side effect.
def compute(i: Int): (Int, () => Boolean) =
(i - 1) -> { () => i > 1 }
To create an immutable while you'll need iteration - a function that accepts state and returns new state of same type plus exit condition.
Iterator.continually
It's not the best solution - in my opinion this code is hard to read, but since you mentioned it:
val (r, p) = Iterator.continually(()).
scanLeft( 13 -> { () => true } ){
case ((r, p), _) => compute(r)
}.dropWhile{ case (r, p) => p() }.
next
// r: Int = 0
// p: () => Boolean = <function0>
You could use val (r, _) = since you don't need p.
If you want a solution with Iterator see this answer with Iterator.iterate.
Tail recursion
I guess this is an idiomatic solution. You could always rewrite your while loop as tail recursion with explicit state type:
#annotation.tailrec
def lastWhile[T](current: T)(f: T => (T, () => Boolean)): T = {
val (r, p) = f(current)
if (p()) lastWhile(r)(f)
else r
}
lastWhile(13){ compute }
// Int = 0
Scalaz unfold
In case you are using scalaz there is such method already. It produces a Stream, so you should get the last element.
At the end of an iteration you should produce an Option (None is an exit condition) with Pair of stream element (r) and next state (r, p()):
unfold(13 -> true) { case (r0, p0) =>
val (r, p) = compute(r0)
p0.option(r -> (r, p()))
}.last
// Int = 0
I don't know if this really answers the question of "built-in." I don't believe there's a solution that is simpler to implement or understand than your recursive routine. But here's another way of attacking the problem.
You can use Iterator.iterate to create an infinite iterator, and then find the first element that fails the predicate.
// Count until we are greater than 5
def compute(r: Int): (Int, () => Boolean) = {
(r + 1, () => (r < 5))
}
// Start at the beginning
val r = 1
val p = () => true
// Create an infinite iterator of computations
val it = Iterator.iterate((r, p))({
case (r, _) => compute(r)
})
// Find (optionally) the first element that fails p. Then get() the result.
val result = it.find({ case (_, p) => !p() })
.map { _._1 }
.get
I had a simple task to find combination which occurs most often when we drop 4 cubic dices an remove one with least points.
So, the question is: are there any Scala core classes to generate streams of cartesian products in Scala? When not - how to implement it in the most simple and effective way?
Here is the code and comparison with naive implementation in Scala:
object D extends App {
def dropLowest(a: List[Int]) = {
a diff List(a.min)
}
def cartesian(to: Int, times: Int): Stream[List[Int]] = {
def stream(x: List[Int]): Stream[List[Int]] = {
if (hasNext(x)) x #:: stream(next(x)) else Stream(x)
}
def hasNext(x: List[Int]) = x.exists(n => n < to)
def next(x: List[Int]) = {
def add(current: List[Int]): List[Int] = {
if (current.head == to) 1 :: add(current.tail) else current.head + 1 :: current.tail // here is a possible bug when we get maximal value, don't reuse this method
}
add(x.reverse).reverse
}
stream(Range(0, times).map(t => 1).toList)
}
def getResult(list: Stream[List[Int]]) = {
list.map(t => dropLowest(t).sum).groupBy(t => t).map(t => (t._1, t._2.size)).toMap
}
val list1 = cartesian(6, 4)
val list = for (i <- Range(1, 7); j <- Range(1,7); k <- Range(1, 7); l <- Range(1, 7)) yield List(i, j, k, l)
println(getResult(list1))
println(getResult(list.toStream) equals getResult(list1))
}
Thanks in advance
I think you can simplify your code by using flatMap :
val stream = (1 to 6).toStream
def cartesian(times: Int): Stream[Seq[Int]] = {
if (times == 0) {
Stream(Seq())
} else {
stream.flatMap { i => cartesian(times - 1).map(i +: _) }
}
}
Maybe a little bit more efficient (memory-wise) would be using Iterators instead:
val pool = (1 to 6)
def cartesian(times: Int): Iterator[Seq[Int]] = {
if (times == 0) {
Iterator(Seq())
} else {
pool.iterator.flatMap { i => cartesian(times - 1).map(i +: _) }
}
}
or even more concise by replacing the recursive calls by a fold :
def cartesian[A](list: Seq[Seq[A]]): Iterator[Seq[A]] =
list.foldLeft(Iterator(Seq[A]())) {
case (acc, l) => acc.flatMap(i => l.map(_ +: i))
}
and then:
cartesian(Seq.fill(4)(1 to 6)).map(dropLowest).toSeq.groupBy(i => i.sorted).mapValues(_.size).toSeq.sortBy(_._2).foreach(println)
(Note that you cannot use groupBy on Iterators, so Streams or even Lists are the way to go whatever to be; above code still valid since toSeq on an Iterator actually returns a lazy Stream).
If you are considering stats on the sums of dice instead of combinations, you can update the dropLowest fonction :
def dropLowest(l: Seq[Int]) = l.sum - l.min
Is there any difference between this code:
for(term <- term_array) {
val list = hashmap.get(term)
...
}
and:
for(term <- term_array; val list = hashmap.get(term)) {
...
}
Inside the loop I'm changing the hashmap with something like this
hashmap.put(term, string :: list)
While checking for the head of list it seems to be outdated somehow when using the second code snippet.
The difference between the two is, that the first one is a definition which is created by pattern matching and the second one is a value inside a function literal. See Programming in Scala, Section 23.1 For Expressions:
for {
p <- persons // a generator
n = p.name // a definition
if (n startsWith "To") // a filter
} yield n
You see the real difference when you compile sources with scalac -Xprint:typer <filename>.scala:
object X {
val x1 = for (i <- (1 to 5); x = i*2) yield x
val x2 = for (i <- (1 to 5)) yield { val x = i*2; x }
}
After code transforming by the compiler you will get something like this:
private[this] val x1: scala.collection.immutable.IndexedSeq[Int] =
scala.this.Predef.intWrapper(1).to(5).map[(Int, Int), scala.collection.immutable.IndexedSeq[(Int, Int)]](((i: Int) => {
val x: Int = i.*(2);
scala.Tuple2.apply[Int, Int](i, x)
}))(immutable.this.IndexedSeq.canBuildFrom[(Int, Int)]).map[Int, scala.collection.immutable.IndexedSeq[Int]]((
(x$1: (Int, Int)) => (x$1: (Int, Int) #unchecked) match {
case (_1: Int, _2: Int)(Int, Int)((i # _), (x # _)) => x
}))(immutable.this.IndexedSeq.canBuildFrom[Int]);
private[this] val x2: scala.collection.immutable.IndexedSeq[Int] =
scala.this.Predef.intWrapper(1).to(5).map[Int, scala.collection.immutable.IndexedSeq[Int]](((i: Int) => {
val x: Int = i.*(2);
x
}))(immutable.this.IndexedSeq.canBuildFrom[Int]);
This can be simplified to:
val x1 = (1 to 5).map {i =>
val x: Int = i * 2
(i, x)
}.map {
case (i, x) => x
}
val x2 = (1 to 5).map {i =>
val x = i * 2
x
}
Instantiating variables inside for loops makes sense if you want to use that variable the for statement, like:
for (i <- is; a = something; if (a)) {
...
}
And the reason why your list is outdated, is that this translates to a foreach call, such as:
term_array.foreach {
term => val list= hashmap.get(term)
} foreach {
...
}
So when you reach ..., your hashmap has already been changed. The other example translates to:
term_array.foreach {
term => val list= hashmap.get(term)
...
}
Each time a function is called, if it's result for a given set of argument values is not yet memoized I'd like to put the result into an in-memory table. One column is meant to store a result, others to store arguments values.
How do I best implement this? Arguments are of diverse types, including some enums.
In C# I'd generally use DataTable. Is there an equivalent in Scala?
You could use a mutable.Map[TupleN[A1, A2, ..., AN], R] , or if memory is a concern, a WeakHashMap[1]. The definitions below (built on the memoization code from michid's blog) allow you to easily memoize functions with multiple arguments. For example:
import Memoize._
def reallySlowFn(i: Int, s: String): Int = {
Thread.sleep(3000)
i + s.length
}
val memoizedSlowFn = memoize(reallySlowFn _)
memoizedSlowFn(1, "abc") // returns 4 after about 3 seconds
memoizedSlowFn(1, "abc") // returns 4 almost instantly
Definitions:
/**
* A memoized unary function.
*
* #param f A unary function to memoize
* #param [T] the argument type
* #param [R] the return type
*/
class Memoize1[-T, +R](f: T => R) extends (T => R) {
import scala.collection.mutable
// map that stores (argument, result) pairs
private[this] val vals = mutable.Map.empty[T, R]
// Given an argument x,
// If vals contains x return vals(x).
// Otherwise, update vals so that vals(x) == f(x) and return f(x).
def apply(x: T): R = vals getOrElseUpdate (x, f(x))
}
object Memoize {
/**
* Memoize a unary (single-argument) function.
*
* #param f the unary function to memoize
*/
def memoize[T, R](f: T => R): (T => R) = new Memoize1(f)
/**
* Memoize a binary (two-argument) function.
*
* #param f the binary function to memoize
*
* This works by turning a function that takes two arguments of type
* T1 and T2 into a function that takes a single argument of type
* (T1, T2), memoizing that "tupled" function, then "untupling" the
* memoized function.
*/
def memoize[T1, T2, R](f: (T1, T2) => R): ((T1, T2) => R) =
Function.untupled(memoize(f.tupled))
/**
* Memoize a ternary (three-argument) function.
*
* #param f the ternary function to memoize
*/
def memoize[T1, T2, T3, R](f: (T1, T2, T3) => R): ((T1, T2, T3) => R) =
Function.untupled(memoize(f.tupled))
// ... more memoize methods for higher-arity functions ...
/**
* Fixed-point combinator (for memoizing recursive functions).
*/
def Y[T, R](f: (T => R) => T => R): (T => R) = {
lazy val yf: (T => R) = memoize(f(yf)(_))
yf
}
}
The fixed-point combinator (Memoize.Y) makes it possible to memoize recursive functions:
val fib: BigInt => BigInt = {
def fibRec(f: BigInt => BigInt)(n: BigInt): BigInt = {
if (n == 0) 1
else if (n == 1) 1
else (f(n-1) + f(n-2))
}
Memoize.Y(fibRec)
}
[1] WeakHashMap does not work well as a cache. See http://www.codeinstructions.com/2008/09/weakhashmap-is-not-cache-understanding.html and this related question.
The version suggested by anovstrup using a mutable Map is basically the same as in C#, and therefore easy to use.
But if you want you can also use a more functional style as well. It uses immutable maps, which act as a kind of accumalator. Having Tuples (instead of Int in the example) as keys works exactly as in the mutable case.
def fib(n:Int) = fibM(n, Map(0->1, 1->1))._1
def fibM(n:Int, m:Map[Int,Int]):(Int,Map[Int,Int]) = m.get(n) match {
case Some(f) => (f, m)
case None => val (f_1,m1) = fibM(n-1,m)
val (f_2,m2) = fibM(n-2,m1)
val f = f_1+f_2
(f, m2 + (n -> f))
}
Of course this is a little bit more complicated, but a useful technique to know (note that the code above aims for clarity, not for speed).
Being a newbie in this subject, I could fully understand none of the examples given (but would like to thank anyway). Respectfully, I'd present my own solution for the case some one comes here having a same level and same problem. I think my code can be crystal clear for anybody having just the very-very basic Scala knowledge.
def MyFunction(dt : DateTime, param : Int) : Double
{
val argsTuple = (dt, param)
if(Memo.contains(argsTuple)) Memo(argsTuple) else Memoize(dt, param, MyRawFunction(dt, param))
}
def MyRawFunction(dt : DateTime, param : Int) : Double
{
1.0 // A heavy calculation/querying here
}
def Memoize(dt : DateTime, param : Int, result : Double) : Double
{
Memo += (dt, param) -> result
result
}
val Memo = new scala.collection.mutable.HashMap[(DateTime, Int), Double]
Works perfectly. I'd appreciate critique If I've missed something.
When using mutable map for memoization, one shall keep in mind that this would cause typical concurrency problems, e.g. doing a get when a write has not completed yet. However, thread-safe attemp of memoization suggests to do so it's of little value if not none.
The following thread-safe code creates a memoized fibonacci function, initiates a couple of threads (named from 'a' through to 'd') that make calls to it. Try the code a couple of times (in REPL), one can easily see f(2) set gets printed more than once. This means a thread A has initiated the calculation of f(2) but Thread B has totally no idea of it and starts its own copy of calculation. Such ignorance is so pervasive at the constructing phase of the cache, because all threads see no sub solution established and would enter the else clause.
object ScalaMemoizationMultithread {
// do not use case class as there is a mutable member here
class Memo[-T, +R](f: T => R) extends (T => R) {
// don't even know what would happen if immutable.Map used in a multithreading context
private[this] val cache = new java.util.concurrent.ConcurrentHashMap[T, R]
def apply(x: T): R =
// no synchronized needed as there is no removal during memoization
if (cache containsKey x) {
Console.println(Thread.currentThread().getName() + ": f(" + x + ") get")
cache.get(x)
} else {
val res = f(x)
Console.println(Thread.currentThread().getName() + ": f(" + x + ") set")
cache.putIfAbsent(x, res) // atomic
res
}
}
object Memo {
def apply[T, R](f: T => R): T => R = new Memo(f)
def Y[T, R](F: (T => R) => T => R): T => R = {
lazy val yf: T => R = Memo(F(yf)(_))
yf
}
}
val fibonacci: Int => BigInt = {
def fiboF(f: Int => BigInt)(n: Int): BigInt = {
if (n <= 0) 1
else if (n == 1) 1
else f(n - 1) + f(n - 2)
}
Memo.Y(fiboF)
}
def main(args: Array[String]) = {
('a' to 'd').foreach(ch =>
new Thread(new Runnable() {
def run() {
import scala.util.Random
val rand = new Random
(1 to 2).foreach(_ => {
Thread.currentThread().setName("Thread " + ch)
fibonacci(5)
})
}
}).start)
}
}
In addition to Landei's answer, I also want to suggest the bottom-up (non-memoization) way of doing DP in Scala is possible, and the core idea is to use foldLeft(s).
Example for computing Fibonacci numbers
def fibo(n: Int) = (1 to n).foldLeft((0, 1)) {
(acc, i) => (acc._2, acc._1 + acc._2)
}._1
Example for longest increasing subsequence
def longestIncrSubseq[T](xs: List[T])(implicit ord: Ordering[T]) = {
xs.foldLeft(List[(Int, List[T])]()) {
(memo, x) =>
if (memo.isEmpty) List((1, List(x)))
else {
val resultIfEndsAtCurr = (memo, xs).zipped map {
(tp, y) =>
val len = tp._1
val seq = tp._2
if (ord.lteq(y, x)) { // current is greater than the previous end
(len + 1, x :: seq) // reversely recorded to avoid O(n)
} else {
(1, List(x)) // start over
}
}
memo :+ resultIfEndsAtCurr.maxBy(_._1)
}
}.maxBy(_._1)._2.reverse
}