Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions hud/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,8 @@
from hud.eval.display import display_results

from .loader import load_dataset, load_tasks, save_tasks
from .runner import run_dataset, run_single_task
from .utils import (
BatchRequest,
SingleTaskRequest,
submit_rollouts,
)
from .runner import run_dataset, run_dataset_async, run_single_task
from .utils import BatchRequest, SingleTaskRequest, submit_rollouts

__all__ = [
"BatchRequest",
Expand All @@ -30,6 +26,7 @@
"load_dataset", # Deprecated alias
"load_tasks",
"run_dataset",
"run_dataset_async",
"run_single_task",
"save_tasks",
"submit_rollouts",
Expand Down
85 changes: 85 additions & 0 deletions hud/datasets/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import asyncio
import logging
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -217,3 +218,87 @@ async def run_single_task(

# Return the Trace (ctx.reward is set by EvalContext.__aexit__)
return result


async def run_dataset_async(
tasks: str | TaskInput | Sequence[TaskInput],
agent_type: str | AgentType,
*,
agent_params: dict[str, Any] | None = None,
max_steps: int = 10,
max_concurrent: int = 30,
group_size: int = 1,
quiet: bool = True,
) -> list[Trace]:
"""Run tasks concurrently by fanning out run_single_task calls.

Returns a flat list of Trace results in task order, repeated by group_size.
"""
from hud.datasets.loader import load_tasks
from hud.eval.task import Task, build_eval_name

if group_size <= 0:
raise ValueError("group_size must be >= 1")
if max_concurrent <= 0:
raise ValueError("max_concurrent must be >= 1")

if isinstance(agent_type, str):
agent_type = AgentType(agent_type)

task_list: list[Task]
if isinstance(tasks, str):
task_list = load_tasks(tasks)
elif isinstance(tasks, Task):
task_list = [tasks]
elif isinstance(tasks, LegacyTask | dict):
task_list = [Task.from_v4(tasks)]
else:
task_list = [t if isinstance(t, Task) else Task.from_v4(t) for t in tasks]

if not task_list:
raise ValueError("No tasks to run")

sem = asyncio.Semaphore(max_concurrent)
total_runs = len(task_list) * group_size
results: list[Trace | None] = [None] * total_runs

async def _worker(
out_index: int,
task: Task,
task_index: int,
run_index: int,
) -> None:
async with sem:
try:
base_name = task.id
if base_name is None:
eval_name = build_eval_name(task.scenario, task.args)
base_name = eval_name if eval_name != "eval" else f"Task {task_index}"
trace_name = f"{base_name} #{run_index + 1}" if group_size > 1 else base_name

results[out_index] = await run_single_task(
task,
agent_type=agent_type,
agent_params=agent_params,
max_steps=max_steps,
trace_name=trace_name,
quiet=quiet,
)
except Exception as exc:
logger.exception("Task %s failed: %s", out_index, exc)
results[out_index] = Trace(isError=True, info={"error": str(exc)})

tasks_to_run: list[tuple[int, Task, int, int]] = []
for task_index, task in enumerate(task_list):
for run_index in range(group_size):
out_index = task_index * group_size + run_index
tasks_to_run.append((out_index, task, task_index, run_index))

await asyncio.gather(*[_worker(*task_args) for task_args in tasks_to_run])

return [
result
if result is not None
else Trace(isError=True, info={"error": "Task did not return a result"})
for result in results
]
29 changes: 28 additions & 1 deletion hud/tests/test_datasets_extended.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pytest

from hud.datasets import run_dataset
from hud.datasets import run_dataset, run_dataset_async
from hud.types import LegacyTask, MCPToolCall


Expand Down Expand Up @@ -239,3 +239,30 @@ async def test_run_dataset_passes_parameters(self):
quiet=True,
taskset=None,
)

@pytest.mark.asyncio
async def test_run_dataset_async_empty(self):
"""Test async runner rejects empty datasets."""
from hud.types import AgentType

with pytest.raises(ValueError, match="No tasks to run"):
await run_dataset_async([], agent_type=AgentType.CLAUDE)

@pytest.mark.asyncio
async def test_run_dataset_async_group_size(self):
"""Test async runner repeats tasks for group_size."""
from hud.eval.task import Task
from hud.types import Trace

task = Task(env={"name": "test"}, scenario="sample", args={})
traces = [Trace(reward=1.0, done=True), Trace(reward=2.0, done=True)]

with patch(
"hud.datasets.runner.run_single_task", new=AsyncMock(side_effect=traces)
) as mock_run:
results = await run_dataset_async([task], agent_type="claude", group_size=2)

assert results == traces
assert mock_run.call_count == 2
assert mock_run.call_args_list[0].kwargs["trace_name"] == "sample #1"
assert mock_run.call_args_list[1].kwargs["trace_name"] == "sample #2"
Loading