Skip to content

Commit 013bf4f

Browse files
gerlerowillemt
authored andcommitted
Support pytest_asyncio_cooperative-marked synchronous tests
1 parent 9d15ec7 commit 013bf4f

File tree

3 files changed

+36
-52
lines changed

3 files changed

+36
-52
lines changed

pytest_asyncio_cooperative/plugin.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,6 @@ def pytest_runtest_makereport(item, call):
7373
call.duration = call.stop - call.start
7474

7575

76-
def not_coroutine_failure(function_name: str, *args, **kwargs):
77-
raise Exception(
78-
f"Function {function_name} is not a coroutine.\n"
79-
f"Tests with the `@pytest.mark.asyncio_cooperative` mark MUST be coroutines.\n"
80-
f"Please add the `async` keyword to the test function."
81-
)
82-
83-
8476
async def test_wrapper(item):
8577
# Do setup
8678
item.start_setup = time.time()
@@ -109,7 +101,10 @@ async def do_teardowns():
109101
# Run test
110102
item.start = time.time()
111103
try:
112-
await item.function(*fixture_values)
104+
if inspect.iscoroutinefunction(item.function):
105+
await item.function(*fixture_values)
106+
else:
107+
item.function(*fixture_values)
113108
except:
114109
# Teardown here otherwise we might leave fixtures with locks acquired
115110
item.stop = time.time()
@@ -158,17 +153,11 @@ def async_to_sync(*args, **kwargs):
158153
item.stop_teardown = time.time()
159154

160155

161-
class NotCoroutine(Exception):
162-
pass
163-
164-
165156
def item_to_task(item):
166-
if inspect.iscoroutinefunction(item.function):
167-
return test_wrapper(item)
168-
elif getattr(item.function, "is_hypothesis_test", False):
157+
if getattr(item.function, "is_hypothesis_test", False):
169158
return hypothesis_test_wrapper(item)
170159
else:
171-
raise NotCoroutine
160+
return test_wrapper(item)
172161

173162

174163
def _run_test_loop(tasks, session, run_tests):
@@ -221,12 +210,7 @@ def pytest_runtestloop(session):
221210

222211
# Coerce into a task
223212
if "asyncio_cooperative" in markers:
224-
try:
225-
task = item_to_task(item)
226-
except NotCoroutine:
227-
item.runtest = functools.partial(not_coroutine_failure, item.name)
228-
item.ihook.pytest_runtest_protocol(item=item, nextitem=None)
229-
continue
213+
task = item_to_task(item)
230214

231215
item._flakey = "flakey" in markers
232216
item_by_coro[task] = item

tests/test_core.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,32 @@ def test_b():
9191

9292
# we expect this:
9393
result.assert_outcomes(failed=1, passed=1)
94+
95+
96+
def test_synchronous_test_with_async_fixture(testdir):
97+
testdir.makepyfile(
98+
"""
99+
import asyncio
100+
import pytest
101+
102+
103+
@pytest.fixture
104+
async def async_fixture():
105+
return await asyncio.sleep(1, 42)
106+
107+
@pytest.fixture
108+
def sync_fixture():
109+
return 42
110+
111+
@pytest.mark.asyncio_cooperative
112+
async def test_async(async_fixture, sync_fixture):
113+
assert async_fixture == sync_fixture == 42
114+
115+
@pytest.mark.asyncio_cooperative
116+
def test_sync(async_fixture, sync_fixture):
117+
assert async_fixture == sync_fixture == 42
118+
"""
119+
)
120+
121+
result = testdir.runpytest()
122+
result.assert_outcomes(passed=2)

tests/test_fail.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,5 @@
11
import pytest
22

3-
from .conftest import includes_lines
4-
5-
6-
def test_function_must_be_async(testdir):
7-
testdir.makeconftest("""""")
8-
9-
testdir.makepyfile(
10-
"""
11-
import asyncio
12-
import pytest
13-
14-
15-
@pytest.mark.asyncio_cooperative
16-
def test_a():
17-
assert 1 == 1
18-
"""
19-
)
20-
21-
expected_lines = [
22-
"E Exception: Function test_a is not a coroutine.",
23-
"E Tests with the `@pytest.mark.asyncio_cooperative` mark MUST be coroutines.",
24-
"E Please add the `async` keyword to the test function.",
25-
]
26-
27-
result = testdir.runpytest()
28-
assert includes_lines(expected_lines, result.stdout.lines)
29-
30-
result.assert_outcomes(failed=1)
31-
323

334
@pytest.mark.parametrize("dur1, dur2, expectedfails, expectedpasses", [
345
(1.1, 2, 2, 0),

0 commit comments

Comments
 (0)