lazy sorts of iterators in Scala? - scala

I've read that in haskell, when sorting an iterator, it only evaluates as much of the qsort as necessary to return the number of values actually evaluated on the resulting iterator (i.e, it is lazy, i.e., once it has completed the LHS of the first pivot and can return one value, it can provide that one value on a call to "next" on the iterator and not continue pivoting unless next is called again).
For example, in haskell, head(qsort list) is O(n). It just finds the minimum value in the list, and doesn't sort the rest of the list unless the rest of the result of the qsort list is accessed.
Is there a way to do this in Scala? I want to use sortWith on a collection but only sort as much as necessary, such that I can mySeq.sortWith(<).take(3) and have it not need to complete the sort operation.
I'd like to know if other sort functions (like sortBy) can be used in a lazy way, and how to ensure laziness, and how to find any other documentation about when sorts in Scala are or are not lazily evaluated.
UPDATE/EDIT: I'm ideally looking for ways to do this with standard sorting functions like sortWith. I'd rather not have to implement my own version of quicksort just to get lazy evaluation. Shouldn't this be built into the standard library, at least for collections like Stream that support laziness??

I've used Scala's priority queue implementation to solve this kind of partial sorting problem:
import scala.collection.mutable.PriorityQueue
val q = PriorityQueue(1289, 12, 123, 894, 1)(Ordering.Int.reverse)
Now we can call dequeue:
scala> q.dequeue
res0: Int = 1
scala> q.dequeue
res1: Int = 12
scala> q.dequeue
res2: Int = 123
It costs O(n) to build the queue and O(k log n) to take the first k elements.
Unfortunately PriorityQueue doesn't iterate in priority order, but it's not too hard to write an iterator that calls dequeue.

As an example, I created an implementation of lazy quick-sort that creates a lazy tree structure (instead of producing a result list). This structure can be asked for any i-th element in O(n) time or a slice of k elements. Asking the same element again (or an nearby element) takes only O(log n) as the tree structure built in the previous step is reused. Traversing all elements takes O(n log n) time. (All assuming that we've chosen reasonable pivots.)
The key is that subtrees are not built right away, they're delayed in a lazy computation. So when asking only for a single element, the root node is computed in O(n), then one of its sub-nodes in O(n/2) etc. until the required element is found, taking O(n + n/2 + n/4 ...) = O(n). When the tree is fully evaluated, picking any element takes O(log n) as with any balanced tree.
Note that the implementation of build is quite inefficient. I wanted it to be simple and as easy to understand as possible. The important thing is that it has the proper asymptotic bounds.
import collection.immutable.Traversable
object LazyQSort {
/**
* Represents a value that is evaluated at most once.
*/
final protected class Thunk[+A](init: => A) extends Function0[A] {
override lazy val apply: A = init;
}
implicit protected def toThunk[A](v: => A): Thunk[A] = new Thunk(v);
implicit protected def fromThunk[A](t: Thunk[A]): A = t.apply;
// -----------------------------------------------------------------
/**
* A lazy binary tree that keeps a list of sorted elements.
* Subtrees are created lazily using `Thunk`s, so only
* the necessary part of the whole tree is created for
* each operation.
*
* Most notably, accessing any i-th element using `apply`
* takes O(n) time and traversing all the elements
* takes O(n * log n) time.
*/
sealed abstract class Tree[+A]
extends Function1[Int,A] with Traversable[A]
{
override def apply(i: Int) = findNth(this, i);
override def head: A = apply(0);
override def last: A = apply(size - 1);
def max: A = last;
def min: A = head;
override def slice(from: Int, until: Int): Traversable[A] =
LazyQSort.slice(this, from, until);
// We could implement more Traversable's methods here ...
}
final protected case class Node[+A](
pivot: A, leftSize: Int, override val size: Int,
left: Thunk[Tree[A]], right: Thunk[Tree[A]]
) extends Tree[A]
{
override def foreach[U](f: A => U): Unit = {
left.foreach(f);
f(pivot);
right.foreach(f);
}
override def isEmpty: Boolean = false;
}
final protected case object Leaf extends Tree[Nothing] {
override def foreach[U](f: Nothing => U): Unit = {}
override def size: Int = 0;
override def isEmpty: Boolean = true;
}
// -----------------------------------------------------------------
/**
* Finds i-th element of the tree.
*/
#annotation.tailrec
protected def findNth[A](tree: Tree[A], n: Int): A =
tree match {
case Leaf => throw new ArrayIndexOutOfBoundsException(n);
case Node(pivot, lsize, _, l, r)
=> if (n == lsize) pivot
else if (n < lsize) findNth(l, n)
else findNth(r, n - lsize - 1);
}
/**
* Cuts a given subinterval from the data.
*/
def slice[A](tree: Tree[A], from: Int, until: Int): Traversable[A] =
tree match {
case Leaf => Leaf
case Node(pivot, lsize, size, l, r) => {
lazy val sl = slice(l, from, until);
lazy val sr = slice(r, from - lsize - 1, until - lsize - 1);
if ((until <= 0) || (from >= size)) Leaf // empty
if (until <= lsize) sl
else if (from > lsize) sr
else sl ++ Seq(pivot) ++ sr
}
}
// -----------------------------------------------------------------
/**
* Builds a tree from a given sequence of data.
*/
def build[A](data: Seq[A])(implicit ord: Ordering[A]): Tree[A] =
if (data.isEmpty) Leaf
else {
// selecting a pivot is traditionally a complex matter,
// for simplicity we take the middle element here
val pivotIdx = data.size / 2;
val pivot = data(pivotIdx);
// this is far from perfect, but still linear
val (l, r) = data.patch(pivotIdx, Seq.empty, 1).partition(ord.lteq(_, pivot));
Node(pivot, l.size, data.size, { build(l) }, { build(r) });
}
}
// ###################################################################
/**
* Tests some operations and prints results to stdout.
*/
object LazyQSortTest extends App {
import util.Random
import LazyQSort._
def trace[A](name: String, comp: => A): A = {
val start = System.currentTimeMillis();
val r: A = comp;
val end = System.currentTimeMillis();
println("-- " + name + " took " + (end - start) + "ms");
return r;
}
{
val n = 1000000;
val rnd = Random.shuffle(0 until n);
val tree = build(rnd);
trace("1st element", println(tree.head));
// Second element is much faster since most of the required
// structure is already built
trace("2nd element", println(tree(1)));
trace("Last element", println(tree.last));
trace("Median element", println(tree(n / 2)));
trace("Median + 1 element", println(tree(n / 2 + 1)));
trace("Some slice", for(i <- tree.slice(n/2, n/2+30)) println(i));
trace("Traversing all elements", for(i <- tree) i);
trace("Traversing all elements again", for(i <- tree) i);
}
}
The output will be something like
0
-- 1st element took 268ms
1
-- 2nd element took 0ms
999999
-- Last element took 39ms
500000
-- Median element took 122ms
500001
-- Median + 1 element took 0ms
500000
...
500029
-- Slice took 6ms
-- Traversing all elements took 7904ms
-- Traversing all elements again took 191ms

You could use a Stream to build something like that. Here is a simple example, that can definitely be made better, but it works as an example, I guess.
def extractMin(xs: List[Int]) = {
def extractMin(xs: List[Int], min: Int, rest: List[Int]): (Int,List[Int]) = xs match {
case Nil => (min, rest)
case head :: tail if head > min => extractMin(tail, min, head :: rest)
case head :: tail => extractMin(tail, head, min :: rest)
}
if(xs.isEmpty) throw new NoSuchElementException("List is empty")
else extractMin(xs.tail, xs.head, Nil)
}
def lazySort(xs: List[Int]): Stream[Int] = xs match {
case Nil => Stream.empty
case _ =>
val (min, rest) = extractMin(xs)
min #:: lazySort(rest)
}

Related

Migrate a Traversable that uses a visitor to an Iterable in Scala 2.13

The migration guide to Scala 2.13 explains that Traversable has been removed and that Iterable should be used instead. This change is particularly annoying for one project, which is using a visitor to implement the foreach method in the Node class of a tree:
case class Node(val subnodes: Seq[Node]) extends Traversable[Node] {
override def foreach[A](f: Node => A) = Visitor.visit(this, f)
}
object Visitor {
def visit[A](n: Node, f: Node => A): Unit = {
f(n)
for (sub <- n.subnodes) {
visit(sub, f)
}
}
}
object Main extends App {
val a = Node(Seq())
val b = Node(Seq())
val c = Node(Seq(a, b))
for (Node(subnodes) <- c) {
Console.println("Visiting a node with " + subnodes.length + " subnodes")
}
}
Output:
Visiting a node with 2 subnodes
Visiting a node with 0 subnodes
Visiting a node with 0 subnodes
An easy fix to migrate to Scala 2.13 is to first store the visited elements in a remaining buffer, which is then used to return an iterator:
import scala.collection.mutable
import scala.language.reflectiveCalls
case class Node(val subnodes: Seq[Node]) extends Iterable[Node] {
override def iterator: Iterator[Node] = {
val remaining = mutable.Queue.empty[Node]
Visitor.visit(this, item => iterator.remaining.enqueue(item))
remaining.iterator
}
}
// Same Visitor object
// Same Main object
This solution has the disadvantages that it introduces new allocations that put pressure on the GC, because the number of visited elements is usually quite large.
Do you have suggestions on how to migrate from Traversable into Iterable, using the existing visitor but without introducing new allocations?
As you noticed, you need to extend Iterable instead of Traversable. You can do it like this:
case class Node(name: String, subnodes: Seq[Node]) extends Iterable[Node] {
override def iterator: Iterator[Node] = Iterator(this) ++ subnodes.flatMap(_.iterator)
}
val a = Node("a", Seq())
val b = Node("b", Seq())
val c = Node("c", Seq(a, b))
val d = Node("d", Seq(c))
for (node#Node(name, _) <- d) {
Console.println("Visiting node " + name + " with " + node.subnodes.length + " subnodes")
}
outputs:
Visiting node d with 1 subnodes
Visiting node c with 2 subnodes
Visiting node a with 0 subnodes
Visiting node b with 0 subnodes
Then you can do more operations such as:
d.count(_.subnodes.length > 1)
Code run at Scastie.
This is an example that your code can be implemented with LazyList and that visitor is not needed:
case class Node(val subnodes: Seq[Node]) {
def recursiveMap[A](f: Node => A): LazyList[A] = {
def expand(node: Node): LazyList[Node] = node #:: LazyList.from(node.subnodes).flatMap(expand)
expand(this).map(f)
}
}
val a = Node(Seq())
val b = Node(Seq())
val c = Node(Seq(a, b))
val lazyList = c.recursiveMap { node =>
println("computing value")
"Visiting a node with " + node.subnodes.length + " subnodes"
}
println("started computing values")
lazyList.iterator.foreach(println)
output
started computing values
computing value
Visiting a node with 2 subnodes
computing value
Visiting a node with 0 subnodes
computing value
Visiting a node with 0 subnodes
If you won't store lazyList reference yourself and only iterator, then JVM would be able to GC values as you go.
We ended up writing a minimal Traversable trait, implementing just the methods that are used in our codebase. This way there is no additional overhead and the visitor's logic doesn't need to be changed.
import scala.collection.mutable
/** A trait for traversable collections. */
trait Traversable[+A] {
self =>
/** Applies a function to all element of the collection. */
def foreach[B](f: A => B): Unit
/** Creates a filter of this traversable collection. */
def withFilter(p: A => Boolean): Traversable[A] = new WithFilter(p)
class WithFilter(p: A => Boolean) extends Traversable[A] {
/** Applies a function to all filtered elements of the outer collection. */
def foreach[U](f: A => U): Unit = {
for (x <- self) {
if (p(x)) f(x)
}
}
/** Further refines the filter of this collection. */
override def withFilter(q: A => Boolean): WithFilter = {
new WithFilter(x => p(x) && q(x))
}
}
/** Finds the first element of this collection for which the given partial
* function is defined, and applies the partial function to it.
*/
def collectFirst[B](pf: PartialFunction[A, B]): Option[B] = {
for (x <- self) {
if (pf.isDefinedAt(x)) {
return Some(pf(x))
}
}
None
}
/** Builds a new collection by applying a partial function to all elements
* of this collection on which the function is defined.
*/
def collect[B](pf: PartialFunction[A, B]): Iterable[B] = {
val elements = mutable.Queue.empty[B]
for (x <- self) {
if (pf.isDefinedAt(x)) {
elements.append(pf(x))
}
}
elements
}
}

Is it a library bug in a functional language when a function with the same name but for different collections produces different side effects?

I'm using Scala 2.13.1 and evaluate my examples in a worksheet.
At first, I define two functions that return the range of a to (z-1) as a stream or respectively a lazy list.
def streamRange(a: Int, z: Int): Stream[Int] = {
print(a + " ")
if (a >= z) Stream.empty else a #:: streamRange(a + 1, z)
}
def lazyListRange(a: Int, z: Int): LazyList[Int] = {
print(a + " ")
if (a >= z) LazyList.empty else a #:: lazyListRange(a + 1, z)
}
Then I call both functions, take a Stream/LazyList of 3 elements and convert them to List:
streamRange(1, 10).take(3).toList // prints 1 2 3
lazyListRange(1, 10).take(3).toList // prints 1 2 3 4
Here I do the same again:
val stream1 = streamRange(1, 10) // prints 1
val stream2 = stream1.take(3)
stream2.toList // prints 2 3
val lazyList1 = lazyListRange(1,10) // prints 1
val lazyList2 = lazyList1.take(3)
lazyList2.toList // prints 2 3 4
The 1 is printed because the function is visited and the print statement is at the start. No surprise.
But I don't understand why the additional 4 is printed for the lazy list and not for the stream.
My assumption is that at the point where 3 is to be concatenated with the next function call, the LazyList version visits the function, whereas in the Stream version the function is not visited. Otherwise the 4 would not have been printed.
It seems like unintended behaviour, at least it is unexpected. But would this difference in side effects be considered a bug or just a detailed difference in the evaluation of Stream and LazyList.
Stream implements #:: using Deferer:
implicit def toDeferrer[A](l: => Stream[A]): Deferrer[A] = new Deferrer[A](() => l)
final class Deferrer[A] private[Stream] (private val l: () => Stream[A]) extends AnyVal {
/** Construct a Stream consisting of a given first element followed by elements
* from another Stream.
*/
def #:: [B >: A](elem: B): Stream[B] = new Cons(elem, l())
/** Construct a Stream consisting of the concatenation of the given Stream and
* another Stream.
*/
def #:::[B >: A](prefix: Stream[B]): Stream[B] = prefix lazyAppendedAll l()
}
where Cons:
final class Cons[A](override val head: A, tl: => Stream[A]) extends Stream[A] {
Whereas LazyList implements #:: with its own Deferer:
implicit def toDeferrer[A](l: => LazyList[A]): Deferrer[A] = new Deferrer[A](() => l)
final class Deferrer[A] private[LazyList] (private val l: () => LazyList[A]) extends AnyVal {
/** Construct a LazyList consisting of a given first element followed by elements
* from another LazyList.
*/
def #:: [B >: A](elem: => B): LazyList[B] = newLL(sCons(elem, l()))
/** Construct a LazyList consisting of the concatenation of the given LazyList and
* another LazyList.
*/
def #:::[B >: A](prefix: LazyList[B]): LazyList[B] = prefix lazyAppendedAll l()
}
where sCons:
#inline private def sCons[A](hd: A, tl: LazyList[A]): State[A] = new State.Cons[A](hd, tl)
and Cons:
final class Cons[A](val head: A, val tail: LazyList[A]) extends State[A]
It means that on the very definition level:
Steam lazily evaluates it tail's creation
LazyList lazily evaluates its tail's content
Difference is noticeable among other in side-effects... which neither of these if made for.
If you want to handle potentially infinite sequences of impore computations, use a proper streaming library: Akka Streams, FS2, ZIO Streams. Build-in streams/lazy list are made for pure computations and if you step into impure directory you should assume that no guarantees regarding side effects are provided.

Scala Probabilistic Priority Queue - dequeue with probability by priority

I have a priority queue, which holds several tasks, each task with a numeric non-unique priority, as follows:
import scala.collection.mutable
class Task(val name: String, val priority: Int) {
override def toString = s"Task(name=$name, priority=$priority)"
}
val task_a = new Task("a", 5)
val task_b = new Task("b", 1)
val task_c = new Task("c", 5)
val pq: mutable.PriorityQueue[Task] =
new mutable.PriorityQueue()(Ordering.by(_.priority))
pq.enqueue(task_a)
pq.enqueue(task_b)
pq.enqueue(task_c)
I want to get the next task:
pq.dequeue()
But this way, I'll always get back task a, even though there's also task c with the same priority.
How to get one of the items with the highest priority randomly? That is to get either task a or task c, with 50/50 chance.
How to get any of the items randomly, with probability according to priority? That is to get 45% task a, 10% task b, and 45% task c.
This should be a good starting point:
abstract class ProbPriorityQueue[V] {
protected type K
protected implicit def ord: Ordering[K]
protected val impl: SortedMap[K, Set[V]]
protected val priority: V => K
def isEmpty: Boolean = impl.isEmpty
def dequeue: Option[(V, ProbPriorityQueue[V])] = {
if (isEmpty) {
None
} else {
// I wish Scala allowed us to collapse these operations...
val k = impl.lastKey
val s = impl(k)
val v = s.head
val s2 = s - v
val impl2 = if (s2.isEmpty)
impl - k
else
impl.updated(k, s2)
Some((v, ProbPriorityQueue.create(impl2, priority)))
}
}
}
object ProbPriorityQueue {
def apply[K: Ordering, V](vs: V*)(priority: V => K): ProbPriorityQueue = {
val impl = vs.foldLeft(SortedMap[K, Set[V]]()) {
case (acc, v) =>
val k = priority(v)
acc get k map { s => acc.updated(k, s + v) } getOrElse (acc + (k -> Set(v)))
}
create(impl, priority)
}
private def create[K0:, V](impl0: SortedMap[K0, Set[V]], priority0: V => K0)(implicit ord0: Ordering[K0]): ProbPriorityQueue[V] =
new ProbPriorityQueue[V] {
type K = K0
def ord = ord0
val impl = impl0
val priority = priority0
}
}
I didn't implement the select function, which would produce a value with weighted probability, but that shouldn't be too hard to do. In order to implement that function, you will need an additional mapping function (similar to priority) which has type K => Double, where Double is the probability weight attached to a particular key bucket. This makes everything somewhat messier, so it didn't seem worth bothering about.
Also this seems like a remarkably specific set of requirements. You're either doing a very interested bit of distributed scheduling, or homework.

Why does Stream.foldLeft take two elements before operating?

foldLeft needs only one element from the collection before operating. So why does it try to resolve two of them? Couldn't it be just a little bit lazier?
def stream(i: Int): Stream[Int] =
if (i < 100) {
println("taking")
i #:: stream(i + 1)
} else Stream.empty
scala> stream(97).foldLeft(0) { case (acc, i) =>
println("using")
acc + i
}
taking
taking
using
taking
using
using
res0: Int = 294
I ask this because I have a built a stream around a mutable priority queue, wherein the iteration of the fold can inject new members into the stream. It starts off with one value and during the first iteration injects more values. But those other values are never seen because the stream has already been resolved to empty in position 2 before the first iteration.
Can only explain why it's happening. Here is source of stream's #:: (Cons):
final class Cons[+A](hd: A, tl: => Stream[A]) extends Stream[A] {
override def isEmpty = false
override def head = hd
#volatile private[this] var tlVal: Stream[A] = _
#volatile private[this] var tlGen = tl _
def tailDefined: Boolean = tlGen eq null
override def tail: Stream[A] = {
if (!tailDefined)
synchronized {
if (!tailDefined) {
tlVal = tlGen()
tlGen = null
}
}
tlVal
}
}
So you can see that head is always calculated (isn't lazy). Here is foldLeft:
override final def foldLeft[B](z: B)(op: (B, A) => B): B = {
if (this.isEmpty) z
else tail.foldLeft(op(z, head))(op)
}
You can see that tail is called here, which means that "head of tail" (second element) becomes calculated automatically (as it requires your stream function to be called again to generate tail). So the better question isn't "why second" - the question is why Stream always calculates its first element. I don't know the answer, but believe that scala-library's implementation could be improved just by making head lazy inside Cons, so you could pass someLazyCalculation #:: stream(i + 1).
Note that eitherway your stream function will be called twice, but second approach gives you a way to avoid automatical second head's calculation by providing some lazy value as a head. Smthng like this could work then (now it doesn't):
def stream(i: Int): Stream[Int] =
if (i < 100) {
lazy val ii = {
println("taking")
i
}
ii #:: stream(i + 1)
} else Stream.empty
P.S. It's probably not so good idea to build (eventually) immutable collection around mutable one.

What type to use to store an in-memory mutable data table in Scala?

Each time a function is called, if it's result for a given set of argument values is not yet memoized I'd like to put the result into an in-memory table. One column is meant to store a result, others to store arguments values.
How do I best implement this? Arguments are of diverse types, including some enums.
In C# I'd generally use DataTable. Is there an equivalent in Scala?
You could use a mutable.Map[TupleN[A1, A2, ..., AN], R] , or if memory is a concern, a WeakHashMap[1]. The definitions below (built on the memoization code from michid's blog) allow you to easily memoize functions with multiple arguments. For example:
import Memoize._
def reallySlowFn(i: Int, s: String): Int = {
Thread.sleep(3000)
i + s.length
}
val memoizedSlowFn = memoize(reallySlowFn _)
memoizedSlowFn(1, "abc") // returns 4 after about 3 seconds
memoizedSlowFn(1, "abc") // returns 4 almost instantly
Definitions:
/**
* A memoized unary function.
*
* #param f A unary function to memoize
* #param [T] the argument type
* #param [R] the return type
*/
class Memoize1[-T, +R](f: T => R) extends (T => R) {
import scala.collection.mutable
// map that stores (argument, result) pairs
private[this] val vals = mutable.Map.empty[T, R]
// Given an argument x,
// If vals contains x return vals(x).
// Otherwise, update vals so that vals(x) == f(x) and return f(x).
def apply(x: T): R = vals getOrElseUpdate (x, f(x))
}
object Memoize {
/**
* Memoize a unary (single-argument) function.
*
* #param f the unary function to memoize
*/
def memoize[T, R](f: T => R): (T => R) = new Memoize1(f)
/**
* Memoize a binary (two-argument) function.
*
* #param f the binary function to memoize
*
* This works by turning a function that takes two arguments of type
* T1 and T2 into a function that takes a single argument of type
* (T1, T2), memoizing that "tupled" function, then "untupling" the
* memoized function.
*/
def memoize[T1, T2, R](f: (T1, T2) => R): ((T1, T2) => R) =
Function.untupled(memoize(f.tupled))
/**
* Memoize a ternary (three-argument) function.
*
* #param f the ternary function to memoize
*/
def memoize[T1, T2, T3, R](f: (T1, T2, T3) => R): ((T1, T2, T3) => R) =
Function.untupled(memoize(f.tupled))
// ... more memoize methods for higher-arity functions ...
/**
* Fixed-point combinator (for memoizing recursive functions).
*/
def Y[T, R](f: (T => R) => T => R): (T => R) = {
lazy val yf: (T => R) = memoize(f(yf)(_))
yf
}
}
The fixed-point combinator (Memoize.Y) makes it possible to memoize recursive functions:
val fib: BigInt => BigInt = {
def fibRec(f: BigInt => BigInt)(n: BigInt): BigInt = {
if (n == 0) 1
else if (n == 1) 1
else (f(n-1) + f(n-2))
}
Memoize.Y(fibRec)
}
[1] WeakHashMap does not work well as a cache. See http://www.codeinstructions.com/2008/09/weakhashmap-is-not-cache-understanding.html and this related question.
The version suggested by anovstrup using a mutable Map is basically the same as in C#, and therefore easy to use.
But if you want you can also use a more functional style as well. It uses immutable maps, which act as a kind of accumalator. Having Tuples (instead of Int in the example) as keys works exactly as in the mutable case.
def fib(n:Int) = fibM(n, Map(0->1, 1->1))._1
def fibM(n:Int, m:Map[Int,Int]):(Int,Map[Int,Int]) = m.get(n) match {
case Some(f) => (f, m)
case None => val (f_1,m1) = fibM(n-1,m)
val (f_2,m2) = fibM(n-2,m1)
val f = f_1+f_2
(f, m2 + (n -> f))
}
Of course this is a little bit more complicated, but a useful technique to know (note that the code above aims for clarity, not for speed).
Being a newbie in this subject, I could fully understand none of the examples given (but would like to thank anyway). Respectfully, I'd present my own solution for the case some one comes here having a same level and same problem. I think my code can be crystal clear for anybody having just the very-very basic Scala knowledge.
def MyFunction(dt : DateTime, param : Int) : Double
{
val argsTuple = (dt, param)
if(Memo.contains(argsTuple)) Memo(argsTuple) else Memoize(dt, param, MyRawFunction(dt, param))
}
def MyRawFunction(dt : DateTime, param : Int) : Double
{
1.0 // A heavy calculation/querying here
}
def Memoize(dt : DateTime, param : Int, result : Double) : Double
{
Memo += (dt, param) -> result
result
}
val Memo = new scala.collection.mutable.HashMap[(DateTime, Int), Double]
Works perfectly. I'd appreciate critique If I've missed something.
When using mutable map for memoization, one shall keep in mind that this would cause typical concurrency problems, e.g. doing a get when a write has not completed yet. However, thread-safe attemp of memoization suggests to do so it's of little value if not none.
The following thread-safe code creates a memoized fibonacci function, initiates a couple of threads (named from 'a' through to 'd') that make calls to it. Try the code a couple of times (in REPL), one can easily see f(2) set gets printed more than once. This means a thread A has initiated the calculation of f(2) but Thread B has totally no idea of it and starts its own copy of calculation. Such ignorance is so pervasive at the constructing phase of the cache, because all threads see no sub solution established and would enter the else clause.
object ScalaMemoizationMultithread {
// do not use case class as there is a mutable member here
class Memo[-T, +R](f: T => R) extends (T => R) {
// don't even know what would happen if immutable.Map used in a multithreading context
private[this] val cache = new java.util.concurrent.ConcurrentHashMap[T, R]
def apply(x: T): R =
// no synchronized needed as there is no removal during memoization
if (cache containsKey x) {
Console.println(Thread.currentThread().getName() + ": f(" + x + ") get")
cache.get(x)
} else {
val res = f(x)
Console.println(Thread.currentThread().getName() + ": f(" + x + ") set")
cache.putIfAbsent(x, res) // atomic
res
}
}
object Memo {
def apply[T, R](f: T => R): T => R = new Memo(f)
def Y[T, R](F: (T => R) => T => R): T => R = {
lazy val yf: T => R = Memo(F(yf)(_))
yf
}
}
val fibonacci: Int => BigInt = {
def fiboF(f: Int => BigInt)(n: Int): BigInt = {
if (n <= 0) 1
else if (n == 1) 1
else f(n - 1) + f(n - 2)
}
Memo.Y(fiboF)
}
def main(args: Array[String]) = {
('a' to 'd').foreach(ch =>
new Thread(new Runnable() {
def run() {
import scala.util.Random
val rand = new Random
(1 to 2).foreach(_ => {
Thread.currentThread().setName("Thread " + ch)
fibonacci(5)
})
}
}).start)
}
}
In addition to Landei's answer, I also want to suggest the bottom-up (non-memoization) way of doing DP in Scala is possible, and the core idea is to use foldLeft(s).
Example for computing Fibonacci numbers
def fibo(n: Int) = (1 to n).foldLeft((0, 1)) {
(acc, i) => (acc._2, acc._1 + acc._2)
}._1
Example for longest increasing subsequence
def longestIncrSubseq[T](xs: List[T])(implicit ord: Ordering[T]) = {
xs.foldLeft(List[(Int, List[T])]()) {
(memo, x) =>
if (memo.isEmpty) List((1, List(x)))
else {
val resultIfEndsAtCurr = (memo, xs).zipped map {
(tp, y) =>
val len = tp._1
val seq = tp._2
if (ord.lteq(y, x)) { // current is greater than the previous end
(len + 1, x :: seq) // reversely recorded to avoid O(n)
} else {
(1, List(x)) // start over
}
}
memo :+ resultIfEndsAtCurr.maxBy(_._1)
}
}.maxBy(_._1)._2.reverse
}