From 9229f9a0fbb813d68cd47363b34f5ae867a30a7e Mon Sep 17 00:00:00 2001
From: AS <aaron.spring@mpimet.mpg.de>
Date: Sun, 14 Jun 2020 18:12:32 +0200
Subject: [PATCH] draw_lon_lat_labels

---
 ci/pymistral.yml  |  2 +-
 pymistral/plot.py | 23 +++++++++++------------
 2 files changed, 12 insertions(+), 13 deletions(-)

diff --git a/ci/pymistral.yml b/ci/pymistral.yml
index 7399ca3..63435d5 100644
--- a/ci/pymistral.yml
+++ b/ci/pymistral.yml
@@ -42,4 +42,4 @@ dependencies:
     - pre-commit
     - pytest-lazy-fixture
     - pytest-tldr
-    + git+https://github.com/dask/dask-labextension.git
+    - git+https://github.com/dask/dask-labextension.git
diff --git a/pymistral/plot.py b/pymistral/plot.py
index fdff3cb..cf16e15 100644
--- a/pymistral/plot.py
+++ b/pymistral/plot.py
@@ -85,17 +85,6 @@ class CartopyMap(object):
         ), (xda.ndim, xda.dims)
         single_plot = True if xda.ndim == 2 else False
 
-        stereo_maps = (
-            ccrs.Stereographic,
-            ccrs.NorthPolarStereo,
-            ccrs.SouthPolarStereo,
-        )
-        if isinstance(proj, stereo_maps):
-            raise ValueError(
-                'Not implemented, see'
-                'https://github.com/luke-gregor/xarray_tools/accessors.py#L222'
-            )
-
         # find whether curv or not
         curv = False
         lon = None
@@ -110,6 +99,10 @@ class CartopyMap(object):
         assert lon, (lon, xda.coords)
         assert lat, (lat, xda.coords)
 
+        if plot_lon_lat_axis:
+            draw_labels = True
+        else:
+            draw_labels = False
         if proj != ccrs.PlateCarree():
             plot_lon_lat_axis = False
 
@@ -160,6 +153,12 @@ class CartopyMap(object):
 
             if plot_lon_lat_axis:
                 _set_lon_lat_axis(axes, proj)
+            else:
+                gl = axes.gridlines(draw_labels=draw_labels)
+                gl.top_labels = False
+                gl.right_labels = False
+                gl.xlines = False
+                gl.ylines = False
 
         if single_plot:
             if ax is None:
@@ -179,7 +178,7 @@ def _rm_singul_lon(ds, lon='lon', lat='lat'):
     lons = ds[lon].values
     fixed_lons = lons.copy()
     for i, start in enumerate(np.argmax(np.abs(np.diff(lons)) > 180, axis=1)):
-        fixed_lons[i, start + 1 :] += 360
+        fixed_lons[i, start + 1:] += 360
     lons_da = xr.DataArray(fixed_lons, ds[lat].coords)
     ds = ds.assign_coords({lon: lons_da})
     return ds
-- 
GitLab