Any suggestions for efficiently operating over windows of data?

I apologize in advance if I’m missing something obvious with this problem - I’m having trouble finding resources which give a direct answer to the issues I’m facing.

In short, I have several different zarr datasets hosted on AWS that I need to work with. Taking a small subset of ERA5 as an example:

import xarray as xr
import s3fs, dask
store = s3fs.S3FileSystem(False, <credentials>).get_mapper(<s3://ERA5 bucket>)
data = xr.open_dataset(store, chunks='auto', engine='zarr')
data
>   <xarray.Dataset>
>   Dimensions:    (latitude: 721, longitude: 1440, time: 72)
>   Coordinates:
>     * latitude   (latitude) float32 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0
>     * longitude  (longitude) float32 -180.0 -179.8 -179.5 ... 179.2 179.5 179.8
>     * time       (time) datetime64[ns] 2015-04-01 ... 2015-04-03T23:00:00
>   Data variables: (12/15)
>       msdwlwrf   (time, latitude, longitude) float32 dask.array<chunksize=(24, 721, 1440), meta=np.ndarray>
>       msdwswrf   (time, latitude, longitude) float32 dask.array<chunksize=(24, 721, 1440), meta=np.ndarray>
>       skt        (time, latitude, longitude) float32 dask.array<chunksize=(24, 721, 1440), meta=np.ndarray>
>       stl1       (time, latitude, longitude) float32 dask.array<chunksize=(24, 721, 1440), meta=np.ndarray>
>       stl2       (time, latitude, longitude) float32 dask.array<chunksize=(24, 721, 1440), meta=np.ndarray>
>       stl3       (time, latitude, longitude) float32 dask.array<chunksize=(24, 721, 1440), meta=np.ndarray>
>       ...         ...
>       swvl3      (time, latitude, longitude) float32 dask.array<chunksize=(24, 721, 1440), meta=np.ndarray>
>       swvl4      (time, latitude, longitude) float32 dask.array<chunksize=(24, 721, 1440), meta=np.ndarray>
>       t2m        (time, latitude, longitude) float32 dask.array<chunksize=(24, 721, 1440), meta=np.ndarray>
>       tp         (time, latitude, longitude) float32 dask.array<chunksize=(24, 721, 1440), meta=np.ndarray>
>       u10        (time, latitude, longitude) float32 dask.array<chunksize=(24, 721, 1440), meta=np.ndarray>
>       v10        (time, latitude, longitude) float32 dask.array<chunksize=(24, 721, 1440), meta=np.ndarray>
>   Attributes:
>       Conventions:  CF-1.6
>       history:      2022-12-09 04:38:31 GMT by grib_to_netcdf-2.25.1: /opt/ecmw...

My end goal is to operate over batches of samples from this dataset and others for ML applications, where each sample is a window surrounding a given spatial/temporal location. So for example I might have samples shaped (time=10, latitude=7, longitude=7) which represents 10 consecutive time steps with a 7x7 spatial window surrounding a given pixel lat/lon.

Seemingly easy enough, but I’m having difficulty figuring out how to best organize operations so that memory usage, disk (network) access, and computation time is minimized.

One part of this problem is that if too many pixels are missing in a window, the window is invalid. Similarly, if any windows are invalid in the timeseries, the entire sample is invalid. So naïvely implementing this, it would look something like:

# These are somewhat large windows, but I want to trigger
# a memory error here if tasks aren't executed as expected
windows = { 
    'latitude'  : 50,
    'longitude' : 50,
    'time'      : 20,
}
center = {
    'latitude'  : True,
    'longitude' : True,
    'time'      : False,
} 

# Create a view of the multidimensional windows
job = data.rolling(center=center, **windows)
job = job.construct({'latitude':'a', 'longitude':'b', 'time':'c'})

# Stack coordinates since we only care about spatial/temporal windows now
with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    job = job.stack(llt=['latitude', 'longitude', 'time'])

# Now we filter out invalid samples
dimsum  = lambda x, dims: x.sum(dims, skipna=False, dtype='int8')
invalid = dimsum(job.isnull(), ['a', 'b']) > 10 # Invalid if > 10 missing pixels
valid   = dimsum(invalid, ['c']) == 0 # Series must have 0 invalid windows

# This actually triggers computation, so it won't run beyond this point
job = job.where(valid, drop=True)

# Now just check the first valid sample
job = job.isel({'llt': 0})

Great, except for trying to compute it results in 4+TiB arrays trying to be allocated at various places. At least some of those places make sense, e.g. the original chunk size is (24, 721, 1440), so getting a rolling window over that single chunk would be problematic. Even if I load the data with something like (10,10,10) chunking though, it isn’t able to operate over windows of this size because it’s automatically rechunked to (29, 51, 55, 50, 50, 20).

It seems that in the minimum chunk size case, it would only need to load a sample array of size [50, 50, 20], check if it’s valid, and if not, discard it and check the next window (reloading the larger block of original data if necessary, and pulling out the relevant single window). I can’t seem to even get it to do that, though - despite trying to rechunk in various places, reorder operations, reformulate the logic, just passing in a fake numpy mask in job.where, etc.

Is there something I’m misunderstanding about how this should be approached?

1 Like

If it’s useful, here’s a snippet that can be dropped in to replace the S3 loading, so that the above can be run:

import numpy as np

def create_array(lat_size, lon_size, time_size, missing_pct=0.1):
    a = np.random.rand(time_size, lat_size, lon_size).astype('float32')
    i = np.random.randint(0, a.size, int(a.size * missing_pct))
    a.ravel()[i] = np.nan 
    return xr.DataArray(a, coords={
        'time'      : np.arange(time_size),
        'latitude'  : np.arange(lat_size), 
        'longitude' : np.arange(lon_size), 
    }) 

# Create a random dataset with 10 variables, shaped (lat=721, lon=1440, time=72)
data = xr.Dataset({f'var_{i}': create_array(721, 1440, 72) for i in range(10)})
data = data.chunk({'latitude': 10, 'longitude': 10, 'time': 10})

You may find some helpful information in Efficiently slicing random windows for reduced xarray dataset, which also considered efficiently selecting only “valid” examples.

We’re working on the xbatcher library for this type of use case - our roadmap for the next few months focuses on improving efficiency for the data loaders and we’re considering support for filtering examples.

1 Like

TL:DR on Efficiently slicing random windows for reduced xarray dataset and how I ended up doing this:

  • build a dataset with dims (“i”, “y”,“x”) where “i” indexes the number of valid elements
  • drop down to job.data.to_delayed().ravel() but be careful to use da.overlap.overlap first if you are worried about “chips” near the boundary of chunks
  • go count valid “chips” with a first pass only over data that determines valid “chips”. This takes 30s on a cluster over a single band using ~np.isnan() as our valid pixel identifier.
  • use this to provide deterministic shapes to a second pass extracting the data by applying a delayed function over the delayed chunks from the xr dataset to literally iterate over each valid chip in numpy code (we literally slice windows/chips one by one).
  • reduce these delayed objects back to dask arrays with da.from_delayed providing the shapes and possibly da.concatenate
  • and you will have a lazy “chip” dask array you can use to create an xarray dataarray/dataset.

This can take significant resources with many valid pixels when you actually call compute. This is a good stopping point for us and so we trigger computation by persisting to zarr for follow up ml operations. We do this to keep ml iteration faster and avoid referencing all these tasks required to build this dataset. Our largest chip dataset is roughly (20e6, 64, 15, 15) extracted from (2 years, 64 variables, 50000x, 50000y) i.e. 20million valid chips with 64 variables and 15x15 chips. They are sparse! It takes 5ish minutes with 32 workers 4threads and 16gb. We then build tfrecords, flatten to xgboost with dask, split + balance on coords, etc.

Thanks very much, both of you! @Leonard_Strnad The approach you describe seems to be pretty much the exact one I came up with since my original post. I’ve pasted the code below, but a few notes:

  • This is indeed using overlap, and so there aren’t any missing valid windows along chunk edges

  • The memory required is fairly minimal until actually pulling windows, due to needing to instantiate the full (chunk x window) numpy view; e.g. (10, 20, 20, 21, 51, 51). This is fixable by gradually building up the window rather than calling reshape on the view, but unfortunately requires the ‘take_along_axis’ function which isn’t yet implemented in dask. If you (or anyone else) has any other suggestions for work arounds, I’d love to hear them.

  • I’m seeing pretty much the same performance you describe while using my desktop on the current sizes: ~30 seconds to find all valid windows, and it’s nearly constant to extract the windows from there. There’s a direct trade off between memory footprint and calculation time, via the initial data chunk sizes.

from dask.diagnostics import ProgressBar

import bottleneck as bn
import dask.array as da 
import xarray as xr
import numpy as np
import dask
import os, psutil


def get_memory_usage():
    process = psutil.Process(os.getpid())
    return f'~{process.memory_info().rss/1e9:,.2f}Gb'


def create_array(
    dimensions  : dict, 
    missing_pct : float = 0.1, 
    random      : bool  = False,
    random_seed : int   = None,
) -> xr.DataArray:
    """ Create a DataArray

    Parameters
    ----------
    dimensions  : dict
        Dict mapping {coordinate: shape} for the DataArray.
    missing_pct : float
        Percentage of values that should be NaN.
    random      : bool
        Whether to use random floats, or just a monotonic range
        of numbers from [0, size of DataArray].
    random_seed : int
        Seed for the random generator.
    Returns
    -------
    xr.DataArray
        DataArray created using the requested parameters.

    """
    numpy_rng    = np.random.default_rng(random_seed)
    keys, shapes = zip(*dimensions.items())
    if random: a = numpy_rng.random(shapes, dtype='float32')
    else:      a = np.arange(np.prod(shapes), dtype='float32').reshape(shapes)
    missing_idxs = i = numpy_rng.integers(0, a.size, int(a.size * missing_pct))
    a.ravel()[i] = np.nan 
    return xr.DataArray(a, coords=dict(zip(keys, map(np.arange, shapes))))



# Small configuration to verify correctness
small_config = {
    'data_kwargs'   : { 
        'dimensions' : { # Dimensions of the DataArray
            'time'      : 3,
            'latitude'  : 5,
            'longitude' : 5,
        }, 
        'random_seed' : 42,
    },
    'chunks' : { # Initial data chunking
        'time'      : 2, 
        'latitude'  : 2,
        'longitude' : 2,
    },
    'depth' : { # 2 x 3 x 3 window
        'time'      : 1, # 1 lookback step = window depth  2
        'latitude'  : 1, # 1 adjacent lats = window height 3
        'longitude' : 1, # 1 adjacent lons = window width  3
    },
}

# Large configuration to verify memory footprint / speed
large_config = {
    'data_kwargs' : {
        'dimensions'   : { # Dimensions of the DataArray
            'time'      : 72,
            'latitude'  : 721,
            'longitude' : 1440,
        },
        'random' : True,
    },
    'chunks' : { # Initial data chunking
        'time'      : 10,
        'latitude'  : 20,
        'longitude' : 20,
    },
    'depth' : { # 21 x 51 x 51 window
        'time'      : 20, # 20 lookback step = window depth  21
        'latitude'  : 25, # 25 adjacent lats = window height 51
        'longitude' : 25, # 25 adjacent lons = window width  51
    },
}


if __name__ == '__main__':
    ProgressBar().register()

    # Create our example data
    config = small_config
    data   = create_array(**config['data_kwargs'])
    print('\nGenerated DataArray:\n', data)
    print('\nCurrent memory usage:', get_memory_usage())

    # Convert depth (adjacency) to total window size
    # Time is only lookback steps, lat/lon is on either side of center
    depth_to_window = lambda k, v: (v+1) if k == 'time' else (v*2+1)

    dim_ax = dict(zip(data.dims, np.arange(len(data.dims))))
    depth  = {i: config['depth'][d] for i,d in enumerate(data.dims)}
    window = [depth_to_window(d, depth[i]) for i,d in enumerate(data.dims)]

    # Chunk the data and pull out the dask array
    data = data.chunk(config['chunks']).data
    print('\nDask data array:\n', data)


    # Define our reduction functions
    def chunking(chunk, *args, **kwargs):
        """ da.reduction chunking function; creates and filters windows """
        create_window = da.lib.stride_tricks.sliding_window_view

        time_ax = dim_ax['time']
        lat_ax  = dim_ax['latitude']
        lon_ax  = dim_ax['longitude']

        # Determine how many pixels are allowed to be missing in a lat/lon window
        allowed_missing = int(window[lat_ax] * window[lon_ax] * 0.1) # 10%

        # Only create windows for time lookback, and discard lookforward
        valid = [slice(None), slice(None), slice(None)]
        valid[time_ax] = slice(None, -(depth[time_ax]))
        chunk = da.from_array(chunk[tuple(valid)])

        # Create a mask for the valid windows, building
        # axes one a time to minimize memory footprint
        invalid = da.isnan(chunk)
        invalid = create_window(invalid, window[lat_ax], axis=lat_ax).sum(-1)
        invalid = create_window(invalid, window[lon_ax], axis=lon_ax).sum(-1)
        invalid = invalid > allowed_missing # More missing than allowed 
        valid   = create_window(invalid, window[time_ax], axis=time_ax).sum(-1)
        valid   = valid == 0 # Zero invalid lat/lon windows in a sequence

        # Compute the valid mask in order to allow dask to 
        # understand the final shape of the filtered windows
        valid = valid.ravel().compute()

        # Create a view on the fully windowed data and pull out the valid
        view = create_window(chunk, window).reshape([-1] + window)
        return view[valid]


    def aggregation(chunk, *args, **kwargs):
        """ da.reduction aggregation function; aggregates windows """
        if isinstance(chunk, list):
            return da.vstack([aggregation(c, *args, **kwargs) for c in chunk])
        return chunk


    # Define our full window generation pipeline
    def pipeline():
        kwargs = {'concatenate': False, 'dtype': 'float32'}

        with dask.config.set(**{'array.slicing.split_large_chunks': True}):
            job = da.overlap.overlap(data, depth=depth, boundary=np.nan)
            job = da.reduction(job, chunking, aggregation, **kwargs)
            return job.compute().rechunk((10, -1, -1, -1))
    

    print('\nCalculating valid windows...')
    windows = pipeline()

    print('\nNumber of valid windows:', len(windows))
    print('Dask array:\n', windows)

    print('\nFirst valid window:')
    if data.size < 100: print(windows[0].compute())
    else:               print(windows[0].compute().shape)
    
    print('\nCurrent memory usage:', get_memory_usage())

    print('\nFirst 500 valid windows:')
    print(windows[:500].compute().shape)

And here’s an example run output using the large configuration:

> Generated DataArray:
>  <xarray.DataArray (time: 72, latitude: 721, longitude: 1440)>
> array([[[6.02052331e-01, 5.51831007e-01, 2.38153279e-01, ...,
>          5.42671919e-01, 1.44276798e-01, 6.79698229e-01],
>         [7.78109729e-01, 5.20548820e-01, 8.25416028e-01, ...,
>                     nan, 3.52441192e-01, 1.51431143e-01],
>         [           nan, 2.62471437e-01, 9.05505776e-01, ...,
>          9.31905925e-01, 7.44619429e-01, 1.70831382e-01],
>         ...,
>         [2.14767992e-01, 8.95660341e-01, 7.51004815e-01, ...,
>          6.88676715e-01,            nan, 7.45943606e-01],
>         [1.44425273e-01, 9.37121928e-01,            nan, ...,
>          4.46808994e-01, 9.83534813e-01, 9.34679627e-01],
>         [3.73181462e-01, 8.24000299e-01, 5.53744912e-01, ...,
>          5.80387652e-01,            nan, 9.54352319e-01]],
> 
>        [[8.78788590e-01, 1.27004385e-01, 3.12559545e-01, ...,
>          9.12928164e-01, 9.52983379e-01, 6.12718880e-01],
>         [7.15651393e-01, 7.28087008e-01, 9.14829373e-01, ...,
>          5.14161229e-01, 9.18503344e-01, 6.50941730e-02],
>         [           nan, 6.83702290e-01, 1.01656258e-01, ...,
>          7.97033966e-01, 4.55283523e-02, 4.91114080e-01],
> ...
>         [           nan, 8.57391715e-01, 2.75434434e-01, ...,
>          4.83412981e-01, 6.74598694e-01,            nan],
>         [8.42558146e-02, 6.00736558e-01, 7.57921457e-01, ...,
>          6.19906187e-03, 5.27651310e-02, 5.92216849e-02],
>         [           nan, 3.51869822e-01, 2.37307727e-01, ...,
>          7.46071398e-01, 4.05963302e-01, 3.40276062e-01]],
> 
>        [[1.12030089e-01, 3.90940666e-01, 8.49876702e-01, ...,
>          3.80216300e-01, 5.12348652e-01, 7.03737080e-01],
>         [9.59311366e-01, 3.27358663e-01, 5.67284942e-01, ...,
>          3.60775530e-01, 4.45301652e-01, 9.48904157e-01],
>         [7.53009439e-01,            nan, 7.29954839e-02, ...,
>          8.25474799e-01, 3.65721703e-01, 9.18886721e-01],
>         ...,
>         [3.08558464e-01, 9.46927965e-01, 5.40417254e-01, ...,
>          8.02780867e-01, 1.08663082e-01, 5.91851711e-01],
>         [9.21159208e-01, 3.93624127e-01, 4.32650268e-01, ...,
>                     nan, 4.59239066e-01, 4.17420208e-01],
>         [6.05253994e-01, 7.09035814e-01, 2.26445198e-02, ...,
>          7.90548265e-01, 9.64069128e-01, 6.18108034e-01]]], dtype=float32)
> Coordinates:
>   * time       (time) int32 0 1 2 3 4 5 6 7 8 9 ... 63 64 65 66 67 68 69 70 71
>   * latitude   (latitude) int32 0 1 2 3 4 5 6 7 ... 714 715 716 717 718 719 720
>   * longitude  (longitude) int32 0 1 2 3 4 5 6 ... 1434 1435 1436 1437 1438 1439
> 
> Current memory usage: ~0.39Gb
> 
> Dask data array:
>  dask.array<xarray-<this-array>, shape=(72, 721, 1440), dtype=float32, chunksize=(10, 20, 20), chunktype=numpy.ndarray>
> 
> Calculating valid windows...
> [########################################] | 100% Completed | 34.9s
> 
> Number of valid windows: 544413
> Dask array:
>  dask.array<rechunk-merge, shape=(544413, 21, 51, 51), dtype=float32, chunksize=(10, 21, 51, 51), chunktype=numpy.ndarray>
> 
> First valid window:
> [########################################] | 100% Completed |  1.8s
> (21, 51, 51)
> 
> Current memory usage: ~3.08Gb
> 
> First 500 valid windows:
> [########################################] | 100% Completed |  2.6s
> (500, 21, 51, 51)

Be aware that running the large config does require something like ~20Gb of memory when computing the first windows in the script, again due to the reshaping issue.