Scala and State Monad - scala

I have been trying to understand the State Monad. Not so much how it is used, though that is not always easy to find, either. But every discussion I find of the State Monad has basically the same information and there is always something I don't understand.
Take this post, for example. In it the author has the following:
case class State[S, A](run: S => (A, S)) {
...
def flatMap[B](f: A => State[S, B]): State[S, B] =
State(s => {
val (a, t) = run(s)
f(a) run t
})
...
}
I can see that the types line up correctly. However, I don't understand the second run at all.
Perhaps I am looking at the whole purpose of this monad incorrectly. I got the impression from the HaskellWiki that the State monad was kind of like a state-machine with the run allowing for transitions (though, in this case, the state-machine doesn't really have fixed state transitions like most state machines). If that is the case then in the above code (a, t) would represent a single transition. The application of f would represent a modification of that value and State (generating a new State object). That leaves me completely confused as to what the second run is all about. It would appear to be a second 'transition'. But that doesn't make any sense to me.
I can see that calling run on the resulting State object produces a new (A, S) pair which, of course, is required for the types to line up. But I don't really see what this is supposed to be doing.
So, what is really going on here? What is the concept being modeled here?
Edit: 12/22/2015
So, it appears I am not expressing my issue very well. Let me try this.
In the same blog post we see the following code for map:
def map[B](f: A => B): State[S, B] =
State(s => {
val (a, t) = run(s)
(f(a), t)
})
Obviously there is only a single call to run here.
The model I have been trying to reconcile is that a call to run moves the state we are keeping forward by a single state-change. This seems to be the case in map. However, in flatMap we have two calls to run. If my model was correct that would result in 'skipping over' a state change.
To make use of the example #Filppo provided below, the first call to run would result in returning (1, List(2,3,4,5)) and the second would result in (2, List(3,4,5)), effectively skipping over the first one. Since, in his example, this was followed immediately by a call to map, this would have resulted in (Map(a->2, b->3), List(4,5)).
Apparently that is not what is happening. So my whole model is incorrect. What is the correct way to reason about this?
2nd Edit: 12/22/2015
I just tried doing what I said in the REPL. And my instincts were correct which leaves me even more confused.
scala> val v = State(head[Int]).flatMap { a => State(head[Int]) }
v: State[List[Int],Int] = State(<function1>
scala> v.run(List(1,2,3,4,5))
res2: (Int, List[Int]) = (2,List(3, 4, 5))
So, this implementation of flatMap does skip over a state. Yet when I run #Filippo's example I get the same answer he does. What is really happening here?

To understand the "second run" let's analyse it "backwards".
The signature def flatMap[B](f: A => State[S, B]): State[S, B] suggests that we need to run a function f and return its result.
To execute function f we need to give it an A. Where do we get one?
Well, we have run that can give us A out of S, so we need an S.
Because of that we do: s => val (a, t) = run(s) ....
We read it as "given an S execute the run function which produces us A and a new S. And this is our "first" run.
Now we have an A and we can execute f. That's what we wanted and f(a) gives us a new State[S, B].
If we do that then we have a function which takes S and returns Stats[S, B]:
(s: S) =>
val (a, t) = run(s)
f(a) //State[S, B]
But function S => State[S, B] isn't what we want to return! We want to return just State[S, B].
How do we do that? We can wrap this function into State:
State(s => ... f(a))
But it doesn't work because State takes S => (B, S), not S => State[B, S].
So we need to get (B, S) out of State[B, S].
We do it by just calling its run method and providing it with the state we just produced on the previous step!
And it is our "second" run.
So as a result we have the following transformation performed by a flatMap:
s => // when a state is provided
val (a, t) = run(s) // produce an `A` and a new state value
val resState = f(a) // produce a new `State[S, B]`
resState.run(t) // return `(S, B)`
This gives us S => (S, B) and we just wrap it with the State constructor.
Another way of looking at these "two runs" is:
first - we transform the state ourselves with "our" run function
second - we pass that transformed state to the function f and let it do its own transformation.
So we kind of "chaining" state transformations one after another. And that's exactly what monads do: they provide us with the ability to schedule computation sequentially.

The state monad boils down to this function from one state to another state (plus A):
type StatefulComputation[S, +A] = S => (A, S)
The implementation mentioned by Tony in that blog post "capture" that function into run of the case class:
case class State[S, A](run: S => (A, S))
The flatmap implementation to bind a state to another state is calling 2 different runs:
// the `run` on the actual `state`
val (a: A, nextState: S) = run(s)
// the `run` on the bound `state`
f(a).run(nextState)
EDIT Example of flatmap between 2 State
Considering a function that simply call .head to a List to get A, and .tail for the next state S
// stateful computation: `S => (A, S)` where `S` is `List[A]`
def head[A](xs: List[A]): (A, List[A]) = (xs.head, xs.tail)
A simple binding of 2 State(head[Int]):
// flatmap example
val result = for {
a <- State(head[Int])
b <- State(head[Int])
} yield Map('a' -> a,
'b' -> b)
The expect behaviour of the for-comprehension is to "extract" the first element of a list into a and the second one in b. The resulting state S would be the remaining tail of the run list:
scala> result.run(List(1, 2, 3, 4, 5))
(Map(a -> 1, b -> 2),List(3, 4, 5))
How? Calling the "stateful computation" head[Int] that is in run on some state s:
s => run(s)
That gives the head (A) and the tail (B) of the list. Now we need to pass the tail to the next State(head[Int])
f(a).run(t)
Where f is in the flatmap signature:
def flatMap[B](f: A => State[S, B]): State[S, B]
Maybe to better understand what is f in this example, we should de-sugar the for-comprehension to:
val result = State(head[Int]).flatMap {
a => State(head[Int]).map {
b => Map('a' -> a, 'b' -> b)
}
}
With f(a) we pass a into the function and with run(t) we pass the modified state.

I have accepted #AlexyRaga's answer to my question. I think #Filippo's answer was very good as well and, in fact, gave me some additional food for thought. Thanks to both of you.
I think the conceptual difficulty I was having was really mostly to do with 'what does the run method 'mean'. That is, what is its purpose and result. I was looking at it as a 'transition' function (from one state to the next). And, after a fashion, that is what it does. However, it doesn't transition from a given (this) state to the next state. Instead, it takes an initial State and returns the (this) state's value and a new 'current' state (not the next state in the state-transition sequence).
That is why the flatMap method is implemented the way it is. When you generate a new State then you need the current value/state pair from it based on the passed-in initial state which can then be wrapped in a new State object as a function. You are not really transitioning to a new state. Just re-wrapping the generated state in a new State object.
I was too steeped in traditional state machines to see what was going on here.
Thank, again, everyone.

Related

Scala: How Does flatMap Work In This Case?

//this class holds method run that let us give a state and obtain from it an
//output value and next state
case class State[S, +A](run: S => (A, S)) {
//returns a state for which the run function will return (b,s) instead of (a,s)
def map[B](f: A => B): State[S, B] =
flatMap(a => unit(f(a)))
//returns a state for which the run function will return (f(a,b),s)
def map2[B, C](sb: State[S, B])(f: (A, B) => C): State[S, C] =
flatMap(a => sb.map(b => f(a, b)))
def flatMap[B](f: A => State[S, B]): State[S, B] =
State(s => {
val (a, s1) = run(s)
val sB = f(a)
sB.run(s1)
})
}
I am new to functional programming and I am dizzied by this code, I tried to understand what does flatMap really does but it looks very abstract to me that I couldn't understand it. For example here we defined map in terms of flatMap, but I can't understand how can we explain the implementation I mean how does it work?
This chapter is about pure functional state, and describes how the state is being updated in a sequence of operations in a functional way, that is to say, by means of functional composition. So you don´t have to create a global state variable that is mutated by a sequence of method executions, as you probably would do in a OOP fashion.
Having said that, this wonderful book explains each concept declaring an abstraction to represent an effect, and some primitives which are going to be used to build more complex functions. In this chapter, the effect is the State. It is represented as a function that takes a previous state and yields a value and a new state, as you can see with the run function declaration. So if you want to compose functions that propagate some state in a pure functional fashion, all of them need to return a new functional State.
With flatMap, which is a primitive of the State type, you can compose functions and, at the same time, propagate some internal state.
For example, if you have two functions:
case class Intermediate(v: Double = 0)
def execute1(a: Int): State[Int, Intermediate] = {
State { s =>
/** Do whatever */
(Intermediate(a * a), s + 1)
}
}
def execute2(a: Int): State[Int, Intermediate] = {
State { s =>
/** Do whatever */
(Intermediate(a * a), s + 1)
}
}
As you can see each function consist in a State function that takes a previous state, executes some logic, and returns a tuple of Int, which is going to be used to track the number of functions that have been executed, and a Intermediate value that we use to yield the execution of every function.
Now, if you apply flat map to execute both functions sequentially:
val result: State[Int, Intermediate] =
execute1(2) flatMap (r => execute2(r.v.toInt))
the result val is another State that you can see as the point of entry to run the pipeline:
/** 0 would be the initial state **/
result.run(0)
the result is:
(Intermediate(16.0),2)
How this works under the hood:
First you have the first function execute1. This function returns a State[Intermediate, Int]. I will use Intermediate to store the result of each execution, that would be the yielded value, and an Int to count the number of functions that are going to be executed in the flatMap sequential composition.
Then I apply flatMap after the function call, which is the flatMap of the State type.
Remember the definition:
def flatMap[B](f: A => State[S, B]): State[S, B] =
State(s => {
/* This function is executed in a State context */
val (a, s1) = run(s)
val sB = f(a)
sB.run(s1)
})
The execution is inside a state context(note that the flatMap function is inside the State type) which is be the first state function of the execute1 function. Maybe this representation is easier:
val st1 = execute1(2)
st1.flatMap(r => execute2(r.v.toInt))
So when run(s) is executed, it actually executes the function state returned by the execute1(2) function. It gives the first (a, s1) result.
Then you have to execute the second function, execute2. You can visualize this function as the input parameter of the flatMap function. This function is executed with the yielded value of the execute1(2) function (the Intermediate type, left side of the tuple). This execution returns another state function that uses the state(right value of the tuple) of the execution1, to create the final tuple with the resulting yielded value of the execute2 function and the int corresponding to the number of executions. Note, that in each function we increment the internal state in one unit.
Note that the result is another State. This has a run method that waits for the initial state.
I have tried to explain the workflow of function calls... it is a little bit hard if you are not used to think in a functional style, but it is only functional composition, it becomes familiar after a while.

Is there a map function that also applies a context?

Is there a map like function in Scala that renders: M[A] -> M[B] but also allows you to specify some context info kinda like in scanLeft? For instance, something like:
def mapWithContext(context: C, f: (C, A) => (C, B)): M[B]
f would be passed 1) the context info of type C 2) the current element in the list of type A
and it would then returns 1) the updated context info C 2) the transformed list element of type B. However, unlike scanLeft it would then discard the context info from the collection it is building and only return the B's
Not exactly, but the new (Scala 2.13) unfold() method comes close. It's available on almost all collection types, mutable and immutable. Here's the profile for List.
def unfold(init: S)(f: (S) => Option[(A, S)]): List[A]
It takes an initial state and continually applies f() until it returns None.
So what you could to is make the initial state a tuple (S, M[A]) and make the f() function f:(S,M[A]) => Option[(B, (S,M[A]))]. The f() code would peal off the head of M[A] to produce each B element until M[A] is empty, when f() returns None.

Avoiding downcasts in generic classes

Noticing that my code was essentially iterating over a list and updating a value in a Map, I first created a trivial helper method which took a function for the transformation of the map value and return an updated map. As the program evolved, it gained a few other Map-transformation functions, so it was natural to turn it into an implicit value class that adds methods to scala.collection.immutable.Map[A, B]. That version works fine.
However, there's nothing about the methods that require a specific map implementation and they would seem to apply to a scala.collection.Map[A, B] or even a MapLike. So I would like it to be generic in the map type as well as the key and value types. This is where it all goes pear-shaped.
My current iteration looks like this:
implicit class RichMap[A, B, MapType[A, B] <: collection.Map[A, B]](
val self: MapType[A, B]
) extends AnyVal {
def updatedWith(k: A, f: B => B): MapType[A, B] =
self updated (k, f(self(k)))
}
This code does not compile because self updated (k, f(self(k))) isa scala.collection.Map[A, B], which is not a MapType[A, B]. In other words, the return type of self.updated is as if self's type was the upper type bound rather than the actual declared type.
I can "fix" the code with a downcast:
def updatedWith(k: A, f: B => B): MapType[A, B] =
self.updated(k, f(self(k))).asInstanceOf[MapType[A, B]]
This does not feel satisfactory because downcasting is a code smell and indicates misuse of the type system. In this particular case it would seem that the value will always be of the cast-to type, and that the whole program compiles and runs correctly with this downcast supports this view, but it still smells.
So, is there a better way to write this code to have scalac correctly infer types without using a downcast, or is this a compiler limitation and a downcast is necessary?
[Edited to add the following.]
My code which uses this method is somewhat more complex and messy as I'm still exploring a few ideas, but an example minimum case is the computation of a frequency distribution as a side-effect with code roughly like this:
var counts = Map.empty[Int, Int] withDefaultValue 0
for (item <- items) {
// loads of other gnarly item-processing code
counts = counts updatedWith (count, 1 + _)
}
There are three answers to my question at the time of writing.
One boils down to just letting updatedWith return a scala.collection.Map[A, B] anyway. Essentially, it takes my original version that accepted and returned an immutable.Map[A, B], and makes the type less specific. In other words, it's still insufficiently generic and sets policy on which types the caller uses. I can certainly change the type on the counts declaration, but that is also a code smell to work around a library returning the wrong type, and all it really does is move the downcast into the caller's code. So I don't really like this answer at all.
The other two are variations on CanBuildFrom and builders in that they essentially iterate over the map to produce a modified copy. One inlines a modified updated method, whereas the other calls the original updated and appends it to the builder and thus appears to make an extra temporary copy. Both are good answers which solve the type correctness problem, although the one that avoids an extra copy is the better of the two from a performance standpoint and I prefer it for that reason. The other is however shorter and arguably more clearly shows intent.
In the case of a hypothetical immutable Map that shares large trees in a similar vein to List, this copying would break the sharing and reduce performance and so it would be preferable to use the existing modified without performing copies. However, Scala's immutable maps don't appear to do this and so copying (once) seems to be the pragmatic solution that is unlikely to make any difference in practice.
Yes! Use CanBuildFrom. This is how the Scala collections library infers the closest collection type to the one you want, using CanBuildFrom evidence. So long as you have implicit evidence of CanBuildFrom[From, Elem, To], where From is the type of collection you're starting with, Elem is the type contained within the collection, and To is the end result you want. The CanBuildFrom will supply a Builder to which you can add elements to, and when you're done, you can call Builder#result() to get the completed collection of the appropriate type.
In this case:
From = MapType[A, B]
Elem = (A, B) // The type actually contained in maps
To = MapType[A, B]
Implementation:
import scala.collection.generic.CanBuildFrom
implicit class RichMap[A, B, MapType[A, B] <: collection.Map[A, B]](
val self: MapType[A, B]
) extends AnyVal {
def updatedWith(k: A, f: B => B)(implicit cbf: CanBuildFrom[MapType[A, B], (A, B), MapType[A, B]]): MapType[A, B] = {
val builder = cbf()
builder ++= self.updated(k, f(self(k)))
builder.result()
}
}
scala> val m = collection.concurrent.TrieMap(1 -> 2, 5 -> 3)
m: scala.collection.concurrent.TrieMap[Int,Int] = TrieMap(1 -> 2, 5 -> 3)
scala> m.updatedWith(1, _ + 10)
res1: scala.collection.concurrent.TrieMap[Int,Int] = TrieMap(1 -> 12, 5 -> 3)
Please note that updated method returns Map class, rather than generic, so I would say you should be fine returning Map as well. But if you really want to return a proper type, you could have a look at implementation of updated in List.updated
I've wrote a small example. I'm not sure it covers all the cases, but it works on my tests. I also used mutable Map, because it was harder for me to test immutable, but I guess it can be easily converted.
implicit class RichMap[A, B, MapType[x, y] <: Map[x, y]](val self: MapType[A, B]) extends AnyVal {
import scala.collection.generic.CanBuildFrom
def updatedWith[R >: B](k: A, f: B => R)(implicit bf: CanBuildFrom[MapType[A, B], (A, R), MapType[A, R]]): MapType[A, R] = {
val b = bf(self)
for ((key, value) <- self) {
if (key != k) {
b += (key -> value)
} else {
b += (key -> f(value))
}
}
b.result()
}
}
import scala.collection.immutable.{TreeMap, HashMap}
val map1 = HashMap(1 -> "s", 2 -> "d").updatedWith(2, _.toUpperCase()) // map1 type is HashMap[Int, String]
val map2 = TreeMap(1 -> "s", 2 -> "d").updatedWith(2, _.toUpperCase()) // map2 type is TreeMap[Int, String]
val map3 = HashMap(1 -> "s", 2 -> "d").updatedWith(2, _.asInstanceOf[Any]) // map3 type is HashMap[Int, Any]
Please also note that CanBuildFrom pattern is much more powerfull and this example doesn't use all of it's power. Thanks to CanBuildFrom some operations can change the type of collection completely like BitSet(1, 3, 5, 7) map {_.toString } type is actually SortedSet[String].

Scala lambda expression

(It seems like I ask too many basic questions, but Scala is very difficult, especially when it comes to the language feature or syntactic sugar, and I come from a Java background..)
I am watching this video: Scalaz State Monad on YouTube. The lecturer wrote these code:
trait State[S, +A] {
def run(initial: S):(S, A)
def map[B](f: A=>B): State[S, B] =
State { s =>
val (s1, a) = run(s)
(s1, f(a))
}
He said the State on the forth line is a "lambda" expression (I may have misheard). I'm a bit confused about this. He did have an object State, which looks like this:
Object State {
def apply[S, A](f: S=>(S, A)): State[S, A] =
new State[S, A] {
def run(i: S) = f(i)
}
}
The lecturer said he was invoking the apply method of Object State by invoking like that, and yes he did pass in a function that returns a correct tuple, but what is that { s=>? Where the hell did he get the s and is this about the implicit mechanism in Scala?
Also the function that is passed in map function changes A to B already, how can it still pass the generics check (from [S, A] to [S, B] assuming A and B are two different types, or are they the same type?If true, why use B not A in map function?)?
The video link is here and is set to be at 18:17
http://youtu.be/Jg3Uv_YWJqI?t=18m17s
Thank you (Monad is a bit daunting to anyone that's not so familiar with functional programming)
You are on right way here.
He passes the function as parameter, but what is the function in general? Function takes argument and return a value. Returned value here is tuple (s1, f(a)), but also we need to reffer to function argument. s in your example is function argument.
It is same as
def map[B](f: A=>B): State[S, B] = {
def buildState(s: S) = {
val (s1, a) = run(s)
(s1, f(a))
}
State(buildState)
}
But, since we dont need this function anywhere else we just inline it in apply call.

Folding on Type without Monoid Instance

I'm working on this Functional Programming in Scala exercise:
// But what if our list has an element type that doesn't have a Monoid instance?
// Well, we can always map over the list to turn it into a type that does.
As I understand this exercise, it means that, if we have a Monoid of type B, but our input List is of type A, then we need to convert the List[A] to List[B], and then call foldLeft.
def foldMap[A, B](as: List[A], m: Monoid[B])(f: A => B): B = {
val bs = as.map(f)
bs.foldLeft(m.zero)((s, i) => m.op(s, i))
}
Does this understanding and code look right?
First I'd simplify the syntax of the body a bit:
def foldMap[A, B](as: List[A], m: Monoid[B])(f: A => B): B =
as.map(f).foldLeft(m.zero)(m.ops)
Then I'd move the monoid instance into its own implicit parameter list:
def foldMap[A, B](as: List[A])(f: A => B)(implicit m: Monoid[B]): B =
as.map(f).foldLeft(m.zero)(m.ops)
See the original "Type Classes as Objects and Implicits" paper for more detail about how Scala implements type classes using implicit parameter resolution, or this answer by Rex Kerr that I've also linked above.
Next I'd switch the order of the other two parameter lists:
def foldMap[A, B](f: A => B)(as: List[A])(implicit m: Monoid[B]): B =
as.map(f).foldLeft(m.zero)(m.ops)
In general you want to place the parameter lists containing parameters that change the least often first, in order to make partial application more useful. In this case there may only be one possible useful value of A => B for any A and B, but there are lots of values of List[A].
For example, switching the order allows us to write the following (which assumes a monoid instance for Bar):
val fooSum: List[Foo] => Bar = foldMap(fooToBar)
Finally, as a performance optimization (mentioned by stew above), you could avoid creating an intermediate list by moving the application of f into the fold:
def foldMap[A, B](f: A => B)(as: List[A])(implicit m: Monoid[B]): B =
as.foldLeft(m.zero) {
case (acc, a) => m.op(acc, f(a))
}
This is equivalent and more efficient, but to my eye much less clear, so I'd suggest treating it like any optimization—if you need it, use it, but think twice about whether the gains are really worth the loss of clarity.