Machine Learning training with large Xarray Datasets

Hi all,

I’d like to share a common workflow I face in Land Cover mapping tasks and hear your thoughts on how best to handle part of it with Xarray/Dask.

The general idea is to use a global land cover product as ground truth for a reference year (e.g., ESA WorldCover 2021), and then train a machine learning model to predict land cover in other years based on Earth Observation data.

My setup typically looks like this:

  • I have a huge, lazy xarray.Dataset of EO features — for example, yearly aggregations (e.g., median, percentiles) of Sentinel-2 L2A bands over a large region.
  • I also have a lazy xarray.DataArray of labels covering the same region. These labels are derived from the land cover product (after some cleaning), and are not very sparse.
  • I want to train an in-memory ML model like Random Forest, XGBoost, or LightGBM — which work well for this kind of task.

The challenge is that I obviously can’t use all the available labeled pixels due to memory constraints. So I’m looking to:

  • Perform random stratified subsampling of the labeled pixels — ideally distributed/lazy until the actual sampling step.
  • Then, compute and bring the sampled data in memory for training the model.

Has anyone tackled something similar? Do you have suggestions on:

  • Efficiently implementing stratified sampling across Dask-backed Xarray objects?
  • Any tools or best practices you recommend for subsampling while retaining class balance?
  • Potential gotchas in performing this kind of reduction operation before loading to memory?

Thanks in advance for any insights!