Best practice advice on parallel processing a suite of zarr files with Dask and xarray

Hello all, I have a suite of about 300 zarr files that total to about 40 Tb that I’m looking for an efficient way to parallel process. The files are structured as [ensemble, time, lat, lon], where I’ve added the ensemble dimension to the zarrs so they can easily combine with xr.open_mfdataset(). This works well and gives you access to a very large well-structured dataset, but it seems to only work for slicing and dicing and doing simple operations. If you try to do something complex, such as groupby operations, calculating percentiles, or about anything that uses the xClim package it seems to overwhelm Dask. Specifically, the Dask graph gets incredibly complicated, to the point that it either takes forever to resolve or just crashes. I am using an HPC, so I’ve been able to get around this by using a SLURM job array with a Python script and having each instance of the script only deal with one or two zarrs at a time. This works amazingly well since you can leverage the HPC resources in tandem, but I wonder if there is a similar approach using Dask.SLURMCluster that could be used from a Jupyter Notebook. Perhaps the suite of zarrs could be processed in parallel using a Dask bag or delayed functions. Since it’s on an HPC, I’d want to parallelize it as much as possible, but I wonder if processing large zarr files in xarray within a Dask bag gets you right back in the same boat, very complicated Dask graphs. I don’t know if there is a Dask way to wall off each item in the bag (which is a bunch of xr.DataArray/DaskArray processing), or if it ultimately comes to creating a master Dask graph before any processing begins.

I’m sure the big data folks have run into this condition before, so I was wondering if there is any best practice advice here. Ideally, I’d like to dedicate 32-64 CPUs to each zarr to calculated xClim metrics and then kick off a bunch of SLURM jobs (either Dask.SLURMCluster or SLURM scripts from the command line). Using a SLURM job array from the command line isolates the processing of each zarr, so I could continue with that, but I was wondering if there was an efficient user-friendly way from Jupyter that might be easier for novices to use.

Thanks

2 Likes

The details here really matter.

Do you have flox installed? Can you name the xclim metrics you’re trying to calculate? Are these grouped percentiles you are calculating?

2 Likes

Yes, I have flox (0.10) installed. I tried with a subset of 72 concatenated zarr files as a test and tried xclim.atmos.wetdays(), which has ~4 Tb input and ~40 Gb output. Most of the metrics I’m interested in are daily → annual counts (such as the number of days over XYZ). Even with a smaller subset, the Dask graph is still 1 Gb and very sluggish until the workers start failing. I was using 512 threads and 1.5 Tb of RAM, so it is not lacking for resources.

I suppose my question comes down to what are good strategies to use once the data become too large or the Dask graphs too complex. Are there ways to continue to use Dask, such as bags, or are you better off with many smaller instances of Python / Dask to keep the graph size from spiraling out? Dask must have some practical upper limits (which likely varies by problem type), so what are good approaches once you’ve reached those limits?

1 Like

I believe batching is the best short term solution.

Most of the metrics I’m interested in are daily → annual counts (such as the number of days over XYZ). Even with a smaller subset, the Dask graph is still 1 Gb and very sluggish until the workers start failing.

Usually this kind of observation means there is a really bad inefficiency somewhere in the stack and it takes some work to track down (example).

The code for xclim.atmos.wetdays() is hard to follow, so I’m not sure what is going on. Can you create a minimal example with the same shape & chunk sizes for your “big array” to illustrate the problem please?

I think this is as close as I can get with synthetic data.

import xclim.indices
import xarray as xr
import numpy as np
import pandas as pd
import dask.array as da
from dask.distributed import Client
from dask_jobqueue import SLURMCluster

cluster = SLURMCluster(cores=16,
                       processes=1,
                       memory="64GB",
                       account="gmegsc",
                       walltime="01:00:00",
                       queue="cpu")
cluster.scale(64)
client = Client(cluster)

scale = 20.
n_lon = 944
n_lat = 474
n_time = 31411
n_ensemble = 72

time = xr.date_range(start="2015-01-01", end="2100-12-31", freq='D', use_cftime=True)
lat = np.arange(23.90625, 53.46875+0.0625, 0.0625)
lon = np.arange(234.53125, 293.46875+0.0625, 0.0625)
ensemble = np.arange(n_ensemble)+1

c_lon = 118
c_lat = 158
c_time = 468
c_ensemble = 1


ds = xr.Dataset(
    data_vars=dict(
        pr=(["ensemble","time","lat","lon"], 
(da.random.random((n_ensemble, n_time, n_lat, n_lon), 
            chunks=(c_ensemble, c_time, c_lat, c_lon))  * scale).astype('float32'), dict(units="mm/d")) 
    ),
    coords=dict(
        lat=lat,
        lon=lon,
        time=time,
        ensemble=ensemble
    )
)


ds_out = xr.Dataset(data_vars=dict(
            wetdays=xclim.atmos.wetdays(ds.pr, thresh='5 mm/d', freq='YS').assign_attrs({"units":"d"}),
        )).chunk({"time":-1})

ds_out.to_zarr("large_xclim_sample.zarr")

client.close()
cluster.close()

This mimics appending 72 zarr files along an ensemble dimension. The real task graph is probably a bit more complicated because of the join, but this is close enough. This sample gives a 750 Mb task graph size warning, which is lower than the 900 Mb to 1 Gb from the real data. My output files also have multiple xclim output variables in ds_out, further complicating the Dask graph.

I ran this sample on our HPC with 1028 threads. It took about ten minutes to start and began dropping workers during the wait, but it did complete. It finished after about 20 minutes (~10. minutes to resolve the task graph), where 8 workers dropped out. The real version with zarr data and a second variable fails.

So far my solution has been to break this up to process one zarr at a time (so ensemble=1) and use a SLURM job array. That works very well, but SLURMCluster within a Jupyter Notebook would also be nice if that were possible / practical.

Thanks

Yeah the core problem is in flox — a simple ds.resample(time="YS").sum().compute() raises a warning about a 300MB graph. It’s embedding data that is duplicated a lot due to the small chunk size.

Since you are calculating annual statistics, and your input chunk sizes are small (30MB), let’s rechunk it (you should do this in your open_dataset call). You can use Xarray’s new-ish TimeResampler objects to rechunk to a frequency (slick!)

from xarray.groupers import TimeResampler

ds.chunk(time=TimeResampler("10YS")) # aim for 200-300MB

Once you do this, the problem is embarassingly parallel. Your previous chunk size (468) did not match the “yearly” frequency, so there was some inter-block communication required.

For a long time, I’ve wanted to figure out automated heuristics for this but haven’t had time to do so. (hah, I even started prototyping some automated rechunking here but never finished)

5 Likes

Thanks, I’ll try the TimeResampler. I set the time chunk to 468 to reach a reasonable block size (30 Mb) for data that will be served on a cloud store eventually. These data include leap years, so it wasn’t as easy as just setting the chunk size to 365.

1 Like

Wow, thanks @dcherian! Talk about a teachable moment. The 10 year chunk size with synthetic data dropped it down to running in about 7 minutes with a much smaller task graph warning. I tried to run the real data with a 25 year chunk and the task graph warning went from about 1 Gb to 250 Mb. It’s on an HPC, so the chunk size increased to ~650 Mb. It successfully ran in 3.5 minutes with the real data. It ran fine, but did push up against the memory cap a few times, so maybe 20 year would be a better target.

Thanks again

5 Likes

:flexed_biceps:

Going back to your original question

I suppose my question comes down to what are good strategies to use once the data become too large

As this example illustrates, it is often a good idea to think about what your analysis is doing and how it lines up with the data layout (chunking). Sadly though, this can take a bunch of digging through various layers of code, so it’s not the most user-friendly.

1 Like

I set the time chunk to 468 to reach a reasonable block size (30 Mb) for data that will be served on a cloud store eventually.

Isn’t 30 MB chunk size rather small considering that the default for dask’s auto chunksize is 128 MiB possibly being increased in a future release (see here)?

Chunk sizes for compute orchestration can be a lot larger than chunk sizes on disk. You’d want smaller chunk sizes on disk to allow efficient queries, but larger chunk sizes for parallel compute to keep the CPU happy.

2 Likes