I am working on a weather forecasting project using the ERA5 datasets. The datasets are stored on a yearly basis in zarr format, with chunking on time dimension (chunk={time:1, lat: -1, lon: -1, pressure_level: -1). I have tested the time of data loading for a one-time step with dataset.isel().load() and the results were pretty promising (around 0.4 sec).
However, when the job was submitted on the cluster with pytorch-ddp, the reading was 10 times slower and it significantly slowed down the training. It now took more than 4 sec to load a single time step. The GPU was waiting for the data for most of the time. I also tested running the job without data parallel, which was as slow. Iâve tested the following methods I found online and apparently, they were not helping:
setting num_workers in dataloader to 0, 1, 2, 4, 8
setting multiprocessing_context=âforkserverâ for dataloader
I also went through some posts about using xarray for ml but have not found anything relevant. We are using dsub as the scheduler. I am wondering whether folks here at pangeo have any insights on this issue.
Where is your Zarr data stored? Is it on a Network File System (NFS) drive, Lustre, or something else? Typically network transfer would be the main bottleneck, and the first thing I would check.
Otherwise, youâre doing most things correctly already: Zarr Chunk size = access pattern . On the dask scheduler part, thereâs sometimes an inherent conflict with Pytorchâs multiprocessing-based parallelism and other libraries that might be using multi-threading, which Iâve been trying to make sense of here, but you should be ok with Zarr. In fact, you could probably do away with the dask part and go with pure zarr-python/xarray, but Iâd check on your filesystem first as mentioned above
Thank you for your reply. I checked with my service provider and it seems like i/o is a problem. They also pointed out that dask threads may be racing for CPU when used with ddp. They recommended to set the environmental variable DASK_THREADS_PER_WORKER to 1 and managed to reduce the read time to 1.4 sec. Our nodes have 8 GPU and 48 CPU cores and I guess switching to machines with higher CPU count will give better results.
Since you mentioned getting rid of dask with pure xarray, I am not sure how this could be done. I tested xarray.open_zarr by setting the chunks to None (disabling dask) and it gave worse results. I am also working on reading zarr directly. Although it requires much more effort, it is much faster, especially with operation fusing.
I can definitely share my implementation of the Pytorch dataset and explain every mistake and problem I encountered during the processing of the data if anyone is interested.
Yes, youâll need to disable multi-threading from dask when using DDP (which uses multi-processing). Unless you want to try out some experimental Python 3.13 GIL-free thread-based torch.DataLoader - Improved Data Loading with Threads | NVIDIA Technical Blog
Sorry, I should have been more clear. If you want to:
Stay with xarray - then disable daskâs multi-threading (as mentioned above). I thought you could use xr.open_zarr without dask and still open chunks, but it seems that Iâm wrong.
If your modelâs forward pass on one data chunk takes 0.2s on 1 GPU, and it takes 1.0s for 1 CPU to load the data, you would need 1.0 / 0.2 = 5 CPUs to keep the GPUs saturated. So 8 GPUs would require 40 CPUs.
If your modelâs forward pass on one data chunk takes 0.05s on 1 GPU, and it takes 1.0s for 1 CPU to load the data, you would need 1.0 / 0.05 = 20 CPUs to keep the GPUs saturated. So 8 GPUs would require 160 CPUs.
As GPUs get faster and faster, youâll realize that CPUs wonât be able to scale to saturate the GPUs. At that point, youâll need to use something like kvikIO (see Zarr â kvikio 24.10.00 documentation) which implements direct-to-GPU reads from Zarr, but it will depend on whether your HPC cluster supports GPUDirect Storage. Also, network I/O bandwidth might be a limitation here.
Sure, I think others will be keen to see this too! There might be a few more optimizations we can try out, but I think youâre doing a good job figuring out and resolving most of the bottlenecks already