from virtualizarr import open_virtual_dataset
import glob
import xarray as xr
import dask
from distributed import Client
client=Client()

def get_unique_vars(selection):
    unique_vars=[]
    for vd in selection:
        for dv in vd.data_vars:
            unique_vars.append(dv)
    unique_vars=list(set(unique_vars))
    return unique_vars

def get_time_series_of_single_vars(unique_vars, selection):
    vds=[]
    for uv in unique_vars:
        print(uv)
        virtual_datasets=[a for a in selection if uv in a.data_vars]
        to_drop=[
            b
            for b in virtual_datasets[0].variables
            if b != uv and b not in ["time","height","height_2"]
        ]
        print(to_drop)
        virtual_datasets=[a.drop_vars(to_drop,errors='ignore') for a in virtual_datasets]
        virtual_ds = xr.combine_nested(virtual_datasets, concat_dim=['time'], 
                                   coords='minimal', 
                                   compat='override', 
                                   fill_value=-8.e+33)
        vds.append(virtual_ds)
    return vds
dim="3d"
time="P1D"
trunk_glob=f"/work/mh0287/m300575/ngc5004/ngc5004_atm_{dim}_{time}*"
virtual_datasets_all = []
for filepath in sorted(glob.glob(trunk_glob)):
    if dim == "2d" and "20200101T000000Z.nc" in filepath:
        continue
    virtual_datasets_all.append(
        dask.delayed(open_virtual_dataset)(filepath)
    )
#with ProgressBar():
virtual_datasets_all=client.compute(virtual_datasets_all)
virtual_datasets_all=client.gather(virtual_datasets_all)
unique_vars=sorted(
    list(
        set(get_unique_vars(virtual_datasets_all))-
        set(["healpix","height_bnds","height_2_bnds"])
    )
)
vds=get_time_series_of_single_vars(unique_vars, virtual_datasets_all)
virtual_ds = xr.combine_by_coords(vds)
crs=xr.open_zarr(
    "https://eerie.cloud.dkrz.de/datasets/nextgems.ICON.ngc4008.PT15M_4/kerchunk"
)["crs"]
virtual_ds["crs"]=crs
virtual_ds["crs"].attrs=virtual_datasets_all[0]["healpix"].attrs
virtual_ds=virtual_ds.assign_coords(crs=[1])
if "height_bnds" in virtual_datasets_all[0].variables:
    virtual_ds["height_bnds"]=virtual_datasets_all[0]["height_bnds"]
if "height_2_bnds" in virtual_datasets_all[-1].variables:
    virtual_ds["height_2_bnds"]=virtual_datasets_all[-1]["height_2_bnds"]
virtual_ds.attrs={
    a:b
    for a,b in virtual_datasets_all[0].attrs.items()
    if not any(a==b for b in ["CDI","CDO","NCO","cdo_openmp_thread_number","history"])
}
virtual_ds.virtualize.to_kerchunk(
    f"/work/bm1344/DKRZ/kerchunks_batched/ngc5004_atm_{dim}_{time}.parq",
    format="parquet"
)