Vantage Point Tree, generalization of the BK-tree

The data structure that I want to talk about is the Vantage Point Tree (a generalization of the BK-tree that is eloquently reviewed in Damn cool algorithms.

Each node of the tree contains one of the input points, and a radius. Under the left child are all points which are closer to the node’s point than the radius. The other child contains all of the points which are farther away. The tree requires no other knowledge about the items in it.

datacentre-290x230-thinkstock

How searching a VP-Tree works

Let us examine one of these nodes in detail, and what happens during a recursive search for the nearest neighbours to a target.

Suppose we want to find the two nearest neighbours to the target, marked with the red X. Since we have no points yet, the node’s center p is the closest candidate, and we add it to the list of results. (It might be bumped out later). At the same time, we update our variable tau which tracks the distance of the farthest point that we have in our results.

Then, we have to decide whether to search the left or right child first. We may end up having to search them both, so we must be careful about it.

Since the target is closer to the node’s center than its outer shell, we search the left child first, which contains all of the points closer than the radius. We find the blue point. Since it is little farther away than tau we update the tau value.

Do we need to continue the search? We know that we have considered all the points that are within the distance radius of p. However, it is closer to get to the outer shell than the farthest point that we have found. Therefore there could be closer points just outside of the shell. We do need to descend into the right child to find the green point.

If, however, we had reached our goal of collecting the n nearest points, and the target point is farther from the the outer shell than the farthest point that we have collected, then we could have stopped looking. This results in significant savings.

Implementation

Here is an implementation of the VP Tree in C++. The recursive search() function decides whether to follow the left, right, or both children. To efficiently maintain the list of results, we use a priority queue.

I tried it out on a database, and the VP tree search was faster than a linear search through all the points. You can download the C++ program that uses the VP tree for this purpose here : amrita

It is worth repeating that you must use a distance metric that satisfies the triangle inequality. I spent a lot of time wondering why my VP tree was not working. It turns out that I had not bothered to find the square root in the distance calculation. This step is important to satisfy the requirements of a metric space, because if the straight line distance to a <= b+c, it does not necessarily follow that a2 <= b2 + c2.

Here is the output of the program when you search for cities by latitude and longitude.

Create took 15484122
Search took 36
ca,waterloo,Waterloo,08,43.4666667,-80.5333333
 0.0141501
ca,kitchener,Kitchener,08,43.45,-80.5
 0.025264
ca,bridgeport,Bridgeport,08,43.4833333,-80.4833333
 0.0396333
ca,elmira,Elmira,08,43.6,-80.55
 0.137071
ca,baden,Baden,08,43.4,-80.6666667
 0.161756
ca,floradale,Floradale,08,43.6166667,-80.5833333
 0.163351
ca,preston,Preston,08,43.4,-80.35
 0.181762
ca,ayr,Ayr,08,43.2833333,-80.45
 0.195739
---
Linear search took 143212
ca,waterloo,Waterloo,08,43.4666667,-80.5333333
 0.0141501
ca,kitchener,Kitchener,08,43.45,-80.5
 0.025264
ca,bridgeport,Bridgeport,08,43.4833333,-80.4833333
 0.0396333
ca,elmira,Elmira,08,43.6,-80.55
 0.137071
ca,baden,Baden,08,43.4,-80.6666667
 0.161756
ca,floradale,Floradale,08,43.6166667,-80.5833333
 0.163351
ca,preston,Preston,08,43.4,-80.35
 0.181762
ca,ayr,Ayr,08,43.2833333,-80.45
 0.195739

Construction

I’m too lazy to implement a delete or insert function. It is most efficient to simply build the tree by repeatedly partitioning the data. We build the tree from the top down from an array of items. For each node, we first choose a point at random, and then partition the list into two sets: The left children contain the points farther away than the median, and the right contains the points that are closer than the median. Then we recursively repeat this until we have run out of points.

// A VP-Tree implementation

// Based on “Data Structures and Algorithms for Nearest Neighbor Search” by Peter N. Yianilos

#include <stdlib.h>

#include <algorithm>

#include <vector>

#include <stdio.h>

#include <queue>

#include <limits>

 

template<typename T, double (*distance)( const T&, const T& )>

class VpTree

{

public:

    VpTree() : _root(0) {}

 

    ~VpTree() {

        delete _root;

    }

 

    void create( const std::vector& items ) {

        delete _root;

        _items = items;

        _root = buildFromPoints(0, items.size());

    }

 

    void search( const T& target, int k, std::vector* results,

        std::vector<double>* distances)

    {

        std::priority_queue<HeapItem> heap;

 

        _tau = std::numeric_limits::max();

        search( _root, target, k, heap );

 

        results->clear(); distances->clear();

 

        while( !heap.empty() ) {

            results->push_back( _items[heap.top().index] );

            distances->push_back( heap.top().dist );

            heap.pop();

        }

 

        std::reverse( results->begin(), results->end() );

        std::reverse( distances->begin(), distances->end() );

    }

 

private:

    std::vector<T> _items;

    double _tau;

 

    struct Node

    {

        int index;

        double threshold;

        Node* left;

        Node* right;

 

        Node() :

            index(0), threshold(0.), left(0), right(0) {}

 

        ~Node() {

            delete left;

            delete right;

        }

    }* _root;

 

    struct HeapItem {

        HeapItem( int index, double dist) :

            index(index), dist(dist) {}

        int index;

        double dist;

        bool operator<( const HeapItem& o ) const {

            return dist < o.dist;  

        }

    };

 

    struct DistanceComparator

    {

        const T& item;

        DistanceComparator( const T& item ) : item(item) {}

        bool operator()(const T& a, const T& b) {

            return distance( item, a ) < distance( item, b );

        }

    };

 

    Node* buildFromPoints( int lower, int upper )

    {

        if ( upper == lower ) {

            return NULL;

        }

 

        Node* node = new Node();

        node->index = lower;

 

        if ( upper – lower > 1 ) {

 

            // choose an arbitrary point and move it to the start

            int i = (int)((double)rand() / RAND_MAX * (upper – lower – 1) ) + lower;

            std::swap( _items[lower], _items[i] );

 

            int median = ( upper + lower ) / 2;

 

            // partitian around the median distance

            std::nth_element(

                _items.begin() + lower + 1,

                _items.begin() + median,

                _items.begin() + upper,

                DistanceComparator( _items[lower] ));

 

            // what was the median?

            node->threshold = distance( _items[lower], _items[median] );

 

            node->index = lower;

            node->left = buildFromPoints( lower + 1, median );

            node->right = buildFromPoints( median, upper );

        }

 

        return node;

    }

 

    void search( Node* node, const T& target, int k,

                 std::priority_queue& heap )

    {

        if ( node == NULL ) return;

 

        double dist = distance( _items[node->index], target );

        //printf(“dist=%g tau=%gn”, dist, _tau );

 

        if ( dist < _tau ) {

            if ( heap.size() == k ) heap.pop();

            heap.push( HeapItem(node->index, dist) );

            if ( heap.size() == k ) _tau = heap.top().dist;

        }

 

        if ( node->left == NULL && node->right == NULL ) {

            return;

        }

 

        if ( dist < node->threshold ) {

            if ( dist – _tau <= node->threshold ) {

                search( node->left, target, k, heap );

            }

 

            if ( dist + _tau >= node->threshold ) {

                search( node->right, target, k, heap );

            }

 

        } else {

            if ( dist + _tau >= node->threshold ) {

                search( node->right, target, k, heap );

            }

 

            if ( dist – _tau <= node->threshold ) {

                search( node->left, target, k, heap );

            }

        }

    }

};

P.S.: I tried to write and explain the Vantage Point concept in lucid way! Hope it benefits the readers : )