Dask RESAMPLING (can't handle large files...)

Hello All! First post (thanks for the referral @rsignell)

My colleague and I are working on cooking up a version of SSEBop (an energy-balance Evapotranspiration model) that takes advantage of Dask parallel processing inspired by the Pangeo tools… We’ve been getting stuck when we try and resample rasters. We’re running into memory errors…

Here is the workflow:

  1. Read in the rasters with rioxarray.open_rasterio(filepath, chunks=True)
  2. Resample several rasters to 25km and 100km with Nearest Neighbor
  3. Do some raster logic/math
  4. Output an ET fraction image.

We’re prototyping this on EROS VDI and our personal laptops. The input files are global rasters at 1km resolution so they are ~2GB in size. When we go to reproject as few as one raster is where we run into issues. Originally we tried resampling as recommended by rioxarray, but this doesn’t work at all because suddenly we get a memory issue:

https://corteva.github.io/rioxarray/stable/examples/resampling.html

I figured that’s because the reproject needs to call everything into memory at once, so I was looking for a way around using the rasterio resampling based on a reprojection. Then I found this pyresample example (Based on a question by @rabernat):

https://github.com/pytroll/pyresample/issues/206

I tried using the pyresample method to do my resampling in a parallel fashion (since it is dask-enabled), but I watched my dask-distributed workers struggling with the workload and overall failing to resample this large 1km raster to 25km:

I also tried the default threaded scheduler and also no luck there…

I thought that Dask was supposed to automate stuff behind the scenes, and queue up jobs in bite-size chunks so you don’t hit memory errors?? I’m certainly not having that experience. We are soon gonna have more compute resources at our disposal and move into AWS (where I don’t have a TON of experience) but I thought this would be possible on a laptop. It’s nice to be able to model AT Least ONE SCENE without messing with the cloud stuff all the time. What am I missing? How do people do resampling workflows quickly?

We could resample by other means and preprocess the types of rasters we need but the vision was to do it on the fly with fast parallel tools, but this ain’t cutting it hahaha. Arcpy can resample this stuff on a laptop so there must be a way we can get it going with open source, right?

Much love,
Gabe

Hi @Gabe-Parrish and welcome to the forum! We would love to help you with this issue.

It’s concerning that you can’t even process a single scene without crashing / running out of memory. My strong recommendation would be to not move to Dask until you have a single standalone task actually working. Dask is not magic. It is good for parallelizing workflows over multiple CPUs or machines in a cluster. But it cannot change the basic memory requirements of your calculation. For Dask to work, you need to split your problem into many small pieces that can be handled in memory. It sounds to me like you have not yet found the right chunking strategy to split up your problem into small pieces.

It sounds like you are trying to resample a global 1km raster image. Resample to what? What is your source / target grid / CRS? Are you trying to do a global → global reprojection of the data? If so, that can be a hugely expensive operation. It may be that the only solution is to run the operation on a huge computer with tons of memory. Regridding is complicated to parallelize because every grid point potentially maps to every other grid point. Dask can’t just magically figure out how to do this. You would need a custom algorithm.

If you share more specific details, we can probably be more helpful

@Gabe-Parrish , as @rabernat said, thanks for posting here!

I notice you don’t mention re-projection here in this bullet list, only resampling a 1 km grid to 25 km. You mention using a nearest neighbor, but would taking the mean be preferable? It should be straightforward to open the data with something like chunks={'x':2500, 'y':2500} and then creating a list of delayed tasks where each task computes 25x25 km block means on each chunk, probably using the reshape approach detailed in this SO answer.

The list of delayed tasks could then be executed using Dask bag.

Or maybe map_blocks would be more appropriate?

Adding to what @rsignell mentioned, the average can be calculated in linear time, while nearest neighbours interpolation tends to be more greedy, with runtimes between logarithmic and quadratic complexities (sometimes with the same impact on memory usage).

In a certain way how you organize your data may affect these runtimes too. A brief discussion on time and memory complexity and design choices may be found here.

For block_mean resampling (e.g. from 1km to 25km), @ocefpaf reminded me of the xarray.coarsen method.

So actually this problem is as easy as:

da_25km = da_1km.coarsen(x=25, y=25, boundary='pad').mean()

Here’s a complete example notebook!

Hello everyone, I’m back after a large hiatus but I need to pick up where I left off on this question!

To start off, I have some debt to you all for getting me a lot farther with your replies. I wrote two or three functions that were really useful to me in my workflow that are able to do the job nicely… One was relies on the coarsen() method that @rsignell threw out there.

def coarsen_dask_arr(arr: da, scaling_value: int, resample_alg='average', mask=None):
    """
    big help with implementation from Rich Signell.

    :param arr: Dask array opened as: rxr.open_rasterio(path).squeeze('band', drop=True)
    :param scaling_value: a whole interger. In the SSEBop context 5, 25 or 100.
    :param resample_alg: 'average', 'mean' or 'sum' accepted.
    :return: dask array that has been resampled based on scaling factor.
    """

    with dask.config.set(**{'array.slicing.split_large_chunks': False}):
        # To avoid creating the large chunks, set the option
        #     >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
        #     ...     array.reshape(shape)Explictly passing ``limit`` to ``reshape`` will also silence this warning
        #     >>> array.reshape(shape, limit='128 MiB')
        #   dsub = arr.coarsen(x=scaling_value, y=scaling_value, boundary='pad').mean()

        if resample_alg == 'average':
            #https://docs.xarray.dev/en/v2022.06.0/generated/xarray.core.rolling.DataArrayCoarsen.
            #construct.html#xarray.core.rolling.DataArrayCoarsen.construct
            dsub = arr.coarsen(x=scaling_value, y=scaling_value, boundary='pad').mean()
        elif resample_alg == 'mean':
            dsub = arr.coarsen(x=scaling_value, y=scaling_value, boundary='pad').mean()
        elif resample_alg == 'sum':
            # is this correct?
            dsub = arr.coarsen(x=scaling_value, y=scaling_value, boundary='pad').sum()

        else:
            print('WARNING')
            print(f"{resample_alg}, is not supported, options are: 'average', 'mean' or 'sum'")
            raise TypeError
        return dsub

This allowed us to coarsen our arrays and do zonal statistics really nicely. Another thing that was huge was that I discovered that you can open vrt (gdal virtual rasters) as dask arrays, so that allowed me to open all the files more easily… Here is a function that resamples files to the same extent and resolution and returns a dask array…

def normalize_to_std_grid_dask(inputs, nodatas=[], sample_file=None,
                          resamplemethod='nearest', outdtype='float64', overwrite=True):
    """
        Uses rasterio virtual raster to standardize grids of different crs, resolution, boundaries based on  a shapefile geometry feature
        :param inputs: a list of (daily) raster input files for SSEBop.
        :param outloc: output locations 'temp' for the virtual files
        :return: list of numpy arrays
        """
    outputs = []

    with rasterio.open(sample_file) as src:

        out_meta = src.meta
        crs = out_meta['crs']
        transform = out_meta['transform']
        left = transform[2]
        top = transform[5]
        cols = out_meta['width']
        rows = out_meta['height']
        xres = transform[0]
        yres = transform[4]
        # return out_meta

    if resamplemethod == 'nearest':
        rs = Resampling.nearest
    elif resamplemethod == 'average':
        rs = Resampling.average
    else:
        print('only nearest-neighbor and average resampling is supported at this time')
        sys.exit(0)

    for i, warpfile in enumerate(inputs):
        print('warpfile', warpfile, i)
        print(f'warping {warpfile}\n with nodata value: {nodatas[i]}')
        # TODO:  Source dataset should be opened in read-only mode. Use of datasets opened in modes other than
        #  'r' will be disallowed in a future version.
        with rasterio.open(warpfile, 'r') as src:
            # create the virtual raster based on the standard rasterio
            # attributes from the sample tiff and shapefile feature.
            # update with suitable nodata values.
            nodata_val = nodatas[i]
            # src.nodata = nodata_val
            with WarpedVRT(src, resampling=rs,
                           crs=crs,
                           transform=transform,
                           height=rows,
                           width=cols,
                           dtype=outdtype) as vrt:
                with dask.config.set(**{'array.slicing.split_large_chunks': True}):
                    # To avoid creating the large chunks, set the option
                    #     >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
                    #     ...     array.reshape(shape)Explictly passing ``limit`` to ``reshape`` will also silence this warning
                    #     >>> array.reshape(shape, limit='128 MiB')
                    #   dsub = arr.coarsen(x=scaling_value, y=scaling_value, boundary='pad').mean()
                    data = riox.open_rasterio(vrt, chunks=(1000, 1000)).squeeze('band', drop=True)
                    outputs.append(data)
    return outputs

I never figured out how to make my earliear aproaches with pyresample or rioxarray resampling work.
I did have the need to bilinearly resample raster files and this dask approach worked for me but was memory-intensive:

cfactor.load()
cfactor_bilinear_ds = cfactor.interp(y=tmax_ds['y'], x=tmax_ds['x'], method='linear')

This executes a linear interpolation in space, which was neat! I am putting it here bc ppl may want to know about it. With files of the size I’m working with ~2GB rasters (AKA Earth at 1km resolution), I was only able to get the linear interpolation above to work with 64GB of RAM. My workaround when I have a lower memory device is to use the rasterio.vrt and virtualwarp the file with bilinear interpolation which works ok… So thanks everybody you really did set me down the right track.

There are still issues with unmanaged memory and warnings that I get, but I’ll put more detail and create a separate question and link it here, because I think it is related, but separate…

1 Like

@Gabe-Parrish glad you’ve gotten things to work! If you are able to share the full code or at the very least the specific data files you are working with that would help diagnose the issues you’re encountering.

A big consideration for memory consumption is going to be 1. the chunks you specify for dask and 2. how the rasters are internally chunked (or “tiled”). If you’re dealing with many rasters that are not on a matched grid and tiling scheme to begin with it is impossible to align your xarray chunks with the underlaying datasets. In a worst case scenario the rasters are not tiled at all and you have to read the whole thing into memory for any operation…

Your normalize_to_std_grid_dask function is effectively what tools like https://stackstac.readthedocs.io are doing with some dask optimizations behind the scenes, so I’d have a look at that! Instead of a VRT you’d need to create STAC metadata for your files, but the concept is the same: reproject/subset upon reading the file.

2 Likes

Hi @scottyhq, intriguing point about the internally chunked or tiled datasets. How do you handle situations where you are just dealing with lots of geotiffs and they aren’t internally chunked, but you want them to be? In other words, how do you rewrite files with a tiling scheme?

I would share the data files but they are very large and on a local system…