Skip to content

Commit

Permalink
Reduce number of specialized methods in stencil kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Oct 16, 2024
1 parent 90572d1 commit c753d55
Showing 1 changed file with 16 additions and 28 deletions.
44 changes: 16 additions & 28 deletions src/Operators/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3327,30 +3327,6 @@ function Base.similar(
return Field(Eltype, sp)
end

function _serial_copyto!(field_out::Field, bc, Ni::Int, Nj::Int, Nh::Int)
space = axes(field_out)
bounds = window_bounds(space, bc)
bcs = bc # strip_space(bc, space)
@inbounds for h in 1:Nh, j in 1:Nj, i in 1:Ni
apply_stencil!(space, field_out, bcs, (i, j, h), bounds)
end
return field_out
end

function _threaded_copyto!(field_out::Field, bc, Ni::Int, Nj::Int, Nh::Int)
space = axes(field_out)
bounds = window_bounds(space, bc)
bcs = bc # strip_space(bc, space)
@inbounds begin
Threads.@threads for h in 1:Nh
for j in 1:Nj, i in 1:Ni
apply_stencil!(space, field_out, bcs, (i, j, h), bounds)
end
end
end
return field_out
end

function strip_space(bc::StencilBroadcasted{Style}, parent_space) where {Style}
current_space = axes(bc)
new_space = placeholder_space(current_space, parent_space)
Expand All @@ -3368,15 +3344,27 @@ function Base.copyto!(
Broadcasted{ColumnStencilStyle},
},
)
space = axes(bc)
space = axes(field_out)
local_geometry = Spaces.local_geometry_data(space)
(Ni, Nj, _, _, Nh) = size(local_geometry)
context = ClimaComms.context(axes(field_out))
context = ClimaComms.context(space)
device = ClimaComms.device(context)
bcs = strip_space(bc, space)
bounds = window_bounds(space, bcs)
if (device isa ClimaComms.CPUMultiThreaded) && Nh > 1
return _threaded_copyto!(field_out, bc, Ni, Nj, Nh)
@inbounds begin
Threads.@threads for h in 1:Nh
for j in 1:Nj, i in 1:Ni
apply_stencil!(space, field_out, bcs, (i, j, h), bounds)
end
end
end
else
@inbounds for h in 1:Nh, j in 1:Nj, i in 1:Ni
apply_stencil!(space, field_out, bcs, (i, j, h), bounds)
end
end
return _serial_copyto!(field_out, bc, Ni, Nj, Nh)
return field_out
end

function window_bounds(space, bc)
Expand Down

0 comments on commit c753d55

Please sign in to comment.