Skip to content

Commit 179312d

Browse files
committed
handle direct calls to the evaluation_test
1 parent e581153 commit 179312d

File tree

3 files changed

+153
-6
lines changed

3 files changed

+153
-6
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 95 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,101 @@ def wrapper_body(**kwargs):
259259

260260
return create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param_names)
261261

262-
wrapper = create_wrapper_with_signature()
263-
wrapper = pytest.mark.parametrize(test_param_names, param_tuples)(wrapper)
264-
wrapper.original_evaluation_test_func = test_func
262+
# Create the pytest wrapper
263+
pytest_wrapper = create_wrapper_with_signature()
264+
pytest_wrapper = pytest.mark.parametrize(test_param_names, param_tuples)(pytest_wrapper)
265265

266-
return wrapper
266+
def create_dual_mode_wrapper() -> Callable:
267+
"""
268+
Creates a wrapper that supports both pytest parameterized execution and direct function calls.
269+
270+
This wrapper enables the decorated evaluation test function to be used in two ways:
271+
1. As a pytest test (via pytest.mark.parametrize) with full parameterization
272+
2. As a direct function call with EvaluationRow data for programmatic use
273+
274+
The wrapper automatically detects the calling pattern and routes to the appropriate
275+
execution path, ensuring consistent behavior regardless of how the function is invoked.
276+
277+
Returns:
278+
A callable that can handle both pytest test execution and direct function calls
279+
"""
280+
import asyncio
281+
282+
# Check if the test function is async
283+
is_async = asyncio.iscoroutinefunction(test_func)
284+
285+
if is_async:
286+
287+
async def dual_mode_wrapper(*args, **kwargs):
288+
# Check if this is a direct call with the expected signature
289+
if mode == "pointwise":
290+
# For pointwise mode, check if called with a single row argument
291+
if len(args) == 1 and isinstance(args[0], EvaluationRow) and not kwargs:
292+
return await test_func(row=args[0])
293+
else:
294+
# For batch mode, check if called with rows argument
295+
if (
296+
len(args) == 1
297+
and isinstance(args[0], list)
298+
and all(isinstance(r, EvaluationRow) for r in args[0])
299+
and not kwargs
300+
):
301+
return await test_func(rows=args[0])
302+
# Also check if called with keyword argument 'rows'
303+
if (
304+
len(args) == 0
305+
and "rows" in kwargs
306+
and isinstance(kwargs["rows"], list)
307+
and all(isinstance(r, EvaluationRow) for r in kwargs["rows"])
308+
):
309+
return await test_func(**kwargs)
310+
311+
# If not a direct call, use the pytest wrapper
312+
return pytest_wrapper(*args, **kwargs)
313+
314+
else:
315+
316+
def dual_mode_wrapper(*args, **kwargs):
317+
# Check if this is a direct call with the expected signature
318+
if mode == "pointwise":
319+
# For pointwise mode, check if called with a single row argument
320+
if len(args) == 1 and isinstance(args[0], EvaluationRow) and not kwargs:
321+
return test_func(row=args[0])
322+
323+
if len(args) == 0 and "row" in kwargs and isinstance(kwargs["row"], EvaluationRow):
324+
return test_func(**kwargs)
325+
else:
326+
# For batch mode, check if called with rows argument
327+
if (
328+
len(args) == 1
329+
and isinstance(args[0], list)
330+
and all(isinstance(r, EvaluationRow) for r in args[0])
331+
and not kwargs
332+
):
333+
return test_func(rows=args[0])
334+
# Also check if called with keyword argument 'rows'
335+
if (
336+
len(args) == 0
337+
and "rows" in kwargs
338+
and isinstance(kwargs["rows"], list)
339+
and all(isinstance(r, EvaluationRow) for r in kwargs["rows"])
340+
):
341+
return test_func(**kwargs)
342+
343+
# If not a direct call, use the pytest wrapper
344+
return pytest_wrapper(*args, **kwargs)
345+
346+
# Copy all attributes from the pytest wrapper to our dual mode wrapper
347+
import functools
348+
349+
functools.update_wrapper(dual_mode_wrapper, pytest_wrapper)
350+
dual_mode_wrapper.original_evaluation_test_func = test_func
351+
352+
return dual_mode_wrapper
353+
354+
# Create the dual mode wrapper
355+
dual_mode_wrapper = create_dual_mode_wrapper()
356+
357+
return dual_mode_wrapper
267358

268359
return decorator

eval_protocol/pytest/utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,17 @@ def execute_function(func: Callable, **kwargs) -> Any:
2424
# Handle async functions with proper event loop management
2525
try:
2626
loop = asyncio.get_event_loop()
27-
if not loop.is_closed():
28-
# Use existing loop
27+
if loop.is_running():
28+
# Event loop is already running, create a task and wait for it
29+
task = loop.create_task(func(**kwargs))
30+
# Use asyncio.wait to avoid run_until_complete on running loop
31+
import concurrent.futures
32+
33+
with concurrent.futures.ThreadPoolExecutor() as executor:
34+
future = executor.submit(asyncio.run, func(**kwargs))
35+
results = future.result()
36+
elif not loop.is_closed():
37+
# Use existing loop that's not running
2938
task = loop.create_task(func(**kwargs))
3039
results = loop.run_until_complete(task)
3140
else:

tests/pytest/test_pytest_async.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import asyncio
12
from typing import List
23

4+
import pytest
5+
36
from eval_protocol.models import EvaluationRow, Message
47
from eval_protocol.pytest import evaluation_test
58
from examples.math_example.main import evaluate as math_evaluate
@@ -19,3 +22,47 @@
1922
async def test_pytest_async(rows: List[EvaluationRow]) -> List[EvaluationRow]:
2023
"""Run math evaluation on sample dataset using pytest interface."""
2124
return rows
25+
26+
27+
@evaluation_test(
28+
input_messages=[
29+
[
30+
Message(role="user", content="What is the capital of France?"),
31+
],
32+
],
33+
model=["accounts/fireworks/models/kimi-k2-instruct"],
34+
mode="pointwise",
35+
)
36+
async def test_pytest_async_pointwise(row: EvaluationRow) -> EvaluationRow:
37+
"""Run pointwise evaluation on sample dataset using pytest interface."""
38+
return row
39+
40+
41+
@pytest.mark.asyncio
42+
async def test_pytest_async_main():
43+
"""
44+
Tests that we can just run the test function directly
45+
"""
46+
rows = [
47+
EvaluationRow(
48+
messages=[
49+
Message(role="user", content="What is the capital of France?"),
50+
],
51+
)
52+
]
53+
result = await test_pytest_async(rows)
54+
assert result == rows
55+
56+
57+
@pytest.mark.asyncio
58+
async def test_pytest_async_pointwise_main():
59+
"""
60+
Tests that we can just run the pointwise test function directly
61+
"""
62+
row = EvaluationRow(
63+
messages=[
64+
Message(role="user", content="What is the capital of France?"),
65+
],
66+
)
67+
result = await test_pytest_async_pointwise(row)
68+
assert result == row

0 commit comments

Comments
 (0)