[Git][debian-gis-team/flox][upstream] New upstream version 0.9.11
Antonio Valentino (@antonio.valentino)
gitlab at salsa.debian.org
Fri Sep 13 21:55:24 BST 2024
Antonio Valentino pushed to branch upstream at Debian GIS Project / flox
Commits:
4a49f04c by Antonio Valentino at 2024-09-13T18:14:23+00:00
New upstream version 0.9.11
- - - - -
25 changed files:
- .github/workflows/ci.yaml
- .pre-commit-config.yaml
- asv_bench/benchmarks/__init__.py
- asv_bench/benchmarks/cohorts.py
- asv_bench/benchmarks/combine.py
- asv_bench/benchmarks/reduce.py
- ci/docs.yml
- ci/environment.yml
- ci/no-dask.yml
- docs/source/conf.py
- docs/source/user-stories/climatology-hourly.ipynb
- flox/__init__.py
- flox/aggregate_flox.py
- flox/aggregations.py
- flox/core.py
- flox/xarray.py
- flox/xrdtypes.py
- flox/xrutils.py
- pyproject.toml
- tests/__init__.py
- tests/strategies.py
- tests/test_asv.py
- tests/test_core.py
- tests/test_properties.py
- tests/test_xarray.py
Changes:
=====================================
.github/workflows/ci.yaml
=====================================
@@ -26,7 +26,7 @@ jobs:
matrix:
os: ["ubuntu-latest"]
env: ["environment"]
- python-version: ["3.9", "3.12"]
+ python-version: ["3.10", "3.12"]
include:
- os: "windows-latest"
env: "environment"
@@ -36,7 +36,7 @@ jobs:
python-version: "3.12"
- os: "ubuntu-latest"
env: "minimal-requirements"
- python-version: "3.9"
+ python-version: "3.10"
steps:
- uses: actions/checkout at v4
with:
@@ -70,6 +70,7 @@ jobs:
- name: Run Tests
id: status
run: |
+ python -c "import xarray; xarray.show_versions()"
pytest --durations=20 --durations-min=0.5 -n auto --cov=./ --cov-report=xml --hypothesis-profile ci
- name: Upload code coverage to Codecov
uses: codecov/codecov-action at v4.5.0
@@ -98,7 +99,7 @@ jobs:
steps:
- uses: actions/checkout at v4
with:
- repository: "pydata/xarray"
+ repository: "dcherian/xarray"
fetch-depth: 0 # Fetch all history for all branches and tags.
- name: Set up conda environment
uses: mamba-org/setup-micromamba at v1
@@ -112,6 +113,7 @@ jobs:
pint>=0.22
- name: Install xarray
run: |
+ git checkout flox-preserve-dtype
python -m pip install --no-deps .
- name: Install upstream flox
run: |
=====================================
.pre-commit-config.yaml
=====================================
@@ -4,10 +4,11 @@ ci:
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
- rev: "v0.5.0"
+ rev: "v0.6.4"
hooks:
- id: ruff
args: ["--fix", "--show-fixes"]
+ - id: ruff-format
- repo: https://github.com/pre-commit/mirrors-prettier
rev: "v4.0.0-alpha.8"
@@ -22,11 +23,6 @@ repos:
- id: end-of-file-fixer
- id: check-docstring-first
- - repo: https://github.com/psf/black-pre-commit-mirror
- rev: 24.4.2
- hooks:
- - id: black
-
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.17
hooks:
@@ -35,13 +31,6 @@ repos:
- mdformat-black
- mdformat-myst
- - repo: https://github.com/nbQA-dev/nbQA
- rev: 1.8.5
- hooks:
- - id: nbqa-black
- - id: nbqa-ruff
- args: [--fix]
-
- repo: https://github.com/kynan/nbstripout
rev: 0.7.1
hooks:
@@ -56,7 +45,7 @@ repos:
- tomli
- repo: https://github.com/abravalheri/validate-pyproject
- rev: v0.18
+ rev: v0.19
hooks:
- id: validate-pyproject
=====================================
asv_bench/benchmarks/__init__.py
=====================================
@@ -21,7 +21,6 @@ def _skip_slow():
>>> from . import _skip_slow
>>> def time_something_slow():
... pass
- ...
>>> time_something.setup = _skip_slow
"""
if os.environ.get("ASV_SKIP_SLOW", "0") == "1":
=====================================
asv_bench/benchmarks/cohorts.py
=====================================
@@ -67,7 +67,12 @@ class Cohorts:
track_num_tasks.unit = "tasks" # type: ignore[attr-defined] # Lazy
track_num_tasks_optimized.unit = "tasks" # type: ignore[attr-defined] # Lazy
track_num_layers.unit = "layers" # type: ignore[attr-defined] # Lazy
- for f in [track_num_tasks, track_num_tasks_optimized, track_num_layers, track_num_cohorts]:
+ for f in [
+ track_num_tasks,
+ track_num_tasks_optimized,
+ track_num_layers,
+ track_num_cohorts,
+ ]:
f.repeat = 1 # type: ignore[attr-defined] # Lazy
f.rounds = 1 # type: ignore[attr-defined] # Lazy
f.number = 1 # type: ignore[attr-defined] # Lazy
@@ -82,9 +87,7 @@ class NWMMidwest(Cohorts):
y = np.repeat(np.arange(30), 60)
by = x[np.newaxis, :] * y[:, np.newaxis]
- self.by = flox.core._factorize_multiple((by,), expected_groups=(None,), any_by_dask=False)[
- 0
- ][0]
+ self.by = flox.core._factorize_multiple((by,), expected_groups=(None,), any_by_dask=False)[0][0]
self.array = dask.array.ones(self.by.shape, chunks=(350, 350))
self.axis = (-2, -1)
@@ -101,7 +104,12 @@ class ERA5Dataset:
def rechunk(self):
self.array = flox.core.rechunk_for_cohorts(
- self.array, -1, self.by, force_new_chunk_at=[1], chunksize=48, ignore_old_chunks=True
+ self.array,
+ -1,
+ self.by,
+ force_new_chunk_at=[1],
+ chunksize=48,
+ ignore_old_chunks=True,
)
@@ -151,7 +159,12 @@ class PerfectMonthly(Cohorts):
def rechunk(self):
self.array = flox.core.rechunk_for_cohorts(
- self.array, -1, self.by, force_new_chunk_at=[1], chunksize=4, ignore_old_chunks=True
+ self.array,
+ -1,
+ self.by,
+ force_new_chunk_at=[1],
+ chunksize=4,
+ ignore_old_chunks=True,
)
=====================================
asv_bench/benchmarks/combine.py
=====================================
@@ -65,12 +65,8 @@ class Combine1d(Combine):
* 2
]
- self.x_chunk_reindexed = [
- construct_member(groups) for groups in [np.array((1, 2, 3, 4))] * 4
- ]
+ self.x_chunk_reindexed = [construct_member(groups) for groups in [np.array((1, 2, 3, 4))] * 4]
self.kwargs = {
- "agg": flox.aggregations._initialize_aggregation(
- "sum", "float64", np.float64, 0, 0, {}
- ),
+ "agg": flox.aggregations._initialize_aggregation("sum", "float64", np.float64, 0, 0, {}),
"axis": (3,),
}
=====================================
asv_bench/benchmarks/reduce.py
=====================================
@@ -7,7 +7,11 @@ import flox.aggregations
N = 3000
funcs = ["sum", "nansum", "mean", "nanmean", "max", "nanmax", "count"]
-engines = [None, "flox", "numpy"] # numbagg is disabled for now since it takes ages in CI
+engines = [
+ None,
+ "flox",
+ "numpy",
+] # numbagg is disabled for now since it takes ages in CI
expected_groups = {
"None": None,
"bins": pd.IntervalIndex.from_breaks([1, 2, 4]),
@@ -17,9 +21,7 @@ expected_names = tuple(expected_groups)
NUMBAGG_FUNCS = ["nansum", "nanmean", "nanmax", "count", "all"]
numbagg_skip = []
for name in expected_names:
- numbagg_skip.extend(
- list((func, name, "numbagg") for func in funcs if func not in NUMBAGG_FUNCS)
- )
+ numbagg_skip.extend(list((func, name, "numbagg") for func in funcs if func not in NUMBAGG_FUNCS))
def setup_jit():
=====================================
ci/docs.yml
=====================================
@@ -16,6 +16,7 @@ dependencies:
- myst-parser
- myst-nb
- sphinx
+ - sphinx-remove-toctrees
- furo>=2024.08
- ipykernel
- jupyter
=====================================
ci/environment.yml
=====================================
@@ -19,7 +19,6 @@ dependencies:
- pytest-pretty
- pytest-xdist
- syrupy
- - xarray
- pre-commit
- numpy_groupies>=0.9.19
- pooch
@@ -27,3 +26,5 @@ dependencies:
- numba
- numbagg>=0.3
- hypothesis
+ - pip:
+ - git+https://github.com/dcherian/xarray.git@flox-preserve-dtype
=====================================
ci/no-dask.yml
=====================================
@@ -14,7 +14,6 @@ dependencies:
- pytest-pretty
- pytest-xdist
- syrupy
- - xarray
- numpydoc
- pre-commit
- numpy_groupies>=0.9.19
@@ -22,3 +21,5 @@ dependencies:
- toolz
- numba
- numbagg>=0.3
+ - pip:
+ - git+https://github.com/dcherian/xarray.git@flox-preserve-dtype
=====================================
docs/source/conf.py
=====================================
@@ -40,6 +40,7 @@ extensions = [
"sphinx.ext.napoleon",
"myst_nb",
"sphinx_codeautolink",
+ "sphinx_remove_toctrees",
]
codeautolink_concat_default = True
@@ -54,6 +55,8 @@ source_suffix = [".rst"]
master_doc = "index"
language = "en"
+remove_from_toctrees = ["generated/*"]
+
# General information about the project.
project = "flox"
current_year = datetime.datetime.now().year
=====================================
docs/source/user-stories/climatology-hourly.ipynb
=====================================
@@ -92,7 +92,6 @@
"%load_ext watermark\n",
"\n",
"\n",
- "\n",
"%watermark -iv"
]
},
=====================================
flox/__init__.py
=====================================
@@ -1,9 +1,15 @@
#!/usr/bin/env python
# flake8: noqa
"""Top-level module for flox ."""
+
from . import cache
from .aggregations import Aggregation, Scan # noqa
-from .core import groupby_reduce, groupby_scan, rechunk_for_blockwise, rechunk_for_cohorts # noqa
+from .core import (
+ groupby_reduce,
+ groupby_scan,
+ rechunk_for_blockwise,
+ rechunk_for_cohorts,
+) # noqa
def _get_version():
=====================================
flox/aggregate_flox.py
=====================================
@@ -89,10 +89,14 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non
idxshape = (q.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],)
lo_ = np.floor(
- virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64)
+ virtual_index,
+ casting="unsafe",
+ out=np.empty(virtual_index.shape, dtype=np.int64),
)
hi_ = np.ceil(
- virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64)
+ virtual_index,
+ casting="unsafe",
+ out=np.empty(virtual_index.shape, dtype=np.int64),
)
kth = np.unique(np.concatenate([lo_.reshape(-1), hi_.reshape(-1)]))
@@ -119,7 +123,15 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non
def _np_grouped_op(
- group_idx, array, op, axis=-1, size=None, fill_value=None, dtype=None, out=None, **kwargs
+ group_idx,
+ array,
+ op,
+ axis=-1,
+ size=None,
+ fill_value=None,
+ dtype=None,
+ out=None,
+ **kwargs,
):
"""
most of this code is from shoyer's gist
=====================================
flox/aggregations.py
=====================================
@@ -3,10 +3,10 @@ from __future__ import annotations
import copy
import logging
import warnings
-from collections.abc import Sequence
+from collections.abc import Callable, Sequence
from dataclasses import dataclass
from functools import cached_property, partial
-from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict
+from typing import TYPE_CHECKING, Any, Literal, TypedDict
import numpy as np
import pandas as pd
@@ -110,7 +110,13 @@ def generic_aggregate(
with warnings.catch_warnings():
warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered")
result = method(
- group_idx, array, axis=axis, size=size, fill_value=fill_value, dtype=dtype, **kwargs
+ group_idx,
+ array,
+ axis=axis,
+ size=size,
+ fill_value=fill_value,
+ dtype=dtype,
+ **kwargs,
)
return result
@@ -238,9 +244,7 @@ class Aggregation:
# The following are set by _initialize_aggregation
self.finalize_kwargs: dict[Any, Any] = {}
self.min_count: int = 0
- self.new_dims_func: Callable = (
- returns_empty_tuple if new_dims_func is None else new_dims_func
- )
+ self.new_dims_func: Callable = returns_empty_tuple if new_dims_func is None else new_dims_func
self.preserves_dtype = preserves_dtype
@cached_property
@@ -386,11 +390,19 @@ nanstd = Aggregation(
min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF, preserves_dtype=True)
nanmin = Aggregation(
- "nanmin", chunk="nanmin", combine="nanmin", fill_value=dtypes.NA, preserves_dtype=True
+ "nanmin",
+ chunk="nanmin",
+ combine="nanmin",
+ fill_value=dtypes.NA,
+ preserves_dtype=True,
)
max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF, preserves_dtype=True)
nanmax = Aggregation(
- "nanmax", chunk="nanmax", combine="nanmax", fill_value=dtypes.NA, preserves_dtype=True
+ "nanmax",
+ chunk="nanmax",
+ combine="nanmax",
+ fill_value=dtypes.NA,
+ preserves_dtype=True,
)
@@ -482,10 +494,18 @@ nanargmin = Aggregation(
first = Aggregation("first", chunk=None, combine=None, fill_value=None, preserves_dtype=True)
last = Aggregation("last", chunk=None, combine=None, fill_value=None, preserves_dtype=True)
nanfirst = Aggregation(
- "nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=dtypes.NA, preserves_dtype=True
+ "nanfirst",
+ chunk="nanfirst",
+ combine="nanfirst",
+ fill_value=dtypes.NA,
+ preserves_dtype=True,
)
nanlast = Aggregation(
- "nanlast", chunk="nanlast", combine="nanlast", fill_value=dtypes.NA, preserves_dtype=True
+ "nanlast",
+ chunk="nanlast",
+ combine="nanlast",
+ fill_value=dtypes.NA,
+ preserves_dtype=True,
)
all_ = Aggregation(
@@ -510,10 +530,18 @@ any_ = Aggregation(
# Support statistical quantities only blockwise
# The parallel versions will be approximate and are hard to implement!
median = Aggregation(
- name="median", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.floating
+ name="median",
+ fill_value=dtypes.NA,
+ chunk=None,
+ combine=None,
+ final_dtype=np.floating,
)
nanmedian = Aggregation(
- name="nanmedian", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.floating
+ name="nanmedian",
+ fill_value=dtypes.NA,
+ chunk=None,
+ combine=None,
+ final_dtype=np.floating,
)
@@ -521,12 +549,15 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
return (Dim(name="quantile", values=q),)
+# if the input contains integers or floats smaller than float64,
+# the output data-type is float64. Otherwise, the output data-type is the same as that
+# of the input.
quantile = Aggregation(
name="quantile",
fill_value=dtypes.NA,
chunk=None,
combine=None,
- final_dtype=np.floating,
+ final_dtype=np.float64,
new_dims_func=quantile_new_dims_func,
)
nanquantile = Aggregation(
@@ -534,15 +565,11 @@ nanquantile = Aggregation(
fill_value=dtypes.NA,
chunk=None,
combine=None,
- final_dtype=np.floating,
+ final_dtype=np.float64,
new_dims_func=quantile_new_dims_func,
)
-mode = Aggregation(
- name="mode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True
-)
-nanmode = Aggregation(
- name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True
-)
+mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True)
+nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True)
@dataclass
@@ -658,9 +685,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan)
engine="flox",
fill_value=agg.identity,
)
- result = AlignedArrays(
- array=final_value[..., left.group_idx.size :], group_idx=right.group_idx
- )
+ result = AlignedArrays(array=final_value[..., left.group_idx.size :], group_idx=right.group_idx)
else:
raise ValueError(f"Unknown binary op application mode: {agg.mode!r}")
@@ -780,10 +805,8 @@ def _initialize_aggregation(
np.dtype(dtype) if dtype is not None and not isinstance(dtype, np.dtype) else dtype
)
final_dtype = dtypes._normalize_dtype(
- dtype_ or agg.dtype_init["final"], array_dtype, fill_value
+ dtype_ or agg.dtype_init["final"], array_dtype, agg.preserves_dtype, fill_value
)
- if not agg.preserves_dtype:
- final_dtype = dtypes._maybe_promote_int(final_dtype)
agg.dtype = {
"user": dtype, # Save to automatically choose an engine
"final": final_dtype,
@@ -794,9 +817,7 @@ def _initialize_aggregation(
if int_dtype is None
else np.dtype(int_dtype)
)
- for int_dtype, int_fv in zip(
- agg.dtype_init["intermediate"], agg.fill_value["intermediate"]
- )
+ for int_dtype, int_fv in zip(agg.dtype_init["intermediate"], agg.fill_value["intermediate"])
),
}
=====================================
flox/core.py
=====================================
@@ -8,7 +8,7 @@ import operator
import sys
import warnings
from collections import namedtuple
-from collections.abc import Sequence
+from collections.abc import Callable, Sequence
from concurrent.futures import ThreadPoolExecutor
from functools import partial, reduce
from itertools import product
@@ -16,8 +16,8 @@ from numbers import Integral
from typing import (
TYPE_CHECKING,
Any,
- Callable,
Literal,
+ TypeAlias,
TypedDict,
TypeVar,
Union,
@@ -74,37 +74,37 @@ if TYPE_CHECKING:
import dask.array.Array as DaskArray
from dask.typing import Graph
- T_DuckArray = Union[np.ndarray, DaskArray, CubedArray] # Any ?
- T_By = T_DuckArray
+ T_DuckArray: TypeAlias = np.ndarray | DaskArray | CubedArray # Any ?
+ T_By: TypeAlias = T_DuckArray
T_Bys = tuple[T_By, ...]
T_ExpectIndex = pd.Index
T_ExpectIndexTuple = tuple[T_ExpectIndex, ...]
- T_ExpectIndexOpt = Union[T_ExpectIndex, None]
+ T_ExpectIndexOpt = T_ExpectIndex | None
T_ExpectIndexOptTuple = tuple[T_ExpectIndexOpt, ...]
- T_Expect = Union[Sequence, np.ndarray, T_ExpectIndex]
+ T_Expect = Sequence | np.ndarray | T_ExpectIndex
T_ExpectTuple = tuple[T_Expect, ...]
- T_ExpectOpt = Union[Sequence, np.ndarray, T_ExpectIndexOpt]
+ T_ExpectOpt = Sequence | np.ndarray | T_ExpectIndexOpt
T_ExpectOptTuple = tuple[T_ExpectOpt, ...]
- T_ExpectedGroups = Union[T_Expect, T_ExpectOptTuple]
- T_ExpectedGroupsOpt = Union[T_ExpectedGroups, None]
- T_Func = Union[str, Callable]
- T_Funcs = Union[T_Func, Sequence[T_Func]]
- T_Agg = Union[str, Aggregation]
- T_Scan = Union[str, Scan]
+ T_ExpectedGroups = T_Expect | T_ExpectOptTuple
+ T_ExpectedGroupsOpt = T_ExpectedGroups | None
+ T_Func = str | Callable
+ T_Funcs = T_Func | Sequence[T_Func]
+ T_Agg = str | Aggregation
+ T_Scan = str | Scan
T_Axis = int
T_Axes = tuple[T_Axis, ...]
- T_AxesOpt = Union[T_Axis, T_Axes, None]
- T_Dtypes = Union[np.typing.DTypeLike, Sequence[np.typing.DTypeLike], None]
- T_FillValues = Union[np.typing.ArrayLike, Sequence[np.typing.ArrayLike], None]
+ T_AxesOpt = T_Axis | T_Axes | None
+ T_Dtypes = np.typing.DTypeLike | Sequence[np.typing.DTypeLike] | None
+ T_FillValues = np.typing.ArrayLike | Sequence[np.typing.ArrayLike] | None
T_Engine = Literal["flox", "numpy", "numba", "numbagg"]
T_EngineOpt = None | T_Engine
T_Method = Literal["map-reduce", "blockwise", "cohorts"]
T_MethodOpt = None | Literal["map-reduce", "blockwise", "cohorts"]
- T_IsBins = Union[bool | Sequence[bool]]
+ T_IsBins = bool | Sequence[bool]
T = TypeVar("T")
-IntermediateDict = dict[Union[str, Callable], Any]
+IntermediateDict = dict[str | Callable, Any]
FinalResultsDict = dict[str, Union["DaskArray", "CubedArray", np.ndarray]]
FactorProps = namedtuple("FactorProps", "offset_group nan_sentinel nanmask")
@@ -136,9 +136,7 @@ def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups):
# The condition needs to be
# len(found_groups) < size; if so we mask with fill_value (?)
default_fv = DEFAULT_FILL_VALUE[func]
- needs_masking = fill_value is not None and not np.array_equal(
- fill_value, default_fv, equal_nan=True
- )
+ needs_masking = fill_value is not None and not np.array_equal(fill_value, default_fv, equal_nan=True)
groups = np.arange(size)
if needs_masking:
mask = np.isin(groups, seen_groups, assume_unique=True, invert=True)
@@ -164,9 +162,7 @@ def _is_arg_reduction(func: T_Agg) -> bool:
def _is_minmax_reduction(func: T_Agg) -> bool:
- return not _is_arg_reduction(func) and (
- isinstance(func, str) and ("max" in func or "min" in func)
- )
+ return not _is_arg_reduction(func) and (isinstance(func, str) and ("max" in func or "min" in func))
def _is_first_last_reduction(func: T_Agg) -> bool:
@@ -254,8 +250,7 @@ def slices_from_chunks(chunks):
"""slightly modified from dask.array.core.slices_from_chunks to be lazy"""
cumdims = [tlz.accumulate(operator.add, bds, 0) for bds in chunks]
slices = (
- (slice(s, s + dim) for s, dim in zip(starts, shapes))
- for starts, shapes in zip(cumdims, chunks)
+ (slice(s, s + dim) for s, dim in zip(starts, shapes)) for starts, shapes in zip(cumdims, chunks)
)
return product(*slices)
@@ -396,9 +391,7 @@ def find_group_cohorts(
chunks_per_label = chunks_per_label[present_labels_mask]
label_chunks = {
- present_labels[idx].item(): bitmask.indices[
- slice(bitmask.indptr[idx], bitmask.indptr[idx + 1])
- ]
+ present_labels[idx].item(): bitmask.indices[slice(bitmask.indptr[idx], bitmask.indptr[idx + 1])]
for idx in range(bitmask.shape[LABEL_AXIS])
}
@@ -510,9 +503,7 @@ def find_group_cohorts(
for rowidx in order:
if present_labels[rowidx] in merged_keys:
continue
- cohidx = containment.indices[
- slice(containment.indptr[rowidx], containment.indptr[rowidx + 1])
- ]
+ cohidx = containment.indices[slice(containment.indptr[rowidx], containment.indptr[rowidx + 1])]
cohort_ = present_labels[cohidx]
cohort = [elem.item() for elem in cohort_ if elem not in merged_keys]
if not cohort:
@@ -604,9 +595,7 @@ def rechunk_for_cohorts(
else:
next_break_is_close = False
- if (not ignore_old_chunks and idx in oldbreaks) or (
- counter >= chunksize and not next_break_is_close
- ):
+ if (not ignore_old_chunks and idx in oldbreaks) or (counter >= chunksize and not next_break_is_close):
divisions.append(idx)
counter = 1
continue
@@ -922,7 +911,10 @@ def chunk_argreduce(
if reindex and expected_groups is not None:
results["intermediates"][1] = reindex_(
- results["intermediates"][1], results["groups"].squeeze(), expected_groups, fill_value=0
+ results["intermediates"][1],
+ results["groups"].squeeze(),
+ expected_groups,
+ fill_value=0,
)
assert results["intermediates"][0].shape == results["intermediates"][1].shape
@@ -1017,8 +1009,7 @@ def chunk_reduce(
order = "C"
if nax > 1:
needs_broadcast = any(
- group_idx.shape[ax] != array.shape[ax] and group_idx.shape[ax] == 1
- for ax in range(-nax, 0)
+ group_idx.shape[ax] != array.shape[ax] and group_idx.shape[ax] == 1 for ax in range(-nax, 0)
)
if needs_broadcast:
# This is the dim=... case, it's a lot faster to ravel group_idx
@@ -1098,9 +1089,7 @@ def chunk_reduce(
result = result[..., :-1]
# TODO: Figure out how to generalize this
if reduction in ("quantile", "nanquantile"):
- new_dims_shape = tuple(
- dim.size for dim in quantile_new_dims_func(**kw) if not dim.is_scalar
- )
+ new_dims_shape = tuple(dim.size for dim in quantile_new_dims_func(**kw) if not dim.is_scalar)
else:
new_dims_shape = tuple()
result = result.reshape(new_dims_shape + final_array_shape[:-1] + found_groups_shape)
@@ -1168,7 +1157,10 @@ def _finalize_results(
# Final reindexing has to be here to be lazy
if not reindex and expected_groups is not None:
finalized[agg.name] = reindex_(
- finalized[agg.name], squeezed["groups"], expected_groups, fill_value=fill_value
+ finalized[agg.name],
+ squeezed["groups"],
+ expected_groups,
+ fill_value=fill_value,
)
finalized["groups"] = expected_groups.to_numpy()
else:
@@ -1194,9 +1186,7 @@ def _aggregate(
def _expand_dims(results: IntermediateDict) -> IntermediateDict:
- results["intermediates"] = tuple(
- np.expand_dims(array, DUMMY_AXIS) for array in results["intermediates"]
- )
+ results["intermediates"] = tuple(np.expand_dims(array, DUMMY_AXIS) for array in results["intermediates"])
return results
@@ -1238,7 +1228,8 @@ def _simple_combine(
# So now reindex before combining by reducing along DUMMY_AXIS
unique_groups = _find_unique_groups(x_chunk)
x_chunk = deepmap(
- partial(reindex_intermediates, agg=agg, unique_groups=unique_groups), x_chunk
+ partial(reindex_intermediates, agg=agg, unique_groups=unique_groups),
+ x_chunk,
)
else:
unique_groups = deepfirst(x_chunk)["groups"]
@@ -1280,7 +1271,10 @@ def reindex_intermediates(x: IntermediateDict, agg: Aggregation, unique_groups)
newx: IntermediateDict = {"groups": np.broadcast_to(unique_groups, new_shape)}
newx["intermediates"] = tuple(
reindex_(
- v, from_=np.atleast_1d(x["groups"].squeeze()), to=pd.Index(unique_groups), fill_value=f
+ v,
+ from_=np.atleast_1d(x["groups"].squeeze()),
+ to=pd.Index(unique_groups),
+ fill_value=f,
)
for v, f in zip(x["intermediates"], agg.fill_value["intermediate"])
)
@@ -1315,7 +1309,8 @@ def _grouped_combine(
# I bet we can minimize the amount of reindexing for mD reductions too, but it's complicated
unique_groups = _find_unique_groups(x_chunk)
x_chunk = deepmap(
- partial(reindex_intermediates, agg=agg, unique_groups=unique_groups), x_chunk
+ partial(reindex_intermediates, agg=agg, unique_groups=unique_groups),
+ x_chunk,
)
# these are negative axis indices useful for concatenating the intermediates
@@ -1332,15 +1327,16 @@ def _grouped_combine(
# We need to send the intermediate array values & indexes at the same time
# intermediates are (value e.g. max, index e.g. argmax, counts)
- array_idx = tuple(
- _conc2(x_chunk, key1="intermediates", key2=idx, axis=axis) for idx in (0, 1)
- )
+ array_idx = tuple(_conc2(x_chunk, key1="intermediates", key2=idx, axis=axis) for idx in (0, 1))
# for a single element along axis, we don't want to run the argreduction twice
# This happens when we are reducing along an axis with a single chunk.
avoid_reduction = array_idx[0].shape[axis[0]] == 1
if avoid_reduction:
- results: IntermediateDict = {"groups": groups, "intermediates": list(array_idx)}
+ results: IntermediateDict = {
+ "groups": groups,
+ "intermediates": list(array_idx),
+ }
else:
results = chunk_argreduce(
array_idx,
@@ -1387,12 +1383,8 @@ def _grouped_combine(
array = _conc2(x_chunk, key1="intermediates", key2=idx, axis=axis)
if array.shape[-1] == 0:
# all empty when combined
- results["intermediates"].append(
- np.empty(shape=(1,) * (len(axis) - 1) + (0,), dtype=dtype)
- )
- results["groups"] = np.empty(
- shape=(1,) * (len(neg_axis) - 1) + (0,), dtype=groups.dtype
- )
+ results["intermediates"].append(np.empty(shape=(1,) * (len(axis) - 1) + (0,), dtype=dtype))
+ results["groups"] = np.empty(shape=(1,) * (len(neg_axis) - 1) + (0,), dtype=groups.dtype)
else:
_results = chunk_reduce(
array,
@@ -1456,9 +1448,7 @@ def _reduce_blockwise(
if _is_arg_reduction(agg):
results["intermediates"][0] = np.unravel_index(results["intermediates"][0], array.shape)[-1]
- result = _finalize_results(
- results, agg, axis, expected_groups, fill_value=fill_value, reindex=reindex
- )
+ result = _finalize_results(results, agg, axis, expected_groups, fill_value=fill_value, reindex=reindex)
return result
@@ -1570,7 +1560,6 @@ def _extract_unknown_groups(reduced, dtype) -> tuple[DaskArray]:
def _unify_chunks(array, by):
-
from dask.array import from_array, unify_chunks
inds = tuple(range(array.ndim))
@@ -1653,9 +1642,7 @@ def dask_groupby_agg(
if method == "blockwise":
# use the "non dask" code path, but applied blockwise
- blockwise_method = partial(
- _reduce_blockwise, agg=agg, fill_value=fill_value, reindex=reindex
- )
+ blockwise_method = partial(_reduce_blockwise, agg=agg, fill_value=fill_value, reindex=reindex)
else:
# choose `chunk_reduce` or `chunk_argreduce`
blockwise_method = partial(
@@ -1752,7 +1739,9 @@ def dask_groupby_agg(
reindexed,
combine=partial(combine, agg=agg, reindex=do_simple_combine),
aggregate=partial(
- aggregate, expected_groups=cohort_index, reindex=do_simple_combine
+ aggregate,
+ expected_groups=cohort_index,
+ reindex=do_simple_combine,
),
)
)
@@ -1882,9 +1871,7 @@ def cubed_groupby_agg(
# let's always do it anyway
if not is_chunked_array(by):
# chunk numpy arrays like the input array
- chunks = tuple(
- array.chunks[ax] if by.shape[ax] != 1 else (1,) for ax in range(-by.ndim, 0)
- )
+ chunks = tuple(array.chunks[ax] if by.shape[ax] != 1 else (1,) for ax in range(-by.ndim, 0))
by = cubed.from_array(by, chunks=chunks, spec=array.spec)
_, (array, by) = cubed.core.unify_chunks(array, inds, by, inds[-by.ndim :])
@@ -1918,9 +1905,7 @@ def cubed_groupby_agg(
out = blockwise_method(a, by)
# Convert dict to one that cubed understands, dropping groups since they are
# known, and the same for every block.
- return {
- f"f{idx}": intermediate for idx, intermediate in enumerate(out["intermediates"])
- }
+ return {f"f{idx}": intermediate for idx, intermediate in enumerate(out["intermediates"])}
def _groupby_combine(a, axis, dummy_axis, dtype, keepdims):
# this is similar to _simple_combine, except the dummy axis and concatenation is handled by cubed
@@ -2003,18 +1988,14 @@ def _validate_reindex(
) -> bool | None:
# logger.debug("Entering _validate_reindex: reindex is {}".format(reindex)) # noqa
def first_or_last():
- return func in ["first", "last"] or (
- _is_first_last_reduction(func) and array_dtype.kind != "f"
- )
+ return func in ["first", "last"] or (_is_first_last_reduction(func) and array_dtype.kind != "f")
all_numpy = not is_dask_array and not any_by_dask
if reindex is True and not all_numpy:
if _is_arg_reduction(func):
raise NotImplementedError
if method == "cohorts" or (method == "blockwise" and not any_by_dask):
- raise ValueError(
- "reindex=True is not a valid choice for method='blockwise' or method='cohorts'."
- )
+ raise ValueError("reindex=True is not a valid choice for method='blockwise' or method='cohorts'.")
if first_or_last():
raise ValueError("reindex must be None or False when func is 'first' or 'last.")
@@ -2144,9 +2125,7 @@ def _factorize_multiple(
for by_, expect in zip(by, expected_groups):
if expect is None:
if is_duck_dask_array(by_):
- raise ValueError(
- "Please provide expected_groups when grouping by a dask array."
- )
+ raise ValueError("Please provide expected_groups when grouping by a dask array.")
found_group = pd.unique(by_.reshape(-1))
else:
@@ -2177,7 +2156,7 @@ def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) ->
return (None,) * nby
if nby == 1 and not isinstance(expected_groups, tuple):
- if isinstance(expected_groups, (pd.Index, np.ndarray)):
+ if isinstance(expected_groups, pd.Index | np.ndarray):
return (expected_groups,)
else:
array = np.asarray(expected_groups)
@@ -2252,9 +2231,7 @@ def _choose_engine(by, agg: Aggregation):
(isinstance(func, str) and "nan" in func) for func in agg.chunk
)
if HAS_NUMBAGG:
- if agg.name in ["all", "any"] or (
- not_arg_reduce and has_blockwise_nan_skipping and dtype is None
- ):
+ if agg.name in ["all", "any"] or (not_arg_reduce and has_blockwise_nan_skipping and dtype is None):
logger.debug("_choose_engine: Choosing 'numbagg'")
return "numbagg"
@@ -2411,11 +2388,7 @@ def groupby_reduce(
any_by_dask = any(by_is_dask)
provided_expected = expected_groups is not None
- if (
- engine == "numbagg"
- and _is_arg_reduction(func)
- and (any_by_dask or is_duck_dask_array(array))
- ):
+ if engine == "numbagg" and _is_arg_reduction(func) and (any_by_dask or is_duck_dask_array(array)):
# There is only one test that fails, but I can't figure
# out why without deep debugging.
# just disable for now.
@@ -2515,9 +2488,7 @@ def groupby_reduce(
# TODO: Does this depend on chunking of by?
# For e.g., we could relax this if there is only one chunk along all
# by dim != axis?
- raise NotImplementedError(
- "Please provide ``expected_groups`` when not reducing along all axes."
- )
+ raise NotImplementedError("Please provide ``expected_groups`` when not reducing along all axes.")
assert nax <= by_.ndim
if nax < by_.ndim:
@@ -2580,7 +2551,13 @@ def groupby_reduce(
elif not has_dask:
results = _reduce_blockwise(
- array, by_, agg, expected_groups=expected_, reindex=bool(reindex), sort=sort, **kwargs
+ array,
+ by_,
+ agg,
+ expected_groups=expected_,
+ reindex=bool(reindex),
+ sort=sort,
+ **kwargs,
)
groups = (results["groups"],)
result = results[agg.name]
@@ -2627,7 +2604,13 @@ def groupby_reduce(
# TODO: clean this up
reindex = _validate_reindex(
- reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array), array.dtype
+ reindex,
+ func,
+ method,
+ expected_,
+ any_by_dask,
+ is_duck_dask_array(array),
+ array.dtype,
)
if TYPE_CHECKING:
@@ -2798,9 +2781,7 @@ def groupby_scan(
if expected_groups is not None:
raise NotImplementedError("Setting `expected_groups` and binning is not supported yet.")
expected_groups = _validate_expected_groups(nby, expected_groups)
- expected_groups = _convert_expected_groups_to_index(
- expected_groups, isbin=(False,) * nby, sort=False
- )
+ expected_groups = _convert_expected_groups_to_index(expected_groups, isbin=(False,) * nby, sort=False)
# Don't factorize early only when
# grouping by dask arrays, and not having expected_groups
@@ -2918,7 +2899,12 @@ def dask_groupby_scan(array, by, axes: T_Axes, agg: Scan) -> DaskArray:
# 1. zip together group indices & array
zipped = map_blocks(
- _zip, by, array, dtype=array.dtype, meta=array._meta, name="groupby-scan-preprocess"
+ _zip,
+ by,
+ array,
+ dtype=array.dtype,
+ meta=array._meta,
+ name="groupby-scan-preprocess",
)
scan_ = partial(chunk_scan, agg=agg)
=====================================
flox/xarray.py
=====================================
@@ -1,7 +1,7 @@
from __future__ import annotations
from collections.abc import Hashable, Iterable, Sequence
-from typing import TYPE_CHECKING, Any, Union
+from typing import TYPE_CHECKING, Any
import numpy as np
import pandas as pd
@@ -25,7 +25,7 @@ if TYPE_CHECKING:
from .core import T_ExpectedGroupsOpt, T_ExpectIndex, T_ExpectOpt
- Dims = Union[str, Iterable[Hashable], None]
+ Dims = str | Iterable[Hashable] | None
def _restore_dim_order(result, obj, by, no_groupby_reorder=False):
@@ -286,9 +286,7 @@ def xarray_reduce(
try:
xr.align(ds, *by_da, join="exact", copy=False)
except ValueError as e:
- raise ValueError(
- "Object being grouped must be exactly aligned with every array in `by`."
- ) from e
+ raise ValueError("Object being grouped must be exactly aligned with every array in `by`.") from e
needs_broadcast = any(
not set(grouper_dims).issubset(set(variable.dims)) for variable in ds.data_vars.values()
@@ -329,15 +327,11 @@ def xarray_reduce(
group_names: tuple[Any, ...] = ()
group_sizes: dict[Any, int] = {}
for idx, (b_, expect, isbin_) in enumerate(zip(by_da, expected_groups_valid, isbins)):
- group_name = (
- f"{b_.name}_bins" if isbin_ or isinstance(expect, pd.IntervalIndex) else b_.name
- )
+ group_name = f"{b_.name}_bins" if isbin_ or isinstance(expect, pd.IntervalIndex) else b_.name
group_names += (group_name,)
if isbin_ and isinstance(expect, int):
- raise NotImplementedError(
- "flox does not support binning into an integer number of bins yet."
- )
+ raise NotImplementedError("flox does not support binning into an integer number of bins yet.")
expect1: T_ExpectOpt
if expect is None:
@@ -448,7 +442,8 @@ def xarray_reduce(
output_core_dims=[output_core_dims],
dask="allowed",
dask_gufunc_kwargs=dict(
- output_sizes=output_sizes, output_dtypes=[dtype] if dtype is not None else None
+ output_sizes=output_sizes,
+ output_dtypes=[dtype] if dtype is not None else None,
),
keep_attrs=keep_attrs,
kwargs={
@@ -520,11 +515,12 @@ def xarray_reduce(
template = obj
if actual[var].ndim > 1 + len(vector_dims):
- no_groupby_reorder = isinstance(
- obj, xr.Dataset
- ) # do not re-order dataarrays inside datasets
+ no_groupby_reorder = isinstance(obj, xr.Dataset) # do not re-order dataarrays inside datasets
actual[var] = _restore_dim_order(
- actual[var], template, by_da[0], no_groupby_reorder=no_groupby_reorder
+ actual[var].variable,
+ template,
+ by_da[0],
+ no_groupby_reorder=no_groupby_reorder,
)
if missing_dim:
@@ -625,13 +621,14 @@ def _rechunk(func, obj, dim, labels, **kwargs):
if obj[var].chunks is not None:
obj[var] = obj[var].copy(
data=func(
- obj[var].data, axis=obj[var].get_axis_num(dim), labels=labels.data, **kwargs
+ obj[var].data,
+ axis=obj[var].get_axis_num(dim),
+ labels=labels.data,
+ **kwargs,
)
)
else:
if obj.chunks is not None:
- obj = obj.copy(
- data=func(obj.data, axis=obj.get_axis_num(dim), labels=labels.data, **kwargs)
- )
+ obj = obj.copy(data=func(obj.data, axis=obj.get_axis_num(dim), labels=labels.data, **kwargs))
return obj
=====================================
flox/xrdtypes.py
=====================================
@@ -150,9 +150,14 @@ def is_datetime_like(dtype):
return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)
-def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype:
+def _normalize_dtype(
+ dtype: DTypeLike, array_dtype: np.dtype, preserves_dtype: bool, fill_value=None
+) -> np.dtype:
if dtype is None:
- dtype = array_dtype
+ if not preserves_dtype:
+ dtype = _maybe_promote_int(array_dtype)
+ else:
+ dtype = array_dtype
if dtype is np.floating:
# mean, std, var always result in floating
# but we preserve the array's dtype if it is floating
=====================================
flox/xrutils.py
=====================================
@@ -4,14 +4,14 @@
import datetime
import importlib
from collections.abc import Iterable
-from typing import Any, Optional
+from typing import Any
import numpy as np
import pandas as pd
from packaging.version import Version
-def module_available(module: str, minversion: Optional[str] = None) -> bool:
+def module_available(module: str, minversion: str | None = None) -> bool:
"""Checks whether a module is installed without importing it.
Use this for a lightweight check and lazy imports.
@@ -137,7 +137,7 @@ def is_scalar(value: Any, include_0d: bool = True) -> bool:
include_0d = getattr(value, "ndim", None) == 0
return (
include_0d
- or isinstance(value, (str, bytes, dict))
+ or isinstance(value, str | bytes | dict)
or not (
isinstance(value, (Iterable,) + NON_NUMPY_SUPPORTED_ARRAY_TYPES)
or hasattr(value, "__array_function__")
@@ -150,7 +150,7 @@ def notnull(data):
data = np.asarray(data)
scalar_type = data.dtype.type
- if issubclass(scalar_type, (np.bool_, np.integer, np.character, np.void)):
+ if issubclass(scalar_type, np.bool_ | np.integer | np.character | np.void):
# these types cannot represent missing values
return np.ones_like(data, dtype=bool)
else:
@@ -163,7 +163,7 @@ def isnull(data):
if not is_duck_array(data):
data = np.asarray(data)
scalar_type = data.dtype.type
- if issubclass(scalar_type, (np.datetime64, np.timedelta64)):
+ if issubclass(scalar_type, np.datetime64 | np.timedelta64):
# datetime types use NaT for null
# note: must check timedelta64 before integers, because currently
# timedelta64 inherits from np.integer
@@ -171,12 +171,12 @@ def isnull(data):
elif issubclass(scalar_type, np.inexact):
# float types use NaN for null
return np.isnan(data)
- elif issubclass(scalar_type, (np.bool_, np.integer, np.character, np.void)):
+ elif issubclass(scalar_type, np.bool_ | np.integer | np.character | np.void):
# these types cannot represent missing values
return np.zeros_like(data, dtype=bool)
else:
# at this point, array should have dtype=object
- if isinstance(data, (np.ndarray, dask_array_type)):
+ if isinstance(data, (np.ndarray, dask_array_type)): # noqa
return pd.isnull(data)
else:
# Not reachable yet, but intended for use with other duck array
@@ -275,9 +275,7 @@ def timedelta_to_numeric(value, datetime_unit="ns", dtype=float):
try:
a = pd.to_timedelta(value)
except ValueError:
- raise ValueError(
- f"Could not convert {value!r} to timedelta64 using pandas.to_timedelta"
- )
+ raise ValueError(f"Could not convert {value!r} to timedelta64 using pandas.to_timedelta")
return py_timedelta_to_float(a, datetime_unit)
else:
raise TypeError(
=====================================
pyproject.toml
=====================================
@@ -3,7 +3,7 @@ name = "flox"
description = "GroupBy operations for dask.array"
license = {file = "LICENSE"}
readme = "README.md"
-requires-python = ">=3.9"
+requires-python = ">=3.10"
keywords = ["xarray", "dask", "groupby"]
classifiers = [
"Development Status :: 4 - Beta",
@@ -11,7 +11,6 @@ classifiers = [
"Natural Language :: English",
"Operating System :: OS Independent",
"Programming Language :: Python",
- "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
@@ -60,12 +59,9 @@ fallback_version = "999"
write_to = "flox/_version.py"
write_to_template= '__version__ = "{version}"'
-[tool.black]
-line-length = 100
-target-version = ["py39"]
-
[tool.ruff]
-target-version = "py39"
+line-length = 110
+target-version = "py310"
builtins = ["ellipsis"]
exclude = [
".eggs",
@@ -109,6 +105,10 @@ known-third-party = [
"xarray"
]
+[tool.ruff.format]
+# Enable reformatting of code snippets in docstrings.
+docstring-code-format = true
+
[tool.mypy]
allow_redefinition = true
files = "**/*.py"
=====================================
tests/__init__.py
=====================================
@@ -188,9 +188,9 @@ def dask_assert_eq(
a_original = a
b_original = b
- if isinstance(a, (list, int, float)):
+ if isinstance(a, list | int | float):
a = np.array(a)
- if isinstance(b, (list, int, float)):
+ if isinstance(b, list | int | float):
b = np.array(b)
a, adt, a_meta, a_computed = _get_dt_meta_computed(
=====================================
tests/strategies.py
=====================================
@@ -1,6 +1,7 @@
from __future__ import annotations
-from typing import Any, Callable
+from collections.abc import Callable
+from typing import Any
import cftime
import dask
@@ -26,22 +27,28 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]:
# TODO: stop excluding everything but U
-array_dtype_st = supported_dtypes().filter(lambda x: x.kind not in "cmMU")
+array_dtypes = supported_dtypes().filter(lambda x: x.kind not in "cmMU")
by_dtype_st = supported_dtypes()
-NON_NUMPY_FUNCS = ["first", "last", "nanfirst", "nanlast", "count", "any", "all"] + list(
- SCIPY_STATS_FUNCS
-)
+NON_NUMPY_FUNCS = [
+ "first",
+ "last",
+ "nanfirst",
+ "nanlast",
+ "count",
+ "any",
+ "all",
+] + list(SCIPY_STATS_FUNCS)
SKIPPED_FUNCS = ["var", "std", "nanvar", "nanstd"]
-func_st = st.sampled_from(
- [f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS]
-)
+func_st = st.sampled_from([f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS])
numeric_arrays = npst.arrays(
- elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtype_st
+ elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtypes
)
all_arrays = npst.arrays(
- elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=supported_dtypes()
+ elements={"allow_subnormal": False},
+ shape=npst.array_shapes(),
+ dtype=supported_dtypes(),
)
calendars = st.sampled_from(
=====================================
tests/test_asv.py
=====================================
@@ -7,9 +7,7 @@ pytest.importorskip("dask")
from asv_bench.benchmarks import reduce
- at pytest.mark.parametrize(
- "problem", [reduce.ChunkReduce1D, reduce.ChunkReduce2D, reduce.ChunkReduce2DAllAxes]
-)
+ at pytest.mark.parametrize("problem", [reduce.ChunkReduce1D, reduce.ChunkReduce2D, reduce.ChunkReduce2DAllAxes])
def test_reduce(problem) -> None:
testcase = problem()
testcase.setup()
=====================================
tests/test_core.py
=====================================
@@ -3,8 +3,9 @@ from __future__ import annotations
import itertools
import logging
import warnings
+from collections.abc import Callable
from functools import partial, reduce
-from typing import TYPE_CHECKING, Callable
+from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch
import numpy as np
@@ -80,7 +81,7 @@ def _get_array_func(func: str) -> Callable:
def npfunc(x, **kwargs):
x = np.asarray(x)
- return (~np.isnan(x)).sum()
+ return (~xrutils.isnull(x)).sum(**kwargs)
elif func in ["nanfirst", "nanlast"]:
npfunc = getattr(xrutils, func)
@@ -126,14 +127,24 @@ def test_alignment_error():
("sum", np.ones((12,)), nan_labels, [1, 4, 2]), # form 1
("sum", np.ones((2, 12)), labels, [[3, 4, 5], [3, 4, 5]]), # form 3
("sum", np.ones((2, 12)), nan_labels, [[1, 4, 2], [1, 4, 2]]), # form 3
- ("sum", np.ones((2, 12)), np.array([labels, labels]), [6, 8, 10]), # form 1 after reshape
+ (
+ "sum",
+ np.ones((2, 12)),
+ np.array([labels, labels]),
+ [6, 8, 10],
+ ), # form 1 after reshape
("sum", np.ones((2, 12)), np.array([nan_labels, nan_labels]), [2, 8, 4]),
# (np.ones((12,)), np.array([labels, labels])), # form 4
("count", np.ones((12,)), labels, [3, 4, 5]), # form 1
("count", np.ones((12,)), nan_labels, [1, 4, 2]), # form 1
("count", np.ones((2, 12)), labels, [[3, 4, 5], [3, 4, 5]]), # form 3
("count", np.ones((2, 12)), nan_labels, [[1, 4, 2], [1, 4, 2]]), # form 3
- ("count", np.ones((2, 12)), np.array([labels, labels]), [6, 8, 10]), # form 1 after reshape
+ (
+ "count",
+ np.ones((2, 12)),
+ np.array([labels, labels]),
+ [6, 8, 10],
+ ), # form 1 after reshape
("count", np.ones((2, 12)), np.array([nan_labels, nan_labels]), [2, 8, 4]),
("nanmean", np.ones((12,)), labels, [1, 1, 1]), # form 1
("nanmean", np.ones((12,)), nan_labels, [1, 1, 1]), # form 1
@@ -215,9 +226,7 @@ def gen_array_by(size, func):
@pytest.mark.parametrize("add_nan_by", [True, False])
@pytest.mark.parametrize("func", ALL_FUNCS)
def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
- if ("arg" in func and engine in ["flox", "numbagg"]) or (
- func in BLOCKWISE_FUNCS and chunks != -1
- ):
+ if ("arg" in func and engine in ["flox", "numbagg"]) or (func in BLOCKWISE_FUNCS and chunks != -1):
pytest.skip()
array, by = gen_array_by(size, func)
@@ -237,7 +246,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
fill_value = np.nan
tolerance = {"rtol": 1e-13, "atol": 1e-15}
elif "quantile" in func:
- finalize_kwargs = [{"q": DEFAULT_QUANTILE}, {"q": [DEFAULT_QUANTILE / 2, DEFAULT_QUANTILE]}]
+ finalize_kwargs = [
+ {"q": DEFAULT_QUANTILE},
+ {"q": [DEFAULT_QUANTILE / 2, DEFAULT_QUANTILE]},
+ ]
fill_value = None
tolerance = None
else:
@@ -313,7 +325,12 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
combine_error = RuntimeError("This combine should not have been called.")
for method, reindex in params:
call = partial(
- groupby_reduce, array, *by, method=method, reindex=reindex, **flox_kwargs
+ groupby_reduce,
+ array,
+ *by,
+ method=method,
+ reindex=reindex,
+ **flox_kwargs,
)
if ("arg" in func or func in ["first", "last"]) and reindex is True:
# simple_combine with argreductions not supported right now
@@ -461,7 +478,9 @@ def test_groupby_agg_dask(func, shape, array_chunks, group_chunks, add_nan, dtyp
labels[-2:] = np.nan
kwargs = dict(
- func=func, expected_groups=[0, 1, 2], fill_value=False if func in ["all", "any"] else 123
+ func=func,
+ expected_groups=[0, 1, 2],
+ fill_value=False if func in ["all", "any"] else 123,
)
expected, _ = groupby_reduce(array.compute(), labels, engine="numpy", **kwargs)
@@ -674,15 +693,16 @@ def test_first_last_disallowed_dask(func):
# anything else is not.
with pytest.raises(ValueError):
groupby_reduce(
- dask.array.empty((2, 3, 2), chunks=(-1, -1, 1)), np.ones((2,)), func=func, axis=-1
+ dask.array.empty((2, 3, 2), chunks=(-1, -1, 1)),
+ np.ones((2,)),
+ func=func,
+ axis=-1,
)
@requires_dask
@pytest.mark.parametrize("func", ALL_FUNCS)
- at pytest.mark.parametrize(
- "axis", [None, (0, 1, 2), (0, 1), (0, 2), (1, 2), 0, 1, 2, (0,), (1,), (2,)]
-)
+ at pytest.mark.parametrize("axis", [None, (0, 1, 2), (0, 1), (0, 2), (1, 2), 0, 1, 2, (0,), (1,), (2,)])
def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
if ("arg" in func and engine in ["flox", "numbagg"]) or func in BLOCKWISE_FUNCS:
pytest.skip()
@@ -797,7 +817,8 @@ def test_groupby_reduce_nans(reindex, chunks, axis, groups, expected_shape, engi
@requires_dask
@pytest.mark.parametrize(
- "expected_groups, reindex", [(None, None), (None, False), ([0, 1, 2], True), ([0, 1, 2], False)]
+ "expected_groups, reindex",
+ [(None, None), (None, False), ([0, 1, 2], True), ([0, 1, 2], False)],
)
def test_groupby_all_nan_blocks_dask(expected_groups, reindex, engine):
labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0])
@@ -848,11 +869,16 @@ def test_bad_npg_behaviour():
# fmt: off
array = np.array([[1] * 12, [1] * 12])
# fmt: on
- assert_equal(aggregate(labels, array, axis=-1, func="argmax"), np.array([[0, 5, 2], [0, 5, 2]]))
+ assert_equal(
+ aggregate(labels, array, axis=-1, func="argmax"),
+ np.array([[0, 5, 2], [0, 5, 2]]),
+ )
assert (
aggregate(
- np.array([0, 1, 2, 0, 1, 2]), np.array([-np.inf, 0, 0, -np.inf, 0, 0]), func="max"
+ np.array([0, 1, 2, 0, 1, 2]),
+ np.array([-np.inf, 0, 0, -np.inf, 0, 0]),
+ func="max",
)[0]
== -np.inf
)
@@ -900,13 +926,17 @@ def test_groupby_bins(chunk_labels, kwargs, chunks, engine, method) -> None:
with raise_if_dask_computes():
actual, *groups = groupby_reduce(
- array, labels, func="count", fill_value=0, engine=engine, method=method, **kwargs
+ array,
+ labels,
+ func="count",
+ fill_value=0,
+ engine=engine,
+ method=method,
+ **kwargs,
)
(groups_array,) = groups
expected = np.array([3, 1, 0], dtype=np.intp)
- for left, right in zip(
- groups_array, pd.IntervalIndex.from_arrays([1, 2, 4], [2, 4, 5]).to_numpy()
- ):
+ for left, right in zip(groups_array, pd.IntervalIndex.from_arrays([1, 2, 4], [2, 4, 5]).to_numpy()):
assert left == right
assert_equal(actual, expected)
@@ -940,7 +970,11 @@ def test_rechunk_for_blockwise(inchunks, expected):
[[[0, 1, 2, 3]], [0, 1, 2, 0, 1, 2, 3], (3, 4)],
[[[0], [1], [2], [3]], [0, 1, 2, 0, 1, 2, 3], (2, 2, 2, 1)],
[[[0, 1, 2], [3]], [0, 1, 2, 0, 1, 2, 3], (3, 3, 1)],
- [[[0], [1, 2, 3, 4], [5]], np.repeat(np.arange(6), [4, 4, 12, 2, 3, 4]), (4, 8, 4, 9, 4)],
+ [
+ [[0], [1, 2, 3, 4], [5]],
+ np.repeat(np.arange(6), [4, 4, 12, 2, 3, 4]),
+ (4, 8, 4, 9, 4),
+ ],
],
)
def test_find_group_cohorts(expected, labels, chunks: tuple[int]) -> None:
@@ -1039,11 +1073,14 @@ def test_fill_value_behaviour(func, chunks, fill_value, engine):
if chunks:
array = dask.array.from_array(array, chunks)
actual, _ = groupby_reduce(
- array, by, func=func, engine=engine, fill_value=fill_value, expected_groups=[0, 1, 2, 3]
- )
- expected = np.array(
- [fill_value, fill_value, npfunc([1.0, 1.0], axis=0), npfunc([1.0, 1.0], axis=0)]
+ array,
+ by,
+ func=func,
+ engine=engine,
+ fill_value=fill_value,
+ expected_groups=[0, 1, 2, 3],
)
+ expected = np.array([fill_value, fill_value, npfunc([1.0, 1.0], axis=0), npfunc([1.0, 1.0], axis=0)])
assert_equal(actual, expected)
@@ -1140,7 +1177,12 @@ def test_dtype_promotion(func, fill_value, expected, engine):
by = [0, 1]
actual, _ = groupby_reduce(
- array, by, func=func, expected_groups=[1, 2], fill_value=fill_value, engine=engine
+ array,
+ by,
+ func=func,
+ expected_groups=[1, 2],
+ fill_value=fill_value,
+ engine=engine,
)
assert np.issubdtype(actual.dtype, expected)
@@ -1259,9 +1301,7 @@ def test_group_by_datetime_cubed(engine, method):
assert_equal(expected, actual)
edges = pd.date_range("1999-12-31", "2000-12-31", freq="ME").to_series().to_numpy()
- actual, _ = groupby_reduce(
- cubedarray, t.to_numpy(), isbin=True, expected_groups=edges, **kwargs
- )
+ actual, _ = groupby_reduce(cubedarray, t.to_numpy(), isbin=True, expected_groups=edges, **kwargs)
expected = data.resample("ME").mean().to_numpy()
assert_equal(expected, actual)
@@ -1316,9 +1356,7 @@ def test_multiple_groupers_bins(chunk) -> None:
@pytest.mark.parametrize("expected_groups", [None, (np.arange(5), [2, 3]), (None, [2, 3])])
- at pytest.mark.parametrize(
- "by1", [np.arange(5)[:, None], np.broadcast_to(np.arange(5)[:, None], (5, 2))]
-)
+ at pytest.mark.parametrize("by1", [np.arange(5)[:, None], np.broadcast_to(np.arange(5)[:, None], (5, 2))])
@pytest.mark.parametrize(
"by2",
[
@@ -1341,9 +1379,7 @@ def test_multiple_groupers(chunk, by1, by2, expected_groups) -> None:
# output from `count` is intp
expected = np.ones((5, 2), dtype=np.intp)
- actual, *_ = groupby_reduce(
- array, by1, by2, axis=(0, 1), func="count", expected_groups=expected_groups
- )
+ actual, *_ = groupby_reduce(array, by1, by2, axis=(0, 1), func="count", expected_groups=expected_groups)
assert_equal(expected, actual)
@@ -1435,9 +1471,7 @@ def test_custom_aggregation_blockwise():
dtype=dtype,
)
- agg_median = Aggregation(
- name="median", numpy=grouped_median, fill_value=-1, chunk=None, combine=None
- )
+ agg_median = Aggregation(name="median", numpy=grouped_median, fill_value=-1, chunk=None, combine=None)
array = np.arange(100, dtype=np.float32).reshape(5, 20)
by = np.ones((20,))
@@ -1480,7 +1514,12 @@ def test_dtype(func, dtype, engine):
arr = np.ones((4, 12), dtype=dtype)
labels = np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"])
actual, _ = groupby_reduce(
- arr, labels, func=func, dtype=np.float64, engine=engine, finalize_kwargs=finalize_kwargs
+ arr,
+ labels,
+ func=func,
+ dtype=np.float64,
+ engine=engine,
+ finalize_kwargs=finalize_kwargs,
)
assert actual.dtype == np.dtype("float64")
@@ -1587,9 +1626,7 @@ def test_subset_block_2d(flatblocks, expectidx):
[True, None, "sum", ([1], None), True],
],
)
-def test_validate_reindex_map_reduce(
- dask_expected, reindex, func, expected_groups, any_by_dask
-) -> None:
+def test_validate_reindex_map_reduce(dask_expected, reindex, func, expected_groups, any_by_dask) -> None:
actual = _validate_reindex(
reindex,
func,
@@ -1720,12 +1757,20 @@ def test_1d_blockwise_sort_optimization():
assert all("getitem" not in k for k in actual.dask)
actual, _ = groupby_reduce(
- array, time.dt.dayofyear.values[::-1], sort=True, method="blockwise", func="count"
+ array,
+ time.dt.dayofyear.values[::-1],
+ sort=True,
+ method="blockwise",
+ func="count",
)
assert any("getitem" in k for k in actual.dask.layers)
actual, _ = groupby_reduce(
- array, time.dt.dayofyear.values[::-1], sort=False, method="blockwise", func="count"
+ array,
+ time.dt.dayofyear.values[::-1],
+ sort=False,
+ method="blockwise",
+ func="count",
)
assert all("getitem" not in k for k in actual.dask.layers)
@@ -1760,9 +1805,7 @@ def test_negative_index_factorize_race_condition():
@pytest.mark.parametrize("sort", [True, False])
def test_expected_index_conversion_passthrough_range_index(sort):
index = pd.RangeIndex(100)
- actual = _convert_expected_groups_to_index(
- expected_groups=(index,), isbin=(False,), sort=(sort,)
- )
+ actual = _convert_expected_groups_to_index(expected_groups=(index,), isbin=(False,), sort=(sort,))
assert actual[0] is index
@@ -1935,11 +1978,22 @@ def test_ffill_bfill(chunks, size, add_nan_by, func):
def test_blockwise_nans():
array = dask.array.ones((1, 10), chunks=2)
by = np.array([-1, 0, -1, 1, -1, 2, -1, 3, 4, 4])
- actual, actual_groups = flox.groupby_reduce(
- array, by, func="sum", expected_groups=pd.RangeIndex(0, 5)
- )
+ actual, actual_groups = flox.groupby_reduce(array, by, func="sum", expected_groups=pd.RangeIndex(0, 5))
expected, expected_groups = flox.groupby_reduce(
array.compute(), by, func="sum", expected_groups=pd.RangeIndex(0, 5)
)
assert_equal(expected_groups, actual_groups)
assert_equal(expected, actual)
+
+
+ at pytest.mark.parametrize("func", ["sum", "prod", "count", "nansum"])
+ at pytest.mark.parametrize("engine", ["flox", "numpy"])
+def test_agg_dtypes(func, engine):
+ # regression test for GH388
+ counts = np.array([0, 2, 1, 0, 1])
+ group = np.array([1, 1, 1, 2, 2])
+ actual, _ = groupby_reduce(
+ counts, group, expected_groups=(np.array([1, 2]),), func=func, dtype="uint8", engine=engine
+ )
+ expected = _get_array_func(func)(counts, dtype="uint8")
+ assert actual.dtype == np.uint8 == expected.dtype
=====================================
tests/test_properties.py
=====================================
@@ -1,5 +1,6 @@
import warnings
-from typing import Any, Callable
+from collections.abc import Callable
+from typing import Any
import pandas as pd
import pytest
@@ -19,7 +20,7 @@ from flox.core import groupby_reduce, groupby_scan
from flox.xrutils import notnull
from . import assert_equal
-from .strategies import by_arrays, chunked_arrays, func_st, numeric_arrays
+from .strategies import array_dtypes, by_arrays, chunked_arrays, func_st, numeric_arrays
from .strategies import chunks as chunks_strategy
dask.config.set(scheduler="sync")
@@ -92,18 +93,18 @@ def test_groupby_reduce(data, array, func: str) -> None:
flox_kwargs: dict[str, Any] = {}
with np.errstate(invalid="ignore", divide="ignore"):
actual, *_ = groupby_reduce(
- array, by, func=func, axis=axis, engine="numpy", **flox_kwargs, finalize_kwargs=kwargs
+ array,
+ by,
+ func=func,
+ axis=axis,
+ engine="numpy",
+ **flox_kwargs,
+ finalize_kwargs=kwargs,
)
# numpy-groupies always does the calculation in float64
if (
- (
- "var" in func
- or "std" in func
- or "sum" in func
- or "mean" in func
- or "quantile" in func
- )
+ ("var" in func or "std" in func or "sum" in func or "mean" in func or "quantile" in func)
and array.dtype.kind == "f"
and array.dtype.itemsize != 8
):
@@ -195,8 +196,18 @@ def test_ffill_bfill_reverse(data, array: dask.array.Array) -> None:
def test_first_last(data, array: dask.array.Array, func: str) -> None:
by = data.draw(by_arrays(shape=(array.shape[-1],)))
- INVERSES = {"first": "last", "last": "first", "nanfirst": "nanlast", "nanlast": "nanfirst"}
- MATES = {"first": "nanfirst", "last": "nanlast", "nanfirst": "first", "nanlast": "last"}
+ INVERSES = {
+ "first": "last",
+ "last": "first",
+ "nanfirst": "nanlast",
+ "nanlast": "nanfirst",
+ }
+ MATES = {
+ "first": "nanfirst",
+ "last": "nanlast",
+ "nanfirst": "first",
+ "nanlast": "last",
+ }
inverse = INVERSES[func]
mate = MATES[func]
@@ -233,3 +244,25 @@ def test_first_last_useless(data, func):
actual, groups = groupby_reduce(array, by, axis=-1, func=func, engine="numpy")
expected = np.zeros(shape[:-1] + (len(groups),), dtype=array.dtype)
assert_equal(actual, expected)
+
+
+ at given(
+ func=st.sampled_from(["sum", "prod", "nansum", "nanprod"]),
+ engine=st.sampled_from(["numpy", "flox"]),
+ array_dtype=st.none() | array_dtypes,
+ dtype=st.none() | array_dtypes,
+)
+def test_agg_dtype_specified(func, array_dtype, dtype, engine):
+ # regression test for GH388
+ counts = np.array([0, 2, 1, 0, 1], dtype=array_dtype)
+ group = np.array([1, 1, 1, 2, 2])
+ actual, _ = groupby_reduce(
+ counts,
+ group,
+ expected_groups=(np.array([1, 2]),),
+ func=func,
+ dtype=dtype,
+ engine=engine,
+ )
+ expected = getattr(np, func)(counts, keepdims=True, dtype=dtype)
+ assert actual.dtype == expected.dtype
=====================================
tests/test_xarray.py
=====================================
@@ -24,7 +24,7 @@ if has_dask:
# test against legacy xarray implementation
# avoid some compilation overhead
-xr.set_options(use_flox=False, use_numbagg=False)
+xr.set_options(use_flox=False, use_numbagg=False, use_bottleneck=False)
tolerance64 = {"rtol": 1e-15, "atol": 1e-18}
np.random.seed(123)
@@ -49,7 +49,9 @@ def test_xarray_reduce(skipna, add_nan, min_count, engine_no_numba, reindex):
labels2 = np.array([1, 2, 2, 1])
da = xr.DataArray(
- arr, dims=("x", "y"), coords={"labels2": ("x", labels2), "labels": ("y", labels)}
+ arr,
+ dims=("x", "y"),
+ coords={"labels2": ("x", labels2), "labels": ("y", labels)},
).expand_dims(z=4)
expected = da.groupby("labels").sum(skipna=skipna, min_count=min_count)
@@ -98,7 +100,9 @@ def test_xarray_reduce_multiple_groupers(pass_expected_groups, chunk, engine_no_
labels2 = np.array([1, 2, 2, 1])
da = xr.DataArray(
- arr, dims=("x", "y"), coords={"labels2": ("x", labels2), "labels": ("y", labels)}
+ arr,
+ dims=("x", "y"),
+ coords={"labels2": ("x", labels2), "labels": ("y", labels)},
).expand_dims(z=4)
if chunk:
@@ -177,9 +181,7 @@ def test_xarray_reduce_multiple_groupers_2(pass_expected_groups, chunk, engine_n
(None, (None, None), [[1, 2], [1, 2]]),
)
def test_validate_expected_groups(expected_groups):
- da = xr.DataArray(
- [1.0, 2.0], dims="x", coords={"labels": ("x", [1, 2]), "labels2": ("x", [1, 2])}
- )
+ da = xr.DataArray([1.0, 2.0], dims="x", coords={"labels": ("x", [1, 2]), "labels2": ("x", [1, 2])})
with pytest.raises(ValueError):
xarray_reduce(
da.chunk({"x": 1}),
@@ -196,12 +198,13 @@ def test_xarray_reduce_single_grouper(engine_no_numba):
engine = engine_no_numba
# DataArray
ds = xr.Dataset(
- {"Tair": (("time", "x", "y"), dask.array.ones((36, 205, 275), chunks=(9, -1, -1)))},
- coords={
- "time": xr.date_range(
- "1980-09-01 00:00", "1983-09-18 00:00", freq="ME", calendar="noleap"
+ {
+ "Tair": (
+ ("time", "x", "y"),
+ dask.array.ones((36, 205, 275), chunks=(9, -1, -1)),
)
},
+ coords={"time": xr.date_range("1980-09-01 00:00", "1983-09-18 00:00", freq="ME", calendar="noleap")},
)
actual = xarray_reduce(ds.Tair, ds.time.dt.month, func="mean", engine=engine)
expected = ds.Tair.groupby("time.month").mean()
@@ -380,12 +383,13 @@ def test_func_is_aggregation():
from flox.aggregations import mean
ds = xr.Dataset(
- {"Tair": (("time", "x", "y"), dask.array.ones((36, 205, 275), chunks=(9, -1, -1)))},
- coords={
- "time": xr.date_range(
- "1980-09-01 00:00", "1983-09-18 00:00", freq="ME", calendar="noleap"
+ {
+ "Tair": (
+ ("time", "x", "y"),
+ dask.array.ones((36, 205, 275), chunks=(9, -1, -1)),
)
},
+ coords={"time": xr.date_range("1980-09-01 00:00", "1983-09-18 00:00", freq="ME", calendar="noleap")},
)
expected = xarray_reduce(ds.Tair, ds.time.dt.month, func="mean")
actual = xarray_reduce(ds.Tair, ds.time.dt.month, func=mean)
@@ -520,7 +524,10 @@ def test_dtype(add_nan, chunk, dtype, dtype_out, engine_no_numba):
data,
dims=("x", "t"),
coords={
- "labels": ("t", np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"]))
+ "labels": (
+ "t",
+ np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"]),
+ )
},
name="arr",
)
@@ -642,7 +649,11 @@ def test_fill_value_xarray_binning():
def test_groupby_2d_dataset():
d = {
"coords": {
- "bit_index": {"dims": ("bit_index",), "attrs": {"name": "bit_index"}, "data": [0, 1]},
+ "bit_index": {
+ "dims": ("bit_index",),
+ "attrs": {"name": "bit_index"},
+ "data": [0, 1],
+ },
"index": {"dims": ("index",), "data": [0, 6, 8, 10, 14]},
"clifford": {"dims": ("index",), "attrs": {}, "data": [1, 1, 4, 10, 4]},
},
@@ -664,18 +675,14 @@ def test_groupby_2d_dataset():
expected = ds.groupby("clifford").mean()
with xr.set_options(use_flox=True):
actual = ds.groupby("clifford").mean()
- assert (
- expected.counts.dims == actual.counts.dims
- ) # https://github.com/pydata/xarray/issues/8292
+ assert expected.counts.dims == actual.counts.dims # https://github.com/pydata/xarray/issues/8292
xr.testing.assert_identical(expected, actual)
@pytest.mark.parametrize("chunk", (pytest.param(True, marks=requires_dask), False))
def test_resampling_missing_groups(chunk):
# Regression test for https://github.com/pydata/xarray/issues/8592
- time_coords = pd.to_datetime(
- ["2018-06-13T03:40:36", "2018-06-13T05:50:37", "2018-06-15T03:02:34"]
- )
+ time_coords = pd.to_datetime(["2018-06-13T03:40:36", "2018-06-13T05:50:37", "2018-06-15T03:02:34"])
latitude_coords = [0.0]
longitude_coords = [0.0]
@@ -684,7 +691,11 @@ def test_resampling_missing_groups(chunk):
da = xr.DataArray(
data,
- coords={"time": time_coords, "latitude": latitude_coords, "longitude": longitude_coords},
+ coords={
+ "time": time_coords,
+ "latitude": latitude_coords,
+ "longitude": longitude_coords,
+ },
dims=["time", "latitude", "longitude"],
)
if chunk:
@@ -749,3 +760,26 @@ def test_direct_reduction(func):
with xr.set_options(use_flox=False):
expected = getattr(data.groupby("x", squeeze=False), func)(**kwargs)
xr.testing.assert_identical(expected, actual)
+
+
+ at pytest.mark.parametrize("reduction", ["max", "min", "nanmax", "nanmin", "sum", "nansum", "prod", "nanprod"])
+def test_groupby_preserve_dtype(reduction):
+ # all groups are present, we should follow numpy exactly
+ ds = xr.Dataset(
+ {
+ "test": (
+ ["x", "y"],
+ np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype="int16"),
+ )
+ },
+ coords={"idx": ("x", [1, 2, 1])},
+ )
+
+ kwargs = {"engine": "numpy"}
+ if "nan" in reduction:
+ kwargs["skipna"] = True
+ with xr.set_options(use_flox=True):
+ actual = getattr(ds.groupby("idx"), reduction.removeprefix("nan"))(**kwargs).test.dtype
+ expected = getattr(np, reduction)(ds.test.data, axis=0).dtype
+
+ assert actual == expected
View it on GitLab: https://salsa.debian.org/debian-gis-team/flox/-/commit/4a49f04c97f9feb2a15fd16224e7c3f9cbb19d18
--
View it on GitLab: https://salsa.debian.org/debian-gis-team/flox/-/commit/4a49f04c97f9feb2a15fd16224e7c3f9cbb19d18
You're receiving this email because of your account on salsa.debian.org.
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://alioth-lists.debian.net/pipermail/pkg-grass-devel/attachments/20240913/d1971efd/attachment-0001.htm>
More information about the Pkg-grass-devel
mailing list