Saving larger-than-memory objects to zarr using dask and xarray

Hey all,

I’m trying to use dask to save a larger-than-memory zarr store. Concretely, I’m trying to:

  • Concatenate several NetCDF files (on the same geo grid, same variables) by time
  • Regrid them to a different grid
  • Save them as a zarr store, chunked such that there is only ever one chunk on time.

(the resultant zarr store would then be used for further analysis, which is done along the time dimension, and therefore I want it to be unchunked by time to not lose a lot of overhead to rechunking / etc.).

However, I’m having trouble setting up the workflow without causing memory use to balloon in the ds.to_zarr() call.

I’m trying to follow the Dask best practices (especially this one). A simplified version of the workflow is:

import xarray as xr
import numpy as np
import xesmf as xe
from distributed import Client

# Filenames to load (for example, historical and SSP2.45 files for the same GCM model / experiment / run)
fn_list = [fn1,fn2]

# Start dask client
client = Client()
display(client)

@dask.delayed
def load(fn_list):
   ds = xr.open_mfdataset(fn_list)
   return ds

@dask.delayed
def process(ds):
   # Do something to dataset, e.g., regridding
   ref_grid = xr.Dataset(coords = {'lat':np.arange(-89.5,89.6),
                                   'lon':np.arange(-179.5,179.6)})
   rgrd = xe.Regridder(ds,ref_grid,'conservative')
   
   ds = rgrd(ds)
   return ds

def workflow(fn_list):
   ds = load(fn_list)
   
   ds = process(ds)

   # Rechunk
   ds = ds.chunk({'time':-1,'lat':12,'lon':12})

   delayed = dask.delayed(ds.to_zarr)('test.zarr')
   return delayed

out = dask.compute(workflow)
dask.compute(out)

From what I’ve been gathering through researching this problem, something in the way the task graph is set up causes the whole array to be loaded and sent to one worker when the dask.compute() gets to the .to_zarr() call (which then of course crashes the client).

I guess my primary question is - why does the .to_zarr() call need everything in memory / is it possible to set it up so that it doesn’t? I’d be grateful for any insight!

Versions, if relevant:

zarr == 2.18.3
xarray == 2024.9.0
dask == 2024.9.1

why does the .to_zarr() call need everything in memory / is it possible to set it up so that it doesn’t?

I think because your workflow uses both an operation that requires the full spatial extent of the data (regridding using xemsf) and an operation that requires the full temporal extent of the data (creating one chunk along the time dimension). The combination requires the full dataset being loaded into memory. AFAIK whether it’s possibly to avoid this depends on the chunking of the input datasets, but I expect your best bet is probably caching the regridded output before rechunking. I recommend checking out @norlandrhagen’s blog post on downscaling pipelines, which also includes a regrid and rechunk step.

A few other comments

  • You probably want to check if using the defaults for open_mfdataset is the right choice for your data (see Stricter defaults for concat, combine, open_mfdataset · Issue #8778 · pydata/xarray · GitHub)
  • I view using both Dask Array and Delayed objects as a bit of an orange flag. I’m not sure why you’d need to do that in this case rather than just using Xarray’s integrated parallelism with Dask.Array objects.
  • XESMF without pre-generated weights is the slowest and most memory intensive of the commonly available regridding options (see Summary – Geospatial reprojection in Python (2024)). While conservative regridding is a good reason to put up with it, it’s worth considering separating out that step from the application of the regridder.
2 Likes

Thanks - super useful response and links, will check them out!

I think you’re making it too complicated. Your code should look something like this

# will produce a dask.array backed dataset
ds = xr.open_mfdataset(list_of_all_the_netcdf_files)

# note: whatever package you use for regridding needs to be "lazy", i.e. work with Dask arrays
# without triggering computing
ds_regridded = regrid(ds)

ds_regridded.to_zarr("test.zarr")

This should all work in a streaming fashion, without overloading memory.

1 Like

For lazy regridding, consider odc.geo.xr.xr_reproject. It’s pretty nifty.

1 Like

Thanks for the suggestion @wietzesuijker! However, it’s important to recognize that regridding and reprojection are not necessarily the same thing. (@maxrjones’s talk here has a great overview: https://discourse.pangeo.io/t/pangeo-showcase-geospatial-reprojection-in-python-2024-whats-available-and-whats-next/4531.) Reprojection is defined a bit more narrowly in terms of moving a geospatial raster from one regular rectangular pixel grid and CRS to another, whereas regridding covers the more complex geometries (e.g. curvilinear grid) and resampling strategies (e.g. conservative) used with numerical modeling grids.

2 Likes

I recently had this problem, but the files I wanted to concat were zarr groups with different chunk sizes across temporal and spatial dimensions with overlapping data. The fastest solution was to rechunk the inputs across a standard chunk size when reading in the data, then rechunk the full dataset after the concat. While it’s not a 1 to 1 solution, it may be useful to attach the code to this thread given the overlapping concept.

(For context, the data is streamflow mapped to river reaches, or graph edges and ~100GB total. Runtime was just over 3 minutes).

data_paths = <insert data path here>

time_slices = [
    ('1980-01-01', '1987-12-31'),
    ('1987-12-31', '1994-12-31'),
    ('1994-12-31', '2001-12-31'),
    ('2001-12-31', '2008-12-31'),
    ('2008-12-31', '2015-12-31'),
    ('2015-12-31', '2019-12-31')
]

standard_chunks = {
    'edge_idx': 1000,
    'time': 1096   
}

for path, (start_time, end_time) in zip(data_paths, time_slices):
    ds = xr.open_zarr(path).sel(time=slice(start_time, end_time))
    ds = ds.chunk(standard_chunks)
    datasets.append(ds)

combined_ds = xr.concat(datasets, dim='time').chunk('auto')
combined_ds.to_zarr(new_data_path, mode='w')
1 Like

Hi @ks905383 ! The majority of answers covered the operations, which will dramatically improve the footprint.

I just want to add, if you aren’t using it, is to consider setting up memory limits to your dask cluster, as defined here: Worker Memory Management — Dask.distributed 2024.11.2+9.g03a45a8 documentation

This has saved me (in addition to other optimizations) lots of headaches when porting code between environments in Alliance Canada, with different architectures/limits/policies.

1 Like

Wanted to flag that the folks at Carbon Plan (@norlandrhagen from the link @maxrjones shared above) seem to have this figured out. What ended up working was adapting their process, which seems to basically save a temp zarr store for every step of the problem (rechunking to one chunk in space, regridding, rechunking to one chunk in time). Very fast and no memory issues.

2 Likes

In theory, Dask should be smart enough to stream all of these operations together without an intermediate write to disk. Maybe this would be a good workflow to share in Large GeoSpatial Benchmarks: First Pass.

1 Like