xr.DataArray.chunks, np.digitize and xr.DataArray.groupby, and dask


I am trying to filter our Zarr datasets which are stored in S3 bucket. We use {time: 250, y: 100, x: 100} chunking when writing our data to the Zarr store.
I have a couple of questions and would appreciate if anybody has any input or a suggestion.

  1. Just an observation and I am not clear on why it’s the case when Zarr is stored in S3 bucket vs. on the local file system. xr.DataArray.chunks information seems to be absent when the Zarr store it belongs to resides in the S3 bucket. The chunk information is available when the Zarr store is on my local file system though.

commit: None
python: 3.9.6 | packaged by conda-forge | (default, Jul 11 2021, 03:36:15)
[Clang 11.1.0 ]
python-bits: 64
OS: Darwin
OS-release: 20.6.0
machine: x86_64
processor: i386
byteorder: little
LC_ALL: None
LOCALE: ('en_US', 'UTF-8')
libhdf5: 1.12.1
libnetcdf: 4.8.1
xarray: 0.19.0
pandas: 1.3.4
numpy: 1.21.3
scipy: None
netCDF4: 1.5.7
pydap: None
h5netcdf: 0.11.0
h5py: 3.4.0
Nio: None
zarr: 2.6.1
nc_time_axis: None
PseudoNetCDF: None
rasterio: None
cfgrib: None
iris: None
bottleneck: None
dask: 2.30.0
distributed: 2.30.1
matplotlib: None
cartopy: None
seaborn: None
numbagg: None
pint: None
setuptools: 58.0.4
pip: 21.3.1
conda: None
pytest: None
IPython: 7.29.0
sphinx: None
  1. This is an optimization or “is there a better way to do the same thing” question. To process our data, we have to filter it in the time dimension for each spacial (x, y) point. Since our data sets are rather large, we process data in chunks. These chunk sizes are the same chunks as we use to store the data to the Zarr in x and y dimensions. The following filter, which bins data based on the date_dt values and identifies invalid entries, takes about 2.1 seconds to process 100 spacial points (cube[:, y, 100]), which is rather slow as our data dimensions are (40981, 834, 834) (it would take ~3.5 hours to run such filter for one data variable, and we need to run it for multiple variables). So I wonder if I am doing something inefficient with xarray here and if there is a better and more efficient way to do it (Matlab code seems to be running on a few orders faster). Here is the code snippet that does the filtering:
import dask
import numpy as np
import s3fs
import xarray as xr

# Read cube from S3
cube_path = 's3://its-live-data/datacubes/v02/N60W130/ITS_LIVE_vel_EPSG3413_G0120_X-3250000_Y250000.zarr'
s3_in = s3fs.S3FileSystem(anon=True, skip_instance_cache=True)
cube_store = s3fs.S3Map(root=cube_path, s3=s3_in, check=False)
cube = xr.open_dataset(cube_store, decode_timedelta=False, engine='zarr', consolidated=True)

# Define edges of dt bins
DT_EDGE = [0, 32, 64, 128, 256, np.inf]


# Scalar relation between MAD and STD
MAD_STD_RATIO = 1.4826


def madFunction(x):
    Compute median absolute deviation (MAD).
    return (np.fabs(x - x.median(dim='mid_date'))).median()

# Load date_dt into memory
dt = cube.date_dt.load()

# Load one x/y "chunk" into memory
block_size = 100
vx = cube.vx[:, 0:block_size, 0:block_size].load()

def filter_iteration(x0, dt):
    maxdt = np.nan
    invalid = np.full_like(dt, False)

    if np.all(x0.isnull()):
        # No data to process
        logging.info(f'No data to process')

        # Filter NAN values out
        mask = ~x0.isnull()
        x0 = x0.where(mask, drop=True)
        x0_dt = dt.where(mask, drop=True)

        np_digitize = np.digitize(x0_dt.values, DT_EDGE, right=False)
        index_var = xr.IndexVariable('mid_date', np_digitize)
        groups = x0.groupby(index_var)

        # Are means significantly different for various dt groupings?
        median = groups.median()
        xmad = groups.map(madFunction)

        # Check if populations overlap (use first, smallest dt, bin as reference)
        std_dev = xmad * DTBIN_RATIO
        minBound = median - std_dev
        maxBound = median + std_dev

        exclude = (minBound > maxBound[0]) | (maxBound < minBound[0])

        if np.any(exclude):
            maxdt = np.take(DT_EDGE, np.take(np_digitize, exclude)).min()
            invalid = dt > maxdt

    return (maxdt, invalid)
tasks = [dask.delayed(filter_iteration)(vx.isel(x=i, y=0), dt) for i in range(0, block_size)]

# results = None
# with ProgressBar():  # Does not work with Client() scheduler
results = dask.compute(

# Process results (don't really do anything with the output while debugging)
for each_output in results[0]:
    iter_maxdt, iter_invalid = each_output

Last executed at 2022-01-14 18:48:10 in 2.21s

  1. I tried to run the same function sequentially and don’t see much difference in runtime when ran with dask. I ran dask parallel code on my laptop with 4 CPUs, but don’t see any speed up. What could be the reason?

Sequential code:

for i in range(0, block_size):
    maxdt, invalid = filter_iteration(vx.isel(x=i, y=0), dt)

Last executed at 2022-01-14 18:48:04 in 2.26s

Many thanks in advance!

1 Like

What happens with a processes scheduler?

@RichardScottOZ Using “processes” scheduler results in deadlock at runtime. According to the dask best practices:
If you’re doing mostly numeric work with Numpy, Pandas, Scikit-Learn, Numba, and other libraries that release the [GIL](https://docs.python.org/3/glossary.html#term-global-interpreter-lock), then use mostly threads. If you’re doing work on text data or Python collections like lists and dicts then use mostly processes.