Zarr parallel region writes using multiple processes

Hello! I’m pretty new to the geospatial data world – and by extension some of the awesome libraries that this community has helped to curate. Apologies in advance, therefore, if my question is somewhat basic.

I’m having trouble incorporating Zarr parallel region writes into my code, am I doing something very obviously not right? I’ve shared the current iteration below with relevant snippets highlighted. As the docstring states the function downloads an image from GEE in several patches. The downloads themselves are parallelized and I’m trying to do the same with the region writes to improve the overall performance.

Unfortunately, the code just hangs when I run it / I don’t get any sort of error. If I do the region writes sequentially, everything works perfectly fine, it’s just when I try to parallelize it. I’ve tried to tinker with it in lots of different ways but nothing seems to be working. The last change I had made was to include the Zarr ProcessSynchronizer.

I’ve also set numcodecs.blosc.use_threads = False as I understood from the documentation that this was necessary when using multiple processes to parallelize writes.

Any and all feedback appreciated in advance.

def download_image(image: Image, patch_size: int = 4000, parallel: bool = False, bands: list = None,
                    retry: int = 5, zarr_path: str = None) -> Dataset:
    '''
        Download an image. Image bigger than a certain size are split into smaller patches and downloaded and 
        then re-combined.

        Parameters
        ----------
        - image: Image - Google Earth Image object. Can be the string ID as well    
        - patch_size: int default=4000 - Size of smaller patch to split into
        - bands: list optional - List of bands to download
        - parallel: bool optional - Download in parallel
        - retry: int default=5 - Retry count with exponential backoff
        - zarr_path: str optional - Download to a zarr dataset. Recommended for large images.

        Returns 
        -------
        Xarray Dataset
    '''
    tmp_path = tempfile.mkdtemp()

    if type(image) == str:
        image = ee.Image(image)

    if bands:
        image = image.select(bands)

    image_meta = image.getInfo()
    print(f'Downloading image {image_meta["id"]}')
    print(f'\nImage has {len(image_meta["bands"])} bands -')
    pprint(image_meta["bands"])

    crs_transform = image_meta['bands'][0]['crs_transform']

    ds = xr.Dataset({
    'HH': (['y', 'x'], da.empty((17409, 22254))),  
    'HV': (['y', 'x'], da.empty((17409, 22254))),
    'LIN': (['y', 'x'], da.empty((17409, 22254))),
    'MSK': (['y', 'x'], da.empty((17409, 22254)))
    })  

    ds.coords['x'] = ('x', da.arange(22254))  
    ds.coords['y'] = ('y', da.arange(17409))

    ds.coords['x'] = ds.coords['x'] * crs_transform[0] + crs_transform[2]
    ds.coords['y'] = ds.coords['y'] * crs_transform[4] + crs_transform[5]

    ds.coords['y'] = ds.coords['y'][::-1]  
    ds.coords['x'] = ds.coords['x'] + (crs_transform[0] / 2)  
    ds.coords['y'] = ds.coords['y'] + (crs_transform[4] / 2)

    ds = ds.chunk({'x': patch_size, 'y': patch_size})
    
    ds.to_zarr(zarr_path, compute=False)

    for band_meta in image_meta['bands']:
        print(f'\nProcessing band {band_meta["id"]}')
        # Clipped images may have an origin offset
        if 'origin' in band_meta:
            crs_transform[2] = crs_transform[2] + crs_transform[0] * band_meta['origin'][0]
            crs_transform[5] = crs_transform[5] + crs_transform[4] * band_meta['origin'][1]
        crs = band_meta['crs']
        x_len, y_len = band_meta['dimensions']

        print(f'Image to download has size {x_len, y_len}')
      
        x_count = math.ceil(x_len/patch_size)
        y_count = math.ceil(y_len/patch_size)

        print(f'Count of patches to download is {x_count, y_count} with patch size {patch_size}')

        synchronizer = zarr.ProcessSynchronizer('/root/cyclops/tmp')

        def _download_patch_part(params):
                image, download_url_params, x, y, retry_count, band_id = params
                url = image.getDownloadURL(download_url_params)
                print(f'Downloading patch {x, y} with dimensions {download_url_params["dimensions"]} from {url}')
                
                for i in range(retry_count):
                    try:
                        response = r.get(url)
                        with ZipFile(BytesIO(response.content)) as zf:
                            img = zf.filelist[0]
                            filename = os.path.join(tmp_path, f'{img.filename}-patch-{x}-{y}.tif')
                            img.filename = filename
                            zf.extract(img)

                        **with rio_open(f'/root/cyclops{filename}') as tif:**
**                            data = tif.read(1)**

**                        # Calculate coordinates**
**                        x_min = x * patch_size**
**                        x_max = x_min + tif.width**
**                        y_min = y * patch_size**
**                        y_max = y_min + tif.height**

**                        patch_ds = xr.Dataset(**
**                            {band_id: (['y', 'x'], data)},**
**                            coords={'x': da.arange(x_min, x_max), 'y': da.arange(y_min, y_max)}**
**                        )**

**                        selection = {'x': slice(x_min, x_max), 'y': slice(y_min, y_max)}**
**                        **
**                        patch_ds.to_zarr(zarr_path, region=selection, mode='r+', synchronizer=synchronizer**

                    except Exception as e:
                        sleep(2**(6+i))
                        print(f'Error downloading patch {x, y} - {e}. Retrying ({i+1}/{retry_count})...')
                        sleep(i)

                raise Exception(f'Error downloading patch {x, y}.')

        # Download image in parts
        patch_parts = []
        for x in range(x_count):
            for y in range(y_count):
                crs_transform_patch = crs_transform.copy()
                crs_transform_patch[2] = crs_transform[2] + crs_transform[0] * patch_size * x
                crs_transform_patch[5] = crs_transform[5] + crs_transform[4] * patch_size * y
                x_patch_size = patch_size if (patch_size * (x+1)) < x_len else x_len - patch_size * x
                y_patch_size = patch_size if (patch_size * (y+1)) < y_len else y_len - patch_size * y
                download_url_params = dict(crs=crs, crs_transform=crs_transform_patch, dimensions=(x_patch_size, y_patch_size))
                patch_parts.append((image.select(band_meta['id']), download_url_params, x, y, retry, band_meta['id']))    

        **if parallel:**
**            print('Downloading patches in parallel...')**
**            with ProcessingPool(4) as pool:**
**                pool.map(_download_patch_part, patch_parts)**
        else:    
            pool.map(_download_patch_part, patch_parts)
        
    ds = xr.open_dataset(zarr_path)
    ds = ds.rio.write_crs(crs)
    ds.to_zarr('example.zarr')

    return xr.open_dataset('example.zarr')