[Git][debian-gis-team/flox][upstream] New upstream version 0.8.6

Antonio Valentino (@antonio.valentino) gitlab at salsa.debian.org
Sat Jan 6 17:06:49 GMT 2024



Antonio Valentino pushed to branch upstream at Debian GIS Project / flox


Commits:
8f81c002 by Antonio Valentino at 2024-01-06T15:35:19+00:00
New upstream version 0.8.6
- - - - -


10 changed files:

- .github/workflows/benchmarks.yml
- .github/workflows/ci.yaml
- .github/workflows/pypi.yaml
- .github/workflows/testpypi-release.yaml
- .pre-commit-config.yaml
- asv_bench/benchmarks/cohorts.py
- docs/source/user-stories/climatology.ipynb
- flox/core.py
- tests/test_core.py
- tests/test_xarray.py


Changes:

=====================================
.github/workflows/benchmarks.yml
=====================================
@@ -64,7 +64,7 @@ jobs:
           cp benchmarks/README_CI.md benchmarks.log .asv/results/
         working-directory: ${{ env.ASV_DIR }}
 
-      - uses: actions/upload-artifact at v3
+      - uses: actions/upload-artifact at v4
         if: always()
         with:
           name: asv-benchmark-results-${{ runner.os }}


=====================================
.github/workflows/ci.yaml
=====================================
@@ -148,4 +148,8 @@ jobs:
           python -m pytest -n auto \
               xarray/tests/test_groupby.py \
               xarray/tests/test_units.py::TestDataArray::test_computation_objects \
-              xarray/tests/test_units.py::TestDataset::test_computation_objects
+              xarray/tests/test_units.py::TestDataArray::test_grouped_operations \
+              xarray/tests/test_units.py::TestDataArray::test_resample \
+              xarray/tests/test_units.py::TestDataset::test_computation_objects \
+              xarray/tests/test_units.py::TestDataset::test_grouped_operations \
+              xarray/tests/test_units.py::TestDataset::test_resample


=====================================
.github/workflows/pypi.yaml
=====================================
@@ -10,7 +10,7 @@ jobs:
     steps:
       - uses: actions/checkout at v4
       - name: Set up Python
-        uses: actions/setup-python at v4
+        uses: actions/setup-python at v5
         with:
           python-version: "3.x"
       - name: Install dependencies


=====================================
.github/workflows/testpypi-release.yaml
=====================================
@@ -21,7 +21,7 @@ jobs:
         with:
           fetch-depth: 0
 
-      - uses: actions/setup-python at v4
+      - uses: actions/setup-python at v5
         name: Install Python
         with:
           python-version: "3.11"
@@ -53,7 +53,7 @@ jobs:
             echo "✅ Looks good"
           fi
 
-      - uses: actions/upload-artifact at v3
+      - uses: actions/upload-artifact at v4
         with:
           name: releases
           path: dist
@@ -62,11 +62,11 @@ jobs:
     needs: build-artifacts
     runs-on: ubuntu-latest
     steps:
-      - uses: actions/setup-python at v4
+      - uses: actions/setup-python at v5
         name: Install Python
         with:
           python-version: "3.11"
-      - uses: actions/download-artifact at v3
+      - uses: actions/download-artifact at v4
         with:
           name: releases
           path: dist


=====================================
.pre-commit-config.yaml
=====================================
@@ -4,13 +4,13 @@ ci:
 repos:
   - repo: https://github.com/astral-sh/ruff-pre-commit
     # Ruff version.
-    rev: "v0.0.292"
+    rev: "v0.1.9"
     hooks:
       - id: ruff
         args: ["--fix", "--show-fixes"]
 
   - repo: https://github.com/pre-commit/mirrors-prettier
-    rev: "v3.0.3"
+    rev: "v4.0.0-alpha.8"
     hooks:
       - id: prettier
 
@@ -23,7 +23,7 @@ repos:
       - id: check-docstring-first
 
   - repo: https://github.com/psf/black-pre-commit-mirror
-    rev: 23.9.1
+    rev: 23.12.1
     hooks:
       - id: black
 
@@ -36,7 +36,7 @@ repos:
           - mdformat-myst
 
   - repo: https://github.com/nbQA-dev/nbQA
-    rev: 1.7.0
+    rev: 1.7.1
     hooks:
       - id: nbqa-black
       - id: nbqa-ruff


=====================================
asv_bench/benchmarks/cohorts.py
=====================================
@@ -11,8 +11,23 @@ class Cohorts:
     def setup(self, *args, **kwargs):
         raise NotImplementedError
 
+    def chunks_cohorts(self):
+        return flox.core.find_group_cohorts(
+            self.by,
+            [self.array.chunks[ax] for ax in self.axis],
+            expected_groups=self.expected,
+        )
+
+    def bitmask(self):
+        chunks = [self.array.chunks[ax] for ax in self.axis]
+        return flox.core._compute_label_chunk_bitmask(self.by, chunks, self.expected[-1] + 1)
+
     def time_find_group_cohorts(self):
-        flox.core.find_group_cohorts(self.by, [self.array.chunks[ax] for ax in self.axis])
+        flox.core.find_group_cohorts(
+            self.by,
+            [self.array.chunks[ax] for ax in self.axis],
+            expected_groups=self.expected,
+        )
         # The cache clear fails dependably in CI
         # Not sure why
         try:
@@ -58,10 +73,15 @@ class NWMMidwest(Cohorts):
     def setup(self, *args, **kwargs):
         x = np.repeat(np.arange(30), 150)
         y = np.repeat(np.arange(30), 60)
-        self.by = x[np.newaxis, :] * y[:, np.newaxis]
+        by = x[np.newaxis, :] * y[:, np.newaxis]
+
+        self.by = flox.core._factorize_multiple(
+            (by,), expected_groups=(None,), any_by_dask=False, reindex=False
+        )[0][0]
 
         self.array = dask.array.ones(self.by.shape, chunks=(350, 350))
         self.axis = (-2, -1)
+        self.expected = pd.RangeIndex(self.by.max() + 1)
 
 
 class ERA5Dataset:
@@ -81,13 +101,15 @@ class ERA5Dataset:
 class ERA5DayOfYear(ERA5Dataset, Cohorts):
     def setup(self, *args, **kwargs):
         super().__init__()
-        self.by = self.time.dt.dayofyear.values
+        self.by = self.time.dt.dayofyear.values - 1
+        self.expected = pd.RangeIndex(self.by.max() + 1)
 
 
-class ERA5DayOfYearRechunked(ERA5DayOfYear, Cohorts):
-    def setup(self, *args, **kwargs):
-        super().setup()
-        self.array = dask.array.random.random((721, 1440, len(self.time)), chunks=(-1, -1, 24))
+# class ERA5DayOfYearRechunked(ERA5DayOfYear, Cohorts):
+#     def setup(self, *args, **kwargs):
+#         super().setup()
+#         self.array = dask.array.random.random((721, 1440, len(self.time)), chunks=(-1, -1, 24))
+#         self.expected = pd.RangeIndex(self.by.max() + 1)
 
 
 class ERA5MonthHour(ERA5Dataset, Cohorts):
@@ -101,7 +123,8 @@ class ERA5MonthHour(ERA5Dataset, Cohorts):
             reindex=False,
         )
         # Add one so the rechunk code is simpler and makes sense
-        self.by = ret[0][0] + 1
+        self.by = ret[0][0]
+        self.expected = pd.RangeIndex(self.by.max() + 1)
 
 
 class ERA5MonthHourRechunked(ERA5MonthHour, Cohorts):
@@ -117,7 +140,8 @@ class PerfectMonthly(Cohorts):
         self.time = pd.Series(pd.date_range("1961-01-01", "2018-12-31 23:59", freq="M"))
         self.axis = (-1,)
         self.array = dask.array.random.random((721, 1440, len(self.time)), chunks=(-1, -1, 4))
-        self.by = self.time.dt.month.values
+        self.by = self.time.dt.month.values - 1
+        self.expected = pd.RangeIndex(self.by.max() + 1)
 
     def rechunk(self):
         self.array = flox.core.rechunk_for_cohorts(
@@ -125,10 +149,10 @@ class PerfectMonthly(Cohorts):
         )
 
 
-class PerfectMonthlyRechunked(PerfectMonthly):
-    def setup(self, *args, **kwargs):
-        super().setup()
-        super().rechunk()
+# class PerfectMonthlyRechunked(PerfectMonthly):
+#     def setup(self, *args, **kwargs):
+#         super().setup()
+#         super().rechunk()
 
 
 class ERA5Google(Cohorts):
@@ -137,4 +161,27 @@ class ERA5Google(Cohorts):
         self.time = pd.Series(pd.date_range("1959-01-01", freq="6H", periods=TIME))
         self.axis = (2,)
         self.array = dask.array.ones((721, 1440, TIME), chunks=(-1, -1, 1))
-        self.by = self.time.dt.day.values
+        self.by = self.time.dt.day.values - 1
+        self.expected = pd.RangeIndex(self.by.max() + 1)
+
+
+def codes_for_resampling(group_as_index, freq):
+    s = pd.Series(np.arange(group_as_index.size), group_as_index)
+    grouped = s.groupby(pd.Grouper(freq=freq))
+    first_items = grouped.first()
+    counts = grouped.count()
+    codes = np.repeat(np.arange(len(first_items)), counts)
+    return codes
+
+
+class PerfectBlockwiseResampling(Cohorts):
+    """Perfectly chunked for blockwise resampling."""
+
+    def setup(self, *args, **kwargs):
+        index = pd.date_range("1959-01-01", freq="D", end="1962-12-31")
+        self.time = pd.Series(index)
+        TIME = len(self.time)
+        self.axis = (2,)
+        self.array = dask.array.ones((721, 1440, TIME), chunks=(-1, -1, 10))
+        self.by = codes_for_resampling(index, freq="5D")
+        self.expected = pd.RangeIndex(self.by.max() + 1)


=====================================
docs/source/user-stories/climatology.ipynb
=====================================
@@ -22,8 +22,6 @@
    "outputs": [],
    "source": [
     "import dask.array\n",
-    "import matplotlib.pyplot as plt\n",
-    "import numpy as np\n",
     "import pandas as pd\n",
     "import xarray as xr\n",
     "\n",
@@ -56,6 +54,27 @@
     "oisst"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "id": "b7f519ee-e575-492c-a70b-8dad63a8c222",
+   "metadata": {},
+   "source": [
+    "To account for Feb-29 being present in some years, we'll construct a time vector to group by as \"mmm-dd\" string.\n",
+    "\n",
+    "For more options, see https://strftime.org/"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "3c42a618-47bc-4c83-a902-ec4cf3420180",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "day = oisst.time.dt.strftime(\"%h-%d\").rename(\"day\")\n",
+    "day"
+   ]
+  },
   {
    "cell_type": "markdown",
    "id": "6d913e7f-25bd-43c4-98b6-93bcb420c524",
@@ -80,7 +99,7 @@
    "source": [
     "flox.xarray.xarray_reduce(\n",
     "    oisst,\n",
-    "    oisst.time.dt.dayofyear,\n",
+    "    day,\n",
     "    func=\"mean\",\n",
     "    method=\"map-reduce\",\n",
     ")"
@@ -106,7 +125,7 @@
    "source": [
     "flox.xarray.xarray_reduce(\n",
     "    oisst.chunk({\"lat\": -1, \"lon\": 120}),\n",
-    "    oisst.time.dt.dayofyear,\n",
+    "    day,\n",
     "    func=\"mean\",\n",
     "    method=\"map-reduce\",\n",
     ")"
@@ -143,7 +162,7 @@
    "source": [
     "flox.xarray.xarray_reduce(\n",
     "    oisst,\n",
-    "    oisst.time.dt.dayofyear,\n",
+    "    day,\n",
     "    func=\"mean\",\n",
     "    method=\"cohorts\",\n",
     ")"
@@ -160,10 +179,7 @@
     "[click here](https://flox.readthedocs.io/en/latest/implementation.html#method-cohorts)).\n",
     "Now we have the opposite problem: the chunk sizes on the output are too small.\n",
     "\n",
-    "Looking more closely, We can see the cohorts that `flox` has detected are not\n",
-    "really cohorts, each cohort is a single group label. We've replicated Xarray's\n",
-    "current strategy; what flox calls\n",
-    "[\"split-reduce\"](https://flox.readthedocs.io/en/latest/implementation.html#method-split-reduce-xarray-s-current-groupby-strategy)\n"
+    "Let us inspect the cohorts"
    ]
   },
   {
@@ -173,112 +189,81 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "flox.core.find_group_cohorts(\n",
-    "    labels=oisst.time.dt.dayofyear.data,\n",
+    "# integer codes for each \"day\"\n",
+    "codes, _ = pd.factorize(day.data)\n",
+    "cohorts = flox.core.find_group_cohorts(\n",
+    "    labels=codes,\n",
     "    chunks=(oisst.chunksizes[\"time\"],),\n",
-    ").values()"
+    ")\n",
+    "print(len(cohorts))"
    ]
   },
   {
    "cell_type": "markdown",
-   "id": "bcbdbb3b-2aed-4f3f-ad20-efabb52b5e68",
+   "id": "068b4109-b7f4-4c16-918d-9a18ff2ed183",
    "metadata": {},
    "source": [
-    "## Rechunking data for cohorts\n",
-    "\n",
-    "Can we fix the \"out of phase\" problem by rechunking along time?\n",
-    "\n",
-    "First lets see where the current chunk boundaries are\n"
+    "Looking more closely, we can see many cohorts with a single entry. "
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "90a884bc-1b71-4874-8143-73b3b5c41458",
+   "id": "57983cd0-a2e0-4d16-abe6-9572f6f252bf",
    "metadata": {},
    "outputs": [],
    "source": [
-    "array = oisst.data\n",
-    "labels = oisst.time.dt.dayofyear.data\n",
-    "axis = oisst.get_axis_num(\"time\")\n",
-    "oldchunks = array.chunks[axis]\n",
-    "oldbreaks = np.insert(np.cumsum(oldchunks), 0, 0)\n",
-    "labels_at_breaks = labels[oldbreaks[:-1]]\n",
-    "labels_at_breaks"
+    "cohorts.values()"
    ]
   },
   {
    "cell_type": "markdown",
-   "id": "4b2573e5-0d30-4cb8-b5af-751b824f0689",
+   "id": "bcbdbb3b-2aed-4f3f-ad20-efabb52b5e68",
    "metadata": {},
    "source": [
-    "Now we'll use a convenient function `rechunk_for_cohorts` to rechunk the `oisst`\n",
-    "dataset along time. We'll ask it to rechunk so that a new chunk starts at each\n",
-    "of the elements\n",
+    "## Rechunking data for cohorts\n",
     "\n",
-    "```\n",
-    "[244, 264, 284, 304, 324, 344, 364,  19,  39,  59,  79,  99, 119,\n",
-    " 139, 159, 179, 199, 219, 239]\n",
-    "```\n",
+    "Can we fix the \"out of phase\" problem by rechunking along time?\n",
     "\n",
-    "These are labels at the chunk boundaries in the first year of data. We are\n",
-    "forcing that chunking pattern to repeat as much as possible. We also tell the\n",
-    "function to ignore any existing chunk boundaries.\n"
+    "First lets see where the current chunk boundaries are"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "a9ab6382-e93b-49e9-8e2e-1ba526046aea",
+   "id": "40d393a5-7a4e-4d33-997b-4c422a0b8100",
    "metadata": {},
    "outputs": [],
    "source": [
-    "rechunked = flox.xarray.rechunk_for_cohorts(\n",
-    "    oisst,\n",
-    "    dim=\"time\",\n",
-    "    labels=oisst.time.dt.dayofyear,\n",
-    "    force_new_chunk_at=[\n",
-    "        244,\n",
-    "        264,\n",
-    "        284,\n",
-    "        304,\n",
-    "        324,\n",
-    "        344,\n",
-    "        364,\n",
-    "        19,\n",
-    "        39,\n",
-    "        59,\n",
-    "        79,\n",
-    "        99,\n",
-    "        119,\n",
-    "        139,\n",
-    "        159,\n",
-    "        179,\n",
-    "        199,\n",
-    "        219,\n",
-    "        239,\n",
-    "    ],\n",
-    "    ignore_old_chunks=True,\n",
-    ")\n",
-    "rechunked"
+    "oisst.chunksizes[\"time\"][:10]"
    ]
   },
   {
    "cell_type": "markdown",
-   "id": "570d869b-9612-4de9-83ee-336a35c1fdad",
+   "id": "cd0033a3-d211-4aef-a284-c9fd3f75f6e4",
+   "metadata": {},
+   "source": [
+    "We'll choose to rechunk such that a single month in is a chunk. This is not too different from the current chunking but will help your periodicity problem"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "5914a350-a7db-49b3-9504-6d63ff874f5e",
    "metadata": {},
+   "outputs": [],
    "source": [
-    "We see that chunks are mostly 20 elements long in time with some differences\n"
+    "newchunks = xr.ones_like(day).astype(int).resample(time=\"M\").count()"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "86bb4461-d921-40f8-9ff7-8d6e7e8c7e4b",
+   "id": "90a884bc-1b71-4874-8143-73b3b5c41458",
    "metadata": {},
    "outputs": [],
    "source": [
-    "plt.plot(rechunked.chunksizes[\"time\"], marker=\"x\", ls=\"none\")"
+    "rechunked = oisst.chunk(time=tuple(newchunks.data))"
    ]
   },
   {
@@ -296,10 +281,22 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "flox.core.find_group_cohorts(\n",
-    "    labels=rechunked.time.dt.dayofyear.data,\n",
+    "new_cohorts = flox.core.find_group_cohorts(\n",
+    "    labels=codes,\n",
     "    chunks=(rechunked.chunksizes[\"time\"],),\n",
-    ").values()"
+    ")\n",
+    "# one cohort per month!\n",
+    "len(new_cohorts)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "4e2b6f70-c057-4783-ad55-21b20ff27e7f",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "new_cohorts.values()"
    ]
   },
   {
@@ -318,7 +315,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "flox.xarray.xarray_reduce(rechunked, rechunked.time.dt.dayofyear, func=\"mean\", method=\"cohorts\")"
+    "flox.xarray.xarray_reduce(rechunked, day, func=\"mean\", method=\"cohorts\")"
    ]
   },
   {


=====================================
flox/core.py
=====================================
@@ -214,8 +214,41 @@ def slices_from_chunks(chunks):
     return product(*slices)
 
 
- at memoize
-def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
+def _compute_label_chunk_bitmask(labels, chunks, nlabels):
+    assert isinstance(labels, np.ndarray)
+    shape = tuple(sum(c) for c in chunks)
+    nchunks = math.prod(len(c) for c in chunks)
+
+    labels = np.broadcast_to(labels, shape[-labels.ndim :])
+
+    rows = []
+    cols = []
+    # Add one to handle the -1 sentinel value
+    label_is_present = np.zeros((nlabels + 1,), dtype=bool)
+    ilabels = np.arange(nlabels)
+    for idx, region in enumerate(slices_from_chunks(chunks)):
+        # This is a quite fast way to find unique integers, when we know how many there are
+        # inspired by a similar idea in numpy_groupies for first, last
+        # instead of explicitly finding uniques, repeatedly write True to the same location
+        subset = labels[region]
+        # The reshape is not strictly necessary but is about 100ms faster on a test problem.
+        label_is_present[subset.reshape(-1)] = True
+        # skip the -1 sentinel by slicing
+        # Faster than np.argwhere by a lot
+        uniques = ilabels[label_is_present[:-1]]
+        rows.append(np.full_like(uniques, idx))
+        cols.append(uniques)
+        label_is_present[:] = False
+    rows_array = np.concatenate(rows)
+    cols_array = np.concatenate(cols)
+    data = np.broadcast_to(np.array(1, dtype=np.uint8), rows_array.shape)
+    bitmask = csc_array((data, (rows_array, cols_array)), dtype=bool, shape=(nchunks, nlabels))
+
+    return bitmask
+
+
+# @memoize
+def find_group_cohorts(labels, chunks, expected_groups: None | pd.RangeIndex = None) -> dict:
     """
     Finds groups labels that occur together aka "cohorts"
 
@@ -230,9 +263,8 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
         represents NaNs.
     chunks : tuple
         chunks of the array being reduced
-    merge : bool, optional
-        Attempt to merge cohorts when one cohort's chunks are a subset
-        of another cohort's chunks.
+    expected_groups: pd.RangeIndex (optional)
+        Used to extract the largest label expected
 
     Returns
     -------
@@ -243,108 +275,90 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
     labels = np.asarray(labels)
 
     shape = tuple(sum(c) for c in chunks)
-    nchunks = math.prod(len(c) for c in chunks)
 
     # assumes that `labels` are factorized
-    nlabels = labels.max() + 1
+    if expected_groups is None:
+        nlabels = labels.max() + 1
+    else:
+        nlabels = expected_groups[-1] + 1
 
     labels = np.broadcast_to(labels, shape[-labels.ndim :])
+    bitmask = _compute_label_chunk_bitmask(labels, chunks, nlabels)
+
+    CHUNK_AXIS, LABEL_AXIS = 0, 1
+    chunks_per_label = bitmask.sum(axis=CHUNK_AXIS)
+
+    # can happen when `expected_groups` is passed but not all labels are present
+    # (binning, resampling)
+    present_labels = np.arange(bitmask.shape[LABEL_AXIS])
+    present_labels_mask = chunks_per_label != 0
+    if not present_labels_mask.all():
+        present_labels = present_labels[present_labels_mask]
+        bitmask = bitmask[..., present_labels_mask]
+        chunks_per_label = chunks_per_label[present_labels_mask]
 
-    rows = []
-    cols = []
-    # Add one to handle the -1 sentinel value
-    label_is_present = np.zeros((nlabels + 1,), dtype=bool)
-    ilabels = np.arange(nlabels)
-    for idx, region in enumerate(slices_from_chunks(chunks)):
-        # This is a quite fast way to find unique integers, when we know how many there are
-        # inspired by a similar idea in numpy_groupies for first, last
-        # instead of explicitly finding uniques, repeatedly write True to the same location
-        subset = labels[region]
-        # The reshape is not strictly necessary but is about 100ms faster on a test problem.
-        label_is_present[subset.reshape(-1)] = True
-        # skip the -1 sentinel by slicing
-        uniques = ilabels[label_is_present[:-1]]
-        rows.append([idx] * len(uniques))
-        cols.append(uniques)
-        label_is_present[:] = False
-    rows_array = np.concatenate(rows)
-    cols_array = np.concatenate(cols)
-    data = np.broadcast_to(np.array(1, dtype=np.uint8), rows_array.shape)
-    bitmask = csc_array((data, (rows_array, cols_array)), dtype=bool, shape=(nchunks, nlabels))
     label_chunks = {
-        lab: bitmask.indices[slice(bitmask.indptr[lab], bitmask.indptr[lab + 1])]
-        for lab in range(nlabels)
+        present_labels[idx]: bitmask.indices[slice(bitmask.indptr[idx], bitmask.indptr[idx + 1])]
+        for idx in range(bitmask.shape[LABEL_AXIS])
     }
 
-    ## numpy bitmask approach, faster than finding uniques, but lots of memory
-    # bitmask = np.zeros((nchunks, nlabels), dtype=bool)
-    # for idx, region in enumerate(slices_from_chunks(chunks)):
-    #     bitmask[idx, labels[region]] = True
-    # bitmask = bitmask[:, :-1]
-    # chunk = np.arange(nchunks)  # [:, np.newaxis] * bitmask
-    # label_chunks = {lab: chunk[bitmask[:, lab]] for lab in range(nlabels - 1)}
-
-    ## Pandas GroupBy approach, quite slow!
-    # which_chunk = np.empty(shape, dtype=np.int64)
-    # for idx, region in enumerate(slices_from_chunks(chunks)):
-    #     which_chunk[region] = idx
-    # which_chunk = which_chunk.reshape(-1)
-    # raveled = labels.reshape(-1)
-    # # these are chunks where a label is present
-    # label_chunks = pd.Series(which_chunk).groupby(raveled).unique()
-
-    # These invert the label_chunks mapping so we know which labels occur together.
+    # Invert the label_chunks mapping so we know which labels occur together.
     def invert(x) -> tuple[np.ndarray, ...]:
-        arr = label_chunks.get(x)
-        return tuple(arr)  # type: ignore [arg-type] # pandas issue?
+        arr = label_chunks[x]
+        return tuple(arr)
 
     chunks_cohorts = tlz.groupby(invert, label_chunks.keys())
 
-    # If our dataset has chunksize one along the axis,
-    # then no merging is possible.
+    # No merging is possible when
+    # 1. Our dataset has chunksize one along the axis,
     single_chunks = all(all(a == 1 for a in ac) for ac in chunks)
-    one_group_per_chunk = (bitmask.sum(axis=1) == 1).all()
-    if not one_group_per_chunk and not single_chunks and merge:
-        # First sort by number of chunks occupied by cohort
-        sorted_chunks_cohorts = dict(
-            sorted(chunks_cohorts.items(), key=lambda kv: len(kv[0]), reverse=True)
-        )
+    # 2. Every chunk only has a single group, but that group might extend across multiple chunks
+    one_group_per_chunk = (bitmask.sum(axis=LABEL_AXIS) == 1).all()
+    # 3. Every group is contained to one block, we should be using blockwise here.
+    every_group_one_block = (chunks_per_label == 1).all()
+    # 4. Existing cohorts don't overlap, great for time grouping with perfect chunking
+    no_overlapping_cohorts = (np.bincount(np.concatenate(tuple(chunks_cohorts.keys()))) == 1).all()
+
+    if every_group_one_block or one_group_per_chunk or single_chunks or no_overlapping_cohorts:
+        return chunks_cohorts
 
-        # precompute needed metrics for the quadratic loop below.
-        items = tuple((k, len(k), set(k), v) for k, v in sorted_chunks_cohorts.items() if k)
-
-        merged_cohorts = {}
-        merged_keys: set[tuple] = set()
-
-        # Now we iterate starting with the longest number of chunks,
-        # and then merge in cohorts that are present in a subset of those chunks
-        # I think this is suboptimal and must fail at some point.
-        # But it might work for most cases. There must be a better way...
-        for idx, (k1, len_k1, set_k1, v1) in enumerate(items):
-            if k1 in merged_keys:
-                continue
-            new_key = set_k1
-            new_value = v1
-            # iterate in reverse since we expect small cohorts
-            # to be most likely merged in to larger ones
-            for k2, len_k2, set_k2, v2 in reversed(items[idx + 1 :]):
-                if k2 not in merged_keys:
-                    if (len(set_k2 & new_key) / len_k2) > 0.75:
-                        new_key |= set_k2
-                        new_value += v2
-                        merged_keys.update((k2,))
-            sorted_ = sorted(new_value)
-            merged_cohorts[tuple(sorted(new_key))] = sorted_
-            if idx == 0 and (len(sorted_) == nlabels) and (np.array(sorted_) == ilabels).all():
-                break
-
-        # sort by first label in cohort
-        # This will help when sort=True (default)
-        # and we have to resort the dask array
-        return dict(sorted(merged_cohorts.items(), key=lambda kv: kv[1][0]))
+    # Containment = |Q & S| / |Q|
+    #  - |X| is the cardinality of set X
+    #  - Q is the query set being tested
+    #  - S is the existing set
+    MIN_CONTAINMENT = 0.75  # arbitrary
+    asfloat = bitmask.astype(float)
+    containment = ((asfloat.T @ asfloat) / chunks_per_label).tocsr()
+    mask = containment.data < MIN_CONTAINMENT
+    containment.data[mask] = 0
+    containment.eliminate_zeros()
+
+    # Iterate over labels, beginning with those with most chunks
+    order = np.argsort(containment.sum(axis=LABEL_AXIS))[::-1]
+    merged_cohorts = {}
+    merged_keys = set()
+    # TODO: we can optimize this to loop over chunk_cohorts instead
+    #       by zeroing out rows that are already in a cohort
+    for rowidx in order:
+        cohort_ = containment.indices[
+            slice(containment.indptr[rowidx], containment.indptr[rowidx + 1])
+        ]
+        cohort = [elem for elem in cohort_ if elem not in merged_keys]
+        if not cohort:
+            continue
+        merged_keys.update(cohort)
+        allchunks = (label_chunks[member] for member in cohort)
+        chunk = tuple(set(itertools.chain(*allchunks)))
+        merged_cohorts[chunk] = cohort
 
-    else:
-        return chunks_cohorts
+    actual_ngroups = np.concatenate(tuple(merged_cohorts.values())).size
+    expected_ngroups = bitmask.shape[LABEL_AXIS]
+    assert expected_ngroups == actual_ngroups, (expected_ngroups, actual_ngroups)
+
+    # sort by first label in cohort
+    # This will help when sort=True (default)
+    # and we have to resort the dask array
+    return dict(sorted(merged_cohorts.items(), key=lambda kv: kv[1][0]))
 
 
 def rechunk_for_cohorts(
@@ -1399,7 +1413,7 @@ def dask_groupby_agg(
     array: DaskArray,
     by: T_By,
     agg: Aggregation,
-    expected_groups: T_ExpectIndexOpt,
+    expected_groups: pd.RangeIndex | None,
     axis: T_Axes = (),
     fill_value: Any = None,
     method: T_Method = "map-reduce",
@@ -1419,7 +1433,7 @@ def dask_groupby_agg(
     name = f"groupby_{agg.name}"
 
     if expected_groups is None and reindex:
-        expected_groups = _get_expected_groups(by, sort=sort)
+        raise ValueError
     if method == "cohorts":
         assert reindex is False
 
@@ -1506,17 +1520,15 @@ def dask_groupby_agg(
     group_chunks: tuple[tuple[int | float, ...]]
 
     if method in ["map-reduce", "cohorts"]:
-        combine: Callable[..., IntermediateDict]
-        if do_simple_combine:
-            combine = partial(_simple_combine, reindex=reindex)
-            combine_name = "simple-combine"
-        else:
-            combine = partial(_grouped_combine, engine=engine, sort=sort)
-            combine_name = "grouped-combine"
+        combine: Callable[..., IntermediateDict] = (
+            partial(_simple_combine, reindex=reindex)
+            if do_simple_combine
+            else partial(_grouped_combine, engine=engine, sort=sort)
+        )
 
         tree_reduce = partial(
             dask.array.reductions._tree_reduce,
-            name=f"{name}-reduce-{method}-{combine_name}",
+            name=f"{name}-reduce-{method}",
             dtype=array.dtype,
             axis=axis,
             keepdims=True,
@@ -1548,7 +1560,7 @@ def dask_groupby_agg(
 
         elif method == "cohorts":
             chunks_cohorts = find_group_cohorts(
-                by_input, [array.chunks[ax] for ax in axis], merge=True
+                by_input, [array.chunks[ax] for ax in axis], expected_groups=expected_groups
             )
             reduced_ = []
             groups_ = []


=====================================
tests/test_core.py
=====================================
@@ -4,12 +4,14 @@ import itertools
 import warnings
 from functools import partial, reduce
 from typing import TYPE_CHECKING, Callable
+from unittest.mock import MagicMock, patch
 
 import numpy as np
 import pandas as pd
 import pytest
 from numpy_groupies.aggregate_numpy import aggregate
 
+import flox
 from flox import xrutils
 from flox.aggregations import Aggregation, _initialize_aggregation
 from flox.core import (
@@ -303,6 +305,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
         if chunks == -1:
             params.extend([("blockwise", None)])
 
+        combine_error = RuntimeError("This combine should not have been called.")
         for method, reindex in params:
             call = partial(
                 groupby_reduce, array, *by, method=method, reindex=reindex, **flox_kwargs
@@ -312,13 +315,22 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
                 with pytest.raises(NotImplementedError):
                     call()
                 continue
-            actual, *groups = call()
-            if method != "blockwise":
+
+            if method == "blockwise":
+                # no combine necessary
+                mocks = {
+                    "_simple_combine": MagicMock(side_effect=combine_error),
+                    "_grouped_combine": MagicMock(side_effect=combine_error),
+                }
+            else:
                 if "arg" not in func:
                     # make sure we use simple combine
-                    assert any("simple-combine" in key for key in actual.dask.layers.keys())
+                    mocks = {"_grouped_combine": MagicMock(side_effect=combine_error)}
                 else:
-                    assert any("grouped-combine" in key for key in actual.dask.layers.keys())
+                    mocks = {"_simple_combine": MagicMock(side_effect=combine_error)}
+
+            with patch.multiple(flox.core, **mocks):
+                actual, *groups = call()
             for actual_group, expect in zip(groups, expected_groups):
                 assert_equal(actual_group, expect, tolerance)
             if "arg" in func:
@@ -832,24 +844,16 @@ def test_rechunk_for_blockwise(inchunks, expected):
 
 @requires_dask
 @pytest.mark.parametrize(
-    "expected, labels, chunks, merge",
+    "expected, labels, chunks",
     [
-        [[[0, 1, 2, 3]], [0, 1, 2, 0, 1, 2, 3], (3, 4), True],
-        [[[0, 1, 2], [3]], [0, 1, 2, 0, 1, 2, 3], (3, 4), False],
-        [[[0], [1], [2], [3]], [0, 1, 2, 0, 1, 2, 3], (2, 2, 2, 1), False],
-        [[[0], [1], [2], [3]], [0, 1, 2, 0, 1, 2, 3], (2, 2, 2, 1), True],
-        [[[0, 1, 2], [3]], [0, 1, 2, 0, 1, 2, 3], (3, 3, 1), True],
-        [[[0, 1, 2], [3]], [0, 1, 2, 0, 1, 2, 3], (3, 3, 1), False],
-        [
-            [[0], [1, 2, 3, 4], [5]],
-            np.repeat(np.arange(6), [4, 4, 12, 2, 3, 4]),
-            (4, 8, 4, 9, 4),
-            True,
-        ],
+        [[[0, 1, 2, 3]], [0, 1, 2, 0, 1, 2, 3], (3, 4)],
+        [[[0], [1], [2], [3]], [0, 1, 2, 0, 1, 2, 3], (2, 2, 2, 1)],
+        [[[0, 1, 2], [3]], [0, 1, 2, 0, 1, 2, 3], (3, 3, 1)],
+        [[[0], [1, 2, 3, 4], [5]], np.repeat(np.arange(6), [4, 4, 12, 2, 3, 4]), (4, 8, 4, 9, 4)],
     ],
 )
-def test_find_group_cohorts(expected, labels, chunks: tuple[int], merge: bool) -> None:
-    actual = list(find_group_cohorts(labels, (chunks,), merge).values())
+def test_find_group_cohorts(expected, labels, chunks: tuple[int]) -> None:
+    actual = list(find_group_cohorts(labels, (chunks,)).values())
     assert actual == expected, (actual, expected)
 
 


=====================================
tests/test_xarray.py
=====================================
@@ -367,26 +367,26 @@ def test_func_is_aggregation():
         xarray_reduce(ds.Tair, ds.time.dt.month, func=mean, skipna=False)
 
 
- at requires_dask
-def test_cache():
-    pytest.importorskip("cachey")
+# @requires_dask
+# def test_cache():
+#     pytest.importorskip("cachey")
 
-    from flox.cache import cache
+#     from flox.cache import cache
 
-    ds = xr.Dataset(
-        {
-            "foo": (("x", "y"), dask.array.ones((10, 20), chunks=2)),
-            "bar": (("x", "y"), dask.array.ones((10, 20), chunks=2)),
-        },
-        coords={"labels": ("y", np.repeat([1, 2], 10))},
-    )
+#     ds = xr.Dataset(
+#         {
+#             "foo": (("x", "y"), dask.array.ones((10, 20), chunks=2)),
+#             "bar": (("x", "y"), dask.array.ones((10, 20), chunks=2)),
+#         },
+#         coords={"labels": ("y", np.repeat([1, 2], 10))},
+#     )
 
-    cache.clear()
-    xarray_reduce(ds, "labels", func="mean", method="cohorts")
-    assert len(cache.data) == 1
+#     cache.clear()
+#     xarray_reduce(ds, "labels", func="mean", method="cohorts")
+#     assert len(cache.data) == 1
 
-    xarray_reduce(ds, "labels", func="mean", method="blockwise")
-    assert len(cache.data) == 2
+#     xarray_reduce(ds, "labels", func="mean", method="blockwise")
+#     assert len(cache.data) == 2
 
 
 @requires_dask
@@ -629,3 +629,30 @@ def test_groupby_2d_dataset():
         expected.counts.dims == actual.counts.dims
     )  # https://github.com/pydata/xarray/issues/8292
     xr.testing.assert_identical(expected, actual)
+
+
+ at pytest.mark.parametrize("chunk", (pytest.param(True, marks=requires_dask), False))
+def test_resampling_missing_groups(chunk):
+    # Regression test for https://github.com/pydata/xarray/issues/8592
+    time_coords = pd.to_datetime(
+        ["2018-06-13T03:40:36", "2018-06-13T05:50:37", "2018-06-15T03:02:34"]
+    )
+
+    latitude_coords = [0.0]
+    longitude_coords = [0.0]
+
+    data = [[[1.0]], [[2.0]], [[3.0]]]
+
+    da = xr.DataArray(
+        data,
+        coords={"time": time_coords, "latitude": latitude_coords, "longitude": longitude_coords},
+        dims=["time", "latitude", "longitude"],
+    )
+    if chunk:
+        da = da.chunk(time=1)
+    # Without chunking the dataarray, it works:
+    with xr.set_options(use_flox=False):
+        expected = da.resample(time="1D").mean()
+    with xr.set_options(use_flox=True):
+        actual = da.resample(time="1D").mean()
+    xr.testing.assert_identical(expected, actual)



View it on GitLab: https://salsa.debian.org/debian-gis-team/flox/-/commit/8f81c002503d92af8c1ed446788e0f73ff7efe05

-- 
View it on GitLab: https://salsa.debian.org/debian-gis-team/flox/-/commit/8f81c002503d92af8c1ed446788e0f73ff7efe05
You're receiving this email because of your account on salsa.debian.org.


-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://alioth-lists.debian.net/pipermail/pkg-grass-devel/attachments/20240106/d89f8b66/attachment-0001.htm>


More information about the Pkg-grass-devel mailing list