Skip to content

Commit 4cd6354

Browse files
authored
Remove UCXAddress.string in favor of __bytes__ (#390)
Currently the Python core API has a `UCXAddress.string` property that returns `bytes`. This is a confusing match, so instead of having that property this now switches to supporting `__bytes__` that can use proper Python API to return the serializable `bytes` object by calling `bytes(ucx_address_object)`. Also removes the `create_from_string` method, as there were two methods that are essentially redundant and might cause more confusion than good. Authors: - Peter Andreas Entschev (https://github.com/pentschev) Approvers: - Mads R. B. Kristensen (https://github.com/madsbk) - https://github.com/jakirkham URL: #390
1 parent 0d29ad8 commit 4cd6354

File tree

2 files changed

+15
-29
lines changed

2 files changed

+15
-29
lines changed

python/ucxx/ucxx/_lib/libucxx.pyx

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -405,29 +405,18 @@ cdef class UCXAddress():
405405
return address
406406

407407
@classmethod
408-
def create_from_string(cls, string address_str) -> UCXAddress:
408+
def create_from_buffer(cls, bytes buf) -> UCXAddress:
409409
cdef UCXAddress address = UCXAddress.__new__(UCXAddress)
410-
cdef string cpp_address_str = address_str
410+
cdef string address_str = string(<const char*>buf, len(buf))
411411

412412
with nogil:
413-
address._address = createAddressFromString(cpp_address_str)
413+
address._address = createAddressFromString(address_str)
414414
address._handle = address._address.get().getHandle()
415415
address._length = address._address.get().getLength()
416416
address._string = address._address.get().getString()
417417

418418
return address
419419

420-
@classmethod
421-
def create_from_buffer(cls, bytes buffer) -> UCXAddress:
422-
cdef string address_str
423-
424-
buf = Array(buffer)
425-
assert buf.c_contiguous
426-
427-
address_str = string(<char*>buf.ptr, <size_t>buf.nbytes)
428-
429-
return UCXAddress.create_from_string(address_str)
430-
431420
# For old UCX-Py API compatibility
432421
@classmethod
433422
def from_worker(cls, UCXWorker worker) -> UCXAddress:
@@ -460,8 +449,7 @@ cdef class UCXAddress():
460449
def length(self) -> int:
461450
return int(self._length)
462451

463-
@property
464-
def string(self) -> bytes:
452+
def __bytes__(self) -> bytes:
465453
return bytes(self._string)
466454

467455
def __getbuffer__(self, Py_buffer *buffer, int flags) -> None:
@@ -489,10 +477,10 @@ cdef class UCXAddress():
489477
pass
490478

491479
def __reduce__(self) -> tuple:
492-
return (UCXAddress.create_from_buffer, (self.string,))
480+
return (UCXAddress.create_from_buffer, (bytes(self),))
493481

494482
def __hash__(self) -> int:
495-
return hash(bytes(self.string))
483+
return hash(bytes(self))
496484

497485

498486
cdef void _generic_callback(void *args) with gil:
Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES.
1+
# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
import multiprocessing as mp
@@ -9,28 +9,26 @@
99
mp = mp.get_context("spawn")
1010

1111

12-
def test_ucx_address_string():
12+
def test_ucx_address_bytes():
1313
ctx = ucx_api.UCXContext()
1414
worker = ucx_api.UCXWorker(ctx)
1515
org_address = worker.address
16-
org_address_str = org_address.string
17-
new_address = ucx_api.UCXAddress.create_from_string(org_address_str)
18-
new_address_str = new_address.string
16+
org_address_bytes = bytes(org_address)
17+
new_address = ucx_api.UCXAddress.create_from_buffer(org_address_bytes)
18+
new_address_bytes = bytes(new_address)
1919
assert hash(org_address) == hash(new_address)
20-
assert bytes(org_address_str) == bytes(new_address_str)
20+
assert org_address_bytes == new_address_bytes
2121

2222

2323
def test_pickle_ucx_address():
2424
ctx = ucx_api.UCXContext()
2525
worker = ucx_api.UCXWorker(ctx)
2626
org_address = worker.address
27-
org_address_str = org_address.string
27+
org_address_bytes = bytes(org_address)
2828
org_address_hash = hash(org_address)
2929
dumped_address = pickle.dumps(org_address)
30-
org_address = bytes(org_address)
3130
new_address = pickle.loads(dumped_address)
32-
new_address_str = new_address.string
31+
new_address_bytes = bytes(new_address)
3332

3433
assert org_address_hash == hash(new_address)
35-
assert bytes(org_address_str) == bytes(new_address_str)
36-
assert bytes(org_address) == bytes(new_address)
34+
assert org_address_bytes == new_address_bytes

0 commit comments

Comments
 (0)