Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,21 @@ async def insert_get_id(self, query: str, bindings: list | None = None) -> int |
result = await self.execute(query, bindings)
return getattr(result, "lastrowid", None)

async def upsert(self, statement) -> int:
"""Execute a prepared SQLAlchemy upsert construct.

``statement`` is a fully-bound Core ``Insert`` construct (the dialect
specific ON CONFLICT / ON DUPLICATE KEY clause is already attached), so
no extra bindings are needed. Returns the number of affected rows.
"""
conn = await self.get_connection()
result = await conn.execute(statement)

if not self.transactions:
await conn.commit()

return result.rowcount

async def update(self, query: str, bindings: list | None = None) -> int:
result = await self.execute(query, bindings)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import inspect
from typing import TYPE_CHECKING, Any

import pendulum

from fastapi_startkit.masoniteorm.expressions.expressions import (
JoinClause,
QueryExpression,
Expand Down Expand Up @@ -318,6 +320,101 @@ async def insert_get_id(

return await self.connection.insert_get_id(sql, bindings)

async def upsert(
self,
values: dict | list,
unique_by: str | list,
update: list | None = None,
) -> int:
"""Eloquent-style bulk upsert emitting a single dialect-correct statement.

- postgres / sqlite: ``INSERT ... ON CONFLICT (unique_by) DO UPDATE SET ...``
- mysql: ``INSERT ... ON DUPLICATE KEY UPDATE ...``

Arguments:
values -- A dict or list of dicts to insert or update.
unique_by -- Column name(s) used to detect a conflict.
update -- Columns to write on conflict. Defaults to every provided
column except the ``unique_by`` keys.

Returns:
The number of rows affected.
"""
self.set_action("bulk_create")

if not values:
return 0

if isinstance(values, dict):
values = [values]

if isinstance(unique_by, str):
unique_by = [unique_by]

statement = self._build_upsert_statement(values, unique_by, update)
return await self.connection.upsert(statement)

def _build_upsert_statement(self, values: list, unique_by: list, update: list | None):
"""Build a dialect-specific SQLAlchemy upsert construct for ``values``."""
from sqlalchemy import column as sa_column, table as sa_table

caster = getattr(self._model, "caster", None)
casts = getattr(caster, "casts", {}) or {}

# Apply set-casts the same way the insert path does (get_attributes_for_insert).
def cast_row(row: dict) -> dict:
return {key: caster.set(key, value) if key in casts else value for key, value in row.items()}

rows = [cast_row(row) for row in values]

# SQLAlchemy needs a consistent column set across every row for a single
# multi-row VALUES clause; missing keys default to NULL.
columns = sorted({key for row in rows for key in row})
rows = [{key: row.get(key) for key in columns} for row in rows]

# Default the update set to everything but the conflict keys.
candidate_update = columns if update is None else update
update_columns = [col for col in candidate_update if col not in unique_by]

# Refresh updated_at on the UPDATE branch only — keeps the INSERT branch
# consistent with the current insert() path, which never injects timestamps.
refresh_updated_at = bool(getattr(self._model, "__timestamps__", False)) and "updated_at" in casts

# Every column the statement references must exist on the table construct.
table_columns = list(dict.fromkeys([*columns, *update_columns]))
if refresh_updated_at and "updated_at" not in table_columns:
table_columns.append("updated_at")

target = sa_table(self._table, *[sa_column(name) for name in table_columns])
dialect = self.connection.engine.dialect.name

if dialect == "mysql":
from sqlalchemy.dialects.mysql import insert as mysql_insert

statement = mysql_insert(target).values(rows)
set_ = {col: statement.inserted[col] for col in update_columns}
if refresh_updated_at:
set_["updated_at"] = caster.set("updated_at", pendulum.now("UTC"))
if not set_:
# Nothing to update — emulate a no-op so the statement stays valid.
set_ = {unique_by[0]: statement.inserted[unique_by[0]]}
return statement.on_duplicate_key_update(set_)

if dialect in ("postgresql", "postgres"):
from sqlalchemy.dialects.postgresql import insert as conflict_insert
elif dialect == "sqlite":
from sqlalchemy.dialects.sqlite import insert as conflict_insert
else:
raise NotImplementedError(f"upsert() is not supported for the '{dialect}' driver.")

statement = conflict_insert(target).values(rows)
set_ = {col: statement.excluded[col] for col in update_columns}
if refresh_updated_at:
set_["updated_at"] = caster.set("updated_at", pendulum.now("UTC"))
if not set_:
return statement.on_conflict_do_nothing(index_elements=unique_by)
return statement.on_conflict_do_update(index_elements=unique_by, set_=set_)

async def update(self, values: dict) -> int:
updates = [UpdateQueryExpression(col, val) for col, val in values.items()]
grammar = self.grammar()
Expand Down
14 changes: 14 additions & 0 deletions fastapi_startkit/src/fastapi_startkit/masoniteorm/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,20 @@ async def first_or_create(cls, search: dict, attributes: dict | None = None) ->
async def update_or_create(cls, search: dict, attributes: dict | None = None) -> "Model":
return await cls.query().update_or_create(search, attributes)

@classmethod
async def upsert(
cls,
values: dict | list,
unique_by: str | list,
update: list | None = None,
) -> int:
"""Bulk insert ``values``, updating existing rows on a ``unique_by`` conflict.

Emits a single dialect-correct statement (postgres/sqlite ON CONFLICT,
mysql ON DUPLICATE KEY UPDATE). Returns the number of affected rows.
"""
return await cls.query().upsert(values, unique_by, update)

@classmethod
async def create(cls, attributes: dict):
instance = cls().new_model_instance(attributes)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from unittest.mock import AsyncMock

from sqlalchemy.dialects import mysql

from ..test_case import TestCase
from ..fixtures.db import DB
from ...fixtures.model import User


class TestMySQLQueryBuilderUpsert(TestCase):
"""MySQL upsert compiles to a single ON DUPLICATE KEY UPDATE statement.

Mirrors the other MySQL builder tests, which assert generated SQL without a
live server.
"""

def _compile(self, statement):
return str(statement.compile(dialect=mysql.dialect()))

async def test_upsert_compiles_on_duplicate_key_update(self):
statement = User.query()._build_upsert_statement(
[
{"email": "a@test.com", "name": "A", "is_admin": False},
{"email": "b@test.com", "name": "B", "is_admin": True},
],
["email"],
None,
)

compiled = self._compile(statement)
# Single multi-row INSERT with a MySQL-style ON DUPLICATE KEY UPDATE.
assert compiled.count("INSERT INTO") == 1
assert "ON DUPLICATE KEY UPDATE" in compiled
assert "name = VALUES(name)" in compiled
assert "is_admin = VALUES(is_admin)" in compiled
# The conflict key is never part of the update set.
assert "email = VALUES(email)" not in compiled

async def test_upsert_respects_explicit_update_columns(self):
statement = User.query()._build_upsert_statement(
[{"email": "a@test.com", "name": "A", "is_admin": False}],
["email"],
["name"],
)

compiled = self._compile(statement)
assert "name = VALUES(name)" in compiled
assert "is_admin = VALUES(is_admin)" not in compiled

async def test_upsert_invokes_connection_once(self):
captured = AsyncMock(return_value=2)
DB.connection("mysql").upsert = captured

await User.query().upsert(
[{"email": "a@test.com", "name": "A"}, {"email": "b@test.com", "name": "B"}],
unique_by="email",
)

captured.assert_called_once()
54 changes: 54 additions & 0 deletions fastapi_startkit/tests/masoniteorm/postgres/models/test_upsert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from ..test_case import TestCase
from ...fixtures.model import User


class TestPostgresUpsert(TestCase):
async def test_upsert_inserts_new_records(self):
affected = await User.upsert(
[
{"email": "u1@example.com", "name": "U1", "is_admin": False},
{"email": "u2@example.com", "name": "U2", "is_admin": True},
],
unique_by="email",
)
self.assertGreaterEqual(affected, 1)

u1 = await User.where("email", "u1@example.com").first()
u2 = await User.where("email", "u2@example.com").first()
self.assertEqual(u1.name, "U1")
self.assertEqual(u2.name, "U2")
self.assertTrue(u2.is_admin)

async def test_upsert_updates_existing_on_conflict(self):
await User.upsert({"email": "dup@example.com", "name": "Original", "is_admin": False}, unique_by="email")
await User.upsert({"email": "dup@example.com", "name": "Updated", "is_admin": True}, unique_by="email")

rows = await User.where("email", "dup@example.com").get()
self.assertEqual(len(rows), 1)
self.assertEqual(rows.first().name, "Updated")
self.assertTrue(rows.first().is_admin)

async def test_upsert_respects_update_subset(self):
await User.upsert({"email": "sub@example.com", "name": "Keep", "is_admin": False}, unique_by="email")
await User.upsert(
{"email": "sub@example.com", "name": "Ignored", "is_admin": True},
unique_by="email",
update=["is_admin"],
)

user = await User.where("email", "sub@example.com").first()
self.assertEqual(user.name, "Keep")
self.assertTrue(user.is_admin)

async def test_upsert_refreshes_updated_at_on_update_branch(self):
await User.upsert(
{"email": "ts@example.com", "name": "TS", "is_admin": False, "updated_at": "2000-01-01 00:00:00"},
unique_by="email",
)
before = await User.where("email", "ts@example.com").first()
self.assertEqual(before.updated_at.year, 2000)

await User.upsert({"email": "ts@example.com", "name": "TS2", "is_admin": False}, unique_by="email")
after = await User.where("email", "ts@example.com").first()
self.assertEqual(after.name, "TS2")
self.assertNotEqual(after.updated_at.year, 2000)
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from unittest.mock import AsyncMock

from sqlalchemy.dialects import sqlite

from ..fixtures.db import DB
from ...fixtures.model import User
from ..test_case import TestCase


class TestQueryBuilderUpsert(TestCase):
async def test_upsert_inserts_new_records(self):
affected = await User.upsert(
[
{"email": "u1@test.com", "name": "U1", "is_admin": False},
{"email": "u2@test.com", "name": "U2", "is_admin": True},
],
unique_by="email",
)
assert affected >= 1

u1 = await User.where("email", "u1@test.com").first()
u2 = await User.where("email", "u2@test.com").first()
assert u1.name == "U1"
assert u2.name == "U2"
assert u2.is_admin is True

async def test_upsert_updates_existing_on_conflict(self):
await User.upsert({"email": "dup@test.com", "name": "Original", "is_admin": False}, unique_by="email")
await User.upsert({"email": "dup@test.com", "name": "Updated", "is_admin": True}, unique_by="email")

rows = await User.where("email", "dup@test.com").get()
assert len(rows) == 1
assert rows.first().name == "Updated"
assert rows.first().is_admin is True

async def test_upsert_respects_update_subset(self):
await User.upsert({"email": "sub@test.com", "name": "Keep", "is_admin": False}, unique_by="email")
# Only ``is_admin`` is listed for update, so ``name`` must remain unchanged.
await User.upsert(
{"email": "sub@test.com", "name": "Ignored", "is_admin": True},
unique_by="email",
update=["is_admin"],
)

user = await User.where("email", "sub@test.com").first()
assert user.name == "Keep"
assert user.is_admin is True

async def test_upsert_refreshes_updated_at_on_update_branch(self):
# Explicit updated_at is honoured on the INSERT branch.
await User.upsert(
{"email": "ts@test.com", "name": "TS", "is_admin": False, "updated_at": "2000-01-01 00:00:00"},
unique_by="email",
)
before = await User.where("email", "ts@test.com").first()
assert before.updated_at.year == 2000

# __timestamps__ is enabled on User, so the UPDATE branch refreshes updated_at.
await User.upsert({"email": "ts@test.com", "name": "TS2", "is_admin": False}, unique_by="email")
after = await User.where("email", "ts@test.com").first()
assert after.name == "TS2"
assert after.updated_at.year != 2000

async def test_upsert_emits_single_on_conflict_statement(self):
captured = AsyncMock(return_value=2)
DB.connection("sqlite").upsert = captured

await User.query().upsert(
[{"email": "a@test.com", "name": "A"}, {"email": "b@test.com", "name": "B"}],
unique_by="email",
)

captured.assert_called_once()
(statement,) = captured.call_args[0]
compiled = str(statement.compile(dialect=sqlite.dialect()))
# A single multi-row VALUES clause with a sqlite ON CONFLICT ... excluded update.
assert compiled.count("INSERT INTO") == 1
assert "ON CONFLICT (email)" in compiled
assert "excluded.name" in compiled
Loading