Skip to content

Commit

Permalink
Merge pull request #21 from ReactiveBayes/add-arrow-matrix
Browse files Browse the repository at this point in the history
feat: add ArrowheadMatrix
  • Loading branch information
bvdmitri authored Oct 11, 2024
2 parents dcc23ff + 80c29f4 commit 535238a
Show file tree
Hide file tree
Showing 7 changed files with 598 additions and 3 deletions.
13 changes: 10 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,16 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04"

[weakdeps]
FastCholesky = "2d5283b6-8564-42b6-bb00-83ed8e915756"

[extensions]
FastCholeskyExt = "FastCholesky"

[compat]
Distributions = "0.25"
DomainSets = "0.5.2, 0.6, 0.7"
FastCholesky = "1.3.1"
LinearAlgebra = "1.9"
LoopVectorization = "0.12"
Random = "1.9"
Expand All @@ -34,13 +41,13 @@ julia = "1.9"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CpuId = "adafc99b-e345-5852-983c-f28acb93d879"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "CpuId", "JET", "Test", "ReTestItems", "LinearAlgebra", "StableRNGs", "HCubature"]
test = ["Aqua", "BenchmarkTools", "CpuId", "FastCholesky", "JET", "Test", "ReTestItems", "StableRNGs", "HCubature"]
6 changes: 6 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ BayesBase.weightedmean_cov
BayesBase.weightedmean_invcov
```

## [Extra matrix structures](@id matrix-structures)
```@docs
BayesBase.ArrowheadMatrix
BayesBase.InvArrowheadMatrix
```

## [Helper utilities](@id library-helpers)

```@docs
Expand Down
10 changes: 10 additions & 0 deletions ext/FastCholeskyExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module FastCholeskyExt

using FastCholesky
using BayesBase

function FastCholesky.cholinv(input::ArrowheadMatrix)
return inv(input)
end

end
1 change: 1 addition & 0 deletions src/BayesBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,6 @@ include("densities/samplelist.jl")
include("densities/mixture.jl")
include("densities/factorizedjoint.jl")
include("densities/contingency.jl")
include("algebra/arrowheadmatrix.jl")

end
277 changes: 277 additions & 0 deletions src/algebra/arrowheadmatrix.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
export ArrowheadMatrix, InvArrowheadMatrix


import LinearAlgebra: SingularException
import Base: getindex
import LinearAlgebra: mul!
import Base: size, *, \, inv, convert, Matrix
"""
ArrowheadMatrix{O, T, Z, P} <: AbstractMatrix{O}
A structure representing an arrowhead matrix, which is a special type of sparse matrix.
# Fields
- `α::T`: The scalar value at the bottom-right corner of the matrix.
- `z::Z`: A vector representing the last row/column (excluding the corner element).
- `D::P`: A vector representing the diagonal elements (excluding the corner element).
# Constructors
ArrowheadMatrix(a::T, z::Z, d::D) where {T,Z,D}
Constructs an `ArrowheadMatrix` with the given α, z, and D values. The output type `O`
is automatically determined as the promoted type of all input elements.
# Operations
- Matrix-vector multiplication: `A * x` or `mul!(y, A, x)`
- Linear system solving: `A \\ b` or `ldiv!(x, A, b)`
- Conversion to dense matrix: `convert(Matrix, A)`
- Inversion: `inv(A)` (returns an `InvArrowheadMatrix`)
# Examples
```julia
α = 2.0
z = [1.0, 2.0, 3.0]
D = [4.0, 5.0, 6.0]
A = ArrowheadMatrix(α, z, D)
# Matrix-vector multiplication
x = [1.0, 2.0, 3.0, 4.0]
y = A * x
# Solving linear system
b = [7.0, 8.0, 9.0, 10.0]
x = A \\ b
# Convert to dense matrix
dense_A = convert(Matrix, A)
```
# Notes
- The matrix is singular if α - dot(z ./ D, z) = 0 or if any element of D is zero.
- For best performance, use `ldiv!` for solving linear systems when possible.
"""
struct ArrowheadMatrix{O, T, Z, P} <: AbstractMatrix{O}
α::T
z::Z
D::P
end
function ArrowheadMatrix(a::T, z::Z, d::D) where {T,Z,D}
O = promote_type(typeof(a), eltype(z), eltype(d))
return ArrowheadMatrix{O, T, Z, D}(a, z, d)
end

function show(io::IO, ::MIME"text/plain", A::ArrowheadMatrix)
n = length(A.D) + 1
println(io, n, "×", n, " ArrowheadMatrix{", eltype(A), "}:")

for i in 1:n-1
for j in 1:n-1
if i == j
print(io, A.D[i])
else
print(io, "")
end
print(io, " ")
end
println(io, A.z[i])
end

# Print the last row
for i in 1:n-1
print(io, A.z[i], " ")
end
println(io, A.α)
end

function size(A::ArrowheadMatrix)
n = length(A.D) + 1
return (n, n)
end

function Base.convert(::Type{Matrix}, A::ArrowheadMatrix{O}) where {O}
n = length(A.z)
M = zeros(O, n + 1, n + 1)
for i in 1:n
M[i, i] = A.D[i]
end
M[1:n, n + 1] .= A.z
M[n + 1, 1:n] .= A.z
M[n + 1, n + 1] = A.α
return M
end

function LinearAlgebra.mul!(y, A::ArrowheadMatrix{T}, x::AbstractVector{T}) where T
n = length(A.z)
if length(x) != n + 1
throw(DimensionMismatch())
end
@inbounds @views begin
y[1:n] = A.D .* x[1:n] + A.z * x[n + 1]
y[n + 1] = dot(A.z, x[1:n]) + A.α * x[n + 1]
end
return y
end

function linsolve!(y::AbstractVector{T2}, A::ArrowheadMatrix{T}, b::AbstractVector{T2}) where {T, T2}
n = length(A.z)

if length(b) != n + 1
throw(DimensionMismatch())
end

z = A.z
D = A.D
α = A.α

# Check for zeros in D to avoid division by zero
@inbounds for i in 1:n
if D[i] == 0
throw(SingularException(1))
end
end

s = zero(T)
t = zero(T)

# Compute s and t in a single loop to avoid recomputing z[i] / D[i]
@inbounds @simd for i in 1:n
zi = z[i]
Di = D[i]
z_div_D = zi / Di
bi = b[i]

s += z_div_D * bi # Accumulate s
t += z_div_D * zi # Accumulate t
end

denom = α - t
if denom == 0
throw(SingularException(1))
end

yn1 = (b[n + 1] - s) / denom
y[n + 1] = yn1

# Compute y[1:n]
@inbounds @simd for i in 1:n
y[i] = (b[i] - z[i] * yn1) / D[i]
end

return y
end

function Base.:\(A::ArrowheadMatrix, b::AbstractVector{T}) where T
y = similar(b)
return linsolve!(y, A, b)
end

function LinearAlgebra.ldiv!(x::AbstractVector{T}, A::ArrowheadMatrix, b::AbstractVector{T}) where T
return linsolve!(x, A, b)
end

"""
InvArrowheadMatrix{O, T, Z, P} <: AbstractMatrix{O}
A wrapper structure representing the inverse of an `ArrowheadMatrix`.
This structure doesn't explicitly compute or store the inverse matrix.
Instead, it stores a reference to the original `ArrowheadMatrix` and
implements efficient operations that leverage the special structure
of the arrowhead matrix.
# Fields
- `A::ArrowheadMatrix{O, T, Z, P}`: The original `ArrowheadMatrix` being inverted.
# Constructors
InvArrowheadMatrix(A::ArrowheadMatrix{O, T, Z, P})
Constructs an `InvArrowheadMatrix` by wrapping the given `ArrowheadMatrix`.
# Operations
- Matrix-vector multiplication: `A_inv * x` or `mul!(y, A_inv, x)`
(Equivalent to solving the system A * y = x)
- Linear system solving: `A_inv \\ x`
(Equivalent to multiplication by the original matrix: A * x)
- Conversion to dense matrix: `convert(Matrix, A_inv)`
(Computes and returns the actual inverse as a dense matrix)
# Examples
```julia
α = 2.0
z = [1.0, 2.0, 3.0]
D = [4.0, 5.0, 6.0]
A = ArrowheadMatrix(α, z, D)
A_inv = inv(A) # Returns an InvArrowheadMatrix
# Multiplication (equivalent to solving A * y = x)
x = [1.0, 2.0, 3.0, 4.0]
y = A_inv * x
# Division (equivalent to multiplying by A)
b = [5.0, 6.0, 7.0, 8.0]
x = A_inv \\ b
# Convert to dense inverse matrix
dense_inv_A = convert(Matrix, A_inv)
```
# Notes
- The inverse exists only if the original `ArrowheadMatrix` is non-singular.
- Operations with `InvArrowheadMatrix` do not explicitly compute the inverse,
but instead solve the corresponding system with the original matrix.
# See Also
- [`ArrowheadMatrix`](@ref): The original arrowhead matrix structure.
"""
struct InvArrowheadMatrix{O, T, Z, P} <: AbstractMatrix{O}
A::ArrowheadMatrix{O, T, Z, P}
end

function show(io::IO, ::MIME"text/plain", A_inv::InvArrowheadMatrix)
n = size(A_inv.A, 1)
println(io, n, "×", n, " InvArrowheadMatrix{", eltype(A_inv), "}:")
println(io, "Inverse of:")
show(io, MIME"text/plain"(), A_inv.A)
end


inv(A::ArrowheadMatrix) = InvArrowheadMatrix(A)

function size(A_inv::InvArrowheadMatrix)
size(A_inv.A)
end

function LinearAlgebra.mul!(y, A_inv::InvArrowheadMatrix{T}, x::AbstractVector{T}) where T
A = A_inv.A
return linsolve!(y, A, x)
end

function Base.:\(A_inv::InvArrowheadMatrix{T}, x::AbstractVector{T}) where T
A = A_inv.A
return A * x
end

function Base.convert(::Type{Matrix}, A_inv::InvArrowheadMatrix{T}) where T
A = A_inv.A
n = length(A.z)
z = A.z
D = A.D
α = A.α

# Compute t = dot(z ./ D, z)
t = dot(z ./ D, z)
denom = α - t
@assert denom != 0 "Matrix is singular."

# Compute u = [ (z ./ D); -1 ]
u = [ z ./ D; -1.0 ]

# Compute the inverse diagonal elements
D_inv = 1.0 ./ D

# Construct the inverse matrix
M = zeros(T, n + 1, n + 1)
M[1:n, 1:n] .= Diagonal(D_inv)
M .+= (u * u') / denom
return M
end
2 changes: 2 additions & 0 deletions test/algebra/algebrasetup_setuptests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
using BenchmarkTools, LinearAlgebra
using BayesBase
Loading

0 comments on commit 535238a

Please sign in to comment.