diff --git a/dataframely/_base_schema.py b/dataframely/_base_schema.py index 7ae64e3..870596f 100644 --- a/dataframely/_base_schema.py +++ b/dataframely/_base_schema.py @@ -8,7 +8,7 @@ from abc import ABCMeta from copy import copy from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import Any import polars as pl @@ -24,6 +24,7 @@ _COLUMN_ATTR = "__dataframely_columns__" _RULE_ATTR = "__dataframely_rules__" +_USE_ATTR_NAMES = "__dataframely_use_attribute_names__" ORIGINAL_COLUMN_PREFIX = "__DATAFRAMELY_ORIGINAL__" @@ -95,13 +96,27 @@ def __new__( bases: tuple[type[object], ...], namespace: dict[str, Any], *args: Any, + use_attribute_names: bool = False, **kwargs: Any, ) -> SchemaMeta: result = Metadata() + for base in bases: result.update(mcs._get_metadata_recursively(base)) - result.update(mcs._get_metadata(namespace)) + + # Copy columns defined in current namespace to avoid mutating shared objects. + # Set _name based on this class's use_attribute_names setting. + for attr, value in list(namespace.items()): + if isinstance(value, Column) and not attr.startswith("__"): + col = copy(value) + col._name = attr if use_attribute_names else (col.alias or attr) + namespace[attr] = col + + result.update( + mcs._get_metadata(namespace, use_attribute_names=use_attribute_names) + ) namespace[_COLUMN_ATTR] = result.columns + namespace[_USE_ATTR_NAMES] = use_attribute_names cls = super().__new__(mcs, name, bases, namespace, *args, **kwargs) # Assign rules retroactively as we only encounter rule factories in the result @@ -177,33 +192,29 @@ def __new__( return cls - if not TYPE_CHECKING: - # Only define __getattribute__ at runtime to allow type checkers to properly - # validate attribute access. When TYPE_CHECKING is True, type checkers will use - # the default metaclass behavior which correctly identifies non-existent attributes. - def __getattribute__(cls, name: str) -> Any: - val = super().__getattribute__(name) - # Dynamically set the name of the column if it is a `Column` instance. - if isinstance(val, Column): - val._name = val.alias or name - return val - @staticmethod def _get_metadata_recursively(kls: type[object]) -> Metadata: result = Metadata() for base in kls.__bases__: result.update(SchemaMeta._get_metadata_recursively(base)) - result.update(SchemaMeta._get_metadata(kls.__dict__)) # type: ignore + use_attr_names = getattr(kls, _USE_ATTR_NAMES, False) + result.update( + SchemaMeta._get_metadata(kls.__dict__, use_attribute_names=use_attr_names) # type: ignore + ) return result @staticmethod - def _get_metadata(source: dict[str, Any]) -> Metadata: + def _get_metadata( + source: dict[str, Any], *, use_attribute_names: bool = False + ) -> Metadata: result = Metadata() for attr, value in { k: v for k, v in source.items() if not k.startswith("__") }.items(): if isinstance(value, Column): - result.columns[value.alias or attr] = value + # When use_attribute_names=True, use attr as key; otherwise use alias or attr + col_name = attr if use_attribute_names else (value.alias or attr) + result.columns[col_name] = value if isinstance(value, RuleFactory): # We must ensure that custom rules do not clash with internal rules. if attr == "primary_key": @@ -238,11 +249,7 @@ def column_names(cls) -> list[str]: @classmethod def columns(cls) -> dict[str, Column]: """The column definitions of this schema.""" - columns: dict[str, Column] = getattr(cls, _COLUMN_ATTR) - for name in columns.keys(): - # Dynamically set the name of the columns. - columns[name]._name = name - return columns + return getattr(cls, _COLUMN_ATTR) @classmethod def primary_key(cls) -> list[str]: @@ -258,3 +265,12 @@ def _validation_rules(cls, *, with_cast: bool) -> dict[str, Rule]: @classmethod def _schema_validation_rules(cls) -> dict[str, Rule]: return getattr(cls, _RULE_ATTR) + + @classmethod + def _alias_mapping(cls) -> dict[str, str]: + """Mapping from aliases to column identifier (attribute).""" + return { + col.alias: col._name + for col in cls.columns().values() + if col.alias is not None and col.alias != col._name + } diff --git a/dataframely/schema.py b/dataframely/schema.py index b64f061..c633719 100644 --- a/dataframely/schema.py +++ b/dataframely/schema.py @@ -820,6 +820,32 @@ def cast( return lf.collect() # type: ignore return lf # type: ignore + @overload + @classmethod + def undo_aliases(cls, df: pl.DataFrame, /) -> pl.DataFrame: ... + + @overload + @classmethod + def undo_aliases(cls, df: pl.LazyFrame, /) -> pl.LazyFrame: ... + + @classmethod + def undo_aliases( + cls, df: pl.DataFrame | pl.LazyFrame, / + ) -> pl.DataFrame | pl.LazyFrame: + """Rename columns from their alias names to their attribute names. + + This method renames columns that have aliases defined, mapping from the + alias (e.g., "price ($)") to the attribute name (e.g., "price"). + + Args: + df: The data frame whose columns should be renamed. + + Returns: + The data frame with columns renamed from aliases to attribute names. + Columns without aliases are left unchanged. + """ + return df.rename(cls._alias_mapping(), strict=False) + # --------------------------------- SERIALIZATION -------------------------------- # @classmethod diff --git a/tests/columns/test_alias.py b/tests/columns/test_alias.py index 0dd364a..9b8d2af 100644 --- a/tests/columns/test_alias.py +++ b/tests/columns/test_alias.py @@ -36,3 +36,135 @@ def test_alias_unset() -> None: no_alias_col = dy.Int32() assert no_alias_col.alias is None assert no_alias_col.name == "" + + +def test_alias_use_attribute_names() -> None: + class MySchema1(dy.Schema, use_attribute_names=True): + price = dy.Int64(alias="price ($)") + + class MySchema2(MySchema1, use_attribute_names=False): + price2 = dy.Int64(alias="price2 ($)") + + class MySchema3(MySchema2): + price3 = dy.Int64(alias="price3 ($)") + + class MySchema4(MySchema3, use_attribute_names=True): + price4 = dy.Int64(alias="price4 ($)") + + class MySchema5(MySchema4): + price5 = dy.Int64(alias="price5 ($)") + + assert MySchema5.price.name == "price" + assert MySchema5.price2.name == "price2 ($)" + assert MySchema5.price3.name == "price3 ($)" + assert MySchema5.price4.name == "price4" + assert MySchema5.price5.name == "price5 ($)" + + assert MySchema5.column_names() == [ + "price", + "price2 ($)", + "price3 ($)", + "price4", + "price5 ($)", + ] + + +def test_alias_mapping() -> None: + class MySchema(dy.Schema, use_attribute_names=True): + price = dy.Int64(alias="price ($)") + production_rank = dy.Int64(alias="Production rank") + no_alias = dy.Int64() + + # _alias_mapping returns alias -> attribute name mapping + assert MySchema._alias_mapping() == { + "price ($)": "price", + "Production rank": "production_rank", + } + + +def test_alias_mapping_empty() -> None: + class NoAliasSchema(dy.Schema): + a = dy.Int64() + b = dy.String() + + # No aliases means empty mapping + assert NoAliasSchema._alias_mapping() == {} + + +def test_undo_aliases() -> None: + class MySchema(dy.Schema, use_attribute_names=True): + price = dy.Int64(alias="price ($)") + production_rank = dy.Int64(alias="Production rank") + + df = pl.DataFrame({"price ($)": [100], "Production rank": [1]}) + result = MySchema.undo_aliases(df) + assert result.columns == ["price", "production_rank"] + + +def test_undo_aliases_lazy() -> None: + class MySchema(dy.Schema, use_attribute_names=True): + price = dy.Int64(alias="price ($)") + + lf = pl.LazyFrame({"price ($)": [100]}) + result = MySchema.undo_aliases(lf).collect() + assert result.columns == ["price"] + + +def test_inherited_column_keeps_parent_name() -> None: + """Inherited columns keep their _name from the parent class.""" + + class Parent(dy.Schema, use_attribute_names=True): + price = dy.Int64(alias="price ($)") + + class Child(Parent, use_attribute_names=False): + quantity = dy.Int64(alias="qty") + + # Parent column keeps its name based on parent's use_attribute_names=True + assert Parent.price.name == "price" + assert Child.price.name == "price" + + # Child's own column uses its use_attribute_names=False setting + assert Child.quantity.name == "qty" + + # column_names reflects the correct names + assert Parent.column_names() == ["price"] + assert Child.column_names() == ["price", "qty"] + + +def test_shared_column_object_is_copied() -> None: + """When a column object is reused, each schema gets its own copy.""" + col = dy.Int64(alias="price ($)") + + class Schema1(dy.Schema, use_attribute_names=True): + price = col + + class Schema2(dy.Schema, use_attribute_names=False): + price = col + + # Each schema has its own copy with the correct _name + assert Schema1.price.name == "price" + assert Schema2.price.name == "price ($)" + + # The original column is not mutated + assert col._name == "" + + +def test_shared_column_in_inheritance() -> None: + """Shared column used in parent and child schemas.""" + col = dy.Int64(alias="price ($)") + + class Parent(dy.Schema, use_attribute_names=True): + price = col + + class Child(Parent, use_attribute_names=False): + price2 = col + + # Parent's column uses parent's setting + assert Parent.price.name == "price" + # Inherited column in child keeps parent's setting + assert Child.price.name == "price" + # Child's own column uses child's setting + assert Child.price2.name == "price ($)" + + assert Parent.column_names() == ["price"] + assert Child.column_names() == ["price", "price ($)"]