Competitive programming in Haskell: building unordered trees

In my previous post I challenged you to solve Subway Tree System, which encodes trees by recording sequences of steps taken away from and towards the root while exploring the whole tree, and asks whether two such recordings denote the same tree. There are two main difficulties here: the first is how to do the parsing; second, how to compare two trees when we don’t care about the order of children at each node. Thanks to all of you who posted your solutions—I learned a lot. I often feel like my solution is obviously the “only” solution, but then when I see how others solve a problem I realize that the solution space is much larger than I thought!

My solution

Here’s my solution, with some commentary interspersed. First, some pragmas and imports and such:

{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TupleSections     #-}

import           Control.Arrow
import           Data.Bool
import qualified Data.ByteString.Lazy.Char8  as C
import           Data.Function
import           Data.List
import           Data.List.Split
import           Data.Map                    (Map)
import qualified Data.Map                    as M

import           Text.Parsec
import           Text.Parsec.ByteString.Lazy
import           Text.Parsec.Char

My main then looks like this:

main = C.interact $
  C.lines >>> drop 1 >>> chunksOf 2 >>>
  map (solve >>> bool "different" "same") >>> C.unlines

The use of ByteString instead of String isn’t really necessary for this problem, just habit. I split the input into lines, group them in twos using Data.List.Split.chunksOf, solve each test case, and turn the output into different or same appropriately. (Data.Bool.bool is the fold/case analysis for the Bool type; I never use it in any other Haskell code but am unreasonably fond of it for this particular use case.) It would also be possible to use the Scanner abstraction instead of lines, drop, and chunksOf, as commenter blaisepascal2014 did. In some ways that would actually be nicer, but I often default to using these more basic tools in simple cases.


Now for parsing the trees. The parsing is not too bad, and several commenters essentially did it manually with a recursive function manipulating a stack and so on; the most creative used a tree zipper to literally walk around the tree being constructed, just like you are supposedly walking around a subway in the problem. However, the parsec package is available in the Kattis environment, so the easiest thing is to actually whip up a proper little parser. (I know of several other Kattis problems which can be nicely solved using parser combinators but would be annoying otherwise, for example, Calculator and Otpor. A rather fiendish but fun parsing puzzle is Learning to Code.)

readTree :: C.ByteString -> Tree
readTree = parse parseTree "" >>> either undefined id
    parseTree    = Node     <$> parseForest
    parseForest  = fromList <$> many parseSubtree
    parseSubtree = char '0' *> parseTree <* char '1'

Of course I haven’t actually shown the definition of Tree, Node, or fromList yet, but hopefully you get the idea. either undefined id is justified here since the input is guaranteed to be well-formed, so the parser will never actually fail with a Left.

Unordered trees

The other difficulty is how to compare trees up to reordering children. Trying all permutations of the children at each node and seeing whether any match is obviously going to be much too slow! The key insight, and what this problem had in common with the one from my previous post, is that we can use an (automatically-derived) Ord instance to sort the children at each node into a canonical order. We don’t really need to know or care what order they end up in, which depends on the precise details of how the derived Ord instance works. The point is that sorting into some consistent order allows us to efficiently test whether two lists are permutations of each other.

I think everyone who posted a solution created some kind of function to “canonicalize” a tree, by first canonicalizing all subtrees and then sorting them. When I first solved this problem, however, I approached it along slightly different lines, hinted at by commenter Globules: can we define the Tree type in such a way that there is only a single representation for each tree-up-to-reordering?

My first idea was to use a Data.Set of children at each node, but this is subtly wrong, since it gets rid of duplicates! We don’t actually want a set of children at each node, but rather a bag (aka multiset). So I made a little Bag abstraction out of a Map. The magical thing is that GHC can still derive an Ord instance for my recursive tree type containing a newtype containing a Map containing trees! (OK, OK, it’s not really magic, but it still feels magic…)

Now, actually, I no longer think this is the best solution, but it’s interesting, so I’ll leave it. Later on I will show what I think is an even better solution.

newtype Tree = Node (Bag Tree)
  deriving (Eq, Ord)

newtype Bag a = Bag (Map a Int)
  deriving (Eq, Ord)

fromList :: Ord a => [a] -> Bag a
fromList = map (,1) >>> M.fromListWith (+) >>> Bag

The final piece is the solve function, which simply calls readTree on the two strings and compares the resulting (canonical!) Tree values for equality.

solve :: [C.ByteString] -> Bool
solve [t1,t2] = ((==) `on` readTree) t1 t2

A better way

I still think it’s a nice idea to have canonical-by-construction trees, rather than building ordered trees and then calling a separate function to canonicalize them afterwards. But inspired by several commenters’ solutions, I realized that rather than my complex Bag type, it’s much nicer to simply use a sorted list as the canonical representation of a Node’s bag of subtrees, and to use a smart constructor to build them:

newtype Tree = Node [Tree]
  deriving (Eq, Ord)

mkNode :: [Tree] -> Tree
mkNode = Node . sort

Then we just use mkNode instead of Node in the parser, and voilà! The canonicalization happens on the fly while parsing the tree. By contrast, if we write a separate canonicalization function, like

canonical :: Tree -> Tree
canonical (Node ts) = Node (map canonical (sort ts))

it is actually possible to get it wrong. In fact, I deliberately introduced a bug into the above function: can you see what it is?

All told, then, here is the (in my opinion) nicest solution that I know of:

{-# LANGUAGE OverloadedStrings #-}

import           Control.Arrow
import           Data.Bool
import qualified Data.ByteString.Lazy.Char8 as C
import           Data.Function
import           Data.List

import           Text.Parsec
import           ScannerBS                  hiding (many)

main = C.interact $
  runScanner (numberOf (two str)) >>>
  map (solve >>> bool "different" "same") >>> C.unlines

solve :: [C.ByteString] -> Bool
solve [t1,t2] = ((==) `on` readTree) t1 t2

newtype Tree = Node [Tree] deriving (Eq, Ord)

readTree :: C.ByteString -> Tree
readTree = parse parseTree "" >>> either undefined id
    parseTree    = (Node . sort) <$> many parseSubtree
    parseSubtree = char '0' *> parseTree <* char '1'

Next problem

For Tuesday, I invite you to solve The Power of Substitution. Don’t let the high difficulty rating scare you; in my estimation it should be quite accessible if you know a bit of math and have been following along with some of my previous posts (YMMV). However, it’s not quite as obvious what the nicest way to write it in Haskell is.


About Brent

Associate Professor of Computer Science at Hendrix College. Functional programmer, mathematician, teacher, pianist, follower of Jesus.
This entry was posted in competitive programming, haskell and tagged , , , , , , , , , . Bookmark the permalink.

26 Responses to Competitive programming in Haskell: building unordered trees

  1. Arnaud Bailly says:

    I am so stupid: I did not realize until now the string was simply a well-formed string of nested parens! Thanks to your parsec-based solution, it’s now obvious :)

  2. Arnaud Bailly says:

    Also, a quick and dirty “solution” to the Power of Substitution problem:
    I think laziness should make it efficient because in the end we’ll only need to compute and compare one item for each iteration of the message.

  3. Jason Hooper says:

    For Power of Substitution, the naive solution times out for me, so I used a lazily-evaluated array (bounds are (1,1)-(100,100), computing as-needed the distance from a number to its ciphered number). This brings the time down to 0.02s on kattis. However it’s not correct, the second test is failing. I suspect I’m off by a simple idea somewhere

    • Brent says:

      I can see what is wrong but I will only give a hint if you want one. =)

      • Jason Hooper says:

        Sure, thanks! I figured it might be zero distances causing the LCM to be 0 overall, but that wasn’t it. Can you suggest an input that the code wouldn’t work on?

        • Brent says:

          Sure, how about

          1 4
          3 5
          2 3 1 5 4 6 7 8 9 …

          where the … means to continue counting up to 100. I’ll let you figure out what the answer should be.

  4. Arnaud Bailly says:

    This problem is fun and I am completely stuck trying to go faster with mutable arrays but to no avail. I suspect I am missing some important property related to permutation groups that alleviate the need for repeatedly comparing the lists… Eagerly waiting the solution :)

  5. shaurya gupta says:

    I’m enjoying the current format of CP in Haskell. Here’s my solution to Power of Substitution.

    If we start walking from any point in the permutation graph, we will eventually land in a cycle.
    So I first brute force n (size of character set, here 100) steps, this ensure that I have either found my answer or I’m in a loop. After that, I find the length of the loop in which the character (i.e. node in permutation graph) lies and the distance between the current character and the final transformed character and used CRT to find the final answer.

    I am sure there must be a lot more elegant way of writing this than I have written here.
    The main parts are:

    solve :: ([Int], [Int], V.Vector Int) -> Int
    solve (m, c, p) = either id undefined $ go m 0 >>= go’
    go :: [Int] -> Int -> Either Int [Int]
    go m acc
    | m == c = Left acc
    | acc == n = Right m
    | otherwise = go (substitute m) (acc + 1)
    go’ :: [Int] -> Either Int [Int]
    go’ m =
    [(stepsTo substitute st nd, cycleSize substitute st) | (st, nd) (a -> a) -> a -> Int
    cycleSize f x = 1 + stepsTo f (f x) x

    stepsTo :: (Eq a) => (a -> a) -> a -> a -> Int
    stepsTo f st nd = go 0 st
    go acc x = if x == nd then acc else go (acc + 1) (f x)

    • shaurya gupta says:

      It would be nice if the comments systems had code blocks and the ability to make edits!

    • blaisepascal2014 says:

      I’m working on a similar idea, but I haven’t gotten my CRT code to work yet. I don’t have a full test, yet.

      The heart of my code is:

      solve kase = show $ fst $ foldr1 crt klPairs
      [_,message,ctext,perm] = map ((map ((- 1) . read)) . words) kase
      mcPairs = zip message ctext
      klPairs = map findKl mcPairs
      findKl (m, c) = (fromJust $ elemIndex c cycle, length cycle)
      cycle = cycleFrom perm m

      cycleFrom perm m = m : (takeWhile (/= m) (tail (iterate (perm!!) m )))

      mcPairs is a list of pairs of message characters and cyphertext characters. klPairs is a list of how many iterations of the permutation is necessary to bring the given plaintext character to the cyphertext character. For each symbol in the message, we have an equation n \equiv k_i \mod l_i, Folding over the list of (k_i,l_i) pairs solves for n.

  6. Connor Baker says:

    First, thank you for introducing me to Control.Arrow — I learned about them from your previous two write-ups. Similarly, thank you for introducing me to the `bool` and `chunksOf` functions!

    My “solution” (it times out on the second secret test case, so not really a solution, I think) uses Vector’s `unsafeBackpermute` for most of the heavy lifting:

    module Main where
    import Prelude
    import Control.Arrow
    import Data.List.Split ( chunksOf )
    import Data.Maybe ( fromJust )
    import qualified Data.ByteString.Lazy.Char8 as C
    import qualified Data.Vector.Unboxed as V
    Zero index
    parse :: [C.ByteString] -> [V.Vector Int]
    parse =
    map (V.fromList . map (flip (-) 1 . (fst . fromJust . C.readInt)) . C.words)
    Vector's unsafeBackpermute does the heavy lifting for us
    solve :: [V.Vector Int] -> Int
    solve [ms, cs, ps] =
    (length . takeWhile (/= cs) . iterate (V.unsafeBackpermute ps)) ms
    solve _ = error "Malformed test case."
    main :: IO ()
    main =
    $ C.lines Each line is an element of the list
    >>> drop 1 We don't need the number of test cases
    >>> chunksOf 4 Each test case has four lines
    >>> map
    ( drop 1 We don't need the first line of the test case
    >>> parse Transform the input to data
    >>> solve Manipulate the data
    >>> show Transform back to text
    >>> C.pack Transform to ByteString
    >>> C.unlines Print each result on a different line

    view raw


    hosted with ❤ by GitHub

    I’m very much looking forward to your writeup!

    • Connor Baker says:

      I don’t know why that gist is appearing inline (it didn’t before…) and I’m not able to edit my comment to try to replace it with just the link.

  7. Justin says:

    Struggling to get my solution running fast enough to pass the second test case. Does “CRT” mean the “Chinese Remainder Theorem” and is that the math trick that helps speed this up? Would love a hint :)

    • Brent says:

      Yes, and yes =). See my previous blog post on the topic!

      • Justin says:

        I think I figured out the connection. I started graphing the values for each iteration of E for a single character (e.g,. m_1, m_2, etc). I realized that, no matter where you start in the substitution table, you will eventually cycle back to that starting position. So you can calculate a “length” for each cycle, starting at a given index in the table. I realized those cycles where the connection to modular arithmetic. For example, if my cycle is 3 steps, then to find the minimum k necessary to get from a starting character to a final character, I can just find the position of the final character in the cycle (0, 1, … ).

        Solving for an entire message means finding where all those cycles line up (thats where graphing helped). The CRT provides the tool for finding where that happens.

        For any given starting position in the substitution, if you apply E enough, you will cycle back to that starting position. Then the lengh of each cycle is the modulus on the RHS of your congruence relation (one relation for each character in the message). The modulee (?) is the index of the particular cipher character in that cycle. In equations:

        x = steps(m_1, c_1) (mod cycle_length(p[m_1]])
        x = steps(m_2, c_2) (mod cycle_length(p[m_1]])

        m_1, m_2, are the characters in the message, p is the substitution table, and c_1, c_2 are the characters in the target message. steps is the number of iterations from m_1 to c_1, starting at p[m_1] (and so on). cycle_length(p[m_1]) is the number of times you have to iterate, starting at p[m_1] until you return to p[m_1].

        Solving CRT for x gives you the number of steps necessary to produce the desired ciphertext.

        Now I just have to implement it!

        p.s. Also there could be some fundamental errors that I’m gonna run into …

      • Justin says:

        Did it! Thanks for the hint. I also used your `egcd` and `gcrt` code because that would have taken me another week to solve!

        Code at

        Passed kattis in 0.01 seconds!

  8. Globules says:

    I won’t post my code here because it’s currently a sprawling mess, and I’ve only tried it on some small test cases (which have worked). But, here’s the gist of it from a comment in the code:

    -- The plaintext is [m1, m2, ..., ml]; the ciphertext is [c1, c2, ..., cl].
    -- The encryption function is a permutation, which I treat as a product of
    -- cycles.  One round of encryption takes each element of its input to the next
    -- element in its corresponding cycle, wrapping around as necessary.
    -- Define the "initial distance" between the plaintext and final ciphertext to
    -- be [d1, d2, ..., dl], where each di is the number of forward steps required
    -- to get from mi to ci in the cycle to which they both belong.
    -- Let L be the length of the longest cycle.  (There may be more than one.)
    -- Let D be a di belonging to one of these longest cycles.  (The value of k that
    -- we seek will be at least D.  It could be 0.)
    -- Define the new distance to be [(d1 - D) mod l1, (d2 - D) mod l2, ..., (dl -
    -- D) mod ll].
    -- If all the di are 0 then k = D, and we're done.  Otherwise k = D + r*L, where
    -- r is the number of times we must take L steps until all the di are 0.
    -- Analogy: I think of the permutation as a set of connected gears, each gear
    -- corresponding to a cycle, and having as many teeth as the length of its
    -- cycle.  The central gear has L teeth, with tooth mi at the top.  We need to
    -- rotate it by D to get tooth ci to the top.  If all the other gears point to
    -- their respective ci then we're done.  Otherwise, we need to make enough full
    -- rotations (r, above) of the central gear until the other gears are in their
    -- final positions.

    Currently, I’m not using the CRT in my code, but would expect to use it to solve for r, above. I.e. a set of equations of the form Lr = di (mod li), using my notation from the comment.

    • Globules says:

      Ok, it’s still a bit of a mess, but passes Kattis. I’ll only include part of it, below. It uses Brent’s gcrt function.

      type I = Integer
      -- The orbit of one element of a permutation.
      orbit :: [I] -> I -> [I]
      orbit xs x = x : takeWhile (/= x) (drop 1 $ iterate (genericIndex xs) x)
      -- A permutation as a product of cycles.
      cycles :: [I] -> [[I]]
      cycles ps = go ps
        where go []     = []
              go (x:xs) = let os = orbit ps x in os : go (xs \\ os)
      -- Map each element of a permutation to the cycle to which it belongs.
      cyclesMap :: Ord a => [[a]] -> [a] -> [(a, [a])]
      cyclesMap css = map cycleMap
        where cycleMap x = (x, fromJust $ find (x `elem`) css)
      -- The number of steps needed to get from y to z in xs, including any necessary
      -- wraparound.
      steps :: Eq a => [a] -> a -> a -> I
      steps xs y z = let dy = fromIntegral $ fromJust $ elemIndex y xs
                         dz = fromIntegral $ fromJust $ elemIndex z xs
                     in (dz - dy) `mod` genericLength xs
      solve :: [[I]] -> I
      solve args =
        let [_, ms, cs, ps] = map (map (subtract 1)) args -- 0-based lists
            cycs = cycles ps
            msgCycles = cyclesMap cycs ms
            cphCycles = cyclesMap cycs cs
            -- stps is [(s, l)], where s is # steps to reach elt in cs, l is len of cycle
            stps = zipWith (\(m, xs) (c, _) -> (steps xs m c, genericLength xs)) msgCycles cphCycles
        in fst $ fromJust $ gcrt stps
  9. Pingback: Resumen de lecturas compartidas durante mayo de 2020 | Vestigium

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

This site uses Akismet to reduce spam. Learn how your comment data is processed.