diff --git a/packages/core/src/imednet/core/http/executor.py b/packages/core/src/imednet/core/http/executor.py index ce81abf13..83d52679d 100644 --- a/packages/core/src/imednet/core/http/executor.py +++ b/packages/core/src/imednet/core/http/executor.py @@ -21,9 +21,9 @@ wait_random_exponential, ) -from imednet.core.http.circuit_breaker import CircuitBreakerError, get_global_circuit_breaker from imednet.core.http.handlers import handle_response from imednet.core.http.monitor import RequestMonitor +from imednet.core.operations.circuit_breaker import CircuitBreakerError, get_global_circuit_breaker from imednet.core.retry import DefaultRetryPolicy, RetryPolicy, RetryState _SUPPRESSED_LOG_LEVEL = logging.CRITICAL + 1 diff --git a/packages/core/src/imednet/core/operations/__init__.py b/packages/core/src/imednet/core/operations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/packages/core/src/imednet/core/http/circuit_breaker.py b/packages/core/src/imednet/core/operations/circuit_breaker.py similarity index 100% rename from packages/core/src/imednet/core/http/circuit_breaker.py rename to packages/core/src/imednet/core/operations/circuit_breaker.py diff --git a/packages/core/src/imednet/core/operations/executor.py b/packages/core/src/imednet/core/operations/executor.py new file mode 100644 index 000000000..3e2eff190 --- /dev/null +++ b/packages/core/src/imednet/core/operations/executor.py @@ -0,0 +1,149 @@ +""" +Universal execution wrapper that applies exponential backoff retries and circuit breaking +to any compliant operation. +""" + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, TypeVar + +from tenacity import ( + AsyncRetrying, + RetryCallState, + RetryError, + Retrying, + stop_after_attempt, + wait_random_exponential, +) + +from imednet.core.operations.circuit_breaker import get_global_circuit_breaker +from imednet.core.operations.monitor import OperationMonitor + +if TYPE_CHECKING: + from opentelemetry.trace import Tracer +else: + Tracer = Any + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +class OperationRetryPolicy(ABC): + @abstractmethod + def should_retry(self, exception: Exception) -> bool: + """Return True if the exception should trigger a retry.""" + pass + + +class DefaultOperationRetryPolicy(OperationRetryPolicy): + def should_retry(self, exception: Exception) -> bool: + # Default fallback: retry on any exception? + # Usually we only retry on specific ones, but for universal wrapper, we can allow everything or leave it configurable + return True + + +class UniversalExecutor: + """Execute arbitrary operations with retry, circuit breaking, and telemetry.""" + + def __init__( + self, + retries: int, + backoff_factor: float, + tracer: Optional[Tracer] = None, + retry_policy: Optional[OperationRetryPolicy] = None, + operation_name: str = "operation", + wait_strategy: Optional[Callable[[RetryCallState], float]] = None, + retry_predicate: Optional[Callable[[RetryCallState], bool]] = None, + **attributes: Any, + ) -> None: + self.retries = retries + self.backoff_factor = backoff_factor + self.tracer = tracer + self.retry_policy = retry_policy or DefaultOperationRetryPolicy() + self.operation_name = operation_name + self.attributes = attributes + self._jitter_wait = wait_random_exponential(multiplier=self.backoff_factor) + self.wait_strategy = wait_strategy or (lambda rs: float(self._jitter_wait(rs))) + self.retry_predicate = retry_predicate or self._should_retry_wrapper + + def _should_retry_wrapper(self, retry_state: RetryCallState) -> bool: + if retry_state.outcome and retry_state.outcome.failed: + exc = retry_state.outcome.exception() + if isinstance(exc, Exception): + return self.retry_policy.should_retry(exc) + return False + + def execute(self, func: Callable[[], T]) -> T: + """Synchronous execution.""" + get_global_circuit_breaker().check_request_allowed() + + retryer = Retrying( + stop=stop_after_attempt(self.retries + 1), + wait=self.wait_strategy, + retry=self.retry_predicate, + reraise=False, + ) + + with OperationMonitor(self.tracer, self.operation_name, **self.attributes) as monitor: + try: + result: Any = retryer(func) + get_global_circuit_breaker().record_success() + monitor.on_success() + return result + except RetryError as e: + get_global_circuit_breaker().record_failure() + cause = e.last_attempt.exception() if e.last_attempt else e + if isinstance(cause, Exception): + try: + monitor.on_retry_error(cause, self.retries) + except Exception as _exc: + if _exc is not cause: + raise + if cause is not None and cause is not e: + raise cause + raise + except Exception as e: + get_global_circuit_breaker().record_failure() + monitor.on_failure(e) + raise + + async def execute_async(self, func: Callable[[], Awaitable[T]]) -> T: + """Asynchronous execution.""" + get_global_circuit_breaker().check_request_allowed() + + retryer = AsyncRetrying( + stop=stop_after_attempt(self.retries + 1), + wait=self.wait_strategy, + retry=self.retry_predicate, + reraise=False, + ) + + async with OperationMonitor(self.tracer, self.operation_name, **self.attributes) as monitor: + try: + + async def _async_wrapper() -> T: + return await func() + + result: Any = await retryer(_async_wrapper) + get_global_circuit_breaker().record_success() + monitor.on_success() + return result + except RetryError as e: + get_global_circuit_breaker().record_failure() + cause = e.last_attempt.exception() if e.last_attempt else e + if isinstance(cause, Exception): + try: + monitor.on_retry_error(cause, self.retries) + except Exception as _exc: + if _exc is not cause: + raise + if cause is not None and cause is not e: + raise cause + raise + except Exception as e: + get_global_circuit_breaker().record_failure() + monitor.on_failure(e) + raise diff --git a/packages/core/src/imednet/core/operations/monitor.py b/packages/core/src/imednet/core/operations/monitor.py new file mode 100644 index 000000000..4a21f796e --- /dev/null +++ b/packages/core/src/imednet/core/operations/monitor.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +import logging +import time +from contextlib import nullcontext +from typing import TYPE_CHECKING, Any, NoReturn, Optional + +if TYPE_CHECKING: + from opentelemetry.trace import Tracer +else: + Tracer = Any + +logger = logging.getLogger(__name__) + + +class OperationMonitor: + """Helper to handle generic operation monitoring (tracing, timing, logging).""" + + def __init__(self, tracer: Optional[Tracer], operation_name: str, **attributes: Any) -> None: + self.tracer = tracer + self.operation_name = operation_name + self.attributes = attributes + self.start_time: float = 0.0 + self.span: Any = None + self._cm: Any = None + + def _create_cm(self) -> Any: + if self.tracer: + return self.tracer.start_as_current_span( + self.operation_name, + attributes=self.attributes, + ) + return nullcontext() + + def __enter__(self) -> "OperationMonitor": + self._cm = self._create_cm() + self.span = self._cm.__enter__() + self.start_time = time.monotonic() + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._cm: + self._cm.__exit__(exc_type, exc_val, exc_tb) + + async def __aenter__(self) -> "OperationMonitor": + self._cm = self._create_cm() + # Handle async context managers if the tracer supports them + if hasattr(self._cm, "__aenter__"): + self.span = await self._cm.__aenter__() + else: + self.span = self._cm.__enter__() + self.start_time = time.monotonic() + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._cm: + if hasattr(self._cm, "__aexit__"): + await self._cm.__aexit__(exc_type, exc_val, exc_tb) + else: + self._cm.__exit__(exc_type, exc_val, exc_tb) + + def on_success(self, **extra_attributes: Any) -> None: + latency = time.monotonic() - self.start_time + logger.info( + f"{self.operation_name} succeeded", + extra={**self.attributes, **extra_attributes, "latency": latency}, + ) + if self.span: + for k, v in extra_attributes.items(): + self.span.set_attribute(k, v) + self.span.set_attribute("status", "success") + + def on_retry_error(self, cause: Exception, retries: int) -> NoReturn: + logger.error( + f"{self.operation_name} failed after retries", + extra={**self.attributes, "retries": retries}, + ) + if self.span: + self.span.set_attribute("status", "failed") + self.span.set_attribute("retries", retries) + raise cause + + def on_failure(self, cause: Exception) -> None: + if self.span: + self.span.set_attribute("status", "failed") diff --git a/packages/core/src/imednet/core/operations/protocols.py b/packages/core/src/imednet/core/operations/protocols.py new file mode 100644 index 000000000..98e2ee284 --- /dev/null +++ b/packages/core/src/imednet/core/operations/protocols.py @@ -0,0 +1,17 @@ +from typing import Any, Awaitable, Protocol, TypeVar, runtime_checkable + +T = TypeVar("T", covariant=True) + + +@runtime_checkable +class OperationProtocol(Protocol[T]): + """Protocol for synchronous operations.""" + + def execute(self, *args: Any, **kwargs: Any) -> T: ... + + +@runtime_checkable +class AsyncOperationProtocol(Protocol[T]): + """Protocol for asynchronous operations.""" + + def execute(self, *args: Any, **kwargs: Any) -> Awaitable[T]: ... diff --git a/packages/core/src/imednet/integrations/sink_base.py b/packages/core/src/imednet/integrations/sink_base.py index 8691942d9..f21b55a42 100644 --- a/packages/core/src/imednet/integrations/sink_base.py +++ b/packages/core/src/imednet/integrations/sink_base.py @@ -114,6 +114,7 @@ class SinkConfig: extra: dict[str, Any] = field(default_factory=dict) quality_gate_enabled: bool = False min_schema_readiness_score: float = 100.0 + tracer: Optional[Any] = field(default=None, repr=False) # --------------------------------------------------------------------------- diff --git a/packages/plugins-sinks/src/imednet_sinks/document.py b/packages/plugins-sinks/src/imednet_sinks/document.py index 3d8fe7553..cb97e4e2a 100644 --- a/packages/plugins-sinks/src/imednet_sinks/document.py +++ b/packages/plugins-sinks/src/imednet_sinks/document.py @@ -179,63 +179,49 @@ def _connect(self) -> None: # ------------------------------------------------------------------ def write_batch(self, records: Sequence[Any], *, batch_id: str) -> int: - """Write *records* to MongoDB using upsert (idempotent) or insert. - - Parameters - ---------- - records: - Sequence of typed ``Record`` model instances. - batch_id: - Idempotency key (e.g. ``"MYSTUDY/FORM1/0"``). - - Returns - ------- - int - Number of records written (upserted or inserted). - """ + """Write *records* to MongoDB using upsert (idempotent) or insert.""" docs = [_record_to_document(r, self._study_key) for r in records] if not docs: return 0 - last_exc: Optional[Exception] = None - for attempt in range(self.config.max_retries + 1): - try: - if self.config.idempotent: - pymongo = _require_optional_dep("pymongo", "mongodb") - ops = [ - pymongo.UpdateOne( - {"_id": doc["_id"]}, - {"$set": doc}, - upsert=True, - ) - for doc in docs - ] - result = self._collection.bulk_write(ops, ordered=False) - written = len(docs) - else: - result = self._collection.insert_many(docs, ordered=False) - written = len(result.inserted_ids) - - logger.debug("Wrote batch %s (%d records)", batch_id, written) - return written - except Exception as exc: # noqa: BLE001 - last_exc = exc - if attempt < self.config.max_retries: - delay = self.config.retry_backoff * (2**attempt) - logger.warning( - "Batch %s attempt %d failed (%s); retrying in %.1fs", - batch_id, - attempt + 1, - exc, - delay, - ) - time.sleep(delay) + from imednet.core.operations.executor import UniversalExecutor - raise ExportBatchError( - f"Batch {batch_id!r} failed after {self.config.max_retries + 1} attempts: {last_exc}", + def execute_export() -> int: + if self.config.idempotent: + pymongo = _require_optional_dep("pymongo", "mongodb") + ops = [ + pymongo.UpdateOne( + {"_id": doc["_id"]}, + {"$set": doc}, + upsert=True, + ) + for doc in docs + ] + self._collection.bulk_write(ops, ordered=False) + written = len(docs) + else: + result = self._collection.insert_many(docs, ordered=False) + written = len(result.inserted_ids) + + logger.debug("Wrote batch %s (%d records)", batch_id, written) + return written + + executor = UniversalExecutor( + retries=self.config.max_retries, + backoff_factor=self.config.retry_backoff, + tracer=self.config.tracer, + operation_name="export_mongodb", batch_id=batch_id, ) + try: + return executor.execute(execute_export) + except Exception as exc: + raise ExportBatchError( + f"Batch {batch_id!r} failed after {self.config.max_retries + 1} attempts: {exc}", + batch_id=batch_id, + ) from exc + def flush(self) -> None: """No-op: MongoDB writes are committed per bulk operation.""" diff --git a/packages/plugins-sinks/src/imednet_sinks/graph.py b/packages/plugins-sinks/src/imednet_sinks/graph.py index f503cb986..02c772254 100644 --- a/packages/plugins-sinks/src/imednet_sinks/graph.py +++ b/packages/plugins-sinks/src/imednet_sinks/graph.py @@ -190,20 +190,7 @@ def _connect(self) -> None: # ------------------------------------------------------------------ def write_batch(self, records: Sequence[Any], *, batch_id: str) -> int: - """Write *records* to Neo4j using MERGE (idempotent) or CREATE. - - Parameters - ---------- - records: - Sequence of typed ``Record`` model instances. - batch_id: - Idempotency key (e.g. ``"MYSTUDY/FORM1/0"``). - - Returns - ------- - int - Number of records written. - """ + """Write *records* to Neo4j using MERGE (idempotent) or CREATE.""" rows = [_record_to_row(r, self._study_key) for r in records] if not rows: return 0 @@ -211,31 +198,30 @@ def write_batch(self, records: Sequence[Any], *, batch_id: str) -> int: cypher = _MERGE_RECORD_CYPHER if self.config.idempotent else _CREATE_RECORD_CYPHER cfg = self.config if isinstance(self.config, Neo4jSinkConfig) else Neo4jSinkConfig() - last_exc: Optional[Exception] = None - for attempt in range(self.config.max_retries + 1): - try: - with self._driver.session(database=cfg.database) as session: - session.run(cypher, rows=rows) - logger.debug("Wrote batch %s (%d records)", batch_id, len(rows)) - return len(rows) - except Exception as exc: # noqa: BLE001 - last_exc = exc - if attempt < self.config.max_retries: - delay = self.config.retry_backoff * (2**attempt) - logger.warning( - "Batch %s attempt %d failed (%s); retrying in %.1fs", - batch_id, - attempt + 1, - exc, - delay, - ) - time.sleep(delay) - - raise ExportBatchError( - f"Batch {batch_id!r} failed after {self.config.max_retries + 1} attempts: {last_exc}", + from imednet.core.operations.executor import UniversalExecutor + + def execute_export() -> int: + with self._driver.session(database=cfg.database) as session: + session.run(cypher, rows=rows) + logger.debug("Wrote batch %s (%d records)", batch_id, len(rows)) + return len(rows) + + executor = UniversalExecutor( + retries=self.config.max_retries, + backoff_factor=self.config.retry_backoff, + tracer=self.config.tracer, + operation_name="export_graph", batch_id=batch_id, ) + try: + return executor.execute(execute_export) + except Exception as exc: + raise ExportBatchError( + f"Batch {batch_id!r} failed after {self.config.max_retries + 1} attempts: {exc}", + batch_id=batch_id, + ) from exc + def flush(self) -> None: """No-op: Neo4j writes are committed per transaction.""" diff --git a/packages/plugins-sinks/src/imednet_sinks/warehouse.py b/packages/plugins-sinks/src/imednet_sinks/warehouse.py index 43aa0ad5d..704d223ed 100644 --- a/packages/plugins-sinks/src/imednet_sinks/warehouse.py +++ b/packages/plugins-sinks/src/imednet_sinks/warehouse.py @@ -234,24 +234,10 @@ def _connect(self) -> None: # ------------------------------------------------------------------ def write_batch(self, records: Sequence[Any], *, batch_id: str) -> int: - """Write *records* to Snowflake via Parquet staging + COPY INTO. - - Parameters - ---------- - records: - Sequence of typed ``Record`` model instances or plain dicts. - batch_id: - Idempotency key (e.g. ``"MYSTUDY/FORM1/0"``). - - Returns - ------- - int - Number of rows loaded. - """ + """Write *records* to Snowflake via Parquet staging + COPY INTO.""" if not records: return 0 - # 1. Convert to Parquet arrow_table = _records_to_arrow_table(records) safe_batch = batch_id.replace("/", "_").replace(" ", "_") local_path = Path(self._staging_dir) / f"{safe_batch}.parquet" @@ -261,13 +247,12 @@ def write_batch(self, records: Sequence[Any], *, batch_id: str) -> int: cfg = self._cfg stage_path = f"@{cfg.stage}/{cfg.stage_prefix}/{safe_batch}.parquet" - last_exc: Optional[Exception] = None - for attempt in range(self.config.max_retries + 1): + from imednet.core.operations.executor import UniversalExecutor + + def execute_export() -> int: + cur = self._conn.cursor() try: - cur = self._conn.cursor() - # 2. PUT to stage cur.execute(f"PUT file://{local_path} @{cfg.stage}/{cfg.stage_prefix}/") # nosem - # 3. COPY INTO table force_clause = "FORCE = FALSE" if self.config.idempotent else "FORCE = TRUE" cur.execute( f"COPY INTO {cfg.database}.{cfg.schema}.{cfg.table} " @@ -277,7 +262,6 @@ def write_batch(self, records: Sequence[Any], *, batch_id: str) -> int: f"{force_clause}" ) # nosem rows_loaded = len(records) - cur.close() logger.debug( "Loaded batch %s (%d rows) via stage %s", batch_id, @@ -286,24 +270,25 @@ def write_batch(self, records: Sequence[Any], *, batch_id: str) -> int: ) self._append_manifest(batch_id, stage_path, rows_loaded) return rows_loaded - except Exception as exc: # noqa: BLE001 - last_exc = exc - if attempt < self.config.max_retries: - delay = self.config.retry_backoff * (2**attempt) - logger.warning( - "Batch %s attempt %d failed (%s); retrying in %.1fs", - batch_id, - attempt + 1, - exc, - delay, - ) - time.sleep(delay) - - raise ExportBatchError( - f"Batch {batch_id!r} failed after {self.config.max_retries + 1} attempts: {last_exc}", + finally: + cur.close() + + executor = UniversalExecutor( + retries=self.config.max_retries, + backoff_factor=self.config.retry_backoff, + tracer=self.config.tracer, + operation_name="export_warehouse", batch_id=batch_id, ) + try: + return executor.execute(execute_export) + except Exception as exc: + raise ExportBatchError( + f"Batch {batch_id!r} failed after {self.config.max_retries + 1} attempts: {exc}", + batch_id=batch_id, + ) from exc + def flush(self) -> None: """No-op: each batch is committed individually.""" diff --git a/scripts/post_smoke_record.py b/scripts/post_smoke_record.py index af98d5bfa..fd3527897 100644 --- a/scripts/post_smoke_record.py +++ b/scripts/post_smoke_record.py @@ -42,7 +42,7 @@ logger = logging.getLogger(__name__) -POLL_TIMEOUT = 90 +POLL_TIMEOUT = 300 SKIP_EXIT_CODE = 2 @@ -154,7 +154,11 @@ def submit_record(sdk: ImednetSDK, study_key: str, record: Dict[str, Any], *, ti """Create ``record`` and return the resulting batch ID.""" job = sdk.records.create(study_key, [record]) if not job.batch_id: - if not job.state or job.state in ("COMPLETED", "SUCCESS") or (job.state and job.state.upper() in ("COMPLETED", "SUCCESS")): + if ( + not job.state + or job.state in ("COMPLETED", "SUCCESS") + or (job.state and job.state.upper() in ("COMPLETED", "SUCCESS")) + ): return "sync-created" raise RuntimeError(f"Record creation returned no batch ID: {job}") diff --git a/tests/conftest.py b/tests/conftest.py index 04781e0cd..e0ca749dd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -59,7 +59,7 @@ def reset_study_context_between_tests(): @pytest.fixture(autouse=True) def reset_circuit_breaker_between_tests(): - from imednet.core.http.circuit_breaker import get_global_circuit_breaker + from imednet.core.operations.circuit_breaker import get_global_circuit_breaker get_global_circuit_breaker().reset() yield diff --git a/tests/unit/core/operations/__init__.py b/tests/unit/core/operations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/core/operations/test_executor.py b/tests/unit/core/operations/test_executor.py new file mode 100644 index 000000000..7fea88362 --- /dev/null +++ b/tests/unit/core/operations/test_executor.py @@ -0,0 +1,75 @@ +import pytest + +from imednet.core.operations.circuit_breaker import get_global_circuit_breaker +from imednet.core.operations.executor import UniversalExecutor +from imednet.core.operations.protocols import OperationProtocol + + +class RESTTask(OperationProtocol[str]): + def __init__(self, fail_times=0): + self.fail_times = fail_times + self.attempts = 0 + + def execute(self) -> str: + self.attempts += 1 + if self.attempts <= self.fail_times: + raise ValueError("HTTP Error") + return "REST Success" + + +class NonRESTTask(OperationProtocol[str]): + def __init__(self, fail_times=0): + self.fail_times = fail_times + self.attempts = 0 + + def execute(self) -> str: + self.attempts += 1 + if self.attempts <= self.fail_times: + raise RuntimeError("DB Error") + return "Non-REST Success" + + +def test_universal_executor_supports_rest_and_non_rest(): + get_global_circuit_breaker().reset() + executor = UniversalExecutor(retries=2, backoff_factor=0.01) + + rest_task = RESTTask(fail_times=1) + result = executor.execute(rest_task.execute) + assert result == "REST Success" + assert rest_task.attempts == 2 + + non_rest_task = NonRESTTask(fail_times=1) + result = executor.execute(non_rest_task.execute) + assert result == "Non-REST Success" + assert non_rest_task.attempts == 2 + + +def test_universal_executor_fails_after_retries(): + get_global_circuit_breaker().reset() + executor = UniversalExecutor(retries=1, backoff_factor=0.01) + + task = RESTTask(fail_times=5) + with pytest.raises(ValueError, match="HTTP Error"): + executor.execute(task.execute) + assert task.attempts == 2 + + +@pytest.mark.asyncio +async def test_universal_executor_async(): + get_global_circuit_breaker().reset() + executor = UniversalExecutor(retries=1, backoff_factor=0.01) + + attempts = 0 + + async def failing_task(): + nonlocal attempts + attempts += 1 + raise ValueError("Async Error") + + # Use a regular function returning a coroutine, to match Callable[[], Awaitable[T]] + def task_factory(): + return failing_task() + + with pytest.raises(ValueError, match="Async Error"): + await executor.execute_async(task_factory) + assert attempts == 2