From 5cce74d61082f7e69874cd1a37e27e857d73666a Mon Sep 17 00:00:00 2001
From: AS <aaron.spring@mpimet.mpg.de>
Date: Mon, 15 Jun 2020 11:46:50 +0200
Subject: [PATCH] plot stereo print

---
 pymistral/plot.py            | 44 ++++++++++++++++++++++++++++--------
 pymistral/tests/test_plot.py | 14 ++++++++++++
 2 files changed, 49 insertions(+), 9 deletions(-)

diff --git a/pymistral/plot.py b/pymistral/plot.py
index cf16e15..a291a09 100644
--- a/pymistral/plot.py
+++ b/pymistral/plot.py
@@ -34,7 +34,7 @@ class CartopyMap(object):
         self,
         ax=None,
         proj=None,
-        plot_lon_lat_axis=True,
+        draw_lon_lat_labels=True,
         feature=None,
         plot_type='pcolormesh',
         rm_cyclic=True,
@@ -44,7 +44,7 @@ class CartopyMap(object):
             ax=ax,
             proj=proj,
             feature=feature,
-            plot_lon_lat_axis=plot_lon_lat_axis,
+            draw_lon_lat_labels=draw_lon_lat_labels,
             plot_type=plot_type,
             rm_cyclic=rm_cyclic,
             **kwargs,
@@ -55,7 +55,7 @@ class CartopyMap(object):
         ax=None,
         proj=ccrs.PlateCarree(),
         feature='land',
-        plot_lon_lat_axis=True,
+        draw_lon_lat_labels=True,
         plot_type='pcolormesh',
         rm_cyclic=True,
         **kwargs,
@@ -68,6 +68,22 @@ class CartopyMap(object):
         if isinstance(proj, str):
             proj = eval(f'ccrs.{proj}()')
 
+        stereo_maps = (
+            ccrs.Stereographic,
+            ccrs.NorthPolarStereo,
+            ccrs.SouthPolarStereo,
+        )
+        if isinstance(proj, stereo_maps):
+            round = True
+            import matplotlib.path as mpath
+
+            theta = np.linspace(0, 2 * np.pi, 100)
+            center, radius = [0.5, 0.5], 0.5
+            verts = np.vstack([np.sin(theta), np.cos(theta)]).T
+            circle = mpath.Path(verts * radius + center)
+        else:
+            round = False
+
         xda = self._obj
         # da, convert to da or error
         if not isinstance(xda, xr.DataArray):
@@ -99,12 +115,12 @@ class CartopyMap(object):
         assert lon, (lon, xda.coords)
         assert lat, (lat, xda.coords)
 
-        if plot_lon_lat_axis:
+        if draw_lon_lat_labels:
             draw_labels = True
         else:
             draw_labels = False
         if proj != ccrs.PlateCarree():
-            plot_lon_lat_axis = False
+            draw_lon_lat_labels = False
 
         if plot_type == 'contourf':
             rm_cyclic = False
@@ -117,6 +133,13 @@ class CartopyMap(object):
             kwargs['cbar_kwargs'] = {'shrink': 0.6}
 
         if ax is None:
+            if round:
+                print(
+                    'use stereo maps and facet plots by:',
+                    ' fig,ax=plt.subplots(subplot_kw='
+                    f"{'projection':ccrs.NorthPolarStereo()}\n"
+                    'd.isel(time=0).plot_map(proj=proj,ax=ax)',
+                )
             if single_plot:
                 axm = getattr(xda.plot, plot_type)(
                     lon,
@@ -134,6 +157,8 @@ class CartopyMap(object):
                     **kwargs,
                 )
         else:
+            if round:
+                ax.set_boundary(circle, transform=ax.transAxes)
             axm = getattr(xda.plot, plot_type)(
                 lon, lat, ax=ax, transform=ccrs.PlateCarree(), **kwargs
             )
@@ -151,14 +176,15 @@ class CartopyMap(object):
                     edgecolor='k',
                 )
 
-            if plot_lon_lat_axis:
+            if draw_lon_lat_labels:
                 _set_lon_lat_axis(axes, proj)
             else:
                 gl = axes.gridlines(draw_labels=draw_labels)
                 gl.top_labels = False
                 gl.right_labels = False
-                gl.xlines = False
-                gl.ylines = False
+                if proj not in stereo_maps:
+                    gl.xlines = False
+                    gl.ylines = False
 
         if single_plot:
             if ax is None:
@@ -178,7 +204,7 @@ def _rm_singul_lon(ds, lon='lon', lat='lat'):
     lons = ds[lon].values
     fixed_lons = lons.copy()
     for i, start in enumerate(np.argmax(np.abs(np.diff(lons)) > 180, axis=1)):
-        fixed_lons[i, start + 1:] += 360
+        fixed_lons[i, start + 1 :] += 360
     lons_da = xr.DataArray(fixed_lons, ds[lat].coords)
     ds = ds.assign_coords({lon: lons_da})
     return ds
diff --git a/pymistral/tests/test_plot.py b/pymistral/tests/test_plot.py
index 74947f2..89ee805 100644
--- a/pymistral/tests/test_plot.py
+++ b/pymistral/tests/test_plot.py
@@ -1,6 +1,7 @@
 import pytest
 import pymistral
 import cartopy.crs as ccrs
+import matplotlib.pyplot as plt
 
 projections = [ccrs.PlateCarree(), ccrs.Robinson()]
 plot_types = ['pcolormesh', 'contourf']
@@ -50,3 +51,16 @@ def test_plot_map_facet(da, rc, plot_type, proj):
 @pytest.mark.parametrize('projstr', [None, 'Robinson', 'AlbersEqualArea'])
 def test_plot_map_projstr(da, projstr):
     da.isel(time=0).plot_map(proj=projstr)
+
+
+@pytest.mark.parametrize(
+    'da',
+    [
+        pytest.lazy_fixture('rasm_da'),
+        pytest.lazy_fixture('air_temperature_da'),
+    ],
+)
+@pytest.mark.parametrize('projstr', ['NorthPolarStereo'])
+def test_plot_map_stereo(da, projstr):
+    fig, ax = plt.subplots(subplot_kw={'projection': ccrs.NorthPolarStereo()})
+    da.isel(time=0).plot_map(proj=projstr, ax=ax)
-- 
GitLab