Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ArrowheadMatrix #21

Merged
merged 18 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,16 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04"

[weakdeps]
FastCholesky = "2d5283b6-8564-42b6-bb00-83ed8e915756"

[extensions]
FastCholeskyExt = "FastCholesky"

[compat]
Distributions = "0.25"
DomainSets = "0.5.2, 0.6, 0.7"
FastCholesky = "1.3.1"
LinearAlgebra = "1.9"
LoopVectorization = "0.12"
Random = "1.9"
Expand All @@ -34,13 +41,13 @@ julia = "1.9"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CpuId = "adafc99b-e345-5852-983c-f28acb93d879"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "CpuId", "JET", "Test", "ReTestItems", "LinearAlgebra", "StableRNGs", "HCubature"]
test = ["Aqua", "BenchmarkTools", "CpuId", "FastCholesky", "JET", "Test", "ReTestItems", "StableRNGs", "HCubature"]
6 changes: 6 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ BayesBase.weightedmean_cov
BayesBase.weightedmean_invcov
```

## [Extra matrix structures](@id matrix-structures)
```@docs
BayesBase.ArrowheadMatrix
BayesBase.InvArrowheadMatrix
```

## [Helper utilities](@id library-helpers)

```@docs
Expand Down
10 changes: 10 additions & 0 deletions ext/FastCholeskyExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module FastCholeskyExt

using FastCholesky
using BayesBase

function FastCholesky.cholinv(input::ArrowheadMatrix)
return inv(input)
end

end
1 change: 1 addition & 0 deletions src/BayesBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,6 @@ include("densities/samplelist.jl")
include("densities/mixture.jl")
include("densities/factorizedjoint.jl")
include("densities/contingency.jl")
include("algebra/arrowheadmatrix.jl")

end
277 changes: 277 additions & 0 deletions src/algebra/arrowheadmatrix.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
export ArrowheadMatrix, InvArrowheadMatrix


import LinearAlgebra: SingularException
import Base: getindex
import LinearAlgebra: mul!
import Base: size, *, \, inv, convert, Matrix
"""
ArrowheadMatrix{O, T, Z, P} <: AbstractMatrix{O}

A structure representing an arrowhead matrix, which is a special type of sparse matrix.

# Fields
- `α::T`: The scalar value at the bottom-right corner of the matrix.
- `z::Z`: A vector representing the last row/column (excluding the corner element).
- `D::P`: A vector representing the diagonal elements (excluding the corner element).

# Constructors
ArrowheadMatrix(a::T, z::Z, d::D) where {T,Z,D}

Constructs an `ArrowheadMatrix` with the given α, z, and D values. The output type `O`
is automatically determined as the promoted type of all input elements.

# Operations
- Matrix-vector multiplication: `A * x` or `mul!(y, A, x)`
- Linear system solving: `A \\ b` or `ldiv!(x, A, b)`
- Conversion to dense matrix: `convert(Matrix, A)`
- Inversion: `inv(A)` (returns an `InvArrowheadMatrix`)

# Examples
```julia
α = 2.0
z = [1.0, 2.0, 3.0]
D = [4.0, 5.0, 6.0]
A = ArrowheadMatrix(α, z, D)

# Matrix-vector multiplication
x = [1.0, 2.0, 3.0, 4.0]
y = A * x

# Solving linear system
b = [7.0, 8.0, 9.0, 10.0]
x = A \\ b

# Convert to dense matrix
dense_A = convert(Matrix, A)
```

# Notes
- The matrix is singular if α - dot(z ./ D, z) = 0 or if any element of D is zero.
- For best performance, use `ldiv!` for solving linear systems when possible.
"""
struct ArrowheadMatrix{O, T, Z, P} <: AbstractMatrix{O}
α::T
z::Z
D::P
end
function ArrowheadMatrix(a::T, z::Z, d::D) where {T,Z,D}
O = promote_type(typeof(a), eltype(z), eltype(d))
return ArrowheadMatrix{O, T, Z, D}(a, z, d)
end

function show(io::IO, ::MIME"text/plain", A::ArrowheadMatrix)
n = length(A.D) + 1
println(io, n, "×", n, " ArrowheadMatrix{", eltype(A), "}:")

for i in 1:n-1
for j in 1:n-1
if i == j
print(io, A.D[i])
else
print(io, "⋅")
end
print(io, " ")
end
println(io, A.z[i])
end

# Print the last row
for i in 1:n-1
print(io, A.z[i], " ")
end
println(io, A.α)
end

function size(A::ArrowheadMatrix)
n = length(A.D) + 1
return (n, n)
end

function Base.convert(::Type{Matrix}, A::ArrowheadMatrix{O}) where {O}
n = length(A.z)
M = zeros(O, n + 1, n + 1)
for i in 1:n
M[i, i] = A.D[i]
end
M[1:n, n + 1] .= A.z
M[n + 1, 1:n] .= A.z
M[n + 1, n + 1] = A.α
return M
end

function LinearAlgebra.mul!(y, A::ArrowheadMatrix{T}, x::AbstractVector{T}) where T
n = length(A.z)
if length(x) != n + 1
throw(DimensionMismatch())
end
@inbounds @views begin
y[1:n] = A.D .* x[1:n] + A.z * x[n + 1]
y[n + 1] = dot(A.z, x[1:n]) + A.α * x[n + 1]
end
return y
end

function linsolve!(y::AbstractVector{T2}, A::ArrowheadMatrix{T}, b::AbstractVector{T2}) where {T, T2}
n = length(A.z)

if length(b) != n + 1
throw(DimensionMismatch())
end

z = A.z
D = A.D
α = A.α

# Check for zeros in D to avoid division by zero
@inbounds for i in 1:n
if D[i] == 0
throw(SingularException(1))
end
end

s = zero(T)
t = zero(T)

# Compute s and t in a single loop to avoid recomputing z[i] / D[i]
@inbounds @simd for i in 1:n
zi = z[i]
Di = D[i]
z_div_D = zi / Di
bi = b[i]

s += z_div_D * bi # Accumulate s
t += z_div_D * zi # Accumulate t
end

denom = α - t
if denom == 0
throw(SingularException(1))
end

yn1 = (b[n + 1] - s) / denom
y[n + 1] = yn1

# Compute y[1:n]
@inbounds @simd for i in 1:n
y[i] = (b[i] - z[i] * yn1) / D[i]
end

return y
end

function Base.:\(A::ArrowheadMatrix, b::AbstractVector{T}) where T
y = similar(b)
return linsolve!(y, A, b)
end

function LinearAlgebra.ldiv!(x::AbstractVector{T}, A::ArrowheadMatrix, b::AbstractVector{T}) where T
return linsolve!(x, A, b)
end

"""
InvArrowheadMatrix{O, T, Z, P} <: AbstractMatrix{O}

A wrapper structure representing the inverse of an `ArrowheadMatrix`.

This structure doesn't explicitly compute or store the inverse matrix.
Instead, it stores a reference to the original `ArrowheadMatrix` and
implements efficient operations that leverage the special structure
of the arrowhead matrix.

# Fields
- `A::ArrowheadMatrix{O, T, Z, P}`: The original `ArrowheadMatrix` being inverted.

# Constructors
InvArrowheadMatrix(A::ArrowheadMatrix{O, T, Z, P})

Constructs an `InvArrowheadMatrix` by wrapping the given `ArrowheadMatrix`.

# Operations
- Matrix-vector multiplication: `A_inv * x` or `mul!(y, A_inv, x)`
(Equivalent to solving the system A * y = x)
- Linear system solving: `A_inv \\ x`
(Equivalent to multiplication by the original matrix: A * x)
- Conversion to dense matrix: `convert(Matrix, A_inv)`
(Computes and returns the actual inverse as a dense matrix)

# Examples
```julia
α = 2.0
z = [1.0, 2.0, 3.0]
D = [4.0, 5.0, 6.0]
A = ArrowheadMatrix(α, z, D)
A_inv = inv(A) # Returns an InvArrowheadMatrix

# Multiplication (equivalent to solving A * y = x)
x = [1.0, 2.0, 3.0, 4.0]
y = A_inv * x

# Division (equivalent to multiplying by A)
b = [5.0, 6.0, 7.0, 8.0]
x = A_inv \\ b

# Convert to dense inverse matrix
dense_inv_A = convert(Matrix, A_inv)
```

# Notes
- The inverse exists only if the original `ArrowheadMatrix` is non-singular.
- Operations with `InvArrowheadMatrix` do not explicitly compute the inverse,
but instead solve the corresponding system with the original matrix.

# See Also
- [`ArrowheadMatrix`](@ref): The original arrowhead matrix structure.
"""
struct InvArrowheadMatrix{O, T, Z, P} <: AbstractMatrix{O}
A::ArrowheadMatrix{O, T, Z, P}
end

function show(io::IO, ::MIME"text/plain", A_inv::InvArrowheadMatrix)
n = size(A_inv.A, 1)
println(io, n, "×", n, " InvArrowheadMatrix{", eltype(A_inv), "}:")
println(io, "Inverse of:")
show(io, MIME"text/plain"(), A_inv.A)
end


inv(A::ArrowheadMatrix) = InvArrowheadMatrix(A)

function size(A_inv::InvArrowheadMatrix)
size(A_inv.A)
end

function LinearAlgebra.mul!(y, A_inv::InvArrowheadMatrix{T}, x::AbstractVector{T}) where T
A = A_inv.A
return linsolve!(y, A, x)
end

function Base.:\(A_inv::InvArrowheadMatrix{T}, x::AbstractVector{T}) where T
A = A_inv.A
return A * x
end

function Base.convert(::Type{Matrix}, A_inv::InvArrowheadMatrix{T}) where T
A = A_inv.A
n = length(A.z)
z = A.z
D = A.D
α = A.α

# Compute t = dot(z ./ D, z)
t = dot(z ./ D, z)
denom = α - t
@assert denom != 0 "Matrix is singular."

# Compute u = [ (z ./ D); -1 ]
u = [ z ./ D; -1.0 ]

# Compute the inverse diagonal elements
D_inv = 1.0 ./ D

# Construct the inverse matrix
M = zeros(T, n + 1, n + 1)
M[1:n, 1:n] .= Diagonal(D_inv)
M .+= (u * u') / denom
return M
end
3 changes: 3 additions & 0 deletions src/promotion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ sampletype(::Type{Univariate}, distribution) = eltype(distribution)
sampletype(::Type{Multivariate}, distribution) = Vector{eltype(distribution)}
sampletype(::Type{Matrixvariate}, distribution) = Matrix{eltype(distribution)}

# Exceptions
sampletype(::Gamma{T}) where {T} = T

"""
samplefloattype(distribution)

Expand Down
2 changes: 2 additions & 0 deletions test/algebra/algebrasetup_setuptests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
using BenchmarkTools, LinearAlgebra
using BayesBase
Loading
Loading