Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,20 @@ async def execute(self, query: str, bindings: list | None = None):

return result

async def execute_statement(self, statement):
"""Execute a SQLAlchemy core construct (not a raw SQL string).

Used for dialect-specific statements such as upserts where the SQL is
built with SQLAlchemy's dialect helpers rather than the grammar.
"""
conn = await self.get_connection()
result = await conn.execute(statement)

if not self.transactions:
await conn.commit()

return result

async def insert(self, query: str, bindings: list | None = None) -> int | None:
result = await self.execute(query, bindings)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,63 @@ async def insert(self, values: dict | list) -> int | None:
bindings = [val for row in values for val in row.values()]
return await self.connection.insert(sql, bindings)

async def upsert(
self,
values: dict | list[dict],
unique_by: str | list[str],
update: str | list[str] | None = None,
) -> int:
"""Insert rows in a single statement, updating on a unique-key collision.

Mirrors Eloquent's ``upsert``: every row is inserted in one bulk
statement; when a row collides on ``unique_by`` the columns in
``update`` are updated instead. When ``update`` is omitted, every
non-unique column is updated. The SQL is dialect-specific (Postgres /
SQLite ``ON CONFLICT`` and MySQL ``ON DUPLICATE KEY UPDATE``).
"""
from fastapi_startkit.masoniteorm.query.upsert import build_upsert_statement

if not values:
return 0

if isinstance(values, dict):
values = [values]
if isinstance(unique_by, str):
unique_by = [unique_by]
if isinstance(update, str):
update = [update]

rows = [self._cast_upsert_row(row) for row in self._normalize_upsert_rows(values)]

if update is None:
update = [column for column in rows[0] if column not in unique_by]

statement = build_upsert_statement(
self.connection.config["driver"],
self.get_table_name(),
rows,
unique_by,
update,
)
result = await self.connection.execute_statement(statement)
return result.rowcount

def _normalize_upsert_rows(self, values: list[dict]) -> list[dict]:
"""Give every row an identical, deterministically ordered column set."""
columns: list[str] = []
for row in values:
for key in row:
if key not in columns:
columns.append(key)
return [{column: row.get(column) for column in columns} for row in values]

def _cast_upsert_row(self, row: dict) -> dict:
model = getattr(self, "_model", None)
caster = getattr(model, "caster", None)
if caster is None:
return dict(row)
return {key: caster.set(key, value) for key, value in row.items()}

async def insert_get_id(
self,
values: dict[str, Any] | list[dict[str, Any]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,15 @@ async def create(cls, attributes: dict):

return instance

@classmethod
async def upsert(
cls,
values: dict | list[dict],
unique_by: str | list[str],
update: str | list[str] | None = None,
) -> int:
return await cls.query().upsert(values, unique_by, update)

async def update(self, attributes: dict) -> bool:
if not self._exists:
return False
Expand Down
46 changes: 46 additions & 0 deletions fastapi_startkit/src/fastapi_startkit/masoniteorm/query/upsert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from sqlalchemy import column
from sqlalchemy import table as sa_table
from sqlalchemy.dialects.mysql import insert as mysql_insert
from sqlalchemy.dialects.postgresql import insert as postgres_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert

# Each supported driver maps to its dialect-specific INSERT builder so the
# generated statement can express the right "on conflict" upsert clause.
_INSERTERS = {
"postgres": postgres_insert,
"sqlite": sqlite_insert,
"mysql": mysql_insert,
}


def build_upsert_statement(driver, table_name, rows, unique_by, update_columns):
"""Build a single bulk INSERT ... ON CONFLICT/DUPLICATE KEY statement.

`rows` must already be normalised to share an identical column set so the
multi-row VALUES clause is well formed.
"""
inserter = _INSERTERS.get(driver)
if inserter is None:
raise ValueError(f"upsert() is not supported for the '{driver}' driver")

columns = list(rows[0].keys())
target = sa_table(table_name, *[column(name) for name in columns])
statement = inserter(target).values(rows)

if driver == "mysql":
if update_columns:
return statement.on_duplicate_key_update(
{name: statement.inserted[name] for name in update_columns}
)
# No columns to update: keep a unique key untouched to make it a no-op.
return statement.on_duplicate_key_update(
{unique_by[0]: statement.inserted[unique_by[0]]}
)

if update_columns:
return statement.on_conflict_do_update(
index_elements=unique_by,
set_={name: statement.excluded[name] for name in update_columns},
)

return statement.on_conflict_do_nothing(index_elements=unique_by)
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from ...fixtures.model import User
from ..test_case import TestCase


class TestQueryBuilderUpsert(TestCase):
async def test_upsert_inserts_new_rows(self):
await User.query().upsert(
[
{"email": "new1@test.com", "name": "New One", "is_admin": False},
{"email": "new2@test.com", "name": "New Two", "is_admin": False},
],
unique_by=["email"],
)

one = await User.where("email", "new1@test.com").first()
two = await User.where("email", "new2@test.com").first()
assert one.name == "New One"
assert two.name == "New Two"

async def test_upsert_accepts_a_single_dict(self):
await User.query().upsert(
{"email": "solo@test.com", "name": "Solo", "is_admin": False},
unique_by="email",
)

user = await User.where("email", "solo@test.com").first()
assert user.name == "Solo"

async def test_upsert_updates_on_conflict_without_duplicating(self):
await User.query().upsert(
{"email": "admin@admin.com", "name": "Renamed Admin", "is_admin": True},
unique_by="email",
)

matches = await User.where("email", "admin@admin.com").get()
assert len(matches) == 1
assert matches.first().name == "Renamed Admin"

async def test_upsert_only_updates_listed_columns(self):
await User.query().upsert(
{"email": "admin@admin.com", "name": "Partial Update", "is_admin": False},
unique_by="email",
update=["name"],
)

user = await User.where("email", "admin@admin.com").first()
assert user.name == "Partial Update"
# is_admin was excluded from `update`, so the seeded True is preserved.
assert user.is_admin is True

async def test_upsert_without_update_list_updates_all_non_unique_columns(self):
await User.query().upsert(
{"email": "admin@admin.com", "name": "Full Update", "is_admin": False},
unique_by="email",
)

user = await User.where("email", "admin@admin.com").first()
assert user.name == "Full Update"
assert user.is_admin is False

async def test_upsert_mixes_insert_and_update_in_one_call(self):
await User.upsert(
[
{"email": "admin@admin.com", "name": "Existing Updated", "is_admin": True},
{"email": "fresh@test.com", "name": "Fresh Insert", "is_admin": False},
],
unique_by="email",
update=["name"],
)

existing = await User.where("email", "admin@admin.com").first()
fresh = await User.where("email", "fresh@test.com").first()
assert existing.name == "Existing Updated"
assert fresh.name == "Fresh Insert"

async def test_upsert_returns_zero_for_empty_values(self):
assert await User.query().upsert([], unique_by="email") == 0
47 changes: 47 additions & 0 deletions fastapi_startkit/tests/masoniteorm/test_upsert_statement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from unittest import TestCase

from sqlalchemy.dialects import mysql, postgresql, sqlite

from fastapi_startkit.masoniteorm.query.upsert import build_upsert_statement

ROWS = [
{"email": "a@a.com", "name": "A"},
{"email": "b@b.com", "name": "B"},
]


def compiled(statement, dialect):
return str(statement.compile(dialect=dialect))


class TestBuildUpsertStatement(TestCase):
def test_postgres_uses_on_conflict_do_update(self):
statement = build_upsert_statement("postgres", "users", ROWS, ["email"], ["name"])
sql = compiled(statement, postgresql.dialect())
self.assertIn("ON CONFLICT (email) DO UPDATE SET name = excluded.name", sql)

def test_sqlite_uses_on_conflict_do_update(self):
statement = build_upsert_statement("sqlite", "users", ROWS, ["email"], ["name"])
sql = compiled(statement, sqlite.dialect())
self.assertIn("ON CONFLICT (email) DO UPDATE SET name = excluded.name", sql)

def test_mysql_uses_on_duplicate_key_update(self):
statement = build_upsert_statement("mysql", "users", ROWS, ["email"], ["name"])
sql = compiled(statement, mysql.dialect())
self.assertIn("ON DUPLICATE KEY UPDATE name = VALUES(name)", sql)

def test_bulk_values_are_a_single_statement(self):
statement = build_upsert_statement("postgres", "users", ROWS, ["email"], ["name"])
sql = compiled(statement, postgresql.dialect())
# Two value tuples → one multi-row INSERT, not a loop.
self.assertEqual(sql.count("INSERT INTO"), 1)
self.assertIn("VALUES", sql)

def test_empty_update_list_falls_back_to_do_nothing(self):
statement = build_upsert_statement("postgres", "users", ROWS, ["email"], [])
sql = compiled(statement, postgresql.dialect())
self.assertIn("ON CONFLICT (email) DO NOTHING", sql)

def test_unsupported_driver_raises(self):
with self.assertRaises(ValueError):
build_upsert_statement("oracle", "users", ROWS, ["email"], ["name"])
Loading