From c5ec545f44b0e754b2840ba2a83384b43546ea00 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Thu, 31 Oct 2024 16:05:33 -0400 Subject: [PATCH] Dynamically split kernels based on parameter memory --- ext/MultiBroadcastFusionCUDAExt.jl | 100 ++++++++++++++++++++++++----- test/execution/parameter_memory.jl | 18 +----- 2 files changed, 88 insertions(+), 30 deletions(-) diff --git a/ext/MultiBroadcastFusionCUDAExt.jl b/ext/MultiBroadcastFusionCUDAExt.jl index 225746f..64abee7 100644 --- a/ext/MultiBroadcastFusionCUDAExt.jl +++ b/ext/MultiBroadcastFusionCUDAExt.jl @@ -6,23 +6,93 @@ import MultiBroadcastFusion: fused_copyto! MBF.device(x::CUDA.CuArray) = MBF.MBF_CUDA() -function fused_copyto!(fmb::MBF.FusedMultiBroadcast, ::MBF.MBF_CUDA) - (; pairs) = fmb - dest = first(pairs).first - destinations = map(p -> p.first, pairs) - all(a -> axes(a) == axes(dest), destinations) || - error("Cannot fuse broadcast expressions with unequal broadcast axes") - nitems = length(parent(dest)) +function get_param_lim() + config = CUDA.compiler_config(CUDA.device()) + (; ptx, cap) = config.params + return cap >= v"7.0" && ptx >= v"8.1" ? 32764 : 4096 +end + +function get_param_usage(args) + config = CUDA.compiler_config(CUDA.device()) + (; ptx, cap) = config.params + # validate use of parameter memory + param_usages = map(x -> sizeof(typeof(CUDA.cudaconvert(x))), args) + return sum(param_usages) +end + +function fused_multibroadcast_args(fmb::MBF.FusedMultiBroadcast) + dest = first(fmb.pairs).first CI = CartesianIndices(axes(dest)) - kernel = - CUDA.@cuda always_inline = true launch = false fused_copyto_kernel!( - fmb, - CI, + return (fmb, CI) +end + +""" + partition_kernels(fmb; + fused_broadcast_constructor = MBF.FusedMultiBroadcast, + args_func::Function = + ) + +Splits fused broadcast kernels into a vector +of kernels, based on parameter memory limitations. + +We first attempt to fuse + 1:N, 1:N-1, 1:N-2, ... until we fuse 1:N-k +Next, we attempt to fuse + N-k+1:N, N-k+1:N-1, N-k+1:N-2, ... + +And so forth. +""" +function partition_kernels( + fmb, + fused_broadcast_constructor = MBF.FusedMultiBroadcast, + args_func::Function = fused_multibroadcast_args, +) + plim = get_param_lim() + usage = get_param_usage(args_func(fmb)) + n_bins = 1 + fmbs = (fmb,) + usage ≤ plim && return fmbs + fmbs_split = [] + N = length(fmb.pairs) + i_start = 1 + i_stop = N + while i_stop ≠ i_start + ith_pairs = fmb.pairs[i_start:i_stop] + ith_fmb = fused_broadcast_constructor(ith_pairs) + if get_param_usage(args_func(ith_fmb)) ≤ plim # first iteration will likely fail (ambitious) + push!(fmbs_split, ith_fmb) + i_stop == N && break + i_start = i_stop + 1 # N on first iteration + i_stop = N # reset i_stop + else + i_stop = i_stop - 1 + end + end + return fmbs_split +end + +function fused_copyto!(fmb::MBF.FusedMultiBroadcast, ::MBF.MBF_CUDA) + destinations = map(p -> p.first, fmb.pairs) + fmbs = partition_kernels(fmb) + for fmb in fmbs + (; pairs) = fmb + dest = first(pairs).first + dests = map(p -> p.first, pairs) + all(a -> axes(a) == axes(dest), dests) || error( + "Cannot fuse broadcast expressions with unequal broadcast axes", ) - config = CUDA.launch_configuration(kernel.fun) - threads = min(nitems, config.threads) - blocks = cld(nitems, threads) - kernel(fmb, CI; threads, blocks) + nitems = length(parent(dest)) + CI = CartesianIndices(axes(dest)) + kernel = + CUDA.@cuda always_inline = true launch = false fused_copyto_kernel!( + fmb, + CI, + ) + config = CUDA.launch_configuration(kernel.fun) + threads = min(nitems, config.threads) + blocks = cld(nitems, threads) + kernel(fmb, CI; threads, blocks) + end return destinations end import Base.Broadcast diff --git a/test/execution/parameter_memory.jl b/test/execution/parameter_memory.jl index ed2932c..bbf90a0 100644 --- a/test/execution/parameter_memory.jl +++ b/test/execution/parameter_memory.jl @@ -49,9 +49,7 @@ function perf_kernel_shared_reads_fused!(X, Y) @. y3 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + x3 @. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 @. y1 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 - @. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 - @. y3 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + x3 - @. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + @. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 # breaks on A100 due to too much parameter memory end end #! format: on @@ -66,18 +64,8 @@ problem_size = (50, 5, 5, 6, 5400) array_size = problem_size # array X = get_arrays(:x, AType, bm.float_type, array_size) Y = get_arrays(:y, AType, bm.float_type, array_size) -@testset "Test breaking case with parameter memory" begin - if use_cuda - try - perf_kernel_shared_reads_fused!(X, Y) - error("The above kernel should error") - catch e - @test startswith( - e.msg, - "Kernel invocation uses too much parameter memory.", - ) - end - end +@testset "Test kernel splitting with too much parameter memory" begin + use_cuda && perf_kernel_shared_reads_fused!(X, Y) end nothing