From 1e596a41ef4a61510395e9683a42aee8092f4464 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 10 Feb 2023 19:24:43 +0100 Subject: [PATCH] Add MvNormal sampling consistency test (#128) --- test/mvnormal.jl | 74 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 17 deletions(-) diff --git a/test/mvnormal.jl b/test/mvnormal.jl index 5b892870..09f3fb23 100644 --- a/test/mvnormal.jl +++ b/test/mvnormal.jl @@ -47,24 +47,64 @@ include("test_utils.jl") end @testset "MvNormal{T,Pathfinder.WoodburyPDMat{T}}" begin - n = 10 - ndraws = 20 - nhist = 4 - A = rand_pd_diag_mat(Float64, 10) - D = rand_pd_mat(Float64, 2nhist) - B = randn(n, 2nhist) - Σ = Pathfinder.WoodburyPDMat(A, B, D) - μ = randn(n) - dist = MvNormal(μ, Σ) + @testset "basic" begin + n = 10 + ndraws = 20 + nhist = 4 + A = rand_pd_diag_mat(Float64, 10) + D = rand_pd_mat(Float64, 2nhist) + B = randn(n, 2nhist) + Σ = Pathfinder.WoodburyPDMat(A, B, D) + μ = randn(n) + dist = MvNormal(μ, Σ) - seed = 42 - rng = Random.seed!(Random.default_rng(), seed) - x, logpx = @inferred Pathfinder.rand_and_logpdf(rng, dist, ndraws) - Random.seed!(rng, seed) - x2 = rand(rng, dist, ndraws) - logpx2 = logpdf(dist, x2) - @test x ≈ x2 - @test logpx ≈ logpx2 + seed = 42 + rng = Random.seed!(Random.default_rng(), seed) + x, logpx = @inferred Pathfinder.rand_and_logpdf(rng, dist, ndraws) + Random.seed!(rng, seed) + x2 = rand(rng, dist, ndraws) + logpx2 = logpdf(dist, x2) + @test x ≈ x2 + @test logpx ≈ logpx2 + end + + @testset "consistency of rand" begin + n = 10 + ndraws = 300_000 + nhist = 4 + A = rand_pd_diag_mat(Float64, 10) + D = rand_pd_mat(Float64, 2nhist) + B = randn(n, 2nhist) + + Σ = Pathfinder.WoodburyPDMat(A, B, D) + μ = randn(n) + dist = MvNormal(μ, Σ) + v = diag(Σ) + R = Matrix(Σ) ./ sqrt.(v) ./ sqrt.(v') + + x = rand(dist, ndraws) + μ_est = mean(x; dims=2) + v_est = var(x; mean=μ_est, dims=2) + R_est = cor(x; dims=2) + + nchecks = 2n + div(n * (n - 1), 2) + α = (0.01 / nchecks) / 2 # multiple correction + tol = quantile(Normal(), 1 - α) / sqrt(ndraws) + + # asymptotic standard errors for the marginal estimators + μ_std = sqrt.(v) + v_std = sqrt(2) * v + + for i in 1:n + @test μ_est[i] ≈ μ[i] atol = (tol * μ_std[i]) + @test v_est[i] ≈ v[i] atol = (tol * v_std[i]) + for j in (i + 1):n + # use variance-stabilizing transformation, recommended in §3.6 of + # Van der Vaart, A. W. (2000). Asymptotic statistics (Vol. 3). + @test atanh(R_est[i, j]) ≈ atanh(R[i, j]) atol = tol + end + end + end end @testset "Normal" begin