You Could Have Invented The State Monad

Posted on 17 June 2016
Tags: , ,

I’m attempting NICTA/course a second time. I gave up the last time because none of the State exercises were making sense and I found myself leaning so heavily on the solutions that I wasn’t actually learning anything. This time I was much better prepared after watching lots of CanFPG talks, reading lots of blog posts and writing a little Haskell, and I easily cleared the State hurdle. In fact, I’m now going to demonstrate how you (yes, you) could have come up with it (with a little help).

The fundamental insight of state is that it can be represented by a function that takes a value of type s and returns a tuple of some value a and a new value of type s:

newtype State s a = State { runState :: s -> (a,s) }

Given such a type, what would its Functor instance look like?

instance Functor (State s) where
  (<$>) :: (a -> b) -> State s a -> State s b

Our implementation should be another State that takes a value s0, passes it to the second argument sa (resulting in (a, s1)) and calls the function fn on a:

  (<$>) fn (State sa) = State (\s0 -> let (a, s1) = sa s0 in (fn a, s1))

This is a State that takes s0 and returns (b, s1), which is exactly what we wanted.

Let’s look at the Applicative instance next:

instance Applicative (State s) where
  pure  :: a -> State s a
  (<*>) :: State s (a -> b) -> State s a -> State s b

The implementation for pure explains where the a in our State comes from. Given some a, return a State that, when fed a value s, results in (a,s). It practically writes itself.

  pure a = State (\s -> (a,s))

(<*>) is a bit trickier, because we’re dealing with both the State the function is in and the State its argument is in. The implementation should be a State that takes a value s0, feeds it to sa to get (fn, s1), feeds s1 to sb to get (a, s2), and calls fn on a:

  (<*>) (State sa) (State sb) =
    State (\s0 -> let (fn, s1) = sa s0
                      (a,  s2) = sb s1
                  in (fn a, s2))

The hardest thing is remembering to thread s0 through sa and sb so that we don’t lose any state on the way. We can usually follow the types but they don’t help in this specific case.

Finally, let’s look at the Monad instance:

instance Monad (State s) where
  (>>=) :: State s a -> (a -> State s b) -> State s b

As with all our previous implementations, it has the form:

  (>>=) (State sa) fn = State (\s0 -> let ??? in ???)

We know that we need to feed s0 to sa to get an a to apply to fn:

  (>>=) (State sa) fn =
    State (\s0 -> let (a, s1) = sa s0
                      ???     = fn a
                  in ???)

The result of fn a is a State sb but we need to return a tuple of (b, s). We can obtain one by feeding s1 to sb:

  (>>=) (State sa) fn =
    State (\s0 -> let (a, s1)  = sa s0
                      State sb = fn a
                  in sb s1)

Success!

Let’s define a few functions to make our lives easier. get returns a State that, when fed some s, returns (s,s). This allows us to expose s for direct modification:

get :: State s s
get = State (\s -> (s, s))

put allows us to store a State that ignores the s passed to it later:

put :: s -> State s ()
put s = State (\_ -> ((),s))

Sometimes we want the s and not the a:

exec :: State s a -> s -> s
exec (State sa) s = snd $ sa s

At other times we want the a and not the s:

eval :: State s a -> s -> a
eval (State sa) s = fst $ sa s

With all this machinery in place, we can do this:

Prelude> exec (do i <- get; put (i+1); return ()) 0
1

I still couldn’t believe that this worked the first time I tried it, so let’s desugar this:

    do i <- get; put (i+1); return ()
 == get >>= \i -> put (i+1) >>= \_ -> pure ()
 == State (\s -> (s, s))    >>= \i ->
    State (\_ -> ((), i+1)) >>= \_ ->
    State (\s -> ((), s))

Let’s simplify from the bottom up. By the definition of (>>=):

    (>>=) (State (\_ -> ((), i+1))) (\_ -> (State (\s -> ((), s)))) =
      State (\s0 -> let (a, s1)  = (\_ -> ((), i+1)) s0
                     -- (a, s1)  = ((), i+1)
                        State sb = (\_ -> (State (\s -> ((), s)))) a
                     --       sb = (\s -> ((), s))
                    in sb s1)
                     -- ((), i+1)
== State (\s0 -> ((), i+1))
== State (\_  -> ((), i+1))

Plugging that back in, we have

State (\s -> (s,s)) >>= \i -> State (\_ -> ((), i+1))

Which we can simplify in the same way:

    (>>=) (State (\s -> (s,s))) (\i -> State (\_ -> ((), i+1))) =
      State (\s0 -> let (a, s1)  = (\s -> (s,s)) s0
                     -- (a, s1)  = (s0, s0)
                        State sb = (\i -> State (\_ -> ((), i+1))) a
                     --       sb = (\_ -> ((), s0+1))
                    in sb s1)
                     -- ((), s0+1)
 == State (\s0 -> ((), s0+1))
 == State (\i  -> ((), i+1))

Finally, we have

    exec (State (\i -> ((), i+1))) 0
 == snd $ runState (State (\i -> ((), i+1))) 0
 == snd $ (\i -> ((), i+1)) 0
 == snd $ ((), 1)
 == 1

This is my favourite thing about Haskell: the fact that it is built on abstractions that can be reasoned about in such a rigorous manner.

In fact, with some inspired renaming, you too could have invented the continuation monad.