diff --git a/src/Operators/finitedifference.jl b/src/Operators/finitedifference.jl index 12740e2b28..c2d628a8c9 100644 --- a/src/Operators/finitedifference.jl +++ b/src/Operators/finitedifference.jl @@ -3327,30 +3327,6 @@ function Base.similar( return Field(Eltype, sp) end -function _serial_copyto!(field_out::Field, bc, Ni::Int, Nj::Int, Nh::Int) - space = axes(field_out) - bounds = window_bounds(space, bc) - bcs = bc # strip_space(bc, space) - @inbounds for h in 1:Nh, j in 1:Nj, i in 1:Ni - apply_stencil!(space, field_out, bcs, (i, j, h), bounds) - end - return field_out -end - -function _threaded_copyto!(field_out::Field, bc, Ni::Int, Nj::Int, Nh::Int) - space = axes(field_out) - bounds = window_bounds(space, bc) - bcs = bc # strip_space(bc, space) - @inbounds begin - Threads.@threads for h in 1:Nh - for j in 1:Nj, i in 1:Ni - apply_stencil!(space, field_out, bcs, (i, j, h), bounds) - end - end - end - return field_out -end - function strip_space(bc::StencilBroadcasted{Style}, parent_space) where {Style} current_space = axes(bc) new_space = placeholder_space(current_space, parent_space) @@ -3368,15 +3344,27 @@ function Base.copyto!( Broadcasted{ColumnStencilStyle}, }, ) - space = axes(bc) + space = axes(field_out) local_geometry = Spaces.local_geometry_data(space) (Ni, Nj, _, _, Nh) = size(local_geometry) - context = ClimaComms.context(axes(field_out)) + context = ClimaComms.context(space) device = ClimaComms.device(context) + bcs = strip_space(bc, space) + bounds = window_bounds(space, bcs) if (device isa ClimaComms.CPUMultiThreaded) && Nh > 1 - return _threaded_copyto!(field_out, bc, Ni, Nj, Nh) + @inbounds begin + Threads.@threads for h in 1:Nh + for j in 1:Nj, i in 1:Ni + apply_stencil!(space, field_out, bcs, (i, j, h), bounds) + end + end + end + else + @inbounds for h in 1:Nh, j in 1:Nj, i in 1:Ni + apply_stencil!(space, field_out, bcs, (i, j, h), bounds) + end end - return _serial_copyto!(field_out, bc, Ni, Nj, Nh) + return field_out end function window_bounds(space, bc)