Is xr.Dataset.map() parallelized with Dask?

Ha all, I have a dataset with many variables I’d like to run through a function in parallel such that each variable gets assigned a Dask worker. What is the best route for this? I can run a for each loop and put the Dask delayed results in a list and then call Dask.compute() and xr.merge(), which works, but doesn’t seem to balance the variables across workers well. I’ll end up with some workers processing 10 variables and others none. Dataset.map() seems the obvious solution, but it doesn’t seem to be executing in parallel in my testing.

Suggestions?

Could you share a reproducible example of what you’re trying to do using synthetic data?

Here’s an example of using map in the way you describe

import xarray as xr
import dask.array as da

# create a dataset
ds = xr.Dataset(
    {
        "A": (('y', 'x'), da.ones(shape=(10, 10), chunks=(10, 10))),
        "B": (('y', 'x'), da.ones(shape=(10, 10), chunks=(10, 10))),
    }
)

# apply map lazily
ds_double = ds.map(lambda x: 2 * x)

# compute results
ds_double.compute()

Dataset.map() does execute using a simple loop over variables, but that just means that the setting up of the various lazy dask tasks is done in serial. When you call .compute() there should still be an embarrassingly parallel lazy dask task graph ready to be executed in parallel.

It is possible to parallelize the setting up of the dask tasks using dask.delayed like you did, but that’s not the actual in-memory computation so it shouldn’t make much different to performance unless you have a really extreme case (e.g. perhaps a Dataset with 1000’s of small variables). That’s why xarray only bothers using dask.delayed internally for IO, where setting up the original dask arrays can be expensive - this is what the parallel=True argument to open_mfdataset does.

As Ryan says I would be curious to see a reproducible example of what you’re trying to do, because this shouldn’t make much difference.

Note that this question recently came up in the context of DataTree.map_over_subtree (since renamed to .map_over_datasets), which you can think of as just a more complicated version of the same Dataset.map primitive.

Thanks, here is some sample code. I’m taking 10 variables that need to go through a slow function. I’m mocking this up with a sleep here. So worse case would be serial, taking 50 seconds.

import xarray as xr
import dask.array as da
import time
import dask
from dask.distributed import Client

client = Client(n_workers=5)

def double_fun(da):
    time.sleep(5)
    return da*2

@dask.delayed
def delayed_double_fun(da):
    time.sleep(5)
    return da*2

# create a dataset
ds = xr.Dataset(
    {
        "A": (('y', 'x'), da.ones(shape=(10, 10), chunks=(10, 10))),
        "B": (('y', 'x'), da.ones(shape=(10, 10), chunks=(10, 10))),
        "C": (('y', 'x'), da.ones(shape=(10, 10), chunks=(10, 10))),
        "D": (('y', 'x'), da.ones(shape=(10, 10), chunks=(10, 10))),
        "E": (('y', 'x'), da.ones(shape=(10, 10), chunks=(10, 10))),
        "F": (('y', 'x'), da.ones(shape=(10, 10), chunks=(10, 10))),
        "G": (('y', 'x'), da.ones(shape=(10, 10), chunks=(10, 10))),
        "H": (('y', 'x'), da.ones(shape=(10, 10), chunks=(10, 10))),
        "I": (('y', 'x'), da.ones(shape=(10, 10), chunks=(10, 10))),
        "J": (('y', 'x'), da.ones(shape=(10, 10), chunks=(10, 10))),
    }
)

st = time.time()
ds_double = ds.map(double_fun)
et = time.time()
elapsed_time = et - st
print('Map() total time:', elapsed_time, 'seconds')

st = time.time()

results = []
for var in ds.data_vars:
    results.append(delayed_double_fun(ds[var]))

ds_double = xr.merge(dask.compute(results)[0][:])

et = time.time()
elapsed_time = et - st
print('Dask.compute total time:', elapsed_time, 'seconds')

client.close()

As expected, I’m getting 50s on xr.map() but also getting 5.6s on the Dask delayed version. The real scenario is spatial averaging code, which seems particularly slow when I have 40 something small variables. I was just hoping for some way to process these variables in parallel. I tried the Dask delayed option obviously, but it seems to be blowing out the memory or dropping workers, so the the Dask client becomes unstable. I’m not sure why the memory footprint is so large, it seems like it’s trying to pre-load all the data assigned to a worker (meaning multiple variables within worker rather that loading in serial within the worker).