Skip to content

Commit

Permalink
Merge branch 'branch-0.41' into python-async-blocking-mode
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev authored Oct 4, 2024
2 parents 5769f31 + c5f5583 commit 4f7a7f2
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 18 deletions.
1 change: 1 addition & 0 deletions python/distributed-ucxx/distributed_ucxx/ucxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ async def connect(
ucxx.exceptions.UCXCloseError,
ucxx.exceptions.UCXCanceledError,
ucxx.exceptions.UCXConnectionResetError,
ucxx.exceptions.UCXMessageTruncatedError,
ucxx.exceptions.UCXNotConnectedError,
ucxx.exceptions.UCXUnreachableError,
):
Expand Down
4 changes: 2 additions & 2 deletions python/libucxx/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ readme = { file = "README.md", content-type = "text/markdown" }
authors = [
{ name = "NVIDIA Corporation" },
]
license = { text = "Apache 2.0" }
license = { text = "BSD-3-Clause" }
classifiers = [
"Intended Audience :: Developers",
"Intended Audience :: System Administrators",
Expand Down Expand Up @@ -64,4 +64,4 @@ requires = [
"librmm==24.12.*,>=0.0.0a0",
"libucx==1.15.0",
"ninja",
] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`.
] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`.
25 changes: 18 additions & 7 deletions python/ucxx/ucxx/_lib/tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,13 @@ def _listener_handler(conn_request):
while ep[0] is None:
worker.progress()

wireup_msg = Array(bytearray(WireupMessageSize))
wireup_request = ep[0].tag_recv(wireup_msg, tag=ucx_api.UCXXTag(0))
wait_requests(worker, "blocking", wireup_request)
wireup_msg_recv = Array(bytearray(WireupMessageSize))
wireup_msg_send = Array(bytes(os.urandom(WireupMessageSize)))
wireup_requests = [
ep[0].tag_recv(wireup_msg_recv, tag=ucx_api.UCXXTag(0)),
ep[0].tag_send(wireup_msg_send, tag=ucx_api.UCXXTag(0)),
]
wait_requests(worker, "blocking", wireup_requests)

if server_close_callback is True:
while closed[0] is False:
Expand All @@ -72,13 +76,20 @@ def _client(port, server_close_callback):
port,
endpoint_error_handling=True,
)
worker.progress()
wireup_msg = Array(bytes(os.urandom(WireupMessageSize)))
wireup_request = ep.tag_send(wireup_msg, tag=ucx_api.UCXXTag(0))
wait_requests(worker, "blocking", wireup_request)
if server_close_callback is False:
closed = [False]
ep.set_close_callback(_close_callback, cb_args=(closed,))
worker.progress()

wireup_msg_send = Array(bytes(os.urandom(WireupMessageSize)))
wireup_msg_recv = Array(bytearray(WireupMessageSize))
wireup_requests = [
ep.tag_send(wireup_msg_send, tag=ucx_api.UCXXTag(0)),
ep.tag_recv(wireup_msg_recv, tag=ucx_api.UCXXTag(0)),
]
wait_requests(worker, "blocking", wireup_requests)

if server_close_callback is False:
while closed[0] is False:
worker.progress()

Expand Down
24 changes: 17 additions & 7 deletions python/ucxx/ucxx/_lib/tests/test_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def _echo_server(get_queue, put_queue, transfer_api, msg_size, progress_mode):
we keep a reference to the listener's endpoint and execute transfers
outside of the callback function.
"""
# TAG is always used for wireup
feature_flags = [ucx_api.Feature.WAKEUP]
if transfer_api == "am":
feature_flags.append(ucx_api.Feature.AM)
Expand Down Expand Up @@ -75,9 +76,13 @@ def _listener_handler(conn_request):
if progress_mode == "blocking":
worker.progress()

wireup_msg = Array(bytearray(WireupMessageSize))
wireup_request = _recv(ep[0], transfer_api, wireup_msg)
wait_requests(worker, progress_mode, wireup_request)
wireup_msg_recv = Array(bytearray(WireupMessageSize))
wireup_msg_send = Array(bytes(os.urandom(WireupMessageSize)))
wireup_requests = [
_recv(ep[0], transfer_api, wireup_msg_recv),
_send(ep[0], transfer_api, wireup_msg_send),
]
wait_requests(worker, progress_mode, wireup_requests)

msg = Array(bytearray(msg_size))

Expand Down Expand Up @@ -110,10 +115,11 @@ def _listener_handler(conn_request):


def _echo_client(transfer_api, msg_size, progress_mode, port):
# TAG is always used for wireup
feature_flags = [ucx_api.Feature.WAKEUP]
if transfer_api == "am":
feature_flags.append(ucx_api.Feature.AM)
if transfer_api == "stream":
elif transfer_api == "stream":
feature_flags.append(ucx_api.Feature.STREAM)
else:
feature_flags.append(ucx_api.Feature.TAG)
Expand All @@ -136,9 +142,13 @@ def _echo_client(transfer_api, msg_size, progress_mode, port):
if progress_mode == "blocking":
worker.progress()

wireup_msg = Array(bytes(os.urandom(WireupMessageSize)))
wireup_request = _send(ep, transfer_api, wireup_msg)
wait_requests(worker, progress_mode, wireup_request)
wireup_msg_send = Array(bytes(os.urandom(WireupMessageSize)))
wireup_msg_recv = Array(bytearray(WireupMessageSize))
wireup_requests = [
_send(ep, transfer_api, wireup_msg_send),
_recv(ep, transfer_api, wireup_msg_recv),
]
wait_requests(worker, progress_mode, wireup_requests)

send_msg = bytes(os.urandom(msg_size))
recv_msg = bytearray(msg_size)
Expand Down
7 changes: 5 additions & 2 deletions python/ucxx/ucxx/_lib_async/application_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,11 +368,14 @@ async def create_endpoint(
listener=False,
stream_timeout=exchange_peer_info_timeout,
)
except UCXMessageTruncatedError:
except UCXMessageTruncatedError as e:
# A truncated message occurs if the remote endpoint closed before
# exchanging peer info, in that case we should raise the endpoint
# error instead.
# error, if available.
ucx_ep.raise_on_error()
# If no endpoint error is available, re-raise exception.
raise e

tags = {
"msg_send": peer_info["msg_tag"],
"msg_recv": msg_tag,
Expand Down

0 comments on commit 4f7a7f2

Please sign in to comment.