Skip to content

Commit

Permalink
Fix QR based fitting
Browse files Browse the repository at this point in the history
Avoid erroring out for low rank design matrices when dropcollinear=false.
Avoid unnecessary triangular solves. Avoid indexing in the Q.
Avoid slicing R matrix in a way that triggers a minimum norm solution.
Remove unnecessary temporaries.
  • Loading branch information
andreasnoack committed May 7, 2024
1 parent 827a7e2 commit 33d1584
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 50 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ julia = "1.6"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["CategoricalArrays", "CSV", "DataFrames", "RDatasets", "StableRNGs", "Test"]
test = ["CategoricalArrays", "CSV", "DataFrames", "Downloads", "RDatasets", "StableRNGs", "Test"]
76 changes: 29 additions & 47 deletions src/linpred.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,23 +63,19 @@ function delbeta! end

function delbeta!(p::DensePredQR{T,<:QRCompactWY}, r::Vector{T}) where T<:BlasReal
rnk = rank(p.qr.R)
rnk == length(p.delbeta) || throw(RankDeficientException(rnk))
p.delbeta = p.qr\r
mul!(p.scratchm1, Diagonal(ones(size(r))), p.X)
p.delbeta = p.qr \ r
return p
end

function delbeta!(p::DensePredQR{T,<:QRCompactWY}, r::Vector{T}, wt::Vector{T}) where T<:BlasReal
rnk = rank(p.qr.R)
rnk == length(p.delbeta) || throw(RankDeficientException(rnk))
X = p.X
W = Diagonal(wt)
sqrtW = Diagonal(sqrt.(wt))
mul!(p.scratchm1, sqrtW, X)
mul!(p.delbeta, X'W, r)
qnr = qr(p.scratchm1)
Rinv = inv(qnr.R)
p.delbeta = Rinv * Rinv' * p.delbeta
= sqrtW * r
p.qr = qr!(p.scratchm1)
p.delbeta = p.qr \
return p
end

Expand All @@ -88,44 +84,32 @@ function delbeta!(p::DensePredQR{T,<:QRPivoted}, r::Vector{T}) where T<:BlasReal
if rnk == length(p.delbeta)
p.delbeta = p.qr\r
else
R = @view p.qr.R[:, 1:rnk]
Q = @view p.qr.Q[:, 1:size(R, 1)]
R = UpperTriangular(view(parent(p.qr.R), 1:rnk, 1:rnk))
piv = p.qr.p
p.delbeta = zeros(size(p.delbeta))
p.delbeta[1:rnk] = R \ Q'r
fill!(p.delbeta, 0)
p.delbeta[1:rnk] = R \ view(p.qr.Q'r, 1:rnk)
invpermute!(p.delbeta, piv)
end
mul!(p.scratchm1, Diagonal(ones(size(r))), p.X)
return p
end

function delbeta!(p::DensePredQR{T,<:QRPivoted}, r::Vector{T}, wt::Vector{T}) where T<:BlasReal
rnk = rank(p.qr.R)
X = p.X
W = Diagonal(wt)
sqrtW = Diagonal(sqrt.(wt))
delbeta = p.delbeta
scratchm2 = similar(X, T)
mul!(p.scratchm1, sqrtW, X)
mul!(scratchm2, W, X)
mul!(delbeta, transpose(scratchm2), r)

if rnk == length(p.delbeta)
qnr = qr(p.scratchm1)
Rinv = inv(qnr.R)
p.delbeta = Rinv * Rinv' * delbeta
else
qnr = pivoted_qr!(copy(p.scratchm1))
R = @view qnr.R[1:rnk, 1:rnk]
Rinv = inv(R)
piv = qnr.p
permute!(delbeta, piv)
for k=(rnk+1):length(delbeta)
delbeta[k] = -zero(T)
end
p.delbeta[1:rnk] = Rinv * Rinv' * view(delbeta, 1:rnk)
invpermute!(delbeta, piv)
= sqrtW * r

p.qr = pivoted_qr!(copy(p.scratchm1))
rnk = rank(p.qr.R) # FIXME! Don't use svd for this
R = UpperTriangular(view(parent(p.qr.R), 1:rnk, 1:rnk))
permute!(p.delbeta, p.qr.p)
for k = (rnk + 1):length(p.delbeta)
p.delbeta[k] = -zero(T)
end
p.delbeta[1:rnk] = R \ (p.qr.Q'*r̃)[1:rnk]
invpermute!(p.delbeta, p.qr.p)

return p
end

Expand Down Expand Up @@ -279,27 +263,25 @@ end
LinearAlgebra.cholesky(p::SparsePredChol{T}) where {T} = copy(p.chol)
LinearAlgebra.cholesky!(p::SparsePredChol{T}) where {T} = p.chol

function invqr(x::DensePredQR{T,<: QRCompactWY}) where T
Q,R = qr(x.scratchm1)
Rinv = inv(R)
function invqr(p::DensePredQR{T,<: QRCompactWY}) where T
Rinv = inv(p.qr.R)
Rinv*Rinv'
end

function invqr(x::DensePredQR{T,<: QRPivoted}) where T
Q,R,pv = pivoted_qr!(copy(x.scratchm1))
rnk = rank(R)
p = length(x.delbeta)
if rnk == p
Rinv = inv(R)
function invqr(p::DensePredQR{T,<: QRPivoted}) where T
rnk = rank(p.qr.R)
k = length(p.delbeta)
if rnk == k
Rinv = inv(p.qr.R)
xinv = Rinv*Rinv'
ipiv = invperm(pv)
ipiv = invperm(p.qr.p)
return xinv[ipiv, ipiv]
else
Rsub = R[1:rnk, 1:rnk]
Rsub = UpperTriangular(view(p.qr.R, 1:rnk, 1:rnk))
RsubInv = inv(Rsub)
xinv = fill(convert(T, NaN), (p,p))
xinv = fill(convert(T, NaN), (k, k))
xinv[1:rnk, 1:rnk] = RsubInv*RsubInv'
ipiv = invperm(pv)
ipiv = invperm(p.qr.p)
return xinv[ipiv, ipiv]
end
end
Expand Down
21 changes: 19 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ end
@test isa(m2p_dep_pos.pp.chol, CholeskyPivoted)
@test isa(m2p_dep_pos_kw.pp.chol, CholeskyPivoted)
elseif dmethod == :qr
@test_throws RankDeficientException m2 = fit(LinearModel, Xmissingcell, ymissingcell;
method = dmethod, dropcollinear=false)
@test fit(LinearModel, Xmissingcell, ymissingcell;
method = dmethod, dropcollinear=false) isa LinearModel
@test isapprox(coef(m2p), [0.9772643585228962, 11.889730016918342, 3.027347397503282,
3.9661379199401177, 5.079410103608539, 6.194461814118862,
-2.9863884084219015, 7.930328728005132, 8.87999491860477,
Expand Down Expand Up @@ -2055,3 +2055,20 @@ end
# values. It doesn't care about links, offsets, etc. as long as the model matrix,
# vcov matrix and stderrors are well defined.
end

@testset "NIST - Filip. Issue 558" begin
fn = Downloads.download("https://www.itl.nist.gov/div898/strd/lls/data/LINKS/DATA/Filip.dat")
filip_estimates_df = CSV.read(fn, DataFrame; skipto = 31, limit = 11, header = ["parameter", "estimate", "se"], delim = " ", ignorerepeated = true)
filip_data_df = CSV.read(fn, DataFrame; skipto = 61, header = ["y", "x"], delim = " ", ignorerepeated = true)
X = [filip_data_df.x[i]^j for i in 1:length(filip_data_df.x), j in 0:10]

# No weights
f1 = lm(X, filip_data_df.y, dropcollinear = false, method = :qr)
@test coef(f1) filip_estimates_df.estimate rtol = 1e-8
@test stderror(f1) filip_estimates_df.se rtol = 1e-7

# Weights
f2 = lm(X, filip_data_df.y, dropcollinear = false, method = :qr, wts = ones(length(filip_data_df.y)))
@test coef(f2) filip_estimates_df.estimate rtol = 1e-8
@test stderror(f2) filip_estimates_df.se rtol = 1e-7
end

0 comments on commit 33d1584

Please sign in to comment.