Skip to content
Snippets Groups Projects
Commit f7a9e598 authored by Etor Lucio Eceiza's avatar Etor Lucio Eceiza
Browse files

add: code coverage for new functions, linting

parent 5935922b
No related branches found
No related tags found
1 merge request!15add functionality to allow prioritizing horizontal chunking over temporal dimension
......@@ -7,10 +7,17 @@ from typing import List, Optional
from ._rechunk import (
rechunk_netcdf_file,
check_chunk_size,
rechunk_dataset,
rechunk_dataset, # noqa
logger,
)
__all__ = [
"rechunk_netcdf_file",
"check_chunk_size",
"rechunk_dataset",
"logger",
]
__version__ = "2503.0.1"
PROGRAM_NAME = "rechunk-data"
......
......@@ -260,11 +260,10 @@ def _horizontal_chunks(
dict
Updated chunk dictionary with appropriate chunk sizes.
"""
orig_size = dict(da.sizes)
chunksizes_dict = dict(zip(da.dims, chunksizes))
chunk_order = dict(zip(da.dims, chunks))
try:
orig_size = dict(da.sizes)
chunk_order = dict(zip(da.dims, chunks))
chunksizes_dict = dict(zip(da.dims, chunksizes))
if _check_horizontal_unchanged(orig_size, chunksizes_dict):
time_size = next(
(
......
"""Unit tests for checking the chunksize of the data."""
import xarray as xr
import numpy as np
from rechunk_data import check_chunk_size
def test_check_chunk_size(capsys) -> None:
"""Test the check_chunk_size function with valid and invalid chunksizes."""
data1 = np.random.rand(100, 1100, 1200)
da1 = xr.DataArray(data1, dims=("time", "lat", "lon"), name="valid_var")
dset1 = xr.Dataset({"valid_var": da1})
dset1 = dset1.chunk({"time": 10, "lat": 550, "lon": 600})
dset1["valid_var"].encoding = {
"chunksizes": (10, 550, 600),
"dtype": "float32",
"zlib": True,
}
data2 = np.random.rand(100, 1100, 1200)
da2 = xr.DataArray(data2, dims=("time", "lat", "lon"), name="invalid_var")
dset2 = xr.Dataset({"invalid_var": da2})
combined_dset = xr.merge([dset1, dset2])
check_chunk_size(combined_dset)
captured = capsys.readouterr()
assert "valid_var" in captured.out
assert "invalid_var" in captured.out
assert (
"⚠️ Warning: No chunk sizes found for invalid_var, skipping..."
in captured.out
)
assert "Estimated Chunk Size: 25.18 MB" in captured.out
......@@ -2,9 +2,10 @@
from pathlib import Path
from tempfile import TemporaryDirectory
from io import StringIO
import pytest
from rechunk_data import cli
import sys
def test_command_line_interface(data_dir: Path) -> None:
......@@ -20,3 +21,17 @@ def test_command_line_interface(data_dir: Path) -> None:
cli([str(data_dir), "--output", temp_dir])
new_files = sorted(Path(temp_dir).rglob("*.nc"))
assert len(data_files) == len(new_files)
def test_check_chunk_size(data_dir: Path) -> None:
"""Test the --check-chunk-size argument."""
data_files = sorted(data_dir.rglob("*.nc"))[0]
captured_output = StringIO()
sys.stdout = captured_output
cli([str(data_files), "--check-chunk-size"])
sys.stdout = sys.__stdout__
output = captured_output.getvalue()
assert "tas:" in output
assert "Chunk Sizes: (1, 1, 24, 24)" in output
assert "Estimated Chunk Size: 0.00 MB" in output
......@@ -2,8 +2,14 @@
import dask
import xarray as xr
import numpy as np
from rechunk_data._rechunk import _rechunk_dataset
from rechunk_data._rechunk import (
_rechunk_dataset,
_optimal_chunks,
_check_horizontal_unchanged,
_horizontal_chunks,
)
def test_rechunking_small_data(
......@@ -36,3 +42,97 @@ def test_rechunking_large_data(
dset, encoding = _rechunk_dataset(large_chunk_data, "h5netcdf")
assert encoding[variable_name]["chunksizes"] == chunks
assert dset[variable_name].data.chunksize == chunks
def test_optimal_chunks():
"""Test the optimal chunk calculation."""
time_size = 100
y_size = 192
x_size = 384
dtype = np.dtype("float32")
time_chunk, y_chunk, x_chunk = _optimal_chunks(
time_size, y_size, x_size, dtype=dtype, only_horizontal=False
)
assert time_chunk <= time_size
assert y_chunk <= y_size
assert x_chunk <= x_size
time_chunk, y_chunk, x_chunk = _optimal_chunks(
time_size, y_size, x_size, dtype=dtype, only_horizontal=True
)
assert time_chunk == time_size
assert y_chunk <= y_size
assert x_chunk <= x_size
time_size = 100_000_000
y_size = 1
x_size = 1
time_chunk, y_chunk, x_chunk = _optimal_chunks(
time_size, y_size, x_size, dtype=dtype, only_horizontal=True
)
assert time_chunk < time_size
def test_check_horizontal_unchanged():
"""Test detection of horizontal chunks remaining unchanged."""
orig_size = {"time": 100, "lat": 192, "lon": 384}
chunksizes = {"time": 50, "lat": 192, "lon": 384}
assert _check_horizontal_unchanged(orig_size, chunksizes) is True
chunksizes = {"time": 50, "lat": 192, "lon": 128}
assert _check_horizontal_unchanged(orig_size, chunksizes) is True
chunksizes = {"time": 50, "lat": 128, "lon": 128}
assert _check_horizontal_unchanged(orig_size, chunksizes) is False
def test_horizontal_chunks():
"""Test horizontal chunking function and _rechunk_dataset() with
force_horizontal."""
data = np.random.rand(100, 1100, 1200)
da = xr.DataArray(data, dims=("time", "lat", "lon"), name="test_var")
chunks = {0: "auto", 1: None, 2: None}
chunksizes = (1, 100, 1200)
updated_chunks = _horizontal_chunks(da, chunks, chunksizes)
assert updated_chunks[0] == 100
assert updated_chunks[1] < 1100
assert updated_chunks[2] < 1200
chunksizes = (50, 100, 600)
updated_chunks = _horizontal_chunks(da, chunks, chunksizes)
assert updated_chunks == chunks
updated_chunks = _horizontal_chunks("invalid_data", chunks, chunksizes)
assert updated_chunks == chunks
dset = xr.Dataset({"test_var": da})
chunksizes = (10, 550, 1200)
dset = dset.chunk(
{"time": chunksizes[0], "lat": chunksizes[1], "lon": chunksizes[2]}
)
encoding = {
"test_var": {
"zlib": True,
"complevel": 4,
"shuffle": True,
"dtype": "float32",
"chunksizes": chunksizes,
"_FillValue": np.nan,
}
}
dset["test_var"].encoding = encoding
_, encoding = _rechunk_dataset(
dset, engine="netcdf4", force_horizontal=True
)
assert "test_var" in encoding
chunksizes_applied = encoding["test_var"].get("chunksizes", None)
assert chunksizes_applied is not None
assert chunksizes_applied[0] == 100
assert chunksizes_applied[1] < 1100
assert chunksizes_applied[2] < 1200
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment