Skip to content

Commit

Permalink
Add a dispatch for LinearAlgebra.norm2
Browse files Browse the repository at this point in the history
`norm(@view x[..], 2)` was previously leading to a call of `LinearAlgebra.generic_norm2` which led to a scalar indexing. This catches such cuda subarray norm2 calls earlier.

Inf-norm and p-norm with cuda subarrays still lead to the following dispatches:
```julia
LinearAlgebra.generic_normInf(x) = float(mapreduce(norm, max, x))
LinearAlgebra.generic_norm1(x) = mapreduce(float ∘ norm, +, x)
```
I am not sure if there is a better way to dispatch the above.

should resolve JuliaGPU#2280
  • Loading branch information
sharanry committed Mar 22, 2024
1 parent f5100a1 commit 0e2ef84
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
4 changes: 4 additions & 0 deletions lib/cublas/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ function LinearAlgebra.norm(x::DenseCuArray{<:Union{Float16, ComplexF16, CublasF
end
end

function LinearAlgebra.norm2(x::SubArray{T,N,P} where {T<:Union{Float16, ComplexF16, CublasFloat}, N, P<:DenseCuArray{<:T}})
return nrm2(x)
end

LinearAlgebra.BLAS.asum(x::StridedCuArray{<:CublasFloat}) = asum(length(x), x)

function LinearAlgebra.axpy!(alpha::Number, x::StridedCuArray{T}, y::StridedCuArray{T}) where T<:Union{Float16, ComplexF16, CublasFloat}
Expand Down
11 changes: 11 additions & 0 deletions test/libraries/cublas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1767,6 +1767,17 @@ end
@view(p[reshape(1:(out*inn),out,inn)]) * x
end
end

@testset "nrm2 with strided inputs" begin # JuliaGPU/CUDA.jl#2280
cudaTypes = (Float16, ComplexF16, CublasFloat)
for CT in cudaTypes
x = rand(CT, 10, 10, 10)
dx = CuArray(x)
dx_ = @view dx[3:6, 1:5, :]
x_ = @view x[3:6, 1:5, :]
@test norm(dx_, 2) norm(x_, 2)
end
end
end

############################################################################################
Expand Down

0 comments on commit 0e2ef84

Please sign in to comment.