diff --git a/dataframely/_base_schema.py b/dataframely/_base_schema.py index 7ae64e3d..b875e6bf 100644 --- a/dataframely/_base_schema.py +++ b/dataframely/_base_schema.py @@ -21,12 +21,12 @@ else: from typing_extensions import Self - _COLUMN_ATTR = "__dataframely_columns__" _RULE_ATTR = "__dataframely_rules__" ORIGINAL_COLUMN_PREFIX = "__DATAFRAMELY_ORIGINAL__" + # --------------------------------------- UTILS -------------------------------------- # @@ -84,6 +84,25 @@ class Metadata: rules: dict[str, RuleFactory] = field(default_factory=dict) def update(self, other: Self) -> None: + """Merge another Metadata instance into this one. + + Overlapping keys are allowed if and only if they refer to the *same* underlying + object. This accommodates multiple-inheritance / diamond patterns where the same + base schema is visited more than once. + """ + # Detect conflicting column definitions: same name, different Column instance + duplicated_column_names = self.columns.keys() & other.columns.keys() + conflicting_columns = { + name + for name in duplicated_column_names + if self.columns[name] is not other.columns[name] + } + if conflicting_columns: + raise ImplementationError( + f"Columns {conflicting_columns} are duplicated with conflicting definitions." + ) + + # All clear self.columns.update(other.columns) self.rules.update(other.rules) @@ -203,6 +222,8 @@ def _get_metadata(source: dict[str, Any]) -> Metadata: k: v for k, v in source.items() if not k.startswith("__") }.items(): if isinstance(value, Column): + if (col_name := value.alias or attr) in result.columns: + raise ImplementationError(f"Column {col_name!r} is duplicated.") result.columns[value.alias or attr] = value if isinstance(value, RuleFactory): # We must ensure that custom rules do not clash with internal rules. diff --git a/tests/columns/test_alias.py b/tests/columns/test_alias.py index 0dd364a7..7276f049 100644 --- a/tests/columns/test_alias.py +++ b/tests/columns/test_alias.py @@ -2,8 +2,10 @@ # SPDX-License-Identifier: BSD-3-Clause import polars as pl +import pytest import dataframely as dy +from dataframely.exc import ImplementationError class AliasSchema(dy.Schema): @@ -36,3 +38,21 @@ def test_alias_unset() -> None: no_alias_col = dy.Int32() assert no_alias_col.alias is None assert no_alias_col.name == "" + + +def test_duplicate_alias_same_schema() -> None: + with pytest.raises(ImplementationError, match="'a' is duplicated"): + + class MySchema(dy.Schema): + a = dy.Int64(alias="a") + b = dy.String(alias="a") + + +def test_duplicate_alias_inherited_schema() -> None: + class MySchema(dy.Schema): + a = dy.Int64(alias="a") + + with pytest.raises(ImplementationError, match="'a'.*duplicated"): + + class MySchema2(MySchema): + b = dy.Int64(alias="a")