## Data structure challenge: finding the rightmost empty slot

Suppose we have a sequence of slots indexed from 1 to $n$. Each slot can be either empty or full, and all start out empty. We want to repeatedly do the following operation:

• Given an index $i$, find the rightmost empty slot at or before index $i$, and mark it full.

We can also think of this in terms of two more fundamental operations:

• Mark a given index $i$ as full.
• Given an index $i$, find the greatest index $j$ such that $j \leq i$ and $j$ is empty (or $0$ if there is no such $j$).

The simplest possible approach would be to use an array of booleans; then marking a slot full is trivial, and finding the rightmost empty slot before index $i$ can be done with a linear scan leftwards from $i$. But the challenge is this:

Can you think of a data structure to support both operations in $O(\lg n)$ time or better?

You can think of this in either functional or imperative terms. I know of two solutions, which I’ll share in a subsequent post, but I’m curious to see what people will come up with.

Note that in my scenario, slots never become empty again after becoming full. As an extra challenge, what if we relax this to allow setting slots arbitrarily?

Posted in data structures | Tagged , , , , , , , | 11 Comments

## Competitive Programming in Haskell: modular arithmetic, part 2

In my last post I wrote about modular exponentiation and egcd. In this post, I consider the problem of solving modular equivalences, building on code from the previous post.

# Solving linear congruences

A linear congruence is a modular equivalence of the form

$ax \equiv b \pmod m$.

Let’s write a function to solve such equivalences for $x$. We want a pair of integers $y$ and $k$ such that $x$ is a solution to $ax \equiv b \pmod m$ if and only if $x \equiv y \pmod k$. This isn’t hard to write in the end, but takes a little bit of thought to do it properly.

First of all, if $a$ and $m$ are relatively prime (that is, $\gcd(a,m) = 1$) then we know from the last post that $a$ has an inverse modulo $m$; multiplying both sides by $a^{-1}$ yields the solution $x \equiv a^{-1} b \pmod m$.

OK, but what if $\gcd(a,m) > 1$? In this case there might not even be any solutions. For example, $2x \equiv 3 \pmod 4$ has no solutions: any even number will be equivalent to $0$ or $2$ modulo $4$, so there is no value of $x$ such that double it will be equivalent to $3$. On the other hand, $2x \equiv 2 \pmod 4$ is OK: this will be true for any odd value of $x$, that is, $x \equiv 1 \pmod 2$. In fact, it is easy to see that any common divisor of $a$ and $m$ must also divide $b$ in order to have any solutions. In case the GCD of $a$ and $m$ does divide $b$, we can simply divide through by the GCD (including dividing the modulus $m$!) and then solve the resulting equivalence.

-- solveMod a b m solves ax = b (mod m), returning a pair (y,k) (with
-- 0 <= y < k) such that x is a solution iff x = y (mod k).
solveMod :: Integer -> Integer -> Integer -> Maybe (Integer, Integer)
solveMod a b m
| g == 1         = Just ((b * inverse m a) mod m, m)
| b mod g == 0 = solveMod (a div g) (b div g) (m div g)
| otherwise      = Nothing
where
g = gcd a m

# Solving systems of congruences with CRT

In its most basic form, the Chinese remainder theorem (CRT) says that if we have a system of two modular equations

$\begin{array}{rcl}x &\equiv& a \pmod m \\ x &\equiv& b \pmod n\end{array}$

then as long as $m$ and $n$ are relatively prime, there is a unique solution for $x$ modulo the product $mn$; that is, the system of two equations is equivalent to a single equation of the form

$x \equiv c \pmod {mn}.$

We first compute the Bézout coefficients $u$ and $v$ such that $mu + nv = 1$ using egcd, and then compute the solution as $c = anv + bmu$. Indeed,

$c = anv + bmu = a(1 - mu) + bmu = a - amu + bmu = a + (b-a)mu$

and hence $c \equiv a \pmod m$; similarly $c \equiv b \pmod n$.

However, this is not quite general enough: we want to still be able to say something useful even if $\gcd(m,n) > 1$. I won’t go through the whole proof, but it turns out that there is a solution if and only if $a \equiv b \pmod {\gcd(m,n)}$, and we can just divide everything through by $g = \gcd(m,n)$, as we did for solving linear congruences. Here’s the code:

-- gcrt2 (a,n) (b,m) solves the pair of modular equations
--
--   x = a (mod n)
--   x = b (mod m)
--
-- It returns a pair (c, k) such that all solutions for x satisfy x =
-- c (mod k), that is, solutions are of the form x = kt + c for
-- integer t.
gcrt2 :: (Integer, Integer) -> (Integer, Integer) -> Maybe (Integer, Integer)
gcrt2 (a,n) (b,m)
| a mod g == b mod g = Just (((a*v*m + b*u*n) div g) mod k, k)
| otherwise              = Nothing
where
(g,u,v) = egcd n m
k = (m*n) div g

From here we can bootstrap ourselves into solving systems of more than two equations, by iteratively combining two equations into one.

-- gcrt solves a system of modular equations.  Each equation x = a
-- (mod n) is given as a pair (a,n).  Returns a pair (z, k) such that
-- solutions for x satisfy x = z (mod k), that is, solutions are of
-- the form x = kt + z for integer t.
gcrt :: [(Integer, Integer)] -> Maybe (Integer, Integer)
gcrt []         = Nothing
gcrt [e]        = Just e
gcrt (e1:e2:es) = gcrt2 e1 e2 >>= \e -> gcrt (e:es)

# Practice problems

And here are a bunch of problems for you to practice!

Posted in haskell | | 1 Comment

## What would Dijkstra do? Proving the associativity of min

This semester I’m teaching a Discrete Mathematics course. Recently, I assigned them a homework problem from the textbook that asked them to prove that the binary $\min$ operator on the real numbers is associative, that is, for all real numbers $a$, $b$, and $c$,

$\min(a, \min(b,c)) = \min(\min(a,b), c)$.

You might like to pause for a minute to think about how you would prove this! Of course, how you prove it depends on how you define $\min$, so you might like to think about that too.

The book expected them to do a proof by cases, with some sort of case split on the order of $a$, $b$, and $c$. What they turned in was mostly pretty good, actually, but while grading it I became disgusted with the whole thing and thought there has to be a better way.

I was reminded of an example of Dijkstra’s that I remember reading. So I asked myself—what would Dijkstra do? The thing I remember reading may have, in fact, been this exact proof, but I couldn’t remember any details and I still can’t find it now, so I had to (re-)work out the details, guided only by some vague intuitions.

Dijkstra would certainly advocate proving associativity of $\min$ using a calculational approach. Dijkstra would also advocate using a symmetric infix operator symbol for a commutative and associative operation, so let’s adopt the symbol $\downarrow$ for $\min$. ($\sqcap$ would also be a reasonable choice, though I find it less mnemonic.)

How can we calculate with $\downarrow$? We have to come up with some way to characterize it that allows us to transform expressions involving $\downarrow$ into something else more fundamental. The most obvious definition would be “$a \downarrow b = a$ if $a \leq b$, and $b$ otherwise”. However, although this is a fantastic implementation of $\downarrow$ if you actually want to run it, it is not so great for reasoning about $\downarrow$, precisely because it involves doing a case split on whether $a \leq b$. This is the definition that leads to the ugly proof by cases.

How else could we define it? The usual more mathematically sophisticated way to define it would be as a greatest lower bound, that is, “$x = a \downarrow b$ if and only if $x \leq a$ and $x \leq b$ and $x$ is the greatest such number, that is, for any other $y$ such that $y \leq a$ and $y \leq b$, we have $y \leq x$.” However, this is a bit roundabout and also not so conducive to calculation.

My first epiphany was that the best way to characterize $\downarrow$ is by its relationship to $\leq$. After one or two abortive attempts, I hit upon the right idea:

$(a \leq b \downarrow c) \leftrightarrow (a \leq b \land a \leq c)$

That is, an arbitrary $a$ is less than or equal to the minimum of $b$ and $c$ precisely when it is less than or equal to both. In fact, this completely characterizes $\downarrow$, and is equivalent to the second definition given above.1 (You should try convincing yourself of this!)

But how do we get anywhere from $a \downarrow (b \downarrow c)$ by itself? We need to somehow introduce a thing which is less than or equal to it, so we can apply our characterization. My second epiphany was that equality of real numbers can also be characterized by having the same “downsets”, i.e. two real numbers are equal if and only if the sets of real numbers less than or equal to them are the same. That is,

$(x = y) \leftrightarrow (\forall z.\; (z \leq x) \leftrightarrow (z \leq y))$

Now the proof almost writes itself. Let $z \in \mathbb{R}$ be arbitrary; we calculate as follows:

$\begin{array}{cl} & z \leq a \downarrow (b \downarrow c) \\ \leftrightarrow & \\ & z \leq a \land (z \leq b \downarrow c) \\ \leftrightarrow & \\ & z \leq a \land (z \leq b \land z \leq c) \\ \leftrightarrow & \\ & (z \leq a \land z \leq b) \land z \leq c \\ \leftrightarrow & \\ & (z \leq a \downarrow b) \land z \leq c \\ \leftrightarrow & \\ & z \leq (a \downarrow b) \downarrow c \end{array}$

Of course this uses our characterization of $\downarrow$ via its relationship to $\leq$, along with the fact that $\land$ is associative. Since we have proven that $z \leq a \downarrow (b \downarrow c)$ if and only if $z \leq (a \downarrow b) \downarrow c$ for arbitrary $z$, therefore $a \downarrow (b \downarrow c) = (a \downarrow b) \downarrow c$.

1. Thinking about it later, I realized that this should not be surprising: it’s just characterizing $\downarrow$ as the categorical product, i.e. meet, i.e. greatest lower bound, in the poset of real numbers ordered by the usual $\leq$.

Posted in math | | 16 Comments

## Competitive Programming in Haskell: modular arithmetic, part 1

Modular arithmetic comes up a lot in computer science, and so it’s no surprise that it is featured, either explicitly or implicitly, in many competitive programming problems.

As a brief aside, to be good at competitive programming it’s not enough to have a library of code at your disposal (though it certainly helps!). You must also deeply understand the code in your library and how it works—so that you know when it is applicable, what the potential pitfalls are, how to debug when things don’t work, and how to make modifications to the code to fit some new problem. I will try to explain all the code I exhibit here—why, and not just how, it works. But you’ll also ultimately be better off if you write your own code rather than using mine! Read my explanations for ideas, and then go see if you can replicate the functionality you need.

# Modular exponentiation

We start with a simple implementation of modular exponentiation, that is, computing $b^e \pmod m$, via repeated squaring. This comes up occasionally in both number theory problems (unsurprisingly) and combinatorics problems (because such problems often ask for a very large answer to be given modulo $10^9+7$ or some other large prime).

This works via the recurrence

$\begin{array}{rcl}x^0 &=& 1 \\[0.5em] x^{2n} &=& (x^n)^2 \\[0.5em] x^{2n+1} &=& x \cdot (x^n)^2\end{array}$

and using the fact that taking the remainder $\pmod m$ commutes with multiplication.

modexp :: Integer -> Integer -> Integer -> Integer
modexp _ 0 _ = 1
modexp b e m
| even e    = (r*r) mod m
| otherwise = (b*r*r) mod m
where
r = modexp b (e div 2) m

This could probably be slightly optimized, but it’s hardly worth it; since the number of multiplications performed is proportional to the logarithm of the exponent, it’s pretty much instantaneous for any inputs that would be used in practice.

However, there’s another technique, obvious in retrospect, that I have recently discovered. Many competitive programming problems ask you to compute the answer modulo some fixed number (usually a large prime). In this context, all arithmetic operations are going to be carried out modulo the same value. With Haskell’s great facilities for cheap abstraction it makes perfect sense to write something like this:

m :: Integer
m = 10^9 + 7   -- or whatever the modulus is supposed to be

-- Make a newtype for integers mod m
newtype M = M { unM :: Integer }
deriving (Eq, Ord)
instance Show M where show = show . unM

-- Do all arithmetic operations mod m
instance Num M where
fromInteger n = M (n mod m)
(M a) + (M b) = M ((a + b) mod m)
(M a) - (M b) = M ((a - b) mod m)
(M a) * (M b) = M ((a * b) mod m)
abs    = undefined  -- make the warnings stop
signum = undefined

The fun thing is that now the normal exponentiation operator (^) does modular exponentiation for free! It is implemented using repeated squaring so it’s quite efficient. You can now write all your code using the M type with normal arithmetic operations, and it will all be carried out mod m automatically.

Here are a couple problems for you to try:

# Extended gcd

Beyond modular exponentiation, the workhorse of many number theory problems is the extended Euclidean Algorithm. It not only computes the GCD $g$ of $a$ and $b$, but also computes $x$ and $y$ such that $ax + by = g$ (which are guaranteed to exist by Bezout’s identity).

First, let’s recall how to compute the GCD via Euclid’s Algorithm:

gcd a 0 = abs a
gcd a b = gcd b (a mod b)

I won’t explain how this works here; you can go read about it at the link above, and it is well-covered elsewhere. But let’s think how we would find appropriate values $x$ and $y$ at the same time. Suppose the recursive call gcd b (a mod b), in addition to returning the greatest common divisor $g$, were to also return values $x$ and $y$ such that $bx + (a \bmod b)y = g$. Then our goal is to find $x'$ and $y'$ such that $ax' + by' = g$, which we can compute as follows:

$\begin{array}{rcl}g &=& bx + (a \bmod b)y \\[0.5em] &=& bx + (a - b\lfloor a/b \rfloor)y \\[0.5em] &=& bx + ay - b\lfloor a/b \rfloor y = ay + b(x - \lfloor a/b \rfloor y)\end{array}$

Hence $x' = y$ and $y' = x - \lfloor a/b \rfloor y$. Note the key step of writing $a \bmod b = a - b \lfloor a/b \rfloor$: If we take the integer quotient of $a$ divided by $b$ and then multiply by $b$ again, we don’t necessarily get $a$ back exactly, but what we do get is the next smaller multiple of $b$. Subtracting this from the original $a$ gives $a \bmod b$.

-- egcd a b = (g,x,y)
--   g is the gcd of a and b, and ax + by = g
egcd :: Integer -> Integer -> (Integer, Integer, Integer)
egcd a 0 = (abs a, signum a, 0)
egcd a b = (g, y, x - (a div b) * y)
where
(g,x,y) = egcd b (a mod b)

Finally, egcd allows us to find modular inverses. The modular inverse of $a \pmod m$ is a number $b$ such that $ab \equiv 1 \pmod m$, which will exist as long as $\gcd(m,a) = 1$: in that case, by Bezout’s identity, there exist $x$ and $y$ such that $mx + ay = 1$, and hence $mx + ay \equiv 0 + ay \equiv ay \equiv 1 \pmod m$ (since $mx \equiv 0 \pmod m$). So $y$ is the desired modular inverse of $a$.

-- inverse m a  is the multiplicative inverse of a mod m.
inverse :: Integer -> Integer -> Integer
inverse m a = y mod m
where
(_,_,y) = egcd m a

Of course, this assumes that $m$ and $a$ are relatively prime; if not it will silently give a bogus answer. If you’re concerned about that you could check that the returned GCD is 1 and throw an error otherwise.

And here are a few problems for you to try!

In part 2 I’ll consider the task of solving modular equations.

## Unexpected benefits of version control

Last night I dreamed that I had shaved off my mustache, and my wife was not happy about it. But in the dream, I couldn’t remember when or why I had shaved. Suddenly, it occurred to me: when I merged those old git commits I must have accidentally merged one that updated my mustache. “Thank goodness for version control,” I thought, as I went to revert that commit. Then I woke up.

Posted in humor | Tagged , , , | 1 Comment

## Competitive Programming in Haskell: primes and factoring

Number theory is a topic that comes up fairly regularly in competitive programming, and it’s a very nice fit for Haskell. I’ve developed a bunch of code over the years that regularly comes in handy. None of this is particularly optimized, and it’s definitely no match for a specialized library like arithmoi, but in a competitive programming context it usually does the trick!

A few imports first:

import           Control.Arrow
import           Data.List     (group, sort)
import           Data.Map      (Map)
import qualified Data.Map      as M

# Primes

We start with a basic definition of the list of primes, made with a simple recursive sieve, but with one very big optimization: when we find a prime $p$, instead of simply filtering out all the multiples of $p$ in the rest of the list, we first take all the numbers less than $p^2$ and pass them through without testing; composite numbers less than $p^2$ would have already been filtered out by a smaller prime.

primes :: [Integer]
primes = 2 : sieve primes [3..]
where
sieve (p:ps) xs =
let (h,t) = span (< p*p) xs
in  h ++ sieve ps (filter ((/=0).(modp)) t)

I got this code from the Haskell wiki page on prime numbers. On my machine this allows us to find all the primes up to one million in about 4 seconds. Not blazing fast by any means, and of course this is not actually a true sieve—but it’s short, relatively easy to remember, and works just fine for many purposes. (There are some competitive programming problems requiring a true sieve, but I typically solve those in Java. Maybe someday I will figure out a concise way to solve them in Haskell.)

# Factoring

Now that we have our list of primes, we can write a function to find prime factorizations:

listFactors :: Integer -> [Integer]
listFactors = go primes
where
go _      1 = []
go (p:ps) n
| p*p > n = [n]
| n mod p == 0 = p : go (p:ps) (n div p)
| otherwise      = go ps n

This is relatively straightforward. Note how we stop when the next prime is greater than the square root of the number being tested, because if there were a prime factor we would have already found it by that point.

# …and related functions

Finally we can use listFactors to build a few other useful functions:

factor :: Integer -> Map Integer Int
factor = listFactors >>> group >>> map (head &&& length) >>> M.fromList

divisors :: Integer -> [Integer]
divisors = factor >>> M.assocs >>> map (\(p,k) -> take (k+1) (iterate (*p) 1))
>>> sequence >>> map product

totient :: Integer -> Integer
totient = factor >>> M.assocs >>> map (\(p,k) -> p^(k-1) * (p-1)) >>> product

factor yields a Map whose keys are unique primes and whose values are the corresponding powers; for example, factor 600 = M.fromList [(2,3), (3,1), (5,2)], corresponding to the factorization $600 = 2^3 \cdot 3 \cdot 5^2$. It works by grouping together like prime factors (note that listFactors guarantees to generate a sorted list of prime factors), counting each group, and building a Map.

divisors n generates a list of all divisors of n. It works by generating all powers of each prime from $p^0$ up to $p^k$, and combining them in all possible ways using sequence. Note it does not guarantee to generate the divisors in order.

totient implements the Euler totient function: totient n says how many numbers from 1 to n are relatively prime to n. To understand how it works, see this series of four blog posts I wrote on my other blog: part 1, part 2, part 3, part 4.

# Problems

Here are a few problems for you to try (ordered roughly from easier to more difficult):

In a subsequent post I’ll continue on the number theory theme and talk about modular arithmetic.

## Counting inversions via rank queries

In a post from about a year ago, I explained an algorithm for counting the number of inversions of a sequence in $O(n \lg n)$ time. As a reminder, given a sequence $a_1, a_2, \dots, a_n$, an inversion is a pair of positions $i, j$ such that $a_i$ and $a_j$ are in the “wrong order”, that is, $i < j$ but $a_i > a_j$. There can be up to $n(n-1)/2$ inversions in the worst case, so we cannot hope to count them in faster than quadratic time by simply incrementing a counter. In my previous post, I explained one way to count inversions in $O(n \lg n)$ time, using a variant of merge sort.

I recently learned of an entirely different algorithm for achieving the same result. (In fact, I learned of it when I gave this problem on an exam and a student came up with an unexpected solution!) This solution does not use a divide-and-conquer approach at all, but hinges on a clever data structure.

Suppose we have a bag of values (i.e. a collection where duplicates are allowed) on which we can perform the following two operations:

1. Insert a new value into the bag.
2. Count how many values in the bag are strictly greater than a given value.

We’ll call the second operation a rank query because it really amounts to finding the rank or index of a given value in the bag—how many values are greater than it (and thus how many are less than or equal to it)?

If we can do these two operations in logarithmic time (i.e. logarithmic in the number of values in the bag), then we can count inversions in $O(n \lg n)$ time. Can you see how before reading on? You might also like to think about how we could actually implement a data structure that supports these operations.

## Counting inversions with bags and rank queries

So, let’s see how to use a bag with logarithmic insertion and rank queries to count inversions. Start with an empty bag. For each element in the sequence, see how many things in the bag are strictly greater than it, and add this count to a running total; then insert the element into the bag, and repeat with the next element. That is, for each element we compute the number of inversions of which it is the right end, by counting how many elements that came before it (and are hence in the bag already) are strictly greater than it. It’s easy to see that this will count every inversion exactly once. It’s also easy to see that it will take $O(n \lg n)$ time: for each of the $n$ elements, we do two $O(\lg n)$ operations (one rank query and one insertion).

In fact, we can do a lot more with this data structure than just count inversions; it sometimes comes in handy for competitive programming problems. More in a future post, perhaps!

So how do we implement this magical data structure? First of all, we can use a balanced binary search tree to store the values in the bag; clearly this will allow us to insert in logarithmic time. However, a plain binary search tree wouldn’t allow us to quickly count the number of values strictly greater than a given query value. The trick is to augment the tree so that each node also caches the size of the subtree rooted at that node, being careful to maintain these counts while inserting and balancing.

## Augmented red-black trees in Haskell

Let’s see some code! In Haskell, probably the easiest type of balanced BST to implement is a red-black tree. (If I were implementing this in an imperative language I might use splay trees instead, but they are super annoying to implement in Haskell. (At least as far as I know. I will definitely take you out for a social beverage of your choice if you can show me an elegant Haskell implementation of splay trees! This is cool but somehow feels too complex.)) However, this isn’t going to be some fancy, type-indexed, correct-by-construction implementation of red-black trees, although that is certainly fun. I am actually going to implement left-leaning red-black trees, mostly following Sedgewick; see those slides for more explanation and proof. This is one of the simplest ways I know to implement red-black trees (though it’s not necessarily the most efficient).

First, a red-black tree is either empty, or a node with a color (which we imagine as the color of the incoming edge), a cached size, a value, and two subtrees.

> {-# LANGUAGE PatternSynonyms #-}
>
> data Color = R | B
>   deriving Show
>
> otherColor :: Color -> Color
> otherColor R = B
> otherColor B = R
>
> data RBTree a
>   = Empty
>   | Node Color Int (RBTree a) a (RBTree a)
>   deriving Show


To make some of the tree manipulation code easier to read, we make some convenient patterns for matching on the structure of a tree when we don’t care about the values or cached sizes: ANY matches any tree and its subtrees, while RED and BLACK only match on nodes of the appropriate color. We also make a function to extract the cached size of a subtree.

> pattern ANY   l r <- Node _ _ l _ r
> pattern RED   l r <- Node R _ l _ r
> pattern BLACK l r <- Node B _ l _ r
>
> size :: RBTree a -> Int
> size Empty            = 0
> size (Node _ n _ _ _) = n


The next thing to implement is the workhorse of most balanced binary tree implementations: rotations. The fiddliest bit here is managing the cached sizes appropriately. When rotating, the size of the root node remains unchanged, but the new child node, as compared to the original, has lost one subtree and gained another. Note also that we will only ever rotate around red edges, so we pattern-match on the color as a sanity check, although this is not strictly necessary. The error cases below should never happen.

> rotateL :: RBTree a -> RBTree a
> rotateL (Node c n t1 x (Node R m t2 y t3))
>   = Node c n (Node R (m + size t1 - size t3) t1 x t2) y t3
> rotateL _ = error "rotateL on non-rotatable tree!"
>
> rotateR :: RBTree a -> RBTree a
> rotateR (Node c n (Node R m t1 x t2) y t3)
>   = Node c n t1 x (Node R (m - size t1 + size t3) t2 y t3)
> rotateR _ = error "rotateR on non-rotatable tree!"


To recolor a node, we just flip its color. We can then split a tree with two red subtrees by recoloring all three nodes. (The “split” terminology comes from the isomorphism between red-black trees and 2-3-4 trees; red edges can be thought of as “gluing” nodes together into a larger node, and this recoloring operation corresponds to splitting a 4-node into three 2-nodes.)

> recolor :: RBTree a -> RBTree a
> recolor Empty            = Empty
> recolor (Node c n l x r) = Node (otherColor c) n l x r
>
> split :: RBTree a -> RBTree a
> split (Node c n l@(RED _ _) x r@(RED _ _))
>   = (Node (otherColor c) n (recolor l) x (recolor r))
> split _ = error "split on non-splittable tree!"


Finally, we implement a function to “fix up” the invariants by doing rotations as necessary: if we have two red subtrees we don’t touch them; if we have only one right red subtree we rotate it to the left (this is where the name “left-leaning” comes from), and if we have a left red child which itself has a left red child, we rotate right. (This function probably seems quite mysterious on its own; see Sedgewick for some nice pictures which explain it very well!)

> fixup :: RBTree a -> RBTree a
> fixup t@(ANY (RED _ _) (RED _ _)) = t
> fixup t@(ANY _         (RED _ _)) = rotateL t
> fixup t@(ANY (RED (RED _ _) _) _) = rotateR t
> fixup t = t


We can finally implement insertion. First, to insert into an empty tree, we create a red node with size 1.

> insert :: Ord a => a -> RBTree a -> RBTree a
> insert a Empty = Node R 1 Empty a Empty


If we encounter a node with two red children, we perform a split before continuing. This may violate the red-black invariants above us, but we will fix it up later on our way back up the tree.

> insert a t@(ANY (RED _ _) (RED _ _)) = insert a (split t)


Otherwise, we compare the element to be inserted with the root, insert on the left or right as appropriate, increment the cached size, and fixup the result. Notice that we don’t stop recursing upon encountering a value that is equal to the value to be inserted, because our goal is to implement a bag rather than a set. Here I have chosen to put values equal to the root in the left subtree, but it really doesn’t matter.

> insert a (Node c n l x r)
>   | a <= x    = fixup (Node c (n+1) (insert a l) x r)
>   | otherwise = fixup (Node c (n+1) l x (insert a r))


## Implementing rank queries

Now, thanks to the cached sizes, we can count the values greater than a query value.

> numGT :: Ord a => RBTree a -> a -> Int


The empty tree contains 0 values strictly greater than anything.

> numGT Empty _ = 0


For a non-empty tree, we distinguish two cases:

> numGT (Node _ n l x r) q


If the query value q is less than the root, then we know that the root along with everything in the right subtree is strictly greater than q, so we can just add 1 + size r without recursing into the right subtree. We also recurse into the left subtree to count any values greater than q it contains.

>   | q < x     = numGT l q + 1 + size r


Otherwise, if q is greater than or equal to the root, any values strictly greater than q must be in the right subtree, so we recurse to count them.

>   | otherwise = numGT r q


By inspection we can see that numGT calls itself at most once, moving one level down the tree with each recursive call, so it makes a logarithmic number of calls, with only a constant amount of work at each call—thanks to the fact that size takes only constant time to look up a cached value.

## Counting inversions

Finally, we can put together the pieces to count inversions. The code is quite simple: recurse through the list with an accumulating red-black tree, doing a rank query on each value, and sum the results.

> inversions :: Ord a => [a] -> Int
> inversions = go Empty
>   where
>     go _ []     = 0
>     go t (a:as) = numGT t a + go (insert a t) as


Let’s try it out!

λ> inversions [3,5,1,4,2]
6
λ> inversions [2,2,2,2,2,1]
5
λ> :set +s
λ> inversions [3000, 2999 .. 1]
4498500
(0.19 secs, 96,898,384 bytes)

It seems to work, and is reasonably fast!

## Exercises

1. Further augment each node with a counter representing the number of copies of the given value which are contained in the bag, and maintain the invariant that each distinct value occurs in only a single node.

2. Rewrite inversions without a recursive helper function, using a scan, a zip, and a fold.

3. It should be possible to implement bags with rank queries using fingertrees instead of building our own custom balanced tree type (though it seems kind of overkill).

4. My intuition tells me that it is not possible to count inversions faster than $n \lg n$. Prove it.

Posted in haskell | Tagged , , , , , , , , , | 4 Comments