Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion py/src/braintrust/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -1351,6 +1351,14 @@ def _validate_classification_result(value: Any, classifier_name: str) -> Classif
return classification


def _get_persisted_base_experiment_id(experiment: Experiment) -> str | None:
try:
base_experiment_id = experiment.data.get("base_exp_id")
except Exception:
return None
return base_experiment_id if isinstance(base_experiment_id, str) and base_experiment_id else None


async def run_evaluator(
experiment: Experiment | None,
evaluator: Evaluator[Input, Output, Expected],
Expand All @@ -1367,7 +1375,13 @@ async def run_evaluator(
)

if experiment:
summary = experiment.summarize(summarize_scores=evaluator.summarize_scores)
comparison_experiment_id = evaluator.base_experiment_id
if comparison_experiment_id is None:
comparison_experiment_id = _get_persisted_base_experiment_id(experiment)
summary = experiment.summarize(
summarize_scores=evaluator.summarize_scores,
comparison_experiment_id=comparison_experiment_id,
)
else:
summary = build_local_summary(evaluator, results)

Expand Down
6 changes: 6 additions & 0 deletions py/src/braintrust/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3969,6 +3969,12 @@ def summarize(
if base_experiment:
comparison_experiment_id = base_experiment.id
comparison_experiment_name = base_experiment.name
else:
try:
comparison_experiment = state.api_conn().get_json(f"v1/experiment/{comparison_experiment_id}")
comparison_experiment_name = comparison_experiment.get("name")
except Exception:
pass

try:
summary_items = state.api_conn().get_json(
Expand Down
88 changes: 87 additions & 1 deletion py/src/braintrust/test_framework.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import importlib.util
import re
import sys
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import pytest
from braintrust.logger import BraintrustState
Expand Down Expand Up @@ -78,6 +78,92 @@ def exact_match(input_value, output, expected):
assert result.summary.scores["exact_match"].score == 1.0


@pytest.mark.asyncio
async def test_run_evaluator_forwards_base_experiment_id_to_summary(with_memory_logger, with_simulate_login):
def exact_match(input_value, output, expected):
return 1.0 if output == expected else 0.0

evaluator = Evaluator(
project_name="test-project",
eval_name="test-evaluator",
data=[EvalCase(input=1, expected=1)],
task=lambda input_value: input_value,
scores=[exact_match],
experiment_name=None,
metadata=None,
base_experiment_id="base-exp-id",
)

exp = init_test_exp("test-evaluator", "test-project")
expected_summary = MagicMock()
exp.summarize = MagicMock(return_value=expected_summary)

result = await run_evaluator(experiment=exp, evaluator=evaluator, position=None, filters=[])

assert result.summary is expected_summary
exp.summarize.assert_called_once_with(
summarize_scores=True,
comparison_experiment_id="base-exp-id",
)


@pytest.mark.asyncio
async def test_run_evaluator_forwards_persisted_base_experiment_id_to_summary(with_memory_logger, with_simulate_login):
def exact_match(input_value, output, expected):
return 1.0 if output == expected else 0.0

evaluator = Evaluator(
project_name="test-project",
eval_name="test-evaluator",
data=[EvalCase(input=1, expected=1)],
task=lambda input_value: input_value,
scores=[exact_match],
experiment_name=None,
metadata=None,
base_experiment_name="base-exp",
)

exp = init_test_exp("test-evaluator", "test-project")
exp.data["base_exp_id"] = "base-exp-id"
expected_summary = MagicMock()
exp.summarize = MagicMock(return_value=expected_summary)

result = await run_evaluator(experiment=exp, evaluator=evaluator, position=None, filters=[])

assert result.summary is expected_summary
exp.summarize.assert_called_once_with(
summarize_scores=True,
comparison_experiment_id="base-exp-id",
)


def test_experiment_summarize_resolves_explicit_comparison_name(with_memory_logger, with_simulate_login):
exp = init_test_exp("test-evaluator", "test-project")
mock_conn = MagicMock()

def get_json(path, args=None):
if path == "v1/experiment/base-exp-id":
return {"name": "base-exp"}
if path == "experiment-comparison2":
return {"scores": {}, "metrics": {}}
raise AssertionError(f"Unexpected get_json call: {path}, {args}")

mock_conn.get_json.side_effect = get_json

with patch.object(exp.state, "api_conn", return_value=mock_conn):
summary = exp.summarize(comparison_experiment_id="base-exp-id")

assert summary.comparison_experiment_name == "base-exp"
mock_conn.get_json.assert_any_call("v1/experiment/base-exp-id")
mock_conn.get_json.assert_any_call(
"experiment-comparison2",
args={
"experiment_id": "test-evaluator",
"base_experiment_id": "base-exp-id",
},
)


@pytest.mark.asyncio
@pytest.mark.skipif(not HAS_PYDANTIC, reason="pydantic not installed")
async def test_run_evaluator_exposes_validated_parameter_values_to_hooks():
Expand Down
Loading