From 8607b662e6ca94d8f1a2c5d654e8ef7162d04e35 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Tue, 9 Jun 2026 20:16:35 -0700 Subject: [PATCH] feat: async user function --- .github/scripts/__init__.py | 1 - .github/scripts/ci-checks.sh | 62 ---- .github/scripts/lintcommit.py | 267 ------------------ .github/scripts/parse_sdk_branch.py | 31 -- .github/scripts/tests/__init__.py | 1 - .github/scripts/tests/test_lintcommit.py | 259 ----------------- .../scripts/tests/test_parse_sdk_branch.py | 84 ------ .github/workflows/ci.yml | 14 - .github/workflows/ecr-release.yml | 130 --------- .github/workflows/notify-issues.yml | 23 -- .github/workflows/notify-pr.yml | 24 -- .github/workflows/notify-release.yml | 29 -- .github/workflows/pypi-publish.yml | 9 +- .github/workflows/test-parser.yml | 24 -- README.md | 25 ++ .../run_in_child_context.py | 6 +- .../src/step/step.py | 6 +- .../wait_for_callback/wait_for_callback.py | 8 +- .../wait_for_condition/wait_for_condition.py | 6 +- .../README.md | 25 ++ .../async_tools.py | 49 ++++ .../context.py | 26 +- .../execution.py | 13 +- .../operation/callback.py | 2 +- .../aws_durable_execution_sdk_python/state.py | 3 +- .../aws_durable_execution_sdk_python/types.py | 15 +- .../tests/async_tools_test.py | 22 ++ .../tests/execution_test.py | 51 ++++ 28 files changed, 227 insertions(+), 988 deletions(-) delete mode 100644 .github/scripts/__init__.py delete mode 100755 .github/scripts/ci-checks.sh delete mode 100644 .github/scripts/lintcommit.py delete mode 100644 .github/scripts/parse_sdk_branch.py delete mode 100644 .github/scripts/tests/__init__.py delete mode 100644 .github/scripts/tests/test_lintcommit.py delete mode 100644 .github/scripts/tests/test_parse_sdk_branch.py delete mode 100644 .github/workflows/ecr-release.yml delete mode 100644 .github/workflows/notify-issues.yml delete mode 100644 .github/workflows/notify-pr.yml delete mode 100644 .github/workflows/notify-release.yml delete mode 100644 .github/workflows/test-parser.yml create mode 100644 packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/async_tools.py create mode 100644 packages/aws-durable-execution-sdk-python/tests/async_tools_test.py diff --git a/.github/scripts/__init__.py b/.github/scripts/__init__.py deleted file mode 100644 index 8b137891..00000000 --- a/.github/scripts/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/.github/scripts/ci-checks.sh b/.github/scripts/ci-checks.sh deleted file mode 100755 index 29ed4476..00000000 --- a/.github/scripts/ci-checks.sh +++ /dev/null @@ -1,62 +0,0 @@ -#!/usr/bin/env bash - -set -e - -REPO_ROOT="$(cd "$(dirname "$0")/../.." && pwd)" -cd "$REPO_ROOT" - -# --- Core SDK checks --- -echo "==========================================" -echo "Running checks for aws-durable-execution-sdk-python" -echo "==========================================" - -hatch run dev-core:cov -echo "SUCCESS: tests + coverage (core)" - -hatch run dev-core:typecheck -echo "SUCCESS: typings (core)" - -# --- OTel SDK checks --- -echo "==========================================" -echo "Running checks for aws-durable-execution-sdk-python-otel" -echo "==========================================" - -hatch run dev-otel:cov -echo "SUCCESS: tests + coverage (otel)" - -hatch run dev-otel:typecheck -echo "SUCCESS: typings (otel)" - -# --- Examples checks --- -echo "==========================================" -echo "Running checks for examples" -echo "==========================================" - -hatch run dev-examples:test -echo "SUCCESS: tests (examples)" - -# --- Formatting / linting (per package) --- -PACKAGES=( - "packages/aws-durable-execution-sdk-python" - "packages/aws-durable-execution-sdk-python-otel" - "packages/aws-durable-execution-sdk-python-examples" -) - -for package_dir in "${PACKAGES[@]}"; do - full_path="$REPO_ROOT/$package_dir" - if [ -d "$full_path" ]; then - echo "==========================================" - echo "Running formatting/linting for $package_dir" - echo "==========================================" - cd "$full_path" - hatch fmt - echo "SUCCESS: linting/fmt ($package_dir)" - else - echo "WARNING: $package_dir does not exist, skipping fmt" - fi -done - -cd "$REPO_ROOT" - -# --- Commit message validation --- -hatch run python .github/scripts/lintcommit.py diff --git a/.github/scripts/lintcommit.py b/.github/scripts/lintcommit.py deleted file mode 100644 index 255ea0ec..00000000 --- a/.github/scripts/lintcommit.py +++ /dev/null @@ -1,267 +0,0 @@ -#!/usr/bin/env python3 -# Checks that commit messages conform to conventional commits -# (https://www.conventionalcommits.org/). -# -# To run tests: -# -# python -m pytest .github/scripts/tests/test_lintcommit.py - -from __future__ import annotations - -import argparse -import re -import subprocess -import sys -from dataclasses import dataclass, field - -TYPES: set[str] = { - "build", - "chore", - "ci", - "deps", - "docs", - "feat", - "fix", - "perf", - "refactor", - "style", - "test", -} - -MAX_SUBJECT_LENGTH: int = 50 -MAX_SCOPE_LENGTH: int = 30 -MAX_BODY_LINE_LENGTH: int = 72 - - -def validate_subject(subject_line: str) -> str | None: - """Validate a commit message subject line. - - Returns None if valid, else an error message string. - """ - parts: list[str] = subject_line.split(":", maxsplit=1) - - if len(parts) < 2: - return "missing colon (:) char" - - type_scope: str = parts[0] - subject: str = parts[1].strip() - - # Parse type and optional scope: type or type(scope) - scope: str | None = None - commit_type: str = type_scope - - if "(" in type_scope: - paren_start: int = type_scope.index("(") - commit_type = type_scope[:paren_start] - - if not type_scope.endswith(")"): - return "must be formatted like type(scope):" - - scope = type_scope[paren_start + 1 : -1] - - if " " in commit_type: - return f'type contains whitespace: "{commit_type}"' - - if commit_type not in TYPES: - return f'invalid type "{commit_type}"' - - if scope is not None: - if len(scope) > MAX_SCOPE_LENGTH: - return f"invalid scope (must be <={MAX_SCOPE_LENGTH} chars)" - - if re.search(r"[^- a-z0-9]", scope): - return f'invalid scope (must be lowercase, ascii only): "{scope}"' - - if len(subject) == 0: - return "empty subject" - - if len(subject) > MAX_SUBJECT_LENGTH: - return f"invalid subject (must be <={MAX_SUBJECT_LENGTH} chars)" - - if subject.endswith("."): - return "subject must not end with a period" - - return None - - -def validate_body(body: str) -> list[str]: - """Validate the body of a commit message. - - Returns a list of warnings (not hard errors) for body issues. - """ - warnings: list[str] = [] - for i, line in enumerate(body.splitlines(), start=1): - if len(line) > MAX_BODY_LINE_LENGTH: - warnings.append( - f"body line {i} exceeds {MAX_BODY_LINE_LENGTH} chars ({len(line)} chars)" - ) - return warnings - - -def validate_message(message: str) -> tuple[str | None, list[str]]: - """Validate a full commit message (subject + optional body). - - Returns (error, warnings) where error is None if the subject is valid. - """ - lines: list[str] = message.strip().splitlines() - if not lines: - return ("empty commit message", []) - - subject_line: str = lines[0] - error: str | None = validate_subject(subject_line) - - warnings: list[str] = [] - # Check for blank line between subject and body - body_start: int = 2 - if len(lines) > 1 and lines[1].strip() != "": - warnings.append("missing blank line between subject and body") - body_start = 1 - - if len(lines) > body_start: - body: str = "\n".join(lines[body_start:]) - warnings.extend(validate_body(body)) - - return (error, warnings) - - -@dataclass -class CommitResult: - """Result of validating a single commit.""" - - sha: str - subject: str - error: str | None = None - warnings: list[str] = field(default_factory=list) - - -@dataclass -class LintResult: - """Result of linting a range of commits.""" - - commits: list[CommitResult] = field(default_factory=list) - skipped: bool = False - skip_reason: str = "" - empty: bool = False - git_error: str = "" - - @property - def has_errors(self) -> bool: - return bool(self.git_error) or any(c.error for c in self.commits) - - -def lint_range(git_range: str, *, skip_dirty_check: bool = False) -> LintResult: - """Validate commit messages in a git range (e.g. 'origin/main..HEAD'). - - Args: - git_range: A git revision range like 'origin/main..HEAD'. - skip_dirty_check: When True, skip the uncommitted changes check - (useful in CI where the worktree may be clean by definition). - - Returns: - A LintResult with per-commit validation results. - """ - if not skip_dirty_check: - status = subprocess.run( - ["git", "status", "--porcelain"], - capture_output=True, - text=True, - check=False, - ) - if status.stdout.strip(): - return LintResult( - skipped=True, - skip_reason=( - "uncommitted changes detected, skipping commit message validation.\n" - "Commit your changes and re-run to validate." - ), - ) - - result = subprocess.run( - ["git", "log", "--no-merges", git_range, "-z", "--format=%H%n%B"], - capture_output=True, - text=True, - check=False, - ) - if result.returncode != 0: - return LintResult(git_error=result.stderr.strip()) - - if not result.stdout.strip(): - return LintResult(empty=True) - - commits: list[CommitResult] = [] - for record in result.stdout.split("\0"): - if not record.strip(): - continue - sha, _, message = record.partition("\n") - message = message.strip() - if not message: - continue - - error, warnings = validate_message(message) - subject = message.splitlines()[0] - commits.append( - CommitResult( - sha=sha[:7], - subject=subject, - error=error, - warnings=warnings, - ) - ) - - return LintResult(commits=commits) - - -def write_output(lint_result: LintResult, git_range: str) -> None: - """Write lint results to stdout/stderr.""" - if lint_result.skipped: - print(f"WARNING: {lint_result.skip_reason}") - return - - if lint_result.git_error: - print(f"git log failed: {lint_result.git_error}", file=sys.stderr) - return - - if lint_result.empty: - print(f"No commits in range {git_range}") - return - - for commit in lint_result.commits: - if commit.error: - print(f"FAIL {commit.sha}: {commit.subject}", file=sys.stderr) - print(f" Error: {commit.error}", file=sys.stderr) - else: - print(f"PASS {commit.sha}: {commit.subject}") - - for warning in commit.warnings: - print(f" Warning: {warning}") - - -def run_range(git_range: str, *, skip_dirty_check: bool = False) -> None: - """Validate commit messages in a git range and exit on errors.""" - lint_result = lint_range(git_range, skip_dirty_check=skip_dirty_check) - write_output(lint_result, git_range) - if lint_result.has_errors: - sys.exit(1) - - -def main() -> None: - parser = argparse.ArgumentParser( - description="Lint commit messages for conventional commits compliance." - ) - parser.add_argument( - "--range", - default=None, - dest="git_range", - help="Validate all commits in a git revision range (e.g. 'origin/main..HEAD'). " - "Skips the uncommitted-changes check (useful in CI).", - ) - args = parser.parse_args() - - if args.git_range is not None: - run_range(args.git_range, skip_dirty_check=True) - else: - run_range("origin/main..HEAD") - - -if __name__ == "__main__": - main() diff --git a/.github/scripts/parse_sdk_branch.py b/.github/scripts/parse_sdk_branch.py deleted file mode 100644 index 1967085e..00000000 --- a/.github/scripts/parse_sdk_branch.py +++ /dev/null @@ -1,31 +0,0 @@ -#!/usr/bin/env python3 - -import os -import re - - -def parse_sdk_branch(pr_body: str, default_ref: str = "main") -> str: - """Parse PR body for TESTING_SDK_BRANCH and return the branch reference.""" - pattern = re.compile(r"(?i)TESTING_SDK_BRANCH\s*[:=]\s*(\S+)", re.MULTILINE) - - match = pattern.search(pr_body) - if match: - ref = match.group(1).strip() - if ref: - return ref - - return default_ref - - -def main(): - pr_body = os.environ.get("PR_BODY", "") - ref = parse_sdk_branch(pr_body) - - github_output = os.environ.get("GITHUB_OUTPUT") - if github_output: - with open(github_output, "a", encoding="utf-8") as f: - f.write(f"testing_ref={ref}\n") - - -if __name__ == "__main__": - main() diff --git a/.github/scripts/tests/__init__.py b/.github/scripts/tests/__init__.py deleted file mode 100644 index 8b137891..00000000 --- a/.github/scripts/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/.github/scripts/tests/test_lintcommit.py b/.github/scripts/tests/test_lintcommit.py deleted file mode 100644 index 42c42334..00000000 --- a/.github/scripts/tests/test_lintcommit.py +++ /dev/null @@ -1,259 +0,0 @@ -#!/usr/bin/env python3 - -from __future__ import annotations - -import os -import sys -from subprocess import CompletedProcess -from unittest.mock import patch - -sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) - -import pytest -from lintcommit import lint_range, validate_message, validate_subject - -# region validate_subject: valid subjects - - -def test_valid_feat() -> None: - assert validate_subject("feat: add new feature") is None - - -def test_valid_fix() -> None: - assert validate_subject("fix: resolve issue") is None - - -def test_valid_fix_with_scope() -> None: - assert validate_subject("fix(sdk): resolve issue") is None - - -def test_valid_build() -> None: - assert validate_subject("build: update build process") is None - - -def test_valid_chore() -> None: - assert validate_subject("chore: update dependencies") is None - - -def test_valid_ci() -> None: - assert validate_subject("ci: configure ci/cd") is None - - -def test_valid_deps() -> None: - assert validate_subject("deps: bump aws-sdk group with 5 updates") is None - - -def test_valid_docs() -> None: - assert validate_subject("docs: update documentation") is None - - -def test_valid_feat_with_scope() -> None: - assert validate_subject("feat(sdk): add new feature") is None - - -def test_valid_feat_scope_bar() -> None: - assert validate_subject("feat(sdk): bar") is None - - -def test_valid_feat_foo() -> None: - assert validate_subject("feat: foo") is None - - -def test_valid_fix_foo() -> None: - assert validate_subject("fix: foo") is None - - -# region validate_subject: invalid subjects - - -def test_invalid_type() -> None: - assert validate_subject("config: foo") == 'invalid type "config"' - - -def test_missing_colon() -> None: - assert validate_subject("invalid title") == "missing colon (:) char" - - -def test_period_at_end() -> None: - assert validate_subject("feat: add thing.") == "subject must not end with a period" - - -def test_empty_subject() -> None: - assert validate_subject("feat: ") == "empty subject" - - -def test_subject_too_long() -> None: - long_subject: str = "feat: " + "a" * 51 - result = validate_subject(long_subject) - assert result is not None - assert "invalid subject" in result - - -def test_type_with_whitespace() -> None: - assert validate_subject("fe at: foo") == 'type contains whitespace: "fe at"' - - -def test_scope_not_closed() -> None: - assert validate_subject("feat(sdk: foo") == "must be formatted like type(scope):" - - -def test_scope_too_long() -> None: - long_scope: str = "a" * 31 - result = validate_subject(f"feat({long_scope}): foo") - assert result is not None - assert "invalid scope" in result - - -def test_scope_uppercase() -> None: - result = validate_subject("feat(SDK): foo") - assert result is not None - assert "invalid scope" in result - - -def test_subject_uppercase() -> None: - assert validate_subject("feat: Add new feature") == "subject must be lowercase" - - -# region validate_message - - -def test_valid_subject_only() -> None: - error, warnings = validate_message("feat: add thing") - assert error is None - assert warnings == [] - - -def test_valid_with_body() -> None: - error, warnings = validate_message("feat: add thing\n\nThis is the body.") - assert error is None - assert warnings == [] - - -def test_missing_blank_line() -> None: - _, warnings = validate_message("feat: add thing\nNo blank line.") - assert "missing blank line between subject and body" in warnings - - -def test_missing_blank_line_body_still_checked() -> None: - _, warnings = validate_message("feat: add thing\n" + "x" * 80) - assert "missing blank line between subject and body" in warnings - assert any("exceeds 72 chars" in w for w in warnings), ( - "body line length should be checked even without blank line" - ) - - -def test_long_body_line() -> None: - _, warnings = validate_message("feat: add thing\n\n" + "x" * 80) - assert len(warnings) == 1 - assert "exceeds 72 chars" in warnings[0] - - -def test_empty_message() -> None: - error, _ = validate_message("") - assert error == "empty commit message" - - -def test_invalid_subject_in_message() -> None: - error, _ = validate_message("invalid title") - assert error == "missing colon (:) char" - - -# region lint_range - - -def _make_git_log_output(*messages: str) -> str: - """Build fake ``git log --no-merges -z --format=%H%n%B`` output. - - Records are separated by null characters. - """ - records: list[str] = [] - for i, msg in enumerate(messages): - sha = f"abc{i:04d}" + "0" * 33 # 40-char fake SHA - records.append(f"{sha}\n{msg}\n") - return "\0".join(records) - - -def _completed( - stdout: str = "", stderr: str = "", returncode: int = 0 -) -> CompletedProcess[str]: - """Shorthand for a ``subprocess.CompletedProcess``.""" - return CompletedProcess( - args=[], returncode=returncode, stdout=stdout, stderr=stderr - ) - - -@patch("subprocess.run") -def test_lint_range_all_valid(mock_run) -> None: - log_output = _make_git_log_output( - "feat: add new feature", - "fix(sdk): resolve issue", - ) - mock_run.return_value = _completed(stdout=log_output) - - result = lint_range("origin/main..HEAD", skip_dirty_check=True) - - assert not result.has_errors - assert len(result.commits) == 2 - assert all(c.error is None for c in result.commits) - - -@patch("subprocess.run") -def test_lint_range_with_invalid_commit(mock_run) -> None: - log_output = _make_git_log_output( - "feat: add new feature", - "bad commit no colon", - ) - mock_run.return_value = _completed(stdout=log_output) - - result = lint_range("origin/main..HEAD", skip_dirty_check=True) - - assert result.has_errors - assert result.commits[0].error is None - assert result.commits[1].error == "missing colon (:) char" - - -@patch("subprocess.run") -def test_lint_range_empty(mock_run) -> None: - mock_run.return_value = _completed(stdout="") - - result = lint_range("origin/main..HEAD", skip_dirty_check=True) - - assert result.empty - assert not result.has_errors - - -@patch("subprocess.run") -def test_lint_range_git_failure(mock_run) -> None: - mock_run.return_value = _completed(returncode=1, stderr="fatal: bad range") - - result = lint_range("bad..range", skip_dirty_check=True) - - assert result.has_errors - assert result.git_error == "fatal: bad range" - - -@patch("subprocess.run") -def test_lint_range_dirty_worktree_skips(mock_run) -> None: - """When skip_dirty_check=False and worktree is dirty, validation is skipped.""" - mock_run.return_value = _completed(stdout=" M .github/scripts/lintcommit.py\n") - - result = lint_range("origin/main..HEAD", skip_dirty_check=False) - - assert result.skipped - assert "uncommitted changes" in result.skip_reason - # git log should never have been called (only git status) - mock_run.assert_called_once() - - -@patch("subprocess.run") -def test_lint_range_warnings_collected(mock_run) -> None: - log_output = _make_git_log_output( - "feat: add thing\n\n" + "x" * 80, - ) - mock_run.return_value = _completed(stdout=log_output) - - result = lint_range("origin/main..HEAD", skip_dirty_check=True) - - assert not result.has_errors - assert len(result.commits) == 1 - assert any("exceeds 72 chars" in w for w in result.commits[0].warnings) diff --git a/.github/scripts/tests/test_parse_sdk_branch.py b/.github/scripts/tests/test_parse_sdk_branch.py deleted file mode 100644 index d4586512..00000000 --- a/.github/scripts/tests/test_parse_sdk_branch.py +++ /dev/null @@ -1,84 +0,0 @@ -#!/usr/bin/env python3 - -import os -import sys - -sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) - -from parse_sdk_branch import parse_sdk_branch - - -def test_parse_sdk_branch(): - test_cases = [ - # Basic cases - ("TESTING_SDK_BRANCH = feature/test", "feature/test"), - ("TESTING_SDK_BRANCH: feature/test", "feature/test"), - ("TESTING_SDK_BRANCH=feature/test", "feature/test"), - ("testing_sdk_branch: feature/test", "feature/test"), - # Complex PR body with backticks and contractions - ( - """Updated the script to safely parse the testing SDK branch from the PR body, handling case insensitivity and whitespace. - -The goal here is to fix the usage of backticks such as in `foo`, and contractions that we've been using such as `we've` - -``` -plus of course the usage of multiple backticks to include code -``` - -TESTING_SDK_BRANCH = main - -By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.""", - "main", - ), - # Edge cases with markdown and special characters - ( - """# PR Title - -Some `code` and we've got contractions here. - -```python -def test(): - return "test" -``` - -TESTING_SDK_BRANCH: feature/fix-backticks - -More text with `inline code` and don't forget contractions.""", - "feature/fix-backticks", - ), - # Multiple occurrences (should take first) - ( - """TESTING_SDK_BRANCH = first-branch - -Some text here. - -TESTING_SDK_BRANCH = second-branch""", - "first-branch", - ), - # Whitespace variations - (" TESTING_SDK_BRANCH = feature/spaces ", "feature/spaces"), - ("TESTING_SDK_BRANCH:feature/no-space", "feature/no-space"), - # Default cases - ("No branch specified", "main"), - ("", "main"), - ("Just some random text", "main"), - # Case with backticks in branch name - ("TESTING_SDK_BRANCH = feature/fix-`backticks`", "feature/fix-`backticks`"), - # Case with contractions in surrounding text - ( - "We've updated this and TESTING_SDK_BRANCH = feature/test and we're done", - "feature/test", - ), - ] - - for input_text, expected in test_cases: - result = parse_sdk_branch(input_text) - # Assert is expected in test functions - assert result == expected, ( # noqa: S101 - f"Expected '{expected}' but got '{result}' for input: {input_text[:50]}..." - ) - - -if __name__ == "__main__": - test_parse_sdk_branch() - sys.exit(0) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2e4f2992..a7d2973a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,20 +13,6 @@ on: branches: [ main ] jobs: - lint-commits: - if: github.event_name == 'pull_request' && github.actor != 'dependabot[bot]' - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - fetch-depth: 0 - - name: Set up Python - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 - with: - python-version: "3.12" - - name: Lint commit messages - run: python .github/scripts/lintcommit.py --range "origin/${{ github.event.pull_request.base.ref }}..${{ github.event.pull_request.head.sha }}" - build: runs-on: ubuntu-latest strategy: diff --git a/.github/workflows/ecr-release.yml b/.github/workflows/ecr-release.yml deleted file mode 100644 index 55ffa028..00000000 --- a/.github/workflows/ecr-release.yml +++ /dev/null @@ -1,130 +0,0 @@ -name: Upload Testing SDK Emulator Image - -on: - release: - types: [published] - -permissions: - contents: read - id-token: write - -env: - package_path: packages/aws-durable-execution-sdk-python-testing - aws_region: us-east-1 - ecr_repository_name: durable-functions/aws-durable-execution-emulator - -jobs: - build-and-upload-image-to-ecr: - runs-on: ubuntu-latest - outputs: - full_image_arm64: ${{ steps.build-publish.outputs.full_image_arm64 }} - full_image_x86_64: ${{ steps.build-publish.outputs.full_image_x86_64 }} - ecr_registry_repository: ${{ steps.build-publish.outputs.ecr_registry_repository }} - version: ${{ steps.version.outputs.VERSION }} - strategy: - matrix: - include: - - arch: x86_64 - platform: linux/amd64 - - arch: arm64 - platform: linux/arm64 - - steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - ref: ${{ github.event.release.tag_name }} - - - name: Set up Python - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 - with: - python-version: "3.13" - - - name: Install Hatch - run: python -m pip install --upgrade hatch==1.16.5 - - - name: Set up QEMU for multi-platform builds - if: matrix.arch == 'arm64' - uses: docker/setup-qemu-action@v3 - with: - platforms: arm64 - - - name: Build distribution - working-directory: ${{ env.package_path }} - run: hatch build - - - name: Get version from __about__.py - id: version - run: | - VERSION=$(grep "^__version__" "${{ env.package_path }}/src/aws_durable_execution_sdk_python_testing/__about__.py" | cut -d'"' -f2) - echo "VERSION=$VERSION" - echo "VERSION=${VERSION}" >> "$GITHUB_OUTPUT" - - - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v4 - with: - role-to-assume: ${{ secrets.ECR_UPLOAD_IAM_ROLE_ARN }} - aws-region: ${{ env.aws_region }} - - - name: Login to Amazon ECR - id: login-ecr-public - uses: aws-actions/amazon-ecr-login@v2 - with: - registry-type: public - - - name: Build, tag, and push image to Amazon ECR - id: build-publish - shell: bash - env: - ECR_REGISTRY: ${{ steps.login-ecr-public.outputs.registry }} - ECR_REPOSITORY: ${{ env.ecr_repository_name }} - PER_ARCH_IMAGE_TAG: v${{ steps.version.outputs.VERSION }}-${{ matrix.arch }} - run: | - docker build --platform "${{ matrix.platform }}" --provenance false "${{ env.package_path }}" -f "${{ env.package_path }}/Dockerfile" -t "$ECR_REGISTRY/$ECR_REPOSITORY:$PER_ARCH_IMAGE_TAG" - docker push "$ECR_REGISTRY/$ECR_REPOSITORY:$PER_ARCH_IMAGE_TAG" - echo "ecr_registry_repository=$ECR_REGISTRY/$ECR_REPOSITORY" >> "$GITHUB_OUTPUT" - echo "full_image_${{ matrix.arch }}=$ECR_REGISTRY/$ECR_REPOSITORY:$PER_ARCH_IMAGE_TAG" >> "$GITHUB_OUTPUT" - - create-ecr-manifest-per-arch: - runs-on: ubuntu-latest - needs: [build-and-upload-image-to-ecr] - steps: - - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v4 - with: - role-to-assume: ${{ secrets.ECR_UPLOAD_IAM_ROLE_ARN }} - aws-region: ${{ env.aws_region }} - - - name: Login to Amazon ECR - uses: aws-actions/amazon-ecr-login@v2 - with: - registry-type: public - - - name: Create and push explicit version manifest - run: | - docker manifest create "${{ needs.build-and-upload-image-to-ecr.outputs.ecr_registry_repository }}:v${{ needs.build-and-upload-image-to-ecr.outputs.version }}" \ - "${{ needs.build-and-upload-image-to-ecr.outputs.full_image_x86_64 }}" \ - "${{ needs.build-and-upload-image-to-ecr.outputs.full_image_arm64 }}" - docker manifest annotate "${{ needs.build-and-upload-image-to-ecr.outputs.ecr_registry_repository }}:v${{ needs.build-and-upload-image-to-ecr.outputs.version }}" \ - "${{ needs.build-and-upload-image-to-ecr.outputs.full_image_arm64 }}" \ - --arch arm64 \ - --os linux - docker manifest annotate "${{ needs.build-and-upload-image-to-ecr.outputs.ecr_registry_repository }}:v${{ needs.build-and-upload-image-to-ecr.outputs.version }}" \ - "${{ needs.build-and-upload-image-to-ecr.outputs.full_image_x86_64 }}" \ - --arch amd64 \ - --os linux - docker manifest push "${{ needs.build-and-upload-image-to-ecr.outputs.ecr_registry_repository }}:v${{ needs.build-and-upload-image-to-ecr.outputs.version }}" - - - name: Create and push latest manifest - run: | - docker manifest create "${{ needs.build-and-upload-image-to-ecr.outputs.ecr_registry_repository }}" \ - "${{ needs.build-and-upload-image-to-ecr.outputs.full_image_arm64 }}" \ - "${{ needs.build-and-upload-image-to-ecr.outputs.full_image_x86_64 }}" - docker manifest annotate "${{ needs.build-and-upload-image-to-ecr.outputs.ecr_registry_repository }}" \ - "${{ needs.build-and-upload-image-to-ecr.outputs.full_image_arm64 }}" \ - --arch arm64 \ - --os linux - docker manifest annotate "${{ needs.build-and-upload-image-to-ecr.outputs.ecr_registry_repository }}" \ - "${{ needs.build-and-upload-image-to-ecr.outputs.full_image_x86_64 }}" \ - --arch amd64 \ - --os linux - docker manifest push "${{ needs.build-and-upload-image-to-ecr.outputs.ecr_registry_repository }}" diff --git a/.github/workflows/notify-issues.yml b/.github/workflows/notify-issues.yml deleted file mode 100644 index c4d89657..00000000 --- a/.github/workflows/notify-issues.yml +++ /dev/null @@ -1,23 +0,0 @@ -name: Notify Slack - Issues - -on: - issues: - types: [opened, reopened] - -permissions: {} - -jobs: - notify: - runs-on: ubuntu-latest - steps: - - name: Send issue notification to Slack - uses: slackapi/slack-github-action@45a88b9581bfab2566dc881e2cd66d334e621e2c # v3.0.3 - with: - webhook: ${{ secrets.SLACK_WEBHOOK_URL_ISSUE }} - webhook-type: incoming-webhook - payload: | - { - "action": "${{ github.event.action }}", - "issue_url": "${{ github.event.issue.html_url }}", - "package_name": "${{ github.repository }}" - } diff --git a/.github/workflows/notify-pr.yml b/.github/workflows/notify-pr.yml deleted file mode 100644 index 0335e6a6..00000000 --- a/.github/workflows/notify-pr.yml +++ /dev/null @@ -1,24 +0,0 @@ -name: Notify Slack - Pull Requests - -on: - pull_request_target: - types: [opened, reopened, ready_for_review] - -permissions: {} - -jobs: - notify: - runs-on: ubuntu-latest - steps: - - name: Send pull request notification to Slack - if: github.event.pull_request.draft == false - uses: slackapi/slack-github-action@45a88b9581bfab2566dc881e2cd66d334e621e2c # v3.0.3 - with: - webhook: ${{ secrets.SLACK_WEBHOOK_URL_PR }} - webhook-type: incoming-webhook - payload: | - { - "action": "${{ github.event.action }}", - "pr_url": "${{ github.event.pull_request.html_url }}", - "package_name": "${{ github.repository }}" - } diff --git a/.github/workflows/notify-release.yml b/.github/workflows/notify-release.yml deleted file mode 100644 index 778db7e9..00000000 --- a/.github/workflows/notify-release.yml +++ /dev/null @@ -1,29 +0,0 @@ -name: Notify Slack - Release - -on: - workflow_call: - inputs: - tag_name: - required: true - type: string - release_url: - required: true - type: string - -permissions: {} - -jobs: - notify: - runs-on: ubuntu-latest - steps: - - name: Send release notification to Slack - uses: slackapi/slack-github-action@45a88b9581bfab2566dc881e2cd66d334e621e2c # v3.0.3 - with: - webhook: ${{ secrets.SLACK_WEBHOOK_URL_RELEASE }} - webhook-type: incoming-webhook - payload: | - { - "tag_name": "${{ inputs.tag_name }}", - "release_url": "${{ inputs.release_url }}", - "package_name": "${{ github.repository }}" - } diff --git a/.github/workflows/pypi-publish.yml b/.github/workflows/pypi-publish.yml index 42c66f83..672c66ca 100644 --- a/.github/workflows/pypi-publish.yml +++ b/.github/workflows/pypi-publish.yml @@ -80,11 +80,4 @@ jobs: with: packages-dir: dist/ - notify-release: - if: always() && contains(needs.pypi-publish.result, 'success') - needs: [pypi-publish] - uses: ./.github/workflows/notify-release.yml - with: - tag_name: ${{ github.event.release.tag_name }} - release_url: ${{ github.event.release.html_url }} - secrets: inherit + diff --git a/.github/workflows/test-parser.yml b/.github/workflows/test-parser.yml deleted file mode 100644 index 4d6249a0..00000000 --- a/.github/workflows/test-parser.yml +++ /dev/null @@ -1,24 +0,0 @@ -name: Test Parser - -on: - pull_request: - paths: - - '.github/scripts/parse_sdk_branch.py' - - '.github/scripts/tests/**' - push: - branches: [ main ] - paths: - - '.github/scripts/parse_sdk_branch.py' - - '.github/scripts/tests/**' - -permissions: - contents: read - -jobs: - test-parser: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - - - name: Run parser tests - run: python .github/scripts/tests/test_parse_sdk_branch.py diff --git a/README.md b/README.md index 3a772717..97cea0ad 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ Build reliable, long-running AWS Lambda workflows with checkpointed steps, waits - **Child contexts** - Structure complex workflows into isolated subflows - **Replay-safe logging** - Use `context.logger` for structured, de-duplicated logs - **Local and cloud testing** - Validate workflows with the testing SDK +- **Async Python support** - Use `async def` for handlers, steps, child contexts, callback submitters, and wait-for-condition checks ## 📦 Packages @@ -66,6 +67,30 @@ def handler(event: dict, context: DurableContext) -> dict: return {"status": "approved", "order_id": order_id} ``` +Async callables are supported anywhere the SDK accepts user code. The public Durable APIs stay synchronous, so async work is awaited transparently for you: + +```python +import asyncio + +from aws_durable_execution_sdk_python import ( + DurableContext, + StepContext, + durable_execution, + durable_step, +) + +@durable_step +async def fetch_order(step_ctx: StepContext, order_id: str) -> dict: + await asyncio.sleep(0) + step_ctx.logger.info("Fetched order", extra={"order_id": order_id}) + return {"order_id": order_id, "status": "ready"} + +@durable_execution +async def handler(event: dict, context: DurableContext) -> dict: + order = context.step(fetch_order(event["order_id"]), name="fetch_order") + return {"order": order} +``` + ## 📚 Documentation The complete documentation for the AWS Durable Execution SDK for Python lives on the AWS Documentation site: diff --git a/packages/aws-durable-execution-sdk-python-examples/src/run_in_child_context/run_in_child_context.py b/packages/aws-durable-execution-sdk-python-examples/src/run_in_child_context/run_in_child_context.py index 9e5a665a..fc711f8f 100644 --- a/packages/aws-durable-execution-sdk-python-examples/src/run_in_child_context/run_in_child_context.py +++ b/packages/aws-durable-execution-sdk-python-examples/src/run_in_child_context/run_in_child_context.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any from aws_durable_execution_sdk_python.context import ( @@ -12,11 +13,12 @@ def multiply_by_two(value: int) -> int: @durable_with_child_context -def child_operation(ctx: DurableContext, value: int) -> int: +async def child_operation(ctx: DurableContext, value: int) -> int: + await asyncio.sleep(0) return ctx.step(lambda _: multiply_by_two(value), name="multiply") @durable_execution -def handler(_event: Any, context: DurableContext) -> str: +async def handler(_event: Any, context: DurableContext) -> str: result = context.run_in_child_context(child_operation(5)) return f"Child context result: {result}" diff --git a/packages/aws-durable-execution-sdk-python-examples/src/step/step.py b/packages/aws-durable-execution-sdk-python-examples/src/step/step.py index 3249040a..3ec92f7e 100644 --- a/packages/aws-durable-execution-sdk-python-examples/src/step/step.py +++ b/packages/aws-durable-execution-sdk-python-examples/src/step/step.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any from aws_durable_execution_sdk_python.context import ( @@ -9,11 +10,12 @@ @durable_step -def add_numbers(_step_context: StepContext, a: int, b: int) -> int: +async def add_numbers(_step_context: StepContext, a: int, b: int) -> int: + await asyncio.sleep(0) return a + b @durable_execution -def handler(_event: Any, context: DurableContext) -> int: +async def handler(_event: Any, context: DurableContext) -> int: result: int = context.step(add_numbers(5, 3)) return result diff --git a/packages/aws-durable-execution-sdk-python-examples/src/wait_for_callback/wait_for_callback.py b/packages/aws-durable-execution-sdk-python-examples/src/wait_for_callback/wait_for_callback.py index bac1eb36..4af2accd 100644 --- a/packages/aws-durable-execution-sdk-python-examples/src/wait_for_callback/wait_for_callback.py +++ b/packages/aws-durable-execution-sdk-python-examples/src/wait_for_callback/wait_for_callback.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any from aws_durable_execution_sdk_python.config import Duration, WaitForCallbackConfig @@ -8,14 +9,17 @@ from aws_durable_execution_sdk_python.execution import durable_execution -def external_system_call(_callback_id: str, _context: WaitForCallbackContext) -> None: +async def external_system_call( + _callback_id: str, _context: WaitForCallbackContext +) -> None: """Simulate calling an external system with callback ID.""" + await asyncio.sleep(0) # In real usage, this would make an API call to an external system # passing the callback_id for the system to call back when done @durable_execution -def handler(_event: Any, context: DurableContext) -> str: +async def handler(_event: Any, context: DurableContext) -> str: config = WaitForCallbackConfig( timeout=Duration.from_seconds(120), heartbeat_timeout=Duration.from_seconds(60) ) diff --git a/packages/aws-durable-execution-sdk-python-examples/src/wait_for_condition/wait_for_condition.py b/packages/aws-durable-execution-sdk-python-examples/src/wait_for_condition/wait_for_condition.py index 37befe6a..7f409877 100644 --- a/packages/aws-durable-execution-sdk-python-examples/src/wait_for_condition/wait_for_condition.py +++ b/packages/aws-durable-execution-sdk-python-examples/src/wait_for_condition/wait_for_condition.py @@ -1,5 +1,6 @@ """Example demonstrating wait-for-condition pattern.""" +import asyncio from typing import Any from aws_durable_execution_sdk_python.context import DurableContext @@ -12,11 +13,12 @@ @durable_execution -def handler(_event: Any, context: DurableContext) -> int: +async def handler(_event: Any, context: DurableContext) -> int: """Handler demonstrating wait-for-condition pattern.""" - def condition_function(state: int, _) -> int: + async def condition_function(state: int, _) -> int: """Increment state by 1.""" + await asyncio.sleep(0) return state + 1 def wait_strategy(state: int, attempt: int) -> dict[str, Any]: diff --git a/packages/aws-durable-execution-sdk-python/README.md b/packages/aws-durable-execution-sdk-python/README.md index 3a772717..97cea0ad 100644 --- a/packages/aws-durable-execution-sdk-python/README.md +++ b/packages/aws-durable-execution-sdk-python/README.md @@ -19,6 +19,7 @@ Build reliable, long-running AWS Lambda workflows with checkpointed steps, waits - **Child contexts** - Structure complex workflows into isolated subflows - **Replay-safe logging** - Use `context.logger` for structured, de-duplicated logs - **Local and cloud testing** - Validate workflows with the testing SDK +- **Async Python support** - Use `async def` for handlers, steps, child contexts, callback submitters, and wait-for-condition checks ## 📦 Packages @@ -66,6 +67,30 @@ def handler(event: dict, context: DurableContext) -> dict: return {"status": "approved", "order_id": order_id} ``` +Async callables are supported anywhere the SDK accepts user code. The public Durable APIs stay synchronous, so async work is awaited transparently for you: + +```python +import asyncio + +from aws_durable_execution_sdk_python import ( + DurableContext, + StepContext, + durable_execution, + durable_step, +) + +@durable_step +async def fetch_order(step_ctx: StepContext, order_id: str) -> dict: + await asyncio.sleep(0) + step_ctx.logger.info("Fetched order", extra={"order_id": order_id}) + return {"order_id": order_id, "status": "ready"} + +@durable_execution +async def handler(event: dict, context: DurableContext) -> dict: + order = context.step(fetch_order(event["order_id"]), name="fetch_order") + return {"order": order} +``` + ## 📚 Documentation The complete documentation for the AWS Durable Execution SDK for Python lives on the AWS Documentation site: diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/async_tools.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/async_tools.py new file mode 100644 index 00000000..fb60278b --- /dev/null +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/async_tools.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import asyncio +import inspect +import queue +import threading +from collections.abc import Awaitable, Callable +from typing import Any, TypeVar, cast + + +T = TypeVar("T") + + +def resolve_awaitable(value: T | Awaitable[T]) -> T: + if inspect.isawaitable(value): + return run_awaitable(cast(Awaitable[T], value)) + return value + + +def invoke_callable(func: Callable[..., T | Awaitable[T]], *args, **kwargs) -> T: + return resolve_awaitable(func(*args, **kwargs)) + + +def run_awaitable(awaitable: Awaitable[T]) -> T: + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(awaitable) + + return _run_awaitable_in_thread(awaitable) + + +def _run_awaitable_in_thread(awaitable: Awaitable[T]) -> T: + result_queue: queue.Queue[tuple[bool, T | BaseException]] = queue.Queue(maxsize=1) + + def runner() -> None: + try: + result_queue.put((True, asyncio.run(awaitable))) + except BaseException as exc: # noqa: BLE001 + result_queue.put((False, exc)) + + thread = threading.Thread(target=runner, name="dex-async-user-code", daemon=True) + thread.start() + success, payload = result_queue.get() + thread.join() + + if success: + return cast(T, payload) + raise cast(BaseException, payload) diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/context.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/context.py index 00e575d2..f4b72ba1 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/context.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/context.py @@ -61,7 +61,7 @@ if TYPE_CHECKING: - from collections.abc import Callable, Sequence + from collections.abc import Awaitable, Callable, Sequence from aws_durable_execution_sdk_python.concurrency.models import BatchResult from aws_durable_execution_sdk_python.state import CheckpointedResult @@ -95,8 +95,8 @@ class ExecutionContext: def durable_step( - func: Callable[Concatenate[StepContext, Params], T], -) -> Callable[Params, Callable[[StepContext], T]]: + func: Callable[Concatenate[StepContext, Params], T | Awaitable[T]], +) -> Callable[Params, Callable[[StepContext], T | Awaitable[T]]]: """Wrap your callable into a named function that a Durable step can run.""" def wrapper(*args, **kwargs): @@ -110,8 +110,8 @@ def function_with_arguments(context: StepContext): def durable_with_child_context( - func: Callable[Concatenate[DurableContext, Params], T], -) -> Callable[Params, Callable[[DurableContext], T]]: + func: Callable[Concatenate[DurableContext, Params], T | Awaitable[T]], +) -> Callable[Params, Callable[[DurableContext], T | Awaitable[T]]]: """Wrap your callable into a Durable child context.""" def wrapper(*args, **kwargs): @@ -127,7 +127,7 @@ def function_with_arguments(child_context: DurableContext): def durable_parallel_branch( name: str | None = None, ) -> Callable[ - [Callable[Concatenate[DurableContext, Params], T]], + [Callable[Concatenate[DurableContext, Params], T | Awaitable[T]]], Callable[Params, ParallelBranch[T]], ]: """Wrap your callable into a named ParallelBranch for use with context.parallel(). @@ -157,7 +157,7 @@ def fetch_orders(ctx: DurableContext, user_id: str) -> list: """ def decorator( - func: Callable[Concatenate[DurableContext, Params], T], + func: Callable[Concatenate[DurableContext, Params], T | Awaitable[T]], ) -> Callable[Params, ParallelBranch[T]]: def wrapper(*args, **kwargs) -> ParallelBranch[T]: def function_with_arguments(ctx: DurableContext) -> T: @@ -171,8 +171,8 @@ def function_with_arguments(ctx: DurableContext) -> T: def durable_wait_for_callback( - func: Callable[Concatenate[str, WaitForCallbackContext, Params], T], -) -> Callable[Params, Callable[[str, WaitForCallbackContext], T]]: + func: Callable[Concatenate[str, WaitForCallbackContext, Params], T | Awaitable[T]], +) -> Callable[Params, Callable[[str, WaitForCallbackContext], T | Awaitable[T]]]: """Wrap your callable into a wait_for_callback submitter function. This decorator allows you to define a submitter function with additional @@ -596,7 +596,7 @@ def parallel_in_child_context() -> BatchResult[T]: def run_in_child_context( self, - func: Callable[[DurableContext], T], + func: Callable[[DurableContext], T | Awaitable[T]], name: str | None = None, config: ChildConfig | None = None, ) -> T: @@ -646,7 +646,7 @@ def callable_with_child_context(): def step( self, - func: Callable[[StepContext], T], + func: Callable[[StepContext], T | Awaitable[T]], name: str | None = None, config: StepConfig | None = None, ) -> T: @@ -699,7 +699,7 @@ def wait(self, duration: Duration, name: str | None = None) -> None: def wait_for_callback( self, - submitter: Callable[[str, WaitForCallbackContext], None], + submitter: Callable[[str, WaitForCallbackContext], Any], name: str | None = None, config: WaitForCallbackConfig | None = None, ) -> Any: @@ -716,7 +716,7 @@ def wait_in_child_context(context: DurableContext): def wait_for_condition( self, - check: Callable[[T, WaitForConditionCheckContext], T], + check: Callable[[T, WaitForConditionCheckContext], T | Awaitable[T]], config: WaitForConditionConfig[T], name: str | None = None, ) -> T: diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/execution.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/execution.py index afb710e9..a3dc9a9b 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/execution.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/execution.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any +from aws_durable_execution_sdk_python.async_tools import invoke_callable from aws_durable_execution_sdk_python.context import DurableContext from aws_durable_execution_sdk_python.exceptions import ( BackgroundThreadError, @@ -37,7 +38,7 @@ from aws_durable_execution_sdk_python.state import ExecutionState, ReplayStatus if TYPE_CHECKING: - from collections.abc import Callable, MutableMapping + from collections.abc import Awaitable, Callable, MutableMapping from mypy_boto3_lambda import LambdaClient as Boto3LambdaClient @@ -160,7 +161,11 @@ def from_durable_execution_invocation_input( def durable_execution( - func: Callable[[Any, DurableContext], Any] | None = None, + func: ( + Callable[[Any, DurableContext], Any] + | Callable[[Any, DurableContext], Awaitable[Any]] + | None + ) = None, *, boto3_client: Boto3LambdaClient | None = None, plugins: list[DurableInstrumentationPlugin] | None = None, @@ -294,7 +299,9 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: logger.debug( "%s entering user-space...", invocation_input.durable_execution_arn ) - user_future = executor.submit(func, input_event, durable_context) + user_future = executor.submit( + invoke_callable, func, input_event, durable_context + ) logger.debug( "%s waiting for user code completion...", diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/callback.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/callback.py index 67c51ebc..f5a35be1 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/callback.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/operation/callback.py @@ -149,7 +149,7 @@ def execute(self, checkpointed_result: CheckpointedResult) -> str: def wait_for_callback_handler( context: DurableContext, - submitter: Callable[[str, WaitForCallbackContext], None], + submitter: Callable[[str, WaitForCallbackContext], Any], name: str | None = None, config: WaitForCallbackConfig | None = None, ) -> Any: diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py index 7fcfadcc..743d0ccc 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py @@ -13,6 +13,7 @@ from threading import Lock from typing import TYPE_CHECKING, Callable, Any +from aws_durable_execution_sdk_python.async_tools import resolve_awaitable from aws_durable_execution_sdk_python.exceptions import ( BackgroundThreadError, CallableRuntimeError, @@ -936,7 +937,7 @@ def wrapper(*args, **kwargs): operation_identifier, is_replay_children, attempt ) try: - result = user_function(*args, **kwargs) + result = resolve_awaitable(user_function(*args, **kwargs)) self._plugin_executor.on_user_function_end(start_info, None) return result except SuspendExecution as e: diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/types.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/types.py index 90080b09..e8eed445 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/types.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/types.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: - from collections.abc import Callable, Mapping, Sequence + from collections.abc import Awaitable, Callable, Mapping, Sequence from aws_durable_execution_sdk_python.config import ( BatchedInput, @@ -95,7 +95,7 @@ class DurableContext(Protocol): @abstractmethod def step( self, - func: Callable[[StepContext], T], + func: Callable[[StepContext], T | Awaitable[T]], name: str | None = None, config: StepConfig | None = None, ) -> T: @@ -105,7 +105,7 @@ def step( @abstractmethod def run_in_child_context( self, - func: Callable[[DurableContext], T], + func: Callable[[DurableContext], T | Awaitable[T]], name: str | None = None, config: ChildConfig | None = None, ) -> T: @@ -116,7 +116,10 @@ def run_in_child_context( def map( self, inputs: Sequence[U], - func: Callable[[DurableContext, U | BatchedInput[Any, U], int, Sequence[U]], T], + func: Callable[ + [DurableContext, U | BatchedInput[Any, U], int, Sequence[U]], + T | Awaitable[T], + ], name: str | None = None, config: MapConfig | None = None, ) -> BatchResult[T]: @@ -126,7 +129,9 @@ def map( @abstractmethod def parallel( self, - functions: Sequence[Callable[[DurableContext], T] | ParallelBranch[T]], + functions: Sequence[ + Callable[[DurableContext], T | Awaitable[T]] | ParallelBranch[T] + ], name: str | None = None, config: ParallelConfig | None = None, ) -> BatchResult[T]: diff --git a/packages/aws-durable-execution-sdk-python/tests/async_tools_test.py b/packages/aws-durable-execution-sdk-python/tests/async_tools_test.py new file mode 100644 index 00000000..b72e7a37 --- /dev/null +++ b/packages/aws-durable-execution-sdk-python/tests/async_tools_test.py @@ -0,0 +1,22 @@ +import asyncio + +from aws_durable_execution_sdk_python.async_tools import invoke_callable + + +def test_invoke_callable_runs_async_callable(): + async def async_callable() -> str: + await asyncio.sleep(0) + return "async-result" + + assert invoke_callable(async_callable) == "async-result" + + +def test_invoke_callable_runs_async_callable_from_running_loop(): + async def async_callable() -> str: + await asyncio.sleep(0) + return "nested-async-result" + + async def main() -> str: + return invoke_callable(async_callable) + + assert asyncio.run(main()) == "nested-async-result" diff --git a/packages/aws-durable-execution-sdk-python/tests/execution_test.py b/packages/aws-durable-execution-sdk-python/tests/execution_test.py index ed79bedf..69d5fec7 100644 --- a/packages/aws-durable-execution-sdk-python/tests/execution_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/execution_test.py @@ -1,5 +1,6 @@ """Tests for execution.""" +import asyncio import datetime import json import time @@ -2975,6 +2976,56 @@ def test_handler(event: Any, context: DurableContext) -> dict: assert len(execution_end_calls) == 0 +def test_durable_execution_supports_async_handler(): + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + @durable_execution + async def test_handler(event: Any, context: DurableContext) -> dict: + await asyncio.sleep(0) + context.logger.info("handled async invocation") + return {"result": "async-success"} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + assert json.loads(result["Result"]) == {"result": "async-success"} + + +def test_durable_execution_supports_async_steps_inside_async_handler(): + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + async def async_step(_step_context) -> str: + await asyncio.sleep(0) + return "async-step-success" + + @durable_execution + async def test_handler(event: Any, context: DurableContext) -> dict: + await asyncio.sleep(0) + step_result = context.step(async_step, name="async-step") + return {"step_result": step_result} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + assert json.loads(result["Result"]) == {"step_result": "async-step-success"} + + def test_durable_execution_with_plugins_retryable_error(): """Test that plugins receive invocation end with RETRY status on retryable error.""" mock_client = Mock(spec=DurableServiceClient)