diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 3c99f16ef7b7c4affeff8875925f55dfb9b227ae..835e3839ed6fcef89284f51b73c91f645511b223 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -58,9 +58,9 @@ pages: - /tmp/test/bin/python -m pip install -e .[test] script: - /tmp/test/bin/coverage run -m pytest - - /tmp/test/bin/coverage report - /tmp/test/bin/coverage html - /tmp/test/bin/coverage xml + - /tmp/test/bin/coverage report coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/' after_script: - mv htmlcov public diff --git a/mypy.ini b/mypy.ini index 9e7903c490334e1d738efb95eb733083aa60e644..86a516e91262914b0edee900ba9e30bd49f33685 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,6 +1,7 @@ [mypy] -files = src/rechunk_data +files = src/rechunk_data/_rechunk.py, src/rechunk_data/__init__.pyi +exclude = src/rechunk_data/tests/*.py, src/rechunk_data/__init__.py strict = False +show_error_codes = True warn_unused_ignores = True warn_unreachable = True -show_error_codes = True diff --git a/setup.py b/setup.py index a276578c9241b9fdcde6fc64d4f4a1cf511ec995..ae0a411fec4c8e4ef2171dc05b7a74d6ec392c29 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ setup( entry_points={ "console_scripts": [f"{find_key(key='PROGRAM_NAME')} = rechunk_data:cli"] }, - install_requires=["argparse", "dask", "xarray", "h5netcdf", "netCDF4"], + install_requires=["argparse", "dask", "xarray", "h5netcdf", "netCDF4", "typing_extensions"], extras_require={ "test": [ "black", diff --git a/src/rechunk_data/__init__.py b/src/rechunk_data/__init__.py index bbdf815b88f134b42d2f0e1ec1930f75bb51ba76..2f849bf6a69335fb95d75c534d8a81fc45e1b579 100644 --- a/src/rechunk_data/__init__.py +++ b/src/rechunk_data/__init__.py @@ -10,7 +10,7 @@ from ._rechunk import ( logger, ) -__version__ = "2206.0.3" +__version__ = "2208.0.1" PROGRAM_NAME = "rechunk-data" diff --git a/src/rechunk_data/__init__.pyi b/src/rechunk_data/__init__.pyi index 6eca2dd0f35ebdd75a1577926ddadc1531026934..8a1130084a494d4de4e802e705343756cbf305a0 100644 --- a/src/rechunk_data/__init__.pyi +++ b/src/rechunk_data/__init__.pyi @@ -1,21 +1,38 @@ import argparse import os from pathlib import Path -from typing import Any, Dict, Generator, List, Optional, Tuple +from typing import ( + Any, + Generator, + Mapping, + Dict, + Hashable, + List, + Optional, + Tuple, +) +from typing_extensions import Literal import xarray as xr def parse_args() -> argparse.Namespace: ... def _search_for_nc_files(input_path: Path) -> Generator[Path, None, None]: ... def _save_dataset( - dset: xr.Dataset, file_name: Path, encoding: Dict[str, Any], engine: str + dset: xr.Dataset, + file_name: Path, + encoding: Dict[Hashable, Dict[str, Any]], + engine: Literal["netcdf4", "h5netcdf"], ) -> None: ... -def _rechunk_dataset(dset: xr.Dataset) -> Tuple[xr.Dataset, Dict[str, Any]]: ... -def rechunk_dataset(dset: xr.Dataset) -> xr.Dataset: ... +def _rechunk_dataset( + dset: xr.Dataset, engine: Literal["netcdf4", "h5netcdf"] +) -> Tuple[xr.Dataset, Dict[Hashable, Dict[str, Any]]]: ... +def rechunk_dataset( + dset: xr.Dataset, engine: Literal["netcdf4", "h5netcdf"] = ... +) -> xr.Dataset: ... def rechunk_netcdf_file( input_path: os.PathLike, output_path: Optional[os.PathLike] = ..., - engine: str = ..., + engine: Literal["netcdf4", "h5netcdf"] = ..., ) -> None: ... def cli( argv: Optional[List[str]] = ..., diff --git a/src/rechunk_data/_rechunk.py b/src/rechunk_data/_rechunk.py index 785553ed8421960b368943a5a46062bf88535392..f70ad0c0efbb3e03b0e793262756a96c1e30a96e 100644 --- a/src/rechunk_data/_rechunk.py +++ b/src/rechunk_data/_rechunk.py @@ -3,7 +3,8 @@ import os import logging from pathlib import Path -from typing import cast, Any, Dict, Generator, Optional, Tuple +from typing import cast, Any, Dict, Hashable, Generator, Optional, Tuple +from typing_extensions import Literal from dask.utils import format_bytes from dask.array.core import Array @@ -15,6 +16,32 @@ logging.basicConfig( ) logger = logging.getLogger("rechunk-data") +ENCODINGS = dict( + h5netcdf={ + "_FillValue", + "complevel", + "chunksizes", + "dtype", + "zlib", + "compression_opts", + "shuffle", + "fletcher32", + "compression", + "contiguous", + }, + netcdf4={ + "contiguous", + "complevel", + "zlib", + "shuffle", + "_FillValue", + "least_significant_digit", + "chunksizes", + "fletcher32", + "dtype", + }, +) + def _search_for_nc_files(input_path: Path) -> Generator[Path, None, None]: suffixes = [".nc", "nc4"] @@ -34,14 +61,15 @@ def _search_for_nc_files(input_path: Path) -> Generator[Path, None, None]: def _save_dataset( dset: xr.Dataset, file_name: Path, - encoding: Dict[str, Any], - engine: str, + encoding: Dict[Hashable, Dict[str, Any]], + engine: Literal["netcdf4", "h5netcdf"], override: bool = False, ) -> None: if not encoding and not override: logger.debug("Chunk size already optimized for %s", file_name.name) return - logger.debug("Saving file ot %s", str(file_name)) + print(str(file_name)) + logger.debug("Saving file to %s using %s engine", str(file_name), engine) try: dset.to_netcdf( file_name, @@ -52,9 +80,19 @@ def _save_dataset( logger.error("Saving to file failed: %s", str(error)) -def _rechunk_dataset(dset: xr.Dataset) -> Tuple[xr.Dataset, Dict[str, Any]]: - encoding: Dict[str, Dict[str, Any]] = {} - for var in map(str, dset.data_vars): +def _rechunk_dataset( + dset: xr.Dataset, + engine: Literal["h5netcdf", "netcdf4"], +) -> Tuple[xr.Dataset, Dict[Hashable, Dict[str, Any]]]: + encoding: Dict[Hashable, Dict[str, Any]] = {} + try: + _keywords = ENCODINGS[engine] + except KeyError as error: + raise ValueError( + "Only the following engines are supported: ', '.join(ENCODINGS.keys())" + ) from error + for data_var in dset.data_vars: + var = str(data_var) if not isinstance(dset[var].data, Array): logger.debug("Skipping rechunking variable %s", var) continue @@ -78,31 +116,37 @@ def _rechunk_dataset(dset: xr.Dataset) -> Tuple[xr.Dataset, Dict[str, Any]]: new_chunks, ) logger.debug("Settings encoding of variable %s", var) - encoding[var] = {str(k): v for k, v in dset[var].encoding.items()} - encoding[var]["chunksizes"] = new_chunks + encoding[data_var] = { + str(k): v for k, v in dset[var].encoding.items() if str(k) in _keywords + } + encoding[data_var]["chunksizes"] = new_chunks return dset, encoding -def rechunk_dataset(dset: xr.Dataset) -> xr.Dataset: +def rechunk_dataset( + dset: xr.Dataset, engine: Literal["h5netcdf", "netcdf4"] = "h5netcdf" +) -> xr.Dataset: """Rechunk a xarray dataset. Parameters ---------- dset: xarray.Dataset Input dataset that is going to be rechunked + engine: str, default: h5netcdf + The netcdf engine used to create the new netcdf file. Returns ------- xarray.Dataset: rechunked dataset """ - data, _ = _rechunk_dataset(dset.chunk()) + data, _ = _rechunk_dataset(dset.chunk(), engine) return data def rechunk_netcdf_file( input_path: os.PathLike, output_path: Optional[os.PathLike] = None, - engine: str = "h5netcdf", + engine: Literal["h5netcdf", "netcdf4"] = "h5netcdf", ) -> None: """Rechunk netcdf files. @@ -115,10 +159,12 @@ def rechunk_netcdf_file( Output file/directory of the chunked netcdf file(s). Note: If ``input`` is a directory output should be a directory. If None given (default) the ``input`` is overidden. - engine: The netcdf engine used to create the new netcdf file. + engine: str, default: h5netcdf + The netcdf engine used to create the new netcdf file. """ input_path = Path(input_path).expanduser().absolute() for input_file in _search_for_nc_files(input_path): + print(input_file) logger.info("Working on file: %s", str(input_file)) if output_path is None: output_file = input_file @@ -130,7 +176,7 @@ def rechunk_netcdf_file( output_file.parent.mkdir(exist_ok=True, parents=True) try: with xr.open_mfdataset(str(input_file), parallel=True) as nc_data: - new_data, encoding = _rechunk_dataset(nc_data) + new_data, encoding = _rechunk_dataset(nc_data, engine) if encoding: logger.debug( "Loading data into memory (%s).", format_bytes(new_data.nbytes) diff --git a/src/rechunk_data/tests/test_rechunk_netcdf.py b/src/rechunk_data/tests/test_rechunk_netcdf.py index fd7da6c98cadb429d3bdc568fc43c28ee2117245..81d02ff2e1b0dd3f11e74df07ac54a1e93ac8fa9 100644 --- a/src/rechunk_data/tests/test_rechunk_netcdf.py +++ b/src/rechunk_data/tests/test_rechunk_netcdf.py @@ -1,9 +1,11 @@ """Test the actual rechunk method.""" import logging -from tempfile import NamedTemporaryFile, TemporaryDirectory from pathlib import Path +import time +from tempfile import NamedTemporaryFile, TemporaryDirectory import dask +import pytest from rechunk_data import rechunk_netcdf_file, rechunk_dataset from rechunk_data._rechunk import _save_dataset @@ -30,10 +32,11 @@ def test_rechunk_data_dir_without_overwrite(data_dir: Path) -> None: def test_rechunk_single_data_file(data_file: Path) -> None: """Testing rechunking of single data files.""" - a_time = float(data_file.stat().st_atime) + a_time = float(data_file.stat().st_mtime) + time.sleep(0.5) with dask.config.set({"array.chunk-size": "1MiB"}): rechunk_netcdf_file(data_file) - assert a_time < float(data_file.stat().st_atime) + assert a_time < float(data_file.stat().st_mtime) with NamedTemporaryFile(suffix=".nc") as temp_file: rechunk_netcdf_file(data_file, Path(temp_file.name)) assert Path(temp_file.name).exists() @@ -61,3 +64,9 @@ def test_wrong_or_format(small_chunk_data, caplog) -> None: _save_dataset(small_chunk_data, temp_file, {"foo": "bar"}, "foo") _, loglevel, message = caplog.record_tuples[-1] assert loglevel == logging.ERROR + + +def test_wrong_engine(small_chunk_data) -> None: + + with pytest.raises(ValueError): + rechunk_dataset(small_chunk_data, engine="foo") diff --git a/src/rechunk_data/tests/test_rechunking.py b/src/rechunk_data/tests/test_rechunking.py index 994a6a39b92b6606ca95744eafabae06971a0f5d..945bb0d2e040b5e3f149304f2aee9374423c93e8 100644 --- a/src/rechunk_data/tests/test_rechunking.py +++ b/src/rechunk_data/tests/test_rechunking.py @@ -10,16 +10,16 @@ def test_rechunking_small_data( ) -> None: """Rechunking of small datasets should have no effect.""" chunks = small_chunk_data[variable_name].data.chunksize - dset, encoding = _rechunk_dataset(small_chunk_data) + dset, encoding = _rechunk_dataset(small_chunk_data, "h5netcdf") assert dset[variable_name].data.chunksize == chunks - dset, encoding = _rechunk_dataset(small_chunk_data.load()) + dset, encoding = _rechunk_dataset(small_chunk_data.load(), "h5netcdf") assert encoding == {} def test_empty_dataset(empty_data): """Test handling of empyt data.""" with dask.config.set({"array.chunk-size": "0.01b"}): - dset, encoding = _rechunk_dataset(empty_data) + dset, encoding = _rechunk_dataset(empty_data, "h5netcdf") assert encoding == {} @@ -32,6 +32,6 @@ def test_rechunking_large_data( .data.rechunk({0: "auto", 1: "auto", 2: None, 3: None}) .chunksize ) - dset, encoding = _rechunk_dataset(large_chunk_data) + dset, encoding = _rechunk_dataset(large_chunk_data, "h5netcdf") assert encoding[variable_name]["chunksizes"] == chunks assert dset[variable_name].data.chunksize == chunks