Efficiently slicing random windows for reduced xarray dataset

So memory used in doing stacking and memory used by data…how many of these end up on a worker at once?..with the coord calc, too.

It fails when only one worker has one task. The stack and window might still be a view, but I’m afraid there might not be a view into the coords or dims. If at least one of those gets copied and is float64 that would be 15x15x8000**2*8/1e9=115gb for 8000x8000 chunks with 15x15 windows.

I’ve tried to do the same thing with smaller chunked dataset, but there are many tasks and it takes too long to let me finish even on a subset. My guess there is that there are too many tasks, and all workers are exchanging all that data rather than all chunks being loaded on the worker the task was sent to.

Yes, might need a hardcore dask expert to say if this is possible at all without much bigger workers.

@dcherian @rabernat do you know if there is any hope for me here? Is there another approach to extracting these windows that I am not thinking about?

This workflow is actually strikingly similar to the one we used for this paper:

Our code is online here: GitHub - ocean-transport/surface_currents_ml. That work used “stencils” (equivalent to your “chips”) of 2x2, 3x3, and 4x4 for training models at each point. And we also had to drop the NaN points (which in our case corresponded to land).

We experimented with workflows that used xbatcher. We used the input_overlap feature of xbatcher to achieve the sliding windows. However, I don’t think we ended up using that for the final workflow.

This notebook - surface_currents_ml/train_models_stencil_in_space.ipynb at master · ocean-transport/surface_currents_ml · GitHub - shows a way of accomplishing what you are looking for using just reshaping and stacking. However, I don’t think it handles the overlapping stencils.

If your original data are Zarr, you might consider not actually using dask when you open the data. This gives you more control over dask graph. You might do something like this (warning: untested pseudocode), which constructs a dask array lazily via delayed

import xarray as xr
import dask
import dask.array as dsa

ds = xr.open_dataset('data.zarr', chunks=None)  # don't chunk yet

# get the list of valid points somehow
center_points = np.where(ds.mask.notnull())

# this operates on one DataArray at a time and returns a numpy array
@dask.delayed
def load_chip(da: xr.DataArray, j, i, chip_size=2) -> np.array:
    chip = da.isel(x=slice(i-chip_size, i+chip_size) y=slice(j-chip_size, j+chip_size)
    return chip.values  # this triggers loading

all_chips = [
    dsa.from_delayed(load_chip(ds["variable"], j, i), (5, 5), dtype=ds["variable"].dtype)
    for j, i in zip(center_points)
]

big_array = dsa.stack(all_chips)
2 Likes

Thanks a ton for the tips, @rabernat! It is greatly appreciated. I will take a look at those examples and try the snippet of code to see what happens.

@rabernat I ended up not taking the from_delayed approach since dask kept getting upset about the large number of tasks. I might revisit your suggested approach, but I ended up chunking the spatial mosaic xarray dset over x and y and for each chunk passing to roughly this function:

def extract_write_chunks(x_slice, y_slice, ...):
    """..."""
    dset = xr.open_zarr(store=store, group=group)
    subset = dset.sel(x=x_slice, y=y_slice)

    non_nan_idx = np.where(~np.isnan(subset[target_name]))
    indexes = np.stack(non_nan_idx).T

    da = subset.to_array()
    array = np.asarray(da)
    da_lats = np.asarray(da.lat)
    da_lons = np.asarray(da.lon)
    da_times = np.asarray(da.times)

    non_null = []
    lats = []
    lons = []
    times = []
    for t, x, y in indexes:
        chip = array[:, t, x - pw : x + pw + 1, y - pw : y + pw + 1]
        if np.sum(chip.shape[-2:]) == ((pw * 2 + 1) * 2):
            lats.append(da_lats[y - pw : y + pw + 1])
            lons.append(da_lons[x - pw : x + pw + 1])
            times.append(da_times[t])
            non_null.append(chip)

    # write out chips to zarr

I found that indexing and slicing out all the chips was much much faster after converting to dataarray, rather than using .sel or .isel.

We successfully have built 1Tb chip datasets in less than 8ish minutes on the cluster. Feels pretty good!

I have also successfully implemented ops to support stratified splits, balanced datasets spatially, etc.

Somewhat-related:

Now I am at the point of needing to figure out how the heck to serve this data to a tensorflow model on a single worker (with 1-8 gpus). That has prompted my post over on dask discourse on building tf.data.Datasets from dask arrays or delayed objects. I am currently just building tfrecords with dask to gcs and using tf.data api to read and load, but it is not as quick as I would like. I am hoping to offload all the work to other machines/workers and have the gpu machine focus only on catching transfered data, decoding, and loading onto the gpu. I looked at some of the tensorflow models you trained above, but the datasets didn’t look large enough to warrant this. I need to speed this up with ultimate hopes of bootstrapping the training runs.

Has anyone tried to integrate dask with tensorflow datasets? Has anyone seen any working approaches here?

@Leonard_Strnad - you may be interested in the Xbatcher project. Currently, it provides a batch generator API and some prototype ML data loaders (for pytorch and tensorflow). On our road map is to tune dataloaders so they play nice with Dask when feeding data to gpu-backed models.

cc @maxrjones and @weiji14 who have been using/developing xbatcher lately.

3 Likes