K-d trees part 2
Implement a k-d tree in C++
Under development :) :) :)
In a previous post, we looked into what k-d trees are. In this post we want to go deeper into this view and attempt to implement a k-d tree in C++. The final code can be found here. Furthermore, we will follow the implementation from the excellent book of Marcello La Rocca Advanced algorithms and data structures by Manning Publications.
Before starting implementing a k-d tree let's recall that this a binary search tree i.e. a hierarchical data structure. Specifically, a k-d tree is a space partitioning data structure for organizing points in a k-dimensional space [1]. In a k-d tree every node in the tree represents a k-dimensional point [2]. Furthermore, we will assume that the coordinates of k-dimensional vector can be compared with each other.
Following [2], here is the exposed API:
template<typename NodeType>
class KDTree
{
public:
typedef NodeType node_type;
typedef typename node_type::data_type data_type;
KDTree(uint_t k);
template<typename Iterator, typename SimilarityPolicy, typename ComparisonPolicy>
KDTree(uint_t k, Iterator begin, Iterator end,
const SimilarityPolicy& sim_policy,
const ComparisonPolicy& comp_policy);
bool empty()const noexcept;
uint_t size()const noexcept;
uint_t dim()const noexcept;
template<typename ComparisonPolicy>
std::shared_ptr<node_type>
search(const data_type& data, const ComparisonPolicy& comp_policy)const;
template<typename Iterator, typename SimilarityPolicy, typename ComparisonPolicy>
void build(Iterator begin, Iterator end,
const SimilarityPolicy& sim_policy,
const ComparisonPolicy& comp_policy);
template<typename ComparisonPolicy>
std::shared_ptr<node_type>
insert(const data_type& data, const ComparisonPolicy& comp_policy);
template<typename ComparisonPolicy>
std::vector<std::pair<typename ComparisonPolicy::value_type, typename NodeType::data_type>>
nearest_search(const data_type& data, uint_t n, const ComparisonPolicy& calculator)const;
};
The class above accepts the tree node as a generic parameter that exposes the type of the data to be stored. In this perspective, the KDTree
is a homogeneous container.
According to the exposed API we can construct a k-d tree in two ways; by specifying the size of the space or by passing a range of data to be stored in the tree. The first construct actually creates an empty tree. We can populate this tree by calling either insert
or preferably build
. We will explain below why this is the preferred method.
We can see that the exposed API does not have a remove
or delete
method. Typically, a k-d tree is constructed as remains as is. Furthermore, removing a node may result in an unbalanced tree which implies that the fast look up will not hold any more. Although it is possible to re-balance the tree see e.g. [2], we won't pursue this path here. So let's concentrate on the rest of the methods. Perhaps the most important of which is the build
method.
In our implementation we distinguish between similarity and comparison. A similarity metric, or policy, is used in order to decide whether two points are similar or close enough in the given metric. A comparison policy is used in order to compare coordinates of points. Thus, we use the similarity policy to search in the tree. And we use the comparison policy whenever strict comparison of point coordinates is needed.
The build
function accepts a range of iterators pointing to the data, a similarity policy and a comparison policy. It delegates all its to the call_
function; a private to the outside world class that implements the nuts and bolds of building the tree. The call_
function definition is shown below
template<typename NodeType>
template<typename Iterator, typename SimilarityPolicy,
typename ComparisonPolicy>
void
KDTree<NodeType>::create_(Iterator begin, Iterator end,
uint_t level,
const SimilarityPolicy& sim_policy,
const ComparisonPolicy& comp_policy){
auto n_points = std::distance(begin, end);
// nothing to do if no points
// are given
if(n_points == 0){
return ;
}
if(n_points == 1){
auto data = *begin;
// create the root
root_ = std::make_shared<NodeType>(level, data, nullptr, nullptr);
++n_nodes_;
return;
}
// otherwise partition the range
auto [median, left, right] = detail::partiion_on_median(begin, end, level, k_, comp_policy);
// create root
root_ = std::make_shared<NodeType>(level, median, nullptr, nullptr);
++n_nodes_;
// create left and right subtrees
auto left_tree = do_create_(left.first, left.second, level + 1, sim_policy, comp_policy);
// create left and right subtrees
auto right_tree = do_create_(right.first, right.second, level + 1, sim_policy, comp_policy);
root_->left = left_tree;
root_->right = right_tree;
}
The implementation above is fairly straightforward. However, let's go over a few details. The detail::partiion_on_median
accepts a range of points the current tree level
and the comp_policy
and returns the median point at this level
and the data that is left and right to the calculated median.
template<typename Iterator, typename ComparisonPolicy>
std::tuple<typename std::iterator_traits<Iterator>::value_type,
std::pair<Iterator, Iterator>,
std::pair<Iterator, Iterator>>
partiion_on_median(Iterator begin, Iterator end,
uint_t level, uint_t k,
const ComparisonPolicy& comp_policy){
...
// the median index
auto median_idx = n_points % 2 == 0 ? (n_points + 1) / 2 : n_points / 2;
// how to compare the data at the given
// level. We use the level % k operation to decide
// which coordinate to use
auto compare = [&](const value_type& v1, const value_type& v2){
auto idx = level % k;
return comp_policy(v1, v2, idx); //v1[idx] < v2[idx];
};
// rearrange the elements. Do partial sorting
//
std::nth_element(begin, begin + median_idx, end , compare);
// get the data corresponding to the median
auto median = *(begin + median_idx);
// create the left are right sub-trees
auto left = std::make_pair<Iterator, Iterator>(std::forward<Iterator>(begin), begin + median_idx);
auto right = std::make_pair<Iterator, Iterator>(begin + median_idx + 1, std::forward<Iterator>(end));
return std::make_tuple(median, left, right);
}
The implementation above uses std::nth_element to partially sort the elements. According to the documentation:
_nth_element
is a partial sorting algorithm that rearranges elements in [first, last)
such that_
- The element pointed at by nth is changed to whatever element would occur in that position if
[first, last)
were sorted. - All of the elements before this new nth element are less than or equal to the elements after the new nth element.
We then recursively build the tree using do_create_
- k-d tree.
- Marcello La Rocca, Advanced algorithms and data structures, Manning Publications.