Efficiently slicing random windows for reduced xarray dataset

Hey, y’all! I have had some great success with cloud based, zarr-backed xarray datasets. Thanks for all the great work.

I am soliciting some advice to efficiently create one xarray dataset from a much larger one. Specifically, I have a “spatial mosaic” dataset with a sparse variables. I would like to create a chip training dataset by slicing out a window for every point that is non-nan for a specific variable and write that to a cloud based, zarr-backed xarray dataset.

The idea is to find all non-nan points in a specific variable, slice a window about that point including all other variables, and reduce this set into a dataset of reasonable size and then write that size to a target dataset.

Here is the general idea in code:

import dask
import numpy as np

target_features = ['a','b']
times = [2019]
final_height = 10000
final_width = 10000
xy_chunksize = 1000
dtype = np.float32

##### Create a dummy dataset to demonstrate the source data
# dummies for a
a_data = dask.array.zeros(
    (len(times), final_width, final_height),
    chunks=[1, xy_chunksize, xy_chunksize],
    dtype=dtype,
)

# random nan for b
b = np.arange(final_height * final_width).astype(np.float32)
idx = np.random.randint(0, final_width * final_height, size=int(1e6))
b[idx] = np.nan
b_data = np.repeat(b, len(times)).reshape((-1, final_width, final_height))


spatial_dset = xr.Dataset(
    data_vars={
        'a': (["time", "x", "y"], a_data), 
        'b':(["time", "x", "y"], b_data)
    },
    coords={
        "lon": (["x"], np.arange(final_width)),
        "lat": (["y"], np.arange(final_height)),
        "times": (("time"), times),
    },
)

# create the expected chip datasets
chip_width = 3
n = int(np.isnan(spatial_target.b).sum())
data_variables = list(spatial_target.data_vars)

# establish the shape of each feature
dummies = dask.array.zeros(
    (n, chip_width, chip_width), chunks=[n, chip_width, chip_width], dtype=dtype
)

# establish the shape of the lat and lon coordinates
lon_dummies = dask.array.zeros((n, chip_width), dtype="float64")
lat_dummies = dask.array.zeros((n, chip_width), dtype="float64")
time_dummies = dask.array.zeros((n), dtype="<M8[ns]")
data_vars = {
    nm: (["i", "x", "y"], dummies) for nm in list(data_variables)
}
chip_target = xr.Dataset(
    data_vars=data_vars,
    coords={
        "lon": (["i", "x"], lon_dummies),
        "lat": (["i", "y"], lat_dummies),
        "i": range(n),
        "time": ("i", time_dummies),
    },
)

#write this target
chip_target.to_zarr("gs://...", compute=False)

This example is using nan as the sparse points of interest, but my actual use case will be non-nan. Anyways, since the sparse data is uniform over the entire spatial dataset, I could leverage something like map_blocks with a function like


def slice_non_nan(block_dset):
    idx = np.where(~np.isnan(block_dset['b']))
    if len(idx[0]) > 0:
        chips = []
        for x,y in zip(idx):
            chip = block_dset.sel(
                x=slice(x - chip_size, x + chip_size),
                y=x=slice(y - chip_size, y + chip_size)
            )
            chips.append(chip)
        block_chip_dset = xr.concat(chips)
        block_chip_dset.to_zarr(
            'gs://path_to_target',
            append_dim='i',
            synchronizer=synchronizer
        )

and simply call

spatial_dset.map_blocks(slice_non_nan).compute()

Is there a more reasonable approach for this? My actual spatial_dset has dims frozen({'time': 2, 'x': 48248, 'y': 48050}) with 42 float32 variables, which is about 770gb. The feature of interest has about 10M non-nan examples that are uniformly spread out. From what I understand, the map_blocks method would be a hack since it is supposed to return a dataset. Any help on making this more efficient or links to relevant existing functions would be greatly appreciated!

Another way to do this is something like

width = 3
target_name = "b"
arr = subset.to_array()

focal_arr = np.lib.stride_tricks.sliding_window_view(
    arr, (width, width), axis=(1, 2))

target_band = np.where(arr.coords['variable'] == target_name)[0]
idx = np.where(~np.isnan(focal_arr[target_band, :, :, width // 2, width // 2]))

chips = focal_arr[:, idx[1], idx[2] ,:,:]

but I am afraid this will not scale well as np.asarray is called on arr and it would have to be brought into memory. I have thought about doing this over chunks like

import dask
from itertools import product

x_starts = np.cumsum([0] + list(dset.chunks['x'])[:-1])
x_start_step = zip(x_starts, dset.chunksizes['x'])
y_starts = np.cumsum([0] + list(dset.chunks['y'])[:-1])
y_start_step = zip(y_starts, dset.chunksizes['y'])

futures = []
for (x_start, x_step), (y_start, y_step) in list(product(x_start_step, y_start_step))[:2]:
    x_slice = slice(x_start, x_start + x_step)
    y_slice = slice(y_start, y_start + y_step)
    futures.append(dask.delayed(do_the_sliding_window_thing)(x_slice, y_slice))   

but tasks become a little chaotic and things do not seem to finish.

I haven’t read your question in detail, but .rolling wraps np.sliding_window_view and will dispatch to the dask version when necessary.

So you could call rolling.construct to construct a memory-efficient dask-aware view and then apply your function on that.

1 Like

Great! I missed the construct bit! I am seeing something closer to what I am looking for. I have yet to try and scale it out, but this is roughly what I am landing on

rolling_obj = dset.rolling({'x': 3, 'y': 3})
windowed_dset = rolling_obj.construct(window_dim={'x': 'x2', "y": 'y2'})

idx = np.where(~np.isnan(dset['b']))

# get a small example
t_slice = idx[0][:5]
x_slice = idx[1][:5]
y_slice = idx[2][:5]
windowed_dset.sel(time=t_slice, x=x_slice, y=y_slice).compute()

It looks like this also won’t work because the only way to slice out the nonnan values is by first stacking the x and y dims, which has a high memory requirement due to the expanded lat and lon indexes, which are float64.

The slicing example directly above does not perform fancy indexing and reshape the data and I am not sure how to do that without stacking first. Any thoughts @dcherian ?

I have also tried using dask.array.overlap to explicitly setup delayed objects and doing this in dask array land, but that also requires too much memory. When I open the zarr xarray dataset and call to_array(), that is when I see this memory issue.

Does anyone know of a similar overlap function? Something like

dset = xr.open_zarr()
overlapped_dset = dset.overlap(
    darr, 
    depth=(0, chip_size, chip_size), 
    boundary="nearest"
)

where overlapped_dset has a new axis that enumerate the number of blocks, with nans in data that are smaller than blocks? This might be pretty specific lol

This is another case where Dataset.to_delayed would be useful. Alas we do not have that yet.

So after rolling.construct you’ll have to use a map_blocks to apply a function that does BOTH the indexing and writing IIUC.

  • If it’s more convenient to have your function receive xarray objects, use xarray.map_blocks
  • if it’s more convenient to receive a plain dask array, use dask.array.map_blocks and apply it with xarray.apply_ufunc(..., dask="allowed").

Just return some dummy data to keep the functions happy.

You could also convert the dask arrays to delayed using windowed_dset["b"].data.to_delayed() and then pursue that approach.

Thanks for some tips, @dcherian.

I am still struggling with this a bit. I keep finding myself needing to stack the windowed dataset in order to slice out the windows with non-nan centers, but that causes killed worker issues since the index for that stacked windowed dataset is huge.

The other option I want to return to is to simply slice out the window for every non-nan, but there are 400k non-nan points, which would result in 400k dset.sel() calls over a chunk. I could probably simplify this a bit by using smaller chunks, but that means more nans for non-nans near the border of the chunks…

Just wanted to say that the .rolling code here might fit into the scope of GitHub - pangeo-data/xbatcher: Batch generation from xarray datasets. Cc @meghanrjones and @jhamman.

1 Like

You are stacking for ML shape reasons, Leonard?

I am stacking in order to drop windows that have nans in their center. Yes, I essentially want all mxm windows with non-nan centers for training. Eventually, we will not restrict to centers and take any window with a non-nan anywhere for a sparse regression training scheme.

I can’t think of a way to extract windows with non-nan centers any other way. It’s easy to find index of all non-nans for the variable of interest, but I do not know how given those index and window size to slice out to an (n, m, m) xarray dataset (n would be number or non-nan pixels and m would be the window width).

Once I write this to zarr/gcs, I have a training dataset for the team to start model development.

I haven’t seen this package yet. It looks extremely helpful.

Support for sparse targets like this “chip” example (given mxm window make a single pixel prediction at the center–very common in biomass/canopy height regression) and splitting for training/test/validation (potentially based on masks) would be oh so nice.

Are you stacking inside the function being mapped? This is key. Stack is a reshape which can be very very inefficient in parallel settings.

2 Likes

Indeed I am, but I have chunk sizes of (2, 8192, 8192) over 42 variables, which gives a stacked (z) index of length 134M.

I think that the success of this method will require smaller chunks, but as I stated above, that means more nans in windows near each chunks boundary…

This is roughly the function I am mapping:

width = 7
target_name = 'b'

def extract_chips(subdset):
    
    rolling_obj = subdset.rolling({"x": width, "y": width})
    windowed_dset = rolling_obj.construct(window_dim={"x": "x2", "y": "y2"})

    # stack over x and y in order to select windows with non-nan centers
    center = width // 2
    stacked_dset = windowed_dset.stack(z=("time", "x", "y"))

    # this seems to kill workers
    nonnan_idx = np.where(
        ~np.isnan(stacked_dset[target_name].sel(x2=center, y2=center))
    )[-1]

    # select the windows with non-nan centers!
    chips = stacked_dset.isel(z=nonnan_idx)
    
    # write out chips to zarr-backed xarray on gcs

I might try to map a pre-computed non-nan index before the windowed_dset:

idx = np.where(~np.isnan(subdset[target_name]))
flattened_idx = somehow_flatten_idx(idx)

and avoid the

    nonnan_idx = np.where(
        ~np.isnan(stacked_dset[target_name].sel(x2=center, y2=center))
    )[-1]

operation.

1 Like

Sorry if I’m being dense, but are you applying extract_chips with xarray.map_blocks? A full reproducible minimal example with dummy data would be useful.

1 Like

Not at all. Yes, I am doing that. Well, I tried that and was struggling to see what was going on during the map so I have been doing something like

import dask
from itertools import product

x_starts = np.cumsum([0] + list(dset.chunks['x'])[:-1])
x_start_step = zip(x_starts, dset.chunksizes['x'])
y_starts = np.cumsum([0] + list(dset.chunks['y'])[:-1])
y_start_step = zip(y_starts, dset.chunksizes['y'])

futures = []
for (x_start, x_step), (y_start, y_step) in list(product(x_start_step, y_start_step))[:2]:
    x_slice = slice(x_start, x_start + x_step)
    y_slice = slice(y_start, y_start + y_step)
    futures.append(dask.delayed(extract_chips)(x_slice, y_slice))   

where extract chips does the slicing.

I will put together a full, minimal example and post today or tomorrow.

1 Like

@dcherian Here is a fully reproducible example with a dataset that has roughly the same number of non-nans in the target, has the same height and width, but fewer variables than my actual dataset. I write it out to zarr to prevent additional tasks, to replicate my exact process, and to isolate the compute issues in the chip extraction.

My dask cluster has 36 workers each with 16gb and 4cores using HelmCluster.

import dask
import numpy as np
import xarray as xr


target_path = "gs://somewhere/out/there"
features = ['a','b','c','d','e','f']
times = [2018, 2019]
final_height = 40000
final_width = 40000
xy_chunksize = 8192
n_non_nans = 10e6
p_non_nan = 1 - (n_non_nans / (final_height * final_width))

dummy_data = dask.array.random.random(
    (len(times), final_width, final_height),
    chunks=[1, xy_chunksize, xy_chunksize],
)
dummy_data = dummy_data.astype('float32')

spatial_dset = xr.Dataset(
    data_vars={ ftr: (["time", "x", "y"], dummy_data) for ftr in features},
    coords={
        "lon": (["x"], np.arange(final_width)),
        "lat": (["y"], np.arange(final_height)),
        "times": (("time"), times),
    },
)

# fake sparse non-nan value data in target
spatial_dset = spatial_dset.assign(
    {'target': xr.where(spatial_dset['f'] < p_non_nan, np.nan, True)}
)

# write to zarr somewhere
spatial_dset.to_zarr(target_path, group='example_mosaic')

# inspect and check
rt_spatial_dset = xr.open_zarr(target_path, group='example_mosaic')
idx = np.where(~np.isnan(rt_spatial_dset['target']))
assert np.abs(len(idx[1]) / len(times) - n_non_nans) < 5000


def extract_write_chip_dset(
    x_slice: slice,
    y_slice: slice,
    width: int,
    target_name: str,
):
    if width % 2 != 1:
        raise ValueError("Width must be odd for non-nan to be at the center.")

    dset = xr.open_zarr(target_path, group='example_mosaic')
    subdset = dset.sel(x=x_slice, y=y_slice)

    # temporarily assign lat and lon to variables indexed by x or y
    subdset = subdset.assign(
        {"lat2": ("y", subdset.lat.data), "lon2": ("x", subdset.lon.data)}
    )

    # use a rolling object to build a windowed view over the entire spatial mosaic
    rolling_obj = subdset.rolling({"x": width, "y": width})
    windowed_dset = rolling_obj.construct(window_dim={"x": "x2", "y": "y2"})


    # stack over x and y in order to select windows with non-nan centers
    center = width // 2
    stacked_dset = windowed_dset.stack(z=("time", "x", "y"))
    #return stacked_dset.dims
    nonnan_idx = np.where(
        ~np.isnan(stacked_dset[target_name].sel(x2=center, y2=center))
    )[-1]
    chips = stacked_dset.isel(z=nonnan_idx)
    
    # should write to a zarr chip dataset here, but I return the dims to 
    # first inspeect that we can extract chips in the first place, something
    # that map_blocks does not allow us to do.

    return chips.dims


import dask
from itertools import product

dset = rt_spatial_dset
dset = dset.unify_chunks()
x_starts = np.cumsum([0] + list(dset.chunks['x'])[:-1])
x_start_step = zip(x_starts, dset.chunksizes['x'])
y_starts = np.cumsum([0] + list(dset.chunks['y'])[:-1])
y_start_step = zip(y_starts, dset.chunksizes['y'])

futures = []
for (x_start, x_step), (y_start, y_step) in list(product(x_start_step, y_start_step))[:5]:
    x_slice = slice(x_start, x_start + x_step)
    y_slice = slice(y_start, y_start + y_step)
    futures.append(dask.delayed(extract_write_chip_dset)(x_slice, y_slice, 3, 'target'))  


results = dask.compute(*futures)

Before I try to use map_blocks, I am (as I have alluded to above) manually creating delayed tasks over chunks explicitly in order to inspect the chips first. The issue with map_blocks is that it is a bit harder to inspect what is going on in the function I am mapping. I also have gotten into this habit of reading and slicing a dataset in the function, which is why I pass slices instead of passing a sliced dataset.

Anyways, I get an issue here

...
     25 stacked_dset = windowed_dset.stack(z=("time", "x", "y"))
     26 #return stacked_dset.dims
---> 27 nonnan_idx = np.where(
     28     ~np.isnan(stacked_dset[target_name].sel(x2=center, y2=center))
     29 )[-1]
     30 chips = stacked_dset.isel(z=nonnan_idx)
     32 return chips.dims
...
KilledWorker: ("('getitem-overlap-reshape-transpose-invert-f9ec729b6f057728d617ffb36be4dc89', 1)", <WorkerState 'tcp://10.100.15.4:45683', status: closed, memory: 0, processing: 3>)

The other issue I am encountering is that when things do work out (with smaller widths and chunksizes), I get more nan windows near the border of the chunks. So, I lose data, which is not ideal.

I hope this makes the problem clearer. And thanks for working with me a bit here!

1 Like

And as an attempt to prevent running np.where over the stacked dataset, I modified selecting windows with non-nan by determining what the corresponding z-index would be:

def extract_write_chip_dset(
    x_slice: slice,
    y_slice: slice,
    width: int,
    target_name: str,
):
    if width % 2 != 1:
        raise ValueError("Width must be odd for non-nan to be at the center.")

    dset = xr.open_zarr(target_path, group='example_mosaic')
    subdset = dset.sel(x=x_slice, y=y_slice)

    # temporarily assign lat and lon to variables indexed by x or y
    subdset = subdset.assign(
        {"lat2": ("y", subdset.lat.data), "lon2": ("x", subdset.lon.data)}
    )

    # use a rolling object to build a windowed view over the entire spatial mosaic
    rolling_obj = subdset.rolling({"x": width, "y": width})
    windowed_dset = rolling_obj.construct(window_dim={"x": "x2", "y": "y2"})


    # stack over x and y in order to select windows with non-nan centers
    stacked_dset = windowed_dset.stack(z=("time", "x", "y"))

    # assuming we stack with (time, x, y) determine the z index of non-nans
    non_nan_idx = np.where(~np.isnan(subdset['target']))
    blah = np.stack(non_nan_idx).T
    x_dims = subdset.dims['x']
    y_dims = subdset.dims['y']
    z_idx = blah[:,0] * x_dims * y_dims + blah[:,1] * y_dims + blah[:,2]
    chips = stacked_dset.isel(z=z_idx)

    return chips.mean().compute()

but I am still getting killed workers, but without any pointer to code. I call .mean().compute() to. make sure I can actually bring those chips into memory, which will have to happen for the write to zarr.

np.where will compute on the whole subdset?

Indeed. If compute has not been called it will be triggered, but we are only doing this over a single variable, which corresponds to a 256mb chunk.