Skip to content
Snippets Groups Projects
Commit 05df30de authored by Martin Bergemann's avatar Martin Bergemann :speech_balloon:
Browse files

Merge branch 'filter-engine-keywords' into 'main'

Filter for only allowed engine keywords

See merge request !11
parents 8804dd08 f3ab2592
No related branches found
No related tags found
1 merge request!11Filter for only allowed engine keywords
Pipeline #21291 passed
...@@ -58,9 +58,9 @@ pages: ...@@ -58,9 +58,9 @@ pages:
- /tmp/test/bin/python -m pip install -e .[test] - /tmp/test/bin/python -m pip install -e .[test]
script: script:
- /tmp/test/bin/coverage run -m pytest - /tmp/test/bin/coverage run -m pytest
- /tmp/test/bin/coverage report
- /tmp/test/bin/coverage html - /tmp/test/bin/coverage html
- /tmp/test/bin/coverage xml - /tmp/test/bin/coverage xml
- /tmp/test/bin/coverage report
coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/' coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/'
after_script: after_script:
- mv htmlcov public - mv htmlcov public
......
[mypy] [mypy]
files = src/rechunk_data files = src/rechunk_data/_rechunk.py, src/rechunk_data/__init__.pyi
exclude = src/rechunk_data/tests/*.py, src/rechunk_data/__init__.py
strict = False strict = False
show_error_codes = True
warn_unused_ignores = True warn_unused_ignores = True
warn_unreachable = True warn_unreachable = True
show_error_codes = True
...@@ -37,7 +37,7 @@ setup( ...@@ -37,7 +37,7 @@ setup(
entry_points={ entry_points={
"console_scripts": [f"{find_key(key='PROGRAM_NAME')} = rechunk_data:cli"] "console_scripts": [f"{find_key(key='PROGRAM_NAME')} = rechunk_data:cli"]
}, },
install_requires=["argparse", "dask", "xarray", "h5netcdf", "netCDF4"], install_requires=["argparse", "dask", "xarray", "h5netcdf", "netCDF4", "typing_extensions"],
extras_require={ extras_require={
"test": [ "test": [
"black", "black",
......
...@@ -10,7 +10,7 @@ from ._rechunk import ( ...@@ -10,7 +10,7 @@ from ._rechunk import (
logger, logger,
) )
__version__ = "2206.0.3" __version__ = "2208.0.1"
PROGRAM_NAME = "rechunk-data" PROGRAM_NAME = "rechunk-data"
......
import argparse import argparse
import os import os
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Generator, List, Optional, Tuple from typing import (
Any,
Generator,
Mapping,
Dict,
Hashable,
List,
Optional,
Tuple,
)
from typing_extensions import Literal
import xarray as xr import xarray as xr
def parse_args() -> argparse.Namespace: ... def parse_args() -> argparse.Namespace: ...
def _search_for_nc_files(input_path: Path) -> Generator[Path, None, None]: ... def _search_for_nc_files(input_path: Path) -> Generator[Path, None, None]: ...
def _save_dataset( def _save_dataset(
dset: xr.Dataset, file_name: Path, encoding: Dict[str, Any], engine: str dset: xr.Dataset,
file_name: Path,
encoding: Dict[Hashable, Dict[str, Any]],
engine: Literal["netcdf4", "h5netcdf"],
) -> None: ... ) -> None: ...
def _rechunk_dataset(dset: xr.Dataset) -> Tuple[xr.Dataset, Dict[str, Any]]: ... def _rechunk_dataset(
def rechunk_dataset(dset: xr.Dataset) -> xr.Dataset: ... dset: xr.Dataset, engine: Literal["netcdf4", "h5netcdf"]
) -> Tuple[xr.Dataset, Dict[Hashable, Dict[str, Any]]]: ...
def rechunk_dataset(
dset: xr.Dataset, engine: Literal["netcdf4", "h5netcdf"] = ...
) -> xr.Dataset: ...
def rechunk_netcdf_file( def rechunk_netcdf_file(
input_path: os.PathLike, input_path: os.PathLike,
output_path: Optional[os.PathLike] = ..., output_path: Optional[os.PathLike] = ...,
engine: str = ..., engine: Literal["netcdf4", "h5netcdf"] = ...,
) -> None: ... ) -> None: ...
def cli( def cli(
argv: Optional[List[str]] = ..., argv: Optional[List[str]] = ...,
......
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
import os import os
import logging import logging
from pathlib import Path from pathlib import Path
from typing import cast, Any, Dict, Generator, Optional, Tuple from typing import cast, Any, Dict, Hashable, Generator, Optional, Tuple
from typing_extensions import Literal
from dask.utils import format_bytes from dask.utils import format_bytes
from dask.array.core import Array from dask.array.core import Array
...@@ -15,6 +16,32 @@ logging.basicConfig( ...@@ -15,6 +16,32 @@ logging.basicConfig(
) )
logger = logging.getLogger("rechunk-data") logger = logging.getLogger("rechunk-data")
ENCODINGS = dict(
h5netcdf={
"_FillValue",
"complevel",
"chunksizes",
"dtype",
"zlib",
"compression_opts",
"shuffle",
"fletcher32",
"compression",
"contiguous",
},
netcdf4={
"contiguous",
"complevel",
"zlib",
"shuffle",
"_FillValue",
"least_significant_digit",
"chunksizes",
"fletcher32",
"dtype",
},
)
def _search_for_nc_files(input_path: Path) -> Generator[Path, None, None]: def _search_for_nc_files(input_path: Path) -> Generator[Path, None, None]:
suffixes = [".nc", "nc4"] suffixes = [".nc", "nc4"]
...@@ -34,14 +61,15 @@ def _search_for_nc_files(input_path: Path) -> Generator[Path, None, None]: ...@@ -34,14 +61,15 @@ def _search_for_nc_files(input_path: Path) -> Generator[Path, None, None]:
def _save_dataset( def _save_dataset(
dset: xr.Dataset, dset: xr.Dataset,
file_name: Path, file_name: Path,
encoding: Dict[str, Any], encoding: Dict[Hashable, Dict[str, Any]],
engine: str, engine: Literal["netcdf4", "h5netcdf"],
override: bool = False, override: bool = False,
) -> None: ) -> None:
if not encoding and not override: if not encoding and not override:
logger.debug("Chunk size already optimized for %s", file_name.name) logger.debug("Chunk size already optimized for %s", file_name.name)
return return
logger.debug("Saving file ot %s", str(file_name)) print(str(file_name))
logger.debug("Saving file to %s using %s engine", str(file_name), engine)
try: try:
dset.to_netcdf( dset.to_netcdf(
file_name, file_name,
...@@ -52,9 +80,19 @@ def _save_dataset( ...@@ -52,9 +80,19 @@ def _save_dataset(
logger.error("Saving to file failed: %s", str(error)) logger.error("Saving to file failed: %s", str(error))
def _rechunk_dataset(dset: xr.Dataset) -> Tuple[xr.Dataset, Dict[str, Any]]: def _rechunk_dataset(
encoding: Dict[str, Dict[str, Any]] = {} dset: xr.Dataset,
for var in map(str, dset.data_vars): engine: Literal["h5netcdf", "netcdf4"],
) -> Tuple[xr.Dataset, Dict[Hashable, Dict[str, Any]]]:
encoding: Dict[Hashable, Dict[str, Any]] = {}
try:
_keywords = ENCODINGS[engine]
except KeyError as error:
raise ValueError(
"Only the following engines are supported: ', '.join(ENCODINGS.keys())"
) from error
for data_var in dset.data_vars:
var = str(data_var)
if not isinstance(dset[var].data, Array): if not isinstance(dset[var].data, Array):
logger.debug("Skipping rechunking variable %s", var) logger.debug("Skipping rechunking variable %s", var)
continue continue
...@@ -78,31 +116,37 @@ def _rechunk_dataset(dset: xr.Dataset) -> Tuple[xr.Dataset, Dict[str, Any]]: ...@@ -78,31 +116,37 @@ def _rechunk_dataset(dset: xr.Dataset) -> Tuple[xr.Dataset, Dict[str, Any]]:
new_chunks, new_chunks,
) )
logger.debug("Settings encoding of variable %s", var) logger.debug("Settings encoding of variable %s", var)
encoding[var] = {str(k): v for k, v in dset[var].encoding.items()} encoding[data_var] = {
encoding[var]["chunksizes"] = new_chunks str(k): v for k, v in dset[var].encoding.items() if str(k) in _keywords
}
encoding[data_var]["chunksizes"] = new_chunks
return dset, encoding return dset, encoding
def rechunk_dataset(dset: xr.Dataset) -> xr.Dataset: def rechunk_dataset(
dset: xr.Dataset, engine: Literal["h5netcdf", "netcdf4"] = "h5netcdf"
) -> xr.Dataset:
"""Rechunk a xarray dataset. """Rechunk a xarray dataset.
Parameters Parameters
---------- ----------
dset: xarray.Dataset dset: xarray.Dataset
Input dataset that is going to be rechunked Input dataset that is going to be rechunked
engine: str, default: h5netcdf
The netcdf engine used to create the new netcdf file.
Returns Returns
------- -------
xarray.Dataset: rechunked dataset xarray.Dataset: rechunked dataset
""" """
data, _ = _rechunk_dataset(dset.chunk()) data, _ = _rechunk_dataset(dset.chunk(), engine)
return data return data
def rechunk_netcdf_file( def rechunk_netcdf_file(
input_path: os.PathLike, input_path: os.PathLike,
output_path: Optional[os.PathLike] = None, output_path: Optional[os.PathLike] = None,
engine: str = "h5netcdf", engine: Literal["h5netcdf", "netcdf4"] = "h5netcdf",
) -> None: ) -> None:
"""Rechunk netcdf files. """Rechunk netcdf files.
...@@ -115,10 +159,12 @@ def rechunk_netcdf_file( ...@@ -115,10 +159,12 @@ def rechunk_netcdf_file(
Output file/directory of the chunked netcdf file(s). Note: If ``input`` Output file/directory of the chunked netcdf file(s). Note: If ``input``
is a directory output should be a directory. If None given (default) is a directory output should be a directory. If None given (default)
the ``input`` is overidden. the ``input`` is overidden.
engine: The netcdf engine used to create the new netcdf file. engine: str, default: h5netcdf
The netcdf engine used to create the new netcdf file.
""" """
input_path = Path(input_path).expanduser().absolute() input_path = Path(input_path).expanduser().absolute()
for input_file in _search_for_nc_files(input_path): for input_file in _search_for_nc_files(input_path):
print(input_file)
logger.info("Working on file: %s", str(input_file)) logger.info("Working on file: %s", str(input_file))
if output_path is None: if output_path is None:
output_file = input_file output_file = input_file
...@@ -130,7 +176,7 @@ def rechunk_netcdf_file( ...@@ -130,7 +176,7 @@ def rechunk_netcdf_file(
output_file.parent.mkdir(exist_ok=True, parents=True) output_file.parent.mkdir(exist_ok=True, parents=True)
try: try:
with xr.open_mfdataset(str(input_file), parallel=True) as nc_data: with xr.open_mfdataset(str(input_file), parallel=True) as nc_data:
new_data, encoding = _rechunk_dataset(nc_data) new_data, encoding = _rechunk_dataset(nc_data, engine)
if encoding: if encoding:
logger.debug( logger.debug(
"Loading data into memory (%s).", format_bytes(new_data.nbytes) "Loading data into memory (%s).", format_bytes(new_data.nbytes)
......
"""Test the actual rechunk method.""" """Test the actual rechunk method."""
import logging import logging
from tempfile import NamedTemporaryFile, TemporaryDirectory
from pathlib import Path from pathlib import Path
import time
from tempfile import NamedTemporaryFile, TemporaryDirectory
import dask import dask
import pytest
from rechunk_data import rechunk_netcdf_file, rechunk_dataset from rechunk_data import rechunk_netcdf_file, rechunk_dataset
from rechunk_data._rechunk import _save_dataset from rechunk_data._rechunk import _save_dataset
...@@ -30,10 +32,11 @@ def test_rechunk_data_dir_without_overwrite(data_dir: Path) -> None: ...@@ -30,10 +32,11 @@ def test_rechunk_data_dir_without_overwrite(data_dir: Path) -> None:
def test_rechunk_single_data_file(data_file: Path) -> None: def test_rechunk_single_data_file(data_file: Path) -> None:
"""Testing rechunking of single data files.""" """Testing rechunking of single data files."""
a_time = float(data_file.stat().st_atime) a_time = float(data_file.stat().st_mtime)
time.sleep(0.5)
with dask.config.set({"array.chunk-size": "1MiB"}): with dask.config.set({"array.chunk-size": "1MiB"}):
rechunk_netcdf_file(data_file) rechunk_netcdf_file(data_file)
assert a_time < float(data_file.stat().st_atime) assert a_time < float(data_file.stat().st_mtime)
with NamedTemporaryFile(suffix=".nc") as temp_file: with NamedTemporaryFile(suffix=".nc") as temp_file:
rechunk_netcdf_file(data_file, Path(temp_file.name)) rechunk_netcdf_file(data_file, Path(temp_file.name))
assert Path(temp_file.name).exists() assert Path(temp_file.name).exists()
...@@ -61,3 +64,9 @@ def test_wrong_or_format(small_chunk_data, caplog) -> None: ...@@ -61,3 +64,9 @@ def test_wrong_or_format(small_chunk_data, caplog) -> None:
_save_dataset(small_chunk_data, temp_file, {"foo": "bar"}, "foo") _save_dataset(small_chunk_data, temp_file, {"foo": "bar"}, "foo")
_, loglevel, message = caplog.record_tuples[-1] _, loglevel, message = caplog.record_tuples[-1]
assert loglevel == logging.ERROR assert loglevel == logging.ERROR
def test_wrong_engine(small_chunk_data) -> None:
with pytest.raises(ValueError):
rechunk_dataset(small_chunk_data, engine="foo")
...@@ -10,16 +10,16 @@ def test_rechunking_small_data( ...@@ -10,16 +10,16 @@ def test_rechunking_small_data(
) -> None: ) -> None:
"""Rechunking of small datasets should have no effect.""" """Rechunking of small datasets should have no effect."""
chunks = small_chunk_data[variable_name].data.chunksize chunks = small_chunk_data[variable_name].data.chunksize
dset, encoding = _rechunk_dataset(small_chunk_data) dset, encoding = _rechunk_dataset(small_chunk_data, "h5netcdf")
assert dset[variable_name].data.chunksize == chunks assert dset[variable_name].data.chunksize == chunks
dset, encoding = _rechunk_dataset(small_chunk_data.load()) dset, encoding = _rechunk_dataset(small_chunk_data.load(), "h5netcdf")
assert encoding == {} assert encoding == {}
def test_empty_dataset(empty_data): def test_empty_dataset(empty_data):
"""Test handling of empyt data.""" """Test handling of empyt data."""
with dask.config.set({"array.chunk-size": "0.01b"}): with dask.config.set({"array.chunk-size": "0.01b"}):
dset, encoding = _rechunk_dataset(empty_data) dset, encoding = _rechunk_dataset(empty_data, "h5netcdf")
assert encoding == {} assert encoding == {}
...@@ -32,6 +32,6 @@ def test_rechunking_large_data( ...@@ -32,6 +32,6 @@ def test_rechunking_large_data(
.data.rechunk({0: "auto", 1: "auto", 2: None, 3: None}) .data.rechunk({0: "auto", 1: "auto", 2: None, 3: None})
.chunksize .chunksize
) )
dset, encoding = _rechunk_dataset(large_chunk_data) dset, encoding = _rechunk_dataset(large_chunk_data, "h5netcdf")
assert encoding[variable_name]["chunksizes"] == chunks assert encoding[variable_name]["chunksizes"] == chunks
assert dset[variable_name].data.chunksize == chunks assert dset[variable_name].data.chunksize == chunks
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment