## Background
Rechunker's current "Push / Pull Consolidated" algorithm can be… thought of as a mix of "split" and "combine" steps:
- _Combine_ `source_chunks` -> `read_chunks` (to fill up `max_mem` per chunk)
- _Split/Combine_ if `read_chunks != write_chunks`:
- _Split_ `read_chunks` -> `int_chunks`
- _Combine_ `int_chunks` -> `write_chunks`
- _Split_ `write_chunks` -> `target_chunks`
This is pretty clever, but can run into scalability issues. In particular, sometimes `int_chunks` must be very small, which results in a significant overhead for reading/writing small files (https://github.com/pangeo-data/rechunker/issues/80).
## Proposal
I think the right way to fix this is to extend Rechunker's algorithm to allow for multiple split/combine stages -- as many as necessary, to avoid creating tiny intermediate chunk sizes. This PR implements the math for such as algorithm, in a fully backwards compatible fashion. Users can control the number of stages via the new `min_mem` parameter, which specifies a _minimum_ chunk size in bytes.
Multi-stage rechunking is not yet hooked up to any of Rechunker's executors, but that _should_ be relatively straightforward. However, I'm probably not going to do that myself, because I only need the math for rechunking inside [Xarray-Beam](https://github.com/google/xarray-beam) (this was an easier way to explore writing the Beam executor part).
Dask also does multi-stage rechunking, which brought significant efficiency gains (https://github.com/dask/dask/issues/417). I considered copying Dask's rechunk planning algorithm here, but it [involves a lot of complex logic](https://github.com/dask/dask/blob/640df6bde07a3eb6b006633f309eccaa50985287/dask/array/rechunk.py) so I decided to try replacing it with a simple heuristic instead. See `algorithm.py` for details.
## Example
My specific motivation for this PR is experimenting with rechunking on Pangeo's ERA5 single-level dataset, which contains a number of 1.5 TB variables (at least once decoded into float32). Rechunking these arrays with `shape=(350640, 721, 1440)` from "whole image" chunks `(31, 721, 1440)` to "whole time-series" chunks `(350640, 10, 10)` with Rechunker's current algorithm produces a very large number of small chunks. It works, but seems much slower than it should be.
I wrote a little Python script to compare Rechunker's current method (`min_mem=0`), with my proposed multi-stage method (`min_mem=10MB`) and Dask's rechunking method:
<details>
```python
from rechunker import algorithm
import numpy as np
import sys
import math
from rechunker.compat import prod
def evaluate_stage_v2(shape, read_chunks, int_chunks, write_chunks):
tasks = algorithm.calculate_single_stage_io_ops(shape, read_chunks, write_chunks)
read_tasks = tasks if write_chunks != read_chunks else 0
write_tasks = tasks if read_chunks != int_chunks else 0
return read_tasks, write_tasks
def evaluate_plan(stages, shape, itemsize):
total_reads = 0
total_writes = 0
for i, stage in enumerate(stages):
read_chunks, int_chunks, write_chunks = stage
read_tasks, write_tasks = evaluate_stage_v2(
shape, read_chunks, int_chunks, write_chunks,
)
total_reads += read_tasks
total_writes += write_tasks
return total_reads, total_writes
def print_summary(stages, shape, itemsize):
for i, stage in enumerate(stages):
print(f"stage={i}: " + " -> ".join(map(str, stage)))
read_chunks, int_chunks, write_chunks = stage
read_tasks, write_tasks = evaluate_stage_v2(
shape, read_chunks, int_chunks, write_chunks,
)
print(f" Tasks: {read_tasks} reads, {write_tasks} writes")
print(f" Split chunks: {itemsize*np.prod(int_chunks)/1e6 :1.3f} MB")
total_reads, total_writes = evaluate_plan(stages, shape, itemsize)
print("Overall:")
print(f' Reads count: {total_reads:1.3e}')
print(f' Write count: {total_writes:1.3e}')
# dask.array.rechunk is the function
rechunk_module = sys.modules['dask.array.rechunk']
def dask_plan(shape, source_chunks, target_chunks, threshold=None):
source_expanded = rechunk_module.normalize_chunks(source_chunks, shape)
target_expanded = rechunk_module.normalize_chunks(target_chunks, shape)
# Note: itemsize seems to be ignored, by default
stages = rechunk_module.plan_rechunk(
source_expanded, target_expanded, threshold=threshold, itemsize=4,
)
write_chunks = [tuple(s[0] for s in stage) for stage in stages]
read_chunks = [source_chunks] + write_chunks[:-1]
int_chunks = [algorithm._calculate_shared_chunks(r, w)
for r, w in zip(write_chunks, read_chunks)]
return list(zip(read_chunks, int_chunks, write_chunks))
def rechunker_plan(shape, source_chunks, target_chunks, **kwargs):
stages = algorithm.multistage_rechunking_plan(
shape, source_chunks, target_chunks, **kwargs
)
return (
[(source_chunks, source_chunks, stages[0][0])]
+ list(stages)
+ [(stages[-1][-1], target_chunks, target_chunks)]
)
itemsize = 4
shape = (350640, 721, 1440)
source_chunks = (31, 721, 1440)
target_chunks = (350640, 10, 10)
print(f'Total size: {itemsize*np.prod(shape)/1e12:.3} TB')
print(f'Source chunk count: {np.prod(shape)/np.prod(source_chunks):1.3e}')
print(f'Target chunk count: {np.prod(shape)/np.prod(target_chunks):1.3e}')
print()
print("Rechunker plan (min_mem=0, max_mem=500 MB):")
plan = rechunker_plan(
shape, source_chunks, target_chunks, itemsize=4, min_mem=0, max_mem=int(500e6)
)
print_summary(plan, shape, itemsize=4)
print()
print("Rechunker plan (min_mem=10 MB, max_mem=500 MB):")
plan = rechunker_plan(
shape, source_chunks, target_chunks, itemsize=4, min_mem=int(10e6), max_mem=int(500e6)
)
print_summary(plan, shape, itemsize=4)
print()
print("Dask plan (default):")
plan = dask_plan(shape, source_chunks, target_chunks)
print_summary(plan, shape, itemsize=4)
```
</details>
```
Total size: 1.46 TB
Source chunk count: 1.131e+04
Target chunk count: 1.038e+04
Rechunker plan (min_mem=0, max_mem=500 MB):
stage=0: (31, 721, 1440) -> (31, 721, 1440) -> (93, 721, 1440)
Tasks: 11311 reads, 0 writes
Split chunks: 128.742 MB
stage=1: (93, 721, 1440) -> (93, 10, 30) -> (350640, 10, 30)
Tasks: 13213584 reads, 13213584 writes
Split chunks: 0.112 MB
stage=2: (350640, 10, 30) -> (350640, 10, 10) -> (350640, 10, 10)
Tasks: 10512 reads, 10512 writes
Split chunks: 140.256 MB
Overall:
Reads count: 1.324e+07
Write count: 1.322e+07
Rechunker plan (min_mem=10 MB, max_mem=500 MB):
stage=0: (31, 721, 1440) -> (31, 721, 1440) -> (93, 721, 1440)
Tasks: 11311 reads, 0 writes
Split chunks: 128.742 MB
stage=1: (93, 721, 1440) -> (93, 173, 396) -> (1447, 173, 396)
Tasks: 80220 reads, 80220 writes
Split chunks: 25.485 MB
stage=2: (1447, 173, 396) -> (1447, 41, 109) -> (22528, 41, 109)
Tasks: 96492 reads, 96492 writes
Split chunks: 25.867 MB
stage=3: (22528, 41, 109) -> (22528, 10, 30) -> (350640, 10, 30)
Tasks: 86864 reads, 86864 writes
Split chunks: 27.034 MB
stage=4: (350640, 10, 30) -> (350640, 10, 10) -> (350640, 10, 10)
Tasks: 10512 reads, 10512 writes
Split chunks: 140.256 MB
Overall:
Reads count: 2.854e+05
Write count: 2.741e+05
Dask plan (default):
stage=0: (31, 721, 1440) -> (31, 721, 1440) -> (32, 721, 1440)
Tasks: 21915 reads, 0 writes
Split chunks: 128.742 MB
stage=1: (32, 721, 1440) -> (32, 160, 480) -> (302, 160, 480)
Tasks: 180705 reads, 180705 writes
Split chunks: 9.830 MB
stage=2: (302, 160, 480) -> (302, 80, 150) -> (2922, 80, 150)
Tasks: 153720 reads, 153720 writes
Split chunks: 14.496 MB
stage=3: (2922, 80, 150) -> (2922, 20, 60) -> (29220, 20, 60)
Tasks: 128760 reads, 128760 writes
Split chunks: 14.026 MB
stage=4: (29220, 20, 60) -> (29220, 10, 10) -> (350640, 10, 10)
Tasks: 126144 reads, 126144 writes
Split chunks: 11.688 MB
Overall:
Reads count: 6.112e+05
Write count: 5.893e+05
```
Comparing my new multi-stage algorithm (`max_mem=10MB`) to Rechunker's existing algorithm (`max_mem=0`), the multi-stage pipeline does two extra dataset copies, but reduces the number of IO operations by ~50x.
Comparing my new algorithm to Dask's algorithm, the plans actually look remarkably similar. My estimates suggest that my algorithm should involve about half the number of IO operations, but Dask's plan uses slightly "nicer" chunk sizes. I have no idea which is better is practice, and note that I'm using Dask's algorithm without adjusting any of the control knobs.
I have not yet benchmarked any of these algorithms on real rechunking tasks.