Skip to content

Commit

Permalink
[rocSOLVER] Add a test for ormqr!
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Mar 22, 2024
1 parent 2687750 commit 625b00a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/solver/highlevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ for (fname, elty) in (
chkside(side)
chktrans(trans)

m, n = (ndims(C) == 2) ? size(C) : (size(C, 1), 1)
m, n = (ndims(C) == 2) ? size(C) : (length(C), 1)
k = length(τ)
mA = size(A, 1)

Expand Down
26 changes: 25 additions & 1 deletion test/rocarray/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ end
end
end

@testset "geqrf! -- omgqr!" begin
@testset "geqrf! -- orgqr!" begin
@testset "elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
A = rand(elty, m, n)
dA = ROCArray(A)
Expand All @@ -147,6 +147,30 @@ end
end
end

@testset "ormqr!" begin
@testset "elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
@testset "side = $side" for side in ['L', 'R']
@testset "trans = $trans" for (trans, op) in [('N', identity), ('T', transpose), ('C', adjoint)]
A = rand(elty, m, n)
dA = ROCArray(A)
dA, dτ = rocSOLVER.geqrf!(dA)

hI = Matrix{elty}(I, m, m)
dI = ROCArray(hI)
dH = rocSOLVER.ormqr!(side, 'N', dA, dτ, dI)
@test dH' * dH I

C = side == 'L' ? rand(elty, m, n) : rand(elty, n, m)
dC = ROCArray(C)
dD = side == 'L' ? op(dH) * dC : dC * op(dH)

ormqr!(side, trans, dA, dτ, dC)
@test dC dD
end
end
end
end

@testset "potrf! -- potrs!" begin
@testset "elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
A = rand(elty,n,n)
Expand Down

0 comments on commit 625b00a

Please sign in to comment.