Skip to content

Commit 2c48dae

Browse files
committed
refac: move run_tests to module scope
1 parent 80f89a1 commit 2c48dae

File tree

1 file changed

+133
-128
lines changed

1 file changed

+133
-128
lines changed

pytest_asyncio_cooperative/plugin.py

Lines changed: 133 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -122,16 +122,144 @@ def item_to_task(item):
122122
return test_wrapper(item)
123123

124124

125-
def _run_test_loop(tasks, session, run_tests):
125+
def get_coro(task):
126+
if sys_version_info >= (3, 8):
127+
return task.get_coro()
128+
else:
129+
return task._coro
130+
131+
132+
def cancel_task(task, now, item):
133+
if sys_version_info >= (3, 9):
134+
msg = "Test took too long ({:.2f} s)".format(now - item.enqueue_time)
135+
task.cancel(msg=msg)
136+
else:
137+
task.cancel()
138+
139+
140+
async def run_tests(tasks, max_tasks: int, session, item_by_coro):
141+
flakes_to_retry = []
142+
143+
sidelined_tasks = tasks[max_tasks:]
144+
tasks = tasks[:max_tasks]
145+
146+
task_timeout = int(
147+
session.config.getoption("--asyncio-task-timeout")
148+
or session.config.getini("asyncio_task_timeout")
149+
)
150+
151+
completed = []
152+
cancelled = []
153+
while tasks:
154+
# Schedule all the coroutines
155+
for i in range(len(tasks)):
156+
if asyncio.iscoroutine(tasks[i]):
157+
tasks[i] = asyncio.create_task(tasks[i])
158+
159+
# Mark when the task was started
160+
earliest_enqueue_time = time.time()
161+
for task in tasks:
162+
if isinstance(task, asyncio.Task):
163+
item = item_by_coro[get_coro(task)]
164+
else:
165+
item = item_by_coro[task]
166+
if not hasattr(item, "enqueue_time"):
167+
item.enqueue_time = time.time()
168+
earliest_enqueue_time = min(item.enqueue_time, earliest_enqueue_time)
169+
170+
time_to_wait = (time.time() - earliest_enqueue_time) - task_timeout
171+
done, pending = await asyncio.wait(
172+
tasks,
173+
return_when=asyncio.FIRST_COMPLETED,
174+
timeout=min(30, int(time_to_wait)),
175+
)
176+
177+
# Cancel tasks that have taken too long
178+
tasks = []
179+
for task in pending:
180+
now = time.time()
181+
item = item_by_coro[get_coro(task)]
182+
if task not in cancelled and task_timeout < now - item.enqueue_time:
183+
cancel_task(task, now, item)
184+
cancelled.append(task)
185+
tasks.append(task)
186+
187+
for result in done:
188+
item = item_by_coro[get_coro(result)]
189+
190+
# Flakey tests will be run again if they failed
191+
# TODO: add retry count
192+
if item._flakey:
193+
try:
194+
result.result()
195+
except:
196+
item._flakey = None
197+
new_task = item_to_task(item)
198+
flakes_to_retry.append(new_task)
199+
item_by_coro[new_task] = item
200+
continue
201+
202+
# We need to change .runtest to a synchronous function for pytest
203+
# however, if it is called again by retry libraries we need to rerun
204+
# the test instead of retuning the previous result
205+
def wrap_in_sync():
206+
def sync_wrapper():
207+
new_task = item_to_task(item)
208+
209+
# We use a new thread because we can't block for an async function
210+
# in the same thread as the current running event loop, nor
211+
# we can nest event loops
212+
result = None
213+
214+
def run_in_thread():
215+
nonlocal result
216+
try:
217+
result = asyncio.run(new_task)
218+
except Exception as e:
219+
result = e
220+
221+
thread = threading.Thread(target=run_in_thread)
222+
thread.start()
223+
thread.join()
224+
225+
if isinstance(result, Exception):
226+
raise result # type: ignore
227+
228+
return result
229+
230+
item.runtest = sync_wrapper
231+
232+
return result.result()
233+
234+
item.runtest = wrap_in_sync
235+
236+
item.ihook.pytest_runtest_protocol(item=item, nextitem=None)
237+
238+
# Hack: See rewrite comment below
239+
# pytest_runttest_protocl will disable the rewrite assertion
240+
# so we renable it here
241+
activate_assert_rewrite(item)
242+
243+
completed.append(result)
244+
245+
if sidelined_tasks:
246+
if len(tasks) < max_tasks:
247+
tasks.append(sidelined_tasks.pop(0))
248+
249+
return flakes_to_retry
250+
251+
252+
def _run_test_loop(tasks, session, item_by_coro):
126253
max_tasks = int(
127254
session.config.getoption("--max-asyncio-tasks")
128255
or session.config.getini("max_asyncio_tasks")
129256
)
130257

131258
loop = asyncio.new_event_loop()
132259
try:
133-
task = run_tests(tasks, int(max_tasks))
134-
loop.run_until_complete(task)
260+
return loop.run_until_complete(
261+
run_tests(tasks, int(max_tasks), session, item_by_coro)
262+
)
135263
finally:
136264
loop.close()
137265

@@ -155,8 +283,6 @@ def pytest_runtestloop(session):
155283

156284
session.wrapped_fixtures = {}
157285

158-
flakes_to_retry = []
159-
160286
# Collect our coroutines
161287
regular_items = []
162288
item_by_coro = {}
@@ -180,127 +306,6 @@ def pytest_runtestloop(session):
180306
else:
181307
regular_items.append(item)
182308

183-
def get_coro(task):
184-
if sys_version_info >= (3, 8):
185-
return task.get_coro()
186-
else:
187-
return task._coro
188-
189-
async def run_tests(tasks, max_tasks: int):
190-
sidelined_tasks = tasks[max_tasks:]
191-
tasks = tasks[:max_tasks]
192-
193-
task_timeout = int(
194-
session.config.getoption("--asyncio-task-timeout")
195-
or session.config.getini("asyncio_task_timeout")
196-
)
197-
198-
completed = []
199-
cancelled = []
200-
while tasks:
201-
# Schedule all the coroutines
202-
for i in range(len(tasks)):
203-
if asyncio.iscoroutine(tasks[i]):
204-
tasks[i] = asyncio.create_task(tasks[i])
205-
206-
# Mark when the task was started
207-
earliest_enqueue_time = time.time()
208-
for task in tasks:
209-
if isinstance(task, asyncio.Task):
210-
item = item_by_coro[get_coro(task)]
211-
else:
212-
item = item_by_coro[task]
213-
if not hasattr(item, "enqueue_time"):
214-
item.enqueue_time = time.time()
215-
earliest_enqueue_time = min(item.enqueue_time, earliest_enqueue_time)
216-
217-
time_to_wait = (time.time() - earliest_enqueue_time) - task_timeout
218-
done, pending = await asyncio.wait(
219-
tasks,
220-
return_when=asyncio.FIRST_COMPLETED,
221-
timeout=min(30, int(time_to_wait)),
222-
)
223-
224-
# Cancel tasks that have taken too long
225-
tasks = []
226-
for task in pending:
227-
now = time.time()
228-
item = item_by_coro[get_coro(task)]
229-
if task not in cancelled and task_timeout < now - item.enqueue_time:
230-
if sys_version_info >= (3, 9):
231-
msg = "Test took too long ({:.2f} s)".format(
232-
now - item.enqueue_time
233-
)
234-
task.cancel(msg=msg)
235-
else:
236-
task.cancel()
237-
cancelled.append(task)
238-
tasks.append(task)
239-
240-
for result in done:
241-
item = item_by_coro[get_coro(result)]
242-
243-
# Flakey tests will be run again if they failed
244-
# TODO: add retry count
245-
if item._flakey:
246-
try:
247-
result.result()
248-
except:
249-
item._flakey = None
250-
new_task = item_to_task(item)
251-
flakes_to_retry.append(new_task)
252-
item_by_coro[new_task] = item
253-
continue
254-
255-
# We need to change .runtest to a synchronous function for pytest
256-
# however, if it is called again by retry libraries we need to rerun
257-
# the test instead of retuning the previous result
258-
def wrap_in_sync():
259-
def sync_wrapper():
260-
new_task = item_to_task(item)
261-
262-
# We use a new thread because we can't block for an async function
263-
# in the same thread as the current running event loop, nor
264-
# we can nest event loops
265-
result = None
266-
267-
def run_in_thread():
268-
nonlocal result
269-
try:
270-
result = asyncio.run(new_task)
271-
except Exception as e:
272-
result = e
273-
274-
thread = threading.Thread(target=run_in_thread)
275-
thread.start()
276-
thread.join()
277-
278-
if isinstance(result, Exception):
279-
raise result # type: ignore
280-
281-
return result
282-
283-
item.runtest = sync_wrapper
284-
285-
return result.result()
286-
287-
item.runtest = wrap_in_sync
288-
289-
item.ihook.pytest_runtest_protocol(item=item, nextitem=None)
290-
291-
# Hack: See rewrite comment below
292-
# pytest_runttest_protocl will disable the rewrite assertion
293-
# so we renable it here
294-
activate_assert_rewrite(item)
295-
296-
completed.append(result)
297-
298-
if sidelined_tasks:
299-
if len(tasks) < max_tasks:
300-
tasks.append(sidelined_tasks.pop(0))
301-
302-
return completed
303-
304309
# Do assert rewrite
305310
# Hack: pytest's implementation sets up assert rewriting as a shared
306311
# resource. This causes a race condition between async tests. Therefore we
@@ -313,11 +318,11 @@ def run_in_thread():
313318
return
314319

315320
# Run the tests using cooperative multitasking
316-
_run_test_loop(tasks, session, run_tests)
321+
flakes_to_retry = _run_test_loop(tasks, session, item_by_coro)
317322

318323
# Run failed flakey tests
319324
if flakes_to_retry:
320-
_run_test_loop(flakes_to_retry, session, run_tests)
325+
_run_test_loop(flakes_to_retry, session, item_by_coro)
321326

322327
# Run synchronous tests
323328
session.items = regular_items

0 commit comments

Comments
 (0)