From f337788752001b5579c3a5b6ad7ce54d1ce51b79 Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Tue, 9 Jun 2026 15:00:19 -0700 Subject: [PATCH] feat: add otel logger injection --- .../examples-catalog.json | 15 + .../src/otel/otel_logger_example.py | 67 +++++ .../template.yaml | 18 ++ .../test/otel/test_otel_logger_example.py | 32 ++ .../__init__.py | 2 + .../logger.py | 86 ++++++ .../plugin.py | 93 ++++-- .../tests/test_logger.py | 183 ++++++++++++ .../tests/test_plugin.py | 276 +++++++++++++++++- .../context.py | 9 +- .../execution.py | 11 +- .../plugin.py | 53 +++- .../tests/plugin_test.py | 119 +++++++- 13 files changed, 927 insertions(+), 37 deletions(-) create mode 100644 packages/aws-durable-execution-sdk-python-examples/src/otel/otel_logger_example.py create mode 100644 packages/aws-durable-execution-sdk-python-examples/test/otel/test_otel_logger_example.py create mode 100644 packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/logger.py create mode 100644 packages/aws-durable-execution-sdk-python-otel/tests/test_logger.py diff --git a/packages/aws-durable-execution-sdk-python-examples/examples-catalog.json b/packages/aws-durable-execution-sdk-python-examples/examples-catalog.json index d921bce..8e5f32f 100644 --- a/packages/aws-durable-execution-sdk-python-examples/examples-catalog.json +++ b/packages/aws-durable-execution-sdk-python-examples/examples-catalog.json @@ -624,6 +624,21 @@ "ExecutionTimeout": 300 }, "path": "./src/plugin/execution_with_otel.py" + }, + { + "name": "Otel Logger Example", + "description": "Demonstrates OTel-enriched logging correlated to durable spans", + "handler": "otel_logger_example.handler", + "integration": true, + "durableConfig": { + "RetentionPeriodInDays": 7, + "ExecutionTimeout": 300 + }, + "loggingConfig": { + "ApplicationLogLevel": "INFO", + "LogFormat": "JSON" + }, + "path": "./src/otel/otel_logger_example.py" } ] } diff --git a/packages/aws-durable-execution-sdk-python-examples/src/otel/otel_logger_example.py b/packages/aws-durable-execution-sdk-python-examples/src/otel/otel_logger_example.py new file mode 100644 index 0000000..8490a4a --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-examples/src/otel/otel_logger_example.py @@ -0,0 +1,67 @@ +"""Demonstrates OTel-enriched logging in a durable execution. + +The DurableExecutionOtelPlugin wraps the execution logger (enrich_logger=True +by default) so every log line emitted through context.logger / step_context.logger +is automatically enriched with the active OpenTelemetry trace context +(otel.trace_id, otel.span_id, otel.trace_sampled). This lets logs correlate to +the spans the plugin emits without any user code changes. + +Logs emitted: +- at the top level correlate to the invocation span +- inside a step correlate to that step's span +- inside a child context correlate to the child-context span +""" + +from typing import Any + +from aws_durable_execution_sdk_python_otel import DurableExecutionOtelPlugin +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter + +from aws_durable_execution_sdk_python import StepContext +from aws_durable_execution_sdk_python.context import ( + DurableContext, + durable_step, + durable_with_child_context, +) +from aws_durable_execution_sdk_python.execution import durable_execution + + +tracer_provider = TracerProvider() +trace.set_tracer_provider(tracer_provider) + +# enrich_logger defaults to True, so the execution logger is wrapped with OTel +# trace context injection (otel.trace_id, otel.span_id, otel.trace_sampled). +otel = DurableExecutionOtelPlugin(tracer_provider) + + +@durable_step +def greet(step_context: StepContext, name: str) -> str: + # Logged inside a step: enriched with this step's span_id. + # Note: avoid reserved LogRecord keys (e.g. "name") in extra. + step_context.logger.info("Greeting inside step", extra={"greeting_name": name}) + return f"hello {name}" + + +@durable_with_child_context +def greet_in_child(child_context: DurableContext, name: str) -> str: + # Logged inside a child context: enriched with the child-context span_id. + child_context.logger.info("Entering child context") + result: str = child_context.step(greet(name), name="child-greet") + child_context.logger.info("Leaving child context", extra={"result": result}) + return result + + +@durable_execution(plugins=[otel]) +def handler(_event: Any, context: DurableContext) -> str: + # Logged at the top level: enriched with the invocation span_id. + context.logger.info("Workflow started") + + top: str = context.step(greet("world"), name="top-greet") + nested: str = context.run_in_child_context( + greet_in_child("nested"), name="child-context" + ) + + context.logger.info("Workflow completed", extra={"top": top, "nested": nested}) + return f"{top} | {nested}" diff --git a/packages/aws-durable-execution-sdk-python-examples/template.yaml b/packages/aws-durable-execution-sdk-python-examples/template.yaml index d56b1af..d469e71 100644 --- a/packages/aws-durable-execution-sdk-python-examples/template.yaml +++ b/packages/aws-durable-execution-sdk-python-examples/template.yaml @@ -1013,6 +1013,24 @@ "ExecutionTimeout": 300 } } + }, + "OtelLoggerExample": { + "Type": "AWS::Serverless::Function", + "Properties": { + "CodeUri": "build/", + "Handler": "otel_logger_example.handler", + "Description": "Demonstrates OTel-enriched logging correlated to durable spans", + "Role": { + "Fn::GetAtt": [ + "DurableFunctionRole", + "Arn" + ] + }, + "DurableConfig": { + "RetentionPeriodInDays": 7, + "ExecutionTimeout": 300 + } + } } } } \ No newline at end of file diff --git a/packages/aws-durable-execution-sdk-python-examples/test/otel/test_otel_logger_example.py b/packages/aws-durable-execution-sdk-python-examples/test/otel/test_otel_logger_example.py new file mode 100644 index 0000000..e366002 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-examples/test/otel/test_otel_logger_example.py @@ -0,0 +1,32 @@ +"""Tests for the OTel-enriched logger example.""" + +import pytest + +from aws_durable_execution_sdk_python.execution import InvocationStatus +from aws_durable_execution_sdk_python.lambda_service import OperationType +from src.otel import otel_logger_example +from test.conftest import deserialize_operation_payload + + +@pytest.mark.example +@pytest.mark.durable_execution( + handler=otel_logger_example.handler, + lambda_function_name="Otel Logger Example", +) +def test_otel_logger_example(durable_runner): + """Verify the OTel logger example runs and produces the expected result.""" + with durable_runner: + result = durable_runner.run(input="{}", timeout=10) + + assert result.status is InvocationStatus.SUCCEEDED + assert deserialize_operation_payload(result.result) == "hello world | hello nested" + + # The top-level step is named "top-greet". + top_step = result.get_step("top-greet") + assert deserialize_operation_payload(top_step.result) == "hello world" + + # The child context wraps a nested step, so a CONTEXT operation exists. + context_ops = [ + op for op in result.operations if op.operation_type is OperationType.CONTEXT + ] + assert len(context_ops) >= 1 diff --git a/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/__init__.py b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/__init__.py index 7ba31ca..d688482 100644 --- a/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/__init__.py +++ b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/__init__.py @@ -9,6 +9,7 @@ from aws_durable_execution_sdk_python_otel.deterministic_id_generator import ( DeterministicIdGenerator, ) +from aws_durable_execution_sdk_python_otel.logger import OtelEnrichedLogger from aws_durable_execution_sdk_python_otel.plugin import ( DurableExecutionOtelPlugin, ) @@ -19,6 +20,7 @@ "ContextExtractor", "DeterministicIdGenerator", "DurableExecutionOtelPlugin", + "OtelEnrichedLogger", "w3c_client_context_extractor", "xray_context_extractor", ] diff --git a/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/logger.py b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/logger.py new file mode 100644 index 0000000..ff76eb3 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/logger.py @@ -0,0 +1,86 @@ +"""OTel-enriched logger for durable executions. + +Provides a LoggerInterface wrapper that injects OpenTelemetry trace context +(trace_id, span_id, trace_sampled) into every log message's extra dict. This +enables log-trace correlation in observability backends without changing user code. +""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING + +from opentelemetry.trace import TraceFlags + + +if TYPE_CHECKING: + from aws_durable_execution_sdk_python.types import LoggerInterface + + from aws_durable_execution_sdk_python_otel.plugin import DurableExecutionOtelPlugin + + +class OtelEnrichedLogger: + """LoggerInterface wrapper that injects OTel trace context into log extra fields. + + The span context is resolved by the plugin via get_current_span_context(), + which returns the active operation span inside steps and the invocation span + for top-level handler code. + + Injected fields: + - otel.trace_id: 32-char hex trace identifier + - otel.span_id: 16-char hex span identifier + - otel.trace_sampled: boolean indicating if the trace is sampled + + Args: + inner: The underlying logger to delegate to after enrichment. + plugin: The OTel plugin instance that resolves the current span context. + """ + + def __init__( + self, inner: LoggerInterface, plugin: DurableExecutionOtelPlugin + ) -> None: + self._inner = inner + self._plugin = plugin + + def debug( + self, msg: object, *args: object, extra: Mapping[str, object] | None = None + ) -> None: + self._inner.debug(msg, *args, extra=self._enrich(extra)) + + def info( + self, msg: object, *args: object, extra: Mapping[str, object] | None = None + ) -> None: + self._inner.info(msg, *args, extra=self._enrich(extra)) + + def warning( + self, msg: object, *args: object, extra: Mapping[str, object] | None = None + ) -> None: + self._inner.warning(msg, *args, extra=self._enrich(extra)) + + def error( + self, msg: object, *args: object, extra: Mapping[str, object] | None = None + ) -> None: + self._inner.error(msg, *args, extra=self._enrich(extra)) + + def exception( + self, msg: object, *args: object, extra: Mapping[str, object] | None = None + ) -> None: + self._inner.exception(msg, *args, extra=self._enrich(extra)) + + def _enrich(self, extra: Mapping[str, object] | None) -> dict[str, object]: + """Inject OTel trace context into the extra dict. + + trace_id, span_id, and trace_sampled come from the span context resolved + by the plugin, so the values always match the exported spans. + """ + enriched: dict[str, object] = dict(extra) if extra else {} + + span_context = self._plugin.get_current_span_context() + if span_context and span_context.is_valid: + enriched["otel.trace_id"] = format(span_context.trace_id, "032x") + enriched["otel.span_id"] = format(span_context.span_id, "016x") + enriched["otel.trace_sampled"] = bool( + span_context.trace_flags & TraceFlags.SAMPLED + ) + + return enriched diff --git a/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/plugin.py b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/plugin.py index a2dda25..cb29044 100644 --- a/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/plugin.py +++ b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/plugin.py @@ -7,19 +7,6 @@ import threading from typing import TYPE_CHECKING, Any -from opentelemetry import trace, context -from opentelemetry.context import Context -from opentelemetry.sdk.trace.sampling import TraceIdRatioBased -from opentelemetry.trace import ( - Tracer, - StatusCode, - SpanContext, - Span, - TracerProvider, - Link, - TraceFlags, -) - from aws_durable_execution_sdk_python.lambda_service import OperationType from aws_durable_execution_sdk_python.plugin import ( DurableInstrumentationPlugin, @@ -27,10 +14,23 @@ InvocationStartInfo, OperationEndInfo, OperationStartInfo, - UserFunctionStartInfo, UserFunctionEndInfo, UserFunctionOutcome, + UserFunctionStartInfo, +) +from opentelemetry import context, trace +from opentelemetry.context import Context +from opentelemetry.sdk.trace.sampling import TraceIdRatioBased +from opentelemetry.trace import ( + Link, + Span, + SpanContext, + StatusCode, + TraceFlags, + Tracer, + TracerProvider, ) + from aws_durable_execution_sdk_python_otel.context_extractors import ( ContextExtractor, xray_context_extractor, @@ -39,9 +39,11 @@ DeterministicIdGenerator, operation_id_to_span_id, ) +from aws_durable_execution_sdk_python_otel.logger import OtelEnrichedLogger + if TYPE_CHECKING: - pass + from aws_durable_execution_sdk_python.types import LoggerInterface logger = logging.getLogger(__name__) @@ -84,6 +86,7 @@ def __init__( context_extractor: ContextExtractor | None = None, sampling_rate: float = 1.0, instrument_name: str = DEFAULT_INSTRUMENT_NAME, + enrich_logger: bool = True, ) -> None: """Initialize the plugin with an OpenTelemetry tracer provider. @@ -91,6 +94,7 @@ def __init__( deterministic ID generator and sampling strategy so spans for a durable execution share stable trace and logical operation identifiers. """ + self._enrich_logger = enrich_logger self._context_extractor: ContextExtractor = ( context_extractor or xray_context_extractor ) @@ -109,6 +113,24 @@ def __init__( self._operation_spans: dict[str | None, Span] = {} self._operation_spans_lock = threading.RLock() + def wrap_logger(self, logger: LoggerInterface) -> LoggerInterface | None: + """Wrap the execution logger to inject OTel trace context. + + When enrich_logger is enabled (default), returns an OtelEnrichedLogger + that adds trace_id, span_id, and trace_sampled to every log message. + Idempotent: returns None if the logger is already an OtelEnrichedLogger. + + Args: + logger: The current logger interface from the execution context. + + Returns: + An OtelEnrichedLogger wrapping the input, or None if disabled or + already wrapped. + """ + if not self._enrich_logger or isinstance(logger, OtelEnrichedLogger): + return None + return OtelEnrichedLogger(inner=logger, plugin=self) + def _set_span(self, operation_id: str | None, span: Span) -> None: """Register the active span for an operation ID.""" with self._operation_spans_lock: @@ -124,6 +146,34 @@ def _get_span(self, operation_id: str | None) -> Span | None: with self._operation_spans_lock: return self._operation_spans.get(operation_id) + def get_current_span_context(self) -> SpanContext | None: + """Return the span context to use for log correlation. + + Resolution order: + 1. The span attached to the OTel thread-local context. Inside a step or + child context this is the active operation span (attached in + on_user_function_start), and between operations it is the enclosing + operation span (restored in on_user_function_end). + 2. The invocation span from the plugin registry. This is the path used + for top-level handler code: the invocation span is never attached to + the worker thread's context, so the registry is the only way to + resolve it. + + Returns: + A valid SpanContext, or None if no span is active. + """ + span_context = trace.get_current_span().get_span_context() + if span_context and span_context.is_valid: + return span_context + + invocation_span = self._get_span(None) + if invocation_span: + invocation_context = invocation_span.get_span_context() + if invocation_context and invocation_context.is_valid: + return invocation_context + + return None + # ------------------------------------------------------------------ # Context resolution # ------------------------------------------------------------------ @@ -250,7 +300,7 @@ def on_invocation_start(self, info: InvocationStartInfo) -> None: self._start_span( operation_id=None, - name=f"invocation", + name="invocation", attributes=self._extract_attributes(info), ) @@ -407,7 +457,16 @@ def on_user_function_end(self, info: UserFunctionEndInfo) -> None: if end_timestamp is not None and end_timestamp == info.start_time: end_timestamp += datetime.timedelta(microseconds=1) self._end_span(info.operation_id, end_timestamp) - # We don't call context.detach because the next operation will override it anyway + # Restore the enclosing operation span as current so code that runs + # after this operation (e.g. between steps in a child context) + # correlates to its enclosing operation, not the operation that just + # ended. For a top-level operation (parent_id is None) this is the + # invocation span; for a nested operation it is the parent context span. + parent_span = self._get_span(info.parent_id) or self._get_span(None) + if parent_span: + context.attach( + trace.set_span_in_context(parent_span, self._extracted_context) + ) def _extract_attributes(self, info: Any) -> dict[str, str]: """Extract durable execution fields as OpenTelemetry span attributes. diff --git a/packages/aws-durable-execution-sdk-python-otel/tests/test_logger.py b/packages/aws-durable-execution-sdk-python-otel/tests/test_logger.py new file mode 100644 index 0000000..b548026 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python-otel/tests/test_logger.py @@ -0,0 +1,183 @@ +"""Tests for the OTel-enriched logger.""" + +from __future__ import annotations + +from datetime import UTC, datetime +from unittest.mock import Mock + +from aws_durable_execution_sdk_python.lambda_service import OperationType +from aws_durable_execution_sdk_python.plugin import ( + InvocationStartInfo, + UserFunctionStartInfo, +) +from opentelemetry.context import Context +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + +from aws_durable_execution_sdk_python_otel.logger import OtelEnrichedLogger +from aws_durable_execution_sdk_python_otel.plugin import DurableExecutionOtelPlugin + + +START_TIME = datetime(2024, 1, 2, 3, 4, 5, tzinfo=UTC) +EXECUTION_ARN = "arn:aws:lambda:us-west-2:123456789012:function:workflow:$LATEST" + + +def _create_plugin( + enrich_logger: bool = True, +) -> tuple[DurableExecutionOtelPlugin, InMemorySpanExporter]: + """Create a plugin wired to an in-memory span exporter.""" + exporter = InMemorySpanExporter() + trace_provider = TracerProvider() + trace_provider.add_span_processor(SimpleSpanProcessor(exporter)) + plugin = DurableExecutionOtelPlugin( + trace_provider=trace_provider, + context_extractor=lambda _: Context(), + enrich_logger=enrich_logger, + ) + return plugin, exporter + + +def _invocation_start_info() -> InvocationStartInfo: + """Create standard invocation start info for tests.""" + return InvocationStartInfo( + request_id="request-1", + execution_arn=EXECUTION_ARN, + start_time=START_TIME, + is_first_invocation=True, + ) + + +def _user_function_start_info(operation_id: str) -> UserFunctionStartInfo: + """Create standard user function start info for tests.""" + return UserFunctionStartInfo( + operation_id=operation_id, + operation_type=OperationType.STEP, + sub_type=None, + name="fetch-user", + parent_id=None, + start_time=START_TIME, + is_replay_children=False, + attempt=1, + ) + + +def test_wrap_logger_returns_enriched_logger_when_enabled(): + """Verify wrap_logger wraps the logger when enrich_logger is enabled.""" + plugin, _ = _create_plugin(enrich_logger=True) + inner = Mock() + + wrapped = plugin.wrap_logger(inner) + + assert isinstance(wrapped, OtelEnrichedLogger) + + +def test_wrap_logger_returns_none_when_disabled(): + """Verify wrap_logger is a no-op when enrich_logger is disabled.""" + plugin, _ = _create_plugin(enrich_logger=False) + inner = Mock() + + assert plugin.wrap_logger(inner) is None + + +def test_wrap_logger_is_idempotent(): + """Verify wrap_logger does not double-wrap an already-wrapped logger.""" + plugin, _ = _create_plugin(enrich_logger=True) + inner = Mock() + + wrapped = plugin.wrap_logger(inner) + assert plugin.wrap_logger(wrapped) is None + + +def test_enriched_logger_delegates_to_inner(): + """Verify all log levels delegate to the underlying logger.""" + plugin, _ = _create_plugin() + inner = Mock() + logger = OtelEnrichedLogger(inner=inner, plugin=plugin) + + logger.debug("debug message") + logger.info("info message") + logger.warning("warning message") + logger.error("error message") + logger.exception("exception message") + + inner.debug.assert_called_once() + inner.info.assert_called_once() + inner.warning.assert_called_once() + inner.error.assert_called_once() + inner.exception.assert_called_once() + + +def test_enriched_logger_injects_trace_id_from_invocation_span(): + """Verify trace_id is injected from the plugin's invocation span.""" + plugin, _ = _create_plugin() + plugin.on_invocation_start(_invocation_start_info()) + inner = Mock() + logger = OtelEnrichedLogger(inner=inner, plugin=plugin) + + logger.info("hello") + + _, kwargs = inner.info.call_args + extra = kwargs["extra"] + assert "otel.trace_id" in extra + assert len(extra["otel.trace_id"]) == 32 + assert "otel.span_id" in extra + assert len(extra["otel.span_id"]) == 16 + assert "otel.trace_sampled" in extra + + +def test_enriched_logger_uses_operation_span_inside_user_function(): + """Verify span_id reflects the active operation span during user code.""" + plugin, _ = _create_plugin() + plugin.on_invocation_start(_invocation_start_info()) + operation_id = "step-1" + plugin.on_user_function_start(_user_function_start_info(operation_id)) + + inner = Mock() + logger = OtelEnrichedLogger(inner=inner, plugin=plugin) + logger.info("inside step") + + _, kwargs = inner.info.call_args + operation_span = plugin._get_span(operation_id) + expected_span_id = format(operation_span.get_span_context().span_id, "016x") + assert kwargs["extra"]["otel.span_id"] == expected_span_id + + +def test_enriched_logger_preserves_existing_extra(): + """Verify caller-provided extra fields are preserved alongside otel fields.""" + plugin, _ = _create_plugin() + plugin.on_invocation_start(_invocation_start_info()) + inner = Mock() + logger = OtelEnrichedLogger(inner=inner, plugin=plugin) + + logger.info("hello", extra={"order_id": "123"}) + + _, kwargs = inner.info.call_args + assert kwargs["extra"]["order_id"] == "123" + assert "otel.trace_id" in kwargs["extra"] + + +def test_enriched_logger_handles_none_extra(): + """Verify None extra is handled without error.""" + plugin, _ = _create_plugin() + plugin.on_invocation_start(_invocation_start_info()) + inner = Mock() + logger = OtelEnrichedLogger(inner=inner, plugin=plugin) + + logger.info("hello", extra=None) + + _, kwargs = inner.info.call_args + assert isinstance(kwargs["extra"], dict) + + +def test_enriched_logger_passes_positional_args(): + """Verify positional format args are forwarded to the inner logger.""" + plugin, _ = _create_plugin() + plugin.on_invocation_start(_invocation_start_info()) + inner = Mock() + logger = OtelEnrichedLogger(inner=inner, plugin=plugin) + + logger.info("hello %s %s", "a", "b") + + args, _ = inner.info.call_args + assert args == ("hello %s %s", "a", "b") diff --git a/packages/aws-durable-execution-sdk-python-otel/tests/test_plugin.py b/packages/aws-durable-execution-sdk-python-otel/tests/test_plugin.py index 5fb8a43..6644f20 100644 --- a/packages/aws-durable-execution-sdk-python-otel/tests/test_plugin.py +++ b/packages/aws-durable-execution-sdk-python-otel/tests/test_plugin.py @@ -5,11 +5,7 @@ from concurrent.futures import ThreadPoolExecutor from datetime import UTC, datetime -from opentelemetry.context import Context -from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import SimpleSpanProcessor -from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter - +import pytest from aws_durable_execution_sdk_python.lambda_service import ( InvocationStatus, OperationStatus, @@ -24,6 +20,13 @@ UserFunctionOutcome, UserFunctionStartInfo, ) +from opentelemetry import context as otel_context +from opentelemetry import trace +from opentelemetry.context import Context +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + from aws_durable_execution_sdk_python_otel.deterministic_id_generator import ( operation_id_to_span_id, ) @@ -35,6 +38,20 @@ EXECUTION_ARN = "arn:aws:lambda:us-west-2:123456789012:function:workflow:$LATEST" +@pytest.fixture(autouse=True) +def _reset_otel_context(): + """Reset the OTel thread-local context before and after each test. + + The plugin attaches spans via context.attach() without ever detaching, + so state would otherwise leak between tests running on the same thread. + """ + token = otel_context.attach(Context()) + try: + yield + finally: + otel_context.detach(token) + + def _create_plugin() -> tuple[DurableExecutionOtelPlugin, InMemorySpanExporter]: """Create a plugin wired to an in-memory span exporter.""" exporter = InMemorySpanExporter() @@ -70,6 +87,48 @@ def _invocation_end_info() -> InvocationEndInfo: ) +def _user_function_start_info( + operation_id: str, + attempt: int = 1, + parent_id: str | None = None, + operation_type: OperationType = OperationType.STEP, +) -> UserFunctionStartInfo: + """Create standard user function start info for tests.""" + return UserFunctionStartInfo( + operation_id=operation_id, + operation_type=operation_type, + sub_type=None, + name=f"step-{operation_id}", + parent_id=parent_id, + start_time=START_TIME, + is_replay_children=False, + attempt=attempt, + ) + + +def _user_function_end_info( + operation_id: str, + outcome: UserFunctionOutcome = UserFunctionOutcome.SUCCEEDED, + attempt: int = 1, + parent_id: str | None = None, + operation_type: OperationType = OperationType.STEP, +) -> UserFunctionEndInfo: + """Create standard user function end info for tests.""" + return UserFunctionEndInfo( + operation_id=operation_id, + operation_type=operation_type, + sub_type=None, + name=f"step-{operation_id}", + parent_id=parent_id, + start_time=START_TIME, + is_replay_children=False, + attempt=attempt, + outcome=outcome, + end_time=END_TIME, + error=None, + ) + + def test_invocation_start_and_end_emit_invocation_span(): """Verify invocation lifecycle callbacks create and finish the root span.""" plugin, exporter = _create_plugin() @@ -223,3 +282,210 @@ def update_span(index: int) -> None: with plugin._operation_spans_lock: assert plugin._operation_spans == {} + + +# ---------------------------------------------------------------------- +# on_user_function_end restores the invocation span to the context +# ---------------------------------------------------------------------- +def test_user_function_end_restores_invocation_span(): + """Verify the invocation span is current again after a step completes.""" + plugin, _ = _create_plugin() + plugin.on_invocation_start(_invocation_start_info()) + invocation_span_id = plugin._get_span(None).get_span_context().span_id + + operation_id = "step-1" + plugin.on_user_function_start(_user_function_start_info(operation_id)) + # Inside the step, the current span is the operation span. + assert ( + trace.get_current_span().get_span_context().span_id + == operation_id_to_span_id(operation_id) + ) + + plugin.on_user_function_end(_user_function_end_info(operation_id)) + + # After the step, the invocation span is restored. + assert trace.get_current_span().get_span_context().span_id == invocation_span_id + + +def test_user_function_end_restores_invocation_span_on_failure(): + """Verify the invocation span is restored even when the step fails.""" + plugin, _ = _create_plugin() + plugin.on_invocation_start(_invocation_start_info()) + invocation_span_id = plugin._get_span(None).get_span_context().span_id + + operation_id = "step-fail" + plugin.on_user_function_start(_user_function_start_info(operation_id)) + plugin.on_user_function_end( + _user_function_end_info(operation_id, outcome=UserFunctionOutcome.FAILED) + ) + + assert trace.get_current_span().get_span_context().span_id == invocation_span_id + + +def test_user_function_end_restores_invocation_span_across_multiple_steps(): + """Verify between-step context is the invocation span across many steps.""" + plugin, _ = _create_plugin() + plugin.on_invocation_start(_invocation_start_info()) + invocation_span_id = plugin._get_span(None).get_span_context().span_id + + for index in range(3): + operation_id = f"step-{index}" + plugin.on_user_function_start(_user_function_start_info(operation_id)) + plugin.on_user_function_end(_user_function_end_info(operation_id)) + # Between each step, the invocation span is the current span. + assert trace.get_current_span().get_span_context().span_id == invocation_span_id + + +# ---------------------------------------------------------------------- +# get_current_span_context resolves the right span context +# ---------------------------------------------------------------------- +def test_get_current_span_context_returns_none_before_invocation_start(): + """Verify no span context is returned when nothing is active.""" + plugin, _ = _create_plugin() + + assert plugin.get_current_span_context() is None + + +def test_get_current_span_context_returns_invocation_span_at_top_level(): + """Verify top-level code resolves to the invocation span context.""" + plugin, _ = _create_plugin() + plugin.on_invocation_start(_invocation_start_info()) + + span_context = plugin.get_current_span_context() + invocation_span = plugin._get_span(None) + assert span_context is not None + assert span_context.span_id == invocation_span.get_span_context().span_id + + +def test_get_current_span_context_returns_operation_span_inside_step(): + """Verify code inside a step resolves to the operation span context.""" + plugin, _ = _create_plugin() + plugin.on_invocation_start(_invocation_start_info()) + operation_id = "step-1" + plugin.on_user_function_start(_user_function_start_info(operation_id)) + + span_context = plugin.get_current_span_context() + assert span_context is not None + assert span_context.span_id == operation_id_to_span_id(operation_id) + + +def test_get_current_span_context_returns_invocation_span_between_steps(): + """Verify between-step code resolves back to the invocation span context.""" + plugin, _ = _create_plugin() + plugin.on_invocation_start(_invocation_start_info()) + operation_id = "step-1" + plugin.on_user_function_start(_user_function_start_info(operation_id)) + plugin.on_user_function_end(_user_function_end_info(operation_id)) + + span_context = plugin.get_current_span_context() + invocation_span = plugin._get_span(None) + assert span_context is not None + assert span_context.span_id == invocation_span.get_span_context().span_id + + +# ---------------------------------------------------------------------- +# on_user_function_end restores the ENCLOSING operation span (nested case) +# ---------------------------------------------------------------------- +def test_user_function_end_restores_parent_context_span_for_nested_step(): + """Verify ending a nested step restores its enclosing child-context span. + + Inside a child context, code that runs after an inner step (e.g. between + inner steps) must correlate to the child context span, not the invocation. + """ + plugin, _ = _create_plugin() + plugin.on_invocation_start(_invocation_start_info()) + + # Enter a child context (CONTEXT operation at the top level). + context_id = "ctx-1" + plugin.on_user_function_start( + _user_function_start_info(context_id, operation_type=OperationType.CONTEXT) + ) + context_span_id = trace.get_current_span().get_span_context().span_id + + # Run an inner step whose parent is the child context. + inner_step_id = "ctx-1-step" + plugin.on_user_function_start( + _user_function_start_info(inner_step_id, parent_id=context_id) + ) + assert ( + trace.get_current_span().get_span_context().span_id + == operation_id_to_span_id(inner_step_id) + ) + + plugin.on_user_function_end( + _user_function_end_info(inner_step_id, parent_id=context_id) + ) + + # After the inner step, the enclosing child-context span is current again, + # NOT the invocation span. + assert trace.get_current_span().get_span_context().span_id == context_span_id + assert ( + trace.get_current_span().get_span_context().span_id + != plugin._get_span(None).get_span_context().span_id + ) + + +def test_user_function_end_falls_back_to_invocation_when_parent_missing(): + """Verify a top-level step (parent_id=None) restores the invocation span.""" + plugin, _ = _create_plugin() + plugin.on_invocation_start(_invocation_start_info()) + invocation_span_id = plugin._get_span(None).get_span_context().span_id + + operation_id = "step-1" + plugin.on_user_function_start(_user_function_start_info(operation_id)) + plugin.on_user_function_end(_user_function_end_info(operation_id)) + + assert trace.get_current_span().get_span_context().span_id == invocation_span_id + + +def test_get_current_span_context_returns_context_span_between_nested_steps(): + """Verify between-step code inside a child context resolves to that context. + + This is the log-correlation path: after an inner step completes, + get_current_span_context must return the enclosing child-context span so + logs emitted between inner steps correlate to the child context. + """ + plugin, _ = _create_plugin() + plugin.on_invocation_start(_invocation_start_info()) + + context_id = "ctx-1" + plugin.on_user_function_start( + _user_function_start_info(context_id, operation_type=OperationType.CONTEXT) + ) + context_span = plugin._get_span(context_id) + + inner_step_id = "ctx-1-step" + plugin.on_user_function_start( + _user_function_start_info(inner_step_id, parent_id=context_id) + ) + plugin.on_user_function_end( + _user_function_end_info(inner_step_id, parent_id=context_id) + ) + + span_context = plugin.get_current_span_context() + assert span_context is not None + assert span_context.span_id == context_span.get_span_context().span_id + assert span_context.span_id != plugin._get_span(None).get_span_context().span_id + + +def test_nested_steps_restore_context_span_across_multiple_iterations(): + """Verify each inner step restores the child-context span between iterations.""" + plugin, _ = _create_plugin() + plugin.on_invocation_start(_invocation_start_info()) + + context_id = "ctx-1" + plugin.on_user_function_start( + _user_function_start_info(context_id, operation_type=OperationType.CONTEXT) + ) + context_span_id = trace.get_current_span().get_span_context().span_id + + for index in range(3): + inner_step_id = f"ctx-1-step-{index}" + plugin.on_user_function_start( + _user_function_start_info(inner_step_id, parent_id=context_id) + ) + plugin.on_user_function_end( + _user_function_end_info(inner_step_id, parent_id=context_id) + ) + # Between each inner step, the child-context span is current. + assert trace.get_current_span().get_span_context().span_id == context_span_id diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/context.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/context.py index 00e575d..2f9bc21 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/context.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/context.py @@ -395,9 +395,14 @@ def _resolve_step_name(name: str | None, func: Callable) -> str | None: return name or getattr(func, "_original_name", None) def set_logger(self, new_logger: LoggerInterface): - """Set the logger for the current context.""" + """Set the logger for the current context. + + If plugins are registered, the logger will be wrapped by plugin logger + enrichment (e.g., OTel trace context injection) before being applied. + """ + wrapped = self.state._plugin_executor.wrap_logger(new_logger) self.logger = Logger.from_log_info( - logger=new_logger, + logger=wrapped, info=self._log_info, ) diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/execution.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/execution.py index afb710e..56a46d6 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/execution.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/execution.py @@ -9,7 +9,6 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any - from aws_durable_execution_sdk_python.context import DurableContext from aws_durable_execution_sdk_python.exceptions import ( BackgroundThreadError, @@ -21,14 +20,14 @@ SuspendExecution, ) from aws_durable_execution_sdk_python.lambda_service import ( + DurableExecutionInvocationOutput, DurableServiceClient, ErrorObject, + InvocationStatus, LambdaClient, Operation, OperationType, OperationUpdate, - InvocationStatus, - DurableExecutionInvocationOutput, ) from aws_durable_execution_sdk_python.plugin import ( DurableInstrumentationPlugin, @@ -36,6 +35,7 @@ ) from aws_durable_execution_sdk_python.state import ExecutionState, ReplayStatus + if TYPE_CHECKING: from collections.abc import Callable, MutableMapping @@ -273,6 +273,11 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: state=execution_state, lambda_context=context ) + # Trigger plugin logger wrapping on the root context's default logger + # (e.g., OTel trace context injection). Child contexts inherit the + # already-wrapped logger and do not re-wrap. + durable_context.set_logger(durable_context.logger.get_logger()) + # Use ThreadPoolExecutor for concurrent execution of user code and background checkpoint processing with ( ThreadPoolExecutor( diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/plugin.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/plugin.py index 0deff94..91faec9 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/plugin.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/plugin.py @@ -10,17 +10,18 @@ from aws_durable_execution_sdk_python.exceptions import SuspendExecution from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import ( - OperationType, - OperationStatus, - OperationAction, - OperationSubType, + DurableExecutionInvocationOutput, ErrorObject, InvocationStatus, Operation, + OperationAction, + OperationStatus, + OperationSubType, + OperationType, OperationUpdate, - DurableExecutionInvocationOutput, ) -from aws_durable_execution_sdk_python.types import LambdaContext +from aws_durable_execution_sdk_python.types import LambdaContext, LoggerInterface + logger = logging.getLogger(__name__) @@ -191,8 +192,20 @@ def on_user_function_end(self, info: UserFunctionEndInfo) -> None: """ pass - # Todo: further discussions required to finalize the following interface - # def enrich_log_context(self, info: OperationStartInfo | None) -> Dict[str, Any] | None: pass + def wrap_logger(self, logger: LoggerInterface) -> LoggerInterface | None: + """Optionally wrap the execution logger to enrich log output. + + Called once per invocation after the root DurableContext is created. + Return a wrapped logger to add plugin-specific fields to log output, + or None to leave the logger unchanged. + + Args: + logger: The current logger interface used by the execution context. + + Returns: + A wrapped LoggerInterface, or None to keep the existing logger. + """ + pass class PluginExecutor: @@ -201,6 +214,30 @@ def __init__(self, plugins: list[DurableInstrumentationPlugin] | None): self._executor: ThreadPoolExecutor | None = None self._invocation_status: InvocationStartInfo | None = None + def wrap_logger(self, current_logger: LoggerInterface) -> LoggerInterface: + """Chain all plugin logger wrappers, returning the final wrapped logger. + + Each plugin's wrap_logger is called in order. If a plugin returns a + wrapped logger, it becomes the input for the next plugin. + + Args: + current_logger: The current logger interface from the DurableContext. + + Returns: + The final logger after all plugins have had a chance to wrap it. + """ + for plugin in self._plugins: + try: + wrapped = plugin.wrap_logger(current_logger) + if wrapped is not None: + current_logger = wrapped + except Exception: + logger.exception( + "Plugin %s wrap_logger exception ignored", + plugin.__class__.__name__, + ) + return current_logger + @contextlib.contextmanager def run(self): if self._plugins: diff --git a/packages/aws-durable-execution-sdk-python/tests/plugin_test.py b/packages/aws-durable-execution-sdk-python/tests/plugin_test.py index b26365c..30bcda1 100644 --- a/packages/aws-durable-execution-sdk-python/tests/plugin_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/plugin_test.py @@ -4,13 +4,13 @@ from unittest.mock import MagicMock from aws_durable_execution_sdk_python.lambda_service import ( + DurableExecutionInvocationOutput, ErrorObject, InvocationStatus, OperationAction, OperationStatus, OperationSubType, OperationType, - DurableExecutionInvocationOutput, ) from aws_durable_execution_sdk_python.plugin import ( DurableInstrumentationPlugin, @@ -19,9 +19,9 @@ OperationEndInfo, OperationStartInfo, PluginExecutor, + UserFunctionEndInfo, UserFunctionOutcome, UserFunctionStartInfo, - UserFunctionEndInfo, ) @@ -715,6 +715,102 @@ def test_ready_is_not_terminal(self): self.assertFalse(PluginExecutor._is_terminal_status(OperationStatus.READY)) +class TestPluginExecutorWrapLogger(unittest.TestCase): + """Tests for PluginExecutor.wrap_logger.""" + + def test_no_plugins_returns_logger_unchanged(self): + """With no plugins, the logger is returned as-is.""" + executor = PluginExecutor(plugins=None) + logger = MagicMock() + + self.assertIs(executor.wrap_logger(logger), logger) + + def test_empty_plugins_returns_logger_unchanged(self): + """With an empty plugin list, the logger is returned as-is.""" + executor = PluginExecutor(plugins=[]) + logger = MagicMock() + + self.assertIs(executor.wrap_logger(logger), logger) + + def test_plugin_returning_none_leaves_logger_unchanged(self): + """A plugin returning None does not replace the logger.""" + executor = PluginExecutor(plugins=[_NoOpPlugin()]) + logger = MagicMock() + + self.assertIs(executor.wrap_logger(logger), logger) + + def test_plugin_wrapping_replaces_logger(self): + """A plugin returning a wrapped logger replaces the original.""" + wrapped = MagicMock() + executor = PluginExecutor(plugins=[_WrappingPlugin(wrapped)]) + logger = MagicMock() + + self.assertIs(executor.wrap_logger(logger), wrapped) + + def test_multiple_plugins_chain_wrappers(self): + """Each plugin wraps the output of the previous plugin in order.""" + first_wrap = MagicMock(name="first") + second_wrap = MagicMock(name="second") + plugin1 = _WrappingPlugin(first_wrap) + plugin2 = _WrappingPlugin(second_wrap) + executor = PluginExecutor(plugins=[plugin1, plugin2]) + logger = MagicMock(name="original") + + result = executor.wrap_logger(logger) + + # plugin1 receives the original logger + self.assertIs(plugin1.received, logger) + # plugin2 receives plugin1's wrapped logger + self.assertIs(plugin2.received, first_wrap) + # final result is plugin2's wrapper + self.assertIs(result, second_wrap) + + def test_plugin_returning_none_passes_original_to_next(self): + """A plugin returning None passes the unchanged logger to the next plugin.""" + wrapped = MagicMock(name="wrapped") + noop = _NoOpPlugin() + wrapping = _WrappingPlugin(wrapped) + executor = PluginExecutor(plugins=[noop, wrapping]) + logger = MagicMock(name="original") + + result = executor.wrap_logger(logger) + + # The wrapping plugin still receives the original logger + self.assertIs(wrapping.received, logger) + self.assertIs(result, wrapped) + + def test_plugin_exception_is_swallowed_and_chain_continues(self): + """If a plugin's wrap_logger raises, it is logged and the chain continues.""" + wrapped = MagicMock(name="wrapped") + failing = _WrapLoggerFailingPlugin() + wrapping = _WrappingPlugin(wrapped) + executor = PluginExecutor(plugins=[failing, wrapping]) + logger = MagicMock(name="original") + + with self.assertLogs( + "aws_durable_execution_sdk_python.plugin", level=logging.ERROR + ): + result = executor.wrap_logger(logger) + + # The failing plugin did not break the chain; wrapping plugin still ran + self.assertIs(wrapping.received, logger) + self.assertIs(result, wrapped) + + def test_all_plugins_failing_returns_original_logger(self): + """If every plugin fails, the original logger is returned unchanged.""" + executor = PluginExecutor( + plugins=[_WrapLoggerFailingPlugin(), _WrapLoggerFailingPlugin()] + ) + logger = MagicMock(name="original") + + with self.assertLogs( + "aws_durable_execution_sdk_python.plugin", level=logging.ERROR + ): + result = executor.wrap_logger(logger) + + self.assertIs(result, logger) + + # endregion PluginExecutor Tests @@ -752,6 +848,25 @@ def on_user_function_end(self, info: UserFunctionEndInfo) -> None: self.calls.append(f"user_function_end:{info.operation_id}") +class _WrappingPlugin(DurableInstrumentationPlugin): + """Plugin that wraps the logger with a fixed replacement and records input.""" + + def __init__(self, replacement) -> None: + self._replacement = replacement + self.received = None + + def wrap_logger(self, logger): + self.received = logger + return self._replacement + + +class _WrapLoggerFailingPlugin(DurableInstrumentationPlugin): + """Plugin whose wrap_logger always raises.""" + + def wrap_logger(self, logger): + raise RuntimeError("boom") + + class _FailingPlugin(DurableInstrumentationPlugin): """Plugin that raises on every hook call."""