Skip to content

Commit

Permalink
Merge pull request #47 from CliMA/ck/dot_Is
Browse files Browse the repository at this point in the history
Apply a few fixes
  • Loading branch information
charleskawczynski authored Oct 10, 2024
2 parents a438a92 + 4a1b14a commit 041fdee
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 23 deletions.
37 changes: 21 additions & 16 deletions ext/MultiBroadcastFusionCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,33 @@ function fused_copyto!(fmb::MBF.FusedMultiBroadcast, ::MBF.MBF_CUDA)
(; pairs) = fmb
dest = first(pairs).first
destinations = map(p -> p.first, pairs)
nitems = length(parent(dest))
max_threads = 256 # can be higher if conditions permit
nthreads = min(max_threads, nitems)
nblocks = cld(nitems, nthreads)
a1 = axes(dest)
all(a -> axes(a) == axes(dest), destinations) ||
error("Cannot fuse broadcast expressions with unequal broadcast axes")
nitems = length(parent(dest))
CI = CartesianIndices(axes(dest))
CUDA.@cuda threads = (nthreads) blocks = (nblocks) fused_copyto_kernel!(
fmb,
CI,
)
return nothing
kernel =
CUDA.@cuda always_inline = true launch = false fused_copyto_kernel!(
fmb,
CI,
)
config = CUDA.launch_configuration(kernel.fun)
threads = min(nitems, config.threads)
blocks = cld(nitems, threads)
kernel(fmb, CI; threads, blocks)
return destinations
end
import Base.Broadcast
function fused_copyto_kernel!(fmb::MBF.FusedMultiBroadcast, CI)
(; pairs) = fmb
dest = first(pairs).first
nitems = length(dest)
idx = CUDA.threadIdx().x + (CUDA.blockIdx().x - 1) * CUDA.blockDim().x
if idx nitems
MBF.rcopyto_at!(pairs, CI[idx])
@inbounds begin
(; pairs) = fmb
dest = first(pairs).first
nitems = length(dest)
idx =
CUDA.threadIdx().x +
(CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x
if 1 idx nitems
MBF.rcopyto_at!(pairs, CI[idx])
end
end
return nothing
end
Expand Down
32 changes: 25 additions & 7 deletions src/execution/fused_kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,40 @@ struct MBF_CUDA end
device(x::AbstractArray) = MBF_CPU()

function Base.copyto!(fmb::FusedMultiBroadcast)
pairs = fmb.pairs # (Pair(dest1, bc1),Pair(dest2, bc2),...)
# Since we intercept Base.copyto!, we have not yet
# called Base.Broadcast.instantiate (as this is done
# in materialize, which has been stripped away), so,
# let's call it here.
fmb′ = FusedMultiBroadcast(
map(fmb.pairs) do p
Pair(p.first, Base.Broadcast.instantiate(p.second))
end,
)
(; pairs) = fmb′ # (Pair(dest1, bc1),Pair(dest2, bc2),...)
dest = first(pairs).first
fused_copyto!(fmb, device(dest))
fused_copyto!(fmb, device(dest))
end

Base.@propagate_inbounds function rcopyto_at!(pair::Pair, i...)
Base.@propagate_inbounds function rcopyto_at!(
pair::Pair,
i::Vararg{T},
) where {T}
dest, src = pair.first, pair.second
@inbounds dest[i...] = src[i...]
return nothing
end
Base.@propagate_inbounds function rcopyto_at!(pairs::Tuple, i...)
Base.@propagate_inbounds function rcopyto_at!(
pairs::Tuple,
i::Vararg{T},
) where {T}
rcopyto_at!(first(pairs), i...)
rcopyto_at!(Base.tail(pairs), i...)
end
Base.@propagate_inbounds rcopyto_at!(pairs::Tuple{<:Any}, i...) =
rcopyto_at!(first(pairs), i...)
@inline rcopyto_at!(pairs::Tuple{}, i...) = nothing
Base.@propagate_inbounds rcopyto_at!(
pairs::Tuple{<:Any},
i::Vararg{T},
) where {T} = rcopyto_at!(first(pairs), i...)
@inline rcopyto_at!(pairs::Tuple{}, i::Vararg{T}) where {T} = nothing

# This is better than the baseline.
function fused_copyto!(fmb::FusedMultiBroadcast, ::MBF_CPU)
Expand All @@ -39,6 +56,7 @@ function fused_copyto!(fmb::FusedMultiBroadcast, ::MBF_CPU)
dest[i] = bc[i]
end
end
return destinations
end


Expand Down
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
#=
julia --project
using TestEnv
TestEnv.activate()
using CUDA;
ENV["PERFORM_BENCHMARK"]="true";
using Revise; include(joinpath("test", "collection", "runtests.jl"))
using Revise; include(joinpath("test", "execution", "runtests.jl"))
using Revise; include(joinpath("test", "runtests.jl"))
Expand Down

0 comments on commit 041fdee

Please sign in to comment.