From 81b65aa6e1ce4d155e2f869df2456e8891ff4f4f Mon Sep 17 00:00:00 2001 From: Jack Oberman Date: Tue, 2 Jun 2026 18:52:57 -0400 Subject: [PATCH 1/2] feat: Add optional sqlalchemy_use_enum for dy.Enum Allow dy.Enum to emit sqlalchemy.Enum for to_sqlalchemy_columns when sqlalchemy_use_enum=True, with optional sqlalchemy_enum_name and column- name defaults for PostgreSQL native enum types. Closes #354 Co-authored-by: Cursor --- dataframely/columns/enum.py | 52 ++++++++++++++++++++++ docs/guides/features/sql-generation.md | 21 +++++++++ tests/column_types/test_enum.py | 40 +++++++++++++++++ tests/columns/test_sqlalchemy_columns.py | 56 ++++++++++++++++++++++++ 4 files changed, 169 insertions(+) diff --git a/dataframely/columns/enum.py b/dataframely/columns/enum.py index 6c63f2c..e77fa6e 100644 --- a/dataframely/columns/enum.py +++ b/dataframely/columns/enum.py @@ -33,6 +33,8 @@ def __init__( alias: str | None = None, metadata: dict[str, Any] | None = None, description: str | None = None, + sqlalchemy_use_enum: bool = False, + sqlalchemy_enum_name: str | None = None, ): """ Args: @@ -68,6 +70,15 @@ def __init__( names, the specified alias is the only valid name. metadata: A dictionary of metadata to attach to the column. description: A human-readable description of the column. + sqlalchemy_use_enum: When ``True``, map this column to :class:`sqlalchemy.Enum` + in :meth:`~dataframely.Schema.to_sqlalchemy_columns` instead of + ``CHAR`` / ``VARCHAR``. Use this for PostgreSQL native enum types and + Alembic schema drift detection. Defaults to ``False`` (string columns). + sqlalchemy_enum_name: Optional name for the SQLAlchemy / database enum type + when ``sqlalchemy_use_enum=True``. If omitted and ``categories`` is a + Python :class:`enum.Enum` subclass, SQLAlchemy uses the enum class name + (lowercased). Otherwise the SQL column name from + :meth:`~dataframely.Schema.to_sqlalchemy_columns` is used. """ super().__init__( nullable=nullable, @@ -78,7 +89,11 @@ def __init__( metadata=metadata, description=description, ) + self.sqlalchemy_use_enum = sqlalchemy_use_enum + self.sqlalchemy_enum_name = sqlalchemy_enum_name + self._enum_class: type[enum.Enum] | None = None if isclass(categories) and issubclass(categories, enum.Enum): + self._enum_class = categories categories = (item.value for item in categories) self.categories = list(categories) @@ -91,12 +106,49 @@ def validate_dtype(self, dtype: PolarsDataType) -> bool: return False return self.categories == dtype.categories.to_list() + def sqlalchemy_column(self, name: str, dialect: sa.Dialect) -> sa.Column: + if self.sqlalchemy_use_enum: + return sa.Column( + name, + self._sqlalchemy_enum_type(dialect, column_name=name), + nullable=self.nullable, + primary_key=self.primary_key, + unique=self.unique, + autoincrement=False, + ) + return super().sqlalchemy_column(name, dialect) + def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: + if self.sqlalchemy_use_enum: + column_name = self._name or None + return self._sqlalchemy_enum_type(dialect, column_name=column_name) category_lengths = [len(c) for c in self.categories] if all(length == category_lengths[0] for length in category_lengths): return sa.CHAR(category_lengths[0]) return sa.String(max(category_lengths)) + def _sqlalchemy_enum_type( + self, _dialect: sa.Dialect, *, column_name: str | None + ) -> sa_TypeEngine: + length = max(len(c) for c in self.categories) + kwargs: dict[str, Any] = {"length": length} + name = self.sqlalchemy_enum_name + if self._enum_class is not None: + if name is not None: + kwargs["name"] = name + return sa.Enum(self._enum_class, **kwargs) + if name is None: + name = column_name + if name is None: + raise ValueError( + "sqlalchemy_enum_name is required for dy.Enum with string categories " + "and sqlalchemy_use_enum=True when not building columns via " + "Schema.to_sqlalchemy_columns(). Alternatively, pass a Python " + "enum.Enum class as categories." + ) + kwargs["name"] = name + return sa.Enum(*self.categories, **kwargs) + @property def pyarrow_dtype(self) -> pa.DataType: if len(self.categories) <= 2**8 - 1: diff --git a/docs/guides/features/sql-generation.md b/docs/guides/features/sql-generation.md index e84d2fd..fe7b11b 100644 --- a/docs/guides/features/sql-generation.md +++ b/docs/guides/features/sql-generation.md @@ -81,6 +81,27 @@ the maximal length of the string is inferred from the regular expression if poss maximal lengths can be particularly important for primary key columns. Some database systems, such as Microsoft SQL Server, do not allow `VARCHAR(max)` columns (unbounded strings) to be used as primary keys. ``` +## Native SQL enums (optional) + +By default, {class}`~dataframely.Enum` maps to fixed-length `CHAR` or `VARCHAR` columns so stored values remain plain strings. For PostgreSQL setups that use database-level `ENUM` types (for example with Alembic autogenerate), set `sqlalchemy_use_enum=True`: + +```python +from enum import StrEnum + +import dataframely as dy + + +class Status(StrEnum): + PENDING = "pending" + APPROVED = "approved" + + +class Staged(dy.Schema): + status = dy.Enum(Status, sqlalchemy_use_enum=True) +``` + +When `categories` is a Python `enum.Enum` subclass, SQLAlchemy uses the enum class name (lowercased) as the database enum type name. For string category lists, the SQL column name is used by default; override it with `sqlalchemy_enum_name` if needed. On dialects without native enums (such as Microsoft SQL Server), SQLAlchemy falls back to `VARCHAR` with a check constraint. + ## Collections of multiple tables If you have an entire `dy.Collection`, it's also easy to generate one table for each member table of the collection. diff --git a/tests/column_types/test_enum.py b/tests/column_types/test_enum.py index 85078a3..9de0103 100644 --- a/tests/column_types/test_enum.py +++ b/tests/column_types/test_enum.py @@ -108,3 +108,43 @@ def test_sequences_and_enums( S = create_schema("test", {"x": dy.Enum(categories1)}) df = pl.DataFrame({"x": pl.Series(["a", "b"], dtype=pl.Enum(categories2))}) S.validate(df) + + +def test_matches_sqlalchemy_use_enum() -> None: + expr = pl.element() + assert dy.Enum(["a", "b"]).matches(dy.Enum(["a", "b"]), expr) + assert not dy.Enum(["a", "b"], sqlalchemy_use_enum=True).matches( + dy.Enum(["a", "b"]), expr + ) + assert dy.Enum(["a", "b"], sqlalchemy_use_enum=True).matches( + dy.Enum(["a", "b"], sqlalchemy_use_enum=True), expr + ) + + +def test_matches_sqlalchemy_enum_name() -> None: + expr = pl.element() + assert not dy.Enum( + ["a", "b"], + sqlalchemy_use_enum=True, + sqlalchemy_enum_name="one", + ).matches( + dy.Enum( + ["a", "b"], + sqlalchemy_use_enum=True, + sqlalchemy_enum_name="two", + ), + expr, + ) + + +def test_as_dict_from_dict_sqlalchemy_enum_flags() -> None: + column = dy.Enum( + ["a", "b"], + sqlalchemy_use_enum=True, + sqlalchemy_enum_name="my_enum", + ) + data = column.as_dict(pl.element()) + restored = dy.Enum.from_dict(data) + assert restored.sqlalchemy_use_enum is True + assert restored.sqlalchemy_enum_name == "my_enum" + assert restored.categories == ["a", "b"] diff --git a/tests/columns/test_sqlalchemy_columns.py b/tests/columns/test_sqlalchemy_columns.py index 6731202..af25070 100644 --- a/tests/columns/test_sqlalchemy_columns.py +++ b/tests/columns/test_sqlalchemy_columns.py @@ -1,6 +1,8 @@ # Copyright (c) QuantCo 2025-2026 # SPDX-License-Identifier: BSD-3-Clause +from enum import Enum + import pytest import dataframely as dy @@ -171,3 +173,57 @@ def test_raise_for_object_column(dialect: Dialect) -> None: NotImplementedError, match="SQL column cannot have 'Object' type." ): dy.Object().sqlalchemy_dtype(dialect) + + +class _Status(str, Enum): + PENDING = "pending" + APPROVED = "approved" + + +@pytest.mark.parametrize( + ("column", "dialect", "datatype"), + [ + ( + dy.Enum(["foo", "bar"], sqlalchemy_use_enum=True), + PGDialect_psycopg2(), + "a", + ), + ( + dy.Enum( + ["foo", "bar"], + sqlalchemy_use_enum=True, + sqlalchemy_enum_name="my_status", + ), + PGDialect_psycopg2(), + "my_status", + ), + (dy.Enum(_Status, sqlalchemy_use_enum=True), PGDialect_psycopg2(), "_status"), + ( + dy.Enum(["foo", "bar"], sqlalchemy_use_enum=True), + MSDialect_pyodbc(), + "VARCHAR(3)", + ), + ], +) +def test_enum_sqlalchemy_native(column: Column, dialect: Dialect, datatype: str) -> None: + schema = create_schema("test", {"a": column}) + columns = schema.to_sqlalchemy_columns(dialect) + assert len(columns) == 1 + assert columns[0].type.compile(dialect) == datatype + + +def test_enum_sqlalchemy_native_string_categories_use_column_name() -> None: + class TestSchema(dy.Schema): + status = dy.Enum(["foo", "bar"], sqlalchemy_use_enum=True) + + column = TestSchema.columns()["status"] + assert column.sqlalchemy_dtype(PGDialect_psycopg2()).compile( + PGDialect_psycopg2() + ) == "status" + + +def test_enum_sqlalchemy_native_string_categories_requires_name_without_column( +) -> None: + column = dy.Enum(["foo", "bar"], sqlalchemy_use_enum=True) + with pytest.raises(ValueError, match="sqlalchemy_enum_name is required"): + column.sqlalchemy_dtype(PGDialect_psycopg2()) From 58e0727d39439f8a964bc869509837bb792f3133 Mon Sep 17 00:00:00 2001 From: Jack Oberman Date: Tue, 2 Jun 2026 19:01:31 -0400 Subject: [PATCH 2/2] fix: Align SQLAlchemy Enum values with dy.Enum categories --- dataframely/columns/enum.py | 16 +++++++--------- docs/guides/features/sql-generation.md | 4 ++-- tests/columns/test_sqlalchemy_columns.py | 7 +++++++ 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/dataframely/columns/enum.py b/dataframely/columns/enum.py index e77fa6e..f1d2699 100644 --- a/dataframely/columns/enum.py +++ b/dataframely/columns/enum.py @@ -78,7 +78,9 @@ def __init__( when ``sqlalchemy_use_enum=True``. If omitted and ``categories`` is a Python :class:`enum.Enum` subclass, SQLAlchemy uses the enum class name (lowercased). Otherwise the SQL column name from - :meth:`~dataframely.Schema.to_sqlalchemy_columns` is used. + :meth:`~dataframely.Schema.to_sqlalchemy_columns` is used. For Python + enums, persisted values are the enum members' ``.value`` strings (not + member names), matching :attr:`categories`. """ super().__init__( nullable=nullable, @@ -108,14 +110,9 @@ def validate_dtype(self, dtype: PolarsDataType) -> bool: def sqlalchemy_column(self, name: str, dialect: sa.Dialect) -> sa.Column: if self.sqlalchemy_use_enum: - return sa.Column( - name, - self._sqlalchemy_enum_type(dialect, column_name=name), - nullable=self.nullable, - primary_key=self.primary_key, - unique=self.unique, - autoincrement=False, - ) + column = super().sqlalchemy_column(name, dialect) + column.type = self._sqlalchemy_enum_type(dialect, column_name=name) + return column return super().sqlalchemy_column(name, dialect) def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: @@ -136,6 +133,7 @@ def _sqlalchemy_enum_type( if self._enum_class is not None: if name is not None: kwargs["name"] = name + kwargs["values_callable"] = lambda enum: [member.value for member in enum] return sa.Enum(self._enum_class, **kwargs) if name is None: name = column_name diff --git a/docs/guides/features/sql-generation.md b/docs/guides/features/sql-generation.md index fe7b11b..696e01e 100644 --- a/docs/guides/features/sql-generation.md +++ b/docs/guides/features/sql-generation.md @@ -86,12 +86,12 @@ maximal lengths can be particularly important for primary key columns. Some data By default, {class}`~dataframely.Enum` maps to fixed-length `CHAR` or `VARCHAR` columns so stored values remain plain strings. For PostgreSQL setups that use database-level `ENUM` types (for example with Alembic autogenerate), set `sqlalchemy_use_enum=True`: ```python -from enum import StrEnum +from enum import Enum import dataframely as dy -class Status(StrEnum): +class Status(str, Enum): PENDING = "pending" APPROVED = "approved" diff --git a/tests/columns/test_sqlalchemy_columns.py b/tests/columns/test_sqlalchemy_columns.py index af25070..9cd8fcf 100644 --- a/tests/columns/test_sqlalchemy_columns.py +++ b/tests/columns/test_sqlalchemy_columns.py @@ -212,6 +212,13 @@ def test_enum_sqlalchemy_native(column: Column, dialect: Dialect, datatype: str) assert columns[0].type.compile(dialect) == datatype +def test_enum_sqlalchemy_native_python_enum_uses_member_values() -> None: + column = dy.Enum(_Status, sqlalchemy_use_enum=True) + schema = create_schema("test", {"a": column}) + sa_type = schema.to_sqlalchemy_columns(PGDialect_psycopg2())[0].type + assert list(sa_type.enums) == column.categories + + def test_enum_sqlalchemy_native_string_categories_use_column_name() -> None: class TestSchema(dy.Schema): status = dy.Enum(["foo", "bar"], sqlalchemy_use_enum=True)