Efficient extraction of concurrent variables using rolling-window argmax indices

Hi,

I am working with CMIP6 daily piControl data. My objective is to identify the 5-day period of maximum precipitation (Rx5day) for each year, and then extract the concurrent daily data for other variables (temperature, omega, surface pressure) on those exact same 5 days.

For context on the data shapes, omega and temperature are pressure level data with dimensions ('time', 'plev', 'lat', 'lon'), while precipitation and surface pressure have dimensions ('time', 'lat', 'lon').

Currently, I am looping through each year, computing the index of the precipitation maximum, and then using a nested loop to extract the other variables via .isel()

Here is my current approach:

for k in range(start_y, end_y + 1):
    year_str = f"{k:04d}"
    prec_data_k = dataset_pr[‘pr’].sel(time=year_str)
    ps_data_k  = ps_daily.sel(time=year_str)
    ta_data_k  = ds_temp[‘ta’].sel(time=year_str)
    wap_data_k  = ds_omega[‘wap’].sel(time=year_str)

    wap_data_k =  wap_data_k.where(wap_data_k<0)
        
    prec_rx5d_sum = prec_data_k.rolling(time=5, center=False).sum()
    
    prec_yMax_idx = prec_rx5d_sum.argmax(dim='time').compute()

    for i in range(5):
        day_idx = prec_yMax_idx - 4 + i       
        wap_cond_k = wap_data_k.isel(time=day_idx)
        ta_cond_k = ta_data_k.isel(time=day_idx)
        ps_cond_k = ps_data_k.isel(time=day_idx)
        ........................................

This code is taking a lot of time to process even a single year.

If I don’t use .compute(), I hit the below error.

ValueError: Vectorized indexing with Dask arrays is not supported. Please pass a numpy array by calling ``.compute``. 

My question: Is there a better approach to overcome this? I am looking for an efficient way to extract these exact 5 days across all variables simultaneously without it taking this long to run.

Any advice or pointers would be helpful!

Thanks
Chaithra

1 Like

Hey Chaithra,

The bottleneck is that inner loop, each .isel(time=day_idx) on a Dask backed array with a full (lat, lon) index array triggers a separate graph evaluation. You’re doing that 5 times per variable per year, so for 4 variables that’s 20 Dask computations per year. That’s where the time goes.

The fix is to .compute() each variable into memory once per year, then do the indexing in NumPy:

import numpy as np

for k in range(start_y, end_y + 1):
    year_str = f"{k:04d}"

    # Load the year's data into memory once per variable
    prec_k = dataset_pr['pr'].sel(time=year_str).compute()
    ps_k   = ps_daily.sel(time=year_str).compute()
    ta_k   = ds_temp['ta'].sel(time=year_str).compute()
    wap_k  = ds_omega['wap'].sel(time=year_str).compute()

    # Apply your omega filter before extraction
    wap_k = wap_k.where(wap_k < 0)

    # Rx5day argmax — shape (lat, lon)
    rx5d = prec_k.rolling(time=5, center=False).sum()
    max_idx = rx5d.argmax(dim='time').values  # numpy array

    # Build all 5 day indices at once — shape (5, lat, lon)
    time_indices = np.stack([max_idx - 4 + i for i in range(5)])
    time_indices = np.clip(time_indices, 0, prec_k.sizes['time'] - 1)

    # Grid indices for advanced indexing
    lat_idx, lon_idx = np.meshgrid(
        np.arange(max_idx.shape[0]),
        np.arange(max_idx.shape[1]),
        indexing='ij'
    )

    # --- 3D variables (time, lat, lon) → result: (5, lat, lon) ---
    pr_window = prec_k.values[time_indices, lat_idx[None], lon_idx[None]]
    ps_window = ps_k.values[time_indices, lat_idx[None], lon_idx[None]]

    # --- 4D variables (time, plev, lat, lon) → result: (5, plev, lat, lon) ---
    ta_window = ta_k.values[
        time_indices[:, None, :, :],
        np.arange(ta_k.sizes['plev'])[None, :, None, None],
        lat_idx[None, None, :, :],
        lon_idx[None, None, :, :]
    ]

    wap_window = wap_k.values[
        time_indices[:, None, :, :],
        np.arange(wap_k.sizes['plev'])[None, :, None, None],
        lat_idx[None, None, :, :],
        lon_idx[None, None, :, :]
    ]

    # Note on dimension order: NumPy advanced indexing produces
    # (5, plev, lat, lon) for the 4D variables. Your original loop
    # with xarray .isel() gives (5, lat, lon, plev) because xarray
    # reorders dimensions during vectorized indexing. The values are
    # identical — just the axis order differs. If your downstream
    # code expects the xarray ordering, uncomment these:
    # ta_window = np.transpose(ta_window, (0, 2, 3, 1))
    # wap_window = np.transpose(wap_window, (0, 2, 3, 1))

    # ... continue with your analysis

Two things are making this faster. First, .compute() is called once per variable per year instead of being triggered 5 times inside the loop, that alone eliminates most of the overhead since each .isel() on a Dask array evaluates the full graph. Second, the NumPy advanced indexing grabs all 5 days in one pass instead of 5 separate calls.

I tested this against the loop approach on synthetic data at CMIP6-like resolution with Dask backed arrays, and it gives identical values for all four variables including the NaN pattern from the omega filter. The speedup will depend on your data’s chunk layout and I/O, but the Dask overhead reduction alone should make a noticeable difference.

The np.clip on the time indices is a safety guard for the edge case where max_idx - 4 goes below zero. In practice this shouldn’t happen because rolling(time=5, center=False).sum() returns NaN for the first 4 timesteps so argmax will always be ≥ 4 but it costs nothing and protects against surprises.

Let me know if you run into anything with the shapes, the 4D broadcasting can be fiddly to get right.

All the best,
Tom

1 Like

Hi Tom,

This is very helpful, thank you! I really appreciate your help and the clear explanation.

The reduction in overhead is significant. Here is the timing for processing a single year now: CPU times: user 3.57 s, sys: 3.34 s, total: 6.91 s Wall time: 20.1 s.

I still have one concern. For example, for CESM2 (r1i1p1f1), which has around 1200 years, this would take many hours(for a single realisation). Since I have multiple models, the total computation time becomes quite large.

Do you have any suggestions to further improve the performance?

Thanks in advance,
Chaithra

Hi again Chaithra,

My pleasure! - Let me do some tests in the morning regarding your consern, you could potentially drop xarray, but, then it’s going to be more code to get around.

Talk tomorrow!

Tom

Hi again Chaithra,

Good to hear the extraction itself is working. The 20s wall time vs 7s CPU time tells you the bottleneck is now I/O and Dask overhead rather than the computation itself, each year triggers 4 separate .compute() calls, each of which builds and evaluates a Dask graph and reads from disk. Over 1200 years that’s ~4800 graph evaluations.

The biggest single win is to batch the loading, instead of .compute() on one year at a time, load 10-50 years in one call and then loop through the years in-memory:

python

import numpy as np

BATCH_SIZE = 10  # tune based on your available RAM

for batch_start in range(start_y, end_y + 1, BATCH_SIZE):
    batch_end = min(batch_start + BATCH_SIZE - 1, end_y)
    start_str = f"{batch_start:04d}"
    end_str = f"{batch_end:04d}"

    # One .compute() per batch instead of per year
    pr_batch = dataset_pr['pr'].sel(time=slice(start_str, end_str)).compute()
    ps_batch = ps_daily.sel(time=slice(start_str, end_str)).compute()
    ta_batch = ds_temp['ta'].sel(time=slice(start_str, end_str)).compute()
    wap_batch = ds_omega['wap'].sel(time=slice(start_str, end_str)).compute()
    wap_batch = wap_batch.where(wap_batch < 0)

    for k in range(batch_start, batch_end + 1):
        year_str = f"{k:04d}"

        # These are just views into the already-loaded data — no I/O
        prec_k = pr_batch.sel(time=year_str)
        ps_k = ps_batch.sel(time=year_str)
        ta_k = ta_batch.sel(time=year_str)
        wap_k = wap_batch.sel(time=year_str)

        rx5d = prec_k.rolling(time=5, center=False).sum()
        max_idx = rx5d.argmax(dim='time').values

        time_indices = np.stack([max_idx - 4 + i for i in range(5)])
        time_indices = np.clip(time_indices, 0, prec_k.sizes['time'] - 1)

        lat_idx, lon_idx = np.meshgrid(
            np.arange(max_idx.shape[0]),
            np.arange(max_idx.shape[1]),
            indexing='ij'
        )

        pr_window = prec_k.values[time_indices, lat_idx[None], lon_idx[None]]
        ps_window = ps_k.values[time_indices, lat_idx[None], lon_idx[None]]

        ta_window = ta_k.values[
            time_indices[:, None, :, :],
            np.arange(ta_k.sizes['plev'])[None, :, None, None],
            lat_idx[None, None, :, :],
            lon_idx[None, None, :, :]
        ]

        wap_window = wap_k.values[
            time_indices[:, None, :, :],
            np.arange(wap_k.sizes['plev'])[None, :, None, None],
            lat_idx[None, None, :, :],
            lon_idx[None, None, :, :]
        ]

        # ... your analysis here

This cuts the number of Dask graph evaluations from 4×1200 to 4×120 (with BATCH_SIZE=10). Since your wall time is dominated by I/O and graph overhead, that alone should bring the per-year cost down significantly.

For BATCH_SIZE, it depends on your memory. With CESM2 at ~192×288 and 19 pressure levels, each year of all 4 variables is roughly 3 GB - the 4D fields (ta, wap) dominate since they carry the pressure level dimension. So a batch of 10 years needs ~30 GB, workable on most HPC nodes. If you’re on a smaller machine, BATCH_SIZE=5 at ~15 GB is safer.

For the multiple models part, since each model is completely independent you can just run them in parallel, either as separate jobs on your cluster, or with a simple bash loop if you’re on a single machine. No code changes needed, just python your_script.py --model CESM2 & for each one.

Let me know how the batching goes, curious what per-year time you get with it.

Best,
Tom

Hi Tom,

You were absolutely right. I just ran a test for the first 20 years using BATCH_SIZE = 10, and the drop in I/O overhead is massive. Here are the total times for 20 years:

CPU times: user 40.7 s, sys: 1min 25s, total: 2min 6s Wall time: 2min 38s

That brings the per-year wall time down to just ~7.9 seconds. I’ll be setting up a bash loop to run the rest of the models in parallel as you suggested.

Thanks so much for all your help!

Regards,

Chaithra

2 Likes

Hey Chaithra,

That’s a solid result, 20s down to ~8s per year, and the wall time is now much closer to the CPU time! At that rate, you should be through the whole set pretty quickly.

Glad it worked out, it was a pleasure to assist! Good luck with the analysis!

All the best,

Tom