diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/connection.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/connection.py index 4ed89f9d..2bac9fc2 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/connection.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/connection.py @@ -102,6 +102,20 @@ async def execute(self, query: str, bindings: list | None = None): return result + async def execute_statement(self, statement): + """Execute a SQLAlchemy core construct (not a raw SQL string). + + Used for dialect-specific statements such as upserts where the SQL is + built with SQLAlchemy's dialect helpers rather than the grammar. + """ + conn = await self.get_connection() + result = await conn.execute(statement) + + if not self.transactions: + await conn.commit() + + return result + async def insert(self, query: str, bindings: list | None = None) -> int | None: result = await self.execute(query, bindings) diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/builder.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/builder.py index 2f077a90..04c7a3e1 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/builder.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/builder.py @@ -308,6 +308,63 @@ async def insert(self, values: dict | list) -> int | None: bindings = [val for row in values for val in row.values()] return await self.connection.insert(sql, bindings) + async def upsert( + self, + values: dict | list[dict], + unique_by: str | list[str], + update: str | list[str] | None = None, + ) -> int: + """Insert rows in a single statement, updating on a unique-key collision. + + Mirrors Eloquent's ``upsert``: every row is inserted in one bulk + statement; when a row collides on ``unique_by`` the columns in + ``update`` are updated instead. When ``update`` is omitted, every + non-unique column is updated. The SQL is dialect-specific (Postgres / + SQLite ``ON CONFLICT`` and MySQL ``ON DUPLICATE KEY UPDATE``). + """ + from fastapi_startkit.masoniteorm.query.upsert import build_upsert_statement + + if not values: + return 0 + + if isinstance(values, dict): + values = [values] + if isinstance(unique_by, str): + unique_by = [unique_by] + if isinstance(update, str): + update = [update] + + rows = [self._cast_upsert_row(row) for row in self._normalize_upsert_rows(values)] + + if update is None: + update = [column for column in rows[0] if column not in unique_by] + + statement = build_upsert_statement( + self.connection.config["driver"], + self.get_table_name(), + rows, + unique_by, + update, + ) + result = await self.connection.execute_statement(statement) + return result.rowcount + + def _normalize_upsert_rows(self, values: list[dict]) -> list[dict]: + """Give every row an identical, deterministically ordered column set.""" + columns: list[str] = [] + for row in values: + for key in row: + if key not in columns: + columns.append(key) + return [{column: row.get(column) for column in columns} for row in values] + + def _cast_upsert_row(self, row: dict) -> dict: + model = getattr(self, "_model", None) + caster = getattr(model, "caster", None) + if caster is None: + return dict(row) + return {key: caster.set(key, value) for key, value in row.items()} + async def insert_get_id( self, values: dict[str, Any] | list[dict[str, Any]], diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/model.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/model.py index 2e6d377f..3084fce6 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/model.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/model.py @@ -316,6 +316,15 @@ async def create(cls, attributes: dict): return instance + @classmethod + async def upsert( + cls, + values: dict | list[dict], + unique_by: str | list[str], + update: str | list[str] | None = None, + ) -> int: + return await cls.query().upsert(values, unique_by, update) + async def update(self, attributes: dict) -> bool: if not self._exists: return False diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/upsert.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/upsert.py new file mode 100644 index 00000000..3215f52a --- /dev/null +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/query/upsert.py @@ -0,0 +1,46 @@ +from sqlalchemy import column +from sqlalchemy import table as sa_table +from sqlalchemy.dialects.mysql import insert as mysql_insert +from sqlalchemy.dialects.postgresql import insert as postgres_insert +from sqlalchemy.dialects.sqlite import insert as sqlite_insert + +# Each supported driver maps to its dialect-specific INSERT builder so the +# generated statement can express the right "on conflict" upsert clause. +_INSERTERS = { + "postgres": postgres_insert, + "sqlite": sqlite_insert, + "mysql": mysql_insert, +} + + +def build_upsert_statement(driver, table_name, rows, unique_by, update_columns): + """Build a single bulk INSERT ... ON CONFLICT/DUPLICATE KEY statement. + + `rows` must already be normalised to share an identical column set so the + multi-row VALUES clause is well formed. + """ + inserter = _INSERTERS.get(driver) + if inserter is None: + raise ValueError(f"upsert() is not supported for the '{driver}' driver") + + columns = list(rows[0].keys()) + target = sa_table(table_name, *[column(name) for name in columns]) + statement = inserter(target).values(rows) + + if driver == "mysql": + if update_columns: + return statement.on_duplicate_key_update( + {name: statement.inserted[name] for name in update_columns} + ) + # No columns to update: keep a unique key untouched to make it a no-op. + return statement.on_duplicate_key_update( + {unique_by[0]: statement.inserted[unique_by[0]]} + ) + + if update_columns: + return statement.on_conflict_do_update( + index_elements=unique_by, + set_={name: statement.excluded[name] for name in update_columns}, + ) + + return statement.on_conflict_do_nothing(index_elements=unique_by) diff --git a/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_builder_upsert.py b/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_builder_upsert.py new file mode 100644 index 00000000..070342b8 --- /dev/null +++ b/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_builder_upsert.py @@ -0,0 +1,77 @@ +from ...fixtures.model import User +from ..test_case import TestCase + + +class TestQueryBuilderUpsert(TestCase): + async def test_upsert_inserts_new_rows(self): + await User.query().upsert( + [ + {"email": "new1@test.com", "name": "New One", "is_admin": False}, + {"email": "new2@test.com", "name": "New Two", "is_admin": False}, + ], + unique_by=["email"], + ) + + one = await User.where("email", "new1@test.com").first() + two = await User.where("email", "new2@test.com").first() + assert one.name == "New One" + assert two.name == "New Two" + + async def test_upsert_accepts_a_single_dict(self): + await User.query().upsert( + {"email": "solo@test.com", "name": "Solo", "is_admin": False}, + unique_by="email", + ) + + user = await User.where("email", "solo@test.com").first() + assert user.name == "Solo" + + async def test_upsert_updates_on_conflict_without_duplicating(self): + await User.query().upsert( + {"email": "admin@admin.com", "name": "Renamed Admin", "is_admin": True}, + unique_by="email", + ) + + matches = await User.where("email", "admin@admin.com").get() + assert len(matches) == 1 + assert matches.first().name == "Renamed Admin" + + async def test_upsert_only_updates_listed_columns(self): + await User.query().upsert( + {"email": "admin@admin.com", "name": "Partial Update", "is_admin": False}, + unique_by="email", + update=["name"], + ) + + user = await User.where("email", "admin@admin.com").first() + assert user.name == "Partial Update" + # is_admin was excluded from `update`, so the seeded True is preserved. + assert user.is_admin is True + + async def test_upsert_without_update_list_updates_all_non_unique_columns(self): + await User.query().upsert( + {"email": "admin@admin.com", "name": "Full Update", "is_admin": False}, + unique_by="email", + ) + + user = await User.where("email", "admin@admin.com").first() + assert user.name == "Full Update" + assert user.is_admin is False + + async def test_upsert_mixes_insert_and_update_in_one_call(self): + await User.upsert( + [ + {"email": "admin@admin.com", "name": "Existing Updated", "is_admin": True}, + {"email": "fresh@test.com", "name": "Fresh Insert", "is_admin": False}, + ], + unique_by="email", + update=["name"], + ) + + existing = await User.where("email", "admin@admin.com").first() + fresh = await User.where("email", "fresh@test.com").first() + assert existing.name == "Existing Updated" + assert fresh.name == "Fresh Insert" + + async def test_upsert_returns_zero_for_empty_values(self): + assert await User.query().upsert([], unique_by="email") == 0 diff --git a/fastapi_startkit/tests/masoniteorm/test_upsert_statement.py b/fastapi_startkit/tests/masoniteorm/test_upsert_statement.py new file mode 100644 index 00000000..2fa977bc --- /dev/null +++ b/fastapi_startkit/tests/masoniteorm/test_upsert_statement.py @@ -0,0 +1,47 @@ +from unittest import TestCase + +from sqlalchemy.dialects import mysql, postgresql, sqlite + +from fastapi_startkit.masoniteorm.query.upsert import build_upsert_statement + +ROWS = [ + {"email": "a@a.com", "name": "A"}, + {"email": "b@b.com", "name": "B"}, +] + + +def compiled(statement, dialect): + return str(statement.compile(dialect=dialect)) + + +class TestBuildUpsertStatement(TestCase): + def test_postgres_uses_on_conflict_do_update(self): + statement = build_upsert_statement("postgres", "users", ROWS, ["email"], ["name"]) + sql = compiled(statement, postgresql.dialect()) + self.assertIn("ON CONFLICT (email) DO UPDATE SET name = excluded.name", sql) + + def test_sqlite_uses_on_conflict_do_update(self): + statement = build_upsert_statement("sqlite", "users", ROWS, ["email"], ["name"]) + sql = compiled(statement, sqlite.dialect()) + self.assertIn("ON CONFLICT (email) DO UPDATE SET name = excluded.name", sql) + + def test_mysql_uses_on_duplicate_key_update(self): + statement = build_upsert_statement("mysql", "users", ROWS, ["email"], ["name"]) + sql = compiled(statement, mysql.dialect()) + self.assertIn("ON DUPLICATE KEY UPDATE name = VALUES(name)", sql) + + def test_bulk_values_are_a_single_statement(self): + statement = build_upsert_statement("postgres", "users", ROWS, ["email"], ["name"]) + sql = compiled(statement, postgresql.dialect()) + # Two value tuples → one multi-row INSERT, not a loop. + self.assertEqual(sql.count("INSERT INTO"), 1) + self.assertIn("VALUES", sql) + + def test_empty_update_list_falls_back_to_do_nothing(self): + statement = build_upsert_statement("postgres", "users", ROWS, ["email"], []) + sql = compiled(statement, postgresql.dialect()) + self.assertIn("ON CONFLICT (email) DO NOTHING", sql) + + def test_unsupported_driver_raises(self): + with self.assertRaises(ValueError): + build_upsert_statement("oracle", "users", ROWS, ["email"], ["name"])