Running into TypeError when plotting with xarray

Hi! I’m running into an unusual error when I try to plot the result from the ff code:

# Download reanalysis data
woa_o2 = xr.open_mfdataset(['https://www.ncei.noaa.gov/thredds-ocean/dodsC/ncei/woa/oxygen/all/1.00/woa18_all_o06_01.nc'], decode_times=False)
obs_grid = xr.Dataset({"lat": (["lat"], woa_o2.lat.data), "lon": (["lon"], woa_o2.lon.data)})
obs_lev = woa_o2.depth.data

# Download CMIP data
cat_url = "https://storage.googleapis.com/cmip6/pangeo-cmip6-noQC.json"
col = intake.open_esm_datastore(cat_url)
cat = col.search(table_id='Omon', activity_id='CMIP', experiment_id='historical', variable_id=['thetao','so'], 
                 grid_label='gn', member_id='r1i1p1f1', source_id='MPI-ESM1-2-LR')
cmip6_compiled = cat.to_dataset_dict(zarr_kwargs={'consolidated':True, 'decode_times': True, 'use_cftime': True}, preprocess=combined_preprocessing)

basins = regionmask.defined_regions.natural_earth.ocean_basins_50
subset_temp = {}

for k, ds in cmip6_compiled.items():
         
    if 'member_id' in ds.dims:
        ds = ds.isel(member_id=0).unify_chunks() 
        
    for key, value in ds.items():
        mask = merged_mask(basins,ds[key])
        var_masked = ds[key].where(mask == 2)
        regridder = xe.Regridder(var_masked, obs_grid, 'nearest_s2d')
        var_regridded = regridder(var_masked)
        var_subset = var_regridded.sel(time=slice('2014-01-01','2014-12-31'))
        var_subset = var_subset.sel(lon=slice(-130,-110),lat=slice(30,45)).mean(dim='time').squeeze()

        # Save data into xarray Dataset
        var_name = key + "_subset"
        subset_temp[var_name] = var_subset
        subset = xr.Dataset(subset_temp)

    # Interpolate into common vertical grid using xgcm
    grid = Grid(subset, coords={'Z': {'center': 'lev'}}, periodic=False)

    for varname, varvalue in subset.items(): 
        var_interp = grid.transform(subset[varname], 'Z', obs_lev, target_data=None, method='linear')
        var_subset_plot = var_interp.sel(lev=10).squeeze()

var_subset can be plotted but when I try with var_subset_plot:

plt.figure(figsize=(8,4))
ax = plt.axes(projection=ccrs.Mercator())
var_subset_plot.plot.pcolormesh(yincrease=True, x='lon', y='lat', levels=31, ax=ax, robust=True, add_colorbar=False, transform=ccrs.PlateCarree())

the ff error appears:

TypeError                                 Traceback (most recent call last)
<ipython-input-18-ed50ece6f080> in <module>
      1 plt.figure(figsize=(8,4))
      2 ax = plt.axes(projection=ccrs.Mercator())
----> 3 var_subset_plot.plot.pcolormesh(yincrease=True, x='lon', y='lat', levels=31, ax=ax, robust=True, add_colorbar=False, transform=ccrs.PlateCarree())

/srv/conda/envs/notebook/lib/python3.8/site-packages/xarray/plot/plot.py in plotmethod(_PlotMethods_obj, x, y, figsize, size, aspect, ax, row, col, col_wrap, xincrease, yincrease, add_colorbar, add_labels, vmin, vmax, cmap, colors, center, robust, extend, levels, infer_intervals, subplot_kws, cbar_ax, cbar_kwargs, xscale, yscale, xticks, yticks, xlim, ylim, norm, **kwargs)
   1306         for arg in ["_PlotMethods_obj", "newplotfunc", "kwargs"]:
   1307             del allargs[arg]
-> 1308         return newplotfunc(**allargs)
   1309 
   1310     # Add to class _PlotMethods

/srv/conda/envs/notebook/lib/python3.8/site-packages/xarray/plot/plot.py in newplotfunc(darray, x, y, figsize, size, aspect, ax, row, col, col_wrap, xincrease, yincrease, add_colorbar, add_labels, vmin, vmax, cmap, center, robust, extend, levels, infer_intervals, colors, subplot_kws, cbar_ax, cbar_kwargs, xscale, yscale, xticks, yticks, xlim, ylim, norm, **kwargs)
   1170 
   1171         # Pass the data as a masked ndarray too
-> 1172         zval = darray.to_masked_array(copy=False)
   1173 
   1174         # Replace pd.Intervals if contained in xval or yval.

/srv/conda/envs/notebook/lib/python3.8/site-packages/xarray/core/dataarray.py in to_masked_array(self, copy)
   2785             Masked where invalid values (nan or inf) occur.
   2786         """
-> 2787         values = self.to_numpy()  # only compute lazy arrays once
   2788         isnull = pd.isnull(values)
   2789         return np.ma.MaskedArray(data=values, mask=isnull, copy=copy)

/srv/conda/envs/notebook/lib/python3.8/site-packages/xarray/core/dataarray.py in to_numpy(self)
    666         DataArray.data
    667         """
--> 668         return self.variable.to_numpy()
    669 
    670     def as_numpy(self: T_DataArray) -> T_DataArray:

/srv/conda/envs/notebook/lib/python3.8/site-packages/xarray/core/variable.py in to_numpy(self)
   1074         # TODO first attempt to call .to_numpy() once some libraries implement it
   1075         if isinstance(data, dask_array_type):
-> 1076             data = data.compute()
   1077         if isinstance(data, cupy_array_type):
   1078             data = data.get()

/srv/conda/envs/notebook/lib/python3.8/site-packages/dask/base.py in compute(self, **kwargs)
    277         dask.base.compute
    278         """
--> 279         (result,) = compute(self, traverse=False, **kwargs)
    280         return result
    281 

/srv/conda/envs/notebook/lib/python3.8/site-packages/dask/base.py in compute(*args, **kwargs)
    559         postcomputes.append(x.__dask_postcompute__())
    560 
--> 561     results = schedule(dsk, keys, **kwargs)
    562     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    563 

/srv/conda/envs/notebook/lib/python3.8/site-packages/distributed/client.py in get(self, dsk, keys, restrictions, loose_restrictions, resources, sync, asynchronous, direct, retries, priority, fifo_timeout, actors, **kwargs)
   2682                     should_rejoin = False
   2683             try:
-> 2684                 results = self.gather(packed, asynchronous=asynchronous, direct=direct)
   2685             finally:
   2686                 for f in futures.values():

/srv/conda/envs/notebook/lib/python3.8/site-packages/distributed/client.py in gather(self, futures, errors, direct, asynchronous)
   1991             else:
   1992                 local_worker = None
-> 1993             return self.sync(
   1994                 self._gather,
   1995                 futures,

/srv/conda/envs/notebook/lib/python3.8/site-packages/distributed/client.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
    837             return future
    838         else:
--> 839             return sync(
    840                 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
    841             )

/srv/conda/envs/notebook/lib/python3.8/site-packages/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs)
    338     if error[0]:
    339         typ, exc, tb = error[0]
--> 340         raise exc.with_traceback(tb)
    341     else:
    342         return result[0]

/srv/conda/envs/notebook/lib/python3.8/site-packages/distributed/utils.py in f()
    322             if callback_timeout is not None:
    323                 future = asyncio.wait_for(future, callback_timeout)
--> 324             result[0] = yield future
    325         except Exception as exc:
    326             error[0] = sys.exc_info()

/srv/conda/envs/notebook/lib/python3.8/site-packages/tornado/gen.py in run(self)
    760 
    761                     try:
--> 762                         value = future.result()
    763                     except Exception:
    764                         exc_info = sys.exc_info()

/srv/conda/envs/notebook/lib/python3.8/site-packages/distributed/client.py in _gather(self, futures, errors, direct, local_worker)
   1856                             exc = CancelledError(key)
   1857                         else:
-> 1858                             raise exception.with_traceback(traceback)
   1859                         raise exc
   1860                     if errors == "skip":

/srv/conda/envs/notebook/lib/python3.8/site-packages/dask/optimization.py in __call__()
    961         if not len(args) == len(self.inkeys):
    962             raise ValueError("Expected %d args, got %d" % (len(self.inkeys), len(args)))
--> 963         return core.get(self.dsk, self.outkey, dict(zip(self.inkeys, args)))
    964 
    965     def __reduce__(self):

/srv/conda/envs/notebook/lib/python3.8/site-packages/dask/core.py in get()
    149     for key in toposort(dsk):
    150         task = dsk[key]
--> 151         result = _execute_task(task, cache)
    152         cache[key] = result
    153     result = _execute_task(out, cache)

/srv/conda/envs/notebook/lib/python3.8/site-packages/dask/core.py in _execute_task()
    119         # temporaries by their reference count and can execute certain
    120         # operations in-place.
--> 121         return func(*(_execute_task(a, cache) for a in args))
    122     elif not ishashable(arg):
    123         return arg

TypeError: interp_1d_linear() got an unexpected keyword argument 'bypass_checks'

I’ve tried looking through xarray and dask but I can’t seem to find where the error is pointing to specifically. I would appreciate any insights, thank you!

Package versions:

xarray='0.19.0'
dask='2021.1.1'
dask-gateway='0.9.0'
intake='0.6.0'
intake-esm='2021.1.15'
cmip6-preprocessing='0.5.0'
xesmf='0.5.1'
xgcm='0.5.3.dev28+g95f4f33'
regionmask='0.8.1.dev1+g0872147'
Cartopy='0.19.0.post1'
matplotlib='3.3.4' 
1 Like

This looks like an xgcm error: xgcm/transform.py at master · xgcm/xgcm · GitHub

cc @jbusecke

1 Like

Oh wow sorry about that, because I could see the subset I thought the transformation was okay and the error was with xarray plot or dask :sweat_smile: :woman_facepalming:t4: Thanks for the tip Deepak! I’ll double check how I use xgcm transform.

Note: var_subset_plot plots fine if I don’t update xarray nor use a dask cluster (if xarray=0.16.2 is used). So maybe xgcm/transform just needs some updating similar to cmip6_preprocessing (PR 174)? I’ll try testing xgcm/transform/interp_1d_linear() with xarray=0.19.0.

Just an update, to avoid the error, first install an updated xgcm version:

!pip install xgcm=='0.5.2' --upgrade
import xgcm
from xgcm import Grid

Then include it in PipInstallPlugin when using dask:

plugin = PipPlugin(['xarray==0.19.0', 'xgcm==0.5.2'])

Many thanks to @jsignell!

2 Likes