Xarray slow read on cluster

Hi there,

I am working on a weather forecasting project using the ERA5 datasets. The datasets are stored on a yearly basis in zarr format, with chunking on time dimension (chunk={time:1, lat: -1, lon: -1, pressure_level: -1). I have tested the time of data loading for a one-time step with dataset.isel().load() and the results were pretty promising (around 0.4 sec).

However, when the job was submitted on the cluster with pytorch-ddp, the reading was 10 times slower and it significantly slowed down the training. It now took more than 4 sec to load a single time step. The GPU was waiting for the data for most of the time. I also tested running the job without data parallel, which was as slow. I’ve tested the following methods I found online and apparently, they were not helping:

  • setting num_workers in dataloader to 0, 1, 2, 4, 8
  • setting multiprocessing_context=‘forkserver’ for dataloader
  • dask.config.set(scheduler=“synchronous”)
  • dask.config.set(scheduler=“threads”, num_workers=x)

I also went through some posts about using xarray for ml but have not found anything relevant. We are using dsub as the scheduler. I am wondering whether folks here at pangeo have any insights on this issue.

1 Like

Where is your Zarr data stored? Is it on a Network File System (NFS) drive, Lustre, or something else? Typically network transfer would be the main bottleneck, and the first thing I would check.

Otherwise, you’re doing most things correctly already: Zarr :white_check_mark: Chunk size = access pattern :white_check_mark: . On the dask scheduler part, there’s sometimes an inherent conflict with Pytorch’s multiprocessing-based parallelism and other libraries that might be using multi-threading, which I’ve been trying to make sense of here, but you should be ok with Zarr. In fact, you could probably do away with the dask part and go with pure zarr-python/xarray, but I’d check on your filesystem first as mentioned above :point_up:

Hi Weiji,

Thank you for your reply. I checked with my service provider and it seems like i/o is a problem. They also pointed out that dask threads may be racing for CPU when used with ddp. They recommended to set the environmental variable DASK_THREADS_PER_WORKER to 1 and managed to reduce the read time to 1.4 sec. Our nodes have 8 GPU and 48 CPU cores and I guess switching to machines with higher CPU count will give better results.

Since you mentioned getting rid of dask with pure xarray, I am not sure how this could be done. I tested xarray.open_zarr by setting the chunks to None (disabling dask) and it gave worse results. I am also working on reading zarr directly. Although it requires much more effort, it is much faster, especially with operation fusing.

I can definitely share my implementation of the Pytorch dataset and explain every mistake and problem I encountered during the processing of the data if anyone is interested.

Yes, you’ll need to disable multi-threading from dask when using DDP (which uses multi-processing). Unless you want to try out some experimental Python 3.13 GIL-free thread-based torch.DataLoader - Improved Data Loading with Threads | NVIDIA Technical Blog

Sorry, I should have been more clear. If you want to:

  • Stay with xarray - then disable dask’s multi-threading (as mentioned above). I thought you could use xr.open_zarr without dask and still open chunks, but it seems that I’m wrong.
  • Use pure zarr-python - Yes, this is more effort, but will be generally faster as you’ll skip the overhead of xarray+dask (see also Favorite way to go from netCDF (&xarray) to torch/TF/Jax et al)

Higher CPU count will only help to a certain extent. You’ll need to work out the ratio of how many CPUs it takes to keep 1 GPU saturated (see also Cloud native data loaders for machine learning using Zarr and Xarray | Earthmover). The math is something like this:

  • If your model’s forward pass on one data chunk takes 0.2s on 1 GPU, and it takes 1.0s for 1 CPU to load the data, you would need 1.0 / 0.2 = 5 CPUs to keep the GPUs saturated. So 8 GPUs would require 40 CPUs.
  • If your model’s forward pass on one data chunk takes 0.05s on 1 GPU, and it takes 1.0s for 1 CPU to load the data, you would need 1.0 / 0.05 = 20 CPUs to keep the GPUs saturated. So 8 GPUs would require 160 CPUs.

As GPUs get faster and faster, you’ll realize that CPUs won’t be able to scale to saturate the GPUs. At that point, you’ll need to use something like kvikIO (see Zarr — kvikio 24.10.00 documentation) which implements direct-to-GPU reads from Zarr, but it will depend on whether your HPC cluster supports GPUDirect Storage. Also, network I/O bandwidth might be a limitation here.

Sure, I think others will be keen to see this too! There might be a few more optimizations we can try out, but I think you’re doing a good job figuring out and resolving most of the bottlenecks already :smile:

1 Like

Yes, that would be good.