Skip to content

Commit

Permalink
Merge branch 'branch-0.41' into sync_some_kvikio_arr_changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jakirkham authored Oct 4, 2024
2 parents 261f24e + 1f4e508 commit 575efac
Show file tree
Hide file tree
Showing 10 changed files with 260 additions and 84 deletions.
82 changes: 62 additions & 20 deletions cpp/python/src/future.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ PyObject* create_future_str = NULL;
PyObject* future_str = NULL;
PyObject* set_exception_str = NULL;
PyObject* set_result_str = NULL;
PyObject* done_str = NULL;
PyObject* cancelled_str = NULL;

static int intern_strings(void)
{
Expand All @@ -32,6 +34,10 @@ static int intern_strings(void)
if (set_exception_str == NULL) { return -1; }
set_result_str = PyUnicode_InternFromString("set_result");
if (set_result_str == NULL) { return -1; }
done_str = PyUnicode_InternFromString("done");
if (done_str == NULL) { return -1; }
cancelled_str = PyUnicode_InternFromString("cancelled");
if (cancelled_str == NULL) { return -1; }
return 0;
}

Expand Down Expand Up @@ -60,18 +66,12 @@ static PyObject* get_asyncio_future_object()
}

asyncio_module = PyImport_Import(asyncio_str);
if (PyErr_Occurred()) {
ucxx_trace_req("ucxx::python::%s, error importing asyncio", __func__);
PyErr_Print();
}
if (PyErr_Occurred()) PyErr_Print();
if (PyErr_Occurred()) ucxx_error("ucxx::python::%s, error importing asyncio", __func__);
if (asyncio_module == NULL) goto finish;

asyncio_future_object = PyObject_GetAttr(asyncio_module, future_str);
if (PyErr_Occurred()) {
ucxx_trace_req("ucxx::python::%s, error getting asyncio.Future method", __func__);
PyErr_Print();
}
if (PyErr_Occurred())
ucxx_error("ucxx::python::%s, error getting asyncio.Future method", __func__);
Py_DECREF(asyncio_module);
if (asyncio_future_object == NULL) { goto finish; }

Expand Down Expand Up @@ -103,13 +103,39 @@ PyObject* create_python_future()
}

result = PyObject_CallFunctionObjArgs(future_object, NULL);
if (PyErr_Occurred()) ucxx_error("ucxx::python::%s, error creating asyncio.Future", __func__);

finish:
PyGILState_Release(state);
return result;
}

PyObject* check_future_state(PyObject* future)
{
PyObject* result = NULL;

PyGILState_STATE state = PyGILState_Ensure();

result = PyObject_CallMethodNoArgs(future, cancelled_str);
if (PyErr_Occurred()) {
ucxx_trace_req("ucxx::python::%s, error creating asyncio.Future", __func__);
PyErr_Print();
ucxx_error("ucxx::python::%s, error calling `cancelled()` from `asyncio.Future` object",
__func__);
} else if (PyObject_IsTrue(result)) {
ucxx_trace_req("ucxx::python::%s, `asyncio.Future` object has been cancelled.", __func__);
goto finish;
}

result = PyObject_CallMethodNoArgs(future, done_str);
if (PyErr_Occurred()) {
ucxx_error("ucxx::python::%s, error calling `done()` from `asyncio.Future` object", __func__);
} else if (PyObject_IsTrue(result)) {
ucxx_trace_req("ucxx::python::%s, `asyncio.Future` object is already done.", __func__);
goto finish;
}

finish:
PyGILState_Release(state);

return result;
}

Expand All @@ -119,13 +145,22 @@ PyObject* future_set_result(PyObject* future, PyObject* value)

PyGILState_STATE state = PyGILState_Ensure();

if (PyObject_IsTrue(check_future_state(future))) {
ucxx_trace_req(
"ucxx::python::%s, `asyncio.Future` object is already done or has been cancelled, "
"skipping `set_result()`.",
__func__);
goto finish;
}

result = PyObject_CallMethodOneArg(future, set_result_str, value);
if (PyErr_Occurred()) {
ucxx_trace_req("ucxx::python::%s, error calling `set_result()` from `asyncio.Future` object",
__func__);
ucxx_error("ucxx::python::%s, error calling `set_result()` from `asyncio.Future` object",
__func__);
PyErr_Print();
}

finish:
PyGILState_Release(state);

return result;
Expand All @@ -140,6 +175,14 @@ PyObject* future_set_exception(PyObject* future, PyObject* exception, const char

PyGILState_STATE state = PyGILState_Ensure();

if (PyObject_IsTrue(check_future_state(future))) {
ucxx_trace_req(
"ucxx::python::%s, `asyncio.Future` object is already done or has been cancelled, "
"skipping `set_exception()`.",
__func__);
goto finish;
}

message_object = PyUnicode_FromString(message);
if (message_object == NULL) goto err;
message_tuple = PyTuple_Pack(1, message_object);
Expand Down Expand Up @@ -174,8 +217,7 @@ PyObject* create_python_future_with_event_loop(PyObject* event_loop)

result = PyObject_CallMethodObjArgs(event_loop, create_future_str, NULL);
if (PyErr_Occurred()) {
ucxx_trace_req("ucxx::python::%s, error calling `create_future` from event loop object",
__func__);
ucxx_error("ucxx::python::%s, error calling `create_future` from event loop object", __func__);
PyErr_Print();
}

Expand All @@ -198,8 +240,8 @@ PyObject* future_set_result_with_event_loop(PyObject* event_loop, PyObject* futu

set_result_callable = PyObject_GetAttr(future, set_result_str);
if (PyErr_Occurred()) {
ucxx_trace_req(
"ucxx::python::%s, error getting `set_result` method from `asyncio.Future` object", __func__);
ucxx_error("ucxx::python::%s, error getting `set_result` method from `asyncio.Future` object",
__func__);
PyErr_Print();
goto finish;
}
Expand All @@ -214,7 +256,7 @@ PyObject* future_set_result_with_event_loop(PyObject* event_loop, PyObject* futu
result = PyObject_CallMethodObjArgs(
event_loop, call_soon_threadsafe_str, set_result_callable, value, NULL);
if (PyErr_Occurred()) {
ucxx_trace_req(
ucxx_error(
"ucxx::python::%s, error calling `call_soon_threadsafe` from event loop object to set future "
"result",
__func__);
Expand Down Expand Up @@ -247,7 +289,7 @@ PyObject* future_set_exception_with_event_loop(PyObject* event_loop,

set_exception_callable = PyObject_GetAttr(future, set_exception_str);
if (PyErr_Occurred()) {
ucxx_trace_req(
ucxx_error(
"ucxx::python::%s, Error getting `set_exception` method from `asyncio.Future` object",
__func__);
PyErr_Print();
Expand All @@ -271,7 +313,7 @@ PyObject* future_set_exception_with_event_loop(PyObject* event_loop,
result = PyObject_CallMethodObjArgs(
event_loop, call_soon_threadsafe_str, set_exception_callable, formed_exception, NULL);
if (PyErr_Occurred()) {
ucxx_trace_req(
ucxx_error(
"ucxx::python::%s, Error calling `call_soon_threadsafe` from event loop object to set future "
"exception",
__func__);
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/utils/callback_notifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ bool CallbackNotifier::wait(uint64_t period,
bool ret = false;
for (size_t i = 0; i < attempts; ++i) {
ret = _conditionVariable.wait_for(
lock, std::chrono::duration<uint64_t, std::nano>(period), [this]() {
lock, std::chrono::duration<uint64_t, std::nano>(signalInterval), [this]() {
return _flag.load(std::memory_order_relaxed) == true;
});
if (signalWorkerFunction) signalWorkerFunction();
Expand Down
43 changes: 43 additions & 0 deletions cpp/tests/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,49 @@ TEST_P(RequestTest, TagUserCallback)

for (const auto request : requests)
ASSERT_THAT(request->getStatus(), UCS_OK);
for (const auto status : requestStatus)
ASSERT_THAT(status, UCS_OK);

// Assert data correctness
ASSERT_THAT(_recv[0], ContainerEq(_send[0]));
}

TEST_P(RequestTest, TagUserCallbackDiscardReturn)
{
allocate();

std::vector<ucs_status_t> requestStatus(2, UCS_INPROGRESS);

auto checkStatus = [&requestStatus](ucs_status_t status, ::ucxx::RequestCallbackUserData data) {
auto idx = *std::static_pointer_cast<size_t>(data);
requestStatus[idx] = status;
};

auto checkCompletion = [&requestStatus, this]() {
std::vector<size_t> completed(2, 0);
while (std::accumulate(completed.begin(), completed.end(), 0) != 2) {
_progressWorker();
std::transform(
requestStatus.begin(), requestStatus.end(), completed.begin(), [](ucs_status_t status) {
return status == UCS_INPROGRESS ? 0 : 1;
});
}
};

auto sendIndex = std::make_shared<size_t>(0u);
auto recvIndex = std::make_shared<size_t>(1u);

// Submit and wait for transfers to complete
std::ignore =
_ep->tagSend(_sendPtr[0], _messageSize, ucxx::Tag{0}, false, checkStatus, sendIndex);
std::ignore = _ep->tagRecv(
_recvPtr[0], _messageSize, ucxx::Tag{0}, ucxx::TagMaskFull, false, checkStatus, recvIndex);
checkCompletion();

copyResults();

for (const auto status : requestStatus)
ASSERT_THAT(status, UCS_OK);

// Assert data correctness
ASSERT_THAT(_recv[0], ContainerEq(_send[0]));
Expand Down
1 change: 0 additions & 1 deletion python/distributed-ucxx/distributed_ucxx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .ucxx import UCXXBackend, UCXXConnector, UCXXListener # noqa: F401
from . import distributed_patches # noqa: F401


from ._version import __git_commit__, __version__
41 changes: 0 additions & 41 deletions python/distributed-ucxx/distributed_ucxx/distributed_patches.py

This file was deleted.

Loading

0 comments on commit 575efac

Please sign in to comment.