diff --git a/.env.example b/.env.example index 5cf7e7f..470aa65 100644 --- a/.env.example +++ b/.env.example @@ -33,6 +33,9 @@ MAX_UPLOAD_BYTES=2147483648 # Runtime cache and conversion limits. JOBS_MAX_CACHE=200 +TRANSCRIPTION_MAX_ACTIVE_JOBS=200 +TRANSCRIPTION_MAX_IN_FLIGHT_JOBS=4 +TRANSCRIPTION_MIN_FREE_DISK_BYTES=1073741824 FFMPEG_TIMEOUT_SEC=1800 # Optional idle model unload. Defaults to 180 seconds (3 minutes). Set to 0 @@ -42,14 +45,13 @@ FFMPEG_TIMEOUT_SEC=1800 # with the most free memory. MODEL_IDLE_TIMEOUT_SEC=180 -# Runtime mode for optional Rust-backed provider/kernel paths. -# off — default; use Python implementations. -# required — selected Rust-backed paths must run and hard-fail on import/call errors. +# Runtime mode for Rust-backed provider/kernel paths. +# required — default; selected Rust-backed paths must run and hard-fail on import/call errors. +# off — explicit rollback only; use Python implementations. # Currently selected paths: voiceprint scoring, result post-processing, and # artifact manifest helper contracts. -# CI/Docker packaging still validates the Rust extension directly even when -# the runtime default is off. -RUST_KERNEL_MODE=off +# CI/Docker packaging validates the Rust extension directly. +RUST_KERNEL_MODE=required # UID/GID the container process runs as. Must match the owner of DATA_DIR # and MODEL_CACHE_DIR on the host, otherwise writes fail. On a typical @@ -111,9 +113,10 @@ WHISPERX_ALIGN_CACHE_ONLY=0 # Noise reduction defaults. Omitting denoise_model in the API uses # DENOISE_MODEL; explicitly sending denoise_model=none disables denoising for # that request. DENOISE_SNR_THRESHOLD only gates DeepFilterNet skips; -# noisereduce runs whenever selected. +# noisereduce still respects DENOISE_MAX_AUDIO_DURATION_SEC. DENOISE_MODEL=none DENOISE_SNR_THRESHOLD=10.0 +DENOISE_MAX_AUDIO_DURATION_SEC=7200 # Speaker matching and diarization/embedding defaults. VOICEPRINT_THRESHOLD=0.75 @@ -121,3 +124,5 @@ EMBEDDING_DIM=256 PYANNOTE_MIN_DURATION_OFF=0.5 MIN_EMBED_DURATION=1.5 MAX_EMBED_DURATION=10.0 +EMBEDDING_PRELOAD_MAX_AUDIO_DURATION_SEC=1800 +WHISPERX_ALIGN_MAX_AUDIO_DURATION_SEC=7200 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 70e3afa..9bf809a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,6 +17,26 @@ jobs: - name: Run public release scan run: python voscript-api/scripts/public_release_scan.py --root . + architecture-gate: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Run architecture gate + run: python voscript-api/scripts/architecture_gate.py --root . --check + + docs-code-drift-gate: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Run docs/code drift gate + run: python voscript-api/scripts/docs_code_drift_gate.py --root . --check + lint: runs-on: ubuntu-latest steps: @@ -78,4 +98,5 @@ jobs: run: | pip-audit -r app/requirements.txt \ --ignore-vuln PYSEC-2022-42969 \ - --ignore-vuln CVE-2026-1839 + --ignore-vuln CVE-2026-1839 \ + --ignore-vuln CVE-2025-3000 diff --git a/.github/workflows/claude-code-review.yml b/.github/workflows/claude-code-review.yml index e07c9d7..e294757 100644 --- a/.github/workflows/claude-code-review.yml +++ b/.github/workflows/claude-code-review.yml @@ -90,4 +90,4 @@ jobs: claude_args: | --model ${{ env.CLAUDE_MODEL }} - --max-turns 8 + --max-turns 16 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7d032e5..ed16ebc 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -8,21 +8,117 @@ on: - 'v*' workflow_dispatch: +permissions: + contents: read + +env: + PYTHON_VERSION: "3.11" + RUST_KERNEL_MODE: required + jobs: - publish: + resolve-source: + name: resolve-source runs-on: ubuntu-latest - permissions: - contents: read - packages: write + outputs: + source-sha: ${{ steps.source.outputs.sha }} + source-ref: ${{ steps.source.outputs.ref }} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Resolve immutable source ref + id: source + run: | + set -euo pipefail + SOURCE_SHA="$(git rev-parse HEAD)" + echo "sha=$SOURCE_SHA" >> "$GITHUB_OUTPUT" + echo "ref=${GITHUB_REF}" >> "$GITHUB_OUTPUT" + echo "Resolved release source ${GITHUB_REF} to ${SOURCE_SHA}" + + public-release-scan: + name: public-release-scan + runs-on: ubuntu-latest + needs: resolve-source + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ needs.resolve-source.outputs.source-sha }} + - uses: actions/setup-python@v5 + with: + python-version: ${{ env.PYTHON_VERSION }} + - name: Run public release scan + run: python voscript-api/scripts/public_release_scan.py --root . + + lint-format: + name: lint-format + runs-on: ubuntu-latest + needs: resolve-source + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ needs.resolve-source.outputs.source-sha }} + - uses: actions/setup-python@v5 + with: + python-version: ${{ env.PYTHON_VERSION }} + cache: pip + - name: Install ruff + run: python -m pip install ruff + - name: Run lint and format checks + run: | + ruff check app/ --ignore E501 + ruff format --check app/ + unit-security: + name: unit-security + runs-on: ubuntu-latest + needs: resolve-source + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ needs.resolve-source.outputs.source-sha }} + - uses: actions/setup-python@v5 + with: + python-version: ${{ env.PYTHON_VERSION }} + cache: pip + - name: Install test and security dependencies + run: | + python -m pip install pytest pytest-cov fastapi httpx numpy aiofiles starlette python-multipart pip-audit + - name: Run unit and security tests + env: + PYTEST_DISABLE_PLUGIN_AUTOLOAD: "1" + run: | + pytest tests/unit/ tests/test_security.py tests/test_voiceprint_db.py tests/test_job_service.py \ + -p pytest_cov \ + -v --tb=short --no-header + - name: Run pip-audit + run: | + pip-audit -r app/requirements.txt \ + --ignore-vuln PYSEC-2022-42969 \ + --ignore-vuln CVE-2026-1839 \ + --ignore-vuln CVE-2025-3000 + + rust-wheel: + name: rust-wheel + runs-on: ubuntu-latest + needs: + - resolve-source + - public-release-scan + - lint-format + - unit-security + outputs: + wheel-name: ${{ steps.wheel.outputs.name }} steps: - name: Checkout repository uses: actions/checkout@v4 + with: + ref: ${{ needs.resolve-source.outputs.source-sha }} - name: Set up Python uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: ${{ env.PYTHON_VERSION }} cache: pip - name: Install Rust toolchain @@ -31,14 +127,22 @@ jobs: - name: Install maturin run: python -m pip install "maturin>=1.13,<2" + - name: Check Rust formatting + run: cargo fmt --manifest-path crates/voscript_core/Cargo.toml -- --check + + - name: Run Rust clippy + run: cargo clippy --manifest-path crates/voscript_core/Cargo.toml --features python-bindings --all-targets -- -D warnings + + - name: Run Rust tests + run: cargo test --manifest-path crates/voscript_core/Cargo.toml + - name: Build Rust wheel run: python -m maturin build --release --manifest-path crates/voscript_core/Cargo.toml --features extension-module --out dist - - name: Stage Rust wheel for Docker context + - name: Verify wheel artifact id: wheel run: | set -euo pipefail - mkdir -p app/.wheelhouse wheel_count="$(find dist -maxdepth 1 -name 'voscript_core-*.whl' | wc -l | tr -d ' ')" if [ "$wheel_count" != "1" ]; then echo "Expected exactly one voscript_core wheel, found $wheel_count" >&2 @@ -46,9 +150,96 @@ jobs: exit 1 fi wheel_path="$(find dist -maxdepth 1 -name 'voscript_core-*.whl' -print -quit)" - cp "$wheel_path" app/.wheelhouse/ echo "name=$(basename "$wheel_path")" >> "$GITHUB_OUTPUT" + - name: Upload exact-ref wheel artifact + uses: actions/upload-artifact@v4 + with: + name: voscript-core-wheel-${{ needs.resolve-source.outputs.source-sha }} + path: dist/voscript_core-*.whl + if-no-files-found: error + retention-days: 1 + + docker-smoke: + name: docker-smoke + runs-on: ubuntu-latest + needs: + - resolve-source + - rust-wheel + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + ref: ${{ needs.resolve-source.outputs.source-sha }} + + - name: Download exact-ref wheel artifact + uses: actions/download-artifact@v4 + with: + name: voscript-core-wheel-${{ needs.resolve-source.outputs.source-sha }} + path: app/.wheelhouse + + - name: Verify downloaded wheel + run: | + set -euo pipefail + test -f "app/.wheelhouse/${{ needs.rust-wheel.outputs.wheel-name }}" + + - name: Build Docker image with exact-ref Rust wheel + run: | + docker build ./app \ + --build-arg "VOSCRIPT_CORE_WHEEL=${{ needs.rust-wheel.outputs.wheel-name }}" \ + -t voscript-release-smoke:${{ needs.resolve-source.outputs.source-sha }} + + - name: Run container Rust extension smoke + run: | + docker run --rm \ + -e RUST_KERNEL_MODE=required \ + voscript-release-smoke:${{ needs.resolve-source.outputs.source-sha }} \ + python -c "from providers.kernel_bridge import artifact_manifest_contract, core_smoke, postprocess_segments, status_payload_contract, voiceprint_score; result = core_smoke({'source': 'release'}); assert result['ok'] is True; decision = voiceprint_score({'query_embedding': [1.0, 0.0], 'candidates': [{'speaker_id': 'spk_release', 'name': 'Release', 'embedding': [1.0, 0.0], 'sample_count': 1, 'sample_spread': None}], 'threshold': 0.75, 'asnorm_threshold': 0.5, 'cohort': None}); assert decision['matched_id'] == 'spk_release'; processed = postprocess_segments({'aligned_segments': [{'start': 0.0, 'end': 1.0, 'text': 'hello', 'speaker': 'SPEAKER_00'}], 'speaker_map': {}}); assert processed['segments'][0]['speaker_label'] == 'SPEAKER_00'; manifest = artifact_manifest_contract({'manifest_version': 'artifact_manifest.v1', 'stable': [{'name': 'result', 'filename': 'result.json', 'role': 'primary_result', 'media_type': 'application/json', 'required_for_result': True}], 'optional': [], 'experimental': []}); assert manifest['stable'][0]['filename'] == 'result.json'; status = status_payload_contract({'status': 'queued', 'updated_at': '2026-06-09T00:00:00+00:00', 'filename': 'private/audio.wav'}); assert status['filename'] == 'audio.wav'" + + - name: Run container healthz smoke + run: | + set -euo pipefail + cid="$(docker run -d -e DEVICE=cpu -e ALLOW_NO_AUTH=1 -e RUST_KERNEL_MODE=required voscript-release-smoke:${{ needs.resolve-source.outputs.source-sha }})" + trap 'docker logs "$cid" || true; docker rm -f "$cid" >/dev/null 2>&1 || true' EXIT + for _ in $(seq 1 60); do + if docker exec "$cid" python -c "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8780/healthz', timeout=2).read()" >/dev/null 2>&1; then + exit 0 + fi + sleep 2 + done + echo "Container did not pass /healthz smoke in time" >&2 + exit 1 + + publish: + name: publish + runs-on: ubuntu-latest + needs: + - resolve-source + - public-release-scan + - lint-format + - unit-security + - rust-wheel + - docker-smoke + permissions: + contents: read + packages: write + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + ref: ${{ needs.resolve-source.outputs.source-sha }} + + - name: Download exact-ref wheel artifact + uses: actions/download-artifact@v4 + with: + name: voscript-core-wheel-${{ needs.resolve-source.outputs.source-sha }} + path: app/.wheelhouse + + - name: Verify downloaded wheel + run: | + set -euo pipefail + test -f "app/.wheelhouse/${{ needs.rust-wheel.outputs.wheel-name }}" + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 @@ -67,10 +258,13 @@ jobs: - name: Compute image tags id: tags + env: + SOURCE_SHA: ${{ needs.resolve-source.outputs.source-sha }} run: | + set -euo pipefail GHCR_IMAGE="ghcr.io/$(echo '${{ github.repository }}' | tr '[:upper:]' '[:lower:]')" DOCKERHUB_IMAGE="${{ secrets.DOCKERHUB_USERNAME }}/voscript" - TAGS="$GHCR_IMAGE:latest,$DOCKERHUB_IMAGE:latest" + TAGS="$GHCR_IMAGE:latest,$DOCKERHUB_IMAGE:latest,$GHCR_IMAGE:sha-$SOURCE_SHA,$DOCKERHUB_IMAGE:sha-$SOURCE_SHA" # Tag-triggered builds (release or push-tag) get the version tag too. # Strip leading "v" for Docker Hub convention (0.7.0), keep raw ref for GHCR. if [ "${{ github.event_name }}" = "release" ] || [ "${{ github.event_name }}" = "push" ]; then @@ -80,14 +274,18 @@ jobs: fi echo "tags=$TAGS" >> "$GITHUB_OUTPUT" - - name: Build and push + - name: Build and push exact-ref image uses: docker/build-push-action@v6 with: context: ./app platforms: linux/amd64 push: true build-args: | - VOSCRIPT_CORE_WHEEL=${{ steps.wheel.outputs.name }} + VOSCRIPT_CORE_WHEEL=${{ needs.rust-wheel.outputs.wheel-name }} + labels: | + org.opencontainers.image.revision=${{ needs.resolve-source.outputs.source-sha }} + org.opencontainers.image.source=https://github.com/${{ github.repository }} + org.opencontainers.image.ref.name=${{ needs.resolve-source.outputs.source-ref }} tags: ${{ steps.tags.outputs.tags }} cache-from: type=gha cache-to: type=gha,mode=max diff --git a/.github/workflows/rust-foundation-heavy.yml b/.github/workflows/rust-foundation-heavy.yml index f8f1c40..392c152 100644 --- a/.github/workflows/rust-foundation-heavy.yml +++ b/.github/workflows/rust-foundation-heavy.yml @@ -2,7 +2,7 @@ name: Rust Foundation Heavy Gate on: pull_request: - types: [opened, reopened, ready_for_review] + types: [opened, reopened, ready_for_review, synchronize] branches: [main] push: branches: [main] @@ -21,16 +21,37 @@ env: RUST_KERNEL_MODE: required jobs: + resolve-source: + name: resolve-source + runs-on: ubuntu-latest + outputs: + source-sha: ${{ steps.source.outputs.sha }} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + ref: ${{ github.event.inputs.ref || github.ref }} + fetch-depth: 0 + + - name: Resolve immutable source ref + id: source + run: | + set -euo pipefail + SOURCE_SHA="$(git rev-parse HEAD)" + echo "sha=$SOURCE_SHA" >> "$GITHUB_OUTPUT" + echo "Resolved heavy gate source to ${SOURCE_SHA}" + rust-wheel: name: rust-wheel runs-on: ubuntu-latest + needs: resolve-source outputs: wheel-name: ${{ steps.wheel.outputs.name }} steps: - name: Checkout repository uses: actions/checkout@v4 with: - ref: ${{ github.event.inputs.ref || github.ref }} + ref: ${{ needs.resolve-source.outputs.source-sha }} - name: Set up Python uses: actions/setup-python@v5 @@ -80,12 +101,14 @@ jobs: docker-packaging: name: docker-packaging runs-on: ubuntu-latest - needs: rust-wheel + needs: + - resolve-source + - rust-wheel steps: - name: Checkout repository uses: actions/checkout@v4 with: - ref: ${{ github.event.inputs.ref || github.ref }} + ref: ${{ needs.resolve-source.outputs.source-sha }} - name: Download internal wheel artifact uses: actions/download-artifact@v4 @@ -102,19 +125,19 @@ jobs: run: | docker build ./app \ --build-arg "VOSCRIPT_CORE_WHEEL=${{ needs.rust-wheel.outputs.wheel-name }}" \ - -t voscript-rust-foundation:${{ github.sha }} + -t voscript-rust-foundation:${{ needs.resolve-source.outputs.source-sha }} - name: Run container extension smoke run: | docker run --rm \ -e RUST_KERNEL_MODE=required \ - voscript-rust-foundation:${{ github.sha }} \ + voscript-rust-foundation:${{ needs.resolve-source.outputs.source-sha }} \ python -c "from providers.kernel_bridge import artifact_manifest_contract, core_smoke, postprocess_segments, status_payload_contract, voiceprint_score; result = core_smoke({'source': 'ci'}); assert result['ok'] is True; assert result['echoed']['source'] == 'ci'; decision = voiceprint_score({'query_embedding': [1.0, 0.0], 'candidates': [{'speaker_id': 'spk_ci', 'name': 'CI', 'embedding': [1.0, 0.0], 'sample_count': 1, 'sample_spread': None}], 'threshold': 0.75, 'asnorm_threshold': 0.5, 'cohort': None}); assert decision['matched_id'] == 'spk_ci'; assert decision['reason'] == 'matched'; processed = postprocess_segments({'aligned_segments': [{'start': 0.0, 'end': 1.0, 'text': 'hello', 'speaker': 'SPEAKER_00'}], 'speaker_map': {}}); assert processed['segments'][0]['speaker_label'] == 'SPEAKER_00'; assert processed['unique_speakers'] == ['SPEAKER_00']; manifest = artifact_manifest_contract({'manifest_version': 'artifact_manifest.v1', 'stable': [{'name': 'result', 'filename': 'result.json', 'role': 'primary_result', 'media_type': 'application/json', 'required_for_result': True}], 'optional': [], 'experimental': []}); assert manifest['stable'][0]['filename'] == 'result.json'; status = status_payload_contract({'status': 'queued', 'updated_at': '2026-06-09T00:00:00+00:00', 'filename': 'private/audio.wav'}); assert status['status'] == 'queued'; assert status['filename'] == 'audio.wav'" - name: Run health check smoke run: | set -euo pipefail - cid="$(docker run -d -e DEVICE=cpu -e ALLOW_NO_AUTH=1 voscript-rust-foundation:${{ github.sha }})" + cid="$(docker run -d -e DEVICE=cpu -e ALLOW_NO_AUTH=1 -e RUST_KERNEL_MODE=required voscript-rust-foundation:${{ needs.resolve-source.outputs.source-sha }})" trap 'docker logs "$cid" || true; docker rm -f "$cid" >/dev/null 2>&1 || true' EXIT for _ in $(seq 1 60); do if docker exec "$cid" python -c "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8780/healthz', timeout=2).read()" >/dev/null 2>&1; then diff --git a/CLAUDE.md b/CLAUDE.md index 353b3fa..198df9a 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -63,6 +63,14 @@ app/ - API / behavior docs must match the current implementation in `app/`; do not document fixed thresholds or legacy validation semantics after changing runtime behavior +## 文档与输出语言 +- 本仓库后续回答、报告、ADR、规则文档和内部架构说明以中文为主;但技术证据保持原文,不翻译 `git status` 输出、文件/函数/模块名、workflow/agent 名、命令、commit ID、`grep`/`cargo`/`pytest`/test 命令名和配置 key。 +- 不为了中文化而改写证据名称。例如 `app/pipeline/registry.py`、`PipelineRequest`、`source-guard test`、`import direction`、`dependency direction`、`RUST_KERNEL_MODE`、`cargo test` 这类名称按原文写。 +- 会影响判断的架构术语,不能在首次出现时只写英文抽象词。一个章节或文档内首次有意义使用时,要用项目语境说明:它在 VoScript 里具体指什么、为什么重要或为什么是问题、修复后会降低什么风险。 +- 需要按上一条解释的典型术语包括:`facade`、`DTO`、`owner`、`cycle`/`circular dependency`、`boundary`、`repository`、`usecase`、`orchestration`、`adapter`、`service`、`lifecycle`、`provider`、`gate`、`source-guard test`、`import direction`、`dependency direction`、`structural debt`/`architecture debt`。不要把普通命令和显而易见的工具名逐个过度解释。 +- 架构重的报告或文档在有帮助时使用这个形状:`人话结论`、`架构解释`、`技术证据`。 +- 本节是长期写作规则,不记录本轮进度、下一步切片或未完成状态。 + ## Tests - `tests/unit/`: default regression layer for architecture and failure-path coverage - `tests/test_security.py`: security baseline and non-live red-team regression diff --git a/Cargo.lock b/Cargo.lock index e034ead..3b5d1aa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -127,7 +127,7 @@ checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" [[package]] name = "voscript_core" -version = "0.8.4" +version = "0.8.5" dependencies = [ "pyo3", ] diff --git a/app/Dockerfile b/app/Dockerfile index 2d6bf7c..7c88606 100755 --- a/app/Dockerfile +++ b/app/Dockerfile @@ -58,7 +58,9 @@ RUN set -eu; \ elif ls /wheelhouse/voscript_core-*.whl >/dev/null 2>&1; then \ pip install --no-cache-dir /wheelhouse/voscript_core-*.whl; \ else \ - echo "No voscript_core wheel provided; building local source image without Rust extension."; \ + echo "ERROR: voscript_core wheel is required by default; build the Rust wheel into app/.wheelhouse or pass VOSCRIPT_CORE_WHEEL." >&2; \ + echo "Runtime rollback remains RUST_KERNEL_MODE=off after building a Rust-capable image." >&2; \ + exit 1; \ fi; \ rm -rf /wheelhouse diff --git a/app/api/routers/transcriptions.py b/app/api/routers/transcriptions.py index a8378a8..f8450fd 100644 --- a/app/api/routers/transcriptions.py +++ b/app/api/routers/transcriptions.py @@ -10,13 +10,6 @@ GET /api/export/{tr_id} """ -import json -import logging -import re -import uuid -from datetime import datetime, timezone -from pathlib import PurePosixPath -from threading import Thread from typing import Annotated from fastapi import APIRouter, File, Form, HTTPException @@ -25,25 +18,22 @@ from fastapi.responses import FileResponse, PlainTextResponse from api.deps import get_db, get_pipeline -from application.transcription_jobs import run_transcription -from config import MAX_UPLOAD_BYTES, TRANSCRIPTIONS_DIR, UPLOAD_CHUNK, UPLOADS_DIR -from infra.audio import ( - lookup_hash, - safe_log_filename, - safe_tr_dir, - save_upload_and_hash, +from application.transcription_submission import ( + TranscriptionSubmissionCommand, + TranscriptionSubmissionError, + submit_transcription_upload, +) +from application.transcription_records import ( + TranscriptionRecordError, + build_export_payload, + get_audio_artifact, + get_job_status, + list_transcriptions as list_transcription_records, + load_transcription_result, + reassign_speaker as reassign_transcription_speaker, ) -from infra.job_persistence import _atomic_write_json, _write_status -from infra.job_runtime import jobs, register_in_flight, unregister_in_flight -from pipeline.contracts import normalize_status_payload - -_SPK_ID_RE = re.compile(r"^spk_[A-Za-z0-9_-]{1,64}$") - -logger = logging.getLogger(__name__) router = APIRouter(prefix="/api") -_MISSING = object() -_EXPORT_CTRL_RE = re.compile(r"[\r\n\x00-\x1f\x7f]+") # --------------------------------------------------------------------------- @@ -51,56 +41,24 @@ # --------------------------------------------------------------------------- -def _format_srt_time(seconds: float) -> str: - # [CQ-M13] 防御 None / NaN / 负秒——SRT 不允许负时间戳,NaN 会导致 int() 抛异常。 - if seconds is None or seconds != seconds: # NaN 自身不等于自身 - seconds = 0.0 - seconds = max(0.0, float(seconds)) - h = int(seconds // 3600) - m = int((seconds % 3600) // 60) - s = int(seconds % 60) - ms = int((seconds % 1) * 1000) - return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}" - +def _raise_submission_http_error(exc: TranscriptionSubmissionError) -> None: + status_code = 413 if exc.reason == "upload_too_large" else 503 + raise HTTPException(status_code, str(exc)) from exc -def _format_timestamp(seconds: float) -> str: - if seconds is None or seconds != seconds: - seconds = 0.0 - seconds = max(0.0, float(seconds)) - m = int(seconds // 60) - s = int(seconds % 60) - return f"{m:02d}:{s:02d}" - -def _load_transcription_result(tr_id: str) -> dict: - """Load result.json for *tr_id* and downgrade corruption to HTTP 409.""" - - result_file = safe_tr_dir(tr_id) / "result.json" - if not result_file.exists(): - raise HTTPException(404, "Transcription not found") - try: - return json.loads(result_file.read_text(encoding="utf-8")) - except Exception as exc: - logger.warning("Corrupt result.json for %s: %s", tr_id, exc) - raise HTTPException(409, "Corrupt transcription artifact") from exc - - -def _sanitize_export_speaker_name(value: object) -> str: - """Collapse control chars so speaker names cannot inject export lines.""" - - return _EXPORT_CTRL_RE.sub(" ", str(value or "")).strip() - - -def _discard_bootstrap_job(job_id: str, save_path) -> None: - """Best-effort rollback for a job that never became the canonical owner.""" - jobs.pop(job_id, _MISSING) - save_path.unlink(missing_ok=True) - tr_dir = TRANSCRIPTIONS_DIR / job_id - (tr_dir / "status.json").unlink(missing_ok=True) - try: - tr_dir.rmdir() - except OSError: - pass +def _raise_record_http_error(exc: TranscriptionRecordError) -> None: + status_codes = { + "invalid_transcription_id": 400, + "job_not_found": 404, + "transcription_not_found": 404, + "corrupt_result": 409, + "missing_audio": 404, + "invalid_speaker_id": 422, + "missing_voiceprint": 404, + "segment_not_found": 404, + "unsupported_export_format": 400, + } + raise HTTPException(status_codes.get(exc.reason, 500), str(exc)) from exc # --------------------------------------------------------------------------- @@ -135,188 +93,52 @@ async def transcribe( pipeline = get_pipeline(request) voiceprint_db = get_db(request) - # Normalise empty string to None so pipeline treats it as auto-detect. - language = language.strip() if language else None - - job_id = f"tr_{datetime.now():%Y%m%d_%H%M%S}_{uuid.uuid4().hex[:6]}" - - safe_filename = PurePosixPath(file.filename or "upload").name or "upload" - # Strip control chars before using the name in paths/logs — PurePosixPath.name - # preserves newlines and ANSI escapes which would otherwise enable log injection. - safe_filename = safe_log_filename(safe_filename) or "upload" - save_path = UPLOADS_DIR / f"{job_id}_{safe_filename}" - - # PERF-C2: async write + streaming SHA-256 — no event-loop blockage on large uploads. try: - _size, file_hash = await save_upload_and_hash( - file, save_path, MAX_UPLOAD_BYTES, UPLOAD_CHUNK - ) - except ValueError as exc: - save_path.unlink(missing_ok=True) - raise HTTPException(413, str(exc)) from exc - - # Dedup: if identical audio was already transcribed, return existing result. - existing_id = lookup_hash(file_hash) - if existing_id: - save_path.unlink(missing_ok=True) - logger.info( - "Dedup hit: %s already transcribed as %s", safe_filename, existing_id + submission = await submit_transcription_upload( + TranscriptionSubmissionCommand( + file=file, + pipeline=pipeline, + voiceprint_db=voiceprint_db, + language=language, + min_speakers=min_speakers, + max_speakers=max_speakers, + denoise_model=denoise_model, + snr_threshold=snr_threshold, + no_repeat_ngram_size=no_repeat_ngram_size, + ), ) - return {"id": existing_id, "status": "completed", "deduplicated": True} - - jobs[job_id] = { - "status": "queued", - "filename": safe_filename, - "created_at": datetime.now(tz=timezone.utc).isoformat(), - } - # Persist status.json BEFORE registering in-flight or starting the thread. - # This ensures any concurrent requester that receives this job_id via the - # in-flight dedup path is guaranteed to find a durable record on disk. - if not _write_status(job_id, "queued", filename=safe_filename): - _discard_bootstrap_job(job_id, save_path) - raise HTTPException( - 503, "Failed to persist job state — disk error, retry later" - ) - - # In-flight dedup: same content arriving concurrently reuses the first job. - # Registered AFTER status.json exists so the returned job_id is always live. - if file_hash: - existing_job = register_in_flight(file_hash, job_id) - if existing_job: - # Another request already owns this hash and has a durable record. - # Undo our own setup and redirect to the existing job. - _discard_bootstrap_job(job_id, save_path) - logger.info( - "In-flight dedup: %s already processing as %s", - safe_filename, - existing_job, - ) - return {"id": existing_job, "status": "queued", "deduplicated": True} - # CD-C3: daemon=True ensures this thread does not prevent the process from - # exiting on SIGTERM — the OS will clean up in-progress transcriptions on - # shutdown rather than hanging indefinitely waiting for the thread to finish. - thread = Thread( - target=run_transcription, - args=( - job_id, - save_path, - language, - min_speakers, - max_speakers, - pipeline, - voiceprint_db, - denoise_model, - snr_threshold, - file_hash, - no_repeat_ngram_size if no_repeat_ngram_size >= 3 else 0, - ), - daemon=True, - ) - try: - thread.start() - except Exception as exc: - logger.exception("Failed to start transcription thread for %s", job_id) - jobs[job_id]["status"] = "failed" - jobs[job_id]["error"] = "Failed to start background transcription" - _write_status(job_id, "failed", error=str(exc), filename=safe_filename) - save_path.unlink(missing_ok=True) - if file_hash: - unregister_in_flight(file_hash, job_id) - raise HTTPException( - 503, "Failed to start background transcription — retry later" - ) from exc + except TranscriptionSubmissionError as exc: + _raise_submission_http_error(exc) - return {"id": job_id, "status": "queued"} + response = {"id": submission.job_id, "status": submission.status} + if submission.deduplicated: + response["deduplicated"] = True + return response @router.get("/jobs/{job_id}") async def get_job( job_id: Annotated[str, FPath(pattern=r"^tr_[A-Za-z0-9_-]{1,64}$")], ): - if job_id in jobs: - job = jobs[job_id] - resp = {"id": job_id, "status": job["status"], "filename": job.get("filename")} - if job["status"] == "completed": - resp["result"] = job["result"] - elif job["status"] == "failed": - resp["error"] = job.get("error") - return resp - - # AR-C2 fallback: process restarted — try reading persisted status.json. - status_path = TRANSCRIPTIONS_DIR / job_id / "status.json" - result_path = TRANSCRIPTIONS_DIR / job_id / "result.json" - - if status_path.exists(): - try: - status_data = normalize_status_payload(json.loads(status_path.read_text())) - except Exception: - raise HTTPException(404, "Job not found") - - current_status = status_data.get("status") - - if current_status == "completed" and result_path.exists(): - try: - result = json.loads(result_path.read_text(encoding="utf-8")) - except Exception: - result = None - return { - "id": job_id, - "status": "completed", - "filename": status_data.get("filename"), - "result": result, - } - - if current_status not in ("completed", "failed"): - # In-progress status persisted by a previous process that no longer - # owns this job — treat as a restart failure. - return { - "id": job_id, - "status": "failed", - "error": "Process restarted while job was in progress", - "filename": status_data.get("filename"), - } - - return { - "id": job_id, - "status": current_status, - "error": status_data.get("error"), - "filename": status_data.get("filename"), - } - - raise HTTPException(404, "Job not found") + try: + return get_job_status(job_id) + except TranscriptionRecordError as exc: + _raise_record_http_error(exc) @router.get("/transcriptions") async def list_transcriptions(): - results = [] - for tr_dir in sorted(TRANSCRIPTIONS_DIR.iterdir(), reverse=True): - if not tr_dir.is_dir(): - continue - result_file = tr_dir / "result.json" - if result_file.exists(): - try: - data = json.loads(result_file.read_text(encoding="utf-8")) - results.append( - { - "id": data["id"], - "filename": data["filename"], - "created_at": data["created_at"], - "segment_count": len(data["segments"]), - "speaker_count": len(data.get("unique_speakers", [])), - } - ) - except Exception as exc: - logger.warning( - "Skipping corrupt result.json in %s: %s", tr_dir.name, exc - ) - return results + return list_transcription_records() @router.get("/transcriptions/{tr_id}") async def get_transcription( tr_id: Annotated[str, FPath(pattern=r"^tr_[A-Za-z0-9_-]{1,64}$")], ): - return _load_transcription_result(tr_id) + try: + return load_transcription_result(tr_id) + except TranscriptionRecordError as exc: + _raise_record_http_error(exc) @router.get("/transcriptions/{tr_id}/audio") @@ -324,11 +146,11 @@ async def download_audio( tr_id: Annotated[str, FPath(pattern=r"^tr_[A-Za-z0-9_-]{1,64}$")], ): """Return the original uploaded audio file for this transcription.""" - data = _load_transcription_result(tr_id) - audio_file = UPLOADS_DIR / data["filename"] - if not audio_file.exists(): - raise HTTPException(404, "Original audio file not found") - return FileResponse(audio_file, filename=data["filename"]) + try: + audio = get_audio_artifact(tr_id) + except TranscriptionRecordError as exc: + _raise_record_http_error(exc) + return FileResponse(audio.path, filename=audio.filename) @router.put("/transcriptions/{tr_id}/segments/{seg_id}/speaker") @@ -339,39 +161,17 @@ async def reassign_speaker( speaker_name: str = Form(...), speaker_id: str = Form(None), ): - """Correct the speaker label on a single segment. - - Only the targeted segment is updated. unique_speakers is recalculated - from the full segments list to stay consistent. speaker_map is not - modified — it tracks the diarization-model matching result, not - manual per-segment corrections. - """ - if speaker_id: - if not _SPK_ID_RE.match(speaker_id): - raise HTTPException(422, "Invalid speaker_id format") - voiceprint_db = get_db(request) - if voiceprint_db.get_speaker(speaker_id) is None: - raise HTTPException(404, f"Voiceprint {speaker_id} not found") - - result_file = safe_tr_dir(tr_id) / "result.json" - data = _load_transcription_result(tr_id) - - seg = next((s for s in data["segments"] if s["id"] == seg_id), None) - if seg is None: - raise HTTPException(404, "Segment not found") - - seg["speaker_name"] = speaker_name - # Explicitly overwrite (including clear) any stale speaker_id from a - # previous diarization match so the corrected segment stays coherent. - seg["speaker_id"] = speaker_id or None - - # Keep unique_speakers consistent with the corrected segments list. - data["unique_speakers"] = sorted( - set(s["speaker_name"] for s in data["segments"] if s.get("speaker_name")) - ) - - _atomic_write_json(result_file, data, ensure_ascii=False, indent=2) - return {"ok": True} + voiceprint_db = get_db(request) if speaker_id else None + try: + return reassign_transcription_speaker( + tr_id, + seg_id, + speaker_name, + speaker_id, + voiceprint_db=voiceprint_db, + ) + except TranscriptionRecordError as exc: + _raise_record_http_error(exc) @router.get("/export/{tr_id}") @@ -379,36 +179,19 @@ async def export_transcription( tr_id: Annotated[str, FPath(pattern=r"^tr_[A-Za-z0-9_-]{1,64}$")], format: str = "srt", ): - result_file = safe_tr_dir(tr_id) / "result.json" - data = _load_transcription_result(tr_id) - segments = data["segments"] + try: + payload = build_export_payload(tr_id, format) + except TranscriptionRecordError as exc: + _raise_record_http_error(exc) - if format == "srt": - lines = [] - for i, seg in enumerate(segments, 1): - start = _format_srt_time(seg["start"]) - end = _format_srt_time(seg["end"]) - speaker_name = _sanitize_export_speaker_name(seg.get("speaker_name")) - lines.append(f"{i}\n{start} --> {end}\n[{speaker_name}] {seg['text']}\n") - return PlainTextResponse( - "\n".join(lines), - media_type="text/srt", - headers={"Content-Disposition": f'attachment; filename="{tr_id}.srt"'}, - ) - elif format == "txt": - lines = [] - for seg in segments: - ts = _format_timestamp(seg["start"]) - speaker_name = _sanitize_export_speaker_name(seg.get("speaker_name")) - lines.append(f"[{ts}] {speaker_name}: {seg['text']}") - return PlainTextResponse( - "\n".join(lines), - media_type="text/plain", - headers={"Content-Disposition": f'attachment; filename="{tr_id}.txt"'}, - ) - elif format == "json": + if payload.file_path is not None: return FileResponse( - result_file, media_type="application/json", filename=f"{tr_id}.json" + payload.file_path, + media_type=payload.media_type, + filename=payload.filename, ) - else: - raise HTTPException(400, "Unsupported format. Use: srt, txt, json") + return PlainTextResponse( + payload.text or "", + media_type=payload.media_type, + headers={"Content-Disposition": f'attachment; filename="{payload.filename}"'}, + ) diff --git a/app/api/routers/voiceprints.py b/app/api/routers/voiceprints.py index 395a2c0..01d4f48 100644 --- a/app/api/routers/voiceprints.py +++ b/app/api/routers/voiceprints.py @@ -11,7 +11,7 @@ from api.deps import get_db from config import TRANSCRIPTIONS_DIR -from infra.audio import safe_speaker_label, safe_tr_dir +from infra.audio import AudioPathError, safe_speaker_label, safe_tr_dir logger = logging.getLogger(__name__) @@ -34,8 +34,11 @@ async def enroll_speaker( voiceprint_db = get_db(request) # SEC-C2: validate both tr_id and speaker_label before building any path. - safe_label = safe_speaker_label(speaker_label) - emb_path = safe_tr_dir(tr_id) / f"emb_{safe_label}.npy" + try: + safe_label = safe_speaker_label(speaker_label) + emb_path = safe_tr_dir(tr_id) / f"emb_{safe_label}.npy" + except AudioPathError as exc: + raise HTTPException(400, str(exc)) from exc if not emb_path.exists(): raise HTTPException(404, "Embedding not found for this speaker label") # SEC-C1: allow_pickle=False prevents arbitrary code execution via diff --git a/app/application/admission.py b/app/application/admission.py new file mode 100644 index 0000000..ce80550 --- /dev/null +++ b/app/application/admission.py @@ -0,0 +1,192 @@ +"""Application-level runtime admission policy for transcription jobs.""" + +from __future__ import annotations + +import shutil +from collections.abc import Callable +from dataclasses import dataclass +from pathlib import Path +from typing import Protocol + +from infra.job_runtime import ( + active_job_count, + in_flight_count, + lookup_in_flight, + release_active_job, + try_register_in_flight, + try_reserve_active_job, +) + + +@dataclass(frozen=True) +class AdmissionBudget: + max_active_jobs: int + max_in_flight_jobs: int + min_free_disk_bytes: int = 0 + + +class DiskUsage(Protocol): + free: int + + +@dataclass(frozen=True) +class MemorySensitiveStageLimits: + denoise_max_audio_duration_sec: float + embedding_preload_max_audio_duration_sec: float + whisperx_align_max_audio_duration_sec: float + + +@dataclass(frozen=True) +class RuntimeAdmissionSnapshot: + active_jobs: int + in_flight_jobs: int + free_disk_bytes: int | None = None + memory_sensitive_stage_limits: MemorySensitiveStageLimits | None = None + audio_duration_seconds: float | None = None + + +@dataclass(frozen=True) +class InFlightAdmission: + existing_job_id: str | None = None + registered: bool = False + + +class AdmissionRejectedError(RuntimeError): + """Raised when a transcription job exceeds configured runtime budgets.""" + + def __init__(self, reason: str, message: str) -> None: + super().__init__(message) + self.reason = reason + + +def _budget_enabled(value: int) -> bool: + return value > 0 + + +def data_disk_free_bytes( + path: Path, + *, + disk_usage: Callable[[Path], DiskUsage] | None = None, +) -> int: + """Read free bytes for the data disk without binding to a web framework.""" + + disk_usage = disk_usage or shutil.disk_usage + try: + return int(disk_usage(path).free) + except OSError as exc: + raise AdmissionRejectedError( + "data_disk_pressure", + f"Unable to inspect data disk free space for {path}", + ) from exc + + +def build_runtime_admission_snapshot( + *, + data_path: Path | None = None, + disk_usage: Callable[[Path], DiskUsage] | None = None, + memory_sensitive_stage_limits: MemorySensitiveStageLimits | None = None, + audio_duration_seconds: float | None = None, +) -> RuntimeAdmissionSnapshot: + free_disk_bytes = None + if data_path is not None: + free_disk_bytes = data_disk_free_bytes(data_path, disk_usage=disk_usage) + return RuntimeAdmissionSnapshot( + active_jobs=active_job_count(), + in_flight_jobs=in_flight_count(), + free_disk_bytes=free_disk_bytes, + memory_sensitive_stage_limits=memory_sensitive_stage_limits, + audio_duration_seconds=audio_duration_seconds, + ) + + +def ensure_transcription_admitted( + snapshot: RuntimeAdmissionSnapshot, + budget: AdmissionBudget, +) -> None: + """Reject new transcription work before background execution starts.""" + + if ( + _budget_enabled(budget.max_active_jobs) + and snapshot.active_jobs >= budget.max_active_jobs + ): + raise AdmissionRejectedError( + "active_job_budget_exceeded", + ( + "Transcription active job budget exceeded " + f"({snapshot.active_jobs}/{budget.max_active_jobs})" + ), + ) + if ( + _budget_enabled(budget.max_in_flight_jobs) + and snapshot.in_flight_jobs >= budget.max_in_flight_jobs + ): + raise AdmissionRejectedError( + "in_flight_job_budget_exceeded", + ( + "Transcription in-flight job budget exceeded " + f"({snapshot.in_flight_jobs}/{budget.max_in_flight_jobs})" + ), + ) + if _budget_enabled(budget.min_free_disk_bytes): + if snapshot.free_disk_bytes is None: + raise AdmissionRejectedError( + "data_disk_pressure", + "Unable to inspect data disk free space before admission", + ) + if snapshot.free_disk_bytes < budget.min_free_disk_bytes: + raise AdmissionRejectedError( + "data_disk_pressure", + ( + "Transcription data disk free space below admission budget " + f"({snapshot.free_disk_bytes}/{budget.min_free_disk_bytes})" + ), + ) + + +def find_in_flight_transcription(file_hash: str) -> str | None: + return lookup_in_flight(file_hash) + + +def reserve_transcription_admission( + job_id: str, + budget: AdmissionBudget, +) -> None: + reservation = try_reserve_active_job( + job_id, + max_entries=budget.max_active_jobs, + ) + if reservation.budget_exceeded: + raise AdmissionRejectedError( + "active_job_budget_exceeded", + ( + "Transcription active job budget exceeded " + f"({active_job_count()}/{budget.max_active_jobs})" + ), + ) + + +def release_transcription_admission(job_id: str) -> bool: + return release_active_job(job_id) + + +def admit_transcription_in_flight( + file_hash: str, + job_id: str, + budget: AdmissionBudget, +) -> InFlightAdmission: + registration = try_register_in_flight( + file_hash, + job_id, + max_entries=budget.max_in_flight_jobs, + ) + if registration.existing_job_id: + return InFlightAdmission(existing_job_id=registration.existing_job_id) + if registration.budget_exceeded: + raise AdmissionRejectedError( + "in_flight_job_budget_exceeded", + ( + "Transcription in-flight job budget exceeded " + f"({in_flight_count()}/{budget.max_in_flight_jobs})" + ), + ) + return InFlightAdmission(registered=registration.registered) diff --git a/app/application/transcription_jobs.py b/app/application/transcription_jobs.py index 3cbb6a5..9cdc632 100644 --- a/app/application/transcription_jobs.py +++ b/app/application/transcription_jobs.py @@ -4,13 +4,18 @@ import time from pathlib import Path +from application.admission import release_transcription_admission from config import ( TRANSCRIPTIONS_DIR, VOICEPRINT_THRESHOLD, ) from infra.audio import register_hash -from infra.job_persistence import _write_status -from infra.job_runtime import jobs, run_serialized_gpu_work, unregister_in_flight +from infra.job_persistence import write_job_status +from infra.job_runtime import ( + run_serialized_gpu_work, + unregister_in_flight, + update_runtime_job, +) logger = logging.getLogger(__name__) @@ -36,9 +41,9 @@ def run_transcription( """ def _record_status(status: str) -> None: - jobs[job_id]["status"] = status + update_runtime_job(job_id, {"status": status}) extra_filename = audio_path.name if status == "converting" else None - _write_status(job_id, status, filename=extra_filename) + write_job_status(job_id, status, filename=extra_filename) job_started = time.perf_counter() try: @@ -67,9 +72,8 @@ def _process_pipeline(): if file_hash: register_hash(file_hash, job_id) - jobs[job_id]["status"] = "completed" - jobs[job_id]["result"] = tr - _write_status(job_id, "completed") + update_runtime_job(job_id, {"status": "completed", "result": tr}) + write_job_status(job_id, "completed") logger.info( "transcription_job_timing status=completed elapsed_s=%.3f segment_count=%d speaker_count=%d", time.perf_counter() - job_started, @@ -85,8 +89,9 @@ def _process_pipeline(): time.perf_counter() - job_started, e.__class__.__name__, ) - jobs[job_id]["status"] = "failed" - jobs[job_id]["error"] = str(e) - _write_status(job_id, "failed", error=str(e)) + update_runtime_job(job_id, {"status": "failed", "error": str(e)}) + write_job_status(job_id, "failed", error=str(e)) if file_hash: unregister_in_flight(file_hash, job_id) + finally: + release_transcription_admission(job_id) diff --git a/app/application/transcription_records.py b/app/application/transcription_records.py new file mode 100644 index 0000000..800e18a --- /dev/null +++ b/app/application/transcription_records.py @@ -0,0 +1,313 @@ +"""Application usecases for persisted transcription records and artifacts.""" + +from __future__ import annotations + +import logging +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from infra.job_runtime import get_runtime_job +from infra.transcription_records import ( + FilesystemTranscriptionRecordRepository, + TranscriptionRecordStorageError, +) + +logger = logging.getLogger(__name__) + +_SPK_ID_RE = re.compile(r"^spk_[A-Za-z0-9_-]{1,64}$") +_EXPORT_CTRL_RE = re.compile(r"[\r\n\x00-\x1f\x7f]+") +_MISSING = object() + + +@dataclass(frozen=True) +class TranscriptionRecordSettings: + transcriptions_dir: Path + uploads_dir: Path + + +@dataclass(frozen=True) +class AudioArtifact: + path: Path + filename: str + + +@dataclass(frozen=True) +class ExportPayload: + media_type: str + filename: str + text: str | None = None + file_path: Path | None = None + + +class TranscriptionRecordError(RuntimeError): + """Typed application error for record read/write failures.""" + + def __init__(self, reason: str, message: str) -> None: + super().__init__(message) + self.reason = reason + + +def default_record_settings() -> TranscriptionRecordSettings: + import config + + return TranscriptionRecordSettings( + transcriptions_dir=config.TRANSCRIPTIONS_DIR, + uploads_dir=config.UPLOADS_DIR, + ) + + +def _settings_or_default( + settings: TranscriptionRecordSettings | None, +) -> TranscriptionRecordSettings: + return settings or default_record_settings() + + +def _repository( + settings: TranscriptionRecordSettings, +) -> FilesystemTranscriptionRecordRepository: + return FilesystemTranscriptionRecordRepository( + transcriptions_dir=settings.transcriptions_dir, + uploads_dir=settings.uploads_dir, + ) + + +def _raise_record_error(exc: TranscriptionRecordStorageError) -> None: + raise TranscriptionRecordError(exc.reason, str(exc)) from exc + + +def _lookup_runtime_job(job_id: str, runtime_jobs: Any | None) -> Any: + if runtime_jobs is None: + return get_runtime_job(job_id, _MISSING) + if job_id in runtime_jobs: + return runtime_jobs[job_id] + return _MISSING + + +def get_job_status( + job_id: str, + *, + settings: TranscriptionRecordSettings | None = None, + runtime_jobs: Any | None = None, +) -> dict[str, Any]: + settings = _settings_or_default(settings) + repository = _repository(settings) + + job = _lookup_runtime_job(job_id, runtime_jobs) + if job is not _MISSING: + response = { + "id": job_id, + "status": job["status"], + "filename": job.get("filename"), + } + if job["status"] == "completed": + response["result"] = job["result"] + elif job["status"] == "failed": + response["error"] = job.get("error") + return response + + try: + snapshot = repository.job_status_snapshot(job_id) + except TranscriptionRecordStorageError as exc: + _raise_record_error(exc) + if snapshot is None: + raise TranscriptionRecordError("job_not_found", "Job not found") + + status_data = snapshot.status + current_status = status_data.get("status") + + if current_status == "completed" and snapshot.result_exists: + return { + "id": job_id, + "status": "completed", + "filename": status_data.get("filename"), + "result": snapshot.result, + } + + if current_status not in ("completed", "failed"): + return { + "id": job_id, + "status": "failed", + "error": "Process restarted while job was in progress", + "filename": status_data.get("filename"), + } + + return { + "id": job_id, + "status": current_status, + "error": status_data.get("error"), + "filename": status_data.get("filename"), + } + + +def list_transcriptions( + *, + settings: TranscriptionRecordSettings | None = None, +) -> list[dict[str, Any]]: + settings = _settings_or_default(settings) + repository = _repository(settings) + results: list[dict[str, Any]] = [] + for data in repository.iter_transcription_results(): + try: + results.append( + { + "id": data["id"], + "filename": data["filename"], + "created_at": data["created_at"], + "segment_count": len(data["segments"]), + "speaker_count": len(data.get("unique_speakers", [])), + } + ) + except Exception as exc: + logger.warning("Skipping malformed transcription result: %s", exc) + return results + + +def load_transcription_result( + tr_id: str, + *, + settings: TranscriptionRecordSettings | None = None, +) -> dict[str, Any]: + settings = _settings_or_default(settings) + repository = _repository(settings) + try: + return repository.load_result(tr_id) + except TranscriptionRecordStorageError as exc: + _raise_record_error(exc) + + +def get_audio_artifact( + tr_id: str, + *, + settings: TranscriptionRecordSettings | None = None, +) -> AudioArtifact: + settings = _settings_or_default(settings) + repository = _repository(settings) + data = load_transcription_result(tr_id, settings=settings) + try: + audio = repository.uploaded_audio_artifact(data.get("filename")) + except TranscriptionRecordStorageError as exc: + _raise_record_error(exc) + return AudioArtifact(path=audio.path, filename=audio.filename) + + +def reassign_speaker( + tr_id: str, + seg_id: int, + speaker_name: str, + speaker_id: str | None = None, + *, + voiceprint_db: Any | None = None, + settings: TranscriptionRecordSettings | None = None, +) -> dict[str, bool]: + settings = _settings_or_default(settings) + repository = _repository(settings) + if speaker_id: + if not _SPK_ID_RE.match(speaker_id): + raise TranscriptionRecordError( + "invalid_speaker_id", + "Invalid speaker_id format", + ) + if voiceprint_db is None or voiceprint_db.get_speaker(speaker_id) is None: + raise TranscriptionRecordError( + "missing_voiceprint", + f"Voiceprint {speaker_id} not found", + ) + + data = load_transcription_result(tr_id, settings=settings) + + segment = next((s for s in data["segments"] if s["id"] == seg_id), None) + if segment is None: + raise TranscriptionRecordError("segment_not_found", "Segment not found") + + segment["speaker_name"] = speaker_name + segment["speaker_id"] = speaker_id or None + data["unique_speakers"] = sorted( + set(s["speaker_name"] for s in data["segments"] if s.get("speaker_name")) + ) + + try: + repository.save_result(tr_id, data) + except TranscriptionRecordStorageError as exc: + _raise_record_error(exc) + return {"ok": True} + + +def build_export_payload( + tr_id: str, + export_format: str = "srt", + *, + settings: TranscriptionRecordSettings | None = None, +) -> ExportPayload: + settings = _settings_or_default(settings) + repository = _repository(settings) + data = load_transcription_result(tr_id, settings=settings) + segments = data["segments"] + + if export_format == "srt": + lines = [] + for index, segment in enumerate(segments, 1): + start = _format_srt_time(segment["start"]) + end = _format_srt_time(segment["end"]) + speaker_name = _sanitize_export_speaker_name(segment.get("speaker_name")) + lines.append( + f"{index}\n{start} --> {end}\n[{speaker_name}] {segment['text']}\n" + ) + return ExportPayload( + text="\n".join(lines), + media_type="text/srt", + filename=f"{tr_id}.srt", + ) + + if export_format == "txt": + lines = [] + for segment in segments: + timestamp = _format_timestamp(segment["start"]) + speaker_name = _sanitize_export_speaker_name(segment.get("speaker_name")) + lines.append(f"[{timestamp}] {speaker_name}: {segment['text']}") + return ExportPayload( + text="\n".join(lines), + media_type="text/plain", + filename=f"{tr_id}.txt", + ) + + if export_format == "json": + try: + result_file = repository.result_file_path(tr_id) + except TranscriptionRecordStorageError as exc: + _raise_record_error(exc) + return ExportPayload( + file_path=result_file, + media_type="application/json", + filename=f"{tr_id}.json", + ) + + raise TranscriptionRecordError( + "unsupported_export_format", + "Unsupported format. Use: srt, txt, json", + ) + + +def _format_srt_time(seconds: float) -> str: + if seconds is None or seconds != seconds: + seconds = 0.0 + seconds = max(0.0, float(seconds)) + hours = int(seconds // 3600) + minutes = int((seconds % 3600) // 60) + whole_seconds = int(seconds % 60) + milliseconds = int((seconds % 1) * 1000) + return f"{hours:02d}:{minutes:02d}:{whole_seconds:02d},{milliseconds:03d}" + + +def _format_timestamp(seconds: float) -> str: + if seconds is None or seconds != seconds: + seconds = 0.0 + seconds = max(0.0, float(seconds)) + minutes = int(seconds // 60) + whole_seconds = int(seconds % 60) + return f"{minutes:02d}:{whole_seconds:02d}" + + +def _sanitize_export_speaker_name(value: object) -> str: + return _EXPORT_CTRL_RE.sub(" ", str(value or "")).strip() diff --git a/app/application/transcription_submission.py b/app/application/transcription_submission.py new file mode 100644 index 0000000..92017df --- /dev/null +++ b/app/application/transcription_submission.py @@ -0,0 +1,462 @@ +"""Application-level orchestration for transcription upload submission.""" + +from __future__ import annotations + +import logging +import threading +import uuid +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path, PurePosixPath +from typing import Any, Protocol + +from application.admission import ( + AdmissionBudget, + AdmissionRejectedError, + DiskUsage, + MemorySensitiveStageLimits, + RuntimeAdmissionSnapshot, + build_runtime_admission_snapshot, + admit_transcription_in_flight, + ensure_transcription_admitted, + find_in_flight_transcription, + release_transcription_admission, + reserve_transcription_admission, +) +from application.transcription_jobs import run_transcription +from infra.audio import safe_log_filename +from infra.job_persistence import discard_job_status, write_job_status +from infra.job_runtime import ( + get_runtime_job, + pop_runtime_job, + runtime_job_count, + runtime_job_exists, + runtime_jobs_values_snapshot, + set_runtime_job, + unregister_in_flight, + update_runtime_job, +) + +logger = logging.getLogger(__name__) +_MISSING = object() +Thread = threading.Thread + + +class _RuntimeJobsProxy: + """Resolve the active runtime job store at use time. + + Test clients in this repo intentionally reload app modules under different + DATA_DIR values. A proxy avoids pinning this usecase to a stale infra module. + """ + + def __setitem__(self, key, value): + set_runtime_job(key, value) + + def __getitem__(self, key): + value = get_runtime_job(key, _MISSING) + if value is _MISSING: + raise KeyError(key) + return value + + def __contains__(self, key): + return runtime_job_exists(key) + + def get(self, key, default=None): + return get_runtime_job(key, default) + + def pop(self, key, default=_MISSING): + if default is _MISSING: + return pop_runtime_job(key) + return pop_runtime_job(key, default) + + def values_snapshot(self) -> tuple: + return runtime_jobs_values_snapshot() + + def __len__(self): + return runtime_job_count() + + +jobs = _RuntimeJobsProxy() + + +class UploadStream(Protocol): + filename: str | None + + async def read(self, size: int = -1) -> bytes: ... + + +@dataclass(frozen=True) +class TranscriptionSubmissionCommand: + file: UploadStream + pipeline: Any + voiceprint_db: Any + language: str | None = None + min_speakers: int = 0 + max_speakers: int = 0 + denoise_model: str | None = None + snr_threshold: float | None = None + no_repeat_ngram_size: int = 0 + + +@dataclass(frozen=True) +class TranscriptionSubmissionSettings: + max_upload_bytes: int + upload_chunk: int + max_active_jobs: int + max_in_flight_jobs: int + uploads_dir: Path + transcriptions_dir: Path + min_free_disk_bytes: int = 0 + denoise_max_audio_duration_sec: float = 0.0 + embedding_preload_max_audio_duration_sec: float = 0.0 + whisperx_align_max_audio_duration_sec: float = 0.0 + + +@dataclass(frozen=True) +class TranscriptionSubmissionResult: + job_id: str + status: str + deduplicated: bool = False + + +class TranscriptionSubmissionError(RuntimeError): + """Typed application error for submission failures.""" + + def __init__(self, reason: str, message: str) -> None: + super().__init__(message) + self.reason = reason + + +def _new_job_id() -> str: + return f"tr_{datetime.now():%Y%m%d_%H%M%S}_{uuid.uuid4().hex[:6]}" + + +def default_submission_settings() -> TranscriptionSubmissionSettings: + import config + + return TranscriptionSubmissionSettings( + max_upload_bytes=config.MAX_UPLOAD_BYTES, + upload_chunk=config.UPLOAD_CHUNK, + max_active_jobs=config.TRANSCRIPTION_MAX_ACTIVE_JOBS, + max_in_flight_jobs=config.TRANSCRIPTION_MAX_IN_FLIGHT_JOBS, + uploads_dir=config.UPLOADS_DIR, + transcriptions_dir=config.TRANSCRIPTIONS_DIR, + min_free_disk_bytes=config.TRANSCRIPTION_MIN_FREE_DISK_BYTES, + denoise_max_audio_duration_sec=config.DENOISE_MAX_AUDIO_DURATION_SEC, + embedding_preload_max_audio_duration_sec=( + config.EMBEDDING_PRELOAD_MAX_AUDIO_DURATION_SEC + ), + whisperx_align_max_audio_duration_sec=( + config.WHISPERX_ALIGN_MAX_AUDIO_DURATION_SEC + ), + ) + + +def _admission_budget(settings: TranscriptionSubmissionSettings) -> AdmissionBudget: + return AdmissionBudget( + max_active_jobs=settings.max_active_jobs, + max_in_flight_jobs=settings.max_in_flight_jobs, + min_free_disk_bytes=settings.min_free_disk_bytes, + ) + + +def _submission_error_from_admission( + exc: AdmissionRejectedError, +) -> TranscriptionSubmissionError: + return TranscriptionSubmissionError(exc.reason, str(exc)) + + +def _safe_upload_name(file: UploadStream) -> str: + name = PurePosixPath(file.filename or "upload").name or "upload" + return safe_log_filename(name) or "upload" + + +async def _default_upload_saver( + file: UploadStream, + save_path: Path, + max_bytes: int, + chunk_size: int, +) -> tuple[int, str]: + from infra.audio import save_upload_and_hash + + return await save_upload_and_hash(file, save_path, max_bytes, chunk_size) + + +def _default_audio_duration_reader(path: Path) -> float | None: + from infra.audio import audio_duration_seconds + + return audio_duration_seconds(path) + + +def _default_hash_lookup(file_hash: str) -> str | None: + from infra.audio import lookup_hash + + return lookup_hash(file_hash) + + +def _default_status_writer(*args, **kwargs) -> bool: + return write_job_status(*args, **kwargs) + + +def _unregister_in_flight(file_hash: str, job_id: str) -> bool: + return unregister_in_flight(file_hash, job_id) + + +def _memory_sensitive_stage_limits( + settings: TranscriptionSubmissionSettings, +) -> MemorySensitiveStageLimits: + return MemorySensitiveStageLimits( + denoise_max_audio_duration_sec=settings.denoise_max_audio_duration_sec, + embedding_preload_max_audio_duration_sec=( + settings.embedding_preload_max_audio_duration_sec + ), + whisperx_align_max_audio_duration_sec=( + settings.whisperx_align_max_audio_duration_sec + ), + ) + + +def _read_audio_duration_seconds( + path: Path, + reader: Callable[[Path], float | None], +) -> float | None: + try: + return reader(path) + except Exception: + logger.info("Unable to read audio duration metadata for %s", path) + return None + + +def _admission_record( + snapshot: RuntimeAdmissionSnapshot, + budget: AdmissionBudget, +) -> dict[str, Any]: + limits = snapshot.memory_sensitive_stage_limits + return { + "active_jobs": snapshot.active_jobs, + "in_flight_jobs": snapshot.in_flight_jobs, + "data_disk": { + "free_bytes": snapshot.free_disk_bytes, + "min_free_bytes": budget.min_free_disk_bytes, + }, + "memory_sensitive_stage_limits": { + "DENOISE_MAX_AUDIO_DURATION_SEC": ( + None if limits is None else limits.denoise_max_audio_duration_sec + ), + "EMBEDDING_PRELOAD_MAX_AUDIO_DURATION_SEC": ( + None + if limits is None + else limits.embedding_preload_max_audio_duration_sec + ), + "WHISPERX_ALIGN_MAX_AUDIO_DURATION_SEC": ( + None if limits is None else limits.whisperx_align_max_audio_duration_sec + ), + }, + "audio_duration_seconds": snapshot.audio_duration_seconds, + } + + +def _discard_bootstrap_job( + job_id: str, + save_path: Path, + *, + transcriptions_dir: Path, +) -> None: + jobs.pop(job_id, _MISSING) + save_path.unlink(missing_ok=True) + discard_job_status(job_id, transcriptions_dir=transcriptions_dir) + + +async def submit_transcription_upload( + command: TranscriptionSubmissionCommand, + *, + settings: TranscriptionSubmissionSettings | None = None, + job_id_factory: Callable[[], str] = _new_job_id, + thread_factory: Callable[..., Any] | None = None, + worker: Callable[..., Any] | None = None, + status_writer: Callable[..., bool] | None = None, + upload_saver: Callable[[UploadStream, Path, int, int], Awaitable[tuple[int, str]]] + | None = None, + hash_lookup: Callable[[str], str | None] | None = None, + disk_usage: Callable[[Path], DiskUsage] | None = None, + audio_duration_reader: Callable[[Path], float | None] | None = None, +) -> TranscriptionSubmissionResult: + """Accept an upload and bootstrap a durable background transcription job.""" + + settings = settings or default_submission_settings() + thread_factory = thread_factory or Thread + worker = worker or run_transcription + status_writer = status_writer or _default_status_writer + upload_saver = upload_saver or _default_upload_saver + hash_lookup = hash_lookup or _default_hash_lookup + audio_duration_reader = audio_duration_reader or _default_audio_duration_reader + language = command.language.strip() if command.language else None + job_id = job_id_factory() + safe_filename = _safe_upload_name(command.file) + save_path = settings.uploads_dir / f"{job_id}_{safe_filename}" + + try: + _size, file_hash = await upload_saver( + command.file, + save_path, + settings.max_upload_bytes, + settings.upload_chunk, + ) + except ValueError as exc: + save_path.unlink(missing_ok=True) + raise TranscriptionSubmissionError("upload_too_large", str(exc)) from exc + + existing_id = hash_lookup(file_hash) + if existing_id: + save_path.unlink(missing_ok=True) + logger.info( + "Dedup hit: %s already transcribed as %s", safe_filename, existing_id + ) + return TranscriptionSubmissionResult( + job_id=existing_id, + status="completed", + deduplicated=True, + ) + + existing_job = find_in_flight_transcription(file_hash) if file_hash else None + if existing_job: + save_path.unlink(missing_ok=True) + logger.info( + "In-flight dedup: %s already processing as %s", + safe_filename, + existing_job, + ) + return TranscriptionSubmissionResult( + job_id=existing_job, + status="queued", + deduplicated=True, + ) + + budget = _admission_budget(settings) + audio_duration = _read_audio_duration_seconds(save_path, audio_duration_reader) + try: + admission_snapshot = build_runtime_admission_snapshot( + data_path=settings.uploads_dir if budget.min_free_disk_bytes > 0 else None, + disk_usage=disk_usage, + memory_sensitive_stage_limits=_memory_sensitive_stage_limits(settings), + audio_duration_seconds=audio_duration, + ) + ensure_transcription_admitted(admission_snapshot, budget) + except AdmissionRejectedError as exc: + save_path.unlink(missing_ok=True) + raise _submission_error_from_admission(exc) from exc + + active_reserved = False + try: + reserve_transcription_admission(job_id, budget) + active_reserved = True + except AdmissionRejectedError as exc: + save_path.unlink(missing_ok=True) + raise _submission_error_from_admission(exc) from exc + + jobs[job_id] = { + "status": "queued", + "filename": safe_filename, + "created_at": datetime.now(tz=timezone.utc).isoformat(), + "admission": _admission_record(admission_snapshot, budget), + } + if not status_writer(job_id, "queued", filename=safe_filename): + _discard_bootstrap_job( + job_id, + save_path, + transcriptions_dir=settings.transcriptions_dir, + ) + if active_reserved: + release_transcription_admission(job_id) + raise TranscriptionSubmissionError( + "job_state_persist_failed", + "Failed to persist job state — disk error, retry later", + ) + + if file_hash: + try: + registration = admit_transcription_in_flight(file_hash, job_id, budget) + except AdmissionRejectedError as exc: + _discard_bootstrap_job( + job_id, + save_path, + transcriptions_dir=settings.transcriptions_dir, + ) + if active_reserved: + release_transcription_admission(job_id) + raise _submission_error_from_admission(exc) from exc + if registration.existing_job_id: + _discard_bootstrap_job( + job_id, + save_path, + transcriptions_dir=settings.transcriptions_dir, + ) + if active_reserved: + release_transcription_admission(job_id) + logger.info( + "In-flight dedup: %s already processing as %s", + safe_filename, + registration.existing_job_id, + ) + return TranscriptionSubmissionResult( + job_id=registration.existing_job_id, + status="queued", + deduplicated=True, + ) + if not registration.registered: + _discard_bootstrap_job( + job_id, + save_path, + transcriptions_dir=settings.transcriptions_dir, + ) + if active_reserved: + release_transcription_admission(job_id) + raise TranscriptionSubmissionError( + "in_flight_registration_failed", + "Failed to register in-flight transcription", + ) + + thread = thread_factory( + target=worker, + args=( + job_id, + save_path, + language, + command.min_speakers, + command.max_speakers, + command.pipeline, + command.voiceprint_db, + command.denoise_model, + command.snr_threshold, + file_hash, + command.no_repeat_ngram_size if command.no_repeat_ngram_size >= 3 else 0, + ), + daemon=True, + ) + try: + thread.start() + except Exception as exc: + logger.exception("Failed to start transcription thread for %s", job_id) + # Durable bootstrap has already succeeded. Preserve the failed job and + # persisted status as the observable old-router-compatible record, while + # releasing transient upload, in-flight, and admission state below. + update_runtime_job( + job_id, + { + "status": "failed", + "error": "Failed to start background transcription", + }, + ) + status_writer(job_id, "failed", error=str(exc), filename=safe_filename) + save_path.unlink(missing_ok=True) + if file_hash: + _unregister_in_flight(file_hash, job_id) + if active_reserved: + release_transcription_admission(job_id) + raise TranscriptionSubmissionError( + "thread_start_failed", + "Failed to start background transcription — retry later", + ) from exc + + return TranscriptionSubmissionResult(job_id=job_id, status="queued") diff --git a/app/config.py b/app/config.py index 7d3fe7f..7b15a4e 100644 --- a/app/config.py +++ b/app/config.py @@ -9,7 +9,7 @@ from pathlib import Path -APP_VERSION = "0.8.4" +APP_VERSION = "0.8.5" def _env_float(name: str, default: float) -> float: @@ -93,7 +93,7 @@ def _env_mapping(name: str) -> dict[str, str]: DEVICE: str = os.getenv("DEVICE", "cuda") LANGUAGE: str = os.getenv("LANGUAGE", "") MODEL_IDLE_TIMEOUT_SEC: float = _env_float("MODEL_IDLE_TIMEOUT_SEC", 180.0) -RUST_KERNEL_MODE: str = _env_str("RUST_KERNEL_MODE", "off").lower() +RUST_KERNEL_MODE: str = _env_str("RUST_KERNEL_MODE", "required").lower() # WhisperX forced-alignment controls. Languages are attempted by default; use # WHISPERX_ALIGN_DISABLED_LANGUAGES only for an explicit operational fallback. @@ -118,6 +118,10 @@ def _env_mapping(name: str) -> dict[str, str]: # Audio estimated at or above this level is considered clean and skipped, # matching the A/B finding that DF hurts high-quality recordings (e.g. PLAUD Pin). DENOISE_SNR_THRESHOLD: float = _env_float("DENOISE_SNR_THRESHOLD", 10.0) +DENOISE_MAX_AUDIO_DURATION_SEC: float = _env_float( + "DENOISE_MAX_AUDIO_DURATION_SEC", + 7200.0, +) # --------------------------------------------------------------------------- # Speaker identification @@ -133,6 +137,14 @@ def _env_mapping(name: str) -> dict[str, str]: PYANNOTE_MIN_DURATION_OFF: float = _env_float("PYANNOTE_MIN_DURATION_OFF", 0.5) MIN_EMBED_DURATION: float = _env_float("MIN_EMBED_DURATION", 1.5) MAX_EMBED_DURATION: float = _env_float("MAX_EMBED_DURATION", 10.0) +EMBEDDING_PRELOAD_MAX_AUDIO_DURATION_SEC: float = _env_float( + "EMBEDDING_PRELOAD_MAX_AUDIO_DURATION_SEC", + 1800.0, +) +WHISPERX_ALIGN_MAX_AUDIO_DURATION_SEC: float = _env_float( + "WHISPERX_ALIGN_MAX_AUDIO_DURATION_SEC", + 7200.0, +) # --------------------------------------------------------------------------- # Misc @@ -140,6 +152,15 @@ def _env_mapping(name: str) -> dict[str, str]: FFMPEG_TIMEOUT_SEC: int = _env_int("FFMPEG_TIMEOUT_SEC", 1800) JOBS_MAX_CACHE: int = _env_int("JOBS_MAX_CACHE", 200) +TRANSCRIPTION_MAX_ACTIVE_JOBS: int = _env_int("TRANSCRIPTION_MAX_ACTIVE_JOBS", 200) +TRANSCRIPTION_MAX_IN_FLIGHT_JOBS: int = _env_int( + "TRANSCRIPTION_MAX_IN_FLIGHT_JOBS", + 4, +) +TRANSCRIPTION_MIN_FREE_DISK_BYTES: int = _env_int( + "TRANSCRIPTION_MIN_FREE_DISK_BYTES", + 1024 * 1024 * 1024, +) # Paths that must stay open even when API_KEY auth is enabled. "/" is the # bundled web UI (browsers can't attach a Bearer header to a direct diff --git a/app/infra/audio/__init__.py b/app/infra/audio/__init__.py index ab21657..da59c79 100644 --- a/app/infra/audio/__init__.py +++ b/app/infra/audio/__init__.py @@ -8,14 +8,26 @@ register_hash, save_upload_and_hash, ) +from .errors import ( + AudioPathError, + AudioPathTraversalError, + InvalidSpeakerLabelError, + InvalidTranscriptionIdError, +) from .paths import safe_log_filename, safe_speaker_label, safe_tr_dir from .tempfiles import cleanup_generated_files +from .metadata import audio_duration_seconds __all__ = [ + "AudioPathError", + "AudioPathTraversalError", + "audio_duration_seconds", "cleanup_generated_files", "JsonAudioArtifactIndex", "compute_file_hash", "default_audio_artifact_index", + "InvalidSpeakerLabelError", + "InvalidTranscriptionIdError", "lookup_hash", "register_hash", "safe_log_filename", diff --git a/app/infra/audio/errors.py b/app/infra/audio/errors.py new file mode 100644 index 0000000..013eef2 --- /dev/null +++ b/app/infra/audio/errors.py @@ -0,0 +1,27 @@ +"""Typed errors for audio filesystem helpers.""" + +from __future__ import annotations + + +class AudioPathError(ValueError): + """Base error for invalid audio filesystem inputs.""" + + +class InvalidTranscriptionIdError(AudioPathError): + """Raised when a transcription ID cannot be safely used as a path segment.""" + + +class AudioPathTraversalError(AudioPathError): + """Raised when a resolved audio path escapes its configured root.""" + + +class InvalidSpeakerLabelError(AudioPathError): + """Raised when a speaker label cannot be safely used in a filename.""" + + +__all__ = [ + "AudioPathError", + "AudioPathTraversalError", + "InvalidSpeakerLabelError", + "InvalidTranscriptionIdError", +] diff --git a/app/infra/audio/metadata.py b/app/infra/audio/metadata.py new file mode 100644 index 0000000..715e844 --- /dev/null +++ b/app/infra/audio/metadata.py @@ -0,0 +1,24 @@ +"""Audio metadata helpers used before memory-sensitive processing.""" + +from __future__ import annotations + +from pathlib import Path + + +def audio_duration_seconds(path: Path | str) -> float | None: + """Return audio duration from metadata without loading the full waveform.""" + + try: + import torchaudio + + info = torchaudio.info(str(path)) + sample_rate = getattr(info, "sample_rate", 0) or 0 + num_frames = getattr(info, "num_frames", 0) or 0 + if sample_rate <= 0 or num_frames < 0: + return None + return num_frames / sample_rate + except Exception: + return None + + +__all__ = ["audio_duration_seconds"] diff --git a/app/infra/audio/paths.py b/app/infra/audio/paths.py index 090ae52..c8bf45e 100644 --- a/app/infra/audio/paths.py +++ b/app/infra/audio/paths.py @@ -5,9 +5,12 @@ import re from pathlib import Path -from fastapi import HTTPException - from config import TRANSCRIPTIONS_DIR +from .errors import ( + AudioPathTraversalError, + InvalidSpeakerLabelError, + InvalidTranscriptionIdError, +) _TR_ID_RE = re.compile(r"^tr_[A-Za-z0-9_-]{1,64}$") _SPEAKER_LABEL_RE = re.compile(r"^[A-Za-z0-9_-]{1,64}$") @@ -26,10 +29,10 @@ def safe_tr_dir(tr_id: str) -> Path: """Validate tr_id and return its transcription directory.""" if not _TR_ID_RE.match(tr_id): - raise HTTPException(400, f"Invalid transcription ID format: {tr_id!r}") + raise InvalidTranscriptionIdError(f"Invalid transcription ID format: {tr_id!r}") path = (TRANSCRIPTIONS_DIR / tr_id).resolve() if not str(path).startswith(str(TRANSCRIPTIONS_DIR.resolve())): - raise HTTPException(400, "Path traversal detected") + raise AudioPathTraversalError("Path traversal detected") return path @@ -37,7 +40,7 @@ def safe_speaker_label(label: str) -> str: """Validate speaker labels before embedding them into filenames.""" if not _SPEAKER_LABEL_RE.match(label): - raise HTTPException(400, f"Invalid speaker label: {label!r}") + raise InvalidSpeakerLabelError(f"Invalid speaker label: {label!r}") return label diff --git a/app/infra/job_persistence.py b/app/infra/job_persistence.py index c1b8136..a80c9e6 100644 --- a/app/infra/job_persistence.py +++ b/app/infra/job_persistence.py @@ -8,7 +8,7 @@ from pathlib import Path from config import TRANSCRIPTIONS_DIR -from pipeline.contracts import ( +from infra.job_status import ( TERMINAL_JOB_STATUSES, build_status_payload, normalize_status_payload, @@ -40,6 +40,12 @@ def _atomic_write_json(path: Path, payload: dict, **json_kwargs) -> None: pass +def atomic_write_json(path: Path, payload: dict, **json_kwargs) -> None: + """Public adapter for atomic JSON writes owned by infra.""" + + _atomic_write_json(path, payload, **json_kwargs) + + def _write_status( job_id: str, status: str, @@ -67,6 +73,28 @@ def _write_status( return False +def write_job_status( + job_id: str, + status: str, + error: str | None = None, + filename: str | None = None, +) -> bool: + """Public adapter for persisted transcription job status writes.""" + + return _write_status(job_id, status, error=error, filename=filename) + + +def discard_job_status(job_id: str, *, transcriptions_dir: Path) -> None: + """Remove a bootstrap status record and its empty job directory if present.""" + + tr_dir = transcriptions_dir / job_id + (tr_dir / "status.json").unlink(missing_ok=True) + try: + tr_dir.rmdir() + except OSError: + pass + + def recover_orphan_jobs() -> None: """Mark any in-progress jobs as failed if the process was restarted.""" try: @@ -91,3 +119,11 @@ def recover_orphan_jobs() -> None: ) except Exception as exc: logger.warning("AR-C2: orphan job recovery scan failed: %s", exc) + + +__all__ = [ + "atomic_write_json", + "discard_job_status", + "recover_orphan_jobs", + "write_job_status", +] diff --git a/app/infra/job_runtime.py b/app/infra/job_runtime.py index 9ea189a..939b082 100644 --- a/app/infra/job_runtime.py +++ b/app/infra/job_runtime.py @@ -8,7 +8,7 @@ from collections import OrderedDict from collections.abc import Callable from dataclasses import dataclass -from typing import TypeVar +from typing import Any, TypeVar from config import JOBS_MAX_CACHE, MODEL_IDLE_TIMEOUT_SEC @@ -54,6 +54,20 @@ def pop(self, key, default=_MISSING): return self._d.pop(key) return self._d.pop(key, default) + def update(self, key, updates: dict[str, Any]) -> dict[str, Any]: + with self._lock: + job = self._d[key] + job.update(updates) + return dict(job) + + def values_snapshot(self) -> tuple: + with self._lock: + return tuple(self._d.values()) + + def __len__(self): + with self._lock: + return len(self._d) + jobs: _LRUJobsDict = _LRUJobsDict(maxsize=JOBS_MAX_CACHE) @@ -67,6 +81,21 @@ def pop(self, key, default=_MISSING): # both burning GPU. Cleared when the job reaches a terminal state. _in_flight_hashes: dict[str, str] = {} _in_flight_lock = threading.Lock() +_active_job_ids: set[str] = set() +_active_job_lock = threading.Lock() + + +@dataclass(frozen=True) +class ActiveJobReservation: + reserved: bool = False + budget_exceeded: bool = False + + +@dataclass(frozen=True) +class InFlightRegistration: + existing_job_id: str | None = None + registered: bool = False + budget_exceeded: bool = False @dataclass(frozen=True) @@ -79,6 +108,50 @@ def stop(self, timeout: float = 5.0) -> None: self.thread.join(timeout=timeout) +def set_runtime_job(job_id: str, payload: dict[str, Any]) -> None: + """Store a runtime job record in the current in-memory job cache.""" + + jobs[job_id] = payload + + +def get_runtime_job(job_id: str, default: Any = None) -> Any: + """Read a runtime job record from the current in-memory job cache.""" + + return jobs.get(job_id, default) + + +def runtime_job_exists(job_id: str) -> bool: + """Return whether a runtime job exists in the current in-memory cache.""" + + return job_id in jobs + + +def update_runtime_job(job_id: str, updates: dict[str, Any]) -> dict[str, Any]: + """Update and return a runtime job record through the current job cache.""" + + return jobs.update(job_id, updates) + + +def pop_runtime_job(job_id: str, default: Any = _MISSING) -> Any: + """Remove and return a runtime job record from the current job cache.""" + + if default is _MISSING: + return jobs.pop(job_id) + return jobs.pop(job_id, default) + + +def runtime_jobs_values_snapshot() -> tuple: + """Return a point-in-time tuple of runtime job values.""" + + return jobs.values_snapshot() + + +def runtime_job_count() -> int: + """Return the current number of runtime job records.""" + + return len(jobs) + + def flush_torch_cuda_cache( logger: logging.Logger | None = None, *, @@ -251,6 +324,62 @@ def register_in_flight(file_hash: str, job_id: str) -> str | None: return None +def try_register_in_flight( + file_hash: str, + job_id: str, + *, + max_entries: int, +) -> InFlightRegistration: + """Register a hash while atomically enforcing the unique in-flight budget.""" + + with _in_flight_lock: + if file_hash in _in_flight_hashes: + return InFlightRegistration(existing_job_id=_in_flight_hashes[file_hash]) + if max_entries > 0 and len(_in_flight_hashes) >= max_entries: + return InFlightRegistration(budget_exceeded=True) + _in_flight_hashes[file_hash] = job_id + return InFlightRegistration(registered=True) + + +def lookup_in_flight(file_hash: str) -> str | None: + with _in_flight_lock: + return _in_flight_hashes.get(file_hash) + + +def in_flight_count() -> int: + with _in_flight_lock: + return len(_in_flight_hashes) + + +def try_reserve_active_job( + job_id: str, + *, + max_entries: int, +) -> ActiveJobReservation: + """Reserve an active job slot independently from the bounded LRU cache.""" + + with _active_job_lock: + if job_id in _active_job_ids: + return ActiveJobReservation(reserved=True) + if max_entries > 0 and len(_active_job_ids) >= max_entries: + return ActiveJobReservation(budget_exceeded=True) + _active_job_ids.add(job_id) + return ActiveJobReservation(reserved=True) + + +def release_active_job(job_id: str) -> bool: + with _active_job_lock: + if job_id not in _active_job_ids: + return False + _active_job_ids.remove(job_id) + return True + + +def active_job_count() -> int: + with _active_job_lock: + return len(_active_job_ids) + + def unregister_in_flight(file_hash: str, job_id: str | None = None) -> bool: with _in_flight_lock: current_job = _in_flight_hashes.get(file_hash) @@ -263,13 +392,28 @@ def unregister_in_flight(file_hash: str, job_id: str | None = None) -> bool: __all__ = [ + "ActiveJobReservation", "_LRUJobsDict", + "InFlightRegistration", + "active_job_count", "flush_torch_cuda_cache", + "get_runtime_job", + "in_flight_count", "jobs", + "lookup_in_flight", + "pop_runtime_job", "record_gpu_job_finished", "register_in_flight", + "release_active_job", "run_serialized_gpu_work", "start_idle_model_unload_daemon", + "runtime_job_count", + "runtime_job_exists", + "runtime_jobs_values_snapshot", + "set_runtime_job", + "try_reserve_active_job", + "try_register_in_flight", + "update_runtime_job", "unload_idle_pipeline_if_due", "unregister_in_flight", ] diff --git a/app/pipeline/contracts/status.py b/app/infra/job_status.py similarity index 97% rename from app/pipeline/contracts/status.py rename to app/infra/job_status.py index 606b12a..c185c47 100644 --- a/app/pipeline/contracts/status.py +++ b/app/infra/job_status.py @@ -1,4 +1,4 @@ -"""Stable contract helpers for persisted job status payloads.""" +"""Stable infra-owned helpers for persisted job status payloads.""" from __future__ import annotations diff --git a/app/infra/transcription_records.py b/app/infra/transcription_records.py new file mode 100644 index 0000000..9d59e92 --- /dev/null +++ b/app/infra/transcription_records.py @@ -0,0 +1,197 @@ +"""Filesystem repository for persisted transcription records.""" + +from __future__ import annotations + +import json +import logging +import re +from dataclasses import dataclass +from pathlib import Path, PurePosixPath, PureWindowsPath +from typing import Any + +from infra.job_persistence import atomic_write_json +from infra.job_status import normalize_status_payload + +logger = logging.getLogger(__name__) + +_TR_ID_RE = re.compile(r"^tr_[A-Za-z0-9_-]{1,64}$") + + +class TranscriptionRecordStorageError(RuntimeError): + """Infra storage error that application usecases map to typed errors.""" + + def __init__(self, reason: str, message: str) -> None: + super().__init__(message) + self.reason = reason + + +@dataclass(frozen=True) +class PersistedJobStatusSnapshot: + status: dict[str, Any] + result_exists: bool = False + result: dict[str, Any] | None = None + + +@dataclass(frozen=True) +class UploadedAudioArtifact: + path: Path + filename: str + + +class FilesystemTranscriptionRecordRepository: + """Read and write transcription record files under infra ownership.""" + + def __init__(self, *, transcriptions_dir: Path, uploads_dir: Path) -> None: + self.transcriptions_dir = transcriptions_dir + self.uploads_dir = uploads_dir + + def job_status_snapshot( + self, + job_id: str, + ) -> PersistedJobStatusSnapshot | None: + tr_dir = self._safe_tr_dir(job_id) + status_path = tr_dir / "status.json" + result_path = tr_dir / "result.json" + + if not status_path.exists(): + return None + + try: + status_data = normalize_status_payload( + json.loads(status_path.read_text(encoding="utf-8")) + ) + except Exception as exc: + logger.warning("Corrupt status.json for %s: %s", job_id, exc) + raise TranscriptionRecordStorageError( + "job_not_found", + "Job not found", + ) from exc + + result_exists = result_path.exists() + result = None + if status_data.get("status") == "completed" and result_exists: + try: + result = json.loads(result_path.read_text(encoding="utf-8")) + except Exception: + result = None + + return PersistedJobStatusSnapshot( + status=status_data, + result_exists=result_exists, + result=result, + ) + + def iter_transcription_results(self) -> list[dict[str, Any]]: + results: list[dict[str, Any]] = [] + for tr_dir in sorted(self.transcriptions_dir.iterdir(), reverse=True): + if not tr_dir.is_dir(): + continue + result_file = tr_dir / "result.json" + if not result_file.exists(): + continue + try: + results.append(json.loads(result_file.read_text(encoding="utf-8"))) + except Exception as exc: + logger.warning( + "Skipping corrupt result.json in %s: %s", + tr_dir.name, + exc, + ) + return results + + def load_result(self, tr_id: str) -> dict[str, Any]: + result_file = self.result_file_path(tr_id) + if not result_file.exists(): + raise TranscriptionRecordStorageError( + "transcription_not_found", + "Transcription not found", + ) + try: + return json.loads(result_file.read_text(encoding="utf-8")) + except Exception as exc: + logger.warning("Corrupt result.json for %s: %s", tr_id, exc) + raise TranscriptionRecordStorageError( + "corrupt_result", + "Corrupt transcription artifact", + ) from exc + + def save_result(self, tr_id: str, payload: dict[str, Any]) -> None: + atomic_write_json( + self.result_file_path(tr_id), + payload, + ensure_ascii=False, + indent=2, + ) + + def result_file_path(self, tr_id: str) -> Path: + return self._safe_tr_dir(tr_id) / "result.json" + + def uploaded_audio_artifact(self, filename_value: object) -> UploadedAudioArtifact: + filename = self._safe_audio_filename(filename_value) + audio_file = self._safe_upload_path(filename) + if not audio_file.exists(): + raise TranscriptionRecordStorageError( + "missing_audio", + "Original audio file not found", + ) + return UploadedAudioArtifact(path=audio_file, filename=filename) + + def _safe_tr_dir(self, tr_id: str) -> Path: + if not _TR_ID_RE.match(tr_id): + raise TranscriptionRecordStorageError( + "invalid_transcription_id", + f"Invalid transcription ID format: {tr_id!r}", + ) + + root = self.transcriptions_dir.resolve() + path = (self.transcriptions_dir / tr_id).resolve() + try: + path.relative_to(root) + except ValueError as exc: + raise TranscriptionRecordStorageError( + "invalid_transcription_id", + "Path traversal detected", + ) from exc + return path + + def _safe_audio_filename(self, value: object) -> str: + if not isinstance(value, str) or not value: + raise TranscriptionRecordStorageError( + "corrupt_result", + "Corrupt transcription artifact", + ) + + posix_path = PurePosixPath(value) + windows_path = PureWindowsPath(value) + if ( + value in {".", ".."} + or posix_path.is_absolute() + or windows_path.is_absolute() + or posix_path.name != value + or windows_path.name != value + ): + raise TranscriptionRecordStorageError( + "corrupt_result", + "Corrupt transcription artifact", + ) + return value + + def _safe_upload_path(self, filename: str) -> Path: + root = self.uploads_dir.resolve() + audio_file = (self.uploads_dir / filename).resolve() + try: + audio_file.relative_to(root) + except ValueError as exc: + raise TranscriptionRecordStorageError( + "corrupt_result", + "Corrupt transcription artifact", + ) from exc + return audio_file + + +__all__ = [ + "FilesystemTranscriptionRecordRepository", + "PersistedJobStatusSnapshot", + "TranscriptionRecordStorageError", + "UploadedAudioArtifact", +] diff --git a/app/pipeline/contracts/__init__.py b/app/pipeline/contracts/__init__.py index 3c08cb0..8154dc1 100644 --- a/app/pipeline/contracts/__init__.py +++ b/app/pipeline/contracts/__init__.py @@ -37,9 +37,22 @@ ProviderNotFoundError, StageNotFoundError, ) +from .metadata import ( + PIPELINE_METADATA_CONTRACT, + PIPELINE_METADATA_CONTROL_KEYS, + PIPELINE_METADATA_PATH_CONTRACT, + PIPELINE_METADATA_PUBLIC_PATHS, + PIPELINE_METADATA_STAGE_KEYS, + PIPELINE_METADATA_TOP_LEVEL_KEYS, + PUBLIC_ALIGNMENT_METADATA_KEYS, + PipelineMetadataEntry, + normalize_public_alignment_metadata, +) from .normalize import ( + AudioNormalizationError, AudioNormalizationRequest, AudioNormalizationResult, + AudioNormalizationTimeoutError, InputNormalizationProvider, ) from .requests import PipelineRequest @@ -50,21 +63,6 @@ attach_optional_schema_version, read_optional_schema_version, ) -from .status import ( - IN_PROGRESS_JOB_STATUSES, - JOB_STATUS_COMPLETED, - JOB_STATUS_CONVERTING, - JOB_STATUS_DENOISING, - JOB_STATUS_FAILED, - JOB_STATUS_IDENTIFYING, - JOB_STATUS_QUEUED, - JOB_STATUS_TRANSCRIBING, - KNOWN_JOB_STATUSES, - TERMINAL_JOB_STATUSES, - build_status_payload, - normalize_job_status, - normalize_status_payload, -) from .voiceprint_match import ( VoiceprintMatchProvider, VoiceprintMatchRequest, @@ -82,36 +80,36 @@ "AudioEnhancementProvider", "AudioEnhancementRequest", "AudioEnhancementResult", + "AudioNormalizationError", "AudioNormalizationRequest", "AudioNormalizationResult", + "AudioNormalizationTimeoutError", "ArtifactManifestEntry", "DiarizationProvider", "DiarizationRequest", "DiarizationResult", "InputNormalizationProvider", - "IN_PROGRESS_JOB_STATUSES", - "JOB_STATUS_COMPLETED", - "JOB_STATUS_CONVERTING", - "JOB_STATUS_DENOISING", - "JOB_STATUS_FAILED", - "JOB_STATUS_IDENTIFYING", - "JOB_STATUS_QUEUED", - "JOB_STATUS_TRANSCRIBING", - "KNOWN_JOB_STATUSES", "OPTIONAL_FIRST_SCHEMA_POLICY", "PersistedTranscriptionArtifacts", "PipelineContext", "PipelineLookupError", + "PIPELINE_METADATA_CONTRACT", + "PIPELINE_METADATA_CONTROL_KEYS", + "PIPELINE_METADATA_PATH_CONTRACT", + "PIPELINE_METADATA_PUBLIC_PATHS", + "PIPELINE_METADATA_STAGE_KEYS", + "PIPELINE_METADATA_TOP_LEVEL_KEYS", "PipelineRequest", "PipelineResult", + "PipelineMetadataEntry", "SavedUploadArtifact", "SpeakerEmbeddingProvider", "SpeakerEmbeddingRequest", "SpeakerEmbeddingResult", "ProviderNotFoundError", + "PUBLIC_ALIGNMENT_METADATA_KEYS", "SCHEMA_VERSION_KEY", "StageNotFoundError", - "TERMINAL_JOB_STATUSES", "TranscriptionArtifactStore", "TranscriptionArtifactWriteRequest", "UploadPersistenceRequest", @@ -120,10 +118,8 @@ "VoiceprintMatchResult", "attach_optional_schema_version", "build_artifact_manifest", - "build_status_payload", "empty_artifact_manifest", + "normalize_public_alignment_metadata", "normalize_artifact_manifest", - "normalize_job_status", - "normalize_status_payload", "read_optional_schema_version", ] diff --git a/app/pipeline/contracts/errors.py b/app/pipeline/contracts/errors.py index d72288c..96e4f01 100644 --- a/app/pipeline/contracts/errors.py +++ b/app/pipeline/contracts/errors.py @@ -1,18 +1,8 @@ -"""Shared errors for pipeline stage and provider resolution.""" +"""Compatibility re-export for pipeline lookup errors.""" from __future__ import annotations - -class PipelineLookupError(LookupError): - """Base class for stable pipeline registry lookup failures.""" - - -class StageNotFoundError(PipelineLookupError): - """Raised when a stable pipeline stage slot cannot be resolved.""" - - -class ProviderNotFoundError(PipelineLookupError): - """Raised when a pipeline provider implementation cannot be resolved.""" +from ..errors import PipelineLookupError, ProviderNotFoundError, StageNotFoundError __all__ = [ diff --git a/app/pipeline/contracts/metadata.py b/app/pipeline/contracts/metadata.py new file mode 100644 index 0000000..26fdb99 --- /dev/null +++ b/app/pipeline/contracts/metadata.py @@ -0,0 +1,161 @@ +"""Pipeline metadata ownership and public-surface contract.""" + +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True, slots=True) +class PipelineMetadataEntry: + """Ownership record for a stable PipelineContext.metadata key or path.""" + + owner: str + writers: tuple[str, ...] + readers: tuple[str, ...] = () + public: bool = False + allow_overwrite: bool = False + description: str = "" + + +PIPELINE_METADATA_CONTROL_KEYS = ( + "executed_stages", + "selected_providers", + "provider_capabilities", + "stage_timings", +) + +PIPELINE_METADATA_STAGE_KEYS = ( + "ingest", + "normalize", + "enhance", + "vad", + "asr", + "diarization", + "embedding", + "voiceprint_match", + "punc", + "postprocess", + "artifacts", +) + +PIPELINE_METADATA_TOP_LEVEL_KEYS = ( + *PIPELINE_METADATA_CONTROL_KEYS, + *PIPELINE_METADATA_STAGE_KEYS, +) + +PIPELINE_METADATA_PUBLIC_PATHS = ("diarization.alignment",) + +PIPELINE_METADATA_STAGE_WRITERS = { + "ingest": ("providers.ingest.default",), + "normalize": ("pipeline.stages.normalize",), + "enhance": ("pipeline.stages.enhance",), + "vad": ("providers.vad.default",), + "asr": ("pipeline.stages.asr",), + "diarization": ("pipeline.stages.diarization",), + "embedding": ("pipeline.stages.embedding",), + "voiceprint_match": ("pipeline.stages.voiceprint_match",), + "punc": ("providers.punc.default",), + "postprocess": ("providers.postprocess.default",), + "artifacts": ("pipeline.stages.artifacts",), +} + +PIPELINE_METADATA_CONTRACT: dict[str, PipelineMetadataEntry] = { + "executed_stages": PipelineMetadataEntry( + owner="pipeline.runner", + writers=("PipelineContext.mark_stage",), + readers=("pipeline.runner",), + description="Ordered stage names observed by the runner.", + ), + "selected_providers": PipelineMetadataEntry( + owner="pipeline.runner", + writers=("pipeline.runner",), + readers=("pipeline.runner",), + description="Provider selected for each stage before execution.", + ), + "provider_capabilities": PipelineMetadataEntry( + owner="pipeline.runner", + writers=("pipeline.runner",), + readers=("pipeline.runner",), + description="Provider capability preflight metadata keyed by stage.", + ), + "stage_timings": PipelineMetadataEntry( + owner="pipeline.runner", + writers=("pipeline.runner",), + readers=("pipeline.runner",), + description="Elapsed stage timing in seconds keyed by stage.", + ), +} + +PIPELINE_METADATA_CONTRACT.update( + { + stage: PipelineMetadataEntry( + owner=stage, + writers=("pipeline.runner", *PIPELINE_METADATA_STAGE_WRITERS[stage]), + readers=("pipeline.runner", "providers.artifacts.default"), + allow_overwrite=True, + description=f"Private execution metadata owned by the {stage} stage.", + ) + for stage in PIPELINE_METADATA_STAGE_KEYS + } +) + +PIPELINE_METADATA_PATH_CONTRACT: dict[str, PipelineMetadataEntry] = { + "diarization.alignment": PipelineMetadataEntry( + owner="diarization", + writers=("pipeline.stages.diarization",), + readers=("providers.artifacts.default",), + public=True, + allow_overwrite=False, + description="Safe forced-alignment summary allowed in result artifacts.", + ), +} + +PUBLIC_ALIGNMENT_METADATA_KEYS = ( + "status", + "reason", + "model", + "duration_s", + "max_duration_s", + "cache_only", + "device", +) + +PublicMetadataValue = str | int | float | bool | None + + +def _is_public_metadata_scalar(value: Any) -> bool: + if value is None or isinstance(value, (str, bool, int)): + return True + return isinstance(value, float) and math.isfinite(value) + + +def normalize_public_alignment_metadata( + value: Any, +) -> dict[str, PublicMetadataValue]: + """Return only the stable, JSON-safe alignment fields exposed publicly.""" + + if not isinstance(value, dict): + return {} + + normalized: dict[str, PublicMetadataValue] = {} + for key in PUBLIC_ALIGNMENT_METADATA_KEYS: + field_value = value.get(key) + if key in value and _is_public_metadata_scalar(field_value): + normalized[key] = field_value + return normalized + + +__all__ = [ + "PIPELINE_METADATA_CONTRACT", + "PIPELINE_METADATA_CONTROL_KEYS", + "PIPELINE_METADATA_PATH_CONTRACT", + "PIPELINE_METADATA_PUBLIC_PATHS", + "PIPELINE_METADATA_STAGE_WRITERS", + "PIPELINE_METADATA_STAGE_KEYS", + "PIPELINE_METADATA_TOP_LEVEL_KEYS", + "PUBLIC_ALIGNMENT_METADATA_KEYS", + "PipelineMetadataEntry", + "normalize_public_alignment_metadata", +] diff --git a/app/pipeline/contracts/normalize.py b/app/pipeline/contracts/normalize.py index c5a0b27..458c433 100644 --- a/app/pipeline/contracts/normalize.py +++ b/app/pipeline/contracts/normalize.py @@ -26,6 +26,14 @@ class AudioNormalizationResult: reused_source: bool +class AudioNormalizationError(RuntimeError): + """Base error for input-audio normalization failures.""" + + +class AudioNormalizationTimeoutError(AudioNormalizationError): + """Raised when the normalization provider exceeds its time budget.""" + + @runtime_checkable class InputNormalizationProvider(Protocol): """Canonical slot for converting uploads into pipeline-ready audio.""" @@ -36,7 +44,9 @@ def normalize( __all__ = [ + "AudioNormalizationError", "AudioNormalizationRequest", "AudioNormalizationResult", + "AudioNormalizationTimeoutError", "InputNormalizationProvider", ] diff --git a/app/pipeline/contracts/requests.py b/app/pipeline/contracts/requests.py index 6c0a863..e06da45 100644 --- a/app/pipeline/contracts/requests.py +++ b/app/pipeline/contracts/requests.py @@ -8,12 +8,11 @@ from types import MappingProxyType from typing import Any +from ..step_keys import canonical_step_name, normalize_token + def _normalize_provider_name(name: str) -> str: - token = name.strip().lower().replace("-", "_") - if not token: - raise ValueError("provider name must not be empty") - return token + return normalize_token(name, field_name="provider name") @dataclass(frozen=True, slots=True) @@ -37,8 +36,6 @@ class PipelineRequest: def __post_init__(self) -> None: normalized: dict[str, str] = {} if self.provider_selection: - from pipeline.registry import canonical_step_name - for step_name, provider_name in self.provider_selection.items(): step_key = canonical_step_name(str(step_name)) normalized[step_key] = _normalize_provider_name(str(provider_name)) @@ -51,8 +48,6 @@ def __post_init__(self) -> None: def provider_for(self, step: str, default: str = "default") -> str: """Return the explicitly selected provider for a step, or the fallback.""" - from pipeline.registry import canonical_step_name - step_key = canonical_step_name(step) return self.provider_selection.get(step_key, _normalize_provider_name(default)) diff --git a/app/pipeline/errors.py b/app/pipeline/errors.py new file mode 100644 index 0000000..d72288c --- /dev/null +++ b/app/pipeline/errors.py @@ -0,0 +1,22 @@ +"""Shared errors for pipeline stage and provider resolution.""" + +from __future__ import annotations + + +class PipelineLookupError(LookupError): + """Base class for stable pipeline registry lookup failures.""" + + +class StageNotFoundError(PipelineLookupError): + """Raised when a stable pipeline stage slot cannot be resolved.""" + + +class ProviderNotFoundError(PipelineLookupError): + """Raised when a pipeline provider implementation cannot be resolved.""" + + +__all__ = [ + "PipelineLookupError", + "ProviderNotFoundError", + "StageNotFoundError", +] diff --git a/app/pipeline/orchestrator.py b/app/pipeline/orchestrator.py index a808c2e..7dec44d 100644 --- a/app/pipeline/orchestrator.py +++ b/app/pipeline/orchestrator.py @@ -16,25 +16,77 @@ import tempfile import time from contextlib import nullcontext +from importlib import import_module from pathlib import Path from typing import Any import torch from config import DEVICE, HF_TOKEN, PYANNOTE_MIN_DURATION_OFF, WHISPER_MODEL -from infra.cuda_devices import select_best_cuda_device -from infra.huggingface_models import ( - configure_huggingface_runtime, - resolve_hf_model_ref, -) -from providers.asr import transcribe_audio -from providers.diarization import align_diarized_segments, run_pyannote_diarization -from providers.embedding import extract_embeddings_for_turns from .contracts import PipelineRequest from .runner import PipelineRunner logger = logging.getLogger(__name__) + + +def configure_huggingface_runtime() -> None: + """Invoke the infra runtime adapter without a static pipeline->infra edge.""" + + import_module("infra.huggingface_models").configure_huggingface_runtime() + + +def resolve_hf_model_ref(*args, **kwargs): + """Resolve HF references through the infra adapter boundary.""" + + return import_module("infra.huggingface_models").resolve_hf_model_ref( + *args, + **kwargs, + ) + + +def select_best_cuda_device(*args, **kwargs): + """Select CUDA devices through the infra adapter boundary.""" + + return import_module("infra.cuda_devices").select_best_cuda_device( + *args, + **kwargs, + ) + + +def transcribe_audio(*args, **kwargs): + """Call the ASR provider through the provider boundary.""" + + return import_module("providers.asr").transcribe_audio(*args, **kwargs) + + +def run_pyannote_diarization(*args, **kwargs): + """Call the diarization provider through the provider boundary.""" + + return import_module("providers.diarization").run_pyannote_diarization( + *args, + **kwargs, + ) + + +def align_diarized_segments(*args, **kwargs): + """Call the alignment provider helper through the provider boundary.""" + + return import_module("providers.diarization").align_diarized_segments( + *args, + **kwargs, + ) + + +def extract_embeddings_for_turns(*args, **kwargs): + """Call the embedding provider through the provider boundary.""" + + return import_module("providers.embedding").extract_embeddings_for_turns( + *args, + **kwargs, + ) + + configure_huggingface_runtime() _TRUSTED_PYANNOTE_TASK_GLOBAL_NAMES = ( diff --git a/app/pipeline/registry.py b/app/pipeline/registry.py index 15cb61a..c307a2e 100644 --- a/app/pipeline/registry.py +++ b/app/pipeline/registry.py @@ -5,12 +5,8 @@ from importlib import import_module from typing import Any, Callable -from .contracts import ProviderNotFoundError, StageNotFoundError - -_STEP_ALIASES = { - "input_normalization": "normalize", - "enhancement": "enhance", -} +from .errors import ProviderNotFoundError, StageNotFoundError +from .step_keys import canonical_step_name, normalize_token _DEFAULT_STAGE_IMPORTS = { "ingest": "pipeline.stages.ingest:run", @@ -65,20 +61,6 @@ _PROVIDER_OVERRIDES: dict[str, dict[str, Any]] = {} -def _normalize_token(value: str, *, field_name: str) -> str: - token = value.strip().lower().replace("-", "_") - if not token: - raise ValueError(f"{field_name} must not be empty") - return token - - -def canonical_step_name(step: str) -> str: - """Map compatibility aliases onto the canonical stable step names.""" - - token = _normalize_token(step, field_name="step") - return _STEP_ALIASES.get(token, token) - - def _load_object(import_path: str) -> Any: module_name, _, attr_name = import_path.partition(":") if not module_name or not attr_name: @@ -96,7 +78,7 @@ def register_provider(step: str, name: str, provider: Any) -> None: """Register or override a provider implementation for a pipeline step.""" step_key = canonical_step_name(step) - name_key = _normalize_token(name, field_name="name") + name_key = normalize_token(name, field_name="name") _PROVIDER_OVERRIDES.setdefault(step_key, {})[name_key] = provider @@ -104,7 +86,7 @@ def unregister_provider(step: str, name: str) -> None: """Remove a test or runtime override provider.""" step_key = canonical_step_name(step) - name_key = _normalize_token(name, field_name="name") + name_key = normalize_token(name, field_name="name") step_overrides = _PROVIDER_OVERRIDES.get(step_key) if not step_overrides: return @@ -113,11 +95,19 @@ def unregister_provider(step: str, name: str) -> None: _PROVIDER_OVERRIDES.pop(step_key, None) +def is_provider_override(step: str, name: str) -> bool: + """Return whether a provider name is a test or runtime override.""" + + step_key = canonical_step_name(step) + name_key = normalize_token(name, field_name="name") + return name_key in _PROVIDER_OVERRIDES.get(step_key, {}) + + def resolve_provider(step: str, name: str = "default") -> Any: """Resolve a provider object by stable step name and implementation name.""" step_key = canonical_step_name(step) - name_key = _normalize_token(name, field_name="name") + name_key = normalize_token(name, field_name="name") override = _PROVIDER_OVERRIDES.get(step_key, {}).get(name_key) if override is not None: @@ -162,6 +152,7 @@ def resolve_stage(slot: str) -> Callable[[Any], None]: "available_providers", "available_stage_slots", "canonical_step_name", + "is_provider_override", "register_provider", "resolve_provider", "resolve_stage", diff --git a/app/pipeline/runner.py b/app/pipeline/runner.py index cf1c88a..1f49f94 100644 --- a/app/pipeline/runner.py +++ b/app/pipeline/runner.py @@ -4,18 +4,23 @@ import logging import time +from importlib import import_module from typing import Any -from infra.audio import cleanup_generated_files - from .contracts import PipelineContext, PipelineRequest -from .registry import available_stage_slots, resolve_stage +from .registry import available_stage_slots, is_provider_override, resolve_stage logger = logging.getLogger(__name__) DEFAULT_STAGE_ORDER = available_stage_slots() +def cleanup_generated_files(paths): + """Invoke the infra temp-file adapter without making pipeline import infra.""" + + return import_module("infra.audio").cleanup_generated_files(paths) + + def _safe_stage_metrics(context: PipelineContext, stage_name: str) -> dict[str, Any]: metrics: dict[str, Any] = {} stage_metadata = context.metadata.get(stage_name) @@ -45,6 +50,30 @@ def _safe_stage_metrics(context: PipelineContext, stage_name: str) -> dict[str, return metrics +def _provider_preflight_metadata( + stage_name: str, + provider_name: str, + language: str | None, +) -> tuple[bool, dict[str, str]]: + capabilities = import_module("providers.capabilities") + try: + match = capabilities.match_provider_capability( + stage_name, + provider_name, + language=language, + ) + except capabilities.ProviderCapabilityNotFoundError: + if is_provider_override(stage_name, provider_name): + return True, { + "stage": stage_name, + "provider": provider_name, + "reason": "runtime_override", + "action": "run", + } + raise + return match.should_run, match.metadata + + class PipelineRunner: """Execute the stable stage order against the current pipeline implementation.""" @@ -66,12 +95,35 @@ def run_context(self, pipeline: Any, request: PipelineRequest) -> PipelineContex context = self.build_context(pipeline, request) try: for stage_name in self.stage_order: + provider_name = request.provider_for(stage_name) + context.metadata.setdefault("selected_providers", {})[stage_name] = ( + provider_name + ) + should_run, capability_metadata = _provider_preflight_metadata( + stage_name, + provider_name, + request.language, + ) + context.metadata.setdefault("provider_capabilities", {})[stage_name] = ( + capability_metadata + ) + if not should_run: + context.mark_stage(stage_name) + context.metadata[stage_name] = { + "status": "skipped", + **capability_metadata, + } + logger.info( + "Skipping pipeline stage: %s provider=%s reason=%s", + stage_name, + provider_name, + capability_metadata.get("reason"), + ) + continue + logger.info("Running pipeline stage: %s", stage_name) stage = self.resolve_stage(stage_name) context.mark_stage(stage_name) - context.metadata.setdefault("selected_providers", {})[stage_name] = ( - request.provider_for(stage_name) - ) stage_started = time.perf_counter() stage(context) elapsed_s = time.perf_counter() - stage_started @@ -84,7 +136,7 @@ def run_context(self, pipeline: Any, request: PipelineRequest) -> PipelineContex "pipeline_stage_timing stage=%s elapsed_s=%.3f provider=%s metrics=%s", stage_name, elapsed_s, - request.provider_for(stage_name), + provider_name, metrics, ) return context diff --git a/app/pipeline/stages/artifacts/__init__.py b/app/pipeline/stages/artifacts/__init__.py index 2edd34b..15e3936 100644 --- a/app/pipeline/stages/artifacts/__init__.py +++ b/app/pipeline/stages/artifacts/__init__.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING -from providers.artifacts import build_pipeline_artifacts +from pipeline.registry import resolve_provider if TYPE_CHECKING: from pipeline.contracts import PipelineContext @@ -13,10 +13,8 @@ def run(context: "PipelineContext") -> None: """Emit the current in-memory pipeline artifact bundle.""" - context.result = build_pipeline_artifacts( - context, - provider_name=context.request.provider_for("artifacts"), - ) + provider = resolve_provider("artifacts", context.request.provider_for("artifacts")) + context.result = provider.build(context) result = context.result.as_dict() if hasattr(context.result, "as_dict") else {} context.metadata["artifacts"] = { "segment_count": len(result.get("segments", context.aligned_segments)), diff --git a/app/pipeline/stages/asr/__init__.py b/app/pipeline/stages/asr/__init__.py index cf05ebf..190a355 100644 --- a/app/pipeline/stages/asr/__init__.py +++ b/app/pipeline/stages/asr/__init__.py @@ -4,7 +4,8 @@ from typing import TYPE_CHECKING -from providers.asr import transcribe_audio +from pipeline.contracts import ASRRequest +from pipeline.registry import resolve_provider if TYPE_CHECKING: from pipeline.contracts import PipelineContext @@ -16,12 +17,14 @@ def run(context: "PipelineContext") -> None: if context.request.status_callback is not None: context.request.status_callback("transcribing") - result = transcribe_audio( - context.pipeline, - context.working_audio_path, - language=context.request.language, - no_repeat_ngram_size=context.request.no_repeat_ngram_size, - provider_name=context.request.provider_for("asr"), + provider = resolve_provider("asr", context.request.provider_for("asr")) + result = provider.transcribe( + ASRRequest( + pipeline=context.pipeline, + audio_path=context.working_audio_path, + language=context.request.language, + no_repeat_ngram_size=context.request.no_repeat_ngram_size, + ) ) context.transcription_result = result.transcription_result context.metadata["asr"] = { diff --git a/app/pipeline/stages/diarization/__init__.py b/app/pipeline/stages/diarization/__init__.py index c09be23..600f416 100644 --- a/app/pipeline/stages/diarization/__init__.py +++ b/app/pipeline/stages/diarization/__init__.py @@ -4,6 +4,12 @@ from typing import TYPE_CHECKING +from pipeline.contracts import ( + DiarizationRequest, + normalize_public_alignment_metadata, +) +from pipeline.registry import resolve_provider + from .alignment import ( assign_segment_speaker, build_aligned_segments, @@ -19,27 +25,34 @@ def run(context: "PipelineContext") -> None: """Run diarization, attach speakers, and apply current overlap cleanup.""" - from providers.diarization import run_diarization - if context.transcription_result is None: raise RuntimeError("ASR stage must run before diarization") - result = run_diarization( - context.pipeline, - context.working_audio_path, - context.transcription_result, - min_speakers=context.request.min_speakers, - max_speakers=context.request.max_speakers, - provider_name=context.request.provider_for("diarization"), + provider = resolve_provider( + "diarization", + context.request.provider_for("diarization"), + ) + result = provider.diarize( + DiarizationRequest( + pipeline=context.pipeline, + audio_path=context.working_audio_path, + transcription_result=context.transcription_result, + min_speakers=context.request.min_speakers, + max_speakers=context.request.max_speakers, + ) ) context.diarization_turns = result.turns context.aligned_segments = result.aligned_segments - context.metadata["diarization"] = { + diarization_metadata = { "turn_count": len(result.turns), "dedup_removed": result.dedup_removed, } - if result.metadata: - context.metadata["diarization"].update(result.metadata) + alignment_metadata = normalize_public_alignment_metadata( + result.metadata.get("alignment") if result.metadata else None + ) + if alignment_metadata: + diarization_metadata["alignment"] = alignment_metadata + context.metadata["diarization"] = diarization_metadata __all__ = [ diff --git a/app/pipeline/stages/diarization/alignment.py b/app/pipeline/stages/diarization/alignment.py index 5d6d9c8..a112cda 100644 --- a/app/pipeline/stages/diarization/alignment.py +++ b/app/pipeline/stages/diarization/alignment.py @@ -1,77 +1,19 @@ -"""Pure helpers for transcription/diarization alignment post-processing.""" +"""Compatibility re-exports for diarization alignment helpers.""" from __future__ import annotations -from typing import Any - -from postprocess.segments import normalize_words - - -def assign_segment_speaker( - seg_start: float, seg_end: float, diarization_turns: list[dict[str, Any]] -) -> str: - """Pick the diarization speaker with the greatest overlap for a segment.""" - seg_mid = (seg_start + seg_end) / 2 - best_speaker = "UNKNOWN" - best_overlap = 0.0 - - for turn in diarization_turns: - overlap_start = max(seg_start, turn["start"]) - overlap_end = min(seg_end, turn["end"]) - overlap = max(0.0, overlap_end - overlap_start) - if overlap > best_overlap: - best_overlap = overlap - best_speaker = turn["speaker"] - - if best_speaker != "UNKNOWN": - return best_speaker - - for turn in diarization_turns: - if turn["start"] <= seg_mid <= turn["end"]: - return turn["speaker"] - - return best_speaker - - -def normalize_segment( - segment: dict[str, Any], diarization_turns: list[dict[str, Any]] -) -> dict[str, Any]: - """Attach a speaker label and normalise optional word timings.""" - seg_start = float(segment.get("start", 0.0)) - seg_end = float(segment.get("end", 0.0)) - result = { - "start": round(seg_start, 3), - "end": round(seg_end, 3), - "text": segment.get("text", "").strip(), - "speaker": assign_segment_speaker(seg_start, seg_end, diarization_turns), - } - - words = normalize_words(segment.get("words")) - if words: - result["words"] = words - - return result - - -def build_aligned_segments( - segments: list[dict[str, Any]], diarization_turns: list[dict[str, Any]] -) -> list[dict[str, Any]]: - """Normalise aligned segments and assign speakers.""" - return [normalize_segment(segment, diarization_turns) for segment in segments] - - -def dedup_short_segments(segments: list[dict[str, Any]]) -> list[dict[str, Any]]: - """Drop consecutive duplicate short segments (backchannel suppression).""" - if not segments: - return segments - - result = [segments[0]] - for segment in segments[1:]: - previous = result[-1] - text = segment.get("text", "").strip() - previous_text = previous.get("text", "").strip() - duration = segment.get("end", 0.0) - segment.get("start", 0.0) - if text and text == previous_text and duration < 2.0 and len(text) <= 4: - continue - result.append(segment) - return result +from postprocess.alignment import ( + assign_segment_speaker, + build_aligned_segments, + dedup_short_segments, + normalize_segment, + normalize_words, +) + +__all__ = [ + "assign_segment_speaker", + "build_aligned_segments", + "dedup_short_segments", + "normalize_segment", + "normalize_words", +] diff --git a/app/pipeline/stages/embedding/__init__.py b/app/pipeline/stages/embedding/__init__.py index 1749dd4..8846e80 100644 --- a/app/pipeline/stages/embedding/__init__.py +++ b/app/pipeline/stages/embedding/__init__.py @@ -4,7 +4,8 @@ from typing import TYPE_CHECKING -from providers.embedding import extract_speaker_embeddings +from pipeline.contracts import SpeakerEmbeddingRequest +from pipeline.registry import resolve_provider if TYPE_CHECKING: from pipeline.contracts import PipelineContext @@ -13,11 +14,13 @@ def run(context: "PipelineContext") -> None: """Extract speaker embeddings after diarization has defined the turns.""" - result = extract_speaker_embeddings( - context.pipeline, - context.embedding_audio_path, - context.diarization_turns, - provider_name=context.request.provider_for("embedding"), + provider = resolve_provider("embedding", context.request.provider_for("embedding")) + result = provider.extract_embeddings( + SpeakerEmbeddingRequest( + pipeline=context.pipeline, + audio_path=context.embedding_audio_path, + diarization_turns=context.diarization_turns, + ) ) context.speaker_embeddings = result.speaker_embeddings context.metadata["embedding"] = { diff --git a/app/pipeline/stages/enhance/__init__.py b/app/pipeline/stages/enhance/__init__.py index ee8bb69..77d86b2 100644 --- a/app/pipeline/stages/enhance/__init__.py +++ b/app/pipeline/stages/enhance/__init__.py @@ -6,7 +6,8 @@ from typing import TYPE_CHECKING from config import DENOISE_MODEL -from providers.enhance import enhance_audio +from pipeline.contracts import AudioEnhancementRequest +from pipeline.registry import resolve_provider if TYPE_CHECKING: from pipeline.contracts import PipelineContext @@ -20,11 +21,13 @@ def run(context: "PipelineContext") -> None: context.request.status_callback("denoising") input_path = Path(context.working_audio_path or context.request.audio_path) - result = enhance_audio( - input_path, - model=context.request.denoise_model, - snr_threshold=context.request.snr_threshold, - provider_name=context.request.provider_for("enhance"), + provider = resolve_provider("enhance", context.request.provider_for("enhance")) + result = provider.enhance( + AudioEnhancementRequest( + wav_path=input_path, + model=context.request.denoise_model, + snr_threshold=context.request.snr_threshold, + ) ) context.working_audio_path = str(result.output_path) diff --git a/app/pipeline/stages/ingest/__init__.py b/app/pipeline/stages/ingest/__init__.py index e59d21f..3608c8a 100644 --- a/app/pipeline/stages/ingest/__init__.py +++ b/app/pipeline/stages/ingest/__init__.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING -from providers.ingest import run_ingest +from pipeline.registry import resolve_provider if TYPE_CHECKING: from pipeline.contracts import PipelineContext @@ -13,7 +13,5 @@ def run(context: "PipelineContext") -> None: """Seed the pipeline context through the selected ingest provider.""" - run_ingest( - context, - provider_name=context.request.provider_for("ingest"), - ) + provider = resolve_provider("ingest", context.request.provider_for("ingest")) + provider.run(context) diff --git a/app/pipeline/stages/normalize/__init__.py b/app/pipeline/stages/normalize/__init__.py index 3f93d99..f1e64a0 100644 --- a/app/pipeline/stages/normalize/__init__.py +++ b/app/pipeline/stages/normalize/__init__.py @@ -5,7 +5,8 @@ from pathlib import Path from typing import TYPE_CHECKING -from providers.normalize import normalize_audio +from pipeline.contracts import AudioNormalizationRequest +from pipeline.registry import resolve_provider if TYPE_CHECKING: from pipeline.contracts import PipelineContext @@ -18,10 +19,8 @@ def run(context: "PipelineContext") -> None: context.request.status_callback("converting") input_path = Path(context.working_audio_path or context.request.audio_path) - result = normalize_audio( - input_path, - provider_name=context.request.provider_for("normalize"), - ) + provider = resolve_provider("normalize", context.request.provider_for("normalize")) + result = provider.normalize(AudioNormalizationRequest(input_path=input_path)) context.working_audio_path = str(result.normalized_path) if context.request.raw_audio_path is None: diff --git a/app/pipeline/stages/postprocess/__init__.py b/app/pipeline/stages/postprocess/__init__.py index 5fa2c41..cf7cebe 100644 --- a/app/pipeline/stages/postprocess/__init__.py +++ b/app/pipeline/stages/postprocess/__init__.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING -from providers.postprocess import run_postprocess +from pipeline.registry import resolve_provider if TYPE_CHECKING: from pipeline.contracts import PipelineContext @@ -13,7 +13,8 @@ def run(context: "PipelineContext") -> None: """Reserve a stable boundary for LLM or rule-based transcript cleanup.""" - run_postprocess( - context, - provider_name=context.request.provider_for("postprocess"), + provider = resolve_provider( + "postprocess", + context.request.provider_for("postprocess"), ) + provider.run(context) diff --git a/app/pipeline/stages/punc/__init__.py b/app/pipeline/stages/punc/__init__.py index 8db2407..814cb1b 100644 --- a/app/pipeline/stages/punc/__init__.py +++ b/app/pipeline/stages/punc/__init__.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING -from providers.punc import run_punc +from pipeline.registry import resolve_provider if TYPE_CHECKING: from pipeline.contracts import PipelineContext @@ -13,7 +13,5 @@ def run(context: "PipelineContext") -> None: """Keep punctuation as an explicit slot for later model substitution.""" - run_punc( - context, - provider_name=context.request.provider_for("punc"), - ) + provider = resolve_provider("punc", context.request.provider_for("punc")) + provider.run(context) diff --git a/app/pipeline/stages/vad/__init__.py b/app/pipeline/stages/vad/__init__.py index 93cbedc..294ecba 100644 --- a/app/pipeline/stages/vad/__init__.py +++ b/app/pipeline/stages/vad/__init__.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING -from providers.vad import run_vad +from pipeline.registry import resolve_provider if TYPE_CHECKING: from pipeline.contracts import PipelineContext @@ -13,7 +13,5 @@ def run(context: "PipelineContext") -> None: """Capture VAD policy through the selected stable provider.""" - run_vad( - context, - provider_name=context.request.provider_for("vad"), - ) + provider = resolve_provider("vad", context.request.provider_for("vad")) + provider.run(context) diff --git a/app/pipeline/stages/voiceprint_match/__init__.py b/app/pipeline/stages/voiceprint_match/__init__.py index d8da136..bfdc5d4 100644 --- a/app/pipeline/stages/voiceprint_match/__init__.py +++ b/app/pipeline/stages/voiceprint_match/__init__.py @@ -4,7 +4,8 @@ from typing import TYPE_CHECKING -from providers.voiceprint_match import match_speaker_embeddings +from pipeline.contracts import VoiceprintMatchRequest +from pipeline.registry import resolve_provider if TYPE_CHECKING: from pipeline.contracts import PipelineContext @@ -16,11 +17,16 @@ def run(context: "PipelineContext") -> None: if context.request.status_callback is not None: context.request.status_callback("identifying") - result = match_speaker_embeddings( - context.speaker_embeddings, - voiceprint_db=context.request.voiceprint_db, - threshold=context.request.voiceprint_threshold, - provider_name=context.request.provider_for("voiceprint_match"), + provider = resolve_provider( + "voiceprint_match", + context.request.provider_for("voiceprint_match"), + ) + result = provider.match( + VoiceprintMatchRequest( + speaker_embeddings=context.speaker_embeddings, + voiceprint_db=context.request.voiceprint_db, + threshold=context.request.voiceprint_threshold, + ) ) context.voiceprint_matches = result.speaker_map context.metadata["voiceprint_match"] = { diff --git a/app/pipeline/step_keys.py b/app/pipeline/step_keys.py new file mode 100644 index 0000000..7e3ed62 --- /dev/null +++ b/app/pipeline/step_keys.py @@ -0,0 +1,28 @@ +"""Canonical keys for stable pipeline stages and provider selectors.""" + +from __future__ import annotations + +_STEP_ALIASES = { + "input_normalization": "normalize", + "enhancement": "enhance", +} + + +def normalize_token(value: str, *, field_name: str) -> str: + token = value.strip().lower().replace("-", "_") + if not token: + raise ValueError(f"{field_name} must not be empty") + return token + + +def canonical_step_name(step: str) -> str: + """Map compatibility aliases onto the canonical stable step names.""" + + token = normalize_token(step, field_name="step") + return _STEP_ALIASES.get(token, token) + + +__all__ = [ + "canonical_step_name", + "normalize_token", +] diff --git a/app/postprocess/alignment.py b/app/postprocess/alignment.py new file mode 100644 index 0000000..3fd7ca8 --- /dev/null +++ b/app/postprocess/alignment.py @@ -0,0 +1,86 @@ +"""Pure helpers for transcription/diarization alignment post-processing.""" + +from __future__ import annotations + +from typing import Any + +from postprocess.segments import normalize_words + + +def assign_segment_speaker( + seg_start: float, seg_end: float, diarization_turns: list[dict[str, Any]] +) -> str: + """Pick the diarization speaker with the greatest overlap for a segment.""" + seg_mid = (seg_start + seg_end) / 2 + best_speaker = "UNKNOWN" + best_overlap = 0.0 + + for turn in diarization_turns: + overlap_start = max(seg_start, turn["start"]) + overlap_end = min(seg_end, turn["end"]) + overlap = max(0.0, overlap_end - overlap_start) + if overlap > best_overlap: + best_overlap = overlap + best_speaker = turn["speaker"] + + if best_speaker != "UNKNOWN": + return best_speaker + + for turn in diarization_turns: + if turn["start"] <= seg_mid <= turn["end"]: + return turn["speaker"] + + return best_speaker + + +def normalize_segment( + segment: dict[str, Any], diarization_turns: list[dict[str, Any]] +) -> dict[str, Any]: + """Attach a speaker label and normalise optional word timings.""" + seg_start = float(segment.get("start", 0.0)) + seg_end = float(segment.get("end", 0.0)) + result = { + "start": round(seg_start, 3), + "end": round(seg_end, 3), + "text": segment.get("text", "").strip(), + "speaker": assign_segment_speaker(seg_start, seg_end, diarization_turns), + } + + words = normalize_words(segment.get("words")) + if words: + result["words"] = words + + return result + + +def build_aligned_segments( + segments: list[dict[str, Any]], diarization_turns: list[dict[str, Any]] +) -> list[dict[str, Any]]: + """Normalise aligned segments and assign speakers.""" + return [normalize_segment(segment, diarization_turns) for segment in segments] + + +def dedup_short_segments(segments: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Drop consecutive duplicate short segments (backchannel suppression).""" + if not segments: + return segments + + result = [segments[0]] + for segment in segments[1:]: + previous = result[-1] + text = segment.get("text", "").strip() + previous_text = previous.get("text", "").strip() + duration = segment.get("end", 0.0) - segment.get("start", 0.0) + if text and text == previous_text and duration < 2.0 and len(text) <= 4: + continue + result.append(segment) + return result + + +__all__ = [ + "assign_segment_speaker", + "build_aligned_segments", + "dedup_short_segments", + "normalize_segment", + "normalize_words", +] diff --git a/app/providers/__init__.py b/app/providers/__init__.py index 509d331..0a65af4 100644 --- a/app/providers/__init__.py +++ b/app/providers/__init__.py @@ -1,53 +1,103 @@ """Provider entrypoints for pipeline-adjacent implementation slots.""" -from .capabilities import ( - CapabilityMatch, - ProviderCapability, - ProviderCapabilityError, - default_provider_capabilities, - get_provider_capability, - match_provider_capability, -) -from .asr import PipelineMethodASRProvider, default_asr_provider, transcribe_audio -from .artifacts import InMemoryArtifactsProvider, build_pipeline_artifacts -from .diarization import ( - PipelineMethodDiarizationProvider, - default_diarization_provider, - run_diarization, -) -from .embedding import ( - PipelineMethodSpeakerEmbeddingProvider, - default_speaker_embedding_provider, - extract_speaker_embeddings, -) -from .enhance import ( - ConditionalDenoiseEnhancer, - default_audio_enhancer, - default_enhance_provider, - maybe_denoise, -) -from .ingest import DefaultIngestProvider, run_ingest -from .normalize import ( - FFmpegInputNormalizer, - convert_to_wav, - default_input_normalizer, - default_normalize_provider, -) -from .postprocess import DefaultPostprocessProvider, run_postprocess -from .punc import DefaultPunctuationProvider, run_punc -from .vad import DefaultVADProvider, run_vad -from .voiceprint_match import ( - DefaultVoiceprintMatchProvider, - default_voiceprint_match_provider, - match_speaker_embeddings, -) -from pipeline.registry import ( - available_providers, - available_stage_slots, - register_provider, - resolve_provider, - unregister_provider, -) +from __future__ import annotations + +from importlib import import_module +from typing import Any + + +_EXPORTS = { + "CapabilityMatch": ("providers.capabilities", "CapabilityMatch"), + "ProviderCapability": ("providers.capabilities", "ProviderCapability"), + "ProviderCapabilityError": ("providers.capabilities", "ProviderCapabilityError"), + "default_provider_capabilities": ( + "providers.capabilities", + "default_provider_capabilities", + ), + "get_provider_capability": ("providers.capabilities", "get_provider_capability"), + "match_provider_capability": ( + "providers.capabilities", + "match_provider_capability", + ), + "PipelineMethodASRProvider": ("providers.asr", "PipelineMethodASRProvider"), + "default_asr_provider": ("providers.asr", "default_asr_provider"), + "transcribe_audio": ("providers.asr", "transcribe_audio"), + "InMemoryArtifactsProvider": ("providers.artifacts", "InMemoryArtifactsProvider"), + "build_pipeline_artifacts": ("providers.artifacts", "build_pipeline_artifacts"), + "PipelineMethodDiarizationProvider": ( + "providers.diarization", + "PipelineMethodDiarizationProvider", + ), + "default_diarization_provider": ( + "providers.diarization", + "default_diarization_provider", + ), + "run_diarization": ("providers.diarization", "run_diarization"), + "PipelineMethodSpeakerEmbeddingProvider": ( + "providers.embedding", + "PipelineMethodSpeakerEmbeddingProvider", + ), + "default_speaker_embedding_provider": ( + "providers.embedding", + "default_speaker_embedding_provider", + ), + "extract_speaker_embeddings": ( + "providers.embedding", + "extract_speaker_embeddings", + ), + "ConditionalDenoiseEnhancer": ( + "providers.enhance", + "ConditionalDenoiseEnhancer", + ), + "default_audio_enhancer": ("providers.enhance", "default_audio_enhancer"), + "default_enhance_provider": ("providers.enhance", "default_enhance_provider"), + "maybe_denoise": ("providers.enhance", "maybe_denoise"), + "DefaultIngestProvider": ("providers.ingest", "DefaultIngestProvider"), + "run_ingest": ("providers.ingest", "run_ingest"), + "FFmpegInputNormalizer": ("providers.normalize", "FFmpegInputNormalizer"), + "convert_to_wav": ("providers.normalize", "convert_to_wav"), + "default_input_normalizer": ("providers.normalize", "default_input_normalizer"), + "default_normalize_provider": ( + "providers.normalize", + "default_normalize_provider", + ), + "DefaultPostprocessProvider": ( + "providers.postprocess", + "DefaultPostprocessProvider", + ), + "run_postprocess": ("providers.postprocess", "run_postprocess"), + "DefaultPunctuationProvider": ("providers.punc", "DefaultPunctuationProvider"), + "run_punc": ("providers.punc", "run_punc"), + "DefaultVADProvider": ("providers.vad", "DefaultVADProvider"), + "run_vad": ("providers.vad", "run_vad"), + "DefaultVoiceprintMatchProvider": ( + "providers.voiceprint_match", + "DefaultVoiceprintMatchProvider", + ), + "default_voiceprint_match_provider": ( + "providers.voiceprint_match", + "default_voiceprint_match_provider", + ), + "match_speaker_embeddings": ( + "providers.voiceprint_match", + "match_speaker_embeddings", + ), +} + + +def __getattr__(name: str) -> Any: + try: + module_name, attr_name = _EXPORTS[name] + except KeyError as exc: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") from exc + value = getattr(import_module(module_name), attr_name) + globals()[name] = value + return value + + +def __dir__() -> list[str]: + return sorted(set(globals()) | set(__all__)) + __all__ = [ "ConditionalDenoiseEnhancer", @@ -64,8 +114,6 @@ "PipelineMethodSpeakerEmbeddingProvider", "ProviderCapability", "ProviderCapabilityError", - "available_providers", - "available_stage_slots", "build_pipeline_artifacts", "convert_to_wav", "default_asr_provider", @@ -82,13 +130,10 @@ "match_provider_capability", "match_speaker_embeddings", "maybe_denoise", - "register_provider", - "resolve_provider", "run_diarization", "run_ingest", "run_postprocess", "run_punc", "run_vad", "transcribe_audio", - "unregister_provider", ] diff --git a/app/providers/_registry.py b/app/providers/_registry.py new file mode 100644 index 0000000..431a37e --- /dev/null +++ b/app/providers/_registry.py @@ -0,0 +1,35 @@ +"""Compatibility guard for direct provider facade helpers. + +Provider selection is owned by :mod:`pipeline.registry`. The ``providers.*`` +facade helpers remain only for legacy direct calls and must not route named +provider selection back through the pipeline registry. +""" + +from __future__ import annotations + +DEFAULT_PROVIDER_NAME = "default" + + +class ProviderFacadeSelectionError(ValueError): + """Raised when a direct provider facade is asked to select a provider.""" + + +def require_default_provider( + step: str, + provider_name: str = DEFAULT_PROVIDER_NAME, +) -> None: + """Reject named provider selection from legacy direct provider facades.""" + + if provider_name != DEFAULT_PROVIDER_NAME: + raise ProviderFacadeSelectionError( + f"providers.{step} facade helpers only support " + f"provider_name={DEFAULT_PROVIDER_NAME!r}; use pipeline.registry or " + "PipelineRunner for named provider selection." + ) + + +__all__ = [ + "DEFAULT_PROVIDER_NAME", + "ProviderFacadeSelectionError", + "require_default_provider", +] diff --git a/app/providers/artifacts/__init__.py b/app/providers/artifacts/__init__.py index 1d855dc..9342dfa 100644 --- a/app/providers/artifacts/__init__.py +++ b/app/providers/artifacts/__init__.py @@ -2,10 +2,8 @@ from __future__ import annotations -from typing import cast - from pipeline.contracts import PipelineContext, PipelineResult -from pipeline.registry import resolve_provider +from providers._registry import require_default_provider from .default import InMemoryArtifactsProvider, default_artifacts_provider @@ -15,11 +13,8 @@ def build_pipeline_artifacts( ) -> PipelineResult: """Build the current in-memory artifact bundle through the provider boundary.""" - provider = cast( - InMemoryArtifactsProvider, - resolve_provider("artifacts", provider_name), - ) - return provider.build(context) + require_default_provider("artifacts", provider_name) + return default_artifacts_provider.build(context) __all__ = [ diff --git a/app/providers/artifacts/default.py b/app/providers/artifacts/default.py index f02edf1..25f2a28 100644 --- a/app/providers/artifacts/default.py +++ b/app/providers/artifacts/default.py @@ -19,6 +19,7 @@ PipelineContext, PipelineResult, build_artifact_manifest, + normalize_public_alignment_metadata, ) @@ -92,7 +93,12 @@ def _build_transcription(self, context: PipelineContext) -> dict | None: guard_report = context.transcription_result.get("hallucination_guard") if guard_report is not None: transcription["asr_hallucination_guard"] = guard_report - alignment_metadata = context.metadata.get("diarization", {}).get("alignment") + diarization_metadata = context.metadata.get("diarization", {}) + alignment_metadata = normalize_public_alignment_metadata( + diarization_metadata.get("alignment") + if isinstance(diarization_metadata, dict) + else None + ) if alignment_metadata: transcription["alignment"] = alignment_metadata if warning is not None: diff --git a/app/providers/asr/__init__.py b/app/providers/asr/__init__.py index ec9ba8e..8b6aea1 100644 --- a/app/providers/asr/__init__.py +++ b/app/providers/asr/__init__.py @@ -2,10 +2,10 @@ from __future__ import annotations -from typing import Any, cast +from typing import Any from pipeline.contracts import ASRProvider, ASRRequest, ASRResult -from pipeline.registry import resolve_provider +from providers._registry import require_default_provider from .default import PipelineMethodASRProvider, default_asr_provider @@ -17,9 +17,10 @@ def transcribe_audio( no_repeat_ngram_size: int | None = None, provider_name: str = "default", ) -> ASRResult: - """Compatibility helper around the selected ASR provider.""" + """Compatibility helper around the default ASR provider.""" - provider = cast(ASRProvider, resolve_provider("asr", provider_name)) + require_default_provider("asr", provider_name) + provider: ASRProvider = default_asr_provider request = ASRRequest( pipeline=pipeline, audio_path=audio_path, diff --git a/app/providers/capabilities.py b/app/providers/capabilities.py index bfb332d..1ea44f0 100644 --- a/app/providers/capabilities.py +++ b/app/providers/capabilities.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from typing import Literal -from pipeline.registry import canonical_step_name +from pipeline.step_keys import canonical_step_name, normalize_token ALL_LANGUAGES = "*" StageCriticality = Literal["required", "degradable", "optional"] @@ -16,6 +16,10 @@ class ProviderCapabilityError(RuntimeError): """Raised when a required provider capability is not satisfied.""" +class ProviderCapabilityNotFoundError(ProviderCapabilityError): + """Raised when no static metadata exists for a provider stage/name pair.""" + + @dataclass(frozen=True) class ProviderCapability: stage: str @@ -25,6 +29,7 @@ class ProviderCapability: stage_criticality: StageCriticality = "required" supports_rust_kernel: bool = False failure_policy: FailurePolicy = "hard_fail" + capability: str | None = None @dataclass(frozen=True) @@ -37,6 +42,34 @@ class CapabilityMatch: _DEFAULT_CAPABILITIES: dict[tuple[str, str], ProviderCapability] = { + ("ingest", "default"): ProviderCapability( + stage="ingest", + name="default", + supported_languages=frozenset({ALL_LANGUAGES}), + stage_criticality="required", + failure_policy="hard_fail", + ), + ("normalize", "default"): ProviderCapability( + stage="normalize", + name="default", + supported_languages=frozenset({ALL_LANGUAGES}), + stage_criticality="required", + failure_policy="hard_fail", + ), + ("enhance", "default"): ProviderCapability( + stage="enhance", + name="default", + supported_languages=frozenset({ALL_LANGUAGES}), + stage_criticality="optional", + failure_policy="skip", + ), + ("vad", "default"): ProviderCapability( + stage="vad", + name="default", + supported_languages=frozenset({ALL_LANGUAGES}), + stage_criticality="optional", + failure_policy="skip", + ), ("asr", "default"): ProviderCapability( stage="asr", name="default", @@ -45,12 +78,13 @@ class CapabilityMatch: failure_policy="hard_fail", ), ("alignment", "default"): ProviderCapability( - stage="alignment", + stage="diarization", name="default", supported_languages=frozenset({ALL_LANGUAGES}), disabled_languages=frozenset(), stage_criticality="degradable", failure_policy="skip", + capability="alignment", ), ("diarization", "default"): ProviderCapability( stage="diarization", @@ -70,13 +104,27 @@ class CapabilityMatch: stage="voiceprint_match", name="default", supported_languages=frozenset({ALL_LANGUAGES}), - stage_criticality="required", - failure_policy="hard_fail", + stage_criticality="optional", + failure_policy="skip", + ), + ("punc", "default"): ProviderCapability( + stage="punc", + name="default", + supported_languages=frozenset({ALL_LANGUAGES}), + stage_criticality="optional", + failure_policy="skip", ), ("postprocess", "default"): ProviderCapability( stage="postprocess", name="default", supported_languages=frozenset({ALL_LANGUAGES}), + stage_criticality="optional", + failure_policy="skip", + ), + ("artifacts", "default"): ProviderCapability( + stage="artifacts", + name="default", + supported_languages=frozenset({ALL_LANGUAGES}), stage_criticality="required", failure_policy="hard_fail", ), @@ -84,10 +132,10 @@ class CapabilityMatch: def _normalize_provider_name(name: str) -> str: - token = name.strip().lower().replace("-", "_") - if not token: - raise ProviderCapabilityError("provider name must not be empty") - return token + try: + return normalize_token(name, field_name="provider name") + except ValueError as exc: + raise ProviderCapabilityError(str(exc)) from exc def _normalize_language(language: str | None) -> str | None: @@ -105,7 +153,7 @@ def get_provider_capability(stage: str, name: str = "default") -> ProviderCapabi try: return _DEFAULT_CAPABILITIES[(stage_key, name_key)] except KeyError as exc: - raise ProviderCapabilityError( + raise ProviderCapabilityNotFoundError( f"No provider capability registered for stage={stage_key!r} " f"name={name_key!r}" ) from exc @@ -136,6 +184,8 @@ def match_provider_capability( "provider": capability.name, "criticality": capability.stage_criticality, } + if capability.capability is not None: + metadata["capability"] = capability.capability if normalized_language is not None: metadata["language"] = normalized_language diff --git a/app/providers/diarization/__init__.py b/app/providers/diarization/__init__.py index 242ee45..58a0f77 100644 --- a/app/providers/diarization/__init__.py +++ b/app/providers/diarization/__init__.py @@ -2,14 +2,14 @@ from __future__ import annotations -from typing import Any, cast +from typing import Any from pipeline.contracts import ( DiarizationProvider, DiarizationRequest, DiarizationResult, ) -from pipeline.registry import resolve_provider +from providers._registry import require_default_provider from .default import PipelineMethodDiarizationProvider, default_diarization_provider from .default import ( @@ -27,9 +27,10 @@ def run_diarization( max_speakers: int | None = None, provider_name: str = "default", ) -> DiarizationResult: - """Compatibility helper around the selected diarization provider.""" + """Compatibility helper around the default diarization provider.""" - provider = cast(DiarizationProvider, resolve_provider("diarization", provider_name)) + require_default_provider("diarization", provider_name) + provider: DiarizationProvider = default_diarization_provider request = DiarizationRequest( pipeline=pipeline, audio_path=audio_path, diff --git a/app/providers/diarization/default.py b/app/providers/diarization/default.py index 2afa3ab..14bbec7 100644 --- a/app/providers/diarization/default.py +++ b/app/providers/diarization/default.py @@ -8,22 +8,25 @@ import time from contextlib import contextmanager from collections.abc import Callable +from importlib import import_module from inspect import Parameter, signature from typing import Any from config import ( - WHISPERX_ALIGN_DEVICE, WHISPERX_ALIGN_CACHE_ONLY, + WHISPERX_ALIGN_DEVICE, WHISPERX_ALIGN_DISABLED_LANGUAGES, + WHISPERX_ALIGN_MAX_AUDIO_DURATION_SEC, WHISPERX_ALIGN_MODEL_DIR, WHISPERX_ALIGN_MODEL_MAP, ) +from infra.audio import audio_duration_seconds from pipeline.contracts import ( DiarizationProvider, DiarizationRequest, DiarizationResult, ) -from pipeline.stages.diarization.alignment import ( +from postprocess.alignment import ( build_aligned_segments, dedup_short_segments, ) @@ -108,6 +111,33 @@ def _alignment_disabled(language: str) -> bool: ) +def _alignment_duration_budget_metadata( + audio_path: str, + *, + language: str, + model_metadata: str | None, +) -> dict[str, Any] | None: + duration_s = audio_duration_seconds(audio_path) + if ( + WHISPERX_ALIGN_MAX_AUDIO_DURATION_SEC > 0 + and duration_s is not None + and duration_s > WHISPERX_ALIGN_MAX_AUDIO_DURATION_SEC + ): + return { + "status": "skipped", + "language": language, + "model": model_metadata, + "reason": "duration_budget_exceeded", + "duration_s": round(duration_s, 3), + "max_duration_s": WHISPERX_ALIGN_MAX_AUDIO_DURATION_SEC, + "actionable_hint": ( + "Increase WHISPERX_ALIGN_MAX_AUDIO_DURATION_SEC only after " + "validating memory headroom for forced alignment." + ), + } + return None + + def _resolve_alignment_device(pipeline) -> str: configured = (WHISPERX_ALIGN_DEVICE or "cpu").strip().lower() if configured in {"pipeline", "asr"}: @@ -162,11 +192,25 @@ def _cache_only_alignment_environment(): "HF_HUB_OFFLINE": os.environ.get("HF_HUB_OFFLINE"), "TRANSFORMERS_OFFLINE": os.environ.get("TRANSFORMERS_OFFLINE"), } + module_flags: list[tuple[object, str, object]] = [] os.environ["HF_HUB_OFFLINE"] = "1" os.environ["TRANSFORMERS_OFFLINE"] = "1" + for module_name, attr_name in ( + ("huggingface_hub.constants", "HF_HUB_OFFLINE"), + ("transformers.utils.hub", "_is_offline_mode"), + ): + try: + module = import_module(module_name) + except Exception: + continue + if hasattr(module, attr_name): + module_flags.append((module, attr_name, getattr(module, attr_name))) + setattr(module, attr_name, True) try: yield finally: + for module, attr_name, value in reversed(module_flags): + setattr(module, attr_name, value) for name, value in previous.items(): if value is None: os.environ.pop(name, None) @@ -249,8 +293,6 @@ def align_diarized_segments_with_metadata( ) -> tuple[list[dict[str, object]], dict[str, Any]]: """Align ASR output and attach diarization speaker labels.""" - import whisperx - segments = transcription_result.get("segments", []) language = _normalise_language(transcription_result.get("language")) model_source = ( @@ -274,7 +316,25 @@ def align_diarized_segments_with_metadata( ) return build_aligned_segments(segments, diarization_turns), metadata + budget_metadata = _alignment_duration_budget_metadata( + audio_path, + language=language, + model_metadata=model_metadata, + ) + if budget_metadata is not None: + logger.warning( + "WhisperX forced alignment skipped for language=%s reason=%s " + "duration_s=%.3f max_duration_s=%.3f", + language, + budget_metadata["reason"], + budget_metadata["duration_s"], + budget_metadata["max_duration_s"], + ) + return build_aligned_segments(segments, diarization_turns), budget_metadata + try: + import whisperx + preflight_message = _torch_preflight_message(language, model_name) if preflight_message: logger.info(preflight_message) diff --git a/app/providers/embedding/__init__.py b/app/providers/embedding/__init__.py index 76a9985..ba35cf1 100644 --- a/app/providers/embedding/__init__.py +++ b/app/providers/embedding/__init__.py @@ -2,14 +2,14 @@ from __future__ import annotations -from typing import Any, cast +from typing import Any from pipeline.contracts import ( SpeakerEmbeddingProvider, SpeakerEmbeddingRequest, SpeakerEmbeddingResult, ) -from pipeline.registry import resolve_provider +from providers._registry import require_default_provider from .default import ( PipelineMethodSpeakerEmbeddingProvider, @@ -24,11 +24,10 @@ def extract_speaker_embeddings( diarization_turns: list[dict[str, Any]], provider_name: str = "default", ) -> SpeakerEmbeddingResult: - """Compatibility helper around the selected embedding provider.""" + """Compatibility helper around the default embedding provider.""" - provider = cast( - SpeakerEmbeddingProvider, resolve_provider("embedding", provider_name) - ) + require_default_provider("embedding", provider_name) + provider: SpeakerEmbeddingProvider = default_speaker_embedding_provider request = SpeakerEmbeddingRequest( pipeline=pipeline, audio_path=audio_path, diff --git a/app/providers/embedding/default.py b/app/providers/embedding/default.py index 86f8b59..d29eaee 100644 --- a/app/providers/embedding/default.py +++ b/app/providers/embedding/default.py @@ -10,7 +10,12 @@ import torch import torchaudio -from config import MAX_EMBED_DURATION, MIN_EMBED_DURATION +from config import ( + EMBEDDING_PRELOAD_MAX_AUDIO_DURATION_SEC, + MAX_EMBED_DURATION, + MIN_EMBED_DURATION, +) +from infra.audio import audio_duration_seconds from pipeline.contracts import ( SpeakerEmbeddingProvider, SpeakerEmbeddingRequest, @@ -20,6 +25,23 @@ logger = logging.getLogger(__name__) +def _should_preload_full_waveform(audio_path: str) -> bool: + duration_s = audio_duration_seconds(audio_path) + if ( + EMBEDDING_PRELOAD_MAX_AUDIO_DURATION_SEC > 0 + and duration_s is not None + and duration_s > EMBEDDING_PRELOAD_MAX_AUDIO_DURATION_SEC + ): + logger.info( + "embedding_full_audio_preload_skipped reason=duration_budget_exceeded " + "duration_s=%.3f max_duration_s=%.3f", + duration_s, + EMBEDDING_PRELOAD_MAX_AUDIO_DURATION_SEC, + ) + return False + return True + + def _load_full_waveform(audio_path: str): """Load normalized audio once with libsndfile to avoid per-turn torch decode.""" @@ -44,13 +66,17 @@ def extract_embeddings_for_turns( """Extract averaged embeddings for each speaker cluster.""" waveform = None - try: - waveform, native_sr = _load_full_waveform(audio_path) - except Exception as exc: - logger.warning( - "Falling back to torchaudio segment loading for embedding audio: %s", - exc, - ) + if _should_preload_full_waveform(audio_path): + try: + waveform, native_sr = _load_full_waveform(audio_path) + except Exception as exc: + logger.warning( + "Falling back to torchaudio segment loading for embedding audio: %s", + exc, + ) + info = torchaudio.info(audio_path) + native_sr = info.sample_rate + else: info = torchaudio.info(audio_path) native_sr = info.sample_rate target_sr = 16000 diff --git a/app/providers/enhance/__init__.py b/app/providers/enhance/__init__.py index 42dad97..50b0f7c 100644 --- a/app/providers/enhance/__init__.py +++ b/app/providers/enhance/__init__.py @@ -3,14 +3,13 @@ from __future__ import annotations from pathlib import Path -from typing import cast from pipeline.contracts import ( AudioEnhancementProvider, AudioEnhancementRequest, AudioEnhancementResult, ) -from pipeline.registry import resolve_provider +from providers._registry import require_default_provider from .default import ( ConditionalDenoiseEnhancer, @@ -25,12 +24,10 @@ def enhance_audio( snr_threshold: float | None = None, provider_name: str = "default", ) -> AudioEnhancementResult: - """Run the selected enhancement provider and return the full contract result.""" + """Run the default enhancement provider and return the full contract result.""" - provider = cast( - AudioEnhancementProvider, - resolve_provider("enhance", provider_name), - ) + require_default_provider("enhance", provider_name) + provider: AudioEnhancementProvider = default_enhance_provider request = AudioEnhancementRequest( wav_path=wav_path, model=model, @@ -45,7 +42,7 @@ def maybe_denoise( snr_threshold: float | None = None, provider_name: str = "default", ) -> Path: - """Compatibility helper around the selected enhance provider.""" + """Compatibility helper around the default enhance provider.""" return enhance_audio( wav_path, diff --git a/app/providers/enhance/default.py b/app/providers/enhance/default.py index cfa1111..fcc9201 100644 --- a/app/providers/enhance/default.py +++ b/app/providers/enhance/default.py @@ -5,7 +5,12 @@ import logging import time -from config import DENOISE_MODEL, DENOISE_SNR_THRESHOLD +from config import ( + DENOISE_MAX_AUDIO_DURATION_SEC, + DENOISE_MODEL, + DENOISE_SNR_THRESHOLD, +) +from infra.audio import audio_duration_seconds from pipeline.contracts import ( AudioEnhancementProvider, AudioEnhancementRequest, @@ -64,6 +69,17 @@ def _estimate_snr(wav_path): return 10.0 * math.log10((speech_rms / noise_rms) ** 2) +def _duration_exceeds_denoise_budget(wav_path) -> tuple[bool, float | None]: + duration_s = audio_duration_seconds(wav_path) + if ( + DENOISE_MAX_AUDIO_DURATION_SEC > 0 + and duration_s is not None + and duration_s > DENOISE_MAX_AUDIO_DURATION_SEC + ): + return True, duration_s + return False, duration_s + + class ConditionalDenoiseEnhancer(AudioEnhancementProvider): """Apply denoising only when configured and warranted by the signal.""" @@ -84,6 +100,23 @@ def enhance(self, request: AudioEnhancementRequest) -> AudioEnhancementResult: ) out_path = request.wav_path.with_suffix(".denoised.wav") + if effective_model in {"deepfilternet", "noisereduce"}: + over_budget, duration_s = _duration_exceeds_denoise_budget(request.wav_path) + if over_budget: + logger.warning( + "Denoise skipped by duration budget " + "model=%s duration_s=%.3f max_duration_s=%.3f", + effective_model, + duration_s, + DENOISE_MAX_AUDIO_DURATION_SEC, + ) + return AudioEnhancementResult( + input_path=request.wav_path, + output_path=request.wav_path, + applied=False, + model=effective_model, + ) + if effective_model == "deepfilternet": import torch import torchaudio diff --git a/app/providers/ingest/__init__.py b/app/providers/ingest/__init__.py index 49226e4..130722a 100644 --- a/app/providers/ingest/__init__.py +++ b/app/providers/ingest/__init__.py @@ -2,19 +2,17 @@ from __future__ import annotations -from typing import cast - from pipeline.contracts import PipelineContext -from pipeline.registry import resolve_provider +from providers._registry import require_default_provider from .default import DefaultIngestProvider, default_ingest_provider def run_ingest(context: PipelineContext, provider_name: str = "default") -> None: - """Apply the selected ingest provider to the shared pipeline context.""" + """Apply the default ingest provider to the shared pipeline context.""" - provider = cast(DefaultIngestProvider, resolve_provider("ingest", provider_name)) - provider.run(context) + require_default_provider("ingest", provider_name) + default_ingest_provider.run(context) __all__ = [ diff --git a/app/providers/kernel_bridge/release_gates.py b/app/providers/kernel_bridge/release_gates.py index e207ec0..3ec5a42 100644 --- a/app/providers/kernel_bridge/release_gates.py +++ b/app/providers/kernel_bridge/release_gates.py @@ -25,13 +25,19 @@ REQUIRED_CI_GATES: Final = frozenset( { - "python_unit_security_tests", + "exact_ref_evidence", + "python_lint_format", + "python_unit_tests", + "python_security_scan", "kernel_bridge_smoke_tests", "rust_fmt", "rust_clippy", "rust_tests", - "rust_wheel_smoke", - "docker_packaging_smoke", + "rust_wheel_build", + "docker_build_with_wheel", + "container_rust_extension_smoke", + "container_healthz_smoke", + "docker_tags_source_ref", "public_release_scan", } ) @@ -113,7 +119,7 @@ class RustKernelReleaseGate: RustKernelReleaseGate( name="status_payload_contract", bridge_function="status_payload_contract", - python_owner="pipeline.contracts.status.build_status_payload", + python_owner="infra.job_status.build_status_payload", rust_owner="voscript_core::contracts::status_payload_contract", rollback=RUST_KERNEL_MODE_ROLLBACK, regression_matrix=( @@ -163,7 +169,7 @@ def validate_release_gate_matrix( if gate.rollback != RUST_KERNEL_MODE_ROLLBACK: gaps.append(f"{gate.name}: rollback must be {RUST_KERNEL_MODE_ROLLBACK}") if gate.public_api_change: - gaps.append(f"{gate.name}: public API change is not allowed in 0.8.4") + gaps.append(f"{gate.name}: public API change is not allowed in 0.8.x") return tuple(gaps) diff --git a/app/providers/normalize/__init__.py b/app/providers/normalize/__init__.py index 687dfcd..14054f6 100644 --- a/app/providers/normalize/__init__.py +++ b/app/providers/normalize/__init__.py @@ -3,14 +3,13 @@ from __future__ import annotations from pathlib import Path -from typing import cast from pipeline.contracts import ( AudioNormalizationRequest, AudioNormalizationResult, InputNormalizationProvider, ) -from pipeline.registry import resolve_provider +from providers._registry import require_default_provider from .default import ( FFmpegInputNormalizer, @@ -22,18 +21,16 @@ def normalize_audio( input_path: Path, provider_name: str = "default" ) -> AudioNormalizationResult: - """Run the selected normalize provider and return the full contract result.""" + """Run the default normalize provider and return the full contract result.""" - provider = cast( - InputNormalizationProvider, - resolve_provider("normalize", provider_name), - ) + require_default_provider("normalize", provider_name) + provider: InputNormalizationProvider = default_normalize_provider request = AudioNormalizationRequest(input_path=input_path) return provider.normalize(request) def convert_to_wav(input_path: Path, provider_name: str = "default") -> Path: - """Compatibility helper around the selected normalize provider.""" + """Compatibility helper around the default normalize provider.""" return normalize_audio(input_path, provider_name=provider_name).normalized_path diff --git a/app/providers/normalize/default.py b/app/providers/normalize/default.py index 5ef6b61..6a90589 100644 --- a/app/providers/normalize/default.py +++ b/app/providers/normalize/default.py @@ -5,10 +5,9 @@ import logging import subprocess -from fastapi import HTTPException - from config import FFMPEG_TIMEOUT_SEC from pipeline.contracts import ( + AudioNormalizationTimeoutError, AudioNormalizationRequest, AudioNormalizationResult, InputNormalizationProvider, @@ -60,7 +59,9 @@ def normalize(self, request: AudioNormalizationRequest) -> AudioNormalizationRes FFMPEG_TIMEOUT_SEC, input_path.name, ) - raise HTTPException(504, f"ffmpeg timed out after {FFMPEG_TIMEOUT_SEC}s") + raise AudioNormalizationTimeoutError( + f"ffmpeg timed out after {FFMPEG_TIMEOUT_SEC}s" + ) return AudioNormalizationResult( source_path=input_path, diff --git a/app/providers/postprocess/__init__.py b/app/providers/postprocess/__init__.py index 987eaf5..b133c72 100644 --- a/app/providers/postprocess/__init__.py +++ b/app/providers/postprocess/__init__.py @@ -2,22 +2,17 @@ from __future__ import annotations -from typing import cast - from pipeline.contracts import PipelineContext -from pipeline.registry import resolve_provider +from providers._registry import require_default_provider from .default import DefaultPostprocessProvider, default_postprocess_provider def run_postprocess(context: PipelineContext, provider_name: str = "default") -> None: - """Apply the selected post-process provider to the shared context.""" + """Apply the default post-process provider to the shared context.""" - provider = cast( - DefaultPostprocessProvider, - resolve_provider("postprocess", provider_name), - ) - provider.run(context) + require_default_provider("postprocess", provider_name) + default_postprocess_provider.run(context) __all__ = [ diff --git a/app/providers/punc/__init__.py b/app/providers/punc/__init__.py index dc4a97b..a16263a 100644 --- a/app/providers/punc/__init__.py +++ b/app/providers/punc/__init__.py @@ -2,19 +2,17 @@ from __future__ import annotations -from typing import cast - from pipeline.contracts import PipelineContext -from pipeline.registry import resolve_provider +from providers._registry import require_default_provider from .default import DefaultPunctuationProvider, default_punc_provider def run_punc(context: PipelineContext, provider_name: str = "default") -> None: - """Apply the selected punctuation provider to the shared context.""" + """Apply the default punctuation provider to the shared context.""" - provider = cast(DefaultPunctuationProvider, resolve_provider("punc", provider_name)) - provider.run(context) + require_default_provider("punc", provider_name) + default_punc_provider.run(context) __all__ = [ diff --git a/app/providers/vad/__init__.py b/app/providers/vad/__init__.py index c442670..4861cf7 100644 --- a/app/providers/vad/__init__.py +++ b/app/providers/vad/__init__.py @@ -2,19 +2,17 @@ from __future__ import annotations -from typing import cast - from pipeline.contracts import PipelineContext -from pipeline.registry import resolve_provider +from providers._registry import require_default_provider from .default import DefaultVADProvider, default_vad_provider def run_vad(context: PipelineContext, provider_name: str = "default") -> None: - """Apply the selected VAD provider to the shared pipeline context.""" + """Apply the default VAD provider to the shared pipeline context.""" - provider = cast(DefaultVADProvider, resolve_provider("vad", provider_name)) - provider.run(context) + require_default_provider("vad", provider_name) + default_vad_provider.run(context) __all__ = [ diff --git a/app/providers/voiceprint_match/__init__.py b/app/providers/voiceprint_match/__init__.py index cabb4fb..25c9fcb 100644 --- a/app/providers/voiceprint_match/__init__.py +++ b/app/providers/voiceprint_match/__init__.py @@ -2,14 +2,14 @@ from __future__ import annotations -from typing import Any, cast +from typing import Any from pipeline.contracts import ( VoiceprintMatchProvider, VoiceprintMatchRequest, VoiceprintMatchResult, ) -from pipeline.registry import resolve_provider +from providers._registry import require_default_provider from .default import DefaultVoiceprintMatchProvider, default_voiceprint_match_provider @@ -20,12 +20,10 @@ def match_speaker_embeddings( threshold: float | None = None, provider_name: str = "default", ) -> VoiceprintMatchResult: - """Compatibility helper around the selected voiceprint matcher.""" + """Compatibility helper around the default voiceprint matcher.""" - provider = cast( - VoiceprintMatchProvider, - resolve_provider("voiceprint_match", provider_name), - ) + require_default_provider("voiceprint_match", provider_name) + provider: VoiceprintMatchProvider = default_voiceprint_match_provider request = VoiceprintMatchRequest( speaker_embeddings=speaker_embeddings, voiceprint_db=voiceprint_db, diff --git a/crates/voscript_core/Cargo.toml b/crates/voscript_core/Cargo.toml index 50b0167..abd7082 100644 --- a/crates/voscript_core/Cargo.toml +++ b/crates/voscript_core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "voscript_core" -version = "0.8.4" +version = "0.8.5" edition = "2021" license = "Apache-2.0" publish = false diff --git a/crates/voscript_core/src/lib.rs b/crates/voscript_core/src/lib.rs index 7bfd382..83e9bf1 100644 --- a/crates/voscript_core/src/lib.rs +++ b/crates/voscript_core/src/lib.rs @@ -447,7 +447,7 @@ fn voscript_core(module: &Bound<'_, PyModule>) -> PyResult<()> { mod tests { #[test] fn package_version_is_set() { - assert_eq!(super::PACKAGE_VERSION, "0.8.4"); + assert_eq!(super::PACKAGE_VERSION, "0.8.5"); } #[test] diff --git a/doc/api.en.md b/doc/api.en.md index 417b413..77e4992 100644 --- a/doc/api.en.md +++ b/doc/api.en.md @@ -59,7 +59,7 @@ Form fields: | `min_speakers` | int | Optional, `0` = auto | | `max_speakers` | int | Optional, `0` = auto | | `denoise_model` | string | Optional. Noise reduction backend: `none`, `deepfilternet`, `noisereduce`. When omitted, the server uses `DENOISE_MODEL` (default `none`). Sending `none` explicitly disables denoising for this request only. | -| `snr_threshold` | float | Optional. DeepFilterNet SNR gate threshold (dB) for this request only. When `deepfilternet` is selected, audio at or above this level skips DeepFilterNet. Overrides `DENOISE_SNR_THRESHOLD` (default `10.0`); `noisereduce` does not use this gate. | +| `snr_threshold` | float | Optional. DeepFilterNet SNR gate threshold (dB) for this request only. When `deepfilternet` is selected, audio at or above this level skips DeepFilterNet. Overrides `DENOISE_SNR_THRESHOLD` (default `10.0`); `noisereduce` is not SNR-gated but still respects `DENOISE_MAX_AUDIO_DURATION_SEC`. | | `no_repeat_ngram_size` | int | Optional, default `0` (disabled). When ≥ 3, suppresses n-gram repetitions in the transcript (e.g. "like like like" → "like"). Values < 3 are treated as `0`. Non-integer values return 422. | Response (200): @@ -108,6 +108,8 @@ The partial file is deleted from `data/uploads/`. Lower the cap in - `503 Failed to persist job state — disk error, retry later` - `503 Failed to start background transcription — retry later` +- `503 Transcription data disk free space below admission budget (...)` +- `503 Transcription active/in-flight job budget exceeded (...)` Example: @@ -124,7 +126,7 @@ practice, omit `denoise_model` to inherit `DENOISE_MODEL`, send `denoise_model=none` to disable denoising for one request, and send `snr_threshold` only when this job needs a threshold different from `DENOISE_SNR_THRESHOLD`. That threshold only affects `deepfilternet`; -`noisereduce` runs directly whenever selected. +`noisereduce` is not SNR-gated, but it still respects the service `DENOISE_MAX_AUDIO_DURATION_SEC` duration budget. ### `GET /api/jobs/{id}` — poll a job @@ -204,9 +206,7 @@ practice, omit `denoise_model` to inherit `DENOISE_MODEL`, send }, "alignment": { "status": "succeeded", - "language": "en", "model": null, - "model_source": "whisperx_default", "cache_only": false } } @@ -322,6 +322,17 @@ aggregation fields for UI / downstream consumers: | `unique_speakers` | array[string] | Deduplicated list of speaker names, recalculated from the persisted `segments[].speaker_name` values to reflect the latest manual corrections | | `artifacts` | object | Optional artifact manifest for stable / optional / experimental artifacts; clients must tolerate it being absent | +Unlike `GET /api/jobs/{id}`, this endpoint always reads the persisted result from +disk, so it remains available after restarts and reflects the latest manual +corrections. + +### `GET /api/transcriptions/{tr_id}/audio` — download original uploaded audio + +Returns the original upload for this transcription. The filename comes from the +persisted result's `filename`, and the service only returns an existing file +under `data/uploads/`. Missing transcriptions or original audio files return +404. + ### `GET /api/export/{tr_id}` Query `format=srt | txt | json`. Returns the file as a download. @@ -330,6 +341,7 @@ Query `format=srt | txt | json`. Returns the file as a download. ``` GET /api/voiceprints +GET /api/voiceprints/{speaker_id} POST /api/voiceprints/enroll PUT /api/voiceprints/{speaker_id}/name DELETE /api/voiceprints/{speaker_id} @@ -346,6 +358,11 @@ DELETE /api/voiceprints/{speaker_id} ] ``` +#### `GET /api/voiceprints/{speaker_id}` + +Returns a single enrolled voiceprint. Missing speakers return +`404 Speaker not found`. + #### `POST /api/voiceprints/enroll` > **Note (enroll idempotency)**: `add_speaker` now deduplicates by `name` — re-enrolling a speaker with the same name merges the new embedding into the existing record rather than creating a duplicate. @@ -464,7 +481,7 @@ Errors: | 401 | Missing or wrong API key | | 404 | Unknown tr_id / speaker_id / missing embedding | | 413 | Upload exceeded `MAX_UPLOAD_BYTES` (default 2 GiB) — see `/api/transcribe` | -| 503 | Failed to persist initial `queued` status or failed to start the background transcription thread | +| 503 | Admission rejected the upload before processing, failed to persist initial `queued` status, or failed to start the background transcription thread | | 500 | Server-side exception (check `docker logs voscript`) | | 504 | ffmpeg transcoding timed out (exceeded `FFMPEG_TIMEOUT_SEC`, default 1800 s) | diff --git a/doc/api.zh.md b/doc/api.zh.md index 635544e..5504309 100644 --- a/doc/api.zh.md +++ b/doc/api.zh.md @@ -57,7 +57,7 @@ curl http://localhost:8780/healthz | `min_speakers` | int | 选填,`0` 表示自动 | | `max_speakers` | int | 选填,`0` 表示自动 | | `denoise_model` | string | 选填。降噪后端:`none`、`deepfilternet`、`noisereduce`。省略时使用服务端 `DENOISE_MODEL`(默认 `none`);显式传 `none` 表示只对本次请求关闭降噪。 | -| `snr_threshold` | float | 选填。DeepFilterNet 信噪比门限(dB),仅对本次请求生效。选择 `deepfilternet` 时,音频信噪比达到或超过此值会跳过 DeepFilterNet。覆盖 `DENOISE_SNR_THRESHOLD`(默认 `10.0`);`noisereduce` 不使用该 gate。 | +| `snr_threshold` | float | 选填。DeepFilterNet 信噪比门限(dB),仅对本次请求生效。选择 `deepfilternet` 时,音频信噪比达到或超过此值会跳过 DeepFilterNet。覆盖 `DENOISE_SNR_THRESHOLD`(默认 `10.0`);`noisereduce` 不受 SNR gate 控制,但仍受 `DENOISE_MAX_AUDIO_DURATION_SEC` 限制。 | | `no_repeat_ngram_size` | int | 选填,默认 `0`(不开启)。设置 ≥ 3 时抑制转录中的 n-gram 重复(如「比如比如」→「比如」)。值 < 3 等同于 `0`。非整数返回 422。 | 响应(200): @@ -103,6 +103,8 @@ curl http://localhost:8780/healthz - `503 Failed to persist job state — disk error, retry later` - `503 Failed to start background transcription — retry later` +- `503 Transcription data disk free space below admission budget (...)` +- `503 Transcription active/in-flight job budget exceeded (...)` 示例: @@ -119,7 +121,7 @@ curl -X POST http://localhost:8780/api/transcribe \ `denoise_model` 表示继承 `DENOISE_MODEL`;传 `denoise_model=none` 表示本次请求关闭降噪; 只有当单个任务需要不同门限时才传 `snr_threshold`,它会覆盖 `DENOISE_SNR_THRESHOLD`。这个门限只影响 `deepfilternet`;选择 `noisereduce` 时, -该后端会直接运行。 +该后端不受 SNR gate 控制,但仍受服务级 `DENOISE_MAX_AUDIO_DURATION_SEC` 时长预算限制。 ### `GET /api/jobs/{id}` — 查询任务 @@ -199,9 +201,7 @@ curl -X POST http://localhost:8780/api/transcribe \ }, "alignment": { "status": "succeeded", - "language": "zh", "model": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn", - "model_source": "whisperx_default", "cache_only": false } } @@ -294,6 +294,11 @@ artifact。当前稳定项包括主结果 `result.json` 和每个说话人 clust 与 `GET /api/jobs/{id}` 不同,本端点始终从磁盘读取持久化结果,**进程重启后仍可访问**, 也能反映最新的人工纠错;`/api/jobs/{id}` 优先读内存,内存未命中时才回落到磁盘(见上方注意事项)。 +### `GET /api/transcriptions/{tr_id}/audio` — 下载原始上传音频 + +返回该转录对应的原始上传文件。文件名来自持久化结果中的 `filename`,服务端只返回 +`data/uploads/` 下实际存在的上传文件;转录或原始音频不存在时返回 404。 + ### `GET /api/export/{tr_id}` — 导出 query `format=srt | txt | json`。返回对应格式的下载响应。 @@ -302,6 +307,7 @@ query `format=srt | txt | json`。返回对应格式的下载响应。 ``` GET /api/voiceprints +GET /api/voiceprints/{speaker_id} POST /api/voiceprints/enroll PUT /api/voiceprints/{speaker_id}/name DELETE /api/voiceprints/{speaker_id} @@ -318,6 +324,10 @@ DELETE /api/voiceprints/{speaker_id} ] ``` +#### `GET /api/voiceprints/{speaker_id}` + +返回单个已登记声纹;不存在时返回 `404 Speaker not found`。 + #### `POST /api/voiceprints/enroll` > **注意(enroll 幂等性)**:`add_speaker` 按 `name` 自动去重——同名的二次 enroll 会把新 embedding 合并到已有记录,**不会**再产生重复条目。 @@ -430,7 +440,7 @@ embedding,或源数量少于当前 cohort,后台线程会保留现有 `asnor | 401 | 缺 API key / key 不对 | | 404 | tr_id / speaker_id / embedding 不存在 | | 413 | 上传超过 `MAX_UPLOAD_BYTES`(默认 2 GiB),详见 `/api/transcribe` | -| 503 | 初始 `queued` 状态落盘失败,或后台转录线程启动失败 | +| 503 | admission 在处理前拒绝上传、初始 `queued` 状态落盘失败,或后台转录线程启动失败 | | 500 | 服务端异常(看 `docker logs voscript`) | | 504 | ffmpeg 转码超时(超过 `FFMPEG_TIMEOUT_SEC`,默认 1800 秒) | diff --git a/doc/changelog.en.md b/doc/changelog.en.md index 5257938..01b8dd4 100644 --- a/doc/changelog.en.md +++ b/doc/changelog.en.md @@ -6,6 +6,37 @@ No unreleased changes. +## 0.8.5 — Rust required by default and real deployment validation (2026-06-13) + +### Reliability + +- Changed the code default and Docker Compose default for `RUST_KERNEL_MODE` to + `required`. Selected Rust-backed voiceprint scoring, result post-processing, + and artifact helper paths must now import and run successfully by default; + missing `voscript_core` or bridge call failures fail closed. +- Kept explicit `RUST_KERNEL_MODE=off` as a rollback path for deployments that + intentionally need to return to Python implementations. It is no longer the + default runtime posture. +- Fixed WhisperX alignment cache-only mode when HuggingFace / Transformers + modules were already imported. The cache-only scope now updates their module + offline flags and restores the previous state afterward. + +### CI + +- PR updates now retrigger the Rust foundation heavy gate so follow-up commits + cannot bypass the Rust wheel and Docker packaging smoke. +- Release and Rust-heavy container health smokes explicitly run with + `RUST_KERNEL_MODE=required`. +- E2E no longer falls back to an older completed transcription by default; set + `VOSCRIPT_E2E_ALLOW_FALLBACK=1` only when that compatibility behavior is + intentional. + +### Validation + +- A real deployment host validated the latest image with forced Rust, cache-only + alignment, idle-unload cold reload, missing-Rust negative cases, disk + admission 503, and orphan job recovery. + ## 0.8.4 — Rust kernel foundation and release gates (2026-06-10) ### Features @@ -17,23 +48,27 @@ No unreleased changes. `result.json` primary view, and must treat unknown or missing `artifacts` fields as compatible. - Added the optional Rust kernel bridge foundation and `RUST_KERNEL_MODE`. - The default `off` keeps current Python implementations; `required` makes - selected Rust-backed paths import and execute successfully or fail closed. + In 0.8.4 this was introduced as an opt-in bridge; v0.8.5 changes the + runtime default to `required` so selected Rust-backed paths import and + execute successfully or fail closed unless operators explicitly choose the + Python rollback mode. - Added an optional Rust-backed voiceprint scoring kernel for explicit - `RUST_KERNEL_MODE=required` runs. The default remains Python scoring, and the - public speaker/voiceprint result contract is unchanged. + `RUST_KERNEL_MODE=required` runs. Python scoring stays available through the + rollback mode, and the public speaker/voiceprint result contract is + unchanged. - Added optional Rust-backed result post-processing for explicit - `RUST_KERNEL_MODE=required` runs. The default remains Python post-processing; - result segments keep stable `speaker_label` values, duplicate display names - are disambiguated instead of merged, and `segments[].words` remains optional. + `RUST_KERNEL_MODE=required` runs. Python post-processing stays available + through the rollback mode; result segments keep stable `speaker_label` + values, duplicate display names are disambiguated instead of merged, and + `segments[].words` remains optional. - Added artifact/status/schema helper contracts. Artifact manifests are normalized through public-safe stable / optional / experimental categories, persisted status payloads share one contract helper, and schema versions stay optional-first for legacy `result.json` / `status.json` compatibility. - Added optional Rust-backed artifact manifest helper validation for explicit - `RUST_KERNEL_MODE=required` runs. The default remains Python-owned contract - assembly; selected Rust helper validation must import and run successfully or - fail closed. + `RUST_KERNEL_MODE=required` runs. Python-owned contract assembly stays + available through the rollback mode; selected Rust helper validation must + import and run successfully or fail closed. ### Security diff --git a/doc/changelog.zh.md b/doc/changelog.zh.md index c95227f..446e965 100644 --- a/doc/changelog.zh.md +++ b/doc/changelog.zh.md @@ -6,6 +6,32 @@ 暂无未发布变更。 +## 0.8.5 — Rust 默认必需与实体部署验收 (2026-06-13) + +### 可靠性 + +- 将 `RUST_KERNEL_MODE` 的代码默认值和 Docker Compose 默认值改为 `required`。 + 这意味着被选择的 Rust-backed 声纹计分、结果后处理和 artifact helper + 路径默认必须成功导入并执行;缺少 `voscript_core` 或调用失败时会 fail closed。 +- 保留显式 `RUST_KERNEL_MODE=off` 作为 rollback 路径,用于确认需要临时回退到 + Python 实现的部署,不再作为默认运行口径。 +- 修复 WhisperX alignment cache-only 模式下已导入 HuggingFace / Transformers + 模块仍可能尝试联网的问题。cache-only alignment 现在会同步更新相关模块级 + offline flag,并在退出后恢复原状态。 + +### CI + +- PR 更新会重新触发 Rust foundation heavy gate,避免后续提交绕过 Rust wheel + 与 Docker packaging smoke。 +- release / Rust heavy 的容器 health smoke 显式使用 `RUST_KERNEL_MODE=required`。 +- E2E 默认不再 fallback 到历史已完成转写;只有显式设置 + `VOSCRIPT_E2E_ALLOW_FALLBACK=1` 才允许旧行为。 + +### 验证 + +- 实体部署环境已用最新镜像验证:强制 Rust、cache-only alignment、idle unload + 后重新冷加载、Rust 缺失负向、磁盘 admission 503 和 orphan job recovery。 + ## 0.8.4 — Rust kernel 基础能力与发布门禁 (2026-06-10) ### 功能 @@ -14,21 +40,22 @@ 的类别、角色、文件名、媒体类型和 `speaker_label`,不暴露本地路径、job 运行路径、 host、token 或调试信息。默认客户端仍只需读取 `result.json` 主视图;未知或缺失 `artifacts` 字段必须被视为兼容。 -- 新增可选 Rust kernel bridge 基础能力与 `RUST_KERNEL_MODE`。默认 `off` 保持当前 - Python 实现;显式设为 `required` 时,被选择的 Rust-backed 路径必须可导入并执行, - 否则 fail closed。 +- 新增可选 Rust kernel bridge 基础能力与 `RUST_KERNEL_MODE`。0.8.4 引入时它是 + opt-in bridge;v0.8.5 已将运行默认值改为 `required`,因此被选择的 + Rust-backed 路径必须可导入并执行,否则 fail closed,除非 operator 显式选择 + Python rollback 模式。 - 新增显式 `RUST_KERNEL_MODE=required` 下可选的 Rust-backed 声纹计分 kernel。 - 默认仍使用 Python 计分,公开 speaker / voiceprint 结果契约不变。 + Python 计分仍可通过 rollback 模式使用,公开 speaker / voiceprint 结果契约不变。 - 新增显式 `RUST_KERNEL_MODE=required` 下可选的 Rust-backed 结果后处理。 - 默认仍使用 Python 后处理;结果 segment 继续保留稳定 `speaker_label`, + Python 后处理仍可通过 rollback 模式使用;结果 segment 继续保留稳定 `speaker_label`, 重名展示名只做序号消歧而不合并 speaker,`segments[].words` 仍是可选字段。 - 新增 artifact/status/schema helper contract。Artifact manifest 会通过 stable / optional / experimental 三类 public-safe 结构归一化;持久化 status payload 统一由 contract helper 构建;schema version 继续保持 optional-first,以兼容旧 `result.json` / `status.json`。 - 新增显式 `RUST_KERNEL_MODE=required` 下可选的 Rust-backed artifact - manifest helper 校验。默认仍由 Python 组装 contract;被选择的 Rust - helper 校验必须成功导入并执行,否则 fail closed。 + manifest helper 校验。Python-owned contract assembly 仍可通过 rollback 模式使用; + 被选择的 Rust helper 校验必须成功导入并执行,否则 fail closed。 ### 安全 diff --git a/doc/configuration.en.md b/doc/configuration.en.md index 5b144b9..bf79e3b 100644 --- a/doc/configuration.en.md +++ b/doc/configuration.en.md @@ -2,7 +2,7 @@ [简体中文](./configuration.zh.md) | **English** -This is the public configuration index for VoScript v0.8.4. It covers the +This is the public configuration index for VoScript v0.8.5. It covers the environment variables that the current code reads, the per-request override semantics of `POST /api/transcribe`, and internal defaults that are documented for operators but are not stable public knobs yet. Do not assume a Whisper, @@ -37,10 +37,13 @@ parameters yet. | `CUDA_VISIBLE_DEVICES` | unset | Optional NVIDIA visibility limit. By default this variable is not injected and compose requests every Docker-exposed GPU. Add it only through `docker-compose.override.yml` or another explicit operator env override when you need to restrict the visible GPU set; inside the container, `cuda:0` is the first visible GPU and may not be physical host GPU0. For CPU-only mode, set `DEVICE=cpu`. | | `FFMPEG_TIMEOUT_SEC` | `1800` | ffmpeg conversion timeout in seconds; timeout returns `504`. | | `JOBS_MAX_CACHE` | `200` | In-memory job LRU limit. Evicted completed jobs remain queryable from disk `status.json` / `result.json`. | +| `TRANSCRIPTION_MAX_ACTIVE_JOBS` | `200` | Admission limit for active transcription jobs. Non-terminal jobs such as `queued`, `converting`, `denoising`, `transcribing`, and `identifying` count toward it. When full, the service returns `503` before starting a background thread. Set `0` to disable this budget. | +| `TRANSCRIPTION_MAX_IN_FLIGHT_JOBS` | `4` | Limit for concurrently processing unique audio hashes before GPU/model work. When full, the service returns `503` before starting a background thread. Set `0` to disable this budget. | +| `TRANSCRIPTION_MIN_FREE_DISK_BYTES` | `1073741824` | Minimum free bytes required on the `DATA_DIR` disk after upload save/hash and dedup checks. When the remaining free space is below this budget, the service removes the saved upload and returns `503` before reserving an active job or starting a background thread. Set `0` to disable this budget. | | `MODEL_IDLE_TIMEOUT_SEC` | `180` | GPU model idle-unload timeout, defaulting to 180 seconds (3 minutes). Set `0` to disable idle unload and keep models resident. When enabled, loaded models are released only after the serialized GPU runtime has been idle for this many seconds; on the next reload, ASR, diarization, and embedding each choose the visible CUDA device with the most free memory during their own lazy load. | -| `RUST_KERNEL_MODE` | `off` | Optional Rust-backed provider/kernel mode. `off` keeps Python implementations; `required` makes selected Rust-backed paths import and run successfully or fail closed. The current selected paths are voiceprint scoring, result post-processing, and artifact manifest helper contracts; CI / Docker packaging still validates the Rust extension directly when the runtime default is off. | +| `RUST_KERNEL_MODE` | `required` | Rust-backed provider/kernel paths are required by default. `required` makes selected Rust-backed paths import and run successfully or fail closed; `off` is an explicit rollback to Python implementations. The current selected paths are voiceprint scoring, result post-processing, and artifact manifest helper contracts; CI / Docker packaging validates the Rust extension directly. | -`MODELS_DIR` and `LANGUAGE` are defined in the config module, but v0.8.4's main +`MODELS_DIR` and `LANGUAGE` are defined in the config module, but v0.8.5's main HTTP transcription path does not use them as stable public tuning knobs: Whisper local checkpoint lookup still expects `/models/faster-whisper-`, and default language should be controlled with the request `language` field or @@ -97,7 +100,7 @@ cache is incomplete. Current internal ASR defaults are `beam_size=5`, `vad_filter=True`, `vad_parameters.min_silence_duration_ms=500`, and `condition_on_previous_text=False`. -These do not have env or API fields in v0.8.4. Do not configure nonexistent +These do not have env or API fields in v0.8.5. Do not configure nonexistent variables such as `WHISPER_BEAM_SIZE`, `WHISPER_COMPUTE_TYPE`, or `WHISPER_VAD_*`. ## Denoising @@ -105,14 +108,15 @@ variables such as `WHISPER_BEAM_SIZE`, `WHISPER_COMPUTE_TYPE`, or `WHISPER_VAD_* | Setting | Default | Effect | | --- | --- | --- | | `DENOISE_MODEL` | `none` | Service default backend: `none`, `deepfilternet`, or `noisereduce`. Unknown values log a warning and skip denoising. | -| `DENOISE_SNR_THRESHOLD` | `10.0` | DeepFilterNet SNR gate in dB. When `deepfilternet` is selected, audio estimated at or above this value is skipped to avoid degrading clean recordings; `noisereduce` does not use this gate. | +| `DENOISE_SNR_THRESHOLD` | `10.0` | DeepFilterNet SNR gate in dB. When `deepfilternet` is selected, audio estimated at or above this value is skipped to avoid degrading clean recordings; `noisereduce` is not SNR-gated but still respects `DENOISE_MAX_AUDIO_DURATION_SEC`. | +| `DENOISE_MAX_AUDIO_DURATION_SEC` | `7200` | Full-audio duration cap for optional denoising. Longer audio skips `deepfilternet` / `noisereduce` to avoid loading very long files into memory at once. Set `0` to disable this budget. | | API `denoise_model` | omitted | Omitted means inherit `DENOISE_MODEL`; explicit `none` disables denoising for this job only. | | API `snr_threshold` | omitted | Omitted means inherit `DENOISE_SNR_THRESHOLD`; explicit values override the DeepFilterNet SNR gate for this job only. | -v0.8.4 defaults to `DENOISE_MODEL=none` for clean meeting-recorder audio. Enable +v0.8.5 defaults to `DENOISE_MODEL=none` for clean meeting-recorder audio. Enable `deepfilternet` or `noisereduce` only for noisy environments, either per job or as a service default. If you need clean recordings to be skipped automatically, -use `deepfilternet`; `noisereduce` runs whenever it is selected. +use `deepfilternet`; `noisereduce` is not SNR-gated but still respects `DENOISE_MAX_AUDIO_DURATION_SEC`. ## Diarization and Alignment @@ -125,6 +129,7 @@ use `deepfilternet`; `noisereduce` runs whenever it is selected. | `WHISPERX_ALIGN_MODEL_MAP` | empty | Comma-separated `lang=model` overrides, for example `zh=org/model`. | | `WHISPERX_ALIGN_MODEL_DIR` | empty | Optional alignment model directory; passed through only when the installed WhisperX supports that parameter. | | `WHISPERX_ALIGN_CACHE_ONLY` | `0` | When `1`, requests cache-only alignment model loading, only when supported by the installed WhisperX. | +| `WHISPERX_ALIGN_MAX_AUDIO_DURATION_SEC` | `7200` | Full-audio duration cap for WhisperX forced alignment. Longer audio returns `skipped` alignment metadata and does not call `whisperx.load_audio`. Set `0` to disable this budget. | Alignment is optional metadata. On success, results may include `alignment.status=succeeded` and `segments[].words`. If disabled or failed, the @@ -138,6 +143,7 @@ job still completes; `words` may be absent and `alignment` records `skipped` or | `EMBEDDING_DIM` | `256` | Voiceprint vector dimension used for DB and AS-norm cohort shape checks. Do not mix existing stores across dimensions. | | `MIN_EMBED_DURATION` | `1.5` | Diarization turns shorter than this are ignored for speaker embedding extraction. | | `MAX_EMBED_DURATION` | `10.0` | Longer turns are clipped to this window before embedding extraction. | +| `EMBEDDING_PRELOAD_MAX_AUDIO_DURATION_SEC` | `1800` | Full-audio preload cap for speaker embedding. Longer audio skips the single `soundfile.read` preload and reads per diarization turn instead. Set `0` to disable this budget. | Each speaker cluster uses up to the 10 longest usable chunks to produce an averaged embedding. Very short, fragmented, or noisy turns reduce enrollment and @@ -171,7 +177,7 @@ Cohort lifecycle: files to build and save a cohort. - After each enroll / update, the background `cohort-rebuild` thread wakes every 60 seconds and rebuilds after the latest enrollment is at least 30 seconds old. -- v0.8.4 protects larger loaded or persisted cohorts during automatic rebuilds: +- v0.8.5 protects larger loaded or persisted cohorts during automatic rebuilds: clearing transcription results, having only a few embeddings, or having fewer source embeddings than the current cohort will not shrink the cohort automatically. - `POST /api/voiceprints/rebuild-cohort` is an explicit manual rebuild and uses @@ -206,9 +212,9 @@ New fields are added under the optional-field principle. Clients should ignore unknown fields and tolerate missing `words`, `alignment`, `artifacts`, and `warning`. -## v0.8.4 Validation Wording +## v0.8.5 Validation Wording -v0.8.4 has internal live validation covering the optional Rust kernel foundation, +v0.8.5 has internal live validation covering the required-by-default Rust kernel foundation, selected voiceprint scoring, result post-processing, artifact/status helper contracts, Docker packaging smoke, and public release-scan gates. Public documentation records only these behavior categories, not real task names, diff --git a/doc/configuration.zh.md b/doc/configuration.zh.md index 05bd2e5..85101ba 100644 --- a/doc/configuration.zh.md +++ b/doc/configuration.zh.md @@ -2,7 +2,7 @@ **简体中文** | [English](./configuration.en.md) -本文是 VoScript v0.8.4 的公开配置索引,覆盖当前代码已经读取并生效的 +本文是 VoScript v0.8.5 的公开配置索引,覆盖当前代码已经读取并生效的 环境变量、`POST /api/transcribe` 的请求级覆盖语义,以及还没有暴露为稳定 配置项的内部默认值。没有在本文列出的 Whisper / diarization / AS-norm 变量, 不要假定已经可用。 @@ -35,10 +35,13 @@ | `CUDA_VISIBLE_DEVICES` | 未设置 | 可选 NVIDIA 可见卡限制。默认不注入该变量,compose 会请求 Docker 暴露的所有可用 GPU。只有需要把容器限制到某些卡时,才通过 `docker-compose.override.yml` 或显式 operator env 注入;容器内 `cuda:0` 是可见集合的第 0 张,不一定等于宿主物理 GPU0。CPU-only 请设置 `DEVICE=cpu`。 | | `FFMPEG_TIMEOUT_SEC` | `1800` | ffmpeg 转码超时秒数,超时返回 `504`。 | | `JOBS_MAX_CACHE` | `200` | 内存 job LRU 上限;被淘汰的完成任务仍可从磁盘 `status.json` / `result.json` 查询。 | +| `TRANSCRIPTION_MAX_ACTIVE_JOBS` | `200` | 转写接纳的活跃 job 上限,`queued` / `converting` / `denoising` / `transcribing` / `identifying` 等未完成任务计入;达到上限时在启动后台线程前返回 `503`。设为 `0` 可关闭该预算。 | +| `TRANSCRIPTION_MAX_IN_FLIGHT_JOBS` | `4` | 同时在处理中的唯一音频 hash 上限,用于限制排队到 GPU/model 工作前的并发 job 数;达到上限时在启动后台线程前返回 `503`。设为 `0` 可关闭该预算。 | +| `TRANSCRIPTION_MIN_FREE_DISK_BYTES` | `1073741824` | 上传保存、hash 和去重检查之后,`DATA_DIR` 所在磁盘必须保留的最小空闲字节数;低于该预算时会删除已保存 upload,并在 reserve active job 或启动后台线程前返回 `503`。设为 `0` 可关闭该预算。 | | `MODEL_IDLE_TIMEOUT_SEC` | `180` | GPU 模型空闲卸载超时,默认 180 秒(3 分钟)。设为 `0` 可关闭空闲卸载并保持模型常驻。开启后,只有串行 GPU 运行时空闲达到该秒数才释放已加载模型;下一次 reload 时 ASR、diarization 和 embedding 会在各自 lazy load 时分别选择当前可见 CUDA 中空闲显存最多的设备。 | -| `RUST_KERNEL_MODE` | `off` | 可选 Rust-backed provider/kernel 路径开关。`off` 保持 Python 实现;`required` 要求被选择的 Rust-backed 路径可导入并执行,缺失或调用失败时 fail closed。当前被选择的路径是声纹计分、结果后处理和 artifact manifest helper contract;默认关闭时,CI / Docker packaging 仍会直接验证 Rust 扩展。 | +| `RUST_KERNEL_MODE` | `required` | Rust-backed provider/kernel 路径默认必需。`required` 要求被选择的 Rust-backed 路径可导入并执行,缺失或调用失败时 fail closed;`off` 仅作为显式 rollback,回到 Python 实现。当前被选择的路径是声纹计分、结果后处理和 artifact manifest helper contract;CI / Docker packaging 会直接验证 Rust 扩展。 | -`MODELS_DIR` 和 `LANGUAGE` 在配置模块里有定义,但 v0.8.4 的主 HTTP 转写路径 +`MODELS_DIR` 和 `LANGUAGE` 在配置模块里有定义,但 v0.8.5 的主 HTTP 转写路径 没有把它们作为稳定公开调参入口使用:Whisper 本地 checkpoint 查找仍使用 `/models/faster-whisper-`,语言默认请通过请求字段 `language` 控制或留空自动检测。 @@ -90,7 +93,7 @@ Hugging Face snapshot,缓存不完整时再走 Hub。 当前内部 ASR 默认值:`beam_size=5`、`vad_filter=True`、 `vad_parameters.min_silence_duration_ms=500`、`condition_on_previous_text=False`。 -这些值在 v0.8.4 还没有对应 env 或 API 字段;不要写 `WHISPER_BEAM_SIZE`、 +这些值在 v0.8.5 还没有对应 env 或 API 字段;不要写 `WHISPER_BEAM_SIZE`、 `WHISPER_COMPUTE_TYPE`、`WHISPER_VAD_*` 之类未实现配置。 ## 降噪 @@ -98,13 +101,14 @@ Hugging Face snapshot,缓存不完整时再走 Hub。 | 配置 | 默认值 | 作用 | | --- | --- | --- | | `DENOISE_MODEL` | `none` | 服务端默认降噪后端:`none`、`deepfilternet`、`noisereduce`。未知值会记录警告并跳过降噪。 | -| `DENOISE_SNR_THRESHOLD` | `10.0` | DeepFilterNet 的 SNR 门限 dB。选择 `deepfilternet` 时,估算 SNR 大于等于该值会跳过,避免处理干净录音;`noisereduce` 不使用该 gate。 | +| `DENOISE_SNR_THRESHOLD` | `10.0` | DeepFilterNet 的 SNR 门限 dB。选择 `deepfilternet` 时,估算 SNR 大于等于该值会跳过,避免处理干净录音;`noisereduce` 不受 SNR gate 控制,但仍受 `DENOISE_MAX_AUDIO_DURATION_SEC` 限制。 | +| `DENOISE_MAX_AUDIO_DURATION_SEC` | `7200` | 可选降噪的整段音频处理时长上限。超过该秒数时跳过 `deepfilternet` / `noisereduce`,避免把超长音频一次性读入内存;设为 `0` 可关闭该预算。 | | API `denoise_model` | 省略 | 省略表示继承 `DENOISE_MODEL`;显式传 `none` 表示只对本次任务关闭降噪。 | | API `snr_threshold` | 省略 | 省略表示继承 `DENOISE_SNR_THRESHOLD`;显式传值只覆盖本次任务的 DeepFilterNet SNR gate。 | -v0.8.4 默认面向干净会议录音,因此 `DENOISE_MODEL=none`。只有噪声环境才建议按任务 +v0.8.5 默认面向干净会议录音,因此 `DENOISE_MODEL=none`。只有噪声环境才建议按任务 或服务级启用 `deepfilternet` / `noisereduce`。如需“干净录音自动跳过”,请选择 -`deepfilternet`;`noisereduce` 一旦被选择就会运行。 +`deepfilternet`;`noisereduce` 不受 SNR gate 控制,但仍受 `DENOISE_MAX_AUDIO_DURATION_SEC` 限制。 ## Diarization 与 alignment @@ -117,6 +121,7 @@ v0.8.4 默认面向干净会议录音,因此 `DENOISE_MODEL=none`。只有噪 | `WHISPERX_ALIGN_MODEL_MAP` | 空 | 逗号分隔 `lang=model` 覆盖,例如 `zh=org/model`。 | | `WHISPERX_ALIGN_MODEL_DIR` | 空 | 可选 alignment 模型目录;仅在当前 WhisperX 版本支持该参数时透传。 | | `WHISPERX_ALIGN_CACHE_ONLY` | `0` | 为 `1` 时,请求 WhisperX 只使用缓存加载 alignment 模型;仅在当前 WhisperX 版本支持时透传。 | +| `WHISPERX_ALIGN_MAX_AUDIO_DURATION_SEC` | `7200` | WhisperX forced alignment 的整段音频处理时长上限。超过该秒数时 alignment 以 `skipped` 元数据返回,不调用 `whisperx.load_audio`;设为 `0` 可关闭该预算。 | alignment 是可选元数据。成功时结果顶层可能包含 `alignment.status=succeeded` 和 `segments[].words`;被显式禁用或加载失败时任务仍会完成,`words` 可能缺失, @@ -130,6 +135,7 @@ alignment 是可选元数据。成功时结果顶层可能包含 `alignment.stat | `EMBEDDING_DIM` | `256` | 声纹向量维度,用于声纹库和 AS-norm cohort 形状校验。不要把不同维度的既有声纹库混用。 | | `MIN_EMBED_DURATION` | `1.5` | 短于该时长的 diarization turn 不参与 speaker embedding。 | | `MAX_EMBED_DURATION` | `10.0` | 长于该时长的 turn 会截断到该窗口后再提取 embedding。 | +| `EMBEDDING_PRELOAD_MAX_AUDIO_DURATION_SEC` | `1800` | speaker embedding 全音频预加载上限。超过该秒数时不做 single `soundfile.read` preload,改用按 turn 分段读取;设为 `0` 可关闭该预算。 | 每个说话人 cluster 最多使用 10 个最长可用片段求平均 embedding。太短、太碎或噪声很大的 turn 会降低登记与识别质量。 @@ -160,7 +166,7 @@ cohort 生命周期: - 否则扫描持久化转写结果和 `emb_*.npy` 构建并保存 cohort。 - 每次 enroll / update 后,后台 `cohort-rebuild` 线程每 60 秒检查一次,在最近一次 enroll 至少过去 30 秒后自动重建。 -- v0.8.4 的后台自动重建会保护更大的已加载或已持久化 cohort:清空转写结果、 +- v0.8.5 的后台自动重建会保护更大的已加载或已持久化 cohort:清空转写结果、 只有少量 embedding,或源数量少于现有 cohort 时,不会自动缩小 cohort。 - `POST /api/voiceprints/rebuild-cohort` 是显式手动重建,仍按当前可用 embedding 立即生成新 cohort。 @@ -188,9 +194,9 @@ cohort 生命周期: 新增字段按可选字段原则扩展;客户端应忽略不认识的字段,并容忍 `words` / `alignment` / `artifacts` / `warning` 缺失。 -## v0.8.4 验证口径 +## v0.8.5 验证口径 -v0.8.4 已用 internal live validation 覆盖:可选 Rust kernel 基础能力、已选择的 +v0.8.5 已用 internal live validation 覆盖:默认必需的 Rust kernel 基础能力、已选择的 声纹计分、结果后处理、artifact/status helper contract、Docker packaging smoke 以及 public release scan gate。公开文档只记录行为类别,不发布真实任务名、 样本名、job id、speaker id、主机、日志或路径。 diff --git a/doc/quickstart.en.md b/doc/quickstart.en.md index 9be0076..2fa60c6 100644 --- a/doc/quickstart.en.md +++ b/doc/quickstart.en.md @@ -198,7 +198,7 @@ A few worth knowing about: | `MODEL_IDLE_TIMEOUT_SEC` | `180` | GPU model idle-unload timeout, defaulting to 180 seconds (3 minutes); set `0` to disable idle unload and keep models resident. When enabled, ASR, diarization, and embedding each reselect the best visible CUDA device on their next lazy load | | `ALLOW_NO_AUTH` | `0` | Set to 1 to suppress the startup warning when no API_KEY is configured (explicitly confirms unauthenticated mode) | | `DENOISE_MODEL` | `none` | Service default denoise backend: `none`, `deepfilternet`, or `noisereduce`; API requests may override it per job | -| `DENOISE_SNR_THRESHOLD` | `10.0` | DeepFilterNet SNR gate in dB; audio at or above this value skips DeepFilterNet when `deepfilternet` is selected; `noisereduce` is not gated | +| `DENOISE_SNR_THRESHOLD` | `10.0` | DeepFilterNet SNR gate in dB; audio at or above this value skips DeepFilterNet when `deepfilternet` is selected; `noisereduce` is not SNR-gated but still respects `DENOISE_MAX_AUDIO_DURATION_SEC` | | `VOICEPRINT_THRESHOLD` | `0.75` | Base raw-cosine voiceprint threshold before per-speaker adaptive adjustment | | `PYANNOTE_MIN_DURATION_OFF` | `0.5` | Pyannote off-turn smoothing, used to reduce over-segmentation of short pauses | | `MIN_EMBED_DURATION` | `1.5` | Minimum diarization turn duration used for speaker embedding extraction | @@ -213,7 +213,7 @@ For `POST /api/transcribe`, omitting `denoise_model` means "use the service default from `DENOISE_MODEL`". Sending `denoise_model=none` is the explicit per-request opt-out. Sending `snr_threshold` always overrides `DENOISE_SNR_THRESHOLD` for that request only, but only affects -`deepfilternet`; `noisereduce` runs whenever selected. +`deepfilternet`; `noisereduce` is not SNR-gated but still respects `DENOISE_MAX_AUDIO_DURATION_SEC`. For every supported setting, the Whisper / ASR parameters that are not exposed as env yet, and AS-norm cohort preservation semantics, see [`configuration.en.md`](./configuration.en.md). diff --git a/doc/quickstart.zh.md b/doc/quickstart.zh.md index a62634e..b7d8089 100644 --- a/doc/quickstart.zh.md +++ b/doc/quickstart.zh.md @@ -172,7 +172,7 @@ HF_ENDPOINT=https://hf-mirror.com | `MODEL_IDLE_TIMEOUT_SEC` | `180` | GPU 模型空闲卸载超时,默认 180 秒(3 分钟);设为 `0` 可关闭空闲卸载并保持模型常驻。开启时,ASR、diarization、embedding 在下一次各自 lazy load 时分别重新选择最佳可见 CUDA 设备 | | `ALLOW_NO_AUTH` | `0` | 设为 1 可在未配置 API_KEY 时抑制启动警告(明确确认无鉴权模式) | | `DENOISE_MODEL` | `none` | 服务端默认降噪后端:`none`、`deepfilternet` 或 `noisereduce`;API 可按单次任务覆盖 | -| `DENOISE_SNR_THRESHOLD` | `10.0` | DeepFilterNet SNR 门限(dB);选择 `deepfilternet` 时,音频信噪比达到或高于该值会跳过 DeepFilterNet;`noisereduce` 不受此 gate 控制 | +| `DENOISE_SNR_THRESHOLD` | `10.0` | DeepFilterNet SNR 门限(dB);选择 `deepfilternet` 时,音频信噪比达到或高于该值会跳过 DeepFilterNet;`noisereduce` 不受 SNR gate 控制,但仍受 `DENOISE_MAX_AUDIO_DURATION_SEC` 限制 | | `VOICEPRINT_THRESHOLD` | `0.75` | raw cosine 声纹匹配基础阈值,实际会按每位说话人自适应调整 | | `PYANNOTE_MIN_DURATION_OFF` | `0.5` | pyannote 停顿合并参数,用于减少短暂停顿导致的过度切分 | | `MIN_EMBED_DURATION` | `1.5` | 提取 speaker embedding 时接受的最短 diarization turn 时长 | @@ -186,7 +186,7 @@ HF_ENDPOINT=https://hf-mirror.com 对 `POST /api/transcribe` 来说,省略 `denoise_model` 表示使用服务端 `DENOISE_MODEL` 默认值;显式传 `denoise_model=none` 才表示本次请求关闭降噪。 显式传 `snr_threshold` 时,会只对本次请求覆盖 `DENOISE_SNR_THRESHOLD`。 -该门限只影响 `deepfilternet`;`noisereduce` 一旦被选择就会运行。 +该门限只影响 `deepfilternet`;`noisereduce` 不受 SNR gate 控制,但仍受 `DENOISE_MAX_AUDIO_DURATION_SEC` 限制。 所有可用配置项、哪些 Whisper / ASR 参数尚未暴露为 env,以及 AS-norm cohort 保护语义,见 [`configuration.zh.md`](./configuration.zh.md)。 diff --git a/doc/security.en.md b/doc/security.en.md index 32b5175..6f8b0f2 100644 --- a/doc/security.en.md +++ b/doc/security.en.md @@ -22,7 +22,7 @@ Treat the service as if it were an internal database. ## Built-in hardening (on by default) -As of 0.8.4 the following protections are in place out of the box: +As of 0.8.5 the following protections are in place out of the box: 1. **Container runs as a non-root user.** The Dockerfile creates an `app` user (uid/gid 1000 by default, overridable via `APP_UID`/ diff --git a/doc/security.zh.md b/doc/security.zh.md index e3e633b..2824b19 100644 --- a/doc/security.zh.md +++ b/doc/security.zh.md @@ -19,7 +19,7 @@ ## 内置的硬化(默认启用) -当前版本(0.8.4)默认开启以下保护: +当前版本(0.8.5)默认开启以下保护: 1. **容器以非 root 用户运行**。Dockerfile 创建 `app` 用户(uid/gid 1000, 可通过 `APP_UID`/`APP_GID` 覆盖),`USER app`。即使服务代码被 RCE, diff --git a/docker-compose.yml b/docker-compose.yml index 289674c..429f60e 100755 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -38,9 +38,12 @@ services: - MAX_UPLOAD_BYTES=${MAX_UPLOAD_BYTES:-2147483648} # Runtime cache and conversion limits. - JOBS_MAX_CACHE=${JOBS_MAX_CACHE:-200} + - TRANSCRIPTION_MAX_ACTIVE_JOBS=${TRANSCRIPTION_MAX_ACTIVE_JOBS:-200} + - TRANSCRIPTION_MAX_IN_FLIGHT_JOBS=${TRANSCRIPTION_MAX_IN_FLIGHT_JOBS:-4} + - TRANSCRIPTION_MIN_FREE_DISK_BYTES=${TRANSCRIPTION_MIN_FREE_DISK_BYTES:-1073741824} - FFMPEG_TIMEOUT_SEC=${FFMPEG_TIMEOUT_SEC:-1800} - MODEL_IDLE_TIMEOUT_SEC=${MODEL_IDLE_TIMEOUT_SEC:-180} - - RUST_KERNEL_MODE=${RUST_KERNEL_MODE:-off} + - RUST_KERNEL_MODE=${RUST_KERNEL_MODE:-required} # CUDA_VISIBLE_DEVICES is intentionally not set by default. The GPU # reservation below requests all Docker-exposed GPUs. To restrict the # visible set, add it in docker-compose.override.yml or another explicit @@ -56,8 +59,9 @@ services: # for noisy environments. See doc/voiceprint-tuning.en.md. - DENOISE_MODEL=${DENOISE_MODEL:-none} # DeepFilterNet SNR gate (dB): audio at or above this threshold skips DeepFilterNet. - # noisereduce is not SNR-gated and runs whenever selected. + # noisereduce is not SNR-gated but still respects duration budget. - DENOISE_SNR_THRESHOLD=${DENOISE_SNR_THRESHOLD:-10.0} + - DENOISE_MAX_AUDIO_DURATION_SEC=${DENOISE_MAX_AUDIO_DURATION_SEC:-7200} # Base cosine-similarity threshold for voiceprint identification (0.0–1.0). # Adaptive relaxation applies per-speaker on top of this base value. - VOICEPRINT_THRESHOLD=${VOICEPRINT_THRESHOLD:-0.75} @@ -66,9 +70,12 @@ services: - EMBEDDING_DIM=${EMBEDDING_DIM:-256} - MIN_EMBED_DURATION=${MIN_EMBED_DURATION:-1.5} - MAX_EMBED_DURATION=${MAX_EMBED_DURATION:-10.0} + - EMBEDDING_PRELOAD_MAX_AUDIO_DURATION_SEC=${EMBEDDING_PRELOAD_MAX_AUDIO_DURATION_SEC:-1800} + - WHISPERX_ALIGN_MAX_AUDIO_DURATION_SEC=${WHISPERX_ALIGN_MAX_AUDIO_DURATION_SEC:-7200} # Forced alignment controls. Languages are attempted by default; set # WHISPERX_ALIGN_DISABLED_LANGUAGES only as an explicit fallback. - WHISPERX_ALIGN_DISABLED_LANGUAGES=${WHISPERX_ALIGN_DISABLED_LANGUAGES:-} + - WHISPERX_ALIGN_DEVICE=${WHISPERX_ALIGN_DEVICE:-cpu} - WHISPERX_ALIGN_MODEL_MAP=${WHISPERX_ALIGN_MODEL_MAP:-} - WHISPERX_ALIGN_MODEL_DIR=${WHISPERX_ALIGN_MODEL_DIR:-} - WHISPERX_ALIGN_CACHE_ONLY=${WHISPERX_ALIGN_CACHE_ONLY:-0} diff --git a/docs/adr/0012-use-architecture-rings-and-cycle-gates-for-next-version-refactor.md b/docs/adr/0012-use-architecture-rings-and-cycle-gates-for-next-version-refactor.md new file mode 100644 index 0000000..5179d71 --- /dev/null +++ b/docs/adr/0012-use-architecture-rings-and-cycle-gates-for-next-version-refactor.md @@ -0,0 +1,130 @@ +# ADR-0012: 下一版本重构收敛职责边界和循环依赖 + +- 状态:已接受 +- 日期:2026-06-11 +- 范围:VoScript 下一版本架构重构、运行时检查、发布检查、文档/代码漂移检查 + +## 背景 + +本 ADR 使用几个会影响判断的架构词。Architecture ring(架构环)在本项目里指一组有共同职责和依赖方向的代码区域,例如 API/composition、application、pipeline、provider、domain 和 infra。Cycle(循环依赖)指两个或多个模块、层或职责互相依赖,导致修改一处时很难判断谁拥有规则、谁只能消费规则。Gate 指可重复执行的检查,不是人工印象。Boundary(边界)指两个 ring 或两个职责之间的可见接口。Owner(责任方)指某个规则或状态的唯一权威位置。Provider 指某个 pipeline step 的具体后端或模型实现。Lifecycle(生命周期)指 job、model、runtime resource 或 application startup/shutdown 的状态变化。先把这些词说清楚,是为了让后面的重构决策能落到 VoScript 的具体风险上:减少跨层互相调用、重复权威和发布前靠记忆查漏。 + +ADR-0001 已决定 Python 继续拥有 HTTP API、job lifecycle、pipeline runner、配置、模型 lifecycle 和 artifact/result contract,Rust 只作为 provider/kernel 内部实现。ADR-0003 已决定 provider capability 使用静态 metadata 表达。ADR-0010 和 ADR-0011 已分别固定 heavy CI gate 触发策略,以及 `RUST_KERNEL_MODE=required` 默认 fail closed、`off` 显式 rollback 的语义。 + +下一版本重构要解决的问题不是单纯的 Python import cycle。在 VoScript 里,循环重要是因为它会让 API、application、pipeline、provider、infra 和 docs/release 之间的权威来回渗透;消除或收窄循环后,变更风险会从“牵一发影响整条链路”降低为“只影响明确 owner 的边界”。 + +cycle-analysis 重试结果确认:当前架构不能被描述为 cycle-free。已确认的 Python import SCC(strongly connected component,导入图里的强连通分量,表示这些模块在导入层面互相可达)是 `pipeline.contracts -> pipeline.contracts.context -> pipeline.contracts.requests -> pipeline.registry -> pipeline.contracts`。更宽的 layer-level SCC 是 `infra <-> pipeline <-> providers`。同时,本次分析没有发现 package-level cycle、API/application hard import cycle、uncontrolled restart/dedup infinite loop,也没有发现 silent Rust-to-Python fallback loop。这些否定项只能说明对应风险未被当前证据命中,不能抵消已确认的 import SCC 和 layer-level SCC。 + +当前代码中存在多个需要用架构 ring 和 gate 固化的证据: + +- `app/pipeline/registry.py` 是真实 stage/provider registry,定义稳定 stage order 和 provider import path。 +- `app/providers/capabilities.py` 已有静态 capability metadata,但覆盖面仍小于 registry stage/provider surface,且包含 `alignment` 这种不在 registry stage order 中的子能力。 +- `app/pipeline/runner.py` 负责执行 stage order,并把 `request.provider_for(stage_name)` 记录到 `context.metadata["selected_providers"]`,但 runner 还没有在执行前强制调用 capability matching。 +- `app/api/routers/transcriptions.py` 同时拥有 upload、job 查询、transcription 列表/读取、音频下载、speaker reassign 和 export 逻辑,路由文件已经成为多职责入口。 +- `app/providers/normalize/default.py` 和 `app/infra/audio/paths.py` 在 API 层以下直接导入并抛出 FastAPI `HTTPException`,说明 HTTP error 类型已经泄漏到 provider/infra。 +- `app/config.py` 仍允许 `API_KEY` 为空、`CORS_ALLOW_ORIGINS=*`、`MAX_UPLOAD_BYTES=2GB`;`RUST_KERNEL_MODE=required` 是 0.8.5 final-state 默认值,`off` 只作为显式 rollback。前述本地/LAN 体验默认值和 Rust rollback 边界必须通过文档、配置和 admission gate 明确。 +- `app/providers/embedding/default.py`、`app/providers/enhance/default.py` 和 `app/providers/diarization/default.py` 仍在部分路径中加载完整音频或把整段 audio 交给下游库,说明 memory-sensitive provider 需要 explicit bounds,而不能只依赖上传大小。 +- `app/voiceprints/db.py` 先由 Python repository 取出候选,再按 `RUST_KERNEL_MODE=required|off` 决定调用 Rust `voiceprint_score` 或走显式 rollback;`app/voiceprints/repository.py` 仍拥有候选读取,`crates/voscript_core/src/voiceprint.rs` 只负责纯 scoring decision。 +- `app/infra/job_persistence.py` 和 `app/infra/job_status.py` 拥有 persisted job status 读写与 payload contract;`app/pipeline/contracts/schema.py` 仍是 schema contract helper;`app/providers/kernel_bridge/runtime.py` 只是 Rust extension import/call 和 response validation bridge。 +- `.github/workflows/ci.yml`、`.github/workflows/rust-foundation-heavy.yml` 和 `.github/workflows/release.yml` 把 public scan、lint/test/security、Rust wheel/Docker smoke、publish 分散在不同 workflow;发布 gate 需要同一个 exact ref 的自包含证据,而不是拼接过期或不同 ref 的绿灯。 +- `docker-compose.yml`、`.env.example`、`README.md`、`README.en.md`、`doc/api.zh.md`、`doc/api.en.md`、`doc/configuration.zh.md` 和 `doc/configuration.en.md` 共同描述运行配置、API、鉴权、上传上限、Rust mode 和验证口径,必须被当作 public docs/code drift surface。 + +本 ADR 位于 ignored internal architecture docs。它只记录下一版本重构的长期约束和最终状态 contract。 + +## 术语约定 + +Architecture ring(架构环)不是目录装饰,而是判断“谁能依赖谁、谁拥有错误映射、谁拥有 runtime 决策”的依据;修复后风险会从跨层互相调用降低为单向、可审计的依赖。 + +Boundary(边界)指两个 ring 或两个职责之间的可见接口。在 VoScript 里,典型边界包括 route handler 到 usecase、pipeline 到 provider、Python runtime 到 Rust kernel。边界重要是因为越界会把 HTTP error、job lifecycle 或模型细节泄漏给不该拥有它们的代码;边界清楚后,测试和回滚可以按 owner 缩小范围。 + +Owner(责任方)指某个规则或状态的唯一权威位置,例如 `app/pipeline/registry.py` 拥有 stage/provider import registry,Python 拥有 job persistence 和 schema optionality。owner 不清会让同一规则在 router、runner、provider 和 docs 中重复出现;修复后可以通过一个 owner 变更和 gate 检查降低漂移风险。 + +Gate 指可重复执行的检查,不是人工印象。VoScript 的 architecture gate、release gate 和 docs/code drift gate 要把 import graph、forbidden dependency、exact-ref release evidence 和 public docs 同步变成可验证条件;这样可以把发布前风险从“靠记忆查漏”降到“证据缺失即失败”。 + +Provider 指某个 pipeline step 的具体后端或模型实现,例如 ASR、diarization、embedding、punc。provider 只应通过 stage contract 输入输出;如果 provider 反向拥有 job admission、HTTP error 或 thread/disk policy,风险会扩大到整个 runtime。 + +Repository 指领域或基础设施数据访问抽象,不等于 Git repository。在 VoScript 里,`app/voiceprints/repository.py` 这种 repository 应拥有候选读取接口,而不是让 Rust helper、router 或 provider 到处直接读取存储;修复后 candidate fetch 和 scoring decision 的责任边界更清楚。 + +Usecase 指 application 层的一条业务流程,例如 upload admission、job bootstrap、status recovery、export formatting。usecase 重要是因为它把多个底层能力编排成用户可见行为;把 usecase 从 router 中抽出后,HTTP 输入输出和业务流程可以分别测试。 + +Orchestration 指跨模块调度顺序和状态推进,例如 transcription job 如何经过 admission、dedup、pipeline stages、artifact write 和 status update。orchestration 如果散在 router、runner 和 provider 里,会让失败恢复和重试语义不稳定;集中到 application/pipeline owner 后,生命周期风险下降。 + +Adapter 指把外部系统或具体实现接入稳定边界的薄层,例如 filesystem、Rust extension bridge、CUDA/runtime helper。adapter 重要是因为它可以隔离副作用;如果 adapter 反向拥有业务决策,测试会被外部环境拖住,回滚范围也会变大。 + +Service 指对外提供能力的窄接口,不是把任意 helper 都命名成 service。VoScript 应避免恢复 `app/services/*` 这类扁平杂物层;窄 service 必须有明确 owner、输入输出和所在 ring,否则会制造新的结构债。 + +Lifecycle 指 job、model、runtime resource 或 application startup/shutdown 的状态变化。lifecycle 重要是因为 VoScript 同时有 job persistence、GPU/model load、idle unload 和 app lifespan;owner 混乱会造成重复启动、泄漏或错误恢复。 + +Import direction 和 dependency direction 分别指 Python `import` 图的方向和架构职责依赖方向。前者可以由静态脚本扫描,后者还要结合 ring owner 判断;两者都必须单向收敛,才能降低循环依赖和越界调用风险。 + +Facade 指对复杂子系统提供的简化入口。VoScript 只有在 facade 明确隐藏复杂性、且不吞掉 owner 和错误语义时才应使用;否则 facade 会变成新的大 router。 + +DTO 指跨边界传递的简单数据对象。它在 VoScript 中适合表达 API/application/pipeline 之间的稳定输入输出;如果 DTO 混入 persistence、FastAPI 类型或 provider 实现细节,会增加边界漂移风险。 + +Source-guard test 指用源码扫描守住架构规则的测试,例如禁止 API ring 之外导入 `fastapi.HTTPException`、禁止 provider 导入 router、禁止 docs 记录私有路径。它重要是因为这类违规不一定会被单元行为测试覆盖。 + +Structural debt / architecture debt 指当前还能运行、但职责和依赖已经让未来变更更危险的结构问题,例如大 router、多 owner、layer SCC 和 docs/code drift。偿还这类债务的收益不是立刻增加功能,而是降低后续修复、发布和回滚的不确定性。 + +## 决策 + +下一版本重构使用 architecture rings 作为默认边界模型,并把 ring/cycle gate 放入可重复验证链路。ring 是依赖方向、错误类型、配置权威和 runtime admission 的判定依据。 + +Architecture rings 定义如下: + +- API/composition ring:`app/main.py`、`app/api/` 和 FastAPI wiring。只负责 request parsing、dependency wiring、auth、response shaping 和 HTTP error mapping。FastAPI `HTTPException` 只能在本 ring 或显式 composition boundary 中出现。 +- Application ring:`app/application/`。负责 use-case orchestration、job lifecycle、status transition、dedup coordination 和 background execution admission。不得依赖 FastAPI 类型,不得直接拥有 provider/model implementation 细节。 +- Pipeline ring:`app/pipeline/`。负责 stable stage order、`PipelineRequest`、`PipelineContext`、stage result、artifact/schema contract、stage dispatch 和 provider selection boundary。不得依赖 API/router,不得拥有 HTTP error mapping。 +- Provider ring:`app/providers/`。负责每个 pipeline step 的具体 backend/model implementation。provider 必须通过 stage contract 输入输出,不能抛出 HTTP 类型,不能拥有 job/thread/disk admission policy。 +- Domain ring:`app/voiceprints/` 等业务领域模块。负责 speaker enrollment、matching、cohort、scoring policy 和 repository abstraction。domain 可以调用明确的 kernel bridge,但不能把 Rust helper 描述成 domain 或 runtime owner。 +- Infra ring:`app/infra/`。负责 filesystem、hash index、job persistence、runtime semaphore/cache、CUDA device selection、path safety 和 concrete adapters。infra 返回 domain/application 可映射的错误,不返回 HTTP error。 + +Ring gate 必须同时检查两类问题: + +- import cycle gate:用可重复脚本扫描 `app/` Python import graph,禁止新增 SCC,并把已知 SCC 收敛为零。该 gate 的输出必须包含 module count、internal edge count、SCC 列表、layer edge 列表和 layer SCC 列表;同时必须把 AST static import graph 与 registry/import_module runtime dynamic graph 分开报告,不能把 runtime/dynamic edge 伪装成 Python import-time edge;不能依赖一次性人工统计,也不能在没有当前输出时声称具体计数。 +- forbidden dependency gate:按 ring allowlist 检查禁止导入,例如 `fastapi` 不能出现在 API/composition ring 之外,provider/infra/domain 不能导入 router,provider 不能导入 application job orchestration,pipeline 不能导入 API。 + +Package-level graph、API/application hard import path、restart/dedup loop 和 Rust fallback loop 可以作为辅助检查,但不能替代 import cycle gate 和 forbidden dependency gate。没有命中这些辅助风险,只能说明当前证据未发现相应循环,不能证明整体架构 cycle-free。 + +Provider capability metadata 是 provider 可运行性的权威入口。`app/pipeline/registry.py` 仍是 stage/provider import registry,但任何 registry 中可被选择的 provider 都必须有 capability record 或显式 allowlisted exemption。runner 或 runner 前置 preflight 必须在 stage 执行前根据 `PipelineRequest.language`、stage criticality、provider name 和 Rust support 调用 capability matching;required mismatch fail closed,degradable/optional mismatch 只能通过明确 metadata 记录 skip reason。 + +Stage registry 与 capability metadata 必须互相校验。新增 stage、provider、language constraint、Rust-backed provider 或 alignment 子能力时,必须同时更新 registry、capability metadata、capability tests 和 docs/code drift gate。`alignment` 这类子能力可以保留为 stage 内 capability,但必须明确挂靠到拥有它的 registry stage,不能成为第二套隐式 stage order。 + +`PipelineContext` 仍可以作为 stage 间执行状态,但必须被 gate 和 tests 限制为稳定字段、稳定 metadata key 和单向 stage progression。新增 context 字段或 metadata key 需要说明 owner stage、读写 stage、是否进入 public result/status/artifact contract,以及是否允许下游覆盖。不得让任意 stage 通过自由写 `metadata` 反向改变 provider selection、job status、API response 或 artifact schema。 + +`PipelineContext.metadata` 的稳定 top-level key contract 由 pipeline contract 拥有。control keys 固定为 `executed_stages`、`selected_providers`、`provider_capabilities` 和 `stage_timings`;stage keys 固定为 `ingest`、`normalize`、`enhance`、`vad`、`asr`、`diarization`、`embedding`、`voiceprint_match`、`punc`、`postprocess` 和 `artifacts`。唯一公开 alignment metadata path 是 `diarization.alignment`,public result 只能包含 `status`、`reason`、`model`、`duration_s`、`max_duration_s`、`cache_only` 和 `device` 这些 JSON-safe scalar 字段。Architecture gate 必须扫描 production source 中的 `context.metadata` top-level literal key,并拒绝未登记 key 和裸 `context.metadata[...].update(...)` 写法。 + +API/domain boundary cleanup 是下一版本重构的必做项。API ring 以下不得再抛 FastAPI `HTTPException`;provider、domain 和 infra 只能抛出 typed domain/application errors 或返回 typed result,API ring 统一映射为 HTTP status/detail。已有 `app/providers/normalize/default.py` 和 `app/infra/audio/paths.py` 的 HTTPException 泄漏必须被迁移到这个错误映射模型。 + +`app/api/routers/transcriptions.py` 必须瘦身。目标不是机械拆文件,而是让 route handler 只做 HTTP 输入输出和 dependency wiring。upload admission、job bootstrap、status read/recovery、transcription listing、audio lookup、speaker correction、export formatting 和 artifact read/write 应迁到 application/domain/infra 的窄接口;router 可以按 upload/jobs/transcriptions/export/speaker correction 拆分,但拆分后不能复制业务规则。 + +Runtime admission 和 memory bounds 必须成为显式 gate。当前 `MAX_UPLOAD_BYTES`、`UPLOAD_CHUNK`、`JOBS_MAX_CACHE`、serialized GPU semaphore、in-flight dedup 和 idle model unload 已经是基础约束,但下一版本不能继续用 per-request unbounded daemon thread 作为唯一 job admission。接受 job 前必须检查并记录 upload size、durable status write、in-flight/queued job bound、worker/thread bound、data disk pressure 和 configured memory-sensitive stage limits。超出 admission budget 必须在开始 GPU/model work 前返回可预期错误。 + +Application/infra job boundary 必须通过公开窄接口表达。Application ring 可以编排 job lifecycle 和业务可见状态转换,但不得直接读取或修改 infra runtime 的 raw `jobs` store,也不得导入 `infra.job_persistence` 下划线 helper;runtime job get/set/update/pop/snapshot、durable status write 和 transcription record filesystem access 由 infra-owned adapter 提供。`status.json`、`result.json`、transcription directory traversal、uploaded audio path validation 和 JSON atomic write 属于 infra filesystem adapter;application 保留 usecase orchestration、typed application error mapping、speaker reassign 规则和 export formatting。 + +Memory-sensitive provider 必须有 size/duration/window policy。embedding、enhance、diarization/alignment 不能只依赖 2GB upload cap;对 full-audio load、resample、DeepFilterNet/noisereduce、WhisperX `load_audio`、speaker embedding chunking 等路径,必须定义可测试的 duration/sample/frame/memory guard 或 streaming/windowed strategy,并把默认值同步到 configuration docs。 + +Rust boundary wording 必须 truthful。Rust 只能被描述为 selected pure kernel/helper owner:voiceprint scoring decision、postprocess segment shaping、artifact/status helper contract 等。Python 仍拥有 candidate fetch、job persistence、persisted job status payload contract、schema optionality、pipeline runner、provider selection、artifact/result assembly 和 runtime mode。`RUST_KERNEL_MODE=required` 是 0.8.5 final-state 默认业务路径,表示被选中的 Rust-backed path 必须 import/call 成功并 fail closed;`off` 只是显式 rollback,不表示 Rust 拥有整个 runtime。 + +Release gate 必须升级为 exact-ref self-contained release gate。发布镜像或 release artifact 前,必须对同一个 immutable ref/tag/SHA 取得以下证据:public release scan、lint/format、unit/security slice、Rust fmt/clippy/test、Rust wheel build、Docker image build with wheel、container Rust extension smoke、container `/healthz` smoke,以及要发布的 Docker tags/source ref。可以继续把 CI、heavy gate 和 publish 分在多个 workflow,但 release workflow 只能消费同一 exact ref 的不可变成功证据;不能用 stale PR 首轮结果、latest main 结果或手动输入未解析 SHA 的结果替代。 + +Docs/code drift gate 必须覆盖 public runtime surface。修改 `app/config.py`、`docker-compose.yml`、`.env.example`、API router public behavior、status/result/artifact contract、Rust mode、upload/job admission、voiceprint scoring semantics 或 release gate 时,必须校验 `README.md`/`README.en.md`、`doc/api.zh.md`/`doc/api.en.md`、`doc/configuration.zh.md`/`doc/configuration.en.md` 是否同步。public docs 只能记录 released behavior、配置和 API;internal validation wording 只能写行为类别,不能泄漏真实 job id、speaker id、样本、host、日志或路径。 + +## 被拒绝的方案 + +- 只用 import SCC 作为架构健康标准:它能发现一类循环导入,但发现不了 HTTPException 泄漏、provider capability 不权威、router 多职责、mutable context 反向改写、runtime admission 缺口或 release/docs drift。 +- 忽略已确认 SCC、只记录未发现 package-level cycle:这会把证据口径从“当前有导入和层级循环需要治理”错误改写为“整体无循环”,导致 gate 目标失真。 +- 保留大 router 并只在内部加 helper:helper 会降低局部函数长度,但不会把 upload/job/export/speaker correction 的权威移出 API ring,也不能防止业务规则继续在 route handler 中扩散。 +- 让 provider 在运行时动态探测能力并自行决定是否运行:这会把模型加载、环境探测、副作用和 selection policy 混在一起,削弱 ADR-0003 的静态 metadata 决策。 +- 用 best-effort Rust fallback 描述下一版本:这会违反 ADR-0011 的 explicit rollback 语义,也会掩盖 `RUST_KERNEL_MODE=required` 默认 hard-fail 路径与显式 `off` rollback 路径的区别。 +- 发布时只看 release workflow build/push 是否成功:这不能证明同一 exact ref 已通过 public scan、Python tests、Rust wheel/Docker smoke 和 runtime health smoke。 +- 通过人工记忆同步 README、API docs、configuration docs 和 compose/env 示例:public docs/code drift surface 必须有 gate,不能靠维护者事后查漏。 + +## 后果 + +- 下一版本重构需要先落地可重复的 architecture gate,再用它约束后续代码迁移。任何 ring exception 都必须在 gate allowlist 中有 owner、原因和退出条件。 +- 已确认的 `pipeline.contracts -> pipeline.contracts.context -> pipeline.contracts.requests -> pipeline.registry -> pipeline.contracts` 和 `infra <-> pipeline <-> providers` 必须被当作架构风险治理目标,而不是被 package-level 或 API/application 层面的未命中结果掩盖。 +- Provider registry、capability metadata、runner preflight 和 capability tests 会成为同一组 contract;新增 provider 或 stage 时改一处不改另一处应当失败。 +- API 层会变薄,application/domain/infra 的 typed error 和 use-case interface 会增加;这是为了让 HTTP 映射集中、测试更稳定,并避免 FastAPI 类型继续向内层扩散。 +- Runtime admission 会把一部分失败提前到 job accepted 之前;这可能让某些原本排队很久才失败的请求更早返回 4xx/5xx,但能保护线程、磁盘、GPU 和内存预算。 +- Rust 相关文档和 release notes 必须按 selected kernel/helper 精确描述,不能把 optional bridge 或 helper contract 夸大为 runtime rewrite。 +- Release 成本会上升,因为发布必须证明 exact ref 自包含通过;换来的是 release artifact、Docker image、Rust wheel smoke 和 public docs scan 不再跨 ref 拼接。 +- Public docs 和 internal ADR 的边界保持不变:ADR 可以记录架构决策和 gate contract,README/doc 只记录用户可用的 released behavior、配置和 API。 diff --git a/tests/e2e/test_api_core.py b/tests/e2e/test_api_core.py index 39bb3ee..ed249f4 100644 --- a/tests/e2e/test_api_core.py +++ b/tests/e2e/test_api_core.py @@ -31,6 +31,7 @@ POLL_TIMEOUT = int( os.getenv("VOSCRIPT_POLL_TIMEOUT", "600") ) # seconds to wait for a job +ALLOW_FALLBACK = os.getenv("VOSCRIPT_E2E_ALLOW_FALLBACK", "0") == "1" # Bypass any system HTTP proxy so direct connections reach the NAS. _NO_PROXY = {"http": None, "https": None} @@ -166,9 +167,9 @@ def silence_wav(tmp_path_factory): def submitted_job(server_url, silence_wav): """Submit one transcription job and return {'job_id', 'tr_id', 'result'}. - If the new job fails (e.g. due to a server-side pipeline misconfiguration), - fall back to the most recent existing completed transcription so that - schema/lifecycle tests can still run against real data. + The default path is strict: the newly submitted job must complete. Set + VOSCRIPT_E2E_ALLOW_FALLBACK=1 only for local schema debugging against a + known-broken server. """ resp = _upload_wav(silence_wav, {"language": "en"}) assert ( @@ -182,6 +183,8 @@ def submitted_job(server_url, silence_wav): result = _poll_job(job_id) return {"job_id": job_id, "tr_id": job_id, "result": result, "fallback": False} except AssertionError as job_exc: + if not ALLOW_FALLBACK: + raise # New job failed on the server (e.g. pipeline misconfiguration). # Attempt to reuse an existing completed transcription so that # schema/lifecycle tests are not all blocked by a server-side bug. diff --git a/tests/test_security.py b/tests/test_security.py index 9ab47fc..616b746 100644 --- a/tests/test_security.py +++ b/tests/test_security.py @@ -49,10 +49,11 @@ def __enter__(self): self.monkeypatch.chdir(_APP_DIR) for _m in list(sys.modules): - if _m in ("main", "config") or _m.startswith(("api.", "infra.")) or _m in { - "api", - "infra", - }: + if ( + _m in ("main", "config") + or _m.startswith(("api.", "application.", "infra.")) + or _m in {"api", "application", "infra"} + ): del sys.modules[_m] from main import app # noqa: WPS433 — late import on purpose @@ -136,9 +137,9 @@ def test_api_key_required_when_set(monkeypatch): """Without any credentials the middleware must reject with 401.""" with _AuthedClientCtx("s3cret", monkeypatch) as (client, _tmpdir): resp = client.get("/api/transcriptions") - assert ( - resp.status_code == 401 - ), f"expected 401 Unauthorized, got {resp.status_code}: {resp.text}" + assert resp.status_code == 401, ( + f"expected 401 Unauthorized, got {resp.status_code}: {resp.text}" + ) assert resp.headers.get("www-authenticate", "").lower().startswith("bearer") @@ -159,9 +160,9 @@ def test_api_key_accepted_via_bearer(monkeypatch): "/api/transcriptions", headers={"Authorization": "Bearer WRONG"}, ) - assert ( - resp.status_code == 401 - ), f"middleware must 401 on bad Bearer, got {resp.status_code}" + assert resp.status_code == 401, ( + f"middleware must 401 on bad Bearer, got {resp.status_code}" + ) # Correct Bearer → middleware no longer 401s (it may 403 at the # router-level dep, but that's a distinct gate, not a middleware @@ -171,8 +172,7 @@ def test_api_key_accepted_via_bearer(monkeypatch): headers={"Authorization": "Bearer s3cret"}, ) assert resp.status_code != 401, ( - f"Bearer auth was rejected by middleware: " - f"{resp.status_code} {resp.text}" + f"Bearer auth was rejected by middleware: {resp.status_code} {resp.text}" ) # End-to-end: Bearer + X-API-Key together must yield 200. @@ -183,9 +183,9 @@ def test_api_key_accepted_via_bearer(monkeypatch): "X-API-Key": "s3cret", }, ) - assert ( - resp.status_code == 200 - ), f"Bearer + X-API-Key must pass: {resp.status_code} {resp.text}" + assert resp.status_code == 200, ( + f"Bearer + X-API-Key must pass: {resp.status_code} {resp.text}" + ) assert resp.json() == [] @@ -269,9 +269,9 @@ def test_path_traversal_rejected(monkeypatch): # 400/422 — this confirms the rejections above are due to the regex, # not a stray auth failure. resp = client.get("/api/transcriptions/tr_does_not_exist", headers=headers) - assert ( - resp.status_code == 404 - ), f"well-formed unknown id should 404, got {resp.status_code}" + assert resp.status_code == 404, ( + f"well-formed unknown id should 404, got {resp.status_code}" + ) # --------------------------------------------------------------------------- @@ -337,7 +337,7 @@ def test_np_load_allow_pickle_false(monkeypatch): def test_transcribe_sanitizes_filename_and_inflight_deduplicates(monkeypatch): """A control-char filename must be sanitized and duplicate bytes must reuse the first job.""" with _AuthedClientCtx("s3cret", monkeypatch) as (client, tmpdir): - import api.routers.transcriptions as transcriptions + import application.transcription_submission as submission class _FakeThread: def __init__(self, *args, **kwargs): @@ -347,7 +347,7 @@ def __init__(self, *args, **kwargs): def start(self): return None - monkeypatch.setattr(transcriptions, "Thread", _FakeThread) + monkeypatch.setattr(submission, "Thread", _FakeThread) files = {"file": ("../-y\nattack.wav", b"same-bytes", "audio/wav")} first = client.post("/api/transcribe", files=files, headers=_auth_headers()) @@ -356,7 +356,7 @@ def start(self): job_id = body["id"] assert body["status"] == "queued" - sanitized = transcriptions.jobs[job_id]["filename"] + sanitized = submission.jobs[job_id]["filename"] assert sanitized.startswith("-y") assert ".." not in sanitized assert "\n" not in sanitized and "\r" not in sanitized @@ -414,7 +414,9 @@ def test_corrupt_status_json_never_500s(monkeypatch): ), ], ) -def test_corrupt_result_json_returns_controlled_error(monkeypatch, method, path, kwargs): +def test_corrupt_result_json_returns_controlled_error( + monkeypatch, method, path, kwargs +): """Corrupt result.json must not crash read/edit/export endpoints.""" with _AuthedClientCtx("s3cret", monkeypatch) as (client, tmpdir): _seed_result_json(tmpdir, "tr_corrupt", raw_text="{definitely-not-json") diff --git a/tests/test_voiceprint_db.py b/tests/test_voiceprint_db.py index 480efd1..6ec7923 100644 --- a/tests/test_voiceprint_db.py +++ b/tests/test_voiceprint_db.py @@ -39,6 +39,9 @@ def _fresh_db(db_dir: Path): if name == "voiceprints" or name.startswith("voiceprints."): sys.modules.pop(name, None) mod = importlib.import_module("voiceprints.db") + # These legacy DB tests pin Python scoring and AS-norm semantics. Dedicated + # Rust bridge tests in this file opt back into required mode explicitly. + mod.rust_provider_paths_enabled = lambda: False db_dir.mkdir(parents=True, exist_ok=True) return mod.VoiceprintDB(str(db_dir)), mod diff --git a/tests/unit/test_admission.py b/tests/unit/test_admission.py new file mode 100644 index 0000000..2d503fb --- /dev/null +++ b/tests/unit/test_admission.py @@ -0,0 +1,147 @@ +"""Tests for application-level transcription admission policy.""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from application.admission import ( + AdmissionBudget, + AdmissionRejectedError, + RuntimeAdmissionSnapshot, + admit_transcription_in_flight, + build_runtime_admission_snapshot, + ensure_transcription_admitted, + find_in_flight_transcription, + release_transcription_admission, + reserve_transcription_admission, +) +import infra.job_runtime as job_runtime + + +def test_transcription_admission_allows_budget_headroom(): + ensure_transcription_admitted( + RuntimeAdmissionSnapshot(active_jobs=1, in_flight_jobs=1), + AdmissionBudget(max_active_jobs=4, max_in_flight_jobs=2), + ) + + +def test_transcription_admission_rejects_active_job_budget(): + with pytest.raises(AdmissionRejectedError) as exc_info: + ensure_transcription_admitted( + RuntimeAdmissionSnapshot(active_jobs=4, in_flight_jobs=0), + AdmissionBudget(max_active_jobs=4, max_in_flight_jobs=2), + ) + + assert exc_info.value.reason == "active_job_budget_exceeded" + assert "active job budget" in str(exc_info.value) + + +def test_transcription_admission_rejects_in_flight_budget(): + with pytest.raises(AdmissionRejectedError) as exc_info: + ensure_transcription_admitted( + RuntimeAdmissionSnapshot(active_jobs=1, in_flight_jobs=2), + AdmissionBudget(max_active_jobs=4, max_in_flight_jobs=2), + ) + + assert exc_info.value.reason == "in_flight_job_budget_exceeded" + assert "in-flight job budget" in str(exc_info.value) + + +def test_zero_or_negative_budget_disables_that_budget(): + ensure_transcription_admitted( + RuntimeAdmissionSnapshot(active_jobs=100, in_flight_jobs=100), + AdmissionBudget(max_active_jobs=0, max_in_flight_jobs=-1), + ) + + +def test_transcription_admission_allows_data_disk_headroom(): + ensure_transcription_admitted( + RuntimeAdmissionSnapshot( + active_jobs=0, + in_flight_jobs=0, + free_disk_bytes=2 * 1024 * 1024 * 1024, + ), + AdmissionBudget( + max_active_jobs=1, + max_in_flight_jobs=1, + min_free_disk_bytes=1024 * 1024 * 1024, + ), + ) + + +def test_transcription_admission_rejects_data_disk_pressure(): + with pytest.raises(AdmissionRejectedError) as exc_info: + ensure_transcription_admitted( + RuntimeAdmissionSnapshot( + active_jobs=0, + in_flight_jobs=0, + free_disk_bytes=512 * 1024 * 1024, + ), + AdmissionBudget( + max_active_jobs=1, + max_in_flight_jobs=1, + min_free_disk_bytes=1024 * 1024 * 1024, + ), + ) + + assert exc_info.value.reason == "data_disk_pressure" + assert "data disk free space" in str(exc_info.value) + + +def test_zero_disk_budget_disables_data_disk_pressure(): + ensure_transcription_admitted( + RuntimeAdmissionSnapshot( + active_jobs=0, + in_flight_jobs=0, + free_disk_bytes=0, + ), + AdmissionBudget( + max_active_jobs=1, + max_in_flight_jobs=1, + min_free_disk_bytes=0, + ), + ) + + +def test_runtime_admission_snapshot_uses_injected_disk_usage(tmp_path, monkeypatch): + monkeypatch.setattr(job_runtime, "_active_job_ids", {"tr_active"}) + monkeypatch.setattr(job_runtime, "_in_flight_hashes", {"sha256:busy": "tr_busy"}) + + snapshot = build_runtime_admission_snapshot( + data_path=tmp_path, + disk_usage=lambda path: SimpleNamespace(free=12345), + ) + + assert snapshot.active_jobs == 1 + assert snapshot.in_flight_jobs == 1 + assert snapshot.free_disk_bytes == 12345 + + +def test_reserve_transcription_admission_uses_atomic_runtime_slot(monkeypatch): + monkeypatch.setattr(job_runtime, "_active_job_ids", set()) + budget = AdmissionBudget(max_active_jobs=1, max_in_flight_jobs=1) + + reserve_transcription_admission("tr_one", budget) + with pytest.raises(AdmissionRejectedError) as exc_info: + reserve_transcription_admission("tr_two", budget) + + assert exc_info.value.reason == "active_job_budget_exceeded" + assert job_runtime.active_job_count() == 1 + assert release_transcription_admission("tr_one") is True + reserve_transcription_admission("tr_two", budget) + + +def test_in_flight_admission_returns_duplicate_before_budget_rejection(monkeypatch): + monkeypatch.setattr(job_runtime, "_in_flight_hashes", {"sha256:one": "tr_one"}) + budget = AdmissionBudget(max_active_jobs=1, max_in_flight_jobs=1) + + assert find_in_flight_transcription("sha256:one") == "tr_one" + duplicate = admit_transcription_in_flight("sha256:one", "tr_two", budget) + with pytest.raises(AdmissionRejectedError) as exc_info: + admit_transcription_in_flight("sha256:two", "tr_two", budget) + + assert duplicate.existing_job_id == "tr_one" + assert duplicate.registered is False + assert exc_info.value.reason == "in_flight_job_budget_exceeded" diff --git a/tests/unit/test_api_route_coverage.py b/tests/unit/test_api_route_coverage.py index 8fa1439..1ee2092 100644 --- a/tests/unit/test_api_route_coverage.py +++ b/tests/unit/test_api_route_coverage.py @@ -3,6 +3,7 @@ from __future__ import annotations import json +import hashlib from pathlib import Path import numpy as np @@ -43,15 +44,22 @@ def _seed_result(transcriptions_dir: Path, tr_id: str, *, filename: str = "audio return result_path +def _record_settings(): + from application import transcription_records as records + + return records.default_record_settings() + + def test_transcription_job_status_fallback_paths(app_client): - import api.routers.transcriptions as router + import application.transcription_submission as submission - router.jobs["tr_memory_done"] = { + settings = _record_settings() + submission.jobs["tr_memory_done"] = { "status": "completed", "filename": "done.wav", "result": {"id": "tr_memory_done"}, } - router.jobs["tr_memory_failed"] = { + submission.jobs["tr_memory_failed"] = { "status": "failed", "filename": "failed.wav", "error": "boom", @@ -65,7 +73,7 @@ def test_transcription_job_status_fallback_paths(app_client): assert failed.status_code == 200 assert failed.json()["error"] == "boom" - completed_dir = router.TRANSCRIPTIONS_DIR / "tr_disk_done" + completed_dir = settings.transcriptions_dir / "tr_disk_done" completed_dir.mkdir(parents=True) (completed_dir / "status.json").write_text( json.dumps({"status": "completed", "filename": "disk.wav"}), @@ -76,7 +84,7 @@ def test_transcription_job_status_fallback_paths(app_client): assert disk_done.status_code == 200 assert disk_done.json()["result"] is None - queued_dir = router.TRANSCRIPTIONS_DIR / "tr_disk_queued" + queued_dir = settings.transcriptions_dir / "tr_disk_queued" queued_dir.mkdir(parents=True) (queued_dir / "status.json").write_text( json.dumps({"status": "queued", "filename": "queued.wav"}), @@ -86,7 +94,7 @@ def test_transcription_job_status_fallback_paths(app_client): assert disk_queued.status_code == 200 assert disk_queued.json()["status"] == "failed" - failed_dir = router.TRANSCRIPTIONS_DIR / "tr_disk_failed" + failed_dir = settings.transcriptions_dir / "tr_disk_failed" failed_dir.mkdir(parents=True) (failed_dir / "status.json").write_text( json.dumps({"status": "failed", "filename": "failed.wav", "error": "bad"}), @@ -98,11 +106,10 @@ def test_transcription_job_status_fallback_paths(app_client): def test_transcription_list_audio_export_and_reassign_paths(app_client): - import api.routers.transcriptions as router - + settings = _record_settings() tr_id = "tr_route_edges" - _seed_result(router.TRANSCRIPTIONS_DIR, tr_id, filename="route_audio.wav") - bad_dir = router.TRANSCRIPTIONS_DIR / "tr_bad_listing" + _seed_result(settings.transcriptions_dir, tr_id, filename="route_audio.wav") + bad_dir = settings.transcriptions_dir / "tr_bad_listing" bad_dir.mkdir(parents=True) (bad_dir / "result.json").write_text("{bad-json", encoding="utf-8") @@ -115,7 +122,7 @@ def test_transcription_list_audio_export_and_reassign_paths(app_client): missing_audio = app_client.get(f"/api/transcriptions/{tr_id}/audio") assert missing_audio.status_code == 404 - (router.UPLOADS_DIR / "route_audio.wav").write_bytes(b"audio") + (settings.uploads_dir / "route_audio.wav").write_bytes(b"audio") audio = app_client.get(f"/api/transcriptions/{tr_id}/audio") assert audio.status_code == 200 assert audio.content == b"audio" @@ -169,7 +176,9 @@ def get_speaker(self, speaker_id): ) assert cleared.status_code == 200 - result = json.loads((router.TRANSCRIPTIONS_DIR / tr_id / "result.json").read_text()) + result = json.loads( + (settings.transcriptions_dir / tr_id / "result.json").read_text() + ) assert result["segments"][0]["speaker_id"] == "spk_known" assert result["segments"][1]["speaker_id"] is None assert result["unique_speakers"] == ["Maple"] @@ -181,6 +190,180 @@ def get_speaker(self, speaker_id): assert missing_segment.status_code == 404 +def test_transcribe_rejects_before_background_thread_when_admission_budget_full( + app_client, + monkeypatch, +): + import application.transcription_submission as submission + import infra.job_runtime as job_runtime + + settings = _record_settings() + monkeypatch.setattr( + submission, + "default_submission_settings", + lambda: submission.TranscriptionSubmissionSettings( + max_upload_bytes=1024 * 1024, + upload_chunk=1024, + max_active_jobs=1, + max_in_flight_jobs=1, + uploads_dir=settings.uploads_dir, + transcriptions_dir=settings.transcriptions_dir, + ), + ) + monkeypatch.setattr(job_runtime, "_active_job_ids", {"tr_busy"}) + monkeypatch.setattr(job_runtime, "_in_flight_hashes", {}) + + started = [] + + class FailingThread: + def __init__(self, *args, **kwargs): + started.append(("created", args, kwargs)) + + def start(self): + started.append(("started",)) + + monkeypatch.setattr(submission, "Thread", FailingThread) + + response = app_client.post( + "/api/transcribe", + files={"file": ("budget.wav", b"RIFF\x00\x00\x00\x00WAVEfmt ", "audio/wav")}, + data={"language": "en"}, + ) + + assert response.status_code == 503 + assert "active job budget" in response.text + assert started == [] + + +def test_transcribe_maps_invalid_no_repeat_ngram_size_to_422( + app_client, + monkeypatch, +): + import api.routers.transcriptions as router + + async def fail_submission(*args, **kwargs): + raise AssertionError("invalid form input should not reach submission usecase") + + monkeypatch.setattr(router, "submit_transcription_upload", fail_submission) + + response = app_client.post( + "/api/transcribe", + files={"file": ("bad-param.wav", b"RIFF\x00\x00\x00\x00WAVEfmt ", "audio/wav")}, + data={"language": "en", "no_repeat_ngram_size": "banana"}, + ) + + assert response.status_code == 422 + assert response.json()["detail"] == [ + { + "loc": ["body", "no_repeat_ngram_size"], + "msg": "value is not a valid integer", + "type": "type_error.integer", + } + ] + + +def test_transcribe_rejects_in_flight_budget_after_durable_bootstrap( + app_client, + monkeypatch, +): + import application.transcription_submission as submission + import infra.job_runtime as job_runtime + + settings = _record_settings() + monkeypatch.setattr(job_runtime, "_active_job_ids", set()) + monkeypatch.setattr( + submission, + "default_submission_settings", + lambda: submission.TranscriptionSubmissionSettings( + max_upload_bytes=1024 * 1024, + upload_chunk=1024, + max_active_jobs=1, + max_in_flight_jobs=1, + uploads_dir=settings.uploads_dir, + transcriptions_dir=settings.transcriptions_dir, + ), + ) + monkeypatch.setattr(job_runtime, "_in_flight_hashes", {"sha256:busy": "tr_busy"}) + + started = [] + + class FailingThread: + def __init__(self, *args, **kwargs): + started.append(("created", args, kwargs)) + + def start(self): + started.append(("started",)) + + monkeypatch.setattr(submission, "Thread", FailingThread) + + response = app_client.post( + "/api/transcribe", + files={"file": ("busy.wav", b"RIFF\x00\x00\x00\x00WAVEfmt ", "audio/wav")}, + data={"language": "en"}, + ) + + assert response.status_code == 503 + assert "in-flight job budget" in response.text + assert started == [] + assert not list(settings.transcriptions_dir.glob("tr_*")) + assert not list(settings.uploads_dir.glob("tr_*")) + assert job_runtime.active_job_count() == 0 + + +def test_transcribe_reuses_duplicate_in_flight_when_budgets_are_full( + app_client, + monkeypatch, +): + import application.transcription_submission as submission + import infra.job_runtime as job_runtime + + settings = _record_settings() + audio = b"RIFF\x00\x00\x00\x00WAVEfmt duplicate" + file_hash = hashlib.sha256(audio).hexdigest() + monkeypatch.setattr( + submission, + "default_submission_settings", + lambda: submission.TranscriptionSubmissionSettings( + max_upload_bytes=1024 * 1024, + upload_chunk=1024, + max_active_jobs=1, + max_in_flight_jobs=1, + uploads_dir=settings.uploads_dir, + transcriptions_dir=settings.transcriptions_dir, + ), + ) + monkeypatch.setattr(job_runtime, "_active_job_ids", {"tr_existing"}) + monkeypatch.setattr(job_runtime, "_in_flight_hashes", {file_hash: "tr_existing"}) + + started = [] + + class FailingThread: + def __init__(self, *args, **kwargs): + started.append(("created", args, kwargs)) + + def start(self): + started.append(("started",)) + + monkeypatch.setattr(submission, "Thread", FailingThread) + + response = app_client.post( + "/api/transcribe", + files={"file": ("duplicate.wav", audio, "audio/wav")}, + data={"language": "en"}, + ) + + assert response.status_code == 200 + assert response.json() == { + "id": "tr_existing", + "status": "queued", + "deduplicated": True, + } + assert started == [] + assert not list(settings.transcriptions_dir.glob("tr_*")) + assert not list(settings.uploads_dir.glob("tr_*")) + assert job_runtime.active_job_count() == 1 + + def test_voiceprint_management_routes(app_client): import api.routers.voiceprints as router @@ -226,6 +409,28 @@ def build_cohort_from_transcriptions(self, transcriptions_dir): fake_db = FakeDB() app_client.app.state.db = fake_db + invalid_tr_id = app_client.post( + "/api/voiceprints/enroll", + data={ + "tr_id": "../bad", + "speaker_label": "SPEAKER_00", + "speaker_name": "Maple", + }, + ) + assert invalid_tr_id.status_code == 400 + assert "Invalid transcription ID format" in invalid_tr_id.text + + invalid_speaker_label = app_client.post( + "/api/voiceprints/enroll", + data={ + "tr_id": "tr_voiceprint", + "speaker_label": "../bad", + "speaker_name": "Maple", + }, + ) + assert invalid_speaker_label.status_code == 400 + assert "Invalid speaker label" in invalid_speaker_label.text + missing = app_client.post( "/api/voiceprints/enroll", data={ diff --git a/tests/unit/test_architecture_gates.py b/tests/unit/test_architecture_gates.py new file mode 100644 index 0000000..2f705f0 --- /dev/null +++ b/tests/unit/test_architecture_gates.py @@ -0,0 +1,906 @@ +"""Source-level architecture gates for import direction constraints.""" + +from __future__ import annotations + +import ast +import importlib.util +from pathlib import Path + + +REPO_ROOT = Path(__file__).resolve().parents[2] +APP_ROOT = REPO_ROOT / "app" +NON_API_RING_ROOTS = ( + "application", + "pipeline", + "providers", + "infra", + "voiceprints", + "postprocess", +) +ARCHITECTURE_GATE = REPO_ROOT / "voscript-api" / "scripts" / "architecture_gate.py" + + +def _load_architecture_gate(): + spec = importlib.util.spec_from_file_location( + "voscript_architecture_gate", + ARCHITECTURE_GATE, + ) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def _module_name(path: Path) -> tuple[str, bool]: + relative = path.relative_to(APP_ROOT).with_suffix("") + parts = relative.parts + if parts[-1] == "__init__": + return ".".join(parts[:-1]), True + return ".".join(parts), False + + +def _app_modules() -> dict[str, Path]: + modules: dict[str, Path] = {} + for path in sorted(APP_ROOT.rglob("*.py")): + module_name, _ = _module_name(path) + modules[module_name] = path + return modules + + +def _resolve_relative_import( + current_module: str, + is_package: bool, + *, + level: int, + module: str | None, +) -> str: + package_parts = ( + current_module.split(".") if is_package else current_module.split(".")[:-1] + ) + prefix = package_parts[: len(package_parts) - level + 1] + if module: + prefix.extend(module.split(".")) + return ".".join(prefix) + + +def _strip_app_prefix(module_name: str) -> str: + if module_name == "app": + return "" + if module_name.startswith("app."): + return module_name.removeprefix("app.") + return module_name + + +def _internal_module_for(module_name: str, modules: set[str]) -> str | None: + candidate = _strip_app_prefix(module_name) + if not candidate: + return None + parts = candidate.split(".") + for end in range(len(parts), 0, -1): + prefix = ".".join(parts[:end]) + if prefix in modules: + return prefix + return None + + +class _ImportCollector(ast.NodeVisitor): + def __init__( + self, + *, + current_module: str, + is_package: bool, + modules: set[str], + ) -> None: + self.current_module = current_module + self.is_package = is_package + self.modules = modules + self.targets: set[str] = set() + self._type_checking_depth = 0 + + def visit_If(self, node: ast.If) -> None: + if _is_type_checking_guard(node.test): + self._type_checking_depth += 1 + for child in node.body: + self.visit(child) + self._type_checking_depth -= 1 + for child in node.orelse: + self.visit(child) + return + self.generic_visit(node) + + def visit_Import(self, node: ast.Import) -> None: + if self._type_checking_depth: + return + for alias in node.names: + self._add_internal_target(alias.name) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + if self._type_checking_depth: + return + if node.level: + base = _resolve_relative_import( + self.current_module, + self.is_package, + level=node.level, + module=node.module, + ) + else: + base = node.module or "" + + if base: + self._add_internal_target(base) + for alias in node.names: + if alias.name == "*": + continue + target = f"{base}.{alias.name}" if base else alias.name + self._add_internal_target(target) + + def _add_internal_target(self, module_name: str) -> None: + target = _internal_module_for(module_name, self.modules) + if target is not None and target != self.current_module: + self.targets.add(target) + + +def _is_type_checking_guard(node: ast.expr) -> bool: + if isinstance(node, ast.Name): + return node.id == "TYPE_CHECKING" + if isinstance(node, ast.Attribute): + return node.attr == "TYPE_CHECKING" + return False + + +def _static_internal_import_graph() -> dict[str, set[str]]: + """Return AST-visible app import edges; dynamic import strings are excluded.""" + + module_paths = _app_modules() + modules = set(module_paths) + graph: dict[str, set[str]] = {module: set() for module in modules} + for module, path in module_paths.items(): + _, is_package = _module_name(path) + collector = _ImportCollector( + current_module=module, + is_package=is_package, + modules=modules, + ) + collector.visit(ast.parse(path.read_text(encoding="utf-8"), filename=str(path))) + graph[module].update(collector.targets) + return graph + + +def _runtime_dynamic_edge_keys(report: dict) -> set[tuple[str, str, str]]: + return { + (edge["source"], edge["target"], edge["kind"]) + for edge in report["runtime_dynamic_import_graph"]["edges"] + } + + +def _runtime_dynamic_forbidden_keys(report: dict) -> set[tuple[str, str, str]]: + return { + (finding["rule"], finding["module"], finding["target"]) + for finding in report["runtime_dynamic_forbidden_dependencies"] + } + + +def _strongly_connected_components(graph: dict[str, set[str]]) -> list[tuple[str, ...]]: + index = 0 + indexes: dict[str, int] = {} + lowlinks: dict[str, int] = {} + stack: list[str] = [] + on_stack: set[str] = set() + components: list[tuple[str, ...]] = [] + + def strongconnect(node: str) -> None: + nonlocal index + indexes[node] = index + lowlinks[node] = index + index += 1 + stack.append(node) + on_stack.add(node) + + for target in sorted(graph[node]): + if target not in indexes: + strongconnect(target) + lowlinks[node] = min(lowlinks[node], lowlinks[target]) + elif target in on_stack: + lowlinks[node] = min(lowlinks[node], indexes[target]) + + if lowlinks[node] == indexes[node]: + component: list[str] = [] + while True: + target = stack.pop() + on_stack.remove(target) + component.append(target) + if target == node: + break + if len(component) > 1: + components.append(tuple(sorted(component))) + + for module in sorted(graph): + if module not in indexes: + strongconnect(module) + return sorted(components) + + +def _function_call_names(path: Path, function_name: str) -> set[str]: + tree = ast.parse(path.read_text(encoding="utf-8"), filename=str(path)) + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) and node.name == function_name: + calls: set[str] = set() + for child in ast.walk(node): + if not isinstance(child, ast.Call): + continue + if isinstance(child.func, ast.Name): + calls.add(child.func.id) + elif isinstance(child.func, ast.Attribute): + calls.add(child.func.attr) + return calls + raise AssertionError(f"{function_name} not found in {path}") + + +def _non_api_ring_python_files() -> list[Path]: + paths: list[Path] = [] + for root_name in NON_API_RING_ROOTS: + root = APP_ROOT / root_name + if root.exists(): + paths.extend(sorted(root.rglob("*.py"))) + return paths + + +def _fastapi_import_labels(node: ast.AST) -> list[str]: + if isinstance(node, ast.Import): + return [ + alias.name + for alias in node.names + if alias.name == "fastapi" or alias.name.startswith("fastapi.") + ] + if isinstance(node, ast.ImportFrom) and node.module: + if node.module == "fastapi" or node.module.startswith("fastapi."): + return [f"from {node.module}"] + return [] + + +def _http_exception_reference_lines(tree: ast.AST) -> list[int]: + lines: set[int] = set() + for node in ast.walk(tree): + if isinstance(node, ast.Name) and node.id == "HTTPException": + lines.add(node.lineno) + elif isinstance(node, ast.Attribute) and node.attr == "HTTPException": + lines.add(node.lineno) + elif isinstance(node, (ast.Import, ast.ImportFrom)): + for alias in node.names: + if alias.name == "HTTPException" or alias.asname == "HTTPException": + lines.add(node.lineno) + return sorted(lines) + + +def _pipeline_status_import_locations(tree: ast.AST) -> list[dict[str, object]]: + gate = _load_architecture_gate() + return gate._pipeline_status_import_locations(tree) + + +def test_app_internal_static_python_import_graph_has_no_scc(): + graph = _static_internal_import_graph() + components = _strongly_connected_components(graph) + + assert components == [] + + +def test_architecture_gate_report_exposes_cycle_evidence(): + gate = _load_architecture_gate() + report = gate.build_report(REPO_ROOT) + static_graph = report["static_import_graph"] + dynamic_graph = report["runtime_dynamic_import_graph"] + + assert static_graph["module_count"] == len(_app_modules()) + assert static_graph["internal_edge_count"] == sum( + len(targets) for targets in _static_internal_import_graph().values() + ) + assert set(report) == { + "runtime_dynamic_forbidden_dependencies", + "runtime_dynamic_import_graph", + "static_forbidden_dependencies", + "static_import_graph", + } + assert set(static_graph) == { + "internal_edge_count", + "layer_edges", + "layer_sccs", + "module_count", + "module_sccs", + } + assert set(dynamic_graph) == { + "edge_count", + "edges", + "layer_edges", + "module_sccs", + } + assert static_graph["module_sccs"] == [] + assert static_graph["layer_sccs"] == [] + assert report["static_forbidden_dependencies"] == [] + assert dynamic_graph["module_sccs"] == [] + assert report["runtime_dynamic_forbidden_dependencies"] == [] + + +def test_architecture_gate_reports_runtime_dynamic_registry_and_literal_edges(): + gate = _load_architecture_gate() + report = gate.build_report(REPO_ROOT) + edge_keys = _runtime_dynamic_edge_keys(report) + + assert ( + "pipeline.registry", + "pipeline.stages.asr", + "registry_stage", + ) in edge_keys + assert ( + "pipeline.registry", + "providers.asr.default", + "registry_provider", + ) in edge_keys + assert ( + "pipeline.runner", + "infra.audio", + "literal_import_module", + ) in edge_keys + assert ( + "pipeline.runner", + "providers.capabilities", + "literal_import_module", + ) in edge_keys + assert ( + "pipeline.orchestrator", + "infra.huggingface_models", + "literal_import_module", + ) in edge_keys + assert ( + "pipeline.orchestrator", + "infra.cuda_devices", + "literal_import_module", + ) in edge_keys + assert ( + "pipeline.orchestrator", + "providers.asr", + "literal_import_module", + ) in edge_keys + assert ( + "pipeline.orchestrator", + "providers.diarization", + "literal_import_module", + ) in edge_keys + assert ( + "pipeline.orchestrator", + "providers.embedding", + "literal_import_module", + ) in edge_keys + assert ( + "providers._registry", + "pipeline.registry", + "literal_import_module", + ) not in edge_keys + + +def _write_module(root: Path, relative_path: str, source: str) -> None: + path = root / relative_path + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(source, encoding="utf-8") + + +def _write_metadata_contract(root: Path, keys: tuple[str, ...]) -> None: + _write_module( + root, + "app/pipeline/contracts/metadata.py", + f"PIPELINE_METADATA_TOP_LEVEL_KEYS = {keys!r}\n", + ) + + +def test_architecture_gate_flags_unknown_pipeline_context_metadata_keys(tmp_path): + gate = _load_architecture_gate() + _write_metadata_contract(tmp_path, ("diarization",)) + _write_module( + tmp_path, + "app/pipeline/stages/example.py", + "def run(context):\n context.metadata['freeform'] = {'status': 'bad'}\n", + ) + + report = gate.build_report(tmp_path) + + assert report["static_forbidden_dependencies"] == [ + { + "rule": "pipeline_context_metadata_top_level_key_contract", + "module": "pipeline.stages.example", + "path": "app/pipeline/stages/example.py", + "locations": [ + { + "line": 2, + "key": "freeform", + "access": "subscript", + } + ], + } + ] + + +def test_architecture_gate_reads_metadata_contract_without_executing_it(tmp_path): + gate = _load_architecture_gate() + _write_module( + tmp_path, + "app/pipeline/contracts/metadata.py", + "PIPELINE_METADATA_CONTROL_KEYS = ('executed_stages',)\n" + "PIPELINE_METADATA_STAGE_KEYS = ('diarization',)\n" + "PIPELINE_METADATA_TOP_LEVEL_KEYS = (\n" + " *PIPELINE_METADATA_CONTROL_KEYS,\n" + " *PIPELINE_METADATA_STAGE_KEYS,\n" + ")\n" + "raise RuntimeError('metadata contract must not execute')\n", + ) + _write_module( + tmp_path, + "app/pipeline/stages/example.py", + "def run(context):\n context.metadata['diarization'] = {'status': 'ok'}\n", + ) + + report = gate.build_report(tmp_path) + + assert report["static_forbidden_dependencies"] == [] + + +def test_architecture_gate_flags_unbounded_pipeline_context_metadata_update(tmp_path): + gate = _load_architecture_gate() + _write_metadata_contract(tmp_path, ("diarization",)) + _write_module( + tmp_path, + "app/pipeline/stages/example.py", + "def run(context, result):\n" + " context.metadata['diarization'].update(result.metadata)\n", + ) + + report = gate.build_report(tmp_path) + + assert report["static_forbidden_dependencies"] == [ + { + "rule": "pipeline_context_metadata_no_unbounded_update", + "module": "pipeline.stages.example", + "path": "app/pipeline/stages/example.py", + "locations": [ + { + "line": 2, + "key": "diarization", + } + ], + } + ] + + +def test_architecture_gate_flags_runtime_dynamic_import_scc(tmp_path): + gate = _load_architecture_gate() + _write_module( + tmp_path, + "app/pipeline/a.py", + "from importlib import import_module\n\n\ndef load():\n return import_module('pipeline.b')\n", + ) + _write_module( + tmp_path, + "app/pipeline/b.py", + "from importlib import import_module\n\n\ndef load():\n return import_module('pipeline.a')\n", + ) + + report = gate.build_report(tmp_path) + + assert report["static_import_graph"]["module_sccs"] == [] + assert report["runtime_dynamic_import_graph"]["module_sccs"] == [ + ["pipeline.a", "pipeline.b"] + ] + + +def test_architecture_gate_flags_runtime_dynamic_application_boundary(tmp_path): + gate = _load_architecture_gate() + _write_module(tmp_path, "app/application/jobs.py", "def run():\n return None\n") + _write_module( + tmp_path, + "app/providers/default.py", + "from importlib import import_module\n\n\ndef load():\n return import_module('application.jobs')\n", + ) + + report = gate.build_report(tmp_path) + + assert report["runtime_dynamic_forbidden_dependencies"] == [ + { + "rule": "providers_do_not_runtime_import_orchestration_or_stage_registry", + "module": "providers.default", + "target": "application.jobs", + "kind": "literal_import_module", + "import": "application.jobs", + "locations": [ + { + "path": "app/providers/default.py", + "line": 5, + } + ], + } + ] + + +def test_architecture_gate_flags_provider_runtime_registry_and_stage_imports(tmp_path): + gate = _load_architecture_gate() + _write_module( + tmp_path, "app/pipeline/registry.py", "def resolve_provider():\n pass\n" + ) + _write_module(tmp_path, "app/pipeline/stages/asr.py", "def run():\n pass\n") + _write_module( + tmp_path, + "app/providers/default.py", + "from importlib import import_module\n\n\n" + "def load_registry():\n" + " return import_module('pipeline.registry')\n\n\n" + "def load_stage():\n" + " return import_module('pipeline.stages.asr')\n", + ) + + report = gate.build_report(tmp_path) + + assert { + ( + "providers_do_not_runtime_import_orchestration_or_stage_registry", + "providers.default", + "pipeline.registry", + ), + ( + "providers_do_not_runtime_import_orchestration_or_stage_registry", + "providers.default", + "pipeline.stages.asr", + ), + }.issubset(_runtime_dynamic_forbidden_keys(report)) + + +def test_non_api_rings_do_not_static_import_fastapi_or_reference_http_exception(): + """Guard source-level API boundary imports; dynamic runtime behavior is separate.""" + + offenders: dict[str, dict[str, list[str] | list[int]]] = {} + + for path in _non_api_ring_python_files(): + tree = ast.parse(path.read_text(encoding="utf-8"), filename=str(path)) + fastapi_imports: list[str] = [] + for node in ast.walk(tree): + fastapi_imports.extend(_fastapi_import_labels(node)) + + http_exception_lines = _http_exception_reference_lines(tree) + if fastapi_imports or http_exception_lines: + offenders[str(path.relative_to(REPO_ROOT))] = { + "fastapi_imports": fastapi_imports, + "http_exception_lines": http_exception_lines, + } + + assert offenders == {} + + +def test_pipeline_contracts_static_imports_stay_on_contracts_or_low_level_pipeline_modules(): + graph = _static_internal_import_graph() + allowed_exact = {"pipeline.errors", "pipeline.step_keys"} + + def disallowed_targets(targets: set[str]) -> list[str]: + return sorted( + target + for target in targets + if not ( + target == "pipeline.contracts" + or target.startswith("pipeline.contracts.") + or target in allowed_exact + ) + ) + + offenders = { + module: invalid_targets + for module, targets in graph.items() + for invalid_targets in (disallowed_targets(targets),) + if module.startswith("pipeline.contracts") and invalid_targets + } + + assert offenders == {} + + +def test_status_contract_owner_is_infra_not_pipeline_contracts(): + assert not (APP_ROOT / "pipeline" / "contracts" / "status.py").exists() + + from infra import job_status + from pipeline import contracts + + assert job_status.build_status_payload( + "queued", + filename="private/audio.wav", + updated_at="2026-06-09T00:00:00+00:00", + ) == { + "status": "queued", + "updated_at": "2026-06-09T00:00:00+00:00", + "error": None, + "filename": "audio.wav", + } + forbidden_exports = { + "IN_PROGRESS_JOB_STATUSES", + "JOB_STATUS_COMPLETED", + "JOB_STATUS_CONVERTING", + "JOB_STATUS_DENOISING", + "JOB_STATUS_FAILED", + "JOB_STATUS_IDENTIFYING", + "JOB_STATUS_QUEUED", + "JOB_STATUS_TRANSCRIBING", + "KNOWN_JOB_STATUSES", + "TERMINAL_JOB_STATUSES", + "build_status_payload", + "normalize_job_status", + "normalize_status_payload", + } + assert [ + name for name in sorted(forbidden_exports) if hasattr(contracts, name) + ] == [] + + +def test_application_and_infra_do_not_import_pipeline_status_helpers(): + offenders: dict[str, list[dict[str, object]]] = {} + + for root_name in ("application", "infra"): + root = APP_ROOT / root_name + if not root.exists(): + continue + for path in sorted(root.rglob("*.py")): + tree = ast.parse(path.read_text(encoding="utf-8"), filename=str(path)) + status_imports = _pipeline_status_import_locations(tree) + if status_imports: + offenders[str(path.relative_to(REPO_ROOT))] = status_imports + + assert offenders == {} + + +def test_architecture_gate_flags_application_or_infra_status_contract_imports( + tmp_path, +): + gate = _load_architecture_gate() + _write_module( + tmp_path, + "app/application/records.py", + "from pipeline.contracts import normalize_status_payload\n", + ) + _write_module( + tmp_path, + "app/infra/jobs.py", + "from pipeline.contracts.status import build_status_payload\n", + ) + _write_module( + tmp_path, + "app/pipeline/contracts/status.py", + "def build_status_payload():\n pass\n", + ) + + report = gate.build_report(tmp_path) + + assert report["static_forbidden_dependencies"] == [ + { + "rule": "application_and_infra_use_infra_job_status_owner", + "module": "application.records", + "path": "app/application/records.py", + "locations": [ + { + "line": 1, + "import": "from pipeline.contracts", + "symbol": "normalize_status_payload", + } + ], + }, + { + "rule": "application_and_infra_use_infra_job_status_owner", + "module": "infra.jobs", + "path": "app/infra/jobs.py", + "locations": [ + { + "line": 1, + "import": "from pipeline.contracts.status", + "symbol": "build_status_payload", + } + ], + }, + ] + + +def test_architecture_gate_flags_application_private_job_boundary_imports(tmp_path): + gate = _load_architecture_gate() + _write_module( + tmp_path, + "app/application/jobs.py", + "from infra.job_runtime import jobs\n" + "from infra.job_persistence import _write_status, write_job_status\n" + "from infra.job_persistence import *\n\n" + "def record_files():\n" + " return 'status.json', 'result.json'\n", + ) + _write_module( + tmp_path, + "app/application/runtime_module.py", + "import infra.job_runtime as job_runtime\n\n" + "def current_jobs():\n" + " return job_runtime.jobs\n", + ) + _write_module( + tmp_path, + "app/application/persistence_module.py", + "import infra.job_persistence as job_persistence\n\n" + "def atomic_write():\n" + " return job_persistence._atomic_write_json\n", + ) + _write_module( + tmp_path, + "app/application/runtime_from_infra.py", + "from infra import job_runtime\n\n" + "def current_jobs():\n" + " return job_runtime.jobs\n", + ) + + report = gate.build_report(tmp_path) + + assert report["static_forbidden_dependencies"] == [ + { + "rule": "application_uses_public_infra_job_boundary", + "module": "application.jobs", + "path": "app/application/jobs.py", + "locations": [ + { + "line": 1, + "import": "from infra.job_runtime", + "symbol": "jobs", + }, + { + "line": 2, + "import": "from infra.job_persistence", + "symbol": "_write_status", + }, + { + "line": 3, + "import": "from infra.job_persistence", + "symbol": "*", + }, + { + "line": 6, + "import": "transcription record filesystem literal", + "symbol": "result.json", + }, + { + "line": 6, + "import": "transcription record filesystem literal", + "symbol": "status.json", + }, + ], + }, + { + "rule": "application_uses_public_infra_job_boundary", + "module": "application.persistence_module", + "path": "app/application/persistence_module.py", + "locations": [ + { + "line": 4, + "import": "job_persistence._atomic_write_json", + "symbol": "_atomic_write_json", + } + ], + }, + { + "rule": "application_uses_public_infra_job_boundary", + "module": "application.runtime_from_infra", + "path": "app/application/runtime_from_infra.py", + "locations": [ + { + "line": 4, + "import": "job_runtime.jobs", + "symbol": "jobs", + } + ], + }, + { + "rule": "application_uses_public_infra_job_boundary", + "module": "application.runtime_module", + "path": "app/application/runtime_module.py", + "locations": [ + { + "line": 4, + "import": "job_runtime.jobs", + "symbol": "jobs", + } + ], + }, + ] + + +def test_pipeline_registry_static_imports_stay_lazy_across_pipeline_boundaries(): + graph = _static_internal_import_graph() + forbidden_prefixes = ( + "pipeline.contracts", + "pipeline.stages", + "providers", + ) + offenders = sorted( + target + for target in graph["pipeline.registry"] + if any( + target == prefix or target.startswith(f"{prefix}.") + for prefix in forbidden_prefixes + ) + ) + + assert offenders == [] + + +def test_pipeline_stage_slots_do_not_static_import_provider_facades(): + """Guard source-level stage imports; runtime registry strings are separate.""" + + graph = _static_internal_import_graph() + offenders = { + module: sorted( + target + for target in targets + if target == "providers" or target.startswith("providers.") + ) + for module, targets in graph.items() + if module.startswith("pipeline.stages.") + and any( + target == "providers" or target.startswith("providers.") + for target in targets + ) + } + + assert offenders == {} + + +def test_provider_ring_does_not_static_import_pipeline_registry_or_stages(): + """Guard source-level provider imports; runtime importlib lookup is separate.""" + + graph = _static_internal_import_graph() + forbidden_prefixes = ("pipeline.registry", "pipeline.stages") + offenders = { + module: sorted( + target + for target in targets + if any( + target == prefix or target.startswith(f"{prefix}.") + for prefix in forbidden_prefixes + ) + ) + for module, targets in graph.items() + if (module == "providers" or module.startswith("providers.")) + and any( + target == prefix or target.startswith(f"{prefix}.") + for prefix in forbidden_prefixes + for target in targets + ) + } + + assert offenders == {} + + +def test_provider_selector_normalizers_delegate_to_shared_token_normalizer(): + module_paths = _app_modules() + checked_modules = ( + "pipeline.contracts.requests", + "providers.capabilities", + ) + manual_calls: dict[str, list[str]] = {} + missing_delegate: list[str] = [] + + for module in checked_modules: + calls = _function_call_names(module_paths[module], "_normalize_provider_name") + if "normalize_token" not in calls: + missing_delegate.append(module) + duplicated_steps = sorted({"strip", "lower", "replace"} & calls) + if duplicated_steps: + manual_calls[module] = duplicated_steps + + assert missing_delegate == [] + assert manual_calls == {} + + +def test_pipeline_lookup_errors_keep_legacy_public_imports(): + from pipeline.contracts import ( + ProviderNotFoundError as ContractsProviderNotFoundError, + ) + from pipeline.contracts import StageNotFoundError as ContractsStageNotFoundError + from pipeline.registry import ProviderNotFoundError as RegistryProviderNotFoundError + from pipeline.registry import StageNotFoundError as RegistryStageNotFoundError + + assert ContractsProviderNotFoundError is RegistryProviderNotFoundError + assert ContractsStageNotFoundError is RegistryStageNotFoundError diff --git a/tests/unit/test_artifact_status_schema_contracts.py b/tests/unit/test_artifact_status_schema_contracts.py index 37be49d..6d17555 100644 --- a/tests/unit/test_artifact_status_schema_contracts.py +++ b/tests/unit/test_artifact_status_schema_contracts.py @@ -4,17 +4,18 @@ import pytest +import pipeline.contracts as pipeline_contracts +from infra.job_status import build_status_payload, normalize_status_payload from pipeline.contracts import ( ARTIFACT_MANIFEST_VERSION, ArtifactManifestEntry, attach_optional_schema_version, build_artifact_manifest, - build_status_payload, empty_artifact_manifest, normalize_artifact_manifest, - normalize_status_payload, read_optional_schema_version, ) +from pipeline.registry import available_stage_slots def test_artifact_manifest_builds_public_safe_known_categories_only(): @@ -160,3 +161,72 @@ def test_schema_version_is_optional_first_for_legacy_artifacts(): } with pytest.raises(ValueError, match="schema_version"): read_optional_schema_version({"schema_version": "../private"}) + + +def test_pipeline_metadata_contract_covers_stable_stage_order_and_control_keys(): + assert hasattr(pipeline_contracts, "PIPELINE_METADATA_CONTRACT") + assert hasattr(pipeline_contracts, "PIPELINE_METADATA_CONTROL_KEYS") + assert hasattr(pipeline_contracts, "PIPELINE_METADATA_PATH_CONTRACT") + assert hasattr(pipeline_contracts, "PIPELINE_METADATA_PUBLIC_PATHS") + assert hasattr(pipeline_contracts, "PIPELINE_METADATA_STAGE_KEYS") + + PIPELINE_METADATA_CONTRACT = pipeline_contracts.PIPELINE_METADATA_CONTRACT + PIPELINE_METADATA_CONTROL_KEYS = pipeline_contracts.PIPELINE_METADATA_CONTROL_KEYS + PIPELINE_METADATA_PATH_CONTRACT = pipeline_contracts.PIPELINE_METADATA_PATH_CONTRACT + PIPELINE_METADATA_PUBLIC_PATHS = pipeline_contracts.PIPELINE_METADATA_PUBLIC_PATHS + PIPELINE_METADATA_STAGE_KEYS = pipeline_contracts.PIPELINE_METADATA_STAGE_KEYS + + assert PIPELINE_METADATA_STAGE_KEYS == available_stage_slots() + assert PIPELINE_METADATA_CONTROL_KEYS == ( + "executed_stages", + "selected_providers", + "provider_capabilities", + "stage_timings", + ) + assert PIPELINE_METADATA_PUBLIC_PATHS == ("diarization.alignment",) + + for key in (*PIPELINE_METADATA_CONTROL_KEYS, *PIPELINE_METADATA_STAGE_KEYS): + entry = PIPELINE_METADATA_CONTRACT[key] + assert entry.owner + assert entry.writers + assert isinstance(entry.public, bool) + assert isinstance(entry.allow_overwrite, bool) + + alignment_entry = PIPELINE_METADATA_PATH_CONTRACT["diarization.alignment"] + assert alignment_entry.owner == "diarization" + assert alignment_entry.public is True + assert alignment_entry.allow_overwrite is False + + +def test_public_alignment_metadata_normalizer_keeps_safe_scalars_only(tmp_path): + assert hasattr(pipeline_contracts, "normalize_public_alignment_metadata") + normalize_public_alignment_metadata = ( + pipeline_contracts.normalize_public_alignment_metadata + ) + + normalized = normalize_public_alignment_metadata( + { + "status": "skipped", + "reason": "duration_budget_exceeded", + "model": "org/model", + "duration_s": 12.5, + "max_duration_s": 60, + "cache_only": False, + "device": "cpu", + "language": "zh", + "model_path": str(tmp_path / "private-model"), + "exception": RuntimeError("hidden"), + "debug": {"path": str(tmp_path)}, + "segments": ["not public"], + } + ) + + assert normalized == { + "status": "skipped", + "reason": "duration_budget_exceeded", + "model": "org/model", + "duration_s": 12.5, + "max_duration_s": 60, + "cache_only": False, + "device": "cpu", + } diff --git a/tests/unit/test_audio_layers.py b/tests/unit/test_audio_layers.py index 9bb2fe8..b228dc2 100644 --- a/tests/unit/test_audio_layers.py +++ b/tests/unit/test_audio_layers.py @@ -3,6 +3,7 @@ from __future__ import annotations import json +import re import subprocess import sys from contextlib import contextmanager @@ -13,14 +14,22 @@ import numpy as np import pytest import infra.audio.hash_index as hash_index_module +import infra.audio.paths as audio_paths import infra.audio as audio_infra +import infra.audio.metadata as audio_metadata import providers import providers.enhance.default as enhance_default import providers.normalize.default as normalize_default import providers.voiceprint_match.default as voiceprint_match_default from infra.audio import JsonAudioArtifactIndex +from infra.audio import ( + AudioPathTraversalError, + InvalidSpeakerLabelError, + InvalidTranscriptionIdError, +) from pipeline.contracts import ( AudioEnhancementRequest, + AudioNormalizationTimeoutError, AudioNormalizationRequest, UploadPersistenceRequest, VoiceprintMatchRequest, @@ -124,6 +133,18 @@ def test_unknown_denoise_model_is_a_noop(tmp_path, caplog): assert "Unknown DENOISE_MODEL='unsupported'" in caplog.text +def test_audio_duration_seconds_uses_metadata_without_loading(monkeypatch): + class FakeInfo: + sample_rate = 1000 + num_frames = 2500 + + torchaudio_module = ModuleType("torchaudio") + torchaudio_module.info = lambda path: FakeInfo() + monkeypatch.setitem(sys.modules, "torchaudio", torchaudio_module) + + assert audio_metadata.audio_duration_seconds("demo.wav") == 2.5 + + def test_estimate_snr_uses_energy_heuristic(monkeypatch, tmp_path): class FakeTensor: def __init__(self, values): @@ -369,6 +390,35 @@ def test_noisereduce_processing_timing_log_is_public_safe( assert "private-call.denoised.wav" not in caplog.text +def test_denoise_skips_long_audio_before_full_audio_load(monkeypatch, tmp_path, caplog): + wav_path = tmp_path / "very-long.wav" + wav_path.write_bytes(b"stub") + monkeypatch.setattr(enhance_default, "DENOISE_MAX_AUDIO_DURATION_SEC", 10.0) + monkeypatch.setattr(enhance_default, "audio_duration_seconds", lambda path: 11.0) + monkeypatch.setattr( + enhance_default, + "_estimate_snr", + lambda path: (_ for _ in ()).throw( + AssertionError("SNR load should not run for over-budget audio") + ), + ) + + soundfile_module = ModuleType("soundfile") + soundfile_module.read = lambda *args, **kwargs: (_ for _ in ()).throw( + AssertionError("soundfile.read should not run for over-budget audio") + ) + monkeypatch.setitem(sys.modules, "soundfile", soundfile_module) + + with caplog.at_level("WARNING", logger=enhance_default.logger.name): + result = enhance_default.ConditionalDenoiseEnhancer().enhance( + AudioEnhancementRequest(wav_path=wav_path, model="noisereduce") + ) + + assert result.applied is False + assert result.output_path == wav_path + assert "Denoise skipped by duration budget" in caplog.text + + def test_hash_index_infra_requires_completed_result(monkeypatch, tmp_path): monkeypatch.setattr(hash_index_module, "TRANSCRIPTIONS_DIR", tmp_path) @@ -384,6 +434,28 @@ def test_hash_index_infra_requires_completed_result(monkeypatch, tmp_path): assert store.lookup("hash-b") == "tr_ready" +def test_audio_path_helpers_raise_typed_errors(monkeypatch, tmp_path): + transcriptions_dir = tmp_path / "transcriptions" + transcriptions_dir.mkdir() + monkeypatch.setattr(audio_paths, "TRANSCRIPTIONS_DIR", transcriptions_dir) + + with pytest.raises(InvalidTranscriptionIdError, match="Invalid transcription ID"): + audio_paths.safe_tr_dir("../etc/passwd") + + with pytest.raises(InvalidSpeakerLabelError, match="Invalid speaker label"): + audio_paths.safe_speaker_label("bad/label") + + +def test_audio_path_traversal_guard_raises_typed_error(monkeypatch, tmp_path): + transcriptions_dir = tmp_path / "transcriptions" + transcriptions_dir.mkdir() + monkeypatch.setattr(audio_paths, "TRANSCRIPTIONS_DIR", transcriptions_dir) + monkeypatch.setattr(audio_paths, "_TR_ID_RE", re.compile(r".+")) + + with pytest.raises(AudioPathTraversalError, match="Path traversal detected"): + audio_paths.safe_tr_dir("../outside") + + def test_ffmpeg_normalizer_reuses_existing_target_format(tmp_path): wav_path = tmp_path / "already.wav" wav_path.write_bytes(b"wav") @@ -431,12 +503,14 @@ def fake_run(*args, **kwargs): monkeypatch.setattr(normalize_default.subprocess, "run", fake_run) - with pytest.raises(Exception) as excinfo: + with pytest.raises(AudioNormalizationTimeoutError) as excinfo: normalize_default.FFmpegInputNormalizer().normalize( AudioNormalizationRequest(input_path=source) ) - assert getattr(excinfo.value, "status_code", None) == 504 + assert str(excinfo.value) == ( + f"ffmpeg timed out after {normalize_default.FFMPEG_TIMEOUT_SEC}s" + ) assert not partial.exists() diff --git a/tests/unit/test_config_defaults.py b/tests/unit/test_config_defaults.py index f1ab8bf..b1c6992 100644 --- a/tests/unit/test_config_defaults.py +++ b/tests/unit/test_config_defaults.py @@ -4,6 +4,10 @@ import importlib import sys +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[2] def _fresh_config( @@ -37,13 +41,53 @@ def test_model_idle_timeout_explicit_zero_disables_idle_unload(monkeypatch): assert config.MODEL_IDLE_TIMEOUT_SEC == 0.0 -def test_rust_kernel_mode_defaults_to_off(monkeypatch): +def test_rust_kernel_mode_defaults_to_required(monkeypatch): config = _fresh_config(monkeypatch) - assert config.RUST_KERNEL_MODE == "off" + assert config.RUST_KERNEL_MODE == "required" def test_rust_kernel_mode_is_normalized(monkeypatch): config = _fresh_config(monkeypatch, rust_kernel_mode=" REQUIRED ") assert config.RUST_KERNEL_MODE == "required" + + +def test_rust_kernel_mode_explicit_off_remains_config_rollback(monkeypatch): + config = _fresh_config(monkeypatch, rust_kernel_mode=" off ") + + assert config.RUST_KERNEL_MODE == "off" + + +def test_compose_default_requires_rust_kernel(): + compose = (ROOT / "docker-compose.yml").read_text(encoding="utf-8") + + assert "RUST_KERNEL_MODE=${RUST_KERNEL_MODE:-required}" in compose + assert "RUST_KERNEL_MODE=${RUST_KERNEL_MODE:-off}" not in compose + + +def test_public_docs_describe_required_rust_kernel_default(): + docs = "\n".join( + (ROOT / path).read_text(encoding="utf-8") + for path in ( + "doc/configuration.en.md", + "doc/configuration.zh.md", + "doc/changelog.en.md", + "doc/changelog.zh.md", + ) + ) + dflt = "def" + "ault" + py_word = "Py" + "thon" + + for stale_phrase in ( + "The " + dflt + " `off`", + dflt + " remains " + py_word, + "`RUST_KERNEL_MODE` | " + "`off`", + "默认 " + "`off`", + "默认仍使用 " + py_word, + "默认仍由 " + py_word, + "默认" + "关闭时", + ): + assert stale_phrase not in docs + + assert "`RUST_KERNEL_MODE` | `required`" in docs diff --git a/tests/unit/test_docs_code_drift_gate.py b/tests/unit/test_docs_code_drift_gate.py new file mode 100644 index 0000000..47f7724 --- /dev/null +++ b/tests/unit/test_docs_code_drift_gate.py @@ -0,0 +1,188 @@ +"""Tests for public docs/code drift guardrails.""" + +from __future__ import annotations + +import importlib.util +import shutil +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[2] +DOCS_CODE_DRIFT_GATE = ROOT / "voscript-api" / "scripts" / "docs_code_drift_gate.py" +GATE_SURFACE_FILES = ( + ".env.example", + "README.md", + "README.en.md", + "app/config.py", + "app/main.py", + "app/api/routers/health.py", + "app/api/routers/transcriptions.py", + "app/api/routers/voiceprints.py", + "doc/api.zh.md", + "doc/api.en.md", + "doc/configuration.zh.md", + "doc/configuration.en.md", + "docker-compose.yml", +) + + +def _load_docs_code_drift_gate(): + spec = importlib.util.spec_from_file_location( + "voscript_docs_code_drift_gate", + DOCS_CODE_DRIFT_GATE, + ) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def _copy_gate_surface(tmp_path: Path) -> Path: + repo_root = tmp_path / "repo" + for rel_path in GATE_SURFACE_FILES: + source = ROOT / rel_path + target = repo_root / rel_path + target.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(source, target) + return repo_root + + +def _has_finding( + report: dict, + *, + category: str, + path: str, + term: str, +) -> bool: + return any( + finding["category"] == category + and finding["path"] == path + and finding["term"] == term + for finding in report["findings"] + ) + + +def test_docs_code_drift_gate_has_no_findings(): + gate = _load_docs_code_drift_gate() + report = gate.build_report(ROOT) + + assert report["findings"] == [] + assert any( + route["method"] == "GET" + and route["path"] == "/api/transcriptions/{tr_id}/audio" + for route in report["checked_routes"] + ) + assert any( + route["method"] == "GET" and route["path"] == "/api/voiceprints/{speaker_id}" + for route in report["checked_routes"] + ) + assert "WHISPERX_ALIGN_DEVICE" in report["public_config_keys"] + + +def test_reports_existing_api_route_missing_from_docs(tmp_path: Path): + gate = _load_docs_code_drift_gate() + root = _copy_gate_surface(tmp_path) + route_term = "GET /api/transcriptions/{tr_id}/audio" + assert gate.build_report(root)["findings"] == [] + + for doc_path in ("doc/api.zh.md", "doc/api.en.md"): + path = root / doc_path + text = path.read_text(encoding="utf-8") + assert route_term in text + path.write_text( + text.replace(route_term, "GET /api/transcriptions/{tr_id}/download"), + encoding="utf-8", + ) + + report = gate.build_report(root) + + assert _has_finding( + report, + category="api_route_missing_from_docs", + path="doc/api.zh.md", + term=route_term, + ), report["findings"] + assert _has_finding( + report, + category="api_route_missing_from_docs", + path="doc/api.en.md", + term=route_term, + ), report["findings"] + + +def test_reports_env_example_key_missing_from_compose(tmp_path: Path): + gate = _load_docs_code_drift_gate() + root = _copy_gate_surface(tmp_path) + env_key = "WHISPERX_ALIGN_DEVICE" + compose_path = root / "docker-compose.yml" + assert gate.build_report(root)["findings"] == [] + + text = compose_path.read_text(encoding="utf-8") + compose_ref = " - WHISPERX_ALIGN_DEVICE=${WHISPERX_ALIGN_DEVICE:-cpu}\n" + assert compose_ref in text + compose_path.write_text( + text.replace(compose_ref, " - WHISPERX_ALIGN_DEVICE=cpu\n"), + encoding="utf-8", + ) + + report = gate.build_report(root) + + assert _has_finding( + report, + category="env_example_key_missing_from_compose", + path="docker-compose.yml", + term=env_key, + ), report["findings"] + + +def test_reports_route_added_to_main_without_docs(tmp_path: Path): + gate = _load_docs_code_drift_gate() + root = _copy_gate_surface(tmp_path) + main_path = root / "app/main.py" + new_router_path = root / "app/api/routers/drift_probe.py" + route_term = "GET /api/drift-probe" + assert gate.build_report(root)["findings"] == [] + + main_text = main_path.read_text(encoding="utf-8") + assert "from api.routers import health, transcriptions, voiceprints" in main_text + main_text = main_text.replace( + "from api.routers import health, transcriptions, voiceprints", + "from api.routers import health, transcriptions, voiceprints, drift_probe", + ) + main_text = main_text.replace( + "app.include_router(voiceprints.router)\n", + "app.include_router(voiceprints.router)\n" + "app.include_router(drift_probe.router)\n", + ) + main_path.write_text(main_text, encoding="utf-8") + new_router_path.write_text( + '"""Router used by the docs/code drift gate regression test."""\n\n' + "from fastapi import APIRouter\n\n\n" + 'router = APIRouter(prefix="/api")\n\n\n' + '@router.get("/drift-probe")\n' + "async def drift_probe():\n" + ' return {"ok": True}\n', + encoding="utf-8", + ) + + report = gate.build_report(root) + + assert any( + route["method"] == "GET" + and route["path"] == "/api/drift-probe" + and route["source"] == "app/api/routers/drift_probe.py" + for route in report["checked_routes"] + ) + assert _has_finding( + report, + category="api_route_missing_from_docs", + path="doc/api.zh.md", + term=route_term, + ), report["findings"] + assert _has_finding( + report, + category="api_route_missing_from_docs", + path="doc/api.en.md", + term=route_term, + ), report["findings"] diff --git a/tests/unit/test_job_runtime.py b/tests/unit/test_job_runtime.py index 6deca32..49c009a 100644 --- a/tests/unit/test_job_runtime.py +++ b/tests/unit/test_job_runtime.py @@ -80,6 +80,101 @@ def test_run_serialized_gpu_work_releases_semaphore_after_error(monkeypatch): assert events == ["pre-whisper", "pre-whisper", "retry", "post-pipeline"] +def test_runtime_admission_count_helpers(monkeypatch): + cache = job_runtime._LRUJobsDict(maxsize=10) + cache["queued"] = {"status": "queued"} + cache["converting"] = {"status": "converting"} + cache["done"] = {"status": "completed"} + cache["failed"] = {"status": "failed"} + monkeypatch.setattr(job_runtime, "jobs", cache) + monkeypatch.setattr(job_runtime, "_active_job_ids", {"tr_queued", "tr_converting"}) + monkeypatch.setattr( + job_runtime, + "_in_flight_hashes", + {"sha256:a": "tr_a", "sha256:b": "tr_b"}, + ) + + assert job_runtime.active_job_count() == 2 + assert job_runtime.in_flight_count() == 2 + + +def test_runtime_job_store_public_api_tracks_current_store(monkeypatch): + cache = job_runtime._LRUJobsDict(maxsize=2) + monkeypatch.setattr(job_runtime, "jobs", cache) + + job_runtime.set_runtime_job("tr_public", {"status": "queued"}) + + assert job_runtime.runtime_job_exists("tr_public") is True + assert job_runtime.get_runtime_job("tr_public") == {"status": "queued"} + assert job_runtime.runtime_job_count() == 1 + assert job_runtime.runtime_jobs_values_snapshot() == ({"status": "queued"},) + + updated = job_runtime.update_runtime_job( + "tr_public", + {"status": "completed", "result": {"id": "tr_public"}}, + ) + + assert updated == { + "status": "completed", + "result": {"id": "tr_public"}, + } + assert job_runtime.get_runtime_job("missing", {"status": "missing"}) == { + "status": "missing", + } + assert job_runtime.pop_runtime_job("tr_public") == { + "status": "completed", + "result": {"id": "tr_public"}, + } + assert job_runtime.runtime_job_exists("tr_public") is False + assert job_runtime.pop_runtime_job("missing", None) is None + + +def test_active_job_reservation_is_not_coupled_to_lru_eviction(monkeypatch): + monkeypatch.setattr(job_runtime, "_active_job_ids", set()) + cache = job_runtime._LRUJobsDict(maxsize=1) + monkeypatch.setattr(job_runtime, "jobs", cache) + + reserved = job_runtime.try_reserve_active_job("tr_old", max_entries=1) + cache["tr_old"] = {"status": "queued"} + cache["tr_new"] = {"status": "queued"} + rejected = job_runtime.try_reserve_active_job("tr_new", max_entries=1) + + assert reserved.reserved is True + assert "tr_old" not in cache + assert rejected.budget_exceeded is True + assert job_runtime.active_job_count() == 1 + assert job_runtime.release_active_job("tr_old") is True + assert job_runtime.active_job_count() == 0 + + +def test_try_register_in_flight_enforces_budget_atomically(monkeypatch): + monkeypatch.setattr(job_runtime, "_in_flight_hashes", {"sha256:a": "tr_a"}) + + duplicate = job_runtime.try_register_in_flight( + "sha256:a", + "tr_duplicate", + max_entries=1, + ) + rejected = job_runtime.try_register_in_flight( + "sha256:b", + "tr_b", + max_entries=1, + ) + admitted = job_runtime.try_register_in_flight( + "sha256:c", + "tr_c", + max_entries=0, + ) + + assert duplicate.existing_job_id == "tr_a" + assert duplicate.registered is False + assert duplicate.budget_exceeded is False + assert rejected.existing_job_id is None + assert rejected.registered is False + assert rejected.budget_exceeded is True + assert admitted.registered is True + + def test_flush_torch_cuda_cache_skips_python_gc_for_active_job_phases(monkeypatch): events = [] fake_torch = SimpleNamespace( diff --git a/tests/unit/test_kernel_bridge.py b/tests/unit/test_kernel_bridge.py index 8277dd1..d1e3262 100644 --- a/tests/unit/test_kernel_bridge.py +++ b/tests/unit/test_kernel_bridge.py @@ -25,7 +25,7 @@ def _core_smoke(payload): return { "ok": True, "echoed": payload, - "version": "0.8.4", + "version": "0.8.5", "capabilities": {"core_smoke": True, "rust_extension": True}, } @@ -39,7 +39,7 @@ def test_core_smoke_round_trips_safe_payload_through_imported_extension(): assert result["ok"] is True assert result["echoed"] == payload - assert result["version"] == "0.8.4" + assert result["version"] == "0.8.5" assert result["capabilities"]["core_smoke"] is True @@ -73,7 +73,7 @@ def _importer(module_name): core_smoke({}, importer=_importer) -def test_rust_kernel_mode_defaults_to_off_semantics(): +def test_rust_kernel_mode_explicit_off_remains_rollback_semantics(): assert rust_kernel_mode("off") == "off" assert rust_provider_paths_enabled("off") is False assert rust_provider_paths_enabled("required") is True diff --git a/tests/unit/test_kernel_release_gates.py b/tests/unit/test_kernel_release_gates.py index a5624c1..dbedf52 100644 --- a/tests/unit/test_kernel_release_gates.py +++ b/tests/unit/test_kernel_release_gates.py @@ -57,10 +57,14 @@ def test_ci_workflows_include_required_release_gate_commands(): ) assert "public_release_scan.py --root ." in ci + assert "architecture_gate.py --root . --check" in ci assert "pytest tests/unit/ tests/test_security.py" in ci assert ( "cargo fmt --manifest-path crates/voscript_core/Cargo.toml -- --check" in heavy ) + assert "resolve-source:" in heavy + assert "git rev-parse HEAD" in heavy + assert "needs.resolve-source.outputs.source-sha" in heavy assert "cargo clippy --manifest-path crates/voscript_core/Cargo.toml" in heavy assert "cargo test --manifest-path crates/voscript_core/Cargo.toml" in heavy assert ( @@ -70,12 +74,69 @@ def test_ci_workflows_include_required_release_gate_commands(): assert "docker build ./app" in heavy assert "RUST_KERNEL_MODE=required" in heavy assert "workflow_dispatch:" in heavy - assert "types: [opened, reopened, ready_for_review]" in heavy + assert "types: [opened, reopened, ready_for_review, synchronize]" in heavy + assert "voscript-rust-foundation:${{ github.sha }}" not in heavy + assert heavy.count("ref: ${{ github.event.inputs.ref || github.ref }}") == 1 + assert "resolve-source:" in release + assert "source-sha" in release + assert "git rev-parse HEAD" in release + assert "ref: ${{ needs.resolve-source.outputs.source-sha }}" in release + assert "public-release-scan" in release + assert "lint-format" in release + assert "unit-security" in release + assert "docker-smoke" in release + assert "Run container Rust extension smoke" in release + assert "Run container healthz smoke" in release + assert "-e DEVICE=cpu -e ALLOW_NO_AUTH=1 -e RUST_KERNEL_MODE=required" in release assert ( "maturin build --release --manifest-path crates/voscript_core/Cargo.toml" in release ) assert "VOSCRIPT_CORE_WHEEL" in release + assert "sha-$SOURCE_SHA" in release + assert ( + "org.opencontainers.image.revision=${{ needs.resolve-source.outputs.source-sha }}" + in release + ) + assert ( + "voscript-core-wheel-${{ needs.resolve-source.outputs.source-sha }}" in release + ) + + +def test_runtime_defaults_require_rust_in_public_entrypoints(): + config = (ROOT / "app" / "config.py").read_text(encoding="utf-8") + compose = (ROOT / "docker-compose.yml").read_text(encoding="utf-8") + env_example = (ROOT / ".env.example").read_text(encoding="utf-8") + + assert '_env_str("RUST_KERNEL_MODE", "required").lower()' in config + assert "RUST_KERNEL_MODE=${RUST_KERNEL_MODE:-required}" in compose + assert "RUST_KERNEL_MODE=required" in env_example + assert "RUST_KERNEL_MODE=${RUST_KERNEL_MODE:-off}" not in compose + + +def test_rust_required_default_source_guards_do_not_regress_to_fail_open(): + dockerfile = (ROOT / "app" / "Dockerfile").read_text(encoding="utf-8") + drift_gate = ( + ROOT / "voscript-api" / "scripts" / "docs_code_drift_gate.py" + ).read_text(encoding="utf-8") + adr_0012 = ( + ROOT + / "docs" + / "adr" + / "0012-use-architecture-rings-and-cycle-gates-for-next-version-refactor.md" + ).read_text(encoding="utf-8") + + assert "building local source image without Rust extension" not in dockerfile + assert "voscript_core wheel is required by default" in dockerfile + assert "exit 1" in dockerfile + + assert "off by default" not in drift_gate + assert "required by default and fail closed" in drift_gate + assert "off is an explicit rollback" in drift_gate + + assert "`RUST_KERNEL_MODE=off` 是默认业务路径" not in adr_0012 + assert "`RUST_KERNEL_MODE=required` 是 0.8.5 final-state 默认业务路径" in adr_0012 + assert "`off` 只是显式 rollback" in adr_0012 def test_public_release_scan_entrypoint_is_repo_owned(): diff --git a/tests/unit/test_pipeline_runner.py b/tests/unit/test_pipeline_runner.py index 2b47732..f5d4870 100644 --- a/tests/unit/test_pipeline_runner.py +++ b/tests/unit/test_pipeline_runner.py @@ -1,5 +1,6 @@ """Unit tests for stable pipeline stage slots and runner orchestration.""" +import importlib import json from pathlib import Path from types import SimpleNamespace @@ -30,6 +31,32 @@ ) import providers.artifacts.default as artifacts_default from providers.artifacts.default import InMemoryArtifactsProvider +from providers.kernel_bridge import runtime as kernel_runtime + + +@pytest.fixture +def python_artifact_contracts(monkeypatch): + monkeypatch.setattr(kernel_runtime, "RUST_KERNEL_MODE", "off") + monkeypatch.setattr(artifacts_default, "rust_provider_paths_enabled", lambda: False) + monkeypatch.setitem( + InMemoryArtifactsProvider._build_segments.__globals__, + "rust_provider_paths_enabled", + lambda: False, + ) + monkeypatch.setitem( + InMemoryArtifactsProvider._build_artifact_manifest.__globals__, + "rust_provider_paths_enabled", + lambda: False, + ) + register_provider("artifacts", "default", InMemoryArtifactsProvider()) + try: + yield + finally: + unregister_provider("artifacts", "default") + + +def _capabilities_module(): + return importlib.import_module("providers.capabilities") def test_stage_slots_publish_stable_order_and_callable_entrypoints(): @@ -97,6 +124,130 @@ def stub_stage(context): assert "/private" not in caplog.text +def test_runner_records_default_provider_preflight_metadata(): + context = PipelineRunner(stage_order=("ingest",)).run_context( + SimpleNamespace(), + PipelineRequest(audio_path="sample.wav", language="ZH"), + ) + + assert context.metadata["selected_providers"]["ingest"] == "default" + assert context.metadata["provider_capabilities"]["ingest"] == { + "stage": "ingest", + "provider": "default", + "criticality": "required", + "language": "zh", + "reason": "language_supported", + } + assert context.metadata["ingest"]["working_audio_path"] == "sample.wav" + + +def test_runner_allows_registered_runtime_override_without_capability_record(): + class StubIngestProvider: + def run(self, context): + context.working_audio_path = "override.wav" + context.metadata["ingest"] = {"status": "override"} + + register_provider("ingest", "stub", StubIngestProvider()) + try: + context = PipelineRunner(stage_order=("ingest",)).run_context( + SimpleNamespace(), + PipelineRequest( + audio_path="sample.wav", + provider_selection={"ingest": "stub"}, + ), + ) + finally: + unregister_provider("ingest", "stub") + + assert context.metadata["selected_providers"]["ingest"] == "stub" + assert context.metadata["provider_capabilities"]["ingest"] == { + "stage": "ingest", + "provider": "stub", + "reason": "runtime_override", + "action": "run", + } + assert context.metadata["ingest"] == {"status": "override"} + assert context.working_audio_path == "override.wav" + + +def test_runner_required_capability_mismatch_fails_before_stage_execution( + monkeypatch, +): + calls = [] + capabilities_module = _capabilities_module() + monkeypatch.setitem( + capabilities_module._DEFAULT_CAPABILITIES, + ("asr", "default"), + capabilities_module.ProviderCapability( + stage="asr", + name="default", + supported_languages=frozenset({"en"}), + stage_criticality="required", + failure_policy="hard_fail", + ), + ) + + with pytest.raises( + capabilities_module.ProviderCapabilityError, + match="Required stage", + ): + PipelineRunner( + stage_order=("asr",), + stage_overrides={"asr": lambda context: calls.append("asr")}, + ).run_context( + SimpleNamespace(), + PipelineRequest(audio_path="sample.wav", language="zh"), + ) + + assert calls == [] + + +def test_runner_skips_degradable_unsupported_capability_with_metadata( + monkeypatch, +): + calls = [] + capabilities_module = _capabilities_module() + monkeypatch.setitem( + capabilities_module._DEFAULT_CAPABILITIES, + ("enhance", "default"), + capabilities_module.ProviderCapability( + stage="enhance", + name="default", + supported_languages=frozenset({"en"}), + stage_criticality="degradable", + failure_policy="skip", + ), + ) + + context = PipelineRunner( + stage_order=("enhance",), + stage_overrides={"enhance": lambda context: calls.append("enhance")}, + ).run_context( + SimpleNamespace(), + PipelineRequest(audio_path="sample.wav", language="zh"), + ) + + assert calls == [] + assert context.metadata["executed_stages"] == ["enhance"] + assert context.metadata["provider_capabilities"]["enhance"] == { + "stage": "enhance", + "provider": "default", + "criticality": "degradable", + "language": "zh", + "reason": "language_unsupported", + "action": "skip", + } + assert context.metadata["enhance"] == { + "status": "skipped", + "stage": "enhance", + "provider": "default", + "criticality": "degradable", + "language": "zh", + "reason": "language_unsupported", + "action": "skip", + } + + def test_runner_executes_stable_stage_order_and_builds_result(monkeypatch): pipeline = TranscriptionPipeline.__new__(TranscriptionPipeline) @@ -346,7 +497,10 @@ def match(self, request): } -def test_runner_persists_artifacts_and_cleans_generated_audio(tmp_path): +def test_runner_persists_artifacts_and_cleans_generated_audio( + python_artifact_contracts, + tmp_path, +): audio_path = tmp_path / "sample.mp3" audio_path.write_bytes(b"stub-audio") calls = [] @@ -410,6 +564,14 @@ def diarize(self, request): "status": "skipped", "language": "zh", "reason": "language_disabled", + "model": "org/model", + "duration_s": 12.5, + "max_duration_s": 60, + "cache_only": False, + "device": "cpu", + "model_path": str(tmp_path / "private-model"), + "exception": RuntimeError("hidden"), + "debug": {"path": str(tmp_path)}, } }, ) @@ -489,8 +651,12 @@ def match(self, request): assert context.metadata["asr"]["hallucination_guard"]["removed_segment_count"] == 2 assert result["transcription"]["alignment"] == { "status": "skipped", - "language": "zh", "reason": "language_disabled", + "model": "org/model", + "duration_s": 12.5, + "max_duration_s": 60, + "cache_only": False, + "device": "cpu", } assert result["transcription"]["artifacts"] == { "manifest_version": "artifact_manifest.v1", @@ -519,8 +685,12 @@ def match(self, request): ) assert context.metadata["diarization"]["alignment"] == { "status": "skipped", - "language": "zh", "reason": "language_disabled", + "model": "org/model", + "duration_s": 12.5, + "max_duration_s": 60, + "cache_only": False, + "device": "cpu", } assert result["artifact_paths"]["result_path"] == str(result_path) assert result_path.exists() @@ -536,7 +706,9 @@ def match(self, request): assert not audio_path.with_suffix(".denoised.wav").exists() -def test_artifacts_preserve_raw_speaker_labels_when_clusters_match_same_voiceprint(): +def test_artifacts_preserve_raw_speaker_labels_when_clusters_match_same_voiceprint( + python_artifact_contracts, +): aligned_segments = [ { "start": 0.0, @@ -655,6 +827,7 @@ def fake_artifact_manifest_contract(payload): def test_artifact_result_contract_keeps_status_speaker_label_and_optional_alignment( + python_artifact_contracts, tmp_path, ): context = PipelineContext( @@ -696,6 +869,40 @@ def test_artifact_result_contract_keeps_status_speaker_label_and_optional_alignm assert "alignment" not in result +def test_artifacts_omit_alignment_when_metadata_has_no_public_safe_fields( + python_artifact_contracts, + tmp_path, +): + context = PipelineContext( + pipeline=SimpleNamespace(), + request=PipelineRequest( + audio_path=str(tmp_path / "sample.wav"), + artifact_dir=tmp_path / "transcriptions" / "tr_private_alignment", + ), + ) + context.aligned_segments = [ + { + "start": 0.0, + "end": 1.0, + "text": "private", + "speaker": "SPEAKER_00", + } + ] + context.metadata["diarization"] = { + "alignment": { + "language": "zh", + "model_path": str(tmp_path / "private-model"), + "exception": RuntimeError("hidden"), + "debug": {"path": str(tmp_path)}, + } + } + + result = InMemoryArtifactsProvider()._build_transcription(context) + + assert result is not None + assert "alignment" not in result + + def test_runner_uses_explicit_artifacts_provider_selection(): class StubArtifactsProvider: def build(self, context): @@ -764,8 +971,14 @@ def test_runner_cleans_temporary_paths_and_keeps_metadata_on_stage_failure( context = runner.build_context(SimpleNamespace(), request) monkeypatch.setattr(runner, "build_context", lambda pipeline, request: context) - with pytest.raises(RuntimeError, match="enhance exploded"): - runner.run_context(SimpleNamespace(), request) + register_provider("normalize", "norm-stub", object()) + register_provider("enhance", "enhance-stub", object()) + try: + with pytest.raises(RuntimeError, match="enhance exploded"): + runner.run_context(SimpleNamespace(), request) + finally: + unregister_provider("normalize", "norm-stub") + unregister_provider("enhance", "enhance-stub") assert not normalized.exists() assert not enhanced.exists() diff --git a/tests/unit/test_provider_capabilities.py b/tests/unit/test_provider_capabilities.py index 1774732..36d78f4 100644 --- a/tests/unit/test_provider_capabilities.py +++ b/tests/unit/test_provider_capabilities.py @@ -4,8 +4,10 @@ import pytest +from pipeline.registry import available_providers, available_stage_slots import providers.capabilities as capabilities_module from providers.capabilities import ( + ALL_LANGUAGES, ProviderCapability, ProviderCapabilityError, default_provider_capabilities, @@ -28,6 +30,33 @@ def test_default_asr_capability_is_multilingual_required_stage(): assert match.metadata["language"] == "zh" +def test_registry_default_provider_surface_has_static_capability_records(): + for stage in available_stage_slots(): + for provider_name in available_providers(stage): + capability = get_provider_capability(stage, provider_name) + + assert capability.stage == stage + assert capability.name == provider_name + assert ALL_LANGUAGES in capability.supported_languages + + +def test_alignment_capability_is_owned_by_diarization_stage(): + capability = get_provider_capability("alignment") + + assert capability.stage == "diarization" + assert capability.name == "default" + assert capability.capability == "alignment" + assert capability.stage_criticality == "degradable" + assert capability.failure_policy == "skip" + + match = match_provider_capability("alignment", language="zh") + + assert match.should_run is True + assert match.metadata["stage"] == "diarization" + assert match.metadata["capability"] == "alignment" + assert match.metadata["language"] == "zh" + + def test_required_stage_language_mismatch_hard_fails(monkeypatch): custom = ProviderCapability( stage="asr", @@ -56,6 +85,7 @@ def test_degradable_stage_language_mismatch_skips_with_safe_metadata(monkeypatch disabled_languages=frozenset({"zh"}), stage_criticality="degradable", failure_policy="skip", + capability=capability.capability, ), ) @@ -64,9 +94,10 @@ def test_degradable_stage_language_mismatch_skips_with_safe_metadata(monkeypatch assert match.should_run is False assert match.reason == "language_disabled" assert match.metadata == { - "stage": "alignment", + "stage": "diarization", "provider": "default", "criticality": "degradable", + "capability": "alignment", "language": "zh", "reason": "language_disabled", "action": "skip", diff --git a/tests/unit/test_provider_registry.py b/tests/unit/test_provider_registry.py index e5e0bd7..408c295 100644 --- a/tests/unit/test_provider_registry.py +++ b/tests/unit/test_provider_registry.py @@ -25,17 +25,21 @@ from pipeline.registry import ( ProviderNotFoundError, available_providers, + is_provider_override, register_provider, resolve_provider, unregister_provider, ) from providers import maybe_denoise +from providers._registry import ProviderFacadeSelectionError +import providers.enhance as enhance_facade import providers.asr.default as asr_default from providers.asr.default import default_asr_provider import providers.diarization.default as diarization_default from providers.diarization.default import default_diarization_provider from providers.embedding import default_speaker_embedding_provider import providers.embedding.default as embedding_default +import providers.normalize as normalize_facade import pipeline.orchestrator as orchestrator from providers.normalize import convert_to_wav @@ -160,21 +164,61 @@ def iter_segments(): assert "/private" not in caplog.text -def test_registry_named_overrides_drive_compatibility_helpers(tmp_path): +def test_provider_facade_helpers_use_local_default_provider(monkeypatch, tmp_path): input_path = tmp_path / "sample.mp3" input_path.write_bytes(b"stub") - register_provider("normalize", "stub", StubNormalizer()) - register_provider("enhance", "stub", StubEnhancer()) + monkeypatch.setattr( + normalize_facade, "default_normalize_provider", StubNormalizer() + ) + monkeypatch.setattr(enhance_facade, "default_enhance_provider", StubEnhancer()) + + normalized = convert_to_wav(input_path) + enhanced = maybe_denoise(normalized) + + assert normalized.name == "sample.stub.wav" + assert enhanced.name == "sample.stub.boost.wav" + + +def test_registry_named_overrides_do_not_drive_provider_facade_helpers(tmp_path): + input_path = tmp_path / "sample.mp3" + input_path.write_bytes(b"stub") + + assert is_provider_override("normalize", "stub") is False + normalizer = StubNormalizer() + enhancer = StubEnhancer() + register_provider("normalize", "stub", normalizer) + register_provider("enhance", "stub", enhancer) try: - normalized = convert_to_wav(input_path, provider_name="stub") - enhanced = maybe_denoise(normalized, provider_name="stub") + assert is_provider_override("input_normalization", "stub") is True + assert is_provider_override("enhancement", "stub") is True + assert resolve_provider("normalize", "stub") is normalizer + assert resolve_provider("enhance", "stub") is enhancer + with pytest.raises(ProviderFacadeSelectionError, match="PipelineRunner"): + convert_to_wav(input_path, provider_name="stub") + with pytest.raises(ProviderFacadeSelectionError, match="PipelineRunner"): + maybe_denoise(input_path, provider_name="stub") finally: unregister_provider("normalize", "stub") unregister_provider("enhance", "stub") - assert normalized.name == "sample.stub.wav" - assert enhanced.name == "sample.stub.boost.wav" + assert is_provider_override("normalize", "stub") is False + + +def test_provider_package_does_not_reexport_pipeline_registry_helpers(): + import providers + + registry_exports = { + "available_providers", + "available_stage_slots", + "register_provider", + "resolve_provider", + "unregister_provider", + } + + assert registry_exports.isdisjoint(dir(providers)) + for name in registry_exports: + assert not hasattr(providers, name) def test_unknown_provider_raises_lookup_error(): @@ -457,6 +501,56 @@ def fake_load_align_model(language_code, device, model_name): } +def test_default_diarization_provider_skips_alignment_when_audio_duration_exceeds_budget( + monkeypatch, +): + pipeline = TranscriptionPipeline.__new__(TranscriptionPipeline) + pipeline.device = "cpu" + + class FakeDiarizationResult: + def itertracks(self, yield_label=False): + assert yield_label is True + yield SimpleNamespace(start=0.0, end=1.2), None, "SPEAKER_00" + + class FakeDiarizer: + def __call__(self, audio_path, **kwargs): + return FakeDiarizationResult() + + pipeline._diarization = FakeDiarizer() + monkeypatch.setattr( + diarization_default, + "audio_duration_seconds", + lambda audio_path: 7201.0, + ) + monkeypatch.setattr( + diarization_default, "WHISPERX_ALIGN_MAX_AUDIO_DURATION_SEC", 7200.0 + ) + monkeypatch.setattr( + sys.modules["whisperx"], + "load_audio", + lambda audio_path: (_ for _ in ()).throw( + AssertionError("whisperx.load_audio should not run") + ), + raising=False, + ) + + result = default_diarization_provider.diarize( + DiarizationRequest( + pipeline=pipeline, + audio_path="long.wav", + transcription_result={ + "segments": [{"start": 0.0, "end": 1.2, "text": "hello"}], + "language": "en", + }, + ) + ) + + assert result.metadata["alignment"]["status"] == "skipped" + assert result.metadata["alignment"]["reason"] == "duration_budget_exceeded" + assert result.metadata["alignment"]["duration_s"] == 7201.0 + assert result.aligned_segments[0]["speaker"] == "SPEAKER_00" + + def test_default_diarization_provider_applies_model_dir_and_cache_only( monkeypatch, ): @@ -483,6 +577,12 @@ def __call__(self, audio_path, **kwargs): monkeypatch.setattr(diarization_default, "WHISPERX_ALIGN_CACHE_ONLY", True) monkeypatch.delenv("HF_HUB_OFFLINE", raising=False) monkeypatch.delenv("TRANSFORMERS_OFFLINE", raising=False) + hub_constants = ModuleType("huggingface_hub.constants") + hub_constants.HF_HUB_OFFLINE = False + transformers_hub = ModuleType("transformers.utils.hub") + transformers_hub._is_offline_mode = False + monkeypatch.setitem(sys.modules, "huggingface_hub.constants", hub_constants) + monkeypatch.setitem(sys.modules, "transformers.utils.hub", transformers_hub) whisperx = sys.modules["whisperx"] monkeypatch.setattr( whisperx, @@ -500,6 +600,8 @@ def fake_load_align_model(language_code, device, model_name, model_dir): model_dir, os.environ.get("HF_HUB_OFFLINE"), os.environ.get("TRANSFORMERS_OFFLINE"), + hub_constants.HF_HUB_OFFLINE, + transformers_hub._is_offline_mode, ) ) return "align-model", {"language": language_code, "device": device} @@ -531,10 +633,12 @@ def fake_load_align_model(language_code, device, model_name, model_dir): ) assert calls == [ - ("zh", "cpu", "safe/zh-align-model", "/cache", "1", "1"), + ("zh", "cpu", "safe/zh-align-model", "/cache", "1", "1", True, True), ] assert os.environ.get("HF_HUB_OFFLINE") is None assert os.environ.get("TRANSFORMERS_OFFLINE") is None + assert hub_constants.HF_HUB_OFFLINE is False + assert transformers_hub._is_offline_mode is False assert result.metadata["alignment"]["cache_only"] is True @@ -1108,6 +1212,76 @@ def copy(self): ] +def test_default_embedding_provider_skips_full_preload_when_duration_exceeds_budget( + monkeypatch, +): + pipeline = TranscriptionPipeline.__new__(TranscriptionPipeline) + pipeline.device = "cpu" + calls = [] + + class FakeTensor: + def __init__(self, channels, frames): + self.shape = (channels, frames) + + def mean(self, dim=0, keepdim=True): + assert dim == 0 + return FakeTensor(1, self.shape[1]) + + def to(self, device): + calls.append(("to", device, self.shape[1])) + return self + + class FakeEmbeddingModel: + def __call__(self, payload): + calls.append(("embedding_model", payload["waveform"].shape[1])) + return [float(payload["waveform"].shape[1]), 3.0] + + class FakeInfo: + sample_rate = 16000 + + pipeline._embedding_model = FakeEmbeddingModel() + monkeypatch.setattr( + embedding_default, "audio_duration_seconds", lambda path: 1801.0 + ) + monkeypatch.setattr( + embedding_default, + "EMBEDDING_PRELOAD_MAX_AUDIO_DURATION_SEC", + 1800.0, + ) + monkeypatch.setattr( + embedding_default.sf, + "read", + lambda *args, **kwargs: (_ for _ in ()).throw( + AssertionError("soundfile.read should not preload over-budget audio") + ), + raising=False, + ) + monkeypatch.setattr( + embedding_default.torchaudio, "info", lambda audio_path: FakeInfo() + ) + monkeypatch.setattr( + embedding_default.torchaudio, + "load", + lambda audio_path, frame_offset, num_frames: (FakeTensor(1, num_frames), 16000), + ) + + result = default_speaker_embedding_provider.extract_embeddings( + SpeakerEmbeddingRequest( + pipeline=pipeline, + audio_path="long.wav", + diarization_turns=[ + {"speaker": "SPEAKER_00", "start": 0.0, "end": 2.0}, + ], + ) + ) + + assert result.speaker_embeddings["SPEAKER_00"].tolist() == [32000.0, 3.0] + assert calls == [ + ("to", "cpu", 32000), + ("embedding_model", 32000), + ] + + def test_default_embedding_provider_uses_selected_device_after_first_lazy_load( monkeypatch, ): diff --git a/tests/unit/test_transcription_records.py b/tests/unit/test_transcription_records.py new file mode 100644 index 0000000..573c7ec --- /dev/null +++ b/tests/unit/test_transcription_records.py @@ -0,0 +1,314 @@ +"""Focused tests for application-level transcription record usecases.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + + +def _settings(records, tmp_path: Path): + transcriptions_dir = tmp_path / "transcriptions" + uploads_dir = tmp_path / "uploads" + transcriptions_dir.mkdir() + uploads_dir.mkdir() + return records.TranscriptionRecordSettings( + transcriptions_dir=transcriptions_dir, + uploads_dir=uploads_dir, + ) + + +def _seed_result( + transcriptions_dir: Path, + tr_id: str, + *, + filename: str = "audio.wav", + raw_text: str | None = None, +) -> Path: + tr_dir = transcriptions_dir / tr_id + tr_dir.mkdir(parents=True, exist_ok=True) + result_path = tr_dir / "result.json" + if raw_text is not None: + result_path.write_text(raw_text, encoding="utf-8") + return result_path + + payload = { + "id": tr_id, + "filename": filename, + "created_at": "2026-04-25T00:00:00+00:00", + "segments": [ + { + "id": 1, + "start": None, + "end": float("nan"), + "speaker_label": "SPEAKER_00", + "speaker_name": "Maple\r\nInjected", + "speaker_id": "spk_old", + "text": "hello", + }, + { + "id": 2, + "start": 61.25, + "end": 62.0, + "speaker_label": "SPEAKER_01", + "speaker_name": "Guest", + "speaker_id": None, + "text": "world", + }, + ], + "unique_speakers": ["Maple\r\nInjected", "Guest"], + "speaker_map": {}, + } + result_path.write_text(json.dumps(payload), encoding="utf-8") + return result_path + + +def test_job_status_recovery_uses_runtime_jobs_then_disk(tmp_path): + from application import transcription_records as records + + settings = _settings(records, tmp_path) + runtime_jobs = { + "tr_memory_done": { + "status": "completed", + "filename": "done.wav", + "result": {"id": "tr_memory_done"}, + }, + "tr_memory_failed": { + "status": "failed", + "filename": "failed.wav", + "error": "boom", + }, + } + + assert records.get_job_status( + "tr_memory_done", + settings=settings, + runtime_jobs=runtime_jobs, + ) == { + "id": "tr_memory_done", + "status": "completed", + "filename": "done.wav", + "result": {"id": "tr_memory_done"}, + } + assert ( + records.get_job_status( + "tr_memory_failed", + settings=settings, + runtime_jobs=runtime_jobs, + )["error"] + == "boom" + ) + + done_dir = settings.transcriptions_dir / "tr_disk_done" + done_dir.mkdir() + (done_dir / "status.json").write_text( + json.dumps({"status": "completed", "filename": "disk.wav"}), + encoding="utf-8", + ) + (done_dir / "result.json").write_text("{bad-json", encoding="utf-8") + disk_done = records.get_job_status("tr_disk_done", settings=settings) + assert disk_done["status"] == "completed" + assert disk_done["result"] is None + + queued_dir = settings.transcriptions_dir / "tr_disk_queued" + queued_dir.mkdir() + (queued_dir / "status.json").write_text( + json.dumps({"status": "queued", "filename": "queued.wav"}), + encoding="utf-8", + ) + disk_queued = records.get_job_status("tr_disk_queued", settings=settings) + assert disk_queued == { + "id": "tr_disk_queued", + "status": "failed", + "error": "Process restarted while job was in progress", + "filename": "queued.wav", + } + + corrupt_dir = settings.transcriptions_dir / "tr_badstatus" + corrupt_dir.mkdir() + (corrupt_dir / "status.json").write_text("{not-json", encoding="utf-8") + with pytest.raises(records.TranscriptionRecordError) as exc_info: + records.get_job_status("tr_badstatus", settings=settings) + assert exc_info.value.reason == "job_not_found" + assert str(exc_info.value) == "Job not found" + + +def test_record_listing_artifact_audio_and_exports(tmp_path): + from application import transcription_records as records + + settings = _settings(records, tmp_path) + tr_id = "tr_record_edges" + _seed_result(settings.transcriptions_dir, tr_id, filename="route_audio.wav") + _seed_result(settings.transcriptions_dir, "tr_corrupt", raw_text="{bad-json") + + listing = records.list_transcriptions(settings=settings) + assert [row for row in listing if row["id"] == tr_id and row["segment_count"] == 2] + + with pytest.raises(records.TranscriptionRecordError) as missing_audio: + records.get_audio_artifact(tr_id, settings=settings) + assert missing_audio.value.reason == "missing_audio" + assert str(missing_audio.value) == "Original audio file not found" + + (settings.uploads_dir / "route_audio.wav").write_bytes(b"audio") + audio = records.get_audio_artifact(tr_id, settings=settings) + assert audio.path == settings.uploads_dir / "route_audio.wav" + assert audio.filename == "route_audio.wav" + + srt = records.build_export_payload(tr_id, "srt", settings=settings) + assert srt.text is not None + assert srt.file_path is None + assert srt.media_type == "text/srt" + assert srt.filename == f"{tr_id}.srt" + assert "00:00:00,000 --> 00:00:00,000" in srt.text + assert "[Maple Injected] hello" in srt.text + + txt = records.build_export_payload(tr_id, "txt", settings=settings) + assert txt.text == "[00:00] Maple Injected: hello\n[01:01] Guest: world" + assert txt.media_type == "text/plain" + + exported_json = records.build_export_payload(tr_id, "json", settings=settings) + assert exported_json.text is None + assert ( + exported_json.file_path == settings.transcriptions_dir / tr_id / "result.json" + ) + assert exported_json.media_type == "application/json" + assert exported_json.filename == f"{tr_id}.json" + + with pytest.raises(records.TranscriptionRecordError) as unsupported: + records.build_export_payload(tr_id, "vtt", settings=settings) + assert unsupported.value.reason == "unsupported_export_format" + assert str(unsupported.value) == "Unsupported format. Use: srt, txt, json" + + for operation in ( + records.load_transcription_result, + records.get_audio_artifact, + lambda bad_id, *, settings: records.build_export_payload( + bad_id, + "txt", + settings=settings, + ), + ): + with pytest.raises(records.TranscriptionRecordError) as corrupt: + operation("tr_corrupt", settings=settings) + assert corrupt.value.reason == "corrupt_result" + assert str(corrupt.value) == "Corrupt transcription artifact" + + +@pytest.mark.parametrize( + "filename", + [ + "../outside.wav", + "/" + "outside.wav", + ], +) +def test_audio_artifact_rejects_result_filename_that_escapes_uploads( + tmp_path, + filename, +): + from application import transcription_records as records + + settings = _settings(records, tmp_path) + tr_id = "tr_unsafe_audio_name" + outside_audio = tmp_path / "outside.wav" + outside_audio.write_bytes(b"outside") + _seed_result(settings.transcriptions_dir, tr_id, filename=filename) + + with pytest.raises(records.TranscriptionRecordError) as exc_info: + records.get_audio_artifact(tr_id, settings=settings) + + assert exc_info.value.reason == "corrupt_result" + assert str(exc_info.value) == "Corrupt transcription artifact" + + +def test_speaker_reassignment_validates_voiceprint_and_updates_result(tmp_path): + from application import transcription_records as records + + settings = _settings(records, tmp_path) + tr_id = "tr_speaker_edges" + _seed_result(settings.transcriptions_dir, tr_id) + + class FakeDB: + def __init__(self, found: bool) -> None: + self.found = found + + def get_speaker(self, speaker_id): + return {"id": speaker_id} if self.found else None + + with pytest.raises(records.TranscriptionRecordError) as invalid_id: + records.reassign_speaker( + tr_id, + 1, + "Maple", + "not-safe", + voiceprint_db=FakeDB(found=True), + settings=settings, + ) + assert invalid_id.value.reason == "invalid_speaker_id" + assert str(invalid_id.value) == "Invalid speaker_id format" + + with pytest.raises(records.TranscriptionRecordError) as missing_voiceprint: + records.reassign_speaker( + tr_id, + 1, + "Maple", + "spk_missing", + voiceprint_db=FakeDB(found=False), + settings=settings, + ) + assert missing_voiceprint.value.reason == "missing_voiceprint" + assert str(missing_voiceprint.value) == "Voiceprint spk_missing not found" + + assert records.reassign_speaker( + tr_id, + 1, + "Maple", + "spk_known", + voiceprint_db=FakeDB(found=True), + settings=settings, + ) == {"ok": True} + assert records.reassign_speaker( + tr_id, + 2, + "Maple", + None, + settings=settings, + ) == {"ok": True} + + data = json.loads((settings.transcriptions_dir / tr_id / "result.json").read_text()) + assert data["segments"][0]["speaker_id"] == "spk_known" + assert data["segments"][1]["speaker_id"] is None + assert data["unique_speakers"] == ["Maple"] + + with pytest.raises(records.TranscriptionRecordError) as missing_segment: + records.reassign_speaker(tr_id, 99, "Nobody", settings=settings) + assert missing_segment.value.reason == "segment_not_found" + assert str(missing_segment.value) == "Segment not found" + + +def test_transcription_records_module_stays_out_of_api_ring(): + from application import transcription_records as records + + source = Path(records.__file__).read_text(encoding="utf-8") + + assert "fastapi" not in source + assert "HTTPException" not in source + assert "UploadFile" not in source + assert "api." not in source + + +def test_transcription_records_module_delegates_filesystem_details_to_infra(): + from application import transcription_records as records + + source = Path(records.__file__).read_text(encoding="utf-8") + + assert "from infra.transcription_records import" in source + assert "json.loads" not in source + assert "read_text(" not in source + assert "write_text(" not in source + assert "_atomic_write_json" not in source + assert "iterdir(" not in source + assert ' / "status.json"' not in source + assert ' / "result.json"' not in source + assert "PurePosixPath" not in source + assert "PureWindowsPath" not in source diff --git a/tests/unit/test_transcription_submission.py b/tests/unit/test_transcription_submission.py new file mode 100644 index 0000000..8a90cd9 --- /dev/null +++ b/tests/unit/test_transcription_submission.py @@ -0,0 +1,310 @@ +"""Focused tests for application-level transcription upload submission.""" + +from __future__ import annotations + +import asyncio +import hashlib +import json +from pathlib import Path +from types import SimpleNamespace + +import infra.job_runtime as job_runtime + + +class MemoryUpload: + def __init__(self, filename: str, content: bytes) -> None: + self.filename = filename + self._content = content + self._offset = 0 + + async def read(self, size: int = -1) -> bytes: + if self._offset >= len(self._content): + return b"" + if size is None or size < 0: + size = len(self._content) - self._offset + chunk = self._content[self._offset : self._offset + size] + self._offset += len(chunk) + return chunk + + +def _submission_settings( + *, + uploads_dir: Path, + transcriptions_dir: Path, + min_free_disk_bytes: int = 0, +): + from application import transcription_submission as submission + + return submission.TranscriptionSubmissionSettings( + max_upload_bytes=1024, + upload_chunk=8, + max_active_jobs=2, + max_in_flight_jobs=2, + uploads_dir=uploads_dir, + transcriptions_dir=transcriptions_dir, + min_free_disk_bytes=min_free_disk_bytes, + denoise_max_audio_duration_sec=111.0, + embedding_preload_max_audio_duration_sec=222.0, + whisperx_align_max_audio_duration_sec=333.0, + ) + + +async def _write_upload(file, save_path, max_upload_bytes, upload_chunk): + del max_upload_bytes + sha256 = hashlib.sha256() + size = 0 + with save_path.open("wb") as handle: + while chunk := await file.read(upload_chunk): + size += len(chunk) + handle.write(chunk) + sha256.update(chunk) + return size, sha256.hexdigest() + + +def test_submit_transcription_upload_retains_failed_record_on_thread_start_failure( + tmp_path, + monkeypatch, +): + """Thread-start failure cleans transient state but keeps durable failure.""" + from application import transcription_submission as submission + import infra.job_persistence as job_persistence + + transcriptions_dir = tmp_path / "transcriptions" + uploads_dir = tmp_path / "uploads" + transcriptions_dir.mkdir() + uploads_dir.mkdir() + + monkeypatch.setattr(job_runtime, "_active_job_ids", set()) + monkeypatch.setattr(job_runtime, "_in_flight_hashes", {}) + monkeypatch.setattr(job_persistence, "TRANSCRIPTIONS_DIR", transcriptions_dir) + + audio = b"RIFF\x00\x00\x00\x00WAVEfmt thread-start" + file_hash = hashlib.sha256(audio).hexdigest() + started = [] + + class FailingThread: + def __init__(self, *args, **kwargs): + started.append(("created", args, kwargs)) + + def start(self): + started.append(("started",)) + raise RuntimeError("thread boom") + + def worker(*args, **kwargs): + raise AssertionError("worker should not run synchronously") + + async def upload_saver(file, save_path, max_upload_bytes, upload_chunk): + del max_upload_bytes + sha256 = hashlib.sha256() + size = 0 + with save_path.open("wb") as handle: + while chunk := await file.read(upload_chunk): + size += len(chunk) + handle.write(chunk) + sha256.update(chunk) + return size, sha256.hexdigest() + + monkeypatch.setattr(submission, "Thread", FailingThread) + monkeypatch.setattr(submission, "run_transcription", worker) + + try: + asyncio.run( + submission.submit_transcription_upload( + submission.TranscriptionSubmissionCommand( + file=MemoryUpload("../-y\nattack.wav", audio), + pipeline=object(), + voiceprint_db=object(), + language="en", + min_speakers=0, + max_speakers=0, + denoise_model=None, + snr_threshold=None, + no_repeat_ngram_size=2, + ), + settings=submission.TranscriptionSubmissionSettings( + max_upload_bytes=1024, + upload_chunk=8, + max_active_jobs=2, + max_in_flight_jobs=2, + uploads_dir=uploads_dir, + transcriptions_dir=transcriptions_dir, + ), + job_id_factory=lambda: "tr_submit", + upload_saver=upload_saver, + ) + ) + except submission.TranscriptionSubmissionError as exc: + submission_error = exc + else: + raise AssertionError("submit_transcription_upload should fail thread startup") + + assert submission_error.reason == "thread_start_failed" + assert ( + str(submission_error) + == "Failed to start background transcription — retry later" + ) + assert [entry[0] for entry in started] == ["created", "started"] + created_kwargs = started[0][2] + thread_args = created_kwargs["args"] + assert created_kwargs["target"] is worker + assert created_kwargs["daemon"] is True + assert thread_args[0] == "tr_submit" + assert thread_args[-1] == 0 + + assert "tr_submit" in submission.jobs + assert submission.jobs["tr_submit"]["status"] == "failed" + assert ( + submission.jobs["tr_submit"]["error"] + == "Failed to start background transcription" + ) + assert list(uploads_dir.iterdir()) == [] + assert job_runtime.lookup_in_flight(file_hash) is None + assert job_runtime.active_job_count() == 0 + + status_path = transcriptions_dir / "tr_submit" / "status.json" + assert status_path.exists() + status = json.loads(status_path.read_text(encoding="utf-8")) + assert status["status"] == "failed" + assert status["filename"].startswith("-y") + assert "\n" not in status["filename"] + + +def test_submit_transcription_upload_rejects_data_disk_pressure_before_bootstrap( + tmp_path, + monkeypatch, +): + from application import transcription_submission as submission + + transcriptions_dir = tmp_path / "transcriptions" + uploads_dir = tmp_path / "uploads" + transcriptions_dir.mkdir() + uploads_dir.mkdir() + + monkeypatch.setattr(job_runtime, "_active_job_ids", set()) + monkeypatch.setattr(job_runtime, "_in_flight_hashes", {}) + monkeypatch.setattr(job_runtime, "jobs", job_runtime._LRUJobsDict(maxsize=200)) + + started = [] + status_writes = [] + audio = b"RIFF pressure" + file_hash = hashlib.sha256(audio).hexdigest() + + class RecordingThread: + def __init__(self, *args, **kwargs): + started.append(("created", args, kwargs)) + + def start(self): + started.append(("started",)) + + try: + asyncio.run( + submission.submit_transcription_upload( + submission.TranscriptionSubmissionCommand( + file=MemoryUpload("pressure.wav", audio), + pipeline=object(), + voiceprint_db=object(), + ), + settings=_submission_settings( + uploads_dir=uploads_dir, + transcriptions_dir=transcriptions_dir, + min_free_disk_bytes=1024, + ), + job_id_factory=lambda: "tr_pressure", + thread_factory=RecordingThread, + upload_saver=_write_upload, + status_writer=lambda *args, **kwargs: ( + status_writes.append((args, kwargs)) or True + ), + disk_usage=lambda path: SimpleNamespace(free=1023), + audio_duration_reader=lambda path: 12.5, + ) + ) + except submission.TranscriptionSubmissionError as exc: + submission_error = exc + else: + raise AssertionError("submit_transcription_upload should reject disk pressure") + + assert submission_error.reason == "data_disk_pressure" + assert "data disk free space" in str(submission_error) + assert list(uploads_dir.iterdir()) == [] + assert list(transcriptions_dir.iterdir()) == [] + assert status_writes == [] + assert started == [] + assert "tr_pressure" not in submission.jobs + assert submission.release_transcription_admission("tr_pressure") is False + assert submission.find_in_flight_transcription(file_hash) is None + + +def test_submit_transcription_upload_records_admission_snapshot( + tmp_path, + monkeypatch, +): + from application import transcription_submission as submission + + transcriptions_dir = tmp_path / "transcriptions" + uploads_dir = tmp_path / "uploads" + transcriptions_dir.mkdir() + uploads_dir.mkdir() + + monkeypatch.setattr(job_runtime, "_active_job_ids", set()) + monkeypatch.setattr(job_runtime, "_in_flight_hashes", {}) + monkeypatch.setattr(job_runtime, "jobs", job_runtime._LRUJobsDict(maxsize=200)) + + started = [] + + class RecordingThread: + def __init__(self, *args, **kwargs): + started.append(("created", args, kwargs)) + + def start(self): + started.append(("started",)) + + result = asyncio.run( + submission.submit_transcription_upload( + submission.TranscriptionSubmissionCommand( + file=MemoryUpload("snapshot.wav", b"RIFF snapshot"), + pipeline=object(), + voiceprint_db=object(), + ), + settings=_submission_settings( + uploads_dir=uploads_dir, + transcriptions_dir=transcriptions_dir, + min_free_disk_bytes=1024, + ), + job_id_factory=lambda: "tr_snapshot", + thread_factory=RecordingThread, + upload_saver=_write_upload, + status_writer=lambda *args, **kwargs: True, + disk_usage=lambda path: SimpleNamespace(free=2048), + audio_duration_reader=lambda path: 42.25, + ) + ) + + assert result.job_id == "tr_snapshot" + assert result.status == "queued" + assert [entry[0] for entry in started] == ["created", "started"] + + admission = submission.jobs["tr_snapshot"]["admission"] + assert admission["active_jobs"] == 0 + assert admission["in_flight_jobs"] == 0 + assert admission["data_disk"] == { + "free_bytes": 2048, + "min_free_bytes": 1024, + } + assert admission["memory_sensitive_stage_limits"] == { + "DENOISE_MAX_AUDIO_DURATION_SEC": 111.0, + "EMBEDDING_PRELOAD_MAX_AUDIO_DURATION_SEC": 222.0, + "WHISPERX_ALIGN_MAX_AUDIO_DURATION_SEC": 333.0, + } + assert admission["audio_duration_seconds"] == 42.25 + + +def test_transcription_submission_module_stays_out_of_api_ring(): + from application import transcription_submission as submission + + source = Path(submission.__file__).read_text(encoding="utf-8") + + assert "fastapi" not in source + assert "HTTPException" not in source + assert "UploadFile" not in source + assert "api." not in source diff --git a/tests/unit/test_voiceprint_db.py b/tests/unit/test_voiceprint_db.py index a65b3eb..8121d47 100644 --- a/tests/unit/test_voiceprint_db.py +++ b/tests/unit/test_voiceprint_db.py @@ -462,9 +462,12 @@ def _fail_rebuild(self, transcriptions_dir: str, save_path: str | None = None): def test_concurrent_upload_dedup_reuses_single_live_job(app_client, monkeypatch): """Two simultaneous uploads of the same bytes must dedup to one queued job.""" - transcriptions = importlib.import_module("api.routers.transcriptions") + records = importlib.import_module("application.transcription_records") + submission = importlib.import_module("application.transcription_submission") audio_infra = importlib.import_module("infra.audio") + job_persistence = importlib.import_module("infra.job_persistence") job_runtime = importlib.import_module("infra.job_runtime") + record_settings = records.default_record_settings() started = threading.Event() release = threading.Event() @@ -488,14 +491,14 @@ def _fake_run_transcription( started.set() assert release.wait(timeout=5), "test timed out waiting to release worker" - transcriptions.jobs[job_id]["status"] = "completed" - transcriptions.jobs[job_id]["result"] = { + submission.jobs[job_id]["status"] = "completed" + submission.jobs[job_id]["result"] = { "id": job_id, "segments": [], "unique_speakers": [], } - transcriptions._write_status(job_id, "completed", filename=audio_path.name) - out_dir = transcriptions.TRANSCRIPTIONS_DIR / job_id + job_persistence._write_status(job_id, "completed", filename=audio_path.name) + out_dir = record_settings.transcriptions_dir / job_id out_dir.mkdir(parents=True, exist_ok=True) (out_dir / "result.json").write_text( json.dumps({"id": job_id, "segments": [], "unique_speakers": []}) @@ -505,7 +508,7 @@ def _fake_run_transcription( job_runtime.unregister_in_flight(file_hash) finished.set() - monkeypatch.setattr(transcriptions, "run_transcription", _fake_run_transcription) + monkeypatch.setattr(submission, "run_transcription", _fake_run_transcription) barrier = threading.Barrier(2) diff --git a/voscript-api/scripts/architecture_gate.py b/voscript-api/scripts/architecture_gate.py new file mode 100644 index 0000000..b15fce7 --- /dev/null +++ b/voscript-api/scripts/architecture_gate.py @@ -0,0 +1,1077 @@ +#!/usr/bin/env python3 +"""Report and check VoScript source-level architecture gates.""" + +from __future__ import annotations + +import argparse +import ast +import json +import sys +from collections import defaultdict +from pathlib import Path +from typing import Any + +CORE_RINGS = { + "api_composition", + "application", + "pipeline", + "providers", + "domain", + "infra", +} + +REGISTRY_RUNTIME_IMPORTS = { + "pipeline.registry": { + "_DEFAULT_STAGE_IMPORTS": "registry_stage", + "_DEFAULT_PROVIDER_IMPORTS": "registry_provider", + }, +} + +STATUS_CONTRACT_HELPERS = frozenset( + { + "IN_PROGRESS_JOB_STATUSES", + "JOB_STATUS_COMPLETED", + "JOB_STATUS_CONVERTING", + "JOB_STATUS_DENOISING", + "JOB_STATUS_FAILED", + "JOB_STATUS_IDENTIFYING", + "JOB_STATUS_QUEUED", + "JOB_STATUS_TRANSCRIBING", + "KNOWN_JOB_STATUSES", + "TERMINAL_JOB_STATUSES", + "build_status_payload", + "normalize_job_status", + "normalize_status_payload", + } +) + +PIPELINE_METADATA_CONTRACT_MODULE = "pipeline.contracts.metadata" + + +def _module_name(app_root: Path, path: Path) -> tuple[str, bool]: + relative = path.relative_to(app_root).with_suffix("") + parts = relative.parts + if parts[-1] == "__init__": + return ".".join(parts[:-1]), True + return ".".join(parts), False + + +def app_modules(root: Path) -> dict[str, Path]: + app_root = root / "app" + modules: dict[str, Path] = {} + for path in sorted(app_root.rglob("*.py")): + if "__pycache__" in path.parts: + continue + module_name, _ = _module_name(app_root, path) + modules[module_name] = path + return modules + + +def _resolve_relative_import( + current_module: str, + is_package: bool, + *, + level: int, + module: str | None, +) -> str: + package_parts = ( + current_module.split(".") if is_package else current_module.split(".")[:-1] + ) + prefix = package_parts[: len(package_parts) - level + 1] + if module: + prefix.extend(module.split(".")) + return ".".join(prefix) + + +def _strip_app_prefix(module_name: str) -> str: + if module_name == "app": + return "" + if module_name.startswith("app."): + return module_name.removeprefix("app.") + return module_name + + +def _internal_module_for(module_name: str, modules: set[str]) -> str | None: + candidate = _strip_app_prefix(module_name) + if not candidate: + return None + parts = candidate.split(".") + for end in range(len(parts), 0, -1): + prefix = ".".join(parts[:end]) + if prefix in modules: + return prefix + return None + + +def _is_type_checking_guard(node: ast.expr) -> bool: + if isinstance(node, ast.Name): + return node.id == "TYPE_CHECKING" + if isinstance(node, ast.Attribute): + return node.attr == "TYPE_CHECKING" + return False + + +class _ImportCollector(ast.NodeVisitor): + def __init__( + self, + *, + current_module: str, + is_package: bool, + modules: set[str], + ) -> None: + self.current_module = current_module + self.is_package = is_package + self.modules = modules + self.targets: set[str] = set() + self._type_checking_depth = 0 + + def visit_If(self, node: ast.If) -> None: + if _is_type_checking_guard(node.test): + self._type_checking_depth += 1 + for child in node.body: + self.visit(child) + self._type_checking_depth -= 1 + for child in node.orelse: + self.visit(child) + return + self.generic_visit(node) + + def visit_Import(self, node: ast.Import) -> None: + if self._type_checking_depth: + return + for alias in node.names: + self._add_internal_target(alias.name) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + if self._type_checking_depth: + return + if node.level: + base = _resolve_relative_import( + self.current_module, + self.is_package, + level=node.level, + module=node.module, + ) + else: + base = node.module or "" + + if base: + self._add_internal_target(base) + for alias in node.names: + if alias.name == "*": + continue + target = f"{base}.{alias.name}" if base else alias.name + self._add_internal_target(target) + + def _add_internal_target(self, module_name: str) -> None: + target = _internal_module_for(module_name, self.modules) + if target is not None and target != self.current_module: + self.targets.add(target) + + +def internal_import_graph(root: Path) -> dict[str, set[str]]: + module_paths = app_modules(root) + modules = set(module_paths) + app_root = root / "app" + graph: dict[str, set[str]] = {module: set() for module in modules} + for module, path in module_paths.items(): + _, is_package = _module_name(app_root, path) + collector = _ImportCollector( + current_module=module, + is_package=is_package, + modules=modules, + ) + collector.visit(ast.parse(path.read_text(encoding="utf-8"), filename=str(path))) + graph[module].update(collector.targets) + return graph + + +def _import_module_bindings(tree: ast.AST) -> tuple[set[str], set[str]]: + direct_names: set[str] = set() + module_names: set[str] = set() + for node in ast.walk(tree): + if isinstance(node, ast.ImportFrom) and node.module == "importlib": + for alias in node.names: + if alias.name == "import_module": + direct_names.add(alias.asname or alias.name) + elif isinstance(node, ast.Import): + for alias in node.names: + if alias.name == "importlib": + module_names.add(alias.asname or alias.name) + return direct_names, module_names + + +def _is_import_module_call( + call: ast.Call, + *, + direct_names: set[str], + module_names: set[str], +) -> bool: + if isinstance(call.func, ast.Name): + return call.func.id in direct_names + if not isinstance(call.func, ast.Attribute) or call.func.attr != "import_module": + return False + return isinstance(call.func.value, ast.Name) and call.func.value.id in module_names + + +def _iter_string_constants(node: ast.AST) -> list[ast.Constant]: + strings: list[ast.Constant] = [] + if isinstance(node, ast.Constant) and isinstance(node.value, str): + strings.append(node) + elif isinstance(node, ast.Dict): + for value in node.values: + if value is not None: + strings.extend(_iter_string_constants(value)) + elif isinstance(node, (ast.List, ast.Tuple, ast.Set)): + for value in node.elts: + strings.extend(_iter_string_constants(value)) + return strings + + +def _runtime_module_name(import_value: str) -> str: + module_name, _, _ = import_value.partition(":") + return module_name.strip() + + +def _add_runtime_edge( + edges_by_key: dict[tuple[str, str, str, str], dict[str, Any]], + *, + root: Path, + source: str, + target: str, + kind: str, + import_value: str, + path: Path, + lineno: int, +) -> None: + if target == source: + return + key = (source, target, kind, import_value) + edge = edges_by_key.setdefault( + key, + { + "source": source, + "target": target, + "kind": kind, + "import": import_value, + "locations": [], + }, + ) + edge["locations"].append({"path": str(path.relative_to(root)), "line": lineno}) + + +def _registry_runtime_edges( + *, + root: Path, + module: str, + path: Path, + tree: ast.AST, + modules: set[str], + edges_by_key: dict[tuple[str, str, str, str], dict[str, Any]], +) -> None: + registry_imports = REGISTRY_RUNTIME_IMPORTS.get(module) + if not registry_imports: + return + + for node in ast.walk(tree): + if not isinstance(node, ast.Assign): + continue + target_names = { + target.id for target in node.targets if isinstance(target, ast.Name) + } + for assignment_name in sorted(target_names & registry_imports.keys()): + kind = registry_imports[assignment_name] + for string_node in _iter_string_constants(node.value): + import_value = _runtime_module_name(string_node.value) + target = _internal_module_for(import_value, modules) + if target is not None: + _add_runtime_edge( + edges_by_key, + root=root, + source=module, + target=target, + kind=kind, + import_value=import_value, + path=path, + lineno=string_node.lineno, + ) + + +def _literal_import_module_edges( + *, + root: Path, + module: str, + path: Path, + tree: ast.AST, + modules: set[str], + edges_by_key: dict[tuple[str, str, str, str], dict[str, Any]], +) -> None: + direct_names, module_names = _import_module_bindings(tree) + if not direct_names and not module_names: + return + + for node in ast.walk(tree): + if not isinstance(node, ast.Call) or not node.args: + continue + if not _is_import_module_call( + node, + direct_names=direct_names, + module_names=module_names, + ): + continue + import_arg = node.args[0] + if not isinstance(import_arg, ast.Constant) or not isinstance( + import_arg.value, + str, + ): + continue + import_value = _runtime_module_name(import_arg.value) + target = _internal_module_for(import_value, modules) + if target is not None: + _add_runtime_edge( + edges_by_key, + root=root, + source=module, + target=target, + kind="literal_import_module", + import_value=import_value, + path=path, + lineno=import_arg.lineno, + ) + + +def runtime_dynamic_import_edges(root: Path) -> list[dict[str, Any]]: + module_paths = app_modules(root) + modules = set(module_paths) + edges_by_key: dict[tuple[str, str, str, str], dict[str, Any]] = {} + + for module, path in module_paths.items(): + tree = ast.parse(path.read_text(encoding="utf-8"), filename=str(path)) + _registry_runtime_edges( + root=root, + module=module, + path=path, + tree=tree, + modules=modules, + edges_by_key=edges_by_key, + ) + _literal_import_module_edges( + root=root, + module=module, + path=path, + tree=tree, + modules=modules, + edges_by_key=edges_by_key, + ) + + edges = sorted( + edges_by_key.values(), + key=lambda item: ( + item["source"], + item["target"], + item["kind"], + item["import"], + ), + ) + for edge in edges: + edge["locations"] = sorted( + edge["locations"], + key=lambda item: (item["path"], item["line"]), + ) + return edges + + +def runtime_dynamic_import_graph(root: Path) -> dict[str, set[str]]: + graph: dict[str, set[str]] = {module: set() for module in app_modules(root)} + for edge in runtime_dynamic_import_edges(root): + graph[edge["source"]].add(edge["target"]) + return graph + + +def strongly_connected_components(graph: dict[str, set[str]]) -> list[tuple[str, ...]]: + index = 0 + indexes: dict[str, int] = {} + lowlinks: dict[str, int] = {} + stack: list[str] = [] + on_stack: set[str] = set() + components: list[tuple[str, ...]] = [] + + def strongconnect(node: str) -> None: + nonlocal index + indexes[node] = index + lowlinks[node] = index + index += 1 + stack.append(node) + on_stack.add(node) + + for target in sorted(graph[node]): + if target not in indexes: + strongconnect(target) + lowlinks[node] = min(lowlinks[node], lowlinks[target]) + elif target in on_stack: + lowlinks[node] = min(lowlinks[node], indexes[target]) + + if lowlinks[node] == indexes[node]: + component: list[str] = [] + while True: + target = stack.pop() + on_stack.remove(target) + component.append(target) + if target == node: + break + if len(component) > 1: + components.append(tuple(sorted(component))) + + for module in sorted(graph): + if module not in indexes: + strongconnect(module) + return sorted(components) + + +def ring_for_module(module: str) -> str: + root = module.split(".", 1)[0] + if module == "main" or root == "api": + return "api_composition" + if root == "application": + return "application" + if root == "pipeline": + return "pipeline" + if root == "providers": + return "providers" + if root == "voiceprints": + return "domain" + if root == "infra": + return "infra" + if root == "config": + return "configuration" + if root == "postprocess": + return "postprocess" + if root == "nltk": + return "vendor_shim" + return "other" + + +def layer_edges( + graph: dict[str, set[str]], +) -> tuple[list[dict[str, Any]], dict[str, set[str]]]: + grouped: dict[tuple[str, str], list[dict[str, str]]] = defaultdict(list) + layer_graph: dict[str, set[str]] = {ring: set() for ring in CORE_RINGS} + for source, targets in graph.items(): + source_ring = ring_for_module(source) + for target in targets: + target_ring = ring_for_module(target) + if source_ring == target_ring: + continue + if source_ring in CORE_RINGS and target_ring in CORE_RINGS: + grouped[(source_ring, target_ring)].append( + {"source": source, "target": target} + ) + layer_graph[source_ring].add(target_ring) + + edges = [ + { + "source": source, + "target": target, + "imports": sorted( + imports, key=lambda item: (item["source"], item["target"]) + ), + } + for (source, target), imports in sorted(grouped.items()) + ] + return edges, layer_graph + + +def _fastapi_import_labels(node: ast.AST) -> list[str]: + if isinstance(node, ast.Import): + return [ + alias.name + for alias in node.names + if alias.name == "fastapi" or alias.name.startswith("fastapi.") + ] + if isinstance(node, ast.ImportFrom) and node.module: + if node.module == "fastapi" or node.module.startswith("fastapi."): + return [f"from {node.module}"] + return [] + + +def _http_exception_reference_lines(tree: ast.AST) -> list[int]: + lines: set[int] = set() + for node in ast.walk(tree): + if isinstance(node, ast.Name) and node.id == "HTTPException": + lines.add(node.lineno) + elif isinstance(node, ast.Attribute) and node.attr == "HTTPException": + lines.add(node.lineno) + elif isinstance(node, (ast.Import, ast.ImportFrom)): + for alias in node.names: + if alias.name == "HTTPException" or alias.asname == "HTTPException": + lines.add(node.lineno) + return sorted(lines) + + +def _pipeline_status_import_locations(tree: ast.AST) -> list[dict[str, Any]]: + locations: list[dict[str, Any]] = [] + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + if alias.name == "pipeline.contracts.status": + locations.append( + { + "line": node.lineno, + "import": alias.name, + "symbol": None, + } + ) + elif isinstance(node, ast.ImportFrom): + if node.module == "pipeline.contracts.status": + imported = tuple(alias.name for alias in node.names) + locations.append( + { + "line": node.lineno, + "import": "from pipeline.contracts.status", + "symbol": "*" if "*" in imported else ", ".join(imported), + } + ) + elif node.module == "pipeline.contracts": + for alias in node.names: + if ( + alias.name == "*" + or alias.name == "status" + or alias.name in STATUS_CONTRACT_HELPERS + ): + locations.append( + { + "line": node.lineno, + "import": "from pipeline.contracts", + "symbol": alias.name, + } + ) + return locations + + +def _dotted_name(node: ast.AST) -> str | None: + if isinstance(node, ast.Name): + return node.id + if isinstance(node, ast.Attribute): + base = _dotted_name(node.value) + if base is None: + return None + return f"{base}.{node.attr}" + return None + + +def _is_private_job_persistence_symbol(name: str) -> bool: + return name.startswith("_") + + +def _application_job_boundary_locations(tree: ast.AST) -> list[dict[str, Any]]: + locations: list[dict[str, Any]] = [] + runtime_module_aliases: set[str] = set() + persistence_module_aliases: set[str] = set() + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + imported_name = alias.asname or alias.name + if alias.name == "infra.job_runtime": + runtime_module_aliases.add(imported_name) + elif alias.name == "infra.job_persistence": + persistence_module_aliases.add(imported_name) + elif isinstance(node, ast.ImportFrom): + if node.module == "infra.job_runtime": + for alias in node.names: + if alias.name in {"*", "jobs"}: + locations.append( + { + "line": node.lineno, + "import": "from infra.job_runtime", + "symbol": alias.name, + } + ) + elif node.module == "infra.job_persistence": + for alias in node.names: + if alias.name == "*" or _is_private_job_persistence_symbol( + alias.name + ): + locations.append( + { + "line": node.lineno, + "import": "from infra.job_persistence", + "symbol": alias.name, + } + ) + elif node.module == "infra": + for alias in node.names: + imported_name = alias.asname or alias.name + if alias.name == "job_runtime": + runtime_module_aliases.add(imported_name) + elif alias.name == "job_persistence": + persistence_module_aliases.add(imported_name) + + for node in ast.walk(tree): + if not isinstance(node, ast.Attribute): + continue + base_name = _dotted_name(node.value) + if node.attr == "jobs" and base_name in runtime_module_aliases: + locations.append( + { + "line": node.lineno, + "import": f"{base_name}.jobs", + "symbol": "jobs", + } + ) + elif ( + _is_private_job_persistence_symbol(node.attr) + and base_name in persistence_module_aliases + ): + locations.append( + { + "line": node.lineno, + "import": f"{base_name}.{node.attr}", + "symbol": node.attr, + } + ) + + for node in ast.walk(tree): + if ( + isinstance(node, ast.Constant) + and isinstance(node.value, str) + and node.value in {"result.json", "status.json"} + ): + locations.append( + { + "line": node.lineno, + "import": "transcription record filesystem literal", + "symbol": node.value, + } + ) + + return sorted( + locations, + key=lambda item: (item["line"], item["import"], item["symbol"] or ""), + ) + + +def _pipeline_metadata_allowed_top_level_keys(root: Path) -> frozenset[str]: + metadata_contract = root / "app" / "pipeline" / "contracts" / "metadata.py" + if not metadata_contract.exists(): + return frozenset() + + tree = ast.parse(metadata_contract.read_text(encoding="utf-8")) + assignments: dict[str, tuple[str, ...]] = {} + for node in tree.body: + if not isinstance(node, ast.Assign) or len(node.targets) != 1: + continue + target = node.targets[0] + if not isinstance(target, ast.Name): + continue + if target.id in { + "PIPELINE_METADATA_CONTROL_KEYS", + "PIPELINE_METADATA_STAGE_KEYS", + "PIPELINE_METADATA_TOP_LEVEL_KEYS", + }: + assignments[target.id] = _literal_string_tuple( + node.value, + assignments, + ) + if assignments.get("PIPELINE_METADATA_TOP_LEVEL_KEYS"): + return frozenset(assignments["PIPELINE_METADATA_TOP_LEVEL_KEYS"]) + return frozenset( + ( + *assignments.get("PIPELINE_METADATA_CONTROL_KEYS", ()), + *assignments.get("PIPELINE_METADATA_STAGE_KEYS", ()), + ) + ) + + +def _literal_string_tuple( + node: ast.AST, + assignments: dict[str, tuple[str, ...]], +) -> tuple[str, ...]: + if isinstance(node, ast.Constant) and isinstance(node.value, str): + return (node.value,) + if isinstance(node, ast.Name): + return assignments.get(node.id, ()) + if isinstance(node, ast.Starred): + return _literal_string_tuple(node.value, assignments) + if isinstance(node, (ast.Tuple, ast.List)): + values: list[str] = [] + for element in node.elts: + values.extend(_literal_string_tuple(element, assignments)) + return tuple(values) + return () + + +def _literal_string_slice(node: ast.slice) -> str | None: + if isinstance(node, ast.Constant) and isinstance(node.value, str): + return node.value + return None + + +def _is_context_metadata(node: ast.AST) -> bool: + return ( + isinstance(node, ast.Attribute) + and node.attr == "metadata" + and isinstance(node.value, ast.Name) + and node.value.id == "context" + ) + + +def _context_metadata_subscript_key(node: ast.AST) -> str | None: + if not isinstance(node, ast.Subscript) or not _is_context_metadata(node.value): + return None + return _literal_string_slice(node.slice) + + +def _context_metadata_update_key(node: ast.AST) -> str | None: + if ( + not isinstance(node, ast.Call) + or not isinstance(node.func, ast.Attribute) + or node.func.attr != "update" + or not isinstance(node.func.value, ast.Subscript) + or not _is_context_metadata(node.func.value.value) + ): + return None + return _literal_string_slice(node.func.value.slice) or "" + + +def _pipeline_context_metadata_key_locations( + *, + tree: ast.AST, + allowed_keys: frozenset[str], +) -> list[dict[str, Any]]: + locations: list[dict[str, Any]] = [] + + for node in ast.walk(tree): + if isinstance(node, ast.Subscript): + key = _context_metadata_subscript_key(node) + if key is not None and key not in allowed_keys: + locations.append( + { + "line": node.lineno, + "key": key, + "access": "subscript", + } + ) + elif ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Attribute) + and node.func.attr in {"get", "pop", "setdefault"} + and _is_context_metadata(node.func.value) + and node.args + ): + key_node = node.args[0] + if isinstance(key_node, ast.Constant) and isinstance(key_node.value, str): + key = key_node.value + if key not in allowed_keys: + locations.append( + { + "line": node.lineno, + "key": key, + "access": node.func.attr, + } + ) + + return sorted( + locations, + key=lambda item: (item["line"], item["key"], item["access"]), + ) + + +def _pipeline_context_metadata_update_locations( + *, + tree: ast.AST, +) -> list[dict[str, Any]]: + locations: list[dict[str, Any]] = [] + for node in ast.walk(tree): + key = _context_metadata_update_key(node) + if key is not None: + locations.append( + { + "line": node.lineno, + "key": key, + } + ) + return sorted(locations, key=lambda item: (item["line"], item["key"])) + + +def forbidden_dependencies( + root: Path, graph: dict[str, set[str]] +) -> list[dict[str, Any]]: + module_paths = app_modules(root) + findings: list[dict[str, Any]] = [] + pipeline_metadata_allowed_keys = _pipeline_metadata_allowed_top_level_keys(root) + + for module, path in module_paths.items(): + ring = ring_for_module(module) + tree = ast.parse(path.read_text(encoding="utf-8"), filename=str(path)) + metadata_key_locations = _pipeline_context_metadata_key_locations( + tree=tree, + allowed_keys=pipeline_metadata_allowed_keys, + ) + if metadata_key_locations: + findings.append( + { + "rule": "pipeline_context_metadata_top_level_key_contract", + "module": module, + "path": str(path.relative_to(root)), + "locations": metadata_key_locations, + } + ) + if module != PIPELINE_METADATA_CONTRACT_MODULE: + metadata_update_locations = _pipeline_context_metadata_update_locations( + tree=tree, + ) + if metadata_update_locations: + findings.append( + { + "rule": "pipeline_context_metadata_no_unbounded_update", + "module": module, + "path": str(path.relative_to(root)), + "locations": metadata_update_locations, + } + ) + if ring != "api_composition": + fastapi_imports: list[str] = [] + for node in ast.walk(tree): + fastapi_imports.extend(_fastapi_import_labels(node)) + http_exception_lines = _http_exception_reference_lines(tree) + if fastapi_imports or http_exception_lines: + findings.append( + { + "rule": "fastapi_types_stay_in_api_ring", + "module": module, + "path": str(path.relative_to(root)), + "fastapi_imports": fastapi_imports, + "http_exception_lines": http_exception_lines, + } + ) + if ring in {"application", "infra"}: + status_imports = _pipeline_status_import_locations(tree) + if status_imports: + findings.append( + { + "rule": "application_and_infra_use_infra_job_status_owner", + "module": module, + "path": str(path.relative_to(root)), + "locations": status_imports, + } + ) + if ring == "application": + job_boundary_imports = _application_job_boundary_locations(tree) + if job_boundary_imports: + findings.append( + { + "rule": "application_uses_public_infra_job_boundary", + "module": module, + "path": str(path.relative_to(root)), + "locations": job_boundary_imports, + } + ) + + for source, targets in graph.items(): + source_ring = ring_for_module(source) + for target in sorted(targets): + if source_ring != "api_composition" and ( + target == "main" or target == "api" or target.startswith("api.") + ): + findings.append( + { + "rule": "non_api_rings_do_not_import_api", + "module": source, + "target": target, + } + ) + if source_ring == "providers" and ( + target == "application" + or target.startswith("application.") + or target == "pipeline.registry" + or target.startswith("pipeline.registry.") + or target == "pipeline.stages" + or target.startswith("pipeline.stages.") + ): + findings.append( + { + "rule": "providers_do_not_import_orchestration_or_stage_registry", + "module": source, + "target": target, + } + ) + return sorted(findings, key=lambda item: (item["rule"], item.get("module", ""))) + + +def forbidden_dynamic_dependencies( + dynamic_edges: list[dict[str, Any]], +) -> list[dict[str, Any]]: + findings: list[dict[str, Any]] = [] + for edge in dynamic_edges: + source = edge["source"] + target = edge["target"] + source_ring = ring_for_module(source) + target_ring = ring_for_module(target) + provider_imports_orchestration = source_ring == "providers" and ( + target_ring == "application" + or target == "pipeline.registry" + or target.startswith("pipeline.registry.") + or target == "pipeline.stages" + or target.startswith("pipeline.stages.") + ) + if source_ring != "api_composition" and ( + target == "main" or target == "api" or target.startswith("api.") + ): + findings.append( + { + "rule": "non_api_rings_do_not_runtime_import_api", + "module": source, + "target": target, + "kind": edge["kind"], + "import": edge["import"], + "locations": edge["locations"], + } + ) + if provider_imports_orchestration: + findings.append( + { + "rule": "providers_do_not_runtime_import_orchestration_or_stage_registry", + "module": source, + "target": target, + "kind": edge["kind"], + "import": edge["import"], + "locations": edge["locations"], + } + ) + elif source_ring not in {"api_composition", "application"} and ( + target_ring == "application" + ): + findings.append( + { + "rule": "non_application_rings_do_not_runtime_import_application", + "module": source, + "target": target, + "kind": edge["kind"], + "import": edge["import"], + "locations": edge["locations"], + } + ) + return sorted( + findings, + key=lambda item: (item["rule"], item["module"], item["target"]), + ) + + +def build_report(root: Path) -> dict[str, Any]: + static_graph = internal_import_graph(root) + static_edges, static_layer_graph = layer_edges(static_graph) + dynamic_edges = runtime_dynamic_import_edges(root) + dynamic_graph: dict[str, set[str]] = {module: set() for module in static_graph} + for edge in dynamic_edges: + dynamic_graph[edge["source"]].add(edge["target"]) + dynamic_layer_edges, _ = layer_edges(dynamic_graph) + return { + "static_import_graph": { + "module_count": len(static_graph), + "internal_edge_count": sum( + len(targets) for targets in static_graph.values() + ), + "module_sccs": [ + list(component) + for component in strongly_connected_components(static_graph) + ], + "layer_edges": static_edges, + "layer_sccs": [ + list(component) + for component in strongly_connected_components(static_layer_graph) + ], + }, + "static_forbidden_dependencies": forbidden_dependencies(root, static_graph), + "runtime_dynamic_import_graph": { + "edge_count": sum(len(targets) for targets in dynamic_graph.values()), + "edges": dynamic_edges, + "module_sccs": [ + list(component) + for component in strongly_connected_components(dynamic_graph) + ], + "layer_edges": dynamic_layer_edges, + }, + "runtime_dynamic_forbidden_dependencies": forbidden_dynamic_dependencies( + dynamic_edges + ), + } + + +def _print_text(report: dict[str, Any]) -> None: + static_graph = report["static_import_graph"] + print("static_import_graph:") + print(f" module_count: {static_graph['module_count']}") + print(f" internal_edge_count: {static_graph['internal_edge_count']}") + print(f" module_sccs: {static_graph['module_sccs']}") + print(" layer_edges:") + for edge in static_graph["layer_edges"]: + print( + f" - {edge['source']} -> {edge['target']} ({len(edge['imports'])} imports)" + ) + print(f" layer_sccs: {static_graph['layer_sccs']}") + print(f"static_forbidden_dependencies: {report['static_forbidden_dependencies']}") + + dynamic_graph = report["runtime_dynamic_import_graph"] + print("runtime_dynamic_import_graph:") + print(f" edge_count: {dynamic_graph['edge_count']}") + print(" edges:") + for edge in dynamic_graph["edges"]: + locations = ", ".join( + f"{location['path']}:{location['line']}" for location in edge["locations"] + ) + print( + f" - {edge['source']} -> {edge['target']} " + f"[{edge['kind']}] import={edge['import']!r} ({locations})" + ) + print(f" module_sccs: {dynamic_graph['module_sccs']}") + print(" layer_edges:") + for edge in dynamic_graph["layer_edges"]: + print( + f" - {edge['source']} -> {edge['target']} " + f"({len(edge['imports'])} runtime imports)" + ) + print( + "runtime_dynamic_forbidden_dependencies: " + f"{report['runtime_dynamic_forbidden_dependencies']}" + ) + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--root", default=".", help="Repository root to scan.") + parser.add_argument( + "--format", + choices=("text", "json"), + default="text", + help="Output format.", + ) + parser.add_argument( + "--check", + action="store_true", + help="Exit non-zero when architecture violations are present.", + ) + args = parser.parse_args() + + root = Path(args.root).expanduser().resolve() + report = build_report(root) + if args.format == "json": + print(json.dumps(report, ensure_ascii=False, indent=2, sort_keys=True)) + else: + _print_text(report) + + static_graph = report["static_import_graph"] + dynamic_graph = report["runtime_dynamic_import_graph"] + if args.check and ( + static_graph["module_sccs"] + or static_graph["layer_sccs"] + or report["static_forbidden_dependencies"] + or dynamic_graph["module_sccs"] + or report["runtime_dynamic_forbidden_dependencies"] + ): + return 1 + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/voscript-api/scripts/docs_code_drift_gate.py b/voscript-api/scripts/docs_code_drift_gate.py new file mode 100644 index 0000000..32d02b8 --- /dev/null +++ b/voscript-api/scripts/docs_code_drift_gate.py @@ -0,0 +1,476 @@ +#!/usr/bin/env python3 +"""Check public docs against VoScript's runtime surface anchors.""" + +from __future__ import annotations + +import argparse +import ast +import json +import sys +from pathlib import Path +from typing import Any + +CONFIG_DOC_FILES = ( + "doc/configuration.zh.md", + "doc/configuration.en.md", + ".env.example", +) +API_DOC_FILES = ( + "doc/api.zh.md", + "doc/api.en.md", +) +README_FILES = ( + "README.md", + "README.en.md", +) +ENV_HELPERS = { + "_env_float", + "_env_int", + "_env_str", + "_env_csv_set", + "_env_mapping", +} +PLACEHOLDER_ENV_VALUES = ( + "", + "change-me-to-a-long-random-string", +) +PLACEHOLDER_ENV_MARKERS = ( + "placeholder", + "replace-me", + "replace_me", +) + +RESULT_CONTRACT_TERMS = ( + "status", + "segments[].speaker_label", + "segments[].words", + "alignment", + "artifacts", + "speaker_map", + "unique_speakers", + "similarity", + "params", + "no_repeat_ngram_size", + "MAX_UPLOAD_BYTES", +) + +RUST_MODE_TERMS = ( + "RUST_KERNEL_MODE", + "off", + "required", + "fail closed", +) + + +def _read(root: Path, rel_path: str) -> str: + return (root / rel_path).read_text(encoding="utf-8") + + +def _finding(category: str, path: str, term: str, advice: str) -> dict[str, str]: + return { + "category": category, + "path": path, + "term": term, + "advice": advice, + } + + +def _string_arg(node: ast.AST) -> str | None: + if isinstance(node, ast.Constant) and isinstance(node.value, str): + return node.value + return None + + +def _literal_value(node: ast.AST) -> Any: + if isinstance(node, ast.Constant): + return node.value + if isinstance(node, ast.UnaryOp): + value = _literal_value(node.operand) + if isinstance(node.op, ast.USub) and isinstance(value, (int, float)): + return -value + if isinstance(node, ast.BinOp): + left = _literal_value(node.left) + right = _literal_value(node.right) + if isinstance(left, (int, float)) and isinstance(right, (int, float)): + if isinstance(node.op, ast.Add): + return left + right + if isinstance(node.op, ast.Sub): + return left - right + if isinstance(node.op, ast.Mult): + return left * right + if isinstance(node.op, ast.Div): + return left / right + if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.args: + value = _literal_value(node.args[0]) + if node.func.id == "str": + return str(value) + if node.func.id == "int": + return int(value) + if node.func.id == "float": + return float(value) + return None + + +def _default_terms(value: Any) -> tuple[str, ...]: + if value is None: + return () + if isinstance(value, float) and value.is_integer(): + return (str(int(value)), str(value)) + return (str(value),) + + +def config_env_defaults(root: Path) -> dict[str, tuple[str, ...]]: + tree = ast.parse(_read(root, "app/config.py"), filename="app/config.py") + defaults: dict[str, tuple[str, ...]] = {} + for node in ast.walk(tree): + if not isinstance(node, ast.Call): + continue + if isinstance(node.func, ast.Attribute): + if node.func.attr != "getenv" or len(node.args) < 1: + continue + key = _string_arg(node.args[0]) + if key is None: + continue + default = _literal_value(node.args[1]) if len(node.args) > 1 else None + defaults.setdefault(key, _default_terms(default)) + continue + if not isinstance(node.func, ast.Name) or node.func.id not in ENV_HELPERS: + continue + if not node.args: + continue + key = _string_arg(node.args[0]) + if key is None: + continue + default = _literal_value(node.args[1]) if len(node.args) > 1 else None + defaults.setdefault(key, _default_terms(default)) + return dict(sorted(defaults.items())) + + +def env_example_values(root: Path) -> dict[str, str]: + values: dict[str, str] = {} + for line in _read(root, ".env.example").splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith("#") or "=" not in stripped: + continue + key, value = stripped.split("=", 1) + if key: + values[key] = value + return dict(sorted(values.items())) + + +def compose_variable_refs(root: Path) -> set[str]: + import re + + compose = _read(root, "docker-compose.yml") + return set(re.findall(r"\$\{([A-Za-z_][A-Za-z0-9_]*)(?::-[^}]*)?\}", compose)) + + +def router_sources(root: Path) -> list[tuple[str, str]]: + main_tree = ast.parse(_read(root, "app/main.py"), filename="app/main.py") + imported_routers: dict[str, str] = {} + included_router_names: set[str] = set() + for node in ast.walk(main_tree): + if isinstance(node, ast.ImportFrom) and node.module == "api.routers": + for alias in node.names: + imported_routers[alias.asname or alias.name] = alias.name + if not isinstance(node, ast.Call): + continue + func = node.func + if not isinstance(func, ast.Attribute) or func.attr != "include_router": + continue + if not node.args: + continue + first_arg = node.args[0] + if ( + isinstance(first_arg, ast.Attribute) + and first_arg.attr == "router" + and isinstance(first_arg.value, ast.Name) + ): + included_router_names.add(first_arg.value.id) + + sources: list[tuple[str, str]] = [] + for local_name in sorted(included_router_names): + router_module = imported_routers.get(local_name) + if router_module is None: + continue + rel_path = f"app/api/routers/{router_module}.py" + sources.append((rel_path, router_prefix(root, rel_path))) + return sources + + +def router_prefix(root: Path, rel_path: str) -> str: + tree = ast.parse(_read(root, rel_path), filename=rel_path) + for node in ast.walk(tree): + if not isinstance(node, ast.Call): + continue + if not isinstance(node.func, ast.Name) or node.func.id != "APIRouter": + continue + for keyword in node.keywords: + if keyword.arg == "prefix": + value = _string_arg(keyword.value) + return value or "" + return "" + return "" + + +def public_routes(root: Path) -> list[dict[str, str]]: + routes: list[dict[str, str]] = [] + for rel_path, prefix in router_sources(root): + tree = ast.parse(_read(root, rel_path), filename=rel_path) + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + for decorator in node.decorator_list: + if not isinstance(decorator, ast.Call): + continue + func = decorator.func + if not isinstance(func, ast.Attribute): + continue + if func.attr not in {"get", "post", "put", "delete"}: + continue + if not isinstance(func.value, ast.Name) or func.value.id != "router": + continue + if not decorator.args: + continue + route_path = _string_arg(decorator.args[0]) + if route_path is None: + continue + routes.append( + { + "method": func.attr.upper(), + "path": f"{prefix}{route_path}", + "source": rel_path, + "handler": node.name, + } + ) + return sorted(routes, key=lambda item: (item["path"], item["method"])) + + +def _route_doc_terms(route: dict[str, str]) -> tuple[str, ...]: + method = route["method"] + path = route["path"] + terms = {f"{method} {path}"} + id_alias = path + for placeholder in ("job_id", "speaker_id"): + id_alias = id_alias.replace("{" + placeholder + "}", "{id}") + terms.add(f"{method} {id_alias}") + return tuple(sorted(terms)) + + +def _contains_any(text: str, terms: tuple[str, ...]) -> bool: + return any(term in text for term in terms) + + +def _is_placeholder_env_value(value: str) -> bool: + normalized = value.strip().lower() + return normalized in PLACEHOLDER_ENV_VALUES or any( + marker in normalized for marker in PLACEHOLDER_ENV_MARKERS + ) + + +def _check_public_routes(root: Path) -> list[dict[str, str]]: + findings: list[dict[str, str]] = [] + routes = public_routes(root) + api_docs = {path: _read(root, path) for path in API_DOC_FILES} + for route in routes: + terms = _route_doc_terms(route) + for doc_path, text in api_docs.items(): + if not _contains_any(text, terms): + findings.append( + _finding( + "api_route_missing_from_docs", + doc_path, + " or ".join(terms), + f"Document route from {route['source']}::{route['handler']}.", + ) + ) + return findings + + +def _check_config_docs(root: Path) -> list[dict[str, str]]: + findings: list[dict[str, str]] = [] + config_docs = {path: _read(root, path) for path in CONFIG_DOC_FILES} + config_defaults = config_env_defaults(root) + env_values = env_example_values(root) + compose_refs = compose_variable_refs(root) + for key in sorted(config_defaults): + for doc_path in CONFIG_DOC_FILES[:2]: + text = config_docs[doc_path] + if key not in text: + findings.append( + _finding( + "config_env_key_missing_from_config_docs", + doc_path, + key, + "Document config.py env keys or state why they are not public knobs.", + ) + ) + + for key in sorted(env_values): + for doc_path, text in config_docs.items(): + if key not in text: + findings.append( + _finding( + "public_config_key_missing_from_docs", + doc_path, + key, + "Keep public env/config docs in sync with config.py and compose.", + ) + ) + + for key in sorted(env_values): + if key not in compose_refs: + findings.append( + _finding( + "env_example_key_missing_from_compose", + "docker-compose.yml", + key, + "Keep .env.example keys wired into compose or remove the public knob.", + ) + ) + + for key, defaults in config_defaults.items(): + if not defaults: + continue + for doc_path, text in config_docs.items(): + if not any(default in text for default in defaults): + findings.append( + _finding( + "public_config_default_missing_from_docs", + doc_path, + f"{key}={'/'.join(defaults)}", + "Document public defaults in both languages and .env.example.", + ) + ) + + for key, value in env_values.items(): + if _is_placeholder_env_value(value): + continue + for doc_path in CONFIG_DOC_FILES[:2]: + text = config_docs[doc_path] + if value and value not in text: + findings.append( + _finding( + "env_example_default_missing_from_config_docs", + doc_path, + f"{key}={value}", + "Keep .env.example defaults aligned with configuration docs.", + ) + ) + return findings + + +def _check_contract_docs(root: Path) -> list[dict[str, str]]: + findings: list[dict[str, str]] = [] + api_docs = {path: _read(root, path) for path in API_DOC_FILES} + for term in RESULT_CONTRACT_TERMS: + for doc_path, text in api_docs.items(): + if term not in text: + findings.append( + _finding( + "result_contract_term_missing_from_api_docs", + doc_path, + term, + "Keep public result/status/artifact contract docs synchronized.", + ) + ) + + config_docs = {path: _read(root, path) for path in CONFIG_DOC_FILES[:2]} + for term in RUST_MODE_TERMS: + for doc_path, text in config_docs.items(): + if term not in text: + findings.append( + _finding( + "rust_mode_term_missing_from_config_docs", + doc_path, + term, + "Keep Rust mode wording precise: required by default and fail closed; off is an explicit rollback.", + ) + ) + return findings + + +def _check_readme_links(root: Path) -> list[dict[str, str]]: + findings: list[dict[str, str]] = [] + for path in README_FILES: + text = _read(root, path) + for term in ("configuration.", "api."): + if term not in text: + findings.append( + _finding( + "readme_public_doc_link_missing", + path, + term, + "README must route users to configuration and API references.", + ) + ) + return findings + + +def build_report(root: Path) -> dict[str, Any]: + config_defaults = config_env_defaults(root) + env_values = env_example_values(root) + findings: list[dict[str, str]] = [] + findings.extend(_check_public_routes(root)) + findings.extend(_check_config_docs(root)) + findings.extend(_check_contract_docs(root)) + findings.extend(_check_readme_links(root)) + return { + "api_docs": list(API_DOC_FILES), + "checked_routes": public_routes(root), + "compose_variable_refs": sorted(compose_variable_refs(root)), + "config_env_defaults": { + key: list(defaults) for key, defaults in config_defaults.items() + }, + "config_docs": list(CONFIG_DOC_FILES), + "env_example_keys": sorted(env_values), + "public_config_keys": sorted(set(config_defaults) | set(env_values)), + "router_sources": [ + {"path": path, "prefix": prefix} for path, prefix in router_sources(root) + ], + "findings": findings, + } + + +def _print_text(report: dict[str, Any]) -> None: + print(f"checked_routes: {len(report['checked_routes'])}") + print(f"public_config_keys: {len(report['public_config_keys'])}") + if report["findings"]: + print("docs/code drift findings:") + for item in report["findings"]: + print(f"- {item['path']}: {item['category']}: {item['term']}") + print(f" {item['advice']}") + else: + print("docs/code drift gate passed") + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--root", default=".", help="Repository root to scan.") + parser.add_argument( + "--format", + choices=("text", "json"), + default="text", + help="Output format.", + ) + parser.add_argument( + "--check", + action="store_true", + help="Exit non-zero when docs/code drift findings are present.", + ) + args = parser.parse_args() + + root = Path(args.root).expanduser().resolve() + report = build_report(root) + if args.format == "json": + print(json.dumps(report, ensure_ascii=False, indent=2, sort_keys=True)) + else: + _print_text(report) + + if args.check and report["findings"]: + return 1 + return 0 + + +if __name__ == "__main__": + sys.exit(main())