Skip to content

evaluate_pairwise

autogen.beta.eval.pairwise.evaluate_pairwise async #

evaluate_pairwise(source_a, source_b, *, comparators, store_dir, variant_a='A', variant_b='B', suite=None, concurrency=4, run_id=None, label=None, stream=None)

Compare variant A vs variant B over the tasks both produced traces for.

Traces are paired by TraceRef.task_id (stamp ag2.eval.task_id at produce time). Each comparator runs over every pair; results roll up per key.

label is a shared identifier recorded on the run; pass stream to publish pairwise lifecycle events (PairwiseStarted / PairwiseCompared per case / PairwiseCompleted) for live observation.

Source code in autogen/beta/eval/pairwise.py
async def evaluate_pairwise(
    source_a: TraceSource,
    source_b: TraceSource,
    *,
    comparators: Iterable[PairwiseComparator],
    store_dir: str | os.PathLike[str],
    variant_a: str = "A",
    variant_b: str = "B",
    suite: Suite | None = None,
    concurrency: int = 4,
    run_id: str | None = None,
    label: str | None = None,
    stream: Stream | None = None,
) -> PairwiseRunResult:
    """Compare variant A vs variant B over the tasks both produced traces for.

    Traces are paired by ``TraceRef.task_id`` (stamp ``ag2.eval.task_id`` at
    produce time). Each comparator runs over every pair; results roll up per key.

    ``label`` is a shared identifier recorded on the run; pass ``stream`` to publish
    pairwise lifecycle events (``PairwiseStarted`` / ``PairwiseCompared`` per case /
    ``PairwiseCompleted``) for live observation.
    """
    comparator_list = tuple(comparators)
    keys = tuple(c.key for c in comparator_list)
    tasks_by_id = {task.task_id: task for task in suite} if suite is not None else {}

    refs_a = [ref async for ref in source_a.list()]
    b_by_task: dict[str, TraceRef] = {}
    async for ref in source_b.list():
        if ref.task_id is not None:
            b_by_task[ref.task_id] = ref
    pairs = [(ra, b_by_task[ra.task_id]) for ra in refs_a if ra.task_id is not None and ra.task_id in b_by_task]
    if not pairs:
        logger.warning(
            "evaluate_pairwise: no task_id-matched pairs between the sources (need ag2.eval.task_id on both)."
        )

    semaphore = asyncio.Semaphore(max(1, concurrency))
    actual_run_id = run_id if run_id is not None else uuid4().hex
    created_at = datetime.now(timezone.utc).isoformat()
    started = time.perf_counter()

    eval_ctx = ConversationContext(stream=stream) if stream is not None else None
    if stream is not None:
        await stream.send(
            PairwiseStarted(
                run_id=actual_run_id, label=label, variant_a=variant_a, variant_b=variant_b, total=len(pairs)
            ),
            eval_ctx,
        )
    on_case = (
        partial(_publish_pairwise_compared, stream, eval_ctx, actual_run_id, label) if stream is not None else None
    )

    case_lists = await asyncio.gather(
        *(
            _evaluate_pair(semaphore, source_a, source_b, ra, rb, comparator_list, tasks_by_id, on_case)
            for ra, rb in pairs
        )
    )
    cases = tuple(case for case_list in case_lists for case in case_list)
    duration_ms = int((time.perf_counter() - started) * 1000)

    result = PairwiseRunResult(
        run_id=actual_run_id,
        cases=cases,
        variant_a=variant_a,
        variant_b=variant_b,
        keys=keys,
        created_at=created_at,
        duration_ms=duration_ms,
        n_pairs=len(pairs),
        label=label,
        store_dir=store_dir,
    )
    result.save()
    if stream is not None:
        await stream.send(PairwiseCompleted(run_id=actual_run_id, label=label, result=result), eval_ctx)
    return result