diff --git a/backend/app/api/routes/collections.py b/backend/app/api/routes/collections.py index 6c9bee17a..9bc104aa7 100644 --- a/backend/app/api/routes/collections.py +++ b/backend/app/api/routes/collections.py @@ -8,6 +8,7 @@ from app.api.deps import SessionDep, AuthContextDep from app.api.permissions import Permission, require_permission from app.core.telemetry import log_context +from app.core.rate_monitor import monitor_rate from app.crud import ( CollectionCrud, CollectionJobCrud, @@ -85,7 +86,10 @@ def list_collections( description=load_description("collections/create.md"), response_model=APIResponse[CollectionJobImmediatePublic], callbacks=collection_callback_router.routes, - dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + dependencies=[ + Depends(require_permission(Permission.REQUIRE_PROJECT)), + Depends(monitor_rate("collections")), + ], ) def create_collection( session: SessionDep, diff --git a/backend/app/api/routes/evaluations/evaluation.py b/backend/app/api/routes/evaluations/evaluation.py index 591f5d985..216afe95c 100644 --- a/backend/app/api/routes/evaluations/evaluation.py +++ b/backend/app/api/routes/evaluations/evaluation.py @@ -12,6 +12,7 @@ ) from app.api.deps import AuthContextDep, SessionDep +from app.core.rate_monitor import monitor_rate from app.crud.evaluations import list_evaluation_runs as list_evaluation_runs_crud from app.crud.evaluations.core import group_traces_by_question_id from app.models.evaluation import EvaluationRunPublic @@ -34,7 +35,10 @@ "", description=load_description("evaluation/create_evaluation.md"), response_model=APIResponse[EvaluationRunPublic], - dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + dependencies=[ + Depends(require_permission(Permission.REQUIRE_PROJECT)), + Depends(monitor_rate("evaluations")), + ], ) def evaluate( session: SessionDep, diff --git a/backend/app/api/routes/llm.py b/backend/app/api/routes/llm.py index 8106046e7..8ce96cf76 100644 --- a/backend/app/api/routes/llm.py +++ b/backend/app/api/routes/llm.py @@ -9,6 +9,7 @@ from app.api.permissions import Permission, require_permission from app.core.cloud.storage import get_cloud_storage from app.core.telemetry import log_context +from app.core.rate_monitor import monitor_rate from app.crud.jobs import JobCrud from app.crud.llm import get_llm_calls_by_job_id from app.models import ( @@ -22,7 +23,6 @@ from app.services.llm.jobs import start_job from app.utils import APIResponse, validate_callback_url, load_description - logger = logging.getLogger(__name__) router = APIRouter(tags=["LLM"]) @@ -50,7 +50,10 @@ def llm_callback_notification(body: APIResponse[LLMCallResponse]): description=load_description("llm/llm_call.md"), response_model=APIResponse[LLMJobImmediatePublic], callbacks=llm_callback_router.routes, - dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + dependencies=[ + Depends(require_permission(Permission.REQUIRE_PROJECT)), + Depends(monitor_rate("llm_call")), + ], ) def llm_call( _current_user: AuthContextDep, session: SessionDep, request: LLMCallRequest diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 3088d8002..8d4e44d5c 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -151,6 +151,11 @@ def AWS_S3_BUCKET(self) -> str: BACKEND_SERVICE_NAME: str = "kaapi-backend" CRON_SERVICE_NAME: str = "kaapi-cron" + # Threshold Request Rate per minute + THRESHOLD_LLM_CALL_RATE: int = 15 + THRESHOLD_COLLECTIONS_RATE: int = 3 + THRESHOLD_EVALUATIONS_RATE: int = 3 + # Celery Configuration CELERY_WORKER_CONCURRENCY: int | None = None CELERY_WORKER_MAX_TASKS_PER_CHILD: int = 150 diff --git a/backend/app/core/rate_monitor.py b/backend/app/core/rate_monitor.py new file mode 100644 index 000000000..ff0da9f6f --- /dev/null +++ b/backend/app/core/rate_monitor.py @@ -0,0 +1,98 @@ +import logging +import time + +from collections.abc import Callable +from typing import Literal + +import redis + +from app.api.deps import AuthContextDep +from app.core.config import settings + +from app.core.telemetry import record_rate_threshold + +logger = logging.getLogger(__name__) + +# Categories of rates we want to monitor +RateCategory = Literal["llm_call", "collections", "evaluations"] + +# THRESHOLD NUMBERS +THRESHOLDS: dict[RateCategory, int] = { + "llm_call": settings.THRESHOLD_LLM_CALL_RATE, + "collections": settings.THRESHOLD_COLLECTIONS_RATE, + "evaluations": settings.THRESHOLD_EVALUATIONS_RATE, +} + +# Delete record after 2 minutes from redis +_EXPIRATION_SECONDS = 120 + +_redis_client: redis.Redis = redis.from_url(settings.REDIS_URL, decode_responses=True) + + +# count incrementor after each request and get count +def increment_and_get_count(key: str) -> int | None: + """Increment the count for the given key and return the new count. + The count will automatically expire after _EXPIRATION_SECONDS. + """ + try: + pipe = _redis_client.pipeline() + pipe.incr(key) + pipe.expire(key, _EXPIRATION_SECONDS) + count, _ = pipe.execute() + return count + except Exception as e: + logger.error( + f"[increment_and_get_count] Error incrementing count for {key}: {e}" + ) + return None + + +def monitor_rate(category: RateCategory) -> Callable[[AuthContextDep], None]: + """Monitor the rate of events for the given category. If the rate exceeds the threshold, record it in telemetry. + + Usage: + dependencies=[ + Depends(require_permission(Permission.REQUIRE_PROJECT)), + Depends(monitor_rate("{category}")), + ] + """ + + def _checker(auth_context: AuthContextDep) -> None: + project = auth_context.project + if project is None: + return + + threshold = THRESHOLDS.get(category, None) + if threshold is None: + logger.warning( + f"[monitor_rate] No threshold defined for category {category}" + ) + return + + minute_bucket = int(time.time() // 60) + redis_key = f"rate_monitor:{category}:{project.id}:{minute_bucket}" + + try: + count = increment_and_get_count(redis_key) + if count is not None and count == threshold + 1: + logger.warning( + f"[monitor_rate] Rate threshold exceeded for {category} in project {project.id}: count={count}" + ) + record_rate_threshold( + project_id=project.id, + project_name=project.name, + category=category, + request_count=count, + threshold=threshold, + ) + + except redis.RedisError as e: + logger.error( + "[monitor_rate] Redis unavailable, skipping rate check " + "(project_id=%s category=%s)", + project.id, + category, + exc_info=e, + ) + + return _checker diff --git a/backend/app/core/telemetry.py b/backend/app/core/telemetry.py index 963c3086b..99d2fc959 100644 --- a/backend/app/core/telemetry.py +++ b/backend/app/core/telemetry.py @@ -453,6 +453,34 @@ def record_stale_pending_jobs( ) +def record_rate_threshold( + *, + project_id: int, + project_name: str | None, + category: str, + request_count: int, + threshold: int, +) -> None: + """Emit rate threshold exceeded event to Sentry.""" + + try: + if not sentry_sdk.get_client().is_active(): + return + with sentry_sdk.push_scope() as scope: + scope.set_tag("alert.type", "threshold_rate_monitor") + scope.set_tag("tenant.project_id", project_id) + scope.set_tag("route_category", category) + scope.set_extra("request_count", request_count) + scope.set_extra("threshold", threshold) + sentry_sdk.capture_message( + f"[Threshold-Monitor] {category} rate limit exceeded for project {project_id} | {project_name}: {request_count} req/min " + f"(limit {threshold}/min)", + level="warning", + ) + except Exception as e: + logger.exception("[record_rate_threshold] Failed to emit alert", exc_info=e) + + def flush_telemetry(timeout_millis: int = 10000) -> None: """Force-flush OTel spans into Sentry, then flush Sentry's transport. diff --git a/backend/app/tests/core/test_rate_monitor.py b/backend/app/tests/core/test_rate_monitor.py new file mode 100644 index 000000000..39221039c --- /dev/null +++ b/backend/app/tests/core/test_rate_monitor.py @@ -0,0 +1,229 @@ +"""Tests for rate_monitor.py and the record_rate_threshold telemetry helper. +All Redis and Sentry calls are mocked; no real Redis or Sentry connection is used. +""" + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import redis + +from app.core import rate_monitor, telemetry + + +def _auth_context(project_id: int | None = 1, project_name: str = "Acme"): + """Build a minimal stand-in for AuthContext. + + monitor_rate's checker only reads auth_context.project.id and .name, + so a SimpleNamespace is enough — no DB or real models required. + """ + project = ( + None + if project_id is None + else SimpleNamespace(id=project_id, name=project_name) + ) + return SimpleNamespace(project=project) + + +# --------------------------------------------------------------------------- +# increment_and_get_count +# --------------------------------------------------------------------------- + + +class TestIncrementAndGetCount: + def test_returns_count_and_sets_expiry(self): + """Pipeline runs INCR + EXPIRE and returns the incremented value.""" + pipe = MagicMock() + pipe.execute.return_value = [5, True] # [incr_result, expire_result] + fake_redis = MagicMock() + fake_redis.pipeline.return_value = pipe + + with patch.object(rate_monitor, "_redis_client", fake_redis): + count = rate_monitor.increment_and_get_count("some-key") + + assert count == 5 + pipe.incr.assert_called_once_with("some-key") + pipe.expire.assert_called_once_with( + "some-key", rate_monitor._EXPIRATION_SECONDS + ) + + def test_returns_none_on_redis_error(self): + """Any Redis failure is caught and returns None rather than raising.""" + fake_redis = MagicMock() + fake_redis.pipeline.side_effect = redis.RedisError("boom") + + with patch.object(rate_monitor, "_redis_client", fake_redis): + count = rate_monitor.increment_and_get_count("some-key") + + assert count is None + + +# --------------------------------------------------------------------------- +# monitor_rate / _checker +# --------------------------------------------------------------------------- + + +class TestMonitorRate: + def test_skips_when_no_project(self): + """No project on the request → nothing counted, no Redis call.""" + checker = rate_monitor.monitor_rate("llm_call") + + with patch.object(rate_monitor, "increment_and_get_count") as inc: + checker(_auth_context(project_id=None)) + + inc.assert_not_called() + + def test_skips_when_no_threshold_for_category(self): + """Unknown category has no threshold → return early, no alert.""" + checker = rate_monitor.monitor_rate("unknown") # type: ignore[arg-type] + + with ( + patch.object(rate_monitor, "increment_and_get_count") as inc, + patch.object(rate_monitor, "record_rate_threshold") as record, + ): + checker(_auth_context()) + + inc.assert_not_called() + record.assert_not_called() + + def test_no_alert_when_under_threshold(self): + """Count at or below the threshold does not alert.""" + checker = rate_monitor.monitor_rate("collections") + threshold = rate_monitor.THRESHOLDS["collections"] + + with ( + patch.object( + rate_monitor, "increment_and_get_count", return_value=threshold + ), + patch.object(rate_monitor, "record_rate_threshold") as record, + ): + checker(_auth_context()) + + record.assert_not_called() + + def test_no_alert_when_already_breached(self): + """Only the first breach (threshold + 1) alerts; later counts stay silent.""" + checker = rate_monitor.monitor_rate("collections") + threshold = rate_monitor.THRESHOLDS["collections"] + + with ( + patch.object( + rate_monitor, "increment_and_get_count", return_value=threshold + 2 + ), + patch.object(rate_monitor, "record_rate_threshold") as record, + ): + checker(_auth_context()) + + record.assert_not_called() + + def test_alerts_when_over_threshold(self): + """First breach (threshold + 1) records a Sentry alert with project details.""" + checker = rate_monitor.monitor_rate("llm_call") + threshold = rate_monitor.THRESHOLDS["llm_call"] + over = threshold + 1 + + with ( + patch.object(rate_monitor, "increment_and_get_count", return_value=over), + patch.object(rate_monitor, "record_rate_threshold") as record, + ): + checker(_auth_context(project_id=616, project_name="Acme")) + + record.assert_called_once_with( + project_id=616, + project_name="Acme", + category="llm_call", + request_count=over, + threshold=threshold, + ) + + def test_no_alert_when_count_is_none(self): + """increment returning None (Redis down) is treated as no breach.""" + checker = rate_monitor.monitor_rate("llm_call") + + with ( + patch.object(rate_monitor, "increment_and_get_count", return_value=None), + patch.object(rate_monitor, "record_rate_threshold") as record, + ): + checker(_auth_context()) + + record.assert_not_called() + + def test_redis_error_is_swallowed(self): + """A RedisError from increment must not propagate out of the checker.""" + checker = rate_monitor.monitor_rate("llm_call") + + with patch.object( + rate_monitor, + "increment_and_get_count", + side_effect=redis.RedisError("down"), + ): + # Should not raise. + checker(_auth_context()) + + +# --------------------------------------------------------------------------- +# telemetry.record_rate_threshold +# --------------------------------------------------------------------------- + + +class TestRecordRateThreshold: + def test_emits_warning_message_with_tags(self): + """When Sentry is active, a warning message is captured with tags/extras.""" + client = MagicMock() + client.is_active.return_value = True + scope = MagicMock() + scope_cm = MagicMock() + scope_cm.__enter__.return_value = scope + + with ( + patch.object(telemetry.sentry_sdk, "get_client", return_value=client), + patch.object(telemetry.sentry_sdk, "push_scope", return_value=scope_cm), + patch.object(telemetry.sentry_sdk, "capture_message") as capture, + ): + telemetry.record_rate_threshold( + project_id=616, + project_name="Acme", + category="llm_call", + request_count=16, + threshold=15, + ) + + capture.assert_called_once() + assert capture.call_args.kwargs["level"] == "warning" + scope.set_tag.assert_any_call("alert.type", "threshold_rate_monitor") + scope.set_tag.assert_any_call("tenant.project_id", 616) + scope.set_extra.assert_any_call("request_count", 16) + scope.set_extra.assert_any_call("threshold", 15) + + def test_noop_when_sentry_inactive(self): + """No Sentry client → nothing is captured.""" + client = MagicMock() + client.is_active.return_value = False + + with ( + patch.object(telemetry.sentry_sdk, "get_client", return_value=client), + patch.object(telemetry.sentry_sdk, "capture_message") as capture, + ): + telemetry.record_rate_threshold( + project_id=1, + project_name="Acme", + category="llm_call", + request_count=16, + threshold=15, + ) + + capture.assert_not_called() + + def test_swallows_exceptions(self): + """An error inside Sentry emission must never propagate.""" + client = MagicMock() + client.is_active.side_effect = RuntimeError("sentry exploded") + + with patch.object(telemetry.sentry_sdk, "get_client", return_value=client): + # Should not raise. + telemetry.record_rate_threshold( + project_id=1, + project_name="Acme", + category="llm_call", + request_count=16, + threshold=15, + )