Skip to content
Snippets Groups Projects
Commit 2687a49b authored by Etor Lucio Eceiza's avatar Etor Lucio Eceiza
Browse files

update tests

parent f10e9801
No related branches found
No related tags found
1 merge request!15add functionality to allow prioritizing horizontal chunking over temporal dimension
......@@ -3,17 +3,16 @@ import os
from pathlib import Path
from typing import (
Any,
Generator,
Mapping,
Dict,
Generator,
Hashable,
List,
Optional,
Tuple,
)
from typing_extensions import Literal
import xarray as xr
from typing_extensions import Literal
def parse_args() -> argparse.Namespace: ...
def _search_for_nc_files(input_path: Path) -> Generator[Path, None, None]: ...
......
"""pytest definitions to run the unittests."""
from pathlib import Path
from tempfile import TemporaryDirectory, NamedTemporaryFile
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import Generator, Tuple
import dask
import pytest
import numpy as np
import pytest
import xarray as xr
......
"""Unit tests for checking the chunksize of the data."""
import xarray as xr
import numpy as np
import xarray as xr
from rechunk_data import check_chunk_size
......
"""Unit tests for the cli."""
import sys
from io import StringIO
from pathlib import Path
from tempfile import TemporaryDirectory
from io import StringIO
import pytest
from rechunk_data import cli
import sys
def test_command_line_interface(data_dir: Path) -> None:
......
"""Test the actual rechunk method."""
import logging
from pathlib import Path
import time
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
import dask
import pytest
from rechunk_data import rechunk_netcdf_file, rechunk_dataset
from rechunk_data import rechunk_dataset, rechunk_netcdf_file
from rechunk_data._rechunk import _save_dataset
......@@ -62,13 +63,15 @@ def test_wrong_or_format(small_chunk_data, caplog) -> None:
_, loglevel, message = caplog.record_tuples[-1]
assert loglevel == logging.ERROR
assert "Error while" in message
_save_dataset(small_chunk_data, temp_file, {}, "foo")
_save_dataset(small_chunk_data, temp_file, {}, "foo") # type: ignore[arg-type]
_, loglevel, message = caplog.record_tuples[-1]
_save_dataset(small_chunk_data, temp_file, {"foo": "bar"}, "foo")
_save_dataset(
small_chunk_data, temp_file, {"foo": "bar"}, "foo"
) # type: ignore[arg-type]
_, loglevel, message = caplog.record_tuples[-1]
assert loglevel == logging.ERROR
def test_wrong_engine(small_chunk_data) -> None:
with pytest.raises(ValueError):
rechunk_dataset(small_chunk_data, engine="foo")
rechunk_dataset(small_chunk_data, engine="foo") # type: ignore[arg-type]
"""Unit tests for rechunking the data."""
import dask
import xarray as xr
import numpy as np
import xarray as xr
from rechunk_data._rechunk import (
_rechunk_dataset,
_optimal_chunks,
_check_horizontal_unchanged,
_horizontal_chunks,
_optimal_chunks,
_rechunk_dataset,
)
......@@ -75,20 +74,6 @@ def test_optimal_chunks():
assert time_chunk < time_size
def test_check_horizontal_unchanged():
"""Test detection of horizontal chunks remaining unchanged."""
orig_size = {"time": 100, "lat": 192, "lon": 384}
chunksizes = {"time": 50, "lat": 192, "lon": 384}
assert _check_horizontal_unchanged(orig_size, chunksizes) is True
chunksizes = {"time": 50, "lat": 192, "lon": 128}
assert _check_horizontal_unchanged(orig_size, chunksizes) is True
chunksizes = {"time": 50, "lat": 128, "lon": 128}
assert _check_horizontal_unchanged(orig_size, chunksizes) is False
def test_horizontal_chunks():
"""Test horizontal chunking function and _rechunk_dataset() with
force_horizontal."""
......@@ -96,18 +81,12 @@ def test_horizontal_chunks():
da = xr.DataArray(data, dims=("time", "lat", "lon"), name="test_var")
chunks = {0: "auto", 1: None, 2: None}
chunksizes = (1, 100, 1200)
updated_chunks = _horizontal_chunks(da, chunks, chunksizes)
updated_chunks = _horizontal_chunks(da, chunks)
assert updated_chunks[0] == 100
assert updated_chunks[1] < 1100
assert updated_chunks[2] < 1200
chunksizes = (50, 100, 600)
updated_chunks = _horizontal_chunks(da, chunks, chunksizes)
assert updated_chunks == chunks
updated_chunks = _horizontal_chunks("invalid_data", chunks, chunksizes)
updated_chunks = _horizontal_chunks("invalid_data", chunks)
assert updated_chunks == chunks
dset = xr.Dataset({"test_var": da})
......@@ -136,3 +115,47 @@ def test_horizontal_chunks():
assert chunksizes_applied[0] == 100
assert chunksizes_applied[1] < 1100
assert chunksizes_applied[2] < 1200
dset["test_var"].encoding = encoding
_, encoding = _rechunk_dataset(
dset, engine="netcdf4", force_horizontal=True, auto_chunks=True
)
chunksizes_applied = encoding["test_var"].get("chunksizes", None)
assert chunksizes_applied[0] < 100
assert chunksizes_applied[1] < 1100
assert chunksizes_applied[2] < 1200
def test_auto_size_chunks():
"""
Test the automatic chunking and size adjustment functionality.
"""
data = np.random.rand(100, 1100, 1200)
da = xr.DataArray(data, dims=("time", "lat", "lon"), name="test_var")
dset = xr.Dataset({"test_var": da})
chunksizes = (1, 1, 1200)
dset = dset.chunk(
{"time": chunksizes[0], "lat": chunksizes[1], "lon": chunksizes[2]}
)
encoding = {
"test_var": {
"zlib": True,
"complevel": 4,
"shuffle": True,
"dtype": "float32",
"chunksizes": chunksizes,
"_FillValue": np.nan,
}
}
dset["test_var"].encoding = encoding
_, encoding = _rechunk_dataset(dset, engine="netcdf4", auto_chunks=True)
chunksizes_applied = encoding["test_var"].get("chunksizes", None)
assert chunksizes_applied[0] == 100
assert chunksizes_applied[2] == 1200
dset["test_var"].encoding = encoding
_, encoding = _rechunk_dataset(dset, engine="netcdf4", size=20)
chunksizes_applied = encoding["test_var"].get("chunksizes", None)
assert 1 < chunksizes_applied[0] < 100
assert 1 < chunksizes_applied[1] < 1100
assert chunksizes_applied[2] == 1200
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment