diff --git a/.env.example b/.env.example index 46392c3..0f3799e 100644 --- a/.env.example +++ b/.env.example @@ -45,6 +45,7 @@ MODEL_IDLE_TIMEOUT_SEC=180 # Runtime mode for optional Rust-backed provider/kernel paths. # off — default; use Python implementations. # required — selected Rust-backed paths must run and hard-fail on import/call errors. +# Currently selected paths: voiceprint scoring and result post-processing. # CI/Docker packaging still validates the Rust extension directly even when # the runtime default is off. RUST_KERNEL_MODE=off diff --git a/.github/workflows/rust-foundation-heavy.yml b/.github/workflows/rust-foundation-heavy.yml index da6505a..036e736 100644 --- a/.github/workflows/rust-foundation-heavy.yml +++ b/.github/workflows/rust-foundation-heavy.yml @@ -109,7 +109,7 @@ jobs: docker run --rm \ -e RUST_KERNEL_MODE=required \ voscript-rust-foundation:${{ github.sha }} \ - python -c "from providers.kernel_bridge import core_smoke, voiceprint_score; result = core_smoke({'source': 'ci'}); assert result['ok'] is True; assert result['echoed']['source'] == 'ci'; decision = voiceprint_score({'query_embedding': [1.0, 0.0], 'candidates': [{'speaker_id': 'spk_ci', 'name': 'CI', 'embedding': [1.0, 0.0], 'sample_count': 1, 'sample_spread': None}], 'threshold': 0.75, 'asnorm_threshold': 0.5, 'cohort': None}); assert decision['matched_id'] == 'spk_ci'; assert decision['reason'] == 'matched'" + python -c "from providers.kernel_bridge import core_smoke, postprocess_segments, voiceprint_score; result = core_smoke({'source': 'ci'}); assert result['ok'] is True; assert result['echoed']['source'] == 'ci'; decision = voiceprint_score({'query_embedding': [1.0, 0.0], 'candidates': [{'speaker_id': 'spk_ci', 'name': 'CI', 'embedding': [1.0, 0.0], 'sample_count': 1, 'sample_spread': None}], 'threshold': 0.75, 'asnorm_threshold': 0.5, 'cohort': None}); assert decision['matched_id'] == 'spk_ci'; assert decision['reason'] == 'matched'; processed = postprocess_segments({'aligned_segments': [{'start': 0.0, 'end': 1.0, 'text': 'hello', 'speaker': 'SPEAKER_00'}], 'speaker_map': {}}); assert processed['segments'][0]['speaker_label'] == 'SPEAKER_00'; assert processed['unique_speakers'] == ['SPEAKER_00']" - name: Run health check smoke run: | diff --git a/Cargo.lock b/Cargo.lock index 426af82..001a789 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -127,7 +127,7 @@ checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" [[package]] name = "voscript_core" -version = "0.8.1" +version = "0.8.2" dependencies = [ "pyo3", ] diff --git a/app/pipeline/stages/diarization/alignment.py b/app/pipeline/stages/diarization/alignment.py index aae8976..5d6d9c8 100644 --- a/app/pipeline/stages/diarization/alignment.py +++ b/app/pipeline/stages/diarization/alignment.py @@ -4,6 +4,8 @@ from typing import Any +from postprocess.segments import normalize_words + def assign_segment_speaker( seg_start: float, seg_end: float, diarization_turns: list[dict[str, Any]] @@ -31,22 +33,6 @@ def assign_segment_speaker( return best_speaker -def normalize_words(raw_words: list[dict[str, Any]] | None) -> list[dict[str, Any]]: - """Normalise WhisperX word payloads to JSON-safe plain Python dicts.""" - if not raw_words: - return [] - - return [ - { - "word": str(word.get("word", "")), - "start": round(float(word.get("start", 0.0)), 3), - "end": round(float(word.get("end", 0.0)), 3), - "score": round(float(word.get("score", 0.0)), 4), - } - for word in raw_words - ] - - def normalize_segment( segment: dict[str, Any], diarization_turns: list[dict[str, Any]] ) -> dict[str, Any]: diff --git a/app/postprocess/__init__.py b/app/postprocess/__init__.py new file mode 100644 index 0000000..e8033ed --- /dev/null +++ b/app/postprocess/__init__.py @@ -0,0 +1,15 @@ +"""Pure transcript post-processing helpers.""" + +from .segments import ( + build_display_names, + build_result_segments, + merge_aligned_segments, + normalize_words, +) + +__all__ = [ + "build_display_names", + "build_result_segments", + "merge_aligned_segments", + "normalize_words", +] diff --git a/app/postprocess/segments.py b/app/postprocess/segments.py new file mode 100644 index 0000000..b16f079 --- /dev/null +++ b/app/postprocess/segments.py @@ -0,0 +1,161 @@ +"""Pure helpers for result-segment post-processing. + +These helpers are the Python oracle for Rust post-processing kernels. Keep them +free of model, filesystem, request, and database state so the cross-language +contract stays small and directly testable. +""" + +from __future__ import annotations + +from collections.abc import Mapping +from math import isfinite +from typing import Any + +MERGE_GAP_SECONDS = 0.05 + + +def _number(value: object, *, default: float = 0.0) -> float: + try: + parsed = float(value) + except (TypeError, ValueError): + return default + return parsed if isfinite(parsed) else default + + +def _rounded_time(value: object) -> float: + return round(max(0.0, _number(value)), 3) + + +def _rounded_score(value: object) -> float: + return round(_number(value), 4) + + +def normalize_words(raw_words: list[dict[str, Any]] | None) -> list[dict[str, Any]]: + """Normalize model word payloads to JSON-safe plain dictionaries.""" + + if not raw_words: + return [] + + normalized: list[dict[str, Any]] = [] + for raw_word in raw_words: + word = raw_word if isinstance(raw_word, Mapping) else {} + normalized.append( + { + "word": str(word.get("word", "")), + "start": _rounded_time(word.get("start", 0.0)), + "end": _rounded_time(word.get("end", 0.0)), + "score": _rounded_score(word.get("score", 0.0)), + } + ) + return normalized + + +def _normalize_aligned_segment(segment: Mapping[str, Any]) -> dict[str, Any]: + result = { + "start": _rounded_time(segment.get("start", 0.0)), + "end": _rounded_time(segment.get("end", 0.0)), + "text": str(segment.get("text", "")).strip(), + "speaker": str(segment.get("speaker", "UNKNOWN") or "UNKNOWN"), + } + words = normalize_words(segment.get("words")) + if words: + result["words"] = words + return result + + +def _can_merge_segments( + previous: Mapping[str, Any], current: Mapping[str, Any] +) -> bool: + if previous.get("speaker") != current.get("speaker"): + return False + if previous.get("words") or current.get("words"): + return False + previous_end = _number(previous.get("end", 0.0)) + current_start = _number(current.get("start", 0.0)) + return current_start <= previous_end + MERGE_GAP_SECONDS + + +def merge_aligned_segments( + aligned_segments: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Merge adjacent text-only segments for the same stable speaker label.""" + + merged: list[dict[str, Any]] = [] + for raw_segment in aligned_segments: + segment = _normalize_aligned_segment(raw_segment) + if merged and _can_merge_segments(merged[-1], segment): + previous = merged[-1] + previous["end"] = max( + _rounded_time(previous.get("end", 0.0)), + _rounded_time(segment.get("end", 0.0)), + ) + previous_text = str(previous.get("text", "")).strip() + current_text = str(segment.get("text", "")).strip() + previous["text"] = " ".join( + part for part in (previous_text, current_text) if part + ) + continue + merged.append(segment) + return merged + + +def build_display_names( + speaker_labels: list[str], + speaker_map: dict[str, dict[str, Any]], +) -> dict[str, str]: + """Disambiguate duplicate enrolled display names without merging speakers.""" + + labels_by_name: dict[str, list[str]] = {} + + for speaker_label in speaker_labels: + match = speaker_map.get(speaker_label, {}) + speaker_name = str(match.get("matched_name") or speaker_label) + labels_by_name.setdefault(speaker_name, []).append(speaker_label) + + display_names: dict[str, str] = {} + for speaker_name, labels in labels_by_name.items(): + for index, speaker_label in enumerate(labels, start=1): + display_names[speaker_label] = ( + speaker_name if index == 1 else f"{speaker_name} ({index})" + ) + return display_names + + +def build_result_segments( + aligned_segments: list[dict[str, Any]], + speaker_map: dict[str, dict[str, Any]], +) -> tuple[list[dict[str, Any]], list[str]]: + """Build public result segments while preserving stable speaker labels.""" + + merged_segments = merge_aligned_segments(aligned_segments) + speaker_labels = list( + dict.fromkeys(segment["speaker"] for segment in merged_segments) + ) + display_names = build_display_names(speaker_labels, speaker_map) + segments: list[dict[str, Any]] = [] + seen_speakers: set[str] = set() + unique_speakers: list[str] = [] + + for index, segment in enumerate(merged_segments): + speaker_label = segment["speaker"] + match = speaker_map.get(speaker_label, {}) + speaker_name = display_names.get(speaker_label, speaker_label) + output = { + "id": index, + "start": segment["start"], + "end": segment["end"], + "text": segment["text"], + "speaker_label": speaker_label, + "speaker_id": match.get("matched_id"), + "speaker_name": speaker_name, + "similarity": match.get("similarity", 0), + } + if segment.get("words"): + output["words"] = segment["words"] + segments.append(output) + + if speaker_name not in seen_speakers: + seen_speakers.add(speaker_name) + unique_speakers.append(speaker_name) + + return segments, unique_speakers diff --git a/app/providers/artifacts/default.py b/app/providers/artifacts/default.py index 106fee4..b3a55c0 100644 --- a/app/providers/artifacts/default.py +++ b/app/providers/artifacts/default.py @@ -8,6 +8,8 @@ from config import DENOISE_MODEL, DENOISE_SNR_THRESHOLD from infra.audio.paths import safe_speaker_label from infra.transcription_artifacts import persist_transcription_artifacts +from postprocess.segments import build_display_names, build_result_segments +from providers.kernel_bridge import postprocess_segments, rust_provider_paths_enabled from pipeline.contracts import ( ArtifactManifestEntry, PipelineContext, @@ -24,60 +26,22 @@ def _build_display_names( speaker_labels: list[str], speaker_map: dict[str, dict], ) -> dict[str, str]: - labels_by_name: dict[str, list[str]] = {} - - for speaker_label in speaker_labels: - match = speaker_map.get(speaker_label, {}) - speaker_name = str(match.get("matched_name") or speaker_label) - labels_by_name.setdefault(speaker_name, []).append(speaker_label) - - display_names: dict[str, str] = {} - for speaker_name, labels in labels_by_name.items(): - for index, speaker_label in enumerate(labels, start=1): - display_names[speaker_label] = ( - speaker_name if index == 1 else f"{speaker_name} ({index})" - ) - return display_names + return build_display_names(speaker_labels, speaker_map) @staticmethod def _build_segments( aligned_segments: list[dict], speaker_map: dict[str, dict], ) -> tuple[list[dict], list[str]]: - speaker_labels = list( - dict.fromkeys(segment["speaker"] for segment in aligned_segments) - ) - display_names = InMemoryArtifactsProvider._build_display_names( - speaker_labels, - speaker_map, - ) - segments: list[dict] = [] - seen_speakers: set[str] = set() - unique_speakers: list[str] = [] - - for index, segment in enumerate(aligned_segments): - speaker_label = segment["speaker"] - match = speaker_map.get(speaker_label, {}) - speaker_name = display_names.get(speaker_label, speaker_label) - output = { - "id": index, - "start": segment["start"], - "end": segment["end"], - "text": segment["text"], - "speaker_label": speaker_label, - "speaker_id": match.get("matched_id"), - "speaker_name": speaker_name, - "similarity": match.get("similarity", 0), - } - if segment.get("words"): - output["words"] = segment["words"] - segments.append(output) - - if speaker_name not in seen_speakers: - seen_speakers.add(speaker_name) - unique_speakers.append(speaker_name) - - return segments, unique_speakers + if rust_provider_paths_enabled(): + response = postprocess_segments( + { + "aligned_segments": aligned_segments, + "speaker_map": speaker_map, + } + ) + return response["segments"], response["unique_speakers"] + return build_result_segments(aligned_segments, speaker_map) def _build_transcription(self, context: PipelineContext) -> dict | None: if context.request.artifact_dir is None: diff --git a/app/providers/kernel_bridge/__init__.py b/app/providers/kernel_bridge/__init__.py index 6b6ec68..e6c134c 100644 --- a/app/providers/kernel_bridge/__init__.py +++ b/app/providers/kernel_bridge/__init__.py @@ -5,6 +5,7 @@ RUST_KERNEL_MODE_REQUIRED, RustKernelBridgeError, core_smoke, + postprocess_segments, require_rust_core, rust_kernel_mode, rust_provider_paths_enabled, @@ -16,6 +17,7 @@ "RUST_KERNEL_MODE_REQUIRED", "RustKernelBridgeError", "core_smoke", + "postprocess_segments", "require_rust_core", "rust_kernel_mode", "rust_provider_paths_enabled", diff --git a/app/providers/kernel_bridge/runtime.py b/app/providers/kernel_bridge/runtime.py index 3052088..36ba553 100644 --- a/app/providers/kernel_bridge/runtime.py +++ b/app/providers/kernel_bridge/runtime.py @@ -178,6 +178,131 @@ def _validate_voiceprint_score_candidate_response(candidate: Any) -> dict[str, A return result +def _validate_postprocess_segments_response(response: Any) -> dict[str, Any]: + if not isinstance(response, Mapping): + raise RustKernelBridgeError( + "Rust postprocess_segments returned a non-mapping response" + ) + + result = dict(response) + required_keys = {"segments", "unique_speakers"} + missing = sorted(required_keys.difference(result)) + if missing: + raise RustKernelBridgeError( + f"Rust postprocess_segments response missing keys: {', '.join(missing)}" + ) + if not isinstance(result["segments"], list): + raise RustKernelBridgeError("Rust postprocess_segments segments must be a list") + if not isinstance(result["unique_speakers"], list) or not all( + isinstance(speaker, str) and speaker for speaker in result["unique_speakers"] + ): + raise RustKernelBridgeError( + "Rust postprocess_segments unique_speakers must be non-empty strings" + ) + result["segments"] = [ + _validate_postprocess_segment_response(segment) + for segment in result["segments"] + ] + return result + + +def _validate_postprocess_segment_response(segment: Any) -> dict[str, Any]: + if not isinstance(segment, Mapping): + raise RustKernelBridgeError( + "Rust postprocess_segments segment returned a non-mapping response" + ) + + result = dict(segment) + required_keys = { + "id", + "start", + "end", + "text", + "speaker_label", + "speaker_id", + "speaker_name", + "similarity", + } + missing = sorted(required_keys.difference(result)) + if missing: + raise RustKernelBridgeError( + "Rust postprocess_segments segment missing keys: " + ", ".join(missing) + ) + try: + result["id"] = int(result["id"]) + except (TypeError, ValueError) as exc: + raise RustKernelBridgeError( + "Rust postprocess_segments segment id must be integer-like" + ) from exc + if result["id"] < 0: + raise RustKernelBridgeError( + "Rust postprocess_segments segment id must be non-negative" + ) + for key in ("start", "end", "similarity"): + try: + result[key] = float(result[key]) + except (TypeError, ValueError) as exc: + raise RustKernelBridgeError( + f"Rust postprocess_segments segment {key} must be numeric" + ) from exc + if not isfinite(result[key]): + raise RustKernelBridgeError( + f"Rust postprocess_segments segment {key} must be finite" + ) + for key in ("text", "speaker_label", "speaker_name"): + if not isinstance(result[key], str): + raise RustKernelBridgeError( + f"Rust postprocess_segments segment {key} must be a string" + ) + if not result["speaker_label"] or not result["speaker_name"]: + raise RustKernelBridgeError( + "Rust postprocess_segments segment speaker labels must be non-empty" + ) + if result["speaker_id"] is not None and not isinstance(result["speaker_id"], str): + raise RustKernelBridgeError( + "Rust postprocess_segments segment speaker_id must be a string or null" + ) + if "words" in result: + if not isinstance(result["words"], list): + raise RustKernelBridgeError( + "Rust postprocess_segments segment words must be a list" + ) + result["words"] = [ + _validate_postprocess_word_response(word) for word in result["words"] + ] + return result + + +def _validate_postprocess_word_response(word: Any) -> dict[str, Any]: + if not isinstance(word, Mapping): + raise RustKernelBridgeError( + "Rust postprocess_segments word returned a non-mapping response" + ) + result = dict(word) + required_keys = {"word", "start", "end", "score"} + missing = sorted(required_keys.difference(result)) + if missing: + raise RustKernelBridgeError( + "Rust postprocess_segments word missing keys: " + ", ".join(missing) + ) + if not isinstance(result["word"], str): + raise RustKernelBridgeError( + "Rust postprocess_segments word text must be string" + ) + for key in ("start", "end", "score"): + try: + result[key] = float(result[key]) + except (TypeError, ValueError) as exc: + raise RustKernelBridgeError( + f"Rust postprocess_segments word {key} must be numeric" + ) from exc + if not isfinite(result[key]): + raise RustKernelBridgeError( + f"Rust postprocess_segments word {key} must be finite" + ) + return result + + def core_smoke( payload: Any, importer: Callable[[str], ModuleType] = import_module, @@ -204,3 +329,17 @@ def voiceprint_score( except Exception as exc: raise RustKernelBridgeError("Rust voiceprint_score call failed") from exc return _validate_voiceprint_score_response(response) + + +def postprocess_segments( + payload: dict[str, Any], + importer: Callable[[str], ModuleType] = import_module, +) -> dict[str, Any]: + """Call the native result-segment post-processing kernel.""" + + rust_core = require_rust_core(importer=importer) + try: + response = rust_core.postprocess_segments(payload) + except Exception as exc: + raise RustKernelBridgeError("Rust postprocess_segments call failed") from exc + return _validate_postprocess_segments_response(response) diff --git a/crates/voscript_core/Cargo.toml b/crates/voscript_core/Cargo.toml index 044cc9d..883d2c9 100644 --- a/crates/voscript_core/Cargo.toml +++ b/crates/voscript_core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "voscript_core" -version = "0.8.1" +version = "0.8.2" edition = "2021" license = "Apache-2.0" publish = false diff --git a/crates/voscript_core/src/lib.rs b/crates/voscript_core/src/lib.rs index 4eb5151..ad1bebe 100644 --- a/crates/voscript_core/src/lib.rs +++ b/crates/voscript_core/src/lib.rs @@ -5,6 +5,7 @@ use pyo3::prelude::*; #[cfg(feature = "python-bindings")] use pyo3::types::{PyDict, PyList, PyModule}; +pub mod postprocess; pub mod voiceprint; pub const CORE_SMOKE_CAPABILITY: &str = "core_smoke"; @@ -51,6 +52,33 @@ fn optional_usize(dict: &Bound<'_, PyDict>, key: &str, default: usize) -> PyResu } } +#[cfg(feature = "python-bindings")] +fn optional_string(dict: &Bound<'_, PyDict>, key: &str) -> PyResult> { + match dict.get_item(key)? { + Some(value) if !value.is_none() => Ok(Some(value.str()?.to_string())), + _ => Ok(None), + } +} + +#[cfg(feature = "python-bindings")] +fn item_f64_or_default(item: Option>, default: f64) -> PyResult { + match item { + Some(value) if !value.is_none() => { + if let Ok(parsed) = value.extract::() { + return Ok(parsed); + } + let text = value.str()?.to_string(); + Ok(text.parse::().unwrap_or(default)) + } + _ => Ok(default), + } +} + +#[cfg(feature = "python-bindings")] +fn optional_f64_or_default(dict: &Bound<'_, PyDict>, key: &str, default: f64) -> PyResult { + item_f64_or_default(dict.get_item(key)?, default) +} + #[cfg(feature = "python-bindings")] fn parse_voiceprint_candidate( item: Bound<'_, PyAny>, @@ -101,6 +129,102 @@ fn parse_voiceprint_request( }) } +#[cfg(feature = "python-bindings")] +fn parse_postprocess_word(item: Bound<'_, PyAny>) -> PyResult { + let dict = match item.cast_into::() { + Ok(dict) => dict, + Err(_) => { + return Ok(postprocess::Word { + word: String::new(), + start: 0.0, + end: 0.0, + score: 0.0, + }); + } + }; + let word = match dict.get_item("word")? { + Some(value) if !value.is_none() => value.str()?.to_string(), + _ => String::new(), + }; + Ok(postprocess::Word { + word, + start: optional_f64_or_default(&dict, "start", 0.0)?, + end: optional_f64_or_default(&dict, "end", 0.0)?, + score: optional_f64_or_default(&dict, "score", 0.0)?, + }) +} + +#[cfg(feature = "python-bindings")] +fn parse_postprocess_words(dict: &Bound<'_, PyDict>) -> PyResult> { + let words_any = match dict.get_item("words")? { + Some(value) if !value.is_none() => value, + _ => return Ok(Vec::new()), + }; + let words_list = words_any.cast_into::()?; + let mut words = Vec::with_capacity(words_list.len()); + for item in words_list.iter() { + words.push(parse_postprocess_word(item)?); + } + Ok(words) +} + +#[cfg(feature = "python-bindings")] +fn parse_aligned_segment(item: Bound<'_, PyAny>) -> PyResult { + let dict = item.cast_into::()?; + let text = match dict.get_item("text")? { + Some(value) if !value.is_none() => value.str()?.to_string(), + _ => String::new(), + }; + let speaker = match dict.get_item("speaker")? { + Some(value) if !value.is_none() => value.str()?.to_string(), + _ => "UNKNOWN".to_string(), + }; + Ok(postprocess::AlignedSegment { + start: optional_f64_or_default(&dict, "start", 0.0)?, + end: optional_f64_or_default(&dict, "end", 0.0)?, + text, + speaker, + words: parse_postprocess_words(&dict)?, + }) +} + +#[cfg(feature = "python-bindings")] +fn parse_speaker_match(item: Bound<'_, PyAny>) -> PyResult { + let dict = match item.cast_into::() { + Ok(dict) => dict, + Err(_) => return Ok(postprocess::SpeakerMatch::default()), + }; + Ok(postprocess::SpeakerMatch { + matched_id: optional_string(&dict, "matched_id")?, + matched_name: optional_string(&dict, "matched_name")?, + similarity: Some(optional_f64_or_default(&dict, "similarity", 0.0)?), + }) +} + +#[cfg(feature = "python-bindings")] +fn parse_postprocess_request( + payload: &Bound<'_, PyDict>, +) -> PyResult<( + Vec, + std::collections::HashMap, +)> { + let segments_any = required_item(payload, "aligned_segments")?; + let segments_list = segments_any.cast_into::()?; + let mut aligned_segments = Vec::with_capacity(segments_list.len()); + for item in segments_list.iter() { + aligned_segments.push(parse_aligned_segment(item)?); + } + + let speaker_map_any = required_item(payload, "speaker_map")?; + let speaker_map_dict = speaker_map_any.cast_into::()?; + let mut speaker_map = std::collections::HashMap::new(); + for (key, value) in speaker_map_dict.iter() { + speaker_map.insert(key.str()?.to_string(), parse_speaker_match(value)?); + } + + Ok((aligned_segments, speaker_map)) +} + #[cfg(feature = "python-bindings")] #[pyfunction] fn voiceprint_score(py: Python<'_>, payload: &Bound<'_, PyDict>) -> PyResult> { @@ -133,12 +257,50 @@ fn voiceprint_score(py: Python<'_>, payload: &Bound<'_, PyDict>) -> PyResult, payload: &Bound<'_, PyDict>) -> PyResult> { + let (aligned_segments, speaker_map) = parse_postprocess_request(payload)?; + let result = postprocess::build_result_segments(aligned_segments, speaker_map); + + let response = PyDict::new(py); + let segments = PyList::empty(py); + for segment in result.segments { + let item = PyDict::new(py); + item.set_item("id", segment.id)?; + item.set_item("start", segment.start)?; + item.set_item("end", segment.end)?; + item.set_item("text", segment.text)?; + item.set_item("speaker_label", segment.speaker_label)?; + item.set_item("speaker_id", segment.speaker_id)?; + item.set_item("speaker_name", segment.speaker_name)?; + item.set_item("similarity", segment.similarity)?; + if !segment.words.is_empty() { + let words = PyList::empty(py); + for word in segment.words { + let word_item = PyDict::new(py); + word_item.set_item("word", word.word)?; + word_item.set_item("start", word.start)?; + word_item.set_item("end", word.end)?; + word_item.set_item("score", word.score)?; + words.append(word_item)?; + } + item.set_item("words", words)?; + } + segments.append(item)?; + } + response.set_item("segments", segments)?; + response.set_item("unique_speakers", result.unique_speakers)?; + Ok(response.unbind()) +} + #[cfg(feature = "python-bindings")] #[pymodule] fn voscript_core(module: &Bound<'_, PyModule>) -> PyResult<()> { module.add("__version__", PACKAGE_VERSION)?; module.add_function(wrap_pyfunction!(core_smoke, module)?)?; module.add_function(wrap_pyfunction!(voiceprint_score, module)?)?; + module.add_function(wrap_pyfunction!(postprocess_segments, module)?)?; Ok(()) } @@ -146,7 +308,7 @@ fn voscript_core(module: &Bound<'_, PyModule>) -> PyResult<()> { mod tests { #[test] fn package_version_is_set() { - assert_eq!(super::PACKAGE_VERSION, "0.8.1"); + assert_eq!(super::PACKAGE_VERSION, "0.8.2"); } #[test] diff --git a/crates/voscript_core/src/postprocess.rs b/crates/voscript_core/src/postprocess.rs new file mode 100644 index 0000000..b453d5d --- /dev/null +++ b/crates/voscript_core/src/postprocess.rs @@ -0,0 +1,351 @@ +use std::collections::{HashMap, HashSet}; + +pub const MERGE_GAP_SECONDS: f64 = 0.05; + +#[derive(Clone, Debug, PartialEq)] +pub struct Word { + pub word: String, + pub start: f64, + pub end: f64, + pub score: f64, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct AlignedSegment { + pub start: f64, + pub end: f64, + pub text: String, + pub speaker: String, + pub words: Vec, +} + +#[derive(Clone, Debug, Default, PartialEq)] +pub struct SpeakerMatch { + pub matched_id: Option, + pub matched_name: Option, + pub similarity: Option, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct ResultSegment { + pub id: usize, + pub start: f64, + pub end: f64, + pub text: String, + pub speaker_label: String, + pub speaker_id: Option, + pub speaker_name: String, + pub similarity: f64, + pub words: Vec, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct PostprocessResult { + pub segments: Vec, + pub unique_speakers: Vec, +} + +fn safe_number(value: f64) -> f64 { + if value.is_finite() { + value + } else { + 0.0 + } +} + +fn round_to(value: f64, scale: f64) -> f64 { + (value * scale).round() / scale +} + +fn round_time(value: f64) -> f64 { + round_to(safe_number(value).max(0.0), 1000.0) +} + +fn round_score(value: f64) -> f64 { + round_to(safe_number(value), 10000.0) +} + +pub fn normalize_word(word: Word) -> Word { + Word { + word: word.word, + start: round_time(word.start), + end: round_time(word.end), + score: round_score(word.score), + } +} + +pub fn normalize_segment(segment: AlignedSegment) -> AlignedSegment { + AlignedSegment { + start: round_time(segment.start), + end: round_time(segment.end), + text: segment.text.trim().to_string(), + speaker: if segment.speaker.is_empty() { + "UNKNOWN".to_string() + } else { + segment.speaker + }, + words: segment.words.into_iter().map(normalize_word).collect(), + } +} + +fn can_merge_segments(previous: &AlignedSegment, current: &AlignedSegment) -> bool { + previous.speaker == current.speaker + && previous.words.is_empty() + && current.words.is_empty() + && current.start <= previous.end + MERGE_GAP_SECONDS +} + +pub fn merge_aligned_segments(segments: Vec) -> Vec { + let mut merged: Vec = Vec::new(); + for raw_segment in segments { + let segment = normalize_segment(raw_segment); + if let Some(previous) = merged.last_mut() { + if can_merge_segments(previous, &segment) { + previous.end = previous.end.max(segment.end); + let previous_text = previous.text.trim(); + let current_text = segment.text.trim(); + previous.text = match (previous_text.is_empty(), current_text.is_empty()) { + (true, true) => String::new(), + (true, false) => current_text.to_string(), + (false, true) => previous_text.to_string(), + (false, false) => format!("{previous_text} {current_text}"), + }; + continue; + } + } + merged.push(segment); + } + merged +} + +pub fn build_display_names( + speaker_labels: &[String], + speaker_map: &HashMap, +) -> HashMap { + let mut labels_by_name: Vec<(String, Vec)> = Vec::new(); + + for speaker_label in speaker_labels { + let speaker_name = speaker_map + .get(speaker_label) + .and_then(|entry| entry.matched_name.as_ref()) + .filter(|name| !name.is_empty()) + .cloned() + .unwrap_or_else(|| speaker_label.clone()); + + if let Some((_, labels)) = labels_by_name + .iter_mut() + .find(|(known_name, _)| known_name == &speaker_name) + { + labels.push(speaker_label.clone()); + } else { + labels_by_name.push((speaker_name, vec![speaker_label.clone()])); + } + } + + let mut display_names = HashMap::new(); + for (speaker_name, labels) in labels_by_name { + for (index, speaker_label) in labels.into_iter().enumerate() { + let display_name = if index == 0 { + speaker_name.clone() + } else { + format!("{} ({})", speaker_name, index + 1) + }; + display_names.insert(speaker_label, display_name); + } + } + display_names +} + +pub fn build_result_segments( + aligned_segments: Vec, + speaker_map: HashMap, +) -> PostprocessResult { + let merged_segments = merge_aligned_segments(aligned_segments); + let mut seen_labels = HashSet::new(); + let mut speaker_labels = Vec::new(); + for segment in &merged_segments { + if seen_labels.insert(segment.speaker.clone()) { + speaker_labels.push(segment.speaker.clone()); + } + } + + let display_names = build_display_names(&speaker_labels, &speaker_map); + let mut seen_speakers = HashSet::new(); + let mut unique_speakers = Vec::new(); + let mut segments = Vec::with_capacity(merged_segments.len()); + + for (index, segment) in merged_segments.into_iter().enumerate() { + let speaker_label = segment.speaker; + let speaker_match = speaker_map.get(&speaker_label); + let speaker_name = display_names + .get(&speaker_label) + .cloned() + .unwrap_or_else(|| speaker_label.clone()); + let similarity = speaker_match + .and_then(|entry| entry.similarity) + .map(safe_number) + .unwrap_or(0.0); + if seen_speakers.insert(speaker_name.clone()) { + unique_speakers.push(speaker_name.clone()); + } + segments.push(ResultSegment { + id: index, + start: segment.start, + end: segment.end, + text: segment.text, + speaker_label, + speaker_id: speaker_match.and_then(|entry| entry.matched_id.clone()), + speaker_name, + similarity, + words: segment.words, + }); + } + + PostprocessResult { + segments, + unique_speakers, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn merge_preserves_speaker_label_and_skips_word_segments() { + let segments = vec![ + AlignedSegment { + start: 0.0, + end: 1.0, + text: " first ".to_string(), + speaker: "SPEAKER_00".to_string(), + words: vec![], + }, + AlignedSegment { + start: 1.02, + end: 2.0, + text: "second".to_string(), + speaker: "SPEAKER_00".to_string(), + words: vec![], + }, + AlignedSegment { + start: 2.0, + end: 3.0, + text: "worded".to_string(), + speaker: "SPEAKER_00".to_string(), + words: vec![Word { + word: "worded".to_string(), + start: 2.0, + end: 2.5, + score: 0.5, + }], + }, + ]; + + let merged = merge_aligned_segments(segments); + + assert_eq!(merged.len(), 2); + assert_eq!(merged[0].speaker, "SPEAKER_00"); + assert_eq!(merged[0].text, "first second"); + assert_eq!(merged[1].text, "worded"); + } + + #[test] + fn display_names_disambiguate_without_merging_speakers() { + let labels = vec!["SPEAKER_00".to_string(), "SPEAKER_01".to_string()]; + let mut speaker_map = HashMap::new(); + speaker_map.insert( + "SPEAKER_00".to_string(), + SpeakerMatch { + matched_id: Some("spk_same".to_string()), + matched_name: Some("Maple".to_string()), + similarity: Some(2.0), + }, + ); + speaker_map.insert( + "SPEAKER_01".to_string(), + SpeakerMatch { + matched_id: Some("spk_same".to_string()), + matched_name: Some("Maple".to_string()), + similarity: Some(1.0), + }, + ); + + let display_names = build_display_names(&labels, &speaker_map); + + assert_eq!(display_names["SPEAKER_00"], "Maple"); + assert_eq!(display_names["SPEAKER_01"], "Maple (2)"); + } + + #[test] + fn word_normalization_is_json_safe() { + let word = normalize_word(Word { + word: "hello".to_string(), + start: -1.0, + end: f64::NAN, + score: f64::INFINITY, + }); + + assert_eq!( + word, + Word { + word: "hello".to_string(), + start: 0.0, + end: 0.0, + score: 0.0, + } + ); + } + + #[test] + fn result_segments_preserve_labels_and_disambiguate_names() { + let aligned_segments = vec![ + AlignedSegment { + start: 0.0, + end: 1.0, + text: "hello".to_string(), + speaker: "SPEAKER_00".to_string(), + words: vec![], + }, + AlignedSegment { + start: 1.0, + end: 2.0, + text: "world".to_string(), + speaker: "SPEAKER_01".to_string(), + words: vec![Word { + word: "world".to_string(), + start: -1.0, + end: 2.0, + score: 0.77777, + }], + }, + ]; + let mut speaker_map = HashMap::new(); + speaker_map.insert( + "SPEAKER_00".to_string(), + SpeakerMatch { + matched_id: Some("spk_same".to_string()), + matched_name: Some("Maple".to_string()), + similarity: Some(2.0), + }, + ); + speaker_map.insert( + "SPEAKER_01".to_string(), + SpeakerMatch { + matched_id: Some("spk_same".to_string()), + matched_name: Some("Maple".to_string()), + similarity: Some(1.0), + }, + ); + + let result = build_result_segments(aligned_segments, speaker_map); + + assert_eq!(result.unique_speakers, vec!["Maple", "Maple (2)"]); + assert_eq!(result.segments[0].speaker_label, "SPEAKER_00"); + assert_eq!(result.segments[1].speaker_label, "SPEAKER_01"); + assert_eq!(result.segments[0].speaker_name, "Maple"); + assert_eq!(result.segments[1].speaker_name, "Maple (2)"); + assert_eq!(result.segments[1].words[0].start, 0.0); + assert_eq!(result.segments[1].words[0].score, 0.7778); + } +} diff --git a/doc/changelog.en.md b/doc/changelog.en.md index 5fde29e..ddb57de 100644 --- a/doc/changelog.en.md +++ b/doc/changelog.en.md @@ -18,6 +18,10 @@ - Added an optional Rust-backed voiceprint scoring kernel for explicit `RUST_KERNEL_MODE=required` runs. The default remains Python scoring, and the public speaker/voiceprint result contract is unchanged. +- Added optional Rust-backed result post-processing for explicit + `RUST_KERNEL_MODE=required` runs. The default remains Python post-processing; + result segments keep stable `speaker_label` values, duplicate display names + are disambiguated instead of merged, and `segments[].words` remains optional. ### Security @@ -34,6 +38,9 @@ - Extended Rust kernel tests with voiceprint scoring golden cases for raw cosine, AS-norm activation, small-cohort raw fallback, ambiguous top-2 margins, and non-finite embedding rejection. +- Extended Rust kernel and Docker smoke coverage to include result + post-processing segment assembly, display-name disambiguation, and word + normalization. ## 0.7.6 — Health, alignment, and embedding runtime fixes (2026-05-07) diff --git a/doc/changelog.zh.md b/doc/changelog.zh.md index 7a0986e..9396422 100644 --- a/doc/changelog.zh.md +++ b/doc/changelog.zh.md @@ -15,6 +15,9 @@ 否则 fail closed。 - 新增显式 `RUST_KERNEL_MODE=required` 下可选的 Rust-backed 声纹计分 kernel。 默认仍使用 Python 计分,公开 speaker / voiceprint 结果契约不变。 +- 新增显式 `RUST_KERNEL_MODE=required` 下可选的 Rust-backed 结果后处理。 + 默认仍使用 Python 后处理;结果 segment 继续保留稳定 `speaker_label`, + 重名展示名只做序号消歧而不合并 speaker,`segments[].words` 仍是可选字段。 ### 安全 @@ -28,6 +31,8 @@ 后续 PR 更新不自动重复重型 gate,需在合并前按需手动触发。 - 扩展 Rust kernel 测试,覆盖 raw cosine、AS-norm 启用、小 cohort 回退 raw、 top-2 margin 模糊拒绝以及非有限 embedding 拒绝等声纹计分 golden case。 +- 扩展 Rust kernel 与 Docker smoke 覆盖,加入结果后处理 segment 组装、展示名 + 消歧和 word normalization。 ## 0.7.6 — 健康检查、alignment 与 embedding 运行时修复 (2026-05-07) diff --git a/doc/configuration.en.md b/doc/configuration.en.md index c39cc0f..44a72ad 100644 --- a/doc/configuration.en.md +++ b/doc/configuration.en.md @@ -38,7 +38,7 @@ parameters yet. | `FFMPEG_TIMEOUT_SEC` | `1800` | ffmpeg conversion timeout in seconds; timeout returns `504`. | | `JOBS_MAX_CACHE` | `200` | In-memory job LRU limit. Evicted completed jobs remain queryable from disk `status.json` / `result.json`. | | `MODEL_IDLE_TIMEOUT_SEC` | `180` | GPU model idle-unload timeout, defaulting to 180 seconds (3 minutes). Set `0` to disable idle unload and keep models resident. When enabled, loaded models are released only after the serialized GPU runtime has been idle for this many seconds; on the next reload, ASR, diarization, and embedding each choose the visible CUDA device with the most free memory during their own lazy load. | -| `RUST_KERNEL_MODE` | `off` | Optional Rust-backed provider/kernel mode. `off` keeps Python implementations; `required` makes selected Rust-backed paths import and run successfully or fail closed. The current selected path is voiceprint scoring; CI / Docker packaging still validates the Rust extension directly when the runtime default is off. | +| `RUST_KERNEL_MODE` | `off` | Optional Rust-backed provider/kernel mode. `off` keeps Python implementations; `required` makes selected Rust-backed paths import and run successfully or fail closed. The current selected paths are voiceprint scoring and result post-processing; CI / Docker packaging still validates the Rust extension directly when the runtime default is off. | `MODELS_DIR` and `LANGUAGE` are defined in the config module, but v0.7.6's main HTTP transcription path does not use them as stable public tuning knobs: diff --git a/doc/configuration.zh.md b/doc/configuration.zh.md index e0bf0e0..4a89668 100644 --- a/doc/configuration.zh.md +++ b/doc/configuration.zh.md @@ -36,7 +36,7 @@ | `FFMPEG_TIMEOUT_SEC` | `1800` | ffmpeg 转码超时秒数,超时返回 `504`。 | | `JOBS_MAX_CACHE` | `200` | 内存 job LRU 上限;被淘汰的完成任务仍可从磁盘 `status.json` / `result.json` 查询。 | | `MODEL_IDLE_TIMEOUT_SEC` | `180` | GPU 模型空闲卸载超时,默认 180 秒(3 分钟)。设为 `0` 可关闭空闲卸载并保持模型常驻。开启后,只有串行 GPU 运行时空闲达到该秒数才释放已加载模型;下一次 reload 时 ASR、diarization 和 embedding 会在各自 lazy load 时分别选择当前可见 CUDA 中空闲显存最多的设备。 | -| `RUST_KERNEL_MODE` | `off` | 可选 Rust-backed provider/kernel 路径开关。`off` 保持 Python 实现;`required` 要求被选择的 Rust-backed 路径可导入并执行,缺失或调用失败时 fail closed。当前被选择的路径是声纹计分;默认关闭时,CI / Docker packaging 仍会直接验证 Rust 扩展。 | +| `RUST_KERNEL_MODE` | `off` | 可选 Rust-backed provider/kernel 路径开关。`off` 保持 Python 实现;`required` 要求被选择的 Rust-backed 路径可导入并执行,缺失或调用失败时 fail closed。当前被选择的路径是声纹计分和结果后处理;默认关闭时,CI / Docker packaging 仍会直接验证 Rust 扩展。 | `MODELS_DIR` 和 `LANGUAGE` 在配置模块里有定义,但 v0.7.6 的主 HTTP 转写路径 没有把它们作为稳定公开调参入口使用:Whisper 本地 checkpoint 查找仍使用 diff --git a/tests/unit/test_kernel_bridge.py b/tests/unit/test_kernel_bridge.py index 2bcd203..c8037b9 100644 --- a/tests/unit/test_kernel_bridge.py +++ b/tests/unit/test_kernel_bridge.py @@ -9,6 +9,7 @@ from providers.kernel_bridge import ( RustKernelBridgeError, core_smoke, + postprocess_segments, require_rust_core, rust_kernel_mode, rust_provider_paths_enabled, @@ -22,7 +23,7 @@ def _core_smoke(payload): return { "ok": True, "echoed": payload, - "version": "0.8.1", + "version": "0.8.2", "capabilities": {"core_smoke": True, "rust_extension": True}, } @@ -36,7 +37,7 @@ def test_core_smoke_round_trips_safe_payload_through_imported_extension(): assert result["ok"] is True assert result["echoed"] == payload - assert result["version"] == "0.8.1" + assert result["version"] == "0.8.2" assert result["capabilities"]["core_smoke"] is True @@ -79,3 +80,60 @@ def test_rust_kernel_mode_defaults_to_off_semantics(): def test_invalid_rust_kernel_mode_hard_fails(): with pytest.raises(RustKernelBridgeError, match="Invalid RUST_KERNEL_MODE"): rust_kernel_mode("auto") + + +def test_postprocess_segments_round_trips_valid_kernel_response(): + def _importer(module_name): + assert module_name == "voscript_core" + + def _postprocess_segments(payload): + assert payload["aligned_segments"][0]["speaker"] == "SPEAKER_00" + return { + "segments": [ + { + "id": 0, + "start": 0.0, + "end": 1.0, + "text": "hello", + "speaker_label": "SPEAKER_00", + "speaker_id": None, + "speaker_name": "SPEAKER_00", + "similarity": 0, + "words": [ + { + "word": "hello", + "start": 0.0, + "end": 1.0, + "score": 0.0, + } + ], + } + ], + "unique_speakers": ["SPEAKER_00"], + } + + return SimpleNamespace(postprocess_segments=_postprocess_segments) + + result = postprocess_segments( + { + "aligned_segments": [{"speaker": "SPEAKER_00"}], + "speaker_map": {}, + }, + importer=_importer, + ) + + assert result["segments"][0]["id"] == 0 + assert result["segments"][0]["similarity"] == 0.0 + assert result["unique_speakers"] == ["SPEAKER_00"] + + +def test_postprocess_segments_invalid_response_hard_fails(): + def _importer(module_name): + assert module_name == "voscript_core" + return SimpleNamespace(postprocess_segments=lambda payload: {"segments": []}) + + with pytest.raises(RustKernelBridgeError, match="missing keys"): + postprocess_segments( + {"aligned_segments": [], "speaker_map": {}}, + importer=_importer, + ) diff --git a/tests/unit/test_pipeline_alignment.py b/tests/unit/test_pipeline_alignment.py index ee9020c..f748b7e 100644 --- a/tests/unit/test_pipeline_alignment.py +++ b/tests/unit/test_pipeline_alignment.py @@ -32,9 +32,7 @@ def test_normalize_words_returns_json_safe_values(): [{"word": 7, "start": "1.2349", "end": 2, "score": "0.98765"}] ) - assert words == [ - {"word": "7", "start": 1.235, "end": 2.0, "score": 0.9877} - ] + assert words == [{"word": "7", "start": 1.235, "end": 2.0, "score": 0.9877}] def test_build_aligned_segments_attaches_speaker_and_words(): @@ -56,9 +54,7 @@ def test_build_aligned_segments_attaches_speaker_and_words(): "end": 1.2, "text": "hello", "speaker": "SPEAKER_00", - "words": [ - {"word": "hi", "start": 0.01, "end": 0.4, "score": 0.5} - ], + "words": [{"word": "hi", "start": 0.01, "end": 0.4, "score": 0.5}], } ] diff --git a/tests/unit/test_pipeline_runner.py b/tests/unit/test_pipeline_runner.py index 4831f72..c809d58 100644 --- a/tests/unit/test_pipeline_runner.py +++ b/tests/unit/test_pipeline_runner.py @@ -28,6 +28,7 @@ available_stage_slots as available_stage_slots_compat, resolve_stage as resolve_stage_compat, ) +import providers.artifacts.default as artifacts_default from providers.artifacts.default import InMemoryArtifactsProvider @@ -579,6 +580,59 @@ def test_artifacts_preserve_raw_speaker_labels_when_clusters_match_same_voicepri assert unique_speakers == ["Matched Speaker", "Matched Speaker (2)"] +def test_artifacts_use_rust_postprocess_when_required(monkeypatch): + aligned_segments = [ + { + "start": 0.0, + "end": 1.0, + "text": "first", + "speaker": "SPEAKER_00", + } + ] + speaker_map = {} + calls = [] + + def fake_postprocess_segments(payload): + calls.append(payload) + return { + "segments": [ + { + "id": 0, + "start": 0.0, + "end": 1.0, + "text": "first", + "speaker_label": "SPEAKER_00", + "speaker_id": None, + "speaker_name": "SPEAKER_00", + "similarity": 0.0, + } + ], + "unique_speakers": ["SPEAKER_00"], + } + + monkeypatch.setattr(artifacts_default, "rust_provider_paths_enabled", lambda: True) + monkeypatch.setattr( + artifacts_default, "postprocess_segments", fake_postprocess_segments + ) + assert InMemoryArtifactsProvider._build_segments.__globals__[ + "rust_provider_paths_enabled" + ]() + + segments, unique_speakers = InMemoryArtifactsProvider._build_segments( + aligned_segments, + speaker_map, + ) + + assert calls == [ + { + "aligned_segments": aligned_segments, + "speaker_map": speaker_map, + } + ] + assert segments[0]["speaker_label"] == "SPEAKER_00" + assert unique_speakers == ["SPEAKER_00"] + + def test_artifact_result_contract_keeps_status_speaker_label_and_optional_alignment( tmp_path, ): diff --git a/tests/unit/test_postprocess_segments_kernel.py b/tests/unit/test_postprocess_segments_kernel.py new file mode 100644 index 0000000..4e54be7 --- /dev/null +++ b/tests/unit/test_postprocess_segments_kernel.py @@ -0,0 +1,143 @@ +"""Golden tests for pure transcript post-processing kernels.""" + +from __future__ import annotations + +from math import nan + +from postprocess.segments import ( + build_display_names, + build_result_segments, + merge_aligned_segments, + normalize_words, +) + + +def test_normalize_words_handles_missing_none_nan_and_negative_values(): + words = normalize_words( + [ + { + "word": None, + "start": -1, + "end": nan, + "score": "not-a-number", + }, + {}, + ] + ) + + assert words == [ + {"word": "None", "start": 0.0, "end": 0.0, "score": 0.0}, + {"word": "", "start": 0.0, "end": 0.0, "score": 0.0}, + ] + + +def test_merge_aligned_segments_merges_adjacent_text_only_same_speaker(): + segments = merge_aligned_segments( + [ + { + "start": "0", + "end": 1.0, + "text": " first ", + "speaker": "SPEAKER_00", + }, + { + "start": 1.03, + "end": 2.0, + "text": "second", + "speaker": "SPEAKER_00", + }, + { + "start": 2.0, + "end": 3.0, + "text": "third", + "speaker": "SPEAKER_01", + }, + ] + ) + + assert segments == [ + { + "start": 0.0, + "end": 2.0, + "text": "first second", + "speaker": "SPEAKER_00", + }, + { + "start": 2.0, + "end": 3.0, + "text": "third", + "speaker": "SPEAKER_01", + }, + ] + + +def test_merge_aligned_segments_does_not_merge_word_payloads(): + segments = merge_aligned_segments( + [ + { + "start": 0.0, + "end": 1.0, + "text": "first", + "speaker": "SPEAKER_00", + "words": [{"word": "first", "start": 0, "end": 1, "score": 0.8}], + }, + { + "start": 1.0, + "end": 2.0, + "text": "second", + "speaker": "SPEAKER_00", + }, + ] + ) + + assert len(segments) == 2 + assert segments[0]["speaker"] == "SPEAKER_00" + assert segments[0]["words"] == [ + {"word": "first", "start": 0.0, "end": 1.0, "score": 0.8} + ] + + +def test_display_names_disambiguate_without_rewriting_speaker_labels(): + display_names = build_display_names( + ["SPEAKER_00", "SPEAKER_01"], + { + "SPEAKER_00": {"matched_name": "Maple"}, + "SPEAKER_01": {"matched_name": "Maple"}, + }, + ) + + assert display_names == { + "SPEAKER_00": "Maple", + "SPEAKER_01": "Maple (2)", + } + + +def test_build_result_segments_preserves_raw_label_and_unique_display_names(): + segments, unique_speakers = build_result_segments( + [ + {"start": 0.0, "end": 1.0, "text": "a", "speaker": "SPEAKER_00"}, + {"start": 1.0, "end": 2.0, "text": "b", "speaker": "SPEAKER_01"}, + ], + { + "SPEAKER_00": { + "matched_id": "spk_same", + "matched_name": "Maple", + "similarity": 2.0, + }, + "SPEAKER_01": { + "matched_id": "spk_same", + "matched_name": "Maple", + "similarity": 1.0, + }, + }, + ) + + assert [segment["speaker_label"] for segment in segments] == [ + "SPEAKER_00", + "SPEAKER_01", + ] + assert [segment["speaker_name"] for segment in segments] == [ + "Maple", + "Maple (2)", + ] + assert unique_speakers == ["Maple", "Maple (2)"]