diff --git a/src/fft/fft.jl b/src/fft/fft.jl index 85f9e49c..ea94b574 100644 --- a/src/fft/fft.jl +++ b/src/fft/fft.jl @@ -88,6 +88,7 @@ function update_stream!(plan::ROCFFTPlan) new_stream = AMDGPU.stream() if plan.stream != new_stream plan.stream = new_stream + info = plan.execution_info rocfft_execution_info_set_stream(info, new_stream) end return diff --git a/test/rocarray/fft.jl b/test/rocarray/fft.jl index 626a3bee..70280bf8 100644 --- a/test/rocarray/fft.jl +++ b/test/rocarray/fft.jl @@ -334,4 +334,19 @@ end end end +@testset "Asynchronous" begin + X = rand(Float32, 10, 10) + d_X = ROCArray(X) + + p = plan_rfft(X) + d_p = plan_rfft(d_X) + + Y = p * X + + task = Threads.@spawn d_p * d_X # executes FFT on separate AMDGPU stream + d_Y = fetch(task) + + @test isapprox(collect(d_Y), Y; rtol=MYRTOL, atol=MYATOL) +end + end # testset FFT