Dynamic programming in Haskell: automatic memoization

This is part 2 of a promised multi-part series on dynamic programming in Haskell. As a reminder, we’re using Zapis as a sample problem. In this problem, we are given a sequence of opening and closing brackets (parens, square brackets, and curly braces) with question marks, and have to compute the number of different ways in which the question marks could be replaced by brackets to create valid, properly nested bracket sequences.

Last time, we developed some code to efficiently solve this problem using a mutually recursive pair of a function and a lookup table represented by a lazy, immutable array. This solution is pretty good, but it leaves a few things to be desired:

  • It requires defining both a function and a lazy, immutable array, and coming up with names for them.
  • When defining the function, we have to remember to index into the array instead of calling the function recursively, and there is nothing that will warn us if we forget.

An impossible dream

Wouldn’t it be cool if we could just write the recursive function, and then have some generic machinery make it fast for us by automatically generating a memo table?

In other words, we’d like a magic memoization function, with a type something like this:

memo :: (i -> a) -> (i -> a)

Then we could just define our slow, recursive function normally, wave our magic memo wand over it, and get a fast version for free!

This sounds lovely, of course, but there are a few problems:

  • Surely this magic memo function won’t be able to work for any type i. Well, OK, we can add something like an Ix i constraint and/or extra arguments to make sure that values of type i can be used as (or converted to) array indices.

  • How can memo possibly know how big of a table to allocate? One simple way to solve this would be to provide the table size as an extra explicit argument to memo. (In my next post we’ll also explore some clever things we can do when we don’t know in advance how big of a table we will need.)

  • More fundamentally, though, our dream seems impossible: given a function i -> a, the only thing the memo function can do is call it on some input of type i; if the i -> a function is recursive then it will go off and do its recursive thing without ever consulting a memo table, defeating the entire purpose.

… or is it?

For now let’s ignore the fact that our dream seems impossible and think about how we could write memo. The idea is to take the given (i -> a) function and first turn it into a lookup table storing a value of type a for each i; then return a new i -> a function which works by just doing a table lookup.

From my previous post we already have a function to create a table for a given function:

tabulate :: Ix i => (i,i) -> (i -> a) -> Array i a
tabulate rng f = listArray rng (map f $ range rng)

The inverse function, which turns an array back into a function, is just the array indexing operator, with extra parentheses around the i -> a to emphasize the shift in perspective:

(!) :: Ix i => Array i a -> (i -> a)

So we can define memo simply as the composition

memo :: Ix i => (i,i) -> (i -> a) -> (i -> a)
memo rng = (!) . tabulate rng

This is nifty… but as we already saw, it doesn’t help very much… right? For example, let’s define a recursive (slow!) Fibonacci function, and apply memo to it:

{-# LANGUAGE LambdaCase #-}

fib :: Int -> Integer
fib = \case
  0 -> 0
  1 -> 1
  n -> fib (n-1) + fib (n-2)

fib' :: Int -> Integer
fib' = memo (0,1000) fib

As you can see from the following ghci session, calling, say, fib' 35 is still very slow the first time, since it simply calls fib 35 which does its usual exponential recursion. However, if we call fib' 35 a second time, we get the answer instantly:

λ> :set +s
λ> fib' 35
9227465
(4.18 secs, 3,822,432,984 bytes)
λ> fib' 35
9227465
(0.00 secs, 94,104 bytes)

This is better than nothing, but it’s not really the point. We want it to be fast the first time by looking up intermediate results in the memo table. And trying to call fib' on bigger inputs is still going to be completely hopeless.

The punchline

All might seem hopeless at this point, but we actually have everything we need—all we have to do is just stick the call to memo in the definition of fib itself!

fib :: Int -> Integer
fib = memo (0,1000) $ \case
  0 -> 0
  1 -> 1
  n -> fib (n-1) + fib (n-2)

Magically, fib is now fast:

λ> fib 35
9227465
(0.00 secs, 94,096 bytes)
λ> fib 1000
43466557686937456435688527675040625802564660517371780402481729089536555417949051890403879840079255169295922593080322634775209689623239873322471161642996440906533187938298969649928516003704476137795166849228875
(0.01 secs, 807,560 bytes)

This solves all our problems. We only have to write a single definition, which is a directly recursive function, so it’s hard to mess it up. The only thing we have to change is to stick a call to memo (with an appropriate index range) on the front; the whole thing is elegant and short.

How does this even work, though? At first glance, it might seem like it will generate a new table with every recursive call to fib, which would obviously be a disaster. However, that’s not what happens: there is only a single, top-level definition of fib, and it is defined as the function which looks up its input in a certain table. Every time we call fib we are calling that same, unique top-level function which is defined in terms of its (unique, top-level) table. So this ends up being equivalent to our previous solution—there is a mutually recursive pair of a function and a lookup table—but written in a much nicer, more compact way that doesn’t require us to explicitly name the table.

So here’s our final solution for Zapis. As you can see, the extra code we have to write in order to memoize our recurrence boils down to about five lines (two of which are type signatures and could be omitted). This is definitely a technique worth knowing!

{-# LANGUAGE LambdaCase #-}

import Control.Arrow
import Data.Array

main = interact $ lines >>> last >>> solve >>> format

format :: Integer -> String
format = show >>> reverse >>> take 5 >>> reverse

tabulate :: Ix i => (i,i) -> (i -> a) -> Array i a
tabulate rng f = listArray rng (map f $ range rng)

memo :: Ix i => (i,i) -> (i -> a) -> (i -> a)
memo rng = (!) . tabulate rng

solve :: String -> Integer
solve str = c (0,n)
  where
    n = length str
    s = listArray (0,n-1) str

    c :: (Int, Int) -> Integer
    c = memo ((0,0), (n,n)) $ \case
      (i,j)
        | i == j           -> 1
        | even i /= even j -> 0
        | otherwise        -> sum
          [ m (s!i) (s!k) * c (i+1,k) * c (k+1, j)
          | k <- [i+1, i+3 .. j-1]
          ]

m '(' ')'                = 1
m '[' ']'                = 1
m '{' '}'                = 1
m '?' '?'                = 3
m b '?' | b `elem` "([{" = 1
m '?' b | b `elem` ")]}" = 1
m _ _                    = 0
Advertisement
Posted in competitive programming, haskell | Tagged , , , | 3 Comments

Dynamic programming in Haskell: lazy immutable arrays

This is part 1 of a promised multi-part series on dynamic programming in Haskell. As a reminder, we’re using Zapis as a sample problem. In this problem, we are given a sequence of opening and closing brackets (parens, square brackets, and curly braces) with question marks, and have to compute the number of different ways in which the question marks could be replaced by brackets to create valid, properly nested bracket sequences.

Last time, we developed a recurrence for this problem and saw some naive, directly recursive Haskell code for computing it. Although this naive version is technically correct, it is much too slow, so our goal is to implement it more efficiently.

Mutable arrays?

Someone coming from an imperative background might immediately reach for some kind of mutable array, e.g. STUArray. Every time we call the function, we check whether the corresponding array index has already been filled in. If so, we simply return the stored value; if not, we compute the value recursively, and then fill in the array before returning it.

This would work, but there is a better way!

Immutable arrays

While mutable arrays occasionally have their place, we can surprisingly often get away with immutable arrays, where we completely define the array up front and then only use it for fast lookups afterwards.

  • If the type of the array elements is suitable, and we can initialize the array elements all at once from a list using some kind of formula, map, scan, etc., we should use UArray since it is much faster than Array.
  • However, UArray is strict in the elements, and the elements must be of a type that can be stored unboxed. If we need a more complex element type, or we need to compute the array recursively (where some elements depend on other elements), we can use Array.

What about the vector library, you ask? Well, it’s a very nice library, and quite fast, but unfortunately it is not available on many judging platforms, so I tend to stick to array to be safe. However, if you’re doing something like Advent of Code or Project Euler where you get to run the code on your own machine, then you should definitely reach for vector.

Lazy, recursive, immutable arrays

In my previous post on topsort we already saw the basic idea: since Arrays are lazy in their elements, we can define them recursively; the Haskell runtime then takes care of computing the elements in a suitable order. Previously, we saw this applied to automatically compute a topological sort, but more generally, we can use it to fill out a table of values for any recurrence.

So, as a first attempt, let’s just replace our recursive c function from last time with an array. I’ll only show the solve function for now; the rest of the code remains the same. (Spoiler alert: this solution works, but it’s ugly. We’ll develop much better solutions later.)

solve :: String -> Integer
solve str = c!(0,n)
  where
    n = length str
    s = listArray (0,n-1) str

    c :: Array (Int, Int) Integer
    c = array ((0,0),(n,n)) $
      [ ((i,i), 1) | i <- [0..n] ]
      ++
      [ ((i,j),0) | i <- [0..n], j <- [0..n], even i /= even j ]
      ++
      [ ((i,j),v)
      | i <- [0..n], j <- [0..n], i /= j, even i == even j
      , let v = sum [ m (s!i) (s!k) * c!(i+1,k) * c!(k+1,j) | k <- [i+1, i+3 .. j-1]]
      ]

We use the array function to create an array, which takes first a pair of indices specifying the index range, and then a list of (index, value) pairs. (The listArray function can also be particularly useful, when we have a list of values which are already in index order, as in the definition of s.)

This solution is accepted, and it’s quite fast (0.04s for me). However, it’s really ugly, and although it’s conceptually close to our directly recursive function from before, the code is almost unrecognizably different. It’s ugly that we have to repeat conditions like i /= j and even i == even j, and binders like i <- [0..n]; the multiple list comprehensions and nested pairs like ((i,j),v) are kind of ugly, and the fact that this is implementing a recurrence is completely obscured.

However, I included this solution as a first step because for a long time, after I learned about using lazy immutable arrays to implement dynamic programming in Haskell, this was the kind of solution I wrote! Indeed, if you just think about the idea of creating a recursively defined array, this might be the kind of thing you come up with: we define an array c using the array function, then we have to list all its elements, and we get to refer to c along the way.

Mutual recursion to the rescue

Most of the ugliness comes from losing sight of the fact that there is a function mapping indices to values: we simply listed out all the function’s input/output pairs without getting to use any of Haskell’s very nice facilities for defining functions! So we can clean up the code considerably if we make a mutually recursive pair of an array and a function: the array values are defined using the function, and the function definition can look up values in the array.

solve :: String -> Integer
solve str = cA!(0,n)
  where
    n = length str
    s = listArray (0,n-1) str

    cA :: Array (Int, Int) Integer
    cA = array ((0,0),(n,n)) $
      [ ((i,j), c (i,j)) | i <- [0 .. n], j <- [0 .. n] ]

    c :: (Int, Int) -> Integer
    c (i,j)
      | i == j           = 1
      | even i /= even j = 0
      | otherwise        = sum
        [ m (s!i) (s!k) * cA ! (i+1,k) * cA ! (k+1,j)
        | k <- [i+1, i+3 .. j-1]
        ]

Much better! The c function looks much the same as our naive version from before, with the one difference that instead of calling itself recursively, it looks up values in the array cA. The array, in turn, is simply defined as a lookup table for the outputs of the function.

Generalized tabulation

One nice trick we can use to simplify the code a bit more is to use the range function to generate the list of all valid array indices, and then just map the c function over this. This also allows us to use the listArray function, since we know that the range will generate the indices in the right order.

cA :: Array (Int, Int) Integer
cA = listArray rng $ map c (range rng)
  where
    rng = ((0,0), (n,n))

In fact, we can abstract this into a useful little function to create a lookup table for a function:

tabulate :: Ix i => (i,i) -> (i -> a) -> Array i a
tabulate rng f = listArray rng (map f $ range rng)

(We can generalize this even more to make it work for UArray as well as Array, but I’ll stop here for now. And yes, I intentionally named this to echo the tabulate function from the adjunctions package; Array i is indeed a representable functor, though it’s not really possible to express without dependent types.)

The solution so far

Putting it all together, here’s our complete solution so far. It’s pretty good, and in fact it’s organized in a very similar way to Soumik Sarkar’s dynamic programming solution to Chemist’s Vows. (However, there’s an even better solution coming in my next post!)

import Control.Arrow
import Data.Array

main = interact $ lines >>> last >>> solve >>> format

format :: Integer -> String
format = show >>> reverse >>> take 5 >>> reverse

tabulate :: Ix i => (i,i) -> (i -> a) -> Array i a
tabulate rng f = listArray rng (map f $ range rng)

solve :: String -> Integer
solve str = cA!(0,n)
  where
    n = length str
    s = listArray (0,n-1) str

    cA :: Array (Int, Int) Integer
    cA = tabulate ((0,0),(n,n)) c

    c :: (Int, Int) -> Integer
    c (i,j)
      | i == j           = 1
      | even i /= even j = 0
      | otherwise        = sum
        [ m (s!i) (s!k) * cA ! (i+1,k) * cA ! (k+1,j)
        | k <- [i+1, i+3 .. j-1]
        ]

m '(' ')'                = 1
m '[' ']'                = 1
m '{' '}'                = 1
m '?' '?'                = 3
m b '?' | b `elem` "([{" = 1
m '?' b | b `elem` ")]}" = 1
m _ _                    = 0

Coming up next: automatic memoization!

So what’s not to like about this solution? Well, I still don’t like the fact that we have to define a mutually recursive array and function. Conceptually, I want to name them both c (or whatever) since they are really isomorphic representations of the exact same mathematical function. It’s annoying that I have to make up a name like cA or c' or whatever for one of them. I also don’t like that we have to remember to do array lookups instead of recursive calls in the function—and if we forget, Haskell will not complain! It will just be really slow.

Next time, we’ll see how to use some clever ideas from Conal Elliot’s MemoTrie package (which themselves ultimately came from a paper by Ralf Hinze) to solve these remaining issues and end up with some really beautiful code!

Posted in competitive programming, haskell | Tagged , , , | 2 Comments

Competitive programming in Haskell: introduction to dynamic programming

In my previous post, I challenged you to solve Zapis. In this problem, we are given a sequence of opening and closing brackets (parens, square brackets, and curly braces) with question marks, and have to compute the number of different ways in which the question marks could be replaced by brackets to create valid, properly nested bracket sequences.

For example, given (??), the answer is 4: we could replace the question marks with any matched pair (either (), [], or {}), or we could replace them with )(, resulting in ()().

An annoying aside

One very annoying thing to mention about this problem is that it requires us to output the last 5 digits of the answer. At first, I interpreted that to mean “output the answer modulo 10^5”, which would be a standard sort of condition for a combinatorics problem, but that’s not quite the same thing, in a very annoying way: for example, if the answer is 2, we are supposed to output 2; but if the answer is 1000000002, we are supposed to output 00002, not 2! So simply computing the answer modulo 10^5 is not good enough; if we get a final answer of 2, we don’t know whether we are supposed to pad it with zeros. I could imagine keeping track of both the result modulo 10^5 along with a Boolean flag telling us whether the number has ever overflowed; we have to pad with zeros iff the flag is set at the end. I’m pretty sure this would work. But for this problem, it turns out that the final answer is at most “only” about 100 digits, so we can just compute the answer exactly as an Integer and then literally show the last 5 digits.

A recurrence

Now, how to compute the answer? For this kind of problem the first step is to come up with a recurrence. Let s[0 \dots n-1] be the given string, and let c(i,j) be the number of ways to turn the substring s[i \dots j-1] into a properly nested sequence of brackets, so ultimately we want to compute the value of c(0,n). (Note we make c(i,j) correspond to the substring which includes i but excludes j, which means, for example, that the length of the substring is j-i.) First, some base cases:

  • c(i,i) = 1 since the empty string always counts as properly nested.
  • c(i,j) = 0 if i and j have different parity, since any properly nested string must have even length.

Otherwise, s[i] had better be an opening bracket of some kind, and we can try matching it with each of s[i+1], s[i+3], s[i+5], …, s[j-1]. In general, matching s[i] with s[k] can be done in either 0, 1, or 3 ways depending on whether they are proper opening and closing brackets and whether any question marks are involved; then we have c(i+1,k) ways to make the substring between s[i] and s[k] properly nested, and c(k+1,j) ways for the rest of the string following s[k]. These are all independent, so we multiply them. Overall, we get this:

c(i,j) = \begin{cases} 1 & i = j \\ 0 & i \not \equiv j \pmod 2 \\ \displaystyle \sum_{k \in [i+1, i+3, \dots, j-1]} m(s[i], s[k]) \cdot c(i+1,k) \cdot c(k+1,j) & \text{otherwise} \end{cases}

where m(x,y) counts the number of ways to make x and y into a matching pair of brackets: it returns 0 if the two characters cannot possibly be a matching open-close pair (either because they do not match or because one of them is the wrong way around); 1 if they match, and at most one of them is a question mark; and 3 if both are question marks.

How do we come up with such recurrences in the first place? Unfortunately, Haskell doesn’t really make this any easier—it requires some experience and insight. However, what we can say is that Haskell makes it very easy to directly code a recurrence as a recursive function, to play with it and ensure that it gives correct results for small input values.

A naive solution

To that end, if we directly code up our recurrence in Haskell, we get the following naive solution:

import Control.Arrow
import Data.Array

main = interact $ lines >>> last >>> solve >>> format

format :: Integer -> String
format = show >>> reverse >>> take 5 >>> reverse

solve :: String -> Integer
solve str = c (0,n)
  where
    n = length str
    s = listArray (0,n-1) str

    c :: (Int, Int) -> Integer
    c (i,j)
      | i == j           = 1
      | even i /= even j = 0
      | otherwise        = sum
        [ m (s!i) (s!k) * c (i+1,k) * c (k+1,j)
        | k <- [i+1, i+3 .. j-1]
        ]

m '(' ')'                = 1
m '[' ']'                = 1
m '{' '}'                = 1
m '?' '?'                = 3
m b '?' | b `elem` "([{" = 1
m '?' b | b `elem` ")]}" = 1
m _ _                    = 0

This solution is correct, but much too slow—it passes the first four test cases but then fails with a Time Limit Exceeded error. In fact, it takes exponential time in the length of the input string, because it has a classic case of overlapping subproblems. Our goal is to compute the same function, but in a way that is actually efficient.

Dynamic programming, aka memoizing recurrences

I hate the name “dynamic programming”—it conveys zero information about the thing that it names, and was essentially invented as a marketing gimmick. Dynamic programming is really just memoizing recurrences in order to compute them more efficiently. By memoizing we mean caching some kind of mapping from input to output values, so that we only have to compute a function once for each given input value; on subsequent calls with a repeated input we can just look up the corresponding output. There are many, many variations on the theme, but memoizing recurrences is really the heart of it.

In imperative languages, dynamic programming is often carried out by filling in tables via nested loops—the fact that there is a recurrence involved is obscured by the implementation. However, in Haskell, our goal will be to write code that is as close as possible to the above naive recursive version, but still actually efficient. Over the next few posts we will discuss several techniques for doing just that.

  • In part 1, we will explore the basic idea of using lazy, recursive, immutable arrays (which we have already seen in a previous post).
  • In part 2, we will use ideas from Conal Elliot’s MemoTrie package (and ultimately from a paper by Ralf Hinze) to clean up the code and make it a lot closer to the naive version.
  • In part 3, we’ll discuss how to memoize functions with infinite (or just very large) domains.
  • There may very well end up being more parts… we’ll see where it ends up!

Along the way I’ll also drop more links to relevant background. This will ultimately end up as a chapter in the book I’m slowly writing, and I’d like to make it into the definitive reference on dynamic programming in Haskell—so any thoughts, comments, links, etc. are most welcome!

Posted in competitive programming, haskell | Tagged , , , | 7 Comments

Competitive programming in Haskell: parsing with an NFA

In my previous post, I challenged you to solve Chemist’s Vows. In this problem, we have to decide which words can be made by concatenating atomic element symbols. So this is another parsing problem; but unlike the previous problem, element symbols are not prefix-free. For example, B and Be are both element symbols. So, if we see BE..., we don’t immediately know whether we should parse it as Be, or as B followed by an element that starts with E (such as Er).

A first try

A parsing problem, eh? Haskell actually shines in this area because of its nice parser combinator libraries. The Kattis environment does in fact have the parsec package available; and even on platforms that don’t have parsec, we can always use the Text.ParserCombinators.ReadP module that comes in base. So let’s try throwing one of those packages at the problem and see what happens!

If we try using parsec, we immediately run into problems; honestly, I don’t even know how to solve the problem using parsec. The problem is that <|> represents left-biased choice. If we parse p1 <|> p2 and parser p1 succeeds, then we will never consider p2. But for this parsing problem, because the symbols are not prefix-free, sometimes we can’t know which of two options we should have picked until later.

ReadP, on the other hand, explicitly has both biased and unbiased choice operators, and can return a list of possible parses instead of just a single parse. That sounds promising! Here’s a simple attempt using ReadP: to parse a single element, we use an unbiased choice over all the element names; then we use many parseElement <* eof to parse each word, and check whether there are any successful parses at all.

{-# LANGUAGE OverloadedStrings #-}

import           Control.Arrow
import           Data.Bool
import qualified Data.ByteString.Lazy.Char8   as C
import           Text.ParserCombinators.ReadP (ReadP, choice, eof, many,
                                               readP_to_S, string)

main = C.interact $
  C.lines >>> drop 1 >>> map (solve >>> bool "NO" "YES") >>> C.unlines

solve :: C.ByteString -> Bool
solve s = case readP_to_S (many parseElement <* eof) (C.unpack s) of
  [] -> False
  _  -> True

elements :: [String]
elements = words $
  "h he li be b c n o f ne na mg al si p s cl ar k ca sc ti v cr mn fe co ni cu zn ga ge as se br kr rb sr y zr nb mo tc ru rh pd ag cd in sn sb te i xe cs ba hf ta w re os ir pt au hg tl pb bi po at rn fr ra rf db sg bh hs mt ds rg cn fl lv la ce pr nd pm sm eu gd tb dy ho er tm yb lu ac th pa u np pu am cm bk cf es fm md no lr"

parseElement :: ReadP String
parseElement = choice (map string elements)

Unfortunately, this fails with a Time Limit Exceeded error (it takes longer than the allotted 5 seconds). The problem is that backtracking and trying every possible parse like this is super inefficient. One of the secret test inputs is almost cerainly constructed so that there are an exponential number of ways to parse some prefix of the input, but no way to parse the entire thing. As a simple example, the string crf can be parsed as either c rf (carbon + rutherfordium) or cr f (chromium + fluorine), so by repeating crf n times we can make a string of length 3n which has 2^n different parses. If we fed this string to the ReadP solution above, it would quickly succeed with more or less the first thing that it tried. However, if we stick a letter on the end that does not occur in any element symbol (such as q), the result will be an unparseable string, and the ReadP solution will spend a very long time backtracking through exponentially many parses that all ultimately fail.

Solution

The key insight is that we don’t really care about all the different possible parses; we only care whether the given string is parseable at all. At any given point in the string, there are only two possible states we could be in: we could be finished reading one element symbol and about to start reading the next one, or we could be in the middle of reading a two-letter element symbol. We can just scan through the string and keep track of the set of (at most two) possible states; in other words, we will simulate an NFA which accepts the language of strings composed of element symbols.

First, some setup as before.

{-# LANGUAGE OverloadedStrings #-}

import           Control.Arrow              ((>>>))
import           Data.Array                 (Array, accumArray, (!))
import           Data.Bool                  (bool)
import qualified Data.ByteString.Lazy.Char8 as C
import           Data.List                  (partition, nub)
import           Data.Set                   (Set)
import qualified Data.Set                   as S

main = C.interact $
  C.lines >>> drop 1 >>> map (solve >>> bool "NO" "YES") >>> C.unlines

elements :: [String]
elements = words $
  "h he li be b c n o f ne na mg al si p s cl ar k ca sc ti v cr mn
fe co ni cu zn ga ge as se br kr rb sr y zr nb mo tc ru rh pd ag cd
in sn sb te i xe cs ba hf ta w re os ir pt au hg tl pb bi po at rn
fr ra rf db sg bh hs mt ds rg cn fl lv la ce pr nd pm sm eu gd tb dy
ho er tm yb lu ac th pa u np pu am cm bk cf es fm md no lr"

Now, let’s split the element symbols into one-letter and two-letter symbols:

singles, doubles :: [String]
(singles, doubles) = partition ((==1).length) elements

We can now make boolean lookup arrays that tell us whether a given letter occurs as a single-letter element symbol (single) and whether a given letter occurs as the first letter of a two-letter symbol (lead). We also make a Set of all two-letter element symbols, for fast lookup.

mkAlphaArray :: [Char] -> Array Char Bool
mkAlphaArray cs = accumArray (||) False ('a', 'z') (zip cs (repeat True))

single, lead :: Array Char Bool
[single, lead] = map (mkAlphaArray . map head) [singles, doubles]

doubleSet :: Set String
doubleSet = S.fromList doubles

Now for simulating the NFA itself. There are two states we can be in: START means we are about to start and/or have just finished reading an element symbol; SEEN c means we have seen the first character of some element (c) and are waiting to see another.

data State = START | SEEN Char
  deriving (Eq, Ord, Show)

Our transition function takes a character c and a state and returns a set of all possible next states (we just use a list since these sets will be very small). If we are in the START state, we could end up in the START state again if c is a single-letter element symbol; we could also end up in the SEEN c state if c is the first letter of any two-letter element symbol. On the other hand, if we are in the SEEN x state, then we have to check whether xc is a valid element symbol; if so, we return to START.

delta :: Char -> State -> [State]
delta c START    = [START | single!c] ++ [SEEN c | lead!c]
delta c (SEEN x) = [START | [x,c] `S.member` doubleSet]

We can now extend delta to act on a set of states, giving us the set of all possible resulting states; the drive function then iterates this one-letter transition over an entire input string. Finally, to solve the problem, we start with the singleton set [START], call drive using the input string, and check whether START (which is also the only accepting state) is an element of the resulting set of states.

trans :: Char -> [State] -> [State]
trans c sts = nub (sts >>= delta c)

drive :: C.ByteString -> ([State] -> [State])
drive = C.foldr (\c -> (trans c >>>)) id

solve :: C.ByteString -> Bool
solve s = START `elem` drive s [START]

And that’s it! This solution is accepted in 0.27 seconds (out of a maximum allowed 5 seconds).

For next time

  • If you want to practice the concepts from my past couple posts, give Haiku a try.
  • For my next post, I challenge you to solve Zapis!
Posted in competitive programming, haskell | Tagged , , , , | 12 Comments

New ko-fi page: help me attend ICFP!

tl;dr: if you appreciate my past or ongoing contributions to the Haskell community, please consider helping me get to ICFP by donating via my new ko-fi page!

ko-fi

Working at a small liberal arts institution has some tremendous benefits (close interaction with motivated students, freedom to pursue the projects I want rather than jump through a bunch of hoops to get tenure, fantastic colleagues), and I love my job. But there are also downsides; the biggest ones for me are the difficulty of securing enough travel funding, and, relatedly, the difficulty of cultivating and maintaining collaborations.

I would really like to be able to attend ICFP in Seattle this September; the last time I was able to attend ICFP in person was 2019 in Berlin. With transportation, lodging, food, and registration fees, it will probably come to about $3000. I can get a grant from my instutition to pay for up to $1200, but that still leaves a big gap.

As I was brainstorming other sources of funding, it dawned on me that there are probably many people who have been positively impacted by my contributions to the Haskell community (e.g. CIS 194, the Typeclassopedia, diagrams, split, MonadRandom, burrito metaphors…) and/or would like to support my ongoing work (competitive programming in Haskell, swarm, disco, ongoing package maintenance…) and would be happy to donate a bit.

So, to that end, I have set up a ko-fi page.

  • If you have been positively impacted by my contributions and would like to help me get to ICFP this fall, one-time donations — even very small amounts — are greatly appreciated! I’m not going to promise any particular rewards, but if you’re at ICFP I will definitely find you to say thanks!

  • Thinking beyond this fall, ideally this could even become a reliable source of funding to help me travel to ICFP or other collaboration opportunities every year. To that end, if you’re willing to sign up for a recurring monthly donation, that would be amazing — think of it as supporting my ongoing work: blog posts (and book) on competitive programming in Haskell, Swarm and Disco development, and ongoing package maintenance. I will post updates on ko-fi with things I’m thinking about and working on; I am also going to try to publish more frequent blog posts, at least in the near term.

Thank you, friends — I hope to see many people in Seattle! Next up: back to your regularly scheduled competitive programming!

Posted in haskell, meta | Tagged , , , , , | Leave a comment

Competitive programming in Haskell: tries

In my previous post, I challenged you to solve Alien Math, which is about reading numbers in some base B, but with a twist. We are given a list of B strings representing the names of the digits 0 through B-1, and a single string describing a number, consisting of concatenated digit names. For example, if B = 3 and the names of the digits are zero, one, two, then we might be given a string like twotwozerotwoone, which we should interpret as 22021_3 = 223_{10}. Crucially, we are also told that the digit names are prefix-free, that is, no digit name is a prefix of any other. But other than that, the digit names could be really weird: they could be very different lengths, some digit names could occur as substrings (just not prefixes) of others, digit names could share common prefixes, and so on. So this is really more of a parsing problem than a math problem; once we have parsed the string as a list of digits, converting from base B is the easy part.

One simple way we can do this is to define a map from digit names to digits, and simply look up each prefix of the given string until we find a hit, then chop off that prefix and start looking at successive prefixes of the remainder. This takes something like O(n^2 \lg n) time in the worst case (I think)—but this is actually fine since n is at most 300. This solution is accepted and runs in 0.00 seconds for me.

Tries

However, I want to talk about a more sophisticated solution that has better asymptotic time complexity and generalizes nicely to other problems. Reading a sequence of strings from a prefix-free set should make you think of Huffman coding, if you’ve ever seen that before. In general, the idea is to define a trie containing all the digit names, with each leaf storing the corresponding digit. We can then scan through the input one character at a time, keeping track of our current position in trie, and emit a digit (and restart at the root) every time we reach a leaf. This should run in O(n) time.

Let’s see some generic Haskell code for tries (this code can also be found at byorgey/comprog-hs/Trie.hs on GitHub). First, some imports, a data type definition, and emptyTrie and foldTrie for convenience:

module Trie where

import           Control.Monad              ((>=>))
import qualified Data.ByteString.Lazy.Char8 as C
import           Data.List                  (foldl')
import           Data.Map                   (Map, (!))
import qualified Data.Map                   as M
import           Data.Maybe                 (fromMaybe)

data Trie a = Trie
  { trieSize :: !Int
  , value    :: !(Maybe a)
  , children :: !(Map Char (Trie a))
  }
  deriving Show

emptyTrie :: Trie a
emptyTrie = Trie 0 Nothing M.empty

-- | Fold a trie into a summary value.
foldTrie :: (Int -> Maybe a -> Map Char r -> r) -> Trie a -> r
foldTrie f (Trie n b m) = f n b (M.map (foldTrie f) m)

A trie has a cached size (we could easily generalize this to store any sort of monoidal annotation), a possible value (i.e. the value associated with the empty string key, if any), and a map from characters to child tries. The cached size is not needed for this problem, but is included since I needed it for some other problems.

Now for inserting a key/value pair into a Trie. This code honestly took me a while to get right! We fold over the given string key, producing for each key suffix a function which will insert that key suffix into a trie. We have to be careful to correctly update the size, which depends on whether the key being inserted already exists—so the recursive go function actually returns a pair of a new Trie and an Int representing the change in size.

-- | Insert a new key/value pair into a trie, updating the size
--   appropriately.
insert :: C.ByteString -> a -> Trie a -> Trie a
insert w a t = fst (go w t)
  where
    go = C.foldr
      (\c insSuffix (Trie n v m) ->
         let (t', ds) = insSuffix (fromMaybe emptyTrie (M.lookup c m))
         in  (Trie (n+ds) v (M.insert c t' m), ds)
      )
      (\(Trie n v m) ->
         let ds = if isJust v then 0 else 1
         in  (Trie (n+ds) (Just a) m, ds)
      )

Now we can create an entire Trie in one go by folding over a list of key/value pairs with insert:

-- | Create an initial trie from a list of key/value pairs.  If there
--   are multiple pairs with the same key, later pairs override
--   earlier ones.
mkTrie :: [(C.ByteString, a)] -> Trie a
mkTrie = foldl' (flip (uncurry insert)) emptyTrie

A few lookup functions: one to look up a single character and return the corresponding child trie, and then on top of that we can build one to look up the value associated to an entire string key.

-- | Look up a single character in a trie, returning the corresponding
--   child trie (if any).
lookup1 :: Char -> Trie a -> Maybe (Trie a)
lookup1 c = M.lookup c . children

-- | Look up a string key in a trie, returning the corresponding value
--   (if any).
lookup :: C.ByteString -> Trie a -> Maybe a
lookup = C.foldr ((>=>) . lookup1) value

Finally, a function that often comes in handy for using a trie to decode a prefix-free code. It takes an input string and looks it up character by character; every time it encounters a key which exists in the trie, it emits the corresponding value and then starts over at the root of the trie.

decode :: Trie a -> C.ByteString -> [a]
decode t = reverse . snd . C.foldl' step (t, [])
  where
    step (s, as) c =
      let Just s' = lookup1 c s
      in  maybe (s', as) (\a -> (t, a:as)) (value s')

These tries are limited to string keys, since that is most useful in a competitive programming context, but it is of course possible to make much more general sorts of tries — see Hinze, Generalizing Generalized Tries.

Solution

Finally, we can use our generic tries to solve the problem: read the input, build a trie mapping digit names to values, use the decode function to read the given number, and finally interpret the resulting list of digits in the given base.

import Control.Arrow ((>>>))
import ScannerBS
import Trie

main = C.interact $ runScanner tc >>> solve >>> showB

data TC = TC { base :: Integer, digits :: [C.ByteString], number :: C.ByteString }
  deriving (Eq, Show)

tc :: Scanner TC
tc = do
  base <- integer
  TC base <$> (fromIntegral base >< str) <*> str

solve :: TC -> Integer
solve TC{..} = foldl' (\n d -> n*base + d) 0 (decode t number)
  where
    t = mkTrie (zip digits [0 :: Integer ..])

Practice problems

Here are a few other problems where you can profitably make use of tries. Some of these can be solved directly using the Trie code given above; others may require some modifications or enhancements to the basic concept.

For next time

For next time, I challenge you to solve Chemist’s vows!

Posted in competitive programming, haskell | Tagged , , , | 2 Comments

Competitive programming in Haskell: topsort via laziness

In my previous post, I challenged you to solve Letter Optimiztion. In this problem, we have a directed acyclic graph where each vertex represents a person, and there is an edge p -> q when person p sends their finished envelopes to person q. Also:

  • Some people may send their envelopes to multiple other people, in which case they send a certain percentage of their output to each.
  • Each person has a maximum speed at which they are able to process envelopes, measured in envelopes per second.
  • The people with no inputs are assumed to have an infinite stack of envelopes and therefore work at their maximum speed.
  • There are guaranteed to be no cycles.

The problem is to figure out which people are actually working at their maximum speed. Of course, the reason this is interesting is that the rate at which a person can work is partially determined by the rate at which envelopes are coming to them, which depends on the rates at which people before them in the pipeline are working, and so on.

The typical way to solve this would be to first topologically sort the people (e.g. using a DFS or Kahn’s Algorithm), then fill in the speed of each person in order of the topological sort. That way, when we calculate each person’s rate, we already know the rates of anyone that sends them input. This can also be thought of as a particularly simple form of dynamic programming.

Get rid of topological sort with this one neat trick

However, there is a nice trick we can use in Haskell to save ourselves a bunch of work: instead of doing an explicit topological sort, we can simply define a lazy, recursive array or map with the final values we want; laziness will take care of evaluating the array or map entries in the correct order. Essentially, we are co-opting the Haskell runtime into doing a topological sort for us!

Let’s see some code! First, some pragmas, imports, and boring utility functions. (For an explanation of the Scanner import, see this post and also this one.)

{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TupleSections   #-}

import           Control.Arrow              (second, (>>>))
import           Data.Array
import qualified Data.ByteString.Lazy.Char8 as C
import           Data.Set                   (Set)
import qualified Data.Set                   as S
import           ScannerBS

infixl 0 >$>
(>$>) = flip ($)

showB :: Show a => a -> C.ByteString
showB = show >>> C.pack

Now for some data types to represent the input, and some code to parse it.

data Person = Person { maxSpeed :: Double, sends :: [(Int, Double)] }
  deriving (Eq, Show)
data TC = TC { n :: Int, people :: Array Int Person }
  deriving (Eq, Show)

tc :: Scanner TC
tc = do
  n <- int
  people <- listArray (1,n) <$> (n >< (Person <$> double <*> numberOf ((,) <$> int <*> ((/100) <$> double))))
  return TC{..}

As an aside, notice how I use a record wildcard to create the output TC value. I find this a quick, simple, and consistent way to structure my scanning code, without having to come up with multiple names for the same thing. I don’t know whether I would ever use it in production code; I’ll leave that to others for debate.

To solve the problem, we take an array production holding the computed production speeds for each person (we’ll see how to build it in a minute), and extract the people who are working at their max speed.

main = C.interact $ runScanner tc >>> solve >>> map showB >>> C.unwords

solve :: TC -> [Int]
solve TC{..} =
  production >$> assocs >>>
  filter (\(p,u) -> abs (u - maxSpeed (people!p)) < 0.0001) >>>
  map fst

How do we compute the array of production speeds? First, we build a map from each person to their set of inputs:

 where
  -- inputMap!p = set of people from whom p gets input, with percentage for each
  inputMap :: Array Int (Set (Int,Double))
  inputMap = accumArray (flip S.insert) S.empty (1,n) (concatMap getInputs (assocs people))

  getInputs :: (Int, Person) -> [(Int, (Int, Double))]
  getInputs (p, Person _ ss) = map (second (p,)) ss

Now we create a lazy, recursive Array that maps each person to their production speed. Notice how the definition of production refers to itself: this works because the Array type is lazy in the values stored in the array. The values are not computed until we actually demand the value stored for a particular index; the Haskell runtime then goes off to compute it, which may involve demanding the values at other indices, and so on.

  production :: Array Int Double
  production = array (1,n)
    [ (p,u)
    | p <- [1 .. n]
    , let m = maxSpeed (people!p)
          i = (inputMap!p) >$> S.toList >>> map (\(x,pct) -> pct * (production!x)) >>> sum
          u = if S.null (inputMap!p) then m else min i m
    ]

For each person p, m is their maximum speed, i is the sum of all the production coming from their inputs (depending on their inputs’ own production speeds), and the person’s production speed u is the minimum of their input and their maximum speed (or simply their maximum speed if they have no inputs).

For next time

For next time, I challenge you to solve Alien Math!

Posted in haskell, competitive programming | Tagged , , | 5 Comments

Competitive programming in Haskell challenge: Letter Optimization

Now that I’ve wrapped up my series of posts on Infinite 2D Array (challenge, part 1, part 2, part 3), I’d like to tackle another topic related to competitive programming in Haskell. I won’t tell you what the topic is quite yet—for now, I challenge you to use Haskell to solve Letter Optimiztion! Feel free to post thoughts, questions, solutions, etc. in the comments (so don’t look at the comments if you don’t want spoilers!). In an upcoming post I will discuss my solution.

Posted in competitive programming, haskell | Tagged , , | 3 Comments

Competitive programming in Haskell: Infinite 2D array, Level 4

In a previous post, I challenged you to solve Infinite 2D Array using Haskell. After deriving a formula for F_{x,y} that involves only a linear number of terms, last time we discussed how to efficiently calculate Fibonacci numbers and binomial coefficients modulo a prime. Today, we’ll finally see some actual Haskell code for solving this problem.

The code is not very long, and seems rather simple, but what it doesn’t show is the large amount of time and effort I spent trying different versions until I figured out how to make it fast enough. Later in the post I will share some lessons learned.

Modular arithmetic

When a problem requires a fixed modulus like this, I typically prefer using a newtype M with a Num instance that does all operations using modular arithmetic, as explained in this post. However, that approach has a big downside: we cannot (easily) store our own newtype values in an unboxed array (UArray), since that requires defining a bunch of instances by hand. And the speedup we get from using unboxed vs boxed arrays is significant, especially for this problem.

So instead I just made some standalone functions to do arithmetic modulo 10^9 + 7:

p :: Int
p = 10^9 + 7

padd :: Int -> Int -> Int
padd x y = (x + y) `mod` p

pmul :: Int -> Int -> Int
pmul x y = (x*y) `mod` p

What about modular inverses? At first I defined a modular inverse operation based on my own implementation of the extended Euclidean Algorithm, but at some point I did some benchmarking and realized that my egcd function was taking up the majority of the runtime, so I replaced it with a highly optimized version taken from the arithmoi package. Rather than pasting in the code I will let you go look at it yourself if you’re interested.

Given the efficient extendedGCD, we can now define modular inversion like so:

inv :: Int -> Int
inv a = y `mod` p
  where
    (_,_,y) = extendedGCD p a

Fibonacci numbers and factorials

We want to compute Fibonacci numbers and factorials modulo p = 10^9 + 7 and put them in tables so we can quickly look them up later. The simplest way to do this is to generate an infinite list of each (via the standard knot-tying approach in the case of Fibonacci numbers, and scanl' in the case of factorials) and then put them into an appropriate UArray:

fibList :: [Int]
fibList = 0 : 1 : zipWith padd fibList (tail fibList)

fib :: UArray Int Int
fib = listArray (0, 10^6) fibList

fac :: UArray Int Int
fac = listArray (0, 2*10^6) (scanl' pmul 1 [1 ..])

I should mention that at one point I defined fib this way instead:

fib' :: Array Int Int
fib' = array (0, 10^6) $ (0,0):(1,1):[ (i, fib!(i-1) `padd` fib!(i-2)) | i <- [2 .. 10^6]]

This takes advantage of the fact that unboxed arrays are lazy in their values—and can hence be constructed recursively—to directly define the array via dynamic programming. But this version is much slower, and uglier to boot! (If we really needed to initialize an unboxed array using recursion/dynamic programming, we could do that via runSTUArray, but it would be overkill for this problem.)

Binomial coefficients modulo a prime

We can now efficiently compute binomial coefficients using fac and inv, like so:

mbinom :: Int -> Int -> Int
mbinom m k = (fac!m) `pdiv` ((fac!k) `pmul` (fac!(m-k)))

As mentioned in a previous post, this only works since the modulus is prime; otherwise, more complex techniques would be needed.

We could also precompute all inverse factorials, and then we can get rid of the pdiv call in mbinom (remember that pmul is very fast, whereas pdiv has to call extendedGCD):

ifac :: UArray Int Int
ifac = listArray (0, 2*10^6) (scanl' pdiv 1 [1 ..])

mbinom' :: Int -> Int -> Int
mbinom' m k = (fac!m) `pmul` (ifac!k) `pmul` (ifac!(m-k))

For this particular problem, it doesn’t make much difference either way, since the total number of pdiv calls stays about the same. But this can be an important optimization for problems where the number of calls to mbinom will be much larger than the max size of its arguments.

Putting it all together

Finally, we can put all the pieces together to solve the problem like so:

main = interact $ words >>> map read >>> solve >>> show

solve :: [Int] -> Int
solve [x,y] =
  sum [ (fib!k) `pmul` mbinom (x-k+y-1) (x-k) | k <- [1 .. x]] `padd`
  sum [ (fib!k) `pmul` mbinom (y-k+x-1) (y-k) | k <- [1 .. y]]

Lessons learned

The fact that the above code is fairly short (besides extendedGCD) belies the amount of time I spent trying to optimize it. Here are some things I learned while benchmarking and comparing different versions.

First, we should try really, really hard to use unboxed arrays (UArray) instead of boxed arrays (Array). Boxed arrays have one distinct advantage, which is that they can be constructed lazily, and hence recursively. This helps a lot for dynamic programming problems (which I have a lot to write about at some point in the future). But otherwise, they introduce a ton of overhead.

In this particular problem, committing to use UArray meant (1) using explicit modular operations like padd and pmul instead of a newtype, and (2) constructing the fib array by calculating a list of values and then using it to construct the array, instead of defining the array via recursion/DP.

The optimized implementation of extendedGCD makes a big difference, too, which makes sense: a majority of the computation time for this problem is spent running it (via pdiv). I don’t know what general lesson to draw from this other than affirm the value of profiling to figure out where optimizations would help the most.

I tried a whole bunch of other things which turn out to make very little difference in comparison to the above optimizations. For example:

  • Optimizing padd and pmul to conditionally avoid an expensive mod operation when the arguments are not too big: this sped things up a tiny bit but not much.

  • Rewriting everything in terms of a tail-recursive loop that computes the required Fibonacci numbers and binomial coefficients incrementally, and hence does not require any lookup arrays:

solve' :: [Int] -> Int
solve' [x,y] = go x y 0 1 0 1 (mbinom (x+y-2) (x-1)) `padd`
               go y x 0 1 0 1 (mbinom (x+y-2) (y-1))
  where
    -- Invariants:
    --   s  = sum so far
    --   k  = current k
    --   f' = F_{k-1}
    --   f  = F_k
    --   bx  = binom (x-k+y-1) (x-k)
    go x y !s !k !f' !f !bx
      | k > x     = s
      | otherwise
      = go x y (s `padd` (bx `pmul` f)) (k+1)
           f (f' `padd` f) ((bx `pdiv` (x-k+y-1)) `pmul` (x-k))

    mbinom' n k = fac' n `pdiv` (fac' k `pmul` fac' (n-k))
    fac' k = foldl' pmul 1 [1 .. k]

This version is super ugly and erases most of the benefits of using Haskell in the first place, so I am happy to report that it runs in exactly the same amount of time as the solution I described earlier.

Posted in competitive programming, haskell | Tagged , , | 4 Comments

Subtracting natural numbers: types and usability

For several years now I have been working on a functional teaching language for discrete mathematics, called Disco. It has a strong static type system, subtyping, equirecursive algebraic types, built-in property-based testing, and mathematically-inspired syntax. If you want to know more about it in general, you can check out the GitHub repo, or give it a try on replit.com.

In this blog I want to write about a particular usability issue surrounding the type of the subtraction operation, partly because I think some might find it interesting, and partly because forcing myself to clearly articulate possible solutions may help me come to a good resolution.

The problem with subtraction

Disco supports four basic numeric types: natural numbers \mathbb{N}, integers \mathbb{Z}, “fractional” numbers \mathbb{F} (i.e. nonnegative rationals), and rational numbers \mathbb{Q}. These types form a subtyping lattice, with natural numbers being a subtype of both integers and fractionals, and integers and fractionals in turn being subtypes of the rationals. All the numeric types support addition and multiplication; the integers allow negation/subtraction, the fractionals allow reciprocals/division, and rationals allow both.

So what is the type of x - y? Clearly it has to be either \mathbb{Z} or \mathbb{Q}; that’s the whole point. Natural numbers and fractional numbers are not closed under subtraction; \mathbb{Z} and \mathbb{Q} are precisely what we get when we start with \mathbb{N} or \mathbb{F} and decide to allow subtraction, i.e. when we throw in additive inverses for everything.

However, this is one of the single biggest things that trips up students. As an example, consider the following function definition:

fact_bad : N -> N
fact_bad(0) = 1
fact_bad(n) = n * fact_bad(n-1)

This looks perfectly reasonable on the surface, and would work flawlessly at runtime. However, it does not typecheck: the argument to the recursive call must be of type \mathbb{N}, but since n-1 uses subtraction, it cannot have that type.

This is very annoying in practice for several reasons. The most basic reason is that, in my experience at least, it is very common: students often write functions like this without thinking about the fact that they happened to use subtraction along the way, and are utterly baffled when the function does not type check. This case is also extra annoying since it would work at runtime: we can clearly reason that if n is a natural number that is not zero, then it must be 1 or greater, and hence n-1 will in fact be a natural number. Because of Rice’s Theorem, we know that every decidable type system must necessarily exclude some programs as untypeable which nonetheless do not “go wrong”, i.e. exhibit no undesirable behavior when evaluated. The above fact_bad function is a particularly irksome example.

To be clear, there is nothing wrong with the type system, which is working exactly as intended. Rather, the problem lies in the fact that this is a common and confusing issue for students.

Implementing factorial

You may be wondering how it is even possible to implement something like factorial at all without being able to subtract natural numbers. In fact, there are two good ways to implement it, but they don’t necessarily solve the problem of student confusion.

  • One solution is to use an arithmetic pattern and match on n+1 instead of n, like this:

    fact_ok1 : N -> N
    fact_ok1(0) = 1
    fact_ok1(n+1) = (n+1) * fact_ok1(n)

    This works, and it’s theoretically well-motivated, but feels somewhat unsatisfying: both because we have to repeat n+1 and because this style of definition probably feels foreign to anyone except those who have played with a Nat algebraic data type (which excludes the vast majority of Discrete Math students).

  • Another solution is to use a saturating subtraction operator, x \mathbin{\dot -} y = \max(0, x - y). In Disco this operator is written x .- y. Unlike normal subtraction, it can have the type \mathbb{N} \times \mathbb{N} \to \mathbb{N}, so we can rewrite the factorial function this way:

    fact_ok2 : N -> N
    fact_ok2(0) = 1
    fact_ok2(n) = n * fact_ok2(n .- 1)

    The .- operator is also theoretically well-motivated, being the “monus” operator for the commutative monoid of natural numbers under addition. However, in my experience, students are annoyed and confused by this. They often do not understand when and why they are supposed to use .-. Of course, better error messages could help here, as could better pedagogy. This is actually my current approach: this semester I talked about the difference between \mathbb{N} and \mathbb{Z} very early, hitting on the fact that \mathbb{N} is not closed under subtraction, and explicitly made them explore the use of the .- operator in their first homework assignment. We’ll see how it goes!

Some tempting and expedient, but wrong, solutions

One solution that sounds nice on the surface is to just pun the notation: why not just have a single operator -, but make it behave like .- on types without negation (\mathbb{N} and \mathbb{F}), and like normal subtraction on \mathbb{Z} and \mathbb{Q}? That way students wouldn’t have to remember to use one version or the other, they can just use subtraction and have it do the right thing depending on the type.

This would be sound from a type system point of view; that is, we would never be able to produce, say, a negative value with type \mathbb{N}. However, in the presence of subtyping and type inference, there is a subtle problem from a semantics point of view. To understand the problem, consider the following function:

f : N -> Z
f(n) = (-3) * (n - 5)

What is the output of f(3)? Most people would say it should be (-3) * (3 - 5) = (-3) * (-2) = 6. However, if the behavior of subtraction depends on its type, it would also be sound for f(3) to output 0! The input 3 and the constant 5 can both be given the type \mathbb{N}, in which case the subtraction would act as a saturating subtraction and result in 0.

What’s going on here? Conceptually, one of the jobs of type inference, when subtyping is involved, is to decide where to insert type coercions. (Practically speaking, in Disco, such coercions are always no-ops; for example, all numeric values are represented as Rational, so 3 : N and 3 : Q have the same runtime representation.) An important guiding principle is that the semantics of a program should not depend on where coercions are inserted, and type-dependent-subtraction violates this principle. f(3) evaluates to either 6 or 0, depending on whether a coercion from \mathbb{N} to \mathbb{Z} is inserted inside or outside the subtraction. Violating this principle can make it very difficult for anyone (let alone students!) to understand the semantics of a given program: at worst it is ambiguous or undefined; at best, it depends on understanding where coercions will be inserted.

What about having - always mean subtraction, but crash at runtime if we try to subtract natural numbers and get something less than 0? That way we can use it as long as we “know it is safe” (as in the factorial example). Unfortunately, this has the exact same issue, which the above example with f(3) still illustrates perfectly: f(3) can either evaluate to 6 or crash, depending on exactly where coercions are inserted.

Typechecking heuristics?

Another interesting option would be to make typechecking a bit smarter, so that instead of only keeping track of the type of each variable, we also sometimes keep track of values we statically know a variable can and can’t have in a certain context. We could then use this information to allow subtraction to have a type like \mathbb{N} \times \mathbb{N} \to \mathbb{N} as long as we can statically prove it is safe. For example, after matching on 0 in the first line of fact_bad, in the second line we know n cannot be 0, and we could imagine using this information to decide that the expression n - 1 is safe. This scheme would not change the semantics of any existing programs; it would only allow some additional programs to typecheck which did not before.

Of course, this would never be complete—there would always be examples of Disco programs where we can prove that a certain subtraction is safe but the heuristics don’t cover it. But it might still go a long way towards making this kind of thing less annoying. On the other hand, it makes errors even more mysterious when they do happen, and hard to understand when a program will and won’t typecheck. Perhaps it is best to just double down on the pedagogy and get students to understand the difference between \mathbb{N} and \mathbb{Z}!

Division?

As a final aside, note that we have the same issue with division: x / y is only allowed at types \mathbb{F} or \mathbb{Q}. If we want to divide integers, we can use a different built-in operator, // which does integer division, i.e. “rounds down”. However, this is not nearly as bad of an issue in practice, both because some students are already used to the idea of integer division (e.g. Python makes the same distinction), and because wanting to divide integers does not come up nearly as often, in practice, as wanting to subtract natural numbers.

Posted in projects, teaching | Tagged , , , , | 3 Comments