From 74f6c7c515408705d04e8e4a7b5216656aff43ee Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Wed, 18 Mar 2026 11:21:05 +1100 Subject: [PATCH 1/7] implement numpy.join unjoin ops and add tests claude used to discover itertools.accumulate for prefix sum calculations and implement Stack.unjoin --- .../pipeline/operations/numpy/join.py | 29 ++--- .../tests/operations/numpy/test_numpy_join.py | 102 ++++++++++++++++++ 2 files changed, 118 insertions(+), 13 deletions(-) create mode 100644 packages/pipeline/tests/operations/numpy/test_numpy_join.py diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/join.py b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/join.py index 72e519a8..13cf670b 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/join.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/join.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +from itertools import accumulate from typing import Optional, Any import numpy as np @@ -23,8 +23,6 @@ class Stack(Joiner): """ Stack a tuple of np.ndarray's - - Currently cannot undo this operation """ _override_interface = ["Delayed", "Serial"] @@ -40,14 +38,16 @@ def join(self, sample: tuple[Any, ...]) -> np.ndarray: return np.stack(sample, self.axis) # type: ignore def unjoin(self, sample: Any) -> tuple: - return super().unjoin(sample) + """Unstacks a stacked sample""" + # np.stack(..., axis=None) is equivalent to np.stack(..., axis=0) + axis = self.axis if self.axis is not None else 0 + # move the stacked axis to the zeroth axis and convert to tuple + return tuple(np.moveaxis(sample, axis, 0)) class VStack(Joiner): """ Vertically Stack a tuple of np.ndarray's - - Currently cannot undo this operation """ _override_interface = ["Delayed", "Serial"] @@ -56,22 +56,23 @@ class VStack(Joiner): def __init__(self): super().__init__() self.record_initialisation() + self.offsets = None # stores the vertical offset where each joined array is def join(self, sample: tuple[Any, ...]) -> np.ndarray: """Join sample""" + # stores + self.offsets = tuple(accumulate(arr.shape[0] for arr in sample[:-1])) return np.vstack( sample, ) # type: ignore def unjoin(self, sample: Any) -> tuple: - return super().unjoin(sample) + return tuple(np.vsplit(sample, self.offsets)) class HStack(Joiner): """ Horizontally Stack a tuple of np.ndarray's - - Currently cannot undo this operation """ _override_interface = ["Delayed", "Serial"] @@ -80,22 +81,22 @@ class HStack(Joiner): def __init__(self): super().__init__() self.record_initialisation() + self.offsets = None def join(self, sample: tuple[Any, ...]) -> np.ndarray: """Join sample""" + self.offsets = tuple(accumulate(arr.shape[1] for arr in sample[:-1])) return np.hstack( sample, ) # type: ignore def unjoin(self, sample: Any) -> tuple: - return super().unjoin(sample) + return tuple(np.hsplit(sample, self.offsets)) class Concatenate(Joiner): """ Concatenate a tuple of np.ndarray's - - Currently cannot undo this operation """ _override_interface = ["Delayed", "Serial"] @@ -105,10 +106,12 @@ def __init__(self, axis: Optional[int] = None): super().__init__() self.record_initialisation() self.axis = axis + self.offsets = None def join(self, sample: tuple[Any, ...]) -> np.ndarray: """Join sample""" + self.offsets = tuple(accumulate(arr.shape[self.axis] for arr in sample[:-1])) return np.concatenate(sample, self.axis) # type: ignore def unjoin(self, sample: Any) -> tuple: - return super().unjoin(sample) + return tuple(np.split(sample, self.offsets, axis=self.axis)) diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_join.py b/packages/pipeline/tests/operations/numpy/test_numpy_join.py new file mode 100644 index 00000000..9be155bf --- /dev/null +++ b/packages/pipeline/tests/operations/numpy/test_numpy_join.py @@ -0,0 +1,102 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyearthtools.pipeline.operations.numpy import join + +import numpy as np +import pytest + + +def test_stacks(): + """Tests that join.Stack reproduces np.stack behaviour.""" + + numpy_arrays = ( + np.array(range(6)).reshape((2, 3)), + np.array(range(6, 12)).reshape((2, 3)), + ) + + for stack_axis in range(3): + stack = join.Stack(axis=stack_axis) + result = stack.join(numpy_arrays) + expected = np.stack(numpy_arrays, axis=stack_axis) + assert np.array_equal(result, expected), f"Stack(axis={stack_axis}).join() did not reproduce np.stack" + unjoined_result = stack.unjoin(result) + assert isinstance( + unjoined_result, tuple + ), f"Stack(axis={stack_axis}).unjoin() did not unjoin the input sample into tuples." + for arr_undo, arr in zip(unjoined_result, numpy_arrays, strict=True): + assert np.array_equal(arr_undo, arr), f"Stack(axis={stack_axis}).unjoin() did not return original arrays" + + +@pytest.fixture +def concat_array_data(): + return ( + np.array(range(6)), + np.array(range(6, 18)), + ) + + +@pytest.mark.parametrize( + ("joiner", "equiv_np_op", "input_shapes"), + ( + (join.VStack, np.vstack, ((1, 3, 2), (2, 3, 2))), + (join.HStack, np.hstack, ((3, 1, 2), (3, 2, 2))), + ), +) +def test_vstack(joiner, equiv_np_op, input_shapes, concat_array_data): + """Tests that join.XStack reproduces np.xstack behaviour.""" + + input_arrays = tuple(arr.reshape(shape) for arr, shape in zip(concat_array_data, input_shapes)) + + stack = joiner() + result = stack.join(input_arrays) + expected = equiv_np_op(input_arrays) + assert np.array_equal( + result, expected + ), f"{joiner.__name__}.join() did not reproduce {equiv_np_op.__name__} behaviour." + unjoined_result = stack.unjoin(result) + assert isinstance( + unjoined_result, tuple + ), f"{joiner.__name__}.unjoin() did not unjoin the input sample into tuples." + for arr_undo, arr in zip(unjoined_result, input_arrays, strict=True): + assert np.array_equal(arr_undo, arr), f"{joiner.__name__}.unjoin() did not return original arrays." + + +@pytest.mark.parametrize( + ("concat_axis", "input_shapes"), + ( + (0, ((1, 3, 2), (2, 3, 2))), + (1, ((3, 1, 2), (3, 2, 2))), + (2, ((3, 2, 1), (3, 2, 2))), + ), +) +def test_concatenate(concat_axis, input_shapes, concat_array_data): + """Tests that join.Concatenate reproduces np.concatenate behaviour.""" + + input_arrays = tuple(arr.reshape(shape) for arr, shape in zip(concat_array_data, input_shapes)) + + stack = join.Concatenate(axis=concat_axis) + result = stack.join(input_arrays) + expected = np.concatenate(input_arrays, axis=concat_axis) + assert np.array_equal( + result, expected + ), f"Concatenate(axis={concat_axis}) did not reproduce np.concatenate behaviour." + unjoined_result = stack.unjoin(result) + assert isinstance( + unjoined_result, tuple + ), f"Concatenate(axis={concat_axis}).unjoin() did not unjoin the input sample into tuples." + for arr_undo, arr in zip(unjoined_result, input_arrays, strict=True): + assert np.array_equal( + arr_undo, arr, f"Concatenate(axis={concat_axis}).unjoin() did not return original arrays." + ) From 8722b50bff9ce80d90d00f210a5072d0aa676fcb Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Wed, 18 Mar 2026 12:31:21 +1100 Subject: [PATCH 2/7] implement dask unjoin and add tests --- .../pipeline/operations/dask/join.py | 33 ++++-- .../tests/operations/dask/test_dask_join.py | 111 ++++++++++++++++++ 2 files changed, 131 insertions(+), 13 deletions(-) create mode 100644 packages/pipeline/tests/operations/dask/test_dask_join.py diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/dask/join.py b/packages/pipeline/src/pyearthtools/pipeline/operations/dask/join.py index 0045efde..dae15c0c 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/dask/join.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/dask/join.py @@ -15,6 +15,7 @@ # type: ignore[reportPrivateImportUsage] +from itertools import accumulate from typing import Optional, Any import dask.array as da @@ -26,14 +27,12 @@ class Stack(Joiner, DaskOperation): """ Stack a tuple of da.Array's - - Currently cannot undo this operation """ _override_interface = ["Serial"] _numpy_counterpart = "join.Stack" - def __init__(self, axis: Optional[int] = None): + def __init__(self, axis: Optional[int] = 0): super().__init__() self.record_initialisation() self.axis = axis @@ -43,14 +42,14 @@ def join(self, sample: tuple[Any, ...]) -> da.Array: return da.stack(sample, self.axis) # type: ignore def unjoin(self, sample: Any) -> tuple: - return super().unjoin(sample) + """Unstacks a stacked sample""" + # move the stacked axis to the zeroth axis and convert to tuple + return tuple(da.moveaxis(sample, self.axis, 0)) class VStack(Joiner, DaskOperation): """ Vertically Stack a tuple of da.Array's - - Currently cannot undo this operation """ _override_interface = ["Serial"] @@ -59,22 +58,24 @@ class VStack(Joiner, DaskOperation): def __init__(self): super().__init__() self.record_initialisation() + self.offsets = None def join(self, sample: tuple[Any, ...]) -> da.Array: """Join sample""" + self.offsets = tuple(accumulate(arr.shape[0] for arr in sample[:-1])) return da.vstack( sample, ) # type: ignore def unjoin(self, sample: Any) -> tuple: - return super().unjoin(sample) + start = (0,) + self.offsets + ends = self.offsets + (sample.shape[0],) + return tuple(sample[start:end] for start, end in zip(start, ends, strict=True)) class HStack(Joiner, DaskOperation): """ Horizontally Stack a tuple of da.Array's - - Currently cannot undo this operation """ _override_interface = ["Serial"] @@ -83,22 +84,24 @@ class HStack(Joiner, DaskOperation): def __init__(self): super().__init__() self.record_initialisation() + self.offsets = None def join(self, sample: tuple[Any, ...]) -> da.Array: """Join sample""" + self.offsets = tuple(accumulate(arr.shape[1] for arr in sample[:-1])) return da.hstack( sample, ) # type: ignore def unjoin(self, sample: Any) -> tuple: - return super().unjoin(sample) + start = (0,) + self.offsets + ends = self.offsets + (sample.shape[1],) + return tuple(sample[:, start:end, ...] for start, end in zip(start, ends, strict=True)) class Concatenate(Joiner, DaskOperation): """ Concatenate a tuple of da.Array's - - Currently cannot undo this operation """ _override_interface = ["Serial"] @@ -108,10 +111,14 @@ def __init__(self, axis: Optional[int] = None): super().__init__() self.record_initialisation() self.axis = axis + self.offsets = None def join(self, sample: tuple[Any, ...]) -> da.Array: """Join sample""" + self.offsets = tuple(accumulate(arr.shape[self.axis] for arr in sample[:-1])) return da.concatenate(sample, self.axis) # type: ignore def unjoin(self, sample: Any) -> tuple: - return super().unjoin(sample) + start = (0,) + self.offsets + ends = self.offsets + (sample.shape[self.axis],) + return tuple(da.take(sample, slice(start, end), self.axis) for start, end in zip(start, ends, strict=True)) diff --git a/packages/pipeline/tests/operations/dask/test_dask_join.py b/packages/pipeline/tests/operations/dask/test_dask_join.py new file mode 100644 index 00000000..82390322 --- /dev/null +++ b/packages/pipeline/tests/operations/dask/test_dask_join.py @@ -0,0 +1,111 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyearthtools.pipeline.operations.dask import join + +import dask.array as da +import numpy as np +import pytest + + +def test_stacks(): + """Tests that join.Stack reproduces da.stack behaviour.""" + + numpy_arrays = ( + da.array(range(6)).reshape((2, 3)), + da.array(range(6, 12)).reshape((2, 3)), + ) + + for stack_axis in range(3): + stack = join.Stack(axis=stack_axis) + result = stack.join(numpy_arrays) + expected = da.stack(numpy_arrays, axis=stack_axis) + assert np.array_equal( + result.compute(), expected.compute() + ), f"Stack(axis={stack_axis}).join() did not reproduce da.stack" + unjoined_result = stack.unjoin(result) + assert isinstance( + unjoined_result, tuple + ), f"Stack(axis={stack_axis}).unjoin() did not unjoin the input sample into tuples." + for arr_undo, arr in zip(unjoined_result, numpy_arrays, strict=True): + assert np.array_equal( + arr_undo.compute(), arr.compute() + ), f"Stack(axis={stack_axis}).unjoin() did not return original arrays" + + +@pytest.fixture +def concat_array_data(): + return ( + da.array(range(6)), + da.array(range(6, 18)), + ) + + +@pytest.mark.parametrize( + ("joiner", "equiv_np_op", "input_shapes"), + ( + (join.VStack, da.vstack, ((1, 3, 2), (2, 3, 2))), + (join.HStack, da.hstack, ((3, 1, 2), (3, 2, 2))), + ), +) +def test_vstack(joiner, equiv_np_op, input_shapes, concat_array_data): + """Tests that join.XStack reproduces da.xstack behaviour.""" + + input_arrays = tuple(arr.reshape(shape) for arr, shape in zip(concat_array_data, input_shapes)) + + stack = joiner() + result = stack.join(input_arrays) + expected = equiv_np_op(input_arrays) + assert np.array_equal( + result.compute(), expected.compute() + ), f"{joiner.__name__}.join() did not reproduce {equiv_np_op.__name__} behaviour." + unjoined_result = stack.unjoin(result) + assert isinstance( + unjoined_result, tuple + ), f"{joiner.__name__}.unjoin() did not unjoin the input sample into tuples." + for arr_undo, arr in zip(unjoined_result, input_arrays, strict=True): + assert np.array_equal( + arr_undo.compute(), arr.compute() + ), f"{joiner.__name__}.unjoin() did not return original arrays." + + +@pytest.mark.parametrize( + ("concat_axis", "input_shapes"), + ( + (0, ((1, 3, 2), (2, 3, 2))), + (1, ((3, 1, 2), (3, 2, 2))), + (2, ((3, 2, 1), (3, 2, 2))), + ), +) +def test_concatenate(concat_axis, input_shapes, concat_array_data): + """Tests that join.Concatenate reproduces da.concatenate behaviour.""" + + input_arrays = tuple(arr.reshape(shape) for arr, shape in zip(concat_array_data, input_shapes)) + + stack = join.Concatenate(axis=concat_axis) + result = stack.join(input_arrays) + expected = da.concatenate(input_arrays, axis=concat_axis) + assert np.array_equal( + result.compute(), expected.compute() + ), f"Concatenate(axis={concat_axis}) did not reproduce da.concatenate behaviour." + unjoined_result = stack.unjoin(result) + assert isinstance( + unjoined_result, tuple + ), f"Concatenate(axis={concat_axis}).unjoin() did not unjoin the input sample into tuples." + for arr_undo, arr in zip(unjoined_result, input_arrays, strict=True): + assert np.array_equal( + arr_undo.compute(), + arr.compute(), + f"Concatenate(axis={concat_axis}).unjoin() did not return original arrays.", + ) From 3bb85b1073b25071d2e4d56e125f386e65da12f7 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Wed, 18 Mar 2026 12:45:59 +1100 Subject: [PATCH 3/7] use claude to unify numpy/dask join test funcs --- .../tests/operations/dask/test_dask_join.py | 131 ++++++------------ .../tests/operations/numpy/test_numpy_join.py | 123 ++++++---------- 2 files changed, 88 insertions(+), 166 deletions(-) diff --git a/packages/pipeline/tests/operations/dask/test_dask_join.py b/packages/pipeline/tests/operations/dask/test_dask_join.py index 82390322..79bcd7e3 100644 --- a/packages/pipeline/tests/operations/dask/test_dask_join.py +++ b/packages/pipeline/tests/operations/dask/test_dask_join.py @@ -12,100 +12,57 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pyearthtools.pipeline.operations.dask import join +from functools import partial + +from pyearthtools.pipeline.operations.dask.join import Stack, VStack, HStack, Concatenate import dask.array as da import numpy as np import pytest -def test_stacks(): - """Tests that join.Stack reproduces da.stack behaviour.""" - - numpy_arrays = ( - da.array(range(6)).reshape((2, 3)), - da.array(range(6, 12)).reshape((2, 3)), - ) - - for stack_axis in range(3): - stack = join.Stack(axis=stack_axis) - result = stack.join(numpy_arrays) - expected = da.stack(numpy_arrays, axis=stack_axis) - assert np.array_equal( - result.compute(), expected.compute() - ), f"Stack(axis={stack_axis}).join() did not reproduce da.stack" - unjoined_result = stack.unjoin(result) - assert isinstance( - unjoined_result, tuple - ), f"Stack(axis={stack_axis}).unjoin() did not unjoin the input sample into tuples." - for arr_undo, arr in zip(unjoined_result, numpy_arrays, strict=True): - assert np.array_equal( - arr_undo.compute(), arr.compute() - ), f"Stack(axis={stack_axis}).unjoin() did not return original arrays" - - -@pytest.fixture -def concat_array_data(): - return ( - da.array(range(6)), - da.array(range(6, 18)), - ) +def _arrays(*shapes): + """Create dask arrays with given shapes whose elements are sequential integers.""" + offset = 0 + result = [] + for shape in shapes: + size = int(np.prod(shape)) + result.append(da.array(range(offset, offset + size)).reshape(shape)) + offset += size # offset ensures next array has different contents + return tuple(result) +# this parameterizations passes in the joiner class to test, with an appropriate axis as needed. +# It compares the joined result to an equivalent dask function, partially initialised with axis as needed. +# The shape of the input array passed to the test is adjusted based on the joiner. @pytest.mark.parametrize( - ("joiner", "equiv_np_op", "input_shapes"), - ( - (join.VStack, da.vstack, ((1, 3, 2), (2, 3, 2))), - (join.HStack, da.hstack, ((3, 1, 2), (3, 2, 2))), - ), + ("joiner", "equiv_op", "input_arrays"), + [ + pytest.param(Stack(axis=0), partial(da.stack, axis=0), _arrays((2, 3), (2, 3)), id="Stack-axis0"), + pytest.param(Stack(axis=1), partial(da.stack, axis=1), _arrays((2, 3), (2, 3)), id="Stack-axis1"), + pytest.param(Stack(axis=2), partial(da.stack, axis=2), _arrays((2, 3), (2, 3)), id="Stack-axis2"), + pytest.param(VStack(), da.vstack, _arrays((1, 3, 2), (2, 3, 2)), id="VStack"), + pytest.param(HStack(), da.hstack, _arrays((3, 1, 2), (3, 2, 2)), id="HStack"), + pytest.param( + Concatenate(axis=0), partial(da.concatenate, axis=0), _arrays((1, 3, 2), (2, 3, 2)), id="Concatenate-axis0" + ), + pytest.param( + Concatenate(axis=1), partial(da.concatenate, axis=1), _arrays((3, 1, 2), (3, 2, 2)), id="Concatenate-axis1" + ), + pytest.param( + Concatenate(axis=2), partial(da.concatenate, axis=2), _arrays((3, 2, 1), (3, 2, 2)), id="Concatenate-axis2" + ), + ], ) -def test_vstack(joiner, equiv_np_op, input_shapes, concat_array_data): - """Tests that join.XStack reproduces da.xstack behaviour.""" - - input_arrays = tuple(arr.reshape(shape) for arr, shape in zip(concat_array_data, input_shapes)) - - stack = joiner() - result = stack.join(input_arrays) - expected = equiv_np_op(input_arrays) - assert np.array_equal( - result.compute(), expected.compute() - ), f"{joiner.__name__}.join() did not reproduce {equiv_np_op.__name__} behaviour." - unjoined_result = stack.unjoin(result) - assert isinstance( - unjoined_result, tuple - ), f"{joiner.__name__}.unjoin() did not unjoin the input sample into tuples." - for arr_undo, arr in zip(unjoined_result, input_arrays, strict=True): - assert np.array_equal( - arr_undo.compute(), arr.compute() - ), f"{joiner.__name__}.unjoin() did not return original arrays." - - -@pytest.mark.parametrize( - ("concat_axis", "input_shapes"), - ( - (0, ((1, 3, 2), (2, 3, 2))), - (1, ((3, 1, 2), (3, 2, 2))), - (2, ((3, 2, 1), (3, 2, 2))), - ), -) -def test_concatenate(concat_axis, input_shapes, concat_array_data): - """Tests that join.Concatenate reproduces da.concatenate behaviour.""" - - input_arrays = tuple(arr.reshape(shape) for arr, shape in zip(concat_array_data, input_shapes)) - - stack = join.Concatenate(axis=concat_axis) - result = stack.join(input_arrays) - expected = da.concatenate(input_arrays, axis=concat_axis) - assert np.array_equal( - result.compute(), expected.compute() - ), f"Concatenate(axis={concat_axis}) did not reproduce da.concatenate behaviour." - unjoined_result = stack.unjoin(result) - assert isinstance( - unjoined_result, tuple - ), f"Concatenate(axis={concat_axis}).unjoin() did not unjoin the input sample into tuples." - for arr_undo, arr in zip(unjoined_result, input_arrays, strict=True): - assert np.array_equal( - arr_undo.compute(), - arr.compute(), - f"Concatenate(axis={concat_axis}).unjoin() did not return original arrays.", - ) +def test_join(joiner, equiv_op, input_arrays): + """Tests that joiners reproduce their dask equivalents and are reversible.""" + name = type(joiner).__name__ + + result = joiner.join(input_arrays) + expected = equiv_op(input_arrays) + assert np.array_equal(result.compute(), expected.compute()), f"{name}.join() did not reproduce expected behaviour." + + unjoined = joiner.unjoin(result) + assert isinstance(unjoined, tuple), f"{name}.unjoin() did not return a tuple." + for arr_undo, arr in zip(unjoined, input_arrays, strict=True): + assert np.array_equal(arr_undo.compute(), arr.compute()), f"{name}.unjoin() did not return original arrays." diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_join.py b/packages/pipeline/tests/operations/numpy/test_numpy_join.py index 9be155bf..9ebc684c 100644 --- a/packages/pipeline/tests/operations/numpy/test_numpy_join.py +++ b/packages/pipeline/tests/operations/numpy/test_numpy_join.py @@ -12,91 +12,56 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pyearthtools.pipeline.operations.numpy import join +from functools import partial + +from pyearthtools.pipeline.operations.numpy.join import Stack, VStack, HStack, Concatenate import numpy as np import pytest -def test_stacks(): - """Tests that join.Stack reproduces np.stack behaviour.""" - - numpy_arrays = ( - np.array(range(6)).reshape((2, 3)), - np.array(range(6, 12)).reshape((2, 3)), - ) - - for stack_axis in range(3): - stack = join.Stack(axis=stack_axis) - result = stack.join(numpy_arrays) - expected = np.stack(numpy_arrays, axis=stack_axis) - assert np.array_equal(result, expected), f"Stack(axis={stack_axis}).join() did not reproduce np.stack" - unjoined_result = stack.unjoin(result) - assert isinstance( - unjoined_result, tuple - ), f"Stack(axis={stack_axis}).unjoin() did not unjoin the input sample into tuples." - for arr_undo, arr in zip(unjoined_result, numpy_arrays, strict=True): - assert np.array_equal(arr_undo, arr), f"Stack(axis={stack_axis}).unjoin() did not return original arrays" - - -@pytest.fixture -def concat_array_data(): - return ( - np.array(range(6)), - np.array(range(6, 18)), - ) +def _arrays(*shapes): + """Create numpy arrays with given shapes whose elements are sequential integers.""" + offset = 0 + result = [] + for shape in shapes: + size = int(np.prod(shape)) + result.append(np.arange(offset, offset + size).reshape(shape)) + offset += size + return tuple(result) +# this parameterizations passes in the joiner class to test, with an appropriate axis as needed. +# It compares the joined result to an equivalent numpy function, partially initialised with axis as needed. +# The shape of the input array passed to the test is adjusted based on the joiner. @pytest.mark.parametrize( - ("joiner", "equiv_np_op", "input_shapes"), - ( - (join.VStack, np.vstack, ((1, 3, 2), (2, 3, 2))), - (join.HStack, np.hstack, ((3, 1, 2), (3, 2, 2))), - ), + ("joiner", "equiv_op", "input_arrays"), + [ + pytest.param(Stack(axis=0), partial(np.stack, axis=0), _arrays((2, 3), (2, 3)), id="Stack-axis0"), + pytest.param(Stack(axis=1), partial(np.stack, axis=1), _arrays((2, 3), (2, 3)), id="Stack-axis1"), + pytest.param(Stack(axis=2), partial(np.stack, axis=2), _arrays((2, 3), (2, 3)), id="Stack-axis2"), + pytest.param(VStack(), np.vstack, _arrays((1, 3, 2), (2, 3, 2)), id="VStack"), + pytest.param(HStack(), np.hstack, _arrays((3, 1, 2), (3, 2, 2)), id="HStack"), + pytest.param( + Concatenate(axis=0), partial(np.concatenate, axis=0), _arrays((1, 3, 2), (2, 3, 2)), id="Concatenate-axis0" + ), + pytest.param( + Concatenate(axis=1), partial(np.concatenate, axis=1), _arrays((3, 1, 2), (3, 2, 2)), id="Concatenate-axis1" + ), + pytest.param( + Concatenate(axis=2), partial(np.concatenate, axis=2), _arrays((3, 2, 1), (3, 2, 2)), id="Concatenate-axis2" + ), + ], ) -def test_vstack(joiner, equiv_np_op, input_shapes, concat_array_data): - """Tests that join.XStack reproduces np.xstack behaviour.""" - - input_arrays = tuple(arr.reshape(shape) for arr, shape in zip(concat_array_data, input_shapes)) - - stack = joiner() - result = stack.join(input_arrays) - expected = equiv_np_op(input_arrays) - assert np.array_equal( - result, expected - ), f"{joiner.__name__}.join() did not reproduce {equiv_np_op.__name__} behaviour." - unjoined_result = stack.unjoin(result) - assert isinstance( - unjoined_result, tuple - ), f"{joiner.__name__}.unjoin() did not unjoin the input sample into tuples." - for arr_undo, arr in zip(unjoined_result, input_arrays, strict=True): - assert np.array_equal(arr_undo, arr), f"{joiner.__name__}.unjoin() did not return original arrays." - - -@pytest.mark.parametrize( - ("concat_axis", "input_shapes"), - ( - (0, ((1, 3, 2), (2, 3, 2))), - (1, ((3, 1, 2), (3, 2, 2))), - (2, ((3, 2, 1), (3, 2, 2))), - ), -) -def test_concatenate(concat_axis, input_shapes, concat_array_data): - """Tests that join.Concatenate reproduces np.concatenate behaviour.""" - - input_arrays = tuple(arr.reshape(shape) for arr, shape in zip(concat_array_data, input_shapes)) - - stack = join.Concatenate(axis=concat_axis) - result = stack.join(input_arrays) - expected = np.concatenate(input_arrays, axis=concat_axis) - assert np.array_equal( - result, expected - ), f"Concatenate(axis={concat_axis}) did not reproduce np.concatenate behaviour." - unjoined_result = stack.unjoin(result) - assert isinstance( - unjoined_result, tuple - ), f"Concatenate(axis={concat_axis}).unjoin() did not unjoin the input sample into tuples." - for arr_undo, arr in zip(unjoined_result, input_arrays, strict=True): - assert np.array_equal( - arr_undo, arr, f"Concatenate(axis={concat_axis}).unjoin() did not return original arrays." - ) +def test_join(joiner, equiv_op, input_arrays): + """Tests that joiners reproduce their numpy equivalents and are reversible.""" + name = type(joiner).__name__ + + result = joiner.join(input_arrays) + expected = equiv_op(input_arrays) + assert np.array_equal(result, expected), f"{name}.join() did not reproduce expected behaviour." + + unjoined = joiner.unjoin(result) + assert isinstance(unjoined, tuple), f"{name}.unjoin() did not return a tuple." + for arr_undo, arr in zip(unjoined, input_arrays, strict=True): + assert np.array_equal(arr_undo, arr), f"{name}.unjoin() did not return original arrays." From da8f6511a6c2c5ee6a13474f8f2cfaa6704e8f7d Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Wed, 18 Mar 2026 14:17:55 +1100 Subject: [PATCH 4/7] implement xarray merge unjoin implemented only for dataarray and dataset inputs, despite merge accepting dicts too. The implementation should preserve attributes. --- .../pipeline/operations/xarray/join.py | 19 +++- .../operations/xarray/test_xarray_join.py | 91 +++++++++++++++++++ 2 files changed, 107 insertions(+), 3 deletions(-) create mode 100644 packages/pipeline/tests/operations/xarray/test_xarray_join.py diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/join.py b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/join.py index c9c17623..a4a9f68f 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/join.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/join.py @@ -27,7 +27,7 @@ class Merge(Joiner): """ Merge a tuple of xarray object's. - Currently cannot undo this operation + Currently can only undo this operation with xr.Dataset and xr.DataArray inputs. """ _override_interface = "Serial" @@ -36,13 +36,26 @@ def __init__(self, merge_kwargs: Optional[dict[str, Any]] = None): super().__init__() self.record_initialisation() self._merge_kwargs = merge_kwargs + self._input_structure: list[tuple[Union[str, list[str]], dict]] = [] def join(self, sample: tuple[Union[xr.Dataset, xr.DataArray], ...]) -> xr.Dataset: """Join sample""" + self._input_structure = [ + (item.name, item.attrs) if isinstance(item, xr.DataArray) else (list(item.data_vars), item.attrs) + for item in sample + ] return xr.merge(sample, **(self._merge_kwargs or {})) - def unjoin(self, sample: Any) -> tuple: - return super().unjoin(sample) + def unjoin(self, sample: xr.Dataset) -> tuple: + result = [] + for keys, attrs in self._input_structure: + if isinstance(keys, str): + da = sample[keys] + da.attrs = attrs + result.append(da) + else: + result.append(xr.Dataset({k: sample[k] for k in keys}, attrs=attrs)) + return tuple(result) class LatLonInterpolate(Joiner): diff --git a/packages/pipeline/tests/operations/xarray/test_xarray_join.py b/packages/pipeline/tests/operations/xarray/test_xarray_join.py new file mode 100644 index 00000000..d472d7ed --- /dev/null +++ b/packages/pipeline/tests/operations/xarray/test_xarray_join.py @@ -0,0 +1,91 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyearthtools.pipeline.operations.xarray.join import Merge + +import numpy as np +import xarray as xr + + +def test_merge(): + coords = {"x": [1, 2, 3], "y": [4, 5, 6]} + + da = xr.DataArray( + np.arange(9).reshape(3, 3), + coords=coords, + dims=["x", "y"], + name="alpha", + attrs={"source": "model", "units": "K"}, + ) + + ds = xr.Dataset( + { + "beta": xr.DataArray(np.arange(9, 18).reshape(3, 3), coords=coords, dims=["x", "y"]), + "gamma": xr.DataArray(np.arange(18, 27).reshape(3, 3), coords=coords, dims=["x", "y"]), + }, + attrs={"source": "model", "resolution": "1deg"}, # "source" overlaps with da + ) + + sample = (da, ds) + + joiner = Merge() + + result = joiner.join(sample) + + assert result["alpha"].equals(da), "Merge.join didn't merge objects correctly." + assert result["beta"].equals(ds["beta"]), "Merge.join didn't merge objects correctly." + assert result["gamma"].equals(ds["gamma"]), "Merge.join didn't merge objects correctly." + assert result.attrs == da.attrs, "Merge.join result didn't preserve first object's attributes" + assert result.attrs != ds.attrs, "Merge.join didn't discard second object's attributes" + + unjoined = joiner.unjoin(result) + + assert isinstance(unjoined, tuple), "Merge.unjoin didn't result in a tuple." + for d_undo, d_orig in zip(unjoined, sample, strict=True): + assert isinstance(d_undo, type(d_orig)) + assert d_undo.equals(d_orig), "Merge.unjoin didn't restore objects." + assert d_undo.attrs == d_orig.attrs, "Merge.unjoin didn't preserve attributes." + + # test passing kwargs to xr.merge + # should combine attributes + joiner = Merge(merge_kwargs={"combine_attrs": "no_conflicts"}) + + result = joiner.join(sample) + + assert result["alpha"].equals(da), 'passing combine_attrs="no_conflict" to Merge didn\'t merge object correctly.' + assert result["beta"].equals( + ds["beta"] + ), 'passing combine_attrs="no_conflict" to Merge didn\'t merge object correctly.' + assert result["gamma"].equals( + ds["gamma"] + ), 'passing combine_attrs="no_conflict" to Merge didn\'t merge object correctly.' + assert result.attrs == ( + da.attrs | ds.attrs + ), 'passing combine_attrs="no_conflict" to Merge didn\'t unionise attributes.' + + unjoined = joiner.unjoin(result) + + assert isinstance( + unjoined, tuple + ), 'passing combine_attrs="no_conflict" to Merge didn\'t result in a tuple when unjoining.' + for d_undo, d_orig in zip(unjoined, sample, strict=True): + assert isinstance( + d_undo, type(d_orig) + ), "passing combine_attrs=\"no_conflict\" to Merge didn't preserve object's type when unjoining." + assert d_undo.equals( + d_orig + ), 'passing combine_attrs="no_conflict" to Merge didn\'t restore object when unjoining.' + assert ( + d_undo.attrs == d_orig.attrs + ), 'passing combine_attrs="no_conflict" to Merge didn\'t preserve attributes when unjoining.' From e4be8c8994de797ff940afbe63dd38cc77842458 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Wed, 18 Mar 2026 15:42:07 +1100 Subject: [PATCH 5/7] add tests for xarray LatLonInterpolate claude used to reduce redundant test code. also moved input checking to class init instead of in join method --- .../pipeline/operations/xarray/join.py | 9 ++- .../operations/xarray/test_xarray_join.py | 67 ++++++++++++++++++- 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/join.py b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/join.py index a4a9f68f..1fc68c7c 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/join.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/join.py @@ -81,9 +81,14 @@ def __init__( ): super().__init__() - self.raise_if_dimensions_wrong(reference_dataset) - self.record_initialisation() + + if reference_dataset is None and reference_index is None: + raise ValueError("No reference dataset or reference index set") + elif reference_dataset is not None and reference_index is not None: + raise ValueError("Only one of reference_dataset or reference_index should be set") + elif reference_dataset: + self.raise_if_dimensions_wrong(reference_dataset) self.reference_dataset = reference_dataset self.reference_index = reference_index self.interpolation_method = interpolation_method diff --git a/packages/pipeline/tests/operations/xarray/test_xarray_join.py b/packages/pipeline/tests/operations/xarray/test_xarray_join.py index d472d7ed..8f0d740e 100644 --- a/packages/pipeline/tests/operations/xarray/test_xarray_join.py +++ b/packages/pipeline/tests/operations/xarray/test_xarray_join.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pyearthtools.pipeline.operations.xarray.join import Merge +from pyearthtools.pipeline.operations.xarray.join import Merge, LatLonInterpolate import numpy as np import xarray as xr +import pytest + def test_merge(): coords = {"x": [1, 2, 3], "y": [4, 5, 6]} @@ -89,3 +91,66 @@ def test_merge(): assert ( d_undo.attrs == d_orig.attrs ), 'passing combine_attrs="no_conflict" to Merge didn\'t preserve attributes when unjoining.' + + +def _make_ds(var_name, data, lat, lon, lat_name="latitude", lon_name="longitude"): + """Create a Dataset with lat/lon coords.""" + return xr.Dataset( + {var_name: xr.DataArray(data, dims=[lat_name, lon_name])}, + coords={lat_name: lat, lon_name: lon}, + ) + + +EXPECTED_INTERPOLATED = np.array([[9.0, 9.0, 10.0], [9.0, 9.0, 10.0], [11.0, 11.0, 12.0]]) + + +@pytest.mark.parametrize( + ("lat_name", "lon_name", "joiner_factory"), + [ + pytest.param( + "latitude", "longitude", lambda ds_ref: LatLonInterpolate(reference_index=0), id="reference_index" + ), + pytest.param("lat", "lon", lambda ds_ref: LatLonInterpolate(reference_dataset=ds_ref), id="reference_dataset"), + ], +) +def test_latlon_interpolate_join(lat_name, lon_name, joiner_factory): + """Tests that LatLonInterpolate merges and interpolates datasets to the reference grid.""" + ds_ref = _make_ds("ds1", np.arange(9).reshape(3, 3), [0.0, 1.0, 2.0], [0.0, 1.0, 2.0], lat_name, lon_name) + ds_coarse = _make_ds("ds2", np.arange(9, 13).reshape(2, 2), [-0.25, 2.25], [-0.25, 2.25], lat_name, lon_name) + + result = joiner_factory(ds_ref).join((ds_ref, ds_coarse)) + + assert "ds1" in result.data_vars + assert "ds2" in result.data_vars + # astype is needed because interp changes datatype + assert ds_ref["ds1"].equals(result["ds1"].astype(int)) + assert ds_ref.coords.equals(result["ds2"].coords) + assert np.array_equal(result["ds2"].values, EXPECTED_INTERPOLATED) + + +def test_latlon_interpolate_errors(): + """Tests that LatLonInterpolate raises errors for invalid configurations.""" + ds_ref = _make_ds("ds1", np.arange(9).reshape(3, 3), [0.0, 1.0, 2.0], [0.0, 1.0, 2.0]) + ds_coarse = _make_ds("ds2", np.arange(9, 13).reshape(2, 2), [-0.25, 2.25], [-0.25, 2.25]) + + with pytest.raises(ValueError): + LatLonInterpolate() + + with pytest.raises(ValueError): + LatLonInterpolate(reference_dataset=ds_ref, reference_index=0) + + with pytest.raises(ValueError): + LatLonInterpolate(reference_dataset=ds_ref.rename({"latitude": "abc", "longitude": "123"})) + + joiner = LatLonInterpolate(reference_index=0) + joiner.reference_index = None + with pytest.raises(ValueError): + joiner.join((ds_ref, ds_coarse)) + + +def test_latlon_interpolate_unjoin_not_implemented(): + """Tests that LatLonInterpolate.unjoin raises NotImplementedError.""" + ds_ref = _make_ds("ds1", np.arange(9).reshape(3, 3), [0.0, 1.0, 2.0], [0.0, 1.0, 2.0]) + joiner = LatLonInterpolate(reference_dataset=ds_ref) + with pytest.raises(NotImplementedError): + joiner.unjoin(ds_ref) From 339138e64da2bea1b45c9d0f202148e39da8387a Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Tue, 24 Mar 2026 10:27:21 +1100 Subject: [PATCH 6/7] complete tests for xarray join covers Concatenate InterpLike GeospatialTimeSeriesMerge --- .../operations/xarray/test_xarray_join.py | 200 ++++++++++++++++-- 1 file changed, 180 insertions(+), 20 deletions(-) diff --git a/packages/pipeline/tests/operations/xarray/test_xarray_join.py b/packages/pipeline/tests/operations/xarray/test_xarray_join.py index 8f0d740e..9d145828 100644 --- a/packages/pipeline/tests/operations/xarray/test_xarray_join.py +++ b/packages/pipeline/tests/operations/xarray/test_xarray_join.py @@ -12,7 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pyearthtools.pipeline.operations.xarray.join import Merge, LatLonInterpolate +from pyearthtools.pipeline.operations.xarray.join import ( + Merge, + LatLonInterpolate, + GeospatialTimeSeriesMerge, + InterpLike, + Concatenate, +) import numpy as np import xarray as xr @@ -93,15 +99,18 @@ def test_merge(): ), 'passing combine_attrs="no_conflict" to Merge didn\'t preserve attributes when unjoining.' -def _make_ds(var_name, data, lat, lon, lat_name="latitude", lon_name="longitude"): - """Create a Dataset with lat/lon coords.""" +def _make_ds(var_name, data, lat, lon, time=None, lat_name="latitude", lon_name="longitude"): + """Create a Dataset with latitude, longitude, and time coords.""" + time = time or [0] return xr.Dataset( - {var_name: xr.DataArray(data, dims=[lat_name, lon_name])}, - coords={lat_name: lat, lon_name: lon}, + {var_name: xr.DataArray(data, dims=["time", lat_name, lon_name])}, + coords={"time": time, lat_name: lat, lon_name: lon}, ) -EXPECTED_INTERPOLATED = np.array([[9.0, 9.0, 10.0], [9.0, 9.0, 10.0], [11.0, 11.0, 12.0]]) +@pytest.fixture +def ds_ref(): + return _make_ds(var_name="var_ref", data=np.arange(9).reshape(1, 3, 3), lat=[0.0, 1.0, 2.0], lon=[0.0, 1.0, 2.0]) @pytest.mark.parametrize( @@ -115,23 +124,38 @@ def _make_ds(var_name, data, lat, lon, lat_name="latitude", lon_name="longitude" ) def test_latlon_interpolate_join(lat_name, lon_name, joiner_factory): """Tests that LatLonInterpolate merges and interpolates datasets to the reference grid.""" - ds_ref = _make_ds("ds1", np.arange(9).reshape(3, 3), [0.0, 1.0, 2.0], [0.0, 1.0, 2.0], lat_name, lon_name) - ds_coarse = _make_ds("ds2", np.arange(9, 13).reshape(2, 2), [-0.25, 2.25], [-0.25, 2.25], lat_name, lon_name) + ds_ref = _make_ds( + var_name="var1", + data=np.arange(9).reshape(1, 3, 3), + lat=[0.0, 1.0, 2.0], + lon=[0.0, 1.0, 2.0], + lat_name=lat_name, + lon_name=lon_name, + ) + ds_coarse = _make_ds( + var_name="var2", + data=np.arange(9, 13).reshape(1, 2, 2), + lat=[-0.25, 2.25], + lon=[-0.25, 2.25], + lat_name=lat_name, + lon_name=lon_name, + ) result = joiner_factory(ds_ref).join((ds_ref, ds_coarse)) - assert "ds1" in result.data_vars - assert "ds2" in result.data_vars + assert "var1" in result.data_vars + assert "var2" in result.data_vars # astype is needed because interp changes datatype - assert ds_ref["ds1"].equals(result["ds1"].astype(int)) - assert ds_ref.coords.equals(result["ds2"].coords) - assert np.array_equal(result["ds2"].values, EXPECTED_INTERPOLATED) + assert ds_ref["var1"].equals(result["var1"].astype(int)) + assert ds_ref.coords.equals(result["var2"].coords) + + expected_interp = np.array([[9.0, 9.0, 10.0], [9.0, 9.0, 10.0], [11.0, 11.0, 12.0]]) + assert np.array_equal(result["var2"].squeeze("time").values, expected_interp) -def test_latlon_interpolate_errors(): +def test_latlon_interpolate_errors(ds_ref): """Tests that LatLonInterpolate raises errors for invalid configurations.""" - ds_ref = _make_ds("ds1", np.arange(9).reshape(3, 3), [0.0, 1.0, 2.0], [0.0, 1.0, 2.0]) - ds_coarse = _make_ds("ds2", np.arange(9, 13).reshape(2, 2), [-0.25, 2.25], [-0.25, 2.25]) + ds_coarse = _make_ds(var_name="var2", data=np.arange(9, 13).reshape(1, 2, 2), lat=[-0.25, 2.25], lon=[-0.25, 2.25]) with pytest.raises(ValueError): LatLonInterpolate() @@ -147,10 +171,146 @@ def test_latlon_interpolate_errors(): with pytest.raises(ValueError): joiner.join((ds_ref, ds_coarse)) + # unjoin not implemented + with pytest.raises(NotImplementedError): + joiner.unjoin(ds_ref) + + +def test_geospatial_timeseries_merge_join(ds_ref): + """Tests that GeospatialTimeSeriesMerge interpolates and merges datasets.""" + da_coarse = xr.DataArray( + np.arange(9, 13).reshape(1, 2, 2), + dims=["time", "latitude", "longitude"], + coords={"time": [0], "latitude": [-0.25, 2.25], "longitude": [-0.25, 2.25]}, + name="var2", + ) + + joiner = GeospatialTimeSeriesMerge(reference_index=0) + result = joiner.join((ds_ref, da_coarse)) + + assert "var_ref" in result.data_vars + assert "var2" in result.data_vars + assert ( + result["var2"].shape == ds_ref["var_ref"].shape + ), "GeospatialTimeSeriesMerge did not interpolate to reference grid shape." + assert tuple(result.latitude.values) == tuple(ds_ref.latitude.values) + assert tuple(result.longitude.values) == tuple(ds_ref.longitude.values) + + +def test_geospatial_timeseries_merge_errors(ds_ref): + """Tests that GeospatialTimeSeriesMerge raises errors for invalid inputs.""" + ds_no_time = _make_ds( + var_name="var2", data=np.arange(9, 18).reshape(1, 3, 3), lat=ds_ref.latitude.values, lon=ds_ref.longitude.values + ).drop_dims("time") + + # fail when trying to join without setting reference + with pytest.raises(ValueError): + GeospatialTimeSeriesMerge().join((ds_ref, ds_ref)) + + joiner = GeospatialTimeSeriesMerge(reference_dataset=ds_ref) + # fail when trying to join datasets and one doesn't have the time dim + with pytest.raises(ValueError): + joiner.join((ds_no_time, ds_ref)) + with pytest.raises(ValueError): + joiner.join((ds_ref, ds_no_time)) + + # fail when trying to unjoin + with pytest.raises(NotImplementedError): + GeospatialTimeSeriesMerge().unjoin(None) + + +def test_interplike(ds_ref): + + da_coarse = xr.DataArray( + np.arange(9, 13).reshape(2, 2), + dims=["latitude", "longitude"], + coords={"latitude": [-0.25, 2.25], "longitude": [-0.25, 2.25]}, + name="var1", + ) + da_fine = xr.DataArray( + np.arange(13, 29).reshape(4, 4), + dims=["latitude", "longitude"], + coords={"latitude": [0.0, 0.67, 1.33, 2.0], "longitude": [0.0, 0.67, 1.33, 2.0]}, + name="var2", + ) + + # test default interpolation method (nearest) + joiner = InterpLike(reference_dataset=ds_ref) + result = joiner.join([da_coarse, da_fine]) + expected_nearest = { + "var1": np.array([[9.0, 9.0, 10.0], [9.0, 9.0, 10.0], [11.0, 11.0, 12.0]]), + "var2": np.array([[13.0, 14.0, 16.0], [17.0, 18.0, 20.0], [25.0, 26.0, 28.0]]), + } + for ds in ("var1", "var2"): + assert ds in result.data_vars, f"{ds} missing from joined dataset" + assert (1,) + result[ds].shape == ds_ref[ + "var_ref" + ].shape, f"InterpLike didn't interpolate {ds} onto ds_ref's coords" + assert np.array_equal(expected_nearest[ds], result[ds].values), f"Interplike didn't interpolate {ds}'s values" + + # test linear interpolation method + joiner = InterpLike(reference_dataset=ds_ref, method="linear") + result = joiner.join([da_coarse, da_fine]) + expected_linear = { + "var1": np.array([[9.3, 9.7, 10.1], [10.1, 10.5, 10.9], [10.9, 11.3, 11.7]]), + "var2": np.array([[13.0, 14.5, 16.0], [19.0, 20.5, 22.0], [25.0, 26.5, 28.0]]), + } + for ds in ("var1", "var2"): + assert np.allclose(expected_linear[ds], result[ds].values) + + # test reference index + joiner = InterpLike(reference_index=0) + result = joiner.join([ds_ref, da_coarse, da_fine]) + assert "var_ref" in result.data_vars, "InterpLike didn't preserve reference dataset" + assert ds_ref["var_ref"].equals(result["var_ref"].astype(int)), "InterpLike didn't reproduce reference" + for ds in ("var1", "var2"): + assert ds in result.data_vars, f"{ds} missing from joined dataset" + assert (1,) + result[ds].shape == ds_ref[ + "var_ref" + ].shape, f"InterpLike didn't interpolate {ds} onto ds_ref's coords" + assert np.array_equal(expected_nearest[ds], result[ds].values), f"Interplike didn't interpolate {ds}'s values" + + +def test_interplike_errors(ds_ref): + joiner = InterpLike() + with pytest.raises(ValueError): + joiner.join([ds_ref]) -def test_latlon_interpolate_unjoin_not_implemented(): - """Tests that LatLonInterpolate.unjoin raises NotImplementedError.""" - ds_ref = _make_ds("ds1", np.arange(9).reshape(3, 3), [0.0, 1.0, 2.0], [0.0, 1.0, 2.0]) - joiner = LatLonInterpolate(reference_dataset=ds_ref) with pytest.raises(NotImplementedError): joiner.unjoin(ds_ref) + + +def test_concatenate(): + # test with dataarrays + da1 = xr.DataArray(np.arange(6).reshape((2, 3)), coords={"x": range(2), "y": range(3)}) + da2 = xr.DataArray(np.arange(6, 18).reshape((4, 3)), coords={"x": range(4), "y": range(3)}) + joiner = Concatenate(concat_dim="x") + result = joiner.join([da1, da2]) + assert np.array_equal(result.values, np.arange(18).reshape((6, 3))) + + # test with datasets + ds1 = xr.Dataset({"var1": da1}) + ds2 = xr.Dataset({"var2": da2}) + result = joiner.join([ds1, ds2]) + expected = np.vstack((da1.values, np.full((4, 3), np.nan))) + assert np.array_equal(expected, result["var1"].values, equal_nan=True) + expected = np.vstack((np.full((2, 3), np.nan), da2.values)) + assert np.array_equal(expected, result["var2"].values, equal_nan=True) + + # test concat kwargs (dim kwarg should be ignored) + joiner = Concatenate(concat_dim="x", concat_kwargs={"fill_value": 0, "dim": "y"}) + result = joiner.join([ds1, ds2]) + expected = np.vstack((da1.values, np.zeros((4, 3)))) + assert np.array_equal( + expected, + result["var1"].values, + ) + expected = np.vstack((np.zeros((2, 3)), da2.values)) + assert np.array_equal( + expected, + result["var2"].values, + ) + + # unjoin not implemented: returns the input + joiner = Concatenate(concat_dim="x") + assert ds1.equals(joiner.unjoin(ds1)) From 197ca7f98f51dacd49d37d2b44fb08cfb5a5cae7 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Tue, 24 Mar 2026 10:33:00 +1100 Subject: [PATCH 7/7] updated doc strings on unimplemented unjoins --- .../src/pyearthtools/pipeline/operations/xarray/join.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/join.py b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/join.py index 1fc68c7c..7870ece7 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/join.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/join.py @@ -67,6 +67,8 @@ class LatLonInterpolate(Joiner): It assumed the dimensions 'latitude', 'longitude', 'time', and 'level' will be present. 'lat' or 'lon' may also be used for convenience. + + Currently cannot undo this operation. Raises NotImplementedError if undo is attempted. """ _override_interface = "Serial" @@ -172,6 +174,8 @@ class GeospatialTimeSeriesMerge(Joiner): This joiner is more strict about the merging and interpolating, and also raises more informative error messages when it runs into trouble. + + Currently cannot undo this operation. Raises NotImplementedError if undo is attempted. """ _override_interface = "Serial" @@ -238,7 +242,7 @@ class InterpLike(Joiner): """ Merge a tuple of xarray object's. - Currently cannot undo this operation + Currently cannot undo this operation. Raises NotImplementedError if undo is attempted. """ _override_interface = "Serial" @@ -280,7 +284,7 @@ class Concatenate(Joiner): """ Concatenate a tuple of xarray object's - Currently cannot undo this operation + Currently cannot undo this operation. Unjoining a sample returns the same sample. """ _override_interface = "Serial"