Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import json as _json
import tempfile
from collections.abc import Sequence
from pathlib import Path

import pytest
Expand All @@ -35,8 +36,9 @@ def __init__(self):


class _FakeCursor:
def __init__(self, store: _FakeStore):
def __init__(self, store: _FakeStore, pool: FakePool | None = None):
self._s = store
self._pool = pool
self.rowcount = 0
self._row = None
self._rows: list = []
Expand All @@ -48,6 +50,8 @@ def __exit__(self, *_):
pass

def execute(self, sql: str, params=()):
if self._pool is not None:
self._pool.call_log.append((sql, params))
self._row = None
self._rows = []
self.rowcount = 0
Expand Down Expand Up @@ -149,7 +153,7 @@ def fetchall(self):

class _FakeConn:
def __init__(self, store: _FakeStore, pool: FakePool | None = None):
self._cur = _FakeCursor(store)
self._cur = _FakeCursor(store, pool)
self._pool = pool
self.rollback_called = False

Expand Down Expand Up @@ -187,6 +191,26 @@ def __init__(self):
self._store = _FakeStore()
self.fail_on_commit = False
self.rollback_count = 0
self.call_log: list[tuple[str, Sequence]] = []

def calls_matching(self, sql_fragment: str) -> list[tuple[str, Sequence]]:
"""Return all recorded (sql, params) pairs whose normalised SQL contains sql_fragment."""
frag = " ".join(sql_fragment.split()).upper()
return [
(sql, params) for sql, params in self.call_log if frag in " ".join(sql.split()).upper()
]

def assert_param_at(self, sql_fragment: str, position: int, expected) -> None:
"""Assert that the last call matching sql_fragment has params[position] == expected."""
matches = self.calls_matching(sql_fragment)
assert matches, f"No SQL call matching {sql_fragment!r} in call_log"
_, params = matches[-1]
assert 0 <= position < len(params), (
f"params has only {len(params)} element(s); position {position} is out of range"
)
assert params[position] == expected, (
f"params[{position}] expected {expected!r}, got {params[position]!r}"
)

def getconn(self):
return _FakeConn(self._store, self)
Expand Down
62 changes: 62 additions & 0 deletions tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,3 +415,65 @@ def test_matches_paper_with_none_number_never_paper_matched(self, fake_pool):
assert paper.number is None
result = wl.matches_for_users([paper], [])
assert "U1" not in result


# ── FakePool parameter capture ────────────────────────────────────────────────


class TestFakePoolParamCapture:
"""FakePool captures bound parameters; column-value binding errors are detectable."""

def test_call_log_records_sql_and_params(self, fake_pool):
wl = UserWatchlist(fake_pool)
wl.add("U1", "alice")
calls = fake_pool.calls_matching("INSERT INTO user_watchlist")
assert calls
assert calls[-1][1] # params sequence is non-empty

def test_watchlist_user_id_is_first_param(self, fake_pool):
wl = UserWatchlist(fake_pool)
wl.add("U1", "alice")
fake_pool.assert_param_at("INSERT INTO user_watchlist", 0, "U1")

def test_watchlist_entry_is_second_param(self, fake_pool):
wl = UserWatchlist(fake_pool)
wl.add("U1", "alice")
fake_pool.assert_param_at("INSERT INTO user_watchlist", 1, "alice")

def test_watchlist_entry_type_is_third_param(self, fake_pool):
wl = UserWatchlist(fake_pool)
wl.add("U1", "2300")
fake_pool.assert_param_at("INSERT INTO user_watchlist", 2, "paper")

def test_paper_cache_key_is_first_param(self, fake_pool):
cache = PaperCache(fake_pool)
cache.write({"a": 1})
fake_pool.assert_param_at("INSERT INTO paper_cache", 0, "wg21_index")

def test_discovered_url_is_first_param(self, fake_pool):
state = ProbeState(fake_pool)
url = "https://isocpp.org/files/papers/D2300R11.pdf"
state.mark_discovered(url, last_modified_ts=42.0)
fake_pool.assert_param_at("INSERT INTO discovered_urls", 0, url)

def test_discovered_last_modified_is_second_param(self, fake_pool):
state = ProbeState(fake_pool)
url = "https://isocpp.org/files/papers/D2300R11.pdf"
state.mark_discovered(url, last_modified_ts=42.0)
fake_pool.assert_param_at("INSERT INTO discovered_urls", 1, 42.0)

def test_assert_param_at_out_of_bounds_raises_assertion_error(self, fake_pool):
wl = UserWatchlist(fake_pool)
wl.add("U1", "alice")
with pytest.raises(AssertionError, match="out of range"):
fake_pool.assert_param_at("INSERT INTO user_watchlist", 99, "anything")

def test_regression_swapped_user_id_and_entry_would_fail(self, fake_pool):
"""
Prove assert_param_at catches binding transposition.
This test exercises the FakePool test helper itself, not production code.
"""
wl = UserWatchlist(fake_pool)
wl.add("U1", "alice")
with pytest.raises(AssertionError):
fake_pool.assert_param_at("INSERT INTO user_watchlist", 0, "alice")