Skip to content

Commit

Permalink
Generalize to arbitrary args.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Apr 14, 2023
1 parent 6335f86 commit dab7f8a
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 10 deletions.
36 changes: 34 additions & 2 deletions src/compiler/compilation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,41 @@ function generate_opaque_closure(config::CompilerConfig, src::CodeInfo,
return OpaqueClosure{id, typeof(env), sig, rt}(env)
end

# generated function `ccall`, working around the restriction that ccall type
# tuples need to be literals. this relies on ccall internals...
@inline @generated function generated_ccall(f::Ptr, _rettyp, _types, vals...)
ex = quote end

rettyp = _rettyp.parameters[1]
types = _types.parameters[1].parameters
args = [:(vals[$i]) for i in 1:length(vals)]

# cconvert
cconverted = [Symbol("cconverted_$i") for i in 1:length(vals)]
for (dst, typ, src) in zip(cconverted, types, args)
append!(ex.args, (quote
$dst = Base.cconvert($typ, $src)
end).args)
end

# unsafe_convert
unsafe_converted = [Symbol("unsafe_converted_$i") for i in 1:length(vals)]
for (dst, typ, src) in zip(unsafe_converted, types, cconverted)
append!(ex.args, (quote
$dst = Base.unsafe_convert($typ, $src)
end).args)
end

call = Expr(:foreigncall, :f, rettyp, Core.svec(types...), 0,
QuoteNode(:ccall), unsafe_converted..., cconverted...)
push!(ex.args, call)
return ex
end

# device-side call to an opaque closure
function (oc::OpaqueClosure{F})(a, b) where F
function (oc::OpaqueClosure{F,E,A,R})(args...) where {F,E,A,R}
ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), F)
assume(ptr != C_NULL)
return ccall(ptr, Int, (Int, Int), a, b)
#ccall(ptr, R, (A...), args...)
generated_ccall(ptr, R, A, args...)
end
17 changes: 9 additions & 8 deletions test/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1110,21 +1110,22 @@ end

# basic closure, constructed from CodeInfo
let
ir, rettyp = only(Base.code_typed(+, (Int, Int)))
ir, rettyp = only(Base.code_typed(*, (Int, Int, Int)))
oc = CUDA.OpaqueClosure(ir)

c = CuArray([0])
a = CuArray([1])
b = CuArray([2])
d = CuArray([1])
a = CuArray([2])
b = CuArray([3])
c = CuArray([4])

function kernel(oc, c, a, b)
function kernel(oc, d, a, b, c)
i = threadIdx().x
@inbounds c[i] = oc(a[i], b[i])
@inbounds d[i] = oc(a[i], b[i], c[i])
return
end
@cuda threads=1 kernel(oc, c, a, b)
@cuda threads=1 kernel(oc, d, a, b, c)

@test Array(c)[] == 3
@test Array(d)[] == 24
end

end
Expand Down

0 comments on commit dab7f8a

Please sign in to comment.