[Git][debian-gis-team/flox][upstream] New upstream version 0.10.2

Antonio Valentino (@antonio.valentino) gitlab at salsa.debian.org
Mon Apr 7 06:55:04 BST 2025



Antonio Valentino pushed to branch upstream at Debian GIS Project / flox


Commits:
068a10f3 by Antonio Valentino at 2025-04-07T05:25:43+00:00
New upstream version 0.10.2
- - - - -


15 changed files:

- asv_bench/benchmarks/combine.py
- ci/docs.yml
- docs/source/user-stories.md
- + docs/source/user-stories/large-zonal-stats.ipynb
- flox/__init__.py
- flox/aggregate_flox.py
- flox/core.py
- flox/dask_array_ops.py
- flox/xarray.py
- flox/xrutils.py
- pyproject.toml
- tests/__init__.py
- tests/conftest.py
- tests/test_core.py
- tests/test_xarray.py


Changes:

=====================================
asv_bench/benchmarks/combine.py
=====================================
@@ -14,7 +14,11 @@ def _get_combine(combine):
     if combine == "grouped":
         return partial(flox.core._grouped_combine, engine="numpy")
     else:
-        return partial(flox.core._simple_combine, reindex=False)
+        try:
+            reindex = flox.ReindexStrategy(blockwise=False)
+        except AttributeError:
+            reindex = False
+        return partial(flox.core._simple_combine, reindex=reindex)
 
 
 class Combine:


=====================================
ci/docs.yml
=====================================
@@ -15,6 +15,7 @@ dependencies:
   - matplotlib-base
   - myst-parser
   - myst-nb
+  - sparse
   - sphinx
   - sphinx-remove-toctrees
   - furo>=2024.08


=====================================
docs/source/user-stories.md
=====================================
@@ -10,4 +10,5 @@
    user-stories/climatology-hourly-cubed.ipynb
    user-stories/custom-aggregations.ipynb
    user-stories/nD-bins.ipynb
+   user-stories/large-zonal-stats.ipynb
 ```


=====================================
docs/source/user-stories/large-zonal-stats.ipynb
=====================================
@@ -0,0 +1,194 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "0",
+   "metadata": {},
+   "source": [
+    "# Large Raster Zonal Statistics\n",
+    "\n",
+    "\"Zonal statistics\" spans a large range of problems. \n",
+    "\n",
+    "This one is inspired by [this issue](https://github.com/xarray-contrib/flox/issues/428), where a cell areas raster is aggregated over 6 different groupers and summed. Each array involved has shape 560_000 x 1440_000 and chunk size 10_000 x 10_000. Three of the groupers `tcl_year`, `drivers`, and `tcd_thresholds` have a small number of group labels (23, 5, and 7). \n",
+    "\n",
+    "The last 3 groupers are [GADM](https://gadm.org/) level 0, 1, 2 administrative area polygons rasterized to this grid; with 248, 86, and 854 unique labels respectively (arrays `adm0`, `adm1`, and `adm2`). These correspond to country-level, state-level, and county-level administrative boundaries. "
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "1",
+   "metadata": {},
+   "source": [
+    "## Example dataset"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "2",
+   "metadata": {},
+   "source": [
+    "Here is a representative version of the dataset (in terms of size and chunk sizes)."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "3",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import dask.array\n",
+    "import numpy as np\n",
+    "import xarray as xr\n",
+    "\n",
+    "from flox.xarray import xarray_reduce\n",
+    "\n",
+    "sizes = {\"y\": 560_000, \"x\": 1440_000}\n",
+    "chunksizes = {\"y\": 2_000, \"x\": 2_000}\n",
+    "dims = (\"y\", \"x\")\n",
+    "shape = tuple(sizes[d] for d in dims)\n",
+    "chunks = tuple(chunksizes[d] for d in dims)\n",
+    "\n",
+    "ds = xr.Dataset(\n",
+    "    {\n",
+    "        \"areas\": (dims, dask.array.ones(shape, chunks=chunks, dtype=np.float32)),\n",
+    "        \"tcl_year\": (\n",
+    "            dims,\n",
+    "            1 + dask.array.zeros(shape, chunks=chunks, dtype=np.float32),\n",
+    "        ),\n",
+    "        \"drivers\": (dims, 2 + dask.array.zeros(shape, chunks=chunks, dtype=np.float32)),\n",
+    "        \"tcd_thresholds\": (\n",
+    "            dims,\n",
+    "            3 + dask.array.zeros(shape, chunks=chunks, dtype=np.float32),\n",
+    "        ),\n",
+    "        \"adm0\": (dims, 4 + dask.array.ones(shape, chunks=chunks, dtype=np.float32)),\n",
+    "        \"adm1\": (dims, 5 + dask.array.zeros(shape, chunks=chunks, dtype=np.float32)),\n",
+    "        \"adm2\": (dims, 6 + dask.array.zeros(shape, chunks=chunks, dtype=np.float32)),\n",
+    "    }\n",
+    ")\n",
+    "ds"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "4",
+   "metadata": {},
+   "source": [
+    "## Zonal Statistics"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "5",
+   "metadata": {},
+   "source": [
+    "Next define the grouper arrays and expected group labels"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "6",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "by = (ds.tcl_year, ds.drivers, ds.tcd_thresholds, ds.adm0, ds.adm1, ds.adm2)\n",
+    "expected_groups = (\n",
+    "    np.arange(23),\n",
+    "    np.arange(1, 6),\n",
+    "    np.arange(1, 8),\n",
+    "    np.arange(248),\n",
+    "    np.arange(86),\n",
+    "    np.arange(854),\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "7",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "result = xarray_reduce(\n",
+    "    ds.areas,\n",
+    "    *by,\n",
+    "    expected_groups=expected_groups,\n",
+    "    func=\"sum\",\n",
+    ")\n",
+    "result"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "8",
+   "metadata": {},
+   "source": [
+    "Formulating the three admin levels as orthogonal dimensions is quite wasteful --- not all countries have 86 states or 854 counties per state. \n",
+    "\n",
+    "We end up with one humoungous 56GB chunk, that is mostly empty.\n",
+    "\n",
+    "## We can do better using a sparse array\n",
+    "\n",
+    "Since the results are very sparse, we can instruct flox to constructing dense arrays of intermediate results on the full 23 x 5 x 7 x 248 x 86 x 854 output grid.\n",
+    "\n",
+    "```python\n",
+    "ReindexStrategy(\n",
+    "    # do not reindex to the full output grid at the blockwise aggregation stage\n",
+    "    blockwise=False,\n",
+    "    # when combining intermediate results after blockwise aggregation, reindex to the\n",
+    "    # common grid using a sparse.COO array type\n",
+    "    array_type=ReindexArrayType.SPARSE_COO\n",
+    ")\n",
+    "```"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "9",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from flox import ReindexArrayType, ReindexStrategy\n",
+    "\n",
+    "result = xarray_reduce(\n",
+    "    ds.areas,\n",
+    "    *by,\n",
+    "    expected_groups=expected_groups,\n",
+    "    func=\"sum\",\n",
+    "    reindex=ReindexStrategy(\n",
+    "        blockwise=False,\n",
+    "        array_type=ReindexArrayType.SPARSE_COO,\n",
+    "    ),\n",
+    ")\n",
+    "result"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "10",
+   "metadata": {},
+   "source": [
+    "The output is a sparse array (see the **Data type** section)! Note that the size of this array cannot be estimated without computing it.\n",
+    "\n",
+    "The computation runs smoothly with low memory."
+   ]
+  }
+ ],
+ "metadata": {
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}


=====================================
flox/__init__.py
=====================================
@@ -9,6 +9,8 @@ from .core import (
     groupby_scan,
     rechunk_for_blockwise,
     rechunk_for_cohorts,
+    ReindexStrategy,
+    ReindexArrayType,
 )  # noqa
 
 


=====================================
flox/aggregate_flox.py
=====================================
@@ -261,7 +261,8 @@ def ffill(group_idx, array, *, axis, **kwargs):
     (group_starts,) = flag.nonzero()
 
     # https://stackoverflow.com/questions/41190852/most-efficient-way-to-forward-fill-nan-values-in-numpy-array
-    mask = isnull(array)
+    mask = isnull(array).copy()
+    # copy needed since we might have a broadcast-trick array
     # modified from the SO answer, just reset the index at the start of every group!
     mask[..., np.asarray(group_starts)] = False
 


=====================================
flox/core.py
=====================================
@@ -11,6 +11,8 @@ import warnings
 from collections import namedtuple
 from collections.abc import Callable, Sequence
 from concurrent.futures import ThreadPoolExecutor
+from dataclasses import dataclass
+from enum import Enum, auto
 from functools import partial, reduce
 from itertools import product
 from numbers import Integral
@@ -22,6 +24,7 @@ from typing import (
     TypedDict,
     TypeVar,
     Union,
+    cast,
     overload,
 )
 
@@ -118,13 +121,81 @@ DUMMY_AXIS = -2
 logger = logging.getLogger("flox")
 
 
+class ReindexArrayType(Enum):
+    """
+    Enum describing which array type to reindex to.
+
+    These are enumerated, rather than accepting a constructor,
+    because we might want to optimize for specific array types,
+    and because they don't necessarily have the same signature.
+
+    For example, scipy.sparse.COO only supports a fill_value of 0.
+    """
+
+    AUTO = auto()
+    NUMPY = auto()
+    SPARSE_COO = auto()
+    # Sadly, scipy.sparse.coo_array only supports fill_value = 0
+    # SCIPY_SPARSE_COO = auto()
+    # SPARSE_GCXS = auto()
+
+    def is_same_type(self, other) -> bool:
+        match self:
+            case ReindexArrayType.AUTO:
+                return True
+            case ReindexArrayType.NUMPY:
+                return isinstance(other, np.ndarray)
+            case ReindexArrayType.SPARSE_COO:
+                import sparse
+
+                return isinstance(other, sparse.COO)
+
+
+ at dataclass
+class ReindexStrategy:
+    """
+    Strategy for reindexing.
+
+    Attributes
+    ----------
+    blockwise: bool, optional
+        Whether to reindex at the blockwise step. Must be False for method="cohorts"
+    array_type: ReindexArrayType, optional
+        Whether to reindex to a different array type than array being reduced.
+    """
+
+    # whether to reindex at the blockwise step
+    blockwise: bool | None
+    array_type: ReindexArrayType = ReindexArrayType.AUTO
+
+    def __post_init__(self):
+        if self.blockwise is True:
+            if self.array_type not in (ReindexArrayType.AUTO, ReindexArrayType.NUMPY):
+                raise ValueError("Setting reindex.blockwise=True not allowed for non-numpy array type.")
+
+    def set_blockwise_for_numpy(self):
+        self.blockwise = True if self.blockwise is None else self.blockwise
+
+    def get_dask_meta(self, other, *, fill_value, dtype) -> Any:
+        import dask
+
+        if self.array_type is ReindexArrayType.AUTO:
+            other_type = type(other._meta) if isinstance(other, dask.array.Array) else type(other)
+            return other_type([], dtype=dtype)
+        elif self.array_type is ReindexArrayType.NUMPY:
+            return np.ndarray([], dtype=dtype)
+        elif self.array_type is ReindexArrayType.SPARSE_COO:
+            import sparse
+
+            return sparse.COO.from_numpy(np.ones(shape=(0,) * other.ndim, dtype=dtype), fill_value=fill_value)
+
+
 class FactorizeKwargs(TypedDict, total=False):
     """Used in _factorize_multiple"""
 
     by: T_Bys
     axes: T_Axes
     fastpath: bool
-    expected_groups: T_ExpectIndexOptTuple | None
     reindex: bool
     sort: bool
 
@@ -665,10 +736,50 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) ->
         return array.rechunk({axis: newchunks})
 
 
+def reindex_numpy(array, from_, to, fill_value, dtype, axis):
+    idx = from_.get_indexer(to)
+    indexer = [slice(None, None)] * array.ndim
+    indexer[axis] = idx
+    reindexed = array[tuple(indexer)]
+    if any(idx == -1):
+        if fill_value is None:
+            raise ValueError("Filling is required. fill_value cannot be None.")
+        indexer[axis] = idx == -1
+        reindexed = reindexed.astype(dtype, copy=False)
+        reindexed[tuple(indexer)] = fill_value
+    return reindexed
+
+
+def reindex_pydata_sparse_coo(array, from_, to, fill_value, dtype, axis):
+    import sparse
+
+    assert axis == -1
+
+    if fill_value is None:
+        raise ValueError("Filling is required. fill_value cannot be None.")
+    idx = to.get_indexer(from_)
+    assert (idx != -1).all()  # FIXME
+    shape = array.shape
+    ranges = np.broadcast_arrays(*np.ix_(*(tuple(np.arange(size) for size in shape[:axis]) + (idx,))))
+    coords = np.stack(ranges, axis=0).reshape(array.ndim, -1)
+
+    data = array.data if isinstance(array, sparse.COO) else array.reshape(-1)
+
+    reindexed = sparse.COO(
+        coords=coords,
+        data=data.astype(dtype, copy=False),
+        shape=(*array.shape[:axis], to.size),
+    )
+
+    return reindexed
+
+
 def reindex_(
     array: np.ndarray,
     from_,
     to,
+    *,
+    array_type: ReindexArrayType = ReindexArrayType.AUTO,
     fill_value: Any = None,
     axis: T_Axis = -1,
     promote: bool = False,
@@ -689,7 +800,7 @@ def reindex_(
 
     from_ = pd.Index(from_)
     # short-circuit for trivial case
-    if from_.equals(to):
+    if from_.equals(to) and array_type.is_same_type(array):
         return array
 
     if from_.dtype.kind == "O" and isinstance(from_[0], tuple):
@@ -697,19 +808,21 @@ def reindex_(
             "Currently does not support reindexing with object arrays of tuples. "
             "These occur when grouping by multi-indexed variables in xarray."
         )
-    idx = from_.get_indexer(to)
-    indexer = [slice(None, None)] * array.ndim
-    indexer[axis] = idx
-    reindexed = array[tuple(indexer)]
-    if any(idx == -1):
-        if fill_value is None:
-            raise ValueError("Filling is required. fill_value cannot be None.")
-        indexer[axis] = idx == -1
-        # This allows us to match xarray's type promotion rules
-        if fill_value is xrdtypes.NA or isnull(fill_value):
-            new_dtype, fill_value = xrdtypes.maybe_promote(reindexed.dtype)
-            reindexed = reindexed.astype(new_dtype, copy=False)
-        reindexed[tuple(indexer)] = fill_value
+    if fill_value is xrdtypes.NA or isnull(fill_value):
+        new_dtype, fill_value = xrdtypes.maybe_promote(array.dtype)
+    else:
+        new_dtype = array.dtype
+
+    if array_type is ReindexArrayType.AUTO:
+        # TODO: generalize here
+        # Right now, we effectively assume NEP-18 I think
+        # assert isinstance(array, np.ndarray)
+        array_type = ReindexArrayType.NUMPY
+
+    if array_type is ReindexArrayType.NUMPY:
+        reindexed = reindex_numpy(array, from_, to, fill_value, new_dtype, axis)
+    elif array_type is ReindexArrayType.SPARSE_COO:
+        reindexed = reindex_pydata_sparse_coo(array, from_, to, fill_value, new_dtype, axis)
     return reindexed
 
 
@@ -731,6 +844,67 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]:
     return offset, size
 
 
+def _factorize_single(by, expect, *, sort: bool, reindex: bool) -> tuple[pd.Index, np.ndarray]:
+    flat = by.reshape(-1)
+    if isinstance(expect, pd.RangeIndex):
+        # idx is a view of the original `by` array
+        # copy here so we don't have a race condition with the
+        # group_idx[nanmask] = nan_sentinel assignment later
+        # this is important in shared-memory parallelism with dask
+        # TODO: figure out how to avoid this
+        idx = flat.copy()
+        found_groups = cast(pd.Index, expect)
+        # TODO: fix by using masked integers
+        idx[idx > expect[-1]] = -1
+
+    elif isinstance(expect, pd.IntervalIndex):
+        if expect.closed == "both":
+            raise NotImplementedError
+        bins = np.concatenate([expect.left.to_numpy(), expect.right.to_numpy()[[-1]]])
+
+        # digitize is 0 or idx.max() for values outside the bounds of all intervals
+        # make it behave like pd.cut which uses -1:
+        if len(bins) > 1:
+            right = expect.closed_right
+            idx = np.digitize(
+                flat,
+                bins=bins.view(np.int64) if bins.dtype.kind == "M" else bins,
+                right=right,
+            )
+            idx -= 1
+            within_bins = flat <= bins.max() if right else flat < bins.max()
+            idx[~within_bins] = -1
+        else:
+            idx = np.zeros_like(flat, dtype=np.intp) - 1
+        found_groups = cast(pd.Index, expect)
+    else:
+        if expect is not None and reindex:
+            sorter = np.argsort(expect)
+            groups = expect[(sorter,)] if sort else expect
+            idx = np.searchsorted(expect, flat, sorter=sorter)
+            mask = ~np.isin(flat, expect) | isnull(flat) | (idx == len(expect))
+            if not sort:
+                # idx is the index in to the sorted array.
+                # if we didn't want sorting, unsort it back
+                idx[(idx == len(expect),)] = -1
+                idx = sorter[(idx,)]
+            idx[mask] = -1
+        else:
+            idx, groups = pd.factorize(flat, sort=sort)
+        found_groups = cast(pd.Index, groups)
+
+    return (found_groups, idx.reshape(by.shape))
+
+
+def _ravel_factorized(*factorized: np.ndarray, grp_shape: tuple[int, ...]) -> np.ndarray:
+    group_idx = np.ravel_multi_index(factorized, grp_shape, mode="wrap")
+    # NaNs; as well as values outside the bins are coded by -1
+    # Restore these after the raveling
+    nan_by_mask = reduce(np.logical_or, [(f == -1) for f in factorized])
+    group_idx[nan_by_mask] = -1
+    return group_idx
+
+
 @overload
 def factorize_(
     by: T_Bys,
@@ -740,7 +914,7 @@ def factorize_(
     expected_groups: T_ExpectIndexOptTuple | None = None,
     reindex: bool = False,
     sort: bool = True,
-) -> tuple[np.ndarray, tuple[np.ndarray, ...], tuple[int, ...], int, int, None]: ...
+) -> tuple[np.ndarray, tuple[pd.Index, ...], tuple[int, ...], int, int, None]: ...
 
 
 @overload
@@ -752,7 +926,7 @@ def factorize_(
     reindex: bool = False,
     sort: bool = True,
     fastpath: Literal[False] = False,
-) -> tuple[np.ndarray, tuple[np.ndarray, ...], tuple[int, ...], int, int, FactorProps]: ...
+) -> tuple[np.ndarray, tuple[pd.Index, ...], tuple[int, ...], int, int, FactorProps]: ...
 
 
 @overload
@@ -764,7 +938,7 @@ def factorize_(
     reindex: bool = False,
     sort: bool = True,
     fastpath: bool = False,
-) -> tuple[np.ndarray, tuple[np.ndarray, ...], tuple[int, ...], int, int, FactorProps | None]: ...
+) -> tuple[np.ndarray, tuple[pd.Index, ...], tuple[int, ...], int, int, FactorProps | None]: ...
 
 
 def factorize_(
@@ -775,9 +949,9 @@ def factorize_(
     reindex: bool = False,
     sort: bool = True,
     fastpath: bool = False,
-) -> tuple[np.ndarray, tuple[np.ndarray, ...], tuple[int, ...], int, int, FactorProps | None]:
+) -> tuple[np.ndarray, tuple[pd.Index, ...], tuple[int, ...], int, int, FactorProps | None]:
     """
-    Returns an array of integer  codes  for groups (and associated data)
+    Returns an array of integer codes for groups (and associated data)
     by wrapping pd.cut and pd.factorize (depending on isbin).
     This method handles reindex and sort so that we don't spend time reindexing / sorting
     a possibly large results array. Instead we set up the appropriate integer codes (group_idx)
@@ -786,75 +960,32 @@ def factorize_(
     if expected_groups is None:
         expected_groups = (None,) * len(by)
 
-    factorized = []
-    found_groups = []
-    for groupvar, expect in zip(by, expected_groups):
-        flat = groupvar.reshape(-1)
-        if isinstance(expect, pd.RangeIndex):
-            # idx is a view of the original `by` array
-            # copy here so we don't have a race condition with the
-            # group_idx[nanmask] = nan_sentinel assignment later
-            # this is important in shared-memory parallelism with dask
-            # TODO: figure out how to avoid this
-            idx = flat.copy()
-            found_groups.append(np.array(expect))
-            # TODO: fix by using masked integers
-            idx[idx > expect[-1]] = -1
-
-        elif isinstance(expect, pd.IntervalIndex):
-            if expect.closed == "both":
-                raise NotImplementedError
-            bins = np.concatenate([expect.left.to_numpy(), expect.right.to_numpy()[[-1]]])
-
-            # digitize is 0 or idx.max() for values outside the bounds of all intervals
-            # make it behave like pd.cut which uses -1:
-            if len(bins) > 1:
-                right = expect.closed_right
-                idx = np.digitize(
-                    flat,
-                    bins=bins.view(np.int64) if bins.dtype.kind == "M" else bins,
-                    right=right,
-                )
-                idx -= 1
-                within_bins = flat <= bins.max() if right else flat < bins.max()
-                idx[~within_bins] = -1
-            else:
-                idx = np.zeros_like(flat, dtype=np.intp) - 1
-
-            found_groups.append(np.array(expect))
-        else:
-            if expect is not None and reindex:
-                sorter = np.argsort(expect)
-                groups = expect[(sorter,)] if sort else expect
-                idx = np.searchsorted(expect, flat, sorter=sorter)
-                mask = ~np.isin(flat, expect) | isnull(flat) | (idx == len(expect))
-                if not sort:
-                    # idx is the index in to the sorted array.
-                    # if we didn't want sorting, unsort it back
-                    idx[(idx == len(expect),)] = -1
-                    idx = sorter[(idx,)]
-                idx[mask] = -1
-            else:
-                idx, groups = pd.factorize(flat, sort=sort)
-
-            found_groups.append(np.array(groups))
-        factorized.append(idx.reshape(groupvar.shape))
+    if len(by) > 2:
+        with ThreadPoolExecutor() as executor:
+            futures = [
+                executor.submit(partial(_factorize_single, sort=sort, reindex=reindex), groupvar, expect)
+                for groupvar, expect in zip(by, expected_groups)
+            ]
+            results = tuple(f.result() for f in futures)
+    else:
+        results = tuple(
+            _factorize_single(groupvar, expect, sort=sort, reindex=reindex)
+            for groupvar, expect in zip(by, expected_groups)
+        )
+    found_groups = tuple(r[0] for r in results)
+    factorized = [r[1] for r in results]
 
     grp_shape = tuple(len(grp) for grp in found_groups)
     ngroups = math.prod(grp_shape)
     if len(by) > 1:
-        group_idx = np.ravel_multi_index(factorized, grp_shape, mode="wrap")
-        # NaNs; as well as values outside the bins are coded by -1
-        # Restore these after the raveling
-        nan_by_mask = reduce(np.logical_or, [(f == -1) for f in factorized])
-        group_idx[nan_by_mask] = -1
+        group_idx = _ravel_factorized(*factorized, grp_shape=grp_shape)
     else:
-        group_idx = factorized[0]
+        (group_idx,) = factorized
 
     if fastpath:
-        return group_idx, tuple(found_groups), grp_shape, ngroups, ngroups, None
+        return group_idx, found_groups, grp_shape, ngroups, ngroups, None
 
-    if len(axes) == 1 and groupvar.ndim > 1:
+    if len(axes) == 1 and by[0].ndim > 1:
         # Not reducing along all dimensions of by
         # this is OK because for 3D by and axis=(1,2),
         # we collapse to a 2D by and axis=-1
@@ -1011,7 +1142,7 @@ def chunk_reduce(
     # indices=[0,0,0]. This is necessary when combining block results
     # factorize can handle strings etc unlike digitize
     group_idx, grps, found_groups_shape, _, size, props = factorize_(
-        (by,), axes, expected_groups=(expected_groups,), reindex=reindex, sort=sort
+        (by,), axes, expected_groups=(expected_groups,), reindex=bool(reindex), sort=sort
     )
     (groups,) = grps
 
@@ -1048,7 +1179,7 @@ def chunk_reduce(
     results: IntermediateDict = {"groups": [], "intermediates": []}
     if reindex and expected_groups is not None:
         # TODO: what happens with binning here?
-        results["groups"] = expected_groups.to_numpy()
+        results["groups"] = expected_groups
     else:
         if empty:
             results["groups"] = np.array([np.nan])
@@ -1132,7 +1263,7 @@ def _finalize_results(
     agg: Aggregation,
     axis: T_Axes,
     expected_groups: pd.Index | None,
-    reindex: bool,
+    reindex: ReindexStrategy,
 ) -> FinalResultsDict:
     """Finalize results by
     1. Squeezing out dummy dimensions
@@ -1169,14 +1300,15 @@ def _finalize_results(
             finalized[agg.name] = np.where(count_mask, fill_value, finalized[agg.name])
 
     # Final reindexing has to be here to be lazy
-    if not reindex and expected_groups is not None:
+    if not reindex.blockwise and expected_groups is not None:
         finalized[agg.name] = reindex_(
             finalized[agg.name],
             squeezed["groups"],
             expected_groups,
             fill_value=fill_value,
+            array_type=reindex.array_type,
         )
-        finalized["groups"] = expected_groups.to_numpy()
+        finalized["groups"] = expected_groups
     else:
         finalized["groups"] = squeezed["groups"]
 
@@ -1192,11 +1324,11 @@ def _aggregate(
     axis: T_Axes,
     keepdims: bool,
     fill_value: Any,
-    reindex: bool,
+    reindex: ReindexStrategy,
 ) -> FinalResultsDict:
     """Final aggregation step of tree reduction"""
     results = combine(x_chunk, agg, axis, keepdims, is_aggregate=True)
-    return _finalize_results(results, agg, axis, expected_groups, reindex)
+    return _finalize_results(results, agg, axis, expected_groups, reindex=reindex)
 
 
 def _expand_dims(results: IntermediateDict) -> IntermediateDict:
@@ -1221,7 +1353,7 @@ def _simple_combine(
     agg: Aggregation,
     axis: T_Axes,
     keepdims: bool,
-    reindex: bool,
+    reindex: ReindexStrategy,
     is_aggregate: bool = False,
 ) -> IntermediateDict:
     """
@@ -1237,12 +1369,17 @@ def _simple_combine(
     from dask.array.core import deepfirst
     from dask.utils import deepmap
 
-    if not reindex:
+    if not reindex.blockwise:
         # We didn't reindex at the blockwise step
         # So now reindex before combining by reducing along DUMMY_AXIS
         unique_groups = _find_unique_groups(x_chunk)
         x_chunk = deepmap(
-            partial(reindex_intermediates, agg=agg, unique_groups=unique_groups),
+            partial(
+                reindex_intermediates,
+                agg=agg,
+                unique_groups=unique_groups,
+                array_type=reindex.array_type,
+            ),
             x_chunk,
         )
     else:
@@ -1260,7 +1397,8 @@ def _simple_combine(
             result = combine(array, axis=axis_, keepdims=True)
         if is_aggregate:
             # squeeze out DUMMY_AXIS if this is the last step i.e. called from _aggregate
-            result = result.squeeze(axis=DUMMY_AXIS)
+            # can't just pass DUMMY_AXIS, because of sparse.COO
+            result = result.squeeze(range(result.ndim)[DUMMY_AXIS])
         results["intermediates"].append(result)
     return results
 
@@ -1280,7 +1418,9 @@ def _conc2(x_chunk, key1, key2=slice(None), axis: T_Axes | None = None) -> np.nd
     # return concatenate3(mapped)
 
 
-def reindex_intermediates(x: IntermediateDict, agg: Aggregation, unique_groups) -> IntermediateDict:
+def reindex_intermediates(
+    x: IntermediateDict, agg: Aggregation, unique_groups, array_type
+) -> IntermediateDict:
     new_shape = x["groups"].shape[:-1] + (len(unique_groups),)
     newx: IntermediateDict = {"groups": np.broadcast_to(unique_groups, new_shape)}
     newx["intermediates"] = tuple(
@@ -1289,6 +1429,7 @@ def reindex_intermediates(x: IntermediateDict, agg: Aggregation, unique_groups)
             from_=np.atleast_1d(x["groups"].squeeze()),
             to=pd.Index(unique_groups),
             fill_value=f,
+            array_type=array_type,
         )
         for v, f in zip(x["intermediates"], agg.fill_value["intermediate"])
     )
@@ -1323,7 +1464,9 @@ def _grouped_combine(
         # I bet we can minimize the amount of reindexing for mD reductions too, but it's complicated
         unique_groups = _find_unique_groups(x_chunk)
         x_chunk = deepmap(
-            partial(reindex_intermediates, agg=agg, unique_groups=unique_groups),
+            partial(
+                reindex_intermediates, agg=agg, unique_groups=unique_groups, array_type=ReindexArrayType.AUTO
+            ),
             x_chunk,
         )
 
@@ -1427,7 +1570,7 @@ def _reduce_blockwise(
     fill_value: Any,
     engine: T_Engine,
     sort: bool,
-    reindex: bool,
+    reindex: ReindexStrategy,
 ) -> FinalResultsDict:
     """
     Blockwise groupby reduction that produces the final result. This code path is
@@ -1455,7 +1598,7 @@ def _reduce_blockwise(
         kwargs=finalize_kwargs_,
         engine=engine,
         sort=sort,
-        reindex=reindex,
+        reindex=bool(reindex.blockwise),
         user_dtype=agg.dtype["user"],
     )
 
@@ -1593,16 +1736,17 @@ def _unify_chunks(array, by):
 def dask_groupby_agg(
     array: DaskArray,
     by: T_By,
+    *,
     agg: Aggregation,
     expected_groups: pd.RangeIndex | None,
+    reindex: ReindexStrategy,
     axis: T_Axes = (),
     fill_value: Any = None,
     method: T_Method = "map-reduce",
-    reindex: bool = False,
     engine: T_Engine = "numpy",
     sort: bool = True,
     chunks_cohorts=None,
-) -> tuple[DaskArray, tuple[np.ndarray | DaskArray]]:
+) -> tuple[DaskArray, tuple[pd.Index | np.ndarray | DaskArray]]:
     import dask.array
     from dask.array.core import slices_from_chunks
     from dask.highlevelgraph import HighLevelGraph
@@ -1616,10 +1760,10 @@ def dask_groupby_agg(
     inds = tuple(range(array.ndim))
     name = f"groupby_{agg.name}"
 
-    if expected_groups is None and reindex:
-        raise ValueError
-    if method == "cohorts":
-        assert reindex is False
+    if expected_groups is None and reindex.blockwise:
+        raise ValueError("reindex.blockwise must be False-y if expected_groups is not provided.")
+    if method == "cohorts" and reindex.blockwise:
+        raise ValueError("reindex.blockwise must be False-y if method is 'cohorts'.")
 
     by_input = by
 
@@ -1642,7 +1786,7 @@ def dask_groupby_agg(
     #    a. "_simple_combine": Where it makes sense, we tree-reduce the reduction,
     #        NOT the groupby-reduction for a speed boost. This is what xhistogram does (effectively),
     #        It requires that all blocks contain all groups after the initial blockwise step (1) i.e.
-    #        reindex=True, and we must know expected_groups
+    #        reindex.blockwise=True, and we must know expected_groups
     #    b. "_grouped_combine": A more general solution where we tree-reduce the groupby reduction.
     #       This allows us to discover groups at compute time, support argreductions, lower intermediate
     #       memory usage (but method="cohorts" would also work to reduce memory in some cases)
@@ -1662,9 +1806,9 @@ def dask_groupby_agg(
         blockwise_method = partial(
             _get_chunk_reduction(agg.reduction_type),
             func=agg.chunk,
+            reindex=reindex.blockwise,
             fill_value=agg.fill_value["intermediate"],
             dtype=agg.dtype["intermediate"],
-            reindex=reindex,
             user_dtype=agg.dtype["user"],
         )
         if do_simple_combine:
@@ -1676,7 +1820,7 @@ def dask_groupby_agg(
         partial(
             blockwise_method,
             axis=axis,
-            expected_groups=expected_groups if reindex else None,
+            expected_groups=expected_groups if reindex.blockwise else None,
             engine=engine,
             sort=sort,
         ),
@@ -1712,7 +1856,7 @@ def dask_groupby_agg(
             keepdims=True,
             concatenate=False,
         )
-        aggregate = partial(_aggregate, combine=combine, agg=agg, fill_value=fill_value)
+        aggregate = partial(_aggregate, combine=combine, agg=agg, fill_value=fill_value, reindex=reindex)
 
         # Each chunk of `reduced`` is really a dict mapping
         # 1. reduction name to array
@@ -1723,14 +1867,14 @@ def dask_groupby_agg(
             reduced = tree_reduce(
                 intermediate,
                 combine=partial(combine, agg=agg),
-                aggregate=partial(aggregate, expected_groups=expected_groups, reindex=reindex),
+                aggregate=partial(aggregate, expected_groups=expected_groups),
             )
             if labels_are_unknown:
                 groups = _extract_unknown_groups(reduced, dtype=by.dtype)
                 group_chunks = ((np.nan,),)
             else:
                 assert expected_groups is not None
-                groups = (expected_groups.to_numpy(),)
+                groups = (expected_groups,)
                 group_chunks = ((len(expected_groups),),)
 
         elif method == "cohorts":
@@ -1744,22 +1888,28 @@ def dask_groupby_agg(
             for icohort, (blks, cohort) in enumerate(chunks_cohorts.items()):
                 cohort_index = pd.Index(cohort)
                 reindexer = (
-                    partial(reindex_intermediates, agg=agg, unique_groups=cohort_index)
+                    partial(
+                        reindex_intermediates,
+                        agg=agg,
+                        unique_groups=cohort_index,
+                        array_type=reindex.array_type,
+                    )
                     if do_simple_combine
                     else identity
                 )
                 subset = subset_to_blocks(intermediate, blks, block_shape, reindexer, chunks_as_array)
                 dsk |= subset.layer  # type: ignore[operator]
                 # now that we have reindexed, we can set reindex=True explicitlly
+                new_reindex = ReindexStrategy(blockwise=do_simple_combine, array_type=reindex.array_type)
                 _tree_reduce(
                     subset,
                     out_dsk=dsk,
                     name=out_name,
                     block_index=icohort,
                     axis=axis,
-                    combine=partial(combine, agg=agg, reindex=do_simple_combine, keepdims=True),
+                    combine=partial(combine, agg=agg, reindex=new_reindex, keepdims=True),
                     aggregate=partial(
-                        aggregate, expected_groups=cohort_index, reindex=do_simple_combine, keepdims=True
+                        aggregate, expected_groups=cohort_index, reindex=new_reindex, keepdims=True
                     ),
                 )
                 # This is done because pandas promotes to 64-bit types when an Index is created
@@ -1780,7 +1930,7 @@ def dask_groupby_agg(
 
     elif method == "blockwise":
         reduced = intermediate
-        if reindex:
+        if reindex.blockwise:
             if TYPE_CHECKING:
                 assert expected_groups is not None
             # TODO: we could have `expected_groups` be a dask array with appropriate chunks
@@ -1824,11 +1974,11 @@ def dask_groupby_agg(
         reduced,
         inds,
         adjust_chunks=dict(zip(out_inds, output_chunks)),
-        dtype=agg.dtype["final"],
         key=agg.name,
         name=f"{name}-{token}",
         concatenate=False,
         new_axes=new_axes,
+        meta=reindex.get_dask_meta(array, dtype=agg.dtype["final"], fill_value=agg.fill_value[agg.name]),
     )
 
     return (result, groups)
@@ -1839,14 +1989,14 @@ def cubed_groupby_agg(
     by: T_By,
     agg: Aggregation,
     expected_groups: pd.Index | None,
+    reindex: ReindexStrategy,
     axis: T_Axes = (),
     fill_value: Any = None,
     method: T_Method = "map-reduce",
-    reindex: bool = False,
     engine: T_Engine = "numpy",
     sort: bool = True,
     chunks_cohorts=None,
-) -> tuple[CubedArray, tuple[np.ndarray | CubedArray]]:
+) -> tuple[CubedArray, tuple[pd.Index | np.ndarray | CubedArray]]:
     import cubed
     import cubed.core.groupby
 
@@ -1882,7 +2032,7 @@ def cubed_groupby_agg(
         result = cubed.core.groupby.groupby_blockwise(
             array, by, axis=axis, func=_reduction_func, num_groups=num_groups
         )
-        groups = (expected_groups.to_numpy(),)
+        groups = (expected_groups,)
         return (result, groups)
 
     else:
@@ -1910,7 +2060,7 @@ def cubed_groupby_agg(
         assert do_simple_combine
         assert method == "map-reduce"
         assert expected_groups is not None
-        assert reindex is True
+        assert reindex.blockwise is True
         assert len(axis) == 1  # one axis/grouping
 
         def _groupby_func(a, by, axis, intermediate_dtype, num_groups):
@@ -1964,7 +2114,7 @@ def cubed_groupby_agg(
             num_groups=num_groups,
         )
 
-        groups = (expected_groups.to_numpy(),)
+        groups = (expected_groups,)
 
         return (result, groups)
 
@@ -2002,20 +2152,20 @@ def _extract_result(result_dict: FinalResultsDict, key) -> np.ndarray:
 
 
 def _validate_reindex(
-    reindex: bool | None,
+    reindex: ReindexStrategy | bool | None,
     func,
     method: T_MethodOpt,
     expected_groups,
     any_by_dask: bool,
     is_dask_array: bool,
     array_dtype: Any,
-) -> bool | None:
+) -> ReindexStrategy:
     # logger.debug("Entering _validate_reindex: reindex is {}".format(reindex))  # noqa
     def first_or_last():
         return func in ["first", "last"] or (_is_first_last_reduction(func) and array_dtype.kind != "f")
 
-    all_numpy = not is_dask_array and not any_by_dask
-    if reindex is True and not all_numpy:
+    all_eager = not is_dask_array and not any_by_dask
+    if reindex is True and not all_eager:
         if _is_arg_reduction(func):
             raise NotImplementedError
         if method == "cohorts" or (method == "blockwise" and not any_by_dask):
@@ -2023,39 +2173,44 @@ def _validate_reindex(
         if first_or_last():
             raise ValueError("reindex must be None or False when func is 'first' or 'last.")
 
-    if reindex is None:
+    if isinstance(reindex, ReindexStrategy):
+        reindex_ = reindex
+    else:
+        reindex_ = ReindexStrategy(blockwise=reindex)
+
+    if reindex_.blockwise is None:
         if method is None:
             # logger.debug("Leaving _validate_reindex: method = None, returning None")
-            return None
+            return ReindexStrategy(blockwise=None)
 
-        if all_numpy:
-            return True
+        if all_eager:
+            return ReindexStrategy(blockwise=True)
 
         if first_or_last():
             # have to do the grouped_combine since there's no good fill_value
             # Also needed for nanfirst, nanlast with no-NaN dtypes
-            return False
+            return ReindexStrategy(blockwise=False)
 
         if method == "blockwise":
             # for grouping by dask arrays, we set reindex=True
-            reindex = any_by_dask
+            reindex_ = ReindexStrategy(blockwise=any_by_dask)
 
         elif _is_arg_reduction(func):
-            reindex = False
+            reindex_ = ReindexStrategy(blockwise=False)
 
         elif method == "cohorts":
-            reindex = False
+            reindex_ = ReindexStrategy(blockwise=False)
 
         elif method == "map-reduce":
             if expected_groups is None and any_by_dask:
-                reindex = False
+                reindex_ = ReindexStrategy(blockwise=False)
             else:
-                reindex = True
+                reindex_ = ReindexStrategy(blockwise=True)
 
-    assert isinstance(reindex, bool)
+    assert isinstance(reindex_, ReindexStrategy)
     # logger.debug("Leaving _validate_reindex: reindex is {}".format(reindex))  # noqa
 
-    return reindex
+    return reindex_
 
 
 def _assert_by_is_aligned(shape: tuple[int, ...], by: T_Bys) -> None:
@@ -2118,10 +2273,9 @@ def _factorize_multiple(
     expected_groups: T_ExpectIndexOptTuple,
     any_by_dask: bool,
     sort: bool = True,
-) -> tuple[tuple[np.ndarray], tuple[np.ndarray, ...], tuple[int, ...]]:
+) -> tuple[tuple[np.ndarray], tuple[pd.Index, ...], tuple[int, ...]]:
     kwargs: FactorizeKwargs = dict(
         axes=(),  # always (), we offset later if necessary.
-        expected_groups=expected_groups,
         fastpath=True,
         # This is the only way it makes sense I think.
         # reindex controls what's actually allocated in chunk_reduce
@@ -2135,34 +2289,36 @@ def _factorize_multiple(
         # unifying chunks will make sure all arrays in `by` are dask arrays
         # with compatible chunks, even if there was originally a numpy array
         inds = tuple(range(by[0].ndim))
-        chunks, by_ = dask.array.unify_chunks(*itertools.chain(*zip(by, (inds,) * len(by))))
-
-        group_idx = dask.array.map_blocks(
-            _lazy_factorize_wrapper,
-            *by_,
-            chunks=tuple(chunks.values()),
-            meta=np.array((), dtype=np.int64),
-            **kwargs,
-        )
-
-        fg, gs = [], []
         for by_, expect in zip(by, expected_groups):
-            if expect is None:
-                if is_duck_dask_array(by_):
-                    raise ValueError("Please provide expected_groups when grouping by a dask array.")
+            if expect is None and is_duck_dask_array(by_):
+                raise ValueError("Please provide expected_groups when grouping by a dask array.")
 
-                found_group = pd.unique(by_.reshape(-1))
-            else:
-                found_group = expect.to_numpy()
+        found_groups = tuple(
+            pd.Index(pd.unique(by_.reshape(-1))) if expect is None else expect
+            for by_, expect in zip(by, expected_groups)
+        )
+        grp_shape = tuple(map(len, found_groups))
 
-            fg.append(found_group)
-            gs.append(len(found_group))
+        chunks, by_chunked = dask.array.unify_chunks(*itertools.chain(*zip(by, (inds,) * len(by))))
+        group_idxs = [
+            dask.array.map_blocks(
+                _lazy_factorize_wrapper,
+                by_,
+                expected_groups=(expect_,),
+                meta=np.array((), dtype=np.int64),
+                **kwargs,
+            )
+            for by_, expect_ in zip(by_chunked, expected_groups)
+        ]
+        # This could be avoied but we'd use `np.where`
+        # instead `_ravel_factorized` instead i.e. a copy.
+        group_idx = dask.array.map_blocks(
+            _ravel_factorized, *group_idxs, grp_shape=grp_shape, chunks=tuple(chunks.values()), dtype=np.int64
+        )
 
-        found_groups = tuple(fg)
-        grp_shape = tuple(gs)
     else:
         kwargs["by"] = by
-        group_idx, found_groups, grp_shape, *_ = factorize_(**kwargs)
+        group_idx, found_groups, grp_shape, *_ = factorize_(**kwargs, expected_groups=expected_groups)
 
     return (group_idx,), found_groups, grp_shape
 
@@ -2280,7 +2436,7 @@ def groupby_reduce(
     min_count: int | None = None,
     method: T_MethodOpt = None,
     engine: T_EngineOpt = None,
-    reindex: bool | None = None,
+    reindex: ReindexStrategy | bool | None = None,
     finalize_kwargs: dict[Any, Any] | None = None,
 ) -> tuple[DaskArray, Unpack[tuple[np.ndarray | DaskArray, ...]]]:
     """
@@ -2293,7 +2449,7 @@ def groupby_reduce(
     *by : ndarray or DaskArray
         Array of labels to group over. Must be aligned with ``array`` so that
         ``array.shape[-by.ndim :] == by.shape`` or any disagreements in that
-        equality check are for dimensions of size 1 in `by`.
+        equality check are for dimensions of size 1 in ``by``.
     func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \
             "max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \
             "quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \
@@ -2308,33 +2464,34 @@ def groupby_reduce(
         reductions when ``method`` is not ``"map-reduce"``. For ``"map-reduce"``, the groups
         are always sorted.
     axis : None or int or Sequence[int], optional
-        If None, reduce across all dimensions of by
-        Else, reduce across corresponding axes of array
-        Negative integers are normalized using array.ndim
+        If None, reduce across all dimensions of ``by``,
+        else reduce across corresponding axes of array.
+        Negative integers are normalized using ``array.ndim``.
     fill_value : Any
         Value to assign when a label in ``expected_groups`` is not present.
     dtype : data-type , optional
         DType for the output. Can be anything that is accepted by ``np.dtype``.
     min_count : int, default: None
         The required number of valid values to perform the operation. If
-        fewer than min_count non-NA values are present the result will be
-        NA. Only used if skipna is set to True or defaults to True for the
+        fewer than ``min_count`` non-NA values are present the result will be
+        NA. Only used if ``skipna`` is set to True or defaults to True for the
         array's dtype.
     method : {"map-reduce", "blockwise", "cohorts"}, optional
+        Note that this arg is chosen by default using heuristics.
         Strategy for reduction of dask arrays only:
           * ``"map-reduce"``:
             First apply the reduction blockwise on ``array``, then
             combine a few newighbouring blocks, apply the reduction.
             Continue until finalizing. Usually, ``func`` will need
-            to be an Aggregation instance for this method to work.
+            to be an ``Aggregation`` instance for this method to work.
             Common aggregations are implemented.
           * ``"blockwise"``:
             Only reduce using blockwise and avoid aggregating blocks
             together. Useful for resampling-style reductions where group
-            members are always together. If  `by` is 1D,  `array` is automatically
+            members are always together. If  ``by`` is 1D,  ``array`` is automatically
             rechunked so that chunk boundaries line up with group boundaries
             i.e. each block contains all members of any group present
-            in that block. For nD `by`, you must make sure that all members of a group
+            in that block. For nD ``by``, you must make sure that all members of a group
             are present in a single block.
           * ``"cohorts"``:
             Finds group labels that tend to occur together ("cohorts"),
@@ -2359,12 +2516,15 @@ def groupby_reduce(
           * ``"numbagg"``:
             Use the reductions supported by ``numbagg.grouped``. This will fall back to ``numpy_groupies.aggregate_numpy``
             for a reduction that is not yet implemented.
-    reindex : bool, optional
-        Whether to "reindex" the blockwise results to ``expected_groups`` (possibly automatically detected).
+    reindex : ReindexStrategy | bool, optional
+        Whether to "reindex" the blockwise reduced results to ``expected_groups`` (possibly automatically detected).
         If True, the intermediate result of the blockwise groupby-reduction has a value for all expected groups,
         and the final result is a simple reduction of those intermediates. In nearly all cases, this is a significant
         boost in computation speed. For cases like time grouping, this may result in large intermediates relative to the
         original block size. Avoid that by using ``method="cohorts"``. By default, it is turned off for argreductions.
+        By default, the type of ``array`` is preserved. You may optionally reindex to a sparse array type to further control memory
+        in the case of ``expected_groups`` being very large. Pass a ``ReindexStrategy`` instance with the appropriate ``array_type``,
+        for example (``reindex=ReindexStrategy(blockwise=False, array_type=ReindexArrayType.SPARSE_COO)``).
     finalize_kwargs : dict, optional
         Kwargs passed to finalize the reduction such as ``ddof`` for var, std or ``q`` for quantile.
 
@@ -2450,7 +2610,7 @@ def groupby_reduce(
     expected_groups = _validate_expected_groups(nby, expected_groups)
 
     for idx, (expect, is_dask) in enumerate(zip(expected_groups, by_is_dask)):
-        if is_dask and (reindex or nby > 1) and expect is None:
+        if is_dask and (reindex.blockwise or nby > 1) and expect is None:
             raise ValueError(
                 f"`expected_groups` for array {idx} in `by` cannot be None since it is a dask.array."
             )
@@ -2583,7 +2743,7 @@ def groupby_reduce(
             by=by_,
             expected_groups=expected_,
             agg=agg,
-            reindex=bool(reindex),
+            reindex=reindex,
             method=method,
             sort=sort,
         )
@@ -2591,12 +2751,13 @@ def groupby_reduce(
         return (result, groups)
 
     elif not has_dask:
+        reindex.set_blockwise_for_numpy()
         results = _reduce_blockwise(
             array,
             by_,
             agg,
             expected_groups=expected_,
-            reindex=bool(reindex),
+            reindex=reindex,
             sort=sort,
             **kwargs,
         )
@@ -2655,6 +2816,7 @@ def groupby_reduce(
         )
 
         if TYPE_CHECKING:
+            assert isinstance(reindex, ReindexStrategy)
             assert method is not None
 
         # TODO: just do this in dask_groupby_agg
@@ -2673,7 +2835,7 @@ def groupby_reduce(
             by=by_,
             expected_groups=expected_,
             agg=agg,
-            reindex=bool(reindex),
+            reindex=reindex,
             method=method,
             chunks_cohorts=chunks_cohorts,
             sort=sort,
@@ -2699,9 +2861,13 @@ def groupby_reduce(
             groups_ = groups_[..., ~mask]
 
         # This reindex also handles bins with no data
-        result = reindex_(result, from_=groups_, to=expected_, fill_value=fill_value).reshape(
-            result.shape[:-1] + grp_shape
-        )
+        result = reindex_(
+            result,
+            from_=groups_,
+            to=expected_,
+            fill_value=fill_value,
+            array_type=ReindexArrayType.AUTO,  # just reindex the received array
+        ).reshape(result.shape[:-1] + grp_shape)
         groups = final_groups
 
     if is_bool_array and (_is_minmax_reduction(func) or _is_first_last_reduction(func)):
@@ -2718,6 +2884,9 @@ def groupby_reduce(
             result = asdelta + offset
             result[nanmask] = np.timedelta64("NaT")
 
+    groups = map(
+        lambda g: g.to_numpy() if isinstance(g, pd.Index) and not isinstance(g, pd.RangeIndex) else g, groups
+    )
     return (result, *groups)
 
 


=====================================
flox/dask_array_ops.py
=====================================
@@ -4,13 +4,23 @@ from functools import lru_cache, partial
 from itertools import product
 from numbers import Integral
 
+import dask
+import pandas as pd
 from dask import config
+from dask.base import normalize_token
 from dask.blockwise import lol_tuples
+from packaging.version import Version
 from toolz import partition_all
 
 from .lib import ArrayLayer
 from .types import Graph
 
+if Version(dask.__version__) <= Version("2025.03.1"):
+    # workaround for https://github.com/dask/dask/issues/11862
+    @normalize_token.register(pd.RangeIndex)
+    def normalize_range_index(x):
+        return normalize_token(type(x)), x.start, x.stop, x.step, x.dtype, x.name
+
 
 # _tree_reduce and partial_reduce are copied from dask.array.reductions
 # They have been modified to work purely with graphs, and without creating new Array layers


=====================================
flox/xarray.py
=====================================
@@ -10,6 +10,7 @@ from packaging.version import Version
 
 from .aggregations import Aggregation, Dim, _atleast_1d, quantile_new_dims_func
 from .core import (
+    ReindexStrategy,
     _convert_expected_groups_to_index,
     _get_expected_groups,
     _validate_expected_groups,
@@ -77,10 +78,10 @@ def xarray_reduce(
     keep_attrs: bool | None = True,
     skipna: bool | None = None,
     min_count: int | None = None,
-    reindex: bool | None = None,
+    reindex: ReindexStrategy | bool | None = None,
     **finalize_kwargs,
 ):
-    """GroupBy reduce operations on xarray objects using numpy-groupies
+    """GroupBy reduce operations on xarray objects using numpy-groupies.
 
     Parameters
     ----------
@@ -105,26 +106,27 @@ def xarray_reduce(
     dim : hashable
         dimension name along which to reduce. If None, reduces across all
         dimensions of `by`
-    fill_value
+    fill_value : Any
         Value used for missing groups in the output i.e. when one of the labels
         in ``expected_groups`` is not actually present in ``by``.
     dtype : data-type, optional
-        DType for the output. Can be anything accepted by ``np.dtype``.
+        DType for the output. Can be anything that is accepted by ``np.dtype``.
     method : {"map-reduce", "blockwise", "cohorts"}, optional
+        Note that this arg is chosen by default using heuristics.
         Strategy for reduction of dask arrays only:
           * ``"map-reduce"``:
             First apply the reduction blockwise on ``array``, then
             combine a few newighbouring blocks, apply the reduction.
             Continue until finalizing. Usually, ``func`` will need
-            to be an Aggregation instance for this method to work.
+            to be an ``Aggregation`` instance for this method to work.
             Common aggregations are implemented.
           * ``"blockwise"``:
             Only reduce using blockwise and avoid aggregating blocks
             together. Useful for resampling-style reductions where group
-            members are always together. If  `by` is 1D,  `array` is automatically
+            members are always together. If  ``by`` is 1D,  ``array`` is automatically
             rechunked so that chunk boundaries line up with group boundaries
             i.e. each block contains all members of any group present
-            in that block. For nD `by`, you must make sure that all members of a group
+            in that block. For nD ``by``, you must make sure that all members of a group
             are present in a single block.
           * ``"cohorts"``:
             Finds group labels that tend to occur together ("cohorts"),
@@ -134,11 +136,11 @@ def xarray_reduce(
             'month', dayofyear' etc. Optimize chunking ``array`` for this
             method by first rechunking using ``rechunk_for_cohorts``
             (for 1D ``by`` only).
-    engine : {"flox", "numpy", "numba"}, optional
+    engine : {"flox", "numpy", "numba", "numbagg"}, optional
         Algorithm to compute the groupby reduction on non-dask arrays and on each dask chunk:
           * ``"numpy"``:
             Use the vectorized implementations in ``numpy_groupies.aggregate_numpy``.
-            This is the default choice because it works for other array types.
+            This is the default choice because it works for most array types.
           * ``"flox"``:
             Use an internal implementation where the data is sorted so that
             all members of a group occur sequentially, and then numpy.ufunc.reduceat
@@ -161,13 +163,16 @@ def xarray_reduce(
         fewer than min_count non-NA values are present the result will be
         NA. Only used if skipna is set to True or defaults to True for the
         array's dtype.
-    reindex : bool, optional
-        Whether to "reindex" the blockwise results to `expected_groups` (possibly automatically detected).
+    reindex : ReindexStrategy | bool, optional
+        Whether to "reindex" the blockwise reduced results to ``expected_groups`` (possibly automatically detected).
         If True, the intermediate result of the blockwise groupby-reduction has a value for all expected groups,
         and the final result is a simple reduction of those intermediates. In nearly all cases, this is a significant
         boost in computation speed. For cases like time grouping, this may result in large intermediates relative to the
-        original block size. Avoid that by using method="cohorts". By default, it is turned off for arg reductions.
-    **finalize_kwargs
+        original block size. Avoid that by using ``method="cohorts"``. By default, it is turned off for argreductions.
+        By default, the type of ``array`` is preserved. You may optionally reindex to a sparse array type to further control memory
+        in the case of ``expected_groups`` being very large. Pass a ``ReindexStrategy`` instance with the appropriate ``array_type``,
+        for example (``reindex=ReindexStrategy(blockwise=False, array_type=ReindexArrayType.SPARSE_COO)``).
+    **finalize_kwargs: dict, optional
         kwargs passed to the finalize function, like ``ddof`` for var, std or ``q`` for quantile.
 
     Returns


=====================================
flox/xrutils.py
=====================================
@@ -152,7 +152,7 @@ def notnull(data):
     scalar_type = data.dtype.type
     if issubclass(scalar_type, np.bool_ | np.integer | np.character | np.void):
         # these types cannot represent missing values
-        return np.ones_like(data, dtype=bool)
+        return np.broadcast_to(np.array(True), data.shape)
     else:
         out = isnull(data)
         np.logical_not(out, out=out)
@@ -173,7 +173,7 @@ def isnull(data):
         return np.isnan(data)
     elif issubclass(scalar_type, np.bool_ | np.integer | np.character | np.void):
         # these types cannot represent missing values
-        return np.zeros_like(data, dtype=bool)
+        return np.broadcast_to(np.array(False), data.shape)
     else:
         # at this point, array should have dtype=object
         if isinstance(data, (np.ndarray, dask_array_type)):  # noqa


=====================================
pyproject.toml
=====================================
@@ -134,6 +134,7 @@ module=[
     "pandas",
     "setuptools",
     "scipy.*",
+    "sparse.*",
     "toolz.*",
 ]
 ignore_missing_imports = true
@@ -145,5 +146,5 @@ testpaths = ["tests"]
 
 
 [tool.codespell]
-ignore-words-list = "nd,nax"
+ignore-words-list = "nd,nax,coo"
 skip = "*.html"


=====================================
tests/__init__.py
=====================================
@@ -16,6 +16,13 @@ try:
 except ImportError:
     dask_array_type = ()  # type: ignore[assignment, misc]
 
+try:
+    import sparse
+
+    sparse_array_type = sparse.COO
+except ImportError:
+    sparse_array_type = ()
+
 
 try:
     import xarray as xr
@@ -48,6 +55,7 @@ def LooseVersion(vstring):
 has_cftime, requires_cftime = _importorskip("cftime")
 has_cubed, requires_cubed = _importorskip("cubed")
 has_dask, requires_dask = _importorskip("dask")
+has_sparse, requires_sparse = _importorskip("sparse")
 has_numba, requires_numba = _importorskip("numba")
 has_numbagg, requires_numbagg = _importorskip("numbagg")
 has_scipy, requires_scipy = _importorskip("scipy")
@@ -111,6 +119,13 @@ def assert_equal(a, b, tolerance=None):
     else:
         a_eager, b_eager = a, b
 
+    if has_sparse:
+        one_is_sparse = isinstance(a_eager, sparse_array_type) or isinstance(b_eager, sparse_array_type)
+        a_eager = a_eager.todense() if isinstance(a_eager, sparse_array_type) else a_eager
+        b_eager = b_eager.todense() if isinstance(b_eager, sparse_array_type) else b_eager
+    else:
+        one_is_sparse = False
+
     if a.dtype.kind in "SUMmO":
         np.testing.assert_equal(a_eager, b_eager)
     else:
@@ -118,7 +133,7 @@ def assert_equal(a, b, tolerance=None):
 
     if has_dask and isinstance(a, dask_array_type) or isinstance(b, dask_array_type):
         # does some validation of the dask graph
-        dask_assert_eq(a, b, equal_nan=True)
+        dask_assert_eq(a, b, equal_nan=True, check_type=not one_is_sparse)
 
 
 def assert_equal_tuple(a, b):


=====================================
tests/conftest.py
=====================================
@@ -30,15 +30,3 @@ settings.load_profile("default")
 )
 def engine(request):
     return request.param
-
-
- at pytest.fixture(
-    scope="module",
-    params=[
-        "flox",
-        "numpy",
-        pytest.param("numbagg", marks=requires_numbagg),
-    ],
-)
-def engine_no_numba(request):
-    return request.param


=====================================
tests/test_core.py
=====================================
@@ -19,6 +19,8 @@ from flox import xrutils
 from flox.aggregations import Aggregation, _initialize_aggregation
 from flox.core import (
     HAS_NUMBAGG,
+    ReindexArrayType,
+    ReindexStrategy,
     _choose_engine,
     _convert_expected_groups_to_index,
     _get_optimal_chunks_for_groups,
@@ -44,6 +46,7 @@ from . import (
     raise_if_dask_computes,
     requires_cubed,
     requires_dask,
+    requires_sparse,
 )
 
 logger = logging.getLogger("flox")
@@ -1637,7 +1640,7 @@ def test_validate_reindex_map_reduce(dask_expected, reindex, func, expected_grou
         is_dask_array=True,
         array_dtype=np.dtype("int32"),
     )
-    assert actual is dask_expected
+    assert actual == ReindexStrategy(blockwise=dask_expected)
 
     # always reindex with all numpy inputs
     actual = _validate_reindex(
@@ -1649,7 +1652,7 @@ def test_validate_reindex_map_reduce(dask_expected, reindex, func, expected_grou
         is_dask_array=False,
         array_dtype=np.dtype("int32"),
     )
-    assert actual
+    assert actual.blockwise
 
     actual = _validate_reindex(
         True,
@@ -1660,7 +1663,7 @@ def test_validate_reindex_map_reduce(dask_expected, reindex, func, expected_grou
         is_dask_array=False,
         array_dtype=np.dtype("int32"),
     )
-    assert actual
+    assert actual.blockwise
 
 
 def test_validate_reindex() -> None:
@@ -1699,7 +1702,7 @@ def test_validate_reindex() -> None:
                 any_by_dask=False,
                 is_dask_array=True,
                 array_dtype=np.dtype("int32"),
-            )
+            ).blockwise
             assert actual is False
 
     with pytest.raises(ValueError):
@@ -1721,7 +1724,7 @@ def test_validate_reindex() -> None:
         any_by_dask=True,
         is_dask_array=True,
         array_dtype=np.dtype("int32"),
-    )
+    ).blockwise
     assert _validate_reindex(
         None,
         "sum",
@@ -1730,7 +1733,7 @@ def test_validate_reindex() -> None:
         any_by_dask=True,
         is_dask_array=True,
         array_dtype=np.dtype("int32"),
-    )
+    ).blockwise
 
     kwargs = dict(
         method="blockwise",
@@ -1740,12 +1743,12 @@ def test_validate_reindex() -> None:
     )
 
     for func in ["nanfirst", "nanlast"]:
-        assert not _validate_reindex(None, func, array_dtype=np.dtype("int32"), **kwargs)  # type: ignore[arg-type]
-        assert _validate_reindex(None, func, array_dtype=np.dtype("float32"), **kwargs)  # type: ignore[arg-type]
+        assert not _validate_reindex(None, func, array_dtype=np.dtype("int32"), **kwargs).blockwise  # type: ignore[arg-type]
+        assert _validate_reindex(None, func, array_dtype=np.dtype("float32"), **kwargs).blockwise  # type: ignore[arg-type]
 
     for func in ["first", "last"]:
-        assert not _validate_reindex(None, func, array_dtype=np.dtype("int32"), **kwargs)  # type: ignore[arg-type]
-        assert not _validate_reindex(None, func, array_dtype=np.dtype("float32"), **kwargs)  # type: ignore[arg-type]
+        assert not _validate_reindex(None, func, array_dtype=np.dtype("int32"), **kwargs).blockwise  # type: ignore[arg-type]
+        assert not _validate_reindex(None, func, array_dtype=np.dtype("float32"), **kwargs).blockwise  # type: ignore[arg-type]
 
 
 @requires_dask
@@ -2023,20 +2026,18 @@ def test_datetime_minmax(engine) -> None:
 
 @pytest.mark.parametrize("func", ["first", "last", "nanfirst", "nanlast"])
 def test_datetime_timedelta_first_last(engine, func) -> None:
-    import flox
-
     idx = 0 if "first" in func else -1
     idx1 = 2 if "first" in func else -1
 
     ## datetime
     dt = pd.date_range("2001-01-01", freq="d", periods=5).values
     by = np.ones(dt.shape, dtype=int)
-    actual, *_ = flox.groupby_reduce(dt, by, func=func, engine=engine)
+    actual, *_ = groupby_reduce(dt, by, func=func, engine=engine)
     assert_equal(actual, dt[[idx]])
 
     # missing group
     by = np.array([0, 2, 3, 3, 3])
-    actual, *_ = flox.groupby_reduce(
+    actual, *_ = groupby_reduce(
         dt, by, expected_groups=([0, 1, 2, 3],), func=func, engine=engine, fill_value=dtypes.NA
     )
     assert_equal(actual, [dt[0], np.datetime64("NaT"), dt[1], dt[idx1]])
@@ -2044,12 +2045,47 @@ def test_datetime_timedelta_first_last(engine, func) -> None:
     ## timedelta
     dt = dt - dt[0]
     by = np.ones(dt.shape, dtype=int)
-    actual, *_ = flox.groupby_reduce(dt, by, func=func, engine=engine)
+    actual, *_ = groupby_reduce(dt, by, func=func, engine=engine)
     assert_equal(actual, dt[[idx]])
 
     # missing group
     by = np.array([0, 2, 3, 3, 3])
-    actual, *_ = flox.groupby_reduce(
+    actual, *_ = groupby_reduce(
         dt, by, expected_groups=([0, 1, 2, 3],), func=func, engine=engine, fill_value=dtypes.NA
     )
     assert_equal(actual, [dt[0], np.timedelta64("NaT"), dt[1], dt[idx1]])
+
+
+ at requires_dask
+ at requires_sparse
+def test_reindex_sparse():
+    import sparse
+
+    array = dask.array.ones((2, 12), chunks=(-1, 3))
+    func = "sum"
+    expected_groups = pd.Index(np.arange(11))
+    by = dask.array.from_array(np.repeat(np.arange(6) * 2, 2), chunks=(3,))
+    dense = np.zeros((2, 11))
+    dense[..., np.arange(6) * 2] = 2
+    expected = sparse.COO.from_numpy(dense)
+
+    with pytest.raises(ValueError):
+        ReindexStrategy(blockwise=True, array_type=ReindexArrayType.SPARSE_COO)
+    reindex = ReindexStrategy(blockwise=False, array_type=ReindexArrayType.SPARSE_COO)
+
+    original_reindex = flox.core.reindex_
+
+    def mocked_reindex(*args, **kwargs):
+        res = original_reindex(*args, **kwargs)
+        if isinstance(res, dask.array.Array):
+            assert isinstance(res._meta, sparse.COO)
+        else:
+            assert isinstance(res, sparse.COO)
+        return res
+
+    with patch("flox.core.reindex_") as mocked_func:
+        mocked_func.side_effect = mocked_reindex
+        actual, *_ = groupby_reduce(array, by, func=func, reindex=reindex, expected_groups=expected_groups)
+        assert_equal(actual, expected)
+        # once during graph construction, 10 times afterward
+        assert mocked_func.call_count > 1


=====================================
tests/test_xarray.py
=====================================
@@ -34,8 +34,7 @@ np.random.seed(123)
 @pytest.mark.parametrize("min_count", [None, 1, 3])
 @pytest.mark.parametrize("add_nan", [True, False])
 @pytest.mark.parametrize("skipna", [True, False])
-def test_xarray_reduce(skipna, add_nan, min_count, engine_no_numba, reindex):
-    engine = engine_no_numba
+def test_xarray_reduce(skipna, add_nan, min_count, engine, reindex):
     if skipna is False and min_count is not None:
         pytest.skip()
 
@@ -91,11 +90,9 @@ def test_xarray_reduce(skipna, add_nan, min_count, engine_no_numba, reindex):
 # TODO: sort
 @pytest.mark.parametrize("pass_expected_groups", [True, False])
 @pytest.mark.parametrize("chunk", (pytest.param(True, marks=requires_dask), False))
-def test_xarray_reduce_multiple_groupers(pass_expected_groups, chunk, engine_no_numba):
+def test_xarray_reduce_multiple_groupers(pass_expected_groups, chunk, engine):
     if chunk and pass_expected_groups is False:
         pytest.skip()
-    engine = engine_no_numba
-
     arr = np.ones((4, 12))
     labels = np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"])
     labels2 = np.array([1, 2, 2, 1])
@@ -140,10 +137,9 @@ def test_xarray_reduce_multiple_groupers(pass_expected_groups, chunk, engine_no_
 
 @pytest.mark.parametrize("pass_expected_groups", [True, False])
 @pytest.mark.parametrize("chunk", (pytest.param(True, marks=requires_dask), False))
-def test_xarray_reduce_multiple_groupers_2(pass_expected_groups, chunk, engine_no_numba):
+def test_xarray_reduce_multiple_groupers_2(pass_expected_groups, chunk, engine):
     if chunk and pass_expected_groups is False:
         pytest.skip()
-    engine = engine_no_numba
 
     arr = np.ones((2, 12))
     labels = np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"])
@@ -218,8 +214,7 @@ def test_xarray_reduce_cftime_var(engine, indexer, expected_groups, func):
 
 @requires_cftime
 @requires_dask
-def test_xarray_reduce_single_grouper(engine_no_numba):
-    engine = engine_no_numba
+def test_xarray_reduce_single_grouper(engine):
     # DataArray
     ds = xr.Dataset(
         {
@@ -326,8 +321,7 @@ def test_rechunk_for_blockwise(inchunks, expected):
 # TODO: dim=None, dim=Ellipsis, groupby unindexed dim
 
 
-def test_groupby_duplicate_coordinate_labels(engine_no_numba):
-    engine = engine_no_numba
+def test_groupby_duplicate_coordinate_labels(engine):
     # fix for http://stackoverflow.com/questions/38065129
     array = xr.DataArray([1, 2, 3], [("x", [1, 1, 2])])
     expected = xr.DataArray([3, 3], [("x", [1, 2])])
@@ -335,8 +329,7 @@ def test_groupby_duplicate_coordinate_labels(engine_no_numba):
     assert_equal(expected, actual)
 
 
-def test_multi_index_groupby_sum(engine_no_numba):
-    engine = engine_no_numba
+def test_multi_index_groupby_sum(engine):
     # regression test for xarray GH873
     ds = xr.Dataset(
         {"foo": (("x", "y", "z"), np.ones((3, 4, 2)))},
@@ -362,8 +355,7 @@ def test_multi_index_groupby_sum(engine_no_numba):
 
 
 @pytest.mark.parametrize("chunks", (None, pytest.param(2, marks=requires_dask)))
-def test_xarray_groupby_bins(chunks, engine_no_numba):
-    engine = engine_no_numba
+def test_xarray_groupby_bins(chunks, engine):
     array = xr.DataArray([1, 1, 1, 1, 1], dims="x")
     labels = xr.DataArray([1, 1.5, 1.9, 2, 3], dims="x", name="labels")
 
@@ -532,11 +524,10 @@ def test_alignment_error():
 @pytest.mark.parametrize("dtype_out", [np.float64, "float64", np.dtype("float64")])
 @pytest.mark.parametrize("dtype", [np.float32, np.float64])
 @pytest.mark.parametrize("chunk", (pytest.param(True, marks=requires_dask), False))
-def test_dtype(add_nan, chunk, dtype, dtype_out, engine_no_numba):
-    if engine_no_numba == "numbagg":
+def test_dtype(add_nan, chunk, dtype, dtype_out, engine):
+    if engine == "numbagg":
         # https://github.com/numbagg/numbagg/issues/121
         pytest.skip()
-    engine = engine_no_numba
     xp = dask.array if chunk else np
     data = xp.linspace(0, 1, 48, dtype=dtype).reshape((4, 12))
 



View it on GitLab: https://salsa.debian.org/debian-gis-team/flox/-/commit/068a10f382451691351637876170efca93a7b9b5

-- 
View it on GitLab: https://salsa.debian.org/debian-gis-team/flox/-/commit/068a10f382451691351637876170efca93a7b9b5
You're receiving this email because of your account on salsa.debian.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://alioth-lists.debian.net/pipermail/pkg-grass-devel/attachments/20250407/84dc084d/attachment-0001.htm>


More information about the Pkg-grass-devel mailing list