Skip to content
6 changes: 5 additions & 1 deletion backend/app/api/routes/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from app.api.deps import SessionDep, AuthContextDep
from app.api.permissions import Permission, require_permission
from app.core.telemetry import log_context
from app.core.rate_monitor import monitor_rate
from app.crud import (
CollectionCrud,
CollectionJobCrud,
Expand Down Expand Up @@ -85,7 +86,10 @@ def list_collections(
description=load_description("collections/create.md"),
response_model=APIResponse[CollectionJobImmediatePublic],
callbacks=collection_callback_router.routes,
dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))],
dependencies=[
Depends(require_permission(Permission.REQUIRE_PROJECT)),
Depends(monitor_rate("collections")),
],
)
def create_collection(
session: SessionDep,
Expand Down
6 changes: 5 additions & 1 deletion backend/app/api/routes/evaluations/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)

from app.api.deps import AuthContextDep, SessionDep
from app.core.rate_monitor import monitor_rate
from app.crud.evaluations import list_evaluation_runs as list_evaluation_runs_crud
from app.crud.evaluations.core import group_traces_by_question_id
from app.models.evaluation import EvaluationRunPublic
Expand All @@ -34,7 +35,10 @@
"",
description=load_description("evaluation/create_evaluation.md"),
response_model=APIResponse[EvaluationRunPublic],
dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))],
dependencies=[
Depends(require_permission(Permission.REQUIRE_PROJECT)),
Depends(monitor_rate("evaluations")),
],
)
def evaluate(
session: SessionDep,
Expand Down
7 changes: 5 additions & 2 deletions backend/app/api/routes/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from app.api.permissions import Permission, require_permission
from app.core.cloud.storage import get_cloud_storage
from app.core.telemetry import log_context
from app.core.rate_monitor import monitor_rate
from app.crud.jobs import JobCrud
from app.crud.llm import get_llm_calls_by_job_id
from app.models import (
Expand All @@ -22,7 +23,6 @@
from app.services.llm.jobs import start_job
from app.utils import APIResponse, validate_callback_url, load_description


logger = logging.getLogger(__name__)

router = APIRouter(tags=["LLM"])
Expand Down Expand Up @@ -50,7 +50,10 @@ def llm_callback_notification(body: APIResponse[LLMCallResponse]):
description=load_description("llm/llm_call.md"),
response_model=APIResponse[LLMJobImmediatePublic],
callbacks=llm_callback_router.routes,
dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))],
dependencies=[
Depends(require_permission(Permission.REQUIRE_PROJECT)),
Depends(monitor_rate("llm_call")),
],
)
def llm_call(
_current_user: AuthContextDep, session: SessionDep, request: LLMCallRequest
Expand Down
5 changes: 5 additions & 0 deletions backend/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ def AWS_S3_BUCKET(self) -> str:
BACKEND_SERVICE_NAME: str = "kaapi-backend"
CRON_SERVICE_NAME: str = "kaapi-cron"

# Threshold Request Rate per minute
THRESHOLD_LLM_CALL_RATE: int = 15
THRESHOLD_COLLECTIONS_RATE: int = 3
THRESHOLD_EVALUATIONS_RATE: int = 3

# Celery Configuration
CELERY_WORKER_CONCURRENCY: int | None = None
CELERY_WORKER_MAX_TASKS_PER_CHILD: int = 150
Expand Down
98 changes: 98 additions & 0 deletions backend/app/core/rate_monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import logging
import time

from collections.abc import Callable
from typing import Literal

import redis

from app.api.deps import AuthContextDep
from app.core.config import settings

from app.core.telemetry import record_rate_threshold

logger = logging.getLogger(__name__)

# Categories of rates we want to monitor
RateCategory = Literal["llm_call", "collections", "evaluations"]

# THRESHOLD NUMBERS
THRESHOLDS: dict[RateCategory, int] = {
"llm_call": settings.THRESHOLD_LLM_CALL_RATE,
"collections": settings.THRESHOLD_COLLECTIONS_RATE,
"evaluations": settings.THRESHOLD_EVALUATIONS_RATE,
}

# Delete record after 2 minutes from redis
_EXPIRATION_SECONDS = 120

_redis_client: redis.Redis = redis.from_url(settings.REDIS_URL, decode_responses=True)


# count incrementor after each request and get count
def increment_and_get_count(key: str) -> int | None:
"""Increment the count for the given key and return the new count.
The count will automatically expire after _EXPIRATION_SECONDS.
"""
try:
pipe = _redis_client.pipeline()
pipe.incr(key)
pipe.expire(key, _EXPIRATION_SECONDS)
Comment on lines +39 to +40
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

increment and expire are not atomic; what if increment executes, system crashes, expire does not execute -- key will remain in redis forever

def increment_and_get_count(key: str) -> int | None:
    try:
        # SET NX atomically creates the key with TTL only on first call.
        _redis_client.set(key, 0, ex=_EXPIRATION_SECONDS, nx=True)
        return _redis_client.incr(key)
    except Exception as e:
        logger.error(
            f"[increment_and_get_count] Error incrementing count for {key}: {e}"
        )
        return None

count, _ = pipe.execute()
return count
except Exception as e:
logger.error(
f"[increment_and_get_count] Error incrementing count for {key}: {e}"
)
return None


def monitor_rate(category: RateCategory) -> Callable[[AuthContextDep], None]:
"""Monitor the rate of events for the given category. If the rate exceeds the threshold, record it in telemetry.

Usage:
dependencies=[
Depends(require_permission(Permission.REQUIRE_PROJECT)),
Depends(monitor_rate("{category}")),
]
"""

def _checker(auth_context: AuthContextDep) -> None:
project = auth_context.project
if project is None:
return

threshold = THRESHOLDS.get(category, None)
if threshold is None:
logger.warning(
f"[monitor_rate] No threshold defined for category {category}"
)
return

minute_bucket = int(time.time() // 60)
redis_key = f"rate_monitor:{category}:{project.id}:{minute_bucket}"

try:
count = increment_and_get_count(redis_key)
if count is not None and count == threshold + 1:
logger.warning(
f"[monitor_rate] Rate threshold exceeded for {category} in project {project.id}: count={count}"
)
record_rate_threshold(
project_id=project.id,
project_name=project.name,
category=category,
Comment thread
coderabbitai[bot] marked this conversation as resolved.
request_count=count,
threshold=threshold,
)

except redis.RedisError as e:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

increment_and_get_count returns None after an exception, so this redis.RedisError will practically never fire right? should remove the exception handler there and let this redis.RedisError handle it?

logger.error(
"[monitor_rate] Redis unavailable, skipping rate check "
"(project_id=%s category=%s)",
project.id,
category,
exc_info=e,
)

return _checker
28 changes: 28 additions & 0 deletions backend/app/core/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,34 @@ def record_stale_pending_jobs(
)


def record_rate_threshold(
*,
project_id: int,
project_name: str | None,
category: str,
request_count: int,
threshold: int,
) -> None:
"""Emit rate threshold exceeded event to Sentry."""

try:
if not sentry_sdk.get_client().is_active():
return
with sentry_sdk.push_scope() as scope:
scope.set_tag("alert.type", "threshold_rate_monitor")
scope.set_tag("tenant.project_id", project_id)
scope.set_tag("route_category", category)
scope.set_extra("request_count", request_count)
scope.set_extra("threshold", threshold)
sentry_sdk.capture_message(
f"[Threshold-Monitor] {category} rate limit exceeded for project {project_id} | {project_name}: {request_count} req/min "
f"(limit {threshold}/min)",
level="warning",
)
except Exception as e:
logger.exception("[record_rate_threshold] Failed to emit alert", exc_info=e)


def flush_telemetry(timeout_millis: int = 10000) -> None:
"""Force-flush OTel spans into Sentry, then flush Sentry's transport.

Expand Down
Loading
Loading