Zarr parallel region writes using multiple processes

Hello! I’m pretty new to the geospatial data world – and by extension some of the awesome libraries that this community has helped to curate. Apologies in advance, therefore, if my question is somewhat basic.

I’m having trouble incorporating Zarr parallel region writes into my code, am I doing something very obviously not right? I’ve shared the current iteration below with relevant snippets highlighted. As the docstring states the function downloads an image from GEE in several patches. The downloads themselves are parallelized and I’m trying to do the same with the region writes to improve the overall performance.

Unfortunately, the code just hangs when I run it / I don’t get any sort of error. If I do the region writes sequentially, everything works perfectly fine, it’s just when I try to parallelize it. I’ve tried to tinker with it in lots of different ways but nothing seems to be working. The last change I had made was to include the Zarr ProcessSynchronizer.

I’ve also set numcodecs.blosc.use_threads = False as I understood from the documentation that this was necessary when using multiple processes to parallelize writes.

Any and all feedback appreciated in advance.

def download_image(image: Image, patch_size: int = 4000, parallel: bool = False, bands: list = None,
                    retry: int = 5, zarr_path: str = None) -> Dataset:
    '''
        Download an image. Image bigger than a certain size are split into smaller patches and downloaded and 
        then re-combined.

        Parameters
        ----------
        - image: Image - Google Earth Image object. Can be the string ID as well    
        - patch_size: int default=4000 - Size of smaller patch to split into
        - bands: list optional - List of bands to download
        - parallel: bool optional - Download in parallel
        - retry: int default=5 - Retry count with exponential backoff
        - zarr_path: str optional - Download to a zarr dataset. Recommended for large images.

        Returns 
        -------
        Xarray Dataset
    '''
    tmp_path = tempfile.mkdtemp()

    if type(image) == str:
        image = ee.Image(image)

    if bands:
        image = image.select(bands)

    image_meta = image.getInfo()
    print(f'Downloading image {image_meta["id"]}')
    print(f'\nImage has {len(image_meta["bands"])} bands -')
    pprint(image_meta["bands"])

    crs_transform = image_meta['bands'][0]['crs_transform']

    ds = xr.Dataset({
    'HH': (['y', 'x'], da.empty((17409, 22254))),  
    'HV': (['y', 'x'], da.empty((17409, 22254))),
    'LIN': (['y', 'x'], da.empty((17409, 22254))),
    'MSK': (['y', 'x'], da.empty((17409, 22254)))
    })  

    ds.coords['x'] = ('x', da.arange(22254))  
    ds.coords['y'] = ('y', da.arange(17409))

    ds.coords['x'] = ds.coords['x'] * crs_transform[0] + crs_transform[2]
    ds.coords['y'] = ds.coords['y'] * crs_transform[4] + crs_transform[5]

    ds.coords['y'] = ds.coords['y'][::-1]  
    ds.coords['x'] = ds.coords['x'] + (crs_transform[0] / 2)  
    ds.coords['y'] = ds.coords['y'] + (crs_transform[4] / 2)

    ds = ds.chunk({'x': patch_size, 'y': patch_size})
    
    ds.to_zarr(zarr_path, compute=False)

    for band_meta in image_meta['bands']:
        print(f'\nProcessing band {band_meta["id"]}')
        # Clipped images may have an origin offset
        if 'origin' in band_meta:
            crs_transform[2] = crs_transform[2] + crs_transform[0] * band_meta['origin'][0]
            crs_transform[5] = crs_transform[5] + crs_transform[4] * band_meta['origin'][1]
        crs = band_meta['crs']
        x_len, y_len = band_meta['dimensions']

        print(f'Image to download has size {x_len, y_len}')
      
        x_count = math.ceil(x_len/patch_size)
        y_count = math.ceil(y_len/patch_size)

        print(f'Count of patches to download is {x_count, y_count} with patch size {patch_size}')

        synchronizer = zarr.ProcessSynchronizer('/root/cyclops/tmp')

        def _download_patch_part(params):
                image, download_url_params, x, y, retry_count, band_id = params
                url = image.getDownloadURL(download_url_params)
                print(f'Downloading patch {x, y} with dimensions {download_url_params["dimensions"]} from {url}')
                
                for i in range(retry_count):
                    try:
                        response = r.get(url)
                        with ZipFile(BytesIO(response.content)) as zf:
                            img = zf.filelist[0]
                            filename = os.path.join(tmp_path, f'{img.filename}-patch-{x}-{y}.tif')
                            img.filename = filename
                            zf.extract(img)

                        **with rio_open(f'/root/cyclops{filename}') as tif:**
**                            data = tif.read(1)**

**                        # Calculate coordinates**
**                        x_min = x * patch_size**
**                        x_max = x_min + tif.width**
**                        y_min = y * patch_size**
**                        y_max = y_min + tif.height**

**                        patch_ds = xr.Dataset(**
**                            {band_id: (['y', 'x'], data)},**
**                            coords={'x': da.arange(x_min, x_max), 'y': da.arange(y_min, y_max)}**
**                        )**

**                        selection = {'x': slice(x_min, x_max), 'y': slice(y_min, y_max)}**
**                        **
**                        patch_ds.to_zarr(zarr_path, region=selection, mode='r+', synchronizer=synchronizer**

                    except Exception as e:
                        sleep(2**(6+i))
                        print(f'Error downloading patch {x, y} - {e}. Retrying ({i+1}/{retry_count})...')
                        sleep(i)

                raise Exception(f'Error downloading patch {x, y}.')

        # Download image in parts
        patch_parts = []
        for x in range(x_count):
            for y in range(y_count):
                crs_transform_patch = crs_transform.copy()
                crs_transform_patch[2] = crs_transform[2] + crs_transform[0] * patch_size * x
                crs_transform_patch[5] = crs_transform[5] + crs_transform[4] * patch_size * y
                x_patch_size = patch_size if (patch_size * (x+1)) < x_len else x_len - patch_size * x
                y_patch_size = patch_size if (patch_size * (y+1)) < y_len else y_len - patch_size * y
                download_url_params = dict(crs=crs, crs_transform=crs_transform_patch, dimensions=(x_patch_size, y_patch_size))
                patch_parts.append((image.select(band_meta['id']), download_url_params, x, y, retry, band_meta['id']))    

        **if parallel:**
**            print('Downloading patches in parallel...')**
**            with ProcessingPool(4) as pool:**
**                pool.map(_download_patch_part, patch_parts)**
        else:    
            pool.map(_download_patch_part, patch_parts)
        
    ds = xr.open_dataset(zarr_path)
    ds = ds.rio.write_crs(crs)
    ds.to_zarr('example.zarr')

    return xr.open_dataset('example.zarr')

Hey there! I’m not sure what’s the source of your error, but it does sound like a doozy. One alternative:

High level feedback looking at your implementation: it looks like there is a mix of process and thread parallelism, as well as parallelism managed by Python and Dask. To me, this many levels of analysis is prone to locks (if not deadlocks) slowing things down in a hard to debug way.

I recommend trying to use concurrency for IO from EE (say, via concurrent futures) combined with parallelism for computation via Dask.

cProfile and similar tools may help reveal where things are taking too long. If you can get a flame graph output, you may be able to uncover what is causing the slowdown. I’ve had good success with py-spy.