{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ed52c85",
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "\n",
    "import intake\n",
    "import numpy as np\n",
    "import xarray as xr\n",
    "\n",
    "import easygems.healpix as egh\n",
    "import easygems.remap as egr\n",
    "\n",
    "\n",
    "# Load C5 input data\n",
    "cat = intake.open_catalog(\"https://data.nextgems-h2020.eu/catalog.yaml\")\n",
    "ds = cat.ICON.C5.AMIP_CNTL(time=\"PT3H\", chunks={\"cell\": -1}).to_dask().pipe(egh.attach_coords)\n",
    "ds = ds[[\"hus2m\", \"pr\", \"rlds\", \"rsds\", \"sfcwind\", \"tas\"]]\n",
    "\n",
    "# Load ICON grid\n",
    "grid = xr.open_dataset(\"/pool/data/ICON/grids/public/mpim/0054/icon_grid_0054_R02B08_G.nc\")\n",
    "icon_lon, icon_lat = np.degrees(grid.clon) % 360, np.degrees(grid.clat)\n",
    "\n",
    "# Periodically extend longitude to the east and west (prevent interpolation gaps on the date line)\n",
    "lon_periodic = np.hstack((ds.lon - 360, ds.lon, ds.lon + 360))\n",
    "lat_periodic = np.hstack((ds.lat, ds.lat, ds.lat))\n",
    "\n",
    "# Dirty hack to fill holes at the poles\n",
    "lat_periodic[lat_periodic > 89.5] = 90.0\n",
    "lat_periodic[lat_periodic < -89.5] = -90.0\n",
    "\n",
    "compute = True\n",
    "if compute:\n",
    "    # Compute weights\n",
    "    weights = egr.compute_weights_delaunay(\n",
    "        points=(lon_periodic, lat_periodic),\n",
    "        xi=(icon_lon, icon_lat)\n",
    "    )\n",
    "\n",
    "    # Remap the source indices back to their valid range\n",
    "    weights = weights.assign(src_idx=weights.src_idx % ds.lon.size)\n",
    "    weights.to_netcdf(\"healpix_weights.nc\")\n",
    "else:\n",
    "    # Load pre-computed weights\n",
    "    weights = xr.open_dataset(\"healpix_weights.nc\")    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e15341c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_encoding(dataset):\n",
    "    return {\n",
    "        var: {\n",
    "            \"compression\": \"blosc_zstd\",\n",
    "            \"chunksizes\": (1, dataset.sizes[\"ncells\"]),  # (time, ncells)\n",
    "        }\n",
    "        for var in dataset.variables\n",
    "        if var not in dataset.dims\n",
    "    }\n",
    "\n",
    "# Apply remapping weights to full dataset\n",
    "ds_remap = xr.apply_ufunc(\n",
    "    egr.apply_weights,\n",
    "    ds,\n",
    "    kwargs=weights,\n",
    "    keep_attrs=True,\n",
    "    input_core_dims=[[\"cell\"]],\n",
    "    output_core_dims=[[\"ncells\"]],\n",
    "    output_dtypes=[\"f4\"],\n",
    "    vectorize=True,\n",
    "    dask=\"parallelized\",\n",
    "    dask_gufunc_kwargs={\n",
    "        \"output_sizes\": {\"ncells\": weights.sizes[\"tgt_idx\"]},\n",
    "    },\n",
    ")\n",
    "\n",
    "# Store output in NetCDF one month per file\n",
    "for y, m in itertools.product(range(1979, 1997), range(1, 13)):\n",
    "    print(y, m)\n",
    "    ds_remap.sel(\n",
    "        time=f\"{y}-{m:02d}\"\n",
    "    ).to_netcdf(\n",
    "        f\"/scratch/m/m300575/jsbach_forcing_{y}{m:02d}.nc\",\n",
    "        encoding=get_encoding(ds_remap),\n",
    "        mode=\"w\",\n",
    "    )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "main",
   "language": "python",
   "name": "main"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}