I want to first explain why the code in the question assigns level numbers. This will lead us directly to two different solutions, one passed on caching, one based on doing two traversals at once. Finally, I show how the second solution relates to the solutions provided by other answers.
What has to be changed in the code from the question?
The code in the question assigns the level number to each node. We can understand why the code behaves like that by looking at the recursive case of the number'
function:
number' a (Node x xl xr) = Node (a,x) (number' (a+1) xl) (number' (a+1) xr)
Note that we use the same number, a + 1
, for both recursive calls. So the root nodes in both subtrees will get assigned the same number. If we want each node to have a different number, we better pass different numbers to the recursive calls.
What number should we pass to the recursive call?
If we want to assign the numbers according to a left-to-right pre-order traversal, then a + 1
is correct for the recursive call on the left subtree, but not for the recursive call on the right subtree. Instead, we want to leave out enough numbers to annotate the whole left subtree, and then start annotating the right subtree with the next number.
How many numbers do we need to reserve for the left subtree? That depends on the subtree's size, as computed by this function:
size :: Tree a -> Int
size Empty = 0
size (Node _ xl xr) = 1 + size xl + size xr
Back to the recursive case of the number'
function. The smallest number annotated somewhere in the left subtree is a + 1
. The biggest number annotated somewhere in the left subtree is a + size xl
. So the smallest number available for the right subtree is a + size xl + 1
. This reasoning leads to the following implementation of the recursive case for number'
that works correctly:
number' :: Int -> Tree a -> Tree (Int, a)
number' a Empty = Empty
number' a (Node x xl xr) = Node (a,x) (number' (a+1) xl) (number' (a + size xl + 1) xr)
Unfortunately, there is a problem with this solution: It is unnecessarily slow.
Why is the solution with size
slow?
The function size
traverses the whole tree. The function number'
also traverses the whole tree, and it calls size
on all left subtrees. Each of these calls will traverse the whole subtree. So overall, the function size
gets executed more than once on the same node, even though it always returns the same value, of course.
How can we avoid traversing the tree when calling size
?
I know two solutions: Either we avoid traversing the tree in the implementation of size
by caching the sizes of all trees, or we avoid calling size
in the first place by numbering the nodes and computing the size in one traversal.
How can we compute the size without traversing the tree?
We cache the size in every tree node:
data Tree a = Empty | Node Int a (Tree a) (Tree a) deriving (Show)
size :: Tree a -> Int
size Empty = 0
size (Node n _ _ _) = n
Note that in the Node
case of size
, we just return the cached size. So this case is not recursive, and size
does not traverse the tree, and the problem with our implementation of number'
above goes away.
But the information about the size
has to come from somewhere! Everytime we create a Node
, we have to provide the correct size to fill the cache. We can lift this task off to smart constructors:
empty :: Tree a
empty = Empty
node :: a -> Tree a -> Tree a -> Tree a
node x xl xr = Node (size xl + size xr + 1) x xl xr
leaf :: a -> Tree a
leaf x = Node 1 x Empty Empty
Only node
is really necessary, but I added the other two for completeness. If we always use one of these three functions to create a tree, the cached size information will always be correct.
Here is the version of number'
that works with these definitions:
number' :: Int -> Tree a -> Tree (Int, a)
number' a Empty = Empty
number' a (Node _ x xl xr) = node (a,x) (number' (a+1) xl) (number' (a + size xl + 1) xr)
We have to adjust two things: When pattern matching on Node
, we ignore the size information. And when creating a Node
, we use the smart constructor node
.
That works fine, but it has the drawback of having to change the definition of trees. On the one hand, caching the size might be a good idea anyway, but on the other hand, it uses some memory and it forces the trees to be finite. What if we want to implement a fast number'
without changing the definition of trees? This brings us to the second solution I promised.
How can we number the tree without computing the size?
We cannot. But we can number the tree and compute the size in a single traversal, avoiding the multiple size
calls.
number' :: Int -> Tree a -> (Int, Tree (Int, a))
Already in the type signature, we see that this version of number'
computes two pieces of information: The first component of the result tuple is the size of the tree, and the second component is the annotated tree.
number' a Empty = (0, Empty)
number' a (Node x xl xr) = (sl + sr + 1, Node (a, x) yl yr) where
(sl, yl) = number' (a + 1) xl
(sr, yr) = number' (a + sl + 1) xr
The implementation decomposes the tuples from the recursive calls and composes the components of the result. Note that sl
is like size xl
from the previous solution, and sr
is like size xr
. We also have to name the annotated subtrees: yl
is the left subtree with node numbers, so it is like number' ... xl
in the previous solution, and yr
is the right subtree with node numbers, so it is like number' ... xr
in the previous solution.
We also have to change number
to only return the second component of the result of number'
:
number :: Tree a -> Tree (Int, a)
number = snd . number' 1
I think that in a way, this is the clearest solution.
What else could we improve?
The previous solution works by returning the size of the subtree. That information is then used to compute the next available node number. Instead, we could also return the next available node number directly.
number' a Empty = (a, Empty)
number' a (Node x xl xr) = (ar, Node (a, x) yl yr) where
(al, yl) = number' (a + 1) xl
(ar, yr) = number' al xr
Note that al
is like a + sl + 1
in the previous solution, and ar
is like a + sl + sr + 1
. Clearly, this change avoids some additions.
This is essentially the solution from Sergey's answer, and I would expect that this is the version most Haskellers would write. You could also hide the manipulations of a
, al
and ar
in a state monad, but I don't think that really helps for such a small example. The answer by Ankur shows how it would look like.