This workflow is actually strikingly similar to the one we used for this paper:
Our code is online here: GitHub - ocean-transport/surface_currents_ml. That work used “stencils” (equivalent to your “chips”) of 2x2, 3x3, and 4x4 for training models at each point. And we also had to drop the NaN points (which in our case corresponded to land).
We experimented with workflows that used xbatcher. We used the input_overlap
feature of xbatcher to achieve the sliding windows. However, I don’t think we ended up using that for the final workflow.
This notebook - surface_currents_ml/train_models_stencil_in_space.ipynb at master · ocean-transport/surface_currents_ml · GitHub - shows a way of accomplishing what you are looking for using just reshaping and stacking. However, I don’t think it handles the overlapping stencils.
If your original data are Zarr, you might consider not actually using dask when you open the data. This gives you more control over dask graph. You might do something like this (warning: untested pseudocode), which constructs a dask array lazily via delayed
import xarray as xr
import dask
import dask.array as dsa
ds = xr.open_dataset('data.zarr', chunks=None) # don't chunk yet
# get the list of valid points somehow
center_points = np.where(ds.mask.notnull())
# this operates on one DataArray at a time and returns a numpy array
@dask.delayed
def load_chip(da: xr.DataArray, j, i, chip_size=2) -> np.array:
chip = da.isel(x=slice(i-chip_size, i+chip_size) y=slice(j-chip_size, j+chip_size)
return chip.values # this triggers loading
all_chips = [
dsa.from_delayed(load_chip(ds["variable"], j, i), (5, 5), dtype=ds["variable"].dtype)
for j, i in zip(center_points)
]
big_array = dsa.stack(all_chips)