Using Dask to parallelize plotting

@ahuang11 chatted an interesting use case in the pyviz gitter using Dask to parallelize the generation of a bunch of plots.

I hadn’t done that before with Dask so I tried it out:

import xarray as xr
import dask
import matplotlib.pyplot as plt

ds = xr.tutorial.open_dataset('air_temperature').load()

@dask.delayed
def plot(ds, time):
    ds.sel(time=time)['air'].plot()
    plt.savefig(str(time)[:16])

tasks = [plot(ds, time) for time in ds['time'].values]

# just do first few 
dask.compute(tasks[:5], scheduler='processes', num_workers=4)

which produced the pngs just fine (and quickly!), but I got extra colorbars on some:

Here is the reproducible Jupyter Notebook

Can someone help me figure out what’s going on here and how to fix?

2 Likes

Likely call plt.figure() in function which will force to create new figure.

3 Likes

Indeed @ahuang11 , that worked great! So this is a nice little demo:

import xarray as xr
import dask
import matplotlib.pyplot as plt

ds = xr.tutorial.open_dataset('air_temperature').load()

@dask.delayed
def plot(ds, time):
    plt.figure()
    ds.sel(time=time)['air'].plot()
    plt.savefig(str(time)[:16])

tasks = [plot(ds, time) for time in ds['time'].values]

# just do first few 
dask.compute(tasks[:5], scheduler='processes', num_workers=4)

3 Likes

My theory is it might be more optimized if you select the dataset beforehand so it doesn’t have to pass the entire dataset’s content to separate processes, but I guess it doesn’t matter if it’s lazily loaded.

Use xr.map_blocks instead to avoid sending too much data to the workers. This application is why I worked on it :grinning_face_with_smiling_eyes:

See example here: Using dask to save frames in parallel · Issue #6 · jbusecke/xmovie · GitHub (replicated below)

def save_image(block):
    import cartopy.crs as ccrs
    import matplotlib.pyplot as plt
    if sum(block.shape) > 0:
        # workaround 1:
        # xarray passes a zero shaped array to infer what this function returns. 
        # we can't run plot, so avoid doing that
        f = plt.figure()
        ax = f.subplots(1, 1, subplot_kw={"projection": ccrs.PlateCarree()}, squeeze=True)

        # xarray plotting goodness is available here!
        block.plot(ax=ax, robust=True, vmin=5, vmax=28, cmap=mpl.cm.Spectral_r, cbar_kwargs={"extend": "both"})
   
        # on pangeo.io, this will need some tweaking to work with gcsfs.
        # haven't tried that. On cheyenne, it works beautifully.
        f.savefig(f"images/aqua/{block.time.values[0]}.png", dpi=180)
        plt.close(f)

    # workaround 2:
    # map_blocks expects to receive an xarray thing back.
    # Just send back one value. If we send back "block" that's like computing the whole dataset!
    return block["time"]
    

# I want to animate in time, so chunk so that there is 1 block per timestep.
tasks = merged.sst.chunk({"time": 1, "lat": -1, "lon": -1}).map_blocks(save_image)
tasks.compute()
4 Likes

I have had great success using Dask to parallel plot snapshots for daily outputs of a model run.
Then plugged it together for a little movie.

Snow Movie

Didn’t know about the dask decorater. Will try that next.

Thank you for sharing!

4 Likes

Thanks Joachim for contributing this! Welcome to the Pangeo discourse!

I tested

import xarray as xr
import dask
import matplotlib.pyplot as plt

@dask.delayed
def plot(ds, time):
    plt.figure()
    ds.sel(time=time)['air'].plot()
    plt.savefig(str(time)[:16])
    plt.close()

ds = xr.tutorial.open_dataset('air_temperature').isel(
    time=slice(0, 1000))
tasks = [plot(ds, time) for time in ds['time'].values]
dask.compute(tasks, scheduler='processes', num_workers=4)

vs your method

import xarray as xr
import dask
import matplotlib.pyplot as plt

def plot(ds):
    time = ds['time']
    if sum(ds.shape) > 0:
        plt.figure()
        ds.plot()
        plt.savefig(str(time.values[0])[:16])
        plt.close()
    return time

ds = xr.tutorial.open_dataset('air_temperature').isel(
    time=slice(0, 1000))
tasks = ds['air'].chunk({"time": 1, "lat": -1, "lon": -1}).map_blocks(plot)
tasks.compute(num_workers=4, scheduler='processes')

They seem to use more or less the same amount of memory when using scheduler=‘processes’., but the dask.delayed method was 20 seconds faster.

Using threads, dask.delayed creates multi colorbars while the map_blocks errors out with ValueError: Given element not contained in the stack

1 Like

@ahuang11 , I tried your .map_blocks code above also, and in addition to the issues you raise, I found the frames have issues as well:

@dcherian , any thoughts on how to fix/work around?

If you use explicit figure and axes handles, it works fine with threads, otherwise you’ll get a mess

import dask
import matplotlib.pyplot as plt

import xarray as xr


def plot(ds):
    time = ds["time"]
    f, ax = plt.subplots(1,1)
    ds.plot(ax=ax)
    f.savefig(str(time.values[0])[:16])
    plt.close(f)
    return time


ds = xr.tutorial.open_dataset("air_temperature").isel(time=slice(0, 10))
tasks = ds["air"].chunk({"time": 1, "lat": -1, "lon": -1}).map_blocks(plot, template=xr.ones_like(da.time).chunk({"time": 1}))
tasks.compute(num_workers=4, scheduler="threads")

and I don’t see any issues with the output either with threads or processes.

I’m using


INSTALLED VERSIONS
------------------
commit: fe036ae443ecc202a04877b67526133a48963b43
python: 3.8.6 | packaged by conda-forge | (default, Jan 25 2021, 23:21:18) 
[GCC 9.3.0]
python-bits: 64
OS: Linux
OS-release: 5.8.0-40-generic
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: en.UTF-8
LOCALE: en_US.UTF-8
libhdf5: 1.10.6
libnetcdf: 4.7.4

xarray: 0.16.3.dev119+g8a3912c72.d20210202
pandas: 1.2.1
numpy: 1.20.0
scipy: 1.5.3
netCDF4: 1.5.5.1
pydap: installed
h5netcdf: 0.8.1
h5py: 3.1.0
Nio: None
zarr: 2.6.1
cftime: 1.4.1
nc_time_axis: 1.2.0
PseudoNetCDF: None
rasterio: 1.2.0
cfgrib: None
iris: 2.4.0
bottleneck: 1.3.2
dask: 2021.01.1
distributed: 2021.01.1
matplotlib: 3.3.4
cartopy: 0.18.0
seaborn: 0.11.1
numbagg: None
pint: 0.16.1
setuptools: 49.6.0.post20210108
pip: 21.0
conda: 4.9.2
pytest: 6.2.2
IPython: 7.20.0
sphinx: 3.4.3

Re:the delayed solution; I have generally found that you shouldn’t pass dask collections to delayed functions, I think it computes the whole thing and sends it to the function. For example, see

import xarray as xr
import dask
import matplotlib.pyplot as plt

@dask.delayed
def plot(ds, time):
    
    import dask
    if not dask.base.is_dask_collection(ds):
        raise ValueError
    plt.figure()
    ds.sel(time=time)['air'].plot()
    plt.savefig(str(time)[:16])
    plt.close()

ds = xr.tutorial.open_dataset('air_temperature').isel(
    time=slice(0, 10))
tasks = [plot(ds, time) for time in ds['time'].values]
dask.compute(tasks, scheduler='processes', num_workers=4)

This raises the ValueError so it’ll only work if your dataset is small enough. Note that dask.array.Array.to_delayed() exists. This use-case would be better served by implementing Dataset.to_delayed()

1 Like

Ah it needs some tweaks to work with distributed

I needed

import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.use("agg")

in my plot function;

1 Like

About: “”"
Re:the delayed solution; I have generally found that you shouldn’t pass dask collections to delayed functions, I think it computes the whole thing and sends it to the function. For example, see
“”"

Map blocks seems to be the same:

import xarray as xr
import dask
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.use("agg")

def plot(ds):
    time = ds['time']
    
    if not dask.base.is_dask_collection(ds):
        raise ValueError

    if sum(ds.shape) > 0:
        plt.figure()
        ax = plt.axes()
        ds.plot(ax=ax)
        plt.savefig(str(time.values[0])[:16])
        plt.close()
    return time

ds = xr.tutorial.open_dataset('air_temperature').isel(
    time=slice(0, 1000))
tasks = ds['air'].chunk({"time": 1, "lat": -1, "lon": -1}).map_blocks(plot)
tasks.compute(num_workers=4, scheduler='processes')

If I measured it correctly using mprof, it seems map_block actually uses more memory:
Delayed:

import xarray as xr
import dask
import matplotlib.pyplot as plt

@dask.delayed
def plot(ds, time):
    plt.figure()
    ax = plt.axes()
    ds.sel(time=time)['air'].plot(ax=ax)
    plt.savefig(str(time)[:16])
    plt.close()


if __name__ == "__main__":
    ds = xr.tutorial.open_dataset('air_temperature').isel(
        time=slice(0, 1000))
    tasks = [plot(ds, time) for time in ds['time'].values]
    dask.compute(tasks, scheduler='processes', num_workers=4)

Map Blocks:

import xarray as xr
import dask
import matplotlib.pyplot as plt

def plot(ds):
    time = ds['time']
    if sum(ds.shape) > 0:
        plt.figure()
        ax = plt.axes()
        ds.plot(ax=ax)
        plt.savefig(str(time.values[0])[:16])
        plt.close()
    return time

ds = xr.tutorial.open_dataset('air_temperature').isel(
    time=slice(0, 1000))
tasks = ds['air'].chunk({"time": 1, "lat": -1, "lon": -1}).map_blocks(plot)
tasks.compute(num_workers=4, scheduler='processes')

Also it seems to take a teeny longer to run as well (50 seconds vs 48); er maybe not. I reran the delayed, and now it’s 52 seconds. might just be my computer overheating from performing these tests, but the memory usage is consistent

Using multi threads doesn’t seem to speed things up at all

Hi everyone,

Thanks for your comments and replies. I have been working on something similar. I have a large dataset in which I would like to create some plots (apply a function) at every timestamp. I have tried the same approach and when the number of plots/tasks is small it works perfectly. However, if the number of plots/tasks is large enough (10000 timestamps) a RuntimeError arises.
Something like:

RuntimeError: dictionary changed size during iteration

Any ideas of what I could be doing wrong? I’ve tried dask delayed approach but I haven’t had luck.

Maybe scheduler=‘processes’, or else some race conditions may arise.

A note to future folks who find this thread: I was running into issues with one of my datasets where my map_blocks code was really slow, and after a lot of debugging, I wasn’t seeing obvious improvements over running in serial. But then I applied the same code to a dataset where (1) each chunk size was smaller (first dataset had chunksizes ~100 MB, second dataset had chunksizes ~2 MB, (2) each chunk had a size of 1 in the time dimension (as opposed to a size of 2 in the time dimension for the first dataset), and (3) I was just plotting a lot less data (<1 MB of data in each figure vs. ~100 MB of data in each figure). If your code is unexpectedly slow, consider exploring these factors in greater detail.

Here’s my latest iteration:

  • uses matplotlib.use("Agg") to make multithreading work
  • selects time outside plot to prevent sending too much data
  • set vmin/vmax to prevent colorbar jumping around

I prefer this over map_blocks because map_blocks requires an xarray object returned

import xarray as xr
import dask
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import imageio.v3 as iio

@dask.delayed
def plot(ds_sel, time):
    fig, ax = plt.subplots()
    ds_sel['air'].plot(ax=ax, vmin=220, vmax=340)
    path = f"test_{time}.png"
    fig.savefig(path)
    plt.close(fig)
    return path

ds = xr.tutorial.open_dataset('air_temperature').isel(time=slice(0, 100))
tasks = [plot(ds.sel(time=time), time) for time in ds['time'].values]
paths = dask.compute(tasks, scheduler="threads")[0]
images = [iio.imread(path) for path in paths]
iio.imwrite("test.gif", images, fps=60, loop=True)

test

1 Like

After tinkering a bit more, I discovered that I can stream the output directly to the GIF; writes and animates 1000 images in 35 seconds.

from io import BytesIO
import xarray as xr
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import imageio.v3 as iio
from dask.distributed import Client


def process(ds_sel, time):
    fig, ax = plt.subplots()
    ds_sel["air"].plot(ax=ax, vmin=220, vmax=340)
    with BytesIO() as buf:
        fig.savefig(buf, format="png")
        plt.close(fig)
        image = iio.imread(buf.getvalue())
    return image


client = Client()
display(client)
ds = xr.tutorial.open_dataset("air_temperature").isel(time=slice(0, 1000))
futures = [
    client.submit(process, ds.sel(time=time), time) for time in ds["time"].values
]

with iio.imopen("test.gif", "w", extension=".gif") as f:
    for future in futures:
        image = future.result()
        f.write(image)
5 Likes

@ahuang11, tried this out, worked perfectly. 40 seconds with 10 workers, 20 seconds with 20 workers. Awesome!

1 Like

Awesome! I’m writing a new package that encapsulates the animation part so users only have to pass in a list, output path, and an optional callable.

1 Like