I apologize in advance if I’m missing something obvious with this problem - I’m having trouble finding resources which give a direct answer to the issues I’m facing.
In short, I have several different zarr datasets hosted on AWS that I need to work with. Taking a small subset of ERA5 as an example:
import xarray as xr
import s3fs, dask
store = s3fs.S3FileSystem(False, <credentials>).get_mapper(<s3://ERA5 bucket>)
data = xr.open_dataset(store, chunks='auto', engine='zarr')
data
> <xarray.Dataset>
> Dimensions: (latitude: 721, longitude: 1440, time: 72)
> Coordinates:
> * latitude (latitude) float32 90.0 89.75 89.5 89.25 ... -89.5 -89.75 -90.0
> * longitude (longitude) float32 -180.0 -179.8 -179.5 ... 179.2 179.5 179.8
> * time (time) datetime64[ns] 2015-04-01 ... 2015-04-03T23:00:00
> Data variables: (12/15)
> msdwlwrf (time, latitude, longitude) float32 dask.array<chunksize=(24, 721, 1440), meta=np.ndarray>
> msdwswrf (time, latitude, longitude) float32 dask.array<chunksize=(24, 721, 1440), meta=np.ndarray>
> skt (time, latitude, longitude) float32 dask.array<chunksize=(24, 721, 1440), meta=np.ndarray>
> stl1 (time, latitude, longitude) float32 dask.array<chunksize=(24, 721, 1440), meta=np.ndarray>
> stl2 (time, latitude, longitude) float32 dask.array<chunksize=(24, 721, 1440), meta=np.ndarray>
> stl3 (time, latitude, longitude) float32 dask.array<chunksize=(24, 721, 1440), meta=np.ndarray>
> ... ...
> swvl3 (time, latitude, longitude) float32 dask.array<chunksize=(24, 721, 1440), meta=np.ndarray>
> swvl4 (time, latitude, longitude) float32 dask.array<chunksize=(24, 721, 1440), meta=np.ndarray>
> t2m (time, latitude, longitude) float32 dask.array<chunksize=(24, 721, 1440), meta=np.ndarray>
> tp (time, latitude, longitude) float32 dask.array<chunksize=(24, 721, 1440), meta=np.ndarray>
> u10 (time, latitude, longitude) float32 dask.array<chunksize=(24, 721, 1440), meta=np.ndarray>
> v10 (time, latitude, longitude) float32 dask.array<chunksize=(24, 721, 1440), meta=np.ndarray>
> Attributes:
> Conventions: CF-1.6
> history: 2022-12-09 04:38:31 GMT by grib_to_netcdf-2.25.1: /opt/ecmw...
My end goal is to operate over batches of samples from this dataset and others for ML applications, where each sample is a window surrounding a given spatial/temporal location. So for example I might have samples shaped (time=10, latitude=7, longitude=7) which represents 10 consecutive time steps with a 7x7 spatial window surrounding a given pixel lat/lon.
Seemingly easy enough, but I’m having difficulty figuring out how to best organize operations so that memory usage, disk (network) access, and computation time is minimized.
One part of this problem is that if too many pixels are missing in a window, the window is invalid. Similarly, if any windows are invalid in the timeseries, the entire sample is invalid. So naïvely implementing this, it would look something like:
# These are somewhat large windows, but I want to trigger
# a memory error here if tasks aren't executed as expected
windows = {
'latitude' : 50,
'longitude' : 50,
'time' : 20,
}
center = {
'latitude' : True,
'longitude' : True,
'time' : False,
}
# Create a view of the multidimensional windows
job = data.rolling(center=center, **windows)
job = job.construct({'latitude':'a', 'longitude':'b', 'time':'c'})
# Stack coordinates since we only care about spatial/temporal windows now
with dask.config.set(**{'array.slicing.split_large_chunks': True}):
job = job.stack(llt=['latitude', 'longitude', 'time'])
# Now we filter out invalid samples
dimsum = lambda x, dims: x.sum(dims, skipna=False, dtype='int8')
invalid = dimsum(job.isnull(), ['a', 'b']) > 10 # Invalid if > 10 missing pixels
valid = dimsum(invalid, ['c']) == 0 # Series must have 0 invalid windows
# This actually triggers computation, so it won't run beyond this point
job = job.where(valid, drop=True)
# Now just check the first valid sample
job = job.isel({'llt': 0})
Great, except for trying to compute it results in 4+TiB arrays trying to be allocated at various places. At least some of those places make sense, e.g. the original chunk size is (24, 721, 1440), so getting a rolling window over that single chunk would be problematic. Even if I load the data with something like (10,10,10) chunking though, it isn’t able to operate over windows of this size because it’s automatically rechunked to (29, 51, 55, 50, 50, 20).
It seems that in the minimum chunk size case, it would only need to load a sample array of size [50, 50, 20], check if it’s valid, and if not, discard it and check the next window (reloading the larger block of original data if necessary, and pulling out the relevant single window). I can’t seem to even get it to do that, though - despite trying to rechunk in various places, reorder operations, reformulate the logic, just passing in a fake numpy mask in job.where, etc.
Is there something I’m misunderstanding about how this should be approached?