2012-03-09 21:07 UTC

A quick disclaimer: I’m still learning Haskell so I may have misunderstood some things or described them incorrectly. I wrote this in one sitting and there are probably a few typos in the text and maybe even some technical errors. Let me know if you find any. Now on with the show.

I found the state monad difficult to understand at first and from my searches it seems that I’m not the only one. On a superficial level it is conceptually simple: pass around the state to keep your functions pure. If the state itself is passed to a function then there are no side effects because each input will be mapped to a single output and we can trust our function to work as expected. Simple, right?

The difficulty lay in understanding what was really going on under the surface of the state monad. Basically, this section is for anyone who keeps looking at something like this:

rollDie :: State StdGen Int rollDie = do generator <- get let ( value, newGenerator ) = randomR (1,6) generator put newGenerator return value

and wondering how the `get`

and `put`

just seem to magically pull state out of thin air, alter it, then push it back before returning a value that appears to be unrelated to this pushing and pulling. I’ll run through the key elements to understanding what’s actually happening then I’ll try to put them all together in a way that makes everything clear.

The code in this section is take from the Haskell wikibook’s state monad page. Note that I’ve changed the names in the definitions of `evalState`

and `execState`

to make it clearer that they accept a state as an argument. I’ve also formatted this code so that it will compile if you want to play around with it.

import System.Random newtype State state result = State { runState :: state -> (result, state) } instance Monad (State state_type) where --return :: result -> State state result return r = State ( \s -> (r, s) ) --(>>=) :: State state result_a -> (result_a -> State state result_b) -> State state result_b processor >>= processorGenerator = State $ \state -> let (result, state') = runState processor state in runState (processorGenerator result) state' put newState = State $ \_ -> ((), newState) get = State $ \state -> (state, state) evalState stateMonad state = fst ( runState stateMonad state ) execState stateMonad state = snd ( runState stateMonad state ) type GeneratorState = State StdGen rollDie :: GeneratorState Int rollDie = do generator <- get let ( value, newGenerator ) = randomR (1,6) generator put newGenerator return value --evalState rollDie (mkStdGen 0)

The state monad doesn’t actually hold a state. It holds something that can transform a state. Let’s look at the type:

newtype State state result = State { runState :: state -> (result, state) }

The record syntax adds some sugar that gets in the way, so let’s rewrite the above without it:

newtype State state result = State ( state -> (result, state) ) runState :: (State state result) -> (state -> (result,state)) runState (State f) = f

So, the state monad contains something that takes a state and returns a tuple consisting of a result and a state. `runState`

is just an accessor function to get the function inside the state monad. We need this because the function is wrapped in the `State`

constructor and we don’t want to have to pattern-match against it all the time.

Actually, it’s probably not even right to say that it “contains” a function. It really *is* a function, but it’s wrapped in a constructor to give it a type that we can work with.

So, what does the state transformer function inside of the state monad do? It takes a state and returns a tuple containing a result and a new state. That’s all.

`do`

notation is syntactic sugar and it makes it easier for people from an imperative programming background to get started with monads, but it also obfuscates the functional nature of the code. It tricks you into thinking about the code as though it were imperative and thus it appears that some “assignments” are completely unrelated to others. I believe that this can impede the learning of Haskell.

Let’s look at the section of code that confuses people:

With imperative eyes this looks simple. The type of `rollDie`

is a state monad and walking through it we see that it gets a state, passes it to randomR, gets a random value along with a new state, puts the state back wherever it got it, then returns the value.

OK, so where did the state come from? How did `get`

magically pull the state out of thin air and how did `put`

put it back? The function is pure. It isn’t changing something outside of it self, and it doesn’t accept any arguments. What’s it actually doing?

Let’s look at the get and put functions again:

put newState = State $ \_ -> ((), newState) get = State $ \state -> (state, state)

`get`

returns a state monad, but looking at the code we see that it takes no arguments and always returns the same thing. It’s not returning some specific state that’s tracking what we’re doing. It’s just a generic state monad with a function that accepts a state and returns a 2-tuple with the unaltered state in both positions. Again, it takes no arguments and always returns the same thing. (Well, not strictly… the type gets marshalled into whatever is needed, but that’s irrelevant here. Read up about type variables and polymorphic types if you want to know more.)

`put`

take a state as an argument and returns a state monad. The function inside the state monad, as we saw above in the ‘The “State Monad” Is A State Transformer’ section, is a function that accepts a state and returns a tuple containing a result and a new state. Looking at the definition of `put`

, we see that this state transforming function ignores its argument (a state) and returns a tuple with an empty result and the state that we passed to `put`

.

It’s not “storing” a state anywhere. It’s actually getting a state monad that can transform all states (of the same type) to a tuple with an empty result and the state that we passed to put.

Be careful not to confuse `put`

with the state transforming function inside the state monad that it returns. `put`

is a function that accepts an argument and returns a state monad. The state monad that it returns contains a function that accepts a state, ignores it, and returns a tuple containing the state that we originally passed to `put`

.

So how does that preserve state and pass it around outside of the function? What is going on here?

It’s time to get rid of some syntactic sugar by removing the do notation. This is what the compiler does behind the scenes when we run our code. Transforming the notation, the `rollDie`

function becomes:

rollDie :: State StdGen Int rollDie = get >>= \generator -> let ( value, newGenerator ) = randomR (1,6) generator in put newGenerator >> return value

If the transformation isn’t clear, read up on `do`

notation.

`(>>=)`

Now that we have the function expressed in terms of `(>>=)`

and `(>>)`

(the “bind” and “then” operators, respectively), we need to look back at how `(>>=)`

is defined for state monads (you should read up on these operators if you are not familiar with them, but if that’s the case then you might want to come back to state monads later). As we’ve just seen above, `get`

is returning a state monad (not a state), so the first argument to the first `(>>=)`

is indeed a state monad. Let’s change the names in the definition above to make it a bit clearer too:

(>>=) :: State state result_a -> (result_a -> State state result_b) -> State state result_b stateTransformer >>= stateTransformerGenerator = State $ \state -> let (result, state') = runState stateTransformer state in runState (stateTransformerGenerator result) state'

Let’s run through this and make some sense of it. The first argument is a state monad and the second argument is a function that takes a result and returns a state monad that might have a different result type. `(>>=)`

takes these two arguments and returns a new state monad.

The definition begins with `State`

, which is the constructor of the state monad. This just ensures that the following lambda expression, which must be a state transforming function, will be wrapped up in a state monad.

In the second line, `runState stateTransformer`

pulls out the state transforming function from the first argument to (>>=), namely `stateTransformer`

. This is the function that accepts a state and returns a tuple with a result and a new state. The returned function is applied to `state`

and returns the tuple `(result, state')`

, where `state'`

is a new state.

Before continuing, remind yourself where `state`

is coming from. Remember that we are defining a state transforming function that accepts a state and returns a tuple, which we are then wrapping in the state monad. `state`

is therefore a state that will be passed to this function later. It’s saying “When you’re passed a state, take that state and pass it to the state transformer in the state monad of the first argument, then call the result and new state in the tuple `result`

and `state'`

, respectively.”. You can think of it as a set of instructions of what to do with a state if it’s passed. They’re not actually being done yet.

In the final line, we take the result from above (`result`

) and we feed it into the second argument of `(>>=)`

, namely `stateTransformerGenerator`

. This returns a new state monad and we then pull out the state transforming function in it with our accessor, `runState`

. We now have a function that accepts a state and returns a tuple. We then pass this function the new state generated in the previous line (`state'`

) and get the tuple that we’re expecting to complete the definition of a state transformer started by the lambda expression on the first line (`\state ->`

).

So, what do we have at the end? We have a state transformer that takes a state, runs it through one state transformer to get a new one and a result, then it uses the result to get another state transformer from our “state transformer generator”, and then runs the new state through that too to get a third state and a new result. We’re threading the state through a succession of state transformers (as we would thread a string through a string through beads) and using the result of each transformation to determine the next state transformer.

`get`

and `put`

Now we need to fit this back into our definition of rollDie to see what it’s doing. Let’s look at the definition again (no need to scroll back):

The rollDie function will return a state monad that can take a “random number generator state” and return an integer between 1 and 6, inclusive, along with a new “random number generator state”, which can then be threaded through rollDie again to get another number and another new state.

Let’s replace `generator`

and `newGenerator`

, which are the “random number generator states”, with `gState`

and `gState'`

. Let’s also replace `get`

and `put newGenerator`

with what they return.

rollDie :: State StdGen Int rollDie = State $ \state -> (state, state) >>= \gState -> let ( value, gState' ) = randomR (1,6) gState in State $ \_ -> ((), gState') >> return value

Make sure that you understand the replacements by looking back at the definitions.

`(>>)`

There is still one last thing that we need to know before we can put all of this together. What does `(>>)`

(then) do in this context? You’re probably familiar with it from the IO monad, which let’s you do something like this:

(putStrLn "foo") >> (putStrLn "bar") >> (putStrLn "baz")

or in the more familiar `do`

notation:

do putStrLn "foo" putStrLn "bar" putStrLn "baz"

With IO, this just prints “foo”, then it prints “bar”, then it prints “baz”, each on a new line. `(>>)`

seems to just sequence the functions and there doesn’t seem to be any connection between them, but this is just a consequence of the definition of `(>>)`

for IO monads. Actually, it’s a consequence of the definition of `(>>=)`

.

`(>>)`

is actually defined in terms of `(>>=)`

:

m >> n = m >>= \_ -> n

Let’s replace `(>>)`

in our rollDie function with this definition and take another look at it.

rollDie :: State StdGen Int rollDie = State $ \state -> (state, state) >>= \gState -> let ( value, gState' ) = randomR (1,6) gState in State $ \_ -> ((), gState') >>= \_ -> return value

So far, so good.

OK, we’re almost there, It’s time to wrap our heads around this. Here are the definitions of `return`

and `(>>=)`

again so you don’t have to scroll back:

return r = State ( \s -> (r, s) ) (>>=) :: State state result_a -> (result_a -> State state result_b) -> State state result_b stateTransformer >>= stateTransformerGenerator = State $ \state -> let (result, state') = runState stateTransformer state in runState (stateTransformerGenerator result) state'

Looking at the definition of `rollDie`

above in terms of `(>>=)`

, we see that it’s defined as `x >>= y >>= z`

, where

x = State $ \state -> (state, state) y = \gState -> let ( value, gState' ) = randomR (1,6) gState in State $ \_ -> ((), gState') z = \_ -> return value

The monad laws let us break this up as either: `(x >>= y) >>= z`

or `x >>= (\w -> y w >>= z)`

. To avoid confusion about how `value`

is getting passed around, we’ll use the latter. This must return a “state transforming function generator”, i.e. something with the type `(result_a -> State state result_b)`

so that it can be passed to `x >>=`

. We can rewrite it as `\w -> (y w >>= z)`

to make it clear how we should apply the bind operator inside the lambda function.

\w -> (y w >>= z) = \w -> State $ \state -> let (result, state') = runState (y w) state in runState (z result) state'

The thing to note here is that the first argument to the bind operator is `y w`

, not `y`

. `y`

takes a state and returns a state monad and the first argument of the bind operator must be a state monad, not a function that generates a state monad from a state (That’s what the second argument to the bind operator is, and that is why we could evaluate `x >>= y`

first).

Let’s step through some evaluations:

y = \gState -> let ( value, gState' ) = randomR (1,6) gState in State $ \_ -> ((), gState') y w = let ( value, gState' ) = randomR (1,6) w in State $ \_ -> ((), gState') runState (y w) = let ( value, gState' ) = randomR (1,6) w in \_ -> ((), gState') (runState (y w)) state = let ( value, gState' ) = randomR (1,6) w in ((), gState')

`runState`

pulls the state transformer function out of the state monad, which is equivalent to removing the constructor from the definition. Passing `state`

to this function returns a tuple. Because the function was defined as `\_ -> ((), gState')`

, it completely ignores the state that we pass it, but it still returns the tuple `((),gState')`

. Let’s put this back into our definition above:

\w -> (y w >>= z) = \w -> State $ \state -> let ( value, gState' ) = randomR (1,6) w (result, state') = ((),gState') in runState (z result) state'

Now we step through evaluations starting with `z`

by first substituting the definition of `return`

:

z = \_ -> return value z = \_ -> State ( \s -> (value, s) ) z result = State ( \s -> (value, s) ) runState (z result) = \s -> (value, s) (runState (z result)) state' = (value, state')

`z`

is a function that ignores its argument and returns a state monad. That state monad’s contains a function that accepts a state and returns a tuple with a fixed value and an unchanged state. As before, `runState`

unwraps the monad to get the function inside it (that accepts a state and returns a tuple). We then pass `state'`

to that function to get the tuple, which is the fixed `value`

and `state'`

, unchanged.

Now we substitute back into our previous expression.

\w -> (y w >>= z) = \w -> State $ \state -> let ( value, gState' ) = randomR (1,6) w (result, state') = ((),gState') in (value, state')

which we can simplify to

\w -> (y w >>= z) = \w -> State $ \_ -> randomR (1,6) w

because `state`

and `result`

are ignored in the definition, and `gState'`

is only used in the definition of `state'`

. As before, work through it to make sure you understand the transformation. Now we substitute into `x >>= (\w -> y w >>= z)`

, again using the definition of `(>>=)`

:

x >>= (\w -> y w >>= z) = State $ \state -> let (result, state') = runState x state in runState ((\w -> y w >>= z) result) state'

and step through some more evaluations:

\w -> y w >>= z = (\w -> State $ \_ -> randomR (1,6) w) (\w -> y w >>= z) result = State $ \_ -> randomR (1,6) result runState ((\w -> y w >>= z) result) = \_ -> randomR (1,6) result (runState ((\w -> y w >>= z) result)) state' = randomR (1,6) result

and thus:

x >>= (\w -> y w >>= z) = State $ \state -> let (result, state') = runState x state in randomR (1,6) result

Now we just need to run through the evaluations starting with `x`

(using `s`

instead of `state`

in the lambda expression from above to avoid confusion with `state`

here).

x = State $ \s -> (s, s) runState x = \s -> (s, s) (runState x) state = (state, state)

and finally we substitute for `x`

and then simplify

x >>= (\w -> y w >>= z) = State $ \state -> let (result, state') = (state,state) in randomR (1,6) result x >>= (\w -> y w >>= z) = State $ \state -> randomR (1,6) state

Both `result`

and `state'`

are equal to `state`

, so we replace them then remove the tuple assignment, which is redundant. Now remember that `x >>= (\w -> y w >>= z)`

is equal to `x >>= y >>= z`

, which is just our definition of `rollDie`

.

In the end, we’re left with

rollDie = State $ \state -> randomR (1,6) state

Let’s check what the type of `randomR (1,6) state`

will be:

ghci> :m System.Random ghci> :t randomR randomR :: (Random a, RandomGen g) => (a, a) -> g -> (a, g)

In our case, `Random a`

is an `Int`

and `RandomGen g`

is our random generator state, so `randomR (1,6) state`

will return a tuple containing an Int and a new random generator state. So our function accepts a state and returns a tuple containing a result and a new state, i.e. a state transformer.

`State`

wraps this state transformer function into a state monad and so our `rollDie`

function returns a state monad, as it’s supposed to.

So, what can we do with it?

evalState rollDie (mkStdGen 0)

`mkStdGen 0`

returns a new random generator state (0 is seed). `evalState`

, which we saw way up above, takes a state monad and a state, pulls out the state transforming function from the state monad, applies it to the state to get the result and new state tuple, then pulls out the result from the tuple and returns it, which gives us an integer between 1 and 6, inclusive.

Yeah, that isn’t all that exciting. Don’t worry though, we didn’t go through all of that just to generate a single digit then call it a day and head home.

Going over the Haskell wikibook’s state monad page again, we find a definition for `rollDice`

, which is a state monad whose result is a pair of integers instead of a single integer (i.e. a simulation of rolling 2 dice):

rollDice :: State StdGen (Int, Int) rollDice = liftM2 (,) rollDie rollDie

If you’re wondering what `liftM2`

is, it’s just a convenience function defined in Control.Monad that leads to more obfuscation in this case, so let’s redefine it:

rollDice :: State StdGen (Int, Int) rollDice :: State StdGen (Int, Int) rollDice = rollDie >>= \result_1 -> rollDie >>= \result_2 -> return (result_1, result_2)

Look back at the definition of `(>>=)`

and `return`

to understand how this works. Here’s the same thing in `do`

notation:

rollDice' :: State StdGen (Int, Int) rollDice' = do result_1 <- rollDie result_2 <- rollDie return (result_1, result_2)

It appears deviously simple, doesn’t it? Let’s roll some dice:

ghci> evalState rollDice (mkStdGen 423752) (5,1)

Note that `rollDice'`

would give the same result with the the same seed (423752). It really is the same function.

Finally, here’s the working code again with both definitions of `rollDice`

and two `roll`

functions. The `roll`

functions accept a seed (Int) for the random generator and return a pair of Ints generated using randomR.

import System.Random newtype State state result = State { runState :: state -> (result, state) } instance Monad (State state_type) where --return :: result -> State state result return r = State ( \s -> (r, s) ) --(>>=) :: State state result_a -> (result_a -> State state result_b) -> State state result_b processor >>= processorGenerator = State $ \state -> let (result, state') = runState processor state in runState (processorGenerator result) state' put newState = State $ \_ -> ((), newState) get = State $ \state -> (state, state) evalState stateMonad state = fst ( runState stateMonad state ) execState stateMonad state = snd ( runState stateMonad state ) type GeneratorState = State StdGen rollDie :: GeneratorState Int rollDie = do generator <- get let ( value, newGenerator ) = randomR (1,6) generator put newGenerator return value --evalState rollDie (mkStdGen 0) rollDice :: State StdGen (Int, Int) rollDice = rollDie >>= \result_1 -> rollDie >>= \result_2 -> return (result_1, result_2) rollDice' :: State StdGen (Int, Int) rollDice' = do result_1 <- rollDie result_2 <- rollDie return (result_1, result_2) roll :: Int -> (Int,Int) roll = \x -> evalState rollDice (mkStdGen x) roll' :: Int -> (Int,Int) roll' = \x -> evalState rollDice' (mkStdGen x)

I hope this helps you to understand the state monad. Feel free to contact me if you would like to give feedback such as suggestions for improvements or corrections (both technical and orthographical), or just to let me know if you found it useful.

- Contact
- echo xyne.archlinux.ca | sed 's/\./@/'
- Feeds
- Blog News
- Validation
- XHTML 1.0 Strict CSS level 3 Atom 1.0