Tail Recursive Tree Traversal

I’ve written an ebook about Programming Language Concepts      Show me!

Once in a while you may be facing the situation that requires you to walk a tree, do some kind of processing on each node and collect a result (or do other things).

In case the tree is arbitrarily large and also arbitrarily branched, you have generally two options (as always): You can either traverse the tree imperatively or write an elegant, intuitive algorithm that is functional and recursive in nature. Well, we of course will take the second approach here.

The recursive approach however also poses a problem: If you don’t use an algorithm that is tail recursive, it may blow the stack. We want to avoid this of course. This is the reason for this blog post.

What is Tail Recursion?

First, for those who don’t have a clear undestanding yet, I’ll try to explain what tail recursion is and why it is so preferable to any other kind of recursion. I’ll try to explain it in a few different ways that are all meaning the same and finally give an example.

A recursive call is said to be tail recursive when it’s result is not involved in any further processing.

Rephrase 1:

A recursive call is tail recursive when it’s caller does not need to wait for it’s result due to any more things it has to do with it.

Rephrase 2:

A recursive call is tail recursive when it is the last thing the caller does.

Rephrase 3:

A recursive call is tail recursive when the result of this call can immediately be returned from the caller without any further steps to be done by the caller.

An Example

Let’s consider an example:

Not tail recursive:

(defn not-tail-recursive
  [n]
  (if (> n 0)
    (+ n (not-tail-recursive (dec n)))
    n))

A tail recursive equivalent:

(defn tail-recursive
  [n acc]
  (if (> n 0)
    (tail-recursive (dec n) (+ acc n))
    acc))

These two functions are not too useful, but they show the difference between tail recursion and non tail recursion.

The first function is not tail recursive, because after calling itself recursively, the result of that call still needs to be processed further, i.e. it needs to be added to n. That means the runtime environment needs to keep track of where each recursive call happened in order to being able to return to that point afterwards. This in turn means the stack grows larger and larger the more recursive calls are being nested.

Now to the second function. That one is in fact tail recursive as the result from the recursive call to itself does not need any further processing. It is in fact the very last action that the function tail-recursive makes and as such, it’s result can be returned immediately. This construction of code allows for something called Tail Call Optimization, which will avoid consuming more stack space for each tail recursive call. That in turn means we don’t have to worry about recursion depth anymore as our stack will not overflow due to tail call optimization.

Converting Non-Tail-Recursive to Tail-Recursive

So, how did we get from the first version to the second tail recursive one?
Actually, a common pattern is to introduce a so called “accumulator” variable that is being passed into the function. This variable holds the value of the calculation up to that point in time and thus allows us to get rid of the calculation that has to be done after the recursive call has returned (in our case: the addition). As you can see, we moved that addition from after (in a temporal sense) the recursive call to before it by passing the result of that very addition as a parameter into the function.

Tail Call Optimization in Clojure

Many compilers for functional programming languages will themselves detect if a recursive call is tail recursive, so they can apply tail call optimization.

Clojure however cannot do this (yet), as it depends on the JVM which does not (yet) support this. For that reason, Clojure features the recur statement, which the developer can explicitly use for making recursive calls. So, the second function from above would actually be written as

(defn tail-recursive
  [n acc]
  (if (> n 0)
    (recur (dec n) (+ acc n))
    acc))

which would enable tail call optimization due to the use of recur.

The Plan for our Tail Recursive Tree Traversal Algorithm

So, for working out a tail recursive method of traversing trees, we’ll walk through several steps:

  1. Start off with a recursive algorithm that is rather naive (i.e. not tail recursive)
  2. Convert that algorithm to it’s tail recursive equivalent

The naive Implementation

So let’s start off with a simple tree in form of a nested vector that looks like this:

user> (def tree [1 2 [3 4 [5 6] [7 8]] [9 10]])
#'user/tree

For the sake of simplicity, let’s say we want to traverse this tree and calculate the sum of all numbers in it. Note that all leaves are effectively numbers (even integers), so we don’t have to handle special cases like encountering a leaf that’s not a number.

Our first naive, but still functional implementation might look like the following:

user> (defn traverse1
  [tree]
  (if (coll? tree)
    (apply + (map traverse1 tree))
    tree))
#'user/traverse1

user> (traverse1 tree)
55

What are we doing here? Well, it’s not that complicated:
If tree is a collection then traverse each element of that collection, add the results and return that final result. If tree is not a collection then in our case we can assume its instead a number and so we do nothing with it but just return it. Seems to work fine as the final result of 55 tells us.

Improving it

However, this implementation is not tail recursive, not only because of the final addition of all results, but also because of the use of the function map, which will need to process all results for building a collection from them. So how can we get rid of this inconvenience?

Using an accumulator like above does not seem to be that easy here, as we need to do not only one recursive call, but arbitrarily many, due to the fact that we are dealing with a tree instead of a list.

The solution to this problem is something that you could call a stack (not the runtime stack, but an explicit one we create ourselves). When dealing with a node that has several child branches, we’ll put these on that stack and recurse:

user> (defn traverse2
  [[node & nodes] acc]
  (cond (coll? node) (recur (concat node nodes) acc)
        node (recur nodes (+ acc node))
       :else acc))
#'user/traverse2

user> (traverse2 tree 0)
55

As you can see, destructuring comes in handy here, as we can assign the first element of our stack to the variable node while leaving the rest in nodes. The algorithm goes like this:
If node is a collection then push node onto the stack and call yourself again.
If node is not a collection and neither something that evaluates to false (in our case this will only apply to numbers), add node to the current value of our computation (accumulator, acc) and call yourself again with the rest of the stack and the new accumulator value.
Otherwise, just return the value in our accumulator.

The fact that we are able to use recur for each recursive call tells us that it is in fact tail recursive, as the Clojure compiler would otherwise have complained.

Depth- or Breadth-Traversal

By the way: You can easily define the type of traversal (depth or breadth) by exchanging the stack with a queue. For our Clojure example here this doesn’t mean more than replacing

(concat node nodes)

with

(concat nodes node)

Recap

So for cases in which you are dealing with recursive tree traversal, you can convert your non-tail-recursive recursion into tail-recursive by adding an accumulator that holds the current value of your computation and putting the nodes of your tree onto a stack that allows you to process one node at a time as you can pop off one element from that stack and process it, instead of having to consider multiple nodes at once.

In Clojure, for simple and common traversal problems (like the one in this post), the provided functions in core.walk will most probably be sufficient (like prewalk), but I have encountered problems that I wasn’t able to solve with these convenience functions, so I had to implement my own traversal algorithm, so it’s always good to know how to tackle these problems, also for example when required to solve them in other languages than Clojure.

I’ve written an ebook about Programming Language Concepts      Show me!

Leave a Reply

Your email address will not be published. Required fields are marked *

*