Efficient bootstrapping with Xarray

Hello,

I have a multidiemnsional DataArray with Dask under the hood that I would like to perform statistical bootstrapping on. There is one large (100k) dimension n from which I would like to pick a 1000 new samples of the same size and calculate the variance of each sample. From what I’ve read in the Dask docs, the dask.array.random.choice method only works on 1D arrays, so I’m generating the indices first and then using advanced indexing.

The problem is that advanced indexing creates new copies of the original array, which is about 13GB large in my case, resulting in the computation quickly running out of memory. Unforutnately, I don’t know how to avoid that, even though immediately after selecting the variance we reduce the n dimension by calculating the variance. Therefore, the result is tiny.

The workaround I picked is simply calling compute on each sample, but that uses about 10% of my CPU and the overall calculation is quite slow. Below you can find a minimal example of that. Do you have some ideas how to improve this solution if it is even possible? I know Dask struggles with operations like sort, so maybe it’s just not very parallelizable?

import numpy as np
import xarray as xr
from dask.distributed import Client, LocalCluster
import dask.array as da

sample_count = 100
n_size = 1000000

if __name__ == "__main__":
    # setup cluster
    cluster = LocalCluster(
        "127.0.0.1:8786",
        processes=True,
        n_workers=2,
        threads_per_worker=8,
        timeout="2s",
    )
    client = Client(cluster)

    # create dummy data (normally I read it from a zarr file)
    projections = xr.DataArray(
        da.random.random((4, 10, 2, n_size, 2, 11)), dims=("a", "b", "c", "n", "d", "e")
    )

    # Keep n in a single chunk (my real data has n in a single chunk)
    projections = projections.chunk(
        {
            **dict(zip(projections.dims, np.repeat("auto", len(projections.dims)))),
            "n": -1,
        }
    )

    # create random indices for each sample
    idx = xr.DataArray(
        np.random.randint(
            0,
            projections.n.size,
            (sample_count, projections.n.size),
        ),
        dims=("sample", "n"),
    )

    result = []
    for i in idx:
        # I want to get rid of this compute call
        result.append(projections.isel(n=i).var(dim="n").compute())

    # I also tried calling projections.isel(n=idx), but the task graph takes forever to generate

    xr.concat(
        result,
        dim="sample",
    ).chunk("auto").to_dataset(
        name="var"
    ).to_zarr("result", mode="w")
1 Like

Do you have enough RAM to just load projections into memory? That would probably be a lot faster. As long as you’re using the dask threaded scheduler, your workers should be able to sample from the array using shared memory.

I do have 128GB of RAM, while projections have 13GB, so that’s not a problem. Thing is, it does not result in any significant speedup and the CPU is barely used, even when doing everything in numpy like so

import numpy as np
import xarray as xr

sample_count = 10
n_size = 1000000

rng = np.random.default_rng()

# create dummy data (normally I read it from a zarr file)
projections = xr.DataArray(
    rng.random((4, 10, 2, n_size, 2, 11)), dims=("a", "b", "c", "n", "d", "e")
)

result = []
n_axis = projections.get_axis_num(dim="n")
dims = list(projections.dims)
coords = projections.coords
del dims[n_axis]

for i in range(sample_count):
    sample = rng.choice(projections, axis=n_axis, size=n_size)
    result.append(np.var(sample, axis=projections.get_axis_num("n")))

xr.concat(
    [xr.DataArray(r, dims=dims, coords=coords) for r in result],
    dim="sample",
).chunk("auto").to_dataset(name="var").to_zarr("result", mode="w")

Sharing memory does not help as each sample is a new array of size 13GB in memory.

@aaronspring and I had written some resampling/bootstrapping code for climpred that got wrapped into xskillscore here: xskillscore/resampling.py at main · xarray-contrib/xskillscore · GitHub. This might be helpful for your use case!

1 Like