Optimizing climatology calculation with Xarray and Dask

I just wanted to leave some experience I am having on a toy project I used for teaching xarray + dask (I teach only the basic premises, as I don’t feel very confident in understanding what’s going on under the hood). Here is the notebook.

Anyways, the short summary is:

  • averaging over multiple files on a large (30+ Gb) dataset on a local cluster works well, BUT dask “pollutes” my swap memory: it’s unclear to me as to why this happens, given that workers with limited memory are not supposed to fill my swap
  • using a for loop across files + dask at the individual file level works best for me. The code is less elegant, but memory usage is most efficient for a comparable compute time. This is much better than not using dask at all (which is good)
  • .groupby accross multiple files just doesn’t work for me. To reach the result in this tweet I used the for loop strategy.

If you have more successful stories with .groupby, let me know! I am aware of the tremendous efforts happening in this space, and I’m looking forward to it :wink:

1 Like

Thanks for sharing @fmaussion! I think many of us share some of these frustration around this very common workflow that still doesn’t quite “just work” out of the box with xarray and dask.

Regarding the Dask swap memory issue, I feel like that really needs to be raised with Dask itself. Have you opened an issue somewhere?

1 Like
import flox.xarray

flox.xarray.xarray_reduce(ds.tp, ds.time.dt.hour, func="mean", method="map-reduce")

should work really well. You could even adapt xhistogram for this.

I started adding groupby “user stories” here. We could add yours :slight_smile:

1 Like

I know for a fact that hdf5 library, used for reading netcdfs, does not support multi-threaded access at all. It assumes that only one thread in a process is calling into it at a time, even when each thread is reading from completely unrelated data sources. I’m not sure how well .open_dataset protects against that when operating in Dask context. Trying out Dask configuration with 1 thread per worker would be helpful.

You might also have sub-optimal chunking, without knowing data format it’s hard to say if chunking is appropriate for a given data source, lattitude=50 is kinda narrow, unless your data is stored in tall narrow chunks on disk that is. But if these are pulled from one file per timestamp that can cause a lot of IO overhead and possibly blow out your task graph leading to memory issues.

1 Like

Thanks all for your suggestions!

Let’s start with the pure .mean() swap issue.

Thanks! I tried it as well, no big change as far as I can see. Rendered notebook (including rich dask output) is here

This results in a swap overfill. You can see a screen recording here

How do I know how the data is stored? In the notebook I offer a link to download the data if anyone is interested.

I haven’t yet. Its a “minor” issue (it works, albeit with some unwanted swap use), and may be related to linux in particular?

Thanks @dcherian ! After the swap, let’s continue with groupby :wink:

Unfortunately, flox doesn’t do it, or maybe I’m doing something wrong.

Rendered notebook here.

Screen recording of whats happening here.


Yeah I think this is the classic distributed failure where it’s memory effficient to execute the graph depth-first but the emergent behaviour of dask+distributed is to execute breadth-first. That means it executes a lot of open_dataset tasks before reducing using up more memory than necessary (see Better Shuffling in Dask: a Proof-of-Concept - Coiled : Coiled under “root task overproduction”).

flox lets you easily construct graphs that work well when executed depth-first so it’s really up to distributed to schedule + compute properly at the moment. Previously xarray’s groupby would make a very convoluted graph that compounded distributed’s problems.

I think you just need more than 2GB per worker to get this to work right now.

Is this a case where New inline_array kwarg for open_dataset by TomNicholas · Pull Request #6566 · pydata/xarray · GitHub might help by removing the root task?

ncdump -h -s FILE

will give you output like this:

netcdf ERA5_HiRes_Hourly_tp_2000_01 {
	longitude = 1440 ;
	latitude = 721 ;
	time = 744 ;
	float longitude(longitude) ;
		longitude:units = "degrees_east" ;
		longitude:long_name = "longitude" ;
	float latitude(latitude) ;
		latitude:units = "degrees_north" ;
		latitude:long_name = "latitude" ;
	int time(time) ;
		time:units = "hours since 1900-01-01 00:00:00.0" ;
		time:long_name = "time" ;
		time:calendar = "gregorian" ;
	short tp(time, latitude, longitude) ;
		tp:scale_factor = 1.06652958510645e-06 ;
		tp:add_offset = 0.0349459083855981 ;
		tp:_FillValue = -32767s ;
		tp:missing_value = -32767s ;
		tp:units = "m" ;
		tp:long_name = "Total precipitation" ;

// global attributes:
		:Conventions = "CF-1.6" ;
		:history = "2022-05-04 02:24:39 GMT by grib_to_netcdf-2.24.3: /opt/ecmwf/mars-client/bin/grib_to_netcdf -S param -o /cache/data8/adaptor.mars.internal-1651631053.6801858-1345-10-118f35a2-9614-4768-9c7a-41753b78a6d1.nc /cache/tmp/118f35a2-9614-4768-9c7a-41753b78a6d1-adaptor.mars.internal-1651631001.944344-1345-14-tmp.grib" ;
		:_Format = "64-bit offset" ;

but since there are no _ChunkSizes attribute your native chunk is the size of that whole file, so 744x1440x721. However since your data is not using any compression I’m guessing accessing it in smaller chunks is probably fine anyway…

Things to try would be:

  • doing rechunk after load instead of as part of load
  • chunking along time dimension instead of latitude
  • not doing rechunking at all

I investigated this on a cloud cluster.

Step 1: Ingest Data using Pangeo Forge

To start, I ingested your data into Google Cloud Storage in the Zarr format using the following Pangeo Forge recipe

import os
from pangeo_forge_recipes.patterns import ConcatDim, FilePattern
from pangeo_forge_recipes.recipes import XarrayZarrRecipe
from pangeo_forge_recipes.storage import CacheFSSpecTarget, FSSpecTarget, MetadataTarget, StorageConfig
import gcsfs
import pandas as pd

# after 2020 data become netCDF4 files 🤦‍♂️ - this causes problems
#dates = pd.date_range('2000-01-01', '2021-12-01', freq='MS')
dates = pd.date_range('2000-01-01', '2019-12-01', freq='MS')

def format_url(time):
    return (

pattern = FilePattern(
    ConcatDim("time", dates),

recipe = XarrayZarrRecipe(
    target_chunks={'time': 24},
    subset_inputs={'time': 4}

fs = gcsfs.GCSFileSystem(skip_instance_cache=True, use_listings_cache=False)

target = FSSpecTarget(fs, f"{os.environ['SCRATCH_BUCKET']}/ERA5_HiRes_Hourly.zarr")
cache = CacheFSSpecTarget(fs, f"{os.environ['SCRATCH_BUCKET']}/ERA5_HiRes_Hourly/cache")
metadata = MetadataTarget(fs, f"{os.environ['SCRATCH_BUCKET']}/ERA5_HiRes_Hourly/metadata")

recipe.storage_config = StorageConfig(target, cache, metadata)

delayed = recipe.to_dask()

This was pretty slow, since I was only about to get about 5 MB/s from the server. But it worked!

Step 2: Flox with “Cohorts”

I was able to get the computation to run with the following cluster settings

from dask_gateway import Gateway
g = Gateway()
options = g.cluster_options()
options.environment = {"MALLOC_TRIM_THRESHOLD_": "0"}
options.worker_memory = 40
gc = g.new_cluster(cluster_options=options)

At this point my cluster will have 800 GB of memory. This matches a rule of thumb that I have observed in this type of workload: it will only really work if you have more memory in the cluster than the total dataset size! I consider this a serious performance limitation of our stack for this type of computation. It means that we can’t actually do “streaming” style processing, where the cluster memory is much smaller than the actual data.

These settings allowed the computation to complete in

import xarray as xr
import os
import flox.xarray

url = f"{os.environ['SCRATCH_BUCKET']}/ERA5_HiRes_Hourly.zarr"
ds = xr.open_dataset(url, engine="zarr", chunks={})

# method = "map-reduce" # <- cluster ran out of memory
method = "cohorts"

tpm = flox.xarray.xarray_reduce(ds.tp, ds.time.dt.hour,  func="mean",  method=method)

The workers got totally saturated with memory and started spilling to disk, but they did manage to recover and finish the computation in about 15 minutes.

tpm[0].plot(robust=True, figsize=(14,8))

I am now investigating using rechunker. This operation is basically embarrassingly parallel in space, similar to timeseries analysis at each point. So I am hoping that, by rechunking the data to be contiguous in time, and using flox with method="map-reduce", I will get a much happier (and faster) dask computation.

1 Like

Doesn’t this skip making a dask array? If so, flox won’t help…

That’s fantastic @rabernat thanks a lot! (sorry for the netcdf4 files that would be my fault - interesting thought that the url read works differently in this case)

Obviously the 800GB memory requirement is crazy, but the 15 mins time is actually quite good! I’m curious to see how this can go even better. This can definitely become a showcase for pangeo because the animations are pretty

No. This translates Zarr chunks directly to Dask chunks. chunks=None skips making a Dask array. I agree it is confusing syntax! :upside_down_face:

I think it is ECMWF’s fault. They changed their format at some point. Xarray can handle both formats fine (when using the netCDF4 engine), but with Pangeo Forge there are some quirks involved in reading data from object storage. Because we read the data using fsspec file-like objects, we can’t use the netCDF4 engine. So we need to either use the scipy engine (for netCDF3) or the h5netcdf engine (for netCDF4). The package currently assumes that all the files you may want to read in a single recipe are of the same type.

Something is very wrong with I/O side of things.

If I use dask.array.ones to mimic the original netCDF collection with same chunks, and the same cluster settings (4 processes, 8GB total), flox with map-reduce finishes in just over a minute. This fits my intuition (phew!). No rechunking should be necessary to do an hourly climatology here.

@rabernat, I think your original inline_array suggestion was right on track.

Notebook | Performance report | Dashboard

We may be comparing apples and oranges. Your data shape is:

                (8760, 721, 1440), chunks=(744, 50, 1440), dtype=np.float32

But the original data’s native chunks are more like

shape = (175320, 721, 1440)
chunks =	(744, 721, 1440)

When I did test my test with an 800 GB, I was using

shape = (175320, 721, 1440)
chunks =	(24, 721, 1440)

Also, ones may have shortcuts because it uses the same dask token for every chunks. Try with random data instead.

@dcherian there is one big difference in your toy notebook though:

all hours (groups) are present in each chunk;

That’s not correct in the multi-file case. Each file is a monthly file, so the dataset has 12 chunks in the time dimension. each of different size (depends of the length of the month).

The workflow you suggest in your notebook is the one I eventually used to get the job done (loop over the individual files, store the result, and merge later)

Ups I may have misread your example. Just to clarify:

  • Ryan uses a 20 years climatology which is a HUGE dataset
  • Deepak uses a 1 year climatology which, I believe, should work on a laptop.

I followed this notebook that was using flox with this dataset… Why did we switch to the 20 year dataset :stuck_out_tongue:

Using random with float64 (not exactly right, but passing dtype raises an error) makes it 3min on my laptop with lots of spilling. Again this is because more “random_sample” tasks get executed (i.e. root task overproduction). If these were fused with the blockwise reduction (“groupby-nanmean-chunk”) that happens later, memory use would be much lower. I didn’t try any of the new memory tricks.

Ryan, I can’t read your scratch bucket. Can we make this more public somehow? Or what chunk sizes are you using?

1 Like

Ah no wonder, this is bad for a map-reduce groupby. Because there is one element per group (i.e. one data point in each hour) per chunk, the blockwise reduction does nothing (input=output). Then we stitch 4 chunks together (memory use is now 4x chunksize at least), and reduce again (back to 1x chunksize). Now we keep repeating these steps till the end.

You could try “split-reduce” which is the standard xarray thing that would split each chunk to 24 new chunks and run it forward. This is probably too large an increase in tasks to work well.

I would call .chunk({"time": 6")} so 4x reduction in chunksize; and then use method="cohorts" so we can get some effective reductions early on in the graph, but obviously better if the zarr dataset was chunked that way to begin with.

Basically for time grouping where groups are periodic with period T, you want chunksize C > T and use “map-reduce”, or C < T and use “cohorts”. If C~T then it’s just bad memory wise (cna we call C/T the flocking number)