Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 37 additions & 21 deletions dataframely/_base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from abc import ABCMeta
from copy import copy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from typing import Any

import polars as pl

Expand All @@ -24,6 +24,7 @@

_COLUMN_ATTR = "__dataframely_columns__"
_RULE_ATTR = "__dataframely_rules__"
_USE_ATTR_NAMES = "__dataframely_use_attribute_names__"

ORIGINAL_COLUMN_PREFIX = "__DATAFRAMELY_ORIGINAL__"

Expand Down Expand Up @@ -95,13 +96,27 @@ def __new__(
bases: tuple[type[object], ...],
namespace: dict[str, Any],
*args: Any,
use_attribute_names: bool = False,
**kwargs: Any,
) -> SchemaMeta:
result = Metadata()

for base in bases:
result.update(mcs._get_metadata_recursively(base))
result.update(mcs._get_metadata(namespace))

# Copy columns defined in current namespace to avoid mutating shared objects.
# Set _name based on this class's use_attribute_names setting.
for attr, value in list(namespace.items()):
if isinstance(value, Column) and not attr.startswith("__"):
col = copy(value)
col._name = attr if use_attribute_names else (col.alias or attr)
namespace[attr] = col

result.update(
mcs._get_metadata(namespace, use_attribute_names=use_attribute_names)
)
namespace[_COLUMN_ATTR] = result.columns
namespace[_USE_ATTR_NAMES] = use_attribute_names
cls = super().__new__(mcs, name, bases, namespace, *args, **kwargs)

# Assign rules retroactively as we only encounter rule factories in the result
Expand Down Expand Up @@ -177,33 +192,29 @@ def __new__(

return cls

if not TYPE_CHECKING:
# Only define __getattribute__ at runtime to allow type checkers to properly
# validate attribute access. When TYPE_CHECKING is True, type checkers will use
# the default metaclass behavior which correctly identifies non-existent attributes.
def __getattribute__(cls, name: str) -> Any:
val = super().__getattribute__(name)
# Dynamically set the name of the column if it is a `Column` instance.
if isinstance(val, Column):
val._name = val.alias or name
return val

@staticmethod
def _get_metadata_recursively(kls: type[object]) -> Metadata:
result = Metadata()
for base in kls.__bases__:
result.update(SchemaMeta._get_metadata_recursively(base))
result.update(SchemaMeta._get_metadata(kls.__dict__)) # type: ignore
use_attr_names = getattr(kls, _USE_ATTR_NAMES, False)
result.update(
SchemaMeta._get_metadata(kls.__dict__, use_attribute_names=use_attr_names) # type: ignore
)
return result

@staticmethod
def _get_metadata(source: dict[str, Any]) -> Metadata:
def _get_metadata(
source: dict[str, Any], *, use_attribute_names: bool = False
) -> Metadata:
result = Metadata()
for attr, value in {
k: v for k, v in source.items() if not k.startswith("__")
}.items():
if isinstance(value, Column):
result.columns[value.alias or attr] = value
# When use_attribute_names=True, use attr as key; otherwise use alias or attr
col_name = attr if use_attribute_names else (value.alias or attr)
result.columns[col_name] = value
if isinstance(value, RuleFactory):
# We must ensure that custom rules do not clash with internal rules.
if attr == "primary_key":
Expand Down Expand Up @@ -238,11 +249,7 @@ def column_names(cls) -> list[str]:
@classmethod
def columns(cls) -> dict[str, Column]:
"""The column definitions of this schema."""
columns: dict[str, Column] = getattr(cls, _COLUMN_ATTR)
for name in columns.keys():
# Dynamically set the name of the columns.
columns[name]._name = name
return columns
return getattr(cls, _COLUMN_ATTR)

@classmethod
def primary_key(cls) -> list[str]:
Expand All @@ -258,3 +265,12 @@ def _validation_rules(cls, *, with_cast: bool) -> dict[str, Rule]:
@classmethod
def _schema_validation_rules(cls) -> dict[str, Rule]:
return getattr(cls, _RULE_ATTR)

@classmethod
def _alias_mapping(cls) -> dict[str, str]:
"""Mapping from aliases to column identifier (attribute)."""
return {
col.alias: col._name
for col in cls.columns().values()
if col.alias is not None and col.alias != col._name
}
26 changes: 26 additions & 0 deletions dataframely/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,32 @@ def cast(
return lf.collect() # type: ignore
return lf # type: ignore

@overload
@classmethod
def undo_aliases(cls, df: pl.DataFrame, /) -> pl.DataFrame: ...

@overload
@classmethod
def undo_aliases(cls, df: pl.LazyFrame, /) -> pl.LazyFrame: ...

@classmethod
def undo_aliases(
cls, df: pl.DataFrame | pl.LazyFrame, /
) -> pl.DataFrame | pl.LazyFrame:
"""Rename columns from their alias names to their attribute names.

This method renames columns that have aliases defined, mapping from the
alias (e.g., "price ($)") to the attribute name (e.g., "price").

Args:
df: The data frame whose columns should be renamed.

Returns:
The data frame with columns renamed from aliases to attribute names.
Columns without aliases are left unchanged.
"""
return df.rename(cls._alias_mapping(), strict=False)

# --------------------------------- SERIALIZATION -------------------------------- #

@classmethod
Expand Down
132 changes: 132 additions & 0 deletions tests/columns/test_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,135 @@ def test_alias_unset() -> None:
no_alias_col = dy.Int32()
assert no_alias_col.alias is None
assert no_alias_col.name == ""


def test_alias_use_attribute_names() -> None:
class MySchema1(dy.Schema, use_attribute_names=True):
price = dy.Int64(alias="price ($)")

class MySchema2(MySchema1, use_attribute_names=False):
price2 = dy.Int64(alias="price2 ($)")

class MySchema3(MySchema2):
price3 = dy.Int64(alias="price3 ($)")

class MySchema4(MySchema3, use_attribute_names=True):
price4 = dy.Int64(alias="price4 ($)")

class MySchema5(MySchema4):
price5 = dy.Int64(alias="price5 ($)")

assert MySchema5.price.name == "price"
assert MySchema5.price2.name == "price2 ($)"
assert MySchema5.price3.name == "price3 ($)"
assert MySchema5.price4.name == "price4"
assert MySchema5.price5.name == "price5 ($)"

assert MySchema5.column_names() == [
"price",
"price2 ($)",
"price3 ($)",
"price4",
"price5 ($)",
]


def test_alias_mapping() -> None:
class MySchema(dy.Schema, use_attribute_names=True):
price = dy.Int64(alias="price ($)")
production_rank = dy.Int64(alias="Production rank")
no_alias = dy.Int64()

# _alias_mapping returns alias -> attribute name mapping
assert MySchema._alias_mapping() == {
"price ($)": "price",
"Production rank": "production_rank",
}


def test_alias_mapping_empty() -> None:
class NoAliasSchema(dy.Schema):
a = dy.Int64()
b = dy.String()

# No aliases means empty mapping
assert NoAliasSchema._alias_mapping() == {}


def test_undo_aliases() -> None:
class MySchema(dy.Schema, use_attribute_names=True):
price = dy.Int64(alias="price ($)")
production_rank = dy.Int64(alias="Production rank")

df = pl.DataFrame({"price ($)": [100], "Production rank": [1]})
result = MySchema.undo_aliases(df)
assert result.columns == ["price", "production_rank"]


def test_undo_aliases_lazy() -> None:
class MySchema(dy.Schema, use_attribute_names=True):
price = dy.Int64(alias="price ($)")

lf = pl.LazyFrame({"price ($)": [100]})
result = MySchema.undo_aliases(lf).collect()
assert result.columns == ["price"]


def test_inherited_column_keeps_parent_name() -> None:
"""Inherited columns keep their _name from the parent class."""

class Parent(dy.Schema, use_attribute_names=True):
price = dy.Int64(alias="price ($)")

class Child(Parent, use_attribute_names=False):
quantity = dy.Int64(alias="qty")

# Parent column keeps its name based on parent's use_attribute_names=True
assert Parent.price.name == "price"
assert Child.price.name == "price"

# Child's own column uses its use_attribute_names=False setting
assert Child.quantity.name == "qty"

# column_names reflects the correct names
assert Parent.column_names() == ["price"]
assert Child.column_names() == ["price", "qty"]


def test_shared_column_object_is_copied() -> None:
"""When a column object is reused, each schema gets its own copy."""
col = dy.Int64(alias="price ($)")

class Schema1(dy.Schema, use_attribute_names=True):
price = col

class Schema2(dy.Schema, use_attribute_names=False):
price = col

# Each schema has its own copy with the correct _name
assert Schema1.price.name == "price"
assert Schema2.price.name == "price ($)"

# The original column is not mutated
assert col._name == ""


def test_shared_column_in_inheritance() -> None:
"""Shared column used in parent and child schemas."""
col = dy.Int64(alias="price ($)")

class Parent(dy.Schema, use_attribute_names=True):
price = col

class Child(Parent, use_attribute_names=False):
price2 = col

# Parent's column uses parent's setting
assert Parent.price.name == "price"
# Inherited column in child keeps parent's setting
assert Child.price.name == "price"
# Child's own column uses child's setting
assert Child.price2.name == "price ($)"

assert Parent.column_names() == ["price"]
assert Child.column_names() == ["price", "price ($)"]
Loading