diff --git a/dataframely/columns/enum.py b/dataframely/columns/enum.py index 6c63f2c..f1d2699 100644 --- a/dataframely/columns/enum.py +++ b/dataframely/columns/enum.py @@ -33,6 +33,8 @@ def __init__( alias: str | None = None, metadata: dict[str, Any] | None = None, description: str | None = None, + sqlalchemy_use_enum: bool = False, + sqlalchemy_enum_name: str | None = None, ): """ Args: @@ -68,6 +70,17 @@ def __init__( names, the specified alias is the only valid name. metadata: A dictionary of metadata to attach to the column. description: A human-readable description of the column. + sqlalchemy_use_enum: When ``True``, map this column to :class:`sqlalchemy.Enum` + in :meth:`~dataframely.Schema.to_sqlalchemy_columns` instead of + ``CHAR`` / ``VARCHAR``. Use this for PostgreSQL native enum types and + Alembic schema drift detection. Defaults to ``False`` (string columns). + sqlalchemy_enum_name: Optional name for the SQLAlchemy / database enum type + when ``sqlalchemy_use_enum=True``. If omitted and ``categories`` is a + Python :class:`enum.Enum` subclass, SQLAlchemy uses the enum class name + (lowercased). Otherwise the SQL column name from + :meth:`~dataframely.Schema.to_sqlalchemy_columns` is used. For Python + enums, persisted values are the enum members' ``.value`` strings (not + member names), matching :attr:`categories`. """ super().__init__( nullable=nullable, @@ -78,7 +91,11 @@ def __init__( metadata=metadata, description=description, ) + self.sqlalchemy_use_enum = sqlalchemy_use_enum + self.sqlalchemy_enum_name = sqlalchemy_enum_name + self._enum_class: type[enum.Enum] | None = None if isclass(categories) and issubclass(categories, enum.Enum): + self._enum_class = categories categories = (item.value for item in categories) self.categories = list(categories) @@ -91,12 +108,45 @@ def validate_dtype(self, dtype: PolarsDataType) -> bool: return False return self.categories == dtype.categories.to_list() + def sqlalchemy_column(self, name: str, dialect: sa.Dialect) -> sa.Column: + if self.sqlalchemy_use_enum: + column = super().sqlalchemy_column(name, dialect) + column.type = self._sqlalchemy_enum_type(dialect, column_name=name) + return column + return super().sqlalchemy_column(name, dialect) + def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: + if self.sqlalchemy_use_enum: + column_name = self._name or None + return self._sqlalchemy_enum_type(dialect, column_name=column_name) category_lengths = [len(c) for c in self.categories] if all(length == category_lengths[0] for length in category_lengths): return sa.CHAR(category_lengths[0]) return sa.String(max(category_lengths)) + def _sqlalchemy_enum_type( + self, _dialect: sa.Dialect, *, column_name: str | None + ) -> sa_TypeEngine: + length = max(len(c) for c in self.categories) + kwargs: dict[str, Any] = {"length": length} + name = self.sqlalchemy_enum_name + if self._enum_class is not None: + if name is not None: + kwargs["name"] = name + kwargs["values_callable"] = lambda enum: [member.value for member in enum] + return sa.Enum(self._enum_class, **kwargs) + if name is None: + name = column_name + if name is None: + raise ValueError( + "sqlalchemy_enum_name is required for dy.Enum with string categories " + "and sqlalchemy_use_enum=True when not building columns via " + "Schema.to_sqlalchemy_columns(). Alternatively, pass a Python " + "enum.Enum class as categories." + ) + kwargs["name"] = name + return sa.Enum(*self.categories, **kwargs) + @property def pyarrow_dtype(self) -> pa.DataType: if len(self.categories) <= 2**8 - 1: diff --git a/docs/guides/features/sql-generation.md b/docs/guides/features/sql-generation.md index e84d2fd..696e01e 100644 --- a/docs/guides/features/sql-generation.md +++ b/docs/guides/features/sql-generation.md @@ -81,6 +81,27 @@ the maximal length of the string is inferred from the regular expression if poss maximal lengths can be particularly important for primary key columns. Some database systems, such as Microsoft SQL Server, do not allow `VARCHAR(max)` columns (unbounded strings) to be used as primary keys. ``` +## Native SQL enums (optional) + +By default, {class}`~dataframely.Enum` maps to fixed-length `CHAR` or `VARCHAR` columns so stored values remain plain strings. For PostgreSQL setups that use database-level `ENUM` types (for example with Alembic autogenerate), set `sqlalchemy_use_enum=True`: + +```python +from enum import Enum + +import dataframely as dy + + +class Status(str, Enum): + PENDING = "pending" + APPROVED = "approved" + + +class Staged(dy.Schema): + status = dy.Enum(Status, sqlalchemy_use_enum=True) +``` + +When `categories` is a Python `enum.Enum` subclass, SQLAlchemy uses the enum class name (lowercased) as the database enum type name. For string category lists, the SQL column name is used by default; override it with `sqlalchemy_enum_name` if needed. On dialects without native enums (such as Microsoft SQL Server), SQLAlchemy falls back to `VARCHAR` with a check constraint. + ## Collections of multiple tables If you have an entire `dy.Collection`, it's also easy to generate one table for each member table of the collection. diff --git a/tests/column_types/test_enum.py b/tests/column_types/test_enum.py index 85078a3..9de0103 100644 --- a/tests/column_types/test_enum.py +++ b/tests/column_types/test_enum.py @@ -108,3 +108,43 @@ def test_sequences_and_enums( S = create_schema("test", {"x": dy.Enum(categories1)}) df = pl.DataFrame({"x": pl.Series(["a", "b"], dtype=pl.Enum(categories2))}) S.validate(df) + + +def test_matches_sqlalchemy_use_enum() -> None: + expr = pl.element() + assert dy.Enum(["a", "b"]).matches(dy.Enum(["a", "b"]), expr) + assert not dy.Enum(["a", "b"], sqlalchemy_use_enum=True).matches( + dy.Enum(["a", "b"]), expr + ) + assert dy.Enum(["a", "b"], sqlalchemy_use_enum=True).matches( + dy.Enum(["a", "b"], sqlalchemy_use_enum=True), expr + ) + + +def test_matches_sqlalchemy_enum_name() -> None: + expr = pl.element() + assert not dy.Enum( + ["a", "b"], + sqlalchemy_use_enum=True, + sqlalchemy_enum_name="one", + ).matches( + dy.Enum( + ["a", "b"], + sqlalchemy_use_enum=True, + sqlalchemy_enum_name="two", + ), + expr, + ) + + +def test_as_dict_from_dict_sqlalchemy_enum_flags() -> None: + column = dy.Enum( + ["a", "b"], + sqlalchemy_use_enum=True, + sqlalchemy_enum_name="my_enum", + ) + data = column.as_dict(pl.element()) + restored = dy.Enum.from_dict(data) + assert restored.sqlalchemy_use_enum is True + assert restored.sqlalchemy_enum_name == "my_enum" + assert restored.categories == ["a", "b"] diff --git a/tests/columns/test_sqlalchemy_columns.py b/tests/columns/test_sqlalchemy_columns.py index 6731202..9cd8fcf 100644 --- a/tests/columns/test_sqlalchemy_columns.py +++ b/tests/columns/test_sqlalchemy_columns.py @@ -1,6 +1,8 @@ # Copyright (c) QuantCo 2025-2026 # SPDX-License-Identifier: BSD-3-Clause +from enum import Enum + import pytest import dataframely as dy @@ -171,3 +173,64 @@ def test_raise_for_object_column(dialect: Dialect) -> None: NotImplementedError, match="SQL column cannot have 'Object' type." ): dy.Object().sqlalchemy_dtype(dialect) + + +class _Status(str, Enum): + PENDING = "pending" + APPROVED = "approved" + + +@pytest.mark.parametrize( + ("column", "dialect", "datatype"), + [ + ( + dy.Enum(["foo", "bar"], sqlalchemy_use_enum=True), + PGDialect_psycopg2(), + "a", + ), + ( + dy.Enum( + ["foo", "bar"], + sqlalchemy_use_enum=True, + sqlalchemy_enum_name="my_status", + ), + PGDialect_psycopg2(), + "my_status", + ), + (dy.Enum(_Status, sqlalchemy_use_enum=True), PGDialect_psycopg2(), "_status"), + ( + dy.Enum(["foo", "bar"], sqlalchemy_use_enum=True), + MSDialect_pyodbc(), + "VARCHAR(3)", + ), + ], +) +def test_enum_sqlalchemy_native(column: Column, dialect: Dialect, datatype: str) -> None: + schema = create_schema("test", {"a": column}) + columns = schema.to_sqlalchemy_columns(dialect) + assert len(columns) == 1 + assert columns[0].type.compile(dialect) == datatype + + +def test_enum_sqlalchemy_native_python_enum_uses_member_values() -> None: + column = dy.Enum(_Status, sqlalchemy_use_enum=True) + schema = create_schema("test", {"a": column}) + sa_type = schema.to_sqlalchemy_columns(PGDialect_psycopg2())[0].type + assert list(sa_type.enums) == column.categories + + +def test_enum_sqlalchemy_native_string_categories_use_column_name() -> None: + class TestSchema(dy.Schema): + status = dy.Enum(["foo", "bar"], sqlalchemy_use_enum=True) + + column = TestSchema.columns()["status"] + assert column.sqlalchemy_dtype(PGDialect_psycopg2()).compile( + PGDialect_psycopg2() + ) == "status" + + +def test_enum_sqlalchemy_native_string_categories_requires_name_without_column( +) -> None: + column = dy.Enum(["foo", "bar"], sqlalchemy_use_enum=True) + with pytest.raises(ValueError, match="sqlalchemy_enum_name is required"): + column.sqlalchemy_dtype(PGDialect_psycopg2())