Quartile Calculations on reshaped arrays - using dask + xarray + s3 and still having memory problems

I would like to find statistics for 30 years worth of gridded data, where the “time” axis must be preserved since I’m calculating quantile statistics. The data is stored on an s3 bucket. I need to save the statistics out as a netcdf. I’m having a memory error despite using dask, so I think I’m not understanding something fundamental to how dask operates.

Work flow is as follows:

Process the data locally, turning the data 1-D so I can store as a dataframe on an s3 bucket. Time is saved separately.

for year in years:
        time, hs, tp, dp = dl.load_bulk_params(year)
        bathy, lat, lon = dl.load_bathy()

        #flip upside down
        hs = np.flipud(hs).ravel()
        tp = np.flipud(tp).ravel()
        dp = np.flipud(dp).ravel()

        dataframe = pd.DataFrame(
             data = {'hs':hs, 'tp':tp, 'dp':dp}
        time_dataframe = pd.DataFrame({'time':time})
        s3.upload(local_path, remote_path)

Load the data from the remote:

    ddf = dask.dataframe.read_parquet('s3://.../params*.parquet'))
    time_ddf = dask.dataframe.read_parquet('s3://.../time*.parquet'))

Turn parameter from dask dataframe into dask array, reshape, rechunk and and wrap into xarray (is this the memory intensive part where I should be careful about chunk sizes?), so I can use the xarray functionality to group by years, seasons,etc. and to save out as a netcdf:

timevec = time_ddf.partitions[:]['time'].values.compute()

    arr = ddf[param].to_dask_array(lengths=True).reshape((len(timevec), len(lat), len(lon))).rechunk((len(timevec), 5,5))

    hs_xr = xr.DataArray(arr, dims = ['time', 'lat', 'lon'], 
                        coords = {'time':timevec, 'lat':lat, 'lon':lon}

Calculate necessary statistics using map_partitions and a defined function:

def dask_percentile(arr, q, axis = 0):
    if len(arr.chunks[axis]) > 1:
        msg = ('Input array cannot be chunked along the percentile '
        raise ValueError(msg)
    return da.array.map_blocks(np.percentile, arr, axis=axis, q=q,

computed_stats_ds = hs_xr.reduce(dask_percentile, dim = 'time', q = 90).compute()

Receiving process “killed”

Should I use a distributed dask cluster (start a client)? I have tried this, also, and I’m using a remote machine so I cannot access the dask dashboard (or can I?). When I use the client, the memory overflows and the workers are killed. I could reconfigure the dask parameters to allow for higher memory usage.

I thought, though, that the worker would process one chunk at a time. Is the rechunking and reshaping taking too much memory? What am I misunderstanding using Dask?

Another way forward would be to save everything as grids locally then load with xarray using open_mfdataset, but this is not desirable because I’d have to increase the disk space of the remote machine. Perhaps the whole work flow should be changed where I save the data out in a gridded format as a zarr?

Thank you so much for taking the time to respond and help me!

First thing - how big is the dataset? How much memory in your workers - in the current division of data is this realistic?

one column (parameter) of the dataframe is 96 GB
one chunk when using the “100 MB” as input argument for chunk size is
23946648 elements which is (8 bytes per element) approx. 191 MB?

running lscpu on linux:
16 core, 1 thread per core
Caches (sum of all):
L1d: 512 KiB (16 instances)
L1i: 512 KiB (16 instances)
L2: 64 MiB (16 instances)