Skip to content
Snippets Groups Projects
Commit af2ca7b8 authored by Sergei Petrov's avatar Sergei Petrov
Browse files

added function to compute SSDs convergence analysis

parent 9bcf4bdd
No related branches found
No related tags found
No related merge requests found
/arrays
import os
from pathlib import Path
import xarray as xr
import pandas as pd
import numpy as np
from matplotlib.ticker import ScalarFormatter, LogLocator, FuncFormatter
import matplotlib.pyplot as plt
from scipy.stats import linregress
def file_show_year(file_path=''):
file_path = Path(file_path)
file_name = file_path.name
if not file_path.exists():
raise FileNotFoundError(f"File '{file_path}' does not exist! Abort 'extract_year'!")
try:
# Open dataset without loading data variables
with xr.open_dataset(file_path, decode_times=True) as ds:
# Check for time coordinate
if 'time' not in ds.coords:
raise ValueError(f"File '{file_path}' does not contain a 'time' coordinate! Abort 'extract_year'!")
# Access only the first time value
first_time = ds.coords['time'].values[0]
# Convert to pandas datetime and extract the year
first_timestamp = pd.to_datetime(first_time)
return int(first_timestamp.year)
except Exception as e:
raise RuntimeError(f"An error occurred while processing '{file_path}'! Abort 'extract_year'!\n{e}")
def compute_monthly_means(file_path='',
variable='',
resample_op='',
resample_freq='',
savepath=''):
from pathlib import Path
import xarray as xr
file_path = Path(file_path)
savepath = Path(savepath)
if not file_path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
# If output already exists, load and return it directly
if savepath.exists():
print(f"📄 Monthly mean already exists, loading: {savepath}")
with xr.open_dataset(savepath) as ds:
return ds[variable]
# Else, compute and save
with xr.open_dataset(file_path) as ds:
if variable not in ds:
raise KeyError(f"Variable '{variable}' not found in {file_path}")
data = ds[variable]
# Resample to desired frequency (e.g., daily mean)
if resample_op == 'mean':
data_resampled = data.resample(time=resample_freq).mean()
elif resample_op == 'sum':
data_resampled = data.resample(time=resample_freq).sum()
else:
raise ValueError(f"Unsupported resampling operator: {resample_op}")
# Compute monthly mean
monthly_means = data_resampled.resample(time='ME').mean()
# Create output directory if needed
savepath.parent.mkdir(parents=True, exist_ok=True)
# Save to NetCDF
monthly_means.to_dataset(name=variable).to_netcdf(savepath)
print(f"💾 Saved monthly mean: {savepath}")
return monthly_means
def extract_monthly_means(sim_to_period={}, # map: sim id -> the desired period [year_start, year_end]
var_to_oper={}, # map var name -> operator for daily resampling (e.g. ['1D', 'mean'], ['1D','sum'])
project_dir='', # the base directory where model outputs are stored
store_dir=''): # the output folder to store monthly-mean netcdfs
sim_to_var_to_year = {
sim: {
var: {
year: None for year in np.arange(sim_to_period[sim][0], sim_to_period[sim][1]+1)
} for var in var_to_oper.keys()
} for sim in sim_to_period.keys()
}
for sim in sim_to_period.keys():
for var in var_to_oper.keys():
for year in np.arange(sim_to_period[sim][0], sim_to_period[sim][1]+1):
# check if output file already exist
out_file = os.path.join(store_dir,f'{sim}_{var}_{year}.nc')
if os.path.isfile(out_file):
# dont need to compute anything, just load the file
print(f'\"{out_file}\" exists, load the monthly netcdf file!')
with xr.open_dataset(out_file) as ds:
sim_to_var_to_year[sim][var][year] = out_file
else:
# check that model output dir exists
current_dir = os.path.join(project_dir, sim, 'post/yearly/',var)
if not os.path.isdir(current_dir):
raise Exception(f'Directory \"{current_dir}\" does NOT exist!')
# find the file for current year
current_file = ''
for file in sorted((os.listdir(current_dir))):
if (file_show_year(file_path=os.path.join(current_dir, file)) == year):
current_file = os.path.join(current_dir, file)
break
if not current_file:
raise Exception(f'Could not find the model output for the year {year}!')
# extract and store the monthly mean values
data = compute_monthly_means(file_path=current_file,
variable=var,
resample_op=var_to_oper[var][1],
resample_freq=var_to_oper[var][0],
savepath=out_file)
sim_to_var_to_year[sim][var][year] = out_file
return sim_to_var_to_year
def compute_and_plot_ssd_loglog(sim_pair=[],
list_of_variables=[],
Nmonth_averaging=[],
sim_to_var_to_year={},
var_to_units={}):
sim1, sim2 = sim_pair
for var in list_of_variables:
monthly_diffs = []
years = sorted(sim_to_var_to_year[sim1][var].keys())
for year in years:
ds1 = xr.open_dataset(sim_to_var_to_year[sim1][var][year])[var]
ds2 = xr.open_dataset(sim_to_var_to_year[sim2][var][year])[var]
diff = ds1 - ds2 # shape [12, lat, lon]
monthly_diffs.append(diff)
diff_all_months = xr.concat(monthly_diffs, dim='time') # shape [T, lat, lon]
n_total_months = diff_all_months.sizes['time']
all_x = []
all_y = []
for K in Nmonth_averaging:
if n_total_months < K:
continue
n_blocks = n_total_months // K
truncated = diff_all_months.isel(time=slice(0, n_blocks * K))
reshaped = truncated.data.reshape((n_blocks, K, *truncated.shape[1:]))
block_means = np.nanmean(reshaped, axis=1) # shape [n_blocks, lat, lon]
ssd_list = np.nanstd(block_means, axis=(1, 2)) # std dev per block over space
# Store values for plotting
all_x.append(np.full_like(ssd_list, K, dtype=float))
all_y.append(ssd_list)
# Flatten
all_x = np.concatenate(all_x)
all_y = np.concatenate(all_y)
# === Plotting ===
plt.figure(figsize=(8, 6))
plt.scatter(all_x, all_y, color='red', marker='x', label=f'{sim1} - {sim2}')
# Log-log linear regression
log_x = np.log(all_x)
log_y = np.log(all_y)
slope, intercept, _, _, _ = linregress(log_x, log_y)
trendline = np.exp(intercept) * all_x ** slope
plt.plot(all_x, trendline, 'r-.', linewidth=2.0)
plt.text(0.14, 0.95, f'log(y) = {slope:.2f} * log(x) + {intercept:.2f}',
transform=plt.gca().transAxes, fontsize=15, verticalalignment='top', color='red')
plt.xscale('log')
plt.yscale('log')
ax = plt.gca()
ax.xaxis.set_major_formatter(ScalarFormatter())
ax.ticklabel_format(style='plain', axis='x')
ax.yaxis.set_major_formatter(ScalarFormatter())
ax.yaxis.set_minor_locator(LogLocator(base=10.0, subs=np.arange(1, 10)*0.1, numticks=10))
ax.yaxis.set_minor_formatter(FuncFormatter(lambda y, _: f'{y:.3}'))
plt.tick_params(axis='both', which='major', length=8, width=1.5)
plt.tick_params(axis='both', which='minor', length=4, width=1)
# Grid
ax.grid(True, which='major', axis='y', linestyle='--', linewidth=1.5)
ax.grid(True, which='minor', axis='y', linestyle=':', linewidth=0.7)
ax.grid(True, which='major', axis='x', linestyle='--', linewidth=1.5)
plt.xticks(Nmonth_averaging, fontsize=10)
plt.xlabel('Number of months in averaging', fontsize=12)
plt.ylabel('SSD for timespan', fontsize=12)
plt.title(f'SSD in log-log scale for "{var}", {var_to_units[var]}', fontsize=20)
plt.legend(loc='upper right')
plt.tight_layout()
plt.show()
Source diff could not be displayed: it is too large. Options to address this: view the blob.
test
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment