Skip to content

Commit

Permalink
Add MvNormal sampling consistency test (#128)
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen authored Feb 10, 2023
1 parent 86b8285 commit 1e596a4
Showing 1 changed file with 57 additions and 17 deletions.
74 changes: 57 additions & 17 deletions test/mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,24 +47,64 @@ include("test_utils.jl")
end

@testset "MvNormal{T,Pathfinder.WoodburyPDMat{T}}" begin
n = 10
ndraws = 20
nhist = 4
A = rand_pd_diag_mat(Float64, 10)
D = rand_pd_mat(Float64, 2nhist)
B = randn(n, 2nhist)
Σ = Pathfinder.WoodburyPDMat(A, B, D)
μ = randn(n)
dist = MvNormal(μ, Σ)
@testset "basic" begin
n = 10
ndraws = 20
nhist = 4
A = rand_pd_diag_mat(Float64, 10)
D = rand_pd_mat(Float64, 2nhist)
B = randn(n, 2nhist)
Σ = Pathfinder.WoodburyPDMat(A, B, D)
μ = randn(n)
dist = MvNormal(μ, Σ)

seed = 42
rng = Random.seed!(Random.default_rng(), seed)
x, logpx = @inferred Pathfinder.rand_and_logpdf(rng, dist, ndraws)
Random.seed!(rng, seed)
x2 = rand(rng, dist, ndraws)
logpx2 = logpdf(dist, x2)
@test x x2
@test logpx logpx2
seed = 42
rng = Random.seed!(Random.default_rng(), seed)
x, logpx = @inferred Pathfinder.rand_and_logpdf(rng, dist, ndraws)
Random.seed!(rng, seed)
x2 = rand(rng, dist, ndraws)
logpx2 = logpdf(dist, x2)
@test x x2
@test logpx logpx2
end

@testset "consistency of rand" begin
n = 10
ndraws = 300_000
nhist = 4
A = rand_pd_diag_mat(Float64, 10)
D = rand_pd_mat(Float64, 2nhist)
B = randn(n, 2nhist)

Σ = Pathfinder.WoodburyPDMat(A, B, D)
μ = randn(n)
dist = MvNormal(μ, Σ)
v = diag(Σ)
R = Matrix(Σ) ./ sqrt.(v) ./ sqrt.(v')

x = rand(dist, ndraws)
μ_est = mean(x; dims=2)
v_est = var(x; mean=μ_est, dims=2)
R_est = cor(x; dims=2)

nchecks = 2n + div(n * (n - 1), 2)
α = (0.01 / nchecks) / 2 # multiple correction
tol = quantile(Normal(), 1 - α) / sqrt(ndraws)

# asymptotic standard errors for the marginal estimators
μ_std = sqrt.(v)
v_std = sqrt(2) * v

for i in 1:n
@test μ_est[i] μ[i] atol = (tol * μ_std[i])
@test v_est[i] v[i] atol = (tol * v_std[i])
for j in (i + 1):n
# use variance-stabilizing transformation, recommended in §3.6 of
# Van der Vaart, A. W. (2000). Asymptotic statistics (Vol. 3).
@test atanh(R_est[i, j]) atanh(R[i, j]) atol = tol
end
end
end
end

@testset "Normal" begin
Expand Down

0 comments on commit 1e596a4

Please sign in to comment.