Swarm: preview and call for collaboration

For about a month now I have been working on building a game1, tentatively titled Swarm. It’s nowhere near finished, but it has at least reached a point where I’m not embarrassed to show it off. I would love to hear feedback, and I would especially love to have others contribute! Read on for more details.

Swarm is a 2D tile-based resource gathering game, but with a twist: the only way you can interact with the world is by building and programming robots. And there’s another twist: the kinds of commands your robots can execute, and the kinds of programming language features they can interpret, depends on what devices they have installed; and you can create new devices only by gathering resources. So you start out with only very basic capabilities and have to bootstrap your way into more sophisticated forms of exploration and resource collection.

I guess you could say it’s kind of like a cross between Minecraft, Factorio, and Karel the Robot, but with a much cooler programming language (lambda calculus + polymorphism + recursion + exceptions + a command monad for first-class imperative programs + a bunch of other stuff).

The game is far from complete, and especially needs a lot more depth in terms of the kinds of devices and levels of abstraction you can build. But for me at least, it has already crossed the line into something that is actually somewhat fun to play.

If it sounds interesting to you, give it a spin! Take a look at the README and the tutorial. If you’re interested in contributing to development, check out the CONTRIBUTING file and the GitHub issue tracker, which I have populated with a plethora of tasks of varying difficulty. This could be a great project to contribute to especially if you’re relatively new to Haskell; I try to keep everything well-organized and well-commented, and am happy to help guide new contributors.

  1. Can you tell I am on sabbatical?

Posted in haskell, projects | Tagged , , , , | 5 Comments

Competitive programming in Haskell: Codeforces Educational Round 114

Yesterday morning I competed in Educational Round 114 on codeforces.com, using only Haskell. It is somewhat annoying since it does not support as many Haskell libraries as Open Kattis (e.g. no unordered-containers, split, or vector); but on the other hand, a lot of really top competitive programmers are active there, and I enjoy occasionally participating in a timed contest like this when I am able.

WARNING: here be spoilers! Stop reading now if you’d like to try solving the contest problems yourself. (However, Codeforces has an editorial with explanations and solutions already posted, so I’m not giving anything away that isn’t already public.) I’m going to post my (unedited) code for each problem, but without all the imports and LANGUAGE extensions and whatnot; hopefully that stuff should be easy to infer.

Problem A – Regular Bracket Sequences

In this problem, we are given a number n and asked to produce any n distinct balanced bracket sequences of length 2n. I immediately just coded up a simple recursive function to generate all possible bracket sequences of length 2n, and then called take n on it. Thanks to laziness this works great. I missed that there is an even simpler solution: just generate the list ()()()()..., (())()()..., ((()))()..., i.e. where the kth bracket sequence starts with k nested pairs of brackets followed by n-k singleton pairs. However, I solved it in only four minutes anyway so it didn’t really matter!

readB = C.unpack >>> read

main = C.interact $
  C.lines >>> drop 1 >>> concatMap (readB >>> solve) >>> C.unlines

bracketSeqs 0 = [""]
bracketSeqs n =
  [ "(" ++ s1 ++ ")" ++ s2
  | k <- [0 .. n-1]
  , s1 <- bracketSeqs k
  , s2 <- bracketSeqs (n - k - 1)

solve n = map C.pack . take n $ bracketSeqs n

Problem B – Combinatorics Homework

In this problem, we are given numbers a, b, c, and m, and asked whether it is possible to create a string of a A’s, b B’s, and c C’s, such that there are exactly m adjacent pairs of equal letters. This problem requires doing a little bit of combinatorial analysis to come up with a simple Boolean expression in terms of a, b, c, and m; there’s not much to say about it from a Haskell point of view. You can refer to the editorial posted on Codeforces if you want to understand the solution.

readB = C.unpack >>> read

main = C.interact $
  C.lines >>> drop 1 >>> map (C.words >>> map readB >>> solve >>> bool "NO" "YES") >>> C.unlines

solve :: [Int] -> Bool
solve [a,b,c,m] = a + b + c - m >= 3 && m >= z - (x+y) - 1
    [x,y,z] = sort [a,b,c]

Problem C – Slay the Dragon

This problem was super annoying and I still haven’t solved it. The idea is that you have a bunch of “heroes”, each with a numeric strength, and there is a dragon described by two numbers: its attack level and its defense level. You have to pick one hero to fight the dragon, whose strength must be greater than or equal to the dragon’s defense; all the rest of the heroes will stay behind to defend your castle, and their combined strength must be greater than the dragon’s attack. This might not be possible, of course, so you can first spend money to level up any of your heroes, at a rate of one coin per strength point; the task is to find the minimum amount of money you must spend.

The problem hinges on doing some case analysis. It took me a good while to come up with something that I think is correct. I spent too long trying to solve it just by thinking hard; I really should have tried formal program derivation much earlier. It’s easy to write down a formal specification of the correct answer which involves looping over every hero and taking a minimum, and this can be manipulated into a form that doesn’t need to do any looping.

In the end it comes down to (for example) finding the hero with the smallest strength greater than or equal to the dragon’s defense, and the hero with the largest strength less than or equal to it (though one of these may not exist). The intended way to solve the problem is to sort the heroes by strength and use binary search; instead, I put all the heroes in an IntSet and used the lookupGE and lookupLE functions.

However, besides my floundering around getting the case analysis wrong at first, I got tripped up by two other things: first, it turns out that on the Codeforces judging hardware, Int is only 32 bits, which is not big enough for this problem! I know this because my code was failing on the third test case, and when I changed it to use Int64 instead of Int (which means I also had to switch to Data.Set instead of Data.IntSet), it failed on the sixth test case instead. The other problem is that my code was too slow: in fact, it timed out on the sixth test case rather than getting it wrong per se. I guess Data.Set and Int64 just have too much overhead.

Anyway, here is my code, which I think is correct, but is too slow.

data TC = TC { heroes :: ![Int64], dragons :: ![Dragon] }
data Dragon = Dragon { defense :: !Int64, attack :: !Int64 }

main = C.interact $
  runScanner tc >>> solve >>> map (show >>> C.pack) >>> C.unlines

tc :: Scanner TC
tc = do
  hs <- numberOf int64
  ds <- numberOf (Dragon <$> int64 <*> int64)
  return $ TC hs ds

solve :: TC -> [Int64]
solve (TC hs ds) = map fight ds
    heroSet = S.fromList hs
    total = foldl' (+) 0 hs
    fight (Dragon df atk) = minimum $
      [ max 0 (atk - (total - hero)) | Just hero <- [mheroGE] ]
      [ df - hero + max 0 (atk - (total - hero)) | Just hero <- [mheroLE]]
        mheroGE = S.lookupGE df heroSet
        mheroLE = S.lookupLE df heroSet

I’d like to come back to this later. Using something like vector to sort and then do binary search on the heroes would probably be faster, but vector is not supported on Codeforces. I’ll probably end up manually implementing binary search on top of something like Data.Array.Unboxed. Doing a binary search on an array also means we can get away with doing only a single search, since the two heroes we are looking for must be right next to each other in the array.

Edited to add: I tried creating an unboxed array and implementing my own binary search over it; however, my solution is still too slow. At this point I think the problem is the sorting. Instead of calling sort on the list of heroes, we probably need to implement our own quicksort or something like that over a mutable array. That doesn’t really sound like much fun so I’m probably going to forget about it for now.

Problem D – The Strongest Build

In this problem, we consider a set of k-tuples, where the value for each slot in a tuple is chosen from among a list of possible values unique to that slot (the values for a slot are given to us in sorted order). For example, perhaps the first slot has the possible values 1, 2, 3, the second slot has possible values 5, 8, and the third slot has possible values 4, 7, 16. In this case there would be 3 \times 2 \times 3 possible tuples, ranging from (1,5,4) up to (3,8,16). We are also given a list of forbidden tuples, and then asked to find a non-forbidden tuple with the largest possible sum.

If the list of slot options is represented as a list of lists, with the first list representing the choices for the first slot, and so on, then we could use sequence to turn this into the list of all possible tuples. Hence, a naive solution could look like this:

solve :: Set [Int] -> [[Int]] -> [Int]
solve forbidden =
  head . filter (`S.notMember` forbidden) . sortOn (Down . sum) . sequence

Of course, this is much too slow. The problem is that although k (the size of the tuples) is limited to at most 10, there can be up to 2 \cdot 10^5 choices for each slot (the choices themselves can be up to 10^8). The list of all possible tuples could thus be truly enormous; in theory, there could be up to (2 \cdot 10^5)^{10} \approx 10^{53}), and generating then sorting them all is out of the question.

We can think of the tuples as forming a lattice, where the children of a tuple t are all the tuples obtained by downgrading exactly one slot of t to the next smaller choice. Then the intended solution is to realize that the largest non-forbidden tuple must either be the top element of the lattice (the tuple with the maximum possible value for every slot), OR a child of one of the forbidden tuples (it is easy to see this by contradiction—any tuple which is not the child of a forbidden tuple has at least one parent which has a greater total value). So we can just iterate over all the forbidden tuples (there are at most 10^5), generate all possible children (at most 10) for each one, and take the maximum.

However, that’s not how I solved it! I started thinking from the naive solution above, and wondered whether there is a way to do sortOn (Down . sum) . sequence more efficiently, by interleaving the sorting and the generation. If it can be done lazily enough, then we could just search through the beginning of the generated ordered list of tuples for the first non-forbidden one, without having to actually generate the entire list. Indeed, this reminded me very much of Richard Bird’s implementation of the Sieve of Eratosthenes (see p. 11 of that PDF). The basic idea is to make a function which takes a list of choices for a slot, and a (recursively generated) list of tuples sorted by decreasing sum, and combines each choice with every tuple, merging the results so they are still sorted. However, the key is that when combining the best possible choice for the slot with the largest tuple in the list, we can just immediately return the resulting tuple as the first (best) tuple in the output list, without needing to involve it in any merging operation. This affords just enough laziness to get the whole thing off the ground. I’m not going to explain it in more detail than that; you can study the code below if you like.

I’m quite pleased that this worked, though it’s definitely an instance of me making things more complicated than necessary.

data TC = TC { slots :: [[Choice]], banned :: [[Int]] }

tc = do
  n <- int
  TC <$> (n >< (zipWith Choice [1 ..] <$> numberOf int)) <*> numberOf (n >< int)

main = C.interact $
  runScanner tc >>> solve >>> map (show >>> C.pack) >>> C.unwords

solve :: TC -> [Int]
solve TC{..} = choices . fromJust $ find ((`S.notMember` bannedSet) . choices) bs
    bannedSet = S.fromList banned
    revSlots = map reverse slots
    bs = builds revSlots

data Choice = Choice { index :: !Int, value :: !Int }

data Build = Build { strength :: !Int, choices :: [Int] }
  deriving (Eq, Show, Ord)

singletonBuild :: Choice -> Build
singletonBuild (Choice i v) = Build v [i]

mkBuild xs = Build (sum xs) xs

-- Pre: all input lists are sorted descending.
-- All possible builds, sorted in descending order of strength.
builds :: [[Choice]] -> [Build]
builds []     = []
builds (i:is) = chooseFrom i (builds is)

chooseFrom :: [Choice] -> [Build] -> [Build]
chooseFrom [] _  = []
chooseFrom xs [] = map singletonBuild xs
chooseFrom (x:xs) (b:bs) = addToBuild x b : mergeBuilds (map (addToBuild x) bs) (chooseFrom xs (b:bs))

addToBuild :: Choice -> Build -> Build
addToBuild (Choice i v) (Build s xs) = Build (v+s) (i:xs)

mergeBuilds xs [] = xs
mergeBuilds [] ys = ys
mergeBuilds (x:xs) (y:ys) = case compare (strength x) (strength y) of
  GT -> x : mergeBuilds xs (y:ys)
  _  -> y : mergeBuilds (x:xs) ys

Problems E and F

I didn’t even get to these problems during the contest; I spent too long fighting with problem C and implementing my overly complicated solution to problem D. I might attempt to solve them in Haskell too; if I do, I’ll write about them in another blog post!

Posted in competitive programming, haskell | Tagged , | 2 Comments

Automatically updated, cached views with lens

Recently I discovered a nice way to deal with records where certain fields of the record cache some expensive function of other fields, using the lens library. I very highly doubt I am the first person to ever think of this, but I don’t think I’ve seen it written down anywhere. I’d be very happy to be learn of similar approaches elsewhere.

The problem

Suppose we have some kind of record data structure, and an expensive-to-calculate function which computes some kind of “view”, or summary value, for the record. Like this:

data Record = Record
  { field1 :: A, field2 :: B, field3 :: C }

expensiveView :: A -> B -> C -> D
expensiveView = ...

(Incidentally, I went back and forth on whether to put real code or only pseudocode in this post; in the end, I decided on pseudocode. Hopefully it should be easy to apply in real situations.)

If we need to refer to the summary value often, we might like to cache the result of the expensive function in the record:

data Record = Record
  { field1 :: A, field2 :: B, field3 :: C, cachedView :: D }

expensiveView :: A -> B -> C -> D
expensiveView = ...

However, this has several drawbacks:

  1. Every time we produce a new Record value by updating one or more fields, we have to remember to also update the cached view. This is easy to miss, especially in a large codebase, and will most likely result in bugs that are very difficult to track down.

  2. Actually, it gets worse: what if we already have a large codebase that is creating updated Record values in various places? We now have to comb through the codebase looking for such places and modifying them to update the cachedExpensive field too. Then we cross our fingers and hope we didn’t miss any.

  3. Finally, there is nothing besides comments and naming conventions to prevent us from accidentally modifying the cachedExpensive field directly.

The point is that our Record type now has an associated invariant, and invariants which are not automatically enforced by the API and/or type system are Bad ™.

Lens to the rescue

If you don’t want to use lens, you can stop reading now. (Honestly, given the title, I’m not even sure why you read this far.) In my case, I was already using it heavily, and I had a lightbulb moment when I realized how I could leverage it to add a safe cached view to a data type without modifying the rest of my codebase at all!

The basic idea is this:

  1. Add a field to hold the cached value as before.
  2. Don’t use lens’s TemplateHaskell utilites to automatically derive lenses for all the fields. Instead, declare them manually, such that they automatically update the cached field on every set operation.
  3. For the field with the cached value itself, declare a Getter, not a Lens.
  4. Do not export the constructor or field projections for your data type; export only the type and the lenses.

In pseudocode, it looks something like this:

module Data.Record
  (Record, field1, field2, field3, cachedView)

import Control.Lens

data Record = Record
  { _field1 :: A, _field2 :: B, _field3 :: C, _cachedView :: D }

expensiveView :: A -> B -> C -> D
expensiveView = ...

recache :: Record -> Record
recache r = r { _cachedView = expensiveView (_field1 r) (_field2 r) (_field3 r) }

cachingLens :: (Record -> a) -> (Record -> a -> Record) -> Lens' Record a
cachingLens get set = lens get (\r a -> recache $ set r a)

field1 :: Lens' Record A
field1 = cachingLens _field1 (\r x -> r { _field1 = x })

field2 :: Lens' Record B
field2 = cachingLens _field2 (\r x -> r { _field2 = x })

field3 :: Lens' Record C
field3 = cachingLens _field3 (\r x -> r { _field3 = x })

cachedView :: Getter Record D
cachedView = to _cachedView

This solves all the problems! (1) We never have to remember to update the cached field; using a lens to modify the value of another field will automatically cause the cached view to be recomputed as well. (3) We can’t accidentally set the cached field, since it only has a Getter, not a Lens. In fact, this even solves (2), the problem of having to update the rest of our codebase: if we are already using lens to access fields in the record (as I was), then the rest of the codebase doesn’t have to change at all! And if we aren’t using lens already, then the typechecker will infallibly guide us to all the places we have to fix; once our code typechecks again, we know we have caught every single access to the record in the codebase.

Variant for only a few fields

What if we have a large record, and the cached summary value only depends on a few of the fields? In that case, we can save a bit of work for ourselves by getting lens to auto-generate lenses for the other fields, and only handcraft lenses for the fields that are actually involved. Like this:

{-# LANGUAGE TemplateHaskell #-}

data Record = Record
  { _field1 :: A, _field2 :: B, _cachedView :: C, ... }

expensiveView :: A -> B -> C
expensiveView = ...

let exclude = ['_field1, '_field2, '_cachedView] in
    (lensRules & lensField . mapped . mapped %~ \fn n ->
      if n `elem` exclude then [] else fn n)

field1 :: Lens' Record A
field1 = ... similar to before ...

field2 :: Lens' Record B
field2 = ...

cachedView :: Getter Record C
cachedView = to _cachedView

But what about the lens laws?

You might worry that having a lens for one field automatically update the value of another field might break the lens laws somehow, but it’s perfectly legal, as we can check.

  1. view l (set l v s) ≡ v clearly holds: setting the cachedView on the side doesn’t change the fact that we get back out whatever we put into, say, field1.
  2. set l v' (set l v s) ≡ set l v' s also clearly holds. On the left-hand side, the cached summary value will simply get overwritten in the same way that the other field does.
  3. set l (view l s) s ≡ s is actually a bit more subtle. If we view the value of field1, then set it with the same value again, how do we know the value of the overall record s doesn’t change? In particular, could we end up with a different cachedView even though field1 is the same? But in fact, in this specific scenario (putting the same value back into a field that we just read), the value of the cachedView won’t change. This depends on two facts: first, that the expensiveView is a deterministic function which always returns the same summary value for the same input record. Of course this is guaranteed by the fact that it’s a pure function. Second, we must maintain the invariant that the cachedView is always up-to-date, so that recomputing the summary value after setting a field to the same value it already had will simply produce the same summary value again, because we know the summary value was correct to begin with. And of course, maintaining this invariant is the whole point; it’s guaranteed by the way we only export the lenses (and only a Getter for the cachedView) and not the record constructor.

And that’s it! I’ve been using this approach very successfully in a current project (the same project that got me to implement Hindley-Milner with unification-fd—watch this space for an announcement soon!). If you know of similar approaches that have been written about elsewhere, or if you end up using this technique in your own project, I’d love to hear about it.

Posted in haskell | Tagged , , , | 10 Comments

Competitive programming in Haskell: Kadane’s algorithm

I will be giving a talk on competitive programming in Haskell tomorrow, September 10, at Haskell Love. You should attend! It’s free!

In my last competitive programming post, I challenged you to solve Purple Rain. We are presented with a linear sequence of cells, each colored either red or blue, and we are supposed to find the (contiguous) segment of cells with the maximal absolute difference between the number of red and blue. For example, below is shown one of the sample inputs, with the solution highlighted: the segment from cell 3 to cell 7 (the cells are numbered from 1) has four red cells compared to only one blue, for an absolute difference of three. You can verify that no other segment does better.

Transforming the problem

The obvious way to do this is to generate a list of all segments, compute the absolute difference between the number of red and blue cells for each, and take the maximum. However, that approach is doomed to exceed the time limit in any programming language: it would take O(n^3) time (O(n^2) possible segments times O(n) to sum each one), and the problem states that n can be up to 10^5. With 10^8 operations per second as a good rule of thumb, O(n^3) with n = 10^5 is clearly too slow. (In fact, any time you see an input size of 10^5, it is a dead giveaway that you are expected to find an O(n) or O(n \lg n) solution. 10^5 is big enough to make an O(n^2) solution prohibitively slow, but not so big that I/O itself becomes the bottleneck.)

The first insight is that we can transform this into the classic problem of finding the maximum sum subarray (also known as the maximum segment sum; either way I will abbreviate it as MSS) in two steps: first, turn each red cell into a 1, and each blue into -1. The sum of a segment then tells us how many more red than blue cells there are. Now, we actually want the biggest absolute difference between red and blue; but if we have an algorithm to find the MSS we can just run it twice: once to find the maximum excess of red over blue, and again with 1 and -1 flipped to find the maximum excess of blue over red.

The MSS problem has a long history in the functional programming community, being one of the flagship problems to demonstrate the techniques of program derivation in the style of the Bird-Meertens Formalism, aka Squiggol and The Algebra of Programming. It is possible to start out with a naive-but-obviously-correct O(n^3) implementation, and do a series of equational transformations to turn it into an efficient O(n) algorithm! If you’ve never seen that kind of thing before, I highly recommend checking it out; the Wikipedia page on the Bird-Meertens Formalism, linked above, is a good place to start. Certainly getting good at such derivations can be a handy skill when doing competitive programming in Haskell. But in any case, today I want to approach the problem from a different point of view, namely, coming up with a good functional equivalent to an existing imperative algorithm.

Kadane’s algorithm

Kadane’s algorithm, first proposed by Jay Kadane sometime in the late 1970s, is a linear-time algorithm for solving the MSS problem. It is actually quite simple to implement (the tricky part is understanding why it works!).

The idea is to loop through an array while keeping track of two things: a current value cur, and a best value. The best value is just the greatest value cur has ever taken on, so keeping it updated is easy: on every loop, we compare cur to best, and save the value of cur into best if it is higher. To keep cur updated, we simply add each new array value to it—but if it ever falls below zero, we just reset cur to zero. Here is some Java code:

public static int kadane(int[] a) {
    int best = 0, cur = 0;
    for (int i = 0; i < a.length; i++) {
        cur += a[i];
        if (cur < 0) cur = 0;
        if (cur > best) best = cur;
   return best;

Again, it is not at all obvious why this works, though putting in the effort to understand a proof is well worth the time. That is not the purpose of this blog post, however, so I’ll leave you to read about it on your own!

Translating Kadane’s algorithm to Haskell

In the imperative version, we iterate through a list, keep track of a current value, and also keep track of the best value we have seen so far. It is possible to translate this directly to Haskell: create a record with two fields, one for the current thing and one for the best thing, then iterate through the list with foldl', doing the appropriate update at each step:

data State s = State { curThing :: s, bestThing :: s }

-- Given a way to update the current s value with the next list
-- element of type a, update a State s value which keeps track of the
-- current s value as well as the best s value seen so far.
step :: Ord s => (s -> a -> s) -> (State s -> a -> State s)
step update (State cur best) a = State next (max best next)
    next = update cur a

bestIntermediate :: Ord s => (s -> a -> s) -> s -> [a] -> s
bestIntermediate update init = bestThing . foldl' (step update) (State init init)

But there’s a much better way! Note that the update function has the right type to be used with foldl'. But if we just computed foldl' update init directly, we would get only the single s value at the very end. But our goal is to get the best out of all the intermediate values. No problem: a scan is just a fold that returns all the intermediate values instead of only the final one! So instead of all this complicated and quasi-imperative State stuff, we just do a scanl' followed by taking the maximum:

bestIntermediate :: Ord s => (s -> a -> s) -> s -> [a] -> s
bestIntermediate update init = maximum . scanl' update init

Ah, much better! Using bestIntermediate, we can now translate Kadane’s algorithm as follows:

kadane1 :: [Int] -> Int
kadane1 = bestIntermediate next 0
    next s a = max 0 (s + a)

Whenever I write down an algorithm like this in Haskell—especially if I have “translated” it from an existing, imperative algorithm—I like to figure out how I can generalize it as much as possible. What structure is assumed of the inputs that makes the algorithm work? Can we replace some concrete monomorphic types with polymorphic ones? What type class constraints are needed? Often these sorts of generalizations make for supposedly tricky competitive programming problems. For example, if you have only ever seen Dijkstra’s algorithm presented in terms of finding shortest paths by summing edge weights, it takes quite a bit of insight to realize that the algorithm really just finds the “best” path with respect to operations that play nicely with each other in certain ways. For example, if we use the operations max and min in place of min and (+), Dijkstra’s algorithm finds the path with the maximum bottleneck. (This will probably end up being its own blog post at some point…)

Anyway, how can we generalize kadane1? The obvious starting point is that if we just let GHC infer a type for kadane1, we would get something more general:

kadane2 :: (Num a, Ord a) => [a] -> a
kadane2 = bestIntermediate next 0
    next s a = max 0 (s + a)

The only thing we do with the input list elements is add them and compare them; we also need 0 to have the same type as the input list elements. So the algorithm works for anything that has Ord and Num instances.

But wait—do we really need Num if all we are using is 0 and +? Really we are using the monoid of integers under addition. So we can generalize again, to any ordered monoid:

kadane :: (Monoid a, Ord a) => [a] -> a
kadane = bestIntermediate next mempty
    next s a = max mempty (s <> a)

In fact, if you study the proof of Kadane’s algorithm, you will see that this works just so long as the monoid operation interacts nicely with the ordering, that is, if x < y implies x \diamond z < y \diamond z and z \diamond x < z \diamond y for all z (this is what is usually meant by an “ordered monoid”).

Finding the best segment

So far, our code finds the best segment sum, but it doesn’t tell us which segment it was that was best—and for this problem we are actually supposed to output the starting and ending indices of the best segment, not the maximum red-blue difference itself.

If I were doing this in Java, I would probably just add several more variables: one to record where the segment currently being considered starts (which gets reset to i+1 when cur is reset to 0), and two to record the start and end indices of the best segment so far. This gets kind of ugly. Conceptually, the values actually belong in triples representing a segment: the start and end index together with the sum. In Java, it would be too heavyweight to construct a class to store these three values together, so in practice I would just do it with a mess of individual variables as described above. Fortunately, in Haskell, this is very lightweight, and we should of course create a data type to represent a segment.

It’s also worth noting that in Haskell, we were naturally led to make a polymorphic bestIntermediate function, which will work just as well with a segment type as it does Int. Only our kadane function itself will have to change. We will make a data type to represent segments, with an appropriate Ord instance to specify when one segment is better than another, and we will update the next helper function to update a segment instead of just updating a sum.

The solution

So let’s get started! First, some LANGUAGE pragmas and imports we will need, and a basic solution framework.

{-# LANGUAGE DerivingStrategies         #-}
{-# LANGUAGE DerivingVia                #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase                 #-}
{-# LANGUAGE ViewPatterns               #-}

import           Control.Arrow         ((>>>))
import           Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as C
import           Data.List             (scanl')
import           Data.Semigroup

showB :: Show a => a -> C.ByteString
showB = show >>> C.pack

main = C.interact $ solve >>> format

Next, a data type to represent a segment, along with an Ord instance, and a function to format the answer for this problem. Note that we make all the fields of Segment strict (this makes a big difference in runtime—you should pretty much always use strict fields unless there is a really good reason not to). Notice also that the Ord instance is where we encode the problem’s specific instructions about how to break ties: “If there are multiple possible answers, print the one that has the Westernmost (smallest-numbered) starting section. If there are multiple answers with the same Westernmost starting section, print the one with the Westernmost ending section.” So we first compare Segments by sum, then by left index, then by right index, being careful to reverse the order of comparison for the indices, since better for indices means smaller.

-- The segment [l,r) (i.e. l inclusive, r exclusive), with its sum
data Segment a = I { l :: !Int, r :: !Int, s :: !a } deriving (Eq, Show)

instance Ord a => Ord (Segment a) where
  compare (I l1 r1 s1) (I l2 r2 s2)
    = compare s1 s2 <> compare l2 l1 <> compare r2 r1

format :: Segment a -> ByteString
format (I l r _) = C.unwords [showB l, showB (r-1)]
  -- (r-1) since r is exclusive but we are supposed to show
  -- the index of the last element

And now for Kadane’s algorithm itself. The bestIntermediate function is unchanged. kadane changes to start with an appropriate “empty” segment, and to use a next function that updates the current segment [l,r) based on the next list element a (the one at index r+1). The right index always gets incremented to r+1. If the current sum combined with a is “negative” (that is, less than mempty), we reset the left index to r+1 as well (making the segment empty), and the running sum to mempty. Otherwise, we leave the left index unchanged and add a to the running sum. I also added an argument to indicate the index of the first list element, since in this problem the list is supposed to be indexed from 1, but future problems might be indexed from 0, and I would never remember which it is. (I suppose we could also take a list of pairs (i,a) where i is the index of a. That would even work for non-consecutive indices.)

bestIntermediate :: Ord s => (s -> a -> s) -> s -> [a] -> s
bestIntermediate update init = maximum . scanl' update init

kadane :: (Monoid a, Ord a) => Int -> [a] -> Segment a
kadane ix = bestIntermediate next (I ix ix mempty)
    next (I l r s) a
      | s<>a < mempty = I (r+1) (r+1) mempty
      | otherwise     = I l     (r+1) (s<>a)

Finally, we can write the solve function itself. We map R’s and B’s in the input to 1 and -1, respectively, to generate a list of 1’s and -1’s called mountain. Then we call kadane on mountain and on map negate mountain, and pick whichever gives us a better answer. But wait, not quite! Remember that kadane needs a Monoid instance for the list elements, and Int has none. So we can whip up a quick newtype, Step, that has the right instances (note how deriving Num also allows us to use the literal values 1 and -1 as Step values). DerivingVia is quite handy in situations like this.

solve :: ByteString -> Segment Step
solve (C.init -> b) = max (kadane 1 mountain) (kadane 1 (map negate mountain))
    mountain = map (\case { 'R' -> 1; 'B' -> -1 }) (C.unpack b)

newtype Step = Step Int
  deriving (Semigroup, Monoid) via Sum Int
  deriving (Eq, Ord, Num)

Next time: Modulo Solitaire and fast BFS

For next time, I invite you to solve Modulo Solitaire. Warning: this is a straightforward BFS problem; the issue is getting a Haskell solution to run fast enough! I struggled with this for quite some time before coming up with something that worked. My ultimate goal in this case is to develop a polished library with an elegant, functional API for doing BFS, but that runs fast under the hood. I’m very curious to see how others might approach the problem.

Posted in competitive programming, haskell | Tagged , , , , | Leave a comment

Implementing Hindley-Milner with the unification-fd library

For a current project, I needed to implement type inference for a Hindley-Milner-based type system. (More about that project in an upcoming post!) If you don’t know, Hindley-Milner is what you get when you add polymorphism to the simply typed lambda calculus, but only allow \forall to show up at the very outermost layer of a type. This is the fundamental basis for many real-world type systems (e.g. OCaml or Haskell without RankNTypes enabled).

One of the core operations in any Hindley-Milner type inference algorithm is unification, where we take two types that might contain unification variables (think “named holes”) and try to make them equal, which might fail, or might provide more information about the values that the unification variables should take. For example, if we try to unify a -> Int and Char -> b, we will learn that a = Char and b = Int; on the other hand, trying to unify a -> Int and (b, Char) will fail, because there is no way to make those two types equal (the first is a function type whereas the second is a pair).

I’ve implemented this from scratch before, and it was a great learning experience, but I wasn’t looking forward to implementing it again. But then I remembered the unification-fd library and wondered whether I could use it to simplify the implementation. Long story short, although the documentation for unification-fd claims it can be used to implement Hindley-Milner, I couldn’t find any examples online, and apparently neither could anyone else. So I set out to make my own example, and you’re reading it. It turns out that unification-fd is incredibly powerful, but using it can be a bit finicky, so I hope this example can be helpful to others who wish to use it. Along the way, resources I found especially helpful include this basic unification-fd tutorial by the author, Wren Romano, as well as a blog post by Roman Cheplyaka, and the unification-fd Haddock documentation itself. I also referred to the Wikipedia page on Hindley-Milner, which is extremely thorough.

This blog post is rendered automatically from a literate Haskell file; you can find the complete working source code and blog post on GitHub. I’m always happy to receive comments, fixes, or suggestions for improvement.

Prelude: A Bunch of Extensions and Imports

We will make GHC and other people’s libraries work very hard for us. You know the drill.

> {-# LANGUAGE DeriveAnyClass        #-}
> {-# LANGUAGE DeriveFoldable        #-}
> {-# LANGUAGE DeriveFunctor         #-}
> {-# LANGUAGE DeriveGeneric         #-}
> {-# LANGUAGE DeriveTraversable     #-}
> {-# LANGUAGE FlexibleContexts      #-}
> {-# LANGUAGE FlexibleInstances     #-}
> {-# LANGUAGE GADTs                 #-}
> {-# LANGUAGE LambdaCase            #-}
> {-# LANGUAGE MultiParamTypeClasses #-}
> {-# LANGUAGE PatternSynonyms       #-}
> {-# LANGUAGE StandaloneDeriving    #-}
> {-# LANGUAGE UndecidableInstances  #-}
> import           Control.Category ((>>>))
> import           Control.Monad.Except
> import           Control.Monad.Reader
> import           Data.Foldable              (fold)
> import           Data.Functor.Identity
> import           Data.List                  (intercalate)
> import           Data.Map                   (Map)
> import qualified Data.Map                   as M
> import           Data.Maybe
> import           Data.Set                   (Set, (\\))
> import qualified Data.Set                   as S
> import           Prelude                    hiding (lookup)
> import           Text.Printf
> import           Text.Parsec
> import           Text.Parsec.Expr
> import           Text.Parsec.Language       (emptyDef)
> import           Text.Parsec.String
> import qualified Text.Parsec.Token          as L
> import           Control.Unification        hiding ((=:=), applyBindings)
> import qualified Control.Unification        as U
> import           Control.Unification.IntVar
> import           Data.Functor.Fixedpoint
> import           GHC.Generics               (Generic1)
> import           System.Console.Repline

Representing our types

We’ll be implementing a language with lambas, application, and let-expressions—as well as natural numbers with an addition operation, just to give us a base type and something to do with it. So we will have a natural number type and function types, along with polymorphism, i.e. type variables and forall. (Adding more features like sum and product types, additional base types, recursion, etc. is left as an exercise for the reader!)

So notionally, we want something like this:

type Var = String
data Type = TyVar Var | TyNat | TyFun Type Type

However, when using unification-fd, we have to encode our Type data type (i.e. the thing we want to do unification on) using a “two-level type” (see Tim Sheard’s original paper).

> type Var = String
> data TypeF a = TyVarF Var | TyNatF | TyFunF a a
>   deriving (Show, Eq, Functor, Foldable, Traversable, Generic1, Unifiable)
> type Type = Fix TypeF

TypeF is a “structure functor” that just defines a single level of structure; notice TypeF is not recursive at all, but uses the type parameter a to mark the places where a recursive instance would usually be. unification-fd provides a Fix type to “tie the knot” and make it recursive. (I’m not sure why unification-fd defines its own Fix type instead of using the one from Data.Fix, but perhaps the reason is that it was written before Data.Fix existed—unification-fd was first published in 2007!)

We have to derive a whole bunch of instances for TypeF which fortunately we get for free. Note in particular Generic1 and Unifiable: Unifiable is a type class from unification-fd which describes how to match up values of our type. Thanks to the work of Roman Cheplyaka, there is a default implementation for Unifiable based on a Generic1 instance—which GHC derives for us in turn—and the default implementation works great for our purposes.

unification-fd also provides a second type for tying the knot, called UTerm, defined like so:

data UTerm t v
  = UVar  !v               -- ^ A unification variable.
  | UTerm !(t (UTerm t v)) -- ^ Some structure containing subterms.

It’s similar to Fix, except it also adds unification variables of some type v. (If it means anything to you, note that UTerm is actually the free monad over t.) We also define a version of Type using UTerm, which we will use during type inference:

> type UType = UTerm TypeF IntVar

IntVar is a type provided by unification-fd representing variables as Int values, with a mapping from variables to bindings stored in an IntMap. unification-fd also provies an STVar type which implements variables via STRefs; I presume using STVars would be faster (since no intermediate lookups in an IntMap are required) but forces us to work in the ST monad. For now I will just stick with IntVar, which makes things simpler.

At this point you might wonder: why do we have type variables in our definition of TypeF, but also use UTerm to add unification variables? Can’t we just get rid of the TyVarF constructor, and let UTerm provide the variables? Well, type variables and unification variables are subtly different, though intimately related. A type variable is actually part of a type, whereas a unification variable is not itself a type, but only stands for some type which is (as yet) unknown. After we are completely done with type inference, we won’t have a UTerm any more, but we might have a type like forall a. a -> a which still contains type variables, so we need a way to represent them. We could only get rid of the TyVarF constructor if we were doing type inference for a language without polymorphism (and yes, unification could still be helpful in such a situation—for example, to do full type reconstruction for the simply typed lambda calculus, where lambdas do not have type annotations).

Polytype represents a polymorphic type, with a forall at the front and a list of bound type variables (note that regular monomorphic types can be represented as Forall [] ty). We don’t need to make an instance of Unifiable for Polytype, since we never unify polytypes, only (mono)types. However, we can have polytypes with unification variables in them, so we need two versions, one containing a Type and one containing a UType.

> data Poly t = Forall [Var] t
>   deriving (Eq, Show, Functor)
> type Polytype  = Poly Type
> type UPolytype = Poly UType

Finally, for convenience, we can make a bunch of pattern synonyms that let us work with Type and UType just as if they were directly recursive types. This isn’t required; I just like not having to write Fix and UTerm everywhere.

> pattern TyVar :: Var -> Type
> pattern TyVar v = Fix (TyVarF v)
> pattern TyNat :: Type
> pattern TyNat = Fix TyNatF
> pattern TyFun :: Type -> Type -> Type
> pattern TyFun t1 t2 = Fix (TyFunF t1 t2)
> pattern UTyNat :: UType
> pattern UTyNat = UTerm TyNatF
> pattern UTyFun :: UType -> UType -> UType
> pattern UTyFun t1 t2 = UTerm (TyFunF t1 t2)
> pattern UTyVar :: Var -> UType
> pattern UTyVar v = UTerm (TyVarF v)


Here’s a data type to represent expressions. There’s nothing much interesting to see here, since we don’t need to do anything fancy with expressions. Note that lambdas don’t have type annotations, but let-expressions can have an optional polytype annotation, which will let us talk about checking polymorphic types in addition to inferring them (a lot of presentations of Hindley-Milner don’t talk about this).

> data Expr where
>   EVar  :: Var -> Expr
>   EInt  :: Integer -> Expr
>   EPlus :: Expr -> Expr -> Expr
>   ELam  :: Var -> Expr -> Expr
>   EApp  :: Expr -> Expr -> Expr
>   ELet  :: Var -> Maybe Polytype -> Expr -> Expr -> Expr

Normally at this point we would write parsers and pretty-printers for types and expressions, but that’s boring and has very little to do with unification-fd, so I’ve left those to the end. Let’s get on with the interesting bits!

Type inference infrastructure

Before we get to the type inference algorithm proper, we’ll need to develop a bunch of infrastructure. First, here’s the concrete monad we will be using for type inference. The ReaderT Ctx will keep track of variables and their types; ExceptT TypeError of course allows us to fail with type errors; and IntBindingT is a monad transformer provided by unification-fd which supports various operations such as generating fresh variables and unifying things. Note, for reasons that will become clear later, it’s very important that the IntBindingT is on the bottom of the stack, and the ExceptT comes right above it. Beyond that we can add whatever we like on top.

> type Infer = ReaderT Ctx (ExceptT TypeError (IntBindingT TypeF Identity))

Normally, I would prefer to write everything in a “capability style” where the code is polymorphic in the monad, and just specifies what capabilites/effects it needs (either just using mtl classes directly, or using an effects library like polysemy or fused-effects), but the way the unification-fd API is designed seems to make that a bit tricky.

A type context is a mapping from variable names to polytypes; we also have a function for looking up the type of a variable in the context, and a function for running a local subcomputation in an extended context.

> type Ctx = Map Var UPolytype
> lookup :: Var -> Infer UType
> lookup x = do
>   ctx <- ask
>   maybe (throwError $ UnboundVar x) instantiate (M.lookup x ctx)
> withBinding :: MonadReader Ctx m => Var -> UPolytype -> m a -> m a
> withBinding x ty = local (M.insert x ty)

The lookup function throws an error if the variable is not in the context, and otherwise returns a UType. Conversion from the UPolytype stored in the context to a UType happens via a function called instantiate, which opens up the UPolytype and replaces each of the variables bound by the forall with a fresh unification variable. We will see the implementation of instantiate later.

We will often need to recurse over UTypes. We could just write directly recursive functions ourselves, but there is a better way. Although the unification-fd library provides a function cata for doing a fold over a term built with Fix, it doesn’t provide a counterpart for UTerm; but no matter, we can write one ourselves, like so:

> ucata :: Functor t => (v -> a) -> (t a -> a) -> UTerm t v -> a
> ucata f _ (UVar v) = f v
> ucata f g (UTerm t) = g (fmap (ucata f g) t)

Now, we can write some utilities for finding free variables. Inexplicably, IntVar does not have an Ord instance (even though it is literally just a newtype over Int), so we have to derive one if we want to store them in a Set. Notice that our freeVars function finds free unification variables and free type variables; I will talk about why we need this later (this is something I got wrong at first!).

> deriving instance Ord IntVar
> class FreeVars a where
>   freeVars :: a -> Infer (Set (Either Var IntVar))

Finding the free variables in a UType is our first application of ucata. First, to find the free unification variables, we just use the getFreeVars function provided by unification-fd and massage the output into the right form. To find free type variables, we fold using ucata: we ignore unification variables, capture a singleton set in the TyVarF case, and in the recursive case we call fold, which will turn a TypeF (Set ...) into a Set ... using the Monoid instance for Set, i.e. union.

> instance FreeVars UType where
>   freeVars ut = do
>     fuvs <- fmap (S.fromList . map Right) . lift . lift $ getFreeVars ut
>     let ftvs = ucata (const S.empty)
>                      (\case {TyVarF x -> S.singleton (Left x); f -> fold f})
>                      ut
>     return $ fuvs `S.union` ftvs

Why don’t we just find free unification variables with ucata at the same time as the free type variables, and forget about using getFreeVars? Well, I looked at the source, and getFreeVars is actually a complicated beast. I’m really not sure what it’s doing, and I don’t trust that just manually getting the unification variables ourselves would be doing the right thing!

Now we can leverage the above instance to find free varaibles in UPolytypes and type contexts. For a UPolytype, we of course have to subtract off any variables bound by the forall.

> instance FreeVars UPolytype where
>   freeVars (Forall xs ut) = (\\ (S.fromList (map Left xs))) <$> freeVars ut
> instance FreeVars Ctx where
>   freeVars = fmap S.unions . mapM freeVars . M.elems

And here’s a simple utility function to generate fresh unification variables, built on top of the freeVar function provided by unification-fd:

> fresh :: Infer UType
> fresh = UVar <$> lift (lift freeVar)

One thing to note is the annoying calls to lift we have to do in the definition of FreeVars for UType, and in the definition of fresh. The getFreeVars and freeVar functions provided by unification-fv have to run in a monad which is an instance of BindingMonad, and the BindingMonad class does not provide instances for mtl transformers. We could write our own instances so that these functions would work in our Infer monad automatically, but honestly that sounds like a lot of work. Sprinkling a few lifts here and there isn’t so bad.

Next, a data type to represent type errors, and an instance of the Fallible class, needed so that unification-fd can inject errors into our error type when it encounters unification errors. Basically we just need to provide two specific constructors to represent an “occurs check” failure (i.e. an infinite type), or a unification mismatch failure.

> data TypeError where
>   UnboundVar   :: String -> TypeError
>   Infinite     :: IntVar -> UType -> TypeError
>   Mismatch     :: TypeF UType -> TypeF UType -> TypeError
> instance Fallible TypeF IntVar TypeError where
>   occursFailure   = Infinite
>   mismatchFailure = Mismatch

The =:= operator provided by unification-fd is how we unify two types. It has a kind of bizarre type:

(=:=) :: ( BindingMonad t v m, Fallible t v e
         , MonadTrans em, Functor (em m), MonadError e (em m))
      => UTerm t v -> UTerm t v -> em m (UTerm t v)

(Apparently I am not the only one who thinks this type is bizarre; the unification-fd source code contains the comment -- TODO: what was the reason for the MonadTrans madness?)

I had to stare at this for a while to understand it. It says that the output will be in some BindingMonad (such as IntBindingT), and there must be a single error monad transformer on top of it, with an error type that implements Fallible. So =:= can return ExceptT TypeError (IntBindingT Identity) UType, but it cannot be used directly in our Infer monad, because that has a ReaderT on top of the ExceptT. So I just made my own version with an extra lift to get it to work directly in the Infer monad. While we’re at it, we’ll make a lifted version of applyBindings, which has the same issue.

> (=:=) :: UType -> UType -> Infer UType
> s =:= t = lift $ s U.=:= t
> applyBindings :: UType -> Infer UType
> applyBindings = lift . U.applyBindings

Converting between mono- and polytypes

Central to the way Hindley-Milner works is the way we move back and forth between polytypes and monotypes. First, let’s see how to turn UPolytypes into UTypes, hinted at earlier in the definition of the lookup function. To instantiate a UPolytype, we generate a fresh unification variable for each variable bound by the Forall, and then substitute them throughout the type.

> instantiate :: UPolytype -> Infer UType
> instantiate (Forall xs uty) = do
>   xs' <- mapM (const fresh) xs
>   return $ substU (M.fromList (zip (map Left xs) xs')) uty

The substU function can substitute for either kind of variable in a UType (right now we only need it to substitute for type variables, but we will need it to substitute for unification variables later). Of course, it is implemented via ucata. In the variable cases we make sure to leave the variable alone if it is not a key in the given substitution mapping. In the recursive non-variable case, we just roll up the TypeF UType into a UType by applying UTerm. This is the power of ucata: we can deal with all the boring recursive cases in one fell swoop. This function won’t have to change if we add new types to the language in the future.

> substU :: Map (Either Var IntVar) UType -> UType -> UType
> substU m = ucata
>   (\v -> fromMaybe (UVar v) (M.lookup (Right v) m))
>   (\case
>       TyVarF v -> fromMaybe (UTyVar v) (M.lookup (Left v) m)
>       f -> UTerm f
>   )

There is one other way to convert a UPolytype to a UType, which happens when we want to check that an expression has a polymorphic type specified by the user. For example, let foo : forall a. a -> a = \x.3 in ... should be a type error, because the user specified that foo should have type forall a. a -> a, but then gave the implementation \x.3 which is too specific. In this situation we can’t just instantiate the polytype—that would create a unification variable for a, and while typechecking \x.3 it would unify a with nat. But in this case we don’t want a to unify with nat—it has to be held entirely abstract, because the user’s claim is that this function will work for any type a.

Instead of generating unification variables, we instead want to generate what are known as Skolem variables. Skolem variables do not unify with anything other than themselves. So how can we get unification-fd to do that? It does not have any built-in notion of Skolem variables. What we can do instead is to just embed the variables within the UType as UTyVars instead of UVars! unification-fd does not even know those are variables; it just sees them as another rigid part of the structure that must be matched exactly, just as a TyFun has to match another TyFun. The one remaining issue is that we need to generate fresh Skolem variables; it certainly would not do to have them collide with Skolem variables from some other forall. We could carry around our own supply of unique names in the Infer monad for this purpose, which would probably be the “proper” way to do things; but for now I did something more expedient: just get unification-fd to generate fresh unification variables, then rip the (unique! fresh!) Ints out of them and use those to make our Skolem variables.

> skolemize :: UPolytype -> Infer UType
> skolemize (Forall xs uty) = do
>   xs' <- mapM (const fresh) xs
>   return $ substU (M.fromList (zip (map Left xs) (map toSkolem xs'))) uty
>   where
>     toSkolem (UVar v) = UTyVar (mkVarName "s" v)

When unification-fd generates fresh IntVars it seems that it starts at minBound :: Int and increments, so we can add maxBound + 1 to get numbers that look reasonable. Again, this is not very principled, but for this toy example, who cares?

> mkVarName :: String -> IntVar -> Var
> mkVarName nm (IntVar v) = nm ++ show (v + (maxBound :: Int) + 1)

Next, how do we convert from a UType back to a UPolytype? This happens when we have inferred the type of a let-bound variable and go to put it in the context; typically, Hindley-Milner systems generalize the inferred type to a polytype. If a unification variable still occurs free in a type, it means it was not constrained at all, so we can universally quantify over it. However, we have to be careful: unification variables that occur in some type that is already in the context do not count as free, because we might later discover that they need to be constrained.

Also, just before we do the generalization, it’s very important that we use applyBindings. unification-fd has been collecting a substitution from unification variables to types, but for efficiency’s sake it does not actually apply the substitution until we ask it to, by calling applyBindings. Any unification variables which still remain after applyBindings really are unconstrained so far. So after applyBindings, we get the free unification variables from the type, subtract off any unification variables which are free in the context, and close over the remaining variables with a forall, substituting normal type variables for them. It does not particularly matter if these type variables are fresh (so long as they are distinct). But we can’t look only at unification variables! We have to look at free type variables too (this is the reason that our freeVars function needs to find both free type and unification variables). Why is that? Well, we might have some free type variables floating around if we previously generated some Skolem variables while checking a polymorphic type. (A term which illustrates this behavior is \y. let x : forall a. a -> a = y in x 3.) Free Skolem variables should also be generalized over.

> generalize :: UType -> Infer UPolytype
> generalize uty = do
>   uty' <- applyBindings uty
>   ctx <- ask
>   tmfvs  <- freeVars uty'
>   ctxfvs <- freeVars ctx
>   let fvs = S.toList $ tmfvs \\ ctxfvs
>       xs  = map (either id (mkVarName "a")) fvs
>   return $ Forall xs (substU (M.fromList (zip fvs (map UTyVar xs))) uty')

Finally, we need a way to convert Polytypes entered by the user into UPolytypes, and a way to convert the final UPolytype back into a Polytype. unification-fd provides functions unfreeze : Fix t -> UTerm t v and freeze : UTerm t v -> Maybe (Fix t) to convert between terms built with UTerm (with unification variables) and Fix (without unification variables). Converting to UPolytype is easy: we just use unfreeze to convert the underlying Type to a UType.

> toUPolytype :: Polytype -> UPolytype
> toUPolytype = fmap unfreeze

When converting back, notice that freeze returns a Maybe; it fails if there are any unification variables remaining. So we must be careful to only use fromUPolytype when we know there are no unification variables remaining in a polytype. In fact, we will use this only at the very top level, after generalizing the type that results from inference over a top-level term. Since at the top level we only perform inference on closed terms, in an empty type context, the final generalize step will generalize over all the remaining free unification variables, since there will be no free variables in the context.

> fromUPolytype :: UPolytype -> Polytype
> fromUPolytype = fmap (fromJust . freeze)

Type inference

Finally, the type inference algorithm proper! First, to check that an expression has a given type, we infer the type of the expression and then demand (via =:=) that the inferred type must be equal to the given one. Note that =:= actually returns a UType, and it can apparently be more efficient to use the result of =:= in preference to either of the arguments to it (although they will all give equivalent results). However, in our case this doesn’t seem to make much difference.

> check :: Expr -> UType -> Infer ()
> check e ty = do
>   ty' <- infer e
>   _ <- ty =:= ty'
>   return ()

And now for the infer function. The EVar, EInt, and EPlus cases are straightforward.

> infer :: Expr -> Infer UType
> infer (EVar x)      = lookup x
> infer (EInt _)      = return UTyNat
> infer (EPlus e1 e2) = do
>   check e1 UTyNat
>   check e2 UTyNat
>   return UTyNat

For an application EApp e1 e2, we infer the types funTy and argTy of e1 and e2 respectively, then demand that funTy =:= UTyFun argTy resTy for a fresh unification variable resTy. Again, =:= returns a more efficient UType which is equivalent to funTy, but we don’t need to use that type directly (we return resTy instead), so we just discard the result.

> infer (EApp e1 e2) = do
>   funTy <- infer e1
>   argTy <- infer e2
>   resTy <- fresh
>   _ <- funTy =:= UTyFun argTy resTy
>   return resTy

For a lambda, we make up a fresh unification variable for the type of the argument, then infer the type of the body under an extended context. Notice how we promote the freshly generated unification variable to a UPolytype by wrapping it in Forall []; we do not generalize it, since that would turn it into forall a. a! We just want the lambda argument to have the bare unification variable as its type.

> infer (ELam x body) = do
>   argTy <- fresh
>   withBinding x (Forall [] argTy) $ do
>     resTy <- infer body
>     return $ UTyFun argTy resTy

For a let expression without a type annotation, we infer the type of the definition, then generalize it and add it to the context to infer the type of the body. It is traditional for Hindley-Milner systems to generalize let-bound things this way (although note that GHC does not generalize let-bound things with -XMonoLocalBinds enabled, which is automatically implied by -XGADTs or -XTypeFamilies).

> infer (ELet x Nothing xdef body) = do
>   ty <- infer xdef
>   pty <- generalize ty
>   withBinding x pty $ infer body

For a let expression with a type annotation, we skolemize it and check the definition with the skolemized type; the rest is the same as the previous case.

> infer (ELet x (Just pty) xdef body) = do
>   let upty = toUPolytype pty
>   upty' <- skolemize upty
>   check xdef upty'
>   withBinding x upty $ infer body

Running the Infer monad

We need a way to run computations in the Infer monad. This is a bit fiddly, and it took me a long time to put all the pieces together. (But typed holes are sooooo great! It would have taken me way longer without them…) I’ve written the definition of runInfer using the backwards function composition operator, (>>>), so that the pipeline flows from top to bottom and I can intersperse it with explanation.

> runInfer :: Infer UType -> Either TypeError Polytype
> runInfer

The first thing we do is use applyBindings to make sure that we substitute for any unification variables that we know about. This results in another Infer UType.

>   =   (>>= applyBindings)

We can now generalize over any unification variables that are left, and then convert from UPolytype to Polytype. Again, this conversion is safe because at this top level we know we will be in an empty context, so the generalization step will definitely get rid of all the remaining unification variables.

>   >>> (>>= (generalize >>> fmap fromUPolytype))

Now all that’s left is to interpret the layers of our Infer monad one by one. As promised, we start with an empty type context.

>   >>> flip runReaderT M.empty
>   >>> runExceptT
>   >>> evalIntBindingT
>   >>> runIdentity

Finally, we can make a top-level function to infer a polytype for an expression, just by composing infer and runInfer.

> inferPolytype :: Expr -> Either TypeError Polytype
> inferPolytype = runInfer . infer


To be able to test things out, we can make a very simple REPL that takes input from the user and tries to parse, typecheck, and interpret it, printing either the results or an appropriate error message.

> eval :: String -> IO ()
> eval s = case parse expr "" s of
>   Left err -> print err
>   Right e -> case inferPolytype e of
>     Left tyerr -> putStrLn $ pretty tyerr
>     Right ty   -> do
>       putStrLn $ pretty e ++ " : " ++ pretty ty
>       when (ty == Forall [] TyNat) $ putStrLn $ pretty (interp e)
> main :: IO ()
> main = evalRepl (const (pure "HM> ")) (liftIO . eval) [] Nothing Nothing (Word (const (return []))) (return ()) (return Exit)

Here are a few examples to try out:

HM> 2 + 3
2 + 3 : nat
HM> \x. x
\x. x : forall a0. a0 -> a0
HM> \x.3
\x. 3 : forall a0. a0 -> nat
HM> \x. x + 1
\x. x + 1 : nat -> nat
HM> (\x. 3) (\y.y)
(\x. 3) (\y. y) : nat
HM> \x. y
Unbound variable y
HM> \x. x x
Infinite type u0 = u0 -> u1
HM> 3 3
Can't unify nat and nat -> u0
HM> let foo : forall a. a -> a = \x.3 in foo 5
Can't unify s0 and nat
HM> \f.\g.\x. f (g x)
\f. \g. \x. f (g x) : forall a2 a3 a4. (a3 -> a4) -> (a2 -> a3) -> a2 -> a4
HM> let f : forall a. a -> a = \x.x in let y : forall b. b -> b -> b = \z.\q. f z in y 2 3
let f : forall a. a -> a = \x. x in let y : forall b. b -> b -> b = \z. \q. f z in y 2 3 : nat
HM> \y. let x : forall a. a -> a = y in x 3
\y. let x : forall a. a -> a = y in x 3 : forall s1. (s1 -> s1) -> nat
HM> (\x. let y = x in y) (\z. \q. z)
(\x. let y = x in y) (\z. \q. z) : forall a1 a2. a1 -> a2 -> a1

And that’s it! Feel free to play around with this yourself, and adapt the code for your own projects if it’s helpful. And please do report any typos or bugs that you find.

Below, for completeness, you will find a simple, recursive, environment-passing interpreter, along with code for parsing and pretty-printing. I don’t give any commentary on them because, for the most part, they are straightforward and have nothing to do with unification-fd. But you are certainly welcome to look at them if you want to see how they work. The one interesting thing to say about the parser for types is that it checks that types entered by the user do not contain any free variables, and fails if they do. The parser is not really the right place to do this check, but again, it was expedient for this toy example. Also, I tend to use megaparsec for serious projects, but I had some parsec code for parsing something similar to this toy language lying around, so I just reused that.


> data Value where
>   VInt :: Integer -> Value
>   VClo :: Var -> Expr -> Env -> Value
> type Env = Map Var Value
> interp :: Expr -> Value
> interp = interp' M.empty
> interp' :: Env -> Expr -> Value
> interp' env (EVar x) = fromJust $ M.lookup x env
> interp' _   (EInt n) = VInt n
> interp' env (EPlus ea eb)   =
>   case (interp' env ea, interp' env eb) of
>     (VInt va, VInt vb) -> VInt (va + vb)
>     _ -> error "Impossible! interp' EPlus on non-Ints"
> interp' env (ELam x body) = VClo x body env
> interp' env (EApp fun arg) =
>   case interp' env fun of
>     VClo x body env' ->
>       interp' (M.insert x (interp' env arg) env') body
>     _ -> error "Impossible! interp' EApp on non-closure"
> interp' env (ELet x _ xdef body) =
>   let xval = interp' env xdef
>   in  interp' (M.insert x xval env) body


> lexer :: L.TokenParser u
> lexer = L.makeTokenParser emptyDef
>   { L.reservedNames = ["let", "in", "forall", "nat"] }
> parens :: Parser a -> Parser a
> parens = L.parens lexer
> identifier :: Parser String
> identifier = L.identifier lexer
> reserved :: String -> Parser ()
> reserved = L.reserved lexer
> reservedOp :: String -> Parser ()
> reservedOp = L.reservedOp lexer
> symbol :: String -> Parser String
> symbol = L.symbol lexer
> integer :: Parser Integer
> integer = L.natural lexer
> parseAtom :: Parser Expr
> parseAtom
>   =   EVar  <$> identifier
>   <|> EInt  <$> integer
>   <|> ELam  <$> (symbol "\\" *> identifier)
>             <*> (symbol "." *> parseExpr)
>   <|> ELet  <$> (reserved "let" *> identifier)
>             <*> optionMaybe (symbol ":" *> parsePolytype)
>             <*> (symbol "=" *> parseExpr)
>             <*> (reserved "in" *> parseExpr)
>   <|> parens parseExpr
> parseApp :: Parser Expr
> parseApp = chainl1 parseAtom (return EApp)
> parseExpr :: Parser Expr
> parseExpr = buildExpressionParser table parseApp
>   where
>     table = [ [ Infix (EPlus <$ reservedOp "+") AssocLeft ]
>             ]
> parsePolytype :: Parser Polytype
> parsePolytype = do
>   pty@(Forall xs ty) <- parsePolytype'
>   let fvs :: Set Var
>       fvs = flip cata ty $ \case
>         TyVarF x       -> S.singleton x
>         TyNatF         -> S.empty
>         TyFunF xs1 xs2 -> xs1 `S.union` xs2
>       unbound = fvs \\ S.fromList xs
>   unless (S.null unbound) $ fail $ "Unbound type variables: " ++ unwords (S.toList unbound)
>   return pty
> parsePolytype' :: Parser Polytype
> parsePolytype' =
>   Forall <$> (fromMaybe [] <$> optionMaybe (reserved "forall" *> many identifier <* symbol "."))
>           <*> parseType
> parseTypeAtom :: Parser Type
> parseTypeAtom =
>   (TyNat <$ reserved "nat") <|> (TyVar <$> identifier) <|> parens parseType
> parseType :: Parser Type
> parseType = buildExpressionParser table parseTypeAtom
>   where
>     table = [ [ Infix (TyFun <$ symbol "->") AssocRight ] ]
> expr :: Parser Expr
> expr = spaces *> parseExpr <* eof

Pretty printing

> type Prec = Int
> class Pretty p where
>   pretty :: p -> String
>   pretty = prettyPrec 0
>   prettyPrec :: Prec -> p -> String
>   prettyPrec _ = pretty
> instance Pretty (t (Fix t)) => Pretty (Fix t) where
>   prettyPrec p = prettyPrec p . unFix
> instance Pretty t => Pretty (TypeF t) where
>   prettyPrec _ (TyVarF v) = v
>   prettyPrec _ TyNatF = "nat"
>   prettyPrec p (TyFunF ty1 ty2) =
>     mparens (p > 0) $ prettyPrec 1 ty1 ++ " -> " ++ prettyPrec 0 ty2
> instance (Pretty (t (UTerm t v)), Pretty v) => Pretty (UTerm t v) where
>   pretty (UTerm t) = pretty t
>   pretty (UVar v)  = pretty v
> instance Pretty Polytype where
>   pretty (Forall [] t) = pretty t
>   pretty (Forall xs t) = unwords ("forall" : xs) ++ ". " ++ pretty t
> mparens :: Bool -> String -> String
> mparens True  = ("("++) . (++")")
> mparens False = id
> instance Pretty Expr where
>   prettyPrec _ (EVar x) = x
>   prettyPrec _ (EInt i) = show i
>   prettyPrec p (EPlus e1 e2) =
>     mparens (p>1) $
>       prettyPrec 1 e1 ++ " + " ++ prettyPrec 2 e2
>   prettyPrec p (ELam x body) =
>     mparens (p>0) $
>       "\\" ++ x ++ ". " ++ prettyPrec 0 body
>   prettyPrec p (ELet x mty xdef body) =
>     mparens (p>0) $
>       "let " ++ x ++ maybe "" (\ty -> " : " ++ pretty ty) mty
>             ++ " = " ++ prettyPrec 0 xdef
>             ++ " in " ++ prettyPrec 0 body
>   prettyPrec p (EApp e1 e2) =
>     mparens (p>2) $
>       prettyPrec 2 e1 ++ " " ++ prettyPrec 3 e2
> instance Pretty IntVar where
>   pretty = mkVarName "u"
> instance Pretty TypeError where
>   pretty (UnboundVar x)     = printf "Unbound variable %s" x
>   pretty (Infinite x ty)    = printf "Infinite type %s = %s" (pretty x) (pretty ty)
>   pretty (Mismatch ty1 ty2) = printf "Can't unify %s and %s" (pretty ty1) (pretty ty2)
> instance Pretty Value where
>   pretty (VInt n) = show n
>   pretty (VClo x body env)
>     = printf "<%s: %s %s>"
>       x (pretty body) (pretty env)
> instance Pretty Env where
>   pretty env = "[" ++ intercalate ", " bindings ++ "]"
>     where
>       bindings = map prettyBinding (M.assocs env)
>       prettyBinding (x, v) = x ++ " -> " ++ pretty v
Posted in haskell, teaching | Tagged , , , | 1 Comment

Competitive programming in Haskell: monoidal accumulation

In my last competitive programming post, I challenged you to solve Please, Go First. In that problem, we are presented with a hypothetical scenario with people waiting in a queue for a ski lift. Each person is part of a friend group (possibly just themselves), but friend groups are not necessarily consecutive in line; when someone gets to the top they will wait for the last person in their friend group to arrive before skiing. We are asked to consider how much waiting time could be saved if people start letting others go ahead of them in line as long as it doesn’t cost them any waiting time and decreases the waiting time for the others.

There is actually a bit of ambiguity that we need to resolve first; to be honest, it’s not the most well-written problem statement. Consider this scenario, with three people in group A and two in group b:


Consider the person labelled b_1. Should they let A_2 pass? Letting A_2 pass would not change b_1’s waiting time: they have to wait for b_2 anyway and it does not matter whether they do the waiting at the top or bottom of the mountain. But it would not immediately change A_2’s waiting time, either: they still have to wait for A_3. What the problem literally says is “someone lets another pass if doing this doesn’t change his own total waiting time, but saves time for the other person”, so taking this literally would seem to imply that in this scenario b_1 does not let A_2 pass. However, the given example inputs and outputs imply that in this scenario b_1 should let A_2 pass; indeed, right after doing so, b_1 can then let A_3 pass as well, which saves time for both A_3 and A_2. So in the end, it seems we really want to say something like “x should let y pass if it doesn’t increase x’s waiting time and will eventually save time for y”.

The solution idea

It took me an embarrassingly long time to come up with the following key insight: after doing this process as much as possible, I claim that (1) all the friends within each friend group will be consecutive, and (2) the groups will be sorted by the original position of the last person in each group. To see why claim (2) is true, note that whenever someone is last in their friend group, moving backward in the line always increases their waiting time; so any two people who are both last in their friend group will never pass each other, since it would make the waiting time worse for the one who moves backward. That means the people who are last in their friend group will always remain in the same relative order. As for claim (1), I thought about it for a while and am so far unable to come up with a short, convincing proof, though I still believe it is true (and my solution based on it was accepted). If anyone has a good way to show why this must be true, I’d love to hear about it in the comments.

My second key insight is that the total amount of time saved for a given friend group depends only on (1) how many people are in the group and (2) how many places the last person in the group got to move up (although there are other ways to solve the problem; more below). In particular, the total time saved for the group will be the product of these two numbers, times five minutes. It’s irrelevant how many places someone moves if they are not last in their group, because they have to wait until that last person arrives, and it makes no difference if they do their waiting at the top or bottom of the mountain.

My solution

So here’s my solution, based on the above insights. First, let’s set up the main pipeline to read the input, solve each test case, and produce the output.

main = C.interact $
  runScanner (numberOf (int *> str)) >>> map (solve >>> showB) >>> C.unlines

showB is just a utility function I’ve recently added to my solution template which calls show and then converts the result to a ByteString using pack.

For a given test case, we need to first do a pass through the lift queue in order to accumulate some information about friend groups: for each group, we need to know how big it is, as well as the index of the last member of the group. In an imperative language, we would make accumulator variables to hold this information (probably two maps, aka dictionaries), and then iterate through the queue, imperatively updating the accumulator variables for each item. We can translate that approach more or less mechanically into Haskell, by having an update function that takes a single item and a tuple of accumulators as input, and returns a new tuple of accumulators as output. This is the approach taken by Aaron Allen, and sometimes that’s the best way to do something like this. However, in this particular scenario—looping over a list and accumulating some information—the accumulators are often monoidal, which gives us much nicer tools to work with, such as foldMap and Data.Map.fromListWith (<>).

We’ll make a type Group to represent the needed information about a friend group: the number of people and the index of the last person. We can use DerivingVia to create an appropriate Semigroup instance for it (in this case we actually don’t need Monoid since there is no such thing as an empty group). Note that we use First Int instead of the expected Last Int; this is explained below.

newtype Group = Group { unGroup :: (Int, Int) }
  deriving Semigroup via (Sum Int, First Int)
  deriving Show

Now we can write the code to calculate the total time save for a given starting queue.

solve :: ByteString -> Int
solve (C.unpack -> queue) = timeSaved

We first map over the queue and turn each item into a singleton Group (imap is a utility to do an indexed map, with type (Int -> a -> b) -> [a] -> [b]); then we use M.fromListWith (<>) to build a Map associating each distinct character to a Group. The Semigroup instance will take care of summing the number of friends and keeping only the last index in each group. Note that fromListWith is implemented via a left fold, which explains why we needed to use First Int instead of Last Int: the list items will actually be combined in reverse order. (Alternatively, we could use Last Int and M.fromListWith (flip (<>)); of course, this is only something we need to worry about when using a non-commutative Semigroup).

    groupInfo :: Map Char Group
    groupInfo = queue >$> imap (\i c -> (c, Group (1, i))) >>> M.fromListWith (<>)

Now we can sort the queue by index of the last member of each friend group, producing its final form:

    sortedQueue = sortOn ((groupInfo!) >>> unGroup >>> snd) queue

Computing the total time saved is now just a matter of figuring out how much each last friend moved and summing the time save for each friend group:

    timeSaved = sortedQueue >$> zip [0 :: Int ..]   -- final positions
      >>> groupBy ((==) `on` snd)                   -- put groups together
      >>> map (last >>> timeSaveForGroup) >>> sum
        -- get the time save based on the last person in each group

    timeSaveForGroup (i,c) = 5 * size * (idx - i)
        Group (size, idx) = groupInfo!c

This is not the fastest way to solve the problem—in fact, my solution is slowest of the five Haskell solutions so far!—but I wanted to illustrate this technique of accumulating over an array using a Semigroup and M.fromListWith. foldMap can be used similarly when we need just a single result value rather than a Map of some sort.

Other solutions

Several people linked to their own solutions. I already mentioned Aaron Allen’s solution above. Anurudh Peduri’s solution works by computing the initial and final wait time for each group and subtracting; notably, it simply sorts the groups alphabetically, not by index of the final member of the group. I don’t quite understand it, but I think this works because the initial and final wait times would change by the same amount when permuting the groups in line, so ultimately this cancels out.

Tim Put’s solution is by far the fastest (and, in my opinion, the cleverest). For each friend in a friend group, it computes the number of people in other friend groups who stand between them and the last person in their group (using a clever combination of functions including ByteString.elemIndices). Each such person represents a potential time save of 5 minutes, all of which will be realized once the groups are all consecutive. Hence all we have to do is sum these numbers and multiply by 5. It is instructive thinking about why this works. It does not compute the actual time saved by each group, just the potential time save represented by each group. That potential time save might be realized by the group itself (if the last person in the group gets to move up) or by a different group (if someone in the group lets others go ahead of them). Ultimately, though, it does not matter how much time is saved by each group, only the total amount of time saved.

Next time: Purple Rain

For next time, I invite you to solve Purple Rain. This problem has a solution which is “well known” in competitive programming (if you need a hint, ybbx hc Xnqnar’f Nytbevguz); the challenge is to translate it into idiomatic (and, ideally, reusable) Haskell.

Posted in competitive programming, haskell | Tagged , | 7 Comments

Types versus sets in math and programming languages

For several years I have been designing and implementing a functional teaching language especially for use in the context of a Discrete Mathematics course. The idea is for students to be exposed to some functional and statically-typed programming early in their computer science education, and to give them a fun and concrete way to see the connections between the concepts they learn in a Discrete Math course and computation. I am not the first to think of combining FP + Discrete Math, but I think there is an opportunity to do it really well with a language designed expressly for the purpose. (And, who am I kidding, designing and implementing a language is just plain fun.)

Of course the language has an expressive static type system, with base types like natural numbers, rationals, Booleans, and Unicode characters, as well as sum and product types, lists, strings, and the ability to define arbitrary recursive types. It also has built-in types and syntax for finite sets. For example,

A : Set ℕ
A = {1, 3, 6}

(Incidentally, I will be using Unicode syntax since it looks nice, but there are also ASCII equivalents for everything.) Sets support the usual operations like union, intersection, and difference, as well as set comprehension notation. The intention is that this will provide a rich playground for students to play around with the basic set theory that is typically taught in a discrete math class.

But wait…

Hopefully the above all seems pretty normal if you are used to programming in a statically typed language. Unfortunately, there is something here that I suspect is going to be deeply confusing to students. I am so used to it that it took me a long time to realize what was wrong; maybe you have not realized it either. (Well, perhaps I gave it away with the title of the blog post…)

In a math class, we typically tell students that \mathbb{N} is a set. But in Disco, is a type and something like {1,2,3} is a set. If you have been told that \mathbb{N} is a set, the distinction is going to seem very weird and artificial to you. For example, right now in Disco, you can ask whether {1,2} is a subset of {1,2,3}:

Disco> {1,2} ⊆ {1,2,3}

But if you try to ask whether {1,2} is a subset of , you get a syntax error:

Disco> {1,2} ⊆ ℕ
1 | {1,2} ⊆ ℕ
  |          ^
keyword "ℕ" cannot be used as an identifier

Now, we could try various things to improve this particular example—at the very least, make it fail more gracefully. But the fundamental question remains: what is the distinction between types and sets, and why is it important? If it’s not important, we should get rid of it; if it is important, then I need to be able to explain it to students!

We could try to completely get rid of the distinction, but this seems like it would lead directly to a dependent type system and refinement types. Refinement types are super cool but I really don’t think I want to go there (Disco’s type system is already complicated enough).

However, I think there actually is an important distinction; this blog post is my first attempt at crystallizing my thoughts on the distinction and how I plan to explain it to students.

Types vs sets

So what is the difference between sets and types? The slogan is that types are intensional, whereas sets are extensional. (I won’t actually use those words with my students.) That is:

  • Sets are characterized by the \in relation: we can ask which items are elements of a set and which are not.
  • Types, on the other hand, are characterized by how elements of the type are built: we can construct elements of a type (and deconstruct them) in certain ways specific to the type.

This seems kind of symmetric, but it is not. You can’t ask whether a thing is an element of a set if you don’t know how to even make or talk about any things in the first place. So types are prior to sets: types provide a universe of values, constructed in orderly ways, that we can work with; only then can we start picking out certain values to place them in a set.

Of course, this all presupposes some kind of type theory as foundational. Of course I am aware that one can instead take axiomatic set theory as a foundation and build everything up from the empty set. But I’m building a typed functional programming language, so of course I’m taking type theory as foundational! More importantly, however, it’s what almost every working mathematician does in practice. No one actually works or thinks in terms of axiomatic set theory (besides set theorists). Even in a typical math class, some sets are special. Before we can talk about the set {1,3,6}, we have to introduce the special set \mathbb{N} so we know what 1, 3, and 6 are. Before we can talk about the set {(1,1), (3,5), (6,8)} we have to introduce the special Cartesian product operation on sets so we know what tuples are. And so on. We can think of types as a language for describing this prior class of special sets.

Explaining things to students

So what will I actually say to my students? First of all, when introducing the language, I will tell them about various built-in primitive types like naturals, rationals, booleans, and characters. I won’t make a big deal about it, and I don’t think I will need to: for the most part they will have already seen a language like Python or Java with types for primitive values.

When we get to talking about sets, however (usually the second unit, after starting with propositional logic), we will define sets as collections of values, and I will explicitly point out the similarity to types. I will tell them that types are special built-in sets with rules for building their elements. We will go on to talk about disjoint union and Cartesian product, and practice building elements of sum and product types. (When we later get to recursion, they will therefore have the tools they need to start building recursive types such as lists and trees.)

The other thing to mention will be the way that when we write the type of a set, as in, Set ℕ, we have to write down the type of the elements—in other words, the universe, or ambient set from which the elements are chosen. When introducing set theory, traditionally one mentions universe sets only when talking about the set complement operation; but the fact is that mathematicians always have some universe set in mind when describing a given set.

Now, coming back to the example of {1,2} ⊆ ℕ, it would still be confusing for students if this is a syntax error, and I have some ideas about how to make it work. Briefly, the idea is to allow types to be used in expressions (but not the other way around!), with T : Set T. If I tell them that types are special sets, then logically they will expect to be able to use them as such! However, this is an extremely nontrivial change: it means that Disco would now be able to represent infinite sets, requiring sets to be internally represented via a deep embedding, rather than simply storing their elements (as is currently the case). For example, 2 ∈ (ℕ \ {3,5}) should evaluate to true, but we obviously can’t just enumerate all the elements of ℕ \ {3,5} since there are infinitely many. More on this in a future post, perhaps!

Posted in projects, teaching | Tagged , , , , , , | 10 Comments

Competitive programming in Haskell: folding folds

Now that the semester is over—and I will be on sabbatical in the fall!—you can expect a lot more competitive programming in Haskell posts. In a previous post, I challenged you to solve Origami. j0sejuan took me up on the challenge, as did Aaron Allen and Ryan Yates; if you still want to try it, go do it before reading on!

In the problem, we start with a square sheet of paper and are given a series of folds to perform in sequence; each fold is specified as a line, and we fold whatever is on one side of the line across onto the other side. Given some query points, we have to compute how thick the resulting origami design is at each point.


The first order of business is some computational geometry relating to lines in 2D (this code can all be found in Geom.hs. Here I am following Victor Lecomte’s excellent Handbook of geometry for competitive programmers, which I think I’ve mentioned before. I’ll try to give a bit of explanation, but if you want full explanations and proofs you should consult that document.

The equation of a line ax + by = c can be thought of as the set of all points (x,y) whose dot product with the vector (a,b) is a constant c. This will in fact be a line perpendicular to the vector (a,b), where c determines the distance of the line from the origin. Alternatively, we can think of the vector (b,-a), which is perpendicular to (a,b) and thus parallel to the line; the line now consists of all points (x,y) whose 2D cross product with (b,-a) is the constant c (since (b,-a) \times (x,y) = by - (-a)x = ax + by; note that the order matters). Either representation would work, but I will follow Lecomte in choosing the second: we represent a line by a vector giving its direction, and a scalar offset.

data L2 s = L2 { getDirection :: !(V2 s), getOffset :: !s }
type L2D = L2 Double

There are a few ways to construct a line: from an equation ax + by = c, or from two points which lie on the line. The first is easy, given the above discussion. For the second, given points p and q, we can easily construct the direction of the line as v = q - p. Then to get the constant c, we simply use the fact that c is the cross product of the direction vector with any point on the line, say, p (of course q would also work).

lineFromEquation :: Num s => s -> s -> s -> L2 s
lineFromEquation a b c = L2 (V2 b (-a)) c

lineFromPoints :: Num s => P2 s -> P2 s -> L2 s
lineFromPoints p q = L2 v (v `cross` p)
    v = q ^-^ p

Now we can write some functions to decide where a given point lies with respect to a line. First, the side function computes ax + by - c = (b,-a) \times (x,y) - c for any point p = (x,y).

side :: Num s => L2 s -> P2 s -> s
side (L2 v c) p = cross v p - c

Of course, for points that lie on the line, this quantity will be zero. We can also classify points p as lying to the left or right of the line (looking in the direction of v) depending on whether side l p is positive or negative, respectively.

onLine :: (Num s, Eq s) => L2 s -> P2 s -> Bool
onLine l p = side l p == 0

leftOf :: (Num s, Ord s) => L2 s -> P2 s -> Bool
leftOf l p = side l p > 0

rightOf :: (Num s, Ord s) => L2 s -> P2 s -> Bool
rightOf l p = side l p < 0

The last piece we will need to solve the problem is a way to reflect a point across a line. toProjection l p computes the vector perpendicular to l which points from p to l, and reflectAcross works by adding toProjection l p to p twice. I won’t derive the definition of toProjection, but the basic idea is to start with a vector perpendicular to the direction of the line (perp v) and scale it by a factor related to side l p. (Intuitively, it makes sense that ax + by - c tells us something about the distance from (x,y) to the line; the farther away (x,y) is from the line, the farther ax + by is from c.) See Lecomte for the full details.

toProjection :: Fractional s => L2 s -> P2 s -> V2 s
toProjection l@(L2 v _) p = (-side l p / normSq v) *^ perp v

project :: Fractional s => L2 s -> P2 s -> P2 s
project l p = p ^+^ toProjection l p

reflectAcross :: Fractional s => L2 s -> P2 s -> P2 s
reflectAcross l p = p ^+^ (2 *^ toProjection l p)

Folding origami

Finally we can solve the problem! First, some imports and input parsing.

{-# LANGUAGE RecordWildCards #-}

import           Control.Arrow
import qualified Data.ByteString.Lazy.Char8 as C

import           Geom
import           ScannerBS

main = C.interact $
  runScanner tc >>> solve >>> map (show >>> C.pack) >>> C.unlines

data TC = TC { steps :: [L2D], holes :: [P2D] }

tc = TC <$> numberOf (lineFromPoints <$> p2 double <*> p2 double) <*> numberOf (p2 double)

solve :: TC -> [Int]
solve TC{..} = map countLayers holes

For countLayers, the idea is to work backwards from a given query point to find all its preimages, that is, the points that will eventually map to that point under the folds. Then we can just count how many of those points lie (strictly) inside the original square of paper.

    inSquare (V2 x y) = 0 < x && x < 1000 && 0 < y && y < 1000

For a given point and fold, there are two possibilities, depending on which side of the fold line the point falls on. If the point falls on the fold or to the right of it, then it has no preimages (we always fold from right to left, so after the fold, there will be no paper on the right side of the line, and the problem specifies that points exactly on a folded edge do not count). Hence we can just discard such a point. On the other hand, if the point lies on the left side of the line, then the point has two preimages: the point itself, and its reflection across the fold line.

    preimage :: L2D -> P2D -> [P2D]
    preimage l p
      | leftOf l p = [p, reflectAcross l p]
      | otherwise  = []

So we keep a set of points, starting with the singleton query point, and for each fold (in order from last to first) we find the preimage of every point in the set under the fold. We actually use lists of points instead of sets, because (1) we won’t ever get any collisions (actually, the more I think about this, the less sure I am!) and (2) it lets us use the actual list monad instead of making some ad-hoc Set monad operations.

    countLayers :: P2D -> Int
    countLayers q = length . filter inSquare $ foldr (\l -> (>>= preimage l)) [q] steps

It is very satisfying to use a fold to process a list of folds!

Next time: Please, Go First

For next time, I invite you to solve Please, Go First.

Posted in competitive programming, haskell | Tagged , , , , , , , | 4 Comments

Average and median via optimization

This is certainly not a new observation, and I’m sure it is folklore and/or contained in various textbooks, but it was amusing to derive it independently.

Suppose we have a finite set of real numbers \{x_1, \dots, x_n\}, and we want to pick a value m which is somehow “in the middle” of the x_i. The punchline is that

  • if we want to minimize the sum of the squared distances from m to each x_i, we should pick m to be the average of the x_i;
  • if we want to minimize the sum of the absolute distances from m to each x_i, we should pick m to be the median of the x_i.

The first of these is tricky to understand intuitively but easy to derive; the second is intuitively straightforward but trying to derive it leads to an interesting twist.

Average = minimizing sum of squared distances

Let’s not worry about why we would want to minimize the sum of squared distances; there are good reasons and it’s not the point. I don’t know about you, but I find it difficult to reason intuitively about how and why to pick m to minimize this sum of squared differences. If you know of an intuitive way to explain this, I would love to hear about it! But in any case, it is easy to derive using some strightforward calculus.

Let \displaystyle S(m) = \sum_i(m - x_i)^2 denote the sum of squared distances from a given m to each of the x_i. Taking the derivative of S with respect to m, we find

\displaystyle \frac{d}{dm} S(m) = \sum_i 2(m - x_i).

Setting the derivative equal to zero, we can first divide through by the factor of 2, yielding

\displaystyle 0 = \sum_i (m - x_i)

Since m does not depend on i, this is just n copies of m less the sum of the x_i. Hence, solving for m yields

\displaystyle m = \frac{1}{n} \sum_i x_i

as expected: the value of m which minimizes the sum of squared distances to the x_i is their average, that is, the sum of the x_i divided by the size of the set.

Median = minimizing sum of absolute distances

Now suppose we want to minimize the sum of absolute distances instead, that is,

S(m) = \sum_i |m - x_i|

In this scenario, it is much easier to reason out the correct answer. Start with some arbitrary m, and imagine nudging it by some small amount \Delta x, say, to the right. m’s distances to any points on its left will each increase by \Delta x, and its distances to any points on its right will each decrease by the same amount. Therefore, if there are more x_i to the left of m, then the overall sum of distances distances will increase; if there are more x_i to the right, then the overall sum will decrease. So, to find m which minimizes the sum of absolute differences, we want the same number of x_i on the left and the right, that is, we want the median. Note that if n is odd, then we must pick m to be exactly equal to the x_i in the middle; if n is even, then we can pick m to be anywhere inside the interval between the middle two x_i.

Just for fun, can we derive this answer using calculus, like we did for minimizing squared differences? There is a wrinkle, of course, which is that the absolute value function is not differentiable everywhere: it has a sharp corner at zero. But we won’t let that stop us! Clearly the derivative of |x| is -1 when x < 0 and 1 when x > 0. So it seems reasonable to just assign the derivative a value of 0 at x = 0. Algebraically, we can define

\displaystyle \frac{d}{dx} |x| = [x > 0] - [x < 0]

where [P] is equal to 1 when the proposition P is true, and 0 when it is false (this notation is called the Iverson bracket). So when x > 0 we get [x > 0] - [x < 0] = 1 - 0 = 1; when x < 0 we get 0 - 1 = -1; and when x = 0 both propositions are false so we get 0 - 0 = 0.

Armed with this definition, we can differentiate S with respect to m:

\displaystyle \frac{d}{dm} S(m) = \frac{d}{dm} \sum_i |m - x_i| = \sum_i [m > x_i] - \sum_i [m < x_i]

Clearly, this is zero when \displaystyle \sum_i [m > x_i] = \sum_i [m < x_i], that is, when there are the same number of x_i on either side of m.

The curious thing to me is that even though the derivative of |x| is undefined when x = 0, it seems like it “wants” to be 0 here. In general, if we assign the value k to the derivative at x = 0, the derivative of S becomes

\displaystyle \frac{d}{dm} S(m) = \sum_i [m > x_i] + k \sum_i [m = x_i] - \sum_i [m < x_i]

When k is nonzero and n is odd, there are no values of m for which this derivative is zero, making it more difficult to find the minimum.

Posted in math | Tagged , , , , | 1 Comment

Competitive programming in Haskell: folding challenge

As of this writing, I am the only person who has solved Origami on Open Kattis (see the problem statistics here). This is a shame since it is a very nice problem! Of course, I solved it in Haskell.

I challenge you to solve it—bonus points for using a fold in your solution!

Posted in competitive programming, haskell | Tagged , , , , | 8 Comments