[med-svn] [Git][python-team/packages/python-pynndescent][upstream] New upstream version 0.5.5

Andreas Tille (@tille) gitlab at salsa.debian.org
Sun Jan 2 19:23:21 GMT 2022



Andreas Tille pushed to branch upstream at Debian Python Team / packages / python-pynndescent


Commits:
7d135bb3 by Andreas Tille at 2022-01-02T18:59:03+01:00
New upstream version 0.5.5
- - - - -


19 changed files:

- PKG-INFO
- README.rst
- pynndescent.egg-info/PKG-INFO
- pynndescent.egg-info/SOURCES.txt
- pynndescent/__init__.py
- pynndescent/distances.py
- pynndescent/graph_utils.py
- pynndescent/optimal_transport.py
- pynndescent/pynndescent_.py
- pynndescent/rp_trees.py
- pynndescent/sparse.py
- pynndescent/sparse_nndescent.py
- + pynndescent/tests/conftest.py
- pynndescent/tests/test_distances.py
- pynndescent/tests/test_pynndescent_.py
- pynndescent/tests/test_rank.py
- pynndescent/utils.py
- requirements.txt
- setup.py


Changes:

=====================================
PKG-INFO
=====================================
@@ -1,6 +1,6 @@
 Metadata-Version: 1.2
 Name: pynndescent
-Version: 0.5.2
+Version: 0.5.5
 Summary: Nearest Neighbor Descent
 Home-page: http://github.com/lmcinnes/pynndescent
 Author: Leland McInnes
@@ -8,8 +8,13 @@ Author-email: leland.mcinnes at gmail.com
 Maintainer: Leland McInnes
 Maintainer-email: leland.mcinnes at gmail.com
 License: BSD
-Description: .. image:: https://travis-ci.org/lmcinnes/pynndescent.svg
-            :target: https://travis-ci.org/lmcinnes/pynndescent
+Description: .. image:: doc/pynndescent_logo.png
+          :width: 600
+          :align: center
+          :alt: PyNNDescent Logo
+        
+        .. image:: https://travis-ci.com/lmcinnes/pynndescent.svg
+            :target: https://travis-ci.com/lmcinnes/pynndescent
             :alt: Travis Build Status
         .. image:: https://ci.appveyor.com/api/projects/status/github/lmcinnes/pynndescent?branch=master&svg=true
             :target: https://ci.appveyor.com/project/lmcinnes/pynndescent
@@ -55,14 +60,14 @@ Description: .. image:: https://travis-ci.org/lmcinnes/pynndescent.svg
         `ann-benchmarks <https://github.com/erikbern/ann-benchmarks>`_ system puts it
         solidly in the mix of top performing ANN libraries:
         
-        **GIST-960 Euclidean**
+        **SIFT-128 Euclidean**
         
-        .. image:: https://camo.githubusercontent.com/142a48c992ba689b8ea9e62636b5281a97322f74/68747470733a2f2f7261772e6769746875622e636f6d2f6572696b6265726e2f616e6e2d62656e63686d61726b732f6d61737465722f726573756c74732f676973742d3936302d6575636c696465616e2e706e67
-            :alt: ANN benchmark performance for GIST 960 dataset
+        .. image:: https://pynndescent.readthedocs.io/en/latest/_images/sift.png
+            :alt: ANN benchmark performance for SIFT 128 dataset
         
         **NYTimes-256 Angular**
         
-        .. image:: https://camo.githubusercontent.com/6120a35a9db64104eaa1c95cb4803c2fc4cd2679/68747470733a2f2f7261772e6769746875622e636f6d2f6572696b6265726e2f616e6e2d62656e63686d61726b732f6d61737465722f726573756c74732f6e7974696d65732d3235362d616e67756c61722e706e67
+        .. image:: https://pynndescent.readthedocs.io/en/latest/_images/nytimes.png
             :alt: ANN benchmark performance for NYTimes 256 dataset
         
         While PyNNDescent is among fastest ANN library, it is also both easy to install (pip


=====================================
README.rst
=====================================
@@ -1,5 +1,10 @@
-.. image:: https://travis-ci.org/lmcinnes/pynndescent.svg
-    :target: https://travis-ci.org/lmcinnes/pynndescent
+.. image:: doc/pynndescent_logo.png
+  :width: 600
+  :align: center
+  :alt: PyNNDescent Logo
+
+.. image:: https://travis-ci.com/lmcinnes/pynndescent.svg
+    :target: https://travis-ci.com/lmcinnes/pynndescent
     :alt: Travis Build Status
 .. image:: https://ci.appveyor.com/api/projects/status/github/lmcinnes/pynndescent?branch=master&svg=true
     :target: https://ci.appveyor.com/project/lmcinnes/pynndescent
@@ -45,14 +50,14 @@ PyNNDescent provides fast approximate nearest neighbor queries. The
 `ann-benchmarks <https://github.com/erikbern/ann-benchmarks>`_ system puts it
 solidly in the mix of top performing ANN libraries:
 
-**GIST-960 Euclidean**
+**SIFT-128 Euclidean**
 
-.. image:: https://camo.githubusercontent.com/142a48c992ba689b8ea9e62636b5281a97322f74/68747470733a2f2f7261772e6769746875622e636f6d2f6572696b6265726e2f616e6e2d62656e63686d61726b732f6d61737465722f726573756c74732f676973742d3936302d6575636c696465616e2e706e67
-    :alt: ANN benchmark performance for GIST 960 dataset
+.. image:: https://pynndescent.readthedocs.io/en/latest/_images/sift.png
+    :alt: ANN benchmark performance for SIFT 128 dataset
 
 **NYTimes-256 Angular**
 
-.. image:: https://camo.githubusercontent.com/6120a35a9db64104eaa1c95cb4803c2fc4cd2679/68747470733a2f2f7261772e6769746875622e636f6d2f6572696b6265726e2f616e6e2d62656e63686d61726b732f6d61737465722f726573756c74732f6e7974696d65732d3235362d616e67756c61722e706e67
+.. image:: https://pynndescent.readthedocs.io/en/latest/_images/nytimes.png
     :alt: ANN benchmark performance for NYTimes 256 dataset
 
 While PyNNDescent is among fastest ANN library, it is also both easy to install (pip


=====================================
pynndescent.egg-info/PKG-INFO
=====================================
@@ -1,6 +1,6 @@
 Metadata-Version: 1.2
 Name: pynndescent
-Version: 0.5.2
+Version: 0.5.5
 Summary: Nearest Neighbor Descent
 Home-page: http://github.com/lmcinnes/pynndescent
 Author: Leland McInnes
@@ -8,8 +8,13 @@ Author-email: leland.mcinnes at gmail.com
 Maintainer: Leland McInnes
 Maintainer-email: leland.mcinnes at gmail.com
 License: BSD
-Description: .. image:: https://travis-ci.org/lmcinnes/pynndescent.svg
-            :target: https://travis-ci.org/lmcinnes/pynndescent
+Description: .. image:: doc/pynndescent_logo.png
+          :width: 600
+          :align: center
+          :alt: PyNNDescent Logo
+        
+        .. image:: https://travis-ci.com/lmcinnes/pynndescent.svg
+            :target: https://travis-ci.com/lmcinnes/pynndescent
             :alt: Travis Build Status
         .. image:: https://ci.appveyor.com/api/projects/status/github/lmcinnes/pynndescent?branch=master&svg=true
             :target: https://ci.appveyor.com/project/lmcinnes/pynndescent
@@ -55,14 +60,14 @@ Description: .. image:: https://travis-ci.org/lmcinnes/pynndescent.svg
         `ann-benchmarks <https://github.com/erikbern/ann-benchmarks>`_ system puts it
         solidly in the mix of top performing ANN libraries:
         
-        **GIST-960 Euclidean**
+        **SIFT-128 Euclidean**
         
-        .. image:: https://camo.githubusercontent.com/142a48c992ba689b8ea9e62636b5281a97322f74/68747470733a2f2f7261772e6769746875622e636f6d2f6572696b6265726e2f616e6e2d62656e63686d61726b732f6d61737465722f726573756c74732f676973742d3936302d6575636c696465616e2e706e67
-            :alt: ANN benchmark performance for GIST 960 dataset
+        .. image:: https://pynndescent.readthedocs.io/en/latest/_images/sift.png
+            :alt: ANN benchmark performance for SIFT 128 dataset
         
         **NYTimes-256 Angular**
         
-        .. image:: https://camo.githubusercontent.com/6120a35a9db64104eaa1c95cb4803c2fc4cd2679/68747470733a2f2f7261772e6769746875622e636f6d2f6572696b6265726e2f616e6e2d62656e63686d61726b732f6d61737465722f726573756c74732f6e7974696d65732d3235362d616e67756c61722e706e67
+        .. image:: https://pynndescent.readthedocs.io/en/latest/_images/nytimes.png
             :alt: ANN benchmark performance for NYTimes 256 dataset
         
         While PyNNDescent is among fastest ANN library, it is also both easy to install (pip


=====================================
pynndescent.egg-info/SOURCES.txt
=====================================
@@ -22,15 +22,8 @@ pynndescent.egg-info/not-zip-safe
 pynndescent.egg-info/requires.txt
 pynndescent.egg-info/top_level.txt
 pynndescent/tests/__init__.py
+pynndescent/tests/conftest.py
 pynndescent/tests/test_distances.py
 pynndescent/tests/test_pynndescent_.py
 pynndescent/tests/test_rank.py
-pynndescent/tests/__pycache__/__init__.cpython-37.pyc
-pynndescent/tests/__pycache__/__init__.cpython-38.pyc
-pynndescent/tests/__pycache__/test_distances.cpython-37.pyc
-pynndescent/tests/__pycache__/test_distances.cpython-38-pytest-6.2.2.pyc
-pynndescent/tests/__pycache__/test_pynndescent_.cpython-37.pyc
-pynndescent/tests/__pycache__/test_pynndescent_.cpython-38-pytest-6.2.2.pyc
-pynndescent/tests/__pycache__/test_rank.cpython-37.pyc
-pynndescent/tests/__pycache__/test_rank.cpython-38-pytest-6.2.2.pyc
 pynndescent/tests/test_data/cosine_hang.npy
\ No newline at end of file


=====================================
pynndescent/__init__.py
=====================================
@@ -3,6 +3,13 @@ import numba
 from .pynndescent_ import NNDescent, PyNNDescentTransformer
 
 # Workaround: https://github.com/numba/numba/issues/3341
-numba.config.THREADING_LAYER = "workqueue"
+if numba.config.THREADING_LAYER == "omp":
+    try:
+        from numba.np.ufunc import tbbpool
+
+        numba.config.THREADING_LAYER = "tbb"
+    except ImportError as e:
+        # might be a missing symbol due to e.g. tbb libraries missing
+        numba.config.THREADING_LAYER = "workqueue"
 
 __version__ = pkg_resources.get_distribution("pynndescent").version


=====================================
pynndescent/distances.py
=====================================
@@ -12,6 +12,7 @@ from pynndescent.optimal_transport import (
     network_simplex_core,
     total_cost,
     ProblemStatus,
+    sinkhorn_transport_plan,
 )
 
 _mock_identity = np.eye(2, dtype=np.float32)
@@ -22,7 +23,7 @@ FLOAT32_EPS = np.finfo(np.float32).eps
 FLOAT32_MAX = np.finfo(np.float32).max
 
 
- at numba.njit(fastmath=True, cache=True)
+ at numba.njit(fastmath=True)
 def euclidean(x, y):
     r"""Standard euclidean distance.
 
@@ -50,6 +51,7 @@ def euclidean(x, y):
         "dim": numba.types.intp,
         "i": numba.types.uint16,
     },
+
 )
 def squared_euclidean(x, y):
     r"""Squared euclidean distance.
@@ -66,7 +68,7 @@ def squared_euclidean(x, y):
     return result
 
 
- at numba.njit(fastmath=True, cache=True)
+ at numba.njit(fastmath=True)
 def standardised_euclidean(x, y, sigma=_mock_ones):
     r"""Euclidean distance standardised against a vector of standard
     deviations per coordinate.
@@ -81,7 +83,7 @@ def standardised_euclidean(x, y, sigma=_mock_ones):
     return np.sqrt(result)
 
 
- at numba.njit(fastmath=True, cache=True)
+ at numba.njit(fastmath=True)
 def manhattan(x, y):
     r"""Manhattan, taxicab, or l1 distance.
 
@@ -95,7 +97,7 @@ def manhattan(x, y):
     return result
 
 
- at numba.njit(fastmath=True, cache=True)
+ at numba.njit(fastmath=True)
 def chebyshev(x, y):
     r"""Chebyshev or l-infinity distance.
 
@@ -109,7 +111,7 @@ def chebyshev(x, y):
     return result
 
 
- at numba.njit(fastmath=True, cache=True)
+ at numba.njit(fastmath=True)
 def minkowski(x, y, p=2):
     r"""Minkowski distance.
 
@@ -128,7 +130,7 @@ def minkowski(x, y, p=2):
     return result ** (1.0 / p)
 
 
- at numba.njit(fastmath=True, cache=True)
+ at numba.njit(fastmath=True)
 def weighted_minkowski(x, y, w=_mock_ones, p=2):
     r"""A weighted version of Minkowski distance.
 
@@ -146,7 +148,7 @@ def weighted_minkowski(x, y, w=_mock_ones, p=2):
     return result ** (1.0 / p)
 
 
- at numba.njit(fastmath=True, cache=True)
+ at numba.njit(fastmath=True)
 def mahalanobis(x, y, vinv=_mock_identity):
     result = 0.0
 
@@ -164,7 +166,7 @@ def mahalanobis(x, y, vinv=_mock_identity):
     return np.sqrt(result)
 
 
- at numba.njit(fastmath=True, cache=True)
+ at numba.njit(fastmath=True)
 def hamming(x, y):
     result = 0.0
     for i in range(x.shape[0]):
@@ -174,7 +176,7 @@ def hamming(x, y):
     return float(result) / x.shape[0]
 
 
- at numba.njit(fastmath=True, cache=True)
+ at numba.njit(fastmath=True)
 def canberra(x, y):
     result = 0.0
     for i in range(x.shape[0]):
@@ -185,7 +187,7 @@ def canberra(x, y):
     return result
 
 
- at numba.njit(fastmath=True, cache=True)
+ at numba.njit(fastmath=True)
 def bray_curtis(x, y):
     numerator = 0.0
     denominator = 0.0
@@ -199,7 +201,7 @@ def bray_curtis(x, y):
         return 0.0
 
 
- at numba.njit(fastmath=True, cache=True)
+ at numba.njit(fastmath=True)
 def jaccard(x, y):
     num_non_zero = 0.0
     num_equal = 0.0
@@ -233,6 +235,7 @@ def jaccard(x, y):
         "dim": numba.types.intp,
         "i": numba.types.uint16,
     },
+
 )
 def alternative_jaccard(x, y):
     num_non_zero = 0.0
@@ -255,7 +258,7 @@ def correct_alternative_jaccard(v):
     return 1.0 - pow(2.0, -v)
 
 
- at numba.njit(fastmath=True, cache=True)
+ at numba.njit(fastmath=True)
 def matching(x, y):
     num_not_equal = 0.0
     for i in range(x.shape[0]):
@@ -266,7 +269,7 @@ def matching(x, y):
     return float(num_not_equal) / x.shape[0]
 
 
- at numba.njit(fastmath=True, cache=True)
+ at numba.njit(fastmath=True)
 def dice(x, y):
     num_true_true = 0.0
     num_not_equal = 0.0
@@ -282,7 +285,7 @@ def dice(x, y):
         return num_not_equal / (2.0 * num_true_true + num_not_equal)
 
 
- at numba.njit(fastmath=True, cache=True)
+ at numba.njit(fastmath=True)
 def kulsinski(x, y):
     num_true_true = 0.0
     num_not_equal = 0.0
@@ -300,7 +303,7 @@ def kulsinski(x, y):
         )
 
 
- at numba.njit(fastmath=True, cache=True)
+ at numba.njit(fastmath=True)
 def rogers_tanimoto(x, y):
     num_not_equal = 0.0
     for i in range(x.shape[0]):
@@ -311,7 +314,7 @@ def rogers_tanimoto(x, y):
     return (2.0 * num_not_equal) / (x.shape[0] + num_not_equal)
 
 
- at numba.njit(fastmath=True, cache=True)
+ at numba.njit(fastmath=True)
 def russellrao(x, y):
     num_true_true = 0.0
     for i in range(x.shape[0]):
@@ -325,7 +328,7 @@ def russellrao(x, y):
         return float(x.shape[0] - num_true_true) / (x.shape[0])
 
 
- at numba.njit(fastmath=True, cache=True)
+ at numba.njit(fastmath=True)
 def sokal_michener(x, y):
     num_not_equal = 0.0
     for i in range(x.shape[0]):
@@ -336,7 +339,7 @@ def sokal_michener(x, y):
     return (2.0 * num_not_equal) / (x.shape[0] + num_not_equal)
 
 
- at numba.njit(fastmath=True, cache=True)
+ at numba.njit(fastmath=True)
 def sokal_sneath(x, y):
     num_true_true = 0.0
     num_not_equal = 0.0
@@ -352,7 +355,7 @@ def sokal_sneath(x, y):
         return num_not_equal / (0.5 * num_true_true + num_not_equal)
 
 
- at numba.njit(fastmath=True, cache=True)
+ at numba.njit(fastmath=True)
 def haversine(x, y):
     if x.shape[0] != 2:
         raise ValueError("haversine is only defined for 2 dimensional graph_data")
@@ -362,7 +365,7 @@ def haversine(x, y):
     return 2.0 * np.arcsin(result)
 
 
- at numba.njit(fastmath=True, cache=True)
+ at numba.njit(fastmath=True)
 def yule(x, y):
     num_true_true = 0.0
     num_true_false = 0.0
@@ -384,7 +387,7 @@ def yule(x, y):
         )
 
 
- at numba.njit(fastmath=True, cache=True)
+ at numba.njit(fastmath=True)
 def cosine(x, y):
     result = 0.0
     norm_x = 0.0
@@ -418,6 +421,7 @@ def cosine(x, y):
         "dim": numba.types.intp,
         "i": numba.types.uint16,
     },
+
 )
 def alternative_cosine(x, y):
     result = 0.0
@@ -448,6 +452,7 @@ def alternative_cosine(x, y):
         "dim": numba.types.intp,
         "i": numba.types.uint16,
     },
+
 )
 def dot(x, y):
     result = 0.0
@@ -475,6 +480,7 @@ def dot(x, y):
         "dim": numba.types.intp,
         "i": numba.types.uint16,
     },
+
 )
 def alternative_dot(x, y):
     result = 0.0
@@ -546,7 +552,7 @@ def true_angular_from_alt_cosine(d):
     return 1.0 - (np.arccos(pow(2.0, -d)) / np.pi)
 
 
- at numba.njit(fastmath=True, cache=True)
+ at numba.njit(fastmath=True)
 def correlation(x, y):
     mu_x = 0.0
     mu_y = 0.0
@@ -728,9 +734,7 @@ def kantorovich(x, y, cost=_dummy_cost, max_iter=100000):
     sub_cost = cost[row_mask, :][:, col_mask]
 
     node_arc_data, spanning_tree, graph = allocate_graph_structures(
-        a.shape[0],
-        b.shape[0],
-        False,
+        a.shape[0], b.shape[0], False
     )
     initialize_supply(a, -b, graph, node_arc_data.supply)
     initialize_cost(sub_cost, graph, node_arc_data.cost)
@@ -740,12 +744,7 @@ def kantorovich(x, y, cost=_dummy_cost, max_iter=100000):
         raise ValueError(
             "Kantorovich distance inputs must be valid probability distributions."
         )
-    solve_status = network_simplex_core(
-        node_arc_data,
-        spanning_tree,
-        graph,
-        max_iter,
-    )
+    solve_status = network_simplex_core(node_arc_data, spanning_tree, graph, max_iter)
     # if solve_status == ProblemStatus.MAX_ITER_REACHED:
     #     print("WARNING: RESULT MIGHT BE INACCURATE\nMax number of iteration reached!")
     if solve_status == ProblemStatus.INFEASIBLE:
@@ -761,6 +760,86 @@ def kantorovich(x, y, cost=_dummy_cost, max_iter=100000):
     return result
 
 
+ at numba.njit(fastmath=True)
+def sinkhorn(x, y, cost=_dummy_cost, regularization=1.0):
+    row_mask = x != 0
+    col_mask = y != 0
+
+    a = x[row_mask].astype(np.float64)
+    b = y[col_mask].astype(np.float64)
+
+    a_sum = a.sum()
+    b_sum = b.sum()
+
+    a /= a_sum
+    b /= b_sum
+
+    sub_cost = cost[row_mask, :][:, col_mask]
+
+    transport_plan = sinkhorn_transport_plan(
+        x, y, cost=sub_cost, regularization=regularization
+    )
+    dim_i = transport_plan.shape[0]
+    dim_j = transport_plan.shape[1]
+    result = 0.0
+    for i in range(dim_i):
+        for j in range(dim_j):
+            result += transport_plan[i, j] * cost[i, j]
+
+    return result
+
+
+ at numba.njit()
+def jensen_shannon_divergence(x, y):
+    result = 0.0
+    l1_norm_x = 0.0
+    l1_norm_y = 0.0
+    dim = x.shape[0]
+
+    for i in range(dim):
+        l1_norm_x += x[i]
+        l1_norm_y += y[i]
+
+    l1_norm_x += FLOAT32_EPS * dim
+    l1_norm_y += FLOAT32_EPS * dim
+
+    pdf_x = (x + FLOAT32_EPS) / l1_norm_x
+    pdf_y = (y + FLOAT32_EPS) / l1_norm_y
+    m = 0.5 * (pdf_x + pdf_y)
+
+    for i in range(dim):
+        result += 0.5 * (
+            pdf_x[i] * np.log(pdf_x[i] / m[i]) + pdf_y[i] * np.log(pdf_y[i] / m[i])
+        )
+
+    return result
+
+
+ at numba.njit()
+def symmetric_kl_divergence(x, y):
+    result = 0.0
+    l1_norm_x = 0.0
+    l1_norm_y = 0.0
+    dim = x.shape[0]
+
+    for i in range(dim):
+        l1_norm_x += x[i]
+        l1_norm_y += y[i]
+
+    l1_norm_x += FLOAT32_EPS * dim
+    l1_norm_y += FLOAT32_EPS * dim
+
+    pdf_x = (x + FLOAT32_EPS) / l1_norm_x
+    pdf_y = (y + FLOAT32_EPS) / l1_norm_y
+
+    for i in range(dim):
+        result += pdf_x[i] * np.log(pdf_x[i] / pdf_y[i]) + pdf_y[i] * np.log(
+            pdf_y[i] / pdf_x[i]
+        )
+
+    return result
+
+
 named_distances = {
     # general minkowski distances
     "euclidean": euclidean,
@@ -785,14 +864,21 @@ named_distances = {
     "cosine": cosine,
     "dot": dot,
     "correlation": correlation,
-    "hellinger": hellinger,
     "haversine": haversine,
     "braycurtis": bray_curtis,
     "spearmanr": spearmanr,
-    "kantorovich": kantorovich,
-    "wasserstein": kantorovich,
     "tsss": tsss,
     "true_angular": true_angular,
+    # Distribution distances
+    "hellinger": hellinger,
+    "kantorovich": kantorovich,
+    "wasserstein": kantorovich,
+    "sinkhorn": sinkhorn,
+    "jensen-shannon": jensen_shannon_divergence,
+    "jensen_shannon": jensen_shannon_divergence,
+    "symmetric-kl": symmetric_kl_divergence,
+    "symmetric_kl": symmetric_kl_divergence,
+    "symmetric_kullback_liebler": symmetric_kl_divergence,
     # Binary distances
     "hamming": hamming,
     "jaccard": jaccard,


=====================================
pynndescent/graph_utils.py
=====================================
@@ -53,13 +53,7 @@ def create_component_search(index):
             "seed_scale": numba.types.float32,
         },
     )
-    def custom_search_closure(
-        query_points,
-        candidate_indices,
-        k,
-        epsilon,
-        visited,
-    ):
+    def custom_search_closure(query_points, candidate_indices, k, epsilon, visited):
         result = make_heap(query_points.shape[0], k)
         distance_scale = 1.0 + epsilon
 
@@ -180,8 +174,7 @@ def adjacency_matrix_representation(neighbor_indices, neighbor_distances):
     neighbor_distances[neighbor_distances == 0.0] = FLOAT32_EPS
 
     result.row = np.repeat(
-        np.arange(neighbor_indices.shape[0], dtype=np.int32),
-        neighbor_indices.shape[1],
+        np.arange(neighbor_indices.shape[0], dtype=np.int32), neighbor_indices.shape[1]
     )
     result.col = neighbor_indices.ravel()
     result.data = neighbor_distances.ravel()


=====================================
pynndescent/optimal_transport.py
=====================================
@@ -31,6 +31,9 @@ import numba
 from collections import namedtuple
 from enum import Enum, IntEnum
 
+_mock_identity = np.eye(2, dtype=np.float32)
+_mock_ones = np.ones(2, dtype=np.float32)
+_dummy_cost = np.zeros((2, 2), dtype=np.float64)
 
 # Accuracy tolerance and net supply tolerance
 EPSILON = 2.2204460492503131e-15
@@ -223,14 +226,9 @@ def find_join_node(source, target, succ_num, parent, in_arc):
         "second": numba.uint16,
         "result": numba.uint8,
         "in_arc": numba.uint32,
-    },
+    }
 )
-def find_leaving_arc(
-    join,
-    in_arc,
-    node_arc_data,
-    spanning_tree,
-):
+def find_leaving_arc(join, in_arc, node_arc_data, spanning_tree):
     source = node_arc_data.source
     target = node_arc_data.target
     flow = node_arc_data.flow
@@ -299,20 +297,8 @@ def find_leaving_arc(
 # Change _flow and _state vectors
 # locals: val, u
 # modifies: _state, _flow
- at numba.njit(
-    locals={
-        "u": numba.uint16,
-        "in_arc": numba.uint32,
-        "val": numba.float64,
-    },
-)
-def update_flow(
-    join,
-    leaving_arc_data,
-    node_arc_data,
-    spanning_tree,
-    in_arc,
-):
+ at numba.njit(locals={"u": numba.uint16, "in_arc": numba.uint32, "val": numba.float64})
+def update_flow(join, leaving_arc_data, node_arc_data, spanning_tree, in_arc):
     source = node_arc_data.source
     target = node_arc_data.target
     flow = node_arc_data.flow
@@ -372,15 +358,9 @@ def update_flow(
         "new_stem": numba.uint16,
         "par_stem": numba.uint16,
         "in_arc": numba.uint32,
-    },
+    }
 )
-def update_spanning_tree(
-    spanning_tree,
-    leaving_arc_data,
-    join,
-    in_arc,
-    source,
-):
+def update_spanning_tree(spanning_tree, leaving_arc_data, join, in_arc, source):
 
     parent = spanning_tree.parent
     thread = spanning_tree.thread
@@ -863,12 +843,7 @@ def total_cost(flow, cost):
 
 
 @numba.njit(nogil=True)
-def network_simplex_core(
-    node_arc_data,
-    spanning_tree,
-    graph,
-    max_iter,
-):
+def network_simplex_core(node_arc_data, spanning_tree, graph, max_iter):
 
     # pivot_block = PivotBlock(
     #     max(np.int32(np.sqrt(graph.n_arcs)), 10),
@@ -949,3 +924,271 @@ def network_simplex_core(
             pi[i] -= max_pot
 
     return solution_status
+
+
+#######################################################
+# SINKHORN distances in various variations
+#######################################################
+
+
+ at numba.njit(
+    fastmath=True,
+    parallel=True,
+    locals={"diff": numba.float32, "result": numba.float32},
+    cache=True,
+)
+def right_marginal_error(u, K, v, y):
+    uK = u @ K
+    result = 0.0
+    for i in numba.prange(uK.shape[0]):
+        diff = y[i] - uK[i] * v[i]
+        result += diff * diff
+    return np.sqrt(result)
+
+
+ at numba.njit(
+    fastmath=True,
+    parallel=True,
+    locals={"diff": numba.float32, "result": numba.float32},
+    cache=True,
+)
+def right_marginal_error_batch(u, K, v, y):
+    uK = K.T @ u
+    result = 0.0
+    for i in numba.prange(uK.shape[0]):
+        for j in range(uK.shape[1]):
+            diff = y[j, i] - uK[i, j] * v[i, j]
+            result += diff * diff
+    return np.sqrt(result)
+
+
+ at numba.njit(fastmath=True, parallel=True, cache=True)
+def transport_plan(K, u, v):
+    i_dim = K.shape[0]
+    j_dim = K.shape[1]
+    result = np.empty_like(K)
+    for i in numba.prange(i_dim):
+        for j in range(j_dim):
+            result[i, j] = u[i] * K[i, j] * v[j]
+
+    return result
+
+
+ at numba.njit(fastmath=True, parallel=True, locals={"result": numba.float32}, cache=True)
+def relative_change_in_plan(old_u, old_v, new_u, new_v):
+    i_dim = old_u.shape[0]
+    j_dim = old_v.shape[0]
+    result = 0.0
+    for i in numba.prange(i_dim):
+        for j in range(j_dim):
+            old_uv = old_u[i] * old_v[j]
+            result += np.float32(np.abs(old_uv - new_u[i] * new_v[j]) / old_uv)
+
+    return result / (i_dim * j_dim)
+
+
+ at numba.njit(fastmath=True, parallel=True, cache=True)
+def precompute_K_prime(K, x):
+    i_dim = K.shape[0]
+    j_dim = K.shape[1]
+    result = np.empty_like(K)
+    for i in numba.prange(i_dim):
+        if x[i] > 0.0:
+            x_i_inverse = 1.0 / x[i]
+        else:
+            x_i_inverse = INFINITY
+        for j in range(j_dim):
+            result[i, j] = x_i_inverse * K[i, j]
+
+    return result
+
+
+ at numba.njit(fastmath=True, parallel=True, cache=True)
+def K_from_cost(cost, regularization):
+    i_dim = cost.shape[0]
+    j_dim = cost.shape[1]
+    result = np.empty_like(cost)
+    for i in numba.prange(i_dim):
+        for j in range(j_dim):
+            scaled_cost = cost[i, j] / regularization
+            result[i, j] = np.exp(-scaled_cost)
+
+    return result
+
+
+ at numba.njit(fastmath=True, cache=True)
+def sinkhorn_iterations(
+    x, y, u, v, K, max_iter=1000, error_tolerance=1e-9, change_tolerance=1e-9
+):
+    K_prime = precompute_K_prime(K, x)
+
+    prev_u = u
+    prev_v = v
+
+    for iteration in range(max_iter):
+
+        next_v = y / (K.T @ u)
+
+        if np.any(~np.isfinite(next_v)):
+            break
+
+        next_u = 1.0 / (K_prime @ next_v)
+
+        if np.any(~np.isfinite(next_u)):
+            break
+
+        u = next_u
+        v = next_v
+
+        if iteration % 20 == 0:
+            # Check if values in plan have changed significantly since last 20 iterations
+            relative_change = relative_change_in_plan(prev_u, prev_v, next_u, next_v)
+            if relative_change <= change_tolerance:
+                break
+
+            prev_u = u
+            prev_v = v
+
+        if iteration % 10 == 0:
+            # Check if right marginal error is less than tolerance every 10 iterations
+            err = right_marginal_error(u, K, v, y)
+            if err <= error_tolerance:
+                break
+
+    return u, v
+
+
+ at numba.njit(fastmath=True, cache=True)
+def sinkhorn_iterations_batch(x, y, u, v, K, max_iter=1000, error_tolerance=1e-9):
+    K_prime = precompute_K_prime(K, x)
+
+    for iteration in range(max_iter):
+
+        next_v = y.T / (K.T @ u)
+
+        if np.any(~np.isfinite(next_v)):
+            break
+
+        next_u = 1.0 / (K_prime @ next_v)
+
+        if np.any(~np.isfinite(next_u)):
+            break
+
+        u = next_u
+        v = next_v
+
+        if iteration % 10 == 0:
+            # Check if right marginal error is less than tolerance every 10 iterations
+            err = right_marginal_error_batch(u, K, v, y)
+            if err <= error_tolerance:
+                break
+
+    return u, v
+
+
+ at numba.njit(fastmath=True, cache=True)
+def sinkhorn_transport_plan(
+    x,
+    y,
+    cost=_dummy_cost,
+    regularization=1.0,
+    max_iter=1000,
+    error_tolerance=1e-9,
+    change_tolerance=1e-9,
+):
+    dim_x = x.shape[0]
+    dim_y = y.shape[0]
+    u = np.full(dim_x, 1.0 / dim_x, dtype=cost.dtype)
+    v = np.full(dim_y, 1.0 / dim_y, dtype=cost.dtype)
+
+    K = K_from_cost(cost, regularization)
+    u, v = sinkhorn_iterations(
+        x,
+        y,
+        u,
+        v,
+        K,
+        max_iter=max_iter,
+        error_tolerance=error_tolerance,
+        change_tolerance=change_tolerance,
+    )
+
+    return transport_plan(K, u, v)
+
+
+ at numba.njit(fastmath=True, cache=True)
+def sinkhorn_distance(x, y, cost=_dummy_cost, regularization=1.0):
+    transport_plan = sinkhorn_transport_plan(
+        x, y, cost=cost, regularization=regularization
+    )
+    dim_i = transport_plan.shape[0]
+    dim_j = transport_plan.shape[1]
+    result = 0.0
+    for i in range(dim_i):
+        for j in range(dim_j):
+            result += transport_plan[i, j] * cost[i, j]
+
+    return result
+
+
+ at numba.njit(fastmath=True, parallel=True, cache=True)
+def sinkhorn_distance_batch(x, y, cost=_dummy_cost, regularization=1.0):
+    dim_x = x.shape[0]
+    dim_y = y.shape[0]
+
+    batch_size = y.shape[1]
+
+    u = np.full((dim_x, batch_size), 1.0 / dim_x, dtype=cost.dtype)
+    v = np.full((dim_y, batch_size), 1.0 / dim_y, dtype=cost.dtype)
+
+    K = K_from_cost(cost, regularization)
+    u, v = sinkhorn_iterations_batch(
+        x,
+        y,
+        u,
+        v,
+        K,
+    )
+
+    i_dim = K.shape[0]
+    j_dim = K.shape[1]
+    result = np.zeros(batch_size)
+    for i in range(i_dim):
+        for j in range(j_dim):
+            K_times_cost = K[i, j] * cost[i, j]
+            for batch in range(batch_size):
+                result[batch] += u[i, batch] * K_times_cost * v[j, batch]
+
+    return result
+
+
+def make_fixed_cost_sinkhorn_distance(cost, regularization=1.0):
+
+    K = K_from_cost(cost, regularization)
+    dim_x = K.shape[0]
+    dim_y = K.shape[1]
+
+    @numba.njit(fastmath=True)
+    def closure(x, y):
+        u = np.full(dim_x, 1.0 / dim_x, dtype=cost.dtype)
+        v = np.full(dim_y, 1.0 / dim_y, dtype=cost.dtype)
+
+        K = K_from_cost(cost, regularization)
+        u, v = sinkhorn_iterations(
+            x,
+            y,
+            u,
+            v,
+            K,
+        )
+
+        current_plan = transport_plan(K, u, v)
+
+        result = 0.0
+        for i in range(dim_x):
+            for j in range(dim_y):
+                result += current_plan[i, j] * cost[i, j]
+
+        return result
+
+    return closure


=====================================
pynndescent/pynndescent_.py
=====================================
@@ -9,12 +9,7 @@ import numpy as np
 from sklearn.utils import check_random_state, check_array
 from sklearn.preprocessing import normalize
 from sklearn.base import BaseEstimator, TransformerMixin
-from scipy.sparse import (
-    csr_matrix,
-    coo_matrix,
-    isspmatrix_csr,
-    vstack as sparse_vstack,
-)
+from scipy.sparse import csr_matrix, coo_matrix, isspmatrix_csr, vstack as sparse_vstack
 
 import heapq
 
@@ -30,7 +25,6 @@ from pynndescent.utils import (
     new_build_candidates,
     ts,
     simple_heap_push,
-    flagged_heap_push,
     checked_flagged_heap_push,
     has_been_visited,
     mark_visited,
@@ -64,7 +58,7 @@ FLOAT32_EPS = np.finfo(np.float32).eps
 EMPTY_GRAPH = make_heap(1, 1)
 
 
- at numba.njit(parallel=True)
+ at numba.njit(parallel=True, cache=True)
 def generate_leaf_updates(leaf_block, dist_thresholds, data, dist):
 
     updates = [[(-1, -1, np.inf)] for i in range(leaf_block.shape[0])]
@@ -87,13 +81,7 @@ def generate_leaf_updates(leaf_block, dist_thresholds, data, dist):
     return updates
 
 
- at numba.njit(
-    locals={
-        "d": numba.float32,
-        "p": numba.int32,
-        "q": numba.int32,
-    }
-)
+ at numba.njit(locals={"d": numba.float32, "p": numba.int32, "q": numba.int32}, cache=True)
 def init_rp_tree(data, dist, current_graph, leaf_array):
 
     n_leaves = leaf_array.shape[0]
@@ -107,12 +95,7 @@ def init_rp_tree(data, dist, current_graph, leaf_array):
         leaf_block = leaf_array[block_start:block_end]
         dist_thresholds = current_graph[1][:, 0]
 
-        updates = generate_leaf_updates(
-            leaf_block,
-            dist_thresholds,
-            data,
-            dist,
-        )
+        updates = generate_leaf_updates(leaf_block, dist_thresholds, data, dist)
 
         for j in range(len(updates)):
             for k in range(len(updates[j])):
@@ -121,8 +104,6 @@ def init_rp_tree(data, dist, current_graph, leaf_array):
                 if p == -1 or q == -1:
                     continue
 
-                # heap_push(current_graph, p, d, q, 1)
-                # heap_push(current_graph, q, d, p, 1)
                 checked_flagged_heap_push(
                     current_graph[1][p],
                     current_graph[0][p],
@@ -142,7 +123,9 @@ def init_rp_tree(data, dist, current_graph, leaf_array):
 
 
 @numba.njit(
-    fastmath=True, locals={"d": numba.float32, "idx": numba.int32, "i": numba.int32}
+    fastmath=True,
+    locals={"d": numba.float32, "idx": numba.int32, "i": numba.int32},
+    cache=True,
 )
 def init_random(n_neighbors, data, heap, dist, rng_state):
     for i in range(data.shape[0]):
@@ -150,7 +133,6 @@ def init_random(n_neighbors, data, heap, dist, rng_state):
             for j in range(n_neighbors - np.sum(heap[0][i] >= 0.0)):
                 idx = np.abs(tau_rand_int(rng_state)) % data.shape[0]
                 d = dist(data[idx], data[i])
-                # heap_push(heap, i, d, idx, 1)
                 checked_flagged_heap_push(
                     heap[1][i], heap[0][i], heap[2][i], d, idx, np.uint8(1)
                 )
@@ -158,25 +140,20 @@ def init_random(n_neighbors, data, heap, dist, rng_state):
     return
 
 
- at numba.njit()
+ at numba.njit(cache=True)
 def init_from_neighbor_graph(heap, indices, distances):
     for p in range(indices.shape[0]):
         for k in range(indices.shape[1]):
             q = indices[p, k]
             d = distances[p, k]
-            # unchecked_heap_push(heap, p, d, q, 0)
-            flagged_heap_push(heap[0][p], heap[1][p], heap[2][p], q, d, 0)
+            checked_flagged_heap_push(heap[1][p], heap[0][p], heap[2][p], d, q, 0)
 
     return
 
 
- at numba.njit(parallel=True)
+ at numba.njit(parallel=True, cache=True)
 def generate_graph_updates(
-    new_candidate_block,
-    old_candidate_block,
-    dist_thresholds,
-    data,
-    dist,
+    new_candidate_block, old_candidate_block, dist_thresholds, data, dist
 ):
 
     block_size = new_candidate_block.shape[0]
@@ -210,7 +187,7 @@ def generate_graph_updates(
     return updates
 
 
- at numba.njit()
+ at numba.njit(cache=True)
 def process_candidates(
     data,
     dist,
@@ -219,6 +196,7 @@ def process_candidates(
     old_candidate_neighbors,
     n_blocks,
     block_size,
+    n_threads,
 ):
     c = 0
     n_vertices = new_candidate_neighbors.shape[0]
@@ -232,14 +210,10 @@ def process_candidates(
         dist_thresholds = current_graph[1][:, 0]
 
         updates = generate_graph_updates(
-            new_candidate_block,
-            old_candidate_block,
-            dist_thresholds,
-            data,
-            dist,
+            new_candidate_block, old_candidate_block, dist_thresholds, data, dist
         )
 
-        c += apply_graph_updates_low_memory(current_graph, updates)
+        c += apply_graph_updates_low_memory(current_graph, updates, n_threads)
 
     return c
 
@@ -259,15 +233,14 @@ def nn_descent_internal_low_memory_parallel(
     n_vertices = data.shape[0]
     block_size = 16384
     n_blocks = n_vertices // block_size
+    n_threads = numba.get_num_threads()
 
     for n in range(n_iters):
         if verbose:
             print("\t", n + 1, " / ", n_iters)
 
         (new_candidate_neighbors, old_candidate_neighbors) = new_build_candidates(
-            current_graph,
-            max_candidates,
-            rng_state,
+            current_graph, max_candidates, rng_state, n_threads
         )
 
         c = process_candidates(
@@ -278,6 +251,7 @@ def nn_descent_internal_low_memory_parallel(
             old_candidate_neighbors,
             n_blocks,
             block_size,
+            n_threads,
         )
 
         if c <= delta * n_neighbors * data.shape[0]:
@@ -301,6 +275,7 @@ def nn_descent_internal_high_memory_parallel(
     n_vertices = data.shape[0]
     block_size = 16384
     n_blocks = n_vertices // block_size
+    n_threads = numba.get_num_threads()
 
     in_graph = [
         set(current_graph[0][i].astype(np.int64))
@@ -312,9 +287,7 @@ def nn_descent_internal_high_memory_parallel(
             print("\t", n + 1, " / ", n_iters)
 
         (new_candidate_neighbors, old_candidate_neighbors) = new_build_candidates(
-            current_graph,
-            max_candidates,
-            rng_state,
+            current_graph, max_candidates, rng_state, n_threads
         )
 
         c = 0
@@ -327,11 +300,7 @@ def nn_descent_internal_high_memory_parallel(
             dist_thresholds = current_graph[1][:, 0]
 
             updates = generate_graph_updates(
-                new_candidate_block,
-                old_candidate_block,
-                dist_thresholds,
-                data,
-                dist,
+                new_candidate_block, old_candidate_block, dist_thresholds, data, dist
             )
 
             c += apply_graph_updates_high_memory(current_graph, updates, in_graph)
@@ -465,8 +434,7 @@ def diversify_csr(
                 if retained[l] == 1:
 
                     d = dist(
-                        source_data[current_indices[j]],
-                        source_data[current_indices[k]],
+                        source_data[current_indices[j]], source_data[current_indices[k]]
                     )
                     if current_data[l] > FLOAT32_EPS and d < current_data[j]:
                         if tau_rand(rng_state) < prune_probability:
@@ -605,6 +573,12 @@ class NNDescent(object):
         that no edges get removed, and larger values result in significantly more
         aggressive edge removal. A value of 1.0 will prune all edges that it can.
 
+    n_search_trees: int (optional, default=1)
+        The number of random projection trees to use in initializing searching or
+        querying.
+
+        .. deprecated:: 0.5.5
+
     tree_init: bool (optional, default=True)
         Whether to use random projection trees for initialization.
 
@@ -840,13 +814,7 @@ class NNDescent(object):
 
                 @numba.njit()
                 def _partial_dist_func(ind1, data1, ind2, data2):
-                    return _distance_func(
-                        ind1,
-                        data1,
-                        ind2,
-                        data2,
-                        *dist_args,
-                    )
+                    return _distance_func(ind1, data1, ind2, data2, *dist_args)
 
                 self._distance_func = _partial_dist_func
             else:
@@ -932,7 +900,10 @@ class NNDescent(object):
         if not hasattr(self, "_search_graph"):
             self._init_search_graph()
         if not hasattr(self, "_search_function"):
-            self._init_search_function()
+            if self._is_sparse:
+                self._init_sparse_search_function()
+            else:
+                self._init_search_function()
         result = self.__dict__.copy()
         if hasattr(self, "_rp_forest"):
             del result["_rp_forest"]
@@ -946,7 +917,10 @@ class NNDescent(object):
         self._search_forest = tuple(
             [renumbaify_tree(tree) for tree in d["_search_forest"]]
         )
-        self._init_search_function()
+        if self._is_sparse:
+            self._init_sparse_search_function()
+        else:
+            self._init_search_function()
 
     def _init_search_graph(self):
 
@@ -1118,8 +1092,7 @@ class NNDescent(object):
             print(
                 ts(),
                 "Degree pruning reduced edges from {} to {}".format(
-                    pre_prune_nnz,
-                    self._search_graph.nnz,
+                    pre_prune_nnz, self._search_graph.nnz
                 ),
             )
 
@@ -1171,7 +1144,7 @@ class NNDescent(object):
                 numba.types.Array(numba.types.int32, 1, "C", readonly=True)(
                     numba.types.Array(numba.types.float32, 1, "C", readonly=True),
                     numba.types.Array(numba.types.int64, 1, "C", readonly=False),
-                ),
+                )
             ],
             locals={"node": numba.types.uint32, "side": numba.types.boolean},
         )
@@ -1221,13 +1194,7 @@ class NNDescent(object):
                 "seed_scale": numba.types.float32,
             },
         )
-        def search_closure(
-            query_points,
-            k,
-            epsilon,
-            visited,
-            rng_state,
-        ):
+        def search_closure(query_points, k, epsilon, visited, rng_state):
 
             result = make_heap(query_points.shape[0], k)
             distance_scale = 1.0 + epsilon
@@ -1314,11 +1281,7 @@ class NNDescent(object):
         # Force compilation of the search function (hardcoded k, epsilon)
         query_data = self._raw_data[:1]
         _ = self._search_function(
-            query_data,
-            5,
-            0.0,
-            self._visited,
-            self.search_rng_state,
+            query_data, 5, 0.0, self._visited, self.search_rng_state
         )
 
     def _init_sparse_search_function(self):
@@ -1337,7 +1300,7 @@ class NNDescent(object):
                     numba.types.Array(numba.types.int32, 1, "C", readonly=True),
                     numba.types.Array(numba.types.float32, 1, "C", readonly=True),
                     numba.types.Array(numba.types.int64, 1, "C", readonly=False),
-                ),
+                )
             ],
             locals={"node": numba.types.uint32, "side": numba.types.boolean},
         )
@@ -1389,13 +1352,7 @@ class NNDescent(object):
             },
         )
         def search_closure(
-            query_inds,
-            query_indptr,
-            query_data,
-            k,
-            epsilon,
-            visited,
-            rng_state,
+            query_inds, query_indptr, query_data, k, epsilon, visited, rng_state
         ):
 
             n_query_points = query_indptr.shape[0] - 1
@@ -1424,9 +1381,7 @@ class NNDescent(object):
 
                 ############ Init ################
                 index_bounds = sparse_tree_search_closure(
-                    current_query_inds,
-                    current_query_data,
-                    internal_rng_state,
+                    current_query_inds, current_query_data, internal_rng_state
                 )
                 candidate_indices = tree_indices[index_bounds[0] : index_bounds[1]]
 
@@ -1616,11 +1571,7 @@ class NNDescent(object):
 
             query_data = np.asarray(query_data).astype(np.float32, order="C")
             result = self._search_function(
-                query_data,
-                k,
-                epsilon,
-                self._visited,
-                self.search_rng_state,
+                query_data, k, epsilon, self._visited, self.search_rng_state
             )
         else:
             # Sparse case
@@ -1787,10 +1738,12 @@ class PyNNDescentTransformer(BaseEstimator, TransformerMixin):
         that no edges get removed, and larger values result in significantly more
         aggressive edge removal. A value of 1.0 will prune all edges that it can.
 
-    n_search_trees: float (optional, default=1)
+    n_search_trees: int (optional, default=1)
         The number of random projection trees to use in initializing searching or
         querying.
 
+        .. deprecated:: 0.5.5
+
     search_epsilon: float (optional, default=0.1)
         When searching for nearest neighbors of a query point this values
         controls the trade-off between accuracy and search cost. Larger values
@@ -1975,8 +1928,7 @@ class PyNNDescentTransformer(BaseEstimator, TransformerMixin):
             print(ts(), "Constructing neighbor matrix")
         result = coo_matrix((n_samples_transform, self.n_samples_fit), dtype=np.float32)
         result.row = np.repeat(
-            np.arange(indices.shape[0], dtype=np.int32),
-            indices.shape[1],
+            np.arange(indices.shape[0], dtype=np.int32), indices.shape[1]
         )
         result.col = indices.ravel()
         result.data = distances.ravel()


=====================================
pynndescent/rp_trees.py
=====================================
@@ -6,8 +6,6 @@ from warnings import warn
 import locale
 import numpy as np
 import numba
-from numba.core import types
-from numba.experimental import structref
 import scipy.sparse
 
 from pynndescent.sparse import sparse_mul, sparse_diff, sparse_sum, arr_intersect
@@ -548,13 +546,11 @@ def make_euclidean_tree(
         offsets.append(offset)
         children.append((np.int32(left_node_num), np.int32(right_node_num)))
         point_indices.append(np.array([-1], dtype=np.int32))
-        # print("Made a node in tree with", len(point_indices), "nodes")
     else:
         hyperplanes.append(np.array([-1.0], dtype=np.float32))
         offsets.append(-np.inf)
         children.append((np.int32(-1), np.int32(-1)))
         point_indices.append(indices)
-        # print("Made a leaf in tree with", len(point_indices), "nodes")
 
     return
 
@@ -795,9 +791,7 @@ def make_dense_tree(data, rng_state, leaf_size=30, angular=False):
             leaf_size,
         )
 
-    # print("Completed a tree")
     result = FlatTree(hyperplanes, offsets, children, point_indices, leaf_size)
-    # print("Tree type is:", numba.typeof(result))
     return result
 
 
@@ -856,6 +850,7 @@ def make_sparse_tree(inds, indptr, spdata, rng_state, leaf_size=30, angular=Fals
         "dim": numba.types.intp,
         "d": numba.types.uint16,
     },
+    cache=True,
 )
 def select_side(hyperplane, offset, point, rng_state):
     margin = offset
@@ -888,6 +883,7 @@ def select_side(hyperplane, offset, point, rng_state):
         ),
     ],
     locals={"node": numba.types.uint32, "side": numba.types.boolean},
+    cache=True,
 )
 def search_flat_tree(point, hyperplanes, offsets, children, indices, rng_state):
     node = 0
@@ -901,7 +897,7 @@ def search_flat_tree(point, hyperplanes, offsets, children, indices, rng_state):
     return indices[-children[node, 0] : -children[node, 1]]
 
 
- at numba.njit(fastmath=True)
+ at numba.njit(fastmath=True, cache=True)
 def sparse_select_side(hyperplane, offset, point_inds, point_data, rng_state):
     margin = offset
 
@@ -929,7 +925,7 @@ def sparse_select_side(hyperplane, offset, point_inds, point_data, rng_state):
         return 1
 
 
- at numba.njit(locals={"node": numba.types.uint32})
+ at numba.njit(locals={"node": numba.types.uint32}, cache=True)
 def search_sparse_flat_tree(
     point_inds, point_data, hyperplanes, offsets, children, indices, rng_state
 ):
@@ -975,7 +971,7 @@ def make_forest(
     # print(ts(), "Started forest construction")
     result = []
     if leaf_size is None:
-        leaf_size = max(10, n_neighbors)
+        leaf_size = max(10, np.int32(n_neighbors))
     if n_jobs is None:
         n_jobs = -1
 
@@ -1010,7 +1006,7 @@ def make_forest(
     return tuple(result)
 
 
- at numba.njit(nogil=True)
+ at numba.njit(nogil=True, cache=True)
 def get_leaves_from_tree(tree):
     n_leaves = 0
     for i in range(len(tree.children)):
@@ -1124,7 +1120,7 @@ def recursive_convert_sparse(
         return node_num, leaf_start
 
 
- at numba.njit()
+ at numba.njit(cache=True)
 def num_nodes_and_leaves(tree):
     n_nodes = 0
     n_leaves = 0
@@ -1138,7 +1134,7 @@ def num_nodes_and_leaves(tree):
     return n_nodes, n_leaves
 
 
- at numba.njit()
+ at numba.njit(cache=True)
 def dense_hyperplane_dim(hyperplanes):
     for i in range(len(hyperplanes)):
         if hyperplanes[i].shape[0] > 1:
@@ -1147,7 +1143,7 @@ def dense_hyperplane_dim(hyperplanes):
     raise ValueError("No hyperplanes of adequate size were found!")
 
 
- at numba.njit()
+ at numba.njit(cache=True)
 def sparse_hyperplane_dim(hyperplanes):
     max_dim = 0
     for i in range(len(hyperplanes)):
@@ -1226,6 +1222,7 @@ def renumbaify_tree(tree):
         "result": numba.float32,
         "i": numba.uint32,
     },
+    cache=True,
 )
 def score_tree(tree, neighbor_indices, data, rng_state):
     result = 0.0
@@ -1243,7 +1240,7 @@ def score_tree(tree, neighbor_indices, data, rng_state):
     return result / numba.float32(neighbor_indices.shape[0])
 
 
- at numba.njit(nogil=True, parallel=True, locals={"node": numba.int32})
+ at numba.njit(nogil=True, parallel=True, locals={"node": numba.int32}, cache=True)
 def score_linked_tree(tree, neighbor_indices):
     result = 0.0
     n_nodes = len(tree.children)


=====================================
pynndescent/sparse.py
=====================================
@@ -8,7 +8,11 @@ import numpy as np
 import numba
 
 from pynndescent.utils import norm, tau_rand
-from pynndescent.distances import kantorovich
+from pynndescent.distances import (
+    kantorovich,
+    jensen_shannon_divergence,
+    symmetric_kl_divergence,
+)
 
 locale.setlocale(locale.LC_NUMERIC, "C")
 
@@ -16,14 +20,14 @@ FLOAT32_EPS = np.finfo(np.float32).eps
 FLOAT32_MAX = np.finfo(np.float32).max
 
 # Just reproduce a simpler version of numpy isclose (not numba supported yet)
- at numba.njit()
+ at numba.njit(cache=True)
 def isclose(a, b, rtol=1.0e-5, atol=1.0e-8):
     diff = np.abs(a - b)
     return diff <= (atol + rtol * np.abs(b))
 
 
 # Just reproduce a simpler version of numpy unique (not numba supported yet)
- at numba.njit()
+ at numba.njit(cache=True)
 def arr_unique(arr):
     aux = np.sort(arr)
     flag = np.concatenate((np.ones(1, dtype=np.bool_), aux[1:] != aux[:-1]))
@@ -31,7 +35,7 @@ def arr_unique(arr):
 
 
 # Just reproduce a simpler version of numpy union1d (not numba supported yet)
- at numba.njit()
+ at numba.njit(cache=True)
 def arr_union(ar1, ar2):
     if ar1.shape[0] == 0:
         return ar2
@@ -43,7 +47,7 @@ def arr_union(ar1, ar2):
 
 # Just reproduce a simpler version of numpy intersect1d (not numba supported
 # yet)
- at numba.njit()
+ at numba.njit(cache=True)
 def arr_intersect(ar1, ar2):
     aux = np.concatenate((ar1, ar2))
     aux.sort()
@@ -62,7 +66,7 @@ def arr_intersect(ar1, ar2):
             numba.types.Array(numba.types.float32, 1, "C", readonly=True),
             numba.types.Array(numba.types.int32, 1, "C", readonly=True),
             numba.types.Array(numba.types.float32, 1, "C", readonly=True),
-        ),
+        )
     ],
     fastmath=True,
     locals={
@@ -74,6 +78,7 @@ def arr_intersect(ar1, ar2):
         "j1": numba.types.int32,
         "j2": numba.types.int32,
     },
+    cache=True,
 )
 def sparse_sum(ind1, data1, ind2, data2):
     result_size = ind1.shape[0] + ind2.shape[0]
@@ -138,7 +143,7 @@ def sparse_sum(ind1, data1, ind2, data2):
     return result_ind, result_data
 
 
- at numba.njit()
+ at numba.njit(cache=True)
 def sparse_diff(ind1, data1, ind2, data2):
     return sparse_sum(ind1, data1, ind2, -data2)
 
@@ -156,7 +161,7 @@ def sparse_diff(ind1, data1, ind2, data2):
             numba.types.Array(numba.types.float32, 1, "C", readonly=True),
             numba.types.Array(numba.types.int32, 1, "C", readonly=True),
             numba.types.Array(numba.types.float32, 1, "C", readonly=True),
-        ),
+        )
     ],
     fastmath=True,
     locals={
@@ -166,6 +171,7 @@ def sparse_diff(ind1, data1, ind2, data2):
         "j1": numba.types.int32,
         "j2": numba.types.int32,
     },
+    cache=True,
 )
 def sparse_mul(ind1, data1, ind2, data2):
     result_ind = numba.typed.List.empty_list(numba.types.int32)
@@ -194,7 +200,66 @@ def sparse_mul(ind1, data1, ind2, data2):
     return result_ind, result_data
 
 
+# Return dense vectors supported on the union of the non-zero valued indices
 @numba.njit()
+def dense_union(ind1, data1, ind2, data2):
+    result_ind = arr_union(ind1, ind2)
+    result_data1 = np.zeros(result_ind.shape[0], dtype=np.float32)
+    result_data2 = np.zeros(result_ind.shape[0], dtype=np.float32)
+
+    i1 = 0
+    i2 = 0
+    nnz = 0
+
+    # pass through both index lists
+    while i1 < ind1.shape[0] and i2 < ind2.shape[0]:
+        j1 = ind1[i1]
+        j2 = ind2[i2]
+
+        if j1 == j2:
+            val = data1[i1] + data2[i2]
+            if val != 0:
+                result_data1[nnz] = data1[i1]
+                result_data2[nnz] = data2[i2]
+                nnz += 1
+            i1 += 1
+            i2 += 1
+        elif j1 < j2:
+            val = data1[i1]
+            if val != 0:
+                result_data1[nnz] = data1[i1]
+                nnz += 1
+            i1 += 1
+        else:
+            val = data2[i2]
+            if val != 0:
+                result_data2[nnz] = data2[i2]
+                nnz += 1
+            i2 += 1
+
+    # pass over the tails
+    while i1 < ind1.shape[0]:
+        val = data1[i1]
+        if val != 0:
+            result_data1[nnz] = data1[i1]
+            nnz += 1
+        i1 += 1
+
+    while i2 < ind2.shape[0]:
+        val = data2[i2]
+        if val != 0:
+            result_data2[nnz] = data2[i2]
+            nnz += 1
+        i2 += 1
+
+    # truncate to the correct length in case there were zeros
+    result_data1 = result_data1[:nnz]
+    result_data2 = result_data2[:nnz]
+
+    return result_data1, result_data2
+
+
+ at numba.njit(fastmath=True)
 def sparse_euclidean(ind1, data1, ind2, data2):
     _, aux_data = sparse_diff(ind1, data1, ind2, data2)
     result = 0.0
@@ -335,10 +400,7 @@ def sparse_jaccard(ind1, data1, ind2, data2):
         ),
     ],
     fastmath=True,
-    locals={
-        "num_non_zero": numba.types.intp,
-        "num_equal": numba.types.intp,
-    },
+    locals={"num_non_zero": numba.types.intp, "num_equal": numba.types.intp},
 )
 def sparse_alternative_jaccard(ind1, data1, ind2, data2):
     num_non_zero = arr_union(ind1, ind2).shape[0]
@@ -696,7 +758,24 @@ def sparse_kantorovich(ind1, data1, ind2, data2, ground_metric=dummy_ground_metr
     return kantorovich(data1, data2, cost_matrix)
 
 
- at numba.njit(parallel=True)
+# Because of the EPS values and the need to normalize after adding them (and then average those for jensen_shannon)
+# it seems like we might as well just take the dense union (dense vectors supported on the union of indices)
+# and call the dense distance functions
+
+
+ at numba.njit()
+def sparse_jensen_shannon_divergence(ind1, data1, ind2, data2):
+    dense_data1, dense_data2 = dense_union(ind1, data1, ind2, data2)
+    return jensen_shannon_divergence(dense_data1, dense_data2)
+
+
+ at numba.njit()
+def sparse_symmetric_kl_divergence(ind1, data1, ind2, data2):
+    dense_data1, dense_data2 = dense_union(ind1, data1, ind2, data2)
+    return symmetric_kl_divergence(dense_data1, dense_data2)
+
+
+ at numba.njit(parallel=True, cache=True)
 def diversify(
     indices,
     distances,
@@ -751,7 +830,7 @@ def diversify(
     return indices, distances
 
 
- at numba.njit(parallel=True)
+ at numba.njit(parallel=True, cache=True)
 def diversify_csr(
     graph_indptr,
     graph_indices,
@@ -821,8 +900,6 @@ sparse_named_distances = {
     "minkowski": sparse_minkowski,
     # Other distances
     "canberra": sparse_canberra,
-    "kantorovich": sparse_kantorovich,
-    "wasserstein": sparse_kantorovich,
     "braycurtis": sparse_bray_curtis,
     # Binary distances
     "hamming": sparse_hamming,
@@ -837,7 +914,15 @@ sparse_named_distances = {
     # Angular distances
     "cosine": sparse_cosine,
     "correlation": sparse_correlation,
+    # Distribution distances
+    "kantorovich": sparse_kantorovich,
+    "wasserstein": sparse_kantorovich,
     "hellinger": sparse_hellinger,
+    "jensen-shannon": sparse_jensen_shannon_divergence,
+    "jensen_shannon": sparse_jensen_shannon_divergence,
+    "symmetric-kl": sparse_symmetric_kl_divergence,
+    "symmetric_kl": sparse_symmetric_kl_divergence,
+    "symmetric_kullback_liebler": sparse_symmetric_kl_divergence,
 }
 
 sparse_need_n_features = (


=====================================
pynndescent/sparse_nndescent.py
=====================================
@@ -24,15 +24,8 @@ locale.setlocale(locale.LC_NUMERIC, "C")
 EMPTY_GRAPH = make_heap(1, 1)
 
 
- at numba.njit(parallel=True)
-def generate_leaf_updates(
-    leaf_block,
-    dist_thresholds,
-    inds,
-    indptr,
-    data,
-    dist,
-):
+ at numba.njit(parallel=True, cache=True)
+def generate_leaf_updates(leaf_block, dist_thresholds, inds, indptr, data, dist):
 
     updates = [[(-1, -1, np.inf)] for i in range(leaf_block.shape[0])]
 
@@ -60,13 +53,7 @@ def generate_leaf_updates(
     return updates
 
 
- at numba.njit(
-    locals={
-        "d": numba.float32,
-        "p": numba.int32,
-        "q": numba.int32,
-    }
-)
+ at numba.njit(locals={"d": numba.float32, "p": numba.int32, "q": numba.int32}, cache=True)
 def init_rp_tree(inds, indptr, data, dist, current_graph, leaf_array):
 
     n_leaves = leaf_array.shape[0]
@@ -81,12 +68,7 @@ def init_rp_tree(inds, indptr, data, dist, current_graph, leaf_array):
         dist_thresholds = current_graph[1][:, 0]
 
         updates = generate_leaf_updates(
-            leaf_block,
-            dist_thresholds,
-            inds,
-            indptr,
-            data,
-            dist,
+            leaf_block, dist_thresholds, inds, indptr, data, dist
         )
 
         for j in range(len(updates)):
@@ -96,8 +78,6 @@ def init_rp_tree(inds, indptr, data, dist, current_graph, leaf_array):
                 if p == -1 or q == -1:
                     continue
 
-                # heap_push(current_graph, p, d, q, 1)
-                # heap_push(current_graph, q, d, p, 1)
                 checked_flagged_heap_push(
                     current_graph[1][p],
                     current_graph[0][p],
@@ -118,11 +98,8 @@ def init_rp_tree(inds, indptr, data, dist, current_graph, leaf_array):
 
 @numba.njit(
     fastmath=True,
-    locals={
-        "d": numba.float32,
-        "i": numba.int32,
-        "idx": numba.int32,
-    },
+    locals={"d": numba.float32, "i": numba.int32, "idx": numba.int32},
+    cache=True,
 )
 def init_random(n_neighbors, inds, indptr, data, heap, dist, rng_state):
     n_samples = indptr.shape[0] - 1
@@ -138,7 +115,6 @@ def init_random(n_neighbors, inds, indptr, data, heap, dist, rng_state):
                 to_data = data[indptr[i] : indptr[i + 1]]
                 d = dist(from_inds, from_data, to_inds, to_data)
 
-                # heap_push(heap, i, d, idx, 1)
                 checked_flagged_heap_push(
                     heap[1][i], heap[0][i], heap[2][i], d, idx, np.uint8(1)
                 )
@@ -146,15 +122,9 @@ def init_random(n_neighbors, inds, indptr, data, heap, dist, rng_state):
     return
 
 
- at numba.njit(parallel=True)
+ at numba.njit(parallel=True, cache=True)
 def generate_graph_updates(
-    new_candidate_block,
-    old_candidate_block,
-    dist_thresholds,
-    inds,
-    indptr,
-    data,
-    dist,
+    new_candidate_block, old_candidate_block, dist_thresholds, inds, indptr, data, dist
 ):
 
     block_size = new_candidate_block.shape[0]
@@ -216,15 +186,14 @@ def nn_descent_internal_low_memory_parallel(
     n_vertices = indptr.shape[0] - 1
     block_size = 16384
     n_blocks = n_vertices // block_size
+    n_threads = numba.get_num_threads()
 
     for n in range(n_iters):
         if verbose:
             print("\t", n + 1, " / ", n_iters)
 
         (new_candidate_neighbors, old_candidate_neighbors) = new_build_candidates(
-            current_graph,
-            max_candidates,
-            rng_state,
+            current_graph, max_candidates, rng_state, n_threads
         )
 
         c = 0
@@ -246,7 +215,7 @@ def nn_descent_internal_low_memory_parallel(
                 dist,
             )
 
-            c += apply_graph_updates_low_memory(current_graph, updates)
+            c += apply_graph_updates_low_memory(current_graph, updates, n_threads)
 
         if c <= delta * n_neighbors * n_vertices:
             if verbose:
@@ -271,6 +240,7 @@ def nn_descent_internal_high_memory_parallel(
     n_vertices = indptr.shape[0] - 1
     block_size = 16384
     n_blocks = n_vertices // block_size
+    n_threads = numba.get_num_threads()
 
     in_graph = [
         set(current_graph[0][i].astype(np.int64))
@@ -282,9 +252,7 @@ def nn_descent_internal_high_memory_parallel(
             print("\t", n + 1, " / ", n_iters)
 
         (new_candidate_neighbors, old_candidate_neighbors) = new_build_candidates(
-            current_graph,
-            max_candidates,
-            rng_state,
+            current_graph, max_candidates, rng_state, n_threads
         )
 
         c = 0


=====================================
pynndescent/tests/conftest.py
=====================================
@@ -0,0 +1,63 @@
+import os
+import pytest
+import numpy as np
+from scipy import sparse
+
+# Making Random Seed as a fixture in case it would be
+# needed in tests for random states
+ at pytest.fixture
+def seed():
+    return 189212  # 0b101110001100011100
+
+
+np.random.seed(189212)
+
+
+ at pytest.fixture
+def spatial_data():
+    sp_data = np.random.randn(10, 20)
+    # Add some all zero graph_data for corner case test
+    sp_data = np.vstack([sp_data, np.zeros((2, 20))]).astype(np.float32, order="C")
+    return sp_data
+
+
+ at pytest.fixture
+def binary_data():
+    bin_data = np.random.choice(a=[False, True], size=(10, 20), p=[0.66, 1 - 0.66])
+    # Add some all zero graph_data for corner case test
+    bin_data = np.vstack([bin_data, np.zeros((2, 20), dtype="bool")])
+    return bin_data
+
+
+ at pytest.fixture
+def sparse_spatial_data(spatial_data, binary_data):
+    sp_sparse_data = sparse.csr_matrix(spatial_data * binary_data, dtype=np.float32)
+    sp_sparse_data.sort_indices()
+    return sp_sparse_data
+
+
+ at pytest.fixture
+def sparse_binary_data(binary_data):
+    bin_sparse_data = sparse.csr_matrix(binary_data)
+    bin_sparse_data.sort_indices()
+    return bin_sparse_data
+
+
+ at pytest.fixture
+def nn_data():
+    nndata = np.random.uniform(0, 1, size=(1000, 5))
+    # Add some all zero graph_data for corner case test
+    nndata = np.vstack([nndata, np.zeros((2, 5))])
+    return nndata
+
+
+ at pytest.fixture
+def sparse_nn_data():
+    return sparse.random(1000, 50, density=0.5, format="csr")
+
+
+ at pytest.fixture
+def cosine_hang_data():
+    this_dir = os.path.dirname(os.path.abspath(__file__))
+    data_path = os.path.join(this_dir, "test_data/cosine_hang.npy")
+    return np.load(data_path)


=====================================
pynndescent/tests/test_distances.py
=====================================
@@ -1,27 +1,30 @@
+import pytest
 import numpy as np
 from numpy.testing import assert_array_equal, assert_array_almost_equal
 import pynndescent.distances as dist
 import pynndescent.sparse as spdist
-from scipy import sparse, stats
+from scipy import stats
+from scipy.sparse import csr_matrix
 from sklearn.metrics import pairwise_distances
 from sklearn.neighbors import BallTree
-
-np.random.seed(42)
-spatial_data = np.random.randn(10, 20)
-spatial_data = np.vstack([spatial_data, np.zeros((2, 20))]).astype(
-    np.float32, order="C"
-)  # Add some all zero graph_data for corner case test
-binary_data = np.random.choice(a=[False, True], size=(10, 20), p=[0.66, 1 - 0.66])
-binary_data = np.vstack(
-    [binary_data, np.zeros((2, 20), dtype="bool")]
-)  # Add some all zero graph_data for corner case test
-sparse_spatial_data = sparse.csr_matrix(spatial_data * binary_data, dtype=np.float32)
-sparse_spatial_data.sort_indices()
-sparse_binary_data = sparse.csr_matrix(binary_data)
-sparse_binary_data.sort_indices()
-
-
-def spatial_check(metric):
+from sklearn.preprocessing import normalize
+
+
+ at pytest.mark.parametrize(
+    "metric",
+    [
+        "euclidean",
+        "manhattan",
+        "chebyshev",
+        "minkowski",
+        "hamming",
+        "canberra",
+        "braycurtis",
+        "cosine",
+        "correlation",
+    ],
+)
+def test_spatial_check(spatial_data, metric):
     dist_matrix = pairwise_distances(spatial_data, metric=metric)
     # scipy is bad sometimes
     if metric == "braycurtis":
@@ -48,7 +51,21 @@ def spatial_check(metric):
     )
 
 
-def binary_check(metric):
+ at pytest.mark.parametrize(
+    "metric",
+    [
+        "jaccard",
+        "matching",
+        "dice",
+        "kulsinski",
+        "rogerstanimoto",
+        "russellrao",
+        "sokalmichener",
+        "sokalsneath",
+        "yule",
+    ],
+)
+def test_binary_check(binary_data, metric):
     dist_matrix = pairwise_distances(binary_data, metric=metric)
     if metric in ("jaccard", "dice", "sokalsneath", "yule"):
         dist_matrix[np.where(~np.isfinite(dist_matrix))] = 0.0
@@ -74,7 +91,21 @@ def binary_check(metric):
     )
 
 
-def sparse_spatial_check(metric, decimal=6):
+ at pytest.mark.parametrize(
+    "metric",
+    [
+        "euclidean",
+        "manhattan",
+        "chebyshev",
+        "minkowski",
+        "hamming",
+        "canberra",
+        "cosine",
+        "braycurtis",
+        "correlation",
+    ],
+)
+def test_sparse_spatial_check(sparse_spatial_data, metric, decimal=6):
     if metric in spdist.sparse_named_distances:
         dist_matrix = pairwise_distances(
             sparse_spatial_data.todense().astype(np.float32), metric=metric
@@ -127,10 +158,23 @@ def sparse_spatial_check(metric, decimal=6):
     )
 
 
-def sparse_binary_check(metric):
+ at pytest.mark.parametrize(
+    "metric",
+    [
+        "jaccard",
+        "matching",
+        "dice",
+        "kulsinski",
+        "rogerstanimoto",
+        "russellrao",
+        "sokalmichener",
+        "sokalsneath",
+    ],
+)
+def test_sparse_binary_check(sparse_binary_data, metric):
     if metric in spdist.sparse_named_distances:
         dist_matrix = pairwise_distances(sparse_binary_data.todense(), metric=metric)
-    if metric in ("jaccard", "dice", "sokalsneath", "yule"):
+    if metric in ("jaccard", "dice", "sokalsneath"):
         dist_matrix[np.where(~np.isfinite(dist_matrix))] = 0.0
     if metric in ("kulsinski", "russellrao"):
         dist_matrix[np.where(~np.isfinite(dist_matrix))] = 1.0
@@ -178,147 +222,7 @@ def sparse_binary_check(metric):
     )
 
 
-def test_euclidean():
-    spatial_check("euclidean")
-
-
-def test_manhattan():
-    spatial_check("manhattan")
-
-
-def test_chebyshev():
-    spatial_check("chebyshev")
-
-
-def test_minkowski():
-    spatial_check("minkowski")
-
-
-def test_hamming():
-    spatial_check("hamming")
-
-
-def test_canberra():
-    spatial_check("canberra")
-
-
-def test_braycurtis():
-    spatial_check("braycurtis")
-
-
-def test_cosine():
-    spatial_check("cosine")
-
-
-def test_correlation():
-    spatial_check("correlation")
-
-
-def test_jaccard():
-    binary_check("jaccard")
-
-
-def test_matching():
-    binary_check("matching")
-
-
-def test_dice():
-    binary_check("dice")
-
-
-def test_kulsinski():
-    binary_check("kulsinski")
-
-
-def test_rogerstanimoto():
-    binary_check("rogerstanimoto")
-
-
-def test_russellrao():
-    binary_check("russellrao")
-
-
-def test_sokalmichener():
-    binary_check("sokalmichener")
-
-
-def test_sokalsneath():
-    binary_check("sokalsneath")
-
-
-def test_yule():
-    binary_check("yule")
-
-
-def test_sparse_euclidean():
-    sparse_spatial_check("euclidean")
-
-
-def test_sparse_manhattan():
-    sparse_spatial_check("manhattan")
-
-
-def test_sparse_chebyshev():
-    sparse_spatial_check("chebyshev")
-
-
-def test_sparse_minkowski():
-    sparse_spatial_check("minkowski")
-
-
-def test_sparse_hamming():
-    sparse_spatial_check("hamming")
-
-
-def test_sparse_canberra():
-    sparse_spatial_check("canberra")  # Be a little forgiving
-
-
-def test_sparse_cosine():
-    sparse_spatial_check("cosine")
-
-
-def test_sparse_correlation():
-    sparse_spatial_check("correlation")
-
-
-def test_sparse_jaccard():
-    sparse_binary_check("jaccard")
-
-
-def test_sparse_matching():
-    sparse_binary_check("matching")
-
-
-def test_sparse_dice():
-    sparse_binary_check("dice")
-
-
-def test_sparse_kulsinski():
-    sparse_binary_check("kulsinski")
-
-
-def test_sparse_rogerstanimoto():
-    sparse_binary_check("rogerstanimoto")
-
-
-def test_sparse_russellrao():
-    sparse_binary_check("russellrao")
-
-
-def test_sparse_sokalmichener():
-    sparse_binary_check("sokalmichener")
-
-
-def test_sparse_sokalsneath():
-    sparse_binary_check("sokalsneath")
-
-
-def test_sparse_braycurtis():
-    sparse_spatial_check("braycurtis")
-
-
-def test_seuclidean():
+def test_seuclidean(spatial_data):
     v = np.abs(np.random.randn(spatial_data.shape[1]))
     dist_matrix = pairwise_distances(spatial_data, metric="seuclidean", V=v)
     test_matrix = np.array(
@@ -337,7 +241,7 @@ def test_seuclidean():
     )
 
 
-def test_weighted_minkowski():
+def test_weighted_minkowski(spatial_data):
     v = np.abs(np.random.randn(spatial_data.shape[1]))
     dist_matrix = pairwise_distances(spatial_data, metric="wminkowski", w=v, p=3)
     test_matrix = np.array(
@@ -356,7 +260,7 @@ def test_weighted_minkowski():
     )
 
 
-def test_mahalanobis():
+def test_mahalanobis(spatial_data):
     v = np.cov(np.transpose(spatial_data))
     dist_matrix = pairwise_distances(spatial_data, metric="mahalanobis", VI=v)
     test_matrix = np.array(
@@ -375,7 +279,7 @@ def test_mahalanobis():
     )
 
 
-def test_haversine():
+def test_haversine(spatial_data):
     tree = BallTree(spatial_data[:, :2], metric="haversine")
     dist_matrix, _ = tree.query(spatial_data[:, :2], k=spatial_data.shape[0])
     test_matrix = np.array(
@@ -422,3 +326,49 @@ def test_alternative_distances():
             corrected_alt_distance = correction(alt_dist(x, y))
 
             assert np.isclose(true_distance, corrected_alt_distance)
+
+
+def test_jensen_shannon():
+    test_data = np.random.random(size=(10, 50))
+    test_data = normalize(test_data, norm="l1")
+    for i in range(test_data.shape[0]):
+        for j in range(i + 1, test_data.shape[0]):
+            m = (test_data[i] + test_data[j]) / 2.0
+            p = test_data[i]
+            q = test_data[j]
+            d1 = (
+                -np.sum(m * np.log(m))
+                + (np.sum(p * np.log(p)) + np.sum(q * np.log(q))) / 2.0
+            )
+            d2 = dist.jensen_shannon_divergence(p, q)
+            assert np.isclose(d1, d2, rtol=1e-4)
+
+
+def test_sparse_jensen_shannon():
+    test_data = np.random.random(size=(10, 100))
+    # sparsify
+    test_data[test_data <= 0.5] = 0.0
+    sparse_test_data = csr_matrix(test_data)
+    sparse_test_data = normalize(sparse_test_data, norm="l1")
+    test_data = normalize(test_data, norm="l1")
+
+    for i in range(test_data.shape[0]):
+        for j in range(i + 1, test_data.shape[0]):
+            m = (test_data[i] + test_data[j]) / 2.0
+            p = test_data[i]
+            q = test_data[j]
+            d1 = (
+                -np.sum(m[m > 0] * np.log(m[m > 0]))
+                + (
+                    np.sum(p[p > 0] * np.log(p[p > 0]))
+                    + np.sum(q[q > 0] * np.log(q[q > 0]))
+                )
+                / 2.0
+            )
+            d2 = spdist.sparse_jensen_shannon_divergence(
+                    sparse_test_data[i].indices,
+                    sparse_test_data[i].data,
+                    sparse_test_data[j].indices,
+                    sparse_test_data[j].data,
+                )
+            assert np.isclose(d1, d2, rtol=1e-3)


=====================================
pynndescent/tests/test_pynndescent_.py
=====================================
@@ -1,44 +1,23 @@
 import os
 import io
 import re
+import pytest
 from contextlib import redirect_stdout
 
-from nose.tools import assert_greater_equal, assert_true, assert_equal
-from nose import SkipTest
-
 import numpy as np
-from scipy import sparse
 from sklearn.neighbors import KDTree
 from sklearn.neighbors import NearestNeighbors
 from sklearn.preprocessing import normalize
 import pickle
 import joblib
+import scipy
 
 from pynndescent import NNDescent, PyNNDescentTransformer
 
-np.random.seed(42)
-spatial_data = np.random.randn(10, 20)
-spatial_data = np.vstack(
-    [spatial_data, np.zeros((2, 20))]
-)  # Add some all zero graph_data for corner case test
-
-nn_data = np.random.uniform(0, 1, size=(1000, 5))
-nn_data = np.vstack(
-    [nn_data, np.zeros((2, 5))]
-)  # Add some all zero graph_data for corner case test
-# for_sparse_nn_data = np.random.uniform(0, 1, size=(1002, 500))
-# binary_nn_data = np.random.choice(a=[False, True], size=(1000, 500), p=[0.1, 1 - 0.1])
-# binary_nn_data = np.vstack(
-#     [binary_nn_data, np.zeros((2, 500))]
-# )  # Add some all zero graph_data for corner case test
-# sparse_nn_data = sparse.csr_matrix(for_sparse_nn_data * binary_nn_data)
-sparse_nn_data = sparse.random(1000, 50, density=0.5, format="csr")
-# sparse_nn_data = sparse.csr_matrix(nn_data)
-
-
-def test_nn_descent_neighbor_accuracy():
+
+def test_nn_descent_neighbor_accuracy(nn_data, seed):
     knn_indices, _ = NNDescent(
-        nn_data, "euclidean", {}, 10, random_state=np.random
+        nn_data, "euclidean", {}, 10, random_state=np.random.RandomState(seed)
     )._neighbor_graph
 
     tree = KDTree(nn_data)
@@ -49,16 +28,14 @@ def test_nn_descent_neighbor_accuracy():
         num_correct += np.sum(np.in1d(true_indices[i], knn_indices[i]))
 
     percent_correct = num_correct / (nn_data.shape[0] * 10)
-    assert_greater_equal(
-        percent_correct,
-        0.98,
-        "NN-descent did not get 99% " "accuracy on nearest neighbors",
+    assert percent_correct >= 0.98, (
+        "NN-descent did not get 99% " "accuracy on nearest neighbors"
     )
 
 
-def test_angular_nn_descent_neighbor_accuracy():
+def test_angular_nn_descent_neighbor_accuracy(nn_data, seed):
     knn_indices, _ = NNDescent(
-        nn_data, "cosine", {}, 10, random_state=np.random
+        nn_data, "cosine", {}, 10, random_state=np.random.RandomState(seed)
     )._neighbor_graph
 
     angular_data = normalize(nn_data, norm="l2")
@@ -70,14 +47,16 @@ def test_angular_nn_descent_neighbor_accuracy():
         num_correct += np.sum(np.in1d(true_indices[i], knn_indices[i]))
 
     percent_correct = num_correct / (nn_data.shape[0] * 10)
-    assert_greater_equal(
-        percent_correct,
-        0.98,
-        "NN-descent did not get 99% " "accuracy on nearest neighbors",
+    assert percent_correct >= 0.98, (
+        "NN-descent did not get 99% " "accuracy on nearest neighbors"
     )
 
 
-def test_sparse_nn_descent_neighbor_accuracy():
+ at pytest.mark.skipif(
+    list(map(int, scipy.version.version.split("."))) < [1, 3, 0],
+    reason="requires scipy >= 1.3.0",
+)
+def test_sparse_nn_descent_neighbor_accuracy(sparse_nn_data, seed):
     knn_indices, _ = NNDescent(
         sparse_nn_data, "euclidean", n_neighbors=20, random_state=None
     )._neighbor_graph
@@ -90,14 +69,16 @@ def test_sparse_nn_descent_neighbor_accuracy():
         num_correct += np.sum(np.in1d(true_indices[i], knn_indices[i]))
 
     percent_correct = num_correct / (sparse_nn_data.shape[0] * 10)
-    assert_greater_equal(
-        percent_correct,
-        0.85,
-        "Sparse NN-descent did not get 95% " "accuracy on nearest neighbors",
+    assert percent_correct >= 0.85, (
+        "Sparse NN-descent did not get 95% " "accuracy on nearest neighbors"
     )
 
 
-def test_sparse_angular_nn_descent_neighbor_accuracy():
+ at pytest.mark.skipif(
+    list(map(int, scipy.version.version.split("."))) < [1, 3, 0],
+    reason="requires scipy >= 1.3.0",
+)
+def test_sparse_angular_nn_descent_neighbor_accuracy(sparse_nn_data):
     knn_indices, _ = NNDescent(
         sparse_nn_data, "cosine", {}, 20, random_state=None
     )._neighbor_graph
@@ -111,14 +92,12 @@ def test_sparse_angular_nn_descent_neighbor_accuracy():
         num_correct += np.sum(np.in1d(true_indices[i], knn_indices[i]))
 
     percent_correct = num_correct / (sparse_nn_data.shape[0] * 10)
-    assert_greater_equal(
-        percent_correct,
-        0.85,
-        "Sparse angular NN-descent did not get 98% " "accuracy on nearest neighbors",
+    assert percent_correct >= 0.85, (
+        "Sparse angular NN-descent did not get 98% " "accuracy on nearest neighbors"
     )
 
 
-def test_nn_descent_query_accuracy():
+def test_nn_descent_query_accuracy(nn_data):
     nnd = NNDescent(nn_data[200:], "euclidean", n_neighbors=10, random_state=None)
     knn_indices, _ = nnd.query(nn_data[:200], k=10, epsilon=0.2)
 
@@ -130,14 +109,12 @@ def test_nn_descent_query_accuracy():
         num_correct += np.sum(np.in1d(true_indices[i], knn_indices[i]))
 
     percent_correct = num_correct / (true_indices.shape[0] * 10)
-    assert_greater_equal(
-        percent_correct,
-        0.95,
-        "NN-descent query did not get 95% " "accuracy on nearest neighbors",
+    assert percent_correct >= 0.95, (
+        "NN-descent query did not get 95% " "accuracy on nearest neighbors"
     )
 
 
-def test_nn_descent_query_accuracy_angular():
+def test_nn_descent_query_accuracy_angular(nn_data):
     nnd = NNDescent(nn_data[200:], "cosine", n_neighbors=30, random_state=None)
     knn_indices, _ = nnd.query(nn_data[:200], k=10, epsilon=0.32)
 
@@ -149,14 +126,12 @@ def test_nn_descent_query_accuracy_angular():
         num_correct += np.sum(np.in1d(true_indices[i], knn_indices[i]))
 
     percent_correct = num_correct / (true_indices.shape[0] * 10)
-    assert_greater_equal(
-        percent_correct,
-        0.95,
-        "NN-descent query did not get 95% " "accuracy on nearest neighbors",
+    assert percent_correct >= 0.95, (
+        "NN-descent query did not get 95% " "accuracy on nearest neighbors"
     )
 
 
-def test_sparse_nn_descent_query_accuracy():
+def test_sparse_nn_descent_query_accuracy(sparse_nn_data):
     nnd = NNDescent(
         sparse_nn_data[200:], "euclidean", n_neighbors=15, random_state=None
     )
@@ -170,14 +145,12 @@ def test_sparse_nn_descent_query_accuracy():
         num_correct += np.sum(np.in1d(true_indices[i], knn_indices[i]))
 
     percent_correct = num_correct / (true_indices.shape[0] * 10)
-    assert_greater_equal(
-        percent_correct,
-        0.95,
-        "Sparse NN-descent query did not get 95% " "accuracy on nearest neighbors",
+    assert percent_correct >= 0.95, (
+        "Sparse NN-descent query did not get 95% " "accuracy on nearest neighbors"
     )
 
 
-def test_sparse_nn_descent_query_accuracy_angular():
+def test_sparse_nn_descent_query_accuracy_angular(sparse_nn_data):
     nnd = NNDescent(sparse_nn_data[200:], "cosine", n_neighbors=50, random_state=None)
     knn_indices, _ = nnd.query(sparse_nn_data[:200], k=10, epsilon=0.36)
 
@@ -191,14 +164,12 @@ def test_sparse_nn_descent_query_accuracy_angular():
         num_correct += np.sum(np.in1d(true_indices[i], knn_indices[i]))
 
     percent_correct = num_correct / (true_indices.shape[0] * 10)
-    assert_greater_equal(
-        percent_correct,
-        0.95,
-        "Sparse NN-descent query did not get 95% " "accuracy on nearest neighbors",
+    assert percent_correct >= 0.95, (
+        "Sparse NN-descent query did not get 95% " "accuracy on nearest neighbors"
     )
 
 
-def test_transformer_equivalence():
+def test_transformer_equivalence(nn_data):
     N_NEIGHBORS = 15
     EPSILON = 0.15
     train = nn_data[:400]
@@ -225,7 +196,7 @@ def test_transformer_equivalence():
     assert np.allclose(Xt.data, dists_sorted.flat)
 
 
-def test_random_state_none():
+def test_random_state_none(nn_data, spatial_data):
     knn_indices, _ = NNDescent(
         nn_data, "euclidean", {}, 10, random_state=None
     )._neighbor_graph
@@ -238,10 +209,8 @@ def test_random_state_none():
         num_correct += np.sum(np.in1d(true_indices[i], knn_indices[i]))
 
     percent_correct = num_correct / (spatial_data.shape[0] * 10)
-    assert_greater_equal(
-        percent_correct,
-        0.99,
-        "NN-descent did not get 99% " "accuracy on nearest neighbors",
+    assert percent_correct >= 0.99, (
+        "NN-descent did not get 99% " "accuracy on nearest neighbors"
     )
 
 
@@ -265,42 +234,44 @@ def test_deterministic():
 # https://github.com/lmcinnes/umap/issues/99
 # graph_data used is a cut-down version of that provided by @scharron
 # It contains lots of all-zero vectors and some other duplicates
-def test_rp_trees_should_not_stack_overflow_with_duplicate_data():
-    this_dir = os.path.dirname(os.path.abspath(__file__))
-    data_path = os.path.join(this_dir, "test_data/cosine_hang.npy")
-    data = np.load(data_path)
+def test_rp_trees_should_not_stack_overflow_with_duplicate_data(seed, cosine_hang_data):
 
     n_neighbors = 10
     knn_indices, _ = NNDescent(
-        data, "cosine", {}, n_neighbors, random_state=np.random, n_trees=20
+        cosine_hang_data,
+        "cosine",
+        {},
+        n_neighbors,
+        random_state=np.random.RandomState(seed),
+        n_trees=20,
     )._neighbor_graph
 
-    for i in range(data.shape[0]):
-        assert_equal(
-            len(knn_indices[i]),
-            len(np.unique(knn_indices[i])),
-            "Duplicate graph_indices in knn graph",
-        )
+    for i in range(cosine_hang_data.shape[0]):
+        assert len(knn_indices[i]) == len(
+            np.unique(knn_indices[i])
+        ), "Duplicate graph_indices in knn graph"
+
 
+def test_deduplicated_data_behaves_normally(seed, cosine_hang_data):
 
-def test_deduplicated_data_behaves_normally():
-    this_dir = os.path.dirname(os.path.abspath(__file__))
-    data_path = os.path.join(this_dir, "test_data/cosine_hang.npy")
-    data = np.unique(np.load(data_path), axis=0)
+    data = np.unique(cosine_hang_data, axis=0)
     data = data[~np.all(data == 0, axis=1)]
     data = data[:1000]
 
     n_neighbors = 10
     knn_indices, _ = NNDescent(
-        data, "cosine", {}, n_neighbors, random_state=np.random, n_trees=20
+        data,
+        "cosine",
+        {},
+        n_neighbors,
+        random_state=np.random.RandomState(seed),
+        n_trees=20,
     )._neighbor_graph
 
     for i in range(data.shape[0]):
-        assert_equal(
-            len(knn_indices[i]),
-            len(np.unique(knn_indices[i])),
-            "Duplicate graph_indices in knn graph",
-        )
+        assert len(knn_indices[i]) == len(
+            np.unique(knn_indices[i])
+        ), "Duplicate graph_indices in knn graph"
 
     angular_data = normalize(data, norm="l2")
     tree = KDTree(angular_data)
@@ -311,14 +282,12 @@ def test_deduplicated_data_behaves_normally():
         num_correct += np.sum(np.in1d(true_indices[i], knn_indices[i]))
 
     proportion_correct = num_correct / (data.shape[0] * n_neighbors)
-    assert_greater_equal(
-        proportion_correct,
-        0.95,
-        "NN-descent did not get 95%" " accuracy on nearest neighbors",
+    assert proportion_correct >= 0.95, (
+        "NN-descent did not get 95%" " accuracy on nearest neighbors"
     )
 
 
-def test_output_when_verbose_is_true():
+def test_output_when_verbose_is_true(spatial_data, seed):
     out = io.StringIO()
     with redirect_stdout(out):
         _ = NNDescent(
@@ -326,17 +295,17 @@ def test_output_when_verbose_is_true():
             metric="euclidean",
             metric_kwds={},
             n_neighbors=4,
-            random_state=np.random,
+            random_state=np.random.RandomState(seed),
             n_trees=5,
             n_iters=2,
             verbose=True,
         )
     output = out.getvalue()
-    assert_true(re.match("^.*5 trees", output, re.DOTALL))
-    assert_true(re.match("^.*2 iterations", output, re.DOTALL))
+    assert re.match("^.*5 trees", output, re.DOTALL)
+    assert re.match("^.*2 iterations", output, re.DOTALL)
 
 
-def test_no_output_when_verbose_is_false():
+def test_no_output_when_verbose_is_false(spatial_data, seed):
     out = io.StringIO()
     with redirect_stdout(out):
         _ = NNDescent(
@@ -344,48 +313,48 @@ def test_no_output_when_verbose_is_false():
             metric="euclidean",
             metric_kwds={},
             n_neighbors=4,
-            random_state=np.random,
+            random_state=np.random.RandomState(seed),
             n_trees=5,
             n_iters=2,
             verbose=False,
         )
     output = out.getvalue().strip()
-    assert_equal(len(output), 0)
+    assert len(output) == 0
 
 
 # same as the previous two test, but this time using the PyNNDescentTransformer
 # interface
-def test_transformer_output_when_verbose_is_true():
+def test_transformer_output_when_verbose_is_true(spatial_data, seed):
     out = io.StringIO()
     with redirect_stdout(out):
         _ = PyNNDescentTransformer(
             n_neighbors=4,
             metric="euclidean",
             metric_kwds={},
-            random_state=np.random,
+            random_state=np.random.RandomState(seed),
             n_trees=5,
             n_iters=2,
             verbose=True,
         ).fit_transform(spatial_data)
     output = out.getvalue()
-    assert_true(re.match("^.*5 trees", output, re.DOTALL))
-    assert_true(re.match("^.*2 iterations", output, re.DOTALL))
+    assert re.match("^.*5 trees", output, re.DOTALL)
+    assert re.match("^.*2 iterations", output, re.DOTALL)
 
 
-def test_transformer_output_when_verbose_is_false():
+def test_transformer_output_when_verbose_is_false(spatial_data, seed):
     out = io.StringIO()
     with redirect_stdout(out):
         _ = PyNNDescentTransformer(
             n_neighbors=4,
             metric="standardised_euclidean",
             metric_kwds={"sigma": np.ones(spatial_data.shape[1])},
-            random_state=np.random,
+            random_state=np.random.RandomState(seed),
             n_trees=5,
             n_iters=2,
             verbose=False,
         ).fit_transform(spatial_data)
     output = out.getvalue().strip()
-    assert_equal(len(output), 0)
+    assert len(output) == 0
 
 
 def test_pickle_unpickle():
@@ -394,18 +363,13 @@ def test_pickle_unpickle():
     x1 = seed.normal(0, 100, (1000, 50))
     x2 = seed.normal(0, 100, (1000, 50))
 
-    index1 = NNDescent(
-        x1,
-        "euclidean",
-        {},
-        10,
-        random_state=None,
-    )
+    index1 = NNDescent(x1, "euclidean", {}, 10, random_state=None)
     neighbors1, distances1 = index1.query(x2)
 
-    pickle.dump(index1, open("test_tmp.pkl", "wb"))
-    index2 = pickle.load(open("test_tmp.pkl", "rb"))
-    os.remove("test_tmp.pkl")
+    mem_temp = io.BytesIO()
+    pickle.dump(index1, mem_temp)
+    mem_temp.seek(0)
+    index2 = pickle.load(mem_temp)
 
     neighbors2, distances2 = index2.query(x2)
 
@@ -419,19 +383,13 @@ def test_compressed_pickle_unpickle():
     x1 = seed.normal(0, 100, (1000, 50))
     x2 = seed.normal(0, 100, (1000, 50))
 
-    index1 = NNDescent(
-        x1,
-        "euclidean",
-        {},
-        10,
-        random_state=None,
-        compressed=True,
-    )
+    index1 = NNDescent(x1, "euclidean", {}, 10, random_state=None, compressed=True)
     neighbors1, distances1 = index1.query(x2)
 
-    pickle.dump(index1, open("test_tmp.pkl", "wb"))
-    index2 = pickle.load(open("test_tmp.pkl", "rb"))
-    os.remove("test_tmp.pkl")
+    mem_temp = io.BytesIO()
+    pickle.dump(index1, mem_temp)
+    mem_temp.seek(0)
+    index2 = pickle.load(mem_temp)
 
     neighbors2, distances2 = index2.query(x2)
 
@@ -448,9 +406,10 @@ def test_transformer_pickle_unpickle():
     index1 = PyNNDescentTransformer(n_neighbors=10).fit(x1)
     result1 = index1.transform(x2)
 
-    pickle.dump(index1, open("test_tmp.pkl", "wb"))
-    index2 = pickle.load(open("test_tmp.pkl", "rb"))
-    os.remove("test_tmp.pkl")
+    mem_temp = io.BytesIO()
+    pickle.dump(index1, mem_temp)
+    mem_temp.seek(0)
+    index2 = pickle.load(mem_temp)
 
     result2 = index2.transform(x2)
 
@@ -464,18 +423,13 @@ def test_joblib_dump():
     x1 = seed.normal(0, 100, (1000, 50))
     x2 = seed.normal(0, 100, (1000, 50))
 
-    index1 = NNDescent(
-        x1,
-        "euclidean",
-        {},
-        10,
-        random_state=None,
-    )
+    index1 = NNDescent(x1, "euclidean", {}, 10, random_state=None)
     neighbors1, distances1 = index1.query(x2)
 
-    joblib.dump(index1, "test_tmp.dump")
-    index2 = joblib.load("test_tmp.dump")
-    os.remove("test_tmp.dump")
+    mem_temp = io.BytesIO()
+    joblib.dump(index1, mem_temp)
+    mem_temp.seek(0)
+    index2 = joblib.load(mem_temp)
 
     neighbors2, distances2 = index2.query(x2)
 


=====================================
pynndescent/tests/test_rank.py
=====================================
@@ -1,5 +1,6 @@
+import pytest
 import numpy as np
-from numpy.testing import assert_equal, assert_array_equal
+from numpy.testing import assert_array_equal
 
 from pynndescent.distances import rankdata
 
@@ -94,51 +95,132 @@ def test_big_tie():
         assert_array_equal(r, expected_rank * data, "test failed with n=%d" % n)
 
 
-# fmt: off
-_cases = (
-    # values, method, expected
-    (np.array([], np.float64), 'average', np.array([], np.float64)),
-    (np.array([], np.float64), 'min', np.array([], np.float64)),
-    (np.array([], np.float64), 'max', np.array([], np.float64)),
-    (np.array([], np.float64), 'dense', np.array([], np.float64)),
-    (np.array([], np.float64), 'ordinal', np.array([], np.float64)),
-    #
-    (np.array([100], np.float64), 'average', np.array([1.0], np.float64)),
-    (np.array([100], np.float64), 'min', np.array([1.0], np.float64)),
-    (np.array([100], np.float64), 'max', np.array([1.0], np.float64)),
-    (np.array([100], np.float64), 'dense', np.array([1.0], np.float64)),
-    (np.array([100], np.float64), 'ordinal', np.array([1.0], np.float64)),
-    # #
-    (np.array([100, 100, 100], np.float64), 'average', np.array([2.0, 2.0, 2.0], np.float64)),
-    (np.array([100, 100, 100], np.float64), 'min', np.array([1.0, 1.0, 1.0], np.float64)),
-    (np.array([100, 100, 100], np.float64), 'max', np.array([3.0, 3.0, 3.0], np.float64)),
-    (np.array([100, 100, 100], np.float64), 'dense', np.array([1.0, 1.0, 1.0], np.float64)),
-    (np.array([100, 100, 100], np.float64), 'ordinal', np.array([1.0, 2.0, 3.0], np.float64)),
-    #
-    (np.array([100, 300, 200], np.float64), 'average', np.array([1.0, 3.0, 2.0], np.float64)),
-    (np.array([100, 300, 200], np.float64), 'min', np.array([1.0, 3.0, 2.0], np.float64)),
-    (np.array([100, 300, 200], np.float64), 'max', np.array([1.0, 3.0, 2.0], np.float64)),
-    (np.array([100, 300, 200], np.float64), 'dense', np.array([1.0, 3.0, 2.0], np.float64)),
-    (np.array([100, 300, 200], np.float64), 'ordinal', np.array([1.0, 3.0, 2.0], np.float64)),
-    #
-    (np.array([100, 200, 300, 200], np.float64), 'average', np.array([1.0, 2.5, 4.0, 2.5], np.float64)),
-    (np.array([100, 200, 300, 200], np.float64), 'min', np.array([1.0, 2.0, 4.0, 2.0], np.float64)),
-    (np.array([100, 200, 300, 200], np.float64), 'max', np.array([1.0, 3.0, 4.0, 3.0], np.float64)),
-    (np.array([100, 200, 300, 200], np.float64), 'dense', np.array([1.0, 2.0, 3.0, 2.0], np.float64)),
-    (np.array([100, 200, 300, 200], np.float64), 'ordinal', np.array([1.0, 2.0, 4.0, 3.0], np.float64)),
-    #
-    (np.array([100, 200, 300, 200, 100], np.float64), 'average', np.array([1.5, 3.5, 5.0, 3.5, 1.5], np.float64)),
-    (np.array([100, 200, 300, 200, 100], np.float64), 'min', np.array([1.0, 3.0, 5.0, 3.0, 1.0], np.float64)),
-    (np.array([100, 200, 300, 200, 100], np.float64), 'max', np.array([2.0, 4.0, 5.0, 4.0, 2.0], np.float64)),
-    (np.array([100, 200, 300, 200, 100], np.float64), 'dense', np.array([1.0, 2.0, 3.0, 2.0, 1.0], np.float64)),
-    (np.array([100, 200, 300, 200, 100], np.float64), 'ordinal', np.array([1.0, 3.0, 5.0, 4.0, 2.0], np.float64)),
-    #
-    (np.array([10] * 30, np.float64), 'ordinal', np.arange(1.0, 31.0, dtype=np.float64)),
+ at pytest.mark.parametrize(
+    "values,method,expected",
+    [  # values, method, expected
+        (np.array([], np.float64), "average", np.array([], np.float64)),
+        (np.array([], np.float64), "min", np.array([], np.float64)),
+        (np.array([], np.float64), "max", np.array([], np.float64)),
+        (np.array([], np.float64), "dense", np.array([], np.float64)),
+        (np.array([], np.float64), "ordinal", np.array([], np.float64)),
+        #
+        (np.array([100], np.float64), "average", np.array([1.0], np.float64)),
+        (np.array([100], np.float64), "min", np.array([1.0], np.float64)),
+        (np.array([100], np.float64), "max", np.array([1.0], np.float64)),
+        (np.array([100], np.float64), "dense", np.array([1.0], np.float64)),
+        (np.array([100], np.float64), "ordinal", np.array([1.0], np.float64)),
+        # #
+        (
+            np.array([100, 100, 100], np.float64),
+            "average",
+            np.array([2.0, 2.0, 2.0], np.float64),
+        ),
+        (
+            np.array([100, 100, 100], np.float64),
+            "min",
+            np.array([1.0, 1.0, 1.0], np.float64),
+        ),
+        (
+            np.array([100, 100, 100], np.float64),
+            "max",
+            np.array([3.0, 3.0, 3.0], np.float64),
+        ),
+        (
+            np.array([100, 100, 100], np.float64),
+            "dense",
+            np.array([1.0, 1.0, 1.0], np.float64),
+        ),
+        (
+            np.array([100, 100, 100], np.float64),
+            "ordinal",
+            np.array([1.0, 2.0, 3.0], np.float64),
+        ),
+        #
+        (
+            np.array([100, 300, 200], np.float64),
+            "average",
+            np.array([1.0, 3.0, 2.0], np.float64),
+        ),
+        (
+            np.array([100, 300, 200], np.float64),
+            "min",
+            np.array([1.0, 3.0, 2.0], np.float64),
+        ),
+        (
+            np.array([100, 300, 200], np.float64),
+            "max",
+            np.array([1.0, 3.0, 2.0], np.float64),
+        ),
+        (
+            np.array([100, 300, 200], np.float64),
+            "dense",
+            np.array([1.0, 3.0, 2.0], np.float64),
+        ),
+        (
+            np.array([100, 300, 200], np.float64),
+            "ordinal",
+            np.array([1.0, 3.0, 2.0], np.float64),
+        ),
+        #
+        (
+            np.array([100, 200, 300, 200], np.float64),
+            "average",
+            np.array([1.0, 2.5, 4.0, 2.5], np.float64),
+        ),
+        (
+            np.array([100, 200, 300, 200], np.float64),
+            "min",
+            np.array([1.0, 2.0, 4.0, 2.0], np.float64),
+        ),
+        (
+            np.array([100, 200, 300, 200], np.float64),
+            "max",
+            np.array([1.0, 3.0, 4.0, 3.0], np.float64),
+        ),
+        (
+            np.array([100, 200, 300, 200], np.float64),
+            "dense",
+            np.array([1.0, 2.0, 3.0, 2.0], np.float64),
+        ),
+        (
+            np.array([100, 200, 300, 200], np.float64),
+            "ordinal",
+            np.array([1.0, 2.0, 4.0, 3.0], np.float64),
+        ),
+        #
+        (
+            np.array([100, 200, 300, 200, 100], np.float64),
+            "average",
+            np.array([1.5, 3.5, 5.0, 3.5, 1.5], np.float64),
+        ),
+        (
+            np.array([100, 200, 300, 200, 100], np.float64),
+            "min",
+            np.array([1.0, 3.0, 5.0, 3.0, 1.0], np.float64),
+        ),
+        (
+            np.array([100, 200, 300, 200, 100], np.float64),
+            "max",
+            np.array([2.0, 4.0, 5.0, 4.0, 2.0], np.float64),
+        ),
+        (
+            np.array([100, 200, 300, 200, 100], np.float64),
+            "dense",
+            np.array([1.0, 2.0, 3.0, 2.0, 1.0], np.float64),
+        ),
+        (
+            np.array([100, 200, 300, 200, 100], np.float64),
+            "ordinal",
+            np.array([1.0, 3.0, 5.0, 4.0, 2.0], np.float64),
+        ),
+        #
+        (
+            np.array([10] * 30, np.float64),
+            "ordinal",
+            np.arange(1.0, 31.0, dtype=np.float64),
+        ),
+    ],
 )
-# fmt: on
-
-
-def test_cases():
-    for values, method, expected in _cases:
-        r = rankdata(values, method=method)
-        assert_array_equal(r, expected)
+def test_cases(values, method, expected):
+    r = rankdata(values, method=method)
+    assert_array_equal(r, expected)


=====================================
pynndescent/utils.py
=====================================
@@ -6,17 +6,17 @@ import time
 
 import numba
 from numba.core import types
-from numba.experimental import structref
+import numba.experimental.structref as structref
 import numpy as np
 
 
- at numba.njit("void(i8[:], i8)")
+ at numba.njit("void(i8[:], i8)", cache=True)
 def seed(rng_state, seed):
     """Seed the random number generator with a given seed."""
     rng_state.fill(seed + 0xFFFF)
 
 
- at numba.njit("i4(i8[:])")
+ at numba.njit("i4(i8[:])", cache=True)
 def tau_rand_int(state):
     """A fast (pseudo)-random number generator.
 
@@ -42,7 +42,7 @@ def tau_rand_int(state):
     return state[0] ^ state[1] ^ state[2]
 
 
- at numba.njit("f4(i8[:])")
+ at numba.njit("f4(i8[:])", cache=True)
 def tau_rand(state):
     """A fast (pseudo)-random number generator for floats in the range [0,1]
 
@@ -69,9 +69,10 @@ def tau_rand(state):
     locals={
         "dim": numba.types.intp,
         "i": numba.types.uint32,
-        "result": numba.types.float32,
+        # "result": numba.types.float32, # This provides speed, but causes errors in corner cases
     },
     fastmath=True,
+    cache=True,
 )
 def norm(vec):
     """Compute the (standard l2) norm of a vector.
@@ -91,7 +92,7 @@ def norm(vec):
     return np.sqrt(result)
 
 
- at numba.njit()
+ at numba.njit(cache=True)
 def rejection_sample(n_samples, pool_size, rng_state):
     """Generate n_samples many integers from 0 to pool_size such that no
     integer is selected twice. The duplication constraint is achieved via
@@ -147,31 +148,27 @@ class Heap(structref.StructRefProxy):
         return Heap_get_flags(self)
 
 
- at numba.njit
+ at numba.njit(cache=True)
 def Heap_get_flags(self):
     return self.flags
 
 
- at numba.njit
+ at numba.njit(cache=True)
 def Heap_get_distances(self):
     return self.distances
 
 
- at numba.njit
+ at numba.njit(cache=True)
 def Heap_get_indices(self):
     return self.indices
 
 
-structref.define_proxy(
-    Heap,
-    HeapType,
-    ["indices", "distances", "flags"],
-)
+structref.define_proxy(Heap, HeapType, ["indices", "distances", "flags"])
 
 # Heap = namedtuple("Heap", ("indices", "distances", "flags"))
 
 
- at numba.njit()
+ at numba.njit(cache=True)
 def make_heap(n_points, size):
     """Constructor for the numba enabled heap objects. The heaps are used
     for approximate nearest neighbor search, maintaining a list of potential
@@ -202,197 +199,7 @@ def make_heap(n_points, size):
     return result
 
 
- at numba.jit(
-    locals={
-        "indices": numba.types.int32[::1],
-        "weights": numba.types.float32[::1],
-        "is_new": numba.types.uint8[::1],
-        "i": numba.types.uint16,
-        "ic1": numba.types.uint16,
-        "ic2": numba.types.uint16,
-        "i_swap": numba.types.uint16,
-        "heap_size": numba.types.uint16,
-    }
-)
-def heap_push(heap, row, weight, index, flag):
-    """Push a new element onto the heap. The heap stores potential neighbors
-    for each graph_data point. The ``row`` parameter determines which graph_data point we
-    are addressing, the ``weight`` determines the distance (for heap sorting),
-    the ``index`` is the element to add, and the flag determines whether this
-    is to be considered a new addition.
-
-    Parameters
-    ----------
-    heap: ndarray generated by ``make_heap``
-        The heap object to push into
-
-    row: int
-        Which actual heap within the heap object to push to
-
-    weight: float
-        The priority value of the element to push onto the heap
-
-    index: int
-        The actual value to be pushed
-
-    flag: int
-        Whether to flag the newly added element or not.
-
-    Returns
-    -------
-    success: The number of new elements successfully pushed into the heap.
-    """
-    row = np.int32(row)
-    weight = np.float32(weight)
-    index = np.int32(index)
-    flag = np.uint8(flag)
-
-    indices = heap[0][row]
-    weights = heap[1][row]
-    is_new = heap[2][row]
-
-    if weight >= weights[0]:
-        return 0
-
-    # break if we already have this element.
-    for i in range(indices.shape[0]):
-        if index == indices[i]:
-            return 0
-
-    # insert val at position zero
-    weights[0] = weight
-    indices[0] = index
-    is_new[0] = flag
-
-    # descend the heap, swapping values until the max heap criterion is met
-    i = 0
-    while True:
-        ic1 = 2 * i + 1
-        ic2 = ic1 + 1
-
-        if ic1 >= indices.shape[0]:
-            break
-        elif ic2 >= indices.shape[0]:
-            if weights[ic1] > weight:
-                i_swap = ic1
-            else:
-                break
-        elif weights[ic1] >= weights[ic2]:
-            if weight < weights[ic1]:
-                i_swap = ic1
-            else:
-                break
-        else:
-            if weight < weights[ic2]:
-                i_swap = ic2
-            else:
-                break
-
-        weights[i] = weights[i_swap]
-        indices[i] = indices[i_swap]
-        is_new[i] = is_new[i_swap]
-
-        i = i_swap
-
-    weights[i] = weight
-    indices[i] = index
-    is_new[i] = flag
-
-    return 1
-
-
- at numba.jit(
-    locals={
-        "indices": numba.types.int32[::1],
-        "weights": numba.types.float32[::1],
-        "is_new": numba.types.uint8[::1],
-        "i": numba.types.uint16,
-        "ic1": numba.types.uint16,
-        "ic2": numba.types.uint16,
-        "i_swap": numba.types.uint16,
-        "heap_size": numba.types.uint16,
-    }
-)
-def unchecked_heap_push(heap, row, weight, index, flag):
-    """Push a new element onto the heap. The heap stores potential neighbors
-    for each graph_data point. The ``row`` parameter determines which graph_data point we
-    are addressing, the ``weight`` determines the distance (for heap sorting),
-    the ``index`` is the element to add, and the flag determines whether this
-    is to be considered a new addition.
-
-    Parameters
-    ----------
-    heap: ndarray generated by ``make_heap``
-        The heap object to push into
-
-    row: int
-        Which actual heap within the heap object to push to
-
-    weight: float
-        The priority value of the element to push onto the heap
-
-    index: int
-        The actual value to be pushed
-
-    flag: int
-        Whether to flag the newly added element or not.
-
-    Returns
-    -------
-    success: The number of new elements successfully pushed into the heap.
-    """
-    if weight >= heap[1][row, 0]:
-        return 0
-
-    indices = heap[0][row]
-    weights = heap[1][row]
-    is_new = heap[2][row]
-
-    # insert val at position zero
-    weights[0] = weight
-    indices[0] = index
-    is_new[0] = flag
-
-    heap_size = indices.shape[0]
-
-    # descend the heap, swapping values until the max heap criterion is met
-    i = 0
-    while True:
-        ic1 = 2 * i + 1
-        ic2 = ic1 + 1
-
-        if ic1 >= heap_size:
-            break
-        elif ic2 >= heap_size:
-            if weights[ic1] > weight:
-                i_swap = ic1
-            else:
-                break
-        elif weights[ic1] >= weights[ic2]:
-            if weight < weights[ic1]:
-                i_swap = ic1
-            else:
-                break
-        else:
-            if weight < weights[ic2]:
-                i_swap = ic2
-            else:
-                break
-
-        weights[i] = weights[i_swap]
-        indices[i] = indices[i_swap]
-        is_new[i] = is_new[i_swap]
-
-        i = i_swap
-
-    weights[i] = weight
-    indices[i] = index
-    is_new[i] = flag
-
-    return 1
-
-
- at numba.njit()
+ at numba.njit(cache=True)
 def siftdown(heap1, heap2, elt):
     """Restore the heap property for a heap with an out of place element
     at position ``elt``. This works with a heap pair where heap1 carries
@@ -416,7 +223,7 @@ def siftdown(heap1, heap2, elt):
             elt = swap
 
 
- at numba.njit()
+ at numba.njit(cache=True)
 def deheap_sort(heap):
     """Given an array of heaps (of graph_indices and weights), unpack the heap
     out to give and array of sorted lists of graph_indices and weights by increasing
@@ -460,51 +267,47 @@ def deheap_sort(heap):
     return indices.astype(np.int64), weights
 
 
- at numba.njit()
-def smallest_flagged(heap, row):
-    """Search the heap for the smallest element that is
-    still flagged.
-
-    Parameters
-    ----------
-    heap: array of shape (3, n_samples, n_neighbors)
-        The heaps to search
+# @numba.njit()
+# def smallest_flagged(heap, row):
+#     """Search the heap for the smallest element that is
+#     still flagged.
+#
+#     Parameters
+#     ----------
+#     heap: array of shape (3, n_samples, n_neighbors)
+#         The heaps to search
+#
+#     row: int
+#         Which of the heaps to search
+#
+#     Returns
+#     -------
+#     index: int
+#         The index of the smallest flagged element
+#         of the ``row``th heap, or -1 if no flagged
+#         elements remain in the heap.
+#     """
+#     ind = heap[0][row]
+#     dist = heap[1][row]
+#     flag = heap[2][row]
+#
+#     min_dist = np.inf
+#     result_index = -1
+#
+#     for i in range(ind.shape[0]):
+#         if flag[i] == 1 and dist[i] < min_dist:
+#             min_dist = dist[i]
+#             result_index = i
+#
+#     if result_index >= 0:
+#         flag[result_index] = 0.0
+#         return int(ind[result_index])
+#     else:
+#         return -1
 
-    row: int
-        Which of the heaps to search
 
-    Returns
-    -------
-    index: int
-        The index of the smallest flagged element
-        of the ``row``th heap, or -1 if no flagged
-        elements remain in the heap.
-    """
-    ind = heap[0][row]
-    dist = heap[1][row]
-    flag = heap[2][row]
-
-    min_dist = np.inf
-    result_index = -1
-
-    for i in range(ind.shape[0]):
-        if flag[i] == 1 and dist[i] < min_dist:
-            min_dist = dist[i]
-            result_index = i
-
-    if result_index >= 0:
-        flag[result_index] = 0.0
-        return int(ind[result_index])
-    else:
-        return -1
-
-
- at numba.njit(parallel=True, locals={"idx": numba.types.int64})
-def new_build_candidates(
-    current_graph,
-    max_candidates,
-    rng_state,
-):
+ at numba.njit(parallel=True, locals={"idx": numba.types.int64}, cache=True)
+def new_build_candidates(current_graph, max_candidates, rng_state, n_threads):
     """Build a heap of candidate neighbors for nearest neighbor descent. For
     each vertex the candidate neighbors are any current neighbors, and any
     vertices that have the vertex as one of their nearest neighbors.
@@ -541,8 +344,6 @@ def new_build_candidates(
         (n_vertices, max_candidates), np.inf, dtype=np.float32
     )
 
-    n_threads = numba.get_num_threads()
-
     for n in numba.prange(n_threads):
         local_rng_state = rng_state + n
         for i in range(n_vertices):
@@ -558,10 +359,7 @@ def new_build_candidates(
                 if isn:
                     if i % n_threads == n:
                         checked_heap_push(
-                            new_candidate_priority[i],
-                            new_candidate_indices[i],
-                            d,
-                            idx,
+                            new_candidate_priority[i], new_candidate_indices[i], d, idx
                         )
                     if idx % n_threads == n:
                         checked_heap_push(
@@ -573,10 +371,7 @@ def new_build_candidates(
                 else:
                     if i % n_threads == n:
                         checked_heap_push(
-                            old_candidate_priority[i],
-                            old_candidate_indices[i],
-                            d,
-                            idx,
+                            old_candidate_priority[i], old_candidate_indices[i], d, idx
                         )
                     if idx % n_threads == n:
                         checked_heap_push(
@@ -601,14 +396,14 @@ def new_build_candidates(
     return new_candidate_indices, old_candidate_indices
 
 
- at numba.njit("b1(u1[::1],i4)")
+ at numba.njit("b1(u1[::1],i4)", cache=True)
 def has_been_visited(table, candidate):
     loc = candidate >> 3
     mask = 1 << (candidate & 7)
     return table[loc] & mask
 
 
- at numba.njit("void(u1[::1],i4)")
+ at numba.njit("void(u1[::1],i4)", cache=True)
 def mark_visited(table, candidate):
     loc = candidate >> 3
     mask = 1 << (candidate & 7)
@@ -626,6 +421,7 @@ def mark_visited(table, candidate):
         "ic2": numba.types.uint16,
         "i_swap": numba.types.uint16,
     },
+    cache=True,
 )
 def simple_heap_push(priorities, indices, p, n):
     if p >= priorities[0]:
@@ -682,6 +478,7 @@ def simple_heap_push(priorities, indices, p, n):
         "ic2": numba.types.uint16,
         "i_swap": numba.types.uint16,
     },
+    cache=True,
 )
 def checked_heap_push(priorities, indices, p, n):
     if p >= priorities[0]:
@@ -743,65 +540,7 @@ def checked_heap_push(priorities, indices, p, n):
         "ic2": numba.types.uint16,
         "i_swap": numba.types.uint16,
     },
-)
-def flagged_heap_push(priorities, indices, flags, p, n, f):
-    if p >= priorities[0]:
-        return 0
-
-    size = priorities.shape[0]
-
-    # insert val at position zero
-    priorities[0] = p
-    indices[0] = n
-    flags[0] = f
-
-    # descend the heap, swapping values until the max heap criterion is met
-    i = 0
-    while True:
-        ic1 = 2 * i + 1
-        ic2 = ic1 + 1
-
-        if ic1 >= size:
-            break
-        elif ic2 >= size:
-            if priorities[ic1] > p:
-                i_swap = ic1
-            else:
-                break
-        elif priorities[ic1] >= priorities[ic2]:
-            if p < priorities[ic1]:
-                i_swap = ic1
-            else:
-                break
-        else:
-            if p < priorities[ic2]:
-                i_swap = ic2
-            else:
-                break
-
-        priorities[i] = priorities[i_swap]
-        indices[i] = indices[i_swap]
-        flags[i] = flags[i_swap]
-
-        i = i_swap
-
-    priorities[i] = p
-    indices[i] = n
-    flags[i] = f
-
-    return 1
-
-
- at numba.njit(
-    "i4(f4[::1],i4[::1],u1[::1],f4,i4,u1)",
-    fastmath=True,
-    locals={
-        "size": numba.types.intp,
-        "i": numba.types.uint16,
-        "ic1": numba.types.uint16,
-        "ic2": numba.types.uint16,
-        "i_swap": numba.types.uint16,
-    },
+    cache=True,
 )
 def checked_flagged_heap_push(priorities, indices, flags, p, n, f):
     if p >= priorities[0]:
@@ -867,14 +606,15 @@ def checked_flagged_heap_push(priorities, indices, flags, p, n, f):
         "i": numba.uint32,
         "j": numba.uint32,
     },
+    cache=True,
 )
-def apply_graph_updates_low_memory(current_graph, updates):
+def apply_graph_updates_low_memory(current_graph, updates, n_threads):
 
     n_changes = 0
     priorities = current_graph[1]
     indices = current_graph[0]
     flags = current_graph[2]
-    n_threads = numba.get_num_threads()
+    # n_threads = numba.get_num_threads()
 
     for n in numba.prange(n_threads):
         for i in range(len(updates)):
@@ -885,33 +625,21 @@ def apply_graph_updates_low_memory(current_graph, updates):
                     continue
 
                 if p % n_threads == n:
-                    # added = heap_push(current_graph, p, d, q, 1)
                     added = checked_flagged_heap_push(
-                        priorities[p],
-                        indices[p],
-                        flags[p],
-                        d,
-                        q,
-                        1,
+                        priorities[p], indices[p], flags[p], d, q, 1
                     )
                     n_changes += added
 
                 if q % n_threads == n:
-                    # added = heap_push(current_graph, q, d, p, 1)
                     added = checked_flagged_heap_push(
-                        priorities[q],
-                        indices[q],
-                        flags[q],
-                        d,
-                        p,
-                        1,
+                        priorities[q], indices[q], flags[q], d, p, 1
                     )
                     n_changes += added
 
     return n_changes
 
 
- at numba.njit(locals={"p": numba.types.int64, "q": numba.types.int64})
+ at numba.njit(locals={"p": numba.types.int64, "q": numba.types.int64}, cache=True)
 def apply_graph_updates_high_memory(current_graph, updates, in_graph):
 
     n_changes = 0
@@ -928,8 +656,7 @@ def apply_graph_updates_high_memory(current_graph, updates, in_graph):
             elif q in in_graph[p]:
                 pass
             else:
-                # added = unchecked_heap_push(current_graph, p, d, q, 1)
-                added = flagged_heap_push(
+                added = checked_flagged_heap_push(
                     current_graph[1][p],
                     current_graph[0][p],
                     current_graph[2][p],
@@ -945,8 +672,7 @@ def apply_graph_updates_high_memory(current_graph, updates, in_graph):
             if p == q or p in in_graph[q]:
                 pass
             else:
-                # added = unchecked_heap_push(current_graph, q, d, p, 1)
-                added = flagged_heap_push(
+                added = checked_flagged_heap_push(
                     current_graph[1][p],
                     current_graph[0][p],
                     current_graph[2][p],
@@ -962,7 +688,7 @@ def apply_graph_updates_high_memory(current_graph, updates, in_graph):
     return n_changes
 
 
- at numba.njit()
+ at numba.njit(cache=True)
 def initalize_heap_from_graph_indices(heap, graph_indices, data, metric):
 
     for i in range(graph_indices.shape[0]):
@@ -970,12 +696,12 @@ def initalize_heap_from_graph_indices(heap, graph_indices, data, metric):
             j = graph_indices[i, idx]
             if j >= 0:
                 d = metric(data[i], data[j])
-                flagged_heap_push(heap[1][i], heap[0][i], heap[2][i], d, j, 1)
+                checked_flagged_heap_push(heap[1][i], heap[0][i], heap[2][i], d, j, 1)
 
     return heap
 
 
- at numba.njit(parallel=True)
+ at numba.njit(parallel=True, cache=True)
 def sparse_initalize_heap_from_graph_indices(
     heap, graph_indices, data_indptr, data_indices, data_vals, metric
 ):
@@ -988,8 +714,7 @@ def sparse_initalize_heap_from_graph_indices(
             ind2 = data_indices[data_indptr[j] : data_indptr[j + 1]]
             data2 = data_vals[data_indptr[j] : data_indptr[j + 1]]
             d = metric(ind1, data1, ind2, data2)
-            # unchecked_heap_push(heap, i, d, j, 1)
-            flagged_heap_push(heap[0][i], heap[1][i], heap[2][i], j, d, 1)
+            checked_flagged_heap_push(heap[1][i], heap[0][i], heap[2][i], d, j, 1)
 
     return heap
 


=====================================
requirements.txt
=====================================
@@ -1,4 +1,5 @@
 joblib
+numpy>=1.17
 scikit-learn>=0.18
 scipy>=1.0
 numba>=0.51.2


=====================================
setup.py
=====================================
@@ -8,7 +8,7 @@ def readme():
 
 configuration = {
     "name": "pynndescent",
-    "version": "0.5.2",
+    "version": "0.5.5",
     "description": "Nearest Neighbor Descent",
     "long_description": readme(),
     "classifiers": [



View it on GitLab: https://salsa.debian.org/python-team/packages/python-pynndescent/-/commit/7d135bb3a0913bb5fe61725a2f41d870b2a18c3f

-- 
View it on GitLab: https://salsa.debian.org/python-team/packages/python-pynndescent/-/commit/7d135bb3a0913bb5fe61725a2f41d870b2a18c3f
You're receiving this email because of your account on salsa.debian.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://alioth-lists.debian.net/pipermail/debian-med-commit/attachments/20220102/73652238/attachment-0001.htm>


More information about the debian-med-commit mailing list