Xarray/Dask - Specify that a given task use huge amount of RAM to the Dask Ressource Manager

Hi :wave:, I’ve already post this question in the Dask Discourse however I think some peaple here could have the solution as your familiar with the usage of Dask with geospatial data.

I’m having a issue to distribute a deep learning process over a big raster image (i’m using xarray as a dask interface).
I use a big image of about 1GB in size and I’m reading this image using xarray and chunk it in 10 different chunk of 10MB each.
Then for each chunk I’m doing a deep learning magic using ONNX Runtime. To do so I use xarray apply_ufunc (dask_gufunc wrapper, like a map_block) which apply my inference function predict on each chunk.
At this point I’m getting a nice graph of computation with 10 task, one for each chunk.

However my predict function is using a lot of ram, for a chunk of 100MB I’m using at the peak 300MB of RAM (because of convolution and other fancy stuff going on in it). This lead to a high memory usage and a crash while trying to execute with few ram, at the start of the computation each worker get associated with too many task because Dask think it will handle it but finally it explode.

So my question is: How can I communicate to the Dask ressource manager to NOT give too many predict task to my workers ?

Currently, I’m on a LocalCluster and I want to ensure that my work could run with approx 3GB of RAM.
I found that using the Worker Resources section, I can specify a resources for a specific summit (ie, client.submit(process, d, resources={'MEMORY': "200MB"})) or with a dask.annotate context. However it seems that I can’t specify constraints on apply_ufunc task. EDIT: It seems to work using the dask.annotate context (see post behind), but the scheduler is not handling it well anyway.

Notebook Version of my examble: notebook version

According to .__dask_graph__() my annotation (here with 100MB) is well format in the graph.

I don’t understand why even with the ressource annotation the Dask Scheduler give to many dask to my worker even if it’s explicitly too much for them.

Here is a minimal reproducible example:

Define Cluster and workers

# Load dask cluster and monitoring.
from distributed.client import Client
from distributed.deploy.local import LocalCluster

cluster = LocalCluster(
    # n_workers=4,
    # threads_per_worker=1,
    # processes=True,
    memory_limit="500MB"
)
client = Client(cluster)

Define custom function

With resnet18-v1-7.onnx from here (40MB)

import numpy as np
import onnxruntime as ort

def forward_onnx(image_tiled: np.ndarray) -> np.ndarray:
    *batch, c, h, w = image_tiled.shape
    image_tiled_batched = image_tiled.reshape(np.prod(batch), c, h, w)

    model_session = ort.InferenceSession("./resnet18-v1-7.onnx", providers=['CPUExecutionProvider'])

    outputs = model_session.run(
        output_names=["resnetv15_dense0_fwd"],
        input_feed={"data": image_tiled_batched.astype(np.float32)},
    )
    output_unbatched = outputs[0].reshape(*batch, -1).astype(np.float32)
    return output_unbatched

Load image

With image.tif from here (1GB)

import rioxarray
image = rioxarray.open_rasterio(  # type: ignore
    "./image.tif", 
    parse_coordinates=True, 
    #chunks={"x": "auto", "y": "auto"}
)

Make tile and a rolling windows with rolling

shift = 224
input_size = 224
output_size = 1000

image_tiled = (
    # (band, y, x).
    image.transpose()  # type: ignore # (x, y, band).
    .rolling(  # Rolling object, future computation of sliding_window_view
        dim={"x": input_size, "y": input_size}
    )  # Rolling x->16, y->16.
    .construct(  # Construct the rolling view and apply stride
        x="x_tile", y="y_tile", stride=shift,
    )  # (x, y, band, x_tile, y_tile).
    .chunk(  # Auto chunk (chunk_size ~ jobs_memory).
        ("auto", "auto", -1, -1, -1), merge_chunks=False
    )  # (x(chunked), y(chunked), band, x_tile, y_tile).
)

Compute Deep Learning for each tile on each chunk (as a batch)

I tried to annotate that this task could use up too 300MB per task. But this don’t seems to work.

# Summit future features view into computation graph (async).
from dask import annotate
import xarray as xr
with annotate(resources={'MEMORY': 0.300e9}):
    image_features: xr.DataArray = (
        xr.apply_ufunc(  # Call dask parralelized gufunc
            forward_onnx,
            image_tiled,
            input_core_dims=[["band", "y_tile", "x_tile"]],
            output_core_dims=[["features"]],
            keep_attrs=True,
            dask="parallelized",
            output_dtypes=[np.float32],
            dask_gufunc_kwargs = {"output_sizes": {"features": output_size}}
            
        )  # (x, y, features)
        .stack(xy=["x", "y"])  # (features, x×y('xy'))
        .transpose("xy", ...)  # (x×y('xy'), features)
        # .chunk(("auto", -1), merge_chunks=False)  # (xy(chunked), features)
    )

Compute

image_features_np = image_features.as_numpy()