Skip to content

Commit

Permalink
Merge pull request #25 from ReactiveBayes/implement-getindex-Arrowhea…
Browse files Browse the repository at this point in the history
…dMatrix

feat: implement getindex method for ArrowheadMatrix
  • Loading branch information
Nimrais authored Oct 22, 2024
2 parents 741a8c7 + d8def25 commit 7eab0f9
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BayesBase"
uuid = "b4ee3484-f114-42fe-b91c-797d54a0c67e"
authors = ["Bagaev Dmitry <[email protected]> and contributors"]
version = "1.4.0"
version = "1.5.0"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
41 changes: 40 additions & 1 deletion src/algebra/arrowheadmatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,28 @@ function ArrowheadMatrix(a::T, z::Z, d::D) where {T,Z,D}
return ArrowheadMatrix{O, T, Z, D}(a, z, d)
end

function Base.getindex(A::ArrowheadMatrix, i::Int, j::Int)

@warn "getindex was called on ArrowheadMatrix. This may lead to suboptimal performance. Consider using specialized methods if available." maxlog=1

n = length(A.D) + 1
if i < 1 || i > n || j < 1 || j > n
throw(BoundsError(A, (i, j)))
end

if i == n && j == n
return A.α
elseif i == n
return A.z[j]
elseif j == n
return A.z[i]
elseif i == j
return A.D[i]
else
return zero(eltype(A))
end
end

function show(io::IO, ::MIME"text/plain", A::ArrowheadMatrix)
n = length(A.D) + 1
println(io, n, "×", n, " ArrowheadMatrix{", eltype(A), "}:")
Expand Down Expand Up @@ -274,4 +296,21 @@ function Base.convert(::Type{Matrix}, A_inv::InvArrowheadMatrix{T}) where T
M[1:n, 1:n] .= Diagonal(D_inv)
M .+= (u * u') / denom
return M
end
end

function LinearAlgebra.dot(x::AbstractVector, A_inv::InvArrowheadMatrix, y::AbstractVector)
A = A_inv.A
n = length(A.z)

if length(x) != n + 1 || length(y) != n + 1
throw(DimensionMismatch("Dimensions must match"))
end

# Compute A_inv * y using linsolve!
temp = similar(y)
linsolve!(temp, A, y)

# Compute the dot product of x and temp
return LinearAlgebra.dot(x, temp)
end

97 changes: 97 additions & 0 deletions test/algebra/arrowheadmatrix_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,5 +288,102 @@ end
# Test linear solve with vector of incorrect size
invalid_b = randn(n+2)
test_error_consistency(A_arrow, A_dense, A -> A \ invalid_b)

# Test BoundsError consistency
test_error_consistency(A_arrow, A_dense, A -> A[n+2, n+2])
test_error_consistency(A_arrow, A_dense, A -> A[0, 1])
test_error_consistency(A_arrow, A_dense, A -> A[1, 0])
test_error_consistency(A_arrow, A_dense, A -> A[-1, -1])

#Test ≈ error
test_error_consistency(A_arrow, A_dense, A -> A zeros(n+1, n+1))

#Test matmul error
test_error_consistency(A_arrow, A_dense, A -> A * zeros(n+1, n+1))
test_error_consistency(A_arrow, A_dense, A -> zeros(n+1, n+1) * A)

#Test dot (x, inv(A), y)
test_error_consistency(A_arrow, A_dense, A -> dot(zeros(n+1), inv(A), zeros(n)))
test_error_consistency(A_arrow, A_dense, A -> dot(zeros(n), inv(A), zeros(n+1)))
end
end

@testitem "ArrowheadMatrix getindex based methods: matmul and ≈" begin
include("algebrasetup_setuptests.jl")

@testset "ArrowheadMatrix: matmul" begin
for n in [3, 5, 10]
α = randn()
z = randn(n)
D = randn(n)
A = ArrowheadMatrix(α, z, D)

B = randn(n+1, n+1)

C_right = A * B
C_right_dense = convert(Matrix, A) * B
@test C_right C_right_dense

C_left = B * A
C_left_dense = B * convert(Matrix, A)
@test C_left C_left_dense

# Check that the result is a dense matrix
@test typeof(C_right) <: Matrix
@test typeof(C_left) <: Matrix
end
end

@testset "ArrowheadMatrix: ≈" begin
for n in [3, 5, 10]
α = randn()
z = randn(n)
D = randn(n)
A = ArrowheadMatrix(α, z, D)
dense_A = convert(Matrix, A)
@test A dense_A
end
end

end

@testitem "ArrowheadMatrix: getindex with Warning" begin
include("algebrasetup_setuptests.jl")

α = 2.0
z = [1.0, 2.0, 3.0]
D = [4.0, 5.0, 6.0]
A = ArrowheadMatrix(α, z, D)

# Test that the warning is shown only once
@test_logs (:warn, "getindex was called on ArrowheadMatrix. This may lead to suboptimal performance. Consider using specialized methods if available.") begin
@test A[1,1] == 4.0
@test A[2,2] == 5.0
@test A[3,3] == 6.0
@test A[4,4] == 2.0
end
end

@testitem "InvArrowheadMatrix: dot(x, A, y) comparison with dense matrix" begin
using LinearAlgebra
include("algebrasetup_setuptests.jl")

for n in [3, 5, 10]
α = rand() + n
z = randn(n)
D = rand(n) .+ n

A = ArrowheadMatrix(α, z, D)
A_inv = inv(A)

x = randn(n + 1)
y = randn(n + 1)

result_arrowhead = dot(x, A_inv, y)
A_dense = Matrix(A)
A_inv_dense = inv(A_dense)
result_dense = dot(x, A_inv_dense * y)

@test isapprox(result_arrowhead, result_dense, rtol=1e-5)
end
end

0 comments on commit 7eab0f9

Please sign in to comment.