Note

Under development :) :) :)

Overview

In a previous post, we looked into what k-d trees are. In this post we want to go deeper into this view and attempt to implement a k-d tree in C++. The final code can be found here. Furthermore, we will follow the implementation from the excellent book of Marcello La Rocca Advanced algorithms and data structures by Manning Publications.

K-d trees part 2

Before starting implementing a k-d tree let's recall that this a binary search tree i.e. a hierarchical data structure. Specifically, a k-d tree is a space partitioning data structure for organizing points in a k-dimensional space [1]. In a k-d tree every node in the tree represents a k-dimensional point [2]. Furthermore, we will assume that the coordinates of k-dimensional vector can be compared with each other.

Following [2], here is the exposed API:

template<typename NodeType>
class KDTree
{
public:

 typedef NodeType node_type;
 typedef typename node_type::data_type data_type;

 KDTree(uint_t k);

 template<typename Iterator, typename SimilarityPolicy, typename ComparisonPolicy>
 KDTree(uint_t k, Iterator begin, Iterator end, 
        const SimilarityPolicy& sim_policy, 
        const ComparisonPolicy& comp_policy);

 bool empty()const noexcept;
 uint_t size()const noexcept;
 uint_t dim()const noexcept;

 template<typename ComparisonPolicy>
 std::shared_ptr<node_type>
 search(const data_type& data, const ComparisonPolicy& comp_policy)const;

 template<typename Iterator, typename SimilarityPolicy, typename ComparisonPolicy>
 void build(Iterator begin, Iterator end,
             const SimilarityPolicy& sim_policy, 
             const ComparisonPolicy& comp_policy);

 template<typename ComparisonPolicy>
 std::shared_ptr<node_type>
 insert(const data_type& data, const ComparisonPolicy& comp_policy);

 template<typename ComparisonPolicy>
 std::vector<std::pair<typename ComparisonPolicy::value_type, typename NodeType::data_type>>
 nearest_search(const data_type& data, uint_t n, const ComparisonPolicy& calculator)const;

};

The class above accepts the tree node as a generic parameter that exposes the type of the data to be stored. In this perspective, the KDTree is a homogeneous container.

According to the exposed API we can construct a k-d tree in two ways; by specifying the size of the space or by passing a range of data to be stored in the tree. The first construct actually creates an empty tree. We can populate this tree by calling either insert or preferably build. We will explain below why this is the preferred method.

We can see that the exposed API does not have a remove or delete method. Typically, a k-d tree is constructed as remains as is. Furthermore, removing a node may result in an unbalanced tree which implies that the fast look up will not hold any more. Although it is possible to re-balance the tree see e.g. [2], we won't pursue this path here. So let's concentrate on the rest of the methods. Perhaps the most important of which is the build method.

In our implementation we distinguish between similarity and comparison. A similarity metric, or policy, is used in order to decide whether two points are similar or close enough in the given metric. A comparison policy is used in order to compare coordinates of points. Thus, we use the similarity policy to search in the tree. And we use the comparison policy whenever strict comparison of point coordinates is needed.

Building the tree

The build function accepts a range of iterators pointing to the data, a similarity policy and a comparison policy. It delegates all its to the call_ function; a private to the outside world class that implements the nuts and bolds of building the tree. The call_ function definition is shown below

template<typename NodeType>
template<typename Iterator, typename SimilarityPolicy, 
         typename ComparisonPolicy>
void
KDTree<NodeType>::create_(Iterator begin, Iterator end, 
                          uint_t level, 
                          const SimilarityPolicy& sim_policy, 
                          const ComparisonPolicy& comp_policy){


    auto n_points = std::distance(begin, end);

    // nothing to do if no points
    // are given
    if(n_points == 0){
        return ;
    }

    if(n_points == 1){

        auto data = *begin;

        // create the root
        root_ = std::make_shared<NodeType>(level, data, nullptr, nullptr);
        ++n_nodes_;
        return;
    }

    // otherwise partition the range
    auto [median, left, right] = detail::partiion_on_median(begin, end, level, k_, comp_policy);

    // create root
    root_ = std::make_shared<NodeType>(level, median, nullptr, nullptr);
    ++n_nodes_;

    // create left and right subtrees
    auto left_tree = do_create_(left.first, left.second, level + 1, sim_policy, comp_policy);

    // create left and right subtrees
    auto right_tree = do_create_(right.first, right.second, level + 1, sim_policy, comp_policy);

    root_->left = left_tree;
    root_->right = right_tree;
}

The implementation above is fairly straightforward. However, let's go over a few details. The detail::partiion_on_median accepts a range of points the current tree level and the comp_policy and returns the median point at this level and the data that is left and right to the calculated median.

template<typename Iterator, typename ComparisonPolicy>
std::tuple<typename std::iterator_traits<Iterator>::value_type,
           std::pair<Iterator, Iterator>,
           std::pair<Iterator, Iterator>>
partiion_on_median(Iterator begin, Iterator end, 
                   uint_t level, uint_t k, 
                   const ComparisonPolicy& comp_policy){

    ...

    // the median index
    auto median_idx = n_points % 2 == 0 ? (n_points + 1) / 2 : n_points / 2;

    // how to compare the data at the given
    // level. We use the level % k operation to decide
    // which coordinate to use
    auto compare = [&](const value_type& v1, const value_type& v2){
        auto idx = level % k;
        return comp_policy(v1, v2, idx); //v1[idx] < v2[idx];
    };


    // rearrange the elements. Do partial sorting
    // 
    std::nth_element(begin, begin + median_idx, end , compare);

    // get the data corresponding to the median
    auto median = *(begin + median_idx);

    // create the left are right sub-trees
    auto left = std::make_pair<Iterator,   Iterator>(std::forward<Iterator>(begin), begin + median_idx);
    auto right = std::make_pair<Iterator,   Iterator>(begin + median_idx + 1, std::forward<Iterator>(end));
    return std::make_tuple(median, left, right);

}

The implementation above uses std::nth_element to partially sort the elements. According to the documentation:

_nth_element is a partial sorting algorithm that rearranges elements in [first, last) such that_

  • The element pointed at by nth is changed to whatever element would occur in that position if [first, last) were sorted.
  • All of the elements before this new nth element are less than or equal to the elements after the new nth element.

We then recursively build the tree using do_create_

References

  1. k-d tree.
  2. Marcello La Rocca, Advanced algorithms and data structures, Manning Publications.