Yes, you’ll need to disable multi-threading from dask when using DDP (which uses multi-processing). Unless you want to try out some experimental Python 3.13 GIL-free thread-based torch.DataLoader
- Improved Data Loading with Threads | NVIDIA Technical Blog
Sorry, I should have been more clear. If you want to:
- Stay with
xarray
- then disable dask’s multi-threading (as mentioned above). I thought you could usexr.open_zarr
without dask and still open chunks, but it seems that I’m wrong. - Use pure
zarr-python
- Yes, this is more effort, but will be generally faster as you’ll skip the overhead ofxarray
+dask
(see also Favorite way to go from netCDF (&xarray) to torch/TF/Jax et al)
Higher CPU count will only help to a certain extent. You’ll need to work out the ratio of how many CPUs it takes to keep 1 GPU saturated (see also Cloud native data loaders for machine learning using Zarr and Xarray | Earthmover). The math is something like this:
- If your model’s forward pass on one data chunk takes 0.2s on 1 GPU, and it takes 1.0s for 1 CPU to load the data, you would need
1.0 / 0.2 = 5 CPUs
to keep the GPUs saturated. So 8 GPUs would require 40 CPUs. - If your model’s forward pass on one data chunk takes 0.05s on 1 GPU, and it takes 1.0s for 1 CPU to load the data, you would need
1.0 / 0.05 = 20 CPUs
to keep the GPUs saturated. So 8 GPUs would require 160 CPUs.
As GPUs get faster and faster, you’ll realize that CPUs won’t be able to scale to saturate the GPUs. At that point, you’ll need to use something like kvikIO
(see Zarr — kvikio 24.10.00 documentation) which implements direct-to-GPU reads from Zarr, but it will depend on whether your HPC cluster supports GPUDirect Storage. Also, network I/O bandwidth might be a limitation here.
Sure, I think others will be keen to see this too! There might be a few more optimizations we can try out, but I think you’re doing a good job figuring out and resolving most of the bottlenecks already