[Git][debian-gis-team/flox][upstream] New upstream version 0.9.7

Antonio Valentino (@antonio.valentino) gitlab at salsa.debian.org
Fri May 10 10:14:56 BST 2024



Antonio Valentino pushed to branch upstream at Debian GIS Project / flox


Commits:
6f92dacc by Antonio Valentino at 2024-05-10T08:58:30+00:00
New upstream version 0.9.7
- - - - -


19 changed files:

- .github/workflows/ci-additional.yaml
- .github/workflows/ci.yaml
- .pre-commit-config.yaml
- asv_bench/asv.conf.json
- asv_bench/benchmarks/cohorts.py
- ci/docs.yml
- ci/environment.yml
- docs/source/user-stories.md
- + docs/source/user-stories/climatology-hourly-cubed.ipynb
- docs/source/user-stories/climatology-hourly.ipynb
- docs/source/user-stories/climatology.ipynb
- docs/source/user-stories/custom-aggregations.ipynb
- docs/source/user-stories/nD-bins.ipynb
- flox/aggregations.py
- flox/core.py
- flox/xrutils.py
- pyproject.toml
- tests/__init__.py
- tests/test_core.py


Changes:

=====================================
.github/workflows/ci-additional.yaml
=====================================
@@ -77,7 +77,7 @@ jobs:
           --ignore flox/tests \
           --cov=./ --cov-report=xml
       - name: Upload code coverage to Codecov
-        uses: codecov/codecov-action at v4.1.0
+        uses: codecov/codecov-action at v4.3.1
         with:
           file: ./coverage.xml
           flags: unittests
@@ -131,7 +131,7 @@ jobs:
           python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report
 
       - name: Upload mypy coverage to Codecov
-        uses: codecov/codecov-action at v4.1.0
+        uses: codecov/codecov-action at v4.3.1
         with:
           file: mypy_report/cobertura.xml
           flags: mypy


=====================================
.github/workflows/ci.yaml
=====================================
@@ -24,8 +24,11 @@ jobs:
     strategy:
       fail-fast: false
       matrix:
-        os: ["ubuntu-latest", "windows-latest"]
+        os: ["ubuntu-latest"]
         python-version: ["3.9", "3.12"]
+        include:
+          - os: "windows-latest"
+            python-version: "3.12"
     steps:
       - uses: actions/checkout at v4
         with:
@@ -49,7 +52,7 @@ jobs:
         run: |
           pytest -n auto --cov=./ --cov-report=xml
       - name: Upload code coverage to Codecov
-        uses: codecov/codecov-action at v4.1.0
+        uses: codecov/codecov-action at v4.3.1
         with:
           file: ./coverage.xml
           flags: unittests
@@ -67,10 +70,8 @@ jobs:
       fail-fast: false
       matrix:
         python-version: ["3.12"]
-        env: ["no-xarray", "no-dask"]
+        env: ["no-dask"] # "no-xarray", "no-numba"
         include:
-          - env: "no-numba"
-            python-version: "3.12"
           - env: "minimal-requirements"
             python-version: "3.9"
     steps:
@@ -93,7 +94,7 @@ jobs:
         run: |
           python -m pytest -n auto --cov=./ --cov-report=xml
       - name: Upload code coverage to Codecov
-        uses: codecov/codecov-action at v4.1.0
+        uses: codecov/codecov-action at v4.3.1
         with:
           file: ./coverage.xml
           flags: unittests


=====================================
.pre-commit-config.yaml
=====================================
@@ -4,7 +4,7 @@ ci:
 repos:
   - repo: https://github.com/astral-sh/ruff-pre-commit
     # Ruff version.
-    rev: "v0.1.9"
+    rev: "v0.3.5"
     hooks:
       - id: ruff
         args: ["--fix", "--show-fixes"]
@@ -23,7 +23,7 @@ repos:
       - id: check-docstring-first
 
   - repo: https://github.com/psf/black-pre-commit-mirror
-    rev: 23.12.1
+    rev: 24.3.0
     hooks:
       - id: black
 
@@ -36,14 +36,14 @@ repos:
           - mdformat-myst
 
   - repo: https://github.com/nbQA-dev/nbQA
-    rev: 1.7.1
+    rev: 1.8.5
     hooks:
       - id: nbqa-black
       - id: nbqa-ruff
         args: [--fix]
 
   - repo: https://github.com/kynan/nbstripout
-    rev: 0.6.1
+    rev: 0.7.1
     hooks:
       - id: nbstripout
         args: [--extra-keys=metadata.kernelspec metadata.language_info.version]
@@ -56,12 +56,12 @@ repos:
           - tomli
 
   - repo: https://github.com/abravalheri/validate-pyproject
-    rev: v0.15
+    rev: v0.16
     hooks:
       - id: validate-pyproject
 
   - repo: https://github.com/rhysd/actionlint
-    rev: v1.6.26
+    rev: v1.6.27
     hooks:
       - id: actionlint
         files: ".github/workflows/"


=====================================
asv_bench/asv.conf.json
=====================================
@@ -27,6 +27,11 @@
   //     "python setup.py build",
   //     "PIP_NO_BUILD_ISOLATION=false python -mpip wheel --no-deps --no-index -w {build_cache_dir} {build_dir}"
   // ],
+  //
+  "build_command": [
+    "python setup.py build",
+    "python -mpip wheel --no-deps --no-build-isolation --no-index -w {build_cache_dir} {build_dir}"
+  ],
 
   // List of branches to benchmark. If not provided, defaults to "master"
   // (for git) or "default" (for mercurial).


=====================================
asv_bench/benchmarks/cohorts.py
=====================================
@@ -1,3 +1,5 @@
+from functools import cached_property
+
 import dask
 import numpy as np
 import pandas as pd
@@ -11,6 +13,10 @@ class Cohorts:
     def setup(self, *args, **kwargs):
         raise NotImplementedError
 
+    @cached_property
+    def result(self):
+        return flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis)[0]
+
     def containment(self):
         asfloat = self.bitmask().astype(float)
         chunks_per_label = asfloat.sum(axis=0)
@@ -43,26 +49,17 @@ class Cohorts:
             pass
 
     def time_graph_construct(self):
-        flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis, method="cohorts")
+        flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis)
 
     def track_num_tasks(self):
-        result = flox.groupby_reduce(
-            self.array, self.by, func="sum", axis=self.axis, method="cohorts"
-        )[0]
-        return len(result.dask.to_dict())
+        return len(self.result.dask.to_dict())
 
     def track_num_tasks_optimized(self):
-        result = flox.groupby_reduce(
-            self.array, self.by, func="sum", axis=self.axis, method="cohorts"
-        )[0]
-        (opt,) = dask.optimize(result)
+        (opt,) = dask.optimize(self.result)
         return len(opt.dask.to_dict())
 
     def track_num_layers(self):
-        result = flox.groupby_reduce(
-            self.array, self.by, func="sum", axis=self.axis, method="cohorts"
-        )[0]
-        return len(result.dask.layers)
+        return len(self.result.dask.layers)
 
     track_num_tasks.unit = "tasks"  # type: ignore[attr-defined] # Lazy
     track_num_tasks_optimized.unit = "tasks"  # type: ignore[attr-defined] # Lazy
@@ -193,6 +190,19 @@ class PerfectBlockwiseResampling(Cohorts):
         self.expected = pd.RangeIndex(self.by.max() + 1)
 
 
+class SingleChunk(Cohorts):
+    """Single chunk along reduction axis: always blockwise."""
+
+    def setup(self, *args, **kwargs):
+        index = pd.date_range("1959-01-01", freq="D", end="1962-12-31")
+        self.time = pd.Series(index)
+        TIME = len(self.time)
+        self.axis = (2,)
+        self.array = dask.array.ones((721, 1440, TIME), chunks=(-1, -1, -1))
+        self.by = codes_for_resampling(index, freq="5D")
+        self.expected = pd.RangeIndex(self.by.max() + 1)
+
+
 class OISST(Cohorts):
     def setup(self, *args, **kwargs):
         self.array = dask.array.ones((1, 14532), chunks=(1, 10))


=====================================
ci/docs.yml
=====================================
@@ -2,6 +2,8 @@ name: flox-doc
 channels:
   - conda-forge
 dependencies:
+  - cubed>=0.14.3
+  - cubed-xarray
   - dask-core
   - pip
   - xarray


=====================================
ci/environment.yml
=====================================
@@ -6,6 +6,7 @@ dependencies:
   - cachey
   - cftime
   - codecov
+  - cubed>=0.14.2
   - dask-core
   - pandas
   - numpy>=1.22


=====================================
docs/source/user-stories.md
=====================================
@@ -7,6 +7,7 @@
    user-stories/overlaps.md
    user-stories/climatology.ipynb
    user-stories/climatology-hourly.ipynb
+   user-stories/climatology-hourly-cubed.ipynb
    user-stories/custom-aggregations.ipynb
    user-stories/nD-bins.ipynb
 ```


=====================================
docs/source/user-stories/climatology-hourly-cubed.ipynb
=====================================
@@ -0,0 +1,106 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "0",
+   "metadata": {},
+   "source": [
+    "# More climatology reductions using Cubed\n",
+    "\n",
+    "This is the Cubed equivalent of [More climatology reductions](climatology-hourly.ipynb).\n",
+    "\n",
+    "The task is to compute an hourly climatology from an hourly dataset with 744 hours in each chunk, using the \"map-reduce\" strategy."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "1",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import cubed\n",
+    "import cubed.array_api as xp\n",
+    "import numpy as np\n",
+    "import pandas as pd\n",
+    "import xarray as xr\n",
+    "\n",
+    "import flox.xarray"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "2",
+   "metadata": {},
+   "source": [
+    "## Create data\n",
+    "\n",
+    "Note that we use fewer lat/long points so the computation can be run locally."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "3",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "spec = cubed.Spec(allowed_mem=\"2GB\")\n",
+    "ds = xr.Dataset(\n",
+    "    {\n",
+    "        \"tp\": (\n",
+    "            (\"time\", \"latitude\", \"longitude\"),\n",
+    "            xp.ones((8760, 72, 144), chunks=(744, 5, 144), dtype=np.float32, spec=spec),\n",
+    "        )\n",
+    "    },\n",
+    "    coords={\"time\": pd.date_range(\"2021-01-01\", \"2021-12-31 23:59\", freq=\"h\")},\n",
+    ")\n",
+    "ds"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "4",
+   "metadata": {},
+   "source": [
+    "## Computation"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "5",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "hourly = flox.xarray.xarray_reduce(ds.tp, ds.time.dt.hour, func=\"mean\", reindex=True)\n",
+    "hourly"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "6",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "hourly.compute()"
+   ]
+  }
+ ],
+ "metadata": {
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}


=====================================
docs/source/user-stories/climatology-hourly.ipynb
=====================================
@@ -2,7 +2,7 @@
  "cells": [
   {
    "cell_type": "markdown",
-   "id": "84e8bbee-90cc-4e6a-bf89-c56dc19c11ca",
+   "id": "0",
    "metadata": {},
    "source": [
     "# More climatology reductions\n",
@@ -26,7 +26,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "727f490e-906a-4537-ac5e-3c67985cd6d5",
+   "id": "1",
    "metadata": {},
    "outputs": [
     {
@@ -73,7 +73,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "6085684f-cafa-450c-8448-d5c9c1cbb55f",
+   "id": "2",
    "metadata": {},
    "outputs": [
     {
@@ -98,7 +98,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "5380714a-b35f-4fb0-8b3d-7528ef7a7595",
+   "id": "3",
    "metadata": {},
    "source": [
     "## Create data\n"
@@ -107,7 +107,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "2aa66559-b2dd-4b46-b32b-f1ce2270c3de",
+   "id": "4",
    "metadata": {},
    "outputs": [
     {
@@ -636,7 +636,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "3a350782-b747-4e5e-8b8b-15fab72c0d2c",
+   "id": "5",
    "metadata": {},
    "source": [
     "Here's just plain xarray: 10000 tasks and one chunk per hour in the output\n"
@@ -645,7 +645,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "ecc77698-5879-4b7c-ad97-891fb104d295",
+   "id": "6",
    "metadata": {},
    "outputs": [
     {
@@ -1173,7 +1173,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "beccd9f8-ad62-4cd8-86cc-acfe14f13023",
+   "id": "7",
    "metadata": {},
    "source": [
     "And flox: 600 tasks and all hours in a single chunk\n"
@@ -1182,7 +1182,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "0a3da8e5-863a-4602-9176-0a9adc689563",
+   "id": "8",
    "metadata": {},
    "outputs": [
     {
@@ -1676,7 +1676,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "8aa1a641-1ce1-4264-96dc-d11bb1d4ab57",
+   "id": "9",
    "metadata": {},
    "outputs": [],
    "source": []
@@ -1684,7 +1684,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "e37c5aa2-c77a-4d87-8db4-5052c675c42d",
+   "id": "10",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -1694,7 +1694,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "494766c2-305a-4518-b2b7-a85bcc7fd5b2",
+   "id": "11",
    "metadata": {},
    "source": [
     "View the performance report\n",


=====================================
docs/source/user-stories/climatology.ipynb
=====================================
@@ -2,7 +2,7 @@
  "cells": [
   {
    "cell_type": "markdown",
-   "id": "4e9bf3f9-0952-493c-a8df-4a1d851c37a9",
+   "id": "0",
    "metadata": {},
    "source": [
     "# Strategies for climatology calculations\n",
@@ -15,7 +15,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "85ac0588-ff00-43cc-b952-7ab775b24e4a",
+   "id": "1",
    "metadata": {
     "tags": []
    },
@@ -31,7 +31,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "82f46621-1b6c-4a14-ac0f-3aa5121dad54",
+   "id": "2",
    "metadata": {},
    "source": [
     "Let's first create an example Xarray Dataset representing the OISST dataset,\n",
@@ -41,7 +41,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "9a91d2e2-bd6d-4b35-8002-5fac76c4c5b3",
+   "id": "3",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -56,18 +56,20 @@
   },
   {
    "cell_type": "markdown",
-   "id": "b7f519ee-e575-492c-a70b-8dad63a8c222",
+   "id": "4",
    "metadata": {},
    "source": [
     "To account for Feb-29 being present in some years, we'll construct a time vector to group by as \"mmm-dd\" string.\n",
     "\n",
-    "For more options, see https://strftime.org/"
+    "```{seealso}\n",
+    "For more options, see [this great website](https://strftime.org/).\n",
+    "```"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "3c42a618-47bc-4c83-a902-ec4cf3420180",
+   "id": "5",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -77,10 +79,10 @@
   },
   {
    "cell_type": "markdown",
-   "id": "6d913e7f-25bd-43c4-98b6-93bcb420c524",
+   "id": "6",
    "metadata": {},
    "source": [
-    "## map-reduce\n",
+    "## First, `method=\"map-reduce\"`\n",
     "\n",
     "The default\n",
     "[method=\"map-reduce\"](https://flox.readthedocs.io/en/latest/implementation.html#method-map-reduce)\n",
@@ -93,7 +95,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "ef2a14de-7526-40e3-8a97-28e84d6d6f20",
+   "id": "7",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -107,10 +109,10 @@
   },
   {
    "cell_type": "markdown",
-   "id": "442ad701-ea45-4555-9550-ec9daecfbea3",
+   "id": "8",
    "metadata": {},
    "source": [
-    "## Rechunking for map-reduce\n",
+    "### Rechunking for map-reduce\n",
     "\n",
     "We can split each chunk along the `lat`, `lon` dimensions to make sure the\n",
     "output chunk sizes are more reasonable\n"
@@ -119,7 +121,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "322c7776-9a21-4115-8ac9-9c7c6c6e2c91",
+   "id": "9",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -133,13 +135,13 @@
   },
   {
    "cell_type": "markdown",
-   "id": "833f72eb-1501-4362-ae55-ec419c9f0ac1",
+   "id": "10",
    "metadata": {},
    "source": [
     "But what if we didn't want to rechunk the dataset so drastically (note the 10x\n",
     "increase in tasks). For that let's try `method=\"cohorts\"`\n",
     "\n",
-    "## method=cohorts\n",
+    "## `method=\"cohorts\"`\n",
     "\n",
     "We can take advantage of patterns in the groups here \"day of year\".\n",
     "Specifically:\n",
@@ -156,7 +158,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "a3bafc32-7e13-41b8-90eb-b27955393392",
+   "id": "11",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -170,7 +172,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "b4e1ba0b-20e5-466a-9199-38b47029a0ed",
+   "id": "12",
    "metadata": {},
    "source": [
     "By default cohorts doesn't work so well for this problem because the period\n",
@@ -185,7 +187,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "13ce5531-0d6c-4c89-bc44-dc2c24fa4e47",
+   "id": "13",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -200,7 +202,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "068b4109-b7f4-4c16-918d-9a18ff2ed183",
+   "id": "14",
    "metadata": {},
    "source": [
     "Looking more closely, we can see many cohorts with a single entry. "
@@ -209,7 +211,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "57983cd0-a2e0-4d16-abe6-9572f6f252bf",
+   "id": "15",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -218,7 +220,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "bcbdbb3b-2aed-4f3f-ad20-efabb52b5e68",
+   "id": "16",
    "metadata": {},
    "source": [
     "## Rechunking data for cohorts\n",
@@ -231,7 +233,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "40d393a5-7a4e-4d33-997b-4c422a0b8100",
+   "id": "17",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -240,7 +242,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "cd0033a3-d211-4aef-a284-c9fd3f75f6e4",
+   "id": "18",
    "metadata": {},
    "source": [
     "We'll choose to rechunk such that a single month in is a chunk. This is not too different from the current chunking but will help your periodicity problem"
@@ -249,7 +251,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "5914a350-a7db-49b3-9504-6d63ff874f5e",
+   "id": "19",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -259,7 +261,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "90a884bc-1b71-4874-8143-73b3b5c41458",
+   "id": "20",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -268,20 +270,20 @@
   },
   {
    "cell_type": "markdown",
-   "id": "12b7a27f-ebab-4673-bb9f-80620389994b",
+   "id": "21",
    "metadata": {},
    "source": [
-    "And now our cohorts contain more than one group\n"
+    "And now our cohorts contain more than one group, *and* there is a substantial reduction in number of cohorts **162 -> 12**\n"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "f522fb82-764d-4e4e-8337-a5123e3088f8",
+   "id": "22",
    "metadata": {},
    "outputs": [],
    "source": [
-    "preferrd_method, new_cohorts = flox.core.find_group_cohorts(\n",
+    "preferred_method, new_cohorts = flox.core.find_group_cohorts(\n",
     "    labels=codes,\n",
     "    chunks=(rechunked.chunksizes[\"time\"],),\n",
     ")\n",
@@ -292,7 +294,17 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "4e2b6f70-c057-4783-ad55-21b20ff27e7f",
+   "id": "23",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "preferred_method"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "24",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -301,7 +313,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "949ac39c-dd84-4375-a884-0c1c3c382a8f",
+   "id": "25",
    "metadata": {},
    "source": [
     "Now the groupby reduction **looks OK** in terms of number of tasks but remember\n",
@@ -311,7 +323,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "8f1e45f9-5b18-482a-8c76-66f81ff5710f",
+   "id": "26",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -320,7 +332,25 @@
   },
   {
    "cell_type": "markdown",
-   "id": "93c58969-5c99-4bc0-90ee-9cef468bf78b",
+   "id": "27",
+   "metadata": {},
+   "source": [
+    "flox's heuristics will choose `\"cohorts\"` automatically!"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "28",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "flox.xarray.xarray_reduce(rechunked, day, func=\"mean\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "29",
    "metadata": {},
    "source": [
     "## How about other climatologies?\n",
@@ -331,7 +361,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "e559ea33-5499-48ff-9a2e-5141c3a69fea",
+   "id": "30",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -340,7 +370,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "a00de8eb-e414-4920-8dcd-b64afbf91b62",
+   "id": "31",
    "metadata": {},
    "source": [
     "This looks great. Why?\n",


=====================================
docs/source/user-stories/custom-aggregations.ipynb
=====================================
@@ -2,7 +2,7 @@
  "cells": [
   {
    "cell_type": "markdown",
-   "id": "190a42b0-e1df-40dd-bb68-0f8ebacdc6f3",
+   "id": "0",
    "metadata": {},
    "source": [
     "# Custom Aggregations\n",
@@ -27,7 +27,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "8c6fcc42-b081-44fa-acf7-a95ec4ed75d2",
+   "id": "1",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -68,7 +68,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "4bf0f68a-cefc-454c-80cd-e60688958a87",
+   "id": "2",
    "metadata": {},
    "source": [
     "## A built-in reduction\n",
@@ -79,7 +79,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "c0a7f29f-311c-41fd-b03b-33ba7ffccfc6",
+   "id": "3",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -96,7 +96,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "d58c2984-0589-4730-848f-bb92817a4bd1",
+   "id": "4",
    "metadata": {},
    "source": [
     "## Aggregations\n",
@@ -112,7 +112,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "574b93ef-dd73-4a98-bd53-69119d5d97c0",
+   "id": "5",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -122,7 +122,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "8750f32d-9d77-4197-88bb-b7c1388cdcfe",
+   "id": "6",
    "metadata": {},
    "source": [
     "Here's how the mean Aggregation is created\n",
@@ -152,7 +152,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "ddb4663c-16bc-4f78-899d-490d0ec01452",
+   "id": "7",
    "metadata": {},
    "source": [
     "## Defining a custom aggregation\n",
@@ -202,7 +202,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "05b8a1e5-e865-4b25-8540-df5aa6c218e9",
+   "id": "8",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -220,7 +220,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "b356f4f2-ae22-4f56-89ec-50646136e2eb",
+   "id": "9",
    "metadata": {},
    "source": [
     "Now we create the `Aggregation`\n"
@@ -229,7 +229,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "07c0fc82-c77b-4472-9de7-3c4a7cf3e07e",
+   "id": "10",
    "metadata": {},
    "outputs": [],
    "source": [
@@ -245,7 +245,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "899ece52-ebd4-47b4-8090-cbbb63f504a4",
+   "id": "11",
    "metadata": {},
    "source": [
     "And apply it!\n"
@@ -254,7 +254,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "df85a390-99dd-432f-b248-6160935deb52",
+   "id": "12",
    "metadata": {},
    "outputs": [],
    "source": [


=====================================
docs/source/user-stories/nD-bins.ipynb
=====================================
@@ -2,7 +2,7 @@
  "cells": [
   {
    "cell_type": "markdown",
-   "id": "e970d800-c612-482a-bb3a-b1eb7ad53d88",
+   "id": "0",
    "metadata": {
     "tags": [],
     "user_expressions": []
@@ -28,7 +28,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "01f1a2ef-de62-45d0-a04e-343cd78debc5",
+   "id": "1",
    "metadata": {
     "tags": []
    },
@@ -46,7 +46,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "0be3e214-0cf0-426f-8ebb-669cc5322310",
+   "id": "2",
    "metadata": {
     "user_expressions": []
    },
@@ -56,7 +56,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "ce239000-e053-4fc3-ad14-e9e0160da869",
+   "id": "3",
    "metadata": {
     "user_expressions": []
    },
@@ -67,7 +67,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "7659c24e-f5a1-4e59-84c0-5ec965ef92d2",
+   "id": "4",
    "metadata": {
     "tags": []
    },
@@ -83,7 +83,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "da0c0ac9-ad75-42cd-a1ea-99069f5bef00",
+   "id": "5",
    "metadata": {
     "user_expressions": []
    },
@@ -94,7 +94,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "4601e744-5d22-447e-97ce-9644198d485e",
+   "id": "6",
    "metadata": {
     "tags": []
    },
@@ -110,7 +110,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "61c21c94-7b6e-46a6-b9c2-59d7b2d40c81",
+   "id": "7",
    "metadata": {
     "tags": [],
     "user_expressions": []
@@ -122,7 +122,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "863a1991-ab8d-47c0-aa48-22b422fcea8c",
+   "id": "8",
    "metadata": {
     "tags": []
    },
@@ -139,7 +139,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "e65ecaba-d1cc-4485-ae58-c390cb2ebfab",
+   "id": "9",
    "metadata": {
     "user_expressions": []
    },
@@ -174,7 +174,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "aa33ab2c-0ecf-4198-a033-2a77f5d83c99",
+   "id": "10",
    "metadata": {
     "tags": []
    },
@@ -186,7 +186,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "afcddcc1-dd57-461e-a649-1f8bcd30342f",
+   "id": "11",
    "metadata": {
     "tags": []
    },
@@ -217,7 +217,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "1661312a-dc61-4a26-bfd8-12c2dc01eb15",
+   "id": "12",
    "metadata": {
     "user_expressions": []
    },
@@ -234,7 +234,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "0e5801cb-a79c-4670-ad10-36bb19f1a6ff",
+   "id": "13",
    "metadata": {
     "tags": []
    },
@@ -250,7 +250,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "6c06c48b-316b-4a33-9bc3-921acd10bcba",
+   "id": "14",
    "metadata": {
     "user_expressions": []
    },
@@ -263,7 +263,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "2cf1295e-4585-48b9-ac2b-9e00d03b2b9a",
+   "id": "15",
    "metadata": {
     "tags": []
    },
@@ -283,7 +283,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "3539509b-d9b4-4342-a679-6ada6f285dfb",
+   "id": "16",
    "metadata": {
     "user_expressions": []
    },
@@ -296,7 +296,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "b1389d37-d76d-4a50-9dfb-8710258de3fd",
+   "id": "17",
    "metadata": {
     "tags": []
    },
@@ -316,7 +316,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "a98b5e60-94af-45ae-be1b-4cb47e2d77ba",
+   "id": "18",
    "metadata": {
     "user_expressions": []
    },
@@ -327,7 +327,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "053a8643-f6d9-4fd1-b014-230fa716449c",
+   "id": "19",
    "metadata": {
     "tags": []
    },
@@ -338,7 +338,7 @@
   },
   {
    "cell_type": "markdown",
-   "id": "619ba4c4-7c87-459a-ab86-c187d3a86c67",
+   "id": "20",
    "metadata": {
     "tags": [],
     "user_expressions": []


=====================================
flox/aggregations.py
=====================================
@@ -623,9 +623,11 @@ def _initialize_aggregation(
         "final": final_dtype,
         "numpy": (final_dtype,),
         "intermediate": tuple(
-            _normalize_dtype(int_dtype, np.result_type(array_dtype, final_dtype), int_fv)
-            if int_dtype is None
-            else np.dtype(int_dtype)
+            (
+                _normalize_dtype(int_dtype, np.result_type(array_dtype, final_dtype), int_fv)
+                if int_dtype is None
+                else np.dtype(int_dtype)
+            )
             for int_dtype, int_fv in zip(
                 agg.dtype_init["intermediate"], agg.fill_value["intermediate"]
             )


=====================================
flox/core.py
=====================================
@@ -8,6 +8,7 @@ import sys
 import warnings
 from collections import namedtuple
 from collections.abc import Sequence
+from concurrent.futures import ThreadPoolExecutor
 from functools import partial, reduce
 from itertools import product
 from numbers import Integral
@@ -17,6 +18,7 @@ from typing import (
     Callable,
     Literal,
     TypedDict,
+    TypeVar,
     Union,
     overload,
 )
@@ -38,7 +40,9 @@ from .aggregations import (
 )
 from .cache import memoize
 from .xrutils import (
+    is_chunked_array,
     is_duck_array,
+    is_duck_cubed_array,
     is_duck_dask_array,
     isnull,
     module_available,
@@ -63,10 +67,11 @@ if TYPE_CHECKING:
     except (ModuleNotFoundError, ImportError):
         Unpack: Any  # type: ignore[no-redef]
 
+    import cubed.Array as CubedArray
     import dask.array.Array as DaskArray
     from dask.typing import Graph
 
-    T_DuckArray = Union[np.ndarray, DaskArray]  # Any ?
+    T_DuckArray = Union[np.ndarray, DaskArray, CubedArray]  # Any ?
     T_By = T_DuckArray
     T_Bys = tuple[T_By, ...]
     T_ExpectIndex = pd.Index
@@ -93,9 +98,10 @@ if TYPE_CHECKING:
     T_MethodOpt = None | Literal["map-reduce", "blockwise", "cohorts"]
     T_IsBins = Union[bool | Sequence[bool]]
 
+T = TypeVar("T")
 
 IntermediateDict = dict[Union[str, Callable], Any]
-FinalResultsDict = dict[str, Union["DaskArray", np.ndarray]]
+FinalResultsDict = dict[str, Union["DaskArray", "CubedArray", np.ndarray]]
 FactorProps = namedtuple("FactorProps", "offset_group nan_sentinel nanmask")
 
 # This dummy axis is inserted using np.expand_dims
@@ -137,6 +143,10 @@ def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups):
     return result
 
 
+def identity(x: T) -> T:
+    return x
+
+
 def _issorted(arr: np.ndarray) -> bool:
     return bool((arr[:-1] <= arr[1:]).all())
 
@@ -245,34 +255,77 @@ def slices_from_chunks(chunks):
 
 
 def _compute_label_chunk_bitmask(labels, chunks, nlabels):
+    def make_bitmask(rows, cols):
+        data = np.broadcast_to(np.array(1, dtype=np.uint8), rows.shape)
+        return csc_array((data, (rows, cols)), dtype=bool, shape=(nchunks, nlabels))
+
     assert isinstance(labels, np.ndarray)
     shape = tuple(sum(c) for c in chunks)
     nchunks = math.prod(len(c) for c in chunks)
+    approx_chunk_size = math.prod(c[0] for c in chunks)
 
-    labels = np.broadcast_to(labels, shape[-labels.ndim :])
+    # Shortcut for 1D with size-1 chunks
+    if shape == (nchunks,):
+        rows_array = np.arange(nchunks)
+        cols_array = labels
+        mask = labels >= 0
+        return make_bitmask(rows_array[mask], cols_array[mask])
 
+    labels = np.broadcast_to(labels, shape[-labels.ndim :])
     cols = []
-    # Add one to handle the -1 sentinel value
-    label_is_present = np.zeros((nlabels + 1,), dtype=bool)
     ilabels = np.arange(nlabels)
-    for region in slices_from_chunks(chunks):
+
+    def chunk_unique(labels, slicer, nlabels, label_is_present=None):
+        if label_is_present is None:
+            label_is_present = np.empty((nlabels + 1,), dtype=bool)
+        label_is_present[:] = False
+        subset = labels[slicer]
         # This is a quite fast way to find unique integers, when we know how many there are
         # inspired by a similar idea in numpy_groupies for first, last
         # instead of explicitly finding uniques, repeatedly write True to the same location
-        subset = labels[region]
-        # The reshape is not strictly necessary but is about 100ms faster on a test problem.
         label_is_present[subset.reshape(-1)] = True
         # skip the -1 sentinel by slicing
         # Faster than np.argwhere by a lot
         uniques = ilabels[label_is_present[:-1]]
-        cols.append(uniques)
-        label_is_present[:] = False
+        return uniques
+
+    # TODO: refine this heuristic.
+    # The general idea is that with the threadpool, we repeatedly allocate memory
+    # for `label_is_present`. We trade that off against the parallelism across number of chunks.
+    # For large enough number of chunks (relative to number of labels), it makes sense to
+    # suffer the extra allocation in exchange for parallelism.
+    THRESHOLD = 2
+    if nlabels < THRESHOLD * approx_chunk_size:
+        logger.debug(
+            "Using threadpool since num_labels %s < %d * chunksize %s",
+            nlabels,
+            THRESHOLD,
+            approx_chunk_size,
+        )
+        with ThreadPoolExecutor() as executor:
+            futures = [
+                executor.submit(chunk_unique, labels, slicer, nlabels)
+                for slicer in slices_from_chunks(chunks)
+            ]
+            cols = tuple(f.result() for f in futures)
+
+    else:
+        logger.debug(
+            "Using serial loop since num_labels %s > %d * chunksize %s",
+            nlabels,
+            THRESHOLD,
+            approx_chunk_size,
+        )
+        cols = []
+        # Add one to handle the -1 sentinel value
+        label_is_present = np.empty((nlabels + 1,), dtype=bool)
+        for region in slices_from_chunks(chunks):
+            uniques = chunk_unique(labels, region, nlabels, label_is_present)
+            cols.append(uniques)
     rows_array = np.repeat(np.arange(nchunks), tuple(len(col) for col in cols))
     cols_array = np.concatenate(cols)
-    data = np.broadcast_to(np.array(1, dtype=np.uint8), rows_array.shape)
-    bitmask = csc_array((data, (rows_array, cols_array)), dtype=bool, shape=(nchunks, nlabels))
 
-    return bitmask
+    return make_bitmask(rows_array, cols_array)
 
 
 # @memoize
@@ -309,6 +362,7 @@ def find_group_cohorts(
     labels = np.asarray(labels)
 
     shape = tuple(sum(c) for c in chunks)
+    nchunks = math.prod(len(c) for c in chunks)
 
     # assumes that `labels` are factorized
     if expected_groups is None:
@@ -316,6 +370,10 @@ def find_group_cohorts(
     else:
         nlabels = expected_groups[-1] + 1
 
+    # 1. Single chunk, blockwise always
+    if nchunks == 1:
+        return "blockwise", {(0,): list(range(nlabels))}
+
     labels = np.broadcast_to(labels, shape[-labels.ndim :])
     bitmask = _compute_label_chunk_bitmask(labels, chunks, nlabels)
 
@@ -343,21 +401,21 @@ def find_group_cohorts(
 
     chunks_cohorts = tlz.groupby(invert, label_chunks.keys())
 
-    # 1. Every group is contained to one block, use blockwise here.
+    # 2. Every group is contained to one block, use blockwise here.
     if bitmask.shape[CHUNK_AXIS] == 1 or (chunks_per_label == 1).all():
         logger.info("find_group_cohorts: blockwise is preferred.")
         return "blockwise", chunks_cohorts
 
-    # 2. Perfectly chunked so there is only a single cohort
+    # 3. Perfectly chunked so there is only a single cohort
     if len(chunks_cohorts) == 1:
         logger.info("Only found a single cohort. 'map-reduce' is preferred.")
         return "map-reduce", chunks_cohorts if merge else {}
 
-    # 3. Our dataset has chunksize one along the axis,
+    # 4. Our dataset has chunksize one along the axis,
     single_chunks = all(all(a == 1 for a in ac) for ac in chunks)
-    # 4. Every chunk only has a single group, but that group might extend across multiple chunks
+    # 5. Every chunk only has a single group, but that group might extend across multiple chunks
     one_group_per_chunk = (bitmask.sum(axis=LABEL_AXIS) == 1).all()
-    # 5. Existing cohorts don't overlap, great for time grouping with perfect chunking
+    # 6. Existing cohorts don't overlap, great for time grouping with perfect chunking
     no_overlapping_cohorts = (np.bincount(np.concatenate(tuple(chunks_cohorts.keys()))) == 1).all()
     if one_group_per_chunk or single_chunks or no_overlapping_cohorts:
         logger.info("find_group_cohorts: cohorts is preferred, chunking is perfect.")
@@ -390,6 +448,7 @@ def find_group_cohorts(
             sparsity, MAX_SPARSITY_FOR_COHORTS
         )
     )
+    # 7. Groups seem fairly randomly distributed, use "map-reduce".
     if sparsity > MAX_SPARSITY_FOR_COHORTS:
         if not merge:
             logger.info(
@@ -654,8 +713,7 @@ def factorize_(
     expected_groups: T_ExpectIndexOptTuple | None = None,
     reindex: bool = False,
     sort: bool = True,
-) -> tuple[np.ndarray, tuple[np.ndarray, ...], tuple[int, ...], int, int, None]:
-    ...
+) -> tuple[np.ndarray, tuple[np.ndarray, ...], tuple[int, ...], int, int, None]: ...
 
 
 @overload
@@ -667,8 +725,7 @@ def factorize_(
     reindex: bool = False,
     sort: bool = True,
     fastpath: Literal[False] = False,
-) -> tuple[np.ndarray, tuple[np.ndarray, ...], tuple[int, ...], int, int, FactorProps]:
-    ...
+) -> tuple[np.ndarray, tuple[np.ndarray, ...], tuple[int, ...], int, int, FactorProps]: ...
 
 
 @overload
@@ -680,8 +737,7 @@ def factorize_(
     reindex: bool = False,
     sort: bool = True,
     fastpath: bool = False,
-) -> tuple[np.ndarray, tuple[np.ndarray, ...], tuple[int, ...], int, int, FactorProps | None]:
-    ...
+) -> tuple[np.ndarray, tuple[np.ndarray, ...], tuple[int, ...], int, int, FactorProps | None]: ...
 
 
 def factorize_(
@@ -1424,7 +1480,10 @@ def _normalize_indexes(array: DaskArray, flatblocks, blkshape) -> tuple:
 
 
 def subset_to_blocks(
-    array: DaskArray, flatblocks: Sequence[int], blkshape: tuple[int] | None = None
+    array: DaskArray,
+    flatblocks: Sequence[int],
+    blkshape: tuple[int] | None = None,
+    reindexer=identity,
 ) -> DaskArray:
     """
     Advanced indexing of .blocks such that we always get a regular array back.
@@ -1450,20 +1509,21 @@ def subset_to_blocks(
     index = _normalize_indexes(array, flatblocks, blkshape)
 
     if all(not isinstance(i, np.ndarray) and i == slice(None) for i in index):
-        return array
+        return dask.array.map_blocks(reindexer, array, meta=array._meta)
 
     # These rest is copied from dask.array.core.py with slight modifications
     index = normalize_index(index, array.numblocks)
     index = tuple(slice(k, k + 1) if isinstance(k, Integral) else k for k in index)
 
-    name = "blocks-" + tokenize(array, index)
+    name = "groupby-cohort-" + tokenize(array, index)
     new_keys = array._key_array[index]
 
     squeezed = tuple(np.squeeze(i) if isinstance(i, np.ndarray) else i for i in index)
     chunks = tuple(tuple(np.array(c)[i].tolist()) for c, i in zip(array.chunks, squeezed))
 
     keys = itertools.product(*(range(len(c)) for c in chunks))
-    layer: Graph = {(name,) + key: tuple(new_keys[key].tolist()) for key in keys}
+    layer: Graph = {(name,) + key: (reindexer, tuple(new_keys[key].tolist())) for key in keys}
+
     graph = HighLevelGraph.from_collections(name, layer, dependencies=[array])
 
     return dask.array.Array(graph, name, chunks, meta=array)
@@ -1637,26 +1697,26 @@ def dask_groupby_agg(
 
         elif method == "cohorts":
             assert chunks_cohorts
+            block_shape = array.blocks.shape[-len(axis) :]
+
             reduced_ = []
             groups_ = []
             for blks, cohort in chunks_cohorts.items():
-                index = pd.Index(cohort)
-                subset = subset_to_blocks(intermediate, blks, array.blocks.shape[-len(axis) :])
-                reindexed = dask.array.map_blocks(
-                    reindex_intermediates, subset, agg, index, meta=subset._meta
-                )
+                cohort_index = pd.Index(cohort)
+                reindexer = partial(reindex_intermediates, agg=agg, unique_groups=cohort_index)
+                reindexed = subset_to_blocks(intermediate, blks, block_shape, reindexer)
                 # now that we have reindexed, we can set reindex=True explicitlly
                 reduced_.append(
                     tree_reduce(
                         reindexed,
                         combine=partial(combine, agg=agg, reindex=True),
-                        aggregate=partial(aggregate, expected_groups=index, reindex=True),
+                        aggregate=partial(aggregate, expected_groups=cohort_index, reindex=True),
                     )
                 )
                 # This is done because pandas promotes to 64-bit types when an Index is created
                 # So we use the index to generate the return value for consistency with "map-reduce"
                 # This is important on windows
-                groups_.append(index.values)
+                groups_.append(cohort_index.values)
 
             reduced = dask.array.concatenate(reduced_, axis=-1)
             groups = (np.concatenate(groups_),)
@@ -1718,6 +1778,109 @@ def dask_groupby_agg(
     return (result, groups)
 
 
+def cubed_groupby_agg(
+    array: CubedArray,
+    by: T_By,
+    agg: Aggregation,
+    expected_groups: pd.Index | None,
+    axis: T_Axes = (),
+    fill_value: Any = None,
+    method: T_Method = "map-reduce",
+    reindex: bool = False,
+    engine: T_Engine = "numpy",
+    sort: bool = True,
+    chunks_cohorts=None,
+) -> tuple[CubedArray, tuple[np.ndarray | CubedArray]]:
+    import cubed
+    import cubed.core.groupby
+
+    # I think _tree_reduce expects this
+    assert isinstance(axis, Sequence)
+    assert all(ax >= 0 for ax in axis)
+
+    inds = tuple(range(array.ndim))
+
+    by_input = by
+
+    # Unifying chunks is necessary for argreductions.
+    # We need to rechunk before zipping up with the index
+    # let's always do it anyway
+    if not is_chunked_array(by):
+        # chunk numpy arrays like the input array
+        chunks = tuple(array.chunks[ax] if by.shape[ax] != 1 else (1,) for ax in range(-by.ndim, 0))
+
+        by = cubed.from_array(by, chunks=chunks, spec=array.spec)
+    _, (array, by) = cubed.core.unify_chunks(array, inds, by, inds[-by.ndim :])
+
+    # Cubed's groupby_reduction handles the generation of "intermediates", and the
+    # "map-reduce" combination step, so we don't have to do that here.
+    # Only the equivalent of "_simple_combine" is supported, there is no
+    # support for "_grouped_combine".
+    labels_are_unknown = is_chunked_array(by_input) and expected_groups is None
+    do_simple_combine = not _is_arg_reduction(agg) and not labels_are_unknown
+
+    assert do_simple_combine
+    assert method == "map-reduce"
+    assert expected_groups is not None
+    assert reindex is True
+    assert len(axis) == 1  # one axis/grouping
+
+    def _groupby_func(a, by, axis, intermediate_dtype, num_groups):
+        blockwise_method = partial(
+            _get_chunk_reduction(agg.reduction_type),
+            func=agg.chunk,
+            fill_value=agg.fill_value["intermediate"],
+            dtype=agg.dtype["intermediate"],
+            reindex=reindex,
+            user_dtype=agg.dtype["user"],
+            axis=axis,
+            expected_groups=expected_groups,
+            engine=engine,
+            sort=sort,
+        )
+        out = blockwise_method(a, by)
+        # Convert dict to one that cubed understands, dropping groups since they are
+        # known, and the same for every block.
+        return {f"f{idx}": intermediate for idx, intermediate in enumerate(out["intermediates"])}
+
+    def _groupby_combine(a, axis, dummy_axis, dtype, keepdims):
+        # this is similar to _simple_combine, except the dummy axis and concatenation is handled by cubed
+        # only combine over the dummy axis, to preserve grouping along 'axis'
+        dtype = dict(dtype)
+        out = {}
+        for idx, combine in enumerate(agg.simple_combine):
+            field = f"f{idx}"
+            out[field] = combine(a[field], axis=dummy_axis, keepdims=keepdims)
+        return out
+
+    def _groupby_aggregate(a):
+        # Convert cubed dict to one that _finalize_results works with
+        results = {"groups": expected_groups, "intermediates": a.values()}
+        out = _finalize_results(results, agg, axis, expected_groups, fill_value, reindex)
+        return out[agg.name]
+
+    # convert list of dtypes to a structured dtype for cubed
+    intermediate_dtype = [(f"f{i}", dtype) for i, dtype in enumerate(agg.dtype["intermediate"])]
+    dtype = agg.dtype["final"]
+    num_groups = len(expected_groups)
+
+    result = cubed.core.groupby.groupby_reduction(
+        array,
+        by,
+        func=_groupby_func,
+        combine_func=_groupby_combine,
+        aggregate_func=_groupby_aggregate,
+        axis=axis,
+        intermediate_dtype=intermediate_dtype,
+        dtype=dtype,
+        num_groups=num_groups,
+    )
+
+    groups = (expected_groups.to_numpy(),)
+
+    return (result, groups)
+
+
 def _collapse_blocks_along_axes(reduced: DaskArray, axis: T_Axes, group_chunks) -> DaskArray:
     import dask.array
     from dask.highlevelgraph import HighLevelGraph
@@ -1823,15 +1986,13 @@ def _assert_by_is_aligned(shape: tuple[int, ...], by: T_Bys) -> None:
 @overload
 def _convert_expected_groups_to_index(
     expected_groups: tuple[None, ...], isbin: Sequence[bool], sort: bool
-) -> tuple[None, ...]:
-    ...
+) -> tuple[None, ...]: ...
 
 
 @overload
 def _convert_expected_groups_to_index(
     expected_groups: T_ExpectTuple, isbin: Sequence[bool], sort: bool
-) -> T_ExpectIndexTuple:
-    ...
+) -> T_ExpectIndexTuple: ...
 
 
 def _convert_expected_groups_to_index(
@@ -1919,13 +2080,11 @@ def _factorize_multiple(
 
 
 @overload
-def _validate_expected_groups(nby: int, expected_groups: None) -> tuple[None, ...]:
-    ...
+def _validate_expected_groups(nby: int, expected_groups: None) -> tuple[None, ...]: ...
 
 
 @overload
-def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroups) -> T_ExpectTuple:
-    ...
+def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroups) -> T_ExpectTuple: ...
 
 
 def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) -> T_ExpectOptTuple:
@@ -2165,6 +2324,7 @@ def groupby_reduce(
     nby = len(bys)
     by_is_dask = tuple(is_duck_dask_array(b) for b in bys)
     any_by_dask = any(by_is_dask)
+    provided_expected = expected_groups is not None
 
     if (
         engine == "numbagg"
@@ -2240,6 +2400,7 @@ def groupby_reduce(
     nax = len(axis_)
 
     has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)
+    has_cubed = is_duck_cubed_array(array) or is_duck_cubed_array(by_)
 
     if _is_first_last_reduction(func):
         if has_dask and nax != 1:
@@ -2280,7 +2441,7 @@ def groupby_reduce(
     #     The only way to do this consistently is mask out using min_count
     #     Consider np.sum([np.nan]) = np.nan, np.nansum([np.nan]) = 0
     if min_count is None:
-        if nax < by_.ndim or fill_value is not None:
+        if nax < by_.ndim or (fill_value is not None and provided_expected):
             min_count_: int = 1
         else:
             min_count_ = 0
@@ -2302,7 +2463,30 @@ def groupby_reduce(
     kwargs["engine"] = _choose_engine(by_, agg) if engine is None else engine
 
     groups: tuple[np.ndarray | DaskArray, ...]
-    if not has_dask:
+    if has_cubed:
+        if method is None:
+            method = "map-reduce"
+
+        if method != "map-reduce":
+            raise NotImplementedError(
+                "Reduction for Cubed arrays is only implemented for method 'map-reduce'."
+            )
+
+        partial_agg = partial(cubed_groupby_agg, **kwargs)
+
+        result, groups = partial_agg(
+            array,
+            by_,
+            expected_groups=expected_,
+            agg=agg,
+            reindex=reindex,
+            method=method,
+            sort=sort,
+        )
+
+        return (result, groups)
+
+    elif not has_dask:
         results = _reduce_blockwise(
             array, by_, agg, expected_groups=expected_, reindex=reindex, sort=sort, **kwargs
         )


=====================================
flox/xrutils.py
=====================================
@@ -8,7 +8,6 @@ from typing import Any, Optional
 
 import numpy as np
 import pandas as pd
-from numpy.core.multiarray import normalize_axis_index  # type: ignore[attr-defined]
 from packaging.version import Version
 
 try:
@@ -25,6 +24,37 @@ except ImportError:
     dask_array_type = ()  # type: ignore[assignment, misc]
 
 
+def module_available(module: str, minversion: Optional[str] = None) -> bool:
+    """Checks whether a module is installed without importing it.
+
+    Use this for a lightweight check and lazy imports.
+
+    Parameters
+    ----------
+    module : str
+        Name of the module.
+
+    Returns
+    -------
+    available : bool
+        Whether the module is installed.
+    """
+    has = importlib.util.find_spec(module) is not None
+    if has:
+        mod = importlib.import_module(module)
+        return Version(mod.__version__) >= Version(minversion) if minversion is not None else True
+    else:
+        return False
+
+
+if module_available("numpy", minversion="2.0.0"):
+    from numpy.lib.array_utils import (  # type: ignore[import-not-found]
+        normalize_axis_index,
+    )
+else:
+    from numpy.core.numeric import normalize_axis_index  # type: ignore[attr-defined]
+
+
 def asarray(data, xp=np):
     return data if is_duck_array(data) else xp.asarray(data)
 
@@ -37,11 +67,18 @@ def is_duck_array(value: Any) -> bool:
         hasattr(value, "ndim")
         and hasattr(value, "shape")
         and hasattr(value, "dtype")
-        and hasattr(value, "__array_function__")
-        and hasattr(value, "__array_ufunc__")
+        and (
+            (hasattr(value, "__array_function__") and hasattr(value, "__array_ufunc__"))
+            or hasattr(value, "__array_namespace__")
+        )
     )
 
 
+def is_chunked_array(x) -> bool:
+    """True if dask or cubed"""
+    return is_duck_dask_array(x) or (is_duck_array(x) and hasattr(x, "chunks"))
+
+
 def is_dask_collection(x):
     try:
         import dask
@@ -56,6 +93,15 @@ def is_duck_dask_array(x):
     return is_duck_array(x) and is_dask_collection(x)
 
 
+def is_duck_cubed_array(x):
+    try:
+        import cubed
+
+        return is_duck_array(x) and isinstance(x, cubed.Array)
+    except ImportError:
+        return False
+
+
 class ReprObject:
     """Object that prints as the given value, for use with sentinel values."""
 
@@ -333,26 +379,3 @@ def nanlast(values, axis, keepdims=False):
         return np.expand_dims(result, axis=axis)
     else:
         return result
-
-
-def module_available(module: str, minversion: Optional[str] = None) -> bool:
-    """Checks whether a module is installed without importing it.
-
-    Use this for a lightweight check and lazy imports.
-
-    Parameters
-    ----------
-    module : str
-        Name of the module.
-
-    Returns
-    -------
-    available : bool
-        Whether the module is installed.
-    """
-    has = importlib.util.find_spec(module) is not None
-    if has:
-        mod = importlib.import_module(module)
-        return Version(mod.__version__) >= Version(minversion) if minversion is not None else True
-    else:
-        return False


=====================================
pyproject.toml
=====================================
@@ -121,6 +121,7 @@ module=[
     "asv_runner.*",
     "cachey",
     "cftime",
+    "cubed.*",
     "dask.*",
     "importlib_metadata",
     "numba",


=====================================
tests/__init__.py
=====================================
@@ -46,6 +46,7 @@ def LooseVersion(vstring):
 
 
 has_cftime, requires_cftime = _importorskip("cftime")
+has_cubed, requires_cubed = _importorskip("cubed")
 has_dask, requires_dask = _importorskip("dask")
 has_numba, requires_numba = _importorskip("numba")
 has_numbagg, requires_numbagg = _importorskip("numbagg")


=====================================
tests/test_core.py
=====================================
@@ -36,8 +36,10 @@ from . import (
     SCIPY_STATS_FUNCS,
     assert_equal,
     assert_equal_tuple,
+    has_cubed,
     has_dask,
     raise_if_dask_computes,
+    requires_cubed,
     requires_dask,
 )
 
@@ -61,6 +63,10 @@ else:
         return None
 
 
+if has_cubed:
+    import cubed
+
+
 DEFAULT_QUANTILE = 0.9
 
 if TYPE_CHECKING:
@@ -477,6 +483,49 @@ def test_groupby_agg_dask(func, shape, array_chunks, group_chunks, add_nan, dtyp
     assert_equal(expected, actual)
 
 
+ at requires_cubed
+ at pytest.mark.parametrize("reindex", [True])
+ at pytest.mark.parametrize("func", ALL_FUNCS)
+ at pytest.mark.parametrize("add_nan", [False, True])
+ at pytest.mark.parametrize(
+    "shape, array_chunks, group_chunks",
+    [
+        ((12,), (3,), 3),  # form 1
+    ],
+)
+def test_groupby_agg_cubed(func, shape, array_chunks, group_chunks, add_nan, engine, reindex):
+    """Tests groupby_reduce with cubed arrays against groupby_reduce with numpy arrays"""
+
+    if func in ["first", "last"] or func in BLOCKWISE_FUNCS:
+        pytest.skip()
+
+    if "arg" in func and (engine in ["flox", "numbagg"] or reindex):
+        pytest.skip()
+
+    array = cubed.array_api.ones(shape, chunks=array_chunks)
+
+    labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0])
+    if add_nan:
+        labels = labels.astype(float)
+        labels[:3] = np.nan  # entire block is NaN when group_chunks=3
+        labels[-2:] = np.nan
+
+    kwargs = dict(
+        func=func,
+        expected_groups=[0, 1, 2],
+        fill_value=False if func in ["all", "any"] else 123,
+        reindex=reindex,
+    )
+
+    expected, _ = groupby_reduce(array.compute(), labels, engine="numpy", **kwargs)
+    actual, _ = groupby_reduce(array.compute(), labels, engine=engine, **kwargs)
+    assert_equal(actual, expected)
+
+    # TODO: raise_if_cubed_computes
+    actual, _ = groupby_reduce(array, labels, engine=engine, **kwargs)
+    assert_equal(expected, actual)
+
+
 def test_numpy_reduce_axis_subset(engine):
     # TODO: add NaNs
     by = labels2d
@@ -897,12 +946,12 @@ def test_verify_complex_cohorts(chunksize: int) -> None:
 @pytest.mark.parametrize("chunksize", (12,) + tuple(range(1, 13)) + (-1,))
 def test_method_guessing(chunksize):
     # just a regression test
-    labels = np.tile(np.arange(1, 13), 30)
+    labels = np.tile(np.arange(0, 12), 30)
     by = dask.array.from_array(labels, chunks=chunksize) - 1
     preferred_method, chunks_cohorts = find_group_cohorts(labels, by.chunks[slice(-1, None)])
     if chunksize == -1:
         assert preferred_method == "blockwise"
-        assert chunks_cohorts == {(0,): list(range(1, 13))}
+        assert chunks_cohorts == {(0,): list(range(12))}
     elif chunksize in (1, 2, 3, 4, 6):
         assert preferred_method == "cohorts"
         assert len(chunks_cohorts) == 12 // chunksize
@@ -911,6 +960,21 @@ def test_method_guessing(chunksize):
         assert chunks_cohorts == {}
 
 
+ at requires_dask
+ at pytest.mark.parametrize("ndim", [1, 2, 3])
+def test_single_chunk_method_is_blockwise(ndim):
+    for by_ndim in range(1, ndim + 1):
+        chunks = (5,) * (ndim - by_ndim) + (-1,) * by_ndim
+        assert len(chunks) == ndim
+        array = dask.array.ones(shape=(10,) * ndim, chunks=chunks)
+        by = np.zeros(shape=(10,) * by_ndim, dtype=int)
+        method, chunks_cohorts = find_group_cohorts(
+            by, chunks=[array.chunks[ax] for ax in range(-by.ndim, 0)]
+        )
+        assert method == "blockwise"
+        assert chunks_cohorts == {(0,): [0]}
+
+
 @requires_dask
 @pytest.mark.parametrize(
     "chunk_at,expected",
@@ -1401,14 +1465,18 @@ def test_normalize_block_indexing_2d(flatblocks, expected):
 
 @requires_dask
 def test_subset_block_passthrough():
+    from flox.core import identity
+
     # full slice pass through
     array = dask.array.ones((5,), chunks=(1,))
+    expected = dask.array.map_blocks(identity, array)
     subset = subset_to_blocks(array, np.arange(5))
-    assert subset.name == array.name
+    assert subset.name == expected.name
 
     array = dask.array.ones((5, 5), chunks=1)
+    expected = dask.array.map_blocks(identity, array)
     subset = subset_to_blocks(array, np.arange(25))
-    assert subset.name == array.name
+    assert subset.name == expected.name
 
 
 @requires_dask



View it on GitLab: https://salsa.debian.org/debian-gis-team/flox/-/commit/6f92dacc3855c823a7959924b7685340bb1b6566

-- 
View it on GitLab: https://salsa.debian.org/debian-gis-team/flox/-/commit/6f92dacc3855c823a7959924b7685340bb1b6566
You're receiving this email because of your account on salsa.debian.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://alioth-lists.debian.net/pipermail/pkg-grass-devel/attachments/20240510/00a20803/attachment-0001.htm>


More information about the Pkg-grass-devel mailing list