diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 86ea82140133b3d89876252316a133f374df4444..2f879708645bd9c06350c2681e7dbc344ac2e2cd 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -13,45 +13,53 @@ lint:
     - pip install .[test]
   script:
     - mypy
-    - black --check src
+    - black --check -l 79 src
     - pylint --fail-under 8.5 src/rechunk_data/__init__.py
 
-test_36:
+test_39:
   << : *py_test
   before_script:
-    - conda create -q -p /tmp/test python=3.6 pip dask -y
+    - conda create -q -p /tmp/test python=3.9 pip dask -y
     - /tmp/test/bin/python -m pip install -e .[test]
   script:
     - /tmp/test/bin/python -m pytest -vv
 
-test_37:
+test_310:
   << : *py_test
   before_script:
-    - conda create -q -p /tmp/test python=3.7 pip dask -y
+    - conda create -q -p /tmp/test python=3.10 pip dask -y
     - /tmp/test/bin/python -m pip install -e .[test]
   script:
     - /tmp/test/bin/python -m pytest -vv
 
-test_38:
+test_310:
   << : *py_test
   before_script:
-    - conda create -q -p /tmp/test python=3.8 pip dask -y
+    - conda create -q -p /tmp/test python=3.10 pip dask -y
     - /tmp/test/bin/python -m pip install -e .[test]
   script:
     - /tmp/test/bin/python -m pytest -vv
 
-test_39:
+test_311:
   << : *py_test
   before_script:
-    - conda create -q -p /tmp/test python=3.9 pip dask -y
+    - conda create -q -p /tmp/test python=3.11 pip dask -y
     - /tmp/test/bin/python -m pip install -e .[test]
   script:
     - /tmp/test/bin/python -m pytest -vv
 
-test_310:
+test_312:
   << : *py_test
   before_script:
-    - conda create -q -p /tmp/test python=3.10 pip dask -y
+    - conda create -q -p /tmp/test python=3.12 pip dask -y
+    - /tmp/test/bin/python -m pip install -e .[test]
+  script:
+    - /tmp/test/bin/python -m pytest -vv
+
+test_313:
+  << : *py_test
+  before_script:
+    - conda create -q -p /tmp/test python=3.13 pip dask -y
     - /tmp/test/bin/python -m pip install -e .[test]
   script:
     - /tmp/test/bin/python -m pytest -vv
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..73c495a161feaf01220a9a07aaddcb9fe53e7bd9
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,24 @@
+# makefile used for testing
+#
+#
+all: install test
+
+install:
+	python3 -m pip install -e .[test]
+
+test:
+	python3 -m pytest -vv \
+	    --cov=$(PWD)/src/rechunk_data --cov-report=html:coverage_report \
+	    --junitxml report.xml --cov-report xml \
+		$(PWD)/src/rechunk_data/tests
+	python3 -m coverage report
+
+format:
+	isort --profile=black src
+	black -t py310 -l 79 src
+
+lint:
+	mypy --install-types --non-interactive
+	isort --check-only --profile=black src
+	black --check -t py310 -l 79 src
+	flake8 src/rechunk_data --count --max-complexity=15 --max-line-length=88 --statistics --doctests
\ No newline at end of file
diff --git a/README.md b/README.md
index b75b5d68a901f69260102416628a46fbf95af2b1..fbef163304b0c92cdaab8ec314399d43f535e8d2 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
 # Rechunking NetCDF data.
 
-Rechunking of exsisting netcdf files to an optimal chunk size. This code provides
+Rechunking of existing netcdf files to an optimal chunk size. This code provides
 a simple command line interface (cli) to rechunk existing netcdf data to an
 optimal chunksize of around 128 MB.
 
@@ -19,31 +19,39 @@ Use the `--user` flag if you do not have super user rights and are not using `an
 ### Using the python module
 
 ```python
-from rechunk_data import rechunk_dataset
+from rechunk_data import rechunk_dataset, check_chunk_size
 import xarray as xr
 dset = xr.open_mfdataset("/data/*", parallel=True, combine="by_coords")
+check_chunk_size(dset) # to print the chunksizes of the original set-up
 new_data = rechunk_dataset(dset)
 ```
 
 ### Using the command line interface:
 
 ```bash
-rechunk-data --help
-usage: rechunk-data [-h] [--output OUTPUT] [--netcdf-engine {h5netcdf,netcdf4}] [--skip-cf-convention] [-v] [-V] input
+rechunk-data --help                                
+usage: rechunk-data [-h] [-o OUTPUT] [--netcdf-engine {h5netcdf,netcdf4}] [--size SIZE] [--auto-chunks] [--skip-cf-convention] [--force-horizontal] [--check-chunk-size]
+                    [-v] [-V]
+                    input
 
-Rechunk input netcdf data to optimal chunk-size. approx. 126 MB per chunk
+Rechunk input netcdf data to optimal chunk-size. approx. 126 MB per chunk. By default it only optimises time.
 
 positional arguments:
   input                 Input file/directory. If a directory is given all ``.nc`` files in all sub directories will be processed
 
 options:
   -h, --help            show this help message and exit
-  --output OUTPUT       Output file/directory of the chunked netcdf file(s).
-                        Note: If ``input`` is a directory output should be a directory.
-                        If None given (default) the ``input`` is overidden. (default: None)
+  -o, --output OUTPUT   Output file/directory of the chunked netcdf file(s). Note: If ``input`` is a directory output should be a directory. If None given (default) the
+                        ``input`` is overridden. (default: None)
   --netcdf-engine {h5netcdf,netcdf4}
                         The netcdf engine used to create the new netcdf file. (default: netcdf4)
+  --size, -s SIZE       Specify chunk-size (in MiB). (default: None)
+  --auto-chunks, -ac    Allow Dask to determine optimal chunk sizes for all dimensions. (default: False)
   --skip-cf-convention  Do not assume assume data variables follow CF conventions. (default: False)
+  --force-horizontal, -fh
+                        Force horizontal chunking (~126 MB per chunk). (default: False)
+  --check-chunk-size, -c
+                        Check the chunk size of the input dataset (in MB). (default: False)
   -v                    Increase verbosity (default: 0)
   -V, --version         show program's version number and exit
 ```
diff --git a/setup.py b/setup.py
index ae0a411fec4c8e4ef2171dc05b7a74d6ec392c29..3bbc8c6a1231d9a218445027f91fba1f9c7ec7b0 100644
--- a/setup.py
+++ b/setup.py
@@ -1,9 +1,10 @@
 #!/usr/bin/env python3
-"""Setup script for packaging checkin."""
+"""Setup script for packaging check-in."""
 
 import json
 from pathlib import Path
-from setuptools import setup, find_packages
+
+from setuptools import find_packages, setup
 
 
 def read(*parts):
@@ -35,14 +36,25 @@ setup(
     packages=find_packages("src"),
     package_dir={"": "src"},
     entry_points={
-        "console_scripts": [f"{find_key(key='PROGRAM_NAME')} = rechunk_data:cli"]
+        "console_scripts": [
+            f"{find_key(key='PROGRAM_NAME')} = rechunk_data:cli"
+        ]
     },
-    install_requires=["argparse", "dask", "xarray", "h5netcdf", "netCDF4", "typing_extensions"],
+    install_requires=[
+        "argparse",
+        "dask",
+        "xarray",
+        "h5netcdf",
+        "netCDF4",
+        "typing_extensions",
+    ],
     extras_require={
         "test": [
             "black",
+            "isort",
             "mock",
             "mypy",
+            "ipython",
             "nbformat",
             "pytest",
             "pylint",
@@ -51,9 +63,10 @@ setup(
             "pytest-cov",
             "testpath",
             "types-mock",
+            "flake8",
         ],
     },
-    python_requires=">=3.6",
+    python_requires=">=3.9",
     classifiers=[
         "Development Status :: 3 - Alpha",
         "Environment :: Console",
diff --git a/src/rechunk_data/__init__.py b/src/rechunk_data/__init__.py
index 8e68bac7b079026b27d8d2479588bef99b35c82a..da4524a225ab240060e39b026eaa10c9083d54d6 100644
--- a/src/rechunk_data/__init__.py
+++ b/src/rechunk_data/__init__.py
@@ -4,13 +4,18 @@ import argparse
 import logging
 from pathlib import Path
 from typing import List, Optional
-from ._rechunk import (
-    rechunk_netcdf_file,
-    rechunk_dataset,
-    logger,
-)
 
-__version__ = "2310.0.1"
+from ._rechunk import rechunk_dataset  # noqa
+from ._rechunk import check_chunk_size, logger, rechunk_netcdf_file
+
+__all__ = [
+    "rechunk_netcdf_file",
+    "check_chunk_size",
+    "rechunk_dataset",
+    "logger",
+]
+
+__version__ = "2503.0.1"
 PROGRAM_NAME = "rechunk-data"
 
 
@@ -21,7 +26,7 @@ def parse_args(argv: Optional[List[str]]) -> argparse.Namespace:
         prog=PROGRAM_NAME,
         description=(
             "Rechunk input netcdf data to optimal chunk-size."
-            " approx. 126 MB per chunk"
+            " approx. 126 MB per chunk. By default it only optimises time."
         ),
         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
     )
@@ -35,12 +40,13 @@ def parse_args(argv: Optional[List[str]]) -> argparse.Namespace:
         ),
     )
     parser.add_argument(
+        "-o",
         "--output",
         type=Path,
         help=(
             "Output file/directory of the chunked netcdf "
             "file(s). Note: If ``input`` is a directory output should be a"
-            " directory. If None given (default) the ``input`` is overidden."
+            " directory. If None given (default) the ``input`` is overridden."
         ),
         default=None,
     )
@@ -51,12 +57,40 @@ def parse_args(argv: Optional[List[str]]) -> argparse.Namespace:
         default="netcdf4",
         type=str,
     )
+    parser.add_argument(
+        "--size",
+        "-s",
+        help=("Specify chunk-size (in MiB)."),
+        default=None,
+        type=str,
+    )
+    parser.add_argument(
+        "--auto-chunks",
+        "-ac",
+        help="Allow Dask to determine optimal chunk sizes for all dimensions.",
+        action="store_true",
+        default=False,
+    )
     parser.add_argument(
         "--skip-cf-convention",
         help="Do not assume assume data variables follow CF conventions.",
         action="store_true",
         default=False,
     )
+    parser.add_argument(
+        "--force-horizontal",
+        "-fh",
+        help="Force horizontal chunking (~126 MB per chunk).",
+        action="store_true",
+        default=False,
+    )
+    parser.add_argument(
+        "--check-chunk-size",
+        "-c",
+        help="Check the chunk size of the input dataset (in MB).",
+        action="store_true",
+        default=False,
+    )
     parser.add_argument(
         "-v",
         action="count",
@@ -77,9 +111,18 @@ def parse_args(argv: Optional[List[str]]) -> argparse.Namespace:
 def cli(argv: Optional[List[str]] = None) -> None:
     """Command line interface calling the rechunking method."""
     args = parse_args(argv)
+    if args.check_chunk_size:
+        import xarray as xr
+
+        with xr.open_dataset(args.input, chunks={}) as dset:
+            check_chunk_size(dset)
+        return
     rechunk_netcdf_file(
         args.input,
         args.output,
         engine=args.netcdf_engine,
+        size=args.size,
+        auto_chunks=args.auto_chunks,
         decode_cf=args.skip_cf_convention is False,
+        force_horizontal=args.force_horizontal,
     )
diff --git a/src/rechunk_data/__init__.pyi b/src/rechunk_data/__init__.pyi
index 8a1130084a494d4de4e802e705343756cbf305a0..8bdf7f4d68818e2ed34f0cd299a8a33c335e4107 100644
--- a/src/rechunk_data/__init__.pyi
+++ b/src/rechunk_data/__init__.pyi
@@ -3,17 +3,16 @@ import os
 from pathlib import Path
 from typing import (
     Any,
-    Generator,
-    Mapping,
     Dict,
+    Generator,
     Hashable,
     List,
     Optional,
     Tuple,
 )
-from typing_extensions import Literal
 
 import xarray as xr
+from typing_extensions import Literal
 
 def parse_args() -> argparse.Namespace: ...
 def _search_for_nc_files(input_path: Path) -> Generator[Path, None, None]: ...
diff --git a/src/rechunk_data/_rechunk.py b/src/rechunk_data/_rechunk.py
index 311ed2508fb36b78904887a398e13d08c410904c..42384d1faf673b6a8066975110fe487d0ce8ec04 100644
--- a/src/rechunk_data/_rechunk.py
+++ b/src/rechunk_data/_rechunk.py
@@ -1,15 +1,16 @@
 """Rechunking module."""
 
-import os
 import logging
+import os
 from pathlib import Path
-from typing import cast, Any, Dict, Hashable, Generator, Optional, Tuple
-from typing_extensions import Literal
+from typing import Any, Dict, Generator, Hashable, Optional, Tuple, cast
 
-from dask.utils import format_bytes
-from dask.array.core import Array
+import dask
+import numpy as np
 import xarray as xr
-
+from dask.array.core import Array
+from dask.utils import format_bytes
+from typing_extensions import Literal
 
 logging.basicConfig(
     format="%(name)s - %(levelname)s - %(message)s", level=logging.ERROR
@@ -41,9 +42,27 @@ ENCODINGS = dict(
         "dtype",
     },
 )
+default_chunk_size = 126.0  # in MiB
 
 
 def _search_for_nc_files(input_path: Path) -> Generator[Path, None, None]:
+    """Yield all netCDF files in the given input_path.
+
+    If the input is a directory, search recursively for all files with the
+    suffixes .nc or .nc4. If the input is a file, yield only the file itself.
+    If the input is a path with a glob pattern, construct it and search for
+    netCDF files in the resulting directory.
+
+    Parameters
+    ----------
+    input_path: Path
+        The path to search for netCDF files
+
+    Yields
+    ------
+    Path
+        The path to a netCDF file
+    """
     suffixes = [".nc", "nc4"]
     input_path = input_path.expanduser().absolute()
     if input_path.is_dir() and input_path.exists():
@@ -65,6 +84,28 @@ def _save_dataset(
     engine: Literal["netcdf4", "h5netcdf"],
     override: bool = False,
 ) -> None:
+    """
+    Save the given xarray dataset to a netCDF file with specified encoding.
+
+    Parameters
+    ----------
+    dset : xr.Dataset
+        The dataset to be saved.
+    file_name : Path
+        The path where the netCDF file will be saved.
+    encoding : Dict[Hashable, Dict[str, Any]]
+        Encoding options for each variable in the dataset.
+    engine : Literal["netcdf4", "h5netcdf"]
+        The engine to use for writing the netCDF file.
+    override : bool, optional
+        If True, save the dataset even if no encoding is provided. Default is False.
+
+    Returns
+    -------
+    None
+
+    Logs a debug message indicating successful save or an error message if saving fails.
+    """
     if not encoding and not override:
         logger.debug("Chunk size already optimized for %s", file_name.name)
         return
@@ -79,10 +120,199 @@ def _save_dataset(
         logger.error("Saving to file failed: %s", str(error))
 
 
+def _optimal_chunks(
+    time_size: int,
+    y_size: int,
+    x_size: int,
+    size: float = default_chunk_size,
+    dtype: np.dtype = np.dtype("float32"),
+    only_horizontal: bool = True,
+) -> Tuple[int, int, int]:
+    """
+    Compute optimal chunk sizes for time, y (lat), and x (lon) adjusted to a
+    certain size in MB.
+    Optionally, allow forcing chunking only in horizontal dimensions
+    while keeping the time dimension as a single chunk.
+
+    Parameters
+    ----------
+    time_size: int
+        Total size of the time dimension (kept as a single chunk).
+    y_size: int
+        Total size of the y (latitude) dimension.
+    x_size: int
+        Total size of the x (longitude) dimension.
+    size: float, default: 126MB
+        Desired chunk size in MB.
+    dtype: np.dtype
+        Encoding dtype.
+    only_horizontal: bool
+        Whether to force chunking only in the horizontal dimensions.
+
+    Returns
+    -------
+    Tuple[int, int, int]
+        Optimal chunk sizes for time, y, and x.
+    """
+    dtype_size = dtype.itemsize
+    target_elements: float = (size * (1024**2)) / dtype_size
+
+    if y_size == 1 and x_size == 1:
+        only_horizontal = False
+    if only_horizontal:
+        factor = np.sqrt(target_elements / (time_size * y_size * x_size))
+        time_chunk = time_size
+    else:
+        factor = np.cbrt(target_elements / (time_size * y_size * x_size))
+        time_chunk = max(1, int(time_size * factor))
+    y_chunk = max(1, int(y_size * factor))
+    x_chunk = max(1, int(x_size * factor))
+
+    y_chunk = min(y_chunk, y_size)
+    x_chunk = min(x_chunk, x_size)
+    time_chunk = min(time_chunk, time_size)
+    return (time_chunk, y_chunk, x_chunk)
+
+
+def _map_chunks(
+    chunks: dict, dim_mapping: dict, chunk_tuple: Tuple[int, int, int]
+) -> dict:
+    """
+    Update chunk sizes in `chunks` using the dimension mapping.
+
+    Parameters
+    ----------
+    chunks: dict
+        Dictionary with initial chunking setup (e.g., {0: 'auto', 1: None, 2: None}).
+    dim_mapping: dict
+        Mapping of dimension names to chunk indices (e.g., {'time': 0, 'y': 1, 'x': 2}).
+    chunk_tuple: Tuple[int, int, int]
+        Desired chunk size for (time, y, x) dimensions.
+
+    Returns
+    -------
+    dict
+        Updated chunk dictionary with appropriate chunk sizes.
+    """
+    updated_chunks = chunks.copy()
+
+    for dim, index in dim_mapping.items():
+        if "time" in dim.lower():
+            updated_chunks[index] = chunk_tuple[0]
+        if "y" in dim.lower() or "lat" in dim.lower():
+            updated_chunks[index] = chunk_tuple[1]
+        if "x" in dim.lower() or "lon" in dim.lower():
+            updated_chunks[index] = chunk_tuple[2]
+
+    return updated_chunks
+
+
+def _horizontal_chunks(
+    da: xr.DataArray,
+    chunks: dict,
+    size: Optional[float] = None,
+    only_horizontal: bool = True,
+) -> dict:
+    """
+    Updates chunk sizes, forcing horizontal chunking whenever possible.
+
+    .. Note:
+
+        This function will not take in account level dimensions such as
+        ''lev'' or ''depth'' in the chunking. probably resulting on a chunksize
+        of 1 for those dimensions.
+
+    Parameters
+    ----------
+    da: xr.DataArray
+        Data variable from the dataset.
+    chunks: dict
+        Original chunking dictionary.
+    size: float, default: None
+        Desired chunk size in MB.
+    only_horizontal: bool   default: True
+        Whether to force chunking only in the horizontal dimensions,
+        e.g. disregarding time.
+
+    Returns
+    -------
+    dict
+        Updated chunk dictionary with appropriate chunk sizes.
+    """
+    try:
+        orig_size = dict(da.sizes)
+        chunk_order = dict(zip(da.dims, chunks))
+        time_size = next(
+            (
+                value
+                for key, value in orig_size.items()
+                if isinstance(key, str) and "time" in key.lower()
+            ),
+            1,
+        )
+
+        y_size = next(
+            (
+                value
+                for key, value in orig_size.items()
+                if isinstance(key, str)
+                and any(k in key.lower() for k in ["lat", "y"])
+            ),
+            1,
+        )
+
+        x_size = next(
+            (
+                value
+                for key, value in orig_size.items()
+                if isinstance(key, str)
+                and any(k in key.lower() for k in ["lon", "x"])
+            ),
+            1,
+        )
+        dtype = da.encoding.get("dtype", np.dtype("float32"))
+        chunk_tuple = _optimal_chunks(
+            time_size,
+            y_size,
+            x_size,
+            size=size if size is not None else default_chunk_size,
+            dtype=dtype,
+            only_horizontal=only_horizontal,
+        )
+        return _map_chunks(chunks, chunk_order, chunk_tuple)
+    except Exception as e:
+        logger.error(f"Error in _horizontal_chunks: {e}", exc_info=True)
+        return chunks
+
+
 def _rechunk_dataset(
     dset: xr.Dataset,
     engine: Literal["h5netcdf", "netcdf4"],
+    size: Optional[int] = None,
+    auto_chunks: bool = False,
+    force_horizontal: bool = False,
 ) -> Tuple[xr.Dataset, Dict[Hashable, Dict[str, Any]]]:
+    """
+    Rechunk a xarray dataset.
+
+    Parameters
+    ----------
+    dset: xarray.Dataset
+        Input dataset that is going to be rechunked
+    engine: str, default: netcdf4
+        The netcdf engine used to create the new netcdf file.
+    size: int, default: None
+        Desired chunk size in MB. If None, computed by Dask or default_chunk_size.
+    auto_chunks: bool, default: False
+        If True, Dask automatically determines the optimal chunk size.
+    force_horizontal: bool, default: False
+        If True, forces horizontal chunking whenever possible.
+
+    Returns
+    -------
+    Tuple[xr.Dataset, Dict[Hashable, Dict[str, Any]]]
+        A tuple containing the rechunked dataset and the updated encoding dictionary.
+    """
     encoding: Dict[Hashable, Dict[str, Any]] = {}
     try:
         _keywords = ENCODINGS[engine]
@@ -100,14 +330,39 @@ def _rechunk_dataset(
             logger.debug("Skipping rechunking variable %s", var)
             continue
         logger.debug("Rechunking variable %s", var)
-        chunks: Dict[int, Optional[str]] = {}
-        for i, dim in enumerate(map(str, dset[var].dims)):
-            if "lon" in dim.lower() or "lat" in dim.lower() or "bnds" in dim.lower():
-                chunks[i] = None
-            else:
-                chunks[i] = "auto"
         old_chunks = dset[var].encoding.get("chunksizes")
-        new_chunks = dset[var].data.rechunk(chunks).chunksize
+        chunks: Dict[int, Optional[str]] = {
+            i: "auto" for i, _ in enumerate(dset[var].dims)
+        }
+        if force_horizontal:
+            if auto_chunks:
+                only_horizontal = False
+            else:
+                only_horizontal = True
+            chunks = _horizontal_chunks(
+                dset[var],
+                chunks,
+                size=size,
+                only_horizontal=only_horizontal,
+            )
+            new_chunks = dset[var].data.rechunk(chunks).chunksize
+        else:
+            if not auto_chunks:
+                for i, dim in enumerate(map(str, dset[var].dims)):
+                    if any(
+                        keyword in dim.lower()
+                        for keyword in ["lon", "lat", "bnds", "x", "y"]
+                    ):
+                        chunks[i] = None
+            if size:
+                with dask.config.set(array__chunk_size=f"{size}MiB"):
+                    new_chunks = (
+                        dset[var]
+                        .data.rechunk(chunks, balance=True, method="tasks")
+                        .chunksize
+                    )
+            else:
+                new_chunks = dset[var].data.rechunk(chunks).chunksize
         if new_chunks == old_chunks:
             logger.debug("%s: chunk sizes already optimized, skipping", var)
             continue
@@ -120,55 +375,91 @@ def _rechunk_dataset(
         )
         logger.debug("Settings encoding of variable %s", var)
         encoding[data_var] = {
-            str(k): v for k, v in dset[var].encoding.items() if str(k) in _keywords
+            str(k): v
+            for k, v in dset[var].encoding.items()
+            if str(k) in _keywords
         }
-        if engine != "netcdf4" or encoding[data_var].get("contiguous", False) is False:
+        if (
+            engine != "netcdf4"
+            or encoding[data_var].get("contiguous", False) is False
+        ):
             encoding[data_var]["chunksizes"] = new_chunks
     return dset, encoding
 
 
 def rechunk_dataset(
-    dset: xr.Dataset, engine: Literal["h5netcdf", "netcdf4"] = "netcdf4"
+    dset: xr.Dataset,
+    engine: Literal["h5netcdf", "netcdf4"] = "netcdf4",
+    size: Optional[int] = None,
+    auto_chunks: bool = False,
+    force_horizontal: bool = False,
 ) -> xr.Dataset:
-    """Rechunk a xarray dataset.
+    """
+    Rechunk a xarray dataset.
 
     Parameters
     ----------
-    dset: xarray.Dataset
-        Input dataset that is going to be rechunked
-    engine: str, default: netcdf4
-        The netcdf engine used to create the new netcdf file.
+    dset : xr.Dataset
+        The dataset to be rechunked.
+    engine : Literal["h5netcdf", "netcdf4"], optional
+        The engine to use for writing the netCDF file. Defaults to "netcdf4".
+    size : Optional[int], optional
+        The desired chunk size in MiB. If None, the default chunk size is used.
+        Defaults to None.
+    auto_chunks : bool, optional
+        If True, determine the chunk size automatically using Dask. Defaults to False.
+    force_horizontal : bool, optional
+        If True, force the chunk size to be in the horizontal dimensions
+        (y and x). Defaults to False.
 
     Returns
     -------
-    xarray.Dataset: rechunked dataset
+    xr.Dataset
+        The rechunked dataset.
     """
-    data, _ = _rechunk_dataset(dset.chunk(), engine)
+    data, _ = _rechunk_dataset(
+        dset.chunk(), engine, size, auto_chunks, force_horizontal
+    )
     return data
 
 
 def rechunk_netcdf_file(
     input_path: os.PathLike,
     output_path: Optional[os.PathLike] = None,
-    decode_cf: bool = True,
     engine: Literal["h5netcdf", "netcdf4"] = "netcdf4",
+    size: Optional[int] = None,
+    auto_chunks: bool = False,
+    decode_cf: bool = True,
+    force_horizontal: bool = False,
 ) -> None:
-    """Rechunk netcdf files.
+    """
+    Rechunk a netCDF file.
 
     Parameters
     ----------
-    input_path: os.PathLike
-        Input file/directory. If a directory is given all ``.nc`` in all sub
-        directories will be processed
-    output_path: os.PathLike
-        Output file/directory of the chunked netcdf file(s). Note: If ``input``
-        is a directory output should be a directory. If None given (default)
-        the ``input`` is overidden.
-    decode_cf: bool, default: True
-        Whether to decode these variables, assuming they were saved according
-        to CF conventions.
-    engine: str, default: netcdf4
-        The netcdf engine used to create the new netcdf file.
+    input_path : os.PathLike
+        The path to the netCDF file or directory to be rechunked.
+    output_path : Optional[os.PathLike], optional
+        The path to the directory or file where the rechunked data will be saved.
+        If None, the file is overwritten. Defaults to None.
+    engine : Literal["h5netcdf", "netcdf4"], optional
+        The engine to use for writing the netCDF file. Defaults to "netcdf4".
+    size : Optional[int], optional
+        The desired chunk size in MiB. If None, the default chunk size is used.
+        Defaults to None.
+    auto_chunks : bool, optional
+        If True, determine the chunk size automatically using Dask. Defaults to False.
+    decode_cf : bool, optional
+        Whether to decode CF conventions. Defaults to True.
+    force_horizontal : bool, optional
+        If True, force the chunk size to be in the horizontal dimensions
+        (y and x). Defaults to False.
+
+    Returns
+    -------
+    None
+        The function processes and saves the rechunked dataset(s) to the specified
+        ``output_path``.
     """
     input_path = Path(input_path).expanduser().absolute()
     for input_file in _search_for_nc_files(input_path):
@@ -187,7 +478,9 @@ def rechunk_netcdf_file(
                 parallel=True,
                 decode_cf=decode_cf,
             ) as nc_data:
-                new_data, encoding = _rechunk_dataset(nc_data, engine)
+                new_data, encoding = _rechunk_dataset(
+                    nc_data, engine, size, auto_chunks, force_horizontal
+                )
                 if encoding:
                     logger.debug(
                         "Loading data into memory (%s).",
@@ -208,3 +501,43 @@ def rechunk_netcdf_file(
             engine,
             override=output_file != input_file,
         )
+
+
+def check_chunk_size(dset: xr.Dataset) -> None:
+    """
+    Estimates the chunk size of a dataset in MB.
+
+    Parameters
+    ----------
+    dset: xr.Dataset
+        The dataset for which to estimate the chunk size.
+
+    Returns
+    -------
+    None
+        This function prints out the estimated chunk size for each variable in a
+        dataset. The chunk size is estimated by multiplying the size of a single
+        element of the data type with the product of all chunk sizes.
+    """
+    for var in dset.data_vars:
+        print(f"\n{var}:\t{dict(dset[var].sizes)}")
+        dtype_size = np.dtype(dset[var].dtype).itemsize
+        chunksizes = dset[var].encoding.get("chunksizes")
+
+        if chunksizes is None:
+            print(
+                f"  ⚠️  Warning: No chunk sizes found for {var}, skipping...\n"
+            )
+            continue
+
+        chunksizes = tuple(
+            filter(lambda x: isinstance(x, (int, np.integer)), chunksizes)
+        )
+        chunk_size_bytes = np.prod(chunksizes) * dtype_size
+        chunk_size_mb = chunk_size_bytes / (1024**2)  # Convert to MB
+
+        chunks = dset[var].chunks or tuple(() for _ in dset[var].dims)
+
+        print(f"  * Chunk Sizes: {chunksizes}")
+        print(f"  * Estimated Chunk Size: {chunk_size_mb:.2f} MB")
+        print(f"  * Chunks: {dict(zip(dset[var].dims, map(tuple, chunks)))}\n")
diff --git a/src/rechunk_data/tests/conftest.py b/src/rechunk_data/tests/conftest.py
index 02d03b10aca9659dbdd861cd4953fa2eac8f1325..0a6612bd8acc8e638e64d284f12c3e80ea9f590d 100644
--- a/src/rechunk_data/tests/conftest.py
+++ b/src/rechunk_data/tests/conftest.py
@@ -1,12 +1,12 @@
 """pytest definitions to run the unittests."""
 
 from pathlib import Path
-from tempfile import TemporaryDirectory, NamedTemporaryFile
+from tempfile import NamedTemporaryFile, TemporaryDirectory
 from typing import Generator, Tuple
 
 import dask
-import pytest
 import numpy as np
+import pytest
 import xarray as xr
 
 
diff --git a/src/rechunk_data/tests/test_check_chunksize.py b/src/rechunk_data/tests/test_check_chunksize.py
new file mode 100644
index 0000000000000000000000000000000000000000..22d6edb206b6052c5e4ba680cc5ad9bf6cd51e2c
--- /dev/null
+++ b/src/rechunk_data/tests/test_check_chunksize.py
@@ -0,0 +1,37 @@
+"""Unit tests for checking the chunksize of the data."""
+
+import numpy as np
+import xarray as xr
+
+from rechunk_data import check_chunk_size
+
+
+def test_check_chunk_size(capsys) -> None:
+    """Test the check_chunk_size function with valid and invalid chunksizes."""
+
+    data1 = np.random.rand(100, 1100, 1200)
+    da1 = xr.DataArray(data1, dims=("time", "lat", "lon"), name="valid_var")
+    dset1 = xr.Dataset({"valid_var": da1})
+    dset1 = dset1.chunk({"time": 10, "lat": 550, "lon": 600})
+    dset1["valid_var"].encoding = {
+        "chunksizes": (10, 550, 600),
+        "dtype": "float32",
+        "zlib": True,
+    }
+
+    data2 = np.random.rand(100, 1100, 1200)
+    da2 = xr.DataArray(data2, dims=("time", "lat", "lon"), name="invalid_var")
+    dset2 = xr.Dataset({"invalid_var": da2})
+
+    combined_dset = xr.merge([dset1, dset2])
+
+    check_chunk_size(combined_dset)
+    captured = capsys.readouterr()
+
+    assert "valid_var" in captured.out
+    assert "invalid_var" in captured.out
+    assert (
+        "⚠️  Warning: No chunk sizes found for invalid_var, skipping..."
+        in captured.out
+    )
+    assert "Estimated Chunk Size: 25.18 MB" in captured.out
diff --git a/src/rechunk_data/tests/test_cli.py b/src/rechunk_data/tests/test_cli.py
index cdf7c51f8f101470adaa4885b2074c244b23a092..6ed1aaf8edaf1696623445bcc16f5722232fe81e 100644
--- a/src/rechunk_data/tests/test_cli.py
+++ b/src/rechunk_data/tests/test_cli.py
@@ -1,9 +1,12 @@
 """Unit tests for the cli."""
 
+import sys
+from io import StringIO
 from pathlib import Path
 from tempfile import TemporaryDirectory
 
 import pytest
+
 from rechunk_data import cli
 
 
@@ -20,3 +23,17 @@ def test_command_line_interface(data_dir: Path) -> None:
         cli([str(data_dir), "--output", temp_dir])
         new_files = sorted(Path(temp_dir).rglob("*.nc"))
     assert len(data_files) == len(new_files)
+
+
+def test_check_chunk_size(data_dir: Path) -> None:
+    """Test the --check-chunk-size argument."""
+
+    data_files = sorted(data_dir.rglob("*.nc"))[0]
+    captured_output = StringIO()
+    sys.stdout = captured_output
+    cli([str(data_files), "--check-chunk-size"])
+    sys.stdout = sys.__stdout__
+    output = captured_output.getvalue()
+    assert "tas:" in output
+    assert "Chunk Sizes: (1, 1, 24, 24)" in output
+    assert "Estimated Chunk Size: 0.00 MB" in output
diff --git a/src/rechunk_data/tests/test_rechunk_netcdf.py b/src/rechunk_data/tests/test_rechunk_netcdf.py
index 3544d3d6138bbc1c9b4204c9b206ac88b41708b2..3bc542c1f133b75935e719f80ee22cc54d3b7044 100644
--- a/src/rechunk_data/tests/test_rechunk_netcdf.py
+++ b/src/rechunk_data/tests/test_rechunk_netcdf.py
@@ -1,12 +1,14 @@
 """Test the actual rechunk method."""
+
 import logging
-from pathlib import Path
 import time
+from pathlib import Path
 from tempfile import NamedTemporaryFile, TemporaryDirectory
 
 import dask
 import pytest
-from rechunk_data import rechunk_netcdf_file, rechunk_dataset
+
+from rechunk_data import rechunk_dataset, rechunk_netcdf_file
 from rechunk_data._rechunk import _save_dataset
 
 
@@ -24,7 +26,9 @@ def test_rechunk_data_dir_without_overwrite(data_dir: Path) -> None:
     """Testing the creation of new datafiles from a folder."""
     with TemporaryDirectory() as temp_dir:
         rechunk_netcdf_file(data_dir, Path(temp_dir))
-        new_files = sorted(f.relative_to(temp_dir) for f in Path(temp_dir).rglob(".nc"))
+        new_files = sorted(
+            f.relative_to(temp_dir) for f in Path(temp_dir).rglob(".nc")
+        )
     old_files = sorted(f.relative_to(data_dir) for f in data_dir.rglob(".nc"))
     assert new_files == old_files
 
@@ -59,13 +63,15 @@ def test_wrong_or_format(small_chunk_data, caplog) -> None:
         _, loglevel, message = caplog.record_tuples[-1]
         assert loglevel == logging.ERROR
         assert "Error while" in message
-        _save_dataset(small_chunk_data, temp_file, {}, "foo")
+        _save_dataset(small_chunk_data, temp_file, {}, "foo")  # type: ignore[arg-type]
         _, loglevel, message = caplog.record_tuples[-1]
-        _save_dataset(small_chunk_data, temp_file, {"foo": "bar"}, "foo")
+        _save_dataset(
+            small_chunk_data, temp_file, {"foo": "bar"}, "foo"
+        )  # type: ignore[arg-type]
         _, loglevel, message = caplog.record_tuples[-1]
         assert loglevel == logging.ERROR
 
 
 def test_wrong_engine(small_chunk_data) -> None:
     with pytest.raises(ValueError):
-        rechunk_dataset(small_chunk_data, engine="foo")
+        rechunk_dataset(small_chunk_data, engine="foo")  # type: ignore[arg-type]
diff --git a/src/rechunk_data/tests/test_rechunking.py b/src/rechunk_data/tests/test_rechunking.py
index 945bb0d2e040b5e3f149304f2aee9374423c93e8..4f0e9e186b64b23b529a93f7360a2466d047b816 100644
--- a/src/rechunk_data/tests/test_rechunking.py
+++ b/src/rechunk_data/tests/test_rechunking.py
@@ -1,8 +1,14 @@
 """Unit tests for rechunking the data."""
+
 import dask
+import numpy as np
 import xarray as xr
 
-from rechunk_data._rechunk import _rechunk_dataset
+from rechunk_data._rechunk import (
+    _horizontal_chunks,
+    _optimal_chunks,
+    _rechunk_dataset,
+)
 
 
 def test_rechunking_small_data(
@@ -35,3 +41,121 @@ def test_rechunking_large_data(
     dset, encoding = _rechunk_dataset(large_chunk_data, "h5netcdf")
     assert encoding[variable_name]["chunksizes"] == chunks
     assert dset[variable_name].data.chunksize == chunks
+
+
+def test_optimal_chunks():
+    """Test the optimal chunk calculation."""
+
+    time_size = 100
+    y_size = 192
+    x_size = 384
+    dtype = np.dtype("float32")
+
+    time_chunk, y_chunk, x_chunk = _optimal_chunks(
+        time_size, y_size, x_size, dtype=dtype, only_horizontal=False
+    )
+    assert time_chunk <= time_size
+    assert y_chunk <= y_size
+    assert x_chunk <= x_size
+
+    time_chunk, y_chunk, x_chunk = _optimal_chunks(
+        time_size, y_size, x_size, dtype=dtype, only_horizontal=True
+    )
+    assert time_chunk == time_size
+    assert y_chunk <= y_size
+    assert x_chunk <= x_size
+
+    time_size = 100_000_000
+    y_size = 1
+    x_size = 1
+    time_chunk, y_chunk, x_chunk = _optimal_chunks(
+        time_size, y_size, x_size, dtype=dtype, only_horizontal=True
+    )
+    assert time_chunk < time_size
+
+
+def test_horizontal_chunks():
+    """Test horizontal chunking function and _rechunk_dataset() with
+    force_horizontal."""
+    data = np.random.rand(100, 1100, 1200)
+    da = xr.DataArray(data, dims=("time", "lat", "lon"), name="test_var")
+
+    chunks = {0: "auto", 1: None, 2: None}
+    updated_chunks = _horizontal_chunks(da, chunks)
+    assert updated_chunks[0] == 100
+    assert updated_chunks[1] < 1100
+    assert updated_chunks[2] < 1200
+
+    updated_chunks = _horizontal_chunks("invalid_data", chunks)
+    assert updated_chunks == chunks
+
+    dset = xr.Dataset({"test_var": da})
+    chunksizes = (10, 550, 1200)
+    dset = dset.chunk(
+        {"time": chunksizes[0], "lat": chunksizes[1], "lon": chunksizes[2]}
+    )
+    encoding = {
+        "test_var": {
+            "zlib": True,
+            "complevel": 4,
+            "shuffle": True,
+            "dtype": "float32",
+            "chunksizes": chunksizes,
+            "_FillValue": np.nan,
+        }
+    }
+    dset["test_var"].encoding = encoding
+    _, encoding = _rechunk_dataset(
+        dset, engine="netcdf4", force_horizontal=True
+    )
+    assert "test_var" in encoding
+
+    chunksizes_applied = encoding["test_var"].get("chunksizes", None)
+    assert chunksizes_applied is not None
+    assert chunksizes_applied[0] == 100
+    assert chunksizes_applied[1] < 1100
+    assert chunksizes_applied[2] < 1200
+
+    dset["test_var"].encoding = encoding
+    _, encoding = _rechunk_dataset(
+        dset, engine="netcdf4", force_horizontal=True, auto_chunks=True
+    )
+    chunksizes_applied = encoding["test_var"].get("chunksizes", None)
+    assert chunksizes_applied[0] < 100
+    assert chunksizes_applied[1] < 1100
+    assert chunksizes_applied[2] < 1200
+
+
+def test_auto_size_chunks():
+    """
+    Test the automatic chunking and size adjustment functionality.
+    """
+    data = np.random.rand(100, 1100, 1200)
+    da = xr.DataArray(data, dims=("time", "lat", "lon"), name="test_var")
+    dset = xr.Dataset({"test_var": da})
+    chunksizes = (1, 1, 1200)
+    dset = dset.chunk(
+        {"time": chunksizes[0], "lat": chunksizes[1], "lon": chunksizes[2]}
+    )
+    encoding = {
+        "test_var": {
+            "zlib": True,
+            "complevel": 4,
+            "shuffle": True,
+            "dtype": "float32",
+            "chunksizes": chunksizes,
+            "_FillValue": np.nan,
+        }
+    }
+    dset["test_var"].encoding = encoding
+    _, encoding = _rechunk_dataset(dset, engine="netcdf4", auto_chunks=True)
+    chunksizes_applied = encoding["test_var"].get("chunksizes", None)
+    assert chunksizes_applied[0] == 100
+    assert chunksizes_applied[2] == 1200
+
+    dset["test_var"].encoding = encoding
+    _, encoding = _rechunk_dataset(dset, engine="netcdf4", size=20)
+    chunksizes_applied = encoding["test_var"].get("chunksizes", None)
+    assert 1 < chunksizes_applied[0] < 100
+    assert 1 < chunksizes_applied[1] < 1100
+    assert chunksizes_applied[2] == 1200
diff --git a/src/rechunk_data/tests/test_search_files.py b/src/rechunk_data/tests/test_search_files.py
index cff79c039c46fbbc7c61bbd8b470b37320d27965..297ee6468ff6037bec8109593226b6b5126bed1f 100644
--- a/src/rechunk_data/tests/test_search_files.py
+++ b/src/rechunk_data/tests/test_search_files.py
@@ -1,4 +1,5 @@
 """Unit tests for searching for files."""
+
 from pathlib import Path
 
 from rechunk_data._rechunk import _search_for_nc_files