diff --git a/tests/conftest.py b/tests/conftest.py index 6db6308..dca2179 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ import json as _json import tempfile +from collections.abc import Sequence from pathlib import Path import pytest @@ -35,8 +36,9 @@ def __init__(self): class _FakeCursor: - def __init__(self, store: _FakeStore): + def __init__(self, store: _FakeStore, pool: FakePool | None = None): self._s = store + self._pool = pool self.rowcount = 0 self._row = None self._rows: list = [] @@ -48,6 +50,8 @@ def __exit__(self, *_): pass def execute(self, sql: str, params=()): + if self._pool is not None: + self._pool.call_log.append((sql, params)) self._row = None self._rows = [] self.rowcount = 0 @@ -149,7 +153,7 @@ def fetchall(self): class _FakeConn: def __init__(self, store: _FakeStore, pool: FakePool | None = None): - self._cur = _FakeCursor(store) + self._cur = _FakeCursor(store, pool) self._pool = pool self.rollback_called = False @@ -187,6 +191,26 @@ def __init__(self): self._store = _FakeStore() self.fail_on_commit = False self.rollback_count = 0 + self.call_log: list[tuple[str, Sequence]] = [] + + def calls_matching(self, sql_fragment: str) -> list[tuple[str, Sequence]]: + """Return all recorded (sql, params) pairs whose normalised SQL contains sql_fragment.""" + frag = " ".join(sql_fragment.split()).upper() + return [ + (sql, params) for sql, params in self.call_log if frag in " ".join(sql.split()).upper() + ] + + def assert_param_at(self, sql_fragment: str, position: int, expected) -> None: + """Assert that the last call matching sql_fragment has params[position] == expected.""" + matches = self.calls_matching(sql_fragment) + assert matches, f"No SQL call matching {sql_fragment!r} in call_log" + _, params = matches[-1] + assert 0 <= position < len(params), ( + f"params has only {len(params)} element(s); position {position} is out of range" + ) + assert params[position] == expected, ( + f"params[{position}] expected {expected!r}, got {params[position]!r}" + ) def getconn(self): return _FakeConn(self._store, self) diff --git a/tests/test_storage.py b/tests/test_storage.py index ded2160..31466b4 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -415,3 +415,65 @@ def test_matches_paper_with_none_number_never_paper_matched(self, fake_pool): assert paper.number is None result = wl.matches_for_users([paper], []) assert "U1" not in result + + +# ── FakePool parameter capture ──────────────────────────────────────────────── + + +class TestFakePoolParamCapture: + """FakePool captures bound parameters; column-value binding errors are detectable.""" + + def test_call_log_records_sql_and_params(self, fake_pool): + wl = UserWatchlist(fake_pool) + wl.add("U1", "alice") + calls = fake_pool.calls_matching("INSERT INTO user_watchlist") + assert calls + assert calls[-1][1] # params sequence is non-empty + + def test_watchlist_user_id_is_first_param(self, fake_pool): + wl = UserWatchlist(fake_pool) + wl.add("U1", "alice") + fake_pool.assert_param_at("INSERT INTO user_watchlist", 0, "U1") + + def test_watchlist_entry_is_second_param(self, fake_pool): + wl = UserWatchlist(fake_pool) + wl.add("U1", "alice") + fake_pool.assert_param_at("INSERT INTO user_watchlist", 1, "alice") + + def test_watchlist_entry_type_is_third_param(self, fake_pool): + wl = UserWatchlist(fake_pool) + wl.add("U1", "2300") + fake_pool.assert_param_at("INSERT INTO user_watchlist", 2, "paper") + + def test_paper_cache_key_is_first_param(self, fake_pool): + cache = PaperCache(fake_pool) + cache.write({"a": 1}) + fake_pool.assert_param_at("INSERT INTO paper_cache", 0, "wg21_index") + + def test_discovered_url_is_first_param(self, fake_pool): + state = ProbeState(fake_pool) + url = "https://isocpp.org/files/papers/D2300R11.pdf" + state.mark_discovered(url, last_modified_ts=42.0) + fake_pool.assert_param_at("INSERT INTO discovered_urls", 0, url) + + def test_discovered_last_modified_is_second_param(self, fake_pool): + state = ProbeState(fake_pool) + url = "https://isocpp.org/files/papers/D2300R11.pdf" + state.mark_discovered(url, last_modified_ts=42.0) + fake_pool.assert_param_at("INSERT INTO discovered_urls", 1, 42.0) + + def test_assert_param_at_out_of_bounds_raises_assertion_error(self, fake_pool): + wl = UserWatchlist(fake_pool) + wl.add("U1", "alice") + with pytest.raises(AssertionError, match="out of range"): + fake_pool.assert_param_at("INSERT INTO user_watchlist", 99, "anything") + + def test_regression_swapped_user_id_and_entry_would_fail(self, fake_pool): + """ + Prove assert_param_at catches binding transposition. + This test exercises the FakePool test helper itself, not production code. + """ + wl = UserWatchlist(fake_pool) + wl.add("U1", "alice") + with pytest.raises(AssertionError): + fake_pool.assert_param_at("INSERT INTO user_watchlist", 0, "alice")