# KDTree example¶

'Note: there is an implementation of a kdtree in scipy: http://docs.scipy.org/scipy/docs/scipy.spatial.kdtree.KDTree/ It is recommended to use that instead of the below.'

This is an example of how to construct and search a kd-tree in Pythonwith NumPy. kd-trees are e.g. used to search for neighbouring data points in multidimensional space. Searching the kd-tree for the nearest neighbour of all n points has O(n log n) complexity with respect to sample size.

### Building a kd-tree¶

In [ ]:
#!python numbers=disable

# Copyleft 2008 Sturla Molden
# University of Oslo

#import psyco
#psyco.full()

import numpy

def kdtree( data, leafsize=10 ):
"""
build a kd-tree for O(n log n) nearest neighbour search

input:
data:       2D ndarray, shape =(ndim,ndata), preferentially C order
leafsize:   max. number of data points to leave in a leaf

output:
kd-tree:    list of tuples
"""

ndim = data.shape
ndata = data.shape

# find bounding hyper-rectangle
hrect = numpy.zeros((2,data.shape))
hrect[0,:] = data.min(axis=1)
hrect[1,:] = data.max(axis=1)

# create root of kd-tree
idx = numpy.argsort(data[0,:], kind='mergesort')
data[:,:] = data[:,idx]
splitval = data[0,ndata/2]

left_hrect = hrect.copy()
right_hrect = hrect.copy()
left_hrect[1, 0] = splitval
right_hrect[0, 0] = splitval

tree = [(None, None, left_hrect, right_hrect, None, None)]

stack = [(data[:,:ndata/2], idx[:ndata/2], 1, 0, True),
(data[:,ndata/2:], idx[ndata/2:], 1, 0, False)]

# recursively split data in halves using hyper-rectangles:
while stack:

# pop data off stack
data, didx, depth, parent, leftbranch = stack.pop()
ndata = data.shape
nodeptr = len(tree)

# update parent node

_didx, _data, _left_hrect, _right_hrect, left, right = tree[parent]

tree[parent] = (_didx, _data, _left_hrect, _right_hrect, nodeptr, right) if leftbranch \
else (_didx, _data, _left_hrect, _right_hrect, left, nodeptr)

# insert node in kd-tree

# leaf node?
if ndata <= leafsize:
_didx = didx.copy()
_data = data.copy()
leaf = (_didx, _data, None, None, 0, 0)
tree.append(leaf)

# not a leaf, split the data in two
else:
splitdim = depth % ndim
idx = numpy.argsort(data[splitdim,:], kind='mergesort')
data[:,:] = data[:,idx]
didx = didx[idx]
nodeptr = len(tree)
stack.append((data[:,:ndata/2], didx[:ndata/2], depth+1, nodeptr, True))
stack.append((data[:,ndata/2:], didx[ndata/2:], depth+1, nodeptr, False))
splitval = data[splitdim,ndata/2]
if leftbranch:
left_hrect = _left_hrect.copy()
right_hrect = _left_hrect.copy()
else:
left_hrect = _right_hrect.copy()
right_hrect = _right_hrect.copy()
left_hrect[1, splitdim] = splitval
right_hrect[0, splitdim] = splitval
# append node to tree
tree.append((None, None, left_hrect, right_hrect, None, None))

return tree


### Searching a kd-tree¶

In [ ]:
#!python numbers=disable

def intersect(hrect, r2, centroid):
"""
checks if the hyperrectangle hrect intersects with the
hypersphere defined by centroid and r2
"""
maxval = hrect[1,:]
minval = hrect[0,:]
p = centroid.copy()
idx = p < minval
p[idx] = minval[idx]
idx = p > maxval
p[idx] = maxval[idx]
return ((p-centroid)**2).sum() < r2

""" find K nearest neighbours of data among ldata """
ndata = ldata.shape
param = ldata.shape
K = K if K < ndata else ndata
retval = []
sqd = ((ldata - data[:,:ndata])**2).sum(axis=0) # data.reshape((param,1)).repeat(ndata, axis=1);
idx = numpy.argsort(sqd, kind='mergesort')
idx = idx[:K]
return zip(sqd[idx], lidx[idx])

def search_kdtree(tree, datapoint, K):
""" find the k nearest neighbours of datapoint in a kdtree """
stack = [tree]
knn = [(numpy.inf, None)]*K
_datapt = datapoint[:,0]
while stack:

leaf_idx, leaf_data, left_hrect, \
right_hrect, left, right = stack.pop()

# leaf
if leaf_idx is not None:
_knn = quadratic_knn_search(datapoint, leaf_idx, leaf_data, K)
if _knn < knn[-1]:
knn = sorted(knn + _knn)[:K]

# not a leaf
else:

# check left branch
if intersect(left_hrect, knn[-1], _datapt):
stack.append(tree[left])

# chech right branch
if intersect(right_hrect, knn[-1], _datapt):
stack.append(tree[right])
return knn

def knn_search( data, K, leafsize=2048 ):

""" find the K nearest neighbours for data points in data,
using an O(n log n) kd-tree """

ndata = data.shape
param = data.shape

# build kdtree
tree = kdtree(data.copy(), leafsize=leafsize)

# search kdtree
knn = []
for i in numpy.arange(ndata):
_data = data[:,i].reshape((param,1)).repeat(leafsize, axis=1);
_knn = search_kdtree(tree, _data, K+1)
knn.append(_knn[1:])

return knn

""" find all points within radius of datapoint """
stack = [tree]
inside = []
while stack:

leaf_idx, leaf_data, left_hrect, \
right_hrect, left, right = stack.pop()

# leaf
if leaf_idx is not None:
param=leaf_data.shape
distance = numpy.sqrt(((leaf_data - datapoint.reshape((param,1)))**2).sum(axis=0))
if len(near):
idx = leaf_idx[near]
distance = distance[near]
inside += (zip(distance, idx))

else:

stack.append(tree[left])

stack.append(tree[right])

return inside


### Quadratic search for small data sets¶

In contrast to the kd-tree, straight forward exhaustive search has quadratic complexity with respect to sample size. It can be faster than using a kd-tree when the sample size is very small. On my computer that is approximately 500 samples or less.

In [ ]:
#!python numbers=disable

def knn_search( data, K ):
""" find the K nearest neighbours for data points in data,
using O(n**2) search """
ndata = data.shape
knn = []
idx = numpy.arange(ndata)
for i in numpy.arange(ndata):
_knn = quadratic_knn_search(data[:,i], idx, data, K+1) # see above
knn.append( _knn[1:] )
return knn


### Parallel search for large data sets¶

While creating a kd-tree is very fast, searching it can be time consuming. Due to Python's dreaded "Global Interpreter Lock" (GIL), threads cannot be used to conduct multiple searches in parallel. That is, Python threads can be used for asynchrony but not concurrency. However, we can use multiple processes (multiple interpreters). The pyprocessing package makes this easy. It has an API similar to Python's threading and Queue standard modules, but work with processes instead of threads. Beginning with Python 2.6, pyprocessing is already included in Python's standard library as the "multiprocessing" module. There is a small overhead of using multiple processes, including process creation, process startup, IPC, and process termination. However, because processes run in separate address spaces, no memory contention is incurred. In the following example, the overhead of using multiple processes is very small compared to the computation, giving a speed-up close to the number of CPUs on the computer.

In [ ]:
#!python numbers=disable

try:
import multiprocessing as processing
except:
import processing

import ctypes, os

def __num_processors():
if os.name == 'nt': # Windows
return int(os.getenv('NUMBER_OF_PROCESSORS'))
else: # glibc (Linux, *BSD, Apple)
get_nprocs = ctypes.cdll.libc.get_nprocs
get_nprocs.restype = ctypes.c_int
get_nprocs.argtypes = []
return get_nprocs()

def __search_kdtree(tree, data, K, leafsize):
knn = []
param = data.shape
ndata = data.shape
for i in numpy.arange(ndata):
_data = data[:,i].reshape((param,1)).repeat(leafsize, axis=1);
_knn = search_kdtree(tree, _data, K+1)
knn.append(_knn[1:])
return knn

def __remote_process(rank, qin, qout, tree, K, leafsize):
while 1:
# read input queue (block until data arrives)
nc, data = qin.get()
# process data
knn = __search_kdtree(tree, data, K, leafsize)
# write to output queue
qout.put((nc,knn))

def knn_search_parallel(data, K, leafsize=2048):

""" find the K nearest neighbours for data points in data,
using an O(n log n) kd-tree, exploiting all logical
processors on the computer """

ndata = data.shape
param = data.shape
nproc = __num_processors()
# build kdtree
tree = kdtree(data.copy(), leafsize=leafsize)
# compute chunk size
chunk_size = data.shape / (4*nproc)
chunk_size = 100 if chunk_size < 100 else chunk_size
# set up a pool of processes
qin = processing.Queue(maxsize=ndata/chunk_size)
qout = processing.Queue(maxsize=ndata/chunk_size)
pool = [processing.Process(target=__remote_process,
args=(rank, qin, qout, tree, K, leafsize))
for rank in range(nproc)]
for p in pool: p.start()
# put data chunks in input queue
cur, nc = 0, 0
while 1:
_data = data[:,cur:cur+chunk_size]
if _data.shape == 0: break
qin.put((nc,_data))
cur += chunk_size
nc += 1
knn = []
while len(knn) < nc:
knn += [qout.get()]
# avoid race condition
_knn = [n for i,n in sorted(knn)]
knn = []
for tmp in _knn:
knn += tmp
# terminate workers
for p in pool: p.terminate()
return knn


### Running the code¶

The following shows how to run the example code (including how input data should be formatted):

In [ ]:
#!python numbers=disable

from time import clock

def test():
K = 11
ndata = 10000
ndim = 12
data =  10 * numpy.random.rand(ndata*ndim).reshape((ndim,ndata) )
knn_search(data, K)

if __name__ == '__main__':
t0 = clock()
test()
t1 = clock()
print "Elapsed time %.2f seconds" % t1-t0

#import profile          # using Python's profiler is not useful if you are
#profile.run('test()')   # running the parallel search.

In [ ]: