Approach for averaging an ensemble from nested Zarr archives?

tl;dr: Trying to average an ensemble of CMIP6 data and running into compute issues with Dask to create the new, averaged dataset. Looking for best practices on the best way to approach this problem.

I have a large (~20TB) nested zarr archive that has the structure {model}/{variable}/{time_period}, where each group is its own consolidated zarr archive. The data has been rechunked such that it fits our use case of long time series for a specific areas of interest.

An example of the tree:

Accessing the entire dataset for an individual location takes a long time as we’re using xarray and need to read and merge 360 different datasets – model (20) x variable (9) x time_period (2). In many instances, we simply just want to take the ensemble average – ds = ds.mean(dim="model"). Therefore it would be beneficial to create a model-averaged zarr archive that we could access much faster.

I attempted to do this using Dask’s cloudprovider setup which I used for rechunking, example code below. The just simply does not work because the compute graph seems to be too larger. It takes a ton of memory > 64GB before the graph is sent to the client, and even then the client just seems to hang and never computes anything and I end up getting errors like distributed.worker - CRITICAL - Error trying close worker in response to broken internal state. Forcibly exiting worker NOW

I assume I’m just not approaching this problem in the correct way, or perhaps I should be trying to use a tool other than xarray to accomplish this task. Any suggestions on a better way to approach creating an average across a large set of Zarr archives would be appreciated.

from dask_cloudprovider.aws import FargateCluster
from dask.distributed import Client

ds = get_data_from_rechunked_zarr()
ds_mean = ds.mean(dim="model")
ds_mean = ds_mean.chunk(dict(time=31411, lat=5, lon=5, scenario=-1))


time_string = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")  # just using to identify logs
cpu_spec = 1024
worker_mem = cpu_spec * 8
cluster = FargateCluster(
    n_workers=256,
    worker_cpu=cpu_spec,
    worker_mem=worker_mem,
    image="pangeo/pangeo-notebook",
    cloudwatch_logs_group=f"dask-climo-{time_string}",
    environment=get_aws_credentials(),
    scheduler_timeout="15 minutes",
)
client = Client(cluster)
zarr_destination = "s3://{my_zarr_destination}"
out = ds_mean.to_zarr(zarr_destination, mode="w", compute=False, consolidated=True)
fut = client.compute(out)

20TB problem - 2TB cluster?

Do you know how much going to each worker?

Will the mean for just 1 model work?

First we would need to see what the chunking of ds and ds_mean are before rechunking

ds_mean = ds_mean.chunk(dict(time=31411, lat=5, lon=5, scenario=-1))

Then, I would try “p2p rechunking” and report here: Share your experiences with P2P shuffling · dask/distributed · Discussion #7509 · GitHub or in that Discourse thread.

For reference, the chunks of the original dataset are dict(time=31411, lat=5, lon=5, scenario=1, model=1) and I’m converting to dict(time=31411, lat=5, lon=5, scenario=-1), averaging across the model dimension.

Using p2p did help in getting this started. I ended up signing up and using coiled, since it makes spin up and tear down simpler. I was able to run the code on a small subset (2 models, 2 variables) using the code below. However, when I try to increase the size of the dataset (even with similar scaling of the cluster) I get a dask error like: CancelledError: _finalize_store-629d0e32-d4c9-4380-9a08-4277617ab4ae. I think it may just have something to do with the size of the graph.

The scaled-down dataset that does run successfully has ~550,000 tasks sent to the cluster. The jobs that fail have over a million, and running on the entire dataset would mean nearly 25 million tasks. I’m going to dig in a little deeper and hopefully report back here, though any thoughts or suggestions are appreciated.

with dask.config.set(
    {
        "array.rechunk.method": "p2p",
        "optimization.fuse.active": False,
        "distributed.comm.timeouts.connect": "60s",
    }
):
    out = ds_mean.to_zarr(zarr_destination, mode="w", consolidated=True, compute=False)
    out.compute()

(Dask maintainer here :wave:)

Error trying close worker in response to broken internal state.

First of all, if you still have logs or anything that led to this, please open an upstream issue in dask/distributed. This should be impossible and is a very critical issue we’d like to fix (these things only happen very rarely and are hard to reproduce race conditions).
If you don’t, not a big deal. These things are hard to reproduce and logging may not even be that helpful. If you run into this more frequently, let’s talk!

Using p2p did help in getting this started. I ended up signing up and using coiled.io

Thanks, this is actually helpful for us to debug the issue! I’ll try to have a look shortly (I’m EU timezone, approach EOD)

For such a large graph, I suspect that disabling the fusing is already very helpful. This should reduce the memory footprint on cluster side and kick things of a bit more quickly.

BTW you should wrap your entire code with the dask.config setting. I suspect that you actually didn’t even use P2P for the rechunking and the only benefit you’ve seen is the removal of the optimization.

Specifically, this line

ds_mean = ds_mean.chunk(dict(time=31411, lat=5, lon=5, scenario=-1))

has to be wrapped with the ctx manager.

Or you can just set the config globally using

dask.config.set(...) (i.e. without ctx manager but before you are generating the graph / calling chunk)

Similarly, the connect timeout should be set before you generate the cluster. Ideally you set this all when you startup your notebook or run your script. This way you’re not accidentally setting it at the wrong time (UX is not great here, that’s on us)