Merge sort from "Programming Scala" causes stack overflow - scala

A direct cut and paste of the following algorithm:
def msort[T](less: (T, T) => Boolean)
(xs: List[T]): List[T] = {
def merge(xs: List[T], ys: List[T]): List[T] =
(xs, ys) match {
case (Nil, _) => ys
case (_, Nil) => xs
case (x :: xs1, y :: ys1) =>
if (less(x, y)) x :: merge(xs1, ys)
else y :: merge(xs, ys1)
}
val n = xs.length / 2
if (n == 0) xs
else {
val (ys, zs) = xs splitAt n
merge(msort(less)(ys), msort(less)(zs))
}
}
causes a StackOverflowError on 5000 long lists.
Is there any way to optimize this so that this doesn't occur?

It is doing this because it isn't tail-recursive. You can fix this by either using a non-strict collection, or by making it tail-recursive.
The latter solution goes like this:
def msort[T](less: (T, T) => Boolean)
(xs: List[T]): List[T] = {
def merge(xs: List[T], ys: List[T], acc: List[T]): List[T] =
(xs, ys) match {
case (Nil, _) => ys.reverse ::: acc
case (_, Nil) => xs.reverse ::: acc
case (x :: xs1, y :: ys1) =>
if (less(x, y)) merge(xs1, ys, x :: acc)
else merge(xs, ys1, y :: acc)
}
val n = xs.length / 2
if (n == 0) xs
else {
val (ys, zs) = xs splitAt n
merge(msort(less)(ys), msort(less)(zs), Nil).reverse
}
}
Using non-strictness involves either passing parameters by-name, or using non-strict collections such as Stream. The following code uses Stream just to prevent stack overflow, and List elsewhere:
def msort[T](less: (T, T) => Boolean)
(xs: List[T]): List[T] = {
def merge(left: List[T], right: List[T]): Stream[T] = (left, right) match {
case (x :: xs, y :: ys) if less(x, y) => Stream.cons(x, merge(xs, right))
case (x :: xs, y :: ys) => Stream.cons(y, merge(left, ys))
case _ => if (left.isEmpty) right.toStream else left.toStream
}
val n = xs.length / 2
if (n == 0) xs
else {
val (ys, zs) = xs splitAt n
merge(msort(less)(ys), msort(less)(zs)).toList
}
}

Just playing around with scala's TailCalls (trampolining support), which I suspect wasn't around when this question was originally posed. Here's a recursive immutable version of the merge in Rex's answer.
import scala.util.control.TailCalls._
def merge[T <% Ordered[T]](x:List[T],y:List[T]):List[T] = {
def build(s:List[T],a:List[T],b:List[T]):TailRec[List[T]] = {
if (a.isEmpty) {
done(b.reverse ::: s)
} else if (b.isEmpty) {
done(a.reverse ::: s)
} else if (a.head<b.head) {
tailcall(build(a.head::s,a.tail,b))
} else {
tailcall(build(b.head::s,a,b.tail))
}
}
build(List(),x,y).result.reverse
}
Runs just as fast as the mutable version on big List[Long]s on Scala 2.9.1 on 64bit OpenJDK (Debian/Squeeze amd64 on an i7).

Just in case Daniel's solutions didn't make it clear enough, the problem is that merge's recursion is as deep as the length of the list, and it's not tail-recursion so it can't be converted into iteration.
Scala can convert Daniel's tail-recursive merge solution into something approximately equivalent to this:
def merge(xs: List[T], ys: List[T]): List[T] = {
var acc:List[T] = Nil
var decx = xs
var decy = ys
while (!decx.isEmpty || !decy.isEmpty) {
(decx, decy) match {
case (Nil, _) => { acc = decy.reverse ::: acc ; decy = Nil }
case (_, Nil) => { acc = decx.reverse ::: acc ; decx = Nil }
case (x :: xs1, y :: ys1) =>
if (less(x, y)) { acc = x :: acc ; decx = xs1 }
else { acc = y :: acc ; decy = ys1 }
}
}
acc.reverse
}
but it keeps track of all the variables for you.
(A tail-recursive method is one where the method only calls itself to get a complete answer to pass back; it never calls itself and then does something with the result before passing it back. Also, tail-recursion can't be used if the method might be polymorphic, so it generally only works in objects or with classes marked final.)

Related

What to set as default when the type is A (Scala)

I have an exercise in Scala in which I have to transform that kind of list (a,a,a,b,c,d,d,e,a,a) into
((a,3),(b,1),(c,1),(d,2),(e,1),(a,2)).
I obviously know that my algorithm is not correct yet, but I wanted to start with anything.
The problem is that I don't know how to turn on the function (last line), because the error is that whatever I take as previous' argument, it says that required: A, found: Int/String etc.
The previous was meant to be as a head of the previous iteration.
def compress[A](l: List[A]): List[(A, Int)] = {
def compressHelper(l: List[A], acc: List[(A, Int)], previous: A, counter: Int): List[(A, Int)] = {
l match {
case head::tail => {
if (head == previous) {
compressHelper(tail, acc :+ (head, counter+1), head, counter+1)
}
else {
compressHelper(tail, acc :+ (head, counter), head, 1)
}
}
case Nil => acc
}
}
compressHelper(l, List(), , 1)
}
You don't need to pass previous explicitly, just look at the accumulator:
def compress[A](l: List[A], acc: List[(A, Int)]=Nil): List[(A, Int)] =
(l, acc) match {
case (Nil, _) => acc.reverse
case (head :: tail, (a, n) :: rest) if a == head =>
compress(tail, (a, n+1) :: rest)
case (head :: tail, _) => compress (tail, (head, 1) :: acc)
}

Recursively handle nested lists in scala

I'm teaching myself scala and trying to fatten my FP skills.
One of my references, Essentials of Programming Languages (available here), has a handy list of easy recursive functions. On page page 27/50, we are asked to implement swapper() function.
(swapper s1 s2 slist) returns a list the same as slist, but
with all occurrences of s1 replaced by s2 and all occurrences of s2 replaced by s1.
> (swapper ’a ’d ’(a b c d))
(d b c a)
> (swapper ’a ’d ’(a d () c d))
(d a () c a)
> (swapper ’x ’y ’((x) y (z (x))))
((y) x (z (y)))
In scala, this is:
swapper("a", "d", List("a","b","c","d"))
swapper("a", "d", List("a","d",List(),"c","d"))
swapper("x", "y", List( List("x"), "y", List("z", List("x"))))
My scala version handles all versions save the final x.
def swapper(a: Any, b: Any, lst: List[Any]): List[Any] ={
def r(subList :List[Any], acc : List[Any]): List[Any] ={
def swap (x :Any, xs: List[Any]) =
if(x == a){
r(xs, acc :+ b)
} else if (x == b) {
r(xs, acc :+ a)
} else {
r(xs, acc :+ x)
}
subList match {
case Nil =>
acc
case List(x) :: xs =>
r(xs, r(List(x), List()) +: acc)
case x :: xs =>
swap(x,xs)
//case List(x) :: xs =>
}
}
r(lst, List())
}
Instinctively, I think this is because I have no swap on the section "case List(x) :: xs" but I'm still struggling to fix it.
More difficult, still, this case breaks the tail-call optimization. How can I do this and where can I go to learn more about the general solution?
You can use this foldRight with pattern match approach:
def swapper(a:Any, b:Any, list:List[Any]):List[Any] =
list.foldRight(List.empty[Any]) {
case (item, acc) if item==a => b::acc
case (item, acc) if item==b => a::acc
case (item:List[Any], acc) => swapper(a, b, item)::acc
case (item, acc) => item::acc
}
or even simplier (thanks to #marcospereira):
def swapper(a:Any, b:Any, list:List[Any]):List[Any] =
list.map {
case item if item==a => b
case item if item==b => a
case item:List[Any] => swapper(a, b, item)
case item => item
}
A simpler way to solve this is just use map:
def swapper[T](a: T, b: T, list: List[T]): List[T] = list.map { item =>
if (item == a) b
else if (item == b) a
else item
}
This seems to work.
def swapper[T](a: T, b: T, lst: List[_]): List[_] = {
val m = Map[T, T](a -> b, b -> a).withDefault(identity)
def swap(arg: List[_]): List[_] = arg.map{
case l: List[_] => swap(l)
case x: T => m(x)
}
swap(lst)
}
The List elements are inconsistent because it might be an element or it might be another List, so the type is List[Any], which is a sure sigh that someone needs to rethink this data representation.

scala parameterised merge sort - confusing error message

I am getting a compilation error when calling the (lt: (T,T) => Boolean) function
The error code is "type mismatch; found : x.type (with underlying type T) required: T"
and the x parameter in lt(x,y) is underlined.
object sort {
def msort[T](xs: List[T])(lt: (T, T) => Boolean): List[T] = {
def merge[T](xs: List[T], ys: List[T]): List[T] = (xs, ys) match {
case (Nil, ys) => ys
case (xs, Nil) => xs
case (x :: xs1, y :: ys1) => {
if (lt(x, y)) x :: merge(xs1, ys)
else y :: merge(xs, ys1)
}
}
val n = xs.length / 2
if (n == 0) xs
else {
val (left, right) = xs splitAt n
merge(msort(left)(lt), msort(right)(lt))
}
}
}
msort and merge have different type parameters. Remove the T type parameter from merge:
def merge(xs: List[T], ys: List[T]): List[T] = (xs, ys) match {
The [T] declares a new parameter unrelated to the first. You get the same error if you declare it as:
def merge[U](xs: List[U], ys: List[U]): List[U] = (xs, ys) match {
lt has a type (U, U) => Boolean, and you're calling it with x and y which have type T and don't match.

Insert an element between each two adjacent elements of Seq

For example, I have Seq(1,2,3) and I want to get Seq(1,0,2,0,3)
The first thing that comes to mind is:
scala> Seq(1,2,3).flatMap(e => 0 :: e :: Nil).tail
res17: Seq[Int] = List(1, 0, 2, 0, 3)
Is there any better/more elegant option?
Here is another approach:
def intersperse[E](x: E, xs:Seq[E]): Seq[E] = (x, xs) match {
case (_, Nil) => Nil
case (_, Seq(x)) => Seq(x)
case (sep, y::ys) => y+:sep+:intersperse(sep, ys)
}
which is safe over empty Seqs too.
Try for comprehension:
for(i <- list; p <- List(0, i)) yield p
However you must somehow remove the first element (it yields: 0,1,0,2,0,3), either by:
(for(i <- list; p <- List(0, i)) yield p).tail
or:
list.head :: (for(i <- list.tail; p <- List(0, i)) yield p)
def intersperse[T](xs: List[T], item: T): List[T] = xs match {
case Nil => xs
case _ :: Nil => xs
case a :: ys => a :: item :: intersperse(ys, item)
}
Can also use this extension:
implicit class SeqExtensions[A](val as: Seq[A]) extends AnyVal {
def intersperse(a: A): Seq[A] = {
val b = Seq.newBuilder[A]
val it = as.iterator
if (it.hasNext) {
b += it.next()
while(it.hasNext) {
b += a
b += it.next()
}
}
b.result()
}
}

How to group messages by username?

A message class:
case class Message(username:String, content:String)
A message list:
val list = List(
Message("aaa", "111"),
Message("aaa","222"),
Message("bbb","333"),
Message("aaa", "444"),
Message("aaa", "555"))
How to group the messages by name and get the following result:
List( "aaa"-> List(Message("aaa","111"), Message("aaa","222")),
"bbb" -> List(Message("bbb","333")),
"aaa" -> List(Message("aaa","444"), Message("aaa", "555")) )
That means, if a user post several messages, then group them together, until another user posted. The order should be kept.
I can't think of an easy way to do this with the provided Seq methods, but you can write your own pretty concisely with a fold:
def contGroupBy[A, B](s: List[A])(p: A => B) = (List.empty[(B, List[A])] /: s) {
case (((k, xs) :: rest), y) if k == p(y) => (k, y :: xs) :: rest
case (acc, y) => (p(y), y :: Nil) :: acc
}.reverse.map { case (k, xs) => (k, xs.reverse) }
Now contGroupBy(list)(_.username) gives you what you want.
I tried to create such a code which works not only with Lists and can be written in operator notation. I came up with this:
object Grouper {
import collection.generic.CanBuildFrom
class GroupingCollection[A, C, CC[C]](ca: C)(implicit c2i: C => Iterable[A]) {
def groupBySep[B](f: A => B)(implicit
cbf: CanBuildFrom[C,(B, C),CC[(B,C)]],
cbfi: CanBuildFrom[C,A,C]
): CC[(B, C)] =
if (ca.isEmpty) cbf().result
else {
val iter = c2i(ca).iterator
val outer = cbf()
val inner = cbfi()
val head = iter.next()
var olda = f(head)
inner += head
for (a <- iter) {
val fa = f(a)
if (olda != fa) {
outer += olda -> inner.result
inner.clear()
}
inner += a
olda = fa
}
outer += olda -> inner.result
outer.result
}
}
implicit def GroupingCollection[A, C[A]](ca: C[A])(
implicit c2i: C[A] => Iterable[A]
): GroupingCollection[A, C[A], C] =
new GroupingCollection[A, C[A],C](ca)(c2i)
}
Can be used (with Lists, Seqs, Arrays, ...) as:
list groupBySep (_.username)
def group(lst: List[Message], out: List[(String, List[Message])] = Nil)
: List[(String, List[Message])] = lst match {
case Nil => out.reverse
case Message(u, c) :: xs =>
val (same, rest) = lst span (_.username == u)
group(rest, (u -> same) :: out)
}
Tail recursive version. Usage is simply group(list).
(List[Tuple2[String,List[Message]]]() /: list) {
case (head :: tail, msg) if msg.username == head._1 =>
(msg.username -> (msg :: head._2)) :: tail
case (xs, msg) =>
(msg.username -> List(msg)) :: xs
} map { t => t._1 -> t._2.reverse } reverse
Here's another method using pattern matching and tail recursion. Probably not as efficient as those above though due to the use of both takeWhile and dropWhile.
def groupBy(msgs: List[Message]): List[(String,List[Message])] = msgs match {
case Nil => List()
case head :: tail => (head.username ->
(head :: tail.takeWhile(m => m.username == head.username))) +:
groupBy(tail.dropWhile(m => m.username == head.username))
}