From a50e3167be2b05a3c3bee4e18441504998c907ef Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Mon, 6 Apr 2026 14:00:07 +0200 Subject: [PATCH 1/8] feat: add cast_value and scale_offset codecs Defines two new codecs that together provide a v3-native replacement for the existing `numcodecs.fixedscaleoffset` codec. The `cast_value` codec requires an optional dependency on the `cast-value-rs` package. --- pyproject.toml | 11 + src/zarr/codecs/__init__.py | 6 + src/zarr/codecs/cast_value.py | 325 +++++++++++++++++++++++++ src/zarr/codecs/scale_offset.py | 136 +++++++++++ tests/test_codecs/test_cast_value.py | 281 +++++++++++++++++++++ tests/test_codecs/test_scale_offset.py | 222 +++++++++++++++++ 6 files changed, 981 insertions(+) create mode 100644 src/zarr/codecs/cast_value.py create mode 100644 src/zarr/codecs/scale_offset.py create mode 100644 tests/test_codecs/test_cast_value.py create mode 100644 tests/test_codecs/test_scale_offset.py diff --git a/pyproject.toml b/pyproject.toml index 96932a9611..40480b3864 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ remote = [ gpu = [ "cupy-cuda12x", ] +cast-value-rs = ["cast-value-rs"] cli = ["typer"] optional = ["universal-pathlib"] @@ -190,6 +191,16 @@ run-benchmark = "pytest --benchmark-enable tests/benchmarks" serve-coverage-html = "python -m http.server -d htmlcov 8000" list-env = "pip list" +[tool.hatch.envs.cast-value] +template = "test" +features = ["cast-value-rs"] + +[[tool.hatch.envs.cast-value.matrix]] +python = ["3.12"] + +[tool.hatch.envs.cast-value.scripts] +run = "pytest tests/test_codecs/test_cast_value.py {args:}" + [tool.hatch.envs.gputest] template = "test" extra-dependencies = [ diff --git a/src/zarr/codecs/__init__.py b/src/zarr/codecs/__init__.py index 4c621290e7..756bd97ed2 100644 --- a/src/zarr/codecs/__init__.py +++ b/src/zarr/codecs/__init__.py @@ -2,6 +2,7 @@ from zarr.codecs.blosc import BloscCname, BloscCodec, BloscShuffle from zarr.codecs.bytes import BytesCodec, Endian +from zarr.codecs.cast_value import CastValue from zarr.codecs.crc32c_ import Crc32cCodec from zarr.codecs.gzip import GzipCodec from zarr.codecs.numcodecs import ( @@ -27,6 +28,7 @@ Zlib, Zstd, ) +from zarr.codecs.scale_offset import ScaleOffset from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation from zarr.codecs.transpose import TransposeCodec from zarr.codecs.vlen_utf8 import VLenBytesCodec, VLenUTF8Codec @@ -38,9 +40,11 @@ "BloscCodec", "BloscShuffle", "BytesCodec", + "CastValue", "Crc32cCodec", "Endian", "GzipCodec", + "ScaleOffset", "ShardingCodec", "ShardingCodecIndexLocation", "TransposeCodec", @@ -50,12 +54,14 @@ ] register_codec("blosc", BloscCodec) +register_codec("cast_value", CastValue) register_codec("bytes", BytesCodec) # compatibility with earlier versions of ZEP1 register_codec("endian", BytesCodec) register_codec("crc32c", Crc32cCodec) register_codec("gzip", GzipCodec) +register_codec("scale_offset", ScaleOffset) register_codec("sharding_indexed", ShardingCodec) register_codec("zstd", ZstdCodec) register_codec("vlen-utf8", VLenUTF8Codec) diff --git a/src/zarr/codecs/cast_value.py b/src/zarr/codecs/cast_value.py new file mode 100644 index 0000000000..4b3a554566 --- /dev/null +++ b/src/zarr/codecs/cast_value.py @@ -0,0 +1,325 @@ +"""Cast-value array-to-array codec. + +Value-converts array elements to a new data type during encoding, +and back to the original data type during decoding, with configurable +rounding, out-of-range handling, and explicit scalar mappings. + +Requires the optional ``cast-value-rs`` package for the actual casting +logic. Install it with: ``pip install cast-value-rs``. +""" + +from __future__ import annotations + +from dataclasses import dataclass, replace +from typing import TYPE_CHECKING, Literal, cast + +import numpy as np + +from zarr.abc.codec import ArrayArrayCodec +from zarr.core.common import JSON, parse_named_configuration +from zarr.core.dtype import get_data_type_from_json + +if TYPE_CHECKING: + from collections.abc import Iterable, Mapping + from typing import Any, NotRequired, Self, TypedDict + + from zarr.core.array_spec import ArraySpec + from zarr.core.buffer import NDBuffer + from zarr.core.chunk_grids import ChunkGrid + from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType + + class ScalarMapJSON(TypedDict): + encode: NotRequired[list[tuple[object, object]]] + decode: NotRequired[list[tuple[object, object]]] + + # Pre-parsed scalar map entry: (source_scalar, target_scalar) + ScalarMapEntry = tuple[np.integer[Any] | np.floating[Any], np.integer[Any] | np.floating[Any]] + +RoundingMode = Literal[ + "nearest-even", + "towards-zero", + "towards-positive", + "towards-negative", + "nearest-away", +] + +OutOfRangeMode = Literal["clamp", "wrap"] + + +# --------------------------------------------------------------------------- +# Scalar-map parsing helpers +# --------------------------------------------------------------------------- + + +def _extract_raw_map(data: ScalarMapJSON | None, direction: str) -> dict[str, str] | None: + """Extract raw string mapping from scalar_map JSON for 'encode' or 'decode'.""" + if data is None: + return None + raw: dict[str, str] = {} + pairs: list[tuple[object, object]] = data.get(direction, []) # type: ignore[assignment] + for src, tgt in pairs: + raw[str(src)] = str(tgt) + return raw or None + + +def _parse_map_entries( + mapping: Mapping[str, str], + src_dtype: ZDType[TBaseDType, TBaseScalar], + tgt_dtype: ZDType[TBaseDType, TBaseScalar], +) -> tuple[ScalarMapEntry, ...]: + """Pre-parse a scalar map dict into a tuple of (src, tgt) pairs. + + Each entry's source value is deserialized using ``src_dtype`` and its target + value using ``tgt_dtype``, preserving full precision for both data types. + """ + entries: list[ScalarMapEntry] = [ + ( + src_dtype.from_json_scalar(src_str, zarr_format=3), # type: ignore[misc] + tgt_dtype.from_json_scalar(tgt_str, zarr_format=3), + ) + for src_str, tgt_str in mapping.items() + ] + return tuple(entries) + + +# --------------------------------------------------------------------------- +# Backend: cast-value-rs (optional) +# --------------------------------------------------------------------------- + +try: + from cast_value_rs import cast_array as _rs_cast_array + + _HAS_RUST_BACKEND = True +except ModuleNotFoundError: + _HAS_RUST_BACKEND = False + + +def _dtype_to_str(dtype: np.dtype) -> str: # type: ignore[type-arg] + return dtype.name + + +def _convert_scalar_map( + entries: Iterable[ScalarMapEntry] | None, +) -> list[tuple[int | float, int | float]] | None: + if entries is None: + return None + result: list[tuple[int | float, int | float]] = [] + for src, tgt in entries: + src_py: int | float = int(src) if isinstance(src, np.integer) else float(src) + tgt_py: int | float = int(tgt) if isinstance(tgt, np.integer) else float(tgt) + result.append((src_py, tgt_py)) + return result + + +def _cast_array_rs( + arr: np.ndarray, # type: ignore[type-arg] + *, + target_dtype: np.dtype, # type: ignore[type-arg] + rounding: RoundingMode, + out_of_range: OutOfRangeMode | None, + scalar_map_entries: Iterable[ScalarMapEntry] | None, +) -> np.ndarray: # type: ignore[type-arg] + return _rs_cast_array( # type: ignore[no-any-return] + arr=arr, + target_dtype=_dtype_to_str(target_dtype), + rounding_mode=rounding, + out_of_range_mode=out_of_range, + scalar_map_entries=_convert_scalar_map(scalar_map_entries), + ) + + +# --------------------------------------------------------------------------- +# Codec +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class CastValue(ArrayArrayCodec): + """Cast-value array-to-array codec. + + Value-converts array elements to a new data type during encoding, + and back to the original data type during decoding. + + Requires the ``cast-value-rs`` package for the actual casting logic. + + Parameters + ---------- + data_type : str + Target zarr v3 data type name (e.g. "uint8", "float32"). + rounding : RoundingMode + How to round when exact representation is impossible. Default is "nearest-even". + out_of_range : OutOfRangeMode or None + What to do when a value is outside the target's range. + None means error. "clamp" clips to range. "wrap" uses modular arithmetic + (only valid for integer types). + scalar_map : dict or None + Explicit mapping from input scalars to output scalars. + + References + ---------- + + - The `cast_value` codec spec: https://github.com/zarr-developers/zarr-extensions/tree/main/codecs/cast_value + """ + + is_fixed_size = True + + dtype: ZDType[TBaseDType, TBaseScalar] + rounding: RoundingMode + out_of_range: OutOfRangeMode | None + scalar_map: ScalarMapJSON | None + + def __init__( + self, + *, + data_type: str | ZDType[TBaseDType, TBaseScalar], + rounding: RoundingMode = "nearest-even", + out_of_range: OutOfRangeMode | None = None, + scalar_map: ScalarMapJSON | None = None, + ) -> None: + if isinstance(data_type, str): + zdtype = get_data_type_from_json(data_type, zarr_format=3) + else: + zdtype = data_type + object.__setattr__(self, "dtype", zdtype) + object.__setattr__(self, "rounding", rounding) + object.__setattr__(self, "out_of_range", out_of_range) + object.__setattr__(self, "scalar_map", scalar_map) + + @classmethod + def from_dict(cls, data: dict[str, JSON]) -> Self: + _, configuration_parsed = parse_named_configuration( + data, "cast_value", require_configuration=True + ) + return cls(**configuration_parsed) # type: ignore[arg-type] + + def to_dict(self) -> dict[str, JSON]: + config: dict[str, JSON] = {"data_type": cast("JSON", self.dtype.to_json(zarr_format=3))} + if self.rounding != "nearest-even": + config["rounding"] = self.rounding + if self.out_of_range is not None: + config["out_of_range"] = self.out_of_range + if self.scalar_map is not None: + config["scalar_map"] = cast("JSON", self.scalar_map) + return {"name": "cast_value", "configuration": config} + + def validate( + self, + *, + shape: tuple[int, ...], + dtype: ZDType[TBaseDType, TBaseScalar], + chunk_grid: ChunkGrid, + ) -> None: + source_native = dtype.to_native_dtype() + target_native = self.dtype.to_native_dtype() + for label, dt in [("source", source_native), ("target", target_native)]: + if not np.issubdtype(dt, np.integer) and not np.issubdtype(dt, np.floating): + raise ValueError( + f"The cast_value codec only supports integer and floating-point data types. " + f"Got {label} dtype {dt}." + ) + if self.out_of_range == "wrap" and not np.issubdtype(target_native, np.integer): + raise ValueError("out_of_range='wrap' is only valid for integer target types.") + + def _do_cast( + self, + arr: np.ndarray, # type: ignore[type-arg] + *, + target_dtype: np.dtype, # type: ignore[type-arg] + scalar_map_entries: Iterable[ScalarMapEntry] | None, + ) -> np.ndarray: # type: ignore[type-arg] + if not _HAS_RUST_BACKEND: + raise ImportError( + "The cast_value codec requires the 'cast-value-rs' package. " + "Install it with: pip install cast-value-rs" + ) + return _cast_array_rs( + arr, + target_dtype=target_dtype, + rounding=self.rounding, + out_of_range=self.out_of_range, + scalar_map_entries=scalar_map_entries, + ) + + def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec: + """ + Update the fill value of the output spec by applying casting procedure. + """ + target_zdtype = self.dtype + target_native = target_zdtype.to_native_dtype() + source_native = chunk_spec.dtype.to_native_dtype() + + fill = chunk_spec.fill_value + fill_arr = np.array([fill], dtype=source_native) + + encode_raw = _extract_raw_map(self.scalar_map, "encode") + encode_entries = ( + _parse_map_entries(encode_raw, chunk_spec.dtype, self.dtype) if encode_raw else None + ) + + new_fill_arr = self._do_cast( + fill_arr, target_dtype=target_native, scalar_map_entries=encode_entries + ) + new_fill = target_native.type(new_fill_arr[0]) + + return replace(chunk_spec, dtype=target_zdtype, fill_value=new_fill) + + def _encode_sync( + self, + chunk_array: NDBuffer, + _chunk_spec: ArraySpec, + ) -> NDBuffer | None: + arr = chunk_array.as_ndarray_like() + target_native = self.dtype.to_native_dtype() + + encode_raw = _extract_raw_map(self.scalar_map, "encode") + encode_entries = ( + _parse_map_entries(encode_raw, _chunk_spec.dtype, self.dtype) if encode_raw else None + ) + + result = self._do_cast( + np.asarray(arr), target_dtype=target_native, scalar_map_entries=encode_entries + ) + return chunk_array.__class__.from_ndarray_like(result) + + async def _encode_single( + self, + chunk_data: NDBuffer, + chunk_spec: ArraySpec, + ) -> NDBuffer | None: + return self._encode_sync(chunk_data, chunk_spec) + + def _decode_sync( + self, + chunk_array: NDBuffer, + chunk_spec: ArraySpec, + ) -> NDBuffer: + arr = chunk_array.as_ndarray_like() + target_native = chunk_spec.dtype.to_native_dtype() + + decode_raw = _extract_raw_map(self.scalar_map, "decode") + decode_entries = ( + _parse_map_entries(decode_raw, self.dtype, chunk_spec.dtype) if decode_raw else None + ) + + result = self._do_cast( + np.asarray(arr), target_dtype=target_native, scalar_map_entries=decode_entries + ) + return chunk_array.__class__.from_ndarray_like(result) + + async def _decode_single( + self, + chunk_data: NDBuffer, + chunk_spec: ArraySpec, + ) -> NDBuffer: + return self._decode_sync(chunk_data, chunk_spec) + + def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) -> int: + source_itemsize = chunk_spec.dtype.to_native_dtype().itemsize + target_itemsize = self.dtype.to_native_dtype().itemsize + if source_itemsize == 0 or target_itemsize == 0: + raise ValueError( + "cast_value codec requires fixed-size data types. " + f"Got source itemsize={source_itemsize}, target itemsize={target_itemsize}." + ) + num_elements = input_byte_length // source_itemsize + return num_elements * target_itemsize diff --git a/src/zarr/codecs/scale_offset.py b/src/zarr/codecs/scale_offset.py new file mode 100644 index 0000000000..3e06fca4de --- /dev/null +++ b/src/zarr/codecs/scale_offset.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +from dataclasses import dataclass, replace +from typing import TYPE_CHECKING + +import numpy as np + +from zarr.abc.codec import ArrayArrayCodec +from zarr.core.common import JSON, parse_named_configuration + +if TYPE_CHECKING: + from typing import Self + + from zarr.core.array_spec import ArraySpec + from zarr.core.buffer import NDBuffer + from zarr.core.chunk_grids import ChunkGrid + from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType + + +@dataclass(frozen=True) +class ScaleOffset(ArrayArrayCodec): + """Scale-offset array-to-array codec. + + Encodes values by subtracting an offset and multiplying by a scale factor. + Decodes by dividing by the scale and adding the offset. + + All arithmetic uses the input array's data type semantics (no implicit promotion). + + Parameters + ---------- + offset : float + Value subtracted during encoding. Default is 0. + scale : float + Value multiplied during encoding (after offset subtraction). Default is 1. + """ + + is_fixed_size = True + + offset: int | float + scale: int | float + + def __init__(self, *, offset: object = 0, scale: object = 1) -> None: + if not isinstance(offset, int | float): + raise TypeError(f"offset must be a number, got {type(offset).__name__}") + if not isinstance(scale, int | float): + raise TypeError(f"scale must be a number, got {type(scale).__name__}") + object.__setattr__(self, "offset", offset) + object.__setattr__(self, "scale", scale) + + @classmethod + def from_dict(cls, data: dict[str, JSON]) -> Self: + _, configuration_parsed = parse_named_configuration( + data, "scale_offset", require_configuration=False + ) + configuration_parsed = configuration_parsed or {} + return cls(**configuration_parsed) + + def to_dict(self) -> dict[str, JSON]: + if self.offset == 0 and self.scale == 1: + return {"name": "scale_offset"} + config: dict[str, JSON] = {} + if self.offset != 0: + config["offset"] = self.offset + if self.scale != 1: + config["scale"] = self.scale + return {"name": "scale_offset", "configuration": config} + + def validate( + self, + *, + shape: tuple[int, ...], + dtype: ZDType[TBaseDType, TBaseScalar], + chunk_grid: ChunkGrid, + ) -> None: + native = dtype.to_native_dtype() + if not np.issubdtype(native, np.integer) and not np.issubdtype(native, np.floating): + raise ValueError( + f"scale_offset codec only supports integer and floating-point data types. " + f"Got {dtype}." + ) + + def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec: + native_dtype = chunk_spec.dtype.to_native_dtype() + fill = chunk_spec.fill_value + new_fill = (native_dtype.type(fill) - native_dtype.type(self.offset)) * native_dtype.type( # type: ignore[operator] + self.scale + ) + return replace(chunk_spec, fill_value=new_fill) + + def _decode_sync( + self, + chunk_array: NDBuffer, + _chunk_spec: ArraySpec, + ) -> NDBuffer: + arr = chunk_array.as_ndarray_like() + if np.issubdtype(arr.dtype, np.integer): + result = (arr // arr.dtype.type(self.scale)) + arr.dtype.type(self.offset) + else: + result = (arr / arr.dtype.type(self.scale)) + arr.dtype.type(self.offset) + if result.dtype != arr.dtype: + raise ValueError( + f"scale_offset decode changed dtype from {arr.dtype} to {result.dtype}. " + f"Arithmetic must preserve the data type." + ) + return chunk_array.__class__.from_ndarray_like(result) + + async def _decode_single( + self, + chunk_array: NDBuffer, + chunk_spec: ArraySpec, + ) -> NDBuffer: + return self._decode_sync(chunk_array, chunk_spec) + + def _encode_sync( + self, + chunk_array: NDBuffer, + _chunk_spec: ArraySpec, + ) -> NDBuffer | None: + arr = chunk_array.as_ndarray_like() + result = (arr - arr.dtype.type(self.offset)) * arr.dtype.type(self.scale) + if result.dtype != arr.dtype: + raise ValueError( + f"scale_offset encode changed dtype from {arr.dtype} to {result.dtype}. " + f"Arithmetic must preserve the data type." + ) + return chunk_array.__class__.from_ndarray_like(result) + + async def _encode_single( + self, + chunk_array: NDBuffer, + _chunk_spec: ArraySpec, + ) -> NDBuffer | None: + return self._encode_sync(chunk_array, _chunk_spec) + + def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: + return input_byte_length diff --git a/tests/test_codecs/test_cast_value.py b/tests/test_codecs/test_cast_value.py new file mode 100644 index 0000000000..5b72096de8 --- /dev/null +++ b/tests/test_codecs/test_cast_value.py @@ -0,0 +1,281 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import numpy as np +import pytest + +from zarr.codecs.cast_value import CastValue + +# These tests require cast-value-rs. Skip the entire module if not installed. +pytest.importorskip("cast_value_rs") + + +@dataclass(frozen=True) +class Expect[TIn, TOut]: + """Model an input and an expected output value for a test case.""" + + input: TIn + expected: TOut + + +@dataclass(frozen=True) +class ExpectErr[TIn]: + """Model an input and an expected error message for a test case.""" + + input: TIn + msg: str + exception_cls: type[Exception] + + +# --------------------------------------------------------------------------- +# Serialization +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=CastValue(data_type="uint8"), + expected={"name": "cast_value", "configuration": {"data_type": "uint8"}}, + ), + Expect( + input=CastValue( + data_type="uint8", + rounding="towards-zero", + out_of_range="clamp", + scalar_map={"encode": [("NaN", 0)]}, + ), + expected={ + "name": "cast_value", + "configuration": { + "data_type": "uint8", + "rounding": "towards-zero", + "out_of_range": "clamp", + "scalar_map": {"encode": [("NaN", 0)]}, + }, + }, + ), + ], + ids=["minimal", "full"], +) +def test_to_dict(case: Expect[CastValue, dict[str, Any]]) -> None: + """to_dict produces the expected JSON structure.""" + assert case.input.to_dict() == case.expected + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input={"name": "cast_value", "configuration": {"data_type": "float32"}}, + expected=("float32", "nearest-even", None), + ), + Expect( + input={ + "name": "cast_value", + "configuration": { + "data_type": "int16", + "rounding": "towards-zero", + "out_of_range": "clamp", + }, + }, + expected=("int16", "towards-zero", "clamp"), + ), + ], + ids=["defaults", "explicit"], +) +def test_from_dict(case: Expect[dict[str, Any], tuple[str, str, str | None]]) -> None: + """from_dict deserializes configuration with correct values and defaults.""" + codec = CastValue.from_dict(case.input) + dtype_name, rounding, out_of_range = case.expected + assert codec.dtype.to_native_dtype() == np.dtype(dtype_name) + assert codec.rounding == rounding + assert codec.out_of_range == out_of_range + + +def test_serialization_roundtrip() -> None: + """to_dict followed by from_dict produces an equal codec.""" + original = CastValue(data_type="int16", rounding="towards-zero", out_of_range="clamp") + restored = CastValue.from_dict(original.to_dict()) + assert original == restored + + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input={"dtype": "complex128", "target": "float64"}, + msg="only supports integer and floating-point", + exception_cls=ValueError, + ), + ExpectErr( + input={"dtype": "int32", "target": "float32", "out_of_range": "wrap"}, + msg="only valid for integer", + exception_cls=ValueError, + ), + ], + ids=["complex-source", "wrap-float-target"], +) +def test_validation_rejects_invalid(case: ExpectErr[dict[str, Any]]) -> None: + """Invalid dtype or out_of_range combinations are rejected at array creation.""" + import zarr + + with pytest.raises(case.exception_cls, match=case.msg): + zarr.create_array( + store={}, + shape=(10,), + dtype=case.input["dtype"], + chunks=(10,), + filters=[ + CastValue( + data_type=case.input["target"], + out_of_range=case.input.get("out_of_range"), + ) + ], + compressors=None, + fill_value=0, + ) + + +def test_zero_itemsize_raises() -> None: + """Variable-length dtypes (itemsize=0) are rejected by compute_encoded_size.""" + from zarr.core.array_spec import ArrayConfig, ArraySpec + from zarr.core.buffer import default_buffer_prototype + from zarr.core.dtype.npy.string import VariableLengthUTF8 + + codec = CastValue(data_type="uint8") + spec = ArraySpec( + shape=(10,), + dtype=VariableLengthUTF8(), # type: ignore[arg-type] + fill_value="", + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ) + with pytest.raises(ValueError, match="fixed-size data types"): + codec.compute_encoded_size(100, spec) + + +# --------------------------------------------------------------------------- +# Encode / decode +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect(input=("float64", "float32"), expected=np.arange(50, dtype="float64")), + Expect(input=("float32", "float64"), expected=np.arange(50, dtype="float32")), + Expect(input=("int32", "int64"), expected=np.arange(50, dtype="int32")), + Expect(input=("int64", "int16"), expected=np.arange(50, dtype="int64")), + Expect(input=("float64", "int32"), expected=np.arange(50, dtype="float64")), + Expect(input=("int32", "float64"), expected=np.arange(50, dtype="int32")), + ], + ids=["f64→f32", "f32→f64", "i32→i64", "i64→i16", "f64→i32", "i32→f64"], +) +def test_encode_decode_roundtrip( + case: Expect[tuple[str, str], np.ndarray[Any, np.dtype[Any]]], +) -> None: + """Small integer data survives encode → decode for each dtype pair.""" + import zarr + + source_dtype, target_dtype = case.input + arr = zarr.create_array( + store={}, + shape=(50,), + dtype=source_dtype, + chunks=(50,), + filters=[CastValue(data_type=target_dtype)], + compressors=None, + fill_value=0, + ) + arr[:] = case.expected + np.testing.assert_array_equal(arr[:], case.expected) + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=np.array([1.7, -1.7, 2.5, -2.5], dtype="float64"), + expected=np.array([1, -1, 2, -2], dtype="float64"), + ), + ], + ids=["towards-zero"], +) +def test_float_to_int_rounding( + case: Expect[np.ndarray[Any, np.dtype[Any]], np.ndarray[Any, np.dtype[Any]]], +) -> None: + """Fractional float values are truncated towards zero when cast to int32.""" + import zarr + + arr = zarr.create_array( + store={}, + shape=case.input.shape, + dtype=case.input.dtype, + chunks=case.input.shape, + filters=[CastValue(data_type="int32", rounding="towards-zero", out_of_range="clamp")], + compressors=None, + fill_value=0, + ) + arr[:] = case.input + np.testing.assert_array_equal(arr[:], case.expected) + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=np.array([0, 200, -200], dtype="int32"), + expected=np.array([0, 127, -128], dtype="int32"), + ), + ], + ids=["int32→int8"], +) +def test_out_of_range_clamp( + case: Expect[np.ndarray[Any, np.dtype[Any]], np.ndarray[Any, np.dtype[Any]]], +) -> None: + """Values outside the int8 range are clamped to [-128, 127].""" + import zarr + + arr = zarr.create_array( + store={}, + shape=case.input.shape, + dtype=case.input.dtype, + chunks=case.input.shape, + filters=[CastValue(data_type="int8", out_of_range="clamp")], + compressors=None, + fill_value=0, + ) + arr[:] = case.input + np.testing.assert_array_equal(arr[:], case.expected) + + +def test_combined_with_scale_offset() -> None: + """scale_offset followed by cast_value compresses float64 into int16 and round-trips.""" + import zarr + from zarr.codecs.scale_offset import ScaleOffset + + arr = zarr.create_array( + store={}, + shape=(100,), + dtype="float64", + chunks=(100,), + filters=[ + ScaleOffset(offset=0, scale=10), + CastValue(data_type="int16", rounding="nearest-even", out_of_range="clamp"), + ], + compressors=None, + fill_value=0, + ) + data = np.arange(100, dtype="float64") * 0.1 + arr[:] = data + result = arr[:] + np.testing.assert_array_almost_equal(result, data, decimal=1) # type: ignore[arg-type] diff --git a/tests/test_codecs/test_scale_offset.py b/tests/test_codecs/test_scale_offset.py new file mode 100644 index 0000000000..47af287c1d --- /dev/null +++ b/tests/test_codecs/test_scale_offset.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import numpy as np +import pytest + +from zarr.codecs.scale_offset import ScaleOffset + + +@dataclass(frozen=True) +class Expect[TIn, TOut]: + """Model an input and an expected output value for a test case.""" + + input: TIn + expected: TOut + + +@dataclass(frozen=True) +class ExpectErr[TIn]: + """Model an input and an expected error message for a test case.""" + + input: TIn + msg: str + exception_cls: type[Exception] + + +# --------------------------------------------------------------------------- +# Serialization +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect(input=ScaleOffset(), expected={"name": "scale_offset"}), + Expect( + input=ScaleOffset(offset=5), + expected={"name": "scale_offset", "configuration": {"offset": 5}}, + ), + Expect( + input=ScaleOffset(scale=0.1), + expected={"name": "scale_offset", "configuration": {"scale": 0.1}}, + ), + Expect( + input=ScaleOffset(offset=5, scale=0.1), + expected={"name": "scale_offset", "configuration": {"offset": 5, "scale": 0.1}}, + ), + ], + ids=["default", "offset-only", "scale-only", "both"], +) +def test_to_dict(case: Expect[ScaleOffset, dict[str, Any]]) -> None: + """to_dict produces the expected JSON structure.""" + assert case.input.to_dict() == case.expected + + +@pytest.mark.parametrize( + "case", + [ + Expect(input={"name": "scale_offset"}, expected=(0, 1)), + Expect( + input={"name": "scale_offset", "configuration": {"offset": 3, "scale": 2}}, + expected=(3, 2), + ), + ], + ids=["no-config", "with-config"], +) +def test_from_dict(case: Expect[dict[str, Any], tuple[int | float, int | float]]) -> None: + """from_dict deserializes configuration with correct values and defaults.""" + codec = ScaleOffset.from_dict(case.input) + expected_offset, expected_scale = case.expected + assert codec.offset == expected_offset + assert codec.scale == expected_scale + + +def test_serialization_roundtrip() -> None: + """to_dict followed by from_dict produces an equal codec.""" + original = ScaleOffset(offset=7, scale=0.5) + restored = ScaleOffset.from_dict(original.to_dict()) + assert original == restored + + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr(input={"offset": "bad"}, msg="offset must be a number", exception_cls=TypeError), + ExpectErr(input={"scale": [1, 2]}, msg="scale must be a number", exception_cls=TypeError), + ], + ids=["string-offset", "list-scale"], +) +def test_construction_rejects_non_numeric(case: ExpectErr[dict[str, Any]]) -> None: + """Non-numeric offset or scale is rejected at construction time.""" + with pytest.raises(case.exception_cls, match=case.msg): + ScaleOffset(**case.input) + + +@pytest.mark.parametrize( + "case", + [ + Expect(input={"offset": 5, "scale": 2}, expected=(5, 2)), + Expect(input={"offset": 0.5, "scale": 0.1}, expected=(0.5, 0.1)), + ], + ids=["int", "float"], +) +def test_construction_accepts_numeric( + case: Expect[dict[str, Any], tuple[int | float, int | float]], +) -> None: + """Integer and float values are accepted for both parameters.""" + codec = ScaleOffset(**case.input) + assert codec.offset == case.expected[0] + assert codec.scale == case.expected[1] + + +# --------------------------------------------------------------------------- +# Encode / decode +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("dtype", "offset", "scale"), + [ + ("float64", 10.0, 0.1), + ("float32", 5.0, 2.0), + ("int32", 0, 1), + ], + ids=["float64", "float32", "int32-identity"], +) +def test_encode_decode_roundtrip(dtype: str, offset: float, scale: float) -> None: + """Data survives encode → decode.""" + import zarr + + arr = zarr.create_array( + store={}, + shape=(100,), + dtype=dtype, + chunks=(100,), + filters=[ScaleOffset(offset=offset, scale=scale)], + compressors=None, + fill_value=0, + ) + data = np.arange(100, dtype=dtype) + arr[:] = data + np.testing.assert_array_almost_equal(arr[:], data) # type: ignore[arg-type] + + +def test_fill_value_transformed() -> None: + """Fill value is transformed through the encode formula and read back correctly.""" + import zarr + + arr = zarr.create_array( + store={}, + shape=(10,), + dtype="float64", + chunks=(10,), + filters=[ScaleOffset(offset=5, scale=2)], + compressors=None, + fill_value=10.0, + ) + # fill_value=10.0, encode: (10 - 5) * 2 = 10.0 stored + # Reading back without writing should return the original fill value + np.testing.assert_array_equal(arr[:], np.full(10, 10.0)) + + +def test_identity_is_noop() -> None: + """Default codec (offset=0, scale=1) is a no-op.""" + import zarr + + arr = zarr.create_array( + store={}, + shape=(50,), + dtype="float64", + chunks=(50,), + filters=[ScaleOffset()], + compressors=None, + fill_value=0, + ) + data = np.arange(50, dtype="float64") + arr[:] = data + np.testing.assert_array_equal(arr[:], data) + + +def test_rejects_complex_dtype() -> None: + """Complex dtypes are rejected at array creation time.""" + import zarr + + with pytest.raises(ValueError, match="only supports integer and floating-point"): + zarr.create_array( + store={}, + shape=(10,), + dtype="complex128", + chunks=(10,), + filters=[ScaleOffset(offset=1, scale=2)], + compressors=None, + fill_value=0, + ) + + +def test_dtype_preservation() -> None: + """Integer scale/offset arithmetic preserves the array dtype via floor division.""" + import zarr + + arr = zarr.create_array( + store={}, + shape=(10,), + dtype="int8", + chunks=(10,), + filters=[ScaleOffset(offset=1, scale=2)], + compressors=None, + fill_value=0, + ) + data = np.arange(10, dtype="int8") + arr[:] = data + # offset=1, scale=2: encode=(x-1)*2, decode=x//2+1 + result = arr[:] + expected = ((data - 1) * 2) // 2 + 1 + np.testing.assert_array_equal(result, expected.astype("int8")) From 94ab34a4c3d71e80f13e3e55c808ad6215d48a0d Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Mon, 6 Apr 2026 14:52:33 +0200 Subject: [PATCH 2/8] docs: changelog --- changes/3874.feature.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/3874.feature.md diff --git a/changes/3874.feature.md b/changes/3874.feature.md new file mode 100644 index 0000000000..592c5b330d --- /dev/null +++ b/changes/3874.feature.md @@ -0,0 +1 @@ +Add `cast_value` and `scale_offset` codecs. \ No newline at end of file From 7f5f2b278ff4321be84cc64335745f5402d26d26 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 7 Apr 2026 23:18:16 +0200 Subject: [PATCH 3/8] chore: simplify scalar map handling --- src/zarr/codecs/cast_value.py | 153 ++++++++++----------------- tests/test_codecs/test_cast_value.py | 24 +++++ 2 files changed, 81 insertions(+), 96 deletions(-) diff --git a/src/zarr/codecs/cast_value.py b/src/zarr/codecs/cast_value.py index 4b3a554566..bf39f3f7a6 100644 --- a/src/zarr/codecs/cast_value.py +++ b/src/zarr/codecs/cast_value.py @@ -10,8 +10,9 @@ from __future__ import annotations +from collections.abc import Mapping from dataclasses import dataclass, replace -from typing import TYPE_CHECKING, Literal, cast +from typing import TYPE_CHECKING, Literal, TypedDict, cast import numpy as np @@ -20,8 +21,7 @@ from zarr.core.dtype import get_data_type_from_json if TYPE_CHECKING: - from collections.abc import Iterable, Mapping - from typing import Any, NotRequired, Self, TypedDict + from typing import NotRequired, Self from zarr.core.array_spec import ArraySpec from zarr.core.buffer import NDBuffer @@ -32,8 +32,6 @@ class ScalarMapJSON(TypedDict): encode: NotRequired[list[tuple[object, object]]] decode: NotRequired[list[tuple[object, object]]] - # Pre-parsed scalar map entry: (source_scalar, target_scalar) - ScalarMapEntry = tuple[np.integer[Any] | np.floating[Any], np.integer[Any] | np.floating[Any]] RoundingMode = Literal[ "nearest-even", @@ -46,88 +44,47 @@ class ScalarMapJSON(TypedDict): OutOfRangeMode = Literal["clamp", "wrap"] -# --------------------------------------------------------------------------- -# Scalar-map parsing helpers -# --------------------------------------------------------------------------- - +class ScalarMap(TypedDict): + """ + The normalized, in-memory form of a scalar map. + """ -def _extract_raw_map(data: ScalarMapJSON | None, direction: str) -> dict[str, str] | None: - """Extract raw string mapping from scalar_map JSON for 'encode' or 'decode'.""" - if data is None: - return None - raw: dict[str, str] = {} - pairs: list[tuple[object, object]] = data.get(direction, []) # type: ignore[assignment] - for src, tgt in pairs: - raw[str(src)] = str(tgt) - return raw or None + encode: NotRequired[Mapping[str | float | int, str | float | int]] + decode: NotRequired[Mapping[str | float | int, str | float | int]] -def _parse_map_entries( - mapping: Mapping[str, str], - src_dtype: ZDType[TBaseDType, TBaseScalar], - tgt_dtype: ZDType[TBaseDType, TBaseScalar], -) -> tuple[ScalarMapEntry, ...]: - """Pre-parse a scalar map dict into a tuple of (src, tgt) pairs. +def parse_scalar_map(obj: ScalarMapJSON | ScalarMap) -> ScalarMap: + """ + Parse a scalar map into its normalized dict-of-dicts form. - Each entry's source value is deserialized using ``src_dtype`` and its target - value using ``tgt_dtype``, preserving full precision for both data types. + Accepts either the JSON form (lists of tuples) or an already-normalized form + (dicts). For example, ``{"encode": [("NaN", 0)]}`` becomes + ``{"encode": {"NaN": 0}}``. """ - entries: list[ScalarMapEntry] = [ - ( - src_dtype.from_json_scalar(src_str, zarr_format=3), # type: ignore[misc] - tgt_dtype.from_json_scalar(tgt_str, zarr_format=3), - ) - for src_str, tgt_str in mapping.items() - ] - return tuple(entries) + result: ScalarMap = {} + for direction in ("encode", "decode"): + if direction in obj: + entries = obj[direction] + if entries is not None: + if isinstance(entries, Mapping): + result[direction] = entries + else: + result[direction] = dict(entries) # type: ignore[arg-type] + return result # --------------------------------------------------------------------------- -# Backend: cast-value-rs (optional) +# Backend: cast-value-rs # --------------------------------------------------------------------------- try: - from cast_value_rs import cast_array as _rs_cast_array + from cast_value_rs import cast_array as cast_array_rs _HAS_RUST_BACKEND = True except ModuleNotFoundError: _HAS_RUST_BACKEND = False -def _dtype_to_str(dtype: np.dtype) -> str: # type: ignore[type-arg] - return dtype.name - - -def _convert_scalar_map( - entries: Iterable[ScalarMapEntry] | None, -) -> list[tuple[int | float, int | float]] | None: - if entries is None: - return None - result: list[tuple[int | float, int | float]] = [] - for src, tgt in entries: - src_py: int | float = int(src) if isinstance(src, np.integer) else float(src) - tgt_py: int | float = int(tgt) if isinstance(tgt, np.integer) else float(tgt) - result.append((src_py, tgt_py)) - return result - - -def _cast_array_rs( - arr: np.ndarray, # type: ignore[type-arg] - *, - target_dtype: np.dtype, # type: ignore[type-arg] - rounding: RoundingMode, - out_of_range: OutOfRangeMode | None, - scalar_map_entries: Iterable[ScalarMapEntry] | None, -) -> np.ndarray: # type: ignore[type-arg] - return _rs_cast_array( # type: ignore[no-any-return] - arr=arr, - target_dtype=_dtype_to_str(target_dtype), - rounding_mode=rounding, - out_of_range_mode=out_of_range, - scalar_map_entries=_convert_scalar_map(scalar_map_entries), - ) - - # --------------------------------------------------------------------------- # Codec # --------------------------------------------------------------------------- @@ -166,7 +123,7 @@ class CastValue(ArrayArrayCodec): dtype: ZDType[TBaseDType, TBaseScalar] rounding: RoundingMode out_of_range: OutOfRangeMode | None - scalar_map: ScalarMapJSON | None + scalar_map: ScalarMap | None def __init__( self, @@ -174,7 +131,7 @@ def __init__( data_type: str | ZDType[TBaseDType, TBaseScalar], rounding: RoundingMode = "nearest-even", out_of_range: OutOfRangeMode | None = None, - scalar_map: ScalarMapJSON | None = None, + scalar_map: ScalarMapJSON | ScalarMap | None = None, ) -> None: if isinstance(data_type, str): zdtype = get_data_type_from_json(data_type, zarr_format=3) @@ -183,7 +140,11 @@ def __init__( object.__setattr__(self, "dtype", zdtype) object.__setattr__(self, "rounding", rounding) object.__setattr__(self, "out_of_range", out_of_range) - object.__setattr__(self, "scalar_map", scalar_map) + if scalar_map is not None: + parsed = parse_scalar_map(scalar_map) + else: + parsed = None + object.__setattr__(self, "scalar_map", parsed) @classmethod def from_dict(cls, data: dict[str, JSON]) -> Self: @@ -199,7 +160,11 @@ def to_dict(self) -> dict[str, JSON]: if self.out_of_range is not None: config["out_of_range"] = self.out_of_range if self.scalar_map is not None: - config["scalar_map"] = cast("JSON", self.scalar_map) + json_map: dict[str, list[tuple[object, object]]] = {} + for direction in ("encode", "decode"): + if direction in self.scalar_map: + json_map[direction] = [(k, v) for k, v in self.scalar_map[direction].items()] + config["scalar_map"] = cast("JSON", json_map) return {"name": "cast_value", "configuration": config} def validate( @@ -225,21 +190,32 @@ def _do_cast( arr: np.ndarray, # type: ignore[type-arg] *, target_dtype: np.dtype, # type: ignore[type-arg] - scalar_map_entries: Iterable[ScalarMapEntry] | None, + scalar_map: Mapping[str | float | int, str | float | int] | None, ) -> np.ndarray: # type: ignore[type-arg] if not _HAS_RUST_BACKEND: raise ImportError( "The cast_value codec requires the 'cast-value-rs' package. " "Install it with: pip install cast-value-rs" ) - return _cast_array_rs( + scalar_map_entries: dict[float, float] | None = None + if scalar_map is not None: + scalar_map_entries = {float(k): float(v) for k, v in scalar_map.items()} + return cast_array_rs( # type: ignore[no-any-return] arr, target_dtype=target_dtype, - rounding=self.rounding, - out_of_range=self.out_of_range, + rounding_mode=self.rounding, + out_of_range_mode=self.out_of_range, scalar_map_entries=scalar_map_entries, ) + def _get_scalar_map( + self, direction: str + ) -> Mapping[str | float | int, str | float | int] | None: + """Extract the encode or decode mapping from scalar_map, or None.""" + if self.scalar_map is None: + return None + return self.scalar_map.get(direction) # type: ignore[return-value] + def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec: """ Update the fill value of the output spec by applying casting procedure. @@ -251,13 +227,8 @@ def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec: fill = chunk_spec.fill_value fill_arr = np.array([fill], dtype=source_native) - encode_raw = _extract_raw_map(self.scalar_map, "encode") - encode_entries = ( - _parse_map_entries(encode_raw, chunk_spec.dtype, self.dtype) if encode_raw else None - ) - new_fill_arr = self._do_cast( - fill_arr, target_dtype=target_native, scalar_map_entries=encode_entries + fill_arr, target_dtype=target_native, scalar_map=self._get_scalar_map("encode") ) new_fill = target_native.type(new_fill_arr[0]) @@ -271,13 +242,8 @@ def _encode_sync( arr = chunk_array.as_ndarray_like() target_native = self.dtype.to_native_dtype() - encode_raw = _extract_raw_map(self.scalar_map, "encode") - encode_entries = ( - _parse_map_entries(encode_raw, _chunk_spec.dtype, self.dtype) if encode_raw else None - ) - result = self._do_cast( - np.asarray(arr), target_dtype=target_native, scalar_map_entries=encode_entries + np.asarray(arr), target_dtype=target_native, scalar_map=self._get_scalar_map("encode") ) return chunk_array.__class__.from_ndarray_like(result) @@ -296,13 +262,8 @@ def _decode_sync( arr = chunk_array.as_ndarray_like() target_native = chunk_spec.dtype.to_native_dtype() - decode_raw = _extract_raw_map(self.scalar_map, "decode") - decode_entries = ( - _parse_map_entries(decode_raw, self.dtype, chunk_spec.dtype) if decode_raw else None - ) - result = self._do_cast( - np.asarray(arr), target_dtype=target_native, scalar_map_entries=decode_entries + np.asarray(arr), target_dtype=target_native, scalar_map=self._get_scalar_map("decode") ) return chunk_array.__class__.from_ndarray_like(result) diff --git a/tests/test_codecs/test_cast_value.py b/tests/test_codecs/test_cast_value.py index 5b72096de8..59285ee0a5 100644 --- a/tests/test_codecs/test_cast_value.py +++ b/tests/test_codecs/test_cast_value.py @@ -279,3 +279,27 @@ def test_combined_with_scale_offset() -> None: arr[:] = data result = arr[:] np.testing.assert_array_almost_equal(result, data, decimal=1) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input={"encode": [("NaN", 0)]}, + expected={"encode": {"NaN": 0}}, + ), + Expect( + input={"encode": [("NaN", 0)], "decode": [(0, "NaN")]}, + expected={"encode": {"NaN": 0}, "decode": {0: "NaN"}}, + ), + Expect( + input={"encode": {"NaN": 0}}, + expected={"encode": {"NaN": 0}}, + ), + ], + ids=["encode-only", "both-directions", "already-normalized"], +) +def test_parse_scalar_map(case: Expect[Any, Any]) -> None: + from zarr.codecs.cast_value import parse_scalar_map + + assert parse_scalar_map(case.input) == case.expected From b35d5a39c932471947d058ea59699443c1b7c401 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 7 Apr 2026 23:31:07 +0200 Subject: [PATCH 4/8] chore: coverage --- src/zarr/codecs/cast_value.py | 64 +++++++++++- src/zarr/codecs/scale_offset.py | 4 +- tests/test_codecs/test_cast_value.py | 151 +++++++++++++++++++++++++-- 3 files changed, 207 insertions(+), 12 deletions(-) diff --git a/src/zarr/codecs/cast_value.py b/src/zarr/codecs/cast_value.py index bf39f3f7a6..2a3552df5e 100644 --- a/src/zarr/codecs/cast_value.py +++ b/src/zarr/codecs/cast_value.py @@ -25,8 +25,8 @@ from zarr.core.array_spec import ArraySpec from zarr.core.buffer import NDBuffer - from zarr.core.chunk_grids import ChunkGrid from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType + from zarr.core.metadata.v3 import ChunkGridMetadata class ScalarMapJSON(TypedDict): encode: NotRequired[list[tuple[object, object]]] @@ -85,6 +85,35 @@ def parse_scalar_map(obj: ScalarMapJSON | ScalarMap) -> ScalarMap: _HAS_RUST_BACKEND = False +def _check_scalar_representable( + value: str | float, + dtype: np.dtype, # type: ignore[type-arg] + direction: str, + role: str, +) -> None: + """Raise ``ValueError`` if *value* cannot be represented in *dtype*.""" + fval = float(value) + is_integer_dtype = np.issubdtype(dtype, np.integer) + if np.isnan(fval) and is_integer_dtype: + raise ValueError( + f"scalar_map {direction} {role} {value!r} is NaN, " + f"which is not representable in integer dtype {dtype}." + ) + if is_integer_dtype: + info = np.iinfo(dtype) + ival = int(fval) + if float(ival) != fval: + raise ValueError( + f"scalar_map {direction} {role} {value!r} is not an integer, " + f"which is required for integer dtype {dtype}." + ) + if ival < info.min or ival > info.max: + raise ValueError( + f"scalar_map {direction} {role} {value!r} is out of range " + f"for dtype {dtype} [{info.min}, {info.max}]." + ) + + # --------------------------------------------------------------------------- # Codec # --------------------------------------------------------------------------- @@ -172,7 +201,7 @@ def validate( *, shape: tuple[int, ...], dtype: ZDType[TBaseDType, TBaseScalar], - chunk_grid: ChunkGrid, + chunk_grid: ChunkGridMetadata, ) -> None: source_native = dtype.to_native_dtype() target_native = self.dtype.to_native_dtype() @@ -185,6 +214,30 @@ def validate( if self.out_of_range == "wrap" and not np.issubdtype(target_native, np.integer): raise ValueError("out_of_range='wrap' is only valid for integer target types.") + if self.scalar_map is not None: + self._validate_scalar_map(source_native, target_native) + + def _validate_scalar_map( + self, + source_native: np.dtype, # type: ignore[type-arg] + target_native: np.dtype, # type: ignore[type-arg] + ) -> None: + """Validate that scalar map entries are compatible with source/target dtypes.""" + assert self.scalar_map is not None + # For encode: keys are source values, values are target values. + # For decode: keys are target values, values are source values. + direction_dtypes = { + "encode": (source_native, target_native), + "decode": (target_native, source_native), + } + for direction, (key_dtype, val_dtype) in direction_dtypes.items(): + if direction not in self.scalar_map: + continue + sub_map = self.scalar_map[direction] # type: ignore[literal-required] + for k, v in sub_map.items(): + _check_scalar_representable(k, key_dtype, direction, "key") + _check_scalar_representable(v, val_dtype, direction, "value") + def _do_cast( self, arr: np.ndarray, # type: ignore[type-arg] @@ -197,9 +250,12 @@ def _do_cast( "The cast_value codec requires the 'cast-value-rs' package. " "Install it with: pip install cast-value-rs" ) - scalar_map_entries: dict[float, float] | None = None + scalar_map_entries: dict[float | int, float | int] | None = None if scalar_map is not None: - scalar_map_entries = {float(k): float(v) for k, v in scalar_map.items()} + src_dtype = arr.dtype + to_src = int if np.issubdtype(src_dtype, np.integer) else float + to_tgt = int if np.issubdtype(target_dtype, np.integer) else float + scalar_map_entries = {to_src(k): to_tgt(v) for k, v in scalar_map.items()} return cast_array_rs( # type: ignore[no-any-return] arr, target_dtype=target_dtype, diff --git a/src/zarr/codecs/scale_offset.py b/src/zarr/codecs/scale_offset.py index 3e06fca4de..fae0a9babb 100644 --- a/src/zarr/codecs/scale_offset.py +++ b/src/zarr/codecs/scale_offset.py @@ -13,8 +13,8 @@ from zarr.core.array_spec import ArraySpec from zarr.core.buffer import NDBuffer - from zarr.core.chunk_grids import ChunkGrid from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType + from zarr.core.metadata.v3 import ChunkGridMetadata @dataclass(frozen=True) @@ -70,7 +70,7 @@ def validate( *, shape: tuple[int, ...], dtype: ZDType[TBaseDType, TBaseScalar], - chunk_grid: ChunkGrid, + chunk_grid: ChunkGridMetadata, ) -> None: native = dtype.to_native_dtype() if not np.issubdtype(native, np.integer) and not np.issubdtype(native, np.floating): diff --git a/tests/test_codecs/test_cast_value.py b/tests/test_codecs/test_cast_value.py index 59285ee0a5..c7611f0c9d 100644 --- a/tests/test_codecs/test_cast_value.py +++ b/tests/test_codecs/test_cast_value.py @@ -8,8 +8,16 @@ from zarr.codecs.cast_value import CastValue -# These tests require cast-value-rs. Skip the entire module if not installed. -pytest.importorskip("cast_value_rs") +try: + import cast_value_rs # noqa: F401 + + _HAS_CAST_VALUE_RS = True +except ModuleNotFoundError: + _HAS_CAST_VALUE_RS = False + +requires_cast_value_rs = pytest.mark.skipif( + not _HAS_CAST_VALUE_RS, reason="cast-value-rs not installed" +) @dataclass(frozen=True) @@ -96,11 +104,22 @@ def test_from_dict(case: Expect[dict[str, Any], tuple[str, str, str | None]]) -> assert codec.out_of_range == out_of_range -def test_serialization_roundtrip() -> None: +@pytest.mark.parametrize( + "codec", + [ + CastValue(data_type="int16", rounding="towards-zero", out_of_range="clamp"), + CastValue( + data_type="uint8", + out_of_range="clamp", + scalar_map={"encode": [("NaN", 0)], "decode": [(0, "NaN")]}, + ), + ], + ids=["no-scalar-map", "with-scalar-map"], +) +def test_serialization_roundtrip(codec: CastValue) -> None: """to_dict followed by from_dict produces an equal codec.""" - original = CastValue(data_type="int16", rounding="towards-zero", out_of_range="clamp") - restored = CastValue.from_dict(original.to_dict()) - assert original == restored + restored = CastValue.from_dict(codec.to_dict()) + assert codec == restored # --------------------------------------------------------------------------- @@ -168,6 +187,7 @@ def test_zero_itemsize_raises() -> None: # --------------------------------------------------------------------------- +@requires_cast_value_rs @pytest.mark.parametrize( "case", [ @@ -200,6 +220,7 @@ def test_encode_decode_roundtrip( np.testing.assert_array_equal(arr[:], case.expected) +@requires_cast_value_rs @pytest.mark.parametrize( "case", [ @@ -229,6 +250,7 @@ def test_float_to_int_rounding( np.testing.assert_array_equal(arr[:], case.expected) +@requires_cast_value_rs @pytest.mark.parametrize( "case", [ @@ -258,6 +280,123 @@ def test_out_of_range_clamp( np.testing.assert_array_equal(arr[:], case.expected) +def test_compute_encoded_size() -> None: + """compute_encoded_size correctly scales byte length by itemsize ratio.""" + from zarr.core.array_spec import ArrayConfig, ArraySpec + from zarr.core.buffer import default_buffer_prototype + from zarr.core.dtype import get_data_type_from_json + + codec = CastValue(data_type="int16") + spec = ArraySpec( + shape=(10,), + dtype=get_data_type_from_json("float64", zarr_format=3), + fill_value=0, + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ) + # 10 float64 elements = 80 bytes -> 10 int16 elements = 20 bytes + assert codec.compute_encoded_size(80, spec) == 20 + + +@requires_cast_value_rs +def test_scalar_map_encode_decode_roundtrip() -> None: + """Scalar map entries are applied during encode and decode.""" + import zarr + + data = np.array([1.0, float("nan"), 3.0], dtype="float64") + arr = zarr.create_array( + store={}, + shape=data.shape, + dtype="float64", + chunks=data.shape, + filters=[ + CastValue( + data_type="int32", + rounding="nearest-even", + out_of_range="clamp", + scalar_map={"encode": [("NaN", -999)], "decode": [(-999, "NaN")]}, + ), + ], + compressors=None, + fill_value=1, + ) + arr[:] = data + result = np.asarray(arr[:]) + np.testing.assert_equal(result[0], 1.0) + np.testing.assert_equal(result[2], 3.0) + assert np.isnan(result[1]) + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input={ + "dtype": "int32", + "target": "int8", + "scalar_map": {"encode": [("NaN", 0)]}, + }, + msg="NaN.*not representable in integer dtype", + exception_cls=ValueError, + ), + ExpectErr( + input={ + "dtype": "int32", + "target": "float64", + "scalar_map": {"decode": [(0, "NaN")]}, + }, + msg="NaN.*not representable in integer dtype", + exception_cls=ValueError, + ), + ExpectErr( + input={ + "dtype": "float64", + "target": "int8", + "scalar_map": {"encode": [("NaN", 999)]}, + }, + msg="out of range", + exception_cls=ValueError, + ), + ExpectErr( + input={ + "dtype": "float64", + "target": "int8", + "scalar_map": {"encode": [("NaN", 1.5)]}, + }, + msg="not an integer", + exception_cls=ValueError, + ), + ], + ids=[ + "nan-key-for-int-source", + "nan-value-for-int-decode-target", + "encode-value-out-of-range", + "encode-value-not-integer", + ], +) +def test_scalar_map_validation_rejects_invalid(case: ExpectErr[dict[str, Any]]) -> None: + """Invalid scalar_map entries are rejected at array creation.""" + import zarr + + with pytest.raises(case.exception_cls, match=case.msg): + zarr.create_array( + store={}, + shape=(10,), + dtype=case.input["dtype"], + chunks=(10,), + filters=[ + CastValue( + data_type=case.input["target"], + out_of_range="clamp", + scalar_map=case.input["scalar_map"], + ) + ], + compressors=None, + fill_value=0, + ) + + +@requires_cast_value_rs def test_combined_with_scale_offset() -> None: """scale_offset followed by cast_value compresses float64 into int16 and round-trips.""" import zarr From 0e01a212af1e07e4e373e0cd8c005a3f35f3ed93 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 8 Apr 2026 00:01:24 +0200 Subject: [PATCH 5/8] chore: preserve JSON encoding of scale and offset parameters --- src/zarr/codecs/scale_offset.py | 45 +++++++++++++++-------- tests/test_codecs/test_scale_offset.py | 49 ++++++++++++++++++++++++-- 2 files changed, 76 insertions(+), 18 deletions(-) diff --git a/src/zarr/codecs/scale_offset.py b/src/zarr/codecs/scale_offset.py index fae0a9babb..38796a0b36 100644 --- a/src/zarr/codecs/scale_offset.py +++ b/src/zarr/codecs/scale_offset.py @@ -36,14 +36,14 @@ class ScaleOffset(ArrayArrayCodec): is_fixed_size = True - offset: int | float - scale: int | float + offset: int | float | str + scale: int | float | str def __init__(self, *, offset: object = 0, scale: object = 1) -> None: - if not isinstance(offset, int | float): - raise TypeError(f"offset must be a number, got {type(offset).__name__}") - if not isinstance(scale, int | float): - raise TypeError(f"scale must be a number, got {type(scale).__name__}") + if not isinstance(offset, int | float | str): + raise TypeError(f"offset must be a number or string, got {type(offset).__name__}") + if not isinstance(scale, int | float | str): + raise TypeError(f"scale must be a number or string, got {type(scale).__name__}") object.__setattr__(self, "offset", offset) object.__setattr__(self, "scale", scale) @@ -78,25 +78,38 @@ def validate( f"scale_offset codec only supports integer and floating-point data types. " f"Got {dtype}." ) + for name, value in [("offset", self.offset), ("scale", self.scale)]: + try: + dtype.from_json_scalar(value, zarr_format=3) + except (TypeError, ValueError) as e: + raise ValueError( + f"scale_offset {name} value {value!r} is not representable in dtype {native}." + ) from e + + def _to_scalar(self, value: float | str, dtype: ZDType[TBaseDType, TBaseScalar]) -> TBaseScalar: + """Convert a JSON-form value to a numpy scalar using the given dtype.""" + return dtype.from_json_scalar(value, zarr_format=3) def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec: - native_dtype = chunk_spec.dtype.to_native_dtype() + zdtype = chunk_spec.dtype fill = chunk_spec.fill_value - new_fill = (native_dtype.type(fill) - native_dtype.type(self.offset)) * native_dtype.type( # type: ignore[operator] - self.scale - ) + offset = self._to_scalar(self.offset, zdtype) + scale = self._to_scalar(self.scale, zdtype) + new_fill = (zdtype.to_native_dtype().type(fill) - offset) * scale # type: ignore[operator] return replace(chunk_spec, fill_value=new_fill) def _decode_sync( self, chunk_array: NDBuffer, - _chunk_spec: ArraySpec, + chunk_spec: ArraySpec, ) -> NDBuffer: arr = chunk_array.as_ndarray_like() + offset = self._to_scalar(self.offset, chunk_spec.dtype) + scale = self._to_scalar(self.scale, chunk_spec.dtype) if np.issubdtype(arr.dtype, np.integer): - result = (arr // arr.dtype.type(self.scale)) + arr.dtype.type(self.offset) + result = (arr // scale) + offset # type: ignore[operator] else: - result = (arr / arr.dtype.type(self.scale)) + arr.dtype.type(self.offset) + result = (arr / scale) + offset # type: ignore[operator] if result.dtype != arr.dtype: raise ValueError( f"scale_offset decode changed dtype from {arr.dtype} to {result.dtype}. " @@ -114,10 +127,12 @@ async def _decode_single( def _encode_sync( self, chunk_array: NDBuffer, - _chunk_spec: ArraySpec, + chunk_spec: ArraySpec, ) -> NDBuffer | None: arr = chunk_array.as_ndarray_like() - result = (arr - arr.dtype.type(self.offset)) * arr.dtype.type(self.scale) + offset = self._to_scalar(self.offset, chunk_spec.dtype) + scale = self._to_scalar(self.scale, chunk_spec.dtype) + result = (arr - offset) * scale # type: ignore[operator] if result.dtype != arr.dtype: raise ValueError( f"scale_offset encode changed dtype from {arr.dtype} to {result.dtype}. " diff --git a/tests/test_codecs/test_scale_offset.py b/tests/test_codecs/test_scale_offset.py index 47af287c1d..d943a398b5 100644 --- a/tests/test_codecs/test_scale_offset.py +++ b/tests/test_codecs/test_scale_offset.py @@ -89,10 +89,16 @@ def test_serialization_roundtrip() -> None: @pytest.mark.parametrize( "case", [ - ExpectErr(input={"offset": "bad"}, msg="offset must be a number", exception_cls=TypeError), - ExpectErr(input={"scale": [1, 2]}, msg="scale must be a number", exception_cls=TypeError), + ExpectErr( + input={"offset": [1, 2]}, + msg="offset must be a number or string", + exception_cls=TypeError, + ), + ExpectErr( + input={"scale": [1, 2]}, msg="scale must be a number or string", exception_cls=TypeError + ), ], - ids=["string-offset", "list-scale"], + ids=["list-offset", "list-scale"], ) def test_construction_rejects_non_numeric(case: ExpectErr[dict[str, Any]]) -> None: """Non-numeric offset or scale is rejected at construction time.""" @@ -201,6 +207,43 @@ def test_rejects_complex_dtype() -> None: ) +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input={"dtype": "int32", "offset": 1.5, "scale": 1}, + msg="offset value 1.5 is not representable", + exception_cls=ValueError, + ), + ExpectErr( + input={"dtype": "int32", "offset": 0, "scale": 0.5}, + msg="scale value 0.5 is not representable", + exception_cls=ValueError, + ), + ExpectErr( + input={"dtype": "int16", "offset": "NaN", "scale": 1}, + msg="offset value 'NaN' is not representable", + exception_cls=ValueError, + ), + ], + ids=["float-offset-for-int", "float-scale-for-int", "nan-offset-for-int"], +) +def test_rejects_unrepresentable_scale_offset(case: ExpectErr[dict[str, Any]]) -> None: + """Scale/offset values that can't be represented in the array dtype are rejected.""" + import zarr + + with pytest.raises(case.exception_cls, match=case.msg): + zarr.create_array( + store={}, + shape=(10,), + dtype=case.input["dtype"], + chunks=(10,), + filters=[ScaleOffset(offset=case.input["offset"], scale=case.input["scale"])], + compressors=None, + fill_value=0, + ) + + def test_dtype_preservation() -> None: """Integer scale/offset arithmetic preserves the array dtype via floor division.""" import zarr From 0d6e48d1018874616690cc19176059f29fa6ee3d Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 8 Apr 2026 09:15:57 +0200 Subject: [PATCH 6/8] chore: internal cleanup --- src/zarr/codecs/cast_value.py | 61 ++++++++++---------------- src/zarr/codecs/scale_offset.py | 2 +- tests/test_codecs/conftest.py | 20 +++++++++ tests/test_codecs/test_cast_value.py | 27 +++--------- tests/test_codecs/test_scale_offset.py | 20 +-------- 5 files changed, 51 insertions(+), 79 deletions(-) create mode 100644 tests/test_codecs/conftest.py diff --git a/src/zarr/codecs/cast_value.py b/src/zarr/codecs/cast_value.py index 2a3552df5e..08816763c8 100644 --- a/src/zarr/codecs/cast_value.py +++ b/src/zarr/codecs/cast_value.py @@ -44,13 +44,13 @@ class ScalarMapJSON(TypedDict): OutOfRangeMode = Literal["clamp", "wrap"] -class ScalarMap(TypedDict): +class ScalarMap(TypedDict, total=False): """ The normalized, in-memory form of a scalar map. """ - encode: NotRequired[Mapping[str | float | int, str | float | int]] - decode: NotRequired[Mapping[str | float | int, str | float | int]] + encode: Mapping[str | float | int, str | float | int] + decode: Mapping[str | float | int, str | float | int] def parse_scalar_map(obj: ScalarMapJSON | ScalarMap) -> ScalarMap: @@ -85,33 +85,18 @@ def parse_scalar_map(obj: ScalarMapJSON | ScalarMap) -> ScalarMap: _HAS_RUST_BACKEND = False -def _check_scalar_representable( - value: str | float, - dtype: np.dtype, # type: ignore[type-arg] - direction: str, - role: str, +def _check_representable( + value: JSON, + zdtype: ZDType[TBaseDType, TBaseScalar], + label: str, ) -> None: - """Raise ``ValueError`` if *value* cannot be represented in *dtype*.""" - fval = float(value) - is_integer_dtype = np.issubdtype(dtype, np.integer) - if np.isnan(fval) and is_integer_dtype: + """Raise ``ValueError`` if *value* cannot be parsed by *zdtype*.""" + try: + zdtype.from_json_scalar(value, zarr_format=3) + except (TypeError, ValueError, OverflowError) as e: raise ValueError( - f"scalar_map {direction} {role} {value!r} is NaN, " - f"which is not representable in integer dtype {dtype}." - ) - if is_integer_dtype: - info = np.iinfo(dtype) - ival = int(fval) - if float(ival) != fval: - raise ValueError( - f"scalar_map {direction} {role} {value!r} is not an integer, " - f"which is required for integer dtype {dtype}." - ) - if ival < info.min or ival > info.max: - raise ValueError( - f"scalar_map {direction} {role} {value!r} is out of range " - f"for dtype {dtype} [{info.min}, {info.max}]." - ) + f"{label} {value!r} is not representable in dtype {zdtype.to_native_dtype()}." + ) from e # --------------------------------------------------------------------------- @@ -215,28 +200,30 @@ def validate( raise ValueError("out_of_range='wrap' is only valid for integer target types.") if self.scalar_map is not None: - self._validate_scalar_map(source_native, target_native) + self._validate_scalar_map(dtype, self.dtype) def _validate_scalar_map( self, - source_native: np.dtype, # type: ignore[type-arg] - target_native: np.dtype, # type: ignore[type-arg] + source_zdtype: ZDType[TBaseDType, TBaseScalar], + target_zdtype: ZDType[TBaseDType, TBaseScalar], ) -> None: """Validate that scalar map entries are compatible with source/target dtypes.""" assert self.scalar_map is not None # For encode: keys are source values, values are target values. # For decode: keys are target values, values are source values. - direction_dtypes = { - "encode": (source_native, target_native), - "decode": (target_native, source_native), + direction_dtypes: dict[ + str, tuple[ZDType[TBaseDType, TBaseScalar], ZDType[TBaseDType, TBaseScalar]] + ] = { + "encode": (source_zdtype, target_zdtype), + "decode": (target_zdtype, source_zdtype), } - for direction, (key_dtype, val_dtype) in direction_dtypes.items(): + for direction, (key_zdtype, val_zdtype) in direction_dtypes.items(): if direction not in self.scalar_map: continue sub_map = self.scalar_map[direction] # type: ignore[literal-required] for k, v in sub_map.items(): - _check_scalar_representable(k, key_dtype, direction, "key") - _check_scalar_representable(v, val_dtype, direction, "value") + _check_representable(k, key_zdtype, f"scalar_map {direction} key") + _check_representable(v, val_zdtype, f"scalar_map {direction} value") def _do_cast( self, diff --git a/src/zarr/codecs/scale_offset.py b/src/zarr/codecs/scale_offset.py index 38796a0b36..f1568ca9c5 100644 --- a/src/zarr/codecs/scale_offset.py +++ b/src/zarr/codecs/scale_offset.py @@ -81,7 +81,7 @@ def validate( for name, value in [("offset", self.offset), ("scale", self.scale)]: try: dtype.from_json_scalar(value, zarr_format=3) - except (TypeError, ValueError) as e: + except (TypeError, ValueError, OverflowError) as e: raise ValueError( f"scale_offset {name} value {value!r} is not representable in dtype {native}." ) from e diff --git a/tests/test_codecs/conftest.py b/tests/test_codecs/conftest.py new file mode 100644 index 0000000000..b654ab1ec0 --- /dev/null +++ b/tests/test_codecs/conftest.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class Expect[TIn, TOut]: + """Model an input and an expected output value for a test case.""" + + input: TIn + expected: TOut + + +@dataclass(frozen=True) +class ExpectErr[TIn]: + """Model an input and an expected error message for a test case.""" + + input: TIn + msg: str + exception_cls: type[Exception] diff --git a/tests/test_codecs/test_cast_value.py b/tests/test_codecs/test_cast_value.py index c7611f0c9d..132339a6d7 100644 --- a/tests/test_codecs/test_cast_value.py +++ b/tests/test_codecs/test_cast_value.py @@ -1,11 +1,11 @@ from __future__ import annotations -from dataclasses import dataclass from typing import Any import numpy as np import pytest +from tests.test_codecs.conftest import Expect, ExpectErr from zarr.codecs.cast_value import CastValue try: @@ -20,23 +20,6 @@ ) -@dataclass(frozen=True) -class Expect[TIn, TOut]: - """Model an input and an expected output value for a test case.""" - - input: TIn - expected: TOut - - -@dataclass(frozen=True) -class ExpectErr[TIn]: - """Model an input and an expected error message for a test case.""" - - input: TIn - msg: str - exception_cls: type[Exception] - - # --------------------------------------------------------------------------- # Serialization # --------------------------------------------------------------------------- @@ -336,7 +319,7 @@ def test_scalar_map_encode_decode_roundtrip() -> None: "target": "int8", "scalar_map": {"encode": [("NaN", 0)]}, }, - msg="NaN.*not representable in integer dtype", + msg="not representable in dtype int32", exception_cls=ValueError, ), ExpectErr( @@ -345,7 +328,7 @@ def test_scalar_map_encode_decode_roundtrip() -> None: "target": "float64", "scalar_map": {"decode": [(0, "NaN")]}, }, - msg="NaN.*not representable in integer dtype", + msg="not representable in dtype int32", exception_cls=ValueError, ), ExpectErr( @@ -354,7 +337,7 @@ def test_scalar_map_encode_decode_roundtrip() -> None: "target": "int8", "scalar_map": {"encode": [("NaN", 999)]}, }, - msg="out of range", + msg="not representable in dtype int8", exception_cls=ValueError, ), ExpectErr( @@ -363,7 +346,7 @@ def test_scalar_map_encode_decode_roundtrip() -> None: "target": "int8", "scalar_map": {"encode": [("NaN", 1.5)]}, }, - msg="not an integer", + msg="not representable in dtype int8", exception_cls=ValueError, ), ], diff --git a/tests/test_codecs/test_scale_offset.py b/tests/test_codecs/test_scale_offset.py index d943a398b5..549db5ca43 100644 --- a/tests/test_codecs/test_scale_offset.py +++ b/tests/test_codecs/test_scale_offset.py @@ -1,31 +1,13 @@ from __future__ import annotations -from dataclasses import dataclass from typing import Any import numpy as np import pytest +from tests.test_codecs.conftest import Expect, ExpectErr from zarr.codecs.scale_offset import ScaleOffset - -@dataclass(frozen=True) -class Expect[TIn, TOut]: - """Model an input and an expected output value for a test case.""" - - input: TIn - expected: TOut - - -@dataclass(frozen=True) -class ExpectErr[TIn]: - """Model an input and an expected error message for a test case.""" - - input: TIn - msg: str - exception_cls: type[Exception] - - # --------------------------------------------------------------------------- # Serialization # --------------------------------------------------------------------------- From ee15c9ee135020e8b75fb488a5d7618d09f32d15 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 9 Apr 2026 15:54:26 +0200 Subject: [PATCH 7/8] fix: use explicit list of data type names --- src/zarr/codecs/cast_value.py | 59 +++++++++++++++++++++++----- tests/test_codecs/test_cast_value.py | 7 ++-- 2 files changed, 53 insertions(+), 13 deletions(-) diff --git a/src/zarr/codecs/cast_value.py b/src/zarr/codecs/cast_value.py index 08816763c8..17c144acbf 100644 --- a/src/zarr/codecs/cast_value.py +++ b/src/zarr/codecs/cast_value.py @@ -12,7 +12,7 @@ from collections.abc import Mapping from dataclasses import dataclass, replace -from typing import TYPE_CHECKING, Literal, TypedDict, cast +from typing import TYPE_CHECKING, Final, Literal, TypedDict, cast import numpy as np @@ -53,6 +53,36 @@ class ScalarMap(TypedDict, total=False): decode: Mapping[str | float | int, str | float | int] +PERMITTED_DATA_TYPE_NAMES: Final[set[str]] = { + "int2", + "int4", + "int8", + "int16", + "int32", + "int64", + "int64uint2", + "uint4", + "uint8", + "uint16", + "uint32", + "uint64", + "uint64float4_e2m1fn", + "float6_e2m3fn", + "float6_e3m2fn", + "float8_e3m4", + "float8_e4m3", + "float8_e4m3b11fnuz", + "float8_e4m3fnuz", + "float8_e5m2", + "float8_e5m2fnuz", + "float8_e8m0fnu", + "bfloat16", + "float16", + "float32", + "float64", +} + + def parse_scalar_map(obj: ScalarMapJSON | ScalarMap) -> ScalarMap: """ Parse a scalar map into its normalized dict-of-dicts form. @@ -151,6 +181,12 @@ def __init__( zdtype = get_data_type_from_json(data_type, zarr_format=3) else: zdtype = data_type + if zdtype.to_json(zarr_format=3) not in PERMITTED_DATA_TYPE_NAMES: + raise ValueError( + f"Invalid target data type {data_type!r}. " + f"cast_value codec only supports integer and floating-point data types. " + f"Got {zdtype}." + ) object.__setattr__(self, "dtype", zdtype) object.__setattr__(self, "rounding", rounding) object.__setattr__(self, "out_of_range", out_of_range) @@ -188,14 +224,13 @@ def validate( dtype: ZDType[TBaseDType, TBaseScalar], chunk_grid: ChunkGridMetadata, ) -> None: - source_native = dtype.to_native_dtype() - target_native = self.dtype.to_native_dtype() - for label, dt in [("source", source_native), ("target", target_native)]: - if not np.issubdtype(dt, np.integer) and not np.issubdtype(dt, np.floating): - raise ValueError( - f"The cast_value codec only supports integer and floating-point data types. " - f"Got {label} dtype {dt}." - ) + target_name = dtype.to_json(zarr_format=3) + if target_name not in PERMITTED_DATA_TYPE_NAMES: + raise ValueError( + f"The cast_value codec only supports integer and floating-point data types. " + f"Got dtype {target_name}." + ) + target_native = dtype.to_native_dtype() if self.out_of_range == "wrap" and not np.issubdtype(target_native, np.integer): raise ValueError("out_of_range='wrap' is only valid for integer target types.") @@ -318,6 +353,12 @@ async def _decode_single( return self._decode_sync(chunk_data, chunk_spec) def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) -> int: + dtype_name = chunk_spec.dtype.to_json(zarr_format=3) + if dtype_name not in PERMITTED_DATA_TYPE_NAMES: + raise ValueError( + "cast_value codec only supports fixed-size integer and floating-point data types. " + f"Got source dtype: {chunk_spec.dtype}." + ) source_itemsize = chunk_spec.dtype.to_native_dtype().itemsize target_itemsize = self.dtype.to_native_dtype().itemsize if source_itemsize == 0 or target_itemsize == 0: diff --git a/tests/test_codecs/test_cast_value.py b/tests/test_codecs/test_cast_value.py index 132339a6d7..6112c0e471 100644 --- a/tests/test_codecs/test_cast_value.py +++ b/tests/test_codecs/test_cast_value.py @@ -5,6 +5,7 @@ import numpy as np import pytest +import zarr from tests.test_codecs.conftest import Expect, ExpectErr from zarr.codecs.cast_value import CastValue @@ -119,7 +120,7 @@ def test_serialization_roundtrip(codec: CastValue) -> None: exception_cls=ValueError, ), ExpectErr( - input={"dtype": "int32", "target": "float32", "out_of_range": "wrap"}, + input={"dtype": "float32", "target": "int32", "out_of_range": "wrap"}, msg="only valid for integer", exception_cls=ValueError, ), @@ -128,8 +129,6 @@ def test_serialization_roundtrip(codec: CastValue) -> None: ) def test_validation_rejects_invalid(case: ExpectErr[dict[str, Any]]) -> None: """Invalid dtype or out_of_range combinations are rejected at array creation.""" - import zarr - with pytest.raises(case.exception_cls, match=case.msg): zarr.create_array( store={}, @@ -161,7 +160,7 @@ def test_zero_itemsize_raises() -> None: config=ArrayConfig(order="C", write_empty_chunks=True), prototype=default_buffer_prototype(), ) - with pytest.raises(ValueError, match="fixed-size data types"): + with pytest.raises(ValueError, match="fixed-size integer and floating-point data types"): codec.compute_encoded_size(100, spec) From d1cee7366d9b8ee70f19bed86575c49fbcc13fea Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 9 Apr 2026 15:55:30 +0200 Subject: [PATCH 8/8] docs: add comment --- src/zarr/codecs/cast_value.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/zarr/codecs/cast_value.py b/src/zarr/codecs/cast_value.py index 17c144acbf..da4333a9e7 100644 --- a/src/zarr/codecs/cast_value.py +++ b/src/zarr/codecs/cast_value.py @@ -53,6 +53,7 @@ class ScalarMap(TypedDict, total=False): decode: Mapping[str | float | int, str | float | int] +# see https://github.com/zarr-developers/zarr-extensions/tree/main/codecs/cast_value PERMITTED_DATA_TYPE_NAMES: Final[set[str]] = { "int2", "int4",