Question on DASK efficiency

Hello all,

I would like to ask a question to the community of DASK people, about whether it is worth trying to optimize the computation time of a function applied on a single file. And what I mean by this is that I have the following temperature file, which I tried to open lazily with dask and chunks (just a trial):

chunks     = {'time':1,'j':30,'i':20}
MLTF       = xr.open_dataset('/home/REGIONS/',chunks=chunks)
MLT        = MLTF.MLT;lon=MLTF.i;lat=MLTF.j;time=MLTF.time

where the MLT (temperature file) has the following characteristics:

<xarray.DataArray 'MLT' (time: 5413, j: 131, i: 101)>
dask.array<open_dataset-aa8def32f2e58a7766e4a76da077651fMLT, shape=(5413, 131, 101), dtype=float32, chunksize=(1, 30, 20), chunktype=numpy.ndarray>
  * i        (i) int64 1179 1180 1181 1182 1183 ... 1275 1276 1277 1278 1279
  * j        (j) int64 569 570 571 572 573 574 575 ... 694 695 696 697 698 699
  * time     (time) datetime64[ns] 2001-02-01 2001-02-02 ... 2015-11-30
    standard_name:  MLT
    long_name:      User-Defined Mixed Layer Temperature
    units:          degC

I need to apply a function to it from the following github source, made for xarrays: xmhw/xmhw_demo.ipynb at master · coecms/xmhw · GitHub

Therefore I started by trying to implement dask delayed.

x          = delayed(threshold(MLT,climatologyPeriod=[2001,2015],tdim='time'))

However this computation, for some reason, took a lot of time (it never finished actually and I had to stop it in the end) and therefore I would like to ask whether a 3D file like this MLT(time: 5413, j: 131, i: 101) has a chance to be optimized in any way, when applying any of the functions mentioned in this github repository: (xmhw/xmhw_demo.ipynb at master · coecms/xmhw · GitHub). Or if the file is too big and the function too complicated to run fast.

Thank you in advance for your time and help,
kind regards,

1 Like


This is a very interesting issue. I believe you’re suffering here from an excess of black-box. The whole process, reading then computing, should be fast even with a single thread. Indeed, you’ve only 71 millions elements of float32, representing 0.28 GB of data. Reading 0.28 GB of data should take less than a sec, say a few sec if the data are scattered across many files. Assuming that the “threshold” function does 20 operations per element (wild assumption), this is 1.4 GFlops to perform. With current cpus delivering tens of GFlops s⁻¹ per core, this should be done in less that a sec.

Having said that, I think this would be very useful to separate the I/O issue from the computation issue.

Therefore, I would first read the data and make sure that they are in the RAM, e.g. print( will force to see the values (and read them). If this takes more that a few sec, then you’ve a problem.
Then I’d send this array to threshold(). But because you want to work on the time axis, I’d transpose the array to make the data contiguous along the time axis MLT_flipped = np.transpose(MLT, [1, 2, 0]) and do threshold(MLT_flipped).

Black box are very powerful but they can also prevent you from realizing that the operation you want to do is very simple and should be very fast.

Hope it helps,


1 Like

Hello Guillaume,

Thank you very much for your detail explanation. I start to understand a bit more now thanks to your explanation.

So I did exactly as you suggested:

MLTF       = xr.open_dataset('/home//REGIONS/')
MLT        = MLTF.MLT
print( ## this took only a few seconds as you suggested 

Then I tried to flip the MLT as you suggested and got the following error when I tried to apply the threshold function:

MLT_flipped = np.transpose(MLT, [1, 2, 0])
(131, 101, 5413)
# Applying the function #
----> 1 B=threshold(MLT_flipped,tdim='time')

~/.conda/envs/MHWENV/lib/python3.8/site-packages/xmhw/ in threshold(temp, tdim, climatologyPeriod, pctile, windowHalfWidth, smoothPercentile, smoothPercentileWidth, maxPadLength, coldSpells, tstep, anynans, skipna)
    126     # Returns an array stacked on all dimensions excluded time
    127     # Land cells are removed and new dimensions are (time,cell)
--> 128     ts = land_check(temp, tdim=tdim, anynans=anynans)
    130     # check if the calendar attribute is present in time dimension

~/.conda/envs/MHWENV/lib/python3.8/site-packages/xmhw/ in land_check(temp, tdim, anynans)
    511         if len(temp[d]) == 0:
    512             raise XmhwException(f"Dimension {d} has 0 lenght, exiting")
--> 513     ts = temp.stack(cell=(dims))
    514     # drop cells that have all/any nan values along time
    515     how = "all"

~/.conda/envs/MHWENV/lib/python3.8/site-packages/xarray/core/ in stack(self, dimensions, **dimensions_kwargs)
   1869         DataArray.unstack
   1870         """
-> 1871         ds = self._to_temp_dataset().stack(dimensions, **dimensions_kwargs)
   1872         return self._from_temp_dataset(ds)

~/.conda/envs/MHWENV/lib/python3.8/site-packages/xarray/core/ in _to_temp_dataset(self)
    422     def _to_temp_dataset(self) -> Dataset:
--> 423         return self._to_dataset_whole(name=_THIS_ARRAY, shallow_copy=False)
    425     def _from_temp_dataset(

~/.conda/envs/MHWENV/lib/python3.8/site-packages/xarray/core/ in _to_dataset_whole(self, name, shallow_copy)
    474         coord_names = set(self._coords)
--> 475         dataset = Dataset._construct_direct(variables, coord_names, indexes=indexes)
    476         return dataset

~/.conda/envs/MHWENV/lib/python3.8/site-packages/xarray/core/ in _construct_direct(cls, variables, coord_names, dims, attrs, indexes, encoding, file_obj)
    872         """
    873         if dims is None:
--> 874             dims = calculate_dimensions(variables)
    875         obj = object.__new__(cls)
    876         obj._variables = variables

~/.conda/envs/MHWENV/lib/python3.8/site-packages/xarray/core/ in calculate_dimensions(variables)
    204                 last_used[dim] = k
    205             elif dims[dim] != size:
--> 206                 raise ValueError(
    207                     "conflicting sizes for dimension %r: "
    208                     "length %s on %r and length %s on %r"

ValueError: conflicting sizes for dimension 'time': length 131 on <this-array> and length 5413 on 'time'

I guess that means that the function wants the variable to have dimensions in a specific order?.Also from what I read the function threshold itself uses dask delayed already inside (Setting up dask — xmhw 0.8.2.dev2 documentation), which makes it even more weird that it takes so much time to run.

I do not know if this has to do with the fact that the model data I am using are a very high resolution. (horizontal resolution of 0.01° (approximately 1 km)). Could that be a problem for the computation or it is irrelevant?

Kind regards,

Hi @oceansufi,

Several things here:

  • First, if you’re working with Xarray on dask array, you do’nt want to use delayed! We can safely guess the xmhw package already use the functionnalities of dask arrays to distribute processing. But it looks you’ve reach this conclusion too
  • Second point: as @groullet said, this file is probably small, less than half a GB.

    • Are you sure you need to split it in Dask array?
    • Did you try to open the file without the chunks argument to use plain Numpy and not Dask?
    • Why did you split the file into such small chunks, this would probably lead to a lot of overhead from Dask scheduler, I’m counting more thant 700,000 chunks !!
    • Are you sure the chunking along dimension is optimal for the computation you’re trying to do?
  • Thirdly, if you really need Dask, have you instantiated a Distributed cluster, or do you only work with plain Dask? In order to have more insights on what is happening, I encourage you to at least use a LocalCluster (don’t forget to connect a client to it). This would enable the possibility to see what is happening on Dask Dashboard.

1 Like

@oceanusofi Great, the IO is sound. The problem is on the computation side. Again, you’re hitting the black box issue. Ideally you would like to send the array to threshold() but this function is doing things in your back, like relying on dask. Given the size of your array, you want to be able to run without dask at all because the computation should take a fraction of a second.

I can see several routes:

  1. you don’t transpose and send the native MLT array to threshold(), hoping the black box to be smart enough to overcome the fact that the data ain’t contiguous along the time axis,
  2. see if you can deactivate dask in threshold().
  3. you implement yourself a threshold() function. With a name like, that I hardly believe the algo is more than 10 lines but I might be wrong. With 3) you may get the impression to “reinvent the wheel”, Don’t. At some point if your knife looks shiny but it is completely unfit, it’s better to use your own, that you can sharpen as much as you want.


I agree :100: that these data are much too small to benefit from Dask.

Hello @geynard,

Thanks for your response. I will try to answer your questions one by one.

  1. Indeed, the xmhw package, uses dask-delayed internally. I discovered this later. So i corrected it and i no longer use dask.delayed on the threshold function.

2: The problem is not when I open the file with the xr.open_dataset( The problem appears when I try to implement the function “threshold” from the xmhw package. By reading the documentation of the package it advises towards chunking the data (Setting up dask — xmhw 0.8.2.dev2 documentation) when performing the function on a big grid. This is because the threshold function is doing/calculating a lot of things for every grid point simultaneously perhaps ?(but I am not sure of the details of the function). All I know is that when I tried to implement it on the original file like this:

MLTF       = xr.open_dataset('/home/REGIONS/')
MLT         = MLTF.MLT;
Out[44]: (5413, 131, 101)
B             = threshold(MLT,climatologyPeriod=[2001,2015],tdim='time')

It never finished.And that was my point exactly. I did not quite understand and therefore I am not sure when (aka at which stage of my programming process) or how I should implement the chunking exactly:

A) Should it be when I my file (in which case I chose random chunks just to see if it would make a difference later when applying the threshold function)?

B) Or chunking simply mean that I had to run the threshold function on a subset of my grid every time in order to be more efficient?
Chunking is something that confuses me so maybe I did not quite understand what was I suppose to do.

My question to you now is,how exactly do you count 700,000 chunks? Can you show me what you did to get this number? That would help me understand a bit more of what I did.

3: I have used before Distributed cluster, but as I am new to dask I am not quite sure what function I can use here in order to split the tasks in different workers/threads. I have used the map function before but, if I am not mistaken, this is a function that is no longer valid for dask distributed? Could you give me perhaps some ideas on how to distribute the running of the threshold function on different subsets of my grid on different workers/threads?

Also, since I am working remotely on a cluster I have issues with loading a firefox browser so I cannot really open a Dask Dashboard. any other ideas?

Thanks in advance,

@groullet - thanks for sharing your perspective and suggestions. It’s great to have your input on this forum!

I disagree with the characterization of Dask and Numpy as “black boxes”. They are useful high-level libraries that abstract away some of the complexity of scientific computing. Abstractions are needed in modern scientific computing, and Python excels at providing such abstractions. While it is important to understand what’s under the hood and how to do things at a low level, a big part of the success of the Pangeo community has been in helping people learn how to efficiently leverage such abstractions, particularly Xarray, to be more productive.

Dask and Numpy have excellent documentation that explain how to use the libraries efficiently. In this case, reading the Dask documentation carefully - Best Practices — Dask documentation - would help resolve the issue of knowing when and how to use Dask.

In this case, it sounds like the problem is with the threshold function in the xmhw package. No computational package should require the use of Dask. Have you reached out to the package authors?

As others have said, your data are not big. Just download them onto your laptop. Then you can avoid all of the complexity of working on the cluster.

1 Like

I see your point @groullet. and @rabernat.

So the threshold function is actually doing a bunch of things inside, like calculating daily climatologies and daily climatological values of percentiles for every single point but with very specific ways. So it is a big algorithm and a kind of complex one.

The alternative of the threshold function would be to loop over 131x101 points in order to get the wanted output.

Yes I have reached out to the package authors and they seem to advice towards chunking as well. But I will definitely consider your opinions. There seem to be simpler ways to do what I am looking for.

Thanks a lot all of you for your input so far.


@oceanusofi Okay, threshold() does a real stuff. You don’t want to recode it :wink: .BUT, given all the statistics that is does, it seems (at least to me) obvious that you need to have the time axis as the last one (i.e. have data contiguity along this axis). I’d try to use threshold() with one timeseries, but watch out, this is where Python is tricky, not a view of MLT but a real copy of it, like

mlt1d = MLT[:, 0, 0].copy()

The tricky part is that MLT[:,0,0] is a view, i.e. a wrapper that behaves like you’re a working with one vector, but two consecutive times are separated in memory by 131 x 101 x 4 = 52.9 kB, this is where you get slowness.

@rabernat Sure, no one wants to recode numpy or scipy ! and frankly, most of algos in numpy and scipy are black boxes – for the best. My point is that when it comes to big data analysis, a bit of low level culture may help better appreciate how to assemble the tools together. Data contiguity is not so esoteric and it’s a big thing for performances. This discussion should continue in another thread!

1 Like

Ok thanks a lot for this suggestions @groullet. I will try that as well.


And if this is fast, which I bet, then your problem becomes embarrassingly parallel meaning you can now parallelize on the 131 x 101 points, with the copy trick …

1 Like

I agree it’s an interesting topic which deserves it’s own thread! Based on my own experience, data-contiguity in memory does not matter much for the types of things we do in ocean / weather / climate data analysis. That’s because, when the algos are coded well*, we are almost always I/O bound, not compute / memory bound. I agree strongly that data contiguity on disk matters a huge amount, particularly for distributed / stream calculations on big data. That’s why we have devoted so much energy to efficient chunked storage libraries like Zarr, and tools like Rechunker that allow you to transform this structure.

*and of course the catch is how to code algorithims well such that they can be efficiently reusused by others. Scipy and numpy do a good job, but there is no guarantee that a third-party package like xmhv will provide such efficient algorithms.

Can you possibly elaborate on this? it is not clear to me what do you mean that I need to parallelize now or how to do this.

Also I have to take into account what is the format of the final file that I want to have. If let’s say I want to end up with multiple timeseries files that I will process at a later stage or with a netcdf file that contains all the information at once. I am under the impression that If i use the .copy() function at every timeseries I will then need to run a for loop in 101x131 all these grid points which will take some time, no?


@oceanusofi Sure, here is what I’ve in mind, not very pangeo-style, I know. @rabernat will blame me not to use dask :slight_smile: At least you’ve the idea and the sketch of it.

import numpy as np
import multiprocessing as mp

nthreads = 4
ny, nx, nt = 131, 101, 5413
mlt_flipped = np.random.normal(size=(ny, nx, nt)).astype("f")

def threshold(j, i):
    """ compute three basic statistics on the time series"""
    x = mlt_flipped[j, i, :]
    return (np.mean(x), np.std(x), np.quantile(x, 0.95))

with mp.Pool(nthreads) as pool:
    output = np.array(pool.starmap(threshold, np.ndindex(ny, nx)))

output.shape = (ny, nx, 3)

use more threads (but not more than the number of cores on one node).

1 Like

Ok thank you very much for this. This last part of the mp.Pool is the parallel computing on multiple threads which I need to understand more. Ok thanks.

Hi @groullet, I never really well benchmarked it but may be Creating NumPy universal functions — Numba 0.50.1 documentation with target 'parallel ’ might be slightly simpler? (& may be it can use the CPU’s AVX too, if the CPU has that, then might be even faster … )

Hi again,

So all the suggestions here were good, I’ll just answer for specific parts, but again you probably shouldn’t need to split/chunk data at first, and first try to call your library on a sub-grid (time series) to assess the amount of time the xmhw package will take.

The way to respect @groullet code template with Dask should be to build chunks that take the full time dimension. e.g.

chunks     = {'time': -1, 'lat': 'auto', 'lon': 'auto'}
# or  chunks     = {'time': -1, 'lat': 10, 'lon': 10} if you want to specify something clear
MLTF       = xr.open_dataset('/home/REGIONS/',chunks=chunks)

Then, the code should stream these chunks for processing them as individual tasks (independently on each of your core), if the logic is done along the same axes, but this really depend of the threshold code.

You should first try to understand why, maybe by taking a smaller example (even if this is already not big).

Ideally yes, but use a chunking scheme that makes sense! Complete time-series here it seems.

Chunking is the way dask array works. You should try to output a dask array shape to understand. If the underlying library code is well implemented, you should’nt have to loop or use multi processing as @groullet proposed. threshold code should run the underlying processing on each of your dask array chunks.

The calculation is the following: 5413 / 1 * 131 / 30 * 101 / 20 = 119,365.
I miscalculate previously, but this is still a lot of chunks! And not well aligned with the computation.

threshold should do it for you based on your dask array chunks if it uses them correctly. So choosing chunks is all you should have to do.

Debugging Dask withtout a dashboard is an impossible mission. Just try it locally on your laptop, your data is small :smile:

So @geynard @rabernat

it turns out after a few tests that my .nc files were maded in a way (contiguous?) that the fastest access is on the time axis and the slowest access on the x,y axis:

CPU times: user 2.99 ms, sys: 1.01 ms, total: 3.99 ms
Wall time: 3.38 ms

In [15]: %%time
CPU times: user 29.3 ms, sys: 126 ms, total: 155 ms
Wall time: 139 ms

I know this is opposite from the conventional time of access, so I was wondering whether someone could advise me on how to chunk my files in order to speed up the processing. When I use large chunks in my file like this:

MLTF      = xr.open_dataset('/home/REGIONS/')
MLT       = MLTF.MLT;lon=MLTF.i;lat=MLTF.j;time=MLTF.time
MLT       = MLT.chunk({'time':-1,'j':'auto','i':'auto'}) 
dask.array<xarray-<this-array>, shape=(5413, 131, 101), dtype=float32, chunksize=(5413, 78, 78), chunktype=numpy.ndarray>

I notice that it takes ages afterwards to apply the threshold function.
Therefore, although I am chunking my data as shown above (and it does not help at all for my processing afterwards) in order to speed up the reading of my files I loop through the groups of lat and lon indices like this:

 ## Split the lat/lon array into chucnks of 5 ##
 ilat      = np.arange(0,len(lat),1);ilon = np.arange(0,len(lon),1)
 latc      = list(iterutils.get_chunks(ilat, 5))
 lonc      = list(iterutils.get_chunks(ilon, 5))


In [32]: latc
[array([0, 1, 2, 3, 4]),
 array([5, 6, 7, 8, 9]),
 array([10, 11, 12, 13, 14]),
 array([15, 16, 17, 18, 19]),
 array([20, 21, 22, 23, 24]),
 array([25, 26, 27, 28, 29]),
 array([30, 31, 32, 33, 34]),
 array([35, 36, 37, 38, 39]),
 array([40, 41, 42, 43, 44]),
 array([45, 46, 47, 48, 49]),
 array([50, 51, 52, 53, 54]),
 array([55, 56, 57, 58, 59]),
 array([60, 61, 62, 63, 64]),
 array([65, 66, 67, 68, 69]),
 array([70, 71, 72, 73, 74]),
 array([75, 76, 77, 78, 79]),
 array([80, 81, 82, 83, 84]),
 array([85, 86, 87, 88, 89]),
 array([90, 91, 92, 93, 94]),
 array([95, 96, 97, 98, 99]),
 array([100, 101, 102, 103, 104]),
 array([105, 106, 107, 108, 109]),
 array([110, 111, 112, 113, 114]),
 array([115, 116, 117, 118, 119]),
 array([120, 121, 122, 123, 124]),
 array([125, 126, 127, 128, 129]),

After a few tests I noticed that groups of 5x5 lat & lon take less time to run my following program:

### Run the loop in all grids ##
for y in range(0,len(latc)):
  for x in range(0,len(lonc)):
      lat1 = latc[y][0];lat2 = latc[y][-1];lon1 = lonc[x][0];lon2 = lonc[x][-1]
      temp    = MLT.isel(j=slice(lat1,lat2+1),i=slice(lon1,lon2+1))
      tstart   = perf_counter()
      C        = threshold(temp,climatologyPeriod=[2001,2015], tdim='time',skipna=True)
      C        = C.reindex_like(temp)
      tstop    = perf_counter()
      t1       = tstop -tstart
      del C
    except Exception as e:
        logger.exception('XmhwException: All points of grid are either land or NaN')
        print("I am missing") 

This is faster than looping every other 10 lon x 10 lat grid points.Any ideas on how to speed this up and why this is happening?

Could I use numba to speed up the for loop here perhaps?


I also tried the following: (In case anyone has any advice I would be happy to hear it)

from dask.distributed import Client, LocalCluster
cluster   = LocalCluster()
client    = Client(cluster)## Here cluster is the adress of the scheduler that I am connecting to.
client    = Client(threads_per_worker=2, n_workers=10,processes=True)

MLTF      = xr.open_dataset('/home/REGIONS/')
MLT       = MLTF.MLT;lon=MLTF.i;lat=MLTF.j;time=MLTF.time
MLT       = MLT.chunk({'time':-1,'j':'10','i':'10'})
dask.array<xarray-<this-array>, shape=(5413, 131, 101), dtype=float32, chunksize=(5413, 78, 78), chunktype=numpy.ndarray>

## Split the lat/lon array into chucnks of 10 ##
ilat      = np.arange(0,len(lat),1);ilon = np.arange(0,len(lon),1)
latc     = list(iterutils.get_chunks(ilat, 5))
lonc    = list(iterutils.get_chunks(ilon, 5))

logger = logging.Logger('catch_all') 

### Climatology and threshold function ###

def CLIM(temp,lat1,lat2,lon1,lon2):
    temp     = MLT.isel(j=slice(lat1,lat2+1),i=slice(lon1,lon2+1))
    if temp.isnull().all():
        print("I am missing","XmhwException: All points of grid are either land or NaN")
        data_vars = dict (thresh = (['doy','j','i'],np.nan * np.ones([365,int(lat2)-int(lat1)+1,int(lon2)-int(lon1)+1])),seas=(['doy','j','i'],np.nan * np.ones([365,int(lat2)-int(lat1)+1,int(lon2)-int(lon1)+1])))
        coords    = dict(doy = (['doy'],np.arange(1,366,1)),j = (['j'],,i = (['i'],,quantile=(0.9))
        attrs     = dict(source="xmhw code:")
        C         =  xr.Dataset(data_vars=data_vars,coords=coords,attrs=attrs)
      print("I am working")
      temp     = MLT.isel(j=slice(lat1,lat2+1),i=slice(lon1,lon2+1))#isel does not take into account the last values
      C        = threshold(temp,climatologyPeriod=[2001,2015], tdim='time',skipna=True)
      C        = C.reindex_like(temp)
    return C

## first create a list with the name of the futures that I will create ##
fname = []
for j in range(0,len(latc)*len(lonc)):
    d= [f"future{j}"]

## now create the futures for every lat lon group
k=0;  A=[]
for j in range(0,len(latc)):
  for i in range(0,len(lonc)):
    fname[k][0] = client.submit(CLIM,MLT,latc[j][0],latc[j][-1],lonc[i][0],lonc[i][-1])
    tstart   = perf_counter()
    tstop    = perf_counter()
    t1       = tstop -tstart
    k = k+1

It seems to have reduced a little bit the time of running compared to the previous script. Still I expect it to take about 2 hours . Any ideas for optimization?