diff --git a/src/rechunk_data/_rechunk.py b/src/rechunk_data/_rechunk.py index f591d4b4880bfa5563def6f93990aa2fe4d5a544..eef9ed368ffb91e021586381114dfebd98a96f3e 100644 --- a/src/rechunk_data/_rechunk.py +++ b/src/rechunk_data/_rechunk.py @@ -44,9 +44,7 @@ def _save_dataset( logger.debug("Saving file ot %s", str(file_name)) try: dset.to_netcdf( - file_name, - engine=engine, - encoding=encoding, + file_name, engine=engine, encoding=encoding, ) except Exception as error: logger.error("Saving to file failed: %s", str(error)) @@ -55,7 +53,6 @@ def _save_dataset( def _rechunk_dataset(dset: xr.Dataset) -> Tuple[xr.Dataset, Dict[str, Any]]: encoding: Dict[str, Dict[str, Any]] = {} for var in map(str, dset.data_vars): - skip_var: bool = False if not isinstance(dset[var].data, Array): logger.debug("Skipping rechunking variable %s", var) continue @@ -66,27 +63,18 @@ def _rechunk_dataset(dset: xr.Dataset) -> Tuple[xr.Dataset, Dict[str, Any]]: chunks[i] = None else: chunks[i] = "auto" - try: - old_chunks = dset[var].encoding.get("chunksizes") - new_chunks = dset[var].data.rechunk(chunks).chunksize - if new_chunks == old_chunks: - logger.debug("%s: chunk sizes already optimized, skipping", var) - skip_var = True - continue - dset[var] = dset[var].chunk(dict(zip(dset[var].dims, new_chunks))) - logger.debug( - "%s: old chunk size: %s, new chunk size: %s", - var, - old_chunks, - new_chunks, - ) - except Exception as error: - logger.warning("Could not set chunk size for %s: %s", var, str(error)) + old_chunks = dset[var].encoding.get("chunksizes") + new_chunks = dset[var].data.rechunk(chunks).chunksize + if new_chunks == old_chunks: + logger.debug("%s: chunk sizes already optimized, skipping", var) continue - if not skip_var: - logger.debug("Settings encoding of variable %s", var) - encoding[var] = {str(k): v for k, v in dset[var].encoding.items()} - encoding[var]["chunksizes"] = new_chunks + dset[var] = dset[var].chunk(dict(zip(dset[var].dims, new_chunks))) + logger.debug( + "%s: old chunk size: %s, new chunk size: %s", var, old_chunks, new_chunks, + ) + logger.debug("Settings encoding of variable %s", var) + encoding[var] = {str(k): v for k, v in dset[var].encoding.items()} + encoding[var]["chunksizes"] = new_chunks return dset, encoding @@ -102,7 +90,7 @@ def rechunk_dataset(dset: xr.Dataset) -> xr.Dataset: ------- xarray.Dataset: rechunked dataset """ - data, _ = _rechunk_dataset(dset) + data, _ = _rechunk_dataset(dset.chunk()) return data diff --git a/src/rechunk_data/tests/conftest.py b/src/rechunk_data/tests/conftest.py index def679d948ba3a16d1083eb8e9584668a038c5c2..02d03b10aca9659dbdd861cd4953fa2eac8f1325 100644 --- a/src/rechunk_data/tests/conftest.py +++ b/src/rechunk_data/tests/conftest.py @@ -31,6 +31,15 @@ def small_chunk() -> Generator[Tuple[int, int, int, int], None, None]: yield (1, 1, 24, 24) +@pytest.fixture(scope="session") +def empty_data() -> Generator[xr.Dataset, None, None]: + """Create an empty datasset.""" + + yield xr.Dataset( + {"tas_bnds": xr.DataArray(["hallo"], name="tas_bnds", dims=("lon",))} + ) + + @pytest.fixture(scope="session") def large_chunk() -> Generator[Tuple[int, int, int, int], None, None]: """Define tuple for smaller chunks sizes.""" diff --git a/src/rechunk_data/tests/test_rechunking.py b/src/rechunk_data/tests/test_rechunking.py index a8b3d7c6d9791eec1b0c1536a7ec4eb1974a8223..994a6a39b92b6606ca95744eafabae06971a0f5d 100644 --- a/src/rechunk_data/tests/test_rechunking.py +++ b/src/rechunk_data/tests/test_rechunking.py @@ -1,5 +1,5 @@ """Unit tests for rechunking the data.""" - +import dask import xarray as xr from rechunk_data._rechunk import _rechunk_dataset @@ -16,6 +16,13 @@ def test_rechunking_small_data( assert encoding == {} +def test_empty_dataset(empty_data): + """Test handling of empyt data.""" + with dask.config.set({"array.chunk-size": "0.01b"}): + dset, encoding = _rechunk_dataset(empty_data) + assert encoding == {} + + def test_rechunking_large_data( large_chunk_data: xr.Dataset, variable_name: str ) -> None: