Struggling with large dataset loading/reading using xarray

Hi Ryan, hi everyone,

Having some level of loops in a Python code is okay, performance wise. In my example, operations are still vectorized on 3D arrays. I really believe we would collectively improve our codes if instead of claiming to have “faster” codes, we would simply give the absolute speed in GFlops s⁻¹ of our codes. For a single thread code, a minimum of 1,0 GFlops s⁻¹ should be our standard, and anything less considered as problematic.

Compiling is not a magic bullet. You may have slow compiled code and fast interpreted code. To get performances out of a compiled code, you need to have a good grip on how your data are layout in memory. Memory layout is key for performances. It seems to be under-look and too many colleagues do not have this basic knowledge. I agree with you that if you really want to harness to full power of your architecture, then yes, you should definitely compile, and numba is clearly a very good option to stay in the Python world.

Cheers,
Guillaume

1 Like

Yes, that is the real of computer science, which only some of the people here will have ventured into.

Hi all.
Thanks to Xarray and dask.distributed, my data loaded in 20 minutes instead of more than 3 hours! However, the main challenge remains in how to avoid loops reducing the unbelievable long analysis time! I would appreciate any thought helping me to get rid of the spatial loops in this sample scripts. Following example needs 1-2 minutes.

Thanks in advance,
Ali

import numpy as np
import xarray as xr

da = xr.tutorial.load_dataset("air_temperature", engine="netcdf4").air

for j in range(0, (da.sizes['lat'])):                              
    for k in range(0, (da.sizes['lon'])):
        for i in range(11, (da.sizes['time']-2900)):
            if da[i-1,j,k] >= da.isel(lat=j,lon=k).quantile([0.40], 'time'):
                L=3
            else:
                if da[i-2,j,k] >= da.isel(lat=j,lon=k).quantile([0.40], 'time'):
                    L=4
                else:
                    if da[i-3,j,k] >= da.isel(lat=j,lon=k).quantile([0.40], 'time'):
                        L=5
1 Like

Hi Ali and welcome to the forum!

With Xarray and numpy in general, you almost never have to write these kinds of loops.

What do you want to achieve? What is the data analysis task you are trying to accomplish?

Many thanks Ryan for your warm reply. I believe Xarray certainly would be helpful here avoiding loops.

I’m striving to detect flash drought in the outputs of models and this small piece is only one example.

The challenge I faced was each condition can only compare one value and I think it’s impossible to do the comparison across entire longitudes and latitudes.

We want to help you but we still need more info.

Say more about this. Where do quantiles come in? Can you say in words what the code in your original post is trying to accomplish? What is L?

My apologies if some information were missing and please let me know if need to provide additional information. My original script is related to flash drought detection and requires soil moisture data. However, I came across to the air temperature in the tutorial and thereby changed my problem to an imaginary frost days detection. So L would be the number of frost days.

import numpy as np
import xarray as xr

# Finding the duration of frost days 
# Output is a map showing mean frost days in each particular location 

da = xr.tutorial.load_dataset("air_temperature", engine="netcdf4").air

Dur_FR = [] 
Duration_FR = np.zeros((da.sizes['lat'], da.sizes['lon']))

for j in range(0, (da.sizes['lat'])):                              
    for k in range(0, (da.sizes['lon'])):
        q20 = da.isel(lat=j,lon=k).quantile([0.20], 'time')
        q40 = da.isel(lat=j,lon=k).quantile([0.40], 'time')
        
        for i in range(11, (da.sizes['time']-2900)):
            if da[i-1,j,k] <= q20 and da[i,j,k] <= q20 and da[i+1,j,k] >= q20:
                if da[i-2,j,k] > q20:
                    if da[i-2,j,k] >= q40:
                        L=3
                        Dur_FR.append(L)
                    else:
                        if da[i-3,j,k] >= q40:
                            L=4
                            Dur_FR.append(L)
                        else:
                            if da[i-4,j,k] >= q40:
                                L=5
                                Dur_FR.append(L)
                else:
                    if da[i-3,j,k] > q20:
                        L=4
                        Dur_FR.append(L)
                        
        Duration_FR[j,k] = np.mean(Dur_FR)
        Dur_FR = []

Ok this is a great example to start from. Here’s how you would calculate the number of frost days, from a dataset of daily temperature using Xarray semantics.

import xarray as xr
ds = xr.tutorial.load_dataset("air_temperature")
one_per_day = xr.ones_like(ds.air)
num_frost_days = one_per_day.where(ds.air <= 273.15).sum('time')
num_frost_days.plot()

image

You’re using quantiles in your example, so maybe it’s helpful to use a quantile. Here is how you would count the number of days where the temperature is below the 0.2 quantile level (where the quantile is defined at each point in space).

air_q20 = ds.air.quantile(0.2, dim='time')  # has dims lat, lon
num_days_below_q20 = one_per_day.where(ds.air < air_q20).sum('time')
num_days_below_q20.plot()

image

(Disclaimer: I don’t know if this is a scientifically useful calculation. I’m just making up a simple example to share.)

As you can see, Xarray automatically figures out how dimensions are related in different arrays. You don’t have to manually loop through points in space.

1 Like

Ryan fantastic example! Thank you so much helping me ask my real question.

I wonder how awesome and flexible Xarray is if it could calculate mean duration of the events! The event starts when temperature decreases from above 40 percentile to below 20 percentile (with an average decline rate larger than 5% percentile) and if the declined temperature increases to above 20 percentile, the event terminates. And it should last at least 3 days. Duration of such event is my question.

I can’t think how to do that off the top of my head, but I’m sure it is possible! Maybe it can be accomplished with some clever use of rolling.

You might also investigate apply_ufunc.

Good luck!

1 Like

Many thanks for your nice advice to benefit rolling and apply_ufunc. Before coming to the forum, I actually tried to implement apply_ufunc but not much success. Now after couple of months using loops, I am quite certain without Xarray it is impossible to work with such big climate data!

I would really appreciate your thoughts/points on using apply_ufunc to solve this problem. And I wonder if you could elaborate more on clever use of rolling.

apply_ufunc is amazing function as I managed to compute the analysis over time for one particular location and then this function did the computation across all longitudes and latitudes! Here is great example in using apply_ufunc along time dimension Xarray:

Hi @rabernat it seems you are an expert in xarray.

I am builiding some functionalities in databricks with xarray and meteoroligcal gribs. I have experienced super long times (16 seconds long) doing a selection in a grid covering Western Europe, with a size of 50 Mb roughly.
I am using this command: sp_grib.sel(latitude=node[0], longitude=node[1]).sp.values
Additonaly, I have experienced much less times for equivalent grib files (in size and resolution).

Do you have any clue of what is going on? Is this normal?

Thank you very much in advance.

@INAVAS, it looks like perhaps you are trying to extract a single value from a Grib file? And perhaps you have a collection of these that you want to extract the same point from (e.g. extract a time series?)

If so, check this out: Accessing NetCDF and GRIB file collections as cloud-native virtual datasets using Kerchunk | by Peter Marsh | pangeo | Sep, 2022 | Medium

Thank you for your answer.

Exactly, I extract historical timeseries for a certain point in the grid from a own GRIB database located in a Azure Storage Container. I am processing this gribs from Azure Databricks.

The command sp_grib.sel(latitude=node[0], longitude=node[1]).sp.values takes x8 more time in Databricks than in my own PC. I am trying to figure out why.

Thank you I will check the article after I send the message.

@INAVAS, the reason is likely that many small metadata transactions are taking place when you access the GRIB file, which is okay on regular filesystems, but very slow when accessing object storage. The approach in the article extracts all the metadata in advance and stores that as a separate object. Then subsequent data reads (using the Zarr library to extract the chunks of data form the GRIB files) will be much faster.

2 Likes

Ok.

I see that in the post the guy is reading from a cloud dataset allready loaded. Is there any chance I can read my already downloaded grib files in the same manner?
Just keep reading my files but in an optimised way using Zarr.

Than you @rsignell

1 Like

Yes, you can use kerchunk on local files – I’m not sure the reading will be faster, but you will still benefit from being able to treat collections of files as an aggregated dataset and it will be much easier to modify the attribute metadata for the dataset.

Could I piggyback with a likely-silly question?

I’m working with the outputs of a large climate model (UKESM), but I’ve compressed into some nice hovmollers, such that the resulting data is 332(lat)x1800(months). I was struggling with doing anything on this data before using the .load() operator, which read it into memory for me. However, reading the dataset with .load() takes up to a minute, which a) is a small problem because I have 12 of them and b) surprises me, since 332x1800 really isn’t very big at all. {Trying to do something like a seasonal avearage without .load() is terribly slow, after .load() it’s reasonably fast, as I would expect.} Is there anything I can do to speed it up?

(Sidenote: I realize that the .pkl format may come with some problems when moving between systems and is probably not best practice - I should switch to netcdf).

time1 = time.time()
mapdir = '/gpfs/home/mep22dku/scratch/SOZONE/MEDUSA/BSUB_extractions/EXTRACT/'
UKESM_1A_pco2 = pickle.load(open(f'{mapdir}UKESM_1A_OCN_PCO2_lathovmoller_1950_2100.pkl', 'rb'))
UKESM_1A_pco2
#ukesm_T_SO = ukesm_T_SO.load()
time2 = time.time()
print(time2-time1)

time1 = time.time()
mapdir = '/gpfs/home/mep22dku/scratch/SOZONE/MEDUSA/BSUB_extractions/EXTRACT/'
UKESM_1A_pco2 = pickle.load(open(f'{mapdir}UKESM_1A_OCN_PCO2_lathovmoller_1950_2100.pkl', 'rb'))

UKESM_1A_pco2 = UKESM_1A_pco2.load()
time2 = time.time()
print(time2-time1)
0.04637885093688965
43.691916942596436

I agree that that amount of data should not take that long to load (I’d expect it to be pretty much instantaneous on a normal filesystem).

I’ve never used pickle files to serialize datasets, so I’d need a bit more information to figure out why it takes so long. Can you post the repr of your dataset? (print(UKESM_1A_pco2) or display(UKESM_1A_pco2) in ipython, plus maybe a screenshot of the output if you’re working in a notebook)

1 Like