Skip to content
Snippets Groups Projects
Commit 48ddbc33 authored by Aaron Spring's avatar Aaron Spring
Browse files

plot_map

parent af35396e
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,7 @@ Check out our examples:
- easy access via `intake-esm` on CMIP5, CMIP6 and MiKlip output on `mistral`
- easy access via `intake` to ICDC observations on `mistral`
- grid handling of `MPIOM` via `xgcm`
- plotting the curvilinear `MPIOM` with `cartopy`: `xr.DataArray.plot_map()`
# Contact
......@@ -41,7 +42,8 @@ HPC system. A detailed explanation is at the DKRZ website: <https://www.dkrz.de/
### pymistral_preload
Place the `jupyter_preload` file into your home directory on `mistral` and
change the conda environment name if necessary.
change the conda environment name if necessary. Alternatively, you can create your own `conda`: <https://docs.conda.io/projects/conda/en/latest/user-guide/install/linux.html>
Here don't use your `$HOME` on `mistral`, specify path like `/work/yourgroup/m??????/miniconda3`.
### ./start-pymistral
......
......@@ -25,6 +25,12 @@ dependencies:
- nb_conda_kernels
- ipywidgets
- pip
- cftime=1.0.3.4
- nc-time-axis
- xrviz
- xhistogram
- pip:
- https://github.com/jbusecke/cmip6_preprocessing
- https://github.com/xgcm/xgcm
- git+https://github.com/jbusecke/cmip6_preprocessing.git
- git+https://github.com/xgcm/xgcm.git
- git+https://github.com/xgcm/xrft.git
- git+https://github.com/mathause/regionmask.git
......@@ -10,10 +10,5 @@ Available Modules:
"""
from . import setup
from . import plot
from . import variability
from . import hamocc
from . import mpiom
from . import slurm_post
from . import cmip
from . import hamocc, plot, setup, slurm_post
from .setup import cdo
# workaround until this works:
# # The library auto-initializes upon import.
# import pyessv
import glob
import itertools
import os
import cdo
import pandas as pd
import xarray as xr
from .setup import _squeeze_dims, cmip5_folder, my_system, tmp
cdo = cdo.Cdo(tempdir=tmp)
if my_system is 'local':
CV_basefolder = '/Users/aaron.spring/Coding/'
elif my_system is 'mistral':
CV_basefolder = '/home/mpim/m300524/'
travis = False
if os.getcwd().startswith('/home/travis/'): # workaround for travis
travis = True
CV_basefolder = os.getcwd()+'/'
# CMIP6
# read in all institutions
name = 'institution_id'
cvpath = CV_basefolder + 'CMIP6_CVs/CMIP6_' + name + '.json'
institution_ids = pd.read_json(cvpath).index[:-6].drop(
['CV_collection_version', 'CV_collection_modified'])
# read in all models
name = 'source_id'
cvpath = CV_basefolder + 'CMIP6_CVs/CMIP6_' + name + '.json'
model_ids = pd.read_json(cvpath).index[:-6].drop(
['CV_collection_version', 'CV_collection_modified']).values
# read in all sources
name = 'source_id'
cvpath = CV_basefolder + 'CMIP6_CVs/CMIP6_' + name + '.json'
source_ids = pd.read_json(cvpath)
# read in all activities/MIPs
name = 'activity_id'
cvpath = CV_basefolder + 'CMIP6_CVs/CMIP6_' + name + '.json'
mip_ids = pd.read_json(cvpath).index[:-6].drop(
['CV_collection_version', 'CV_collection_modified']).values
mip_table = pd.read_json(cvpath).drop(
['CV_collection_version', 'CV_collection_modified'])
mip_longnames = pd.read_json(cvpath)['activity_id'][:-6].drop(
['CV_collection_version', 'CV_collection_modified']).values
# read in experiments
name = 'experiment_id'
cvpath = CV_basefolder + 'CMIP6_CVs/CMIP6_' + name + '.json'
experiment_ids = pd.read_json(cvpath).index.drop(
['CV_collection_modified', 'CV_collection_version', 'author']).values
# wrappers using the above
def CMIP6_CV_model_participations(model):
"""Returns MIPs a model participates in.
Args:
model (str): model from model_ids
Returns:
(list) of MIPs a model participates in.
Example:
CMIP6_CV_model_participations('MPI-ESM1-2-HR')
"""
s = source_ids.loc[model].values[0]
return s['activity_participation']
def participation_of_models(mip):
"""Return a list of all CMIP6 models participating in a MIP.
Args:
mip (str): MIP from mip_ids
Example:
participation_of_models('C4MIP')
participation_of_models('DCPP')
"""
mip_models = []
for model in model_ids:
if mip in CMIP6_CV_model_participations(model):
mip_models.append(model)
return mip_models
# CMIP5 on mistral
if not travis:
cmip5_centers_mistral = os.listdir(cmip5_folder)
cmip5_models_mistral = {}
for center in cmip5_centers_mistral:
models = os.listdir('/'.join((cmip5_folder, center)))
cmip5_models_mistral[center] = models
cmip5_all_models_mistral = list(
itertools.chain.from_iterable(cmip5_models_mistral.values()))
def _get_path_cmip(base_folder=cmip5_folder,
model='MPI-ESM-LR',
center='MPI-M',
exp='historical',
period='mon',
varname='tos',
comp='ocean',
run_id='r1i1p1',
ending='.nc',
timestr='*',
**kwargs):
try:
path_v = sorted(
glob.glob('/'.join([
base_folder, center, model, exp, period, comp,
comp[0].upper() + period, run_id, 'v????????'
])))[-1]
return path_v + '/' + varname + '/' + '_'.join([
varname, comp[0].upper() + period, model, exp, run_id, timestr
]) + ending
except:
return '/'.join([
base_folder, center, model, exp, period, comp,
comp[0].upper() + period, run_id
])
# wrapper to check which data is available
def find_cmip5_output(**kwargs):
"""Find available CMIP5 output on mistral. Returns model and center list."""
output_models = []
output_centers = []
for center in cmip5_centers_mistral:
for model in cmip5_models_mistral[center]:
filestr = _get_path_cmip(model=model, center=center, **kwargs)
if glob.glob(filestr) != []:
# print(model,center,'exists')
output_models.append(model)
output_centers.append(center)
print(len(output_models))
return output_centers, output_models
# TODO: adapt for CMIP6, maybe with CMIP=5 arg
def load_cmip(base_folder=cmip5_folder,
model='MPI-ESM-LR',
center='MPI-M',
exp='historical',
period='mon',
varname='tos',
comp='ocean',
run_id='r1i1p1',
ending='.nc',
timestr='*',
operator='',
select=''):
"""Load a variable from CMIP5."""
ncfiles_cmip = _get_path_cmip(
base_folder=cmip5_folder,
model=model,
center=center,
exp=exp,
period=period,
varname=varname,
comp=comp,
run_id=run_id,
ending=ending,
timestr=timestr)
nfiles = len(glob.glob(ncfiles_cmip))
if nfiles is 0:
raise ValueError('no files found in', ncfiles_cmip)
# # TODO: check all args for reasonable inputs, check path exists explicitly
print('Load', nfiles, 'files from:', ncfiles_cmip)
if operator is not '':
print('preprocessing: cdo', operator, ncfiles_cmip)
return xr.open_dataset(
cdo.addc(
'0',
input=operator + ' -select,name=' + varname + select + ' ' +
ncfiles_cmip,
options='-r')).squeeze()[varname]
else:
print('xr.open_mfdataset(' + ncfiles_cmip + ')[' + varname + ']')
return xr.open_mfdataset(ncfiles_cmip, concat_dim='time')[varname]
def load_cmip5_from_center_model_list(center_list=['MPI-M', 'NCAR'],
model_list=['MPI-ESM-LR', 'CCSM4'],
**cmip_kwargs):
data = []
for center, model in zip(center_list, model_list):
print('Load', center, model)
data.append(load_cmip(center=center, model=model, **cmip_kwargs))
data = xr.concat(data, 'mode')
data['model'] = model_list
return data
def get_center_for_cmip5_model(model):
"""Get center name for a CMIP5 model based on CMIP5 centers and models found on mistral."""
for center in cmip5_centers_mistral:
if model in cmip5_models_mistral[center]:
return center
def load_cmip5_from_model_list(model_list=['MPI-ESM-LR', 'CCSM4'],
**cmip_kwargs):
"""Load CMIP5 output from mistral based on model_list.
experiment_id, variables, ... to be specified in **cmip_kwargs."""
data = []
ml = model_list.copy()
for model in model_list:
center = get_center_for_cmip5_model(model)
print(center, model)
filestr = _get_path_cmip(model=model, center=center, **cmip_kwargs)
if glob.glob(filestr) != []:
new = load_cmip(center=center, model=model, **cmip_kwargs)
new = _squeeze_dims(new)
data.append(new)
else:
print('not found', filestr)
ml.remove(model)
try:
data = xr.concat(data, 'model')
data['model'] = ml
except:
print('some error: returns list')
return data
def load_cmip5_many_varnames(varnamelist=['tos', 'sos'], **cmip_kwargs):
"""Load many variables from varnamelist from CMIP5 output from mistral.
experiment_id, model_ids, ... to be specified in **cmip_kwargs."""
data = []
for varname in varnamelist:
print('Load', varname)
data.append(load_cmip(varname=varname, **cmip_kwargs))
data = xr.merge(data)
return data
......@@ -49,3 +49,16 @@ def temfa_phofa(ds):
temfa = .6 * 1.066**(ds['tsw'] - 273.15)
phofa = ds['soflwac'] * 0.02
return temfa * phofa / (np.sqrt(phofa**2 + temfa**2))
r_ppmw2ppmv = 28.8 / 44.0095
CO2_to_C = 44.0095 / 12.0111
def convert_C(ds):
"""Converts CO2 from ppmw to ppmv and co2_flux to C."""
if 'CO2' in ds.data_vars:
ds = ds * 1e6 * r_ppmw2ppmv
if 'co2_fl' in ds.data_vars:
ds = ds / CO2_to_C
return ds
import xarray as xr
def standardize(ds, time_dim='year', index=True):
return (ds - ds.mean(time_dim)) / ds.std(time_dim)
mask_folder = '/work/mh0727/m300524/experiments/masks'
def calc_enso_index(ds, type='12', index=True, time_dim='time'):
if type is '12':
enso_weights = xr.open_dataset(
mask_folder +
'/GR15_lon_-90--80_lat_-10-0.weights.nc')['area'].squeeze()
del enso_weights['depth']
del enso_weights['time']
sst = ds['tos']
sst_clim = sst.groupby('time.month').mean(dim='time')
sst_anom = sst.groupby('time.month') - sst_clim
sst_anom_nino = (sst_anom * enso_weights).sum(['y', 'x'])
if index:
return standardize(sst_anom_nino, time_dim=time_dim)
else:
return sst_anom_nino
......@@ -7,38 +7,109 @@ from cartopy.mpl.ticker import LatitudeFormatter, LongitudeFormatter
from matplotlib.ticker import MaxNLocator
def my_plot(data,
projection=ccrs.PlateCarree(),
coastline_color='gray',
curv=False,
**kwargs):
"""Wrap xr.plot with cartopy."""
plt.figure(figsize=(10, 5))
if curv:
data = _rm_singul_lon(data)
ax = plt.subplot(projection=projection)
data.plot.pcolormesh(
'lon', 'lat', ax=ax, transform=ccrs.PlateCarree(), **kwargs)
# data.plot.contourf('lon', 'lat', ax=ax,
# transform=ccrs.PlateCarree(), **kwargs)
ax.coastlines(color=coastline_color, linewidth=1.5)
if curv:
ax.add_feature(cp.feature.LAND, zorder=100, edgecolor='k')
if projection == ccrs.PlateCarree():
_set_lon_lat_axis(ax, projection)
def _rm_singul_lon(ds):
@xr.register_dataarray_accessor('plot_map')
class CartopyMap(object):
"""
Plot the given 2D array on a cartopy axes with ('xc','lon','longitude') assumed as Longitude and ('yc','lat','latitude') assumed as Latitude.
The default projection is PlateCarree, but can be:
cartopy.crs.<ProjectionName>()
If you would like to create a figure with multiple subplots
you can pass an axes object to the function with keyword argument `ax,
BUT then you need to specify the projection when you create the axes:
plt.axes([x0, y0, w, h], projection=cartopy.crs.<ProjectionName>())
Additional keywords can be given to the function as you would to
the xr.DataArray.plot function. The only difference is that `robust`
is set to True by default.
The function returns a GeoAxes object to which features can be added with:
ax.add_feature(feature.<FeatureName>, **kwargs)
"""
def __init__(self, xarray_obj):
self._obj = xarray_obj
def __call__(self, ax=None, proj=ccrs.PlateCarree(), plot_lon_lat_axis=True, feature='land', plot_type='pcolormesh', **kwargs):
return self._cartopy(ax=ax, proj=proj, feature=feature, plot_lon_lat_axis=plot_lon_lat_axis, plot_type=plot_type, **kwargs)
def _cartopy(self, ax=None, proj=ccrs.PlateCarree(), feature='land', plot_lon_lat_axis=True, plot_type='pcolormesh', **kwargs):
xda = self._obj
# da, convert to da or error
if not isinstance(xda, xr.DataArray):
if len(xda.data_vars) == 1:
xda = xda[xda.data_vars[0]]
else:
raise ValueError(
f'Please provide xr.DataArray, found {type(xda)}')
stereo_maps = (ccrs.Stereographic,
ccrs.NorthPolarStereo,
ccrs.SouthPolarStereo)
if isinstance(proj, stereo_maps):
raise ValueError(
'Not implemented, see https://github.com/luke-gregor/xarray_tools/blob/master/accessors.py#L222')
# find whether curv or not
curv = False
for c in xda.coords:
if len(xda[c].dims) == 2:
curv = True
if c in ['xc', 'lon', 'longitude']:
lon = c
if c in ['yc', 'lat', 'latitude']:
lat = c
xda = self._obj
if not isinstance(xda, xr.DataArray):
raise ValueError(f'Please provide xr.DataArray, found {type(xda)}')
if proj in [ccrs.Robinson]:
plot_lon_lat_axis = False
assert xda.ndim == 2 or (xda.ndim == 3 and 'col' in kwargs or 'row' in kwargs) or (
xda.ndim == 4 and 'col' in kwargs and 'row' in kwargs)
if curv:
xda = _rm_singul_lon(xda, lon=lon, lat=lat)
if 'robust' not in kwargs:
kwargs['robust'] = True
if 'cbar_kwargs' not in kwargs:
kwargs['cbar_kwargs'] = {'shrink': .6}
if ax is None:
axm = getattr(xda.plot, plot_type)(
lon, lat, subplot_kws={'projection': proj}, transform=ccrs.PlateCarree(), **kwargs)
else:
axm = getattr(xda.plot, plot_type)(
lon, lat, ax=ax, transform=ccrs.PlateCarree(), **kwargs)
for axes in axm.axes.flat:
if 'coastline_color' in kwargs:
coastline_color = kwargs['coastline_color']
else:
coastline_color = 'gray'
axes.coastlines(color=coastline_color, linewidth=1.5)
if feature is not None:
axes.add_feature(getattr(cp.feature, feature.upper()),
zorder=100, edgecolor='k')
if plot_lon_lat_axis:
_set_lon_lat_axis(axes, proj)
return axm
def _rm_singul_lon(ds, lon='lon', lat='lat'):
"""Remove singularity from coordinates.
http://nbviewer.jupyter.org/gist/pelson/79cf31ef324774c97ae7
"""
lons = ds['lon'].values
lons = ds[lon].values
fixed_lons = lons.copy()
for i, start in enumerate(np.argmax(np.abs(np.diff(lons)) > 180, axis=1)):
fixed_lons[i, start + 1:] += 360
lons_da = xr.DataArray(fixed_lons, ds.lat.coords)
ds = ds.assign_coords(lon=lons_da)
lons_da = xr.DataArray(fixed_lons, ds[lat].coords)
ds = ds.assign_coords({lon: lons_da})
return ds
......
import glob
import os
import cdo
import numpy as np
import pandas as pd
import xarray as xr
from tqdm import tqdm_notebook
# builds on export WORK, GROUP to be set in your bashrc
user = os.environ['LOGNAME']
my_system = None
cdo_mistral = True
group = 'mh0727'
try:
my_system = None
host = os.environ['HOSTNAME']
user = os.environ['USER']
work = os.environ['WORK']
group = os.environ['GROUP']
assert group == os.environ['GROUP']
for node in ['mlogin', 'mistralpp']:
if node in host:
my_system = 'mistral'
......@@ -24,76 +22,40 @@ try:
except:
my_system = 'local'
# setup folders for working on mistral
if my_system is 'mistral':
mistral_work = '/work/'
file_origin = mistral_work
work = mistral_work + group + '/' + user + '/'
tmp = file_origin + 'tmp'
work = f'{mistral_work}{group}/{user}/'
tmp = work + 'tmp'
if not os.path.exists(tmp):
os.makedirs(tmp)
# setup folder for working via sshfs_mistral in ~/mistral_work
elif my_system is 'local':
mistral_work = '/Users/aaron.spring/mistral_work/'
work = mistral_work + 'mh0727/m300524/' # group + '/' + user + '/'
file_origin = work
mistral_work = f'~/mistral_work/'
work = f'{mistral_work}{group}/{user}/'
cdo_mistral = True
if cdo_mistral:
tmp = os.path.expanduser('~/tmp')
else:
tmp = file_origin + 'tmp'
tmp = work + 'tmp'
if not os.path.exists(tmp):
os.makedirs(tmp)
# start
cdo = cdo.Cdo(tempdir=tmp)
# TODO: load all cmip cmorized varnames?
# Make sure you have create work/group/mxxxxxx/tmp dir
sample_file_dir = work + 'experiments/sample_files/'
# hamocc_data_2d_varnamelist = cdo.showname(
# input=sample_file_dir + 'hamocc_data_2d_*')[0].split()
# echam6_co2_varnamelist = cdo.showname(
# input=sample_file_dir + 'echam6_co2*')[0].split()
# mpiom_data_2d_varnamelist = cdo.showname(
# input=sample_file_dir + 'mpiom_data_2d_*')[0].split()
PM_path = file_origin + 'experiments/'
GE_path = file_origin + 'experiments/GE/'
# cmip5_folder = '/work/ik0555/cmip5/archive/CMIP5/output'
cmip6_folder = mistral_work+'ik1017/CMIP6/data/CMIP6'
cmip5_folder = mistral_work+'kd0956/CMIP5/data/cmip5/output1'
my_GE_path = file_origin + '160701_Grand_Ensemble/'
GE_post = my_GE_path + 'postprocessed/'
PM_post = PM_path + 'postprocessed/'
GE_folder = mistral_work+'mh1007'
def read_table_file(table_file_str):
"""Read partab/.codes file."""
table_file = pd.read_fwf(
table_file_str,
header=None,
names=[
'code', 'a', 'varname', 'b', 'c', 'long_name_and_unit', 'd', 'e'
])
table_file.index = table_file['code']
for a in 'abcde':
del table_file[a]
table_file['novarname'] = 'var' + table_file['code'].apply(str)
table_file['unit'] = table_file['long_name_and_unit'].str.split(
'[', expand=True).get(1)
table_file['long_name'] = table_file['long_name_and_unit'].str.split(
'[', expand=True).get(0)
table_file['unit'] = table_file['unit'].str.replace(']', '')
del table_file['long_name_and_unit']
return table_file
def set_table(ds, table_file_str):
"""Replace variables in ds with table."""
table = read_table_file(table_file_str)
table_dict = {}
for i in table.index:
key, item = table[['novarname', 'varname']].loc[i]
table_dict[key] = item
return ds.rename(table_dict)
def remap_cdo(da):
if not isinstance(da, xr.core.dataset.Dataset):
da = da.to_dataset()
remap = cdo.remapbil(
'r360x180', input=da, returnXDataset=True, options='-P 8')
return remap
def _decode_ym_cftime_to_int(ds):
......@@ -118,397 +80,18 @@ def _squeeze_dims(ds):
return ds.squeeze()
def _set_LY(ds, first=1):
"""Set integer time index starting with first."""
ds = ds.assign(time=np.arange(first, first + ds.time.size))
return ds
def _get_path(varname=None, exp='PM', prefix='ds', ta='ym', **kwargs):
"""Get postprocessed path."""
if exp is 'PM':
path = PM_path + 'postprocessed/'
elif exp is 'GE':
path = my_GE_path + 'postprocessed/'
suffix = ''
if prefix not in ['ds', 'control']:
for key, value in kwargs.items():
if prefix in ['skill']:
if str(key) in ['sig', 'bootstrap']:
continue
if isinstance(value, str):
suffix += "_" + key + "_" + str(value)
else:
suffix += "_" + key + "_" + str(value)
filename = prefix + '_' + varname + '_' + ta + suffix + '.nc'
full_path = path + filename
return full_path
def save(ds, varname=None, exp='PM', prefix='ds', ta='ym', **kwargs):
"""Save xr.object to _get_path location."""
full_path = _get_path(
varname=varname, exp=exp, prefix=prefix, ta=ta, **kwargs)
print('save in:', full_path)
ds.to_netcdf(full_path)
def _set_mm_span(ds):
"""Set monthly mean time axis.
# TODO: make possible for year 2300 or 1100.
Starts in 1900 because of cftime limit."""
span = pd.date_range(start='1/1/1900', periods=ds.time.size, freq='M')
return ds.assign(time=span)
def yearmonmean(ds):
return ds.groupby('time.year').mean('time').rename({'year': 'time'})
def yearsum(ds):
return ds.groupby('time.year').sum('time').rename({'year': 'time'})
r_ppmw2ppmv = 28.8 / 44.0095
CO2_to_C = 44.0095 / 12.0111
def convert_C(ds):
"""Converts CO2 from ppmw to ppmv and co2_flux to C."""
if 'CO2' in ds.data_vars:
ds = ds * 1e6 * r_ppmw2ppmv
if 'co2_fl' in ds.data_vars:
ds = ds / CO2_to_C
return ds
def _get_codes_str(file_type):
return sample_file_dir + 'log/*' + file_type + '.codes'
def _get_GE_path(ext='hist',
m=1,
model='hamocc',
outdatatype='data_2d_mm',
timestr='*',
ending='.nc'):
return GE_path + ext + '/' + ext + str(m).zfill(
4) + '/outdata/' + model + '/' + ext + str(m).zfill(
4) + '_' + model + '_' + outdatatype + '_' + timestr + ending
def _get_PM_path(init=3014,
m=0,
model='hamocc',
outdatatype='data_2d_mm',
timestr='*',
ending='.nc',
control=False):
if control:
run_id = 'vga0214'
else:
run_id = 'asp_esmControl_ens' + str(init) + '_m' + str(m).zfill(3)
return PM_path + run_id + '/outdata/' + model + '/' + run_id + '_' + model + '_' + outdatatype + '_' + timestr + ending
def _get_GE_full_path(ext=['hist', 'rcp26'],
m=1,
model='hamocc',
outdatatype='data_2d_mm',
timestr='*',
ending='.nc'):
path_list = []
for ext in ext:
path_list.append(
_get_GE_path(
ext=ext,
m=m,
model=model,
outdatatype=outdatatype,
timestr=timestr,
ending=ending))
return ' '.join(path_list)
def _agg_over_time(file_str,
varnamelist,
options='',
cdo_op=' -yearmonmean ',
levelstr=''):
"""Aggregate files along time dimension. Converts to netcdf. Optional cdo
operator applicable."""
varnstr = ','.join(varnamelist)
return cdo.addc(
0,
input=cdo_op + ' -select,name=' + varnstr + levelstr + ' ' + file_str,
options='-r ' + options)
def load(varnamelist=['tos'],
exp='PM',
cdo_op='-yearmonmean ',
model='hamocc',
outdatatype='data_2d_mm',
ending='.nc',
levelstr='',
**kwargs):
"""Load variable. """
if exp is 'PM':
file_str = _get_PM_path(
model=model, outdatatype=outdatatype, ending=ending, **kwargs)
elif exp is 'GE':
file_str = _get_GE_path(
model=model, outdatatype=outdatatype, ending=ending, **kwargs)
else:
print('no fs')
if ending == '.grb':
if model == 'echam6':
if outdatatype in ['co2_mm', 'co2']:
codes = _get_codes_str('echam6_co2')
options = '-f nc -t ' + codes
elif outdatatype in ['tracer', 'tracer_mm']:
codes = _get_codes_str('echam6_tracer')
options = '-f nc -t ' + codes
elif outdatatype == 'BOT_mm':
options = '-f nc -t echam6'
else:
raise ValueError('outdatatype not specified yet!')
else:
raise ValueError('model not specified yet!')
else:
options = ''
loaded = _agg_over_time(
file_str,
varnamelist,
options=options,
cdo_op=cdo_op,
levelstr=levelstr)
return loaded
def _load_PM(mmin=0,
mmax=9,
initlist=[3014],
varnamelist=['tos'],
curv=False,
exp='PM',
drop_none=False,
cdo_op='-yearmonmean ',
model='hamocc',
outdatatype='data_2d_mm',
ending='.nc',
levelstr='',
**kwargs):
if curv:
chunks = {'time': 21, 'x': 256, 'y': 220}
else:
chunks = {'time': 21, 'lat': 96, 'lon': 192}
dslist = []
for init in tqdm_notebook(
initlist, desc='initialization loop', leave=False):
many_member_ds = xr.concat([
xr.open_mfdataset(
load(
varnamelist=varnamelist,
exp=exp,
init=init,
m=m,
outdatatype=outdatatype,
model=model,
cdo_op=cdo_op,
levelstr=levelstr,
ending=ending,
**kwargs),
decode_times=False,
chunks=chunks,
preprocess=_squeeze_dims) for m in np.arange(mmin, mmax + 1)
],
dim='member')
many_member_ds = many_member_ds.assign(
member=np.arange(mmin, mmax + 1))
many_member_ds = _set_LY(many_member_ds)
dslist.append(many_member_ds)
ds = xr.concat(dslist, dim='initialization')
ds = ds.assign(initialization=initlist)
print(ds.nbytes / 1e9, 'GB')
print(ds.dims)
return ds
def _load_GE(memberlist=['rcp26', 'rcp45', 'rcp85'],
initlist=[1, 2, 3, 4, 5],
varnamelist=['sst'],
curv=False,
exp='GE',
drop_none=False,
cdo_op='-yearmonmean ',
model='mpiom',
outdatatype='data_2d_mm',
ending='.nc',
levelstr='',
**kwargs):
if curv:
chunks = {'time': 21, 'x': 256, 'y': 220}
else:
chunks = {'time': 21, 'lat': 96, 'lon': 192}
dslist = []
for m in tqdm_notebook(initlist, desc='initialization loop', leave=False):
many_rcp_ds = xr.concat([
xr.open_mfdataset(
load(
varnamelist=varnamelist,
exp=exp,
ext=rcp,
m=m,
outdatatype=outdatatype,
model=model,
cdo_op=cdo_op,
levelstr=levelstr,
ending=ending,
**kwargs),
decode_times=False,
chunks=chunks,
preprocess=_squeeze_dims) for rcp in memberlist
],
dim='member')
many_rcp_ds = many_rcp_ds.assign(member=memberlist)
many_rcp_ds = _set_LY(many_rcp_ds)
dslist.append(many_rcp_ds)
ds = xr.concat(dslist, dim='initialization')
ds = ds.assign(initialization=initlist)
print(ds.nbytes / 1e9, 'GB')
print(ds.dims)
return ds
def postprocess_PM(varnames,
initlist=[3014, 3023],
model='mpiom',
outdatatype='data_2d_mm',
levelstr='',
timestr='*',
ending='.nc',
curv=True):
"""Create lead year timeseries for perfect-model experiment.
Args:
varnames (type): Description of parameter `varnames`.
initlist (type): Description of parameter `initlist`. Defaults to [3014, 3023].
model (type): Description of parameter `model`. Defaults to 'mpiom'.
outdatatype (type): Description of parameter `outdatatype`. Defaults to 'data_2d_mm'.
levelstr (type): Description of parameter `levelstr`. Defaults to ''.
timestr (type): Description of parameter `timestr`. Defaults to '*'.
ending (type): Description of parameter `ending`. Defaults to '.nc'.
curv (type): Description of parameter `curv`. Defaults to True.
Returns:
type: Description of returned object.
"""
"""Create ym and mm output."""
cdo_op = ' ' # ' -yearmonmean '
for control in [False, True]:
for varname in varnames:
print(varname, 'control =', control)
if control:
ds = load(
varnamelist=[varname],
cdo_op=cdo_op,
model=model,
outdatatype=outdatatype,
levelstr=levelstr,
timestr=timestr,
ending=ending,
control=True)
ds = _squeeze_dims(xr.open_dataset(ds))
ds = convert_C(ds)
else:
ds = _load_PM(
varnamelist=[varname],
initlist=initlist,
cdo_op=cdo_op,
model=model,
outdatatype=outdatatype,
levelstr=levelstr,
timestr=timestr,
curv=curv,
ending=ending)
ds = convert_C(ds)
ds = _set_mm_span(ds)
# save mm
ta = 'mm'
save(ds, exp='PM', name=varname, control=control, ta=ta)
# save ym
if control:
ds = yearmonmean(ds)
# ds = _set_LY(ds, first=3000)
pass
else:
ds = _set_LY(yearmonmean(ds))
ta = 'ym'
save(ds, exp='PM', name=varname, control=control, ta=ta)
def postprocess_GE(varnames,
memberlist=['rcp26', 'rcp45', 'rcp85'],
initlist=[1, 2, 3, 4, 5],
model='mpiom',
outdatatype='data_2d_mm',
levelstr='',
timestr='*',
ending='.nc',
curv=True):
"""Create lead year timeseries for Grand Ensemble experiment of a list of varnames for list of extensions and members.
def _set_LY(ds, first=1, dim='lead'):
"""Set integer lead index starting with first."""
return ds.assign({dim: np.arange(first, first + ds[dim].size)})
Args:
varnames (type): Description of parameter `varnames`.
memberlist (type): Description of parameter `memberlist`. Defaults to ['rcp26', 'rcp45', 'rcp85'].
initlist (type): Description of parameter `initlist`. Defaults to [1, 2, 3, 4, 5].
model (type): Description of parameter `model`. Defaults to 'mpiom'.
outdatatype (type): Description of parameter `outdatatype`. Defaults to 'data_2d_mm'.
levelstr (type): Description of parameter `levelstr`. Defaults to ''.
timestr (type): Description of parameter `timestr`. Defaults to '*'.
ending (type): Description of parameter `ending`. Defaults to '.nc'.
curv (type): Description of parameter `curv`. Defaults to True.
Returns:
nothing
def yearmean(ds, dim='time'):
return ds.groupby('{dim}.year').mean(dim).rename({'year': dim})
Saves:
ds (xr.Dataset): lead year timeseries ('time','member','ensemble')
"""
"""Create ym."""
cdo_op = ' -yearmonmean '
control = False
for varname in varnames:
ds = _load_GE(
varnamelist=[varname],
memberlist=memberlist,
initlist=initlist,
cdo_op=cdo_op,
model=model,
outdatatype=outdatatype,
levelstr=levelstr,
timestr=timestr,
curv=curv,
ending=ending)
ds = convert_C(ds)
ds = _set_mm_span(ds)
# save mm
# ta = 'mm'
# save(ds, exp='GE', name=varname, control=control, ta=ta)
# save ym
# ds = _set_LY(yearmonmean(ds))
ta = 'ym'
save(ds, exp='GE', name=varname, control=control, ta=ta)
def yearsum(ds, dim='time'):
return ds.groupby('{dim}.year').sum(dim).rename({'year': dim})
def merge_monitoring(exp):
"""Merge all monitoring files of an experiment into one file."""
pass
def standardize(ds, dim='time'):
return (ds-ds.mean(dim))/ds.std(dim)
import re
import cdo
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sb
import xarray as xr
from scipy.signal import detrend, periodogram, tukey
from scipy.stats import chi2, pearsonr
def Sef2014_Fig3_ACF(control,
varnamelist,
area='Tropical_Pacific',
period='ym'):
"""Plot persistence as Autocorrelation
function (ACF) from control simulation.
Reference
---------
- Séférian, Roland, Laurent Bopp, Marion Gehlen, Didier Swingedouw,
Juliette Mignot, Eric Guilyardi, and Jérôme Servonnat. “Multiyear
Predictability of Tropical Marine Productivity.” Proceedings of the National
Academy of Sciences 111, no. 32 (August 12, 2014): 11646–51.
https://doi.org/10/f6cgs3.
Parameters
----------
control : Dataset with year dimension
Input data
varnamelist : list
variables to be included
area : str
area of interest
period : str
period of interest
"""
plt.figure(figsize=(6, 4))
df = control.sel(
area=area, period=period)[varnamelist].to_dataframe()[varnamelist]
cmap = sb.color_palette("husl", len(varnamelist))
for i, var in enumerate(df.columns):
pd.plotting.autocorrelation_plot(df[var], label=var, color=cmap[i])
plt.xlim([1, 20])
plt.ylim([-.5, 1])
plt.ylabel('Persistence (ACF)')
plt.xlabel('Lag [year]')
plt.legend(ncol=2, loc='lower center')
plt.title((' ').join(('Autocorrelation function', area, period)))
def power_spectrum_markov(control,
varname,
unit=''):
fig, ax = plt.subplots(figsize=(10, 4))
s = control.to_dataframe()[varname]
P, power_spectrum, markov, low_ci, high_ci = create_power_spectrum(s)
plot_power_spectrum_markov(
P,
power_spectrum,
markov,
low_ci,
high_ci,
color='k',
ax=ax,
unit=unit)
def plot_power_spectrum_markov(P,
power_spectrum,
markov,
low_ci,
high_ci,
ax=None,
legend=False,
plot_ci=True,
**kwargs):
ax.plot(P, power_spectrum, **kwargs)
ax.plot(
P,
markov, # label='theoretical Markov spectrum',
alpha=.5,
ls='--')
if plot_ci:
ax.plot(P, low_ci, c='gray', alpha=.5, linestyle='--')
ax.plot(P, high_ci, c='gray', alpha=.5, ls='--')
ax.set_xlabel('Period [yr]')
ax.set_ylabel('Power [(' + 'unit' + ')$^2$]')
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlim([2, 200])
if legend:
ax.legend()
ax.set_title('Power spectrum')
def _get_pcs(anom,
neofs=15,
pcscaling=0,
curv=True):
from eofs.xarray import Eof
def get_anom(df):
return df - df.mean('time')
coslat = np.cos(np.deg2rad(anom.coords['lat'].values))
wgts = np.sqrt(coslat)[..., np.newaxis]
if curv:
wgts = None
else:
coslat = np.cos(np.deg2rad(anom.coords['lat'].values))
wgts = np.sqrt(coslat)[..., np.newaxis]
solver = Eof(anom, weights=wgts)
eofs = solver.eofsAsCorrelation(neofs=neofs)
# eofcov = solver.eofsAsCovariance(neofs=neofs)
pcs = solver.pcs(npcs=neofs, pcscaling=pcscaling)
eofs['mode'] = np.arange(1, eofs.mode.size+1)
pcs['mode'] = np.arange(1, pcs.mode.size+1)
return eofs, pcs
def _get_max_peak_period(P, power_spectrum, high_ci):
significant_peaks = power_spectrum.where(power_spectrum > high_ci)
max_period = significant_peaks.argmax()
return P[max_period]
def Sef2013_Fig4_power_spectrum_pcs(control3d,
neofs=5,
plot_eofs=True,
curv=True,
palette='Set2',
print_peak=True
):
eofs, pcs = _get_pcs(
control3d, pcscaling=1, neofs=neofs)
cmap = sb.color_palette(palette, pcs.mode.size)
if plot_eofs:
eofs.plot(col='mode', robust=True, yincrease=not curv)
plt.show()
pcs.to_dataframe().unstack().plot(colors=cmap, figsize=(10, 4))
fig, ax = plt.subplots(figsize=(10, 4))
for i, mode in enumerate(pcs.mode):
P, power_spectrum, markov, low_ci, high_ci = create_power_spectrum(
pcs.sel(mode=mode).to_series())
plot_power_spectrum_markov(
P,
power_spectrum,
markov,
low_ci,
high_ci,
legend=True,
plot_ci=False,
ax=ax,
color=cmap[i],
label='PC' + str(int(mode)))
x = _get_max_peak_period(P, power_spectrum, high_ci)
if print_peak:
print('PC' + str(int(mode)),
'max peak at',
'{0:.2f}'.format(x),
'years.')
ax.axvline(
x=x,
c=cmap[i],
ls=':')
def corr_plot_2var(control,
varx='fgco2',
vary='po4os',
area='90S-35S',
period='ym'):
"""Plot the correlation between two variables."""
g = sb.jointplot(
x=varx,
y=vary,
data=control.sel(area=area, period=period).to_dataframe(),
kind='reg')
g.annotate(pearsonr)
def corrfunc(x, y, **kws):
"""Corr for corr_pairgrid."""
r, p = pearsonr(x, y)
ax = plt.gca()
ax.annotate(
"r = {:.2f}, p = {:.5f}".format(r, p),
xy=(.1, .9),
xycoords=ax.transAxes)
def corr_pairgrid(control,
varnamelist=['tos', 'sos', 'AMO'],
area='90S-35S',
period='ym'):
"""Plot pairgrid of variables from varnamelist."""
g = sb.PairGrid(
control.sel(area=area, period=period).to_dataframe()[varnamelist],
palette=["red"])
g.map_upper(plt.scatter, s=10)
g.map_diag(sb.distplot, kde=False)
g.map_lower(sb.kdeplot, cmap="Blues_d")
g.map_lower(corrfunc)
g.fig.subplots_adjust(top=0.9)
g.fig.suptitle((' ').join(('Correlations', area, period)))
def show_wavelets(control,
unit='',
cxmax=None):
s = control.to_series()
if cxmax is None:
cxmax = s.var() * 8
title = ' ' # (' ').join((varname, area, period))
label = ' ' # varname + ' ' + area
import pycwt as wavelet
from pycwt.helpers import find
from scipy.signal import detrend
dt = 1
N = s.size
t = s.index
dat_notrend = detrend(s.values)
std = dat_notrend.std() # Standard deviation
var = std**2 # Variance
dat_norm = dat_notrend / std # Normalized dataset
mother = wavelet.Morlet(6)
s0 = 2 * dt # Starting scale, in this case 2 * 0.25 years = 6 months
dj = 1 / 12 # Twelve sub-octaves per octaves
J = 7 / dj # Seven powers of two with dj sub-octaves
# Lag-1 autocorrelation for red noise
alpha, _, _ = wavelet.ar1(dat_notrend)
wave, scales, freqs, coi, fft, fftfreqs = wavelet.cwt(
dat_norm, dt, dj, s0, J, mother)
iwave = wavelet.icwt(wave, scales, dt, dj, mother) * std
power = (np.abs(wave))**2
fft_power = np.abs(fft)**2
period = 1 / freqs
signif, fft_theor = wavelet.significance(
1.0, dt, scales, 0, alpha, significance_level=0.95, wavelet=mother)
sig95 = np.ones([1, N]) * signif[:, None]
sig95 = power / sig95
glbl_power = power.mean(axis=1)
dof = N - scales # Correction for padding at edges
glbl_signif, tmp = wavelet.significance(
var,
dt,
scales,
1,
alpha,
significance_level=0.95,
dof=dof,
wavelet=mother)
sel = find((period >= 2) & (period < 8))
Cdelta = mother.cdelta
scale_avg = (scales * np.ones((N, 1))).transpose()
# As in Torrence and Compo (1998) equation 24
scale_avg = power / scale_avg
scale_avg = var * dj * dt / Cdelta * scale_avg[sel, :].sum(axis=0)
scale_avg_signif, tmp = wavelet.significance(
var,
dt,
scales,
2,
alpha,
significance_level=0.95,
dof=[scales[sel[0]], scales[sel[-1]]],
wavelet=mother)
figprops = dict(figsize=(11, 8), dpi=72)
fig = plt.figure(**figprops)
ax = plt.axes([0.1, 0.75, 0.65, 0.2])
ax.plot(t, iwave, '-', linewidth=1, color=[0.5, 0.5, 0.5])
ax.plot(t, dat_notrend, 'k', linewidth=1.5)
ax.set_title('a) {}'.format(title))
ax.set_ylabel(r'{} [{}]'.format(label, unit))
bx = plt.axes([0.1, 0.37, 0.65, 0.28], sharex=ax)
levels = [0.0625, 0.125, 0.25, 0.5, 1, 2, 4, 8, 16]
bx.contourf(
t,
np.log2(period),
np.log2(power),
np.log2(levels),
extend='both',
cmap=plt.cm.viridis)
extent = [t.min(), t.max(), 0, max(period)]
bx.contour(
t,
np.log2(period),
sig95, [-99, 1],
colors='k',
linewidths=2,
extent=extent)
bx.fill(
np.concatenate([t, t[-1:] + dt, t[-1:] + dt, t[:1] - dt, t[:1] - dt]),
np.concatenate([
np.log2(coi), [1e-9],
np.log2(period[-1:]),
np.log2(period[-1:]), [1e-9]
]),
'k',
alpha=0.3,
hatch='x')
bx.set_title('b) {} Wavelet Power Spectrum ({})'.format(
label, mother.name))
bx.set_ylabel('Period (years)')
#
Yticks = 2**np.arange(
np.ceil(np.log2(period.min())), np.ceil(np.log2(period.max())))
bx.set_yticks(np.log2(Yticks))
bx.set_yticklabels(Yticks)
cx = plt.axes([0.77, 0.37, 0.2, 0.28], sharey=bx)
cx.plot(glbl_signif, np.log2(period), 'k--')
cx.plot(var * fft_theor, np.log2(period), '--', color='#cccccc')
cx.plot(
var * fft_power,
np.log2(1. / fftfreqs),
'-',
color='#cccccc',
linewidth=1.)
cx.plot(var * glbl_power, np.log2(period), 'k-', linewidth=1.5)
cx.set_title('c) Global Wavelet Spectrum')
cx.set_xlabel(r'Power [({})^2]'.format(unit))
#cx.set_xlim([0, glbl_power.max() + var])
cx.set_xlim([0, cxmax])
cx.set_ylim(np.log2([period.min(), period.max()]))
cx.set_yticks(np.log2(Yticks))
cx.set_yticklabels(Yticks)
plt.setp(cx.get_yticklabels(), visible=False)
dx = plt.axes([0.1, 0.07, 0.65, 0.2], sharex=ax)
dx.axhline(scale_avg_signif, color='k', linestyle='--', linewidth=1.)
dx.plot(t, scale_avg, 'k-', linewidth=1.5)
dx.set_title('d) {}--{} year scale-averaged power'.format(2, 8))
dx.set_xlabel('Time (year)')
dx.set_ylabel(r'Average variance [{}]'.format(unit))
ax.set_xlim([t.min(), t.max()])
def remap(da):
if not isinstance(da, xr.core.dataset.Dataset):
da = da.to_dataset()
remap = cdo.remapbil(
'r360x180', input=da, returnXDataset=True, options='-P 8')
return remap
def plot_Hovmoeller(control,
varname,
lats=[-35, -30, -25, -20, -15],
mean_dim='lat',
latstep=5,
remap='cdo',
**kwargs):
if remap is 'cdo':
remap = cdo.remapbil(
'r360x180',
input=control.to_dataset(),
returnXArray=varname,
options='-P 8')
elif remap is 'xesmf':
remap = 0
fig, ax = plt.subplots(
ncols=len(lats), sharey=True, figsize=(5 * len(lats), 10))
for i, slat in enumerate(lats):
nlat = slat + latstep
if slat == lats[-1]:
colorbar = True
else:
colorbar = False
remap.sel(
lon=slice(150, 280), lat=slice(slat, nlat)).mean(mean_dim).plot(
ax=ax[i],
cmap='RdBu_r',
levels=11,
add_colorbar=colorbar,
**kwargs)
ax[i].set_title(str(slat) + '-' + str(nlat))
plt.tight_layout()
# s='GR15_lon_-150--120_lat_-10--35.mask.nc'
def _area_str_2_lons_lats(area):
lons = re.search('_lon_(.*)_lat', area).group(1)
lats = re.search('_lat_(.*).[mw]', area).group(1)
if lats[0] is '-':
lats = lats[1:]
if lons[0] is '-':
lons = lons[1:]
if '--' in lons:
lons = lons.replace('--', '-')
lonl, lonr = lons.split('-')
lonr = '-' + lonr
else:
lonl, lonr = lons.split('-')
if '--' in lats:
lats = lats.replace('--', '-')
latl, latr = lats.split('-')
latr = '-' + latr
else:
latl, latr = lats.split('-')
orilon = re.search('_lon_(.*)_lat', area).group(1)
orilon
if orilon[0] is '-':
lonl = '-' + lonl
orilat = re.search('_lat_(.*).[mw]', area).group(1)
if orilat[0] is '-':
latl = '-' + latl
return int(lonl), int(lonr), int(latl), int(latr)
def _lons_lats_2_area(lonl, lonr, latl, latr, grid='GR15', mask_weight='mask'):
return '_'.join(
(grid, 'lon', str(lonl))) + '-' + str(lonr) + '_' + 'lat' + '_' + str(
latl) + '-' + str(latr) + '.' + mask_weight + '.nc'
def _taper(x, p):
"""
Description needed here.
"""
window = tukey(len(x), p)
y = x * window
return y
def create_power_spectrum(s, pct=0.1, pLow=0.05):
"""
Create power spectrum with CI for a given pd.series.
Reference
---------
- /ncl-6.4.0-gccsys/lib/ncarg/nclscripts/csm/shea_util.ncl
Parameters
----------
s : pd.series
input time series
pct : float (default 0.10)
percent of the time series to be tapered. (0 <= pct <= 1). If pct = 0,
no tapering will be done. If pct = 1, the whole series is tapered.
Tapering should always be done.
pLow : float (default 0.05)
significance interval for markov red-noise spectrum
Returns
-------
p : np.ndarray
period
Pxx_den : np.ndarray
power spectrum
markov : np.ndarray
theoretical markov red noise spectrum
low_ci : np.ndarray
lower confidence interval
high_ci : np.ndarray
upper confidence interval
"""
# A value of 0.10 is common (tapering should always be done).
jave = 1 # smoothing ### DOESNT WORK HERE FOR VALUES OTHER THAN 1 !!!
tapcf = 0.5 * (128 - 93 * pct) / (8 - 5 * pct)**2
wgts = np.linspace(1., 1., jave)
sdof = 2 / (tapcf * np.sum(wgts**2))
pHigh = 1 - pLow
data = s - s.mean()
# detrend
data = detrend(data)
data = _taper(data, pct)
# periodigram
timestep = 1
frequency, power_spectrum = periodogram(data, timestep)
Period = 1 / frequency
power_spectrum_smoothed = pd.Series(power_spectrum).rolling(jave, 1).mean()
# markov theo red noise spectrum
twopi = 2. * np.pi
r = s.autocorr()
temp = r * 2. * np.cos(twopi * frequency) # vector
mkov = 1. / (1 + r**2 - temp) # Markov model
sum1 = np.sum(mkov)
sum2 = np.sum(power_spectrum_smoothed)
scale = sum2 / sum1
xLow = chi2.ppf(pLow, sdof) / sdof
xHigh = chi2.ppf(pHigh, sdof) / sdof
# output
markov = mkov * scale # theor Markov spectrum
low_ci = markov * xLow # confidence
high_ci = markov * xHigh # interval
return Period, power_spectrum_smoothed, markov, low_ci, high_ci
def create_composites(anomaly_field, timeseries, threshold=1, dim='time'):
index_comp = xr.full_like(timeseries, 'none', dtype='U4')
index_comp[timeseries >= threshold] = 'pos'
index_comp[timeseries <= -threshold] = 'neg'
composite = anomaly_field.groupby(index_comp.rename('index')).mean(dim=dim)
return composite
def standardize(ds, dim='time'):
return (ds-ds.mean(dim))/ds.std(dim)
def composite_analysis(field, timeseries, threshold=1, plot=True, **plot_kwargs):
index = standardize(timeseries)
field = field - field.mean('time')
composite = create_composites(field, index, threshold=threshold)
if plot:
composite.sel(index='pos').plot(**plot_kwargs)
plt.show()
composite.sel(index='neg').plot(**plot_kwargs)
else:
return composite
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment