Skip to content

Commit

Permalink
Merge pull request #541 from SciML/gputests2
Browse files Browse the repository at this point in the history
Update GPU tests
  • Loading branch information
ChrisRackauckas authored Jul 31, 2023
2 parents 4927ad9 + 56d66fa commit 76d12b5
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 14 deletions.
26 changes: 14 additions & 12 deletions test/gpu/sde_weak_adaptive_gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ h1(z) = z^3-6*z^2+8*z

function f1!(du,u,p,t)
@inbounds begin
du[1] = 1//2*u[1]+sqrt(u[1]^2 +1)
du[1] = 5f-1*u[1]+sqrt(u[1]^2 +1)
end
nothing
end
Expand All @@ -56,28 +56,29 @@ ensemble_prob1 = EnsembleProblem(prob1;
output_func = output_func,
prob_func = prob_func,
reduction = reduction,
u_init=Vector{eltype(prob1.u0)}([0.0])
u_init=Vector{eltype(prob1.u0)}([0.0]),
safetycopy = false
)


# prob 2
u₀ = [0.1f0,0.1f0]
function f2!(du,u,p,t)
@inbounds begin
du[1] = 3//2*u[1]
du[2] = 3//2*u[2]
du[1] = 1.5f-0*u[1]
du[2] = 1.5f-0*u[2]
end
nothing
end
function g2!(du,u,p,t)
@inbounds begin
du[1] = 1//10*u[1]
du[2] = 1//10*u[2]
du[1] = 1f-1*u[1]
du[2] = 1f-1*u[2]
end
nothing
end

f_true2(t) = 1//10*exp(3//2*t) #1//100*exp(301//100*t)
f_true2(t) = 1f-1*exp(1.5f-1*t) #1//100*exp(301//100*t)

h2(z) = z #z^2

Expand All @@ -86,7 +87,8 @@ ensemble_prob2 = EnsembleProblem(prob2;
output_func = (sol,i) -> (h2.(sol),false),
prob_func = prob_func,
reduction = reduction,
u_init=Vector{eltype(prob2.u0)}([0.0, 0.0])
u_init=Vector{eltype(prob2.u0)}([0.0, 0.0]),
safetycopy = false
)


Expand All @@ -109,19 +111,19 @@ seeds = rand(UInt, numtraj)
for i in 1:2
@show i

err1 = weak_error(probs[i],DRI1NM(),numtraj,Int(1e4),ftrue[i],tsave,abstol=1f0,reltol=1f0, ensemblealg=EnsembleGPUArray())
err1 = weak_error(probs[i],DRI1NM(),numtraj,Int(1e4),ftrue[i],tsave,abstol=1f0,reltol=1f0, ensemblealg=EnsembleGPUArray(CUDA.CUDABackend()))
@show err1
# err2 = weak_error(probs[i],DRI1NM(),numtraj,Int(1e4),ftrue[i],tsave,abstol=0.1f0,reltol=0.1f0, ensemblealg=EnsembleGPUArray())
# err2 = weak_error(probs[i],DRI1NM(),numtraj,Int(1e4),ftrue[i],tsave,abstol=0.1f0,reltol=0.1f0, ensemblealg=EnsembleGPUArray(CUDA.CUDABackend()))
# @show err2
err3 = weak_error(probs[i],DRI1NM(),numtraj,Int(1e4),ftrue[i],tsave,abstol=0.01f0,reltol=0.01f0, ensemblealg=EnsembleGPUArray())
err3 = weak_error(probs[i],DRI1NM(),numtraj,Int(1e4),ftrue[i],tsave,abstol=0.01f0,reltol=0.01f0, ensemblealg=EnsembleGPUArray(CUDA.CUDABackend()))
@show err3
@test err1 > err3
println("")
end


#
# sol = @time solve(probs[1],DRI1NM(),EnsembleGPUArray(),
# sol = @time solve(probs[1],DRI1NM(),EnsembleGPUArray(CUDA.CUDABackend()),
# dt=0.001f0,adaptive=false,abstol=0.1f0,reltol=0.1f0,
# trajectories=numtraj,batch_size=Int(1e1),
# saveat = tsave
Expand Down
2 changes: 1 addition & 1 deletion test/gpu/sde_weak_scalar_adaptive_gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ ensembleprob = EnsembleProblem(prob, prob_func = prob_func)
# CUDAnative.CUDAdrv.@profile
# check either on CPU with EnsembleCPUArray() or on GPU with EnsembleGPUArray()
#@test_nowarn sol = @time solve(ensembleprob,DRI1(),EnsembleCPUArray(),trajectories=numtraj)
sol = @time solve(ensembleprob,DRI1NM(),EnsembleGPUArray(),trajectories=numtraj)
sol = @time solve(ensembleprob,DRI1NM(),EnsembleGPUArray(CUDA.CUDABackend()),trajectories=numtraj)
3 changes: 2 additions & 1 deletion test/weak_convergence/weak_srockc2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ println("SROCKC2")
@test abs(sim.𝒪est[:weak_final]-2) < 0.35
#@test abs(sim.𝒪est[:weak_l2]-2) < 0.3
#@test abs(sim.𝒪est[:weak_l∞]-2) < 0.3
sim = nothing

prob = prob_sde_linear

println("SROCKC2")
@time sim = test_convergence(dts,prob,SROCKC2(),save_everystep=false,trajectories=Int(1e7))
@time sim = test_convergence(dts,prob,SROCKC2(),save_everystep=false,trajectories=Int(5e6))
@show sim.𝒪est[:weak_final]
@test abs(sim.𝒪est[:weak_final]-2) < 0.35
#@test abs(sim.𝒪est[:weak_l2]-2) < 0.3
Expand Down

0 comments on commit 76d12b5

Please sign in to comment.