Skip to content

Commit

Permalink
Replace resource hashes with unique IDs.
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev committed Oct 4, 2024
1 parent c674649 commit 057acfc
Showing 1 changed file with 47 additions and 21 deletions.
68 changes: 47 additions & 21 deletions python/distributed-ucxx/distributed_ucxx/ucxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from __future__ import annotations

import functools
import itertools
import logging
import os
import struct
Expand Down Expand Up @@ -90,26 +91,51 @@ def synchronize_stream(stream=0):
stream.synchronize()


def _register_dask_resource(resource):
def make_register():
"""Register a Dask resource with the UCXX context.
Register a Dask resource with the UCXX context to keep track of it, so that
the notifier thread and progress tasks may be stopped when no more resources
Register a Dask resource with the UCXX context and keep track of it with the
use of a unique ID for the resource. The resource ID is later used to
deregister the resource from the UCXX context calling
`_deregister_dask_resource(resource_id)`, which stops the notifier thread
and progress tasks when no more UCXX resources are alive.
Returns
-------
resource_id: int
The ID of the registered resource that should be used with
`_deregister_dask_resource` during stop/destruction of the resource.
"""
ctx = ucxx.core._get_ctx()

with ctx._dask_resources_lock:
ctx._dask_resources.add(resource)
ctx.start_notifier_thread()
ctx.continuous_ucx_progress()
count = itertools.count()

def register() -> int:
ctx = ucxx.core._get_ctx()
with ctx._dask_resources_lock:
resource_id = next(count)
ctx._dask_resources.add(resource_id)
ctx.start_notifier_thread()
ctx.continuous_ucx_progress()
return resource_id

return register


_register_dask_resource = make_register()


def _deregister_dask_resource(resource):
def _deregister_dask_resource(resource_id):
"""Deregister a Dask resource with the UCXX context.
Deregister a Dask resource from the UCXX context, and if no resources remain
after deregistration, stop the notifier thread and progress tasks.
need UCXX.
Deregister a Dask resource from the UCXX context with given ID, and if no
resources remain after deregistration, stop the notifier thread and progress
tasks.
Parameters
----------
resource_id: int
The unique ID of the resource returned by `_register_dask_resource` upon
registration.
"""
if ucxx.core._ctx is None:
# Prevent creation of context if it was already destroyed, all
Expand All @@ -123,7 +149,7 @@ def _deregister_dask_resource(resource):
if hasattr(ctx, "_dask_resources_lock"):
with ctx._dask_resources_lock:
try:
ctx._dask_resources.remove(resource)
ctx._dask_resources.remove(resource_id)
except KeyError:
pass

Expand Down Expand Up @@ -344,11 +370,11 @@ def __init__( # type: ignore[no-untyped-def]
else:
self._has_close_callback = False

_register_dask_resource(hash(self))
self._resource_id = _register_dask_resource()

logger.debug("UCX.__init__ %s", self)

weakref.finalize(self, _deregister_dask_resource, hash(self))
weakref.finalize(self, _deregister_dask_resource, self._resource_id)

def __del__(self) -> None:
self.abort()
Expand Down Expand Up @@ -557,7 +583,7 @@ def abort(self):
if self._ep is not None:
self._ep.abort()
self._ep = None
_deregister_dask_resource(hash(self))
_deregister_dask_resource(self._resource_id)

def closed(self):
if self._has_close_callback is True:
Expand Down Expand Up @@ -592,7 +618,7 @@ async def connect(
init_once()

try:
_register_dask_resource(hash(self))
self._resource_id = _register_dask_resource()
ep = await ucxx.create_endpoint(ip, port)
except (
ucxx.exceptions.UCXCloseError,
Expand All @@ -604,7 +630,7 @@ async def connect(
):
raise CommClosedError("Connection closed before handshake completed")
finally:
_deregister_dask_resource(hash(self))
_deregister_dask_resource(self._resource_id)
return self.comm_class(
ep,
local_addr="",
Expand Down Expand Up @@ -662,13 +688,13 @@ async def serve_forever(client_ep):
await self.comm_handler(ucx)

init_once()
_register_dask_resource(hash(self))
weakref.finalize(self, _deregister_dask_resource, hash(self))
self._resource_id = _register_dask_resource()
weakref.finalize(self, _deregister_dask_resource, self._resource_id)
self.ucxx_server = ucxx.create_listener(serve_forever, port=self._input_port)

def stop(self):
self.ucxx_server = None
_deregister_dask_resource(hash(self))
_deregister_dask_resource(self._resource_id)

def get_host_port(self):
# TODO: TCP raises if this hasn't started yet.
Expand Down

0 comments on commit 057acfc

Please sign in to comment.