diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 86ea82140133b3d89876252316a133f374df4444..2f879708645bd9c06350c2681e7dbc344ac2e2cd 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -13,45 +13,53 @@ lint: - pip install .[test] script: - mypy - - black --check src + - black --check -l 79 src - pylint --fail-under 8.5 src/rechunk_data/__init__.py -test_36: +test_39: << : *py_test before_script: - - conda create -q -p /tmp/test python=3.6 pip dask -y + - conda create -q -p /tmp/test python=3.9 pip dask -y - /tmp/test/bin/python -m pip install -e .[test] script: - /tmp/test/bin/python -m pytest -vv -test_37: +test_310: << : *py_test before_script: - - conda create -q -p /tmp/test python=3.7 pip dask -y + - conda create -q -p /tmp/test python=3.10 pip dask -y - /tmp/test/bin/python -m pip install -e .[test] script: - /tmp/test/bin/python -m pytest -vv -test_38: +test_310: << : *py_test before_script: - - conda create -q -p /tmp/test python=3.8 pip dask -y + - conda create -q -p /tmp/test python=3.10 pip dask -y - /tmp/test/bin/python -m pip install -e .[test] script: - /tmp/test/bin/python -m pytest -vv -test_39: +test_311: << : *py_test before_script: - - conda create -q -p /tmp/test python=3.9 pip dask -y + - conda create -q -p /tmp/test python=3.11 pip dask -y - /tmp/test/bin/python -m pip install -e .[test] script: - /tmp/test/bin/python -m pytest -vv -test_310: +test_312: << : *py_test before_script: - - conda create -q -p /tmp/test python=3.10 pip dask -y + - conda create -q -p /tmp/test python=3.12 pip dask -y + - /tmp/test/bin/python -m pip install -e .[test] + script: + - /tmp/test/bin/python -m pytest -vv + +test_313: + << : *py_test + before_script: + - conda create -q -p /tmp/test python=3.13 pip dask -y - /tmp/test/bin/python -m pip install -e .[test] script: - /tmp/test/bin/python -m pytest -vv diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..73c495a161feaf01220a9a07aaddcb9fe53e7bd9 --- /dev/null +++ b/Makefile @@ -0,0 +1,24 @@ +# makefile used for testing +# +# +all: install test + +install: + python3 -m pip install -e .[test] + +test: + python3 -m pytest -vv \ + --cov=$(PWD)/src/rechunk_data --cov-report=html:coverage_report \ + --junitxml report.xml --cov-report xml \ + $(PWD)/src/rechunk_data/tests + python3 -m coverage report + +format: + isort --profile=black src + black -t py310 -l 79 src + +lint: + mypy --install-types --non-interactive + isort --check-only --profile=black src + black --check -t py310 -l 79 src + flake8 src/rechunk_data --count --max-complexity=15 --max-line-length=88 --statistics --doctests \ No newline at end of file diff --git a/README.md b/README.md index b75b5d68a901f69260102416628a46fbf95af2b1..fbef163304b0c92cdaab8ec314399d43f535e8d2 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Rechunking NetCDF data. -Rechunking of exsisting netcdf files to an optimal chunk size. This code provides +Rechunking of existing netcdf files to an optimal chunk size. This code provides a simple command line interface (cli) to rechunk existing netcdf data to an optimal chunksize of around 128 MB. @@ -19,31 +19,39 @@ Use the `--user` flag if you do not have super user rights and are not using `an ### Using the python module ```python -from rechunk_data import rechunk_dataset +from rechunk_data import rechunk_dataset, check_chunk_size import xarray as xr dset = xr.open_mfdataset("/data/*", parallel=True, combine="by_coords") +check_chunk_size(dset) # to print the chunksizes of the original set-up new_data = rechunk_dataset(dset) ``` ### Using the command line interface: ```bash -rechunk-data --help -usage: rechunk-data [-h] [--output OUTPUT] [--netcdf-engine {h5netcdf,netcdf4}] [--skip-cf-convention] [-v] [-V] input +rechunk-data --help +usage: rechunk-data [-h] [-o OUTPUT] [--netcdf-engine {h5netcdf,netcdf4}] [--size SIZE] [--auto-chunks] [--skip-cf-convention] [--force-horizontal] [--check-chunk-size] + [-v] [-V] + input -Rechunk input netcdf data to optimal chunk-size. approx. 126 MB per chunk +Rechunk input netcdf data to optimal chunk-size. approx. 126 MB per chunk. By default it only optimises time. positional arguments: input Input file/directory. If a directory is given all ``.nc`` files in all sub directories will be processed options: -h, --help show this help message and exit - --output OUTPUT Output file/directory of the chunked netcdf file(s). - Note: If ``input`` is a directory output should be a directory. - If None given (default) the ``input`` is overidden. (default: None) + -o, --output OUTPUT Output file/directory of the chunked netcdf file(s). Note: If ``input`` is a directory output should be a directory. If None given (default) the + ``input`` is overridden. (default: None) --netcdf-engine {h5netcdf,netcdf4} The netcdf engine used to create the new netcdf file. (default: netcdf4) + --size, -s SIZE Specify chunk-size (in MiB). (default: None) + --auto-chunks, -ac Allow Dask to determine optimal chunk sizes for all dimensions. (default: False) --skip-cf-convention Do not assume assume data variables follow CF conventions. (default: False) + --force-horizontal, -fh + Force horizontal chunking (~126 MB per chunk). (default: False) + --check-chunk-size, -c + Check the chunk size of the input dataset (in MB). (default: False) -v Increase verbosity (default: 0) -V, --version show program's version number and exit ``` diff --git a/setup.py b/setup.py index ae0a411fec4c8e4ef2171dc05b7a74d6ec392c29..3bbc8c6a1231d9a218445027f91fba1f9c7ec7b0 100644 --- a/setup.py +++ b/setup.py @@ -1,9 +1,10 @@ #!/usr/bin/env python3 -"""Setup script for packaging checkin.""" +"""Setup script for packaging check-in.""" import json from pathlib import Path -from setuptools import setup, find_packages + +from setuptools import find_packages, setup def read(*parts): @@ -35,14 +36,25 @@ setup( packages=find_packages("src"), package_dir={"": "src"}, entry_points={ - "console_scripts": [f"{find_key(key='PROGRAM_NAME')} = rechunk_data:cli"] + "console_scripts": [ + f"{find_key(key='PROGRAM_NAME')} = rechunk_data:cli" + ] }, - install_requires=["argparse", "dask", "xarray", "h5netcdf", "netCDF4", "typing_extensions"], + install_requires=[ + "argparse", + "dask", + "xarray", + "h5netcdf", + "netCDF4", + "typing_extensions", + ], extras_require={ "test": [ "black", + "isort", "mock", "mypy", + "ipython", "nbformat", "pytest", "pylint", @@ -51,9 +63,10 @@ setup( "pytest-cov", "testpath", "types-mock", + "flake8", ], }, - python_requires=">=3.6", + python_requires=">=3.9", classifiers=[ "Development Status :: 3 - Alpha", "Environment :: Console", diff --git a/src/rechunk_data/__init__.py b/src/rechunk_data/__init__.py index 8e68bac7b079026b27d8d2479588bef99b35c82a..da4524a225ab240060e39b026eaa10c9083d54d6 100644 --- a/src/rechunk_data/__init__.py +++ b/src/rechunk_data/__init__.py @@ -4,13 +4,18 @@ import argparse import logging from pathlib import Path from typing import List, Optional -from ._rechunk import ( - rechunk_netcdf_file, - rechunk_dataset, - logger, -) -__version__ = "2310.0.1" +from ._rechunk import rechunk_dataset # noqa +from ._rechunk import check_chunk_size, logger, rechunk_netcdf_file + +__all__ = [ + "rechunk_netcdf_file", + "check_chunk_size", + "rechunk_dataset", + "logger", +] + +__version__ = "2503.0.1" PROGRAM_NAME = "rechunk-data" @@ -21,7 +26,7 @@ def parse_args(argv: Optional[List[str]]) -> argparse.Namespace: prog=PROGRAM_NAME, description=( "Rechunk input netcdf data to optimal chunk-size." - " approx. 126 MB per chunk" + " approx. 126 MB per chunk. By default it only optimises time." ), formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) @@ -35,12 +40,13 @@ def parse_args(argv: Optional[List[str]]) -> argparse.Namespace: ), ) parser.add_argument( + "-o", "--output", type=Path, help=( "Output file/directory of the chunked netcdf " "file(s). Note: If ``input`` is a directory output should be a" - " directory. If None given (default) the ``input`` is overidden." + " directory. If None given (default) the ``input`` is overridden." ), default=None, ) @@ -51,12 +57,40 @@ def parse_args(argv: Optional[List[str]]) -> argparse.Namespace: default="netcdf4", type=str, ) + parser.add_argument( + "--size", + "-s", + help=("Specify chunk-size (in MiB)."), + default=None, + type=str, + ) + parser.add_argument( + "--auto-chunks", + "-ac", + help="Allow Dask to determine optimal chunk sizes for all dimensions.", + action="store_true", + default=False, + ) parser.add_argument( "--skip-cf-convention", help="Do not assume assume data variables follow CF conventions.", action="store_true", default=False, ) + parser.add_argument( + "--force-horizontal", + "-fh", + help="Force horizontal chunking (~126 MB per chunk).", + action="store_true", + default=False, + ) + parser.add_argument( + "--check-chunk-size", + "-c", + help="Check the chunk size of the input dataset (in MB).", + action="store_true", + default=False, + ) parser.add_argument( "-v", action="count", @@ -77,9 +111,18 @@ def parse_args(argv: Optional[List[str]]) -> argparse.Namespace: def cli(argv: Optional[List[str]] = None) -> None: """Command line interface calling the rechunking method.""" args = parse_args(argv) + if args.check_chunk_size: + import xarray as xr + + with xr.open_dataset(args.input, chunks={}) as dset: + check_chunk_size(dset) + return rechunk_netcdf_file( args.input, args.output, engine=args.netcdf_engine, + size=args.size, + auto_chunks=args.auto_chunks, decode_cf=args.skip_cf_convention is False, + force_horizontal=args.force_horizontal, ) diff --git a/src/rechunk_data/__init__.pyi b/src/rechunk_data/__init__.pyi index 8a1130084a494d4de4e802e705343756cbf305a0..8bdf7f4d68818e2ed34f0cd299a8a33c335e4107 100644 --- a/src/rechunk_data/__init__.pyi +++ b/src/rechunk_data/__init__.pyi @@ -3,17 +3,16 @@ import os from pathlib import Path from typing import ( Any, - Generator, - Mapping, Dict, + Generator, Hashable, List, Optional, Tuple, ) -from typing_extensions import Literal import xarray as xr +from typing_extensions import Literal def parse_args() -> argparse.Namespace: ... def _search_for_nc_files(input_path: Path) -> Generator[Path, None, None]: ... diff --git a/src/rechunk_data/_rechunk.py b/src/rechunk_data/_rechunk.py index 311ed2508fb36b78904887a398e13d08c410904c..42384d1faf673b6a8066975110fe487d0ce8ec04 100644 --- a/src/rechunk_data/_rechunk.py +++ b/src/rechunk_data/_rechunk.py @@ -1,15 +1,16 @@ """Rechunking module.""" -import os import logging +import os from pathlib import Path -from typing import cast, Any, Dict, Hashable, Generator, Optional, Tuple -from typing_extensions import Literal +from typing import Any, Dict, Generator, Hashable, Optional, Tuple, cast -from dask.utils import format_bytes -from dask.array.core import Array +import dask +import numpy as np import xarray as xr - +from dask.array.core import Array +from dask.utils import format_bytes +from typing_extensions import Literal logging.basicConfig( format="%(name)s - %(levelname)s - %(message)s", level=logging.ERROR @@ -41,9 +42,27 @@ ENCODINGS = dict( "dtype", }, ) +default_chunk_size = 126.0 # in MiB def _search_for_nc_files(input_path: Path) -> Generator[Path, None, None]: + """Yield all netCDF files in the given input_path. + + If the input is a directory, search recursively for all files with the + suffixes .nc or .nc4. If the input is a file, yield only the file itself. + If the input is a path with a glob pattern, construct it and search for + netCDF files in the resulting directory. + + Parameters + ---------- + input_path: Path + The path to search for netCDF files + + Yields + ------ + Path + The path to a netCDF file + """ suffixes = [".nc", "nc4"] input_path = input_path.expanduser().absolute() if input_path.is_dir() and input_path.exists(): @@ -65,6 +84,28 @@ def _save_dataset( engine: Literal["netcdf4", "h5netcdf"], override: bool = False, ) -> None: + """ + Save the given xarray dataset to a netCDF file with specified encoding. + + Parameters + ---------- + dset : xr.Dataset + The dataset to be saved. + file_name : Path + The path where the netCDF file will be saved. + encoding : Dict[Hashable, Dict[str, Any]] + Encoding options for each variable in the dataset. + engine : Literal["netcdf4", "h5netcdf"] + The engine to use for writing the netCDF file. + override : bool, optional + If True, save the dataset even if no encoding is provided. Default is False. + + Returns + ------- + None + + Logs a debug message indicating successful save or an error message if saving fails. + """ if not encoding and not override: logger.debug("Chunk size already optimized for %s", file_name.name) return @@ -79,10 +120,199 @@ def _save_dataset( logger.error("Saving to file failed: %s", str(error)) +def _optimal_chunks( + time_size: int, + y_size: int, + x_size: int, + size: float = default_chunk_size, + dtype: np.dtype = np.dtype("float32"), + only_horizontal: bool = True, +) -> Tuple[int, int, int]: + """ + Compute optimal chunk sizes for time, y (lat), and x (lon) adjusted to a + certain size in MB. + Optionally, allow forcing chunking only in horizontal dimensions + while keeping the time dimension as a single chunk. + + Parameters + ---------- + time_size: int + Total size of the time dimension (kept as a single chunk). + y_size: int + Total size of the y (latitude) dimension. + x_size: int + Total size of the x (longitude) dimension. + size: float, default: 126MB + Desired chunk size in MB. + dtype: np.dtype + Encoding dtype. + only_horizontal: bool + Whether to force chunking only in the horizontal dimensions. + + Returns + ------- + Tuple[int, int, int] + Optimal chunk sizes for time, y, and x. + """ + dtype_size = dtype.itemsize + target_elements: float = (size * (1024**2)) / dtype_size + + if y_size == 1 and x_size == 1: + only_horizontal = False + if only_horizontal: + factor = np.sqrt(target_elements / (time_size * y_size * x_size)) + time_chunk = time_size + else: + factor = np.cbrt(target_elements / (time_size * y_size * x_size)) + time_chunk = max(1, int(time_size * factor)) + y_chunk = max(1, int(y_size * factor)) + x_chunk = max(1, int(x_size * factor)) + + y_chunk = min(y_chunk, y_size) + x_chunk = min(x_chunk, x_size) + time_chunk = min(time_chunk, time_size) + return (time_chunk, y_chunk, x_chunk) + + +def _map_chunks( + chunks: dict, dim_mapping: dict, chunk_tuple: Tuple[int, int, int] +) -> dict: + """ + Update chunk sizes in `chunks` using the dimension mapping. + + Parameters + ---------- + chunks: dict + Dictionary with initial chunking setup (e.g., {0: 'auto', 1: None, 2: None}). + dim_mapping: dict + Mapping of dimension names to chunk indices (e.g., {'time': 0, 'y': 1, 'x': 2}). + chunk_tuple: Tuple[int, int, int] + Desired chunk size for (time, y, x) dimensions. + + Returns + ------- + dict + Updated chunk dictionary with appropriate chunk sizes. + """ + updated_chunks = chunks.copy() + + for dim, index in dim_mapping.items(): + if "time" in dim.lower(): + updated_chunks[index] = chunk_tuple[0] + if "y" in dim.lower() or "lat" in dim.lower(): + updated_chunks[index] = chunk_tuple[1] + if "x" in dim.lower() or "lon" in dim.lower(): + updated_chunks[index] = chunk_tuple[2] + + return updated_chunks + + +def _horizontal_chunks( + da: xr.DataArray, + chunks: dict, + size: Optional[float] = None, + only_horizontal: bool = True, +) -> dict: + """ + Updates chunk sizes, forcing horizontal chunking whenever possible. + + .. Note: + + This function will not take in account level dimensions such as + ''lev'' or ''depth'' in the chunking. probably resulting on a chunksize + of 1 for those dimensions. + + Parameters + ---------- + da: xr.DataArray + Data variable from the dataset. + chunks: dict + Original chunking dictionary. + size: float, default: None + Desired chunk size in MB. + only_horizontal: bool default: True + Whether to force chunking only in the horizontal dimensions, + e.g. disregarding time. + + Returns + ------- + dict + Updated chunk dictionary with appropriate chunk sizes. + """ + try: + orig_size = dict(da.sizes) + chunk_order = dict(zip(da.dims, chunks)) + time_size = next( + ( + value + for key, value in orig_size.items() + if isinstance(key, str) and "time" in key.lower() + ), + 1, + ) + + y_size = next( + ( + value + for key, value in orig_size.items() + if isinstance(key, str) + and any(k in key.lower() for k in ["lat", "y"]) + ), + 1, + ) + + x_size = next( + ( + value + for key, value in orig_size.items() + if isinstance(key, str) + and any(k in key.lower() for k in ["lon", "x"]) + ), + 1, + ) + dtype = da.encoding.get("dtype", np.dtype("float32")) + chunk_tuple = _optimal_chunks( + time_size, + y_size, + x_size, + size=size if size is not None else default_chunk_size, + dtype=dtype, + only_horizontal=only_horizontal, + ) + return _map_chunks(chunks, chunk_order, chunk_tuple) + except Exception as e: + logger.error(f"Error in _horizontal_chunks: {e}", exc_info=True) + return chunks + + def _rechunk_dataset( dset: xr.Dataset, engine: Literal["h5netcdf", "netcdf4"], + size: Optional[int] = None, + auto_chunks: bool = False, + force_horizontal: bool = False, ) -> Tuple[xr.Dataset, Dict[Hashable, Dict[str, Any]]]: + """ + Rechunk a xarray dataset. + + Parameters + ---------- + dset: xarray.Dataset + Input dataset that is going to be rechunked + engine: str, default: netcdf4 + The netcdf engine used to create the new netcdf file. + size: int, default: None + Desired chunk size in MB. If None, computed by Dask or default_chunk_size. + auto_chunks: bool, default: False + If True, Dask automatically determines the optimal chunk size. + force_horizontal: bool, default: False + If True, forces horizontal chunking whenever possible. + + Returns + ------- + Tuple[xr.Dataset, Dict[Hashable, Dict[str, Any]]] + A tuple containing the rechunked dataset and the updated encoding dictionary. + """ encoding: Dict[Hashable, Dict[str, Any]] = {} try: _keywords = ENCODINGS[engine] @@ -100,14 +330,39 @@ def _rechunk_dataset( logger.debug("Skipping rechunking variable %s", var) continue logger.debug("Rechunking variable %s", var) - chunks: Dict[int, Optional[str]] = {} - for i, dim in enumerate(map(str, dset[var].dims)): - if "lon" in dim.lower() or "lat" in dim.lower() or "bnds" in dim.lower(): - chunks[i] = None - else: - chunks[i] = "auto" old_chunks = dset[var].encoding.get("chunksizes") - new_chunks = dset[var].data.rechunk(chunks).chunksize + chunks: Dict[int, Optional[str]] = { + i: "auto" for i, _ in enumerate(dset[var].dims) + } + if force_horizontal: + if auto_chunks: + only_horizontal = False + else: + only_horizontal = True + chunks = _horizontal_chunks( + dset[var], + chunks, + size=size, + only_horizontal=only_horizontal, + ) + new_chunks = dset[var].data.rechunk(chunks).chunksize + else: + if not auto_chunks: + for i, dim in enumerate(map(str, dset[var].dims)): + if any( + keyword in dim.lower() + for keyword in ["lon", "lat", "bnds", "x", "y"] + ): + chunks[i] = None + if size: + with dask.config.set(array__chunk_size=f"{size}MiB"): + new_chunks = ( + dset[var] + .data.rechunk(chunks, balance=True, method="tasks") + .chunksize + ) + else: + new_chunks = dset[var].data.rechunk(chunks).chunksize if new_chunks == old_chunks: logger.debug("%s: chunk sizes already optimized, skipping", var) continue @@ -120,55 +375,91 @@ def _rechunk_dataset( ) logger.debug("Settings encoding of variable %s", var) encoding[data_var] = { - str(k): v for k, v in dset[var].encoding.items() if str(k) in _keywords + str(k): v + for k, v in dset[var].encoding.items() + if str(k) in _keywords } - if engine != "netcdf4" or encoding[data_var].get("contiguous", False) is False: + if ( + engine != "netcdf4" + or encoding[data_var].get("contiguous", False) is False + ): encoding[data_var]["chunksizes"] = new_chunks return dset, encoding def rechunk_dataset( - dset: xr.Dataset, engine: Literal["h5netcdf", "netcdf4"] = "netcdf4" + dset: xr.Dataset, + engine: Literal["h5netcdf", "netcdf4"] = "netcdf4", + size: Optional[int] = None, + auto_chunks: bool = False, + force_horizontal: bool = False, ) -> xr.Dataset: - """Rechunk a xarray dataset. + """ + Rechunk a xarray dataset. Parameters ---------- - dset: xarray.Dataset - Input dataset that is going to be rechunked - engine: str, default: netcdf4 - The netcdf engine used to create the new netcdf file. + dset : xr.Dataset + The dataset to be rechunked. + engine : Literal["h5netcdf", "netcdf4"], optional + The engine to use for writing the netCDF file. Defaults to "netcdf4". + size : Optional[int], optional + The desired chunk size in MiB. If None, the default chunk size is used. + Defaults to None. + auto_chunks : bool, optional + If True, determine the chunk size automatically using Dask. Defaults to False. + force_horizontal : bool, optional + If True, force the chunk size to be in the horizontal dimensions + (y and x). Defaults to False. Returns ------- - xarray.Dataset: rechunked dataset + xr.Dataset + The rechunked dataset. """ - data, _ = _rechunk_dataset(dset.chunk(), engine) + data, _ = _rechunk_dataset( + dset.chunk(), engine, size, auto_chunks, force_horizontal + ) return data def rechunk_netcdf_file( input_path: os.PathLike, output_path: Optional[os.PathLike] = None, - decode_cf: bool = True, engine: Literal["h5netcdf", "netcdf4"] = "netcdf4", + size: Optional[int] = None, + auto_chunks: bool = False, + decode_cf: bool = True, + force_horizontal: bool = False, ) -> None: - """Rechunk netcdf files. + """ + Rechunk a netCDF file. Parameters ---------- - input_path: os.PathLike - Input file/directory. If a directory is given all ``.nc`` in all sub - directories will be processed - output_path: os.PathLike - Output file/directory of the chunked netcdf file(s). Note: If ``input`` - is a directory output should be a directory. If None given (default) - the ``input`` is overidden. - decode_cf: bool, default: True - Whether to decode these variables, assuming they were saved according - to CF conventions. - engine: str, default: netcdf4 - The netcdf engine used to create the new netcdf file. + input_path : os.PathLike + The path to the netCDF file or directory to be rechunked. + output_path : Optional[os.PathLike], optional + The path to the directory or file where the rechunked data will be saved. + If None, the file is overwritten. Defaults to None. + engine : Literal["h5netcdf", "netcdf4"], optional + The engine to use for writing the netCDF file. Defaults to "netcdf4". + size : Optional[int], optional + The desired chunk size in MiB. If None, the default chunk size is used. + Defaults to None. + auto_chunks : bool, optional + If True, determine the chunk size automatically using Dask. Defaults to False. + decode_cf : bool, optional + Whether to decode CF conventions. Defaults to True. + force_horizontal : bool, optional + If True, force the chunk size to be in the horizontal dimensions + (y and x). Defaults to False. + + Returns + ------- + None + The function processes and saves the rechunked dataset(s) to the specified + ``output_path``. """ input_path = Path(input_path).expanduser().absolute() for input_file in _search_for_nc_files(input_path): @@ -187,7 +478,9 @@ def rechunk_netcdf_file( parallel=True, decode_cf=decode_cf, ) as nc_data: - new_data, encoding = _rechunk_dataset(nc_data, engine) + new_data, encoding = _rechunk_dataset( + nc_data, engine, size, auto_chunks, force_horizontal + ) if encoding: logger.debug( "Loading data into memory (%s).", @@ -208,3 +501,43 @@ def rechunk_netcdf_file( engine, override=output_file != input_file, ) + + +def check_chunk_size(dset: xr.Dataset) -> None: + """ + Estimates the chunk size of a dataset in MB. + + Parameters + ---------- + dset: xr.Dataset + The dataset for which to estimate the chunk size. + + Returns + ------- + None + This function prints out the estimated chunk size for each variable in a + dataset. The chunk size is estimated by multiplying the size of a single + element of the data type with the product of all chunk sizes. + """ + for var in dset.data_vars: + print(f"\n{var}:\t{dict(dset[var].sizes)}") + dtype_size = np.dtype(dset[var].dtype).itemsize + chunksizes = dset[var].encoding.get("chunksizes") + + if chunksizes is None: + print( + f" âš ï¸ Warning: No chunk sizes found for {var}, skipping...\n" + ) + continue + + chunksizes = tuple( + filter(lambda x: isinstance(x, (int, np.integer)), chunksizes) + ) + chunk_size_bytes = np.prod(chunksizes) * dtype_size + chunk_size_mb = chunk_size_bytes / (1024**2) # Convert to MB + + chunks = dset[var].chunks or tuple(() for _ in dset[var].dims) + + print(f" * Chunk Sizes: {chunksizes}") + print(f" * Estimated Chunk Size: {chunk_size_mb:.2f} MB") + print(f" * Chunks: {dict(zip(dset[var].dims, map(tuple, chunks)))}\n") diff --git a/src/rechunk_data/tests/conftest.py b/src/rechunk_data/tests/conftest.py index 02d03b10aca9659dbdd861cd4953fa2eac8f1325..0a6612bd8acc8e638e64d284f12c3e80ea9f590d 100644 --- a/src/rechunk_data/tests/conftest.py +++ b/src/rechunk_data/tests/conftest.py @@ -1,12 +1,12 @@ """pytest definitions to run the unittests.""" from pathlib import Path -from tempfile import TemporaryDirectory, NamedTemporaryFile +from tempfile import NamedTemporaryFile, TemporaryDirectory from typing import Generator, Tuple import dask -import pytest import numpy as np +import pytest import xarray as xr diff --git a/src/rechunk_data/tests/test_check_chunksize.py b/src/rechunk_data/tests/test_check_chunksize.py new file mode 100644 index 0000000000000000000000000000000000000000..22d6edb206b6052c5e4ba680cc5ad9bf6cd51e2c --- /dev/null +++ b/src/rechunk_data/tests/test_check_chunksize.py @@ -0,0 +1,37 @@ +"""Unit tests for checking the chunksize of the data.""" + +import numpy as np +import xarray as xr + +from rechunk_data import check_chunk_size + + +def test_check_chunk_size(capsys) -> None: + """Test the check_chunk_size function with valid and invalid chunksizes.""" + + data1 = np.random.rand(100, 1100, 1200) + da1 = xr.DataArray(data1, dims=("time", "lat", "lon"), name="valid_var") + dset1 = xr.Dataset({"valid_var": da1}) + dset1 = dset1.chunk({"time": 10, "lat": 550, "lon": 600}) + dset1["valid_var"].encoding = { + "chunksizes": (10, 550, 600), + "dtype": "float32", + "zlib": True, + } + + data2 = np.random.rand(100, 1100, 1200) + da2 = xr.DataArray(data2, dims=("time", "lat", "lon"), name="invalid_var") + dset2 = xr.Dataset({"invalid_var": da2}) + + combined_dset = xr.merge([dset1, dset2]) + + check_chunk_size(combined_dset) + captured = capsys.readouterr() + + assert "valid_var" in captured.out + assert "invalid_var" in captured.out + assert ( + "âš ï¸ Warning: No chunk sizes found for invalid_var, skipping..." + in captured.out + ) + assert "Estimated Chunk Size: 25.18 MB" in captured.out diff --git a/src/rechunk_data/tests/test_cli.py b/src/rechunk_data/tests/test_cli.py index cdf7c51f8f101470adaa4885b2074c244b23a092..6ed1aaf8edaf1696623445bcc16f5722232fe81e 100644 --- a/src/rechunk_data/tests/test_cli.py +++ b/src/rechunk_data/tests/test_cli.py @@ -1,9 +1,12 @@ """Unit tests for the cli.""" +import sys +from io import StringIO from pathlib import Path from tempfile import TemporaryDirectory import pytest + from rechunk_data import cli @@ -20,3 +23,17 @@ def test_command_line_interface(data_dir: Path) -> None: cli([str(data_dir), "--output", temp_dir]) new_files = sorted(Path(temp_dir).rglob("*.nc")) assert len(data_files) == len(new_files) + + +def test_check_chunk_size(data_dir: Path) -> None: + """Test the --check-chunk-size argument.""" + + data_files = sorted(data_dir.rglob("*.nc"))[0] + captured_output = StringIO() + sys.stdout = captured_output + cli([str(data_files), "--check-chunk-size"]) + sys.stdout = sys.__stdout__ + output = captured_output.getvalue() + assert "tas:" in output + assert "Chunk Sizes: (1, 1, 24, 24)" in output + assert "Estimated Chunk Size: 0.00 MB" in output diff --git a/src/rechunk_data/tests/test_rechunk_netcdf.py b/src/rechunk_data/tests/test_rechunk_netcdf.py index 3544d3d6138bbc1c9b4204c9b206ac88b41708b2..3bc542c1f133b75935e719f80ee22cc54d3b7044 100644 --- a/src/rechunk_data/tests/test_rechunk_netcdf.py +++ b/src/rechunk_data/tests/test_rechunk_netcdf.py @@ -1,12 +1,14 @@ """Test the actual rechunk method.""" + import logging -from pathlib import Path import time +from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryDirectory import dask import pytest -from rechunk_data import rechunk_netcdf_file, rechunk_dataset + +from rechunk_data import rechunk_dataset, rechunk_netcdf_file from rechunk_data._rechunk import _save_dataset @@ -24,7 +26,9 @@ def test_rechunk_data_dir_without_overwrite(data_dir: Path) -> None: """Testing the creation of new datafiles from a folder.""" with TemporaryDirectory() as temp_dir: rechunk_netcdf_file(data_dir, Path(temp_dir)) - new_files = sorted(f.relative_to(temp_dir) for f in Path(temp_dir).rglob(".nc")) + new_files = sorted( + f.relative_to(temp_dir) for f in Path(temp_dir).rglob(".nc") + ) old_files = sorted(f.relative_to(data_dir) for f in data_dir.rglob(".nc")) assert new_files == old_files @@ -59,13 +63,15 @@ def test_wrong_or_format(small_chunk_data, caplog) -> None: _, loglevel, message = caplog.record_tuples[-1] assert loglevel == logging.ERROR assert "Error while" in message - _save_dataset(small_chunk_data, temp_file, {}, "foo") + _save_dataset(small_chunk_data, temp_file, {}, "foo") # type: ignore[arg-type] _, loglevel, message = caplog.record_tuples[-1] - _save_dataset(small_chunk_data, temp_file, {"foo": "bar"}, "foo") + _save_dataset( + small_chunk_data, temp_file, {"foo": "bar"}, "foo" + ) # type: ignore[arg-type] _, loglevel, message = caplog.record_tuples[-1] assert loglevel == logging.ERROR def test_wrong_engine(small_chunk_data) -> None: with pytest.raises(ValueError): - rechunk_dataset(small_chunk_data, engine="foo") + rechunk_dataset(small_chunk_data, engine="foo") # type: ignore[arg-type] diff --git a/src/rechunk_data/tests/test_rechunking.py b/src/rechunk_data/tests/test_rechunking.py index 945bb0d2e040b5e3f149304f2aee9374423c93e8..4f0e9e186b64b23b529a93f7360a2466d047b816 100644 --- a/src/rechunk_data/tests/test_rechunking.py +++ b/src/rechunk_data/tests/test_rechunking.py @@ -1,8 +1,14 @@ """Unit tests for rechunking the data.""" + import dask +import numpy as np import xarray as xr -from rechunk_data._rechunk import _rechunk_dataset +from rechunk_data._rechunk import ( + _horizontal_chunks, + _optimal_chunks, + _rechunk_dataset, +) def test_rechunking_small_data( @@ -35,3 +41,121 @@ def test_rechunking_large_data( dset, encoding = _rechunk_dataset(large_chunk_data, "h5netcdf") assert encoding[variable_name]["chunksizes"] == chunks assert dset[variable_name].data.chunksize == chunks + + +def test_optimal_chunks(): + """Test the optimal chunk calculation.""" + + time_size = 100 + y_size = 192 + x_size = 384 + dtype = np.dtype("float32") + + time_chunk, y_chunk, x_chunk = _optimal_chunks( + time_size, y_size, x_size, dtype=dtype, only_horizontal=False + ) + assert time_chunk <= time_size + assert y_chunk <= y_size + assert x_chunk <= x_size + + time_chunk, y_chunk, x_chunk = _optimal_chunks( + time_size, y_size, x_size, dtype=dtype, only_horizontal=True + ) + assert time_chunk == time_size + assert y_chunk <= y_size + assert x_chunk <= x_size + + time_size = 100_000_000 + y_size = 1 + x_size = 1 + time_chunk, y_chunk, x_chunk = _optimal_chunks( + time_size, y_size, x_size, dtype=dtype, only_horizontal=True + ) + assert time_chunk < time_size + + +def test_horizontal_chunks(): + """Test horizontal chunking function and _rechunk_dataset() with + force_horizontal.""" + data = np.random.rand(100, 1100, 1200) + da = xr.DataArray(data, dims=("time", "lat", "lon"), name="test_var") + + chunks = {0: "auto", 1: None, 2: None} + updated_chunks = _horizontal_chunks(da, chunks) + assert updated_chunks[0] == 100 + assert updated_chunks[1] < 1100 + assert updated_chunks[2] < 1200 + + updated_chunks = _horizontal_chunks("invalid_data", chunks) + assert updated_chunks == chunks + + dset = xr.Dataset({"test_var": da}) + chunksizes = (10, 550, 1200) + dset = dset.chunk( + {"time": chunksizes[0], "lat": chunksizes[1], "lon": chunksizes[2]} + ) + encoding = { + "test_var": { + "zlib": True, + "complevel": 4, + "shuffle": True, + "dtype": "float32", + "chunksizes": chunksizes, + "_FillValue": np.nan, + } + } + dset["test_var"].encoding = encoding + _, encoding = _rechunk_dataset( + dset, engine="netcdf4", force_horizontal=True + ) + assert "test_var" in encoding + + chunksizes_applied = encoding["test_var"].get("chunksizes", None) + assert chunksizes_applied is not None + assert chunksizes_applied[0] == 100 + assert chunksizes_applied[1] < 1100 + assert chunksizes_applied[2] < 1200 + + dset["test_var"].encoding = encoding + _, encoding = _rechunk_dataset( + dset, engine="netcdf4", force_horizontal=True, auto_chunks=True + ) + chunksizes_applied = encoding["test_var"].get("chunksizes", None) + assert chunksizes_applied[0] < 100 + assert chunksizes_applied[1] < 1100 + assert chunksizes_applied[2] < 1200 + + +def test_auto_size_chunks(): + """ + Test the automatic chunking and size adjustment functionality. + """ + data = np.random.rand(100, 1100, 1200) + da = xr.DataArray(data, dims=("time", "lat", "lon"), name="test_var") + dset = xr.Dataset({"test_var": da}) + chunksizes = (1, 1, 1200) + dset = dset.chunk( + {"time": chunksizes[0], "lat": chunksizes[1], "lon": chunksizes[2]} + ) + encoding = { + "test_var": { + "zlib": True, + "complevel": 4, + "shuffle": True, + "dtype": "float32", + "chunksizes": chunksizes, + "_FillValue": np.nan, + } + } + dset["test_var"].encoding = encoding + _, encoding = _rechunk_dataset(dset, engine="netcdf4", auto_chunks=True) + chunksizes_applied = encoding["test_var"].get("chunksizes", None) + assert chunksizes_applied[0] == 100 + assert chunksizes_applied[2] == 1200 + + dset["test_var"].encoding = encoding + _, encoding = _rechunk_dataset(dset, engine="netcdf4", size=20) + chunksizes_applied = encoding["test_var"].get("chunksizes", None) + assert 1 < chunksizes_applied[0] < 100 + assert 1 < chunksizes_applied[1] < 1100 + assert chunksizes_applied[2] == 1200 diff --git a/src/rechunk_data/tests/test_search_files.py b/src/rechunk_data/tests/test_search_files.py index cff79c039c46fbbc7c61bbd8b470b37320d27965..297ee6468ff6037bec8109593226b6b5126bed1f 100644 --- a/src/rechunk_data/tests/test_search_files.py +++ b/src/rechunk_data/tests/test_search_files.py @@ -1,4 +1,5 @@ """Unit tests for searching for files.""" + from pathlib import Path from rechunk_data._rechunk import _search_for_nc_files