Skip to content

Commit

Permalink
rmm dynamic_load_runtime can now detect static linking of cudart.
Browse files Browse the repository at this point in the history
No longer needs RMM_STATIC_CUDART to be set for static cudart usages
  • Loading branch information
robertmaynard committed Oct 28, 2024
1 parent 1ebfe0a commit 38da65b
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 37 deletions.
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ target_include_directories(rmm INTERFACE "$<BUILD_INTERFACE:${CMAKE_CURRENT_SOUR
if(CUDA_STATIC_RUNTIME)
message(STATUS "RMM: Enabling static linking of cudart")
target_link_libraries(rmm INTERFACE CUDA::cudart_static)
target_compile_definitions(rmm INTERFACE RMM_STATIC_CUDART)
else()
target_link_libraries(rmm INTERFACE CUDA::cudart)
endif()
Expand Down
80 changes: 46 additions & 34 deletions include/rmm/detail/dynamic_load_runtime.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,38 +69,36 @@ struct dynamic_load_runtime {
template <typename signature>
static std::optional<signature> function(const char* func_name)
{
auto* runtime = get_cuda_runtime_handle();
auto* handle = ::dlsym(runtime, func_name);
if (!handle) { return std::nullopt; }
// query if the function has already been loaded by the program
auto* handle = ::dlsym(RTLD_DEFAULT, func_name);
auto* error = dlerror();

// throw rmm::logic_error{std::string{"dlysm: "} + error};
if (error != nullptr) {
// function hasn't been loaded already, load it from CUDA runtime
auto* runtime = get_cuda_runtime_handle();
handle = ::dlsym(runtime, func_name);
error = dlerror();
}
if (error != nullptr) { return std::nullopt; }
auto* function_ptr = reinterpret_cast<signature>(handle);
return std::optional<signature>(function_ptr);
}
};

#if defined(RMM_STATIC_CUDART)
// clang-format off
#define RMM_CUDART_API_WRAPPER(name, signature) \
template <typename... Args> \
static cudaError_t name(Args... args) \
{ \
_Pragma("GCC diagnostic push") \
_Pragma("GCC diagnostic ignored \"-Waddress\"") \
static_assert(static_cast<signature>(::name), \
"Failed to find #name function with arguments #signature"); \
_Pragma("GCC diagnostic pop") \
return ::name(args...); \
}
// clang-format on
#else
#define RMM_CUDART_API_WRAPPER(name, signature) \
template <typename... Args> \
static cudaError_t name(Args... args) \
{ \
static const auto func = dynamic_load_runtime::function<signature>(#name); \
if (func) { return (*func)(args...); } \
RMM_FAIL("Failed to find #name function in libcudart.so"); \
#define RMM_CUDART_API_WRAPPER(name, signature) \
template <typename... Args> \
static cudaError_t name(Args... args) \
{ \
auto* p = static_cast<signature>(::name); \
if (p != nullptr) { \
return (*p)(args...); \
} else { \
static const auto func = dynamic_load_runtime::function<signature>(#name); \
if (func) { return (*func)(args...); } \
RMM_FAIL("Failed to find #name function in libcudart.so"); \
} \
}
#endif

#if CUDART_VERSION >= 11020 // 11.2 introduced cudaMallocAsync
/**
Expand All @@ -110,17 +108,31 @@ struct dynamic_load_runtime {
* This allows RMM users to compile/link against CUDA 11.2+ and run with
* < CUDA 11.2 runtime as these functions are found at call time.
*/


extern "C" {
cudaError_t cudaMemPoolCreate(cudaMemPool_t*, const cudaMemPoolProps*) __attribute((weak));
cudaError_t cudaMemPoolSetAttribute(cudaMemPool_t, cudaMemPoolAttr, void*) __attribute((weak));
cudaError_t cudaMemPoolDestroy(cudaMemPool_t) __attribute((weak));
cudaError_t cudaMallocFromPoolAsync(void**, size_t, cudaMemPool_t, cudaStream_t)
__attribute((weak));
cudaError_t cudaFreeAsync(void*, cudaStream_t) __attribute((weak));
cudaError_t cudaDeviceGetDefaultMemPool_sig(cudaMemPool_t*, int) __attribute((weak));
}

struct async_alloc {
static bool is_supported()
{
#if defined(RMM_STATIC_CUDART)
static bool runtime_supports_pool = (CUDART_VERSION >= 11020);
#else
static bool runtime_supports_pool =
dynamic_load_runtime::function<dynamic_load_runtime::function_sig<void*, cudaStream_t>>(
"cudaFreeAsync")
.has_value();
#endif
static bool runtime_supports_pool{[] {
using cuda_free_async_sig = dynamic_load_runtime::function_sig<void*, cudaStream_t>;
bool cuda_free_async_supported = true;
auto* p = static_cast<cuda_free_async_sig>(::cudaFreeAsync);
if (p == nullptr) {
cuda_free_async_supported =
dynamic_load_runtime::function<cuda_free_async_sig>("cudaFreeAsync").has_value();
}
return cuda_free_async_supported;
}()};

static auto driver_supports_pool{[] {
int cuda_pool_supported{};
Expand Down
17 changes: 15 additions & 2 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ endfunction()
function(ConfigureTest TEST_NAME)

set(options)
set(one_value GPUS PERCENT)
set(one_value CUDART GPUS PERCENT)
set(multi_value)
cmake_parse_arguments(_RMM_TEST "${options}" "${one_value}" "${multi_value}" ${ARGN})
if(NOT DEFINED _RMM_TEST_GPUS AND NOT DEFINED _RMM_TEST_PERCENT)
Expand All @@ -99,13 +99,23 @@ function(ConfigureTest TEST_NAME)
set(_RMM_TEST_PERCENT 100)
endif()

if(_RMM_TEST_CUDART STREQUAL SHARED)
set(cudart_link_libs $<COMPILE_ONLY:rmm> CUDA::cudart)
elseif(_RMM_TEST_CUDART STREQUAL STATIC)
set(cudart_link_libs $<COMPILE_ONLY:rmm> CUDA::cudart_static)
else()
set(cudart_link_libs rmm)
endif()

# Test with legacy default stream.
ConfigureTestInternal(${TEST_NAME} ${_RMM_TEST_UNPARSED_ARGUMENTS})
target_link_libraries(${TEST_NAME} ${cudart_link_libs})

# Test with per-thread default stream.
string(REGEX REPLACE "_TEST$" "_PTDS_TEST" PTDS_TEST_NAME "${TEST_NAME}")
ConfigureTestInternal("${PTDS_TEST_NAME}" ${_RMM_TEST_UNPARSED_ARGUMENTS})
target_compile_definitions("${PTDS_TEST_NAME}" PUBLIC CUDA_API_PER_THREAD_DEFAULT_STREAM)
target_link_libraries(${PTDS_TEST_NAME} ${cudart_link_libs})

foreach(name ${TEST_NAME} ${PTDS_TEST_NAME} ${NS_TEST_NAME})
rapids_test_add(
Expand All @@ -131,7 +141,10 @@ ConfigureTest(ADAPTOR_TEST mr/device/adaptor_tests.cpp)
ConfigureTest(POOL_MR_TEST mr/device/pool_mr_tests.cpp GPUS 1 PERCENT 100)

# cuda_async mr tests
ConfigureTest(CUDA_ASYNC_MR_TEST mr/device/cuda_async_mr_tests.cpp GPUS 1 PERCENT 60)
ConfigureTest(CUDA_ASYNC_MR_STATIC_CUDART_TEST mr/device/cuda_async_mr_tests.cpp GPUS 1 PERCENT 60
CUDART STATIC)
ConfigureTest(CUDA_ASYNC_MR_SHARED_CUDART_TEST mr/device/cuda_async_mr_tests.cpp GPUS 1 PERCENT 60
CUDART SHARED)

# thrust allocator tests
ConfigureTest(THRUST_ALLOCATOR_TEST mr/device/thrust_allocator_tests.cu GPUS 1 PERCENT 60)
Expand Down

0 comments on commit 38da65b

Please sign in to comment.