Skip to content

Commit

Permalink
Reduce number of specialized methods
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Oct 15, 2024
1 parent 193ecfa commit 3438e2c
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 46 deletions.
32 changes: 20 additions & 12 deletions ext/cuda/data_layouts_copyto.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
DataLayouts._device_dispatch(x::CUDA.CuArray) = ToCUDA()
DataLayouts.device_dispatch(x::CUDA.CuArray) = ToCUDA()

function knl_copyto!(dest, src, us)
I = universal_index(dest)
Expand All @@ -8,7 +8,7 @@ function knl_copyto!(dest, src, us)
return nothing
end

function cuda_copyto!(dest::AbstractData, bc)
function Base.copyto!(dest::AbstractData, bc, ::ToCUDA)
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
us = DataLayouts.UniversalSize(dest)
if Nv > 0 && Nh > 0
Expand All @@ -26,13 +26,21 @@ function cuda_copyto!(dest::AbstractData, bc)
return dest
end

#! format: off
Base.copyto!(dest::IJFH{S, Nij}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij}, ::ToCUDA) where {S, Nij} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IFH{S, Ni}, bc::DataLayouts.BroadcastedUnionIFH{S, Ni}, ::ToCUDA) where {S, Ni} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IJF{S, Nij}, bc::DataLayouts.BroadcastedUnionIJF{S, Nij}, ::ToCUDA) where {S, Nij} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IF{S, Ni}, bc::DataLayouts.BroadcastedUnionIF{S, Ni}, ::ToCUDA) where {S, Ni} = cuda_copyto!(dest, bc)
Base.copyto!(dest::VIFH{S, Nv, Ni}, bc::DataLayouts.BroadcastedUnionVIFH{S, Nv, Ni}, ::ToCUDA) where {S, Nv, Ni} = cuda_copyto!(dest, bc)
Base.copyto!(dest::VIJFH{S, Nv, Nij}, bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij}, ::ToCUDA) where {S, Nv, Nij} = cuda_copyto!(dest, bc)
Base.copyto!(dest::VF{S, Nv}, bc::DataLayouts.BroadcastedUnionVF{S, Nv}, ::ToCUDA) where {S, Nv} = cuda_copyto!(dest, bc)
Base.copyto!(dest::DataF{S}, bc::DataLayouts.BroadcastedUnionDataF{S}, ::ToCUDA) where {S} = cuda_copyto!(dest, bc)
#! format: on
# broadcasting scalar assignment
# Performance optimization for the common identity scalar case: dest .= val
# And this is valid for the CPU or GPU, since the broadcasted object
# is a scalar type.
function Base.copyto!(
dest::AbstractData,
bc::Base.Broadcast.Broadcasted{Style},
::ToCUDA,
) where {
Style <:
Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}},
}
bc = Base.Broadcast.instantiate(
Base.Broadcast.Broadcasted{Style}(bc.f, bc.args, ()),
)
@inbounds bc0 = bc[]
fill!(dest, bc0)
end
4 changes: 1 addition & 3 deletions ext/cuda/data_layouts_fill.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ function knl_fill!(dest, val, us)
return nothing
end

function cuda_fill!(dest::AbstractData, bc)
function Base.fill!(dest::AbstractData, bc, ::ToCUDA)
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
us = DataLayouts.UniversalSize(dest)
if Nv > 0 && Nh > 0
Expand All @@ -23,5 +23,3 @@ function cuda_fill!(dest::AbstractData, bc)
end
return dest
end

Base.fill!(dest::AbstractData, val, ::ToCUDA) = cuda_fill!(dest, val)
16 changes: 7 additions & 9 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1524,19 +1524,17 @@ array2data(array::AbstractArray{T}, data::AbstractData) where {T} =
)

"""
device_dispatch(data::AbstractData)
device_dispatch(array::AbstractArray)
Returns an `ToCPU` or a `ToCUDA` for CPU
and CUDA-backed arrays accordingly.
"""
device_dispatch(dest::AbstractData) = _device_dispatch(dest)

_device_dispatch(x::Array) = ToCPU()
_device_dispatch(x::SubArray) = _device_dispatch(parent(x))
_device_dispatch(x::Base.ReshapedArray) = _device_dispatch(parent(x))
_device_dispatch(x::AbstractData) = _device_dispatch(parent(x))
_device_dispatch(x::SArray) = ToCPU()
_device_dispatch(x::MArray) = ToCPU()
device_dispatch(x::Array) = ToCPU()
device_dispatch(x::SubArray) = device_dispatch(parent(x))
device_dispatch(x::Base.ReshapedArray) = device_dispatch(parent(x))
device_dispatch(x::AbstractData) = device_dispatch(parent(x))
device_dispatch(x::SArray) = ToCPU()
device_dispatch(x::MArray) = ToCPU()

include("copyto.jl")
include("fused_copyto.jl")
Expand Down
34 changes: 14 additions & 20 deletions src/DataLayouts/copyto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

Base.copyto!(
dest::AbstractData,
bc::Union{AbstractData, Base.Broadcast.Broadcasted},
) = Base.copyto!(dest, bc, device_dispatch(dest))
@nospecialize(bc::Union{AbstractData, Base.Broadcast.Broadcasted}),
) = Base.copyto!(dest, bc, device_dispatch(parent(dest)))

# Specialize on non-Broadcasted objects
function Base.copyto!(dest::D, src::D) where {D <: AbstractData}
Expand All @@ -15,8 +15,6 @@ end

# broadcasting scalar assignment
# Performance optimization for the common identity scalar case: dest .= val
# And this is valid for the CPU or GPU, since the broadcasted object
# is a scalar type.
function Base.copyto!(
dest::AbstractData,
bc::Base.Broadcast.Broadcasted{Style},
Expand Down Expand Up @@ -51,10 +49,9 @@ function Base.copyto!(
::ToCPU,
) where {S, Nij}
(_, _, _, _, Nh) = size(dest)
@inbounds for h in 1:Nh
slab_dest = slab(dest, h)
slab_bc = slab(bc, h)
copyto!(slab_dest, slab_bc)
@inbounds for h in 1:Nh, j in 1:Nij, i in 1:Nij
idx = CartesianIndex(i, j, 1, 1, h)
dest[idx] = convert(S, bc[idx])
end
return dest
end
Expand All @@ -65,10 +62,9 @@ function Base.copyto!(
::ToCPU,
) where {S, Ni}
(_, _, _, _, Nh) = size(dest)
@inbounds for h in 1:Nh
slab_dest = slab(dest, h)
slab_bc = slab(bc, h)
copyto!(slab_dest, slab_bc)
@inbounds for h in 1:Nh, i in 1:Ni
idx = CartesianIndex(i, 1, 1, 1, h)
dest[idx] = convert(S, bc[idx])
end
return dest
end
Expand Down Expand Up @@ -131,10 +127,9 @@ function Base.copyto!(
) where {S, Nv, Ni}
# copy contiguous columns
(_, _, _, _, Nh) = size(dest)
@inbounds for h in 1:Nh, i in 1:Ni
col_dest = column(dest, i, h)
col_bc = column(bc, i, h)
copyto!(col_dest, col_bc)
@inbounds for h in 1:Nh, i in 1:Ni, v in 1:Nv
idx = CartesianIndex(i, 1, 1, v, h)
dest[idx] = convert(S, bc[idx])
end
return dest
end
Expand All @@ -146,10 +141,9 @@ function Base.copyto!(
) where {S, Nv, Nij}
# copy contiguous columns
(_, _, _, _, Nh) = size(dest)
@inbounds for h in 1:Nh, j in 1:Nij, i in 1:Nij
col_dest = column(dest, i, j, h)
col_bc = column(bc, i, j, h)
copyto!(col_dest, col_bc)
@inbounds for h in 1:Nh, j in 1:Nij, i in 1:Nij, v in 1:Nv
idx = CartesianIndex(i, j, 1, v, h)
dest[idx] = convert(S, bc[idx])
end
return dest
end
2 changes: 1 addition & 1 deletion src/DataLayouts/fill.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ function Base.fill!(data::VIFH, val, ::ToCPU)
end

Base.fill!(dest::AbstractData, val) =
Base.fill!(dest, val, device_dispatch(dest))
Base.fill!(dest, val, device_dispatch(parent(dest)))
2 changes: 1 addition & 1 deletion src/DataLayouts/fused_copyto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ function Base.copyto!(
end,
)
# check_fused_broadcast_axes(fmbc) # we should already have checked the axes
fused_copyto!(fmb_inst, dest1, device_dispatch(dest1))
fused_copyto!(fmb_inst, dest1, device_dispatch(parent(dest1)))
end

function fused_copyto!(
Expand Down

0 comments on commit 3438e2c

Please sign in to comment.