Rechunking large data at constant memory in Dask [experimental]

Hi everyone,

With release 2023.3.1, dask offers a new method for rechunking arrays called p2p that rechunks large arrays at constant memory. This enables workloads that were previously un-runnable and required you to rechunking your datasets ahead of time using tools like rechunker.

Early benchmarks on P2P rechunking look very promising:

  • Cluster memory usage remains constant over the entire duration of the rechunk
  • Many workloads that were previously impossible to run now succeed smoothly (not on the plot)

To help us improve on the current implementation, we would love for you to try this out and give us feedback on this topic!

P2P rechunk method is currently available as an experimental feature in Dask. To use it, make sure to install distributed directly from Github using

pip install git+https://github.com/dask/distributed@main

to benefit from the latest improvements. You also need to wrap your code with the required Dask configuration:

# Note: "optimization.fuse.active" must be False for this to work for now 
with dask.config.set({"array.rechunk.method": "p2p", "optimization.fuse.active": False}):
	arr.rechunk(...)
	...
	client.compute(...) # Make sure to include the call to compute()

As an experimental feature, there are a few known caveats:

  • Long initialization time: When the rechunk first starts, it takes some time to initialize state during which the dashboard looks frozen.
  • Memory usage with many chunks: While P2P rechunking only requires constant memory, it currently has a large memory overhead if you use a very large number of chunks (cf. test_rechunk_in_memory[small]).
  • da.store: Currently, P2P rechunking is incompatible with da.store due to dask/dask#10074. This may affect many storage APIs like to_zarr and others.

To help us, we are particularly interested in the following:

  • Which workloads do you typically run?
    • For example: # workers, worker CPU/memory, input chunks, output chunks
  • What is important for you when running them?
    • For example: memory usage, runtime performance, performance predictability
  • How do they perform?
    • Performance reports, memory samples, and dashboard screenshots are greatly appreciated!

To learn more, take a look at the pull request dask/distributed#7534 on GitHub or our blog post about P2P shuffling and rechunking at https://blog.coiled.io/blog/shuffling-large-data-at-constant-memory.

10 Likes

Thanks for sharing! Iā€™m trying this out on a somewhat common workload that should stress exactly this situation: making a cloud-free mosaic of many satellite images by taking a median over time. We have a 4-D array (time, y, x, band) and want to take the median over time. With satellite imagery, the source array essentially always has a chunksiz of time=1, and may or may not be chunked along the other dimensions.

To do the median computation, dask.array needs to rechunk the array to be contiguous in time. Essentially, the chunking on disk is exactly wrong for this operation, so we need to rechunk a lot of data.

My first couple attempts have consistently failed with a CancelledError, resulting from a TimeoutError.

---------------------------------------------------------------------------
CancelledError                            Traceback (most recent call last)
File /srv/conda/envs/notebook/lib/python3.10/asyncio/tasks.py:418, in wait_for()
    417 try:
--> 418     return fut.result()
    419 except exceptions.CancelledError as exc:

CancelledError: 

The above exception was the direct cause of the following exception:

TimeoutError                              Traceback (most recent call last)
File /srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/comm/core.py:330, in connect()
    329     handshake = await wait_for(comm.read(), time_left())
--> 330     await wait_for(comm.write(local_info), time_left())
    331 except Exception as exc:

File /srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/utils.py:1812, in wait_for()
   1811 async def wait_for(fut: Awaitable[T], timeout: float) -> T:
-> 1812     return await asyncio.wait_for(fut, timeout)

File /srv/conda/envs/notebook/lib/python3.10/asyncio/tasks.py:420, in wait_for()
    419     except exceptions.CancelledError as exc:
--> 420         raise exceptions.TimeoutError() from exc
    422 waiter = loop.create_future()

TimeoutError: 

The above exception was the direct cause of the following exception:

OSError                                   Traceback (most recent call last)
File /srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/shuffle/_rechunk.py:41, in rechunk_transfer()
     40 try:
---> 41     return _get_worker_extension().add_partition(
     42         input,
     43         input_partition=input_chunk,
     44         shuffle_id=id,
     45         type=ShuffleType.ARRAY_RECHUNK,
     46         new=new,
     47         old=old,
     48     )
     49 except Exception as e:

File /srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/shuffle/_worker_extension.py:632, in add_partition()
    631 shuffle = self.get_or_create_shuffle(shuffle_id, type=type, **kwargs)
--> 632 return sync(
    633     self.worker.loop,
    634     shuffle.add_partition,
    635     data=data,
    636     input_partition=input_partition,
    637 )

File /srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/utils.py:412, in sync()
    411     typ, exc, tb = error
--> 412     raise exc.with_traceback(tb)
    413 else:

File /srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/utils.py:385, in f()
    384     future = asyncio.ensure_future(future)
--> 385     result = yield future
    386 except Exception:

File /srv/conda/envs/notebook/lib/python3.10/site-packages/tornado/gen.py:769, in run()
    768 try:
--> 769     value = future.result()
    770 except Exception:

File /srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/shuffle/_worker_extension.py:354, in add_partition()
    353 out = await self.offload(_)
--> 354 await self._write_to_comm(out)
    355 return self.run_id

File /srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/shuffle/_worker_extension.py:152, in _write_to_comm()
    151 self.raise_if_closed()
--> 152 await self._comm_buffer.write(data)

File /srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/shuffle/_buffer.py:189, in write()
    188 if self._exception:
--> 189     raise self._exception
    190 if not self._accepts_input or self._inputs_done:

File /srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/shuffle/_buffer.py:107, in process()
    106 try:
--> 107     await self._process(id, shards)
    108     self.bytes_written += size

File /srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/shuffle/_comms.py:71, in _process()
     70 with self.time("send"):
---> 71     await self.send(address, shards)

File /srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/shuffle/_worker_extension.py:123, in send()
    122 self.raise_if_closed()
--> 123 return await self.rpc(address).shuffle_receive(
    124     data=to_serialize(shards),
    125     shuffle_id=self.id,
    126     run_id=self.run_id,
    127 )

File /srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/core.py:1231, in send_recv_from_rpc()
   1230     kwargs["deserializers"] = self.deserializers
-> 1231 comm = await self.pool.connect(self.addr)
   1232 prev_name, comm.name = comm.name, "ConnectionPool." + key

File /srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/core.py:1475, in connect()
   1474     raise
-> 1475 return await connect_attempt

File /srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/core.py:1396, in _connect()
   1395 self._connecting_count += 1
-> 1396 comm = await connect(
   1397     addr,
   1398     timeout=timeout or self.timeout,
   1399     deserialize=self.deserialize,
   1400     **self.connection_args,
   1401 )
   1402 comm.name = "ConnectionPool"

File /srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/comm/core.py:334, in connect()
    333         await comm.close()
--> 334     raise OSError(
    335         f"Timed out during handshake while connecting to {addr} after {timeout} s"
    336     ) from exc
    338 comm.remote_info = handshake

OSError: Timed out during handshake while connecting to tls://10.244.9.6:42079 after 30 s

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
Cell In[6], line 8
      6 with distributed.performance_report("new.html"):
      7     with dask.config.set({"array.rechunk.method": "p2p", "optimization.fuse.active": False}):
----> 8         median = data.median(dim="time").compute()

File /srv/conda/envs/notebook/lib/python3.10/site-packages/xarray/core/dataarray.py:1089, in DataArray.compute(self, **kwargs)
   1070 """Manually trigger loading of this array's data from disk or a
   1071 remote source into memory and return a new array. The original is
   1072 left unaltered.
   (...)
   1086 dask.compute
   1087 """
   1088 new = self.copy(deep=False)
-> 1089 return new.load(**kwargs)

File /srv/conda/envs/notebook/lib/python3.10/site-packages/xarray/core/dataarray.py:1063, in DataArray.load(self, **kwargs)
   1045 def load(self: T_DataArray, **kwargs) -> T_DataArray:
   1046     """Manually trigger loading of this array's data from disk or a
   1047     remote source into memory and return this array.
   1048 
   (...)
   1061     dask.compute
   1062     """
-> 1063     ds = self._to_temp_dataset().load(**kwargs)
   1064     new = self._from_temp_dataset(ds)
   1065     self._variable = new._variable

File /srv/conda/envs/notebook/lib/python3.10/site-packages/xarray/core/dataset.py:746, in Dataset.load(self, **kwargs)
    743 import dask.array as da
    745 # evaluate all the dask arrays simultaneously
--> 746 evaluated_data = da.compute(*lazy_data.values(), **kwargs)
    748 for k, data in zip(lazy_data, evaluated_data):
    749     self.variables[k].data = data

File /srv/conda/envs/notebook/lib/python3.10/site-packages/dask/base.py:599, in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
    596     keys.append(x.__dask_keys__())
    597     postcomputes.append(x.__dask_postcompute__())
--> 599 results = schedule(dsk, keys, **kwargs)
    600 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])

File /srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/client.py:3168, in Client.get(self, dsk, keys, workers, allow_other_workers, resources, sync, asynchronous, direct, retries, priority, fifo_timeout, actors, **kwargs)
   3166         should_rejoin = False
   3167 try:
-> 3168     results = self.gather(packed, asynchronous=asynchronous, direct=direct)
   3169 finally:
   3170     for f in futures.values():

File /srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/client.py:2328, in Client.gather(self, futures, errors, direct, asynchronous)
   2326 else:
   2327     local_worker = None
-> 2328 return self.sync(
   2329     self._gather,
   2330     futures,
   2331     errors=errors,
   2332     direct=direct,
   2333     local_worker=local_worker,
   2334     asynchronous=asynchronous,
   2335 )

File /srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/utils.py:345, in SyncMethodMixin.sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
    343     return future
    344 else:
--> 345     return sync(
    346         self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
    347     )

File /srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/utils.py:412, in sync(loop, func, callback_timeout, *args, **kwargs)
    410 if error:
    411     typ, exc, tb = error
--> 412     raise exc.with_traceback(tb)
    413 else:
    414     return result

File /srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/utils.py:385, in sync.<locals>.f()
    383         future = wait_for(future, callback_timeout)
    384     future = asyncio.ensure_future(future)
--> 385     result = yield future
    386 except Exception:
    387     error = sys.exc_info()

File /srv/conda/envs/notebook/lib/python3.10/site-packages/tornado/gen.py:769, in Runner.run(self)
    766 exc_info = None
    768 try:
--> 769     value = future.result()
    770 except Exception:
    771     exc_info = sys.exc_info()

File /srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/client.py:2191, in Client._gather(self, futures, errors, direct, local_worker)
   2189         exc = CancelledError(key)
   2190     else:
-> 2191         raise exception.with_traceback(traceback)
   2192     raise exc
   2193 if errors == "skip":

File /srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/shuffle/_rechunk.py:50, in rechunk_transfer()
     41     return _get_worker_extension().add_partition(
     42         input,
     43         input_partition=input_chunk,
   (...)
     47         old=old,
     48     )
     49 except Exception as e:
---> 50     raise RuntimeError(f"rechunk_transfer failed during shuffle {id}") from e

RuntimeError: rechunk_transfer failed during shuffle cc9531b7820cc766da633788babf811c

Iā€™m using this notebook. I suspect Iā€™ll be able to simplify things. Iā€™m going to dig through the logs first to see whatā€™s going on (I donā€™t think a worker is dying, but will confirm that).

Edit: Here are the logs from the worker that seemed to run into issues first: shuffle-logs.txt Ā· GitHub. Things start OK, then we get some warnings about the GIL, and then we get the note that a TCP connection failed.

2023-03-20 15:16:29,763 - distributed.comm.tcp - INFO - Connection from tls://10.244.10.9:44336 closed before handshake completed
2023-03-20 15:14:46,236 - distributed.core - INFO - Starting established connection to tls://dask-351d74e2f7f644beb571a17408e81902.staging:8786
2023-03-20 15:14:59,341 - distributed.core - INFO - Event loop was unresponsive in Worker for 9.18s.  This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.
2023-03-20 15:15:00,410 - distributed.utils_perf - INFO - full garbage collection released 47.80 MiB from 0 reference cycles (threshold: 9.54 MiB)
2023-03-20 15:15:03,907 - distributed.core - INFO - Event loop was unresponsive in Worker for 4.53s.  This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.
2023-03-20 15:15:43,799 - distributed.core - INFO - Event loop was unresponsive in Worker for 3.59s.  This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.
2023-03-20 15:15:48,464 - distributed.core - INFO - Event loop was unresponsive in Worker for 4.66s.  This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.
2023-03-20 15:15:51,847 - distributed.core - INFO - Event loop was unresponsive in Worker for 3.38s.  This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.
2023-03-20 15:16:29,763 - distributed.comm.tcp - INFO - Connection from tls://10.244.10.9:44336 closed before handshake completed
2023-03-20 15:18:05,967 - distributed.shuffle._comms - ERROR - Shuffle cc9531b7820cc766da633788babf811c forgotten
Traceback (most recent call last):
  File "/srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/shuffle/_comms.py", line 71, in _process
    await self.send(address, shards)
  File "/srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/shuffle/_worker_extension.py", line 122, in send
    self.raise_if_closed()
  File "/srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/shuffle/_worker_extension.py", line 163, in raise_if_closed
    raise self._exception
RuntimeError: Shuffle cc9531b7820cc766da633788babf811c forgotten
1 Like

I guess youā€™ve tried with a small spatial extent to see if this is a problem size that does work?

1 Like

Thanks for trying this out and thank you for this excellent example. I ran this example on Coiled and it did finish successfully both with the new P2P rechunking but also with the traditional task based algorithm.

This is a memory sample of the runs
rechunk_memory_sample

In terms of walltime performance P2P wins over tasks and is only taking about ~60% of the time. Investigating this a bit closer, the cluster is actually not doing anything for a very long time until the entire graph is submitted (tasks based graph has about 135k nodes while P2P only requires 13k). Adjusting for this, P2P and tasks require about the same time on the cluster but P2P wins because the graph is much smaller, i.e. the computation starts much earlier.
The memory usage, however, is drastically different since P2P is using disk much more efficiently to keep the memory footprint low. There is a lot of room for improvement here. For instance, I noticed that the CPUs on the P2P run are only 50% utilized because the raterio calls are actually rather slow and are bottlenecking, i.e. we could likely double the number of threads easily and get a nice performance kick out of this.

Now, this was running on Coiled on AWS using m6i.large instances. During early testing we saw similar connection failures that were correlated with ā€œEvent loop was unresponsiveā€ messages. We particularly saw this on burstable VMs (performance is rather unpredictable, other VMs can steal CPU time, etc.). If you have control over the VM types you run on, this might be something to check but either way this is something for us to figure out and make more robust.

As a temporary workaround I suggest to increase the connection timeout, e.g.

dask.config.set({
    "distributed.comm.timeouts.connect": "60s"
})  # Or more generously 90s, 120s, etc.

Note about the data size: This example has ~370GiB. All workers combined need to have as much disk mounted. Ideally a bit more since there is unfortunately a bit of overhead. With the 16 workers / m6i.large I came close to the disk limit. This might explain a killed worker if that is indeed your problem.

1 Like

Thanks Florian! Iā€™ll give it another shot with longer timeouts to confirm that that workaround fixes the issue I saw.

Hereā€™s my mental model for what might have happened on the cluster: The workers are putting more strain on the network (perhaps p2p uses more concurrent connections by design, or perhaps workers are just making better use of the hardware and so are able to do more on the network side). At some point the network got so congested that a worker connection (maybe worker ā†’ worker, scheduler ā†’ worker, or worker ā†’ scheduler) timed out, and so the whole operation was (gracefully!) cancelled.

Does that roughly align with what youā€™re thinking?

While it might be nice to have the p2p shuffle automatically consider that worker as suspect / dead (even if it isnā€™t) and route around it, itā€™s somewhat nice to loudly know that the clusterā€™s network is becoming a bottleneck.

Itā€™s great to see the improvements here! Iā€™ll try this out again with higher timeouts, and will see what I can do on the network side.


As an aside for this specific computation, Iā€™m reminded of an idea to push rechunking down into the I/O layer. Some file formats (like COGs in the case) support reading subsets of the file since theyā€™re internally chunked. Itā€™s perhaps a bit wasteful for Dask to download the data from Blob Storage, store an intermediate to disk, and then shuffle data between workers. In this case we might have been able to do the original reads from Blob Storage using our desired chunks by rewriting the I/O operation.

Whether or not that would be faster depends on a bunch of factors, but itā€™s fun to think about (Pushing array rechunking into the I/O layer Ā· Issue #8526 Ā· dask/dask Ā· GitHub, which I now see is a duplicate of Pushdown of slicing operator on arrays? Ā· Issue #6288 Ā· dask/dask Ā· GitHub). And itā€™s only an option when the rechunking happens immediately after reading (or perhaps only after operations that donā€™t depend on the chunking structure).

Well, yes and no. The strain that is causing these connection failures stems from load on the servers/workers, specifically it is about a blocked asyncio event loop (thatā€™s what the message is all about). You can see that the event loop is occasionally blocked by 3-4s on that worker. Establishing a connection currently requires about ā€œsix ticksā€, i.e. if every tick takes 4 seconds weā€™re already at 24s. Plus a bit of slow network and weā€™re easily at the 30s mark that aborts a connection attempt.
This is typically not ā€œvisibleā€ because the ordinary worker<->worker connections retry but P2P doesnā€™t (yet).
The fix for this is

Increasing the timeout is a brute force fix for thisā€¦ I typically rather strongly discourage increasing this value since it can cause the system to appear sluggish in other circumstances.

While it might be nice to have the p2p shuffle automatically consider that worker as suspect / dead (even if it isnā€™t) and route around it, itā€™s somewhat nice to loudly know that the clusterā€™s network is becoming a bottleneck.

Actually a P2P shuffle/rechunk will fail loudly if a worker dies during the operation. We havenā€™t implemented graceful recovery from this, yet.

I implemented something like this once for a parquet shuffle. It works but requires quite a bit of fine tuning to work efficiently (buffer sizes, row group sizes, etc.)

If Iā€™m understanding you correctly, passing something like chunks=(-1, 1, 512, 512) to stackstac.stack will already do just this. Weā€™re not talking about a full shuffle here, just loading all the items from the same spatial region into one chunk. (I actually havenā€™t tested this since the memory improvements in distributed, Iā€™d be curious if that helps at all.)

Of course this doesnā€™t happen automatically right now; you have to know to do it. I would love to do this automatically for users, but doing that with daskā€™s current optimization framework is very hard. High Level Expressions Ā· Issue #7933 Ā· dask/dask Ā· GitHub will help a lot with this. Then we can easily rewrite the upstream stack operation if we see rechunk that comes after itā€”along with lots of other optimizations that arenā€™t currently possible.

I took this for a spin today and had some sad experiences with dask & distributed 2024.5.1

I tried to rechunk a decade of the ERA-5 hourly dataset (weatherbench2). The input data are spatial maps:

Iā€™m rechunking a decade of data so that approximately a year of data is in one chunk (8760). This does not succeed.

I ran this on Coiled and recevied ā€œOSError(28, ā€˜No space left on deviceā€™)ā€ which I donā€™t understand. The dashboard offered no clues.

It does succeed if I only choose a year of data, but not 5 years of data.

Some coiled logs are here:

  1. Coiled
  2. Coiled

These errors are quite reproducible.

Reproducer script:

import coiled
from distributed import Client
import dask
import xarray as xr

output_chunks = {"time": 8760, "latitude": 72, "longitude": 144}
slicer = {"time": slice("2010", "2020")}
bucket = "YOUR_BUCKET_NAME"
prefix = "rechunked"

# public era-5 dataset
ds = xr.open_zarr(
    "gs://weatherbench2/datasets/era5/1959-2023_01_10-full_37-1h-0p25deg-chunk-1.zarr"
)

dask.config.set({"array.rechunk.method": "p2p"})

cluster = coiled.Cluster()
cluster.adapt(maximum=200)
client = Client(cluster)
print("Dask Dashboard: ", client.dashboard_link)

subset = ds[["2m_temperature"]].sel(slicer)
chunk_str = "_".join(str(output_chunks[dim]) for dim in ["time", "latitude", "longitude"])
store_name = f"{subset.sizes['time']}_{chunk_str}.zarr"
(
    subset.chunk(output_chunks)
    .drop_encoding()
    .to_zarr(f"{bucket}/{prefix}/{store_name}", mode="w")
)

Thanks Deepak. Weā€™ll look into this. I opened an upstream issue for this here

There are a couple of things we have to look into but for starters, please do not use adapt when testing P2P. The way P2P is built, it locks in the ā€œparticipating workersā€ once the rechunk starts, i.e. if the cluster scales up and new workers are joining they will most likely just be idle and the very few workers that started will overflow.

Edit: we could already confirm that the failure is indeed caused by the adaptive scaling. Weā€™ll have to think about ways to mitigate this and improve UX. If you use a static cluster (or one with at least a couple of workers / a min_workers value) you should be fine.
Itā€™s still slow and I am looking into this right now

Edit: We noticed that the graph submission took ages and have a fix for this Improve graph submission time for P2P rechunking by avoiding unpack recursion into indices by fjetter Ā· Pull Request #8672 Ā· dask/distributed Ā· GitHub and dealing with better adaptivity is tracked here Restart P2P if cluster significantly grew in size Ā· Issue #8673 Ā· dask/distributed Ā· GitHub

4 Likes

Thanks for the quick fixes! Can confirm it works well with a constant-sized cluster