Using Dask to parallelize plotting

@ahuang11 chatted an interesting use case in the pyviz gitter using Dask to parallelize the generation of a bunch of plots.

I hadn’t done that before with Dask so I tried it out:

import xarray as xr
import dask
import matplotlib.pyplot as plt

ds = xr.tutorial.open_dataset('air_temperature').load()

@dask.delayed
def plot(ds, time):
    ds.sel(time=time)['air'].plot()
    plt.savefig(str(time)[:16])

tasks = [plot(ds, time) for time in ds['time'].values]

# just do first few 
dask.compute(tasks[:5], scheduler='processes', num_workers=4)

which produced the pngs just fine (and quickly!), but I got extra colorbars on some:

Here is the reproducible Jupyter Notebook

Can someone help me figure out what’s going on here and how to fix?

2 Likes

Likely call plt.figure() in function which will force to create new figure.

3 Likes

Indeed @ahuang11 , that worked great! So this is a nice little demo:

import xarray as xr
import dask
import matplotlib.pyplot as plt

ds = xr.tutorial.open_dataset('air_temperature').load()

@dask.delayed
def plot(ds, time):
    plt.figure()
    ds.sel(time=time)['air'].plot()
    plt.savefig(str(time)[:16])

tasks = [plot(ds, time) for time in ds['time'].values]

# just do first few 
dask.compute(tasks[:5], scheduler='processes', num_workers=4)

3 Likes

My theory is it might be more optimized if you select the dataset beforehand so it doesn’t have to pass the entire dataset’s content to separate processes, but I guess it doesn’t matter if it’s lazily loaded.

Use xr.map_blocks instead to avoid sending too much data to the workers. This application is why I worked on it :grinning_face_with_smiling_eyes:

See example here: Using dask to save frames in parallel · Issue #6 · jbusecke/xmovie · GitHub (replicated below)

def save_image(block):
    import cartopy.crs as ccrs
    import matplotlib.pyplot as plt
    if sum(block.shape) > 0:
        # workaround 1:
        # xarray passes a zero shaped array to infer what this function returns. 
        # we can't run plot, so avoid doing that
        f = plt.figure()
        ax = f.subplots(1, 1, subplot_kw={"projection": ccrs.PlateCarree()}, squeeze=True)

        # xarray plotting goodness is available here!
        block.plot(ax=ax, robust=True, vmin=5, vmax=28, cmap=mpl.cm.Spectral_r, cbar_kwargs={"extend": "both"})
   
        # on pangeo.io, this will need some tweaking to work with gcsfs.
        # haven't tried that. On cheyenne, it works beautifully.
        f.savefig(f"images/aqua/{block.time.values[0]}.png", dpi=180)
        plt.close(f)

    # workaround 2:
    # map_blocks expects to receive an xarray thing back.
    # Just send back one value. If we send back "block" that's like computing the whole dataset!
    return block["time"]
    

# I want to animate in time, so chunk so that there is 1 block per timestep.
tasks = merged.sst.chunk({"time": 1, "lat": -1, "lon": -1}).map_blocks(save_image)
tasks.compute()
3 Likes

I have had great success using Dask to parallel plot snapshots for daily outputs of a model run.
Then plugged it together for a little movie.

Snow Movie

Didn’t know about the dask decorater. Will try that next.

Thank you for sharing!

4 Likes

Thanks Joachim for contributing this! Welcome to the Pangeo discourse!

I tested

import xarray as xr
import dask
import matplotlib.pyplot as plt

@dask.delayed
def plot(ds, time):
    plt.figure()
    ds.sel(time=time)['air'].plot()
    plt.savefig(str(time)[:16])
    plt.close()

ds = xr.tutorial.open_dataset('air_temperature').isel(
    time=slice(0, 1000))
tasks = [plot(ds, time) for time in ds['time'].values]
dask.compute(tasks, scheduler='processes', num_workers=4)

vs your method

import xarray as xr
import dask
import matplotlib.pyplot as plt

def plot(ds):
    time = ds['time']
    if sum(ds.shape) > 0:
        plt.figure()
        ds.plot()
        plt.savefig(str(time.values[0])[:16])
        plt.close()
    return time

ds = xr.tutorial.open_dataset('air_temperature').isel(
    time=slice(0, 1000))
tasks = ds['air'].chunk({"time": 1, "lat": -1, "lon": -1}).map_blocks(plot)
tasks.compute(num_workers=4, scheduler='processes')

They seem to use more or less the same amount of memory when using scheduler=‘processes’., but the dask.delayed method was 20 seconds faster.

Using threads, dask.delayed creates multi colorbars while the map_blocks errors out with ValueError: Given element not contained in the stack

1 Like

@ahuang11 , I tried your .map_blocks code above also, and in addition to the issues you raise, I found the frames have issues as well:

@dcherian , any thoughts on how to fix/work around?

If you use explicit figure and axes handles, it works fine with threads, otherwise you’ll get a mess

import dask
import matplotlib.pyplot as plt

import xarray as xr


def plot(ds):
    time = ds["time"]
    f, ax = plt.subplots(1,1)
    ds.plot(ax=ax)
    f.savefig(str(time.values[0])[:16])
    plt.close(f)
    return time


ds = xr.tutorial.open_dataset("air_temperature").isel(time=slice(0, 10))
tasks = ds["air"].chunk({"time": 1, "lat": -1, "lon": -1}).map_blocks(plot, template=xr.ones_like(da.time).chunk({"time": 1}))
tasks.compute(num_workers=4, scheduler="threads")

and I don’t see any issues with the output either with threads or processes.

I’m using


INSTALLED VERSIONS
------------------
commit: fe036ae443ecc202a04877b67526133a48963b43
python: 3.8.6 | packaged by conda-forge | (default, Jan 25 2021, 23:21:18) 
[GCC 9.3.0]
python-bits: 64
OS: Linux
OS-release: 5.8.0-40-generic
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: en.UTF-8
LOCALE: en_US.UTF-8
libhdf5: 1.10.6
libnetcdf: 4.7.4

xarray: 0.16.3.dev119+g8a3912c72.d20210202
pandas: 1.2.1
numpy: 1.20.0
scipy: 1.5.3
netCDF4: 1.5.5.1
pydap: installed
h5netcdf: 0.8.1
h5py: 3.1.0
Nio: None
zarr: 2.6.1
cftime: 1.4.1
nc_time_axis: 1.2.0
PseudoNetCDF: None
rasterio: 1.2.0
cfgrib: None
iris: 2.4.0
bottleneck: 1.3.2
dask: 2021.01.1
distributed: 2021.01.1
matplotlib: 3.3.4
cartopy: 0.18.0
seaborn: 0.11.1
numbagg: None
pint: 0.16.1
setuptools: 49.6.0.post20210108
pip: 21.0
conda: 4.9.2
pytest: 6.2.2
IPython: 7.20.0
sphinx: 3.4.3

Re:the delayed solution; I have generally found that you shouldn’t pass dask collections to delayed functions, I think it computes the whole thing and sends it to the function. For example, see

import xarray as xr
import dask
import matplotlib.pyplot as plt

@dask.delayed
def plot(ds, time):
    
    import dask
    if not dask.base.is_dask_collection(ds):
        raise ValueError
    plt.figure()
    ds.sel(time=time)['air'].plot()
    plt.savefig(str(time)[:16])
    plt.close()

ds = xr.tutorial.open_dataset('air_temperature').isel(
    time=slice(0, 10))
tasks = [plot(ds, time) for time in ds['time'].values]
dask.compute(tasks, scheduler='processes', num_workers=4)

This raises the ValueError so it’ll only work if your dataset is small enough. Note that dask.array.Array.to_delayed() exists. This use-case would be better served by implementing Dataset.to_delayed()

1 Like

Ah it needs some tweaks to work with distributed

I needed

import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.use("agg")

in my plot function;

1 Like

About: “”"
Re:the delayed solution; I have generally found that you shouldn’t pass dask collections to delayed functions, I think it computes the whole thing and sends it to the function. For example, see
“”"

Map blocks seems to be the same:

import xarray as xr
import dask
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.use("agg")

def plot(ds):
    time = ds['time']
    
    if not dask.base.is_dask_collection(ds):
        raise ValueError

    if sum(ds.shape) > 0:
        plt.figure()
        ax = plt.axes()
        ds.plot(ax=ax)
        plt.savefig(str(time.values[0])[:16])
        plt.close()
    return time

ds = xr.tutorial.open_dataset('air_temperature').isel(
    time=slice(0, 1000))
tasks = ds['air'].chunk({"time": 1, "lat": -1, "lon": -1}).map_blocks(plot)
tasks.compute(num_workers=4, scheduler='processes')

If I measured it correctly using mprof, it seems map_block actually uses more memory:
Delayed:

import xarray as xr
import dask
import matplotlib.pyplot as plt

@dask.delayed
def plot(ds, time):
    plt.figure()
    ax = plt.axes()
    ds.sel(time=time)['air'].plot(ax=ax)
    plt.savefig(str(time)[:16])
    plt.close()


if __name__ == "__main__":
    ds = xr.tutorial.open_dataset('air_temperature').isel(
        time=slice(0, 1000))
    tasks = [plot(ds, time) for time in ds['time'].values]
    dask.compute(tasks, scheduler='processes', num_workers=4)

Map Blocks:

import xarray as xr
import dask
import matplotlib.pyplot as plt

def plot(ds):
    time = ds['time']
    if sum(ds.shape) > 0:
        plt.figure()
        ax = plt.axes()
        ds.plot(ax=ax)
        plt.savefig(str(time.values[0])[:16])
        plt.close()
    return time

ds = xr.tutorial.open_dataset('air_temperature').isel(
    time=slice(0, 1000))
tasks = ds['air'].chunk({"time": 1, "lat": -1, "lon": -1}).map_blocks(plot)
tasks.compute(num_workers=4, scheduler='processes')

Also it seems to take a teeny longer to run as well (50 seconds vs 48); er maybe not. I reran the delayed, and now it’s 52 seconds. might just be my computer overheating from performing these tests, but the memory usage is consistent

Using multi threads doesn’t seem to speed things up at all