Skip to content

Commit

Permalink
Add back matshell
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeremy E Kozdon committed Jul 19, 2021
1 parent 0499a9c commit 1e39b30
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 40 deletions.
2 changes: 1 addition & 1 deletion src/PETSc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ include("viewer.jl")
include("options.jl")
include("vec.jl")
include("mat.jl")
# include("matshell.jl")
include("matshell.jl")
# include("ksp.jl")
# include("ref.jl")
# include("pc.jl")
Expand Down
107 changes: 68 additions & 39 deletions src/matshell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,58 +3,87 @@
Create a `m×n` PETSc shell matrix object wrapping `obj`.
If `obj` is a `Function`, then the multiply action `obj(y,x)`; otherwise it calls `mul!(y, obj, x)`.
If `obj` is a `Function`, then the multiply action `obj(y,x)`; otherwise it
calls `mul!(y, obj, x)`.
This can be changed by defining `PETSc._mul!`.
# External Links
$(_doc_external("Mat/MATSHELL"))
"""
mutable struct MatShell{T,A} <: AbstractMat{T}
mutable struct MatShell{PetscLib, PetscScalar, OType} <:
AbstractMat{PetscLib, PetscScalar}
ptr::CMat
obj::A
obj::OType
end

struct MatOp{PetscLib, PetscInt, Op} end

function (::MatOp{PetscLib, PetscInt, LibPETSc.MATOP_MULT})(
M::CMat,
cx::CVec,
cy::CVec,
)::PetscInt where {PetscLib, PetscInt}
r_ctx = Ref{Ptr{Cvoid}}()
LibPETSc.MatShellGetContext(PetscLib, M, r_ctx)
ptr = r_ctx[]
mat = unsafe_pointer_to_objref(ptr)

PetscScalar = getlib(PetscLib).PetscScalar
x = unsafe_localarray(WrapVec{PetscLib, PetscScalar}(cx); write = false)
y = unsafe_localarray(WrapVec{PetscLib, PetscScalar}(cy); read = false)

struct MatOp{T,Op} end
_mul!(y, mat, x)

Base.finalize(y)
Base.finalize(x)
return PetscInt(0)
end

function _mul!(y,mat::MatShell{T,F},x) where {T, F<:Function}
function _mul!(
y,
mat::MatShell{PetscLib, PetscScalar, F},
x,
) where {PetscLib, PetscScalar, F <: Function}
mat.obj(y, x)
end

function _mul!(y,mat::MatShell{T},x) where {T}
function _mul!(y, mat::MatShell, x) where {T}
LinearAlgebra.mul!(y, mat.obj, x)
end

MatShell{T}(obj, m, n) where {T} = MatShell{T}(obj, MPI.COMM_SELF, m, n, m, n)


@for_libpetsc begin
function MatShell{$PetscScalar}(obj::A, comm::MPI.Comm, m, n, M, N) where {A}
mat = MatShell{$PetscScalar,A}(C_NULL, obj)
# we use the MatShell object itsel
ctx = pointer_from_objref(mat)
@chk ccall((:MatCreateShell, $libpetsc), PetscErrorCode,
(MPI.MPI_Comm,$PetscInt,$PetscInt,$PetscInt,$PetscInt,Ptr{Cvoid},Ptr{CMat}),
comm, m, n, M, N, ctx, mat)

mulptr = @cfunction(MatOp{$PetscScalar, MATOP_MULT}(), $PetscInt, (CMat, CVec, CVec))
@chk ccall((:MatShellSetOperation, $libpetsc), PetscErrorCode, (CMat, MatOperation, Ptr{Cvoid}), mat, MATOP_MULT, mulptr)
return mat
end

function (::MatOp{$PetscScalar, MATOP_MULT})(M::CMat,cx::CVec,cy::CVec)::$PetscInt
r_ctx = Ref{Ptr{Cvoid}}()
@chk ccall((:MatShellGetContext, $libpetsc), PetscErrorCode, (CMat, Ptr{Ptr{Cvoid}}), M, r_ctx)
ptr = r_ctx[]
mat = unsafe_pointer_to_objref(ptr)

x = unsafe_localarray($PetscScalar, cx; write=false)
y = unsafe_localarray($PetscScalar, cy; read=false)

_mul!(y,mat,x)

Base.finalize(y)
Base.finalize(x)
return $PetscInt(0)
end

# We have to use the macro here because of the @cfunction
LibPETSc.@for_petsc function MatShell(
petsclib::$PetscLib,
obj::OType,
comm::MPI.Comm,
local_rows,
local_cols,
global_rows = LibPETSc.PETSC_DECIDE,
global_cols = LibPETSc.PETSC_DECIDE,
) where {OType}
mat = MatShell{$PetscLib, $PetscScalar, OType}(C_NULL, obj)

# we use the MatShell object itself
ctx = pointer_from_objref(mat)

LibPETSc.MatCreateShell(
petsclib,
comm,
local_rows,
local_cols,
global_rows,
global_cols,
pointer_from_objref(mat),
mat,
)

mulptr = @cfunction(
MatOp{$PetscLib, $PetscInt, LibPETSc.MATOP_MULT}(),
$PetscInt,
(CMat, CVec, CVec)
)
LibPETSc.MatShellSetOperation(petsclib, mat, LibPETSc.MATOP_MULT, mulptr)

return mat
end
5 changes: 5 additions & 0 deletions src/vec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ Base.eltype(
) where {PetscLib, PetscScalar} = PetscScalar
Base.size(v::AbstractVec) = (length(v),)

mutable struct WrapVec{PetscLib, PetscScalar} <:
AbstractVec{PetscLib, PetscScalar}
ptr::CVec
end

"""
VecSeq(petsclib, v::Vector)
Expand Down
22 changes: 22 additions & 0 deletions test/matshell.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using Test
using PETSc
using MPI

@testset "MatShell" begin
for petsclib in PETSc.petsclibs
PETSc.initialize(petsclib)
PetscScalar = petsclib.PetscScalar

local_rows = 10
local_cols = 5
f!(x, y) = x .= [2y; 3y]
x_jl = collect

matshell =
PETSc.MatShell(petsclib, f!, MPI.COMM_SELF, local_rows, local_cols)
x = PetscScalar.(collect(1:5))
@test matshell * x == [2x; 3x]

PETSc.finalize(petsclib)
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ include("init.jl")
include("options.jl")
include("vec.jl")
include("mat.jl")
include("matshell.jl")

0 comments on commit 1e39b30

Please sign in to comment.