Efficiently slicing random windows for reduced xarray dataset

Just wanted to say that the .rolling code here might fit into the scope of GitHub - pangeo-data/xbatcher: Batch generation from xarray datasets. Cc @maxrjones and @jhamman.

1 Like

You are stacking for ML shape reasons, Leonard?

I am stacking in order to drop windows that have nans in their center. Yes, I essentially want all mxm windows with non-nan centers for training. Eventually, we will not restrict to centers and take any window with a non-nan anywhere for a sparse regression training scheme.

I can’t think of a way to extract windows with non-nan centers any other way. It’s easy to find index of all non-nans for the variable of interest, but I do not know how given those index and window size to slice out to an (n, m, m) xarray dataset (n would be number or non-nan pixels and m would be the window width).

Once I write this to zarr/gcs, I have a training dataset for the team to start model development.

I haven’t seen this package yet. It looks extremely helpful.

Support for sparse targets like this “chip” example (given mxm window make a single pixel prediction at the center–very common in biomass/canopy height regression) and splitting for training/test/validation (potentially based on masks) would be oh so nice.

Are you stacking inside the function being mapped? This is key. Stack is a reshape which can be very very inefficient in parallel settings.

3 Likes

Indeed I am, but I have chunk sizes of (2, 8192, 8192) over 42 variables, which gives a stacked (z) index of length 134M.

I think that the success of this method will require smaller chunks, but as I stated above, that means more nans in windows near each chunks boundary…

This is roughly the function I am mapping:

width = 7
target_name = 'b'

def extract_chips(subdset):
    
    rolling_obj = subdset.rolling({"x": width, "y": width})
    windowed_dset = rolling_obj.construct(window_dim={"x": "x2", "y": "y2"})

    # stack over x and y in order to select windows with non-nan centers
    center = width // 2
    stacked_dset = windowed_dset.stack(z=("time", "x", "y"))

    # this seems to kill workers
    nonnan_idx = np.where(
        ~np.isnan(stacked_dset[target_name].sel(x2=center, y2=center))
    )[-1]

    # select the windows with non-nan centers!
    chips = stacked_dset.isel(z=nonnan_idx)
    
    # write out chips to zarr-backed xarray on gcs

I might try to map a pre-computed non-nan index before the windowed_dset:

idx = np.where(~np.isnan(subdset[target_name]))
flattened_idx = somehow_flatten_idx(idx)

and avoid the

    nonnan_idx = np.where(
        ~np.isnan(stacked_dset[target_name].sel(x2=center, y2=center))
    )[-1]

operation.

1 Like

Sorry if I’m being dense, but are you applying extract_chips with xarray.map_blocks? A full reproducible minimal example with dummy data would be useful.

1 Like

Not at all. Yes, I am doing that. Well, I tried that and was struggling to see what was going on during the map so I have been doing something like

import dask
from itertools import product

x_starts = np.cumsum([0] + list(dset.chunks['x'])[:-1])
x_start_step = zip(x_starts, dset.chunksizes['x'])
y_starts = np.cumsum([0] + list(dset.chunks['y'])[:-1])
y_start_step = zip(y_starts, dset.chunksizes['y'])

futures = []
for (x_start, x_step), (y_start, y_step) in list(product(x_start_step, y_start_step))[:2]:
    x_slice = slice(x_start, x_start + x_step)
    y_slice = slice(y_start, y_start + y_step)
    futures.append(dask.delayed(extract_chips)(x_slice, y_slice))   

where extract chips does the slicing.

I will put together a full, minimal example and post today or tomorrow.

1 Like

@dcherian Here is a fully reproducible example with a dataset that has roughly the same number of non-nans in the target, has the same height and width, but fewer variables than my actual dataset. I write it out to zarr to prevent additional tasks, to replicate my exact process, and to isolate the compute issues in the chip extraction.

My dask cluster has 36 workers each with 16gb and 4cores using HelmCluster.

import dask
import numpy as np
import xarray as xr


target_path = "gs://somewhere/out/there"
features = ['a','b','c','d','e','f']
times = [2018, 2019]
final_height = 40000
final_width = 40000
xy_chunksize = 8192
n_non_nans = 10e6
p_non_nan = 1 - (n_non_nans / (final_height * final_width))

dummy_data = dask.array.random.random(
    (len(times), final_width, final_height),
    chunks=[1, xy_chunksize, xy_chunksize],
)
dummy_data = dummy_data.astype('float32')

spatial_dset = xr.Dataset(
    data_vars={ ftr: (["time", "x", "y"], dummy_data) for ftr in features},
    coords={
        "lon": (["x"], np.arange(final_width)),
        "lat": (["y"], np.arange(final_height)),
        "times": (("time"), times),
    },
)

# fake sparse non-nan value data in target
spatial_dset = spatial_dset.assign(
    {'target': xr.where(spatial_dset['f'] < p_non_nan, np.nan, True)}
)

# write to zarr somewhere
spatial_dset.to_zarr(target_path, group='example_mosaic')

# inspect and check
rt_spatial_dset = xr.open_zarr(target_path, group='example_mosaic')
idx = np.where(~np.isnan(rt_spatial_dset['target']))
assert np.abs(len(idx[1]) / len(times) - n_non_nans) < 5000


def extract_write_chip_dset(
    x_slice: slice,
    y_slice: slice,
    width: int,
    target_name: str,
):
    if width % 2 != 1:
        raise ValueError("Width must be odd for non-nan to be at the center.")

    dset = xr.open_zarr(target_path, group='example_mosaic')
    subdset = dset.sel(x=x_slice, y=y_slice)

    # temporarily assign lat and lon to variables indexed by x or y
    subdset = subdset.assign(
        {"lat2": ("y", subdset.lat.data), "lon2": ("x", subdset.lon.data)}
    )

    # use a rolling object to build a windowed view over the entire spatial mosaic
    rolling_obj = subdset.rolling({"x": width, "y": width})
    windowed_dset = rolling_obj.construct(window_dim={"x": "x2", "y": "y2"})


    # stack over x and y in order to select windows with non-nan centers
    center = width // 2
    stacked_dset = windowed_dset.stack(z=("time", "x", "y"))
    #return stacked_dset.dims
    nonnan_idx = np.where(
        ~np.isnan(stacked_dset[target_name].sel(x2=center, y2=center))
    )[-1]
    chips = stacked_dset.isel(z=nonnan_idx)
    
    # should write to a zarr chip dataset here, but I return the dims to 
    # first inspeect that we can extract chips in the first place, something
    # that map_blocks does not allow us to do.

    return chips.dims


import dask
from itertools import product

dset = rt_spatial_dset
dset = dset.unify_chunks()
x_starts = np.cumsum([0] + list(dset.chunks['x'])[:-1])
x_start_step = zip(x_starts, dset.chunksizes['x'])
y_starts = np.cumsum([0] + list(dset.chunks['y'])[:-1])
y_start_step = zip(y_starts, dset.chunksizes['y'])

futures = []
for (x_start, x_step), (y_start, y_step) in list(product(x_start_step, y_start_step))[:5]:
    x_slice = slice(x_start, x_start + x_step)
    y_slice = slice(y_start, y_start + y_step)
    futures.append(dask.delayed(extract_write_chip_dset)(x_slice, y_slice, 3, 'target'))  


results = dask.compute(*futures)

Before I try to use map_blocks, I am (as I have alluded to above) manually creating delayed tasks over chunks explicitly in order to inspect the chips first. The issue with map_blocks is that it is a bit harder to inspect what is going on in the function I am mapping. I also have gotten into this habit of reading and slicing a dataset in the function, which is why I pass slices instead of passing a sliced dataset.

Anyways, I get an issue here

...
     25 stacked_dset = windowed_dset.stack(z=("time", "x", "y"))
     26 #return stacked_dset.dims
---> 27 nonnan_idx = np.where(
     28     ~np.isnan(stacked_dset[target_name].sel(x2=center, y2=center))
     29 )[-1]
     30 chips = stacked_dset.isel(z=nonnan_idx)
     32 return chips.dims
...
KilledWorker: ("('getitem-overlap-reshape-transpose-invert-f9ec729b6f057728d617ffb36be4dc89', 1)", <WorkerState 'tcp://10.100.15.4:45683', status: closed, memory: 0, processing: 3>)

The other issue I am encountering is that when things do work out (with smaller widths and chunksizes), I get more nan windows near the border of the chunks. So, I lose data, which is not ideal.

I hope this makes the problem clearer. And thanks for working with me a bit here!

1 Like

And as an attempt to prevent running np.where over the stacked dataset, I modified selecting windows with non-nan by determining what the corresponding z-index would be:

def extract_write_chip_dset(
    x_slice: slice,
    y_slice: slice,
    width: int,
    target_name: str,
):
    if width % 2 != 1:
        raise ValueError("Width must be odd for non-nan to be at the center.")

    dset = xr.open_zarr(target_path, group='example_mosaic')
    subdset = dset.sel(x=x_slice, y=y_slice)

    # temporarily assign lat and lon to variables indexed by x or y
    subdset = subdset.assign(
        {"lat2": ("y", subdset.lat.data), "lon2": ("x", subdset.lon.data)}
    )

    # use a rolling object to build a windowed view over the entire spatial mosaic
    rolling_obj = subdset.rolling({"x": width, "y": width})
    windowed_dset = rolling_obj.construct(window_dim={"x": "x2", "y": "y2"})


    # stack over x and y in order to select windows with non-nan centers
    stacked_dset = windowed_dset.stack(z=("time", "x", "y"))

    # assuming we stack with (time, x, y) determine the z index of non-nans
    non_nan_idx = np.where(~np.isnan(subdset['target']))
    blah = np.stack(non_nan_idx).T
    x_dims = subdset.dims['x']
    y_dims = subdset.dims['y']
    z_idx = blah[:,0] * x_dims * y_dims + blah[:,1] * y_dims + blah[:,2]
    chips = stacked_dset.isel(z=z_idx)

    return chips.mean().compute()

but I am still getting killed workers, but without any pointer to code. I call .mean().compute() to. make sure I can actually bring those chips into memory, which will have to happen for the write to zarr.

np.where will compute on the whole subdset?

Indeed. If compute has not been called it will be triggered, but we are only doing this over a single variable, which corresponds to a 256mb chunk.

So memory used in doing stacking and memory used by data…how many of these end up on a worker at once?..with the coord calc, too.

It fails when only one worker has one task. The stack and window might still be a view, but I’m afraid there might not be a view into the coords or dims. If at least one of those gets copied and is float64 that would be 15x15x8000**2*8/1e9=115gb for 8000x8000 chunks with 15x15 windows.

I’ve tried to do the same thing with smaller chunked dataset, but there are many tasks and it takes too long to let me finish even on a subset. My guess there is that there are too many tasks, and all workers are exchanging all that data rather than all chunks being loaded on the worker the task was sent to.

Yes, might need a hardcore dask expert to say if this is possible at all without much bigger workers.

@dcherian @rabernat do you know if there is any hope for me here? Is there another approach to extracting these windows that I am not thinking about?

This workflow is actually strikingly similar to the one we used for this paper:

Our code is online here: GitHub - ocean-transport/surface_currents_ml. That work used “stencils” (equivalent to your “chips”) of 2x2, 3x3, and 4x4 for training models at each point. And we also had to drop the NaN points (which in our case corresponded to land).

We experimented with workflows that used xbatcher. We used the input_overlap feature of xbatcher to achieve the sliding windows. However, I don’t think we ended up using that for the final workflow.

This notebook - surface_currents_ml/train_models_stencil_in_space.ipynb at master · ocean-transport/surface_currents_ml · GitHub - shows a way of accomplishing what you are looking for using just reshaping and stacking. However, I don’t think it handles the overlapping stencils.

If your original data are Zarr, you might consider not actually using dask when you open the data. This gives you more control over dask graph. You might do something like this (warning: untested pseudocode), which constructs a dask array lazily via delayed

import xarray as xr
import dask
import dask.array as dsa

ds = xr.open_dataset('data.zarr', chunks=None)  # don't chunk yet

# get the list of valid points somehow
center_points = np.where(ds.mask.notnull())

# this operates on one DataArray at a time and returns a numpy array
@dask.delayed
def load_chip(da: xr.DataArray, j, i, chip_size=2) -> np.array:
    chip = da.isel(x=slice(i-chip_size, i+chip_size) y=slice(j-chip_size, j+chip_size)
    return chip.values  # this triggers loading

all_chips = [
    dsa.from_delayed(load_chip(ds["variable"], j, i), (5, 5), dtype=ds["variable"].dtype)
    for j, i in zip(center_points)
]

big_array = dsa.stack(all_chips)
2 Likes

Thanks a ton for the tips, @rabernat! It is greatly appreciated. I will take a look at those examples and try the snippet of code to see what happens.

@rabernat I ended up not taking the from_delayed approach since dask kept getting upset about the large number of tasks. I might revisit your suggested approach, but I ended up chunking the spatial mosaic xarray dset over x and y and for each chunk passing to roughly this function:

def extract_write_chunks(x_slice, y_slice, ...):
    """..."""
    dset = xr.open_zarr(store=store, group=group)
    subset = dset.sel(x=x_slice, y=y_slice)

    non_nan_idx = np.where(~np.isnan(subset[target_name]))
    indexes = np.stack(non_nan_idx).T

    da = subset.to_array()
    array = np.asarray(da)
    da_lats = np.asarray(da.lat)
    da_lons = np.asarray(da.lon)
    da_times = np.asarray(da.times)

    non_null = []
    lats = []
    lons = []
    times = []
    for t, x, y in indexes:
        chip = array[:, t, x - pw : x + pw + 1, y - pw : y + pw + 1]
        if np.sum(chip.shape[-2:]) == ((pw * 2 + 1) * 2):
            lats.append(da_lats[y - pw : y + pw + 1])
            lons.append(da_lons[x - pw : x + pw + 1])
            times.append(da_times[t])
            non_null.append(chip)

    # write out chips to zarr

I found that indexing and slicing out all the chips was much much faster after converting to dataarray, rather than using .sel or .isel.

We successfully have built 1Tb chip datasets in less than 8ish minutes on the cluster. Feels pretty good!

I have also successfully implemented ops to support stratified splits, balanced datasets spatially, etc.

Somewhat-related:

Now I am at the point of needing to figure out how the heck to serve this data to a tensorflow model on a single worker (with 1-8 gpus). That has prompted my post over on dask discourse on building tf.data.Datasets from dask arrays or delayed objects. I am currently just building tfrecords with dask to gcs and using tf.data api to read and load, but it is not as quick as I would like. I am hoping to offload all the work to other machines/workers and have the gpu machine focus only on catching transfered data, decoding, and loading onto the gpu. I looked at some of the tensorflow models you trained above, but the datasets didn’t look large enough to warrant this. I need to speed this up with ultimate hopes of bootstrapping the training runs.

Has anyone tried to integrate dask with tensorflow datasets? Has anyone seen any working approaches here?

@Leonard_Strnad - you may be interested in the Xbatcher project. Currently, it provides a batch generator API and some prototype ML data loaders (for pytorch and tensorflow). On our road map is to tune dataloaders so they play nice with Dask when feeding data to gpu-backed models.

cc @maxrjones and @weiji14 who have been using/developing xbatcher lately.

2 Likes