async def evaluate_pairwise(
source_a: TraceSource,
source_b: TraceSource,
*,
comparators: Iterable[PairwiseComparator],
store_dir: str | os.PathLike[str],
variant_a: str = "A",
variant_b: str = "B",
suite: Suite | None = None,
concurrency: int = 4,
run_id: str | None = None,
label: str | None = None,
stream: Stream | None = None,
) -> PairwiseRunResult:
"""Compare variant A vs variant B over the tasks both produced traces for.
Traces are paired by ``TraceRef.task_id`` (stamp ``ag2.eval.task_id`` at
produce time). Each comparator runs over every pair; results roll up per key.
``label`` is a shared identifier recorded on the run; pass ``stream`` to publish
pairwise lifecycle events (``PairwiseStarted`` / ``PairwiseCompared`` per case /
``PairwiseCompleted``) for live observation.
"""
comparator_list = tuple(comparators)
keys = tuple(c.key for c in comparator_list)
tasks_by_id = {task.task_id: task for task in suite} if suite is not None else {}
refs_a = [ref async for ref in source_a.list()]
b_by_task: dict[str, TraceRef] = {}
async for ref in source_b.list():
if ref.task_id is not None:
b_by_task[ref.task_id] = ref
pairs = [(ra, b_by_task[ra.task_id]) for ra in refs_a if ra.task_id is not None and ra.task_id in b_by_task]
if not pairs:
logger.warning(
"evaluate_pairwise: no task_id-matched pairs between the sources (need ag2.eval.task_id on both)."
)
semaphore = asyncio.Semaphore(max(1, concurrency))
actual_run_id = run_id if run_id is not None else uuid4().hex
created_at = datetime.now(timezone.utc).isoformat()
started = time.perf_counter()
eval_ctx = ConversationContext(stream=stream) if stream is not None else None
if stream is not None:
await stream.send(
PairwiseStarted(
run_id=actual_run_id, label=label, variant_a=variant_a, variant_b=variant_b, total=len(pairs)
),
eval_ctx,
)
on_case = (
partial(_publish_pairwise_compared, stream, eval_ctx, actual_run_id, label) if stream is not None else None
)
case_lists = await asyncio.gather(
*(
_evaluate_pair(semaphore, source_a, source_b, ra, rb, comparator_list, tasks_by_id, on_case)
for ra, rb in pairs
)
)
cases = tuple(case for case_list in case_lists for case in case_list)
duration_ms = int((time.perf_counter() - started) * 1000)
result = PairwiseRunResult(
run_id=actual_run_id,
cases=cases,
variant_a=variant_a,
variant_b=variant_b,
keys=keys,
created_at=created_at,
duration_ms=duration_ms,
n_pairs=len(pairs),
label=label,
store_dir=store_dir,
)
result.save()
if stream is not None:
await stream.send(PairwiseCompleted(run_id=actual_run_id, label=label, result=result), eval_ctx)
return result