Getting and setting state in scala - scala

Here is some code from the Functional Programming in Scala book:
import State._
case class State[S, +A](run: S => (A, S)) {
def map[B](f: A => B): State[S, B] =
flatMap(a => unit(f(a)))
def map2[B, C](sb: State[S, B])(f: (A, B) => C): State[S, C] =
flatMap(a => sb.map(b => f(a, b)))
def flatMap[B](f: A => State[S, B]): State[S, B] = State(s => {
val (a, s1) = run(s)
f(a).run(s1)
})
}
object State {
type Rand[A] = State[RNG, A]
def unit[S, A](a: A): State[S, A] =
State(s => (a, s))
// The idiomatic solution is expressed via foldRight
def sequenceViaFoldRight[S, A](sas: List[State[S, A]]): State[S, List[A]] =
sas.foldRight(unit[S, List[A]](List.empty[A]))((f, acc) => f.map2(acc)(_ :: _))
// This implementation uses a loop internally and is the same recursion
// pattern as a left fold. It is quite common with left folds to build
// up a list in reverse order, then reverse it at the end.
// (We could also use a collection.mutable.ListBuffer internally.)
def sequence[S, A](sas: List[State[S, A]]): State[S, List[A]] = {
def go(s: S, actions: List[State[S, A]], acc: List[A]): (List[A], S) =
actions match {
case Nil => (acc.reverse, s)
case h :: t => h.run(s) match {
case (a, s2) => go(s2, t, a :: acc)
}
}
State((s: S) => go(s, sas, List()))
}
// We can also write the loop using a left fold. This is tail recursive like the
// previous solution, but it reverses the list _before_ folding it instead of after.
// You might think that this is slower than the `foldRight` solution since it
// walks over the list twice, but it's actually faster! The `foldRight` solution
// technically has to also walk the list twice, since it has to unravel the call
// stack, not being tail recursive. And the call stack will be as tall as the list
// is long.
def sequenceViaFoldLeft[S, A](l: List[State[S, A]]): State[S, List[A]] =
l.reverse.foldLeft(unit[S, List[A]](List()))((acc, f) => f.map2(acc)(_ :: _))
def modify[S](f: S => S): State[S, Unit] = for {
s <- get // Gets the current state and assigns it to `s`.
_ <- set(f(s)) // Sets the new state to `f` applied to `s`.
} yield ()
def get[S]: State[S, S] = State(s => (s, s))
def set[S](s: S): State[S, Unit] = State(_ => ((), s))
}
I have spent hours thinking about why the get and set methods look the way they do, but I just don't understand.
Could anyone enlighten me, please?

The key is on the 3rd line:
case class State[S, +A](run: S => (A, S))
The stateful computation is expressed with the run function. This function represent a transition from one state S to another state S. A is a value we could produce when moving from one state to the other.
Now, how can we take the state S out of the state-monad? We could make a transition that doesn't go to a different state and we materialise the state as A with the function s => (s, s):
def get[S]: State[S, S] = State(s => (s, s))
How to set the state? All we need is a function that goes to a state s: ??? => (???, s):
def set[S](s: S): State[S, Unit] = State(_ => ((), s))
EDIT I would like to add an example to see get and set in action:
val statefullComputationsCombined = for {
a <- State.get[Int]
b <- State.set(10)
c <- State.get[Int]
d <- State.set(100)
e <- State.get[Int]
} yield (a, c, e)
Without looking further down this answer, what is the type of statefullComputationsCombined?
Must be a State[S, A] right? S is of type Int but what is A? Because we are yielding (a, c, e) must be a 3-tuple made by the As of the flatmap steps (<-).
We said that get "fill" A with the state S so the a, c ,d are of type S, so Int. b, d are Unit because def set[S](s: S): State[S, Unit].
val statefullComputationsCombined: State[Int, (Int, Int, Int)] = for ...
To use statefullComputationsCombined we need to run it:
statefullComputationsCombined.run(1)._1 == (1,10,100)
If we want the state at the end of the computation:
statefullComputationsCombined.run(1)._2 == 100

Related

Type mismatch on moving method from companion object to a class

I'm new to Scala, so the answer might be obvious.
When going through FP in Scala FP in Scala I noticed that the unit method is placed on the companion object, but not the class, so I tried moving it to the class.
It resulted in an error which I don't understand, can someone explain it to me?
When you uncomment the unit method on the class and comment it out on the companion object it results in this error:
Why is that?
import State._
case class State[S, A](run: S => (A, S)) {
def map[B](f: A => B): State[S, B] =
flatMap(a => unit(f(a)))
def map2[B, C](sb: State[S, B])(f: (A, B) => C): State[S, C] =
flatMap(a => sb.map(b => f(a, b)))
def flatMap[B](f: A => State[S, B]): State[S, B] = State(s => {
val (a, s1) = run(s)
f(a).run(s1)
})
// This results in an error here:
// flatMap(a => unit(f(a))) ... Required A, found B
// def unit(a: A): State[S, A] = State(s => (a, s)) // ERROR
}
object State {
// Comment out when uncommenting above
def unit[S, A](a: A): State[S, A] = State(s => (a, s))
}
The error explains the problem, but here is the breakdown.
Start with map:
def map[B](f: A => B): State[S, B] =
flatMap(a => unit(f(a)))
Since f returns an unknown type B, this is calling unit with a value of type B. But unit is defined to take A, which is a type parameter of the case class:
def unit(a: A): State[S, A] = State(s => (a, s))
Since B might not be A, the compiler complains.
The code inside the companion object is different:
object State {
def unit[T, U](a: U): State[T, U] = State(s => (a, s))
}
I have re-named the type parameters to make it clear that they are now unrelated to the type parameters of the case class. And since the argument type of unit can be any type U it is OK to pass a value of type B to it.
To fix the definition of unit inside the case class, give it a type parameter:
case class State[S, A](run: S => (A, S)) {
// ...
def unit[T](a: T): State[S, T] = State(s => (a, s))
}

Programming a state monad in Scala

The theory of how a state monad looks like I borrow from Philip Wadler's Monads for Functional Programming:
type M a = State → (a, State)
type State = Int
unit :: a → M a
unit a = λx. (a, x)
(*) :: M a → (a → M b) → M b
m * k = λx.
let (a, y) = m x in
let (b, z) = k a y in
(b, z)
The way I would like to use a state monad is as follows:
Given a list L I want different parts of my code to get this list and update this list by adding new elements at its end.
I guess the above would be modified as:
type M = State → (List[Data], State)
type State = List[Data]
def unit(a: List[Data]) = (x: State) => (a,x)
def star(m: M, k: List[Data] => M): M = {
(x: M) =>
val (a,y) = m(x)
val (b,z) = k(a)(y)
(b,z)
}
def get = ???
def update = ???
How do I fill in the details, i.e.?
How do I instantiate my hierarchy to work on a concrete list?
How do I implement get and update in terms of the above?
Finally, how would I do this using Scala's syntax with flatMap and unit?
Your M is defined incorrectly. It should take a/A as a parameter, like so:
type M[A] = State => (A, State)
You've also missed that type parameter elsewhere.
unit should have a signature like this:
def unit[A](a: A): M[A]
star should have a signature like this:
def star[A, B](m: M[A], k: A => M[B]): M[B]
Hopefully, that makes the functions more clear.
Your implementation of unit was pretty much the same:
def unit[A](a: A): M[A] = x => (a, x)
However, in star, the parameter of your lambda (x) is of type State, not M, because M[B] is basically State => (A, State). The rest you got right:
def star[A, B](m: M[A])(k: A => M[B]): M[B] =
(x: State) => {
val (a, y) = m(x)
val (b, z) = k(a)(y)
(b, z)
}
Edit: According to #Luis Miguel Mejia Suarez:
It would probably be easier to implement if you make your State a class and define flatMap inside it. And you can define unit in the companion object.
He suggested final class State[S, A](val run: S => (A, S)), which would also allow you to use infix functions like >>=.
Another way to do it would be to define State as a type alias for a function S => (A, S) and extend it using an implicit class.
type State[S, A] = S => (A, S)
object State {
//This is basically "return"
def unit[S, A](a: A): State[S, A] = s => (a, s)
}
implicit class StateOps[S, A](private runState: S => (A, S)) {
//You can rename this to ">>=" or "flatMap"
def *[B](k: A => State[S, B]): State[S, B] = s => {
val (a, s2) = runState(s)
k(a)(s2)
}
}
If your definition of get is
set the result value to the state and leave the state unchanged
(borrowed from Haskell Wiki), then you can implement it like this:
def get[S]: State[S, S] = s => (s, s)
If you mean that you want to extract the state (in this case a List[Data]), you can use execState (define it in StateOps):
def execState(s: S): S = runState(s)._2
Here's a terrible example of how you can add elements to a List.
def addToList(n: Int)(list: List[Int]): ((), List[Int]) = ((), n :: list)
def fillList(n: Int): State[List[Int], ()] =
n match {
case 0 => s => ((), s)
case n => fillList(n - 1) * (_ => addToList(n))
}
println(fillList(10)(List.empty)) gives us this (the second element can be extracted with execState):
((),List(10, 9, 8, 7, 6, 5, 4, 3, 2, 1))

Understanding Scala for comprehension in this example code

I'm new to the Scala world and have some questions regarding following code.
sealed trait Input
case object Coin extends Input
case object Turn extends Input
case class Machine(locked: Boolean, candies: Int, coins: Int)
object Candy {
def update = (i: Input) => (s: Machine) =>
(i, s) match {
case (_, Machine(_, 0, _)) => s
case (Coin, Machine(false, _, _)) => s
case (Turn, Machine(true, _, _)) => s
case (Coin, Machine(true, candy, coin)) =>
Machine(false, candy, coin + 1)
case (Turn, Machine(false, candy, coin)) =>
Machine(true, candy - 1, coin)
}
def simulateMachine(inputs: List[Input]): State[Machine, (Int, Int)] = for {
_ <- sequence(inputs map (modify[Machine] _ compose update))
s <- get
} yield (s.coins, s.candies)
}
case class State[S, +A](run: S => (A, S)) {
def map[B](f: A => B): State[S, B] =
flatMap(a => unit(f(a)))
def map2[B,C](sb: State[S, B])(f: (A, B) => C): State[S, C] =
flatMap(a => sb.map(b => f(a, b)))
def flatMap[B](f: A => State[S, B]): State[S, B] = State(s => {
val (a, s1) = run(s)
f(a).run(s1)
})
}
object State {
def modify[S](f: S => S): State[S, Unit] = for {
s <- get // Gets the current state and assigns it to `s`.
_ <- set(f(s)) // Sets the new state to `f` applied to `s`.
} yield ()
def get[S]: State[S, S] = State(s => (s, s))
def set[S](s: S): State[S, Unit] = State(_ => ((), s))
def sequence[S,A](sas: List[State[S, A]]): State[S, List[A]] =
sas.foldRight(unit[S, List[A]](List()))((f, acc) => f.map2(acc)(_ :: _))
def unit[S, A](a: A): State[S, A] =
State(s => (a, s))
}
In simulateMachine method, what is the first and second _? Feels first one is an ignored output, are those two stands for the same value?
The get is from State, how does it get the state of Machine?
yield output a Tuple, how does it matching return type State[Machine, (Int, Int)]
How to call this simulateMachine method? Looks like I need do initiate Machine state somewhere.
I don't think this document can give me the right comprehension about what is happening under the hood. How can I understand it better?
Here's a slightly more verbose version of simulateMachine that should be a bit easier to understand.
def simulateMachine(inputs: List[Input]): State[Machine, (Int, Int)] = for {
_ <- State.sequence( // this _ means "discard this 'unpacked' value"
inputs.map(
(State.modify[Machine] _).compose(update) // this _ is an eta-expansion
)
)
s <- State.get
} yield (s.coins, s.candies)
"Unpacked" value - for { value <- container } ... basically "unpacks" value from container. Semantics of unpacking is different between different containers - the most straightforward is List or Seq semantics, where value each item of the becomes value and for semantics in this case becomes just iteration. Other notable "containers" are:
Option - for semantics: value present or not + modifications of the value if present
Try - for semantics: successful computation or not + modifications of the successful
Future - for semantics: defining a chain of computations to perform when value becomes present
Practically, each monad (explaining what monad is is haaard :) Try this as a starting point) defines the way to "unpack" and "iterate". Your State is practically a monad (even though it doesn't declare this explicitly) (and btw, there's a canonical State monad you can find in 3rd party libs such as cats or scalaz)
Eta-expansion is explained in this SO answer, but in short it converts a function/method into a function object. Roughly equivalent to x => State.modify[Machine](x)
Or a completely desugared version:
def simulateMachine(inputs: List[Input]): State[Machine, (Int, Int)] = {
val simulationStep: Input => State[Machine, Unit] = (State.modify[Machine] _).compose(update) // this _ is an eta-expansion
State.sequence(inputs.map(input => simulationStep(input)))
.flatMap(_ => State.get) // `_ => xyz` creates a function that does not care about it's input
.map(machine => (machine.candies, machine.coins))
}
Note that the use of _ in _ => State.get is yet another way _ is used - now it creates a function that doesn't care about it's argument.
Now, answering questions as they are asked:
See above :)
If you look at desugared version, you'd see that the State.get is passed as a "callback" to flatMap. There are a few hoops to jump through to trace it (I'll leave it as an exercise for you :)), but essentially what it does is to replace the "result" part of the state (A) with the state itself (S)). Trick here is that it does not get the Machine at all - it just "chains" a call to "fetch" the Machine from the S part of state to A part of state.
I think desugared version should be pretty self-explanatory. Regarding the for-comprehension version - for fixes the "context" using the first call in the chain, so everything else (including yield) is executed in that "context" - specifically, it's fixed to State[_, _]. yield is essentially equivalent to calling map at the end of the chain, and in State[_, _] semantics, map replaces the A part, but preserves the S part.
Candy.simulateMachine(List(Coin, Turn, Coin, Turn)).run(Machine(false, 2, 0))
P.S. your current code still misses State.set method. I'm assuming it's just def set[S](s: => S): State[S, S] = State(_ => (s, s)) - at least adding this makes your snippet compile in my environment.

How to Remove Explicit Casting

How do I remove explicit casting asInstanceOf[XList[B]] in Cons(f(a), b).asInstanceOf[XList[B]] inside map function? Or perhaps redesign reduce and map functions altogether? Thanks
trait XList[+A]
case object Empty extends XList[Nothing]
case class Cons[A](x: A, xs: XList[A]) extends XList[A]
object XList {
def apply[A](as: A*):XList[A] = if (as.isEmpty) Empty else Cons(as.head, apply(as.tail: _*))
def empty[A]: XList[A] = Empty
}
def reduce[A, B](f: B => A => B)(b: B)(xs: XList[A]): B = xs match {
case Empty => b
case Cons(y, ys) => reduce(f)(f(b)(y))(ys)
}
def map[A, B](f: A => B)(xs: XList[A]): XList[B] = reduce((b: XList[B]) => (a: A) => Cons(f(a), b).asInstanceOf[XList[B]])(XList.empty[B])(xs)
You can merge two argument lists into one by replacing )( by ,:
def reduce[A, B](f: B => A => B, b: B)(xs: XList[A]): B = xs match {
case Empty => b
case Cons(y, ys) => reduce(f, f(b)(y))(ys)
}
def map[A, B](f: A => B)(xs: XList[A]): XList[B] =
reduce((b: XList[B]) => (a: A) => Cons(f(a), b), XList.empty[B])(xs)
This will force the type inference algorithm to consider both first arguments of reduce before making up its mind about what B is supposed to be.
You can either widen Cons to a XList[B] at the call site by providing the type parameters explicitly:
def map[A, B](f: A => B)(xs: XList[A]): XList[B] =
reduce[A, XList[B]]((b: XList[B]) => (a: A) => Cons(f(a), b))(XList.empty[B])(xs)
Or use type ascription:
def map[A, B](f: A => B)(xs: XList[A]): XList[B] =
reduce((b: XList[B]) => (a: A) => Cons(f(a), b): XList[B])(XList.empty[B])(xs)
As a side note, reduce is traditionally more strict at the method definition than what you've written. reduce usually looks like this:
def reduce[A](a0: A, a: A): A
Implicitly requiring a non empty collection to begin with. What you've implemented is similar in structure to a foldLeft, which has this structure (from Scalas collection library):
def foldLeft[B](z: B)(op: (B, A) => B): B

Partition a sequence of disjunctions in Scalaz

What is the best way to partition Seq[A \/ B] into (Seq[A], Seq[B]) using Scalaz?
There is a method: separate defined in MonadPlus. This typeclass is a combination a Monad with PlusEmpty (generalized Monoid). So you need to define instance for Seq:
1) MonadPlus[Seq]
implicit val seqmp = new MonadPlus[Seq] {
def plus[A](a: Seq[A], b: => Seq[A]): Seq[A] = a ++ b
def empty[A]: Seq[A] = Seq.empty[A]
def point[A](a: => A): Seq[A] = Seq(a)
def bind[A, B](fa: Seq[A])(f: (A) => Seq[B]): Seq[B] = fa.flatMap(f)
}
Seq is already monadic, so point and bind are easy, empty and plus are monoid operations and Seq is a free monoid
2) Bifoldable[\/]
implicit val bife = new Bifoldable[\/] {
def bifoldMap[A, B, M](fa: \/[A, B])(f: (A) => M)(g: (B) => M)(implicit F: Monoid[M]): M = fa match {
case \/-(r) => g(r)
case -\/(l) => f(l)
}
def bifoldRight[A, B, C](fa: \/[A, B], z: => C)(f: (A, => C) => C)(g: (B, => C) => C): C = fa match {
case \/-(r) => g(r, z)
case -\/(l) => f(l, z)
}
}
Also easy, standard folding, but for type constructors with two parameters.
Now you can use separate:
val seq: Seq[String \/ Int] = List(\/-(10), -\/("wrong"), \/-(22), \/-(1), -\/("exception"))
scala> seq.separate
res2: (Seq[String], Seq[Int]) = (List(wrong, number exception),List(10, 22, 1))
Update
Thanks to Kenji Yoshida, there is a Bitraverse[\/], so you need only MonadPlus.
And a simple solution using foldLeft:
seq.foldLeft((Seq.empty[String], Seq.empty[Int])){ case ((as, ai), either) =>
either match {
case \/-(r) => (as, ai :+ r)
case -\/(l) => (as :+ l, ai)
}
}