diff --git a/include/rmm/mr/device/detail/stream_ordered_memory_resource.hpp b/include/rmm/mr/device/detail/stream_ordered_memory_resource.hpp index c7c8d9178..26abbf747 100644 --- a/include/rmm/mr/device/detail/stream_ordered_memory_resource.hpp +++ b/include/rmm/mr/device/detail/stream_ordered_memory_resource.hpp @@ -32,6 +32,8 @@ namespace rmm::mr::detail { +using expand_callback_t = std::function; + /** * @brief A CRTP helper function * @@ -81,6 +83,12 @@ class stream_ordered_memory_resource : public crtp, public device_ stream_ordered_memory_resource& operator=(stream_ordered_memory_resource const&) = delete; stream_ordered_memory_resource& operator=(stream_ordered_memory_resource&&) = delete; + void set_expand_callback(expand_callback_t func, void* args) + { + expand_callback_ = func; + expand_callback_args_ = args; + } + protected: using free_list = FreeListType; using block_type = typename free_list::block_type; @@ -203,7 +211,7 @@ class stream_ordered_memory_resource : public crtp, public device_ if (size <= 0) { return nullptr; } - lock_guard lock(mtx_); + // lock_guard lock(mtx_); auto stream_event = get_event(stream); @@ -236,7 +244,7 @@ class stream_ordered_memory_resource : public crtp, public device_ if (size <= 0 || ptr == nullptr) { return; } - lock_guard lock(mtx_); + // lock_guard lock(mtx_); auto stream_event = get_event(stream); size = rmm::align_up(size, rmm::CUDA_ALLOCATION_ALIGNMENT); @@ -318,14 +326,7 @@ class stream_ordered_memory_resource : public crtp, public device_ return allocated; } - /** - * @brief Get an available memory block of at least `size` bytes - * - * @param size The number of bytes to allocate - * @param stream_event The stream and associated event on which the allocation will be used. - * @return block_type A block of memory of at least `size` bytes - */ - block_type get_block(std::size_t size, stream_event_pair stream_event) + block_type get_block_no_expand(std::size_t size, stream_event_pair stream_event) { // Try to find a satisfactory block in free list for the same stream (no sync required) auto iter = stream_free_blocks_.find(stream_event); @@ -344,13 +345,39 @@ class stream_ordered_memory_resource : public crtp, public device_ } // no large enough blocks available on other streams, so sync and merge until we find one + return get_block_from_other_stream(size, stream_event, blocks, true); + } + + /** + * @brief Get an available memory block of at least `size` bytes + * + * @param size The number of bytes to allocate + * @param stream_event The stream and associated event on which the allocation will be used. + * @return block_type A block of memory of at least `size` bytes + */ + block_type get_block(std::size_t size, stream_event_pair stream_event) + { { - block_type const block = get_block_from_other_stream(size, stream_event, blocks, true); + block_type const block = get_block_no_expand(size, stream_event); if (block.is_valid()) { return block; } } - log_summary_trace(); + // std::cout << "get_block(no free blocks) - size: " << size << std::endl; + if (expand_callback_.has_value()) { + while (expand_callback_.value()(size, expand_callback_args_)) { + // Let's try one more time + // std::cout << "get_block_no_expand() - size: " << size << std::endl; + block_type const block = get_block_no_expand(size, stream_event); + if (block.is_valid()) { return block; } + } + } + + auto iter = stream_free_blocks_.find(stream_event); + free_list& blocks = + (iter != stream_free_blocks_.end()) ? iter->second : stream_free_blocks_[stream_event]; + log_summary_trace(); + // std::cout << "get_block(spilling didn't help) - size: " << size << std::endl; // no large enough blocks available after merging, so grow the pool block_type const block = this->underlying().expand_pool(size, blocks, cuda_stream_view{stream_event.stream}); @@ -488,6 +515,10 @@ class stream_ordered_memory_resource : public crtp, public device_ std::mutex mtx_; // mutex for thread-safe access rmm::cuda_device_id device_id_{rmm::get_current_cuda_device()}; + + std::optional expand_callback_; + void* expand_callback_args_ = nullptr; + }; // namespace detail } // namespace rmm::mr::detail diff --git a/include/rmm/mr/device/managed_memory_resource.hpp b/include/rmm/mr/device/managed_memory_resource.hpp index ea43315de..c5bc059cc 100644 --- a/include/rmm/mr/device/managed_memory_resource.hpp +++ b/include/rmm/mr/device/managed_memory_resource.hpp @@ -62,6 +62,14 @@ class managed_memory_resource final : public device_memory_resource { void* ptr{nullptr}; RMM_CUDA_TRY_ALLOC(cudaMallocManaged(&ptr, bytes)); + + static bool has_not_prefetched_once_before = true; + if (has_not_prefetched_once_before) { + has_not_prefetched_once_before = false; + std::cout << "do_allocate(managed) - prefetched to device bytes: " << bytes << std::endl; + RMM_CUDA_TRY_ALLOC( + cudaMemPrefetchAsync(ptr, bytes, get_current_cuda_device().value(), stream)); + } return ptr; } diff --git a/python/rmm/_lib/memory_resource.pxd b/python/rmm/_lib/memory_resource.pxd index f9c2e91de..a65094ddf 100644 --- a/python/rmm/_lib/memory_resource.pxd +++ b/python/rmm/_lib/memory_resource.pxd @@ -51,6 +51,7 @@ cdef class CudaAsyncMemoryResource(DeviceMemoryResource): pass cdef class PoolMemoryResource(UpstreamResourceAdaptor): + cdef object _callback pass cdef class FixedSizeMemoryResource(UpstreamResourceAdaptor): diff --git a/python/rmm/_lib/memory_resource.pyx b/python/rmm/_lib/memory_resource.pyx index 100d18b56..78191da06 100644 --- a/python/rmm/_lib/memory_resource.pyx +++ b/python/rmm/_lib/memory_resource.pyx @@ -112,6 +112,10 @@ cdef extern from "rmm/mr/device/cuda_async_memory_resource.hpp" \ cdef extern from "rmm/cuda_device.hpp" namespace "rmm" nogil: size_t percent_of_free_device_memory(int percent) except + +cdef extern from "rmm/mr/device/pool_memory_resource.hpp" \ + namespace "rmm::mr::detail" nogil: + ctypedef bool (*expand_callback_t)(size_t, void*) + cdef extern from "rmm/mr/device/pool_memory_resource.hpp" \ namespace "rmm::mr" nogil: cdef cppclass pool_memory_resource[Upstream](device_memory_resource): @@ -120,6 +124,7 @@ cdef extern from "rmm/mr/device/pool_memory_resource.hpp" \ size_t initial_pool_size, optional[size_t] maximum_pool_size) except + size_t pool_size() + void set_expand_callback(expand_callback_t func, void* args) except + cdef extern from "rmm/mr/device/fixed_size_memory_resource.hpp" \ namespace "rmm::mr" nogil: @@ -216,6 +221,19 @@ cdef extern from "rmm/mr/device/failure_callback_resource_adaptor.hpp" \ ) except + +# Note that this function is specifically designed to rethrow Python exceptions +# as C++ exceptions when called as a callback from C++, so it is noexcept from +# Cython's perspective. +cdef bool _retry_callback_function(size_t nbytes, void *callback_arg) noexcept nogil: + cdef CppExcept err + with gil: + try: + return (callback_arg)(nbytes) + except BaseException as e: + err = translate_python_except_to_cpp(e) + throw_cpp_except(err) + + cdef class DeviceMemoryResource: cdef device_memory_resource* get_mr(self) noexcept nogil: @@ -414,6 +432,16 @@ cdef class PoolMemoryResource(UpstreamResourceAdaptor): ) return c_mr.pool_size() + def set_expand_callback(self, object callback): + cdef pool_memory_resource[device_memory_resource]* c_mr = ( + (self.get_mr()) + ) + self._callback = callback + c_mr.set_expand_callback( + _retry_callback_function, + callback + ) + cdef class FixedSizeMemoryResource(UpstreamResourceAdaptor): def __cinit__( self, @@ -892,19 +920,6 @@ cdef class TrackingResourceAdaptor(UpstreamResourceAdaptor): self.c_obj.get()))[0].log_outstanding_allocations() -# Note that this function is specifically designed to rethrow Python exceptions -# as C++ exceptions when called as a callback from C++, so it is noexcept from -# Cython's perspective. -cdef bool _oom_callback_function(size_t bytes, void *callback_arg) noexcept nogil: - cdef CppExcept err - with gil: - try: - return (callback_arg)(bytes) - except BaseException as e: - err = translate_python_except_to_cpp(e) - throw_cpp_except(err) - - cdef class FailureCallbackResourceAdaptor(UpstreamResourceAdaptor): def __cinit__( @@ -916,7 +931,7 @@ cdef class FailureCallbackResourceAdaptor(UpstreamResourceAdaptor): self.c_obj.reset( new failure_callback_resource_adaptor[device_memory_resource]( upstream_mr.get_mr(), - _oom_callback_function, + _retry_callback_function, callback ) )