How to write an early-return piece of code in scala with no returns/breaks?
For example
for i in 0..10000000
if expensive_operation(i)
return i
return -1
How about
input.find(expensiveOperation).getOrElse(-1)
You can use dropWhile
Here an example:
Seq(2,6,8,3,5).dropWhile(_ % 2 == 0).headOption.getOrElse(default = -1) // -> 8
And here you find more scala-takewhile-example
With your example
(0 to 10000000).dropWhile(!expensive_operation(_)).headOption.getOrElse(default = -1)`
Since you asked for intuition to solve this problem generically. Let me start from the basis.
Scala is (between other things) a functional programming language, as such there is a very important concept for us. And it is that we write programs by composing expressions rather than statements.
Thus, the concept of return value for us means the evaluation of an expression.
(Note this is related to the concept of referential transparency).
val a = expr // a is bounded to the evaluation of expr,
val b = (a, a) // and they are interchangeable, thus b === (expr, expr)
How this relates to your question. In the sense that we really do not have control structures but complex expressions. For example an if
val a = if (expr) exprA else exprB // if itself is an expression, that returns other expressions.
Thus instead of doing something like this:
def foo(a: Int): Int =
if (a != 0) {
val b = a * a
return b
}
return -1
We would do something like:
def foo(a: Int): Int =
if (a != 0)
a * a
else
-1
Because we can bound all the if expression itself as the body of foo.
Now, returning to your specific question. How can we early return a cycle?
The answer is, you can't, at least not without mutations. But, you can use a higher concept, instead of iterating, you can traverse something. And you can do that using recursion.
Thus, let's implement ourselves the find proposed by #Thilo, as a tail-recursive function.
(It is very important that the function is recursive by tail, so the compiler optimizes it as something equivalent to a while loop, that way we will not blow up the stack).
def find(start: Int, end: Int, step: Int = 1)(predicate: Int => Boolean): Option[Int] = {
#annotation.tailrec
def loop(current: Int): Option[Int] =
if (current == end)
None // Base case.
else if (predicate(current))
Some(current) // Early return.
else
loop(current + step) // Recursive step.
loop(current = start)
}
find(0, 10000)(_ == 10)
// res: Option[Int] = Some(10)
Or we may generalize this a little bit more, let's implement find for Lists of any kind of elements.
def find[T](list: List[T])(predicate: T => Boolean): Option[T] = {
#annotation.tailrec
def loop(remaining: List[T]): Option[T] =
remaining match {
case Nil => None
case t :: _ if (predicate(t)) => Some(t)
case _ :: tail => loop(remaining = tail)
}
loop(remaining = list)
}
This is not necessarily the best solution from a practical perspective but I still wanted to add it for educational purposes:
import scala.annotation.tailrec
def expensiveOperation(i: Int): Boolean = ???
#tailrec
def findFirstBy[T](f: (T) => Boolean)(xs: Seq[T]): Option[T] = {
xs match {
case Seq() => None
case Seq(head, _*) if f(head) => Some(head)
case Seq(_, tail#_*) => findFirstBy(f)(tail)
}
}
val result = findFirstBy(expensiveOperation)(Range(0, 10000000)).getOrElse(-1)
Please prefer collections methods (dropWhile, find, ...) in your production code.
There a lot of better answer here but I think a 'while' could work just fine in that situation.
So, this code
for i in 0..10000000
if expensive_operation(i)
return i
return -1
could be rewritten as
var i = 0
var result = false
while(!result && i<(10000000-1)) {
i = i+1
result = expensive_operation(i)
}
After the 'while' the variable 'result' will tell if it succeed or not.
Related
Suppose we have Seq val ourSeq = Seq(10,5,3,5,4).
I want to return a new list which reads from the left and stop when it sees a duplicate number (e.g. Seq(10,5,3) since 5 is repeated).
I was thinking of using fold left as such
ourSeq.foldLeft(Seq())(op = (temp, curr) => {
if (!temp.contains(curr)) {
temp :+ curr
} else break
})
but as far as I understand, there is no way to break out of a foldLeft?
Although it can be accomplished with a foldLeft() without any breaking out, I would argue that fold is the wrong tool for the job.
I'm rather fond of unfold(), which was introduced in Scala 2.13.0.
val ourSeq = Seq(10,5,3,5,4)
Seq.unfold((Set.empty[Int],ourSeq)){ case (seen,ns) =>
Option.when(ns.nonEmpty && !seen(ns.head)) {
(ns.head, (seen+ns.head, ns.tail))
}
}
//res0: Seq[Int] = Seq(10, 5, 3)
You are correct that it's not possible to break out of foldLeft. It would theoretically be possible to get the correct result with foldLeft, but you're still going to iterate the whole data structure. It'll be better to use an algorithm that already understands how to terminate early, and since you want to take a prefix, takeWhile will suffice.
import scala.collection.mutable.Set
val ourSeq = Seq(10, 5, 3, 5, 4)
val seen: Set[Int] = Set()
val untilDups = ourSeq.takeWhile((x) => {
if (seen contains x) {
false
} else {
seen += x
true
}
})
print(untilDups)
If you wanted to be totally immutable about this, you could wrap the whole thing in some kind of lazy fold that uses an immutable Set to keep its data. And that's certainly how I'd do it in Haskell. But this is Scala; we have mutability, and we may as well use it locally when it suits us.
This can be done using a recursive function:
def uniquePrefix[T](ourSeq: Seq[T]): List[T] = {
#annotation.tailrec
def loop(rem: List[T], res: List[T]): List[T] =
rem match {
case hd::tail if !res.contains(hd) =>
loop(tail, res :+ hd)
case _ =>
res
}
loop(ourSeq.toList, Nil)
}
This appears more complicated, but once you are familiar with the general pattern recursive functions are simple to write and more powerful than fold operations.
If you are working on large collections, this version is more efficient because it is O(n):
def distinctPrefix[T](ourSeq: Seq[T]): List[T] = {
#annotation.tailrec
def loop(rem: List[T], found: Set[T], res: List[T]): List[T] =
rem match {
case hd::tail if !found.contains(hd) =>
loop(tail, found + hd, hd +: res)
case _ =>
res.reverse
}
loop(ourSeq.toList, Set.empty, Nil)
}
This version works with any Seq and there are other options using Iterator etc. as described in the comments. You would need to be more specific about the type of the collection in order to create an optimised algorithm.
def uniquePrefix[T](ourSeq: Seq[T]): List[T] = {
#annotation.tailrec
def loop(rem: Seq[T], res: List[T]): List[T] =
rem.take(1) match {
case Seq(hd) if !res.contains(hd) =>
loop(rem.drop(1), res :+ hd)
case _ =>
res
}
loop(ourSeq, Nil)
}
Another option you have, is to use the function inits:
ourSeq.inits.dropWhile(curr => curr.distinct.size != curr.size).next()
Code run at Scastie.
The aim of the method is to take elements in a list until a limit is reached.
e.g.
I've come up with 2 different implementations
def take(l: List[Int], limit: Int): List[Int] = {
var sum = 0
l.takeWhile { e =>
sum += e
sum <= limit
}
}
It is straightforward, but a mutable state is used.
def take(l: List[Int], limit: Int): List[Int] = {
val summed = l.toStream.scanLeft(0) { case (e, sum) => sum + e }
l.take(summed.indexWhere(_ > limit) - 1)
}
It seems cleaner, but it's more verbose and perhaps less memory efficient because a stream is needed.
Is there a better way ?
You could also do that in a single pass with a fold:
def take(l: List[Int], limit: Int): List[Int] =
l.fold((List.empty[Int], 0)) { case ((res, acc), next) =>
if (acc + next > limit)
(res, limit)
else
(next :: res, next + acc)
}
Because the standard lists aren't lazy, and neither is fold, this will always traverse the entire list. One alternative would be to use cats' iteratorFoldM instead for an implementation that short circuits once the limit is reached.
You could also write the short circuiting fold directly using tail recursion, something along those lines:
def take(l: List[Int], limit: Int): List[Int] = {
#annotation.tailrec
def take0(list: List[Int], accList: List[Int], accSum: Int) : List[Int] =
list match {
case h :: t if accSum + h < limit =>
take0(t, h :: accList, h + accSum)
case _ => accList
}
take0(l, Nil, 0).reverse
}
Note that this second solution might be faster, but also less elegant as it requires additional effort to prove that the implementation terminates, something obvious when using a fold.
The first way is perfectly fine as the result of your function is still perfectly immutable.
On a side note, this is actually how many functions of the scala collection library are implemented, they create a mutable builder for efficiency and return an immutable collection out of it.
A functional way is to use recursive function and make sure it is stack safe.
If you just use basic scala:
import scala.annotation.tailrec
def take(l: List[Int], limit: Int) : List[Int] = {
#tailrec
def takeHelper(l:List[Int], limit:Int, r:List[Int]):List[Int] =
l match {
case h::t if (h <= limit ) => takeHelper(t, limit-h, r:+h)
case _ => r
}
takeHelper(l, limit, Nil)
}
If you can use scalaz Trampoline, it is a bit nicer:
import scalaz._
import scalaz.Scalaz._
import Free._
def take(l: List[Int], limit: Int): Trampoline[List[Int]] = {
l match {
case h :: t if (h <= limit) => suspend(take(t, limit - h)).map(h :: _)
case _ => return_(Nil)
}
}
println(take(List(1, 2, 3, 4, 0, 0, 1), 10).run)
println(take(List.fill(10000)(1), 100000000).run)
if you want to extend your own customize way, you could also use something like:
def custom(con: => Boolean)(i: Int)(a: => List[Int])(body: => Unit): List[Int] = {
if (con) {
body
custom(con)(i + 1)(a)(body)
}
else {
a.slice(0, i)
}
}
then call it like this:
var j = 100
val t = customTake(j > 80)(0)((0 to 99).toList) {
j -= 1
}
println(t)
I think your second version is already pretty good. You might tweak it a little, like this:
val sums = l.toStream.scanLeft(0){_ + _} drop 1
l zip sums takeWhile {_._2 <= limit} map (_._1)
This way you aren't dealing with indices, which is usually a little easier to follow.
I know of at least two styles to writing tail recursive functions. Take a sum function for example:
def sum1(xs: List[Int]): Int = {
def loop(xs: List[Int], acc: Int): Int = xs match {
case Nil => acc
case x :: xs1 => loop(xs1, acc + x)
}
loop(xs, 0)
}
vs
def sum2(xs: List[Int], acc: Int = 0): Int = xs match {
case Nil => acc
case x :: xs1 => sum2(xs1, x + acc)
}
I've noticed the first style (internal loop function) much more commonly than the second. Is there any reason to prefer it or is the difference just a matter of style?
There a couple of reasons to prefer the first notation.
Firstly, you define clearly to your reader what's the internal implementation from the external one.
Secondly, in your example the seed value is a pretty simple one that you can put straight as a default argument, but your seed value may be a very complicated-to-compute object that requires a longer init than default. Should this init for example require to be done asynchronously, you definitely want to put it out of your default value and manage with Futures or w/e.
Lastly, as Didier mentioned, the type of sum1 is a function from List[Int] -> Int (which makes sense), while the type of sum2 is a function from (List[Int], Int) -> Int which is less meaningful. Also, this implies that it's easier to pass sum1 around than sum2. For example, if you have an object that encapsulates a list of Int's and you want to provide synthesizer functions over it you can do (pseudocode, i dont have a repl to write it properly now):
class MyFancyList[T](val seed: List[T]) = {
type SyntFunction = (List[T] => Any)
var functions = Set[SyntFunction]
def addFunction(f: SyntFunction) = functions += f
def computeAll = {
for {
f <- functions
}
yield {
f(seed)
}
}
}
And you can do:
def concatStrings(list:List[Int]) = {
val listOfStrings = for {
n <- list
}
yield {
n+""
}
listOfStrings.mkString
}
val x = MyFancyList(List(1, 2, 3))
x.addFunction(sum1)
x.addFunction(concatStrings)
x.computeAll == List(6, "123")
but you can't add sum2 (not as easily at least)
I'm trying to write a scala function which will recursively sum the values in a list. Here is what I have so far :
def sum(xs: List[Int]): Int = {
val num = List(xs.head)
if(!xs.isEmpty) {
sum(xs.tail)
}
0
}
I dont know how to sum the individual Int values as part of the function. I am considering defining a new function within the function sum and have using a local variable which sums values as List is beuing iterated upon. But this seems like an imperative approach. Is there an alternative method ?
Also you can avoid using recursion directly and use some basic abstractions instead:
val l = List(1, 3, 5, 11, -1, -3, -5)
l.foldLeft(0)(_ + _) // same as l.foldLeft(0)((a,b) => a + b)
foldLeft is as reduce() in python. Also there is foldRight which is also known as accumulate (e.g. in SICP).
With recursion I often find it worthwhile to think about how you'd describe the process in English, as that often translates to code without too much complication. So...
"How do I calculate the sum of a list of integers recursively?"
"Well, what's the sum of a list, 3 :: restOfList?
"What's restOfList?
"It could be anything, you don't know. But remember, we're being recursive - and don't you have a function to calculate the sum of a list?"
"Oh right! Well then the sum would be 3 + sum(restOfList).
"That's right. But now your only problem is that every sum is defined in terms of another call to sum(), so you'll never be able to get an actual value out. You'll need some sort of base case that everything will actually reach, and that you can provide a value for."
"Hmm, you're right." Thinks...
"Well, since your lists are getting shorter and shorter, what's the shortest possible list?"
"The empty list?"
"Right! And what's the sum of an empty list of ints?"
"Zero - I get it now. So putting it together, the sum of an empty list is zero, and the sum of any other list is its first element added to the sum of the rest of it.
And indeed, the code could read almost exactly like that last sentence:
def sumList(xs: List[Int]) = {
if (xs.isEmpty) 0
else xs.head + sumList(xs.tail)
}
(The pattern matching versions, such as that proposed by Kim Stebel, are essentially identical to this, they just express the conditions in a more "functional" way.)
Here's the the "standard" recursive approach:
def sum(xs: List[Int]): Int = {
xs match {
case x :: tail => x + sum(tail) // if there is an element, add it to the sum of the tail
case Nil => 0 // if there are no elements, then the sum is 0
}
}
And, here's a tail-recursive function. It will be more efficient than a non-tail-recursive function because the compiler turns it into a while loop that doesn't require pushing a new frame on the stack for every recursive call:
def sum(xs: List[Int]): Int = {
#tailrec
def inner(xs: List[Int], accum: Int): Int = {
xs match {
case x :: tail => inner(tail, accum + x)
case Nil => accum
}
}
inner(xs, 0)
}
You cannot make it more easy :
val list = List(3, 4, 12);
println(list.sum); // result will be 19
Hope it helps :)
Your code is good but you don't need the temporary value num. In Scala [If] is an expression and returns a value, this will be returned as the value of the sum function. So your code will be refactored to:
def sum(xs: List[Int]): Int = {
if(xs.isEmpty) 0
else xs.head + sum(xs.tail)
}
If list is empty return 0 else you add the to the head number the rest of the list
The canonical implementation with pattern matching:
def sum(xs:List[Int]) = xs match {
case Nil => 0
case x::xs => x + sum(xs)
}
This isn't tail recursive, but it's easy to understand.
Building heavily on #Kim's answer:
def sum(xs: List[Int]): Int = {
if (xs.isEmpty) throw new IllegalArgumentException("Empty list provided for sum operation")
def inner(xs: List[Int]): Int = {
xs match {
case Nil => 0
case x :: tail => xs.head + inner(xs.tail)
}
}
return inner(xs)
}
The inner function is recursive and when an empty list is provided raise appropriate exception.
If you are required to write a recursive function using isEmpty, head and tail, and also throw exception in case empty list argument:
def sum(xs: List[Int]): Int =
if (xs.isEmpty) throw new IllegalArgumentException("sum of empty list")
else if (xs.tail.isEmpty) xs.head
else xs.head + sum(xs.tail)
def sum(xs: List[Int]): Int = {
def loop(accum: Int, xs: List[Int]): Int = {
if (xs.isEmpty) accum
else loop(accum + xs.head, xs.tail)
}
loop(0,xs)
}
def sum(xs: List[Int]): Int = xs.sum
scala> sum(List(1,3,7,5))
res1: Int = 16
scala> sum(List())
res2: Int = 0
To add another possible answer to this, here is a solution I came up with that is a slight variation of #jgaw's answer and uses the #tailrec annotation:
def sum(xs: List[Int]): Int = {
if (xs.isEmpty) throw new Exception // May want to tailor this to either some sort of case class or do something else
#tailrec
def go(l: List[Int], acc: Int): Int = {
if (l.tail == Nil) l.head + acc // If the current 'list' (current element in xs) does not have a tail (no more elements after), then we reached the end of the list.
else go(l.tail, l.head + acc) // Iterate to the next, add on the current accumulation
}
go(xs, 0)
}
Quick note regarding the checks for an empty list being passed in; when programming functionally, it is preferred to not throw any exceptions and instead return something else (another value, function, case class, etc.) to handle errors elegantly and to keep flowing through the path of execution rather than stopping it via an Exception. I threw one in the example above since we're just looking at recursively summing items in a list.
Tried the following method without using substitution approach
def sum(xs: List[Int]) = {
val listSize = xs.size
def loop(a:Int,b:Int):Int={
if(a==0||xs.isEmpty)
b
else
loop(a-1,xs(a-1)+b)
}
loop(listSize,0)
}
With the intention of learning and further to this question, I've remained curious of the idiomatic alternatives to explicit recursion for an algorithm that checks whether a list (or collection) is ordered. (I'm keeping things simple here by using an operator to compare and Int as type; I'd like to look at the algorithm before delving into the generics of it)
The basic recursive version would be (by #Luigi Plinge):
def isOrdered(l:List[Int]): Boolean = l match {
case Nil => true
case x :: Nil => true
case x :: xs => x <= xs.head && isOrdered(xs)
}
A poor performing idiomatic way would be:
def isOrdered(l: List[Int]) = l == l.sorted
An alternative algorithm using fold:
def isOrdered(l: List[Int]) =
l.foldLeft((true, None:Option[Int]))((x,y) =>
(x._1 && x._2.map(_ <= y).getOrElse(true), Some(y)))._1
It has the drawback that it will compare for all n elements of the list even if it could stop earlier after finding the first out-of-order element. Is there a way to "stop" fold and therefore making this a better solution?
Any other (elegant) alternatives?
This will exit after the first element that is out of order. It should thus perform well, but I haven't tested that. It's also a lot more elegant in my opinion. :)
def sorted(l:List[Int]) = l.view.zip(l.tail).forall(x => x._1 <= x._2)
By "idiomatic", I assume you're talking about McBride and Paterson's "Idioms" in their paper Applicative Programming With Effects. :o)
Here's how you would use their idioms to check if a collection is ordered:
import scalaz._
import Scalaz._
case class Lte[A](v: A, b: Boolean)
implicit def lteSemigroup[A:Order] = new Semigroup[Lte[A]] {
def append(a1: Lte[A], a2: => Lte[A]) = {
lazy val b = a1.v lte a2.v
Lte(if (!a1.b || b) a1.v else a2.v, a1.b && b && a2.b)
}
}
def isOrdered[T[_]:Traverse, A:Order](ta: T[A]) =
ta.foldMapDefault(x => some(Lte(x, true))).fold(_.b, true)
Here's how this works:
Any data structure T[A] where there exists an implementation of Traverse[T], can be traversed with an Applicative functor, or "idiom", or "strong lax monoidal functor". It just so happens that every Monoid induces such an idiom for free (see section 4 of the paper).
A monoid is just an associative binary operation over some type, and an identity element for that operation. I'm defining a Semigroup[Lte[A]] (a semigroup is the same as a monoid, except without the identity element) whose associative operation tracks the lesser of two values and whether the left value is less than the right value. And of course Option[Lte[A]] is just the monoid generated freely by our semigroup.
Finally, foldMapDefault traverses the collection type T in the idiom induced by the monoid. The result b will contain true if each value was less than all the following ones (meaning the collection was ordered), or None if the T had no elements. Since an empty T is sorted by convention, we pass true as the second argument to the final fold of the Option.
As a bonus, this works for all traversable collections. A demo:
scala> val b = isOrdered(List(1,3,5,7,123))
b: Boolean = true
scala> val b = isOrdered(Seq(5,7,2,3,6))
b: Boolean = false
scala> val b = isOrdered(Map((2 -> 22, 33 -> 3)))
b: Boolean = true
scala> val b = isOrdered(some("hello"))
b: Boolean = true
A test:
import org.scalacheck._
scala> val p = forAll((xs: List[Int]) => (xs /== xs.sorted) ==> !isOrdered(xs))
p:org.scalacheck.Prop = Prop
scala> val q = forAll((xs: List[Int]) => isOrdered(xs.sorted))
q: org.scalacheck.Prop = Prop
scala> p && q check
+ OK, passed 100 tests.
And that's how you do idiomatic traversal to detect if a collection is ordered.
I'm going with this, which is pretty similar to Kim Stebel's, as a matter of fact.
def isOrdered(list: List[Int]): Boolean = (
list
sliding 2
map {
case List(a, b) => () => a < b
}
forall (_())
)
In case you missed missingfaktor's elegant solution in the comments above:
Scala < 2.13.0
(l, l.tail).zipped.forall(_ <= _)
Scala 2.13.x+
l.lazyZip(l.tail).forall(_ <= _)
This solution is very readable and will exit on the first out-of-order element.
The recursive version is fine, but limited to List (with limited changes, it would work well on LinearSeq).
If it was implemented in the standard library (would make sense) it would probably be done in IterableLike and have a completely imperative implementation (see for instance method find)
You can interrupt the foldLeft with a return (in which case you need only the previous element and not boolean all along)
import Ordering.Implicits._
def isOrdered[A: Ordering](seq: Seq[A]): Boolean = {
if (!seq.isEmpty)
seq.tail.foldLeft(seq.head){(previous, current) =>
if (previous > current) return false; current
}
true
}
but I don't see how it is any better or even idiomatic than an imperative implementation. I'm not sure I would not call it imperative actually.
Another solution could be
def isOrdered[A: Ordering](seq: Seq[A]): Boolean =
! seq.sliding(2).exists{s => s.length == 2 && s(0) > s(1)}
Rather concise, and maybe that could be called idiomatic, I'm not sure. But I think it is not too clear. Moreover, all of those methods would probably perform much worse than the imperative or tail recursive version, and I do not think they have any added clarity that would buy that.
Also you should have a look at this question.
To stop iteration, you can use Iteratee:
import scalaz._
import Scalaz._
import IterV._
import math.Ordering
import Ordering.Implicits._
implicit val ListEnumerator = new Enumerator[List] {
def apply[E, A](e: List[E], i: IterV[E, A]): IterV[E, A] = e match {
case List() => i
case x :: xs => i.fold(done = (_, _) => i,
cont = k => apply(xs, k(El(x))))
}
}
def sorted[E: Ordering] : IterV[E, Boolean] = {
def step(is: Boolean, e: E)(s: Input[E]): IterV[E, Boolean] =
s(el = e2 => if (is && e < e2)
Cont(step(is, e2))
else
Done(false, EOF[E]),
empty = Cont(step(is, e)),
eof = Done(is, EOF[E]))
def first(s: Input[E]): IterV[E, Boolean] =
s(el = e1 => Cont(step(true, e1)),
empty = Cont(first),
eof = Done(true, EOF[E]))
Cont(first)
}
scala> val s = sorted[Int]
s: scalaz.IterV[Int,Boolean] = scalaz.IterV$Cont$$anon$2#5e9132b3
scala> s(List(1,2,3)).run
res11: Boolean = true
scala> s(List(1,2,3,0)).run
res12: Boolean = false
If you split the List into two parts, and check whether the last of the first part is lower than the first of the second part. If so, you could check in parallel for both parts. Here the schematic idea, first without parallel:
def isOrdered (l: List [Int]): Boolean = l.size/2 match {
case 0 => true
case m => {
val low = l.take (m)
val high = l.drop (m)
low.last <= high.head && isOrdered (low) && isOrdered (high)
}
}
And now with parallel, and using splitAt instead of take/drop:
def isOrdered (l: List[Int]): Boolean = l.size/2 match {
case 0 => true
case m => {
val (low, high) = l.splitAt (m)
low.last <= high.head && ! List (low, high).par.exists (x => isOrdered (x) == false)
}
}
def isSorted[A <: Ordered[A]](sequence: List[A]): Boolean = {
sequence match {
case Nil => true
case x::Nil => true
case x::y::rest => (x < y) && isSorted(y::rest)
}
}
Explain how it works.
my solution combine with missingfaktor's solution and Ordering
def isSorted[T](l: Seq[T])(implicit ord: Ordering[T]) = (l, l.tail).zipped.forall(ord.lt(_, _))
and you can use your own comparison method. E.g.
isSorted(dataList)(Ordering.by[Post, Date](_.lastUpdateTime))