Advice on writing many slices from one remote zarr xarray to another

I have been spending the last few weeks getting more and more familiar with xarray, zarr and dask. My goal has been to build a chip datastore (small) from an intermediate mosaic datastore (much larger–27gb could get to 100s). We have labels for small regions in many tiff files that we use for training, which we use to extract a chip dataset for training and then make inference over the entire mosaics. So, I am hoping to use xarray for both training and inference.

Given the labels (targets, response, ys, etc.), which have lat and lon coords, I slice out a chip (windowed area around the lat lon coords) from the mosaic datastore and want to store this constant sized chip in a target chip datastore.

The mosaic datastore looks like:


I struggled for a while to determine how to not have to add tile to the dims by building spatial mosaics, but here we are.

and the target chip dataset looks like (notice one extra data variable for the response):

My problem is illustrated in code below:

write = # should this be True or False?
open_zarr_inside = # should this be True or False?

# big as in only 27gb
if not open_zarr_inside:
    # underlying dask arrays
    big_dset = xr.open_zarr("remote_gcs_path")

def build_chips(slices, write_to_store, open_zarr_inside, big_dset=None):
    x_slice, y_slice, specific_time, specific_tile = slices

    if open_zarr_inside:
        # underlying dask arrays
        big_dset = xr.open_zarr("remote_gcs_path")
        
    # actual function is more complicated as coords, dims changes to match target
    # but this suffices to show the idea.
    small_slice = big_dset.sel(x=x_slice, y=y_slice, time=specific_time, tile=specific_tile)
    
    if write_to_store:
        # write out to an existing target
        small_slice.to_zarr(..., region=...)
    elif not write_to_store:
        return small_slice
    elif write_to_store = 'neither'
        return True

# fast, no xarray ops or dask ops here, maybe 2500 slices
slices = generate_xyt_slice(...)

# if write_to_store is True, things write out, but I can only get 100 examples working before it is sad
if write_to_store:
    builder = partial(build_chips, write_to_store=write_to_store, open_zarr_inside= open_zarr_inside)
    futures = client.map(build_chips, slices)
    results = client.gather(futures)

# if write_to_store if False, returning all the datasets to client is not ideal
if not write_to_store:
    builder = partial(build_chips, write_to_store=write_to_store, open_zarr_inside= open_zarr_inside)
    futures = client.map(build_chips, slices)
    results = client.gather(futures)
    xr.concat(results, dim='t')

I am struggling to scale this up beyond 100 slices. I am reaching out for some advice or best practices that have surfaced in this example:

  1. Do I pass the opened mosaic dataset? should open_zarr_inside be True or False. Opening it takes 100ms(ish), so maybe as long as the computation can amortize the cost of opening?
  2. Do I continue writing out to the target within build_chips or should I somehow take advantage of dask.delayed to collect all dsets within dask distributed to merge and then write out? (I have tried many combinations of using dask.delayed and no luck. My guess is that transfering xr.datasets between client and workers or between workers is a bad idea)
  3. Perhaps I should pass in batches of slices to build_chips in order to reduce the number of writes with to_zarr by concatenating within build_chips over each batch? This is the last option I have yet to try that I can think of. It seems hopeful because it avoids transfering xr.datasets and minimizes the number of writes to the target chip dataset.
  4. What are the general best practice around handling xr.datasets in this context? Do we always want to avoid passing them between functions delayed or not?

Any advice is greatly appreciated!

Scaling problem is in too slow, or out of memory type problems?

I think it’s from way too many tasks. Some workers get 10000 tasks with 9 workers at 4cpu + 16gb each. Also, I got an error from gcsfs saying too many requests to gcs at one point.

The slices are pretty small.

Now, I have grouped slice tasks by tile, send them in batch to slice and write, load the tile dataset with ‘.load()’ and am seeing way better performance, which makes sense–maneagable task count, fewer slice downloads, and fewer writes with to_zarr(region=…).

Yes, definitely sounds like too many tasks in the first instance there.

How is this going now?

So, far so good. I am able to process 22 tiles at about 27gb each with 6 time elements (6 years–we measure tree growth, so our temporal resolution is low) to an intermediate mosaic dataset and then build a chip datastore from that mosaic datastore in under 3 minutes with Dask. I can also make inference over pixels with these datastores so much faster than before with dask and it makes experimentation so much easier/quicker.

Indexing the mosaic target as (tile, time, x, y) allows me to easily parallelize at the tile level, which is intuitive and allows me to build chips for that tile after calling mosaics.sel(tile=tile).load() . This keeps tasks down, but I feel like I am not really taking advantage of any dask graphs, which are honestly still new to me. I have made a lot of progress just using client.map() and have not had any success with using dask.delayed() for anything yet.

Hoping to scale up and test 100 and then 1000 tiles (122gb and 1.7tb respectively).

One downside of adding tile to the index is that it is harder to visualize the data spatially merged, but I can easily determine which tile I should slice given a geo without defaulting to bigger lat/lon comparisons…