Skip to content
Snippets Groups Projects
Commit fd44103a authored by Martin Bergemann's avatar Martin Bergemann :speech_balloon:
Browse files

Increase test coverage

parent 585d89e4
No related branches found
No related tags found
1 merge request!10Increase test coverage
Pipeline #19349 failed
...@@ -44,9 +44,7 @@ def _save_dataset( ...@@ -44,9 +44,7 @@ def _save_dataset(
logger.debug("Saving file ot %s", str(file_name)) logger.debug("Saving file ot %s", str(file_name))
try: try:
dset.to_netcdf( dset.to_netcdf(
file_name, file_name, engine=engine, encoding=encoding,
engine=engine,
encoding=encoding,
) )
except Exception as error: except Exception as error:
logger.error("Saving to file failed: %s", str(error)) logger.error("Saving to file failed: %s", str(error))
...@@ -55,7 +53,6 @@ def _save_dataset( ...@@ -55,7 +53,6 @@ def _save_dataset(
def _rechunk_dataset(dset: xr.Dataset) -> Tuple[xr.Dataset, Dict[str, Any]]: def _rechunk_dataset(dset: xr.Dataset) -> Tuple[xr.Dataset, Dict[str, Any]]:
encoding: Dict[str, Dict[str, Any]] = {} encoding: Dict[str, Dict[str, Any]] = {}
for var in map(str, dset.data_vars): for var in map(str, dset.data_vars):
skip_var: bool = False
if not isinstance(dset[var].data, Array): if not isinstance(dset[var].data, Array):
logger.debug("Skipping rechunking variable %s", var) logger.debug("Skipping rechunking variable %s", var)
continue continue
...@@ -66,27 +63,18 @@ def _rechunk_dataset(dset: xr.Dataset) -> Tuple[xr.Dataset, Dict[str, Any]]: ...@@ -66,27 +63,18 @@ def _rechunk_dataset(dset: xr.Dataset) -> Tuple[xr.Dataset, Dict[str, Any]]:
chunks[i] = None chunks[i] = None
else: else:
chunks[i] = "auto" chunks[i] = "auto"
try: old_chunks = dset[var].encoding.get("chunksizes")
old_chunks = dset[var].encoding.get("chunksizes") new_chunks = dset[var].data.rechunk(chunks).chunksize
new_chunks = dset[var].data.rechunk(chunks).chunksize if new_chunks == old_chunks:
if new_chunks == old_chunks: logger.debug("%s: chunk sizes already optimized, skipping", var)
logger.debug("%s: chunk sizes already optimized, skipping", var)
skip_var = True
continue
dset[var] = dset[var].chunk(dict(zip(dset[var].dims, new_chunks)))
logger.debug(
"%s: old chunk size: %s, new chunk size: %s",
var,
old_chunks,
new_chunks,
)
except Exception as error:
logger.warning("Could not set chunk size for %s: %s", var, str(error))
continue continue
if not skip_var: dset[var] = dset[var].chunk(dict(zip(dset[var].dims, new_chunks)))
logger.debug("Settings encoding of variable %s", var) logger.debug(
encoding[var] = {str(k): v for k, v in dset[var].encoding.items()} "%s: old chunk size: %s, new chunk size: %s", var, old_chunks, new_chunks,
encoding[var]["chunksizes"] = new_chunks )
logger.debug("Settings encoding of variable %s", var)
encoding[var] = {str(k): v for k, v in dset[var].encoding.items()}
encoding[var]["chunksizes"] = new_chunks
return dset, encoding return dset, encoding
...@@ -102,7 +90,7 @@ def rechunk_dataset(dset: xr.Dataset) -> xr.Dataset: ...@@ -102,7 +90,7 @@ def rechunk_dataset(dset: xr.Dataset) -> xr.Dataset:
------- -------
xarray.Dataset: rechunked dataset xarray.Dataset: rechunked dataset
""" """
data, _ = _rechunk_dataset(dset) data, _ = _rechunk_dataset(dset.chunk())
return data return data
......
...@@ -31,6 +31,15 @@ def small_chunk() -> Generator[Tuple[int, int, int, int], None, None]: ...@@ -31,6 +31,15 @@ def small_chunk() -> Generator[Tuple[int, int, int, int], None, None]:
yield (1, 1, 24, 24) yield (1, 1, 24, 24)
@pytest.fixture(scope="session")
def empty_data() -> Generator[xr.Dataset, None, None]:
"""Create an empty datasset."""
yield xr.Dataset(
{"tas_bnds": xr.DataArray(["hallo"], name="tas_bnds", dims=("lon",))}
)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def large_chunk() -> Generator[Tuple[int, int, int, int], None, None]: def large_chunk() -> Generator[Tuple[int, int, int, int], None, None]:
"""Define tuple for smaller chunks sizes.""" """Define tuple for smaller chunks sizes."""
......
"""Unit tests for rechunking the data.""" """Unit tests for rechunking the data."""
import dask
import xarray as xr import xarray as xr
from rechunk_data._rechunk import _rechunk_dataset from rechunk_data._rechunk import _rechunk_dataset
...@@ -16,6 +16,13 @@ def test_rechunking_small_data( ...@@ -16,6 +16,13 @@ def test_rechunking_small_data(
assert encoding == {} assert encoding == {}
def test_empty_dataset(empty_data):
"""Test handling of empyt data."""
with dask.config.set({"array.chunk-size": "0.01b"}):
dset, encoding = _rechunk_dataset(empty_data)
assert encoding == {}
def test_rechunking_large_data( def test_rechunking_large_data(
large_chunk_data: xr.Dataset, variable_name: str large_chunk_data: xr.Dataset, variable_name: str
) -> None: ) -> None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment