Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ MODEL_IDLE_TIMEOUT_SEC=180
# Runtime mode for optional Rust-backed provider/kernel paths.
# off — default; use Python implementations.
# required — selected Rust-backed paths must run and hard-fail on import/call errors.
# Currently selected paths: voiceprint scoring and result post-processing.
# CI/Docker packaging still validates the Rust extension directly even when
# the runtime default is off.
RUST_KERNEL_MODE=off
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/rust-foundation-heavy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ jobs:
docker run --rm \
-e RUST_KERNEL_MODE=required \
voscript-rust-foundation:${{ github.sha }} \
python -c "from providers.kernel_bridge import core_smoke, voiceprint_score; result = core_smoke({'source': 'ci'}); assert result['ok'] is True; assert result['echoed']['source'] == 'ci'; decision = voiceprint_score({'query_embedding': [1.0, 0.0], 'candidates': [{'speaker_id': 'spk_ci', 'name': 'CI', 'embedding': [1.0, 0.0], 'sample_count': 1, 'sample_spread': None}], 'threshold': 0.75, 'asnorm_threshold': 0.5, 'cohort': None}); assert decision['matched_id'] == 'spk_ci'; assert decision['reason'] == 'matched'"
python -c "from providers.kernel_bridge import core_smoke, postprocess_segments, voiceprint_score; result = core_smoke({'source': 'ci'}); assert result['ok'] is True; assert result['echoed']['source'] == 'ci'; decision = voiceprint_score({'query_embedding': [1.0, 0.0], 'candidates': [{'speaker_id': 'spk_ci', 'name': 'CI', 'embedding': [1.0, 0.0], 'sample_count': 1, 'sample_spread': None}], 'threshold': 0.75, 'asnorm_threshold': 0.5, 'cohort': None}); assert decision['matched_id'] == 'spk_ci'; assert decision['reason'] == 'matched'; processed = postprocess_segments({'aligned_segments': [{'start': 0.0, 'end': 1.0, 'text': 'hello', 'speaker': 'SPEAKER_00'}], 'speaker_map': {}}); assert processed['segments'][0]['speaker_label'] == 'SPEAKER_00'; assert processed['unique_speakers'] == ['SPEAKER_00']"

- name: Run health check smoke
run: |
Expand Down
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 2 additions & 16 deletions app/pipeline/stages/diarization/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from typing import Any

from postprocess.segments import normalize_words


def assign_segment_speaker(
seg_start: float, seg_end: float, diarization_turns: list[dict[str, Any]]
Expand Down Expand Up @@ -31,22 +33,6 @@ def assign_segment_speaker(
return best_speaker


def normalize_words(raw_words: list[dict[str, Any]] | None) -> list[dict[str, Any]]:
"""Normalise WhisperX word payloads to JSON-safe plain Python dicts."""
if not raw_words:
return []

return [
{
"word": str(word.get("word", "")),
"start": round(float(word.get("start", 0.0)), 3),
"end": round(float(word.get("end", 0.0)), 3),
"score": round(float(word.get("score", 0.0)), 4),
}
for word in raw_words
]


def normalize_segment(
segment: dict[str, Any], diarization_turns: list[dict[str, Any]]
) -> dict[str, Any]:
Expand Down
15 changes: 15 additions & 0 deletions app/postprocess/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Pure transcript post-processing helpers."""

from .segments import (
build_display_names,
build_result_segments,
merge_aligned_segments,
normalize_words,
)

__all__ = [
"build_display_names",
"build_result_segments",
"merge_aligned_segments",
"normalize_words",
]
161 changes: 161 additions & 0 deletions app/postprocess/segments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""Pure helpers for result-segment post-processing.

These helpers are the Python oracle for Rust post-processing kernels. Keep them
free of model, filesystem, request, and database state so the cross-language
contract stays small and directly testable.
"""

from __future__ import annotations

from collections.abc import Mapping
from math import isfinite
from typing import Any

MERGE_GAP_SECONDS = 0.05


def _number(value: object, *, default: float = 0.0) -> float:
try:
parsed = float(value)
except (TypeError, ValueError):
return default
return parsed if isfinite(parsed) else default


def _rounded_time(value: object) -> float:
return round(max(0.0, _number(value)), 3)


def _rounded_score(value: object) -> float:
return round(_number(value), 4)


def normalize_words(raw_words: list[dict[str, Any]] | None) -> list[dict[str, Any]]:
"""Normalize model word payloads to JSON-safe plain dictionaries."""

if not raw_words:
return []

normalized: list[dict[str, Any]] = []
for raw_word in raw_words:
word = raw_word if isinstance(raw_word, Mapping) else {}
normalized.append(
{
"word": str(word.get("word", "")),
"start": _rounded_time(word.get("start", 0.0)),
"end": _rounded_time(word.get("end", 0.0)),
"score": _rounded_score(word.get("score", 0.0)),
}
)
return normalized


def _normalize_aligned_segment(segment: Mapping[str, Any]) -> dict[str, Any]:
result = {
"start": _rounded_time(segment.get("start", 0.0)),
"end": _rounded_time(segment.get("end", 0.0)),
"text": str(segment.get("text", "")).strip(),
"speaker": str(segment.get("speaker", "UNKNOWN") or "UNKNOWN"),
}
words = normalize_words(segment.get("words"))
if words:
result["words"] = words
return result


def _can_merge_segments(
previous: Mapping[str, Any], current: Mapping[str, Any]
) -> bool:
if previous.get("speaker") != current.get("speaker"):
return False
if previous.get("words") or current.get("words"):
return False
previous_end = _number(previous.get("end", 0.0))
current_start = _number(current.get("start", 0.0))
return current_start <= previous_end + MERGE_GAP_SECONDS


def merge_aligned_segments(
aligned_segments: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Merge adjacent text-only segments for the same stable speaker label."""

merged: list[dict[str, Any]] = []
for raw_segment in aligned_segments:
segment = _normalize_aligned_segment(raw_segment)
if merged and _can_merge_segments(merged[-1], segment):
previous = merged[-1]
previous["end"] = max(
_rounded_time(previous.get("end", 0.0)),
_rounded_time(segment.get("end", 0.0)),
)
previous_text = str(previous.get("text", "")).strip()
current_text = str(segment.get("text", "")).strip()
previous["text"] = " ".join(
part for part in (previous_text, current_text) if part
)
continue
merged.append(segment)
return merged


def build_display_names(
speaker_labels: list[str],
speaker_map: dict[str, dict[str, Any]],
) -> dict[str, str]:
"""Disambiguate duplicate enrolled display names without merging speakers."""

labels_by_name: dict[str, list[str]] = {}

for speaker_label in speaker_labels:
match = speaker_map.get(speaker_label, {})
speaker_name = str(match.get("matched_name") or speaker_label)
labels_by_name.setdefault(speaker_name, []).append(speaker_label)

display_names: dict[str, str] = {}
for speaker_name, labels in labels_by_name.items():
for index, speaker_label in enumerate(labels, start=1):
display_names[speaker_label] = (
speaker_name if index == 1 else f"{speaker_name} ({index})"
)
return display_names


def build_result_segments(
aligned_segments: list[dict[str, Any]],
speaker_map: dict[str, dict[str, Any]],
) -> tuple[list[dict[str, Any]], list[str]]:
"""Build public result segments while preserving stable speaker labels."""

merged_segments = merge_aligned_segments(aligned_segments)
speaker_labels = list(
dict.fromkeys(segment["speaker"] for segment in merged_segments)
)
display_names = build_display_names(speaker_labels, speaker_map)
segments: list[dict[str, Any]] = []
seen_speakers: set[str] = set()
unique_speakers: list[str] = []

for index, segment in enumerate(merged_segments):
speaker_label = segment["speaker"]
match = speaker_map.get(speaker_label, {})
speaker_name = display_names.get(speaker_label, speaker_label)
output = {
"id": index,
"start": segment["start"],
"end": segment["end"],
"text": segment["text"],
"speaker_label": speaker_label,
"speaker_id": match.get("matched_id"),
"speaker_name": speaker_name,
"similarity": match.get("similarity", 0),
}
if segment.get("words"):
output["words"] = segment["words"]
segments.append(output)

if speaker_name not in seen_speakers:
seen_speakers.add(speaker_name)
unique_speakers.append(speaker_name)

return segments, unique_speakers
60 changes: 12 additions & 48 deletions app/providers/artifacts/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from config import DENOISE_MODEL, DENOISE_SNR_THRESHOLD
from infra.audio.paths import safe_speaker_label
from infra.transcription_artifacts import persist_transcription_artifacts
from postprocess.segments import build_display_names, build_result_segments
from providers.kernel_bridge import postprocess_segments, rust_provider_paths_enabled
from pipeline.contracts import (
ArtifactManifestEntry,
PipelineContext,
Expand All @@ -24,60 +26,22 @@ def _build_display_names(
speaker_labels: list[str],
speaker_map: dict[str, dict],
) -> dict[str, str]:
labels_by_name: dict[str, list[str]] = {}

for speaker_label in speaker_labels:
match = speaker_map.get(speaker_label, {})
speaker_name = str(match.get("matched_name") or speaker_label)
labels_by_name.setdefault(speaker_name, []).append(speaker_label)

display_names: dict[str, str] = {}
for speaker_name, labels in labels_by_name.items():
for index, speaker_label in enumerate(labels, start=1):
display_names[speaker_label] = (
speaker_name if index == 1 else f"{speaker_name} ({index})"
)
return display_names
return build_display_names(speaker_labels, speaker_map)

Comment on lines 11 to 30
@staticmethod
def _build_segments(
aligned_segments: list[dict],
speaker_map: dict[str, dict],
) -> tuple[list[dict], list[str]]:
speaker_labels = list(
dict.fromkeys(segment["speaker"] for segment in aligned_segments)
)
display_names = InMemoryArtifactsProvider._build_display_names(
speaker_labels,
speaker_map,
)
segments: list[dict] = []
seen_speakers: set[str] = set()
unique_speakers: list[str] = []

for index, segment in enumerate(aligned_segments):
speaker_label = segment["speaker"]
match = speaker_map.get(speaker_label, {})
speaker_name = display_names.get(speaker_label, speaker_label)
output = {
"id": index,
"start": segment["start"],
"end": segment["end"],
"text": segment["text"],
"speaker_label": speaker_label,
"speaker_id": match.get("matched_id"),
"speaker_name": speaker_name,
"similarity": match.get("similarity", 0),
}
if segment.get("words"):
output["words"] = segment["words"]
segments.append(output)

if speaker_name not in seen_speakers:
seen_speakers.add(speaker_name)
unique_speakers.append(speaker_name)

return segments, unique_speakers
if rust_provider_paths_enabled():
response = postprocess_segments(
{
"aligned_segments": aligned_segments,
"speaker_map": speaker_map,
}
)
return response["segments"], response["unique_speakers"]
return build_result_segments(aligned_segments, speaker_map)

def _build_transcription(self, context: PipelineContext) -> dict | None:
if context.request.artifact_dir is None:
Expand Down
2 changes: 2 additions & 0 deletions app/providers/kernel_bridge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
RUST_KERNEL_MODE_REQUIRED,
RustKernelBridgeError,
core_smoke,
postprocess_segments,
require_rust_core,
rust_kernel_mode,
rust_provider_paths_enabled,
Expand All @@ -16,6 +17,7 @@
"RUST_KERNEL_MODE_REQUIRED",
"RustKernelBridgeError",
"core_smoke",
"postprocess_segments",
"require_rust_core",
"rust_kernel_mode",
"rust_provider_paths_enabled",
Expand Down
Loading
Loading