I hadn’t done that before with Dask so I tried it out:
import xarray as xr
import dask
import matplotlib.pyplot as plt
ds = xr.tutorial.open_dataset('air_temperature').load()
@dask.delayed
def plot(ds, time):
ds.sel(time=time)['air'].plot()
plt.savefig(str(time)[:16])
tasks = [plot(ds, time) for time in ds['time'].values]
# just do first few
dask.compute(tasks[:5], scheduler='processes', num_workers=4)
which produced the pngs just fine (and quickly!), but I got extra colorbars on some:
Indeed @ahuang11 , that worked great! So this is a nice little demo:
import xarray as xr
import dask
import matplotlib.pyplot as plt
ds = xr.tutorial.open_dataset('air_temperature').load()
@dask.delayed
def plot(ds, time):
plt.figure()
ds.sel(time=time)['air'].plot()
plt.savefig(str(time)[:16])
tasks = [plot(ds, time) for time in ds['time'].values]
# just do first few
dask.compute(tasks[:5], scheduler='processes', num_workers=4)
My theory is it might be more optimized if you select the dataset beforehand so it doesn’t have to pass the entire dataset’s content to separate processes, but I guess it doesn’t matter if it’s lazily loaded.
def save_image(block):
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
if sum(block.shape) > 0:
# workaround 1:
# xarray passes a zero shaped array to infer what this function returns.
# we can't run plot, so avoid doing that
f = plt.figure()
ax = f.subplots(1, 1, subplot_kw={"projection": ccrs.PlateCarree()}, squeeze=True)
# xarray plotting goodness is available here!
block.plot(ax=ax, robust=True, vmin=5, vmax=28, cmap=mpl.cm.Spectral_r, cbar_kwargs={"extend": "both"})
# on pangeo.io, this will need some tweaking to work with gcsfs.
# haven't tried that. On cheyenne, it works beautifully.
f.savefig(f"images/aqua/{block.time.values[0]}.png", dpi=180)
plt.close(f)
# workaround 2:
# map_blocks expects to receive an xarray thing back.
# Just send back one value. If we send back "block" that's like computing the whole dataset!
return block["time"]
# I want to animate in time, so chunk so that there is 1 block per timestep.
tasks = merged.sst.chunk({"time": 1, "lat": -1, "lon": -1}).map_blocks(save_image)
tasks.compute()
Re:the delayed solution; I have generally found that you shouldn’t pass dask collections to delayed functions, I think it computes the whole thing and sends it to the function. For example, see
import xarray as xr
import dask
import matplotlib.pyplot as plt
@dask.delayed
def plot(ds, time):
import dask
if not dask.base.is_dask_collection(ds):
raise ValueError
plt.figure()
ds.sel(time=time)['air'].plot()
plt.savefig(str(time)[:16])
plt.close()
ds = xr.tutorial.open_dataset('air_temperature').isel(
time=slice(0, 10))
tasks = [plot(ds, time) for time in ds['time'].values]
dask.compute(tasks, scheduler='processes', num_workers=4)
This raises the ValueError so it’ll only work if your dataset is small enough. Note that dask.array.Array.to_delayed() exists. This use-case would be better served by implementing Dataset.to_delayed()
About: “”" Re:the delayed solution; I have generally found that you shouldn’t pass dask collections to delayed functions, I think it computes the whole thing and sends it to the function. For example, see
“”"
Map blocks seems to be the same:
import xarray as xr
import dask
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.use("agg")
def plot(ds):
time = ds['time']
if not dask.base.is_dask_collection(ds):
raise ValueError
if sum(ds.shape) > 0:
plt.figure()
ax = plt.axes()
ds.plot(ax=ax)
plt.savefig(str(time.values[0])[:16])
plt.close()
return time
ds = xr.tutorial.open_dataset('air_temperature').isel(
time=slice(0, 1000))
tasks = ds['air'].chunk({"time": 1, "lat": -1, "lon": -1}).map_blocks(plot)
tasks.compute(num_workers=4, scheduler='processes')
import xarray as xr
import dask
import matplotlib.pyplot as plt
def plot(ds):
time = ds['time']
if sum(ds.shape) > 0:
plt.figure()
ax = plt.axes()
ds.plot(ax=ax)
plt.savefig(str(time.values[0])[:16])
plt.close()
return time
ds = xr.tutorial.open_dataset('air_temperature').isel(
time=slice(0, 1000))
tasks = ds['air'].chunk({"time": 1, "lat": -1, "lon": -1}).map_blocks(plot)
tasks.compute(num_workers=4, scheduler='processes')
Also it seems to take a teeny longer to run as well (50 seconds vs 48); er maybe not. I reran the delayed, and now it’s 52 seconds. might just be my computer overheating from performing these tests, but the memory usage is consistent
Using multi threads doesn’t seem to speed things up at all
Thanks for your comments and replies. I have been working on something similar. I have a large dataset in which I would like to create some plots (apply a function) at every timestamp. I have tried the same approach and when the number of plots/tasks is small it works perfectly. However, if the number of plots/tasks is large enough (10000 timestamps) a RuntimeError arises.
Something like:
RuntimeError: dictionary changed size during iteration
Any ideas of what I could be doing wrong? I’ve tried dask delayed approach but I haven’t had luck.
A note to future folks who find this thread: I was running into issues with one of my datasets where my map_blocks code was really slow, and after a lot of debugging, I wasn’t seeing obvious improvements over running in serial. But then I applied the same code to a dataset where (1) each chunk size was smaller (first dataset had chunksizes ~100 MB, second dataset had chunksizes ~2 MB, (2) each chunk had a size of 1 in the time dimension (as opposed to a size of 2 in the time dimension for the first dataset), and (3) I was just plotting a lot less data (<1 MB of data in each figure vs. ~100 MB of data in each figure). If your code is unexpectedly slow, consider exploring these factors in greater detail.