Skip to content

Commit

Permalink
Properly allocate workspace for conv algo search (#643)
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th authored Jun 13, 2024
1 parent 51921f7 commit d32722b
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 24 deletions.
4 changes: 3 additions & 1 deletion src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,13 @@ function Base.copyto!(
amount == 0 && return dest
@boundscheck checkbounds(dest, d_offset + amount - 1)
@boundscheck checkbounds(source, s_offset + amount - 1)
stm = stream()
Mem.download!(
pointer(dest, d_offset),
Mem.view(convert(Mem.AbstractAMDBuffer, source.buf[]),
(source.offset + s_offset - 1) * sizeof(T)),
amount * sizeof(T); stream=stream(), async)
amount * sizeof(T); stream=stm)
async || synchronize(stm)
dest
end

Expand Down
7 changes: 5 additions & 2 deletions src/dnn/convolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ function get_workspace_size(
wsize_ref = Ref{Csize_t}(0)
get_workspace_size_func(conv_type)(
handle, a_desc.handle, b_desc.handle,
conv_desc.handle, c_desc.handle, wsize_ref) |> check
conv_desc.handle, c_desc.handle, wsize_ref) # NOTE: do not |> check...
wsize_ref[]
end

Expand Down Expand Up @@ -93,9 +93,12 @@ function find_algorithm(
cache = get_benchmark_cache(conv_type, conv_args)
isnothing(cache) || return cache

workspace = ROCArray{UInt8}(undef, 0)
wsize = get_workspace_size(conv_type; handle, a_desc, b_desc, conv_desc, c_desc)
workspace = ROCArray{UInt8}(undef, wsize)
perf_results = find_conv_algo(conv_type;
handle, workspace, a, a_desc, b, b_desc, conv_desc, c, c_desc)
AMDGPU.unsafe_free!(workspace)

set_benchmark_cache!(conv_type, conv_args, perf_results)
workspace = ROCArray{UInt8}(undef, perf_results.memory)

Expand Down
28 changes: 7 additions & 21 deletions src/runtime/memory/hip.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,8 @@ function HIPBuffer(bytesize; stream::HIP.HIPStream)
ptr = alloc_or_retry!(isnothing; stream) do
try
# Try to allocate.
# NOTE Async is ~300x slower for small (≤ 16 bytes) allocations:
# https://github.com/ROCm/HIP/issues/3370#issuecomment-1842938966
if bytesize > 16
HIP.hipMallocAsync(ptr_ref, bytesize, stream) |> HIP.check
# HIP.hipMallocFromPoolAsync(ptr_ref, bytesize, pool, stream) |> HIP.check
else
HIP.hipMalloc(ptr_ref, bytesize) |> HIP.check
end
HIP.hipMallocAsync(ptr_ref, bytesize, stream) |> HIP.check
# HIP.hipMallocFromPoolAsync(ptr_ref, bytesize, pool, stream) |> HIP.check

ptr = ptr_ref[]
ptr == C_NULL && throw(HIP.HIPError(HIP.hipErrorOutOfMemory))
Expand Down Expand Up @@ -78,11 +72,7 @@ function free(buf::HIPBuffer; stream::HIP.HIPStream)
buf.own || return

buf.ptr == C_NULL && return
if buf.bytesize > 16
HIP.hipFreeAsync(buf, stream) |> HIP.check
else
HIP.hipFree(buf) |> HIP.check
end
HIP.hipFreeAsync(buf, stream) |> HIP.check
AMDGPU.account!(AMDGPU.memory_stats(buf.device), -buf.bytesize)
return
end
Expand All @@ -93,13 +83,9 @@ function upload!(dst::HIPBuffer, src::Ptr, bytesize::Int; stream::HIP.HIPStream)
return
end

function download!(dst::Ptr, src::HIPBuffer, bytesize::Int; stream::HIP.HIPStream, async::Bool)
function download!(dst::Ptr, src::HIPBuffer, bytesize::Int; stream::HIP.HIPStream)
bytesize == 0 && return
if async
HIP.hipMemcpyDtoHAsync(dst, src, bytesize, stream) |> HIP.check
else
HIP.hipMemcpyDtoH(dst, src, bytesize) |> HIP.check
end
HIP.hipMemcpyDtoHAsync(dst, src, bytesize, stream) |> HIP.check
return
end

Expand Down Expand Up @@ -157,10 +143,10 @@ upload!(dst::HostBuffer, src::Ptr, sz::Int; stream::HIP.HIPStream) =
upload!(dst::HostBuffer, src::HIPBuffer, sz::Int; stream::HIP.HIPStream) =
HIP.memcpy(dst, src, sz, HIP.hipMemcpyDeviceToHost, stream)

download!(dst::Ptr, src::HostBuffer, sz::Int; stream::HIP.HIPStream, async::Bool) =
download!(dst::Ptr, src::HostBuffer, sz::Int; stream::HIP.HIPStream) =
HIP.memcpy(dst, src, sz, HIP.hipMemcpyHostToHost, stream)

download!(dst::HIPBuffer, src::HostBuffer, sz::Int; stream::HIP.HIPStream, async::Bool) =
download!(dst::HIPBuffer, src::HostBuffer, sz::Int; stream::HIP.HIPStream) =
HIP.memcpy(dst, src, sz, HIP.hipMemcpyHostToDevice, stream)

transfer!(dst::HostBuffer, src::HostBuffer, sz::Int; stream::HIP.HIPStream) =
Expand Down

0 comments on commit d32722b

Please sign in to comment.