@@ -122,16 +122,144 @@ def item_to_task(item):
122122 return test_wrapper (item )
123123
124124
125- def _run_test_loop (tasks , session , run_tests ):
125+ def get_coro (task ):
126+ if sys_version_info >= (3 , 8 ):
127+ return task .get_coro ()
128+ else :
129+ return task ._coro
130+
131+
132+ def cancel_task (task , now , item ):
133+ if sys_version_info >= (3 , 9 ):
134+ msg = "Test took too long ({:.2f} s)" .format (now - item .enqueue_time )
135+ task .cancel (msg = msg )
136+ else :
137+ task .cancel ()
138+
139+
140+ async def run_tests (tasks , max_tasks : int , session , item_by_coro ):
141+ flakes_to_retry = []
142+
143+ sidelined_tasks = tasks [max_tasks :]
144+ tasks = tasks [:max_tasks ]
145+
146+ task_timeout = int (
147+ session .config .getoption ("--asyncio-task-timeout" )
148+ or session .config .getini ("asyncio_task_timeout" )
149+ )
150+
151+ completed = []
152+ cancelled = []
153+ while tasks :
154+ # Schedule all the coroutines
155+ for i in range (len (tasks )):
156+ if asyncio .iscoroutine (tasks [i ]):
157+ tasks [i ] = asyncio .create_task (tasks [i ])
158+
159+ # Mark when the task was started
160+ earliest_enqueue_time = time .time ()
161+ for task in tasks :
162+ if isinstance (task , asyncio .Task ):
163+ item = item_by_coro [get_coro (task )]
164+ else :
165+ item = item_by_coro [task ]
166+ if not hasattr (item , "enqueue_time" ):
167+ item .enqueue_time = time .time ()
168+ earliest_enqueue_time = min (item .enqueue_time , earliest_enqueue_time )
169+
170+ time_to_wait = (time .time () - earliest_enqueue_time ) - task_timeout
171+ done , pending = await asyncio .wait (
172+ tasks ,
173+ return_when = asyncio .FIRST_COMPLETED ,
174+ timeout = min (30 , int (time_to_wait )),
175+ )
176+
177+ # Cancel tasks that have taken too long
178+ tasks = []
179+ for task in pending :
180+ now = time .time ()
181+ item = item_by_coro [get_coro (task )]
182+ if task not in cancelled and task_timeout < now - item .enqueue_time :
183+ cancel_task (task , now , item )
184+ cancelled .append (task )
185+ tasks .append (task )
186+
187+ for result in done :
188+ item = item_by_coro [get_coro (result )]
189+
190+ # Flakey tests will be run again if they failed
191+ # TODO: add retry count
192+ if item ._flakey :
193+ try :
194+ result .result ()
195+ except :
196+ item ._flakey = None
197+ new_task = item_to_task (item )
198+ flakes_to_retry .append (new_task )
199+ item_by_coro [new_task ] = item
200+ continue
201+
202+ # We need to change .runtest to a synchronous function for pytest
203+ # however, if it is called again by retry libraries we need to rerun
204+ # the test instead of retuning the previous result
205+ def wrap_in_sync ():
206+ def sync_wrapper ():
207+ new_task = item_to_task (item )
208+
209+ # We use a new thread because we can't block for an async function
210+ # in the same thread as the current running event loop, nor
211+ # we can nest event loops
212+ result = None
213+
214+ def run_in_thread ():
215+ nonlocal result
216+ try :
217+ result = asyncio .run (new_task )
218+ except Exception as e :
219+ result = e
220+
221+ thread = threading .Thread (target = run_in_thread )
222+ thread .start ()
223+ thread .join ()
224+
225+ if isinstance (result , Exception ):
226+ raise result # type: ignore
227+
228+ return result
229+
230+ item .runtest = sync_wrapper
231+
232+ return result .result ()
233+
234+ item .runtest = wrap_in_sync
235+
236+ item .ihook .pytest_runtest_protocol (item = item , nextitem = None )
237+
238+ # Hack: See rewrite comment below
239+ # pytest_runttest_protocl will disable the rewrite assertion
240+ # so we renable it here
241+ activate_assert_rewrite (item )
242+
243+ completed .append (result )
244+
245+ if sidelined_tasks :
246+ if len (tasks ) < max_tasks :
247+ tasks .append (sidelined_tasks .pop (0 ))
248+
249+ return flakes_to_retry
250+
251+
252+ def _run_test_loop (tasks , session , item_by_coro ):
126253 max_tasks = int (
127254 session .config .getoption ("--max-asyncio-tasks" )
128255 or session .config .getini ("max_asyncio_tasks" )
129256 )
130257
131258 loop = asyncio .new_event_loop ()
132259 try :
133- task = run_tests (tasks , int (max_tasks ))
134- loop .run_until_complete (task )
260+ return loop .run_until_complete (
261+ run_tests (tasks , int (max_tasks ), session , item_by_coro )
262+ )
135263 finally :
136264 loop .close ()
137265
@@ -155,8 +283,6 @@ def pytest_runtestloop(session):
155283
156284 session .wrapped_fixtures = {}
157285
158- flakes_to_retry = []
159-
160286 # Collect our coroutines
161287 regular_items = []
162288 item_by_coro = {}
@@ -180,127 +306,6 @@ def pytest_runtestloop(session):
180306 else :
181307 regular_items .append (item )
182308
183- def get_coro (task ):
184- if sys_version_info >= (3 , 8 ):
185- return task .get_coro ()
186- else :
187- return task ._coro
188-
189- async def run_tests (tasks , max_tasks : int ):
190- sidelined_tasks = tasks [max_tasks :]
191- tasks = tasks [:max_tasks ]
192-
193- task_timeout = int (
194- session .config .getoption ("--asyncio-task-timeout" )
195- or session .config .getini ("asyncio_task_timeout" )
196- )
197-
198- completed = []
199- cancelled = []
200- while tasks :
201- # Schedule all the coroutines
202- for i in range (len (tasks )):
203- if asyncio .iscoroutine (tasks [i ]):
204- tasks [i ] = asyncio .create_task (tasks [i ])
205-
206- # Mark when the task was started
207- earliest_enqueue_time = time .time ()
208- for task in tasks :
209- if isinstance (task , asyncio .Task ):
210- item = item_by_coro [get_coro (task )]
211- else :
212- item = item_by_coro [task ]
213- if not hasattr (item , "enqueue_time" ):
214- item .enqueue_time = time .time ()
215- earliest_enqueue_time = min (item .enqueue_time , earliest_enqueue_time )
216-
217- time_to_wait = (time .time () - earliest_enqueue_time ) - task_timeout
218- done , pending = await asyncio .wait (
219- tasks ,
220- return_when = asyncio .FIRST_COMPLETED ,
221- timeout = min (30 , int (time_to_wait )),
222- )
223-
224- # Cancel tasks that have taken too long
225- tasks = []
226- for task in pending :
227- now = time .time ()
228- item = item_by_coro [get_coro (task )]
229- if task not in cancelled and task_timeout < now - item .enqueue_time :
230- if sys_version_info >= (3 , 9 ):
231- msg = "Test took too long ({:.2f} s)" .format (
232- now - item .enqueue_time
233- )
234- task .cancel (msg = msg )
235- else :
236- task .cancel ()
237- cancelled .append (task )
238- tasks .append (task )
239-
240- for result in done :
241- item = item_by_coro [get_coro (result )]
242-
243- # Flakey tests will be run again if they failed
244- # TODO: add retry count
245- if item ._flakey :
246- try :
247- result .result ()
248- except :
249- item ._flakey = None
250- new_task = item_to_task (item )
251- flakes_to_retry .append (new_task )
252- item_by_coro [new_task ] = item
253- continue
254-
255- # We need to change .runtest to a synchronous function for pytest
256- # however, if it is called again by retry libraries we need to rerun
257- # the test instead of retuning the previous result
258- def wrap_in_sync ():
259- def sync_wrapper ():
260- new_task = item_to_task (item )
261-
262- # We use a new thread because we can't block for an async function
263- # in the same thread as the current running event loop, nor
264- # we can nest event loops
265- result = None
266-
267- def run_in_thread ():
268- nonlocal result
269- try :
270- result = asyncio .run (new_task )
271- except Exception as e :
272- result = e
273-
274- thread = threading .Thread (target = run_in_thread )
275- thread .start ()
276- thread .join ()
277-
278- if isinstance (result , Exception ):
279- raise result # type: ignore
280-
281- return result
282-
283- item .runtest = sync_wrapper
284-
285- return result .result ()
286-
287- item .runtest = wrap_in_sync
288-
289- item .ihook .pytest_runtest_protocol (item = item , nextitem = None )
290-
291- # Hack: See rewrite comment below
292- # pytest_runttest_protocl will disable the rewrite assertion
293- # so we renable it here
294- activate_assert_rewrite (item )
295-
296- completed .append (result )
297-
298- if sidelined_tasks :
299- if len (tasks ) < max_tasks :
300- tasks .append (sidelined_tasks .pop (0 ))
301-
302- return completed
303-
304309 # Do assert rewrite
305310 # Hack: pytest's implementation sets up assert rewriting as a shared
306311 # resource. This causes a race condition between async tests. Therefore we
@@ -313,11 +318,11 @@ def run_in_thread():
313318 return
314319
315320 # Run the tests using cooperative multitasking
316- _run_test_loop (tasks , session , run_tests )
321+ flakes_to_retry = _run_test_loop (tasks , session , item_by_coro )
317322
318323 # Run failed flakey tests
319324 if flakes_to_retry :
320- _run_test_loop (flakes_to_retry , session , run_tests )
325+ _run_test_loop (flakes_to_retry , session , item_by_coro )
321326
322327 # Run synchronous tests
323328 session .items = regular_items
0 commit comments