diff --git a/ucp/_libs/ucx_api.pyx b/ucp/_libs/ucx_api.pyx index c8f896774..860e45f39 100644 --- a/ucp/_libs/ucx_api.pyx +++ b/ucp/_libs/ucx_api.pyx @@ -189,7 +189,7 @@ cdef class UCXContext(UCXObject): ucp_context_h _handle dict _config - def __init__(self, config_dict): + def __init__(self, config_dict, blocking_progress_mode=False): cdef ucp_params_t ucp_params cdef ucp_worker_params_t worker_params cdef ucs_status_t status @@ -199,11 +199,17 @@ cdef class UCXContext(UCXObject): UCP_PARAM_FIELD_REQUEST_SIZE | # noqa UCP_PARAM_FIELD_REQUEST_INIT) - # We always request UCP_FEATURE_WAKEUP even when in blocking mode - # See - ucp_params.features = (UCP_FEATURE_TAG | # noqa - UCP_FEATURE_WAKEUP | # noqa - UCP_FEATURE_STREAM) + # We only enable UCP_FEATURE_WAKEUP on blocking mode. This is + # required for shared memory, which is currently only supported + # by non-blocking mode, and that doesn't implement UCP_FEATURE_WAKEUP. + # See + if blocking_progress_mode is True: + ucp_params.features = (UCP_FEATURE_TAG | # noqa + UCP_FEATURE_WAKEUP | # noqa + UCP_FEATURE_STREAM) + else: + ucp_params.features = (UCP_FEATURE_TAG | # noqa + UCP_FEATURE_STREAM) ucp_params.request_size = sizeof(ucx_py_request) ucp_params.request_init = ( diff --git a/ucp/core.py b/ucp/core.py index e29c65c0a..1f0730190 100644 --- a/ucp/core.py +++ b/ucp/core.py @@ -229,10 +229,6 @@ class ApplicationContext: def __init__(self, config_dict={}, blocking_progress_mode=None): self.progress_tasks = [] - # For now, a application context only has one worker - self.context = ucx_api.UCXContext(config_dict) - self.worker = ucx_api.UCXWorker(self.context) - if blocking_progress_mode is not None: self.blocking_progress_mode = blocking_progress_mode elif "UCXPY_NON_BLOCKING_MODE" in os.environ: @@ -240,6 +236,12 @@ def __init__(self, config_dict={}, blocking_progress_mode=None): else: self.blocking_progress_mode = True + # For now, a application context only has one worker + self.context = ucx_api.UCXContext( + config_dict, blocking_progress_mode=self.blocking_progress_mode + ) + self.worker = ucx_api.UCXWorker(self.context) + if self.blocking_progress_mode: self.epoll_fd = self.worker.init_blocking_progress_mode() weakref.finalize(