[Git][debian-gis-team/flox][upstream] New upstream version 0.6.0

Antonio Valentino (@antonio.valentino) gitlab at salsa.debian.org
Sat Oct 15 09:05:29 BST 2022



Antonio Valentino pushed to branch upstream at Debian GIS Project / flox


Commits:
69e40600 by Antonio Valentino at 2022-10-15T07:14:09+00:00
New upstream version 0.6.0
- - - - -


10 changed files:

- .github/workflows/ci-additional.yaml
- asv_bench/benchmarks/combine.py
- docs/source/implementation.md
- flox/aggregations.py
- flox/core.py
- flox/visualize.py
- flox/xarray.py
- tests/__init__.py
- tests/test_core.py
- tests/test_xarray.py


Changes:

=====================================
.github/workflows/ci-additional.yaml
=====================================
@@ -2,7 +2,7 @@ name: CI Additional
 on:
   push:
     branches:
-      - "*"
+      - "main"
   pull_request:
     branches:
       - "*"


=====================================
asv_bench/benchmarks/combine.py
=====================================
@@ -58,4 +58,4 @@ class Combine1d(Combine):
         ]
 
         self.x_chunk_cohorts = [construct_member(groups) for groups in [np.array((1, 2, 3, 4))] * 4]
-        self.kwargs = {"agg": flox.aggregations.mean, "axis": (3,), "neg_axis": (-1,)}
+        self.kwargs = {"agg": flox.aggregations.mean, "axis": (3,)}


=====================================
docs/source/implementation.md
=====================================
@@ -13,7 +13,7 @@ or `xarray_reduce`.
 
 First we describe xarray's current strategy
 
-## `method="split-reduce"`: Xarray's current GroupBy strategy
+## Background: Xarray's current GroupBy strategy
 
 Xarray's current strategy is to find all unique group labels, index out each group,
 and then apply the reduction operation. Note that this only works if we know the group


=====================================
flox/aggregations.py
=====================================
@@ -55,10 +55,7 @@ def generic_aggregate(
 
 def _normalize_dtype(dtype, array_dtype, fill_value=None):
     if dtype is None:
-        if fill_value is not None and np.isnan(fill_value):
-            dtype = np.floating
-        else:
-            dtype = array_dtype
+        dtype = array_dtype
     if dtype is np.floating:
         # mean, std, var always result in floating
         # but we preserve the array's dtype if it is floating
@@ -68,6 +65,8 @@ def _normalize_dtype(dtype, array_dtype, fill_value=None):
             dtype = np.dtype("float64")
     elif not isinstance(dtype, np.dtype):
         dtype = np.dtype(dtype)
+    if fill_value not in [None, dtypes.INF, dtypes.NINF, dtypes.NA]:
+        dtype = np.result_type(dtype, fill_value)
     return dtype
 
 
@@ -465,6 +464,7 @@ aggregations = {
 
 def _initialize_aggregation(
     func: str | Aggregation,
+    dtype,
     array_dtype,
     fill_value,
     min_count: int | None,
@@ -484,10 +484,18 @@ def _initialize_aggregation(
     else:
         raise ValueError("Bad type for func. Expected str or Aggregation")
 
-    agg.dtype[func] = _normalize_dtype(agg.dtype[func], array_dtype, fill_value)
+    # np.dtype(None) == np.dtype("float64")!!!
+    # so check for not None
+    if dtype is not None and not isinstance(dtype, np.dtype):
+        dtype = np.dtype(dtype)
+
+    agg.dtype[func] = _normalize_dtype(dtype or agg.dtype[func], array_dtype, fill_value)
     agg.dtype["numpy"] = (agg.dtype[func],)
     agg.dtype["intermediate"] = [
-        _normalize_dtype(dtype, array_dtype) for dtype in agg.dtype["intermediate"]
+        _normalize_dtype(int_dtype, np.result_type(array_dtype, agg.dtype[func]), int_fv)
+        if int_dtype is None
+        else int_dtype
+        for int_dtype, int_fv in zip(agg.dtype["intermediate"], agg.fill_value["intermediate"])
     ]
 
     # Replace sentinel fill values according to dtype


=====================================
flox/core.py
=====================================
@@ -6,17 +6,7 @@ import math
 import operator
 from collections import namedtuple
 from functools import partial, reduce
-from typing import (
-    TYPE_CHECKING,
-    Any,
-    Callable,
-    Dict,
-    Iterable,
-    Literal,
-    Mapping,
-    Sequence,
-    Union,
-)
+from typing import TYPE_CHECKING, Any, Callable, Dict, Literal, Mapping, Sequence, Union
 
 import numpy as np
 import numpy_groupies as npg
@@ -37,8 +27,11 @@ from .xrutils import is_duck_array, is_duck_dask_array, isnull
 if TYPE_CHECKING:
     import dask.array.Array as DaskArray
 
+    T_ExpectedGroups = Union[Sequence, np.ndarray, pd.Index]
+    T_ExpectedGroupsOpt = Union[T_ExpectedGroups, None]
     T_Func = Union[str, Callable]
     T_Funcs = Union[T_Func, Sequence[T_Func]]
+    T_Agg = Union[str, Aggregation]
     T_Axis = int
     T_Axes = tuple[T_Axis, ...]
     T_AxesOpt = Union[T_Axis, T_Axes, None]
@@ -60,7 +53,7 @@ FactorProps = namedtuple("FactorProps", "offset_group nan_sentinel nanmask")
 DUMMY_AXIS = -2
 
 
-def _is_arg_reduction(func: str | Aggregation) -> bool:
+def _is_arg_reduction(func: T_Agg) -> bool:
     if isinstance(func, str) and func in ["argmin", "argmax", "nanargmax", "nanargmin"]:
         return True
     if isinstance(func, Aggregation) and func.reduction_type == "argreduce":
@@ -68,6 +61,12 @@ def _is_arg_reduction(func: str | Aggregation) -> bool:
     return False
 
 
+def _is_minmax_reduction(func: T_Agg) -> bool:
+    return not _is_arg_reduction(func) and (
+        isinstance(func, str) and ("max" in func or "min" in func)
+    )
+
+
 def _get_expected_groups(by, sort: bool) -> pd.Index:
     if is_duck_dask_array(by):
         raise ValueError("Please provide expected_groups if not grouping by a numpy array.")
@@ -138,7 +137,7 @@ def _get_optimal_chunks_for_groups(chunks, labels):
 
 
 @memoize
-def find_group_cohorts(labels, chunks, merge=True, method: T_MethodCohorts = "cohorts"):
+def find_group_cohorts(labels, chunks, merge: bool = True):
     """
     Finds groups labels that occur together aka "cohorts"
 
@@ -168,9 +167,6 @@ def find_group_cohorts(labels, chunks, merge=True, method: T_MethodCohorts = "co
     # To do this, we must have values in memory so casting to numpy should be safe
     labels = np.asarray(labels)
 
-    if method == "split-reduce":
-        return list(_get_expected_groups(labels, sort=False).to_numpy().reshape(-1, 1))
-
     # Build an array with the shape of labels, but where every element is the "chunk number"
     # 1. First subset the array appropriately
     axis = range(-labels.ndim, 0)
@@ -196,7 +192,7 @@ def find_group_cohorts(labels, chunks, merge=True, method: T_MethodCohorts = "co
     if merge:
         # First sort by number of chunks occupied by cohort
         sorted_chunks_cohorts = dict(
-            reversed(sorted(chunks_cohorts.items(), key=lambda kv: len(kv[0])))
+            sorted(chunks_cohorts.items(), key=lambda kv: len(kv[0]), reverse=True)
         )
 
         items = tuple(sorted_chunks_cohorts.items())
@@ -219,9 +215,15 @@ def find_group_cohorts(labels, chunks, merge=True, method: T_MethodCohorts = "co
                     merged_cohorts[k1].extend(v2)
                     merged_keys.append(k2)
 
-        return merged_cohorts.values()
+        # make sure each cohort is sorted after merging
+        sorted_merged_cohorts = {k: sorted(v) for k, v in merged_cohorts.items()}
+        # sort by first label in cohort
+        # This will help when sort=True (default)
+        # and we have to resort the dask array
+        return dict(sorted(sorted_merged_cohorts.items(), key=lambda kv: kv[1][0]))
+
     else:
-        return chunks_cohorts.values()
+        return chunks_cohorts
 
 
 def rechunk_for_cohorts(
@@ -734,13 +736,6 @@ def _squeeze_results(results: IntermediateDict, axis: T_Axes) -> IntermediateDic
     return newresults
 
 
-def _split_groups(array, j, slicer):
-    """Slices out chunks when split_out > 1"""
-    results = {"groups": array["groups"][..., slicer]}
-    results["intermediates"] = [v[..., slicer] for v in array["intermediates"]]
-    return results
-
-
 def _finalize_results(
     results: IntermediateDict,
     agg: Aggregation,
@@ -790,6 +785,7 @@ def _finalize_results(
     else:
         finalized["groups"] = squeezed["groups"]
 
+    finalized[agg.name] = finalized[agg.name].astype(agg.dtype[agg.name], copy=False)
     return finalized
 
 
@@ -880,7 +876,6 @@ def _grouped_combine(
     agg: Aggregation,
     axis: T_Axes,
     keepdims: bool,
-    neg_axis: T_Axes,
     engine: T_Engine,
     is_aggregate: bool = False,
     sort: bool = True,
@@ -906,6 +901,9 @@ def _grouped_combine(
             partial(reindex_intermediates, agg=agg, unique_groups=unique_groups), x_chunk
         )
 
+    # these are negative axis indices useful for concatenating the intermediates
+    neg_axis = tuple(range(-len(axis), 0))
+
     groups = _conc2(x_chunk, "groups", axis=neg_axis)
 
     if agg.reduction_type == "argreduce":
@@ -992,40 +990,17 @@ def _grouped_combine(
     return results
 
 
-def split_blocks(applied, split_out, expected_groups, split_name):
-    import dask.array
-    from dask.array.core import normalize_chunks
-    from dask.highlevelgraph import HighLevelGraph
-
-    chunk_tuples = tuple(itertools.product(*tuple(range(n) for n in applied.numblocks)))
-    ngroups = len(expected_groups)
-    group_chunks = normalize_chunks(np.ceil(ngroups / split_out), (ngroups,))
-    idx = tuple(np.cumsum((0,) + group_chunks[0]))
-
-    # split each block into `split_out` chunks
-    dsk = {}
-    for i in chunk_tuples:
-        for j in range(split_out):
-            dsk[(split_name, *i, j)] = (
-                _split_groups,
-                (applied.name, *i),
-                j,
-                slice(idx[j], idx[j + 1]),
-            )
-
-    # now construct an array that can be passed to _tree_reduce
-    intergraph = HighLevelGraph.from_collections(split_name, dsk, dependencies=(applied,))
-    intermediate = dask.array.Array(
-        intergraph,
-        name=split_name,
-        chunks=applied.chunks + ((1,) * split_out,),
-        meta=applied._meta,
-    )
-    return intermediate, group_chunks
-
-
 def _reduce_blockwise(
-    array, by, agg, *, axis: T_Axes, expected_groups, fill_value, engine: T_Engine, sort, reindex
+    array,
+    by,
+    agg: Aggregation,
+    *,
+    axis: T_Axes,
+    expected_groups,
+    fill_value,
+    engine: T_Engine,
+    sort,
+    reindex,
 ) -> FinalResultsDict:
     """
     Blockwise groupby reduction that produces the final result. This code path is
@@ -1068,18 +1043,99 @@ def _reduce_blockwise(
     return result
 
 
+def subset_to_blocks(
+    array: DaskArray, flatblocks: Sequence[int], blkshape: tuple[int] | None = None
+) -> DaskArray:
+    """
+    Advanced indexing of .blocks such that we always get a regular array back.
+
+    Parameters
+    ----------
+    array : dask.array
+    flatblocks : flat indices of blocks to extract
+    blkshape : shape of blocks with which to unravel flatblocks
+
+    Returns
+    -------
+    dask.array
+    """
+    if blkshape is None:
+        blkshape = array.blocks.shape
+
+    unraveled = np.unravel_index(flatblocks, blkshape)
+    normalized: list[Union[int, np.ndarray, slice]] = []
+    for ax, idx in enumerate(unraveled):
+        i = np.unique(idx).squeeze()
+        if i.ndim == 0:
+            normalized.append(i.item())
+        else:
+            if np.array_equal(i, np.arange(blkshape[ax])):
+                normalized.append(slice(None))
+            elif np.array_equal(i, np.arange(i[0], i[-1] + 1)):
+                normalized.append(slice(i[0], i[-1] + 1))
+            else:
+                normalized.append(i)
+    full_normalized = (slice(None),) * (array.ndim - len(normalized)) + tuple(normalized)
+
+    # has no iterables
+    noiter = tuple(i if not hasattr(i, "__len__") else slice(None) for i in full_normalized)
+    # has all iterables
+    alliter = {
+        ax: i if hasattr(i, "__len__") else slice(None) for ax, i in enumerate(full_normalized)
+    }
+
+    # apply everything but the iterables
+    if all(i == slice(None) for i in noiter):
+        return array
+
+    subset = array.blocks[noiter]
+
+    for ax, inds in alliter.items():
+        if isinstance(inds, slice):
+            continue
+        idxr = [slice(None, None)] * array.ndim
+        idxr[ax] = inds
+        subset = subset.blocks[tuple(idxr)]
+
+    return subset
+
+
+def _extract_unknown_groups(reduced, group_chunks, dtype) -> tuple[DaskArray]:
+    import dask.array
+    from dask.highlevelgraph import HighLevelGraph
+
+    layer: dict[tuple, tuple] = {}
+    groups_token = f"group-{reduced.name}"
+    first_block = reduced.ndim * (0,)
+    layer[(groups_token, *first_block)] = (
+        operator.getitem,
+        (reduced.name, *first_block),
+        "groups",
+    )
+    groups: tuple[DaskArray] = (
+        dask.array.Array(
+            HighLevelGraph.from_collections(groups_token, layer, dependencies=[reduced]),
+            groups_token,
+            chunks=group_chunks,
+            meta=np.array([], dtype=dtype),
+        ),
+    )
+
+    return groups
+
+
 def dask_groupby_agg(
     array: DaskArray,
     by: DaskArray | np.ndarray,
     agg: Aggregation,
     expected_groups: pd.Index | None,
     axis: T_Axes = (),
-    split_out: int = 1,
     fill_value: Any = None,
     method: T_Method = "map-reduce",
     reindex: bool = False,
     engine: T_Engine = "numpy",
     sort: bool = True,
+    chunks_cohorts=None,
 ) -> tuple[DaskArray, tuple[np.ndarray | DaskArray]]:
 
     import dask.array
@@ -1090,19 +1146,14 @@ def dask_groupby_agg(
     assert isinstance(axis, Sequence)
     assert all(ax >= 0 for ax in axis)
 
-    if method == "blockwise" and (split_out > 1 or not isinstance(by, np.ndarray)):
-        raise NotImplementedError
-
-    if split_out > 1 and expected_groups is None:
-        # This could be implemented using the "hash_split" strategy
-        # from dask.dataframe
+    if method == "blockwise" and not isinstance(by, np.ndarray):
         raise NotImplementedError
 
     inds = tuple(range(array.ndim))
     name = f"groupby_{agg.name}"
-    token = dask.base.tokenize(array, by, agg, expected_groups, axis, split_out)
+    token = dask.base.tokenize(array, by, agg, expected_groups, axis)
 
-    if expected_groups is None and (reindex or split_out > 1):
+    if expected_groups is None and reindex:
         expected_groups = _get_expected_groups(by, sort=sort)
 
     by_input = by
@@ -1133,9 +1184,7 @@ def dask_groupby_agg(
     #       This allows us to discover groups at compute time, support argreductions, lower intermediate
     #       memory usage (but method="cohorts" would also work to reduce memory in some cases)
 
-    do_simple_combine = (
-        method != "blockwise" and reindex and not _is_arg_reduction(agg) and split_out == 1
-    )
+    do_simple_combine = method != "blockwise" and reindex and not _is_arg_reduction(agg)
     if method == "blockwise":
         #  use the "non dask" code path, but applied blockwise
         blockwise_method = partial(
@@ -1148,18 +1197,18 @@ def dask_groupby_agg(
             func=agg.chunk,
             fill_value=agg.fill_value["intermediate"],
             dtype=agg.dtype["intermediate"],
-            reindex=reindex or (split_out > 1),
+            reindex=reindex,
         )
         if do_simple_combine:
             # Add a dummy dimension that then gets reduced over
             blockwise_method = tlz.compose(_expand_dims, blockwise_method)
 
     # apply reduction on chunk
-    applied = dask.array.blockwise(
+    intermediate = dask.array.blockwise(
         partial(
             blockwise_method,
             axis=axis,
-            expected_groups=expected_groups,
+            expected_groups=None if method in ["split-reduce", "cohorts"] else expected_groups,
             engine=engine,
             sort=sort,
         ),
@@ -1175,54 +1224,88 @@ def dask_groupby_agg(
         token=f"{name}-chunk-{token}",
     )
 
-    if split_out > 1:
-        intermediate, group_chunks = split_blocks(
-            applied, split_out, expected_groups, split_name=f"{name}-split-{token}"
-        )
-    else:
-        intermediate = applied
-        if expected_groups is None:
-            if is_duck_dask_array(by_input):
-                expected_groups = None
-            else:
-                expected_groups = _get_expected_groups(by_input, sort=sort)
-        group_chunks = ((len(expected_groups),) if expected_groups is not None else (np.nan,),)
-
-    if method == "map-reduce":
-        # these are negative axis indices useful for concatenating the intermediates
-        neg_axis = tuple(range(-len(axis), 0))
+    if expected_groups is None:
+        if is_duck_dask_array(by_input):
+            expected_groups = None
+        else:
+            expected_groups = _get_expected_groups(by_input, sort=sort)
+    group_chunks: tuple[tuple[Union[int, float], ...]] = (
+        (len(expected_groups),) if expected_groups is not None else (np.nan,),
+    )
 
+    if method in ["map-reduce", "cohorts", "split-reduce"]:
         combine: Callable[..., IntermediateDict]
         if do_simple_combine:
             combine = _simple_combine
         else:
-            combine = partial(_grouped_combine, engine=engine, neg_axis=neg_axis, sort=sort)
+            combine = partial(_grouped_combine, engine=engine, sort=sort)
 
-        # reduced is really a dict mapping reduction name to array
-        # and "groups" to an array of group labels
+        # Each chunk of `reduced`` is really a dict mapping
+        # 1. reduction name to array
+        # 2. "groups" to an array of group labels
         # Note: it does not make sense to interpret axis relative to
         # shape of intermediate results after the blockwise call
-        reduced = dask.array.reductions._tree_reduce(
-            intermediate,
-            aggregate=partial(
-                _aggregate,
-                combine=combine,
-                agg=agg,
-                expected_groups=None if split_out > 1 else expected_groups,
-                fill_value=fill_value,
-                reindex=reindex,
-            ),
+        tree_reduce = partial(
+            dask.array.reductions._tree_reduce,
             combine=partial(combine, agg=agg),
-            name=f"{name}-reduce",
+            name=f"{name}-reduce-{method}",
             dtype=array.dtype,
             axis=axis,
             keepdims=True,
             concatenate=False,
         )
-        output_chunks = reduced.chunks[: -(len(axis) + int(split_out > 1))] + group_chunks
+        aggregate = partial(
+            _aggregate, combine=combine, agg=agg, fill_value=fill_value, reindex=reindex
+        )
+        if method == "map-reduce":
+            reduced = tree_reduce(
+                intermediate,
+                aggregate=partial(aggregate, expected_groups=expected_groups),
+            )
+            if is_duck_dask_array(by_input) and expected_groups is None:
+                groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype)
+            else:
+                if expected_groups is None:
+                    expected_groups_ = _get_expected_groups(by_input, sort=sort)
+                else:
+                    expected_groups_ = expected_groups
+                groups = (expected_groups_.to_numpy(),)
+
+        elif method in ["cohorts", "split-reduce"]:
+            chunks_cohorts = find_group_cohorts(
+                by_input, [array.chunks[ax] for ax in axis], merge=True
+            )
+            reduced_ = []
+            groups_ = []
+            for blks, cohort in chunks_cohorts.items():
+                subset = subset_to_blocks(intermediate, blks, array.blocks.shape[-len(axis) :])
+                if do_simple_combine:
+                    # reindex so that reindex can be set to True later
+                    reindexed = dask.array.map_blocks(
+                        reindex_intermediates,
+                        subset,
+                        agg=agg,
+                        unique_groups=cohort,
+                        meta=subset._meta,
+                    )
+                else:
+                    reindexed = subset
+
+                reduced_.append(
+                    tree_reduce(
+                        reindexed,
+                        aggregate=partial(aggregate, expected_groups=cohort, reindex=reindex),
+                    )
+                )
+                groups_.append(cohort)
+
+            reduced = dask.array.concatenate(reduced_, axis=-1)
+            groups = (np.concatenate(groups_),)
+            group_chunks = (tuple(len(cohort) for cohort in groups_),)
+
     elif method == "blockwise":
         reduced = intermediate
-        # Here one input chunk → one output chunka
+        # Here one input chunk → one output chunks
         # find number of groups in each chunk, this is needed for output chunks
         # along the reduced axis
         slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis))
@@ -1235,41 +1318,17 @@ def dask_groupby_agg(
             groups_in_block = tuple(
                 np.intersect1d(by_input[slc], expected_groups) for slc in slices
             )
+        groups = (np.concatenate(groups_in_block),)
+
         ngroups_per_block = tuple(len(grp) for grp in groups_in_block)
-        output_chunks = reduced.chunks[: -(len(axis))] + (ngroups_per_block,)
+        group_chunks = (ngroups_per_block,)
+
     else:
         raise ValueError(f"Unknown method={method}.")
 
     # extract results from the dict
-    layer: dict[tuple, tuple] = {}
+    output_chunks = reduced.chunks[: -len(axis)] + group_chunks
     ochunks = tuple(range(len(chunks_v)) for chunks_v in output_chunks)
-    if is_duck_dask_array(by_input) and expected_groups is None:
-        groups_name = f"groups-{name}-{token}"
-        # we've used keepdims=True, so _tree_reduce preserves some dummy dimensions
-        first_block = len(ochunks) * (0,)
-        layer[(groups_name, *first_block)] = (
-            operator.getitem,
-            (reduced.name, *first_block),
-            "groups",
-        )
-        groups: tuple[np.ndarray | DaskArray] = (
-            dask.array.Array(
-                HighLevelGraph.from_collections(groups_name, layer, dependencies=[reduced]),
-                groups_name,
-                chunks=group_chunks,
-                dtype=by.dtype,
-            ),
-        )
-    else:
-        if method == "map-reduce":
-            if expected_groups is None:
-                expected_groups_ = _get_expected_groups(by_input, sort=sort)
-            else:
-                expected_groups_ = expected_groups
-            groups = (expected_groups_.to_numpy(),)
-        else:
-            groups = (np.concatenate(groups_in_block),)
-
     layer2: dict[tuple, tuple] = {}
     agg_name = f"{name}-{token}"
     for ochunk in itertools.product(*ochunks):
@@ -1280,7 +1339,8 @@ def dask_groupby_agg(
                 nblocks = tuple(len(array.chunks[ax]) for ax in axis)
                 inchunk = ochunk[:-1] + np.unravel_index(ochunk[-1], nblocks)
         else:
-            inchunk = ochunk[:-1] + (0,) * len(axis) + (ochunk[-1],) * int(split_out > 1)
+            inchunk = ochunk[:-1] + (0,) * (len(axis) - 1) + (ochunk[-1],)
+
         layer2[(agg_name, *ochunk)] = (operator.getitem, (reduced.name, *inchunk), agg.name)
 
     result = dask.array.Array(
@@ -1309,6 +1369,9 @@ def _validate_reindex(reindex: bool | None, func, method: T_Method, expected_gro
     if method in ["split-reduce", "cohorts"] and reindex is False:
         raise NotImplementedError
 
+    if method in ["split-reduce", "cohorts"] and reindex is None:
+        reindex = True
+
     # TODO: Should reindex be a bool-only at this point? Would've been nice but
     # None's are relied on after this function as well.
     return reindex
@@ -1327,7 +1390,7 @@ def _assert_by_is_aligned(shape, by):
 
 
 def _convert_expected_groups_to_index(
-    expected_groups: Iterable, isbin: Sequence[bool], sort: bool
+    expected_groups: T_ExpectedGroups, isbin: Sequence[bool], sort: bool
 ) -> tuple[pd.Index | None, ...]:
     out: list[pd.Index | None] = []
     for ex, isbin_ in zip(expected_groups, isbin):
@@ -1389,14 +1452,14 @@ def _factorize_multiple(by, expected_groups, by_is_dask, reindex):
 def groupby_reduce(
     array: np.ndarray | DaskArray,
     *by: np.ndarray | DaskArray,
-    func: str | Aggregation,
-    expected_groups: Sequence | np.ndarray | None = None,
+    func: T_Agg,
+    expected_groups: T_ExpectedGroupsOpt = None,
     sort: bool = True,
     isbin: T_IsBins = False,
     axis: T_AxesOpt = None,
     fill_value=None,
+    dtype: np.typing.DTypeLike = None,
     min_count: int | None = None,
-    split_out: int = 1,
     method: T_Method = "map-reduce",
     engine: T_Engine = "numpy",
     reindex: bool | None = None,
@@ -1428,13 +1491,13 @@ def groupby_reduce(
         Negative integers are normalized using array.ndim
     fill_value : Any
         Value to assign when a label in ``expected_groups`` is not present.
+    dtype: data-type , optional
+        DType for the output. Can be anything that is accepted by ``np.dtype``.
     min_count : int, default: None
         The required number of valid values to perform the operation. If
         fewer than min_count non-NA values are present the result will be
         NA. Only used if skipna is set to True or defaults to True for the
         array's dtype.
-    split_out : int, optional
-        Number of chunks along group axis in output (last axis)
     method : {"map-reduce", "blockwise", "cohorts", "split-reduce"}, optional
         Strategy for reduction of dask arrays only:
           * ``"map-reduce"``:
@@ -1460,9 +1523,7 @@ def groupby_reduce(
             method by first rechunking using ``rechunk_for_cohorts``
             (for 1D ``by`` only).
           * ``"split-reduce"``:
-            Break out each group into its own array and then ``"map-reduce"``.
-            This is implemented by having each group be its own cohort,
-            and is identical to xarray's default strategy.
+            Same as "cohorts" and will be removed soon.
     engine : {"flox", "numpy", "numba"}, optional
         Algorithm to compute the groupby reduction on non-dask arrays and on each dask chunk:
           * ``"numpy"``:
@@ -1512,7 +1573,8 @@ def groupby_reduce(
 
     if not is_duck_array(array):
         array = np.asarray(array)
-    array = array.astype(int) if np.issubdtype(array.dtype, bool) else array
+    is_bool_array = np.issubdtype(array.dtype, bool)
+    array = array.astype(int) if is_bool_array else array
 
     if isinstance(isbin, Sequence):
         isbins = isbin
@@ -1604,7 +1666,7 @@ def groupby_reduce(
         fill_value = np.nan
 
     kwargs = dict(axis=axis_, fill_value=fill_value, engine=engine)
-    agg = _initialize_aggregation(func, array.dtype, fill_value, min_count, finalize_kwargs)
+    agg = _initialize_aggregation(func, dtype, array.dtype, fill_value, min_count, finalize_kwargs)
 
     if not has_dask:
         results = _reduce_blockwise(
@@ -1624,73 +1686,33 @@ def groupby_reduce(
                 f"\n\n Received: {func}"
             )
 
+        # TODO: just do this in dask_groupby_agg
         # we always need some fill_value (see above) so choose the default if needed
         if kwargs["fill_value"] is None:
             kwargs["fill_value"] = agg.fill_value[agg.name]
 
-        partial_agg = partial(dask_groupby_agg, split_out=split_out, **kwargs)
+        partial_agg = partial(dask_groupby_agg, **kwargs)
 
-        if method in ["split-reduce", "cohorts"]:
-            cohorts = find_group_cohorts(
-                by_, [array.chunks[ax] for ax in axis_], merge=True, method=method
-            )
-
-            results_ = []
-            groups_ = []
-            for cohort in cohorts:
-                cohort = sorted(cohort)
-                # equivalent of xarray.DataArray.where(mask, drop=True)
-                mask = np.isin(by_, cohort)
-                indexer = [np.unique(v) for v in np.nonzero(mask)]
-                array_subset = array
-                for ax, idxr in zip(range(-by_.ndim, 0), indexer):
-                    array_subset = np.take(array_subset, idxr, axis=ax)
-                numblocks = math.prod([len(array_subset.chunks[ax]) for ax in axis_])
-
-                # get final result for these groups
-                r, *g = partial_agg(
-                    array_subset,
-                    by_[np.ix_(*indexer)],
-                    expected_groups=pd.Index(cohort),
-                    # First deep copy becasue we might be doping blockwise,
-                    # which sets agg.finalize=None, then map-reduce (GH102)
-                    agg=copy.deepcopy(agg),
-                    # reindex to expected_groups at the blockwise step.
-                    # this approach avoids replacing non-cohort members with
-                    # np.nan or some other sentinel value, and preserves dtypes
-                    reindex=True,
-                    # sort controls the final output order so apply that at the end
-                    sort=False,
-                    # if only a single block along axis, we can just work blockwise
-                    # inspired by https://github.com/dask/dask/issues/8361
-                    method="blockwise" if numblocks == 1 and nax == by_.ndim else "map-reduce",
-                )
-                results_.append(r)
-                groups_.append(cohort)
+        if method == "blockwise" and by_.ndim == 1:
+            array = rechunk_for_blockwise(array, axis=-1, labels=by_)
 
-            # concatenate results together,
-            # sort to make sure we match expected output
-            groups = (np.hstack(groups_),)
-            result = np.concatenate(results_, axis=-1)
-        else:
-            if method == "blockwise" and by_.ndim == 1:
-                array = rechunk_for_blockwise(array, axis=-1, labels=by_)
-
-            result, groups = partial_agg(
-                array,
-                by_,
-                expected_groups=None if method == "blockwise" else expected_groups,
-                agg=agg,
-                reindex=reindex,
-                method=method,
-                sort=sort,
-            )
+        result, groups = partial_agg(
+            array,
+            by_,
+            expected_groups=None if method == "blockwise" else expected_groups,
+            agg=agg,
+            reindex=reindex,
+            method=method,
+            sort=sort,
+        )
 
         if sort and method != "map-reduce":
             assert len(groups) == 1
             sorted_idx = np.argsort(groups[0])
-            result = result[..., sorted_idx]
-            groups = (groups[0][sorted_idx],)
+            # This optimization helps specifically with resampling
+            if not (sorted_idx[1:] <= sorted_idx[:-1]).all():
+                result = result[..., sorted_idx]
+                groups = (groups[0][sorted_idx],)
 
     if factorize_early:
         # nan group labels are factorized to -1, and preserved
@@ -1700,4 +1722,7 @@ def groupby_reduce(
             result, from_=groups[0], to=expected_groups, fill_value=fill_value
         ).reshape(result.shape[:-1] + grp_shape)
         groups = final_groups
+
+    if _is_minmax_reduction(func) and is_bool_array:
+        result = result.astype(bool)
     return (result, *groups)


=====================================
flox/visualize.py
=====================================
@@ -136,10 +136,10 @@ def visualize_cohorts_2d(by, array, method="cohorts"):
     print("finding cohorts...")
     before_merged = find_group_cohorts(
         by, [array.chunks[ax] for ax in range(-by.ndim, 0)], merge=False, method=method
-    )
+    ).values()
     merged = find_group_cohorts(
         by, [array.chunks[ax] for ax in range(-by.ndim, 0)], merge=True, method=method
-    )
+    ).values()
     print("finished cohorts...")
 
     xticks = np.cumsum(array.chunks[-1])


=====================================
flox/xarray.py
=====================================
@@ -1,5 +1,6 @@
 from __future__ import annotations
 
+import warnings
 from typing import TYPE_CHECKING, Any, Hashable, Iterable, Sequence, Union
 
 import numpy as np
@@ -61,8 +62,8 @@ def xarray_reduce(
     isbin: bool | Sequence[bool] = False,
     sort: bool = True,
     dim: Dims | ellipsis = None,
-    split_out: int = 1,
     fill_value=None,
+    dtype: np.typing.DTypeLike = None,
     method: str = "map-reduce",
     engine: str = "numpy",
     keep_attrs: bool | None = True,
@@ -93,11 +94,11 @@ def xarray_reduce(
     dim : hashable
         dimension name along which to reduce. If None, reduces across all
         dimensions of `by`
-    split_out : int, optional
-        Number of output chunks along grouped dimension in output.
     fill_value
         Value used for missing groups in the output i.e. when one of the labels
         in ``expected_groups`` is not actually present in ``by``.
+    dtype: data-type, optional
+        DType for the output. Can be anything accepted by ``np.dtype``.
     method : {"map-reduce", "blockwise", "cohorts", "split-reduce"}, optional
         Strategy for reduction of dask arrays only:
           * ``"map-reduce"``:
@@ -123,9 +124,7 @@ def xarray_reduce(
             method by first rechunking using ``rechunk_for_cohorts``
             (for 1D ``by`` only).
           * ``"split-reduce"``:
-            Break out each group into its own array and then ``"map-reduce"``.
-            This is implemented by having each group be its own cohort,
-            and is identical to xarray's default strategy.
+            Same as "cohorts" and will be removed soon.
     engine : {"flox", "numpy", "numba"}, optional
         Algorithm to compute the groupby reduction on non-dask arrays and on each dask chunk:
           * ``"numpy"``:
@@ -387,13 +386,14 @@ def xarray_reduce(
         exclude_dims=set(dim_tuple),
         output_core_dims=[group_names],
         dask="allowed",
-        dask_gufunc_kwargs=dict(output_sizes=group_sizes),
+        dask_gufunc_kwargs=dict(
+            output_sizes=group_sizes, output_dtypes=[dtype] if dtype is not None else None
+        ),
         keep_attrs=keep_attrs,
         kwargs={
             "func": func,
             "axis": axis,
             "sort": sort,
-            "split_out": split_out,
             "fill_value": fill_value,
             "method": method,
             "min_count": min_count,
@@ -403,6 +403,7 @@ def xarray_reduce(
             "expected_groups": tuple(expected_groups),
             "isbin": isbins,
             "finalize_kwargs": finalize_kwargs,
+            "dtype": dtype,
         },
     )
 
@@ -561,6 +562,11 @@ def resample_reduce(
     **kwargs,
 ):
 
+    warnings.warn(
+        "flox.xarray.resample_reduce is now deprecated. Please use Xarray's resample method directly.",
+        DeprecationWarning,
+    )
+
     obj = resampler._obj
     dim = resampler._group_dim
 


=====================================
tests/__init__.py
=====================================
@@ -14,7 +14,7 @@ try:
 
     dask_array_type = da.Array
 except ImportError:
-    dask_array_type = ()
+    dask_array_type = ()  # type: ignore
 
 
 try:
@@ -22,7 +22,7 @@ try:
 
     xr_types = (xr.DataArray, xr.Dataset)
 except ImportError:
-    xr_types = ()
+    xr_types = ()  # type: ignore
 
 
 def _importorskip(modname, minversion=None):
@@ -80,25 +80,39 @@ def raise_if_dask_computes(max_computes=0):
     return dask.config.set(scheduler=scheduler)
 
 
-def assert_equal(a, b):
+def assert_equal(a, b, tolerance=None):
     __tracebackhide__ = True
 
     if isinstance(a, list):
         a = np.array(a)
     if isinstance(b, list):
         b = np.array(b)
+
     if isinstance(a, pd_types) or isinstance(b, pd_types):
         pd.testing.assert_index_equal(a, b)
-    elif has_xarray and isinstance(a, xr_types) or isinstance(b, xr_types):
+        return
+    if has_xarray and isinstance(a, xr_types) or isinstance(b, xr_types):
         xr.testing.assert_identical(a, b)
-    elif has_dask and isinstance(a, dask_array_type) or isinstance(b, dask_array_type):
+        return
+
+    if tolerance is None and (
+        np.issubdtype(a.dtype, np.float64) | np.issubdtype(b.dtype, np.float64)
+    ):
+        tolerance = {"atol": 1e-18, "rtol": 1e-15}
+    else:
+        tolerance = {}
+
+    if has_dask and isinstance(a, dask_array_type) or isinstance(b, dask_array_type):
         # sometimes it's nice to see values and shapes
         # rather than being dropped into some file in dask
-        np.testing.assert_allclose(a, b)
+        np.testing.assert_allclose(a, b, **tolerance)
         # does some validation of the dask graph
         da.utils.assert_eq(a, b, equal_nan=True)
     else:
-        np.testing.assert_allclose(a, b, equal_nan=True)
+        if a.dtype != b.dtype:
+            raise AssertionError(f"a and b have different dtypes: (a: {a.dtype}, b: {b.dtype})")
+
+        np.testing.assert_allclose(a, b, equal_nan=True, **tolerance)
 
 
 @pytest.fixture(scope="module", params=["flox", "numpy", "numba"])


=====================================
tests/test_core.py
=====================================
@@ -1,4 +1,7 @@
+from __future__ import annotations
+
 from functools import reduce
+from typing import TYPE_CHECKING
 
 import numpy as np
 import pandas as pd
@@ -63,6 +66,9 @@ ALL_FUNCS = (
     pytest.param("nanmedian", marks=(pytest.mark.skip,)),
 )
 
+if TYPE_CHECKING:
+    from flox.core import T_Engine, T_ExpectedGroupsOpt, T_Func2
+
 
 def test_alignment_error():
     da = np.ones((12,))
@@ -73,7 +79,7 @@ def test_alignment_error():
 
 
 @pytest.mark.parametrize("dtype", (float, int))
- at pytest.mark.parametrize("chunk, split_out", [(False, 1), (True, 1), (True, 2), (True, 3)])
+ at pytest.mark.parametrize("chunk", [False, True])
 @pytest.mark.parametrize("expected_groups", [None, [0, 1, 2], np.array([0, 1, 2])])
 @pytest.mark.parametrize(
     "func, array, by, expected",
@@ -101,8 +107,15 @@ def test_alignment_error():
     ],
 )
 def test_groupby_reduce(
-    array, by, expected, func, expected_groups, chunk, split_out, dtype, engine
-):
+    engine: T_Engine,
+    func: T_Func2,
+    array: np.ndarray,
+    by: np.ndarray,
+    expected: list[float],
+    expected_groups: T_ExpectedGroupsOpt,
+    chunk: bool,
+    dtype: np.typing.DTypeLike,
+) -> None:
     array = array.astype(dtype)
     if chunk:
         if not has_dask or expected_groups is None:
@@ -110,12 +123,12 @@ def test_groupby_reduce(
         array = da.from_array(array, chunks=(3,) if array.ndim == 1 else (1, 3))
         by = da.from_array(by, chunks=(3,) if by.ndim == 1 else (1, 3))
 
-    if "mean" in func:
-        expected = np.array(expected, dtype=float)
+    if func == "mean" or func == "nanmean":
+        expected_result = np.array(expected, dtype=float)
     elif func == "sum":
-        expected = np.array(expected, dtype=dtype)
+        expected_result = np.array(expected, dtype=dtype)
     elif func == "count":
-        expected = np.array(expected, dtype=int)
+        expected_result = np.array(expected, dtype=int)
 
     result, groups, = groupby_reduce(
         array,
@@ -123,11 +136,12 @@ def test_groupby_reduce(
         func=func,
         expected_groups=expected_groups,
         fill_value=123,
-        split_out=split_out,
         engine=engine,
     )
-    assert_equal(groups, [0, 1, 2])
-    assert_equal(expected, result)
+    g_dtype = by.dtype if expected_groups is None else np.asarray(expected_groups).dtype
+
+    assert_equal(groups, np.array([0, 1, 2], g_dtype))
+    assert_equal(expected_result, result)
 
 
 def gen_array_by(size, func):
@@ -169,8 +183,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
     if "var" in func or "std" in func:
         finalize_kwargs = finalize_kwargs + [{"ddof": 1}, {"ddof": 0}]
         fill_value = np.nan
+        tolerance = {"rtol": 1e-14, "atol": 1e-16}
     else:
         fill_value = None
+        tolerance = None
 
     for kwargs in finalize_kwargs:
         flox_kwargs = dict(func=func, engine=engine, finalize_kwargs=kwargs, fill_value=fill_value)
@@ -191,7 +207,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
             assert_equal(actual_group, expect)
         if "arg" in func:
             assert actual.dtype.kind == "i"
-        assert_equal(actual, expected)
+        assert_equal(actual, expected, tolerance)
 
         if not has_dask:
             continue
@@ -200,10 +216,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
                 continue
             actual, *groups = groupby_reduce(array, *by, method=method, **flox_kwargs)
             for actual_group, expect in zip(groups, expected_groups):
-                assert_equal(actual_group, expect)
+                assert_equal(actual_group, expect, tolerance)
             if "arg" in func:
                 assert actual.dtype.kind == "i"
-            assert_equal(actual, expected)
+            assert_equal(actual, expected, tolerance)
 
 
 @requires_dask
@@ -450,6 +466,11 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
         fill_value = False
     else:
         fill_value = 123
+
+    if "var" in func or "std" in func:
+        tolerance = {"rtol": 1e-14, "atol": 1e-16}
+    else:
+        tolerance = None
     # tests against the numpy output to make sure dask compute matches
     by = np.broadcast_to(labels2d, (3, *labels2d.shape))
     rng = np.random.default_rng(12345)
@@ -468,7 +489,7 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
         kwargs.pop("engine")
         expected_npg, _ = groupby_reduce(array, by, **kwargs, engine="numpy")
         assert_equal(expected_npg, expected)
-    assert_equal(actual, expected)
+    assert_equal(actual, expected, tolerance)
 
 
 @pytest.mark.parametrize("chunks", [None, (2, 2, 3)])
@@ -646,11 +667,11 @@ def test_rechunk_for_blockwise(inchunks, expected):
         [[[1, 2, 3, 4]], [1, 2, 3, 1, 2, 3, 4], (3, 4), True],
         [[[1, 2, 3], [4]], [1, 2, 3, 1, 2, 3, 4], (3, 4), False],
         [[[1], [2], [3], [4]], [1, 2, 3, 1, 2, 3, 4], (2, 2, 2, 1), False],
-        [[[3], [2], [1], [4]], [1, 2, 3, 1, 2, 3, 4], (2, 2, 2, 1), True],
+        [[[1], [2], [3], [4]], [1, 2, 3, 1, 2, 3, 4], (2, 2, 2, 1), True],
         [[[1, 2, 3], [4]], [1, 2, 3, 1, 2, 3, 4], (3, 3, 1), True],
         [[[1, 2, 3], [4]], [1, 2, 3, 1, 2, 3, 4], (3, 3, 1), False],
         [
-            [[2, 3, 4, 1], [5], [0]],
+            [[0], [1, 2, 3, 4], [5]],
             np.repeat(np.arange(6), [4, 4, 12, 2, 3, 4]),
             (4, 8, 4, 9, 4),
             True,
@@ -658,11 +679,7 @@ def test_rechunk_for_blockwise(inchunks, expected):
     ],
 )
 def test_find_group_cohorts(expected, labels, chunks, merge):
-    actual = list(find_group_cohorts(labels, (chunks,), merge, method="cohorts"))
-    assert actual == expected, (actual, expected)
-
-    actual = find_group_cohorts(labels, (chunks,), merge, method="split-reduce")
-    expected = [[label] for label in np.unique(labels)]
+    actual = list(find_group_cohorts(labels, (chunks,), merge).values())
     assert actual == expected, (actual, expected)
 
 
@@ -776,11 +793,9 @@ def test_cohorts_nd_by(func, method, axis, engine):
     assert_equal(actual, expected)
 
     actual, groups = groupby_reduce(array, by, sort=False, **kwargs)
-    if method == "cohorts":
-        assert_equal(groups, [4, 3, 40, 2, 31, 1, 30])
-    elif method in ("split-reduce", "map-reduce"):
+    if method == "map-reduce":
         assert_equal(groups, [1, 30, 2, 31, 3, 4, 40])
-    elif method == "blockwise":
+    else:
         assert_equal(groups, [1, 30, 2, 31, 3, 40, 4])
     reindexed = reindex_(actual, groups, pd.Index(sorted_groups))
     assert_equal(reindexed, expected)
@@ -843,16 +858,16 @@ def test_bool_reductions(func, engine):
 
 
 @requires_dask
-def test_map_reduce_blockwise_mixed():
+def test_map_reduce_blockwise_mixed() -> None:
     t = pd.date_range("2000-01-01", "2000-12-31", freq="D").to_series()
     data = t.dt.dayofyear
-    actual = groupby_reduce(
+    actual, _ = groupby_reduce(
         dask.array.from_array(data.values, chunks=365),
         t.dt.month,
         func="mean",
         method="split-reduce",
     )
-    expected = groupby_reduce(data, t.dt.month, func="mean")
+    expected, _ = groupby_reduce(data, t.dt.month, func="mean")
     assert_equal(expected, actual)
 
 
@@ -908,7 +923,7 @@ def test_factorize_values_outside_bins():
     assert_equal(expected, actual)
 
 
-def test_multiple_groupers():
+def test_multiple_groupers() -> None:
     actual, *_ = groupby_reduce(
         np.ones((5, 2)),
         np.arange(10).reshape(5, 2),
@@ -921,7 +936,7 @@ def test_multiple_groupers():
         reindex=True,
         func="count",
     )
-    expected = np.eye(5, 5)
+    expected = np.eye(5, 5, dtype=int)
     assert_equal(expected, actual)
 
 
@@ -1009,3 +1024,14 @@ def test_custom_aggregation_blockwise():
         method="blockwise",
     )
     assert_equal(expected, actual)
+
+
+ at pytest.mark.parametrize("func", ALL_FUNCS)
+ at pytest.mark.parametrize("dtype", [np.float32, np.float64])
+def test_dtype(func, dtype, engine):
+    if "arg" in func or func in ["any", "all"]:
+        pytest.skip()
+    arr = np.ones((4, 12), dtype=dtype)
+    labels = np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"])
+    actual, _ = groupby_reduce(arr, labels, func=func, dtype=np.float64)
+    assert actual.dtype == np.dtype("float64")


=====================================
tests/test_xarray.py
=====================================
@@ -24,6 +24,10 @@ except ValueError:
     pass
 
 
+tolerance64 = {"rtol": 1e-15, "atol": 1e-18}
+np.random.seed(123)
+
+
 @pytest.mark.parametrize("reindex", [None, False, True])
 @pytest.mark.parametrize("min_count", [None, 1, 3])
 @pytest.mark.parametrize("add_nan", [True, False])
@@ -250,10 +254,15 @@ def test_xarray_resample(chunklen, isdask, dataarray, engine):
         ds = ds.air
 
     resampler = ds.resample(time="M")
-    actual = resample_reduce(resampler, "mean", engine=engine)
+    with pytest.warns(DeprecationWarning):
+        actual = resample_reduce(resampler, "mean", engine=engine)
     expected = resampler.mean()
     xr.testing.assert_allclose(actual, expected)
 
+    with xr.set_options(use_flox=True):
+        actual = resampler.mean()
+    xr.testing.assert_allclose(actual, expected)
+
 
 @requires_dask
 def test_xarray_resample_dataset_multiple_arrays(engine):
@@ -488,3 +497,76 @@ def test_mixed_grouping(chunk):
         fill_value=0,
     )
     assert (r.sel(v1=[3, 4, 5]) == 0).all().data
+
+
+ at pytest.mark.parametrize("add_nan", [True, False])
+ at pytest.mark.parametrize("dtype_out", [np.float64, "float64", np.dtype("float64")])
+ at pytest.mark.parametrize("dtype", [np.float32, np.float64])
+ at pytest.mark.parametrize("chunk", (True, False))
+def test_dtype(add_nan, chunk, dtype, dtype_out, engine):
+    if chunk and not has_dask:
+        pytest.skip()
+
+    xp = dask.array if chunk else np
+    data = xp.linspace(0, 1, 48, dtype=dtype).reshape((4, 12))
+
+    if add_nan:
+        data[1, ...] = np.nan
+        data[0, [0, 2]] = np.nan
+
+    arr = xr.DataArray(
+        data,
+        dims=("x", "t"),
+        coords={
+            "labels": ("t", np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"]))
+        },
+        name="arr",
+    )
+    kwargs = dict(func="mean", dtype=dtype_out, engine=engine)
+    actual = xarray_reduce(arr, "labels", **kwargs)
+    expected = arr.groupby("labels").mean(dtype="float64")
+
+    assert actual.dtype == np.dtype("float64")
+    assert actual.compute().dtype == np.dtype("float64")
+    xr.testing.assert_allclose(expected, actual, **tolerance64)
+
+    actual = xarray_reduce(arr.to_dataset(), "labels", **kwargs)
+    expected = arr.to_dataset().groupby("labels").mean(dtype="float64")
+
+    assert actual.arr.dtype == np.dtype("float64")
+    assert actual.compute().arr.dtype == np.dtype("float64")
+    xr.testing.assert_allclose(expected, actual.transpose("labels", ...), **tolerance64)
+
+
+ at pytest.mark.parametrize("chunk", [True, False])
+ at pytest.mark.parametrize("use_flox", [True, False])
+def test_dtype_accumulation(use_flox, chunk):
+    if chunk and not has_dask:
+        pytest.skip()
+
+    datetimes = pd.date_range("2010-01", "2015-01", freq="6H", inclusive="left")
+    samples = 10 + np.cos(2 * np.pi * 0.001 * np.arange(len(datetimes))) * 1
+    samples += np.random.randn(len(datetimes))
+    samples = samples.astype("float32")
+
+    nan_indices = np.random.default_rng().integers(0, len(samples), size=5_000)
+    samples[nan_indices] = np.nan
+
+    da = xr.DataArray(samples, dims=("time",), coords=[datetimes])
+    if chunk:
+        da = da.chunk(time=1024)
+
+    gb = da.groupby("time.month")
+
+    with xr.set_options(use_flox=use_flox):
+        expected = gb.reduce(np.nanmean)
+        actual = gb.mean()
+        xr.testing.assert_allclose(expected, actual)
+        assert np.issubdtype(actual.dtype, np.float32)
+        assert np.issubdtype(actual.compute().dtype, np.float32)
+
+        expected = gb.reduce(np.nanmean, dtype="float64")
+        actual = gb.mean(dtype="float64")
+        assert np.issubdtype(actual.dtype, np.float64)
+        assert np.issubdtype(actual.compute().dtype, np.float64)
+        xr.testing.assert_allclose(expected, actual, **tolerance64)



View it on GitLab: https://salsa.debian.org/debian-gis-team/flox/-/commit/69e4060018b277ffd562f0d910fc6f25ed953da4

-- 
View it on GitLab: https://salsa.debian.org/debian-gis-team/flox/-/commit/69e4060018b277ffd562f0d910fc6f25ed953da4
You're receiving this email because of your account on salsa.debian.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://alioth-lists.debian.net/pipermail/pkg-grass-devel/attachments/20221015/c2f7e18e/attachment-0001.htm>


More information about the Pkg-grass-devel mailing list