Using Dask to parallelize plotting

@ahuang11 chatted an interesting use case in the pyviz gitter using Dask to parallelize the generation of a bunch of plots.

I hadn’t done that before with Dask so I tried it out:

import xarray as xr
import dask
import matplotlib.pyplot as plt

ds = xr.tutorial.open_dataset('air_temperature').load()

@dask.delayed
def plot(ds, time):
    ds.sel(time=time)['air'].plot()
    plt.savefig(str(time)[:16])

tasks = [plot(ds, time) for time in ds['time'].values]

# just do first few 
dask.compute(tasks[:5], scheduler='processes', num_workers=4)

which produced the pngs just fine (and quickly!), but I got extra colorbars on some:

Here is the reproducible Jupyter Notebook

Can someone help me figure out what’s going on here and how to fix?

2 Likes

Likely call plt.figure() in function which will force to create new figure.

3 Likes

Indeed @ahuang11 , that worked great! So this is a nice little demo:

import xarray as xr
import dask
import matplotlib.pyplot as plt

ds = xr.tutorial.open_dataset('air_temperature').load()

@dask.delayed
def plot(ds, time):
    plt.figure()
    ds.sel(time=time)['air'].plot()
    plt.savefig(str(time)[:16])

tasks = [plot(ds, time) for time in ds['time'].values]

# just do first few 
dask.compute(tasks[:5], scheduler='processes', num_workers=4)

3 Likes

My theory is it might be more optimized if you select the dataset beforehand so it doesn’t have to pass the entire dataset’s content to separate processes, but I guess it doesn’t matter if it’s lazily loaded.

Use xr.map_blocks instead to avoid sending too much data to the workers. This application is why I worked on it :grinning_face_with_smiling_eyes:

See example here: Using dask to save frames in parallel · Issue #6 · jbusecke/xmovie · GitHub (replicated below)

def save_image(block):
    import cartopy.crs as ccrs
    import matplotlib.pyplot as plt
    if sum(block.shape) > 0:
        # workaround 1:
        # xarray passes a zero shaped array to infer what this function returns. 
        # we can't run plot, so avoid doing that
        f = plt.figure()
        ax = f.subplots(1, 1, subplot_kw={"projection": ccrs.PlateCarree()}, squeeze=True)

        # xarray plotting goodness is available here!
        block.plot(ax=ax, robust=True, vmin=5, vmax=28, cmap=mpl.cm.Spectral_r, cbar_kwargs={"extend": "both"})
   
        # on pangeo.io, this will need some tweaking to work with gcsfs.
        # haven't tried that. On cheyenne, it works beautifully.
        f.savefig(f"images/aqua/{block.time.values[0]}.png", dpi=180)
        plt.close(f)

    # workaround 2:
    # map_blocks expects to receive an xarray thing back.
    # Just send back one value. If we send back "block" that's like computing the whole dataset!
    return block["time"]
    

# I want to animate in time, so chunk so that there is 1 block per timestep.
tasks = merged.sst.chunk({"time": 1, "lat": -1, "lon": -1}).map_blocks(save_image)
tasks.compute()
3 Likes

I have had great success using Dask to parallel plot snapshots for daily outputs of a model run.
Then plugged it together for a little movie.

Snow Movie

Didn’t know about the dask decorater. Will try that next.

Thank you for sharing!

4 Likes

Thanks Joachim for contributing this! Welcome to the Pangeo discourse!

I tested

import xarray as xr
import dask
import matplotlib.pyplot as plt

@dask.delayed
def plot(ds, time):
    plt.figure()
    ds.sel(time=time)['air'].plot()
    plt.savefig(str(time)[:16])
    plt.close()

ds = xr.tutorial.open_dataset('air_temperature').isel(
    time=slice(0, 1000))
tasks = [plot(ds, time) for time in ds['time'].values]
dask.compute(tasks, scheduler='processes', num_workers=4)

vs your method

import xarray as xr
import dask
import matplotlib.pyplot as plt

def plot(ds):
    time = ds['time']
    if sum(ds.shape) > 0:
        plt.figure()
        ds.plot()
        plt.savefig(str(time.values[0])[:16])
        plt.close()
    return time

ds = xr.tutorial.open_dataset('air_temperature').isel(
    time=slice(0, 1000))
tasks = ds['air'].chunk({"time": 1, "lat": -1, "lon": -1}).map_blocks(plot)
tasks.compute(num_workers=4, scheduler='processes')

They seem to use more or less the same amount of memory when using scheduler=‘processes’., but the dask.delayed method was 20 seconds faster.

Using threads, dask.delayed creates multi colorbars while the map_blocks errors out with ValueError: Given element not contained in the stack

1 Like

@ahuang11 , I tried your .map_blocks code above also, and in addition to the issues you raise, I found the frames have issues as well:

@dcherian , any thoughts on how to fix/work around?

If you use explicit figure and axes handles, it works fine with threads, otherwise you’ll get a mess

import dask
import matplotlib.pyplot as plt

import xarray as xr


def plot(ds):
    time = ds["time"]
    if sum(ds.shape) > 0:
        f, ax = plt.subplots(1,1)
        ds.plot(ax=ax)
        f.savefig(str(time.values[0])[:16])
        plt.close(f)
    return time


ds = xr.tutorial.open_dataset("air_temperature").isel(time=slice(0, 10))
tasks = ds["air"].chunk({"time": 1, "lat": -1, "lon": -1}).map_blocks(plot)
tasks.compute(num_workers=4, scheduler="threads")

and I don’t see any issues with the output either with threads or processes.

I’m using


INSTALLED VERSIONS
------------------
commit: fe036ae443ecc202a04877b67526133a48963b43
python: 3.8.6 | packaged by conda-forge | (default, Jan 25 2021, 23:21:18) 
[GCC 9.3.0]
python-bits: 64
OS: Linux
OS-release: 5.8.0-40-generic
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: en.UTF-8
LOCALE: en_US.UTF-8
libhdf5: 1.10.6
libnetcdf: 4.7.4

xarray: 0.16.3.dev119+g8a3912c72.d20210202
pandas: 1.2.1
numpy: 1.20.0
scipy: 1.5.3
netCDF4: 1.5.5.1
pydap: installed
h5netcdf: 0.8.1
h5py: 3.1.0
Nio: None
zarr: 2.6.1
cftime: 1.4.1
nc_time_axis: 1.2.0
PseudoNetCDF: None
rasterio: 1.2.0
cfgrib: None
iris: 2.4.0
bottleneck: 1.3.2
dask: 2021.01.1
distributed: 2021.01.1
matplotlib: 3.3.4
cartopy: 0.18.0
seaborn: 0.11.1
numbagg: None
pint: 0.16.1
setuptools: 49.6.0.post20210108
pip: 21.0
conda: 4.9.2
pytest: 6.2.2
IPython: 7.20.0
sphinx: 3.4.3

Re:the delayed solution; I have generally found that you shouldn’t pass dask collections to delayed functions, I think it computes the whole thing and sends it to the function. For example, see

import xarray as xr
import dask
import matplotlib.pyplot as plt

@dask.delayed
def plot(ds, time):
    
    import dask
    if not dask.base.is_dask_collection(ds):
        raise ValueError
    plt.figure()
    ds.sel(time=time)['air'].plot()
    plt.savefig(str(time)[:16])
    plt.close()

ds = xr.tutorial.open_dataset('air_temperature').isel(
    time=slice(0, 10))
tasks = [plot(ds, time) for time in ds['time'].values]
dask.compute(tasks, scheduler='processes', num_workers=4)

This raises the ValueError so it’ll only work if your dataset is small enough. Note that dask.array.Array.to_delayed() exists. This use-case would be better served by implementing Dataset.to_delayed()

1 Like