diff --git a/hud/datasets/__init__.py b/hud/datasets/__init__.py index 6bf88851..fa860942 100644 --- a/hud/datasets/__init__.py +++ b/hud/datasets/__init__.py @@ -16,12 +16,8 @@ from hud.eval.display import display_results from .loader import load_dataset, load_tasks, save_tasks -from .runner import run_dataset, run_single_task -from .utils import ( - BatchRequest, - SingleTaskRequest, - submit_rollouts, -) +from .runner import run_dataset, run_dataset_async, run_single_task +from .utils import BatchRequest, SingleTaskRequest, submit_rollouts __all__ = [ "BatchRequest", @@ -30,6 +26,7 @@ "load_dataset", # Deprecated alias "load_tasks", "run_dataset", + "run_dataset_async", "run_single_task", "save_tasks", "submit_rollouts", diff --git a/hud/datasets/runner.py b/hud/datasets/runner.py index 70174c53..2f55a25c 100644 --- a/hud/datasets/runner.py +++ b/hud/datasets/runner.py @@ -5,6 +5,7 @@ from __future__ import annotations +import asyncio import logging from typing import TYPE_CHECKING, Any @@ -217,3 +218,87 @@ async def run_single_task( # Return the Trace (ctx.reward is set by EvalContext.__aexit__) return result + + +async def run_dataset_async( + tasks: str | TaskInput | Sequence[TaskInput], + agent_type: str | AgentType, + *, + agent_params: dict[str, Any] | None = None, + max_steps: int = 10, + max_concurrent: int = 30, + group_size: int = 1, + quiet: bool = True, +) -> list[Trace]: + """Run tasks concurrently by fanning out run_single_task calls. + + Returns a flat list of Trace results in task order, repeated by group_size. + """ + from hud.datasets.loader import load_tasks + from hud.eval.task import Task, build_eval_name + + if group_size <= 0: + raise ValueError("group_size must be >= 1") + if max_concurrent <= 0: + raise ValueError("max_concurrent must be >= 1") + + if isinstance(agent_type, str): + agent_type = AgentType(agent_type) + + task_list: list[Task] + if isinstance(tasks, str): + task_list = load_tasks(tasks) + elif isinstance(tasks, Task): + task_list = [tasks] + elif isinstance(tasks, LegacyTask | dict): + task_list = [Task.from_v4(tasks)] + else: + task_list = [t if isinstance(t, Task) else Task.from_v4(t) for t in tasks] + + if not task_list: + raise ValueError("No tasks to run") + + sem = asyncio.Semaphore(max_concurrent) + total_runs = len(task_list) * group_size + results: list[Trace | None] = [None] * total_runs + + async def _worker( + out_index: int, + task: Task, + task_index: int, + run_index: int, + ) -> None: + async with sem: + try: + base_name = task.id + if base_name is None: + eval_name = build_eval_name(task.scenario, task.args) + base_name = eval_name if eval_name != "eval" else f"Task {task_index}" + trace_name = f"{base_name} #{run_index + 1}" if group_size > 1 else base_name + + results[out_index] = await run_single_task( + task, + agent_type=agent_type, + agent_params=agent_params, + max_steps=max_steps, + trace_name=trace_name, + quiet=quiet, + ) + except Exception as exc: + logger.exception("Task %s failed: %s", out_index, exc) + results[out_index] = Trace(isError=True, info={"error": str(exc)}) + + tasks_to_run: list[tuple[int, Task, int, int]] = [] + for task_index, task in enumerate(task_list): + for run_index in range(group_size): + out_index = task_index * group_size + run_index + tasks_to_run.append((out_index, task, task_index, run_index)) + + await asyncio.gather(*[_worker(*task_args) for task_args in tasks_to_run]) + + return [ + result + if result is not None + else Trace(isError=True, info={"error": "Task did not return a result"}) + for result in results + ] diff --git a/hud/tests/test_datasets_extended.py b/hud/tests/test_datasets_extended.py index 3a870aaa..051d7d67 100644 --- a/hud/tests/test_datasets_extended.py +++ b/hud/tests/test_datasets_extended.py @@ -7,7 +7,7 @@ import pytest -from hud.datasets import run_dataset +from hud.datasets import run_dataset, run_dataset_async from hud.types import LegacyTask, MCPToolCall @@ -239,3 +239,30 @@ async def test_run_dataset_passes_parameters(self): quiet=True, taskset=None, ) + + @pytest.mark.asyncio + async def test_run_dataset_async_empty(self): + """Test async runner rejects empty datasets.""" + from hud.types import AgentType + + with pytest.raises(ValueError, match="No tasks to run"): + await run_dataset_async([], agent_type=AgentType.CLAUDE) + + @pytest.mark.asyncio + async def test_run_dataset_async_group_size(self): + """Test async runner repeats tasks for group_size.""" + from hud.eval.task import Task + from hud.types import Trace + + task = Task(env={"name": "test"}, scenario="sample", args={}) + traces = [Trace(reward=1.0, done=True), Trace(reward=2.0, done=True)] + + with patch( + "hud.datasets.runner.run_single_task", new=AsyncMock(side_effect=traces) + ) as mock_run: + results = await run_dataset_async([task], agent_type="claude", group_size=2) + + assert results == traces + assert mock_run.call_count == 2 + assert mock_run.call_args_list[0].kwargs["trace_name"] == "sample #1" + assert mock_run.call_args_list[1].kwargs["trace_name"] == "sample #2"