Skip to content

Commit

Permalink
Dynamically split kernels based on parameter memory
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Oct 31, 2024
1 parent 041fdee commit c5ec545
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 30 deletions.
100 changes: 85 additions & 15 deletions ext/MultiBroadcastFusionCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,93 @@ import MultiBroadcastFusion: fused_copyto!

MBF.device(x::CUDA.CuArray) = MBF.MBF_CUDA()

function fused_copyto!(fmb::MBF.FusedMultiBroadcast, ::MBF.MBF_CUDA)
(; pairs) = fmb
dest = first(pairs).first
destinations = map(p -> p.first, pairs)
all(a -> axes(a) == axes(dest), destinations) ||
error("Cannot fuse broadcast expressions with unequal broadcast axes")
nitems = length(parent(dest))
function get_param_lim()
config = CUDA.compiler_config(CUDA.device())
(; ptx, cap) = config.params
return cap >= v"7.0" && ptx >= v"8.1" ? 32764 : 4096
end

function get_param_usage(args)
config = CUDA.compiler_config(CUDA.device())
(; ptx, cap) = config.params
# validate use of parameter memory
param_usages = map(x -> sizeof(typeof(CUDA.cudaconvert(x))), args)
return sum(param_usages)
end

function fused_multibroadcast_args(fmb::MBF.FusedMultiBroadcast)
dest = first(fmb.pairs).first
CI = CartesianIndices(axes(dest))
kernel =
CUDA.@cuda always_inline = true launch = false fused_copyto_kernel!(
fmb,
CI,
return (fmb, CI)
end

"""
partition_kernels(fmb;
fused_broadcast_constructor = MBF.FusedMultiBroadcast,
args_func::Function =
)
Splits fused broadcast kernels into a vector
of kernels, based on parameter memory limitations.
We first attempt to fuse
1:N, 1:N-1, 1:N-2, ... until we fuse 1:N-k
Next, we attempt to fuse
N-k+1:N, N-k+1:N-1, N-k+1:N-2, ...
And so forth.
"""
function partition_kernels(
fmb,
fused_broadcast_constructor = MBF.FusedMultiBroadcast,
args_func::Function = fused_multibroadcast_args,
)
plim = get_param_lim()
usage = get_param_usage(args_func(fmb))
n_bins = 1
fmbs = (fmb,)
usage plim && return fmbs
fmbs_split = []
N = length(fmb.pairs)
i_start = 1
i_stop = N
while i_stop i_start
ith_pairs = fmb.pairs[i_start:i_stop]
ith_fmb = fused_broadcast_constructor(ith_pairs)
if get_param_usage(args_func(ith_fmb)) plim # first iteration will likely fail (ambitious)
push!(fmbs_split, ith_fmb)
i_stop == N && break
i_start = i_stop + 1 # N on first iteration
i_stop = N # reset i_stop
else
i_stop = i_stop - 1
end
end
return fmbs_split
end

function fused_copyto!(fmb::MBF.FusedMultiBroadcast, ::MBF.MBF_CUDA)
destinations = map(p -> p.first, fmb.pairs)
fmbs = partition_kernels(fmb)
for fmb in fmbs
(; pairs) = fmb
dest = first(pairs).first
dests = map(p -> p.first, pairs)
all(a -> axes(a) == axes(dest), dests) || error(
"Cannot fuse broadcast expressions with unequal broadcast axes",
)
config = CUDA.launch_configuration(kernel.fun)
threads = min(nitems, config.threads)
blocks = cld(nitems, threads)
kernel(fmb, CI; threads, blocks)
nitems = length(parent(dest))
CI = CartesianIndices(axes(dest))
kernel =
CUDA.@cuda always_inline = true launch = false fused_copyto_kernel!(
fmb,
CI,
)
config = CUDA.launch_configuration(kernel.fun)
threads = min(nitems, config.threads)
blocks = cld(nitems, threads)
kernel(fmb, CI; threads, blocks)
end
return destinations
end
import Base.Broadcast
Expand Down
18 changes: 3 additions & 15 deletions test/execution/parameter_memory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ function perf_kernel_shared_reads_fused!(X, Y)
@. y3 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + x3
@. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4
@. y1 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1
@. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2
@. y3 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + x3
@. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4
@. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 # breaks on A100 due to too much parameter memory
end
end
#! format: on
Expand All @@ -66,18 +64,8 @@ problem_size = (50, 5, 5, 6, 5400)
array_size = problem_size # array
X = get_arrays(:x, AType, bm.float_type, array_size)
Y = get_arrays(:y, AType, bm.float_type, array_size)
@testset "Test breaking case with parameter memory" begin
if use_cuda
try
perf_kernel_shared_reads_fused!(X, Y)
error("The above kernel should error")
catch e
@test startswith(
e.msg,
"Kernel invocation uses too much parameter memory.",
)
end
end
@testset "Test kernel splitting with too much parameter memory" begin
use_cuda && perf_kernel_shared_reads_fused!(X, Y)
end

nothing

0 comments on commit c5ec545

Please sign in to comment.