[PATCH] New function: find_k_nearest()
Willi Richert
w.richert at gmx.net
Tue Mar 31 14:56:55 UTC 2009
commit 0f4bc35b3801abef9642f5fea573b33c411dc3a5
Author: Willi Richert <w.richert at gmx.net>
Date: Tue Mar 31 16:54:31 2009 +0200
New function: find_k_nearest()
Besides:
- better checks for memory leaks in swig
- cleaned Makefile
- better Python tests
- fixed C++ test file
Signed-off-by: Willi Richert <w.richert at gmx.net>
---
python-bindings/Makefile | 27 ++-
python-bindings/py-kdtree.hpp | 210 --------------
python-bindings/py-kdtree.hpp.tmpl | 265 +++++++++++++++++
python-bindings/py-kdtree.i | 545 ------------------------------------
python-bindings/py-kdtree.i.tmpl | 33 +++
python-bindings/py-kdtree_test.cpp | 67 +++--
python-bindings/py-kdtree_test.py | 229 +++++++++++++--
7 files changed, 553 insertions(+), 823 deletions(-)
delete mode 100644 python-bindings/py-kdtree.hpp
create mode 100644 python-bindings/py-kdtree.hpp.tmpl
delete mode 100644 python-bindings/py-kdtree.i
create mode 100644 python-bindings/py-kdtree.i.tmpl
diff --git a/python-bindings/Makefile b/python-bindings/Makefile
index 6aeb33c..ea96e5c 100644
--- a/python-bindings/Makefile
+++ b/python-bindings/Makefile
@@ -2,41 +2,50 @@
INCLUDE_DIR=..
CXX=g++
+PY_INCLUDE_DIR=$(shell python-config --includes)
+
# CPPFLAGS is used by the default rules. Using "override" and "+="
# allows the user to prepend things to CPPFLAGS on the command line.
override CPPFLAGS += -I$(INCLUDE_DIR) -pedantic -Wno-long-long -Wall -ansi -pedantic
# These options are set by the configure script.
-override CPPFLAGS += -DHAVE_CONFIG_H
+override CPPFLAGS += -DHAVE_CONFIG_H -fPIC
ifeq ($(MINUSG),1)
-override CPPFLAGS += -g
+override CPPFLAGS += -g
else
-override CPPFLAGS += -O3
+override CPPFLAGS += -O3
endif
ifeq ($(MINUSPG),1)
override CPPFLAGS += -pg
endif
+current: py-kdtree
+ python py-kdtree_test.py
# swig bindings
py-kdtree: _kdtree.so
cp _kdtree.so kdtree.py ../../
-py-kdtree_test: py-kdtree.hpp py-kdtree_test.cpp
- $(CXX) $(CPPFLAGS) -I/usr/include/python2.5 -c -o py-kdtree_test.o py-kdtree_test.cpp
+py-kdtree_test: py-kdtree.hpp py-kdtree_test.cpp
+ $(CXX) $(CPPFLAGS) $(PY_INCLUDE_DIR) -c -o py-kdtree_test.o py-kdtree_test.cpp
$(CXX) $(CPPFLAGS) $(LDLIBS) py-kdtree_test.o -o py-kdtree_test
-py-kdtree_wrap.cxx: py-kdtree.i py-kdtree.hpp
- swig -python -modern -c++ py-kdtree.i
+py-kdtree.hpp: py-kdtree.i.tmpl py-kdtree.hpp.tmpl gen-swig-hpp.py
+ python gen-swig-hpp.py
+py-kdtree.i: py-kdtree.i.tmpl py-kdtree.hpp.tmpl gen-swig-hpp.py
+ python gen-swig-hpp.py
+
+py-kdtree_wrap.cxx: py-kdtree.i py-kdtree.hpp py-kdtree.i.tmpl py-kdtree.hpp.tmpl
+ swig -python -modern -c++ py-kdtree.i
_kdtree.so: py-kdtree_wrap.cxx
- $(CXX) $(CPPFLAGS) -c py-kdtree_wrap.cxx -I/usr/include/python2.5 -I$(INCLUDE_DIR)
+ $(CXX) $(CPPFLAGS) -c py-kdtree_wrap.cxx $(PY_INCLUDE_DIR)
$(CXX) $(CPPFLAGS) -shared py-kdtree_wrap.o $(LDLIBS) -o _kdtree.so
clean:
- rm -f test_kdtree *.so py-kdtree_wrap.cxx *.o py-kdtree_test kdtree.py *.pyc
+ rm -f test_kdtree *.so py-kdtree_wrap.cxx *.o dkdtree.py py-kdtree_test kdtree.py *.pyc py-kdtree.i py-kdtree.hpp
.PHONY: clean
diff --git a/python-bindings/py-kdtree.hpp b/python-bindings/py-kdtree.hpp
deleted file mode 100644
index ec78d33..0000000
--- a/python-bindings/py-kdtree.hpp
+++ /dev/null
@@ -1,210 +0,0 @@
-/** \file
- * Provides a Python interface for the libkdtree++.
- *
- * \author Willi Richert <w.richert at gmx.net>
- *
- *
- * This defines a proxy to a (int, int) -> long long KD-Tree. The long
- * long is needed to save a reference to Python's object id(). Thereby,
- * you can associate Python objects with 2D integer points.
- *
- * If you want to customize it you can adapt the following:
- *
- * * Dimension of the KD-Tree point vector.
- * * DIM: number of dimensions.
- * * operator==() and operator<<(): adapt to the number of comparisons
- * * py-kdtree.i: Add or adapt all usages of PyArg_ParseTuple() to reflect the
- * number of dimensions.
- * * adapt query_records in find_nearest() and count_within_range()
- * * Type of points.
- * * coord_t: If you want to have e.g. floats you have
- * to adapt all usages of PyArg_ParseTuple(): Change "i" to "f" e.g.
- * * Type of associated data.
- * * data_t: currently unsigned long long, which is "L" in py-kdtree.i
- * * PyArg_ParseTuple() has to be changed to reflect changes in data_t
- *
- */
-
-
-#ifndef _PY_KDTREE_H_
-#define _PY_KDTREE_H_
-
-#include <kdtree++/kdtree.hpp>
-
-#include <iostream>
-#include <vector>
-#include <limits>
-
-template <size_t DIM, typename COORD_T, typename DATA_T >
-struct record_t {
- static const size_t dim = DIM;
- typedef COORD_T coord_t;
- typedef DATA_T data_t;
-
- typedef coord_t point_t[dim];
-
- inline coord_t operator[](size_t const N) const { return point[N]; }
-
- point_t point;
- data_t data;
-};
-
-////////////////////////////////////////////////////////////////////////////////
-// Definition of (int, int) points that has an unsigned long long as payload
-////////////////////////////////////////////////////////////////////////////////
-#define RECORD_2il record_t<2, int, unsigned long long>
-#define KDTREE_TYPE_2il KDTree::KDTree<2, RECORD_2il, std::pointer_to_binary_function<RECORD_2il,int,double> >
-
-inline bool operator==(RECORD_2il const& A, RECORD_2il const& B) {
- return A.point[0] == B.point[0] && A.point[1] == B.point[1] && A.data == B.data;
-}
-
-std::ostream& operator<<(std::ostream& out, RECORD_2il const& T)
-{
- return out << '(' << T.point[0] << ',' << T.point[1] << '|' << T.data << ')';
-}
-
-////////////////////////////////////////////////////////////////////////////////
-// Definition of (int, int, int, int) points that has an unsigned long long as payload
-////////////////////////////////////////////////////////////////////////////////
-#define RECORD_4il record_t<4, int, unsigned long long>
-#define KDTREE_TYPE_4il KDTree::KDTree<4, RECORD_4il, std::pointer_to_binary_function<RECORD_4il,int,double> >
-
-inline bool operator==(RECORD_4il const& A, RECORD_4il const& B) {
- return A.point[0] == B.point[0] && A.point[1] == B.point[1] &&
- A.point[2] == B.point[2] && A.point[3] == B.point[3] &&
- A.data == B.data;
-}
-
-std::ostream& operator<<(std::ostream& out, RECORD_4il const& T)
-{
- return out << '(' << T.point[0] << ',' << T.point[1] << ',' << T.point[2] << ',' << T.point[3] << '|' << T.data << ')';
-}
-
-////////////////////////////////////////////////////////////////////////////////
-// Definition of (float) points that has an unsigned long long as payload
-////////////////////////////////////////////////////////////////////////////////
-#define RECORD_1fl record_t<1, float, unsigned long long>
-#define KDTREE_TYPE_1fl KDTree::KDTree<1, RECORD_1fl, std::pointer_to_binary_function<RECORD_1fl,int,double> >
-
-inline bool operator==(RECORD_1fl const& A, RECORD_1fl const& B) {
- return A.point[0] == B.point[0] &&
- A.data == B.data;
-}
-
-
-////////////////////////////////////////////////////////////////////////////////
-// Definition of (float, float, float) points that has an unsigned long long as payload
-////////////////////////////////////////////////////////////////////////////////
-#define RECORD_3fl record_t<3, float, unsigned long long>
-#define KDTREE_TYPE_3fl KDTree::KDTree<3, RECORD_3fl, std::pointer_to_binary_function<RECORD_3fl,int,double> >
-
-inline bool operator==(RECORD_3fl const& A, RECORD_3fl const& B) {
- return A.point[0] == B.point[0] && A.point[1] == B.point[1] &&
- A.point[2] == B.point[2] &&
- A.data == B.data;
-}
-
-////////////////////////////////////////////////////////////////////////////////
-// Definition of (float, float, float, float, float, float) points that has an unsigned long long as payload
-////////////////////////////////////////////////////////////////////////////////
-#define RECORD_6fl record_t<6, float, unsigned long long>
-#define KDTREE_TYPE_6fl KDTree::KDTree<6, RECORD_6fl, std::pointer_to_binary_function<RECORD_6fl,int,double> >
-
-inline bool operator==(RECORD_6fl const& A, RECORD_6fl const& B) {
- return A.point[0] == B.point[0] && A.point[1] == B.point[1] &&
- A.point[2] == B.point[2] && A.point[3] == B.point[3] &&
- A.point[4] == B.point[4] && A.point[5] == B.point[5] &&
- A.data == B.data;
-}
-
-////////////////////////////////////////////////////////////////////////////////
-// END OF TYPE SPECIFIC DEFINITIONS
-////////////////////////////////////////////////////////////////////////////////
-
-
-template <class RECORD_T>
-inline double tac(RECORD_T r, int k) { return r[k]; }
-
-template <size_t DIM, typename COORD_T, typename DATA_T >
-class PyKDTree {
-public:
-
- typedef record_t<DIM, COORD_T, DATA_T> RECORD_T;
- typedef KDTree::KDTree<DIM, RECORD_T, std::pointer_to_binary_function<RECORD_T,int,double> > TREE_T;
- TREE_T tree;
-
- PyKDTree() : tree(std::ptr_fun(tac<RECORD_T>)) { };
-
- void add(RECORD_T T) { tree.insert(T); };
-
- /**
- Exact erase.
- */
- bool remove(RECORD_T T) {
- bool removed = false;
-
- typename TREE_T::const_iterator it = tree.find_exact(T);
- if (it!=tree.end()) {
- tree.erase_exact(T);
- removed = true;
- }
- return removed;
- };
-
- int size(void) { return tree.size(); }
-
- void optimize(void) { tree.optimise(); }
-
- RECORD_T* find_exact(RECORD_T T) {
- RECORD_T* found = NULL;
- typename TREE_T::const_iterator it = tree.find_exact(T);
- if (it!=tree.end())
- found = new RECORD_T(*it);
-
- return found;
- }
-
- size_t count_within_range(typename RECORD_T::point_t T, typename TREE_T::distance_type range) {
- RECORD_T query_record;
- memcpy(query_record.point, T, sizeof(COORD_T)*DIM);
-
- return tree.count_within_range(query_record, range);
- }
-
- std::vector<RECORD_T > find_within_range(typename RECORD_T::point_t T, typename TREE_T::distance_type range) {
- RECORD_T query_record;
- memcpy(query_record.point, T, sizeof(COORD_T)*DIM);
-
- std::vector<RECORD_T> v;
- tree.find_within_range(query_record, range, std::back_inserter(v));
- return v;
- }
-
- RECORD_T* find_nearest (typename RECORD_T::point_t T) {
- RECORD_T* found = NULL;
- RECORD_T query_record;
- memcpy(query_record.point, T, sizeof(COORD_T)*DIM);
-
- std::pair<typename TREE_T::const_iterator, typename TREE_T::distance_type> best =
- tree.find_nearest(query_record, std::numeric_limits<typename TREE_T::distance_type>::max());
-
- if (best.first!=tree.end()) {
- found = new RECORD_T(*best.first);
- }
- return found;
- }
-
- std::vector<RECORD_T >* get_all() {
- std::vector<RECORD_T>* v = new std::vector<RECORD_T>;
-
- for (typename TREE_T::const_iterator iter=tree.begin(); iter!=tree.end(); ++iter) {
- v->push_back(*iter);
- }
-
- return v;
- }
-
- size_t __len__() { return tree.size(); }
-};
-#endif //_PY_KDTREE_H_
diff --git a/python-bindings/py-kdtree.hpp.tmpl b/python-bindings/py-kdtree.hpp.tmpl
new file mode 100644
index 0000000..bebf403
--- /dev/null
+++ b/python-bindings/py-kdtree.hpp.tmpl
@@ -0,0 +1,265 @@
+/** \file
+ * Provides a Python interface for the libkdtree++.
+ *
+ * \author Willi Richert <w.richert at gmx.net>
+ *
+ *
+ * This defines a proxy to a (int/float)*dim -> long long KD-Tree. The long
+ * long is needed to save a reference to Python's object id(). Thereby,
+ * you can associate Python objects with 2D integer points.
+ *
+ * If you want to customize it you can adapt the following:
+ *
+ * * Dimension of the KD-Tree point vector.
+ * * DIM: number of dimensions.
+ * * operator==() and operator<<(): adapt to the number of comparisons
+ * * py-kdtree.i: Add or adapt all usages of PyArg_ParseTuple() to reflect the
+ * number of dimensions.
+ * * adapt query_records in find_nearest() and count_within_range()
+ * * Type of points.
+ * * coord_t: If you want to have e.g. floats you have
+ * to adapt all usages of PyArg_ParseTuple(): Change "i" to "f" e.g.
+ * * Type of associated data.
+ * * data_t: currently unsigned long long, which is "L" in py-kdtree.i
+ * * PyArg_ParseTuple() has to be changed to reflect changes in data_t
+ *
+ */
+
+
+#ifndef _PY_KDTREE_H_
+#define _PY_KDTREE_H_
+
+#include <kdtree++/kdtree.hpp>
+
+#include <iostream>
+#include <vector>
+#include <limits>
+#include <vector>
+
+#define MAX_RANGE std::numeric_limits<double>::max()
+
+template <size_t DIM, typename COORD_T, typename DATA_T >
+struct record_t {
+ static const size_t dim = DIM;
+ typedef COORD_T coord_t;
+ typedef DATA_T data_t;
+
+ typedef coord_t point_t[dim];
+
+ inline coord_t operator[](size_t const N) const { return point[N]; }
+
+ point_t point;
+ data_t data;
+};
+
+typedef double RANGE_T;
+
+%%TMPL_HPP_DEFS%%
+
+ ////////////////////////////////////////////////////////////////////////////////
+ // END OF TYPE SPECIFIC DEFINITIONS
+ ////////////////////////////////////////////////////////////////////////////////
+
+
+template <class RECORD_T>
+inline double tac(RECORD_T r, int k) { return r[k]; }
+
+template <size_t DIM, typename COORD_T, typename DATA_T >
+class PyKDTree {
+public:
+
+ typedef record_t<DIM, COORD_T, DATA_T> RECORD_T;
+ typedef KDTree::KDTree<DIM, RECORD_T, std::pointer_to_binary_function<RECORD_T,int,double> > TREE_T;
+
+ TREE_T tree;
+
+ PyKDTree() : tree(std::ptr_fun(tac<RECORD_T>)) { };
+
+ void add(RECORD_T T) { tree.insert(T); };
+
+ /**
+ Exact erase.
+ */
+ bool remove(RECORD_T T) {
+ bool removed = false;
+
+ typename TREE_T::const_iterator it = tree.find_exact(T);
+ if (it!=tree.end()) {
+ tree.erase_exact(T);
+ removed = true;
+ }
+ return removed;
+ };
+
+ int size(void) { return tree.size(); }
+
+ void optimize(void) { tree.optimise(); }
+
+ RECORD_T* find_exact(RECORD_T T) {
+ RECORD_T* found = NULL;
+ typename TREE_T::const_iterator it = tree.find_exact(T);
+ if (it!=tree.end())
+ found = new RECORD_T(*it);
+
+ return found;
+ }
+
+ size_t count_within_range(typename RECORD_T::point_t T, RANGE_T range) {
+ RECORD_T query_record;
+ memcpy(query_record.point, T, sizeof(COORD_T)*DIM);
+
+ return tree.count_within_range(query_record, range);
+ }
+
+ std::vector<RECORD_T >* find_within_range(typename RECORD_T::point_t T, RANGE_T range) {
+ RECORD_T query_record;
+ memcpy(query_record.point, T, sizeof(COORD_T)*DIM);
+
+ std::vector<RECORD_T> *v = new std::vector<RECORD_T>;
+ tree.find_within_range(query_record, range, std::back_inserter(*v));
+ return v;
+ }
+
+ RECORD_T* find_nearest (typename RECORD_T::point_t T) {
+ RECORD_T* found = NULL;
+ RECORD_T query_record;
+ memcpy(query_record.point, T, sizeof(COORD_T)*DIM);
+
+ std::pair<typename TREE_T::const_iterator, typename TREE_T::distance_type> best =
+ tree.find_nearest(query_record, MAX_RANGE);
+
+ if (best.first!=tree.end()) {
+ found = new RECORD_T(*best.first);
+ }
+ return found;
+ }
+
+ struct RECORD_DIST_T
+ {
+ // Priority based on distance to target
+ RECORD_T r;
+ double dist;
+ RECORD_DIST_T(RECORD_T const& t, double d): r(t), dist(d) {}
+
+ bool operator< (const RECORD_DIST_T& other) const
+ {
+ return dist < other.dist;
+ }
+ };
+
+ struct compare_kNN
+ {
+ struct State
+ {
+ RECORD_T target;
+ bool initialised;
+ int counter;
+
+ std::vector<RECORD_DIST_T> q;
+ double best_dist, worst_dist;
+
+
+ State(RECORD_T const& t ) :
+ target(t), initialised(false), best_dist(MAX_RANGE), worst_dist(MAX_RANGE)
+ {
+ counter=0;
+ }
+ };
+
+
+ State *state;
+ size_t k_in_kNN;
+
+ compare_kNN(State *s, size_t k) : state(s), k_in_kNN(k) {}
+
+ bool operator()(RECORD_T const& test_record) const
+ {
+ bool res=false;
+ state->counter += 1;
+ double sqr_dist = 0.0;
+ for (size_t i=0; i<DIM; i++)
+ sqr_dist += pow(test_record[i] - state->target[i], 2);
+
+
+ /* first look at a point... */
+ if (!state->initialised)
+ {
+ state->q.push_back(RECORD_DIST_T(test_record, sqr_dist));
+
+ state->initialised = true;
+ state->best_dist = state->worst_dist = sqr_dist;
+
+ res=true;
+ } else
+ {
+
+ if (sqr_dist < state->best_dist) /* better than our best? */
+ {
+ state->best_dist = sqr_dist;
+ if (state->q.size() == k_in_kNN)
+ state->q.pop_back();
+
+ state->q.push_back(RECORD_DIST_T(test_record, sqr_dist));
+ sort(state->q.begin(), state->q.end());
+
+ /* pretend we weren't interested so it won't reduce the
+ search space to this point this is the key to making
+ kdtree keep looking for 2nd best */
+ res = false;
+ }
+ else if (sqr_dist < state->worst_dist || state->q.size() < k_in_kNN) /* better than our second-best? */
+ {
+ if (state->q.size() == k_in_kNN)
+ {
+ RECORD_DIST_T dropped = state->q.back();
+ state->q.pop_back();
+ }
+
+ state->q.push_back(RECORD_DIST_T(test_record, sqr_dist));
+ sort(state->q.begin(), state->q.end());
+
+ state->worst_dist = (state->q.back()).dist;
+
+ res = true;
+ }
+ }
+
+ return res;
+ }
+
+ };
+
+ typedef std::pair<RECORD_T, double> res_type;
+
+ std::vector<std::pair<RECORD_T, double> >* find_k_nearest (typename RECORD_T::point_t T, size_t k) {
+ std::vector<std::pair<RECORD_T, double> > *v = new std::vector<std::pair<RECORD_T, double> >;
+ RECORD_T query_record;
+ memcpy(query_record.point, T, sizeof(COORD_T)*DIM);
+
+ typename compare_kNN::State state(query_record); // create the state
+ tree.find_nearest_if(query_record, MAX_RANGE, compare_kNN(&state, k));
+
+ if (state.initialised) {
+ //copy(state.q.begin(), state.q.end(), v.begin());
+ for (typename std::vector<RECORD_DIST_T>::const_iterator iter=state.q.begin(); iter!=state.q.end(); ++iter) {
+ v->push_back(std::pair<RECORD_T, double>((*iter).r, (*iter).dist));
+ }
+ }
+ return v;
+ }
+
+ std::vector<RECORD_T >* get_all() {
+ std::vector<RECORD_T>* v = new std::vector<RECORD_T>;
+
+ for (typename TREE_T::const_iterator iter=tree.begin(); iter!=tree.end(); ++iter) {
+ v->push_back(*iter);
+ }
+
+ return v;
+ }
+
+ size_t __len__() { return tree.size(); }
+
+};
+
+#endif //_PY_KDTREE_H_
diff --git a/python-bindings/py-kdtree.i b/python-bindings/py-kdtree.i
deleted file mode 100644
index cd45828..0000000
--- a/python-bindings/py-kdtree.i
+++ /dev/null
@@ -1,545 +0,0 @@
-/** \file
- * $Id$
- *
- * Provides a Python interface for the libkdtree++.
- *
- * \author Willi Richert <w.richert at gmx.net>
- *
- */
-
-%module kdtree
- //%include exception.i
-
-%{
-#define SWIG_FILE_WITH_INIT
-#include "py-kdtree.hpp"
-%}
-
-
-%ignore record_t::operator[];
-%ignore operator==;
-%ignore operator<<;
-%ignore KDTree::KDTree::operator=;
-%ignore tac;
-
-#define RECORD_2il record_t<2, int, unsigned long long> // cf. py-kdtree.hpp
-#define RECORD_4il record_t<4, int, unsigned long long> // cf. py-kdtree.hpp
-
-#define RECORD_1fl record_t<1, float, unsigned long long>
-#define RECORD_3fl record_t<3, float, unsigned long long>
-#define RECORD_6fl record_t<6, float, unsigned long long>
-
-////////////////////////////////////////////////////////////////////////////////
-// TYPE (int, int)
-////////////////////////////////////////////////////////////////////////////////
-%typemap(in) RECORD_2il (RECORD_2il temp) {
- if (PyTuple_Check($input)) {
-
- if (PyArg_ParseTuple($input,"(ii)L", &temp.point[0], &temp.point[1], &temp.data)!=0)
- {
- $1 = temp;
- } else {
- PyErr_SetString(PyExc_TypeError,"tuple must have 2 elements: (2dim int vector, long value)");
- return NULL;
- }
-
- } else {
- PyErr_SetString(PyExc_TypeError,"expected a tuple.");
- return NULL;
- }
- }
-
-%typemap(in) RECORD_2il::point_t (RECORD_2il::point_t point) {
- if (PyTuple_Check($input)) {
- if (PyArg_ParseTuple($input,"ii", &point[0], &point[1])!=0)
- {
- $1 = point;
- } else {
- PyErr_SetString(PyExc_TypeError,"tuple must contain 2 ints");
- return NULL;
- }
-
- } else {
- PyErr_SetString(PyExc_TypeError,"expected a tuple.");
- return NULL;
- }
- }
-
-%typemap(out) RECORD_2il * {
- RECORD_2il * r = $1;
- PyObject* py_result;
-
- if (r != NULL) {
-
- py_result = PyTuple_New(2);
- if (py_result==NULL) {
- PyErr_SetString(PyErr_Occurred(),"unable to create a tuple.");
- return NULL;
- }
-
- if (PyTuple_SetItem(py_result, 0, Py_BuildValue("(ii)", r->point[0], r->point[1]))==-1) {
- PyErr_SetString(PyErr_Occurred(),"(a) when setting element");
-
- Py_DECREF(py_result);
- return NULL;
- }
-
- if (PyTuple_SetItem(py_result, 1, Py_BuildValue("L", r->data))==-1) {
- PyErr_SetString(PyErr_Occurred(),"(b) when setting element");
-
- Py_DECREF(py_result);
- return NULL;
- }
- } else {
- py_result = Py_BuildValue("");
- }
-
- $result = py_result;
- }
-
-%typemap(out) std::vector<RECORD_2il >* {
- std::vector<RECORD_2il >* v = $1;
-
- PyObject* py_result = PyList_New(v->size());
- if (py_result==NULL) {
- PyErr_SetString(PyErr_Occurred(),"unable to create a list.");
- return NULL;
- }
- std::vector<RECORD_2il >::const_iterator iter = v->begin();
-
- for (size_t i=0; i<v->size(); i++, iter++) {
- if (PyList_SetItem(py_result, i, Py_BuildValue("(ii)L", (*iter).point[0], (*iter).point[1], (*iter).data))==-1) {
- PyErr_SetString(PyErr_Occurred(),"(c) when setting element");
-
- Py_DECREF(py_result);
- return NULL;
- } else {
- //std::cout << "successfully set element " << *iter << std::endl;
- }
- }
-
- $result = py_result;
- }
-////////////////////////////////////////////////////////////////////////////////
-// TYPE (int, int, int, int)
-////////////////////////////////////////////////////////////////////////////////
-%typemap(in) RECORD_4il (RECORD_4il temp) {
- if (PyTuple_Check($input)) {
-
- if (PyArg_ParseTuple($input,"(iiii)L", &temp.point[0], &temp.point[1],
- &temp.point[2], &temp.point[3],
- &temp.data)!=0)
- {
- $1 = temp;
- } else {
- PyErr_SetString(PyExc_TypeError,"tuple must have 4 elements: (4dim int vector, long value)");
- return NULL;
- }
-
- } else {
- PyErr_SetString(PyExc_TypeError,"expected a tuple.");
- return NULL;
- }
- }
-
-%typemap(in) RECORD_4il::point_t (RECORD_4il::point_t point) {
- if (PyTuple_Check($input)) {
- if (PyArg_ParseTuple($input,"iiii", &point[0], &point[1], &point[2], &point[3])!=0)
- {
- $1 = point;
- } else {
- PyErr_SetString(PyExc_TypeError,"tuple must contain 4 ints");
- return NULL;
- }
-
- } else {
- PyErr_SetString(PyExc_TypeError,"expected a tuple.");
- return NULL;
- }
- }
-
-%typemap(out) RECORD_4il * {
- RECORD_4il * r = $1;
- PyObject* py_result;
-
- if (r != NULL) {
-
- py_result = PyTuple_New(2);
- if (py_result==NULL) {
- PyErr_SetString(PyErr_Occurred(),"unable to create a tuple.");
- return NULL;
- }
-
- if (PyTuple_SetItem(py_result, 0, Py_BuildValue("(iiii)",
- r->point[0], r->point[1], r->point[2], r->point[3]))==-1) {
- PyErr_SetString(PyErr_Occurred(),"(a) when setting element");
-
- Py_DECREF(py_result);
- return NULL;
- }
-
- if (PyTuple_SetItem(py_result, 1, Py_BuildValue("L", r->data))==-1) {
- PyErr_SetString(PyErr_Occurred(),"(b) when setting element");
-
- Py_DECREF(py_result);
- return NULL;
- }
- } else {
- py_result = Py_BuildValue("");
- }
-
- $result = py_result;
- }
-
-%typemap(out) std::vector<RECORD_4il >* {
- std::vector<RECORD_4il >* v = $1;
-
- PyObject* py_result = PyList_New(v->size());
- if (py_result==NULL) {
- PyErr_SetString(PyErr_Occurred(),"unable to create a list.");
- return NULL;
- }
- std::vector<RECORD_4il >::const_iterator iter = v->begin();
-
- for (size_t i=0; i<v->size(); i++, iter++) {
- if (PyList_SetItem(py_result, i, Py_BuildValue("(iiii)L",
- (*iter).point[0], (*iter).point[1], (*iter).point[2], (*iter).point[3],
- (*iter).data))==-1) {
- PyErr_SetString(PyErr_Occurred(),"(c) when setting element");
-
- Py_DECREF(py_result);
- return NULL;
- } else {
- //std::cout << "successfully set element " << *iter << std::endl;
- }
- }
-
- $result = py_result;
- }
-
-////////////////////////////////////////////////////////////////////////////////
-// TYPE (float)
-////////////////////////////////////////////////////////////////////////////////
-%typemap(in) RECORD_1fl (RECORD_1fl temp) {
- if (PyTuple_Check($input)) {
-
- if (PyArg_ParseTuple($input,"(f)L",
- &temp.point[0],
- &temp.data)!=0)
- {
- $1 = temp;
- } else {
- PyErr_SetString(PyExc_TypeError,"tuple must have 2 elements: (1dim float tuple, long value)");
- return NULL;
- }
-
- } else {
- PyErr_SetString(PyExc_TypeError,"expected a tuple.");
- return NULL;
- }
- }
-
-%typemap(in) RECORD_1fl::point_t (RECORD_1fl::point_t point) {
- if (PyTuple_Check($input)) {
- if (PyArg_ParseTuple($input,"f",
- &point[0])!=0)
- {
- $1 = point;
- } else {
- PyErr_SetString(PyExc_TypeError,"tuple must contain 1 float");
- return NULL;
- }
-
- } else {
- PyErr_SetString(PyExc_TypeError,"expected a tuple.");
- return NULL;
- }
- }
-
-%typemap(out) RECORD_1fl * {
- RECORD_1fl * r = $1;
- PyObject* py_result;
-
- if (r != NULL) {
-
- py_result = PyTuple_New(2);
- if (py_result==NULL) {
- PyErr_SetString(PyErr_Occurred(),"unable to create a tuple.");
- return NULL;
- }
-
- if (PyTuple_SetItem(py_result, 0, Py_BuildValue("(f)",
- r->point[0]))==-1) {
- PyErr_SetString(PyErr_Occurred(),"(a) when setting element");
-
- Py_DECREF(py_result);
- return NULL;
- }
-
- if (PyTuple_SetItem(py_result, 1, Py_BuildValue("L", r->data))==-1) {
- PyErr_SetString(PyErr_Occurred(),"(b) when setting element");
-
- Py_DECREF(py_result);
- return NULL;
- }
- } else {
- py_result = Py_BuildValue("");
- }
-
- $result = py_result;
- }
-
-%typemap(out) std::vector<RECORD_1fl >* {
- std::vector<RECORD_1fl >* v = $1;
-
- PyObject* py_result = PyList_New(v->size());
- if (py_result==NULL) {
- PyErr_SetString(PyErr_Occurred(),"unable to create a list.");
- return NULL;
- }
- std::vector<RECORD_1fl >::const_iterator iter = v->begin();
-
- for (size_t i=0; i<v->size(); i++, iter++) {
- if (PyList_SetItem(py_result, i, Py_BuildValue("(f)L",
- (*iter).point[0],
- (*iter).data))==-1) {
- PyErr_SetString(PyErr_Occurred(),"(c) when setting element");
-
- Py_DECREF(py_result);
- return NULL;
- } else {
- //std::cout << "successfully set element " << *iter << std::endl;
- }
- }
-
- $result = py_result;
- }
-
-////////////////////////////////////////////////////////////////////////////////
-// TYPE (float, float, float)
-////////////////////////////////////////////////////////////////////////////////
-%typemap(in) RECORD_3fl (RECORD_3fl temp) {
- if (PyTuple_Check($input)) {
-
- if (PyArg_ParseTuple($input,"(fff)L",
- &temp.point[0], &temp.point[1],
- &temp.point[2],
- &temp.data)!=0)
- {
- $1 = temp;
- } else {
- PyErr_SetString(PyExc_TypeError,"tuple must have 2 elements: (3dim float tuple, long value)");
- return NULL;
- }
-
- } else {
- PyErr_SetString(PyExc_TypeError,"expected a tuple.");
- return NULL;
- }
- }
-
-%typemap(in) RECORD_3fl::point_t (RECORD_3fl::point_t point) {
- if (PyTuple_Check($input)) {
- if (PyArg_ParseTuple($input,"fff",
- &point[0], &point[1],
- &point[2])!=0)
- {
- $1 = point;
- } else {
- PyErr_SetString(PyExc_TypeError,"tuple must contain 3 floats");
- return NULL;
- }
-
- } else {
- PyErr_SetString(PyExc_TypeError,"expected a tuple.");
- return NULL;
- }
- }
-
-%typemap(out) RECORD_3fl * {
- RECORD_3fl * r = $1;
- PyObject* py_result;
-
- if (r != NULL) {
-
- py_result = PyTuple_New(2);
- if (py_result==NULL) {
- PyErr_SetString(PyErr_Occurred(),"unable to create a tuple.");
- return NULL;
- }
-
- if (PyTuple_SetItem(py_result, 0, Py_BuildValue("(fff)",
- r->point[0], r->point[1], r->point[2]))==-1) {
- PyErr_SetString(PyErr_Occurred(),"(a) when setting element");
-
- Py_DECREF(py_result);
- return NULL;
- }
-
- if (PyTuple_SetItem(py_result, 1, Py_BuildValue("L", r->data))==-1) {
- PyErr_SetString(PyErr_Occurred(),"(b) when setting element");
-
- Py_DECREF(py_result);
- return NULL;
- }
- } else {
- py_result = Py_BuildValue("");
- }
-
- $result = py_result;
- }
-
-%typemap(out) std::vector<RECORD_3fl >* {
- std::vector<RECORD_3fl >* v = $1;
-
- PyObject* py_result = PyList_New(v->size());
- if (py_result==NULL) {
- PyErr_SetString(PyErr_Occurred(),"unable to create a list.");
- return NULL;
- }
- std::vector<RECORD_3fl >::const_iterator iter = v->begin();
-
- for (size_t i=0; i<v->size(); i++, iter++) {
- if (PyList_SetItem(py_result, i, Py_BuildValue("(fff)L",
- (*iter).point[0], (*iter).point[1],
- (*iter).point[2],
- (*iter).data))==-1) {
- PyErr_SetString(PyErr_Occurred(),"(c) when setting element");
-
- Py_DECREF(py_result);
- return NULL;
- } else {
- //std::cout << "successfully set element " << *iter << std::endl;
- }
- }
-
- $result = py_result;
- }
-
-////////////////////////////////////////////////////////////////////////////////
-// TYPE (float, float, float, float, float, float)
-////////////////////////////////////////////////////////////////////////////////
-%typemap(in) RECORD_6fl (RECORD_6fl temp) {
- if (PyTuple_Check($input)) {
-
- if (PyArg_ParseTuple($input,"(ffffff)L",
- &temp.point[0], &temp.point[1],
- &temp.point[2], &temp.point[3],
- &temp.point[4], &temp.point[5],
- &temp.data)!=0)
- {
- $1 = temp;
- } else {
- PyErr_SetString(PyExc_TypeError,"tuple must have 2 elements: (6dim float tuple, long value)");
- return NULL;
- }
-
- } else {
- PyErr_SetString(PyExc_TypeError,"expected a tuple.");
- return NULL;
- }
- }
-
-%typemap(in) RECORD_6fl::point_t (RECORD_6fl::point_t point) {
- if (PyTuple_Check($input)) {
- if (PyArg_ParseTuple($input,"ffffff",
- &point[0], &point[1],
- &point[2], &point[3],
- &point[4], &point[5])!=0)
- {
- $1 = point;
- } else {
- PyErr_SetString(PyExc_TypeError,"tuple must contain 6 floats");
- return NULL;
- }
-
- } else {
- PyErr_SetString(PyExc_TypeError,"expected a tuple.");
- return NULL;
- }
- }
-
-%typemap(out) RECORD_6fl * {
- RECORD_6fl * r = $1;
- PyObject* py_result;
-
- if (r != NULL) {
-
- py_result = PyTuple_New(2);
- if (py_result==NULL) {
- PyErr_SetString(PyErr_Occurred(),"unable to create a tuple.");
- return NULL;
- }
-
- if (PyTuple_SetItem(py_result, 0, Py_BuildValue("(ffffff)",
- r->point[0], r->point[1],
- r->point[2], r->point[3],
- r->point[4], r->point[5]))==-1) {
- PyErr_SetString(PyErr_Occurred(),"(a) when setting element");
-
- Py_DECREF(py_result);
- return NULL;
- }
-
- if (PyTuple_SetItem(py_result, 1, Py_BuildValue("L", r->data))==-1) {
- PyErr_SetString(PyErr_Occurred(),"(b) when setting element");
-
- Py_DECREF(py_result);
- return NULL;
- }
- } else {
- py_result = Py_BuildValue("");
- }
-
- $result = py_result;
- }
-
-%typemap(out) std::vector<RECORD_6fl >* {
- std::vector<RECORD_6fl >* v = $1;
-
- PyObject* py_result = PyList_New(v->size());
- if (py_result==NULL) {
- PyErr_SetString(PyErr_Occurred(),"unable to create a list.");
- return NULL;
- }
- std::vector<RECORD_6fl >::const_iterator iter = v->begin();
-
- for (size_t i=0; i<v->size(); i++, iter++) {
- if (PyList_SetItem(py_result, i, Py_BuildValue("(ffffff)L",
- (*iter).point[0], (*iter).point[1],
- (*iter).point[2], (*iter).point[3],
- (*iter).point[4], (*iter).point[5],
- (*iter).data))==-1) {
- PyErr_SetString(PyErr_Occurred(),"(c) when setting element");
-
- Py_DECREF(py_result);
- return NULL;
- } else {
- //std::cout << "successfully set element " << *iter << std::endl;
- }
- }
-
- $result = py_result;
- }
-
-
-////////////////////////////////////////////////////////////////////////////////
-////////////////////////////////////////////////////////////////////////////////
-////////////////////////////////////////////////////////////////////////////////
-
-%include "py-kdtree.hpp"
-
-%template () RECORD_2il;
-%template (KDTree_2Int) PyKDTree<2, int, unsigned long long>;
-
-%template () RECORD_4il;
-%template (KDTree_4Int) PyKDTree<4, int, unsigned long long>;
-
-%template () RECORD_1fl;
-%template (KDTree_1Float) PyKDTree<1, float, unsigned long long>;
-
-%template () RECORD_3fl;
-%template (KDTree_3Float) PyKDTree<3, float, unsigned long long>;
-
-%template () RECORD_6fl;
-%template (KDTree_6Float) PyKDTree<6, float, unsigned long long>;
diff --git a/python-bindings/py-kdtree.i.tmpl b/python-bindings/py-kdtree.i.tmpl
new file mode 100644
index 0000000..c065f72
--- /dev/null
+++ b/python-bindings/py-kdtree.i.tmpl
@@ -0,0 +1,33 @@
+/** \file
+ *
+ * Provides a Python interface for the libkdtree++.
+ *
+ * \author Willi Richert <w.richert at gmx.net>
+ *
+ */
+
+%module kdtree
+
+%{
+#define SWIG_FILE_WITH_INIT
+#include "py-kdtree.hpp"
+%}
+
+
+%ignore record_t::operator[];
+%ignore operator==;
+%ignore operator<<;
+%ignore KDTree::KDTree::operator=;
+%ignore tac;
+
+%newobject KDTree::KDTree::find_nearest;;
+%newobject KDTree::KDTree::find_k_nearest;
+%newobject KDTree::KDTree::find_exact;
+%newobject KDTree::KDTree::find_within_range;
+%newobject KDTree::KDTree::get_all;
+
+%%TMPL_BODY%%
+
+%include "py-kdtree.hpp"
+
+%%TMPL_PY_CLASS_DEF%%
diff --git a/python-bindings/py-kdtree_test.cpp b/python-bindings/py-kdtree_test.cpp
index df4df0c..1cbb3a5 100644
--- a/python-bindings/py-kdtree_test.cpp
+++ b/python-bindings/py-kdtree_test.cpp
@@ -6,23 +6,24 @@
#include <vector>
#include "py-kdtree.hpp"
+#include <string.h>
int main()
{
- KDTree_2Int t;
+ PyKDTree<2, int, unsigned long long> t;
- RECORD_2il c0 = { {5, 4} }; t.add(c0);
- RECORD_2il c1 = { {4, 2} }; t.add(c1);
- RECORD_2il c2 = { {7, 6} }; t.add(c2);
- RECORD_2il c3 = { {2, 2} }; t.add(c3);
- RECORD_2il c4 = { {8, 0} }; t.add(c4);
- RECORD_2il c5 = { {5, 7} }; t.add(c5);
- RECORD_2il c6 = { {3, 3} }; t.add(c6);
- RECORD_2il c7 = { {9, 7} }; t.add(c7);
- RECORD_2il c8 = { {2, 2} }; t.add(c8);
- RECORD_2il c9 = { {2, 0} }; t.add(c9);
+ RECORD_2iull c0 = { {5, 4} }; t.add(c0);
+ RECORD_2iull c1 = { {4, 2} }; t.add(c1);
+ RECORD_2iull c2 = { {7, 6} }; t.add(c2);
+ RECORD_2iull c3 = { {2, 2} }; t.add(c3);
+ RECORD_2iull c4 = { {8, 0} }; t.add(c4);
+ RECORD_2iull c5 = { {5, 7} }; t.add(c5);
+ RECORD_2iull c6 = { {3, 3} }; t.add(c6);
+ RECORD_2iull c7 = { {9, 7} }; t.add(c7);
+ RECORD_2iull c8 = { {2, 2} }; t.add(c8);
+ RECORD_2iull c9 = { {2, 0} }; t.add(c9);
std::cout << t.tree << std::endl;
@@ -34,9 +35,8 @@ int main()
t.optimize();
std::cout << std::endl << t.tree << std::endl;
-
int i=0;
- for (KDTREE_TYPE_2il::const_iterator iter=t.tree.begin(); iter!=t.tree.end(); ++iter, ++i);
+ for (KDTREE_TYPE_2iull::const_iterator iter=t.tree.begin(); iter!=t.tree.end(); ++iter, ++i);
std::cout << "iterator walked through " << i << " nodes in total" << std::endl;
if (i!=6)
{
@@ -44,7 +44,7 @@ int main()
return 1;
}
i=0;
- for (KDTREE_TYPE_2il::const_reverse_iterator iter=t.tree.rbegin(); iter!=t.tree.rend(); ++iter, ++i);
+ for (KDTREE_TYPE_2iull::const_reverse_iterator iter=t.tree.rbegin(); iter!=t.tree.rend(); ++iter, ++i);
std::cout << "reverse_iterator walked through " << i << " nodes in total" << std::endl;
if (i!=6)
{
@@ -52,33 +52,46 @@ int main()
return 1;
}
- RECORD_2il::point_t s = {5, 4};
- std::vector<RECORD_2il> v;
+ RECORD_2iull::point_t s = {5, 4};
+ std::vector<RECORD_2iull>* v;
unsigned int const RANGE = 3;
size_t count = t.count_within_range(s, RANGE);
std::cout << "counted " << count
- << " nodes within range " << RANGE << " of " << s << ".\n";
+ << " nodes within range " << RANGE << " of " << s << ".\n";
v = t.find_within_range(s, RANGE);
- std::cout << "found " << v.size() << " nodes within range " << RANGE
- << " of " << s << ":\n";
- std::vector<RECORD_2il>::const_iterator ci = v.begin();
- for (; ci != v.end(); ++ci)
+ std::cout << "found " << v->size() << " nodes within range " << RANGE
+ << " of " << s << ":\n";
+ std::vector<RECORD_2iull>::const_iterator ci = v->begin();
+ for (; ci != v->end(); ++ci)
std::cout << *ci << " ";
std::cout << "\n" << std::endl;
+ delete v;
- std::cout << "Nearest to " << s << ": " <<
- t.find_nearest(s) << std::endl;
+ RECORD_2iull *nearest = t.find_nearest(s);
+ std::cout << "Nearest to " << s << ": " << nearest << std::endl;
+ delete nearest;
- RECORD_2il::point_t s2 = { 10, 10};
- std::cout << "Nearest to " << s2 << ": " <<
- t.find_nearest(s2) << std::endl;
+ RECORD_2iull::point_t s2 = { 10, 10};
+ nearest = t.find_nearest(s2);
+ std::cout << "Nearest to " << s2 << ": " << nearest << std::endl;
+ delete nearest;
std::cout << std::endl;
-
std::cout << t.tree << std::endl;
+ // memory leak test
+ PyKDTree<2, int, unsigned long long> t2;
+ for (int i=0; i<10E9; i++)
+ {
+ t2.add(c0);
+ t2.remove(c0);
+ if (i%10000==0)
+ std::cout << i << std::endl;
+ }
+
+
return 0;
}
diff --git a/python-bindings/py-kdtree_test.py b/python-bindings/py-kdtree_test.py
index 6b96a9a..d89858d 100644
--- a/python-bindings/py-kdtree_test.py
+++ b/python-bindings/py-kdtree_test.py
@@ -1,11 +1,21 @@
-#
-# $Id: py-kdtree_test.py 2268 2008-08-20 10:08:58Z richert $
-#
-import unittest
+import unittest, random, sys, os
-from kdtree import KDTree_2Int, KDTree_4Int, KDTree_1Float, KDTree_3Float, KDTree_6Float
+from kdtree import KDTree_2Int, KDTree_4Int, KDTree_2Float, KDTree_3Float, KDTree_4Float, KDTree_6Float
+#from kdtree import KDTree_2Float
+EPSILON=1E-5
+def almost_equal_records(r1, r2):
+ p1, p2 = r1[0], r2[0]
+ d1, d2 = r1[1], r2[1]
+ for a,b in zip(p1, p2):
+ if abs(a-b)>EPSILON:
+ return False
+
+ if d1!=d2:
+ return False
+
+ return True
class KDTree_2IntTestCase(unittest.TestCase):
def test_empty(self):
@@ -51,6 +61,28 @@ class KDTree_2IntTestCase(unittest.TestCase):
actual = nn.find_nearest((6, 6))[1]
self.assertTrue(expected==nn_id[actual], "%s != %s"%(str(expected), str(nn_id[actual])))
+ def test_find_within_range(self):
+ nn = KDTree_6Float()
+
+ nn_id = {}
+
+ o1 = object()
+ nn.add(((1,1,0,0,0,0), id(o1)))
+ nn_id[id(o1)] = o1
+ o2 = object()
+ nn.add(((10,10,0,0,0,0), id(o2)))
+ nn_id[id(o2)] = o2
+ o3 = object()
+ nn.add(((4.1, 4.1,0,0,0,0), id(o3)))
+ nn_id[id(o3)] = o3
+
+ expected = set([long(id(o1)), long(id(o3))])
+ actual = set([ident
+ for _coord, ident
+ in nn.find_within_range((2.1,2.1,0,0,0,0), 3.9)])
+ self.assertTrue(expected==actual, "%s != %s"%(str(expected), str(actual)))
+
+
def test_remove(self):
class C:
def __init__(self, i):
@@ -76,6 +108,36 @@ class KDTree_2IntTestCase(unittest.TestCase):
self.assertTrue(nearest[1] == id(o1), "%s != %s"%(nearest[1], o1))
#self.assertTrue(nearest[1] is o1, "%s,%s is not %s"%(str(nearest[0]), str(nearest[1]), str((k1,id(o1)))))
+ def test_count_within_range(self):
+ nn = KDTree_2Int()
+
+ for p in [(0,0), (1,0), (0,1), (1,1)]:
+ nn.add((p, id(p)))
+
+ res = nn.count_within_range((0,0), 1.0)
+ self.assertEqual(3, res, "Counted %i points instead of %i"%(res, 3))
+
+ res = nn.count_within_range((0,0), 1.9)
+ self.assertEqual(4, res, "Counted %i points instead of %i"%(res, 4))
+
+ def test_memory_usage(self):
+ nn = KDTree_2Int()
+
+ class SO:
+ def __init__(self, point, region):
+ self.point = point
+ self.region = region
+
+ for i in range(1E5):
+ so=SO((1,2), i)
+ nn.add((so.point, id(so)))
+ nn.remove((so.point, id(so)))
+ del so
+ if i%1000==0:
+ print i
+ os.system("free | grep Mem:")
+
+
class KDTree_4IntTestCase(unittest.TestCase):
def test_empty(self):
nn = KDTree_4Int()
@@ -145,51 +207,48 @@ class KDTree_4IntTestCase(unittest.TestCase):
self.assertTrue(nearest[1] == id(o1), "%s != %s"%(nearest[1], o1))
#self.assertTrue(nearest[1] is o1, "%s,%s is not %s"%(str(nearest[0]), str(nearest[1]), str((k1,id(o1)))))
-class KDTree_1FloatTestCase(unittest.TestCase):
+class KDTree_4FloatTestCase(unittest.TestCase):
def test_empty(self):
- nn = KDTree_1Float()
+ nn = KDTree_4Float()
self.assertEqual(0, nn.size())
- actual = nn.find_nearest((2.,))
+ actual = nn.find_nearest((0,0,2,3))
self.assertTrue(None==actual, "%s != %s"%(str(None), str(actual)))
def test_get_all(self):
- nn = KDTree_1Float()
+ nn = KDTree_4Int()
o1 = object()
- nn.add(((1.,), id(o1)))
+ nn.add(((0,0,1,1), id(o1)))
o2 = object()
- nn.add(((10.,), id(o2)))
+ nn.add(((0,0,10,10), id(o2)))
o3 = object()
- nn.add(((11.,), id(o3)))
+ nn.add(((0,0,11,11), id(o3)))
- self.assertEqual([((1.,), id(o1)), ((10.,), id(o2)), ((11.,), id(o3))], nn.get_all())
+ self.assertEqual([((0,0,1,1), id(o1)), ((0,0,10,10), id(o2)), ((0,0,11,11), id(o3))], nn.get_all())
self.assertEqual(3, len(nn))
- nn.remove(((10.,), id(o2)))
+ nn.remove(((0,0,10,10), id(o2)))
self.assertEqual(2, len(nn))
- self.assertEqual([((1.,), id(o1)), ((11.,), id(o3))], nn.get_all())
+ self.assertEqual([((0,0,1,1), id(o1)), ((0,0,11,11), id(o3))], nn.get_all())
def test_nearest(self):
- nn = KDTree_1Float()
+ nn = KDTree_4Int()
nn_id = {}
o1 = object()
- nn.add(((1.,), id(o1)))
+ nn.add(((0,0,1,1), id(o1)))
nn_id[id(o1)] = o1
o2 = object()
- nn.add(((10.,), id(o2)))
+ nn.add(((0,0,10,10), id(o2)))
nn_id[id(o2)] = o2
- o3 = object()
- nn.add(((4.1,), id(o3)))
- nn_id[id(o3)] = o3
- expected = o3
- actual = nn.find_nearest((2.9,))[1]
+ expected = o1
+ actual = nn.find_nearest((0,0,2,2))[1]
self.assertTrue(expected==nn_id[actual], "%s != %s"%(str(expected), str(nn_id[actual])))
- expected = o3
- actual = nn.find_nearest((6.,))[1]
+ expected = o2
+ actual = nn.find_nearest((0,0,6,6))[1]
self.assertTrue(expected==nn_id[actual], "%s != %s"%(str(expected), str(nn_id[actual])))
def test_remove(self):
@@ -198,13 +257,13 @@ class KDTree_1FloatTestCase(unittest.TestCase):
self.i = i
self.next = None
- nn = KDTree_1Float()
+ nn = KDTree_4Int()
- k1, o1 = (1.1,), C(7)
+ k1, o1 = (0,0,1,1), C(7)
self.assertFalse(nn.remove((k1, id(o1))), "This cannot be removed!")
nn.add((k1, id(o1)))
- k2, o2 = (1.1,), C(7)
+ k2, o2 = (0,0,1,1), C(7)
nn.add((k2, id(o2)))
self.assertEqual(2, nn.size())
@@ -216,6 +275,100 @@ class KDTree_1FloatTestCase(unittest.TestCase):
nearest = nn.find_nearest(k1)
self.assertTrue(nearest[1] == id(o1), "%s != %s"%(nearest[1], o1))
#self.assertTrue(nearest[1] is o1, "%s,%s is not %s"%(str(nearest[0]), str(nearest[1]), str((k1,id(o1)))))
+
+class KDTree_2FloatTestCase(unittest.TestCase):
+ def test_empty(self):
+ nn = KDTree_2Float()
+ self.assertEqual(0, nn.size())
+
+ actual = nn.find_nearest((2,3))
+ self.assertTrue(None==actual, "%s != %s"%(str(None), str(actual)))
+
+ o1 = object()
+ p1 = (1.0, 0.0)
+ nn.add((p1, id(o1)))
+
+ o2 = object()
+ p2 = (0.7, 0.7)
+ nn.add((p2, id(o2)))
+
+ actual = nn.find_nearest((0.0, 0.0))
+ expected = (p2, id(o2))
+ print actual, expected
+ self.assertTrue(almost_equal_records(actual, expected), msg="%s != %s"%(str(expected), str(actual)))
+
+ nn.remove((p2, id(o2)))
+ o3 = object()
+ p3 = (0.71, 0.71)
+ nn.add((p3, id(o3)))
+
+ actual = nn.find_nearest((0.0, 0.0))
+ expected = (p1, id(o1))
+ self.assertTrue(almost_equal_records(actual, expected), msg="%s != %s"%(str(expected), str(actual)))
+
+ def test_2_nearest(self):
+ nn = KDTree_2Float()
+ self.assertEqual(0, nn.size())
+
+ actual = nn.find_nearest((2,3))
+ self.assertTrue(None==actual, "%s != %s"%(str(None), str(actual)))
+
+ o1 = object()
+ p1 = (1.0, 0.0)
+ nn.add((p1, id(o1)))
+
+ o2 = object()
+ p2 = (0.7, 0.7)
+ nn.add((p2, id(o2)))
+
+ o3 = object()
+ p3 = (0.6, 0.6)
+ nn.add((p3, id(o3)))
+
+ r1, r2 = nn.find_k_nearest((0,0), 2)
+ self.assertTrue(almost_equal_records(r1, (p3, id(o3))), "%s != %s"%(r1, (p3, id(o3))))
+ self.assertTrue(almost_equal_records(r2, (p2, id(o2))), "%s != %s"%(r2, (p2, id(o2))))
+
+ o4 = object()
+ p4 = (0.5, 0.6)
+ nn.add((p4, id(o4)))
+ r1, r2 = nn.find_k_nearest((0,0), 2)
+ self.assertTrue(almost_equal_records(r1, (p4, id(o4))), "%s != %s"%(r1, (p3, id(o3))))
+ self.assertTrue(almost_equal_records(r2, (p3, id(o3))), "%s != %s"%(r2, (p3, id(o3))))
+
+ nn.remove((p3, id(o3)))
+ r1, r2 = nn.find_k_nearest((0,0), 2)
+ self.assertTrue(almost_equal_records(r1, (p4, id(o4))), "%s != %s"%(r1, (p4, id(o4))))
+ self.assertTrue(almost_equal_records(r2, (p2, id(o2))), "%s != %s"%(r2, (p2, id(o2))))
+
+ def _test_k_nearest(self, k, NUM):
+ nn = KDTree_2Float()
+ self.assertEqual(0, nn.size())
+
+ actual = nn.find_nearest((2,3))
+ self.assertTrue(None==actual, "%s != %s"%(str(None), str(actual)))
+
+ add_list = []
+ ref_list = []
+ for i in range(NUM):
+ o=object()
+ add_list.append(((i,i), id(o)))
+ ref_list.append(o) # so that it does not get ref'counted
+
+
+ random.shuffle(add_list)
+ for r in add_list:
+ nn.add(r)
+
+ print len(add_list), "items in add_list"
+ nearest_res = nn.find_k_nearest((0,0), k)
+ print nearest_res
+ res_list = [p[0] for p in nearest_res]
+ for i in range(k):
+ self.assertTrue((i,i) in res_list, "%s not in %s"%(str((i,i)), str(res_list)))
+
+ def test_7_nearest(self):
+ self._test_k_nearest(k=7, NUM=100)
class KDTree_3FloatTestCase(unittest.TestCase):
def test_empty(self):
@@ -288,7 +441,19 @@ class KDTree_3FloatTestCase(unittest.TestCase):
nearest = nn.find_nearest(k1)
self.assertTrue(nearest[1] == id(o1), "%s != %s"%(nearest[1], o1))
#self.assertTrue(nearest[1] is o1, "%s,%s is not %s"%(str(nearest[0]), str(nearest[1]), str((k1,id(o1)))))
+
+ def test_count_within_range(self):
+ nn = KDTree_3Float()
+
+ for p in [(0,0,0), (1,0,0), (0,1,0), (1,1,0)]:
+ nn.add((p, id(p)))
+
+ res = nn.count_within_range((0,0,0), 1.1)
+ self.assertEqual(3, res, "Counted %i points instead of 3"%res)
+ res = nn.count_within_range((0,0), 1.9)
+ self.assertEqual(4, res, "Counted %i points instead of 4"%res)
+
class KDTree_6FloatTestCase(unittest.TestCase):
def test_empty(self):
nn = KDTree_6Float()
@@ -310,7 +475,8 @@ class KDTree_6FloatTestCase(unittest.TestCase):
self.assertEqual(3, len(nn))
nn.remove(((10,10,0,0,0,0), id(o2)))
- self.assertEqual(2, len(nn))
+ self.assertEqual(2, len(nn))
+
self.assertEqual([((1,1,0,0,0,0), id(o1)), ((11,11,0,0,0,0), id(o3))], nn.get_all())
def test_nearest(self):
@@ -364,7 +530,6 @@ class KDTree_6FloatTestCase(unittest.TestCase):
def suite():
return unittest.defaultTestLoader.loadTestsFromModule(sys.modules.get(__name__))
-
+
if __name__ == '__main__':
- unittest.main()
-
+ unittest.main(argv=sys.argv)
--
1.5.6.3
More information about the libkdtree-devel
mailing list