From 057acfc3b5793cffc0ea4dd50983832ab9a3f3de Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 4 Oct 2024 02:56:11 -0700 Subject: [PATCH] Replace resource hashes with unique IDs. --- .../distributed-ucxx/distributed_ucxx/ucxx.py | 68 +++++++++++++------ 1 file changed, 47 insertions(+), 21 deletions(-) diff --git a/python/distributed-ucxx/distributed_ucxx/ucxx.py b/python/distributed-ucxx/distributed_ucxx/ucxx.py index a9b99f3b..8627dec7 100644 --- a/python/distributed-ucxx/distributed_ucxx/ucxx.py +++ b/python/distributed-ucxx/distributed_ucxx/ucxx.py @@ -8,6 +8,7 @@ from __future__ import annotations import functools +import itertools import logging import os import struct @@ -90,26 +91,51 @@ def synchronize_stream(stream=0): stream.synchronize() -def _register_dask_resource(resource): +def make_register(): """Register a Dask resource with the UCXX context. - Register a Dask resource with the UCXX context to keep track of it, so that - the notifier thread and progress tasks may be stopped when no more resources + Register a Dask resource with the UCXX context and keep track of it with the + use of a unique ID for the resource. The resource ID is later used to + deregister the resource from the UCXX context calling + `_deregister_dask_resource(resource_id)`, which stops the notifier thread + and progress tasks when no more UCXX resources are alive. + + Returns + ------- + resource_id: int + The ID of the registered resource that should be used with + `_deregister_dask_resource` during stop/destruction of the resource. """ - ctx = ucxx.core._get_ctx() - with ctx._dask_resources_lock: - ctx._dask_resources.add(resource) - ctx.start_notifier_thread() - ctx.continuous_ucx_progress() + count = itertools.count() + + def register() -> int: + ctx = ucxx.core._get_ctx() + with ctx._dask_resources_lock: + resource_id = next(count) + ctx._dask_resources.add(resource_id) + ctx.start_notifier_thread() + ctx.continuous_ucx_progress() + return resource_id + + return register + + +_register_dask_resource = make_register() -def _deregister_dask_resource(resource): +def _deregister_dask_resource(resource_id): """Deregister a Dask resource with the UCXX context. - Deregister a Dask resource from the UCXX context, and if no resources remain - after deregistration, stop the notifier thread and progress tasks. - need UCXX. + Deregister a Dask resource from the UCXX context with given ID, and if no + resources remain after deregistration, stop the notifier thread and progress + tasks. + + Parameters + ---------- + resource_id: int + The unique ID of the resource returned by `_register_dask_resource` upon + registration. """ if ucxx.core._ctx is None: # Prevent creation of context if it was already destroyed, all @@ -123,7 +149,7 @@ def _deregister_dask_resource(resource): if hasattr(ctx, "_dask_resources_lock"): with ctx._dask_resources_lock: try: - ctx._dask_resources.remove(resource) + ctx._dask_resources.remove(resource_id) except KeyError: pass @@ -344,11 +370,11 @@ def __init__( # type: ignore[no-untyped-def] else: self._has_close_callback = False - _register_dask_resource(hash(self)) + self._resource_id = _register_dask_resource() logger.debug("UCX.__init__ %s", self) - weakref.finalize(self, _deregister_dask_resource, hash(self)) + weakref.finalize(self, _deregister_dask_resource, self._resource_id) def __del__(self) -> None: self.abort() @@ -557,7 +583,7 @@ def abort(self): if self._ep is not None: self._ep.abort() self._ep = None - _deregister_dask_resource(hash(self)) + _deregister_dask_resource(self._resource_id) def closed(self): if self._has_close_callback is True: @@ -592,7 +618,7 @@ async def connect( init_once() try: - _register_dask_resource(hash(self)) + self._resource_id = _register_dask_resource() ep = await ucxx.create_endpoint(ip, port) except ( ucxx.exceptions.UCXCloseError, @@ -604,7 +630,7 @@ async def connect( ): raise CommClosedError("Connection closed before handshake completed") finally: - _deregister_dask_resource(hash(self)) + _deregister_dask_resource(self._resource_id) return self.comm_class( ep, local_addr="", @@ -662,13 +688,13 @@ async def serve_forever(client_ep): await self.comm_handler(ucx) init_once() - _register_dask_resource(hash(self)) - weakref.finalize(self, _deregister_dask_resource, hash(self)) + self._resource_id = _register_dask_resource() + weakref.finalize(self, _deregister_dask_resource, self._resource_id) self.ucxx_server = ucxx.create_listener(serve_forever, port=self._input_port) def stop(self): self.ucxx_server = None - _deregister_dask_resource(hash(self)) + _deregister_dask_resource(self._resource_id) def get_host_port(self): # TODO: TCP raises if this hasn't started yet.