This is part 2 of a promised multi-part series on dynamic programming in Haskell. As a reminder, we’re using Zapis as a sample problem. In this problem, we are given a sequence of opening and closing brackets (parens, square brackets, and curly braces) with question marks, and have to compute the number of different ways in which the question marks could be replaced by brackets to create valid, properly nested bracket sequences.
Last time, we developed some code to efficiently solve this problem using a mutually recursive pair of a function and a lookup table represented by a lazy, immutable array. This solution is pretty good, but it leaves a few things to be desired:
- It requires defining both a function and a lazy, immutable array, and coming up with names for them.
- When defining the function, we have to remember to index into the array instead of calling the function recursively, and there is nothing that will warn us if we forget.
An impossible dream
Wouldn’t it be cool if we could just write the recursive function, and then have some generic machinery make it fast for us by automatically generating a memo table?
In other words, we’d like a magic memoization function, with a type something like this:
memo :: (i -> a) -> (i -> a)
Then we could just define our slow, recursive function normally, wave our magic memo
wand over it, and get a fast version for free!
This sounds lovely, of course, but there are a few problems:
-
Surely this magic
memo
function won’t be able to work for any typei
. Well, OK, we can add something like anIx i
constraint and/or extra arguments to make sure that values of typei
can be used as (or converted to) array indices. -
How can
memo
possibly know how big of a table to allocate? One simple way to solve this would be to provide the table size as an extra explicit argument tomemo
. (In my next post we’ll also explore some clever things we can do when we don’t know in advance how big of a table we will need.) -
More fundamentally, though, our dream seems impossible: given a function
i -> a
, the only thing thememo
function can do is call it on some input of typei
; if thei -> a
function is recursive then it will go off and do its recursive thing without ever consulting a memo table, defeating the entire purpose.
… or is it?
For now let’s ignore the fact that our dream seems impossible and think about how we could write memo
. The idea is to take the given (i -> a)
function and first turn it into a lookup table storing a value of type a
for each i
; then return a new i -> a
function which works by just doing a table lookup.
From my previous post we already have a function to create a table for a given function:
tabulate :: Ix i => (i,i) -> (i -> a) -> Array i a
tabulate rng f = listArray rng (map f $ range rng)
The inverse function, which turns an array back into a function, is just the array indexing operator, with extra parentheses around the i -> a
to emphasize the shift in perspective:
(!) :: Ix i => Array i a -> (i -> a)
So we can define memo
simply as the composition
memo :: Ix i => (i,i) -> (i -> a) -> (i -> a)
memo rng = (!) . tabulate rng
This is nifty… but as we already saw, it doesn’t help very much… right? For example, let’s define a recursive (slow!) Fibonacci function, and apply memo
to it:
{-# LANGUAGE LambdaCase #-}
fib :: Int -> Integer
fib = \case
0 -> 0
1 -> 1
n -> fib (n-1) + fib (n-2)
fib' :: Int -> Integer
fib' = memo (0,1000) fib
As you can see from the following ghci
session, calling, say, fib' 35
is still very slow the first time, since it simply calls fib 35
which does its usual exponential recursion. However, if we call fib' 35
a second time, we get the answer instantly:
λ> :set +s
λ> fib' 35
9227465
(4.18 secs, 3,822,432,984 bytes)
λ> fib' 35
9227465
(0.00 secs, 94,104 bytes)
This is better than nothing, but it’s not really the point. We want it to be fast the first time by looking up intermediate results in the memo table. And trying to call fib'
on bigger inputs is still going to be completely hopeless.
The punchline
All might seem hopeless at this point, but we actually have everything we need—all we have to do is just stick the call to memo
in the definition of fib
itself!
fib :: Int -> Integer
fib = memo (0,1000) $ \case
0 -> 0
1 -> 1
n -> fib (n-1) + fib (n-2)
Magically, fib
is now fast:
λ> fib 35
9227465
(0.00 secs, 94,096 bytes)
λ> fib 1000
43466557686937456435688527675040625802564660517371780402481729089536555417949051890403879840079255169295922593080322634775209689623239873322471161642996440906533187938298969649928516003704476137795166849228875
(0.01 secs, 807,560 bytes)
This solves all our problems. We only have to write a single definition, which is a directly recursive function, so it’s hard to mess it up. The only thing we have to change is to stick a call to memo
(with an appropriate index range) on the front; the whole thing is elegant and short.
How does this even work, though? At first glance, it might seem like it will generate a new table with every recursive call to fib
, which would obviously be a disaster. However, that’s not what happens: there is only a single, top-level definition of fib
, and it is defined as the function which looks up its input in a certain table. Every time we call fib
we are calling that same, unique top-level function which is defined in terms of its (unique, top-level) table. So this ends up being equivalent to our previous solution—there is a mutually recursive pair of a function and a lookup table—but written in a much nicer, more compact way that doesn’t require us to explicitly name the table.
So here’s our final solution for Zapis. As you can see, the extra code we have to write in order to memoize our recurrence boils down to about five lines (two of which are type signatures and could be omitted). This is definitely a technique worth knowing!
{-# LANGUAGE LambdaCase #-}
import Control.Arrow
import Data.Array
main = interact $ lines >>> last >>> solve >>> format
format :: Integer -> String
format = show >>> reverse >>> take 5 >>> reverse
tabulate :: Ix i => (i,i) -> (i -> a) -> Array i a
tabulate rng f = listArray rng (map f $ range rng)
memo :: Ix i => (i,i) -> (i -> a) -> (i -> a)
memo rng = (!) . tabulate rng
solve :: String -> Integer
solve str = c (0,n)
where
n = length str
s = listArray (0,n-1) str
c :: (Int, Int) -> Integer
c = memo ((0,0), (n,n)) $ \case
(i,j)
| i == j -> 1
| even i /= even j -> 0
| otherwise -> sum
[ m (s!i) (s!k) * c (i+1,k) * c (k+1, j)
| k <- [i+1, i+3 .. j-1]
]
m '(' ')' = 1
m '[' ']' = 1
m '{' '}' = 1
m '?' '?' = 3
m b '?' | b `elem` "([{" = 1
m '?' b | b `elem` ")]}" = 1
m _ _ = 0