[Git][debian-gis-team/flox][upstream] New upstream version 0.9.11

Antonio Valentino (@antonio.valentino) gitlab at salsa.debian.org
Fri Sep 13 21:55:24 BST 2024



Antonio Valentino pushed to branch upstream at Debian GIS Project / flox


Commits:
4a49f04c by Antonio Valentino at 2024-09-13T18:14:23+00:00
New upstream version 0.9.11
- - - - -


25 changed files:

- .github/workflows/ci.yaml
- .pre-commit-config.yaml
- asv_bench/benchmarks/__init__.py
- asv_bench/benchmarks/cohorts.py
- asv_bench/benchmarks/combine.py
- asv_bench/benchmarks/reduce.py
- ci/docs.yml
- ci/environment.yml
- ci/no-dask.yml
- docs/source/conf.py
- docs/source/user-stories/climatology-hourly.ipynb
- flox/__init__.py
- flox/aggregate_flox.py
- flox/aggregations.py
- flox/core.py
- flox/xarray.py
- flox/xrdtypes.py
- flox/xrutils.py
- pyproject.toml
- tests/__init__.py
- tests/strategies.py
- tests/test_asv.py
- tests/test_core.py
- tests/test_properties.py
- tests/test_xarray.py


Changes:

=====================================
.github/workflows/ci.yaml
=====================================
@@ -26,7 +26,7 @@ jobs:
       matrix:
         os: ["ubuntu-latest"]
         env: ["environment"]
-        python-version: ["3.9", "3.12"]
+        python-version: ["3.10", "3.12"]
         include:
           - os: "windows-latest"
             env: "environment"
@@ -36,7 +36,7 @@ jobs:
             python-version: "3.12"
           - os: "ubuntu-latest"
             env: "minimal-requirements"
-            python-version: "3.9"
+            python-version: "3.10"
     steps:
       - uses: actions/checkout at v4
         with:
@@ -70,6 +70,7 @@ jobs:
       - name: Run Tests
         id: status
         run: |
+          python -c "import xarray; xarray.show_versions()"
           pytest --durations=20 --durations-min=0.5 -n auto --cov=./ --cov-report=xml --hypothesis-profile ci
       - name: Upload code coverage to Codecov
         uses: codecov/codecov-action at v4.5.0
@@ -98,7 +99,7 @@ jobs:
     steps:
       - uses: actions/checkout at v4
         with:
-          repository: "pydata/xarray"
+          repository: "dcherian/xarray"
           fetch-depth: 0 # Fetch all history for all branches and tags.
       - name: Set up conda environment
         uses: mamba-org/setup-micromamba at v1
@@ -112,6 +113,7 @@ jobs:
             pint>=0.22
       - name: Install xarray
         run: |
+          git checkout flox-preserve-dtype
           python -m pip install --no-deps .
       - name: Install upstream flox
         run: |


=====================================
.pre-commit-config.yaml
=====================================
@@ -4,10 +4,11 @@ ci:
 repos:
   - repo: https://github.com/astral-sh/ruff-pre-commit
     # Ruff version.
-    rev: "v0.5.0"
+    rev: "v0.6.4"
     hooks:
       - id: ruff
         args: ["--fix", "--show-fixes"]
+      - id: ruff-format
 
   - repo: https://github.com/pre-commit/mirrors-prettier
     rev: "v4.0.0-alpha.8"
@@ -22,11 +23,6 @@ repos:
       - id: end-of-file-fixer
       - id: check-docstring-first
 
-  - repo: https://github.com/psf/black-pre-commit-mirror
-    rev: 24.4.2
-    hooks:
-      - id: black
-
   - repo: https://github.com/executablebooks/mdformat
     rev: 0.7.17
     hooks:
@@ -35,13 +31,6 @@ repos:
           - mdformat-black
           - mdformat-myst
 
-  - repo: https://github.com/nbQA-dev/nbQA
-    rev: 1.8.5
-    hooks:
-      - id: nbqa-black
-      - id: nbqa-ruff
-        args: [--fix]
-
   - repo: https://github.com/kynan/nbstripout
     rev: 0.7.1
     hooks:
@@ -56,7 +45,7 @@ repos:
           - tomli
 
   - repo: https://github.com/abravalheri/validate-pyproject
-    rev: v0.18
+    rev: v0.19
     hooks:
       - id: validate-pyproject
 


=====================================
asv_bench/benchmarks/__init__.py
=====================================
@@ -21,7 +21,6 @@ def _skip_slow():
     >>> from . import _skip_slow
     >>> def time_something_slow():
     ...     pass
-    ...
     >>> time_something.setup = _skip_slow
     """
     if os.environ.get("ASV_SKIP_SLOW", "0") == "1":


=====================================
asv_bench/benchmarks/cohorts.py
=====================================
@@ -67,7 +67,12 @@ class Cohorts:
     track_num_tasks.unit = "tasks"  # type: ignore[attr-defined] # Lazy
     track_num_tasks_optimized.unit = "tasks"  # type: ignore[attr-defined] # Lazy
     track_num_layers.unit = "layers"  # type: ignore[attr-defined] # Lazy
-    for f in [track_num_tasks, track_num_tasks_optimized, track_num_layers, track_num_cohorts]:
+    for f in [
+        track_num_tasks,
+        track_num_tasks_optimized,
+        track_num_layers,
+        track_num_cohorts,
+    ]:
         f.repeat = 1  # type: ignore[attr-defined] # Lazy
         f.rounds = 1  # type: ignore[attr-defined] # Lazy
         f.number = 1  # type: ignore[attr-defined] # Lazy
@@ -82,9 +87,7 @@ class NWMMidwest(Cohorts):
         y = np.repeat(np.arange(30), 60)
         by = x[np.newaxis, :] * y[:, np.newaxis]
 
-        self.by = flox.core._factorize_multiple((by,), expected_groups=(None,), any_by_dask=False)[
-            0
-        ][0]
+        self.by = flox.core._factorize_multiple((by,), expected_groups=(None,), any_by_dask=False)[0][0]
 
         self.array = dask.array.ones(self.by.shape, chunks=(350, 350))
         self.axis = (-2, -1)
@@ -101,7 +104,12 @@ class ERA5Dataset:
 
     def rechunk(self):
         self.array = flox.core.rechunk_for_cohorts(
-            self.array, -1, self.by, force_new_chunk_at=[1], chunksize=48, ignore_old_chunks=True
+            self.array,
+            -1,
+            self.by,
+            force_new_chunk_at=[1],
+            chunksize=48,
+            ignore_old_chunks=True,
         )
 
 
@@ -151,7 +159,12 @@ class PerfectMonthly(Cohorts):
 
     def rechunk(self):
         self.array = flox.core.rechunk_for_cohorts(
-            self.array, -1, self.by, force_new_chunk_at=[1], chunksize=4, ignore_old_chunks=True
+            self.array,
+            -1,
+            self.by,
+            force_new_chunk_at=[1],
+            chunksize=4,
+            ignore_old_chunks=True,
         )
 
 


=====================================
asv_bench/benchmarks/combine.py
=====================================
@@ -65,12 +65,8 @@ class Combine1d(Combine):
             * 2
         ]
 
-        self.x_chunk_reindexed = [
-            construct_member(groups) for groups in [np.array((1, 2, 3, 4))] * 4
-        ]
+        self.x_chunk_reindexed = [construct_member(groups) for groups in [np.array((1, 2, 3, 4))] * 4]
         self.kwargs = {
-            "agg": flox.aggregations._initialize_aggregation(
-                "sum", "float64", np.float64, 0, 0, {}
-            ),
+            "agg": flox.aggregations._initialize_aggregation("sum", "float64", np.float64, 0, 0, {}),
             "axis": (3,),
         }


=====================================
asv_bench/benchmarks/reduce.py
=====================================
@@ -7,7 +7,11 @@ import flox.aggregations
 
 N = 3000
 funcs = ["sum", "nansum", "mean", "nanmean", "max", "nanmax", "count"]
-engines = [None, "flox", "numpy"]  # numbagg is disabled for now since it takes ages in CI
+engines = [
+    None,
+    "flox",
+    "numpy",
+]  # numbagg is disabled for now since it takes ages in CI
 expected_groups = {
     "None": None,
     "bins": pd.IntervalIndex.from_breaks([1, 2, 4]),
@@ -17,9 +21,7 @@ expected_names = tuple(expected_groups)
 NUMBAGG_FUNCS = ["nansum", "nanmean", "nanmax", "count", "all"]
 numbagg_skip = []
 for name in expected_names:
-    numbagg_skip.extend(
-        list((func, name, "numbagg") for func in funcs if func not in NUMBAGG_FUNCS)
-    )
+    numbagg_skip.extend(list((func, name, "numbagg") for func in funcs if func not in NUMBAGG_FUNCS))
 
 
 def setup_jit():


=====================================
ci/docs.yml
=====================================
@@ -16,6 +16,7 @@ dependencies:
   - myst-parser
   - myst-nb
   - sphinx
+  - sphinx-remove-toctrees
   - furo>=2024.08
   - ipykernel
   - jupyter


=====================================
ci/environment.yml
=====================================
@@ -19,7 +19,6 @@ dependencies:
   - pytest-pretty
   - pytest-xdist
   - syrupy
-  - xarray
   - pre-commit
   - numpy_groupies>=0.9.19
   - pooch
@@ -27,3 +26,5 @@ dependencies:
   - numba
   - numbagg>=0.3
   - hypothesis
+  - pip:
+      - git+https://github.com/dcherian/xarray.git@flox-preserve-dtype


=====================================
ci/no-dask.yml
=====================================
@@ -14,7 +14,6 @@ dependencies:
   - pytest-pretty
   - pytest-xdist
   - syrupy
-  - xarray
   - numpydoc
   - pre-commit
   - numpy_groupies>=0.9.19
@@ -22,3 +21,5 @@ dependencies:
   - toolz
   - numba
   - numbagg>=0.3
+  - pip:
+      - git+https://github.com/dcherian/xarray.git@flox-preserve-dtype


=====================================
docs/source/conf.py
=====================================
@@ -40,6 +40,7 @@ extensions = [
     "sphinx.ext.napoleon",
     "myst_nb",
     "sphinx_codeautolink",
+    "sphinx_remove_toctrees",
 ]
 
 codeautolink_concat_default = True
@@ -54,6 +55,8 @@ source_suffix = [".rst"]
 master_doc = "index"
 language = "en"
 
+remove_from_toctrees = ["generated/*"]
+
 # General information about the project.
 project = "flox"
 current_year = datetime.datetime.now().year


=====================================
docs/source/user-stories/climatology-hourly.ipynb
=====================================
@@ -92,7 +92,6 @@
     "%load_ext watermark\n",
     "\n",
     "\n",
-    "\n",
     "%watermark -iv"
    ]
   },


=====================================
flox/__init__.py
=====================================
@@ -1,9 +1,15 @@
 #!/usr/bin/env python
 # flake8: noqa
 """Top-level module for flox ."""
+
 from . import cache
 from .aggregations import Aggregation, Scan  # noqa
-from .core import groupby_reduce, groupby_scan, rechunk_for_blockwise, rechunk_for_cohorts  # noqa
+from .core import (
+    groupby_reduce,
+    groupby_scan,
+    rechunk_for_blockwise,
+    rechunk_for_cohorts,
+)  # noqa
 
 
 def _get_version():


=====================================
flox/aggregate_flox.py
=====================================
@@ -89,10 +89,14 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non
         idxshape = (q.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],)
 
     lo_ = np.floor(
-        virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64)
+        virtual_index,
+        casting="unsafe",
+        out=np.empty(virtual_index.shape, dtype=np.int64),
     )
     hi_ = np.ceil(
-        virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64)
+        virtual_index,
+        casting="unsafe",
+        out=np.empty(virtual_index.shape, dtype=np.int64),
     )
     kth = np.unique(np.concatenate([lo_.reshape(-1), hi_.reshape(-1)]))
 
@@ -119,7 +123,15 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non
 
 
 def _np_grouped_op(
-    group_idx, array, op, axis=-1, size=None, fill_value=None, dtype=None, out=None, **kwargs
+    group_idx,
+    array,
+    op,
+    axis=-1,
+    size=None,
+    fill_value=None,
+    dtype=None,
+    out=None,
+    **kwargs,
 ):
     """
     most of this code is from shoyer's gist


=====================================
flox/aggregations.py
=====================================
@@ -3,10 +3,10 @@ from __future__ import annotations
 import copy
 import logging
 import warnings
-from collections.abc import Sequence
+from collections.abc import Callable, Sequence
 from dataclasses import dataclass
 from functools import cached_property, partial
-from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict
+from typing import TYPE_CHECKING, Any, Literal, TypedDict
 
 import numpy as np
 import pandas as pd
@@ -110,7 +110,13 @@ def generic_aggregate(
     with warnings.catch_warnings():
         warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered")
         result = method(
-            group_idx, array, axis=axis, size=size, fill_value=fill_value, dtype=dtype, **kwargs
+            group_idx,
+            array,
+            axis=axis,
+            size=size,
+            fill_value=fill_value,
+            dtype=dtype,
+            **kwargs,
         )
     return result
 
@@ -238,9 +244,7 @@ class Aggregation:
         # The following are set by _initialize_aggregation
         self.finalize_kwargs: dict[Any, Any] = {}
         self.min_count: int = 0
-        self.new_dims_func: Callable = (
-            returns_empty_tuple if new_dims_func is None else new_dims_func
-        )
+        self.new_dims_func: Callable = returns_empty_tuple if new_dims_func is None else new_dims_func
         self.preserves_dtype = preserves_dtype
 
     @cached_property
@@ -386,11 +390,19 @@ nanstd = Aggregation(
 
 min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF, preserves_dtype=True)
 nanmin = Aggregation(
-    "nanmin", chunk="nanmin", combine="nanmin", fill_value=dtypes.NA, preserves_dtype=True
+    "nanmin",
+    chunk="nanmin",
+    combine="nanmin",
+    fill_value=dtypes.NA,
+    preserves_dtype=True,
 )
 max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF, preserves_dtype=True)
 nanmax = Aggregation(
-    "nanmax", chunk="nanmax", combine="nanmax", fill_value=dtypes.NA, preserves_dtype=True
+    "nanmax",
+    chunk="nanmax",
+    combine="nanmax",
+    fill_value=dtypes.NA,
+    preserves_dtype=True,
 )
 
 
@@ -482,10 +494,18 @@ nanargmin = Aggregation(
 first = Aggregation("first", chunk=None, combine=None, fill_value=None, preserves_dtype=True)
 last = Aggregation("last", chunk=None, combine=None, fill_value=None, preserves_dtype=True)
 nanfirst = Aggregation(
-    "nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=dtypes.NA, preserves_dtype=True
+    "nanfirst",
+    chunk="nanfirst",
+    combine="nanfirst",
+    fill_value=dtypes.NA,
+    preserves_dtype=True,
 )
 nanlast = Aggregation(
-    "nanlast", chunk="nanlast", combine="nanlast", fill_value=dtypes.NA, preserves_dtype=True
+    "nanlast",
+    chunk="nanlast",
+    combine="nanlast",
+    fill_value=dtypes.NA,
+    preserves_dtype=True,
 )
 
 all_ = Aggregation(
@@ -510,10 +530,18 @@ any_ = Aggregation(
 # Support statistical quantities only blockwise
 # The parallel versions will be approximate and are hard to implement!
 median = Aggregation(
-    name="median", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.floating
+    name="median",
+    fill_value=dtypes.NA,
+    chunk=None,
+    combine=None,
+    final_dtype=np.floating,
 )
 nanmedian = Aggregation(
-    name="nanmedian", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.floating
+    name="nanmedian",
+    fill_value=dtypes.NA,
+    chunk=None,
+    combine=None,
+    final_dtype=np.floating,
 )
 
 
@@ -521,12 +549,15 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
     return (Dim(name="quantile", values=q),)
 
 
+# if the input contains integers or floats smaller than float64,
+# the output data-type is float64. Otherwise, the output data-type is the same as that
+# of the input.
 quantile = Aggregation(
     name="quantile",
     fill_value=dtypes.NA,
     chunk=None,
     combine=None,
-    final_dtype=np.floating,
+    final_dtype=np.float64,
     new_dims_func=quantile_new_dims_func,
 )
 nanquantile = Aggregation(
@@ -534,15 +565,11 @@ nanquantile = Aggregation(
     fill_value=dtypes.NA,
     chunk=None,
     combine=None,
-    final_dtype=np.floating,
+    final_dtype=np.float64,
     new_dims_func=quantile_new_dims_func,
 )
-mode = Aggregation(
-    name="mode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True
-)
-nanmode = Aggregation(
-    name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True
-)
+mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True)
+nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True)
 
 
 @dataclass
@@ -658,9 +685,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan)
             engine="flox",
             fill_value=agg.identity,
         )
-        result = AlignedArrays(
-            array=final_value[..., left.group_idx.size :], group_idx=right.group_idx
-        )
+        result = AlignedArrays(array=final_value[..., left.group_idx.size :], group_idx=right.group_idx)
     else:
         raise ValueError(f"Unknown binary op application mode: {agg.mode!r}")
 
@@ -780,10 +805,8 @@ def _initialize_aggregation(
         np.dtype(dtype) if dtype is not None and not isinstance(dtype, np.dtype) else dtype
     )
     final_dtype = dtypes._normalize_dtype(
-        dtype_ or agg.dtype_init["final"], array_dtype, fill_value
+        dtype_ or agg.dtype_init["final"], array_dtype, agg.preserves_dtype, fill_value
     )
-    if not agg.preserves_dtype:
-        final_dtype = dtypes._maybe_promote_int(final_dtype)
     agg.dtype = {
         "user": dtype,  # Save to automatically choose an engine
         "final": final_dtype,
@@ -794,9 +817,7 @@ def _initialize_aggregation(
                 if int_dtype is None
                 else np.dtype(int_dtype)
             )
-            for int_dtype, int_fv in zip(
-                agg.dtype_init["intermediate"], agg.fill_value["intermediate"]
-            )
+            for int_dtype, int_fv in zip(agg.dtype_init["intermediate"], agg.fill_value["intermediate"])
         ),
     }
 


=====================================
flox/core.py
=====================================
@@ -8,7 +8,7 @@ import operator
 import sys
 import warnings
 from collections import namedtuple
-from collections.abc import Sequence
+from collections.abc import Callable, Sequence
 from concurrent.futures import ThreadPoolExecutor
 from functools import partial, reduce
 from itertools import product
@@ -16,8 +16,8 @@ from numbers import Integral
 from typing import (
     TYPE_CHECKING,
     Any,
-    Callable,
     Literal,
+    TypeAlias,
     TypedDict,
     TypeVar,
     Union,
@@ -74,37 +74,37 @@ if TYPE_CHECKING:
     import dask.array.Array as DaskArray
     from dask.typing import Graph
 
-    T_DuckArray = Union[np.ndarray, DaskArray, CubedArray]  # Any ?
-    T_By = T_DuckArray
+    T_DuckArray: TypeAlias = np.ndarray | DaskArray | CubedArray  # Any ?
+    T_By: TypeAlias = T_DuckArray
     T_Bys = tuple[T_By, ...]
     T_ExpectIndex = pd.Index
     T_ExpectIndexTuple = tuple[T_ExpectIndex, ...]
-    T_ExpectIndexOpt = Union[T_ExpectIndex, None]
+    T_ExpectIndexOpt = T_ExpectIndex | None
     T_ExpectIndexOptTuple = tuple[T_ExpectIndexOpt, ...]
-    T_Expect = Union[Sequence, np.ndarray, T_ExpectIndex]
+    T_Expect = Sequence | np.ndarray | T_ExpectIndex
     T_ExpectTuple = tuple[T_Expect, ...]
-    T_ExpectOpt = Union[Sequence, np.ndarray, T_ExpectIndexOpt]
+    T_ExpectOpt = Sequence | np.ndarray | T_ExpectIndexOpt
     T_ExpectOptTuple = tuple[T_ExpectOpt, ...]
-    T_ExpectedGroups = Union[T_Expect, T_ExpectOptTuple]
-    T_ExpectedGroupsOpt = Union[T_ExpectedGroups, None]
-    T_Func = Union[str, Callable]
-    T_Funcs = Union[T_Func, Sequence[T_Func]]
-    T_Agg = Union[str, Aggregation]
-    T_Scan = Union[str, Scan]
+    T_ExpectedGroups = T_Expect | T_ExpectOptTuple
+    T_ExpectedGroupsOpt = T_ExpectedGroups | None
+    T_Func = str | Callable
+    T_Funcs = T_Func | Sequence[T_Func]
+    T_Agg = str | Aggregation
+    T_Scan = str | Scan
     T_Axis = int
     T_Axes = tuple[T_Axis, ...]
-    T_AxesOpt = Union[T_Axis, T_Axes, None]
-    T_Dtypes = Union[np.typing.DTypeLike, Sequence[np.typing.DTypeLike], None]
-    T_FillValues = Union[np.typing.ArrayLike, Sequence[np.typing.ArrayLike], None]
+    T_AxesOpt = T_Axis | T_Axes | None
+    T_Dtypes = np.typing.DTypeLike | Sequence[np.typing.DTypeLike] | None
+    T_FillValues = np.typing.ArrayLike | Sequence[np.typing.ArrayLike] | None
     T_Engine = Literal["flox", "numpy", "numba", "numbagg"]
     T_EngineOpt = None | T_Engine
     T_Method = Literal["map-reduce", "blockwise", "cohorts"]
     T_MethodOpt = None | Literal["map-reduce", "blockwise", "cohorts"]
-    T_IsBins = Union[bool | Sequence[bool]]
+    T_IsBins = bool | Sequence[bool]
 
 T = TypeVar("T")
 
-IntermediateDict = dict[Union[str, Callable], Any]
+IntermediateDict = dict[str | Callable, Any]
 FinalResultsDict = dict[str, Union["DaskArray", "CubedArray", np.ndarray]]
 FactorProps = namedtuple("FactorProps", "offset_group nan_sentinel nanmask")
 
@@ -136,9 +136,7 @@ def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups):
     # The condition needs to be
     # len(found_groups) < size; if so we mask with fill_value (?)
     default_fv = DEFAULT_FILL_VALUE[func]
-    needs_masking = fill_value is not None and not np.array_equal(
-        fill_value, default_fv, equal_nan=True
-    )
+    needs_masking = fill_value is not None and not np.array_equal(fill_value, default_fv, equal_nan=True)
     groups = np.arange(size)
     if needs_masking:
         mask = np.isin(groups, seen_groups, assume_unique=True, invert=True)
@@ -164,9 +162,7 @@ def _is_arg_reduction(func: T_Agg) -> bool:
 
 
 def _is_minmax_reduction(func: T_Agg) -> bool:
-    return not _is_arg_reduction(func) and (
-        isinstance(func, str) and ("max" in func or "min" in func)
-    )
+    return not _is_arg_reduction(func) and (isinstance(func, str) and ("max" in func or "min" in func))
 
 
 def _is_first_last_reduction(func: T_Agg) -> bool:
@@ -254,8 +250,7 @@ def slices_from_chunks(chunks):
     """slightly modified from dask.array.core.slices_from_chunks to be lazy"""
     cumdims = [tlz.accumulate(operator.add, bds, 0) for bds in chunks]
     slices = (
-        (slice(s, s + dim) for s, dim in zip(starts, shapes))
-        for starts, shapes in zip(cumdims, chunks)
+        (slice(s, s + dim) for s, dim in zip(starts, shapes)) for starts, shapes in zip(cumdims, chunks)
     )
     return product(*slices)
 
@@ -396,9 +391,7 @@ def find_group_cohorts(
         chunks_per_label = chunks_per_label[present_labels_mask]
 
     label_chunks = {
-        present_labels[idx].item(): bitmask.indices[
-            slice(bitmask.indptr[idx], bitmask.indptr[idx + 1])
-        ]
+        present_labels[idx].item(): bitmask.indices[slice(bitmask.indptr[idx], bitmask.indptr[idx + 1])]
         for idx in range(bitmask.shape[LABEL_AXIS])
     }
 
@@ -510,9 +503,7 @@ def find_group_cohorts(
     for rowidx in order:
         if present_labels[rowidx] in merged_keys:
             continue
-        cohidx = containment.indices[
-            slice(containment.indptr[rowidx], containment.indptr[rowidx + 1])
-        ]
+        cohidx = containment.indices[slice(containment.indptr[rowidx], containment.indptr[rowidx + 1])]
         cohort_ = present_labels[cohidx]
         cohort = [elem.item() for elem in cohort_ if elem not in merged_keys]
         if not cohort:
@@ -604,9 +595,7 @@ def rechunk_for_cohorts(
         else:
             next_break_is_close = False
 
-        if (not ignore_old_chunks and idx in oldbreaks) or (
-            counter >= chunksize and not next_break_is_close
-        ):
+        if (not ignore_old_chunks and idx in oldbreaks) or (counter >= chunksize and not next_break_is_close):
             divisions.append(idx)
             counter = 1
             continue
@@ -922,7 +911,10 @@ def chunk_argreduce(
 
     if reindex and expected_groups is not None:
         results["intermediates"][1] = reindex_(
-            results["intermediates"][1], results["groups"].squeeze(), expected_groups, fill_value=0
+            results["intermediates"][1],
+            results["groups"].squeeze(),
+            expected_groups,
+            fill_value=0,
         )
 
     assert results["intermediates"][0].shape == results["intermediates"][1].shape
@@ -1017,8 +1009,7 @@ def chunk_reduce(
     order = "C"
     if nax > 1:
         needs_broadcast = any(
-            group_idx.shape[ax] != array.shape[ax] and group_idx.shape[ax] == 1
-            for ax in range(-nax, 0)
+            group_idx.shape[ax] != array.shape[ax] and group_idx.shape[ax] == 1 for ax in range(-nax, 0)
         )
         if needs_broadcast:
             # This is the dim=... case, it's a lot faster to ravel group_idx
@@ -1098,9 +1089,7 @@ def chunk_reduce(
                 result = result[..., :-1]
             # TODO: Figure out how to generalize this
             if reduction in ("quantile", "nanquantile"):
-                new_dims_shape = tuple(
-                    dim.size for dim in quantile_new_dims_func(**kw) if not dim.is_scalar
-                )
+                new_dims_shape = tuple(dim.size for dim in quantile_new_dims_func(**kw) if not dim.is_scalar)
             else:
                 new_dims_shape = tuple()
             result = result.reshape(new_dims_shape + final_array_shape[:-1] + found_groups_shape)
@@ -1168,7 +1157,10 @@ def _finalize_results(
     # Final reindexing has to be here to be lazy
     if not reindex and expected_groups is not None:
         finalized[agg.name] = reindex_(
-            finalized[agg.name], squeezed["groups"], expected_groups, fill_value=fill_value
+            finalized[agg.name],
+            squeezed["groups"],
+            expected_groups,
+            fill_value=fill_value,
         )
         finalized["groups"] = expected_groups.to_numpy()
     else:
@@ -1194,9 +1186,7 @@ def _aggregate(
 
 
 def _expand_dims(results: IntermediateDict) -> IntermediateDict:
-    results["intermediates"] = tuple(
-        np.expand_dims(array, DUMMY_AXIS) for array in results["intermediates"]
-    )
+    results["intermediates"] = tuple(np.expand_dims(array, DUMMY_AXIS) for array in results["intermediates"])
     return results
 
 
@@ -1238,7 +1228,8 @@ def _simple_combine(
         # So now reindex before combining by reducing along DUMMY_AXIS
         unique_groups = _find_unique_groups(x_chunk)
         x_chunk = deepmap(
-            partial(reindex_intermediates, agg=agg, unique_groups=unique_groups), x_chunk
+            partial(reindex_intermediates, agg=agg, unique_groups=unique_groups),
+            x_chunk,
         )
     else:
         unique_groups = deepfirst(x_chunk)["groups"]
@@ -1280,7 +1271,10 @@ def reindex_intermediates(x: IntermediateDict, agg: Aggregation, unique_groups)
     newx: IntermediateDict = {"groups": np.broadcast_to(unique_groups, new_shape)}
     newx["intermediates"] = tuple(
         reindex_(
-            v, from_=np.atleast_1d(x["groups"].squeeze()), to=pd.Index(unique_groups), fill_value=f
+            v,
+            from_=np.atleast_1d(x["groups"].squeeze()),
+            to=pd.Index(unique_groups),
+            fill_value=f,
         )
         for v, f in zip(x["intermediates"], agg.fill_value["intermediate"])
     )
@@ -1315,7 +1309,8 @@ def _grouped_combine(
         # I bet we can minimize the amount of reindexing for mD reductions too, but it's complicated
         unique_groups = _find_unique_groups(x_chunk)
         x_chunk = deepmap(
-            partial(reindex_intermediates, agg=agg, unique_groups=unique_groups), x_chunk
+            partial(reindex_intermediates, agg=agg, unique_groups=unique_groups),
+            x_chunk,
         )
 
     # these are negative axis indices useful for concatenating the intermediates
@@ -1332,15 +1327,16 @@ def _grouped_combine(
 
         # We need to send the intermediate array values & indexes at the same time
         # intermediates are (value e.g. max, index e.g. argmax, counts)
-        array_idx = tuple(
-            _conc2(x_chunk, key1="intermediates", key2=idx, axis=axis) for idx in (0, 1)
-        )
+        array_idx = tuple(_conc2(x_chunk, key1="intermediates", key2=idx, axis=axis) for idx in (0, 1))
 
         # for a single element along axis, we don't want to run the argreduction twice
         # This happens when we are reducing along an axis with a single chunk.
         avoid_reduction = array_idx[0].shape[axis[0]] == 1
         if avoid_reduction:
-            results: IntermediateDict = {"groups": groups, "intermediates": list(array_idx)}
+            results: IntermediateDict = {
+                "groups": groups,
+                "intermediates": list(array_idx),
+            }
         else:
             results = chunk_argreduce(
                 array_idx,
@@ -1387,12 +1383,8 @@ def _grouped_combine(
             array = _conc2(x_chunk, key1="intermediates", key2=idx, axis=axis)
             if array.shape[-1] == 0:
                 # all empty when combined
-                results["intermediates"].append(
-                    np.empty(shape=(1,) * (len(axis) - 1) + (0,), dtype=dtype)
-                )
-                results["groups"] = np.empty(
-                    shape=(1,) * (len(neg_axis) - 1) + (0,), dtype=groups.dtype
-                )
+                results["intermediates"].append(np.empty(shape=(1,) * (len(axis) - 1) + (0,), dtype=dtype))
+                results["groups"] = np.empty(shape=(1,) * (len(neg_axis) - 1) + (0,), dtype=groups.dtype)
             else:
                 _results = chunk_reduce(
                     array,
@@ -1456,9 +1448,7 @@ def _reduce_blockwise(
     if _is_arg_reduction(agg):
         results["intermediates"][0] = np.unravel_index(results["intermediates"][0], array.shape)[-1]
 
-    result = _finalize_results(
-        results, agg, axis, expected_groups, fill_value=fill_value, reindex=reindex
-    )
+    result = _finalize_results(results, agg, axis, expected_groups, fill_value=fill_value, reindex=reindex)
     return result
 
 
@@ -1570,7 +1560,6 @@ def _extract_unknown_groups(reduced, dtype) -> tuple[DaskArray]:
 
 
 def _unify_chunks(array, by):
-
     from dask.array import from_array, unify_chunks
 
     inds = tuple(range(array.ndim))
@@ -1653,9 +1642,7 @@ def dask_groupby_agg(
 
     if method == "blockwise":
         #  use the "non dask" code path, but applied blockwise
-        blockwise_method = partial(
-            _reduce_blockwise, agg=agg, fill_value=fill_value, reindex=reindex
-        )
+        blockwise_method = partial(_reduce_blockwise, agg=agg, fill_value=fill_value, reindex=reindex)
     else:
         # choose `chunk_reduce` or `chunk_argreduce`
         blockwise_method = partial(
@@ -1752,7 +1739,9 @@ def dask_groupby_agg(
                         reindexed,
                         combine=partial(combine, agg=agg, reindex=do_simple_combine),
                         aggregate=partial(
-                            aggregate, expected_groups=cohort_index, reindex=do_simple_combine
+                            aggregate,
+                            expected_groups=cohort_index,
+                            reindex=do_simple_combine,
                         ),
                     )
                 )
@@ -1882,9 +1871,7 @@ def cubed_groupby_agg(
         # let's always do it anyway
         if not is_chunked_array(by):
             # chunk numpy arrays like the input array
-            chunks = tuple(
-                array.chunks[ax] if by.shape[ax] != 1 else (1,) for ax in range(-by.ndim, 0)
-            )
+            chunks = tuple(array.chunks[ax] if by.shape[ax] != 1 else (1,) for ax in range(-by.ndim, 0))
 
             by = cubed.from_array(by, chunks=chunks, spec=array.spec)
         _, (array, by) = cubed.core.unify_chunks(array, inds, by, inds[-by.ndim :])
@@ -1918,9 +1905,7 @@ def cubed_groupby_agg(
             out = blockwise_method(a, by)
             # Convert dict to one that cubed understands, dropping groups since they are
             # known, and the same for every block.
-            return {
-                f"f{idx}": intermediate for idx, intermediate in enumerate(out["intermediates"])
-            }
+            return {f"f{idx}": intermediate for idx, intermediate in enumerate(out["intermediates"])}
 
         def _groupby_combine(a, axis, dummy_axis, dtype, keepdims):
             # this is similar to _simple_combine, except the dummy axis and concatenation is handled by cubed
@@ -2003,18 +1988,14 @@ def _validate_reindex(
 ) -> bool | None:
     # logger.debug("Entering _validate_reindex: reindex is {}".format(reindex))  # noqa
     def first_or_last():
-        return func in ["first", "last"] or (
-            _is_first_last_reduction(func) and array_dtype.kind != "f"
-        )
+        return func in ["first", "last"] or (_is_first_last_reduction(func) and array_dtype.kind != "f")
 
     all_numpy = not is_dask_array and not any_by_dask
     if reindex is True and not all_numpy:
         if _is_arg_reduction(func):
             raise NotImplementedError
         if method == "cohorts" or (method == "blockwise" and not any_by_dask):
-            raise ValueError(
-                "reindex=True is not a valid choice for method='blockwise' or method='cohorts'."
-            )
+            raise ValueError("reindex=True is not a valid choice for method='blockwise' or method='cohorts'.")
         if first_or_last():
             raise ValueError("reindex must be None or False when func is 'first' or 'last.")
 
@@ -2144,9 +2125,7 @@ def _factorize_multiple(
         for by_, expect in zip(by, expected_groups):
             if expect is None:
                 if is_duck_dask_array(by_):
-                    raise ValueError(
-                        "Please provide expected_groups when grouping by a dask array."
-                    )
+                    raise ValueError("Please provide expected_groups when grouping by a dask array.")
 
                 found_group = pd.unique(by_.reshape(-1))
             else:
@@ -2177,7 +2156,7 @@ def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) ->
         return (None,) * nby
 
     if nby == 1 and not isinstance(expected_groups, tuple):
-        if isinstance(expected_groups, (pd.Index, np.ndarray)):
+        if isinstance(expected_groups, pd.Index | np.ndarray):
             return (expected_groups,)
         else:
             array = np.asarray(expected_groups)
@@ -2252,9 +2231,7 @@ def _choose_engine(by, agg: Aggregation):
         (isinstance(func, str) and "nan" in func) for func in agg.chunk
     )
     if HAS_NUMBAGG:
-        if agg.name in ["all", "any"] or (
-            not_arg_reduce and has_blockwise_nan_skipping and dtype is None
-        ):
+        if agg.name in ["all", "any"] or (not_arg_reduce and has_blockwise_nan_skipping and dtype is None):
             logger.debug("_choose_engine: Choosing 'numbagg'")
             return "numbagg"
 
@@ -2411,11 +2388,7 @@ def groupby_reduce(
     any_by_dask = any(by_is_dask)
     provided_expected = expected_groups is not None
 
-    if (
-        engine == "numbagg"
-        and _is_arg_reduction(func)
-        and (any_by_dask or is_duck_dask_array(array))
-    ):
+    if engine == "numbagg" and _is_arg_reduction(func) and (any_by_dask or is_duck_dask_array(array)):
         # There is only one test that fails, but I can't figure
         # out why without deep debugging.
         # just disable for now.
@@ -2515,9 +2488,7 @@ def groupby_reduce(
         # TODO: Does this depend on chunking of by?
         # For e.g., we could relax this if there is only one chunk along all
         # by dim != axis?
-        raise NotImplementedError(
-            "Please provide ``expected_groups`` when not reducing along all axes."
-        )
+        raise NotImplementedError("Please provide ``expected_groups`` when not reducing along all axes.")
 
     assert nax <= by_.ndim
     if nax < by_.ndim:
@@ -2580,7 +2551,13 @@ def groupby_reduce(
 
     elif not has_dask:
         results = _reduce_blockwise(
-            array, by_, agg, expected_groups=expected_, reindex=bool(reindex), sort=sort, **kwargs
+            array,
+            by_,
+            agg,
+            expected_groups=expected_,
+            reindex=bool(reindex),
+            sort=sort,
+            **kwargs,
         )
         groups = (results["groups"],)
         result = results[agg.name]
@@ -2627,7 +2604,13 @@ def groupby_reduce(
 
         # TODO: clean this up
         reindex = _validate_reindex(
-            reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array), array.dtype
+            reindex,
+            func,
+            method,
+            expected_,
+            any_by_dask,
+            is_duck_dask_array(array),
+            array.dtype,
         )
 
         if TYPE_CHECKING:
@@ -2798,9 +2781,7 @@ def groupby_scan(
     if expected_groups is not None:
         raise NotImplementedError("Setting `expected_groups` and binning is not supported yet.")
     expected_groups = _validate_expected_groups(nby, expected_groups)
-    expected_groups = _convert_expected_groups_to_index(
-        expected_groups, isbin=(False,) * nby, sort=False
-    )
+    expected_groups = _convert_expected_groups_to_index(expected_groups, isbin=(False,) * nby, sort=False)
 
     # Don't factorize early only when
     # grouping by dask arrays, and not having expected_groups
@@ -2918,7 +2899,12 @@ def dask_groupby_scan(array, by, axes: T_Axes, agg: Scan) -> DaskArray:
 
     # 1. zip together group indices & array
     zipped = map_blocks(
-        _zip, by, array, dtype=array.dtype, meta=array._meta, name="groupby-scan-preprocess"
+        _zip,
+        by,
+        array,
+        dtype=array.dtype,
+        meta=array._meta,
+        name="groupby-scan-preprocess",
     )
 
     scan_ = partial(chunk_scan, agg=agg)


=====================================
flox/xarray.py
=====================================
@@ -1,7 +1,7 @@
 from __future__ import annotations
 
 from collections.abc import Hashable, Iterable, Sequence
-from typing import TYPE_CHECKING, Any, Union
+from typing import TYPE_CHECKING, Any
 
 import numpy as np
 import pandas as pd
@@ -25,7 +25,7 @@ if TYPE_CHECKING:
 
     from .core import T_ExpectedGroupsOpt, T_ExpectIndex, T_ExpectOpt
 
-    Dims = Union[str, Iterable[Hashable], None]
+    Dims = str | Iterable[Hashable] | None
 
 
 def _restore_dim_order(result, obj, by, no_groupby_reorder=False):
@@ -286,9 +286,7 @@ def xarray_reduce(
     try:
         xr.align(ds, *by_da, join="exact", copy=False)
     except ValueError as e:
-        raise ValueError(
-            "Object being grouped must be exactly aligned with every array in `by`."
-        ) from e
+        raise ValueError("Object being grouped must be exactly aligned with every array in `by`.") from e
 
     needs_broadcast = any(
         not set(grouper_dims).issubset(set(variable.dims)) for variable in ds.data_vars.values()
@@ -329,15 +327,11 @@ def xarray_reduce(
     group_names: tuple[Any, ...] = ()
     group_sizes: dict[Any, int] = {}
     for idx, (b_, expect, isbin_) in enumerate(zip(by_da, expected_groups_valid, isbins)):
-        group_name = (
-            f"{b_.name}_bins" if isbin_ or isinstance(expect, pd.IntervalIndex) else b_.name
-        )
+        group_name = f"{b_.name}_bins" if isbin_ or isinstance(expect, pd.IntervalIndex) else b_.name
         group_names += (group_name,)
 
         if isbin_ and isinstance(expect, int):
-            raise NotImplementedError(
-                "flox does not support binning into an integer number of bins yet."
-            )
+            raise NotImplementedError("flox does not support binning into an integer number of bins yet.")
 
         expect1: T_ExpectOpt
         if expect is None:
@@ -448,7 +442,8 @@ def xarray_reduce(
         output_core_dims=[output_core_dims],
         dask="allowed",
         dask_gufunc_kwargs=dict(
-            output_sizes=output_sizes, output_dtypes=[dtype] if dtype is not None else None
+            output_sizes=output_sizes,
+            output_dtypes=[dtype] if dtype is not None else None,
         ),
         keep_attrs=keep_attrs,
         kwargs={
@@ -520,11 +515,12 @@ def xarray_reduce(
                 template = obj
 
             if actual[var].ndim > 1 + len(vector_dims):
-                no_groupby_reorder = isinstance(
-                    obj, xr.Dataset
-                )  # do not re-order dataarrays inside datasets
+                no_groupby_reorder = isinstance(obj, xr.Dataset)  # do not re-order dataarrays inside datasets
                 actual[var] = _restore_dim_order(
-                    actual[var], template, by_da[0], no_groupby_reorder=no_groupby_reorder
+                    actual[var].variable,
+                    template,
+                    by_da[0],
+                    no_groupby_reorder=no_groupby_reorder,
                 )
 
     if missing_dim:
@@ -625,13 +621,14 @@ def _rechunk(func, obj, dim, labels, **kwargs):
             if obj[var].chunks is not None:
                 obj[var] = obj[var].copy(
                     data=func(
-                        obj[var].data, axis=obj[var].get_axis_num(dim), labels=labels.data, **kwargs
+                        obj[var].data,
+                        axis=obj[var].get_axis_num(dim),
+                        labels=labels.data,
+                        **kwargs,
                     )
                 )
     else:
         if obj.chunks is not None:
-            obj = obj.copy(
-                data=func(obj.data, axis=obj.get_axis_num(dim), labels=labels.data, **kwargs)
-            )
+            obj = obj.copy(data=func(obj.data, axis=obj.get_axis_num(dim), labels=labels.data, **kwargs))
 
     return obj


=====================================
flox/xrdtypes.py
=====================================
@@ -150,9 +150,14 @@ def is_datetime_like(dtype):
     return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)
 
 
-def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype:
+def _normalize_dtype(
+    dtype: DTypeLike, array_dtype: np.dtype, preserves_dtype: bool, fill_value=None
+) -> np.dtype:
     if dtype is None:
-        dtype = array_dtype
+        if not preserves_dtype:
+            dtype = _maybe_promote_int(array_dtype)
+        else:
+            dtype = array_dtype
     if dtype is np.floating:
         # mean, std, var always result in floating
         # but we preserve the array's dtype if it is floating


=====================================
flox/xrutils.py
=====================================
@@ -4,14 +4,14 @@
 import datetime
 import importlib
 from collections.abc import Iterable
-from typing import Any, Optional
+from typing import Any
 
 import numpy as np
 import pandas as pd
 from packaging.version import Version
 
 
-def module_available(module: str, minversion: Optional[str] = None) -> bool:
+def module_available(module: str, minversion: str | None = None) -> bool:
     """Checks whether a module is installed without importing it.
 
     Use this for a lightweight check and lazy imports.
@@ -137,7 +137,7 @@ def is_scalar(value: Any, include_0d: bool = True) -> bool:
         include_0d = getattr(value, "ndim", None) == 0
     return (
         include_0d
-        or isinstance(value, (str, bytes, dict))
+        or isinstance(value, str | bytes | dict)
         or not (
             isinstance(value, (Iterable,) + NON_NUMPY_SUPPORTED_ARRAY_TYPES)
             or hasattr(value, "__array_function__")
@@ -150,7 +150,7 @@ def notnull(data):
         data = np.asarray(data)
 
     scalar_type = data.dtype.type
-    if issubclass(scalar_type, (np.bool_, np.integer, np.character, np.void)):
+    if issubclass(scalar_type, np.bool_ | np.integer | np.character | np.void):
         # these types cannot represent missing values
         return np.ones_like(data, dtype=bool)
     else:
@@ -163,7 +163,7 @@ def isnull(data):
     if not is_duck_array(data):
         data = np.asarray(data)
     scalar_type = data.dtype.type
-    if issubclass(scalar_type, (np.datetime64, np.timedelta64)):
+    if issubclass(scalar_type, np.datetime64 | np.timedelta64):
         # datetime types use NaT for null
         # note: must check timedelta64 before integers, because currently
         # timedelta64 inherits from np.integer
@@ -171,12 +171,12 @@ def isnull(data):
     elif issubclass(scalar_type, np.inexact):
         # float types use NaN for null
         return np.isnan(data)
-    elif issubclass(scalar_type, (np.bool_, np.integer, np.character, np.void)):
+    elif issubclass(scalar_type, np.bool_ | np.integer | np.character | np.void):
         # these types cannot represent missing values
         return np.zeros_like(data, dtype=bool)
     else:
         # at this point, array should have dtype=object
-        if isinstance(data, (np.ndarray, dask_array_type)):
+        if isinstance(data, (np.ndarray, dask_array_type)):  # noqa
             return pd.isnull(data)
         else:
             # Not reachable yet, but intended for use with other duck array
@@ -275,9 +275,7 @@ def timedelta_to_numeric(value, datetime_unit="ns", dtype=float):
         try:
             a = pd.to_timedelta(value)
         except ValueError:
-            raise ValueError(
-                f"Could not convert {value!r} to timedelta64 using pandas.to_timedelta"
-            )
+            raise ValueError(f"Could not convert {value!r} to timedelta64 using pandas.to_timedelta")
         return py_timedelta_to_float(a, datetime_unit)
     else:
         raise TypeError(


=====================================
pyproject.toml
=====================================
@@ -3,7 +3,7 @@ name = "flox"
 description = "GroupBy operations for dask.array"
 license = {file = "LICENSE"}
 readme = "README.md"
-requires-python = ">=3.9"
+requires-python = ">=3.10"
 keywords = ["xarray", "dask", "groupby"]
 classifiers = [
     "Development Status :: 4 - Beta",
@@ -11,7 +11,6 @@ classifiers = [
     "Natural Language :: English",
     "Operating System :: OS Independent",
     "Programming Language :: Python",
-    "Programming Language :: Python :: 3.9",
     "Programming Language :: Python :: 3.10",
     "Programming Language :: Python :: 3.11",
     "Programming Language :: Python :: 3.12",
@@ -60,12 +59,9 @@ fallback_version = "999"
 write_to = "flox/_version.py"
 write_to_template= '__version__ = "{version}"'
 
-[tool.black]
-line-length = 100
-target-version = ["py39"]
-
 [tool.ruff]
-target-version = "py39"
+line-length = 110
+target-version = "py310"
 builtins = ["ellipsis"]
 exclude = [
     ".eggs",
@@ -109,6 +105,10 @@ known-third-party = [
     "xarray"
 ]
 
+[tool.ruff.format]
+# Enable reformatting of code snippets in docstrings.
+docstring-code-format = true
+
 [tool.mypy]
 allow_redefinition = true
 files = "**/*.py"


=====================================
tests/__init__.py
=====================================
@@ -188,9 +188,9 @@ def dask_assert_eq(
     a_original = a
     b_original = b
 
-    if isinstance(a, (list, int, float)):
+    if isinstance(a, list | int | float):
         a = np.array(a)
-    if isinstance(b, (list, int, float)):
+    if isinstance(b, list | int | float):
         b = np.array(b)
 
     a, adt, a_meta, a_computed = _get_dt_meta_computed(


=====================================
tests/strategies.py
=====================================
@@ -1,6 +1,7 @@
 from __future__ import annotations
 
-from typing import Any, Callable
+from collections.abc import Callable
+from typing import Any
 
 import cftime
 import dask
@@ -26,22 +27,28 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]:
 
 
 # TODO: stop excluding everything but U
-array_dtype_st = supported_dtypes().filter(lambda x: x.kind not in "cmMU")
+array_dtypes = supported_dtypes().filter(lambda x: x.kind not in "cmMU")
 by_dtype_st = supported_dtypes()
 
-NON_NUMPY_FUNCS = ["first", "last", "nanfirst", "nanlast", "count", "any", "all"] + list(
-    SCIPY_STATS_FUNCS
-)
+NON_NUMPY_FUNCS = [
+    "first",
+    "last",
+    "nanfirst",
+    "nanlast",
+    "count",
+    "any",
+    "all",
+] + list(SCIPY_STATS_FUNCS)
 SKIPPED_FUNCS = ["var", "std", "nanvar", "nanstd"]
 
-func_st = st.sampled_from(
-    [f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS]
-)
+func_st = st.sampled_from([f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS])
 numeric_arrays = npst.arrays(
-    elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtype_st
+    elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtypes
 )
 all_arrays = npst.arrays(
-    elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=supported_dtypes()
+    elements={"allow_subnormal": False},
+    shape=npst.array_shapes(),
+    dtype=supported_dtypes(),
 )
 
 calendars = st.sampled_from(


=====================================
tests/test_asv.py
=====================================
@@ -7,9 +7,7 @@ pytest.importorskip("dask")
 from asv_bench.benchmarks import reduce
 
 
- at pytest.mark.parametrize(
-    "problem", [reduce.ChunkReduce1D, reduce.ChunkReduce2D, reduce.ChunkReduce2DAllAxes]
-)
+ at pytest.mark.parametrize("problem", [reduce.ChunkReduce1D, reduce.ChunkReduce2D, reduce.ChunkReduce2DAllAxes])
 def test_reduce(problem) -> None:
     testcase = problem()
     testcase.setup()


=====================================
tests/test_core.py
=====================================
@@ -3,8 +3,9 @@ from __future__ import annotations
 import itertools
 import logging
 import warnings
+from collections.abc import Callable
 from functools import partial, reduce
-from typing import TYPE_CHECKING, Callable
+from typing import TYPE_CHECKING
 from unittest.mock import MagicMock, patch
 
 import numpy as np
@@ -80,7 +81,7 @@ def _get_array_func(func: str) -> Callable:
 
         def npfunc(x, **kwargs):
             x = np.asarray(x)
-            return (~np.isnan(x)).sum()
+            return (~xrutils.isnull(x)).sum(**kwargs)
 
     elif func in ["nanfirst", "nanlast"]:
         npfunc = getattr(xrutils, func)
@@ -126,14 +127,24 @@ def test_alignment_error():
         ("sum", np.ones((12,)), nan_labels, [1, 4, 2]),  # form 1
         ("sum", np.ones((2, 12)), labels, [[3, 4, 5], [3, 4, 5]]),  # form 3
         ("sum", np.ones((2, 12)), nan_labels, [[1, 4, 2], [1, 4, 2]]),  # form 3
-        ("sum", np.ones((2, 12)), np.array([labels, labels]), [6, 8, 10]),  # form 1 after reshape
+        (
+            "sum",
+            np.ones((2, 12)),
+            np.array([labels, labels]),
+            [6, 8, 10],
+        ),  # form 1 after reshape
         ("sum", np.ones((2, 12)), np.array([nan_labels, nan_labels]), [2, 8, 4]),
         # (np.ones((12,)), np.array([labels, labels])),  # form 4
         ("count", np.ones((12,)), labels, [3, 4, 5]),  # form 1
         ("count", np.ones((12,)), nan_labels, [1, 4, 2]),  # form 1
         ("count", np.ones((2, 12)), labels, [[3, 4, 5], [3, 4, 5]]),  # form 3
         ("count", np.ones((2, 12)), nan_labels, [[1, 4, 2], [1, 4, 2]]),  # form 3
-        ("count", np.ones((2, 12)), np.array([labels, labels]), [6, 8, 10]),  # form 1 after reshape
+        (
+            "count",
+            np.ones((2, 12)),
+            np.array([labels, labels]),
+            [6, 8, 10],
+        ),  # form 1 after reshape
         ("count", np.ones((2, 12)), np.array([nan_labels, nan_labels]), [2, 8, 4]),
         ("nanmean", np.ones((12,)), labels, [1, 1, 1]),  # form 1
         ("nanmean", np.ones((12,)), nan_labels, [1, 1, 1]),  # form 1
@@ -215,9 +226,7 @@ def gen_array_by(size, func):
 @pytest.mark.parametrize("add_nan_by", [True, False])
 @pytest.mark.parametrize("func", ALL_FUNCS)
 def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
-    if ("arg" in func and engine in ["flox", "numbagg"]) or (
-        func in BLOCKWISE_FUNCS and chunks != -1
-    ):
+    if ("arg" in func and engine in ["flox", "numbagg"]) or (func in BLOCKWISE_FUNCS and chunks != -1):
         pytest.skip()
 
     array, by = gen_array_by(size, func)
@@ -237,7 +246,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
         fill_value = np.nan
         tolerance = {"rtol": 1e-13, "atol": 1e-15}
     elif "quantile" in func:
-        finalize_kwargs = [{"q": DEFAULT_QUANTILE}, {"q": [DEFAULT_QUANTILE / 2, DEFAULT_QUANTILE]}]
+        finalize_kwargs = [
+            {"q": DEFAULT_QUANTILE},
+            {"q": [DEFAULT_QUANTILE / 2, DEFAULT_QUANTILE]},
+        ]
         fill_value = None
         tolerance = None
     else:
@@ -313,7 +325,12 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
         combine_error = RuntimeError("This combine should not have been called.")
         for method, reindex in params:
             call = partial(
-                groupby_reduce, array, *by, method=method, reindex=reindex, **flox_kwargs
+                groupby_reduce,
+                array,
+                *by,
+                method=method,
+                reindex=reindex,
+                **flox_kwargs,
             )
             if ("arg" in func or func in ["first", "last"]) and reindex is True:
                 # simple_combine with argreductions not supported right now
@@ -461,7 +478,9 @@ def test_groupby_agg_dask(func, shape, array_chunks, group_chunks, add_nan, dtyp
         labels[-2:] = np.nan
 
     kwargs = dict(
-        func=func, expected_groups=[0, 1, 2], fill_value=False if func in ["all", "any"] else 123
+        func=func,
+        expected_groups=[0, 1, 2],
+        fill_value=False if func in ["all", "any"] else 123,
     )
 
     expected, _ = groupby_reduce(array.compute(), labels, engine="numpy", **kwargs)
@@ -674,15 +693,16 @@ def test_first_last_disallowed_dask(func):
     # anything else is not.
     with pytest.raises(ValueError):
         groupby_reduce(
-            dask.array.empty((2, 3, 2), chunks=(-1, -1, 1)), np.ones((2,)), func=func, axis=-1
+            dask.array.empty((2, 3, 2), chunks=(-1, -1, 1)),
+            np.ones((2,)),
+            func=func,
+            axis=-1,
         )
 
 
 @requires_dask
 @pytest.mark.parametrize("func", ALL_FUNCS)
- at pytest.mark.parametrize(
-    "axis", [None, (0, 1, 2), (0, 1), (0, 2), (1, 2), 0, 1, 2, (0,), (1,), (2,)]
-)
+ at pytest.mark.parametrize("axis", [None, (0, 1, 2), (0, 1), (0, 2), (1, 2), 0, 1, 2, (0,), (1,), (2,)])
 def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
     if ("arg" in func and engine in ["flox", "numbagg"]) or func in BLOCKWISE_FUNCS:
         pytest.skip()
@@ -797,7 +817,8 @@ def test_groupby_reduce_nans(reindex, chunks, axis, groups, expected_shape, engi
 
 @requires_dask
 @pytest.mark.parametrize(
-    "expected_groups, reindex", [(None, None), (None, False), ([0, 1, 2], True), ([0, 1, 2], False)]
+    "expected_groups, reindex",
+    [(None, None), (None, False), ([0, 1, 2], True), ([0, 1, 2], False)],
 )
 def test_groupby_all_nan_blocks_dask(expected_groups, reindex, engine):
     labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0])
@@ -848,11 +869,16 @@ def test_bad_npg_behaviour():
     # fmt: off
     array = np.array([[1] * 12, [1] * 12])
     # fmt: on
-    assert_equal(aggregate(labels, array, axis=-1, func="argmax"), np.array([[0, 5, 2], [0, 5, 2]]))
+    assert_equal(
+        aggregate(labels, array, axis=-1, func="argmax"),
+        np.array([[0, 5, 2], [0, 5, 2]]),
+    )
 
     assert (
         aggregate(
-            np.array([0, 1, 2, 0, 1, 2]), np.array([-np.inf, 0, 0, -np.inf, 0, 0]), func="max"
+            np.array([0, 1, 2, 0, 1, 2]),
+            np.array([-np.inf, 0, 0, -np.inf, 0, 0]),
+            func="max",
         )[0]
         == -np.inf
     )
@@ -900,13 +926,17 @@ def test_groupby_bins(chunk_labels, kwargs, chunks, engine, method) -> None:
 
     with raise_if_dask_computes():
         actual, *groups = groupby_reduce(
-            array, labels, func="count", fill_value=0, engine=engine, method=method, **kwargs
+            array,
+            labels,
+            func="count",
+            fill_value=0,
+            engine=engine,
+            method=method,
+            **kwargs,
         )
     (groups_array,) = groups
     expected = np.array([3, 1, 0], dtype=np.intp)
-    for left, right in zip(
-        groups_array, pd.IntervalIndex.from_arrays([1, 2, 4], [2, 4, 5]).to_numpy()
-    ):
+    for left, right in zip(groups_array, pd.IntervalIndex.from_arrays([1, 2, 4], [2, 4, 5]).to_numpy()):
         assert left == right
     assert_equal(actual, expected)
 
@@ -940,7 +970,11 @@ def test_rechunk_for_blockwise(inchunks, expected):
         [[[0, 1, 2, 3]], [0, 1, 2, 0, 1, 2, 3], (3, 4)],
         [[[0], [1], [2], [3]], [0, 1, 2, 0, 1, 2, 3], (2, 2, 2, 1)],
         [[[0, 1, 2], [3]], [0, 1, 2, 0, 1, 2, 3], (3, 3, 1)],
-        [[[0], [1, 2, 3, 4], [5]], np.repeat(np.arange(6), [4, 4, 12, 2, 3, 4]), (4, 8, 4, 9, 4)],
+        [
+            [[0], [1, 2, 3, 4], [5]],
+            np.repeat(np.arange(6), [4, 4, 12, 2, 3, 4]),
+            (4, 8, 4, 9, 4),
+        ],
     ],
 )
 def test_find_group_cohorts(expected, labels, chunks: tuple[int]) -> None:
@@ -1039,11 +1073,14 @@ def test_fill_value_behaviour(func, chunks, fill_value, engine):
     if chunks:
         array = dask.array.from_array(array, chunks)
     actual, _ = groupby_reduce(
-        array, by, func=func, engine=engine, fill_value=fill_value, expected_groups=[0, 1, 2, 3]
-    )
-    expected = np.array(
-        [fill_value, fill_value, npfunc([1.0, 1.0], axis=0), npfunc([1.0, 1.0], axis=0)]
+        array,
+        by,
+        func=func,
+        engine=engine,
+        fill_value=fill_value,
+        expected_groups=[0, 1, 2, 3],
     )
+    expected = np.array([fill_value, fill_value, npfunc([1.0, 1.0], axis=0), npfunc([1.0, 1.0], axis=0)])
     assert_equal(actual, expected)
 
 
@@ -1140,7 +1177,12 @@ def test_dtype_promotion(func, fill_value, expected, engine):
     by = [0, 1]
 
     actual, _ = groupby_reduce(
-        array, by, func=func, expected_groups=[1, 2], fill_value=fill_value, engine=engine
+        array,
+        by,
+        func=func,
+        expected_groups=[1, 2],
+        fill_value=fill_value,
+        engine=engine,
     )
     assert np.issubdtype(actual.dtype, expected)
 
@@ -1259,9 +1301,7 @@ def test_group_by_datetime_cubed(engine, method):
     assert_equal(expected, actual)
 
     edges = pd.date_range("1999-12-31", "2000-12-31", freq="ME").to_series().to_numpy()
-    actual, _ = groupby_reduce(
-        cubedarray, t.to_numpy(), isbin=True, expected_groups=edges, **kwargs
-    )
+    actual, _ = groupby_reduce(cubedarray, t.to_numpy(), isbin=True, expected_groups=edges, **kwargs)
     expected = data.resample("ME").mean().to_numpy()
     assert_equal(expected, actual)
 
@@ -1316,9 +1356,7 @@ def test_multiple_groupers_bins(chunk) -> None:
 
 
 @pytest.mark.parametrize("expected_groups", [None, (np.arange(5), [2, 3]), (None, [2, 3])])
- at pytest.mark.parametrize(
-    "by1", [np.arange(5)[:, None], np.broadcast_to(np.arange(5)[:, None], (5, 2))]
-)
+ at pytest.mark.parametrize("by1", [np.arange(5)[:, None], np.broadcast_to(np.arange(5)[:, None], (5, 2))])
 @pytest.mark.parametrize(
     "by2",
     [
@@ -1341,9 +1379,7 @@ def test_multiple_groupers(chunk, by1, by2, expected_groups) -> None:
 
     # output from `count` is intp
     expected = np.ones((5, 2), dtype=np.intp)
-    actual, *_ = groupby_reduce(
-        array, by1, by2, axis=(0, 1), func="count", expected_groups=expected_groups
-    )
+    actual, *_ = groupby_reduce(array, by1, by2, axis=(0, 1), func="count", expected_groups=expected_groups)
     assert_equal(expected, actual)
 
 
@@ -1435,9 +1471,7 @@ def test_custom_aggregation_blockwise():
             dtype=dtype,
         )
 
-    agg_median = Aggregation(
-        name="median", numpy=grouped_median, fill_value=-1, chunk=None, combine=None
-    )
+    agg_median = Aggregation(name="median", numpy=grouped_median, fill_value=-1, chunk=None, combine=None)
 
     array = np.arange(100, dtype=np.float32).reshape(5, 20)
     by = np.ones((20,))
@@ -1480,7 +1514,12 @@ def test_dtype(func, dtype, engine):
     arr = np.ones((4, 12), dtype=dtype)
     labels = np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"])
     actual, _ = groupby_reduce(
-        arr, labels, func=func, dtype=np.float64, engine=engine, finalize_kwargs=finalize_kwargs
+        arr,
+        labels,
+        func=func,
+        dtype=np.float64,
+        engine=engine,
+        finalize_kwargs=finalize_kwargs,
     )
     assert actual.dtype == np.dtype("float64")
 
@@ -1587,9 +1626,7 @@ def test_subset_block_2d(flatblocks, expectidx):
         [True, None, "sum", ([1], None), True],
     ],
 )
-def test_validate_reindex_map_reduce(
-    dask_expected, reindex, func, expected_groups, any_by_dask
-) -> None:
+def test_validate_reindex_map_reduce(dask_expected, reindex, func, expected_groups, any_by_dask) -> None:
     actual = _validate_reindex(
         reindex,
         func,
@@ -1720,12 +1757,20 @@ def test_1d_blockwise_sort_optimization():
     assert all("getitem" not in k for k in actual.dask)
 
     actual, _ = groupby_reduce(
-        array, time.dt.dayofyear.values[::-1], sort=True, method="blockwise", func="count"
+        array,
+        time.dt.dayofyear.values[::-1],
+        sort=True,
+        method="blockwise",
+        func="count",
     )
     assert any("getitem" in k for k in actual.dask.layers)
 
     actual, _ = groupby_reduce(
-        array, time.dt.dayofyear.values[::-1], sort=False, method="blockwise", func="count"
+        array,
+        time.dt.dayofyear.values[::-1],
+        sort=False,
+        method="blockwise",
+        func="count",
     )
     assert all("getitem" not in k for k in actual.dask.layers)
 
@@ -1760,9 +1805,7 @@ def test_negative_index_factorize_race_condition():
 @pytest.mark.parametrize("sort", [True, False])
 def test_expected_index_conversion_passthrough_range_index(sort):
     index = pd.RangeIndex(100)
-    actual = _convert_expected_groups_to_index(
-        expected_groups=(index,), isbin=(False,), sort=(sort,)
-    )
+    actual = _convert_expected_groups_to_index(expected_groups=(index,), isbin=(False,), sort=(sort,))
     assert actual[0] is index
 
 
@@ -1935,11 +1978,22 @@ def test_ffill_bfill(chunks, size, add_nan_by, func):
 def test_blockwise_nans():
     array = dask.array.ones((1, 10), chunks=2)
     by = np.array([-1, 0, -1, 1, -1, 2, -1, 3, 4, 4])
-    actual, actual_groups = flox.groupby_reduce(
-        array, by, func="sum", expected_groups=pd.RangeIndex(0, 5)
-    )
+    actual, actual_groups = flox.groupby_reduce(array, by, func="sum", expected_groups=pd.RangeIndex(0, 5))
     expected, expected_groups = flox.groupby_reduce(
         array.compute(), by, func="sum", expected_groups=pd.RangeIndex(0, 5)
     )
     assert_equal(expected_groups, actual_groups)
     assert_equal(expected, actual)
+
+
+ at pytest.mark.parametrize("func", ["sum", "prod", "count", "nansum"])
+ at pytest.mark.parametrize("engine", ["flox", "numpy"])
+def test_agg_dtypes(func, engine):
+    # regression test for GH388
+    counts = np.array([0, 2, 1, 0, 1])
+    group = np.array([1, 1, 1, 2, 2])
+    actual, _ = groupby_reduce(
+        counts, group, expected_groups=(np.array([1, 2]),), func=func, dtype="uint8", engine=engine
+    )
+    expected = _get_array_func(func)(counts, dtype="uint8")
+    assert actual.dtype == np.uint8 == expected.dtype


=====================================
tests/test_properties.py
=====================================
@@ -1,5 +1,6 @@
 import warnings
-from typing import Any, Callable
+from collections.abc import Callable
+from typing import Any
 
 import pandas as pd
 import pytest
@@ -19,7 +20,7 @@ from flox.core import groupby_reduce, groupby_scan
 from flox.xrutils import notnull
 
 from . import assert_equal
-from .strategies import by_arrays, chunked_arrays, func_st, numeric_arrays
+from .strategies import array_dtypes, by_arrays, chunked_arrays, func_st, numeric_arrays
 from .strategies import chunks as chunks_strategy
 
 dask.config.set(scheduler="sync")
@@ -92,18 +93,18 @@ def test_groupby_reduce(data, array, func: str) -> None:
     flox_kwargs: dict[str, Any] = {}
     with np.errstate(invalid="ignore", divide="ignore"):
         actual, *_ = groupby_reduce(
-            array, by, func=func, axis=axis, engine="numpy", **flox_kwargs, finalize_kwargs=kwargs
+            array,
+            by,
+            func=func,
+            axis=axis,
+            engine="numpy",
+            **flox_kwargs,
+            finalize_kwargs=kwargs,
         )
 
         # numpy-groupies always does the calculation in float64
         if (
-            (
-                "var" in func
-                or "std" in func
-                or "sum" in func
-                or "mean" in func
-                or "quantile" in func
-            )
+            ("var" in func or "std" in func or "sum" in func or "mean" in func or "quantile" in func)
             and array.dtype.kind == "f"
             and array.dtype.itemsize != 8
         ):
@@ -195,8 +196,18 @@ def test_ffill_bfill_reverse(data, array: dask.array.Array) -> None:
 def test_first_last(data, array: dask.array.Array, func: str) -> None:
     by = data.draw(by_arrays(shape=(array.shape[-1],)))
 
-    INVERSES = {"first": "last", "last": "first", "nanfirst": "nanlast", "nanlast": "nanfirst"}
-    MATES = {"first": "nanfirst", "last": "nanlast", "nanfirst": "first", "nanlast": "last"}
+    INVERSES = {
+        "first": "last",
+        "last": "first",
+        "nanfirst": "nanlast",
+        "nanlast": "nanfirst",
+    }
+    MATES = {
+        "first": "nanfirst",
+        "last": "nanlast",
+        "nanfirst": "first",
+        "nanlast": "last",
+    }
     inverse = INVERSES[func]
     mate = MATES[func]
 
@@ -233,3 +244,25 @@ def test_first_last_useless(data, func):
     actual, groups = groupby_reduce(array, by, axis=-1, func=func, engine="numpy")
     expected = np.zeros(shape[:-1] + (len(groups),), dtype=array.dtype)
     assert_equal(actual, expected)
+
+
+ at given(
+    func=st.sampled_from(["sum", "prod", "nansum", "nanprod"]),
+    engine=st.sampled_from(["numpy", "flox"]),
+    array_dtype=st.none() | array_dtypes,
+    dtype=st.none() | array_dtypes,
+)
+def test_agg_dtype_specified(func, array_dtype, dtype, engine):
+    # regression test for GH388
+    counts = np.array([0, 2, 1, 0, 1], dtype=array_dtype)
+    group = np.array([1, 1, 1, 2, 2])
+    actual, _ = groupby_reduce(
+        counts,
+        group,
+        expected_groups=(np.array([1, 2]),),
+        func=func,
+        dtype=dtype,
+        engine=engine,
+    )
+    expected = getattr(np, func)(counts, keepdims=True, dtype=dtype)
+    assert actual.dtype == expected.dtype


=====================================
tests/test_xarray.py
=====================================
@@ -24,7 +24,7 @@ if has_dask:
 
 # test against legacy xarray implementation
 # avoid some compilation overhead
-xr.set_options(use_flox=False, use_numbagg=False)
+xr.set_options(use_flox=False, use_numbagg=False, use_bottleneck=False)
 tolerance64 = {"rtol": 1e-15, "atol": 1e-18}
 np.random.seed(123)
 
@@ -49,7 +49,9 @@ def test_xarray_reduce(skipna, add_nan, min_count, engine_no_numba, reindex):
     labels2 = np.array([1, 2, 2, 1])
 
     da = xr.DataArray(
-        arr, dims=("x", "y"), coords={"labels2": ("x", labels2), "labels": ("y", labels)}
+        arr,
+        dims=("x", "y"),
+        coords={"labels2": ("x", labels2), "labels": ("y", labels)},
     ).expand_dims(z=4)
 
     expected = da.groupby("labels").sum(skipna=skipna, min_count=min_count)
@@ -98,7 +100,9 @@ def test_xarray_reduce_multiple_groupers(pass_expected_groups, chunk, engine_no_
     labels2 = np.array([1, 2, 2, 1])
 
     da = xr.DataArray(
-        arr, dims=("x", "y"), coords={"labels2": ("x", labels2), "labels": ("y", labels)}
+        arr,
+        dims=("x", "y"),
+        coords={"labels2": ("x", labels2), "labels": ("y", labels)},
     ).expand_dims(z=4)
 
     if chunk:
@@ -177,9 +181,7 @@ def test_xarray_reduce_multiple_groupers_2(pass_expected_groups, chunk, engine_n
     (None, (None, None), [[1, 2], [1, 2]]),
 )
 def test_validate_expected_groups(expected_groups):
-    da = xr.DataArray(
-        [1.0, 2.0], dims="x", coords={"labels": ("x", [1, 2]), "labels2": ("x", [1, 2])}
-    )
+    da = xr.DataArray([1.0, 2.0], dims="x", coords={"labels": ("x", [1, 2]), "labels2": ("x", [1, 2])})
     with pytest.raises(ValueError):
         xarray_reduce(
             da.chunk({"x": 1}),
@@ -196,12 +198,13 @@ def test_xarray_reduce_single_grouper(engine_no_numba):
     engine = engine_no_numba
     # DataArray
     ds = xr.Dataset(
-        {"Tair": (("time", "x", "y"), dask.array.ones((36, 205, 275), chunks=(9, -1, -1)))},
-        coords={
-            "time": xr.date_range(
-                "1980-09-01 00:00", "1983-09-18 00:00", freq="ME", calendar="noleap"
+        {
+            "Tair": (
+                ("time", "x", "y"),
+                dask.array.ones((36, 205, 275), chunks=(9, -1, -1)),
             )
         },
+        coords={"time": xr.date_range("1980-09-01 00:00", "1983-09-18 00:00", freq="ME", calendar="noleap")},
     )
     actual = xarray_reduce(ds.Tair, ds.time.dt.month, func="mean", engine=engine)
     expected = ds.Tair.groupby("time.month").mean()
@@ -380,12 +383,13 @@ def test_func_is_aggregation():
     from flox.aggregations import mean
 
     ds = xr.Dataset(
-        {"Tair": (("time", "x", "y"), dask.array.ones((36, 205, 275), chunks=(9, -1, -1)))},
-        coords={
-            "time": xr.date_range(
-                "1980-09-01 00:00", "1983-09-18 00:00", freq="ME", calendar="noleap"
+        {
+            "Tair": (
+                ("time", "x", "y"),
+                dask.array.ones((36, 205, 275), chunks=(9, -1, -1)),
             )
         },
+        coords={"time": xr.date_range("1980-09-01 00:00", "1983-09-18 00:00", freq="ME", calendar="noleap")},
     )
     expected = xarray_reduce(ds.Tair, ds.time.dt.month, func="mean")
     actual = xarray_reduce(ds.Tair, ds.time.dt.month, func=mean)
@@ -520,7 +524,10 @@ def test_dtype(add_nan, chunk, dtype, dtype_out, engine_no_numba):
         data,
         dims=("x", "t"),
         coords={
-            "labels": ("t", np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"]))
+            "labels": (
+                "t",
+                np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"]),
+            )
         },
         name="arr",
     )
@@ -642,7 +649,11 @@ def test_fill_value_xarray_binning():
 def test_groupby_2d_dataset():
     d = {
         "coords": {
-            "bit_index": {"dims": ("bit_index",), "attrs": {"name": "bit_index"}, "data": [0, 1]},
+            "bit_index": {
+                "dims": ("bit_index",),
+                "attrs": {"name": "bit_index"},
+                "data": [0, 1],
+            },
             "index": {"dims": ("index",), "data": [0, 6, 8, 10, 14]},
             "clifford": {"dims": ("index",), "attrs": {}, "data": [1, 1, 4, 10, 4]},
         },
@@ -664,18 +675,14 @@ def test_groupby_2d_dataset():
         expected = ds.groupby("clifford").mean()
     with xr.set_options(use_flox=True):
         actual = ds.groupby("clifford").mean()
-    assert (
-        expected.counts.dims == actual.counts.dims
-    )  # https://github.com/pydata/xarray/issues/8292
+    assert expected.counts.dims == actual.counts.dims  # https://github.com/pydata/xarray/issues/8292
     xr.testing.assert_identical(expected, actual)
 
 
 @pytest.mark.parametrize("chunk", (pytest.param(True, marks=requires_dask), False))
 def test_resampling_missing_groups(chunk):
     # Regression test for https://github.com/pydata/xarray/issues/8592
-    time_coords = pd.to_datetime(
-        ["2018-06-13T03:40:36", "2018-06-13T05:50:37", "2018-06-15T03:02:34"]
-    )
+    time_coords = pd.to_datetime(["2018-06-13T03:40:36", "2018-06-13T05:50:37", "2018-06-15T03:02:34"])
 
     latitude_coords = [0.0]
     longitude_coords = [0.0]
@@ -684,7 +691,11 @@ def test_resampling_missing_groups(chunk):
 
     da = xr.DataArray(
         data,
-        coords={"time": time_coords, "latitude": latitude_coords, "longitude": longitude_coords},
+        coords={
+            "time": time_coords,
+            "latitude": latitude_coords,
+            "longitude": longitude_coords,
+        },
         dims=["time", "latitude", "longitude"],
     )
     if chunk:
@@ -749,3 +760,26 @@ def test_direct_reduction(func):
     with xr.set_options(use_flox=False):
         expected = getattr(data.groupby("x", squeeze=False), func)(**kwargs)
     xr.testing.assert_identical(expected, actual)
+
+
+ at pytest.mark.parametrize("reduction", ["max", "min", "nanmax", "nanmin", "sum", "nansum", "prod", "nanprod"])
+def test_groupby_preserve_dtype(reduction):
+    # all groups are present, we should follow numpy exactly
+    ds = xr.Dataset(
+        {
+            "test": (
+                ["x", "y"],
+                np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype="int16"),
+            )
+        },
+        coords={"idx": ("x", [1, 2, 1])},
+    )
+
+    kwargs = {"engine": "numpy"}
+    if "nan" in reduction:
+        kwargs["skipna"] = True
+    with xr.set_options(use_flox=True):
+        actual = getattr(ds.groupby("idx"), reduction.removeprefix("nan"))(**kwargs).test.dtype
+    expected = getattr(np, reduction)(ds.test.data, axis=0).dtype
+
+    assert actual == expected



View it on GitLab: https://salsa.debian.org/debian-gis-team/flox/-/commit/4a49f04c97f9feb2a15fd16224e7c3f9cbb19d18

-- 
View it on GitLab: https://salsa.debian.org/debian-gis-team/flox/-/commit/4a49f04c97f9feb2a15fd16224e7c3f9cbb19d18
You're receiving this email because of your account on salsa.debian.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://alioth-lists.debian.net/pipermail/pkg-grass-devel/attachments/20240913/d1971efd/attachment-0001.htm>


More information about the Pkg-grass-devel mailing list