Skip to content

Commit db55847

Browse files
author
zach
committed
cleanup: improve task run selection
1 parent 11ed85d commit db55847

File tree

1 file changed

+38
-15
lines changed

1 file changed

+38
-15
lines changed

mcpx_eval/judge.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
logger = logging.getLogger(__name__)
1616

1717

18-
def load_task_run(
19-
client: mcp_run.Client, task: str, task_run_name: str
20-
) -> mcp_run.TaskRun | None:
21-
for run in client.list_task_runs(task):
22-
if run.name == task_run_name:
23-
return run
24-
return None
18+
def is_int(x):
19+
if x is None:
20+
return False
21+
try:
22+
int(x)
23+
return True
24+
except ValueError:
25+
return False
2526

2627

2728
def task_run_index(
@@ -335,6 +336,7 @@ async def run(
335336
expected_tools: List[str],
336337
task: str | None = None,
337338
task_run: str | None = None,
339+
vars: dict | None = None,
338340
) -> Results:
339341
"""Run evaluation across all models."""
340342
scores = []
@@ -350,15 +352,36 @@ async def run(
350352
client, run, check, expected_tools, model_config
351353
)
352354
)
353-
else:
354-
try:
355-
task_run = int(task_run or -1)
356-
except ValueError:
357-
pass
358-
if isinstance(task_run, int):
359-
run = task_run_index(client, task, index=task_run)
355+
elif is_int(task_run) or task_run == "latest":
356+
if task_run == "latest":
357+
task_run = -1
358+
task_run = int(task_run or -1)
359+
run = task_run_index(client, task, index=task_run)
360+
if run is not None:
361+
scores.append(
362+
await self._evaluate_task_run(
363+
client, run, check, expected_tools, model_config
364+
)
365+
)
360366
else:
361-
run = load_task_run(client, task, task_run)
367+
logger.error(f"Unable to load {task_run} for task {task}")
368+
elif task_run is not None and task_run != "new":
369+
found = False
370+
for run in client.list_task_runs(task):
371+
if run.name == task_run:
372+
scores.append(
373+
await self._evaluate_task_run(
374+
client, run, check, expected_tools, model_config
375+
)
376+
)
377+
found = True
378+
if not found:
379+
logger.error(f"Unable to load {task_run} for task {task}")
380+
elif len(self.models) == 0:
381+
logger.info("No task run specified, this will execute a new task run")
382+
run = client.tasks[task].run(vars or {})
383+
run.wait()
384+
run = task_run_index(client, task, index=-1)
362385
if run is not None:
363386
scores.append(
364387
await self._evaluate_task_run(

0 commit comments

Comments
 (0)