Efficiently slicing random windows for reduced xarray dataset

So memory used in doing stacking and memory used by data…how many of these end up on a worker at once?..with the coord calc, too.

It fails when only one worker has one task. The stack and window might still be a view, but I’m afraid there might not be a view into the coords or dims. If at least one of those gets copied and is float64 that would be 15x15x8000**2*8/1e9=115gb for 8000x8000 chunks with 15x15 windows.

I’ve tried to do the same thing with smaller chunked dataset, but there are many tasks and it takes too long to let me finish even on a subset. My guess there is that there are too many tasks, and all workers are exchanging all that data rather than all chunks being loaded on the worker the task was sent to.

Yes, might need a hardcore dask expert to say if this is possible at all without much bigger workers.

@dcherian @rabernat do you know if there is any hope for me here? Is there another approach to extracting these windows that I am not thinking about?

This workflow is actually strikingly similar to the one we used for this paper:

Our code is online here: GitHub - ocean-transport/surface_currents_ml. That work used “stencils” (equivalent to your “chips”) of 2x2, 3x3, and 4x4 for training models at each point. And we also had to drop the NaN points (which in our case corresponded to land).

We experimented with workflows that used xbatcher. We used the input_overlap feature of xbatcher to achieve the sliding windows. However, I don’t think we ended up using that for the final workflow.

This notebook - surface_currents_ml/train_models_stencil_in_space.ipynb at master · ocean-transport/surface_currents_ml · GitHub - shows a way of accomplishing what you are looking for using just reshaping and stacking. However, I don’t think it handles the overlapping stencils.

If your original data are Zarr, you might consider not actually using dask when you open the data. This gives you more control over dask graph. You might do something like this (warning: untested pseudocode), which constructs a dask array lazily via delayed

import xarray as xr
import dask
import dask.array as dsa

ds = xr.open_dataset('data.zarr', chunks=None)  # don't chunk yet

# get the list of valid points somehow
center_points = np.where(ds.mask.notnull())

# this operates on one DataArray at a time and returns a numpy array
@dask.delayed
def load_chip(da: xr.DataArray, j, i, chip_size=2) -> np.array:
    chip = da.isel(x=slice(i-chip_size, i+chip_size) y=slice(j-chip_size, j+chip_size)
    return chip.values  # this triggers loading

all_chips = [
    dsa.from_delayed(load_chip(ds["variable"], j, i), (5, 5), dtype=ds["variable"].dtype)
    for j, i in zip(center_points)
]

big_array = dsa.stack(all_chips)
2 Likes

Thanks a ton for the tips, @rabernat! It is greatly appreciated. I will take a look at those examples and try the snippet of code to see what happens.