[Git][debian-gis-team/flox][upstream] New upstream version 0.5.10

Antonio Valentino (@antonio.valentino) gitlab at salsa.debian.org
Mon Oct 10 07:05:10 BST 2022



Antonio Valentino pushed to branch upstream at Debian GIS Project / flox


Commits:
93acea4e by Antonio Valentino at 2022-10-10T05:49:23+00:00
New upstream version 0.5.10
- - - - -


30 changed files:

- + .github/workflows/benchmarks.yml
- + .github/workflows/ci-additional.yaml
- .github/workflows/ci.yaml
- .pre-commit-config.yaml
- README.md
- asv.conf.json → asv_bench/asv.conf.json
- + asv_bench/benchmarks/README_CI.md
- + asv_bench/benchmarks/__init__.py
- benchmarks/combine.py → asv_bench/benchmarks/combine.py
- benchmarks/reduce.py → asv_bench/benchmarks/reduce.py
- − benchmarks/__init__.py
- ci/docs.yml
- ci/environment.yml
- ci/minimal-requirements.yml
- ci/no-dask.yml
- ci/no-xarray.yml
- docs/source/conf.py
- docs/source/index.md
- docs/source/user-stories.md
- + docs/source/user-stories/overlaps.md
- flox/aggregate_flox.py
- flox/aggregations.py
- flox/cache.py
- flox/core.py
- flox/xarray.py
- flox/xrutils.py
- pyproject.toml
- setup.cfg
- tests/test_core.py
- tests/test_xarray.py


Changes:

=====================================
.github/workflows/benchmarks.yml
=====================================
@@ -0,0 +1,78 @@
+name: Benchmark
+
+on:
+  pull_request:
+    types: [opened, reopened, synchronize, labeled]
+  workflow_dispatch:
+
+jobs:
+  benchmark:
+    # if: ${{ contains( github.event.pull_request.labels.*.name, 'run-benchmark') && github.event_name == 'pull_request' || github.event_name == 'workflow_dispatch' }}  # Run if the PR has been labelled correctly.
+    if: ${{ github.event_name == 'pull_request' || github.event_name == 'workflow_dispatch' }}  # Always run.
+    name: Linux
+    runs-on: ubuntu-20.04
+    env:
+      ASV_DIR: "./asv_bench"
+
+    steps:
+      # We need the full repo to avoid this issue
+      # https://github.com/actions/checkout/issues/23
+      - uses: actions/checkout at v3
+        with:
+          fetch-depth: 0
+
+      - name: Set up conda environment
+        uses: mamba-org/provision-with-micromamba at v13
+        with:
+          environment-file: ci/environment.yml
+          environment-name: flox-tests
+          cache-env: true
+          # extra-specs: |
+          #   python="${{ matrix.python-version }}"
+
+      # - name: Setup some dependencies
+        # shell: bash -l {0}
+        # run: |
+          # pip install asv
+          # sudo apt-get update -y
+
+      - name: Run benchmarks
+        shell: bash -l {0}
+        id: benchmark
+        env:
+          OPENBLAS_NUM_THREADS: 1
+          MKL_NUM_THREADS: 1
+          OMP_NUM_THREADS: 1
+          ASV_FACTOR: 1.5
+          ASV_SKIP_SLOW: 1
+        run: |
+          set -x
+          # ID this runner
+          asv machine --yes
+          echo "Baseline:  ${{ github.event.pull_request.base.sha }} (${{ github.event.pull_request.base.label }})"
+          echo "Contender: ${GITHUB_SHA} (${{ github.event.pull_request.head.label }})"
+          # Use mamba for env creation
+          # export CONDA_EXE=$(which mamba)
+          export CONDA_EXE=$(which conda)
+          # Run benchmarks for current commit against base
+          ASV_OPTIONS="--split --show-stderr --factor $ASV_FACTOR"
+          asv continuous $ASV_OPTIONS ${{ github.event.pull_request.base.sha }} ${GITHUB_SHA} \
+              | sed "/Traceback \|failed$\|PERFORMANCE DECREASED/ s/^/::error::/" \
+              | tee benchmarks.log
+          # Report and export results for subsequent steps
+          if grep "Traceback \|failed\|PERFORMANCE DECREASED" benchmarks.log > /dev/null ; then
+              exit 1
+          fi
+        working-directory: ${{ env.ASV_DIR }}
+
+      - name: Add instructions to artifact
+        if: always()
+        run: |
+          cp benchmarks/README_CI.md benchmarks.log .asv/results/
+        working-directory: ${{ env.ASV_DIR }}
+
+      - uses: actions/upload-artifact at v3
+        if: always()
+        with:
+          name: asv-benchmark-results-${{ runner.os }}
+          path: ${{ env.ASV_DIR }}/.asv/results


=====================================
.github/workflows/ci-additional.yaml
=====================================
@@ -0,0 +1,118 @@
+name: CI Additional
+on:
+  push:
+    branches:
+      - "*"
+  pull_request:
+    branches:
+      - "*"
+  workflow_dispatch: # allows you to trigger manually
+
+concurrency:
+  group: ${{ github.workflow }}-${{ github.ref }}
+  cancel-in-progress: true
+
+jobs:
+  detect-ci-trigger:
+    name: detect ci trigger
+    runs-on: ubuntu-latest
+    if: |
+      github.repository == 'xarray-contrib/flox'
+      && (github.event_name == 'push' || github.event_name == 'pull_request')
+    outputs:
+      triggered: ${{ steps.detect-trigger.outputs.trigger-found }}
+    steps:
+      - uses: actions/checkout at v3
+        with:
+          fetch-depth: 2
+      - uses: xarray-contrib/ci-trigger at v1.1
+        id: detect-trigger
+        with:
+          keyword: "[skip-ci]"
+
+  doctest:
+    name: Doctests
+    runs-on: "ubuntu-latest"
+    needs: detect-ci-trigger
+    if: needs.detect-ci-trigger.outputs.triggered == 'false'
+    defaults:
+      run:
+        shell: bash -l {0}
+
+    env:
+      CONDA_ENV_FILE: ci/environment.yml
+      PYTHON_VERSION: "3.10"
+
+    steps:
+      - uses: actions/checkout at v3
+        with:
+          fetch-depth: 0 # Fetch all history for all branches and tags.
+
+      - name: set environment variables
+        run: |
+          echo "TODAY=$(date  +'%Y-%m-%d')" >> $GITHUB_ENV
+
+      - name: Setup micromamba
+        uses: mamba-org/provision-with-micromamba at 34071ca7df4983ccd272ed0d3625818b27b70dcc
+        with:
+          environment-file: ${{env.CONDA_ENV_FILE}}
+          environment-name: flox-tests
+          extra-specs: |
+            python=${{env.PYTHON_VERSION}}
+          cache-env: true
+          cache-env-key: "${{runner.os}}-${{runner.arch}}-py${{env.PYTHON_VERSION}}-${{env.TODAY}}-${{hashFiles(env.CONDA_ENV_FILE)}}"
+
+      - name: Install flox
+        run: |
+          python -m pip install --no-deps -e .
+      - name: Version info
+        run: |
+          conda info -a
+          conda list
+      - name: Run doctests
+        run: |
+          python -m pytest --doctest-modules flox --ignore flox/tests
+
+  mypy:
+    name: Mypy
+    runs-on: "ubuntu-latest"
+    needs: detect-ci-trigger
+    if: needs.detect-ci-trigger.outputs.triggered == 'false'
+    defaults:
+      run:
+        shell: bash -l {0}
+    env:
+      CONDA_ENV_FILE: ci/environment.yml
+      PYTHON_VERSION: "3.10"
+
+    steps:
+      - uses: actions/checkout at v3
+        with:
+          fetch-depth: 0 # Fetch all history for all branches and tags.
+
+      - name: set environment variables
+        run: |
+          echo "TODAY=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
+      - name: Setup micromamba
+        uses: mamba-org/provision-with-micromamba at 34071ca7df4983ccd272ed0d3625818b27b70dcc
+        with:
+          environment-file: ${{env.CONDA_ENV_FILE}}
+          environment-name: xarray-tests
+          extra-specs: |
+            python=${{env.PYTHON_VERSION}}
+          cache-env: true
+          cache-env-key: "${{runner.os}}-${{runner.arch}}-py${{env.PYTHON_VERSION}}-${{env.TODAY}}-${{hashFiles(env.CONDA_ENV_FILE)}}"
+      - name: Install xarray
+        run: |
+          python -m pip install --no-deps -e .
+      - name: Version info
+        run: |
+          conda info -a
+          conda list
+      - name: Install mypy
+        run: |
+          python -m pip install mypy
+
+      - name: Run mypy
+        run: |
+          python -m mypy --install-types --non-interactive


=====================================
.github/workflows/ci.yaml
=====================================
@@ -34,7 +34,7 @@ jobs:
         run: |
           echo "PYTHON_VERSION=${{ matrix.python-version }}" >> $GITHUB_ENV
       - name: Set up conda environment
-        uses: mamba-org/provision-with-micromamba at v12
+        uses: mamba-org/provision-with-micromamba at v13
         with:
           environment-file: ci/environment.yml
           environment-name: flox-tests
@@ -49,7 +49,7 @@ jobs:
         run: |
           pytest -n auto --cov=./ --cov-report=xml
       - name: Upload code coverage to Codecov
-        uses: codecov/codecov-action at v3.1.0
+        uses: codecov/codecov-action at v3.1.1
         with:
           file: ./coverage.xml
           flags: unittests
@@ -78,7 +78,7 @@ jobs:
         with:
           fetch-depth: 0 # Fetch all history for all branches and tags.
       - name: Set up conda environment
-        uses: mamba-org/provision-with-micromamba at v12
+        uses: mamba-org/provision-with-micromamba at v13
         with:
           environment-file: ci/${{ matrix.env }}.yml
           environment-name: flox-tests
@@ -101,7 +101,7 @@ jobs:
     steps:
       - uses: actions/checkout at v3
       - name: Set up conda environment
-        uses: mamba-org/provision-with-micromamba at v12
+        uses: mamba-org/provision-with-micromamba at v13
         with:
           environment-file: ci/upstream-dev-env.yml
           environment-name: flox-tests
@@ -123,7 +123,7 @@ jobs:
           repository: 'pydata/xarray'
           fetch-depth: 0 # Fetch all history for all branches and tags.
       - name: Set up conda environment
-        uses: mamba-org/provision-with-micromamba at v12
+        uses: mamba-org/provision-with-micromamba at v13
         with:
           environment-file: ci/requirements/environment.yml
           environment-name: xarray-tests


=====================================
.pre-commit-config.yaml
=====================================
@@ -10,12 +10,12 @@ repos:
         - id: check-docstring-first
 
     - repo: https://github.com/psf/black
-      rev: 22.6.0
+      rev: 22.8.0
       hooks:
         - id: black
 
     - repo: https://github.com/PyCQA/flake8
-      rev: 4.0.1
+      rev: 5.0.4
       hooks:
         - id: flake8
 


=====================================
README.md
=====================================
@@ -1,6 +1,6 @@
-[![GitHub Workflow CI Status](https://img.shields.io/github/workflow/status/dcherian/flox/CI?logo=github&style=flat)](https://github.com/dcherian/flox/actions)
-[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/dcherian/flox/main.svg)](https://results.pre-commit.ci/latest/github/dcherian/flox/main)
-[![image](https://img.shields.io/codecov/c/github/dcherian/flox.svg?style=flat)](https://codecov.io/gh/dcherian/flox)
+[![GitHub Workflow CI Status](https://img.shields.io/github/workflow/status/xarray-contrib/flox/CI?logo=github&style=flat)](https://github.com/xarray-contrib/flox/actions)
+[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/xarray-contrib/flox/main.svg)](https://results.pre-commit.ci/latest/github/xarray-contrib/flox/main)
+[![image](https://img.shields.io/codecov/c/github/xarray-contrib/flox.svg?style=flat)](https://codecov.io/gh/xarray-contrib/flox)
 [![Documentation Status](https://readthedocs.org/projects/flox/badge/?version=latest)](https://flox.readthedocs.io/en/latest/?badge=latest)
 
 [![PyPI](https://img.shields.io/pypi/v/flox.svg?style=flat)](https://pypi.org/project/flox/)
@@ -16,7 +16,7 @@ It was motivated by
 
 1.  Dask Dataframe GroupBy
     [blogpost](https://blog.dask.org/2019/10/08/df-groupby)
-2.  numpy_groupies in Xarray
+2.  [numpy_groupies](https://github.com/ml31415/numpy-groupies) in Xarray
     [issue](https://github.com/pydata/xarray/issues/4473)
 
 (See a


=====================================
asv.conf.json → asv_bench/asv.conf.json
=====================================
@@ -11,7 +11,7 @@
 
     // The URL or local path of the source code repository for the
     // project being benchmarked
-    "repo": ".",
+    "repo": "..",
 
     // The Python project's subdirectory in your repo.  If missing or
     // the empty string, the project is assumed to be located at the root
@@ -37,7 +37,7 @@
     // determined from "repo" by looking at the protocol in the URL
     // (if remote), or by looking for special directories, such as
     // ".git" (if local).
-    // "dvcs": "git",
+    "dvcs": "git",
 
     // The tool to use to create environments.  May be "conda",
     // "virtualenv" or other value depending on the plugins in use.
@@ -48,10 +48,10 @@
 
     // timeout in seconds for installing any dependencies in environment
     // defaults to 10 min
-    //"install_timeout": 600,
+    "install_timeout": 600,
 
     // the base URL to show a commit for the project.
-    "show_commit_url": "http://github.com/dcherian/flox/commit/",
+    "show_commit_url": "http://github.com/xarray-contrib/flox/commit/",
 
     // The Pythons you'd like to test against.  If not provided, defaults
     // to the current version of Python used to run `asv`.
@@ -114,7 +114,7 @@
 
     // The directory (relative to the current directory) that benchmarks are
     // stored in.  If not provided, defaults to "benchmarks"
-    // "benchmark_dir": "benchmarks",
+    "benchmark_dir": "benchmarks",
 
     // The directory (relative to the current directory) to cache the Python
     // environments in.  If not provided, defaults to "env"


=====================================
asv_bench/benchmarks/README_CI.md
=====================================
@@ -0,0 +1,122 @@
+# Benchmark CI
+
+<!-- Author: @jaimergp -->
+<!-- Last updated: 2021.07.06 -->
+<!-- Describes the work done as part of https://github.com/scikit-image/scikit-image/pull/5424 -->
+
+## How it works
+
+The `asv` suite can be run for any PR on GitHub Actions (check workflow `.github/workflows/benchmarks.yml`) by adding a `run-benchmark` label to said PR. This will trigger a job that will run the benchmarking suite for the current PR head (merged commit) against the PR base (usually `main`)..
+
+We use `asv continuous` to run the job, which runs a relative performance measurement. This means that there's no state to be saved and that regressions are only caught in terms of performance ratio (absolute numbers are available but they are not useful since we do not use stable hardware over time). `asv continuous` will:
+
+* Compile `scikit-image` for _both_ commits. We use `ccache` to speed up the process, and `mamba` is used to create the build environments.
+* Run the benchmark suite for both commits, _twice_  (since `processes=2` by default).
+* Generate a report table with performance ratios:
+    * `ratio=1.0` -> performance didn't change.
+    * `ratio<1.0` -> PR made it slower.
+    * `ratio>1.0` -> PR made it faster.
+
+Due to the sensitivity of the test, we cannot guarantee that false positives are not produced. In practice, values between `(0.7, 1.5)` are to be considered part of the measurement noise. When in doubt, running the benchmark suite one more time will provide more information about the test being a false positive or not.
+
+## Running the benchmarks on GitHub Actions
+
+1. On a PR, add the label `run-benchmark`.
+2. The CI job will be started. Checks will appear in the usual dashboard panel above the comment box.
+3. If more commits are added, the label checks will be grouped with the last commit checks _before_ you added the label.
+4. Alternatively, you can always go to the `Actions` tab in the repo and [filter for `workflow:Benchmark`](https://github.com/scikit-image/scikit-image/actions?query=workflow%3ABenchmark). Your username will be assigned to the `actor` field, so you can also filter the results with that if you need it.
+
+## The artifacts
+
+The CI job will also generate an artifact. This is the `.asv/results` directory compressed in a zip file. Its contents include:
+
+* `fv-xxxxx-xx/`. A directory for the machine that ran the suite. It contains three files:
+    * `<baseline>.json`, `<contender>.json`: the benchmark results for each commit, with stats.
+    * `machine.json`: details about the hardware.
+* `benchmarks.json`: metadata about the current benchmark suite.
+* `benchmarks.log`: the CI logs for this run.
+* This README.
+
+## Re-running the analysis
+
+Although the CI logs should be enough to get an idea of what happened (check the table at the end), one can use `asv` to run the analysis routines again.
+
+1. Uncompress the artifact contents in the repo, under `.asv/results`. This is, you should see `.asv/results/benchmarks.log`, not `.asv/results/something_else/benchmarks.log`. Write down the machine directory name for later.
+2. Run `asv show` to see your available results. You will see something like this:
+
+```
+$> asv show
+
+Commits with results:
+
+Machine    : Jaimes-MBP
+Environment: conda-py3.9-cython-numpy1.20-scipy
+
+    00875e67
+
+Machine    : fv-az95-499
+Environment: conda-py3.7-cython-numpy1.17-pooch-scipy
+
+    8db28f02
+    3a305096
+```
+
+3. We are interested in the commits for `fv-az95-499` (the CI machine for this run). We can compare them with `asv compare` and some extra options. `--sort ratio` will show largest ratios first, instead of alphabetical order. `--split` will produce three tables: improved, worsened, no changes. `--factor 1.5` tells `asv` to only complain if deviations are above a 1.5 ratio. `-m` is used to indicate the machine ID (use the one you wrote down in step 1). Finally, specify your commit hashes: baseline first, then contender!
+
+```
+$> asv compare --sort ratio --split --factor 1.5 -m fv-az95-499 8db28f02 3a305096
+
+Benchmarks that have stayed the same:
+
+       before           after         ratio
+     [8db28f02]       [3a305096]
+     <ci-benchmark-check~9^2>
+              n/a              n/a      n/a  benchmark_restoration.RollingBall.time_rollingball_ndim
+      1.23±0.04ms       1.37±0.1ms     1.12  benchmark_transform_warp.WarpSuite.time_to_float64(<class 'numpy.float64'>, 128, 3)
+       5.07±0.1μs       5.59±0.4μs     1.10  benchmark_transform_warp.ResizeLocalMeanSuite.time_resize_local_mean(<class 'numpy.float32'>, (192, 192, 192), (192, 192, 192))
+      1.23±0.02ms       1.33±0.1ms     1.08  benchmark_transform_warp.WarpSuite.time_same_type(<class 'numpy.float32'>, 128, 3)
+       9.45±0.2ms       10.1±0.5ms     1.07  benchmark_rank.Rank3DSuite.time_3d_filters('majority', (32, 32, 32))
+       23.0±0.9ms         24.6±1ms     1.07  benchmark_interpolation.InterpolationResize.time_resize((80, 80, 80), 0, 'symmetric', <class 'numpy.float64'>, True)
+         38.7±1ms         41.1±1ms     1.06  benchmark_transform_warp.ResizeLocalMeanSuite.time_resize_local_mean(<class 'numpy.float32'>, (2048, 2048), (192, 192, 192))
+       4.97±0.2μs       5.24±0.2μs     1.05  benchmark_transform_warp.ResizeLocalMeanSuite.time_resize_local_mean(<class 'numpy.float32'>, (2048, 2048), (2048, 2048))
+       4.21±0.2ms       4.42±0.3ms     1.05  benchmark_rank.Rank3DSuite.time_3d_filters('gradient', (32, 32, 32))
+
+...
+```
+
+If you want more details on a specific test, you can use `asv show`. Use `-b pattern` to filter which tests to show, and then specify a commit hash to inspect:
+
+```
+$> asv show -b time_to_float64 8db28f02
+
+Commit: 8db28f02 <ci-benchmark-check~9^2>
+
+benchmark_transform_warp.WarpSuite.time_to_float64 [fv-az95-499/conda-py3.7-cython-numpy1.17-pooch-scipy]
+  ok
+  =============== ============= ========== ============= ========== ============ ========== ============ ========== ============
+  --                                                                N / order
+  --------------- --------------------------------------------------------------------------------------------------------------
+      dtype_in       128 / 0     128 / 1      128 / 3     1024 / 0    1024 / 1    1024 / 3    4096 / 0    4096 / 1    4096 / 3
+  =============== ============= ========== ============= ========== ============ ========== ============ ========== ============
+    numpy.uint8    2.56±0.09ms   523±30μs   1.28±0.05ms   130±3ms     28.7±2ms    81.9±3ms   2.42±0.01s   659±5ms    1.48±0.01s
+    numpy.uint16   2.48±0.03ms   530±10μs   1.28±0.02ms   130±1ms    30.4±0.7ms   81.1±2ms    2.44±0s     653±3ms    1.47±0.02s
+   numpy.float32    2.59±0.1ms   518±20μs   1.27±0.01ms   127±3ms     26.6±1ms    74.8±2ms   2.50±0.01s   546±10ms   1.33±0.02s
+   numpy.float64   2.48±0.04ms   513±50μs   1.23±0.04ms   134±3ms     30.7±2ms    85.4±2ms   2.55±0.01s   632±4ms    1.45±0.01s
+  =============== ============= ========== ============= ========== ============ ========== ============ ========== ============
+  started: 2021-07-06 06:14:36, duration: 1.99m
+```
+
+## Other details
+
+### Skipping slow or demanding tests
+
+To minimize the time required to run the full suite, we trimmed the parameter matrix in some cases and, in others, directly skipped tests that ran for too long or require too much memory. Unlike `pytest`, `asv` does not have a notion of marks. However, you can `raise NotImplementedError` in the setup step to skip a test. In that vein, a new private function is defined at `benchmarks.__init__`: `_skip_slow`. This will check if the `ASV_SKIP_SLOW` environment variable has been defined. If set to `1`, it will raise `NotImplementedError` and skip the test. To implement this behavior in other tests, you can add the following attribute:
+
+```python
+from . import _skip_slow  # this function is defined in benchmarks.__init__
+
+def time_something_slow():
+    pass
+
+time_something.setup = _skip_slow
+```


=====================================
asv_bench/benchmarks/__init__.py
=====================================
@@ -0,0 +1,28 @@
+import os
+
+
+def parameterized(names, params):
+    def decorator(func):
+        func.param_names = names
+        func.params = params
+        return func
+
+    return decorator
+
+
+def _skip_slow():
+    """
+    Use this function to skip slow or highly demanding tests.
+
+    Use it as a `Class.setup` method or a `function.setup` attribute.
+
+    Examples
+    --------
+    >>> from . import _skip_slow
+    >>> def time_something_slow():
+    ...     pass
+    ...
+    >>> time_something.setup = _skip_slow
+    """
+    if os.environ.get("ASV_SKIP_SLOW", "0") == "1":
+        raise NotImplementedError("Skipping this test...")


=====================================
benchmarks/combine.py → asv_bench/benchmarks/combine.py
=====================================


=====================================
benchmarks/reduce.py → asv_bench/benchmarks/reduce.py
=====================================
@@ -6,7 +6,7 @@ import flox
 from . import parameterized
 
 N = 1000
-funcs = ["sum", "nansum", "mean", "nanmean", "argmax", "max"]
+funcs = ["sum", "nansum", "mean", "nanmean", "max"]
 engines = ["flox", "numpy"]
 
 


=====================================
benchmarks/__init__.py deleted
=====================================
@@ -1,7 +0,0 @@
-def parameterized(names, params):
-    def decorator(func):
-        func.param_names = names
-        func.params = params
-        return func
-
-    return decorator


=====================================
ci/docs.yml
=====================================
@@ -5,6 +5,7 @@ dependencies:
   - dask-core
   - pip
   - xarray
+  - numpy>=1.20
   - numpydoc
   - numpy_groupies
   - toolz
@@ -16,4 +17,4 @@ dependencies:
   - ipykernel
   - jupyter
   - pip:
-    - git+https://github.com/dcherian/flox
+    - git+https://github.com/xarray-contrib/flox


=====================================
ci/environment.yml
=====================================
@@ -2,11 +2,14 @@ name: flox-tests
 channels:
   - conda-forge
 dependencies:
+  - asv
   - cachey
   - codecov
   - dask-core
   - netcdf4
   - pandas
+  - numpy>=1.20
+  - matplotlib
   - pip
   - pytest
   - pytest-cov


=====================================
ci/minimal-requirements.yml
=====================================
@@ -8,7 +8,8 @@ dependencies:
   - pytest
   - pytest-cov
   - pytest-xdist
-  - numpy_groupies>=0.9.15
+  - numpy==1.20
+  - numpy_groupies==0.9.15
   - pandas
   - pooch
   - toolz


=====================================
ci/no-dask.yml
=====================================
@@ -5,6 +5,7 @@ dependencies:
   - codecov
   - netcdf4
   - pandas
+  - numpy>=1.20
   - pip
   - pytest
   - pytest-cov


=====================================
ci/no-xarray.yml
=====================================
@@ -5,6 +5,7 @@ dependencies:
   - codecov
   - netcdf4
   - pandas
+  - numpy>=1.20
   - pip
   - pytest
   - pytest-cov


=====================================
docs/source/conf.py
=====================================
@@ -43,8 +43,8 @@ extensions = [
 ]
 
 extlinks = {
-    "issue": ("https://github.com/dcherian/flox/issues/%s", "GH#"),
-    "pr": ("https://github.com/dcherian/flox/pull/%s", "GH#"),
+    "issue": ("https://github.com/xarray-contrib/flox/issues/%s", "GH#"),
+    "pr": ("https://github.com/xarray-contrib/flox/pull/%s", "GH#"),
 }
 
 templates_path = ["_templates"]


=====================================
docs/source/index.md
=====================================
@@ -2,9 +2,9 @@
 
 ## Overview
 
-[![GitHub Workflow CI Status](https://img.shields.io/github/workflow/status/dcherian/flox/CI?logo=github&style=flat)](https://github.com/dcherian/flox/actions)
-[![GitHub Workflow Code Style Status](https://img.shields.io/github/workflow/status/dcherian/flox/code-style?label=Code%20Style&style=flat)](https://github.com/dcherian/flox/actions)
-[![image](https://img.shields.io/codecov/c/github/dcherian/flox.svg?style=flat)](https://codecov.io/gh/dcherian/flox)
+[![GitHub Workflow CI Status](https://img.shields.io/github/workflow/status/xarray-contrib/flox/CI?logo=github&style=flat)](https://github.com/xarray-contrib/flox/actions)
+[![GitHub Workflow Code Style Status](https://img.shields.io/github/workflow/status/xarray-contrib/flox/code-style?label=Code%20Style&style=flat)](https://github.com/xarray-contrib/flox/actions)
+[![image](https://img.shields.io/codecov/c/github/xarray-contrib/flox.svg?style=flat)](https://codecov.io/gh/xarray-contrib/flox)
 [![PyPI](https://img.shields.io/pypi/v/flox.svg?style=flat)](https://pypi.org/project/flox/)
 [![Conda-forge](https://img.shields.io/conda/vn/conda-forge/flox.svg?style=flat)](https://anaconda.org/conda-forge/flox)
 


=====================================
docs/source/user-stories.md
=====================================
@@ -4,6 +4,7 @@
 .. toctree::
    :maxdepth: 1
 
+   user-stories/overlaps.md
    user-stories/climatology.ipynb
    user-stories/climatology-hourly.ipynb
    user-stories/custom-aggregations.ipynb


=====================================
docs/source/user-stories/overlaps.md
=====================================
@@ -0,0 +1,71 @@
+---
+jupytext:
+  text_representation:
+    format_name: myst
+kernelspec:
+  display_name: Python 3
+  name: python3
+---
+
+```{eval-rst}
+.. currentmodule:: flox
+```
+
+# Overlapping Groups
+
+This post is motivated by the problem of computing the [Meridional Overturning Circulation](https://en.wikipedia.org/wiki/Atlantic_meridional_overturning_circulation).
+One of the steps is a binned average over latitude, over regions of the World Ocean. Commonly we want to average
+globally, as well as over the Atlantic, and the Indo-Pacific. Generally group-by problems involve non-overlapping
+groups. In this example, the "global" group overlaps with the "Indo-Pacific" and "Atlantic" groups. Below we consider a simplified version of this problem.
+
+Consider the following labels:
+```{code-cell}
+import numpy as np
+import xarray as xr
+
+from flox.xarray import xarray_reduce
+
+labels = xr.DataArray(
+    [1, 2, 3, 1, 2, 3, 0, 0, 0],
+    dims="x",
+    name="label",
+)
+labels
+```
+
+These labels are non-overlapping. So when we reduce this data array over those labels along `x`
+```{code-cell}
+da = xr.ones_like(labels)
+da
+```
+we get (note the reduction over `x` is implicit here):
+
+```{code-cell}
+xarray_reduce(da, labels, func="sum")
+```
+
+Now let's _also_ calculate the `sum` where `labels` is either `1` or `2`..
+We could easily compute this using the grouped result but here we use this simple example for illustration.
+The trick is to add a new dimension with new labels (here `4`) in the appropriate locations.
+```{code-cell}
+# assign 4 where label == 1 or 2, and -1 otherwise
+newlabels = xr.where(labels.isin([1, 2]), 4, -1)
+
+# concatenate along a new dimension y;
+# note y is not present on da
+expanded = xr.concat([labels, newlabels], dim="y")
+expanded
+```
+
+Now we reduce over `x` _and_ the new dimension `y` (again implicitly) to get the appropriate sum under
+`label=4` (and `label=-1`). We can discard the value accumulated under `label=-1` later.
+```{code-cell}
+xarray_reduce(da, expanded, func="sum")
+```
+
+This way we compute all the reductions we need, in a single pass over the data.
+
+This technique generalizes to more complicated aggregations. The trick is to
+- generate appropriate labels
+- concatenate these new labels along a new dimension (`y`) absent on the object being reduced (`da`), and
+- reduce over that new dimension in addition to any others.


=====================================
flox/aggregate_flox.py
=====================================
@@ -5,6 +5,21 @@ import numpy as np
 from .xrutils import isnull
 
 
+def _prepare_for_flox(group_idx, array):
+    """
+    Sort the input array once to save time.
+    """
+    assert array.shape[-1] == group_idx.shape[0]
+    issorted = (group_idx[:-1] <= group_idx[1:]).all()
+    if issorted:
+        ordered_array = array
+    else:
+        perm = group_idx.argsort(kind="stable")
+        group_idx = group_idx[..., perm]
+        ordered_array = array[..., perm]
+    return group_idx, ordered_array
+
+
 def _np_grouped_op(group_idx, array, op, axis=-1, size=None, fill_value=None, dtype=None, out=None):
     """
     most of this code is from shoyer's gist
@@ -13,7 +28,7 @@ def _np_grouped_op(group_idx, array, op, axis=-1, size=None, fill_value=None, dt
     # assumes input is sorted, which I do in core._prepare_for_flox
     aux = group_idx
 
-    flag = np.concatenate(([True], aux[1:] != aux[:-1]))
+    flag = np.concatenate((np.array([True], like=array), aux[1:] != aux[:-1]))
     uniques = aux[flag]
     (inv_idx,) = flag.nonzero()
 
@@ -25,11 +40,11 @@ def _np_grouped_op(group_idx, array, op, axis=-1, size=None, fill_value=None, dt
     if out is None:
         out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
 
-    if (len(uniques) == size) and (uniques == np.arange(size)).all():
+    if (len(uniques) == size) and (uniques == np.arange(size, like=array)).all():
         # The previous version of this if condition
         #     ((uniques[1:] - uniques[:-1]) == 1).all():
         # does not work when group_idx is [1, 2] for e.g.
-        # This happens  during binning
+        # This happens during binning
         op.reduceat(array, inv_idx, axis=axis, dtype=dtype, out=out)
     else:
         out[..., uniques] = op.reduceat(array, inv_idx, axis=axis, dtype=dtype)
@@ -91,8 +106,7 @@ def nanlen(group_idx, array, *args, **kwargs):
 def mean(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None):
     if fill_value is None:
         fill_value = 0
-    out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
-    sum(group_idx, array, axis=axis, size=size, dtype=dtype, out=out)
+    out = sum(group_idx, array, axis=axis, size=size, dtype=dtype, fill_value=fill_value)
     out /= nanlen(group_idx, array, size=size, axis=axis, fill_value=0)
     return out
 
@@ -100,7 +114,6 @@ def mean(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None):
 def nanmean(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None):
     if fill_value is None:
         fill_value = 0
-    out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
-    nansum(group_idx, array, size=size, axis=axis, dtype=dtype, out=out)
+    out = nansum(group_idx, array, size=size, axis=axis, dtype=dtype, fill_value=fill_value)
     out /= nanlen(group_idx, array, size=size, axis=axis, fill_value=0)
     return out


=====================================
flox/aggregations.py
=====================================
@@ -46,6 +46,8 @@ def generic_aggregate(
             f"Expected engine to be one of ['flox', 'numpy', 'numba']. Received {engine} instead."
         )
 
+    group_idx = np.asarray(group_idx, like=array)
+
     return method(
         group_idx, array, axis=axis, size=size, fill_value=fill_value, dtype=dtype, **kwargs
     )
@@ -465,7 +467,7 @@ def _initialize_aggregation(
     func: str | Aggregation,
     array_dtype,
     fill_value,
-    min_count: int,
+    min_count: int | None,
     finalize_kwargs,
 ) -> Aggregation:
     if not isinstance(func, Aggregation):


=====================================
flox/cache.py
=====================================
@@ -8,4 +8,4 @@ try:
     cache = cachey.Cache(1e6)
     memoize = partial(cache.memoize, key=dask.base.tokenize)
 except ImportError:
-    memoize = lambda x: x
+    memoize = lambda x: x  # type: ignore


=====================================
flox/core.py
=====================================
@@ -2,10 +2,21 @@ from __future__ import annotations
 
 import copy
 import itertools
+import math
 import operator
 from collections import namedtuple
 from functools import partial, reduce
-from typing import TYPE_CHECKING, Any, Callable, Dict, Mapping, Sequence, Union
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Dict,
+    Iterable,
+    Literal,
+    Mapping,
+    Sequence,
+    Union,
+)
 
 import numpy as np
 import numpy_groupies as npg
@@ -13,6 +24,7 @@ import pandas as pd
 import toolz as tlz
 
 from . import xrdtypes
+from .aggregate_flox import _prepare_for_flox
 from .aggregations import (
     Aggregation,
     _atleast_1d,
@@ -25,6 +37,18 @@ from .xrutils import is_duck_array, is_duck_dask_array, isnull
 if TYPE_CHECKING:
     import dask.array.Array as DaskArray
 
+    T_Func = Union[str, Callable]
+    T_Funcs = Union[T_Func, Sequence[T_Func]]
+    T_Axis = int
+    T_Axes = tuple[T_Axis, ...]
+    T_AxesOpt = Union[T_Axis, T_Axes, None]
+    T_Dtypes = Union[np.typing.DTypeLike, Sequence[np.typing.DTypeLike], None]
+    T_FillValues = Union[np.typing.ArrayLike, Sequence[np.typing.ArrayLike], None]
+    T_Engine = Literal["flox", "numpy", "numba"]
+    T_MethodCohorts = Literal["cohorts", "split-reduce"]
+    T_Method = Literal["map-reduce", "blockwise", T_MethodCohorts]
+    T_IsBins = Union[bool | Sequence[bool]]
+
 
 IntermediateDict = Dict[Union[str, Callable], Any]
 FinalResultsDict = Dict[str, Union["DaskArray", np.ndarray]]
@@ -44,26 +68,9 @@ def _is_arg_reduction(func: str | Aggregation) -> bool:
     return False
 
 
-def _prepare_for_flox(group_idx, array):
-    """
-    Sort the input array once to save time.
-    """
-    assert array.shape[-1] == group_idx.shape[0]
-    issorted = (group_idx[:-1] <= group_idx[1:]).all()
-    if issorted:
-        ordered_array = array
-    else:
-        perm = group_idx.argsort(kind="stable")
-        group_idx = group_idx[..., perm]
-        ordered_array = array[..., perm]
-    return group_idx, ordered_array
-
-
-def _get_expected_groups(by, sort, *, raise_if_dask=True) -> pd.Index | None:
+def _get_expected_groups(by, sort: bool) -> pd.Index:
     if is_duck_dask_array(by):
-        if raise_if_dask:
-            raise ValueError("Please provide expected_groups if not grouping by a numpy array.")
-        return None
+        raise ValueError("Please provide expected_groups if not grouping by a numpy array.")
     flatby = by.reshape(-1)
     expected = pd.unique(flatby[~isnull(flatby)])
     return _convert_expected_groups_to_index((expected,), isbin=(False,), sort=sort)[0]
@@ -78,11 +85,11 @@ def _get_chunk_reduction(reduction_type: str) -> Callable:
         raise ValueError(f"Unknown reduction type: {reduction_type}")
 
 
-def is_nanlen(reduction: str | Callable) -> bool:
+def is_nanlen(reduction: T_Func) -> bool:
     return isinstance(reduction, str) and reduction == "nanlen"
 
 
-def _move_reduce_dims_to_end(arr: np.ndarray, axis: Sequence) -> np.ndarray:
+def _move_reduce_dims_to_end(arr: np.ndarray, axis: T_Axes) -> np.ndarray:
     """Transpose `arr` by moving `axis` to the end."""
     axis = tuple(axis)
     order = tuple(ax for ax in np.arange(arr.ndim) if ax not in axis) + axis
@@ -92,7 +99,7 @@ def _move_reduce_dims_to_end(arr: np.ndarray, axis: Sequence) -> np.ndarray:
 
 def _collapse_axis(arr: np.ndarray, naxis: int) -> np.ndarray:
     """Reshape so that the last `naxis` axes are collapsed to one axis."""
-    newshape = arr.shape[:-naxis] + (np.prod(arr.shape[-naxis:]),)
+    newshape = arr.shape[:-naxis] + (math.prod(arr.shape[-naxis:]),)
     return arr.reshape(newshape)
 
 
@@ -131,7 +138,7 @@ def _get_optimal_chunks_for_groups(chunks, labels):
 
 
 @memoize
-def find_group_cohorts(labels, chunks, merge=True, method="cohorts"):
+def find_group_cohorts(labels, chunks, merge=True, method: T_MethodCohorts = "cohorts"):
     """
     Finds groups labels that occur together aka "cohorts"
 
@@ -172,7 +179,7 @@ def find_group_cohorts(labels, chunks, merge=True, method="cohorts"):
 
     #  Iterate over each block and create a new block of same shape with "chunk number"
     shape = tuple(array.blocks.shape[ax] for ax in axis)
-    blocks = np.empty(np.prod(shape), dtype=object)
+    blocks = np.empty(math.prod(shape), dtype=object)
     for idx, block in enumerate(array.blocks.ravel()):
         blocks[idx] = np.full(tuple(block.shape[ax] for ax in axis), idx)
     which_chunk = np.block(blocks.reshape(shape).tolist()).reshape(-1)
@@ -218,7 +225,13 @@ def find_group_cohorts(labels, chunks, merge=True, method="cohorts"):
 
 
 def rechunk_for_cohorts(
-    array, axis, labels, force_new_chunk_at, chunksize=None, ignore_old_chunks=False, debug=False
+    array,
+    axis: T_Axis,
+    labels,
+    force_new_chunk_at,
+    chunksize=None,
+    ignore_old_chunks=False,
+    debug=False,
 ):
     """
     Rechunks array so that each new chunk contains groups that always occur together.
@@ -303,7 +316,7 @@ def rechunk_for_cohorts(
         return array.rechunk({axis: newchunks})
 
 
-def rechunk_for_blockwise(array, axis, labels):
+def rechunk_for_blockwise(array, axis: T_Axis, labels):
     """
     Rechunks array so that group boundaries line up with chunk boundaries, allowing
     embarassingly parallel group reductions.
@@ -336,7 +349,7 @@ def rechunk_for_blockwise(array, axis, labels):
 
 
 def reindex_(
-    array: np.ndarray, from_, to, fill_value=None, axis: int = -1, promote: bool = False
+    array: np.ndarray, from_, to, fill_value=None, axis: T_Axis = -1, promote: bool = False
 ) -> np.ndarray:
 
     if not isinstance(to, pd.Index):
@@ -389,19 +402,19 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]:
     """
     assert labels.ndim > 1
     offset: np.ndarray = (
-        labels + np.arange(np.prod(labels.shape[:-1])).reshape((*labels.shape[:-1], -1)) * ngroups
+        labels + np.arange(math.prod(labels.shape[:-1])).reshape((*labels.shape[:-1], -1)) * ngroups
     )
     # -1 indicates NaNs. preserve these otherwise we aggregate in the wrong groups!
     offset[labels == -1] = -1
-    size: int = np.prod(labels.shape[:-1]) * ngroups  # type: ignore
+    size: int = math.prod(labels.shape[:-1]) * ngroups  # type: ignore
     return offset, size
 
 
 def factorize_(
     by: tuple,
-    axis,
+    axis: T_AxesOpt,
     expected_groups: tuple[pd.Index, ...] = None,
-    reindex=False,
+    reindex: bool = False,
     sort=True,
     fastpath=False,
 ):
@@ -462,7 +475,7 @@ def factorize_(
         factorized.append(idx)
 
     grp_shape = tuple(len(grp) for grp in found_groups)
-    ngroups = np.prod(grp_shape)
+    ngroups = math.prod(grp_shape)
     if len(by) > 1:
         group_idx = np.ravel_multi_index(factorized, grp_shape, mode="wrap")
         # NaNs; as well as values outside the bins are coded by -1
@@ -505,13 +518,13 @@ def factorize_(
 def chunk_argreduce(
     array_plus_idx: tuple[np.ndarray, ...],
     by: np.ndarray,
-    func: Sequence[str],
+    func: T_Funcs,
     expected_groups: pd.Index | None,
-    axis: int | Sequence[int],
-    fill_value: Mapping[str | Callable, Any],
-    dtype=None,
+    axis: T_AxesOpt,
+    fill_value: T_FillValues,
+    dtype: T_Dtypes = None,
     reindex: bool = False,
-    engine: str = "numpy",
+    engine: T_Engine = "numpy",
     sort: bool = True,
 ) -> IntermediateDict:
     """
@@ -557,15 +570,15 @@ def chunk_argreduce(
 def chunk_reduce(
     array: np.ndarray,
     by: np.ndarray,
-    func: str | Callable | Sequence[str] | Sequence[Callable],
+    func: T_Funcs,
     expected_groups: pd.Index | None,
-    axis: int | Sequence[int] = None,
-    fill_value: Mapping[str | Callable, Any] = None,
-    dtype=None,
+    axis: T_AxesOpt = None,
+    fill_value: T_FillValues = None,
+    dtype: T_Dtypes = None,
     reindex: bool = False,
-    engine: str = "numpy",
-    kwargs=None,
-    sort=True,
+    engine: T_Engine = "numpy",
+    kwargs: Sequence[dict] | None = None,
+    sort: bool = True,
 ) -> IntermediateDict:
     """
     Wrapper for numpy_groupies aggregate that supports nD ``array`` and
@@ -596,28 +609,39 @@ def chunk_reduce(
     dict
     """
 
-    if dtype is not None:
-        assert isinstance(dtype, Sequence)
-    if fill_value is not None:
-        assert isinstance(fill_value, Sequence)
-
-    if isinstance(func, str) or callable(func):
-        func = (func,)  # type: ignore
+    if not (isinstance(func, str) or callable(func)):
+        funcs = func
+    else:
+        funcs = (func,)
+    nfuncs = len(funcs)
 
-    func: Sequence[str] | Sequence[Callable]
+    if isinstance(dtype, Sequence):
+        dtypes = dtype
+    else:
+        dtypes = (dtype,) * nfuncs
+    assert len(dtypes) >= nfuncs
 
-    nax = len(axis) if isinstance(axis, Sequence) else by.ndim
-    final_array_shape = array.shape[:-nax] + (1,) * (nax - 1)
-    final_groups_shape = (1,) * (nax - 1)
+    if isinstance(fill_value, Sequence):
+        fill_values = fill_value
+    else:
+        fill_values = (fill_value,) * nfuncs
+    assert len(fill_values) >= nfuncs
 
-    if isinstance(axis, Sequence) and len(axis) == 1:
-        axis = next(iter(axis))
+    if isinstance(kwargs, Sequence):
+        kwargss = kwargs
+    else:
+        kwargss = ({},) * nfuncs
+    assert len(kwargss) >= nfuncs
 
-    if not isinstance(fill_value, Sequence):
-        fill_value = (fill_value,)
+    if isinstance(axis, Sequence):
+        nax = len(axis)
+        if nax == 1:
+            axis = axis[0]
+    else:
+        nax = by.ndim
 
-    if kwargs is None:
-        kwargs = ({},) * len(func)
+    final_array_shape = array.shape[:-nax] + (1,) * (nax - 1)
+    final_groups_shape = (1,) * (nax - 1)
 
     # when axis is a tuple
     # collapse and move reduction dimensions to the end
@@ -637,7 +661,7 @@ def chunk_reduce(
     groups = groups[0]
 
     # always reshape to 1D along group dimensions
-    newshape = array.shape[: array.ndim - by.ndim] + (np.prod(array.shape[-by.ndim :]),)
+    newshape = array.shape[: array.ndim - by.ndim] + (math.prod(array.shape[-by.ndim :]),)
     array = array.reshape(newshape)
 
     assert group_idx.ndim == 1
@@ -667,10 +691,8 @@ def chunk_reduce(
     # we commonly have func=(..., "nanlen", "nanlen") when
     # counts are needed for the final result as well as for masking
     # optimize that out.
-    previous_reduction = None
-    for param in (fill_value, kwargs, dtype):
-        assert len(param) >= len(func)
-    for reduction, fv, kw, dt in zip(func, fill_value, kwargs, dtype):
+    previous_reduction: T_Func = ""
+    for reduction, fv, kw, dt in zip(funcs, fill_values, kwargss, dtypes):
         if empty:
             result = np.full(shape=final_array_shape, fill_value=fv)
         else:
@@ -678,16 +700,16 @@ def chunk_reduce(
                 result = results["intermediates"][-1]
 
             # fill_value here is necessary when reducing with "offset" groups
-            kwargs = dict(size=size, dtype=dt, fill_value=fv)
-            kwargs.update(kw)
+            kw_func = dict(size=size, dtype=dt, fill_value=fv)
+            kw_func.update(kw)
 
             if callable(reduction):
                 # passing a custom reduction for npg to apply per-group is really slow!
                 # So this `reduction` has to do the groupby-aggregation
-                result = reduction(group_idx, array, **kwargs)
+                result = reduction(group_idx, array, **kw_func)
             else:
                 result = generic_aggregate(
-                    group_idx, array, axis=-1, engine=engine, func=reduction, **kwargs
+                    group_idx, array, axis=-1, engine=engine, func=reduction, **kw_func
                 ).astype(dt, copy=False)
             if np.any(props.nanmask):
                 # remove NaN group label which should be last
@@ -699,7 +721,7 @@ def chunk_reduce(
     return results
 
 
-def _squeeze_results(results: IntermediateDict, axis: Sequence) -> IntermediateDict:
+def _squeeze_results(results: IntermediateDict, axis: T_Axes) -> IntermediateDict:
     # at the end we squeeze out extra dims
     groups = results["groups"]
     newresults: IntermediateDict = {"groups": [], "intermediates": []}
@@ -722,11 +744,11 @@ def _split_groups(array, j, slicer):
 def _finalize_results(
     results: IntermediateDict,
     agg: Aggregation,
-    axis: Sequence[int],
+    axis: T_Axes,
     expected_groups: pd.Index | None,
     fill_value: Any,
     reindex: bool,
-):
+) -> FinalResultsDict:
     """Finalize results by
     1. Squeezing out dummy dimensions
     2. Calling agg.finalize with intermediate results
@@ -740,7 +762,7 @@ def _finalize_results(
         squeezed["intermediates"] = squeezed["intermediates"][:-1]
 
     # finalize step
-    finalized: dict[str, DaskArray | np.ndarray] = {}
+    finalized: FinalResultsDict = {}
     if agg.finalize is None:
         finalized[agg.name] = squeezed["intermediates"][0]
     else:
@@ -776,7 +798,7 @@ def _aggregate(
     combine: Callable,
     agg: Aggregation,
     expected_groups: pd.Index | None,
-    axis: Sequence,
+    axis: T_Axes,
     keepdims,
     fill_value: Any,
     reindex: bool,
@@ -794,7 +816,7 @@ def _expand_dims(results: IntermediateDict) -> IntermediateDict:
 
 
 def _simple_combine(
-    x_chunk, agg: Aggregation, axis: Sequence, keepdims: bool, is_aggregate: bool = False
+    x_chunk, agg: Aggregation, axis: T_Axes, keepdims: bool, is_aggregate: bool = False
 ) -> IntermediateDict:
     """
     'Simple' combination of blockwise results.
@@ -808,12 +830,13 @@ def _simple_combine(
     """
     from dask.array.core import deepfirst
 
-    results = {"groups": deepfirst(x_chunk)["groups"]}
+    results: IntermediateDict = {"groups": deepfirst(x_chunk)["groups"]}
     results["intermediates"] = []
+    axis_ = axis[:-1] + (DUMMY_AXIS,)
     for idx, combine in enumerate(agg.combine):
-        array = _conc2(x_chunk, key1="intermediates", key2=idx, axis=axis[:-1] + (DUMMY_AXIS,))
+        array = _conc2(x_chunk, key1="intermediates", key2=idx, axis=axis_)
         assert array.ndim >= 2
-        result = getattr(np, combine)(array, axis=axis[:-1] + (DUMMY_AXIS,), keepdims=True)
+        result = getattr(np, combine)(array, axis=axis_, keepdims=True)
         if is_aggregate:
             # squeeze out DUMMY_AXIS if this is the last step i.e. called from _aggregate
             result = result.squeeze(axis=DUMMY_AXIS)
@@ -821,7 +844,7 @@ def _simple_combine(
     return results
 
 
-def _conc2(x_chunk, key1, key2=slice(None), axis=None) -> np.ndarray:
+def _conc2(x_chunk, key1, key2=slice(None), axis: T_Axes = None) -> np.ndarray:
     """copied from dask.array.reductions.mean_combine"""
     from dask.array.core import _concatenate2
     from dask.utils import deepmap
@@ -855,10 +878,10 @@ def listify_groups(x):
 def _grouped_combine(
     x_chunk,
     agg: Aggregation,
-    axis: Sequence,
+    axis: T_Axes,
     keepdims: bool,
-    neg_axis: Sequence,
-    engine: str,
+    neg_axis: T_Axes,
+    engine: T_Engine,
     is_aggregate: bool = False,
     sort: bool = True,
 ) -> IntermediateDict:
@@ -902,7 +925,7 @@ def _grouped_combine(
         # This happens when we are reducing along an axis with a single chunk.
         avoid_reduction = array_idx[0].shape[axis[0]] == 1
         if avoid_reduction:
-            results = {"groups": groups, "intermediates": list(array_idx)}
+            results: IntermediateDict = {"groups": groups, "intermediates": list(array_idx)}
         else:
             results = chunk_argreduce(
                 array_idx,
@@ -1001,7 +1024,9 @@ def split_blocks(applied, split_out, expected_groups, split_name):
     return intermediate, group_chunks
 
 
-def _reduce_blockwise(array, by, agg, *, axis, expected_groups, fill_value, engine, sort, reindex):
+def _reduce_blockwise(
+    array, by, agg, *, axis: T_Axes, expected_groups, fill_value, engine: T_Engine, sort, reindex
+) -> FinalResultsDict:
     """
     Blockwise groupby reduction that produces the final result. This code path is
     also used for non-dask array aggregations.
@@ -1032,7 +1057,7 @@ def _reduce_blockwise(array, by, agg, *, axis, expected_groups, fill_value, engi
         engine=engine,
         sort=sort,
         reindex=reindex,
-    )  # type: ignore
+    )
 
     if _is_arg_reduction(agg):
         results["intermediates"][0] = np.unravel_index(results["intermediates"][0], array.shape)[-1]
@@ -1048,14 +1073,14 @@ def dask_groupby_agg(
     by: DaskArray | np.ndarray,
     agg: Aggregation,
     expected_groups: pd.Index | None,
-    axis: Sequence = None,
+    axis: T_Axes = (),
     split_out: int = 1,
     fill_value: Any = None,
-    method: str = "map-reduce",
+    method: T_Method = "map-reduce",
     reindex: bool = False,
-    engine: str = "numpy",
+    engine: T_Engine = "numpy",
     sort: bool = True,
-) -> tuple[DaskArray, np.ndarray | DaskArray]:
+) -> tuple[DaskArray, tuple[np.ndarray | DaskArray]]:
 
     import dask.array
     from dask.array.core import slices_from_chunks
@@ -1157,18 +1182,21 @@ def dask_groupby_agg(
     else:
         intermediate = applied
         if expected_groups is None:
-            expected_groups = _get_expected_groups(by_input, sort=sort, raise_if_dask=False)
+            if is_duck_dask_array(by_input):
+                expected_groups = None
+            else:
+                expected_groups = _get_expected_groups(by_input, sort=sort)
         group_chunks = ((len(expected_groups),) if expected_groups is not None else (np.nan,),)
 
     if method == "map-reduce":
         # these are negative axis indices useful for concatenating the intermediates
         neg_axis = tuple(range(-len(axis), 0))
 
-        combine = (
-            _simple_combine
-            if do_simple_combine
-            else partial(_grouped_combine, engine=engine, neg_axis=neg_axis, sort=sort)
-        )
+        combine: Callable[..., IntermediateDict]
+        if do_simple_combine:
+            combine = _simple_combine
+        else:
+            combine = partial(_grouped_combine, engine=engine, neg_axis=neg_axis, sort=sort)
 
         # reduced is really a dict mapping reduction name to array
         # and "groups" to an array of group labels
@@ -1207,13 +1235,12 @@ def dask_groupby_agg(
             groups_in_block = tuple(
                 np.intersect1d(by_input[slc], expected_groups) for slc in slices
             )
-        ngroups_per_block = tuple(len(groups) for groups in groups_in_block)
+        ngroups_per_block = tuple(len(grp) for grp in groups_in_block)
         output_chunks = reduced.chunks[: -(len(axis))] + (ngroups_per_block,)
     else:
         raise ValueError(f"Unknown method={method}.")
 
     # extract results from the dict
-    result: dict = {}
     layer: dict[tuple, tuple] = {}
     ochunks = tuple(range(len(chunks_v)) for chunks_v in output_chunks)
     if is_duck_dask_array(by_input) and expected_groups is None:
@@ -1225,7 +1252,7 @@ def dask_groupby_agg(
             (reduced.name, *first_block),
             "groups",
         )
-        groups = (
+        groups: tuple[np.ndarray | DaskArray] = (
             dask.array.Array(
                 HighLevelGraph.from_collections(groups_name, layer, dependencies=[reduced]),
                 groups_name,
@@ -1236,12 +1263,14 @@ def dask_groupby_agg(
     else:
         if method == "map-reduce":
             if expected_groups is None:
-                expected_groups = _get_expected_groups(by_input, sort=sort)
-            groups = (expected_groups.to_numpy(),)
+                expected_groups_ = _get_expected_groups(by_input, sort=sort)
+            else:
+                expected_groups_ = expected_groups
+            groups = (expected_groups_.to_numpy(),)
         else:
             groups = (np.concatenate(groups_in_block),)
 
-    layer: dict[tuple, tuple] = {}  # type: ignore
+    layer2: dict[tuple, tuple] = {}
     agg_name = f"{name}-{token}"
     for ochunk in itertools.product(*ochunks):
         if method == "blockwise":
@@ -1252,24 +1281,24 @@ def dask_groupby_agg(
                 inchunk = ochunk[:-1] + np.unravel_index(ochunk[-1], nblocks)
         else:
             inchunk = ochunk[:-1] + (0,) * len(axis) + (ochunk[-1],) * int(split_out > 1)
-        layer[(agg_name, *ochunk)] = (operator.getitem, (reduced.name, *inchunk), agg.name)
+        layer2[(agg_name, *ochunk)] = (operator.getitem, (reduced.name, *inchunk), agg.name)
 
     result = dask.array.Array(
-        HighLevelGraph.from_collections(agg_name, layer, dependencies=[reduced]),
+        HighLevelGraph.from_collections(agg_name, layer2, dependencies=[reduced]),
         agg_name,
         chunks=output_chunks,
         dtype=agg.dtype[agg.name],
     )
 
-    return (result, *groups)
-
+    return (result, groups)
 
-def _validate_reindex(reindex: bool, func, method, expected_groups) -> bool:
-    if reindex is True and _is_arg_reduction(func):
-        raise NotImplementedError
 
-    if method == "blockwise" and reindex is True:
-        raise NotImplementedError
+def _validate_reindex(reindex: bool | None, func, method: T_Method, expected_groups) -> bool | None:
+    if reindex is True:
+        if _is_arg_reduction(func):
+            raise NotImplementedError
+        if method == "blockwise":
+            raise NotImplementedError
 
     if method == "blockwise" or _is_arg_reduction(func):
         reindex = False
@@ -1280,6 +1309,8 @@ def _validate_reindex(reindex: bool, func, method, expected_groups) -> bool:
     if method in ["split-reduce", "cohorts"] and reindex is False:
         raise NotImplementedError
 
+    # TODO: Should reindex be a bool-only at this point? Would've been nice but
+    # None's are relied on after this function as well.
     return reindex
 
 
@@ -1296,9 +1327,9 @@ def _assert_by_is_aligned(shape, by):
 
 
 def _convert_expected_groups_to_index(
-    expected_groups: tuple, isbin: bool, sort: bool
-) -> pd.Index | None:
-    out = []
+    expected_groups: Iterable, isbin: Sequence[bool], sort: bool
+) -> tuple[pd.Index | None, ...]:
+    out: list[pd.Index | None] = []
     for ex, isbin_ in zip(expected_groups, isbin):
         if isinstance(ex, pd.IntervalIndex) or (isinstance(ex, pd.Index) and not isbin):
             if sort:
@@ -1361,13 +1392,13 @@ def groupby_reduce(
     func: str | Aggregation,
     expected_groups: Sequence | np.ndarray | None = None,
     sort: bool = True,
-    isbin: bool = False,
-    axis=None,
+    isbin: T_IsBins = False,
+    axis: T_AxesOpt = None,
     fill_value=None,
     min_count: int | None = None,
     split_out: int = 1,
-    method: str = "map-reduce",
-    engine: str = "flox",
+    method: T_Method = "map-reduce",
+    engine: T_Engine = "numpy",
     reindex: bool | None = None,
     finalize_kwargs: Mapping | None = None,
 ) -> tuple[DaskArray, np.ndarray | DaskArray]:
@@ -1434,13 +1465,14 @@ def groupby_reduce(
             and is identical to xarray's default strategy.
     engine : {"flox", "numpy", "numba"}, optional
         Algorithm to compute the groupby reduction on non-dask arrays and on each dask chunk:
+          * ``"numpy"``:
+            Use the vectorized implementations in ``numpy_groupies.aggregate_numpy``.
+            This is the default choice because it works for most array types.
           * ``"flox"``:
             Use an internal implementation where the data is sorted so that
             all members of a group occur sequentially, and then numpy.ufunc.reduceat
             is to used for the reduction. This will fall back to ``numpy_groupies.aggregate_numpy``
             for a reduction that is not yet implemented.
-          * ``"numpy"``:
-            Use the vectorized implementations in ``numpy_groupies.aggregate_numpy``.
           * ``"numba"``:
             Use the implementations in ``numpy_groupies.aggregate_numba``.
     reindex : bool, optional
@@ -1471,9 +1503,9 @@ def groupby_reduce(
         )
     reindex = _validate_reindex(reindex, func, method, expected_groups)
 
-    by: tuple = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by)
-    nby = len(by)
-    by_is_dask = any(is_duck_dask_array(b) for b in by)
+    bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by)
+    nby = len(bys)
+    by_is_dask = any(is_duck_dask_array(b) for b in bys)
 
     if method in ["split-reduce", "cohorts"] and by_is_dask:
         raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")
@@ -1482,54 +1514,58 @@ def groupby_reduce(
         array = np.asarray(array)
     array = array.astype(int) if np.issubdtype(array.dtype, bool) else array
 
-    if isinstance(isbin, bool):
-        isbin = (isbin,) * len(by)
+    if isinstance(isbin, Sequence):
+        isbins = isbin
+    else:
+        isbins = (isbin,) * nby
     if expected_groups is None:
-        expected_groups = (None,) * len(by)
+        expected_groups = (None,) * nby
 
-    _assert_by_is_aligned(array.shape, by)
+    _assert_by_is_aligned(array.shape, bys)
 
-    if len(by) == 1 and not isinstance(expected_groups, tuple):
+    if nby == 1 and not isinstance(expected_groups, tuple):
         expected_groups = (np.asarray(expected_groups),)
-    elif len(expected_groups) != len(by):
+    elif len(expected_groups) != nby:
         raise ValueError(
             f"Must have same number of `expected_groups` (received {len(expected_groups)}) "
-            f" and variables to group by (received {len(by)})."
+            f" and variables to group by (received {nby})."
         )
 
     # We convert to pd.Index since that lets us know if we are binning or not
     # (pd.IntervalIndex or not)
-    expected_groups = _convert_expected_groups_to_index(expected_groups, isbin, sort)
+    expected_groups = _convert_expected_groups_to_index(expected_groups, isbins, sort)
 
     # TODO: could restrict this to dask-only
     factorize_early = (nby > 1) or (
-        any(isbin) and method in ["split-reduce", "cohorts"] and is_duck_dask_array(array)
+        any(isbins) and method in ["split-reduce", "cohorts"] and is_duck_dask_array(array)
     )
     if factorize_early:
-        by, final_groups, grp_shape = _factorize_multiple(
-            by, expected_groups, by_is_dask=by_is_dask, reindex=reindex
+        bys, final_groups, grp_shape = _factorize_multiple(
+            bys, expected_groups, by_is_dask=by_is_dask, reindex=reindex
         )
-        expected_groups = (pd.RangeIndex(np.prod(grp_shape)),)
+        expected_groups = (pd.RangeIndex(math.prod(grp_shape)),)
 
-    assert len(by) == 1
-    by = by[0]
+    assert len(bys) == 1
+    by_ = bys[0]
     expected_groups = expected_groups[0]
 
     if axis is None:
-        axis = tuple(array.ndim + np.arange(-by.ndim, 0))
+        axis_ = tuple(array.ndim + np.arange(-by_.ndim, 0))
     else:
-        axis = np.core.numeric.normalize_axis_tuple(axis, array.ndim)  # type: ignore
+        # TODO: How come this function doesn't exist according to mypy?
+        axis_ = np.core.numeric.normalize_axis_tuple(axis, array.ndim)  # type: ignore
+    nax = len(axis_)
 
-    if method in ["blockwise", "cohorts", "split-reduce"] and len(axis) != by.ndim:
+    if method in ["blockwise", "cohorts", "split-reduce"] and nax != by_.ndim:
         raise NotImplementedError(
             "Must reduce along all dimensions of `by` when method != 'map-reduce'."
             f"Received method={method!r}"
         )
 
     # TODO: make sure expected_groups is unique
-    if len(axis) == 1 and by.ndim > 1 and expected_groups is None:
+    if nax == 1 and by_.ndim > 1 and expected_groups is None:
         if not by_is_dask:
-            expected_groups = _get_expected_groups(by, sort)
+            expected_groups = _get_expected_groups(by_, sort)
         else:
             # When we reduce along all axes, we are guaranteed to see all
             # groups in the final combine stage, so everything works.
@@ -1542,13 +1578,14 @@ def groupby_reduce(
                 "Please provide ``expected_groups`` when not reducing along all axes."
             )
 
-    assert len(axis) <= by.ndim
-    if len(axis) < by.ndim:
-        by = _move_reduce_dims_to_end(by, -array.ndim + np.array(axis) + by.ndim)
-        array = _move_reduce_dims_to_end(array, axis)
-        axis = tuple(array.ndim + np.arange(-len(axis), 0))
+    assert nax <= by_.ndim
+    if nax < by_.ndim:
+        by_ = _move_reduce_dims_to_end(by_, tuple(-array.ndim + ax + by_.ndim for ax in axis_))
+        array = _move_reduce_dims_to_end(array, axis_)
+        axis_ = tuple(array.ndim + np.arange(-nax, 0))
+        nax = len(axis_)
 
-    has_dask = is_duck_dask_array(array) or is_duck_dask_array(by)
+    has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)
 
     # When axis is a subset of possible values; then npg will
     # apply it to groups that don't exist along a particular axis (for e.g.)
@@ -1557,7 +1594,7 @@ def groupby_reduce(
     #     The only way to do this consistently is mask out using min_count
     #     Consider np.sum([np.nan]) = np.nan, np.nansum([np.nan]) = 0
     if min_count is None:
-        if len(axis) < by.ndim or fill_value is not None:
+        if nax < by_.ndim or fill_value is not None:
             min_count = 1
 
     # TODO: set in xarray?
@@ -1566,20 +1603,24 @@ def groupby_reduce(
         # overwrite than when min_count is set
         fill_value = np.nan
 
-    kwargs = dict(axis=axis, fill_value=fill_value, engine=engine)
+    kwargs = dict(axis=axis_, fill_value=fill_value, engine=engine)
     agg = _initialize_aggregation(func, array.dtype, fill_value, min_count, finalize_kwargs)
 
     if not has_dask:
         results = _reduce_blockwise(
-            array, by, agg, expected_groups=expected_groups, reindex=reindex, sort=sort, **kwargs
+            array, by_, agg, expected_groups=expected_groups, reindex=reindex, sort=sort, **kwargs
         )
         groups = (results["groups"],)
         result = results[agg.name]
 
     else:
+        if TYPE_CHECKING:
+            # TODO: How else to narrow that array.chunks is there?
+            assert isinstance(array, DaskArray)
+
         if agg.chunk[0] is None and method != "blockwise":
             raise NotImplementedError(
-                f"Aggregation {func.name!r} is only implemented for dask arrays when method='blockwise'."
+                f"Aggregation {agg.name!r} is only implemented for dask arrays when method='blockwise'."
                 f"\n\n Received: {func}"
             )
 
@@ -1591,25 +1632,25 @@ def groupby_reduce(
 
         if method in ["split-reduce", "cohorts"]:
             cohorts = find_group_cohorts(
-                by, [array.chunks[ax] for ax in axis], merge=True, method=method
+                by_, [array.chunks[ax] for ax in axis_], merge=True, method=method
             )
 
-            results = []
+            results_ = []
             groups_ = []
             for cohort in cohorts:
                 cohort = sorted(cohort)
                 # equivalent of xarray.DataArray.where(mask, drop=True)
-                mask = np.isin(by, cohort)
+                mask = np.isin(by_, cohort)
                 indexer = [np.unique(v) for v in np.nonzero(mask)]
                 array_subset = array
-                for ax, idxr in zip(range(-by.ndim, 0), indexer):
+                for ax, idxr in zip(range(-by_.ndim, 0), indexer):
                     array_subset = np.take(array_subset, idxr, axis=ax)
-                numblocks = np.prod([len(array_subset.chunks[ax]) for ax in axis])
+                numblocks = math.prod([len(array_subset.chunks[ax]) for ax in axis_])
 
                 # get final result for these groups
                 r, *g = partial_agg(
                     array_subset,
-                    by[np.ix_(*indexer)],
+                    by_[np.ix_(*indexer)],
                     expected_groups=pd.Index(cohort),
                     # First deep copy becasue we might be doping blockwise,
                     # which sets agg.finalize=None, then map-reduce (GH102)
@@ -1622,22 +1663,22 @@ def groupby_reduce(
                     sort=False,
                     # if only a single block along axis, we can just work blockwise
                     # inspired by https://github.com/dask/dask/issues/8361
-                    method="blockwise" if numblocks == 1 and len(axis) == by.ndim else "map-reduce",
+                    method="blockwise" if numblocks == 1 and nax == by_.ndim else "map-reduce",
                 )
-                results.append(r)
+                results_.append(r)
                 groups_.append(cohort)
 
             # concatenate results together,
             # sort to make sure we match expected output
             groups = (np.hstack(groups_),)
-            result = np.concatenate(results, axis=-1)
+            result = np.concatenate(results_, axis=-1)
         else:
-            if method == "blockwise" and by.ndim == 1:
-                array = rechunk_for_blockwise(array, axis=-1, labels=by)
+            if method == "blockwise" and by_.ndim == 1:
+                array = rechunk_for_blockwise(array, axis=-1, labels=by_)
 
-            result, *groups = partial_agg(
+            result, groups = partial_agg(
                 array,
-                by,
+                by_,
                 expected_groups=None if method == "blockwise" else expected_groups,
                 agg=agg,
                 reindex=reindex,


=====================================
flox/xarray.py
=====================================
@@ -1,6 +1,6 @@
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, Hashable, Iterable, Sequence
+from typing import TYPE_CHECKING, Any, Hashable, Iterable, Sequence, Union
 
 import numpy as np
 import pandas as pd
@@ -19,7 +19,10 @@ from .core import (
 from .xrutils import _contains_cftime_datetimes, _to_pytimedelta, datetime_to_numeric
 
 if TYPE_CHECKING:
-    from xarray import DataArray, Dataset, Resample
+    from xarray.core.resample import Resample
+    from xarray.core.types import T_DataArray, T_Dataset
+
+    Dims = Union[str, Iterable[Hashable], None]
 
 
 def _get_input_core_dims(group_names, dim, ds, grouper_dims):
@@ -51,17 +54,17 @@ def _restore_dim_order(result, obj, by):
 
 
 def xarray_reduce(
-    obj: Dataset | DataArray,
-    *by: DataArray | Iterable[str] | Iterable[DataArray],
+    obj: T_Dataset | T_DataArray,
+    *by: T_DataArray | Hashable,
     func: str | Aggregation,
     expected_groups=None,
     isbin: bool | Sequence[bool] = False,
     sort: bool = True,
-    dim: Hashable = None,
+    dim: Dims | ellipsis = None,
     split_out: int = 1,
     fill_value=None,
     method: str = "map-reduce",
-    engine: str = "flox",
+    engine: str = "numpy",
     keep_attrs: bool | None = True,
     skipna: bool | None = None,
     min_count: int | None = None,
@@ -125,13 +128,14 @@ def xarray_reduce(
             and is identical to xarray's default strategy.
     engine : {"flox", "numpy", "numba"}, optional
         Algorithm to compute the groupby reduction on non-dask arrays and on each dask chunk:
+          * ``"numpy"``:
+            Use the vectorized implementations in ``numpy_groupies.aggregate_numpy``.
+            This is the default choice because it works for other array types.
           * ``"flox"``:
             Use an internal implementation where the data is sorted so that
             all members of a group occur sequentially, and then numpy.ufunc.reduceat
             is to used for the reduction. This will fall back to ``numpy_groupies.aggregate_numpy``
             for a reduction that is not yet implemented.
-          * ``"numpy"``:
-            Use the vectorized implementations in ``numpy_groupies.aggregate_numpy``.
           * ``"numba"``:
             Use the implementations in ``numpy_groupies.aggregate_numba``.
     keep_attrs : bool, optional
@@ -171,12 +175,29 @@ def xarray_reduce(
 
     Examples
     --------
-    FIXME: Add docs.
+    >>> import xarray as xr
+    >>> from flox.xarray import xarray_reduce
+
+    >>> # Create a group index:
+    >>> labels = xr.DataArray(
+    ...     [1, 2, 3, 1, 2, 3, 0, 0, 0],
+    ...     dims="x",
+    ...     name="label",
+    ... )
+    >>> # Create a DataArray to apply the group index on:
+    >>> da = da = xr.ones_like(labels)
+    >>> # Sum all values in da that matches the elements in the group index:
+    >>> xarray_reduce(da, labels, func="sum")
+    <xarray.DataArray 'label' (label: 4)>
+    array([3, 2, 2, 2])
+    Coordinates:
+      * label    (label) int64 0 1 2 3
     """
 
     if skipna is not None and isinstance(func, Aggregation):
         raise ValueError("skipna must be None when func is an Aggregation.")
 
+    nby = len(by)
     for b in by:
         if isinstance(b, xr.DataArray) and b.name is None:
             raise ValueError("Cannot group by unnamed DataArrays.")
@@ -185,12 +206,15 @@ def xarray_reduce(
     if keep_attrs is None:
         keep_attrs = True
 
-    if isinstance(isbin, bool):
-        isbin = (isbin,) * len(by)
+    if isinstance(isbin, Sequence):
+        isbins = isbin
+    else:
+        isbins = (isbin,) * nby
+
     if expected_groups is None:
-        expected_groups = (None,) * len(by)
+        expected_groups = (None,) * nby
     if isinstance(expected_groups, (np.ndarray, list)):  # TODO: test for list
-        if len(by) == 1:
+        if nby == 1:
             expected_groups = (expected_groups,)
         else:
             raise ValueError("Needs better message.")
@@ -199,76 +223,86 @@ def xarray_reduce(
         raise NotImplementedError
 
     # eventually drop the variables we are grouping by
-    maybe_drop = [b for b in by if isinstance(b, str)]
+    maybe_drop = [b for b in by if isinstance(b, Hashable)]
     unindexed_dims = tuple(
         b
-        for b, isbin_ in zip(by, isbin)
-        if isinstance(b, str) and not isbin_ and b in obj.dims and b not in obj.indexes
+        for b, isbin_ in zip(by, isbins)
+        if isinstance(b, Hashable) and not isbin_ and b in obj.dims and b not in obj.indexes
     )
 
-    by: tuple[DataArray] = tuple(obj[g] if isinstance(g, str) else g for g in by)  # type: ignore
+    by_da = tuple(obj[g] if isinstance(g, Hashable) else g for g in by)
 
     grouper_dims = []
-    for g in by:
+    for g in by_da:
         for d in g.dims:
             if d not in grouper_dims:
                 grouper_dims.append(d)
 
-    if isinstance(obj, xr.DataArray):
-        ds = obj._to_temp_dataset()
-    else:
+    if isinstance(obj, xr.Dataset):
         ds = obj
+    else:
+        ds = obj._to_temp_dataset()
 
     ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables])
 
     if dim is Ellipsis:
-        dim = tuple(obj.dims)
-        if by[0].name in ds.dims and not isbin[0]:
-            dim = tuple(d for d in dim if d != by[0].name)
+        if nby > 1:
+            raise NotImplementedError("Multiple by are not allowed when dim is Ellipsis.")
+        name_ = by_da[0].name
+        if name_ in ds.dims and not isbins[0]:
+            dim_tuple = tuple(d for d in obj.dims if d != name_)
+        else:
+            dim_tuple = tuple(obj.dims)
     elif dim is not None:
-        dim = _atleast_1d(dim)
+        dim_tuple = _atleast_1d(dim)
     else:
-        dim = tuple()
+        dim_tuple = tuple()
 
     # broadcast all variables against each other along all dimensions in `by` variables
     # don't exclude `dim` because it need not be a dimension in any of the `by` variables!
     # in the case where dim is Ellipsis, and by.ndim < obj.ndim
     # then we also broadcast `by` to all `obj.dims`
     # TODO: avoid this broadcasting
-    exclude_dims = tuple(d for d in ds.dims if d not in grouper_dims and d not in dim)
-    ds, *by = xr.broadcast(ds, *by, exclude=exclude_dims)
+    exclude_dims = tuple(d for d in ds.dims if d not in grouper_dims and d not in dim_tuple)
+    ds_broad, *by_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)
 
-    if not dim:
-        dim = tuple(by[0].dims)
+    # all members of by_broad have the same dimensions
+    # so we just pull by_broad[0].dims if dim is None
+    if not dim_tuple:
+        dim_tuple = tuple(by_broad[0].dims)
 
-    if any(d not in grouper_dims and d not in obj.dims for d in dim):
+    if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple):
         raise ValueError(f"Cannot reduce over absent dimensions {dim}.")
 
-    dims_not_in_groupers = tuple(d for d in dim if d not in grouper_dims)
-    if dims_not_in_groupers == tuple(dim) and not any(isbin):
+    dims_not_in_groupers = tuple(d for d in dim_tuple if d not in grouper_dims)
+    if dims_not_in_groupers == tuple(dim_tuple) and not any(isbins):
         # reducing along a dimension along which groups do not vary
         # This is really just a normal reduction.
         # This is not right when binning so we exclude.
-        if skipna and isinstance(func, str):
-            dsfunc = func[3:]
+        if isinstance(func, str):
+            dsfunc = func[3:] if skipna else func
         else:
-            dsfunc = func
+            raise NotImplementedError(
+                "func must be a string when reducing along a dimension not present in `by`"
+            )
         # TODO: skipna needs test
-        result = getattr(ds, dsfunc)(dim=dim, skipna=skipna)
+        result = getattr(ds_broad, dsfunc)(dim=dim_tuple, skipna=skipna)
         if isinstance(obj, xr.DataArray):
             return obj._from_temp_dataset(result)
         else:
             return result
 
-    axis = tuple(range(-len(dim), 0))
-    group_names = tuple(g.name if not binned else f"{g.name}_bins" for g, binned in zip(by, isbin))
-
-    group_shape = [None] * len(by)
-    expected_groups = list(expected_groups)
+    axis = tuple(range(-len(dim_tuple), 0))
 
     # Set expected_groups and convert to index since we need coords, sizes
     # for output xarray objects
-    for idx, (b, expect, isbin_) in enumerate(zip(by, expected_groups, isbin)):
+    expected_groups = list(expected_groups)
+    group_names: tuple[Any, ...] = ()
+    group_sizes: dict[Any, int] = {}
+    for idx, (b_, expect, isbin_) in enumerate(zip(by_broad, expected_groups, isbins)):
+        group_name = b_.name if not isbin_ else f"{b_.name}_bins"
+        group_names += (group_name,)
+
         if isbin_ and isinstance(expect, int):
             raise NotImplementedError(
                 "flox does not support binning into an integer number of bins yet."
@@ -277,13 +311,21 @@ def xarray_reduce(
             if isbin_:
                 raise ValueError(
                     f"Please provided bin edges for group variable {idx} "
-                    f"named {group_names[idx]} in expected_groups."
+                    f"named {group_name} in expected_groups."
                 )
-            expected_groups[idx] = _get_expected_groups(b.data, sort=sort, raise_if_dask=True)
-
-    expected_groups = _convert_expected_groups_to_index(expected_groups, isbin, sort=sort)
-    group_shape = tuple(len(e) for e in expected_groups)
-    group_sizes = dict(zip(group_names, group_shape))
+            expect_ = _get_expected_groups(b_.data, sort=sort)
+        else:
+            expect_ = expect
+        expect_index = _convert_expected_groups_to_index((expect_,), (isbin_,), sort=sort)[0]
+
+        # The if-check is for type hinting mainly, it narrows down the return
+        # type of _convert_expected_groups_to_index to pure pd.Index:
+        if expect_index is not None:
+            expected_groups[idx] = expect_index
+            group_sizes[group_name] = len(expect_index)
+        else:
+            # This will never be reached
+            raise ValueError("expect_index cannot be None")
 
     def wrapper(array, *by, func, skipna, **kwargs):
         # Handle skipna here because I need to know dtype to make a good default choice.
@@ -295,7 +337,10 @@ def xarray_reduce(
             if "nan" not in func and func not in ["all", "any", "count"]:
                 func = f"nan{func}"
 
-        requires_numeric = func not in ["count", "any", "all"]
+        # Flox's count works with non-numeric and its faster than converting.
+        requires_numeric = func not in ["count", "any", "all"] or (
+            func == "count" and engine != "flox"
+        )
         if requires_numeric:
             is_npdatetime = array.dtype.kind in "Mm"
             is_cftime = _contains_cftime_datetimes(array)
@@ -310,7 +355,8 @@ def xarray_reduce(
 
         result, *groups = groupby_reduce(array, *by, func=func, **kwargs)
 
-        if requires_numeric:
+        # Output of count has an int dtype.
+        if requires_numeric and func != "count":
             if is_npdatetime:
                 return result.astype(dtype) + offset
             elif is_cftime:
@@ -325,20 +371,20 @@ def xarray_reduce(
     if isinstance(obj, xr.Dataset):
         # broadcasting means the group dim gets added to ds, so we check the original obj
         for k, v in obj.data_vars.items():
-            is_missing_dim = not (any(d in v.dims for d in dim))
+            is_missing_dim = not (any(d in v.dims for d in dim_tuple))
             if is_missing_dim:
                 missing_dim[k] = v
 
-    input_core_dims = _get_input_core_dims(group_names, dim, ds, grouper_dims)
-    input_core_dims += [input_core_dims[-1]] * (len(by) - 1)
+    input_core_dims = _get_input_core_dims(group_names, dim_tuple, ds_broad, grouper_dims)
+    input_core_dims += [input_core_dims[-1]] * (nby - 1)
 
     actual = xr.apply_ufunc(
         wrapper,
-        ds.drop_vars(tuple(missing_dim)).transpose(..., *grouper_dims),
-        *by,
+        ds_broad.drop_vars(tuple(missing_dim)).transpose(..., *grouper_dims),
+        *by_broad,
         input_core_dims=input_core_dims,
         # for xarray's test_groupby_duplicate_coordinate_labels
-        exclude_dims=set(dim),
+        exclude_dims=set(dim_tuple),
         output_core_dims=[group_names],
         dask="allowed",
         dask_gufunc_kwargs=dict(output_sizes=group_sizes),
@@ -355,18 +401,18 @@ def xarray_reduce(
             "engine": engine,
             "reindex": reindex,
             "expected_groups": tuple(expected_groups),
-            "isbin": isbin,
+            "isbin": isbins,
             "finalize_kwargs": finalize_kwargs,
         },
     )
 
     # restore non-dim coord variables without the core dimension
     # TODO: shouldn't apply_ufunc handle this?
-    for var in set(ds.variables) - set(ds.dims):
-        if all(d not in ds[var].dims for d in dim):
-            actual[var] = ds[var]
+    for var in set(ds_broad.variables) - set(ds_broad.dims):
+        if all(d not in ds_broad[var].dims for d in dim_tuple):
+            actual[var] = ds_broad[var]
 
-    for name, expect, by_ in zip(group_names, expected_groups, by):
+    for name, expect, by_ in zip(group_names, expected_groups, by_broad):
         # Can't remove this till xarray handles IntervalIndex
         if isinstance(expect, pd.IntervalIndex):
             expect = expect.to_numpy()
@@ -374,8 +420,8 @@ def xarray_reduce(
             actual = actual.drop_vars(name)
         # When grouping by MultiIndex, expect is an pd.Index wrapping
         # an object array of tuples
-        if name in ds.indexes and isinstance(ds.indexes[name], pd.MultiIndex):
-            levelnames = ds.indexes[name].names
+        if name in ds_broad.indexes and isinstance(ds_broad.indexes[name], pd.MultiIndex):
+            levelnames = ds_broad.indexes[name].names
             expect = pd.MultiIndex.from_tuples(expect.values, names=levelnames)
             actual[name] = expect
             if Version(xr.__version__) > Version("2022.03.0"):
@@ -388,20 +434,19 @@ def xarray_reduce(
     if unindexed_dims:
         actual = actual.drop_vars(unindexed_dims)
 
-    if len(by) == 1:
+    if nby == 1:
         for var in actual:
-            if isinstance(obj, xr.DataArray):
-                template = obj
-            else:
+            if isinstance(obj, xr.Dataset):
                 template = obj[var]
+            else:
+                template = obj
+
             if actual[var].ndim > 1:
-                actual[var] = _restore_dim_order(actual[var], template, by[0])
+                actual[var] = _restore_dim_order(actual[var], template, by_broad[0])
 
     if missing_dim:
         for k, v in missing_dim.items():
-            missing_group_dims = {
-                dim: size for dim, size in group_sizes.items() if dim not in v.dims
-            }
+            missing_group_dims = {d: size for d, size in group_sizes.items() if d not in v.dims}
             # The expand_dims is for backward compat with xarray's questionable behaviour
             if missing_group_dims:
                 actual[k] = v.expand_dims(missing_group_dims).variable
@@ -415,9 +460,9 @@ def xarray_reduce(
 
 
 def rechunk_for_cohorts(
-    obj: DataArray | Dataset,
+    obj: T_DataArray | T_Dataset,
     dim: str,
-    labels: DataArray,
+    labels: T_DataArray,
     force_new_chunk_at,
     chunksize: int | None = None,
     ignore_old_chunks: bool = False,
@@ -462,7 +507,7 @@ def rechunk_for_cohorts(
     )
 
 
-def rechunk_for_blockwise(obj: DataArray | Dataset, dim: str, labels: DataArray):
+def rechunk_for_blockwise(obj: T_DataArray | T_Dataset, dim: str, labels: T_DataArray):
     """
     Rechunks array so that group boundaries line up with chunk boundaries, allowing
     embarassingly parallel group reductions.


=====================================
flox/xrutils.py
=====================================
@@ -19,7 +19,7 @@ try:
 
     dask_array_type = dask.array.Array
 except ImportError:
-    dask_array_type = ()
+    dask_array_type = ()  # type: ignore
 
 
 def asarray(data, xp=np):
@@ -98,7 +98,8 @@ def is_scalar(value: Any, include_0d: bool = True) -> bool:
 
 
 def isnull(data):
-    data = np.asarray(data)
+    if not is_duck_array(data):
+        data = np.asarray(data)
     scalar_type = data.dtype.type
     if issubclass(scalar_type, (np.datetime64, np.timedelta64)):
         # datetime types use NaT for null


=====================================
pyproject.toml
=====================================
@@ -39,8 +39,12 @@ show_error_codes = true
 
 [[tool.mypy.overrides]]
 module=[
+    "cachey",
+    "cftime",
     "dask.*",
+    "importlib_metadata",
     "numpy_groupies",
+    "matplotlib.*",
     "pandas",
     "setuptools",
     "toolz"


=====================================
setup.cfg
=====================================
@@ -7,7 +7,7 @@ description = GroupBy operations for dask.array
 long_description = file: README.md
 long_description_content_type=text/markdown
 
-url = https://github.com/dcherian/flox
+url = https://github.com/xarray-contrib/flox
 classifiers =
     Development Status :: 4 - Beta
     License :: OSI Approved :: Apache Software License
@@ -27,6 +27,7 @@ include_package_data = True
 python_requires = >=3.8
 install_requires =
     pandas
+    numpy >= '1.20'
     numpy_groupies >= '0.9.15'
     toolz
 
@@ -56,3 +57,5 @@ per-file-ignores =
 exclude=
     .eggs
     doc
+builtins =
+    ellipsis


=====================================
tests/test_core.py
=====================================
@@ -504,7 +504,7 @@ def test_groupby_reduce_nans(chunks, axis, groups, expected_shape, engine):
         fill_value=0,
         engine=engine,
     )
-    assert_equal(result, np.zeros(expected_shape, dtype=np.int64))
+    assert_equal(result, np.zeros(expected_shape, dtype=np.intp))
 
     # now when subsets are NaN
     # labels = np.array([0, 0, 1, 1, 1], dtype=float)


=====================================
tests/test_xarray.py
=====================================
@@ -159,6 +159,9 @@ def test_xarray_reduce_multiple_groupers_2(pass_expected_groups, chunk, engine):
         actual = xarray_reduce(da, "labels", "labels2", **kwargs)
     xr.testing.assert_identical(expected, actual)
 
+    with pytest.raises(NotImplementedError):
+        xarray_reduce(da, "labels", "labels2", dim=..., **kwargs)
+
 
 @requires_dask
 def test_dask_groupers_error():
@@ -420,7 +423,7 @@ def test_cache():
 
 @pytest.mark.parametrize("use_cftime", [True, False])
 @pytest.mark.parametrize("func", ["count", "mean"])
-def test_datetime_array_reduce(use_cftime, func):
+def test_datetime_array_reduce(use_cftime, func, engine):
 
     time = xr.DataArray(
         xr.date_range("2009-01-01", "2012-12-31", use_cftime=use_cftime),
@@ -428,7 +431,7 @@ def test_datetime_array_reduce(use_cftime, func):
         name="time",
     )
     expected = getattr(time.resample(time="YS"), func)()
-    actual = resample_reduce(time.resample(time="YS"), func=func, engine="flox")
+    actual = resample_reduce(time.resample(time="YS"), func=func, engine=engine)
     assert_equal(expected, actual)
 
 
@@ -457,7 +460,7 @@ def test_groupby_bins_indexed_coordinate():
 def test_mixed_grouping(chunk):
     if not has_dask and chunk:
         pytest.skip()
-    # regression test for https://github.com/dcherian/flox/pull/111
+    # regression test for https://github.com/xarray-contrib/flox/pull/111
     sa = 10
     sb = 13
     sc = 3



View it on GitLab: https://salsa.debian.org/debian-gis-team/flox/-/commit/93acea4e2ca30b5cf56ecb10f4a3d8ada5ebb5ec

-- 
View it on GitLab: https://salsa.debian.org/debian-gis-team/flox/-/commit/93acea4e2ca30b5cf56ecb10f4a3d8ada5ebb5ec
You're receiving this email because of your account on salsa.debian.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://alioth-lists.debian.net/pipermail/pkg-grass-devel/attachments/20221010/765280ae/attachment-0001.htm>


More information about the Pkg-grass-devel mailing list