From 9fc7cf7208cd827a8003ff734c2ba50eff9a4ce9 Mon Sep 17 00:00:00 2001 From: Bedram Tamang Date: Thu, 18 Jun 2026 17:15:29 -0700 Subject: [PATCH] feat(orm): add async Eloquent-style upsert() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add an async upsert() that emits a SINGLE bulk statement and is dialect-correct across postgres, sqlite, and mysql using SQLAlchemy dialect insert helpers: - postgres/sqlite: INSERT ... ON CONFLICT (unique_by) DO UPDATE SET col = excluded.col - mysql: INSERT ... ON DUPLICATE KEY UPDATE col = VALUES(col) Exposed on both Model (async classmethod) and QueryBuilder. Signature mirrors Eloquent: upsert(values, unique_by, update=None) where update defaults to all provided columns except the unique_by keys. Casts are applied per-row via the model Caster (same set-cast behavior as the insert path). Timestamps stay consistent with the current insert() path — no auto-injection on INSERT; explicit created_at/updated_at in a row dict are honored; updated_at is refreshed on the DO UPDATE branch only when __timestamps__ is enabled and the model declares updated_at. Execution routes through a new Connection.upsert() that runs the prepared SQLAlchemy construct and returns the affected rowcount. Tests per dialect: live sqlite (insert/update/subset/timestamp + single ON CONFLICT statement), mysql SQL-compile assertions (no live server), and live postgres (skippable via --ignore=tests/masoniteorm/postgres). Co-Authored-By: Claude Opus 4.8 --- .../masoniteorm/connections/connection.py | 15 +++ .../masoniteorm/models/builder.py | 97 +++++++++++++++++++ .../masoniteorm/models/model.py | 14 +++ .../builder/test_mysql_builder_upsert.py | 59 +++++++++++ .../postgres/models/test_upsert.py | 54 +++++++++++ .../builder/test_sqlite_builder_upsert.py | 79 +++++++++++++++ 6 files changed, 318 insertions(+) create mode 100644 fastapi_startkit/tests/masoniteorm/mysql/builder/test_mysql_builder_upsert.py create mode 100644 fastapi_startkit/tests/masoniteorm/postgres/models/test_upsert.py create mode 100644 fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_builder_upsert.py diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/connection.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/connection.py index 4ed89f9d..a6a31e02 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/connection.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/connections/connection.py @@ -111,6 +111,21 @@ async def insert_get_id(self, query: str, bindings: list | None = None) -> int | result = await self.execute(query, bindings) return getattr(result, "lastrowid", None) + async def upsert(self, statement) -> int: + """Execute a prepared SQLAlchemy upsert construct. + + ``statement`` is a fully-bound Core ``Insert`` construct (the dialect + specific ON CONFLICT / ON DUPLICATE KEY clause is already attached), so + no extra bindings are needed. Returns the number of affected rows. + """ + conn = await self.get_connection() + result = await conn.execute(statement) + + if not self.transactions: + await conn.commit() + + return result.rowcount + async def update(self, query: str, bindings: list | None = None) -> int: result = await self.execute(query, bindings) diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/builder.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/builder.py index 2f077a90..cffe0f21 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/builder.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/builder.py @@ -1,6 +1,8 @@ import inspect from typing import TYPE_CHECKING, Any +import pendulum + from fastapi_startkit.masoniteorm.expressions.expressions import ( JoinClause, QueryExpression, @@ -318,6 +320,101 @@ async def insert_get_id( return await self.connection.insert_get_id(sql, bindings) + async def upsert( + self, + values: dict | list, + unique_by: str | list, + update: list | None = None, + ) -> int: + """Eloquent-style bulk upsert emitting a single dialect-correct statement. + + - postgres / sqlite: ``INSERT ... ON CONFLICT (unique_by) DO UPDATE SET ...`` + - mysql: ``INSERT ... ON DUPLICATE KEY UPDATE ...`` + + Arguments: + values -- A dict or list of dicts to insert or update. + unique_by -- Column name(s) used to detect a conflict. + update -- Columns to write on conflict. Defaults to every provided + column except the ``unique_by`` keys. + + Returns: + The number of rows affected. + """ + self.set_action("bulk_create") + + if not values: + return 0 + + if isinstance(values, dict): + values = [values] + + if isinstance(unique_by, str): + unique_by = [unique_by] + + statement = self._build_upsert_statement(values, unique_by, update) + return await self.connection.upsert(statement) + + def _build_upsert_statement(self, values: list, unique_by: list, update: list | None): + """Build a dialect-specific SQLAlchemy upsert construct for ``values``.""" + from sqlalchemy import column as sa_column, table as sa_table + + caster = getattr(self._model, "caster", None) + casts = getattr(caster, "casts", {}) or {} + + # Apply set-casts the same way the insert path does (get_attributes_for_insert). + def cast_row(row: dict) -> dict: + return {key: caster.set(key, value) if key in casts else value for key, value in row.items()} + + rows = [cast_row(row) for row in values] + + # SQLAlchemy needs a consistent column set across every row for a single + # multi-row VALUES clause; missing keys default to NULL. + columns = sorted({key for row in rows for key in row}) + rows = [{key: row.get(key) for key in columns} for row in rows] + + # Default the update set to everything but the conflict keys. + candidate_update = columns if update is None else update + update_columns = [col for col in candidate_update if col not in unique_by] + + # Refresh updated_at on the UPDATE branch only — keeps the INSERT branch + # consistent with the current insert() path, which never injects timestamps. + refresh_updated_at = bool(getattr(self._model, "__timestamps__", False)) and "updated_at" in casts + + # Every column the statement references must exist on the table construct. + table_columns = list(dict.fromkeys([*columns, *update_columns])) + if refresh_updated_at and "updated_at" not in table_columns: + table_columns.append("updated_at") + + target = sa_table(self._table, *[sa_column(name) for name in table_columns]) + dialect = self.connection.engine.dialect.name + + if dialect == "mysql": + from sqlalchemy.dialects.mysql import insert as mysql_insert + + statement = mysql_insert(target).values(rows) + set_ = {col: statement.inserted[col] for col in update_columns} + if refresh_updated_at: + set_["updated_at"] = caster.set("updated_at", pendulum.now("UTC")) + if not set_: + # Nothing to update — emulate a no-op so the statement stays valid. + set_ = {unique_by[0]: statement.inserted[unique_by[0]]} + return statement.on_duplicate_key_update(set_) + + if dialect in ("postgresql", "postgres"): + from sqlalchemy.dialects.postgresql import insert as conflict_insert + elif dialect == "sqlite": + from sqlalchemy.dialects.sqlite import insert as conflict_insert + else: + raise NotImplementedError(f"upsert() is not supported for the '{dialect}' driver.") + + statement = conflict_insert(target).values(rows) + set_ = {col: statement.excluded[col] for col in update_columns} + if refresh_updated_at: + set_["updated_at"] = caster.set("updated_at", pendulum.now("UTC")) + if not set_: + return statement.on_conflict_do_nothing(index_elements=unique_by) + return statement.on_conflict_do_update(index_elements=unique_by, set_=set_) + async def update(self, values: dict) -> int: updates = [UpdateQueryExpression(col, val) for col, val in values.items()] grammar = self.grammar() diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/model.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/model.py index 2e6d377f..d71e1ae9 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/model.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/model.py @@ -309,6 +309,20 @@ async def first_or_create(cls, search: dict, attributes: dict | None = None) -> async def update_or_create(cls, search: dict, attributes: dict | None = None) -> "Model": return await cls.query().update_or_create(search, attributes) + @classmethod + async def upsert( + cls, + values: dict | list, + unique_by: str | list, + update: list | None = None, + ) -> int: + """Bulk insert ``values``, updating existing rows on a ``unique_by`` conflict. + + Emits a single dialect-correct statement (postgres/sqlite ON CONFLICT, + mysql ON DUPLICATE KEY UPDATE). Returns the number of affected rows. + """ + return await cls.query().upsert(values, unique_by, update) + @classmethod async def create(cls, attributes: dict): instance = cls().new_model_instance(attributes) diff --git a/fastapi_startkit/tests/masoniteorm/mysql/builder/test_mysql_builder_upsert.py b/fastapi_startkit/tests/masoniteorm/mysql/builder/test_mysql_builder_upsert.py new file mode 100644 index 00000000..5b904532 --- /dev/null +++ b/fastapi_startkit/tests/masoniteorm/mysql/builder/test_mysql_builder_upsert.py @@ -0,0 +1,59 @@ +from unittest.mock import AsyncMock + +from sqlalchemy.dialects import mysql + +from ..test_case import TestCase +from ..fixtures.db import DB +from ...fixtures.model import User + + +class TestMySQLQueryBuilderUpsert(TestCase): + """MySQL upsert compiles to a single ON DUPLICATE KEY UPDATE statement. + + Mirrors the other MySQL builder tests, which assert generated SQL without a + live server. + """ + + def _compile(self, statement): + return str(statement.compile(dialect=mysql.dialect())) + + async def test_upsert_compiles_on_duplicate_key_update(self): + statement = User.query()._build_upsert_statement( + [ + {"email": "a@test.com", "name": "A", "is_admin": False}, + {"email": "b@test.com", "name": "B", "is_admin": True}, + ], + ["email"], + None, + ) + + compiled = self._compile(statement) + # Single multi-row INSERT with a MySQL-style ON DUPLICATE KEY UPDATE. + assert compiled.count("INSERT INTO") == 1 + assert "ON DUPLICATE KEY UPDATE" in compiled + assert "name = VALUES(name)" in compiled + assert "is_admin = VALUES(is_admin)" in compiled + # The conflict key is never part of the update set. + assert "email = VALUES(email)" not in compiled + + async def test_upsert_respects_explicit_update_columns(self): + statement = User.query()._build_upsert_statement( + [{"email": "a@test.com", "name": "A", "is_admin": False}], + ["email"], + ["name"], + ) + + compiled = self._compile(statement) + assert "name = VALUES(name)" in compiled + assert "is_admin = VALUES(is_admin)" not in compiled + + async def test_upsert_invokes_connection_once(self): + captured = AsyncMock(return_value=2) + DB.connection("mysql").upsert = captured + + await User.query().upsert( + [{"email": "a@test.com", "name": "A"}, {"email": "b@test.com", "name": "B"}], + unique_by="email", + ) + + captured.assert_called_once() diff --git a/fastapi_startkit/tests/masoniteorm/postgres/models/test_upsert.py b/fastapi_startkit/tests/masoniteorm/postgres/models/test_upsert.py new file mode 100644 index 00000000..511df778 --- /dev/null +++ b/fastapi_startkit/tests/masoniteorm/postgres/models/test_upsert.py @@ -0,0 +1,54 @@ +from ..test_case import TestCase +from ...fixtures.model import User + + +class TestPostgresUpsert(TestCase): + async def test_upsert_inserts_new_records(self): + affected = await User.upsert( + [ + {"email": "u1@example.com", "name": "U1", "is_admin": False}, + {"email": "u2@example.com", "name": "U2", "is_admin": True}, + ], + unique_by="email", + ) + self.assertGreaterEqual(affected, 1) + + u1 = await User.where("email", "u1@example.com").first() + u2 = await User.where("email", "u2@example.com").first() + self.assertEqual(u1.name, "U1") + self.assertEqual(u2.name, "U2") + self.assertTrue(u2.is_admin) + + async def test_upsert_updates_existing_on_conflict(self): + await User.upsert({"email": "dup@example.com", "name": "Original", "is_admin": False}, unique_by="email") + await User.upsert({"email": "dup@example.com", "name": "Updated", "is_admin": True}, unique_by="email") + + rows = await User.where("email", "dup@example.com").get() + self.assertEqual(len(rows), 1) + self.assertEqual(rows.first().name, "Updated") + self.assertTrue(rows.first().is_admin) + + async def test_upsert_respects_update_subset(self): + await User.upsert({"email": "sub@example.com", "name": "Keep", "is_admin": False}, unique_by="email") + await User.upsert( + {"email": "sub@example.com", "name": "Ignored", "is_admin": True}, + unique_by="email", + update=["is_admin"], + ) + + user = await User.where("email", "sub@example.com").first() + self.assertEqual(user.name, "Keep") + self.assertTrue(user.is_admin) + + async def test_upsert_refreshes_updated_at_on_update_branch(self): + await User.upsert( + {"email": "ts@example.com", "name": "TS", "is_admin": False, "updated_at": "2000-01-01 00:00:00"}, + unique_by="email", + ) + before = await User.where("email", "ts@example.com").first() + self.assertEqual(before.updated_at.year, 2000) + + await User.upsert({"email": "ts@example.com", "name": "TS2", "is_admin": False}, unique_by="email") + after = await User.where("email", "ts@example.com").first() + self.assertEqual(after.name, "TS2") + self.assertNotEqual(after.updated_at.year, 2000) diff --git a/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_builder_upsert.py b/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_builder_upsert.py new file mode 100644 index 00000000..571f62ca --- /dev/null +++ b/fastapi_startkit/tests/masoniteorm/sqlite/builder/test_sqlite_builder_upsert.py @@ -0,0 +1,79 @@ +from unittest.mock import AsyncMock + +from sqlalchemy.dialects import sqlite + +from ..fixtures.db import DB +from ...fixtures.model import User +from ..test_case import TestCase + + +class TestQueryBuilderUpsert(TestCase): + async def test_upsert_inserts_new_records(self): + affected = await User.upsert( + [ + {"email": "u1@test.com", "name": "U1", "is_admin": False}, + {"email": "u2@test.com", "name": "U2", "is_admin": True}, + ], + unique_by="email", + ) + assert affected >= 1 + + u1 = await User.where("email", "u1@test.com").first() + u2 = await User.where("email", "u2@test.com").first() + assert u1.name == "U1" + assert u2.name == "U2" + assert u2.is_admin is True + + async def test_upsert_updates_existing_on_conflict(self): + await User.upsert({"email": "dup@test.com", "name": "Original", "is_admin": False}, unique_by="email") + await User.upsert({"email": "dup@test.com", "name": "Updated", "is_admin": True}, unique_by="email") + + rows = await User.where("email", "dup@test.com").get() + assert len(rows) == 1 + assert rows.first().name == "Updated" + assert rows.first().is_admin is True + + async def test_upsert_respects_update_subset(self): + await User.upsert({"email": "sub@test.com", "name": "Keep", "is_admin": False}, unique_by="email") + # Only ``is_admin`` is listed for update, so ``name`` must remain unchanged. + await User.upsert( + {"email": "sub@test.com", "name": "Ignored", "is_admin": True}, + unique_by="email", + update=["is_admin"], + ) + + user = await User.where("email", "sub@test.com").first() + assert user.name == "Keep" + assert user.is_admin is True + + async def test_upsert_refreshes_updated_at_on_update_branch(self): + # Explicit updated_at is honoured on the INSERT branch. + await User.upsert( + {"email": "ts@test.com", "name": "TS", "is_admin": False, "updated_at": "2000-01-01 00:00:00"}, + unique_by="email", + ) + before = await User.where("email", "ts@test.com").first() + assert before.updated_at.year == 2000 + + # __timestamps__ is enabled on User, so the UPDATE branch refreshes updated_at. + await User.upsert({"email": "ts@test.com", "name": "TS2", "is_admin": False}, unique_by="email") + after = await User.where("email", "ts@test.com").first() + assert after.name == "TS2" + assert after.updated_at.year != 2000 + + async def test_upsert_emits_single_on_conflict_statement(self): + captured = AsyncMock(return_value=2) + DB.connection("sqlite").upsert = captured + + await User.query().upsert( + [{"email": "a@test.com", "name": "A"}, {"email": "b@test.com", "name": "B"}], + unique_by="email", + ) + + captured.assert_called_once() + (statement,) = captured.call_args[0] + compiled = str(statement.compile(dialect=sqlite.dialect())) + # A single multi-row VALUES clause with a sqlite ON CONFLICT ... excluded update. + assert compiled.count("INSERT INTO") == 1 + assert "ON CONFLICT (email)" in compiled + assert "excluded.name" in compiled