Skip to content

Commit 076aba9

Browse files
authored
Merge pull request #388 from rapidsai/branch-0.43
Forward-merge branch-0.43 into branch-0.44
2 parents a94c6bd + c4a41b6 commit 076aba9

File tree

6 files changed

+147
-73
lines changed

6 files changed

+147
-73
lines changed

python/ucxx/ucxx/_lib_async/tests/conftest.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES.
1+
# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
import asyncio
@@ -70,12 +70,20 @@ def pytest_pyfunc_call(pyfuncitem: pytest.Function):
7070
`pytest.mark.rerun_on_failure(reruns)`. This is similar to `pytest-rerunfailures`,
7171
but that module closes the event loop before this function has awaited, making the
7272
two incompatible.
73+
74+
The timeout value is made available to the test functions via `pytestconfig`. This
75+
can be used to determine internal timeouts, for example to ensure subprocesses
76+
timeout before the test timeout hits and thus prints internal information, such as
77+
the call stack. The timeout value may be retrieved by calling
78+
`pytestconfig.cache.get("asyncio_timeout", {})["timeout"]`, for that the test must
79+
include the `pytestconfig` fixture as argument.
7380
"""
7481
timeout_marker = pyfuncitem.get_closest_marker("asyncio_timeout")
7582
slow_marker = pyfuncitem.get_closest_marker("slow")
7683
rerun_marker = pyfuncitem.get_closest_marker("rerun_on_failure")
7784
default_timeout = 600.0 if slow_marker else 60.0
7885
timeout = float(timeout_marker.args[0]) if timeout_marker else default_timeout
86+
pyfuncitem.config.cache.set("asyncio_timeout", {"timeout": timeout})
7987
if timeout <= 0.0:
8088
raise ValueError("The `pytest.mark.asyncio_timeout` value must be positive.")
8189

python/ucxx/ucxx/_lib_async/tests/test_disconnect.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES.
1+
# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
import asyncio
@@ -12,7 +12,10 @@
1212

1313
import ucxx
1414
from ucxx._lib_async.utils import get_event_loop
15-
from ucxx._lib_async.utils_test import wait_listener_client_handlers
15+
from ucxx._lib_async.utils_test import (
16+
compute_timeouts,
17+
wait_listener_client_handlers,
18+
)
1619
from ucxx.testing import terminate_process
1720

1821
mp = mp.get_context("spawn")
@@ -28,7 +31,7 @@ async def mp_queue_get_nowait(queue):
2831

2932

3033
def _test_shutdown_unexpected_closed_peer_server(
31-
client_queue, server_queue, endpoint_error_handling
34+
client_queue, server_queue, endpoint_error_handling, timeout
3235
):
3336
global ep_alive
3437
ep_alive = None
@@ -61,20 +64,25 @@ async def server_node(ep):
6164

6265
log_stream = StringIO()
6366
logging.basicConfig(stream=log_stream, level=logging.DEBUG)
64-
get_event_loop().run_until_complete(run())
65-
log = log_stream.getvalue()
6667

67-
if endpoint_error_handling is True:
68-
assert ep_alive is False
69-
else:
70-
assert ep_alive
71-
assert log.find("""UCXError('<[Send shutdown]""") != -1
68+
loop = get_event_loop()
69+
try:
70+
loop.run_until_complete(asyncio.wait_for(run(), timeout=timeout))
71+
log = log_stream.getvalue()
7272

73-
ucxx.stop_notifier_thread()
73+
if endpoint_error_handling is True:
74+
assert ep_alive is False
75+
else:
76+
assert ep_alive
77+
assert log.find("""UCXError('<[Send shutdown]""") != -1
78+
finally:
79+
ucxx.stop_notifier_thread()
80+
81+
loop.close()
7482

7583

7684
def _test_shutdown_unexpected_closed_peer_client(
77-
client_queue, server_queue, endpoint_error_handling
85+
client_queue, server_queue, endpoint_error_handling, timeout
7886
):
7987
async def run():
8088
server_port = client_queue.get()
@@ -86,20 +94,25 @@ async def run():
8694
msg = np.empty(100, dtype=np.int64)
8795
await ep.recv(msg)
8896

89-
get_event_loop().run_until_complete(run())
97+
loop = get_event_loop()
98+
try:
99+
loop.run_until_complete(asyncio.wait_for(run(), timeout=timeout))
100+
finally:
101+
ucxx.stop_notifier_thread()
90102

91-
ucxx.stop_notifier_thread()
103+
loop.close()
92104

93105

94106
@pytest.mark.parametrize("endpoint_error_handling", [True, False])
95-
def test_shutdown_unexpected_closed_peer(caplog, endpoint_error_handling):
107+
def test_shutdown_unexpected_closed_peer(pytestconfig, caplog, endpoint_error_handling):
96108
"""
97109
Test clean server shutdown after unexpected peer close
98110
99111
This will causes some UCX warnings to be issued, but this as expected.
100112
The main goal is to assert that the processes exit without errors
101113
despite a somewhat messy initial state.
102114
"""
115+
async_timeout, join_timeout = compute_timeouts(pytestconfig)
103116
if endpoint_error_handling is False:
104117
pytest.xfail(
105118
"Temporarily xfailing, due to https://github.com/rapidsai/ucxx/issues/21"
@@ -120,17 +133,20 @@ def test_shutdown_unexpected_closed_peer(caplog, endpoint_error_handling):
120133
server_queue = mp.Queue()
121134
p1 = mp.Process(
122135
target=_test_shutdown_unexpected_closed_peer_server,
123-
args=(client_queue, server_queue, endpoint_error_handling),
136+
args=(client_queue, server_queue, endpoint_error_handling, async_timeout),
124137
)
125138
p1.start()
126139
p2 = mp.Process(
127140
target=_test_shutdown_unexpected_closed_peer_client,
128-
args=(client_queue, server_queue, endpoint_error_handling),
141+
args=(client_queue, server_queue, endpoint_error_handling, async_timeout),
129142
)
130143
p2.start()
131-
p2.join(timeout=30)
144+
145+
# Increase timeout by an additional 5s to give subprocesses a chance to
146+
# timeout before being forcefully terminated.
147+
p2.join(timeout=join_timeout)
132148
server_queue.put("client is down")
133-
p1.join(timeout=30)
149+
p1.join(timeout=join_timeout)
134150

135151
terminate_process(p2)
136152
terminate_process(p1)

python/ucxx/ucxx/_lib_async/tests/test_from_worker_address.py

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES.
1+
# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
import asyncio
@@ -11,12 +11,13 @@
1111

1212
import ucxx
1313
from ucxx._lib_async.utils import get_event_loop, hash64bits
14+
from ucxx._lib_async.utils_test import compute_timeouts
1415
from ucxx.testing import join_processes, terminate_process
1516

1617
mp = mp.get_context("spawn")
1718

1819

19-
def _test_from_worker_address_server(queue):
20+
def _test_from_worker_address_server(queue, timeout):
2021
async def run():
2122
# Send worker address to client process via multiprocessing.Queue
2223
address = ucxx.get_worker_address()
@@ -40,14 +41,15 @@ async def run():
4041
await ep.close()
4142

4243
loop = get_event_loop()
43-
loop.run_until_complete(run())
44+
try:
45+
loop.run_until_complete(asyncio.wait_for(run(), timeout=timeout))
46+
finally:
47+
ucxx.stop_notifier_thread()
4448

45-
ucxx.stop_notifier_thread()
49+
loop.close()
4650

47-
loop.close()
4851

49-
50-
def _test_from_worker_address_client(queue):
52+
def _test_from_worker_address_client(queue, timeout):
5153
async def run():
5254
# Read local worker address
5355
address = ucxx.get_worker_address()
@@ -69,29 +71,32 @@ async def run():
6971
np.testing.assert_array_equal(recv_msg, np.arange(10, dtype=np.int64))
7072

7173
loop = get_event_loop()
72-
loop.run_until_complete(run())
74+
try:
75+
loop.run_until_complete(asyncio.wait_for(run(), timeout=timeout))
76+
finally:
77+
ucxx.stop_notifier_thread()
7378

74-
ucxx.stop_notifier_thread()
79+
loop.close()
7580

76-
loop.close()
7781

82+
def test_from_worker_address(pytestconfig):
83+
async_timeout, join_timeout = compute_timeouts(pytestconfig)
7884

79-
def test_from_worker_address():
8085
queue = mp.Queue()
8186

8287
server = mp.Process(
8388
target=_test_from_worker_address_server,
84-
args=(queue,),
89+
args=(queue, async_timeout),
8590
)
8691
server.start()
8792

8893
client = mp.Process(
8994
target=_test_from_worker_address_client,
90-
args=(queue,),
95+
args=(queue, async_timeout),
9196
)
9297
client.start()
9398

94-
join_processes([client, server], timeout=30)
99+
join_processes([client, server], timeout=join_timeout)
95100
terminate_process(client)
96101
terminate_process(server)
97102

@@ -157,7 +162,7 @@ def _unpack_address_and_tag(address_packed):
157162
}
158163

159164

160-
def _test_from_worker_address_server_fixedsize(num_nodes, queue):
165+
def _test_from_worker_address_server_fixedsize(num_nodes, queue, timeout):
161166
async def run():
162167
async def _handle_client(packed_remote_address):
163168
# Unpack the fixed-size address+tag buffer
@@ -198,14 +203,15 @@ async def _handle_client(packed_remote_address):
198203
await asyncio.gather(*server_tasks)
199204

200205
loop = get_event_loop()
201-
loop.run_until_complete(run())
202-
203-
ucxx.stop_notifier_thread()
206+
try:
207+
loop.run_until_complete(asyncio.wait_for(run(), timeout=timeout))
208+
finally:
209+
ucxx.stop_notifier_thread()
204210

205-
loop.close()
211+
loop.close()
206212

207213

208-
def _test_from_worker_address_client_fixedsize(queue):
214+
def _test_from_worker_address_client_fixedsize(queue, timeout):
209215
async def run():
210216
# Read local worker address
211217
address = ucxx.get_worker_address()
@@ -232,33 +238,37 @@ async def run():
232238
await ep.send(send_msg, tag=send_tag, force_tag=True)
233239

234240
loop = get_event_loop()
235-
loop.run_until_complete(run())
241+
try:
242+
loop.run_until_complete(asyncio.wait_for(run(), timeout=timeout))
243+
finally:
244+
ucxx.stop_notifier_thread()
236245

237-
ucxx.stop_notifier_thread()
238-
239-
loop.close()
246+
loop.close()
240247

241248

249+
@pytest.mark.slow
242250
@pytest.mark.parametrize("num_nodes", [1, 2, 4, 8])
243-
def test_from_worker_address_multinode(num_nodes):
251+
def test_from_worker_address_multinode(pytestconfig, num_nodes):
252+
async_timeout, join_timeout = compute_timeouts(pytestconfig)
253+
244254
queue = mp.Queue()
245255

246256
server = mp.Process(
247257
target=_test_from_worker_address_server_fixedsize,
248-
args=(num_nodes, queue),
258+
args=(num_nodes, queue, async_timeout),
249259
)
250260
server.start()
251261

252262
clients = []
253263
for i in range(num_nodes):
254264
client = mp.Process(
255265
target=_test_from_worker_address_client_fixedsize,
256-
args=(queue,),
266+
args=(queue, async_timeout),
257267
)
258268
client.start()
259269
clients.append(client)
260270

261-
join_processes(clients + [server], timeout=30)
271+
join_processes(clients + [server], timeout=join_timeout)
262272
for client in clients:
263273
terminate_process(client)
264274
terminate_process(server)

python/ucxx/ucxx/_lib_async/tests/test_from_worker_address_error.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES.
1+
# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
import asyncio
@@ -12,12 +12,13 @@
1212

1313
import ucxx
1414
from ucxx._lib_async.utils import get_event_loop
15+
from ucxx._lib_async.utils_test import compute_timeouts
1516
from ucxx.testing import join_processes, terminate_process
1617

1718
mp = mp.get_context("spawn")
1819

1920

20-
def _test_from_worker_address_error_server(q1, q2, error_type):
21+
def _test_from_worker_address_error_server(q1, q2, error_type, timeout):
2122
async def run():
2223
address = bytearray(ucxx.get_worker_address())
2324

@@ -39,14 +40,15 @@ async def run():
3940
# q1.put("disconnected")
4041

4142
loop = get_event_loop()
42-
loop.run_until_complete(run())
43-
44-
ucxx.stop_notifier_thread()
43+
try:
44+
loop.run_until_complete(asyncio.wait_for(run(), timeout=timeout))
45+
finally:
46+
ucxx.stop_notifier_thread()
4547

46-
loop.close()
48+
loop.close()
4749

4850

49-
def _test_from_worker_address_error_client(q1, q2, error_type):
51+
def _test_from_worker_address_error_client(q1, q2, error_type, timeout):
5052
async def run():
5153
# Receive worker address from server via multiprocessing.Queue
5254
remote_address = ucxx.get_ucx_address_from_buffer(q1.get())
@@ -138,11 +140,12 @@ async def run():
138140
await task
139141

140142
loop = get_event_loop()
141-
loop.run_until_complete(run())
142-
143-
ucxx.stop_notifier_thread()
143+
try:
144+
loop.run_until_complete(asyncio.wait_for(run(), timeout=timeout))
145+
finally:
146+
ucxx.stop_notifier_thread()
144147

145-
loop.close()
148+
loop.close()
146149

147150

148151
@pytest.mark.parametrize(
@@ -164,27 +167,29 @@ async def run():
164167
"UCX_UD_TIMEOUT": "100ms",
165168
},
166169
)
167-
def test_from_worker_address_error(error_type):
170+
def test_from_worker_address_error(pytestconfig, error_type):
171+
async_timeout, join_timeout = compute_timeouts(pytestconfig)
172+
168173
q1 = mp.Queue()
169174
q2 = mp.Queue()
170175

171176
server = mp.Process(
172177
target=_test_from_worker_address_error_server,
173-
args=(q1, q2, error_type),
178+
args=(q1, q2, error_type, async_timeout),
174179
)
175180
server.start()
176181

177182
client = mp.Process(
178183
target=_test_from_worker_address_error_client,
179-
args=(q1, q2, error_type),
184+
args=(q1, q2, error_type, async_timeout),
180185
)
181186
client.start()
182187

183188
if error_type == "unreachable":
184189
server.join()
185190
q1.put("Server closed")
186191

187-
join_processes([client, server], timeout=30)
192+
join_processes([client, server], timeout=join_timeout)
188193
terminate_process(server)
189194
try:
190195
terminate_process(client)

0 commit comments

Comments
 (0)