Skip to content

Commit

Permalink
Merge pull request #25 from ElOceanografo/di
Browse files Browse the repository at this point in the history
Calculate sparse Hessians using DifferentiationInterface
  • Loading branch information
ElOceanografo authored Jun 5, 2024
2 parents f90c410 + 9d937fe commit c2c2c74
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 197 deletions.
30 changes: 18 additions & 12 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,41 +1,47 @@
name = "MarginalLogDensities"
uuid = "f0c3360a-fb8d-11e9-1194-5521fd7ee392"
authors = ["Sam Urmy <[email protected]>"]
version = "0.2.1"
version = "0.3.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"

[compat]
ADTypes = "1"
ChainRulesCore = "1"
DifferentiationInterface = "0.5.3"
Distributions = "0.25"
FiniteDiff = "2"
ForwardDiff = "0.10"
HCubature = "1"
Optim = "1"
Optimization = "3"
OptimizationOptimJL = "0.1, 0.2, 0.3"
LinearAlgebra = "1"
Optimization = "3.25"
OptimizationOptimJL = "0.3"
Reexport = "1"
ReverseDiff = "1"
SparseDiffTools = "2"
SparseArrays = "1"
SparseConnectivityTracer = "0.5"
Zygote = "0.6"
julia = "1.7"
julia = "1.9"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Random", "Test", "ChainRulesTestUtils"]
test = ["ChainRulesTestUtils", "Distributions", "Random", "ReverseDiff", "Test", "Zygote"]
204 changes: 109 additions & 95 deletions src/MarginalLogDensities.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
module MarginalLogDensities
using ForwardDiff, FiniteDiff, ReverseDiff, Zygote

using Reexport
using Optimization
using OptimizationOptimJL
import ForwardDiff, FiniteDiff
@reexport using DifferentiationInterface
@reexport using ADTypes
@reexport using SparseConnectivityTracer
using LinearAlgebra
using SparseArrays
using ChainRulesCore
using HCubature
using Distributions
# using SparsityDetection
# using Symbolics
using SparseDiffTools
# using Distributions

export MarginalLogDensity,
AbstractMarginalizer,
Expand All @@ -24,19 +26,13 @@ export MarginalLogDensity,
merge_parameters,
split_parameters,
optimize_marginal!,
hessdiag,
# hessdiag,
get_hessian_sparsity

# can't seem to precompile these functions
# auto_ad_hess(x::Optimization.AutoFiniteDiff) = FiniteDiff.finite_difference_hessian!
# auto_ad_hess(x::Optimization.AutoForwardDiff) = ForwardDiff.hessian!
# auto_ad_hess(x::Optimization.AutoReverseDiff) = ReverseDiff.hessian!
# auto_ad_hess(x::Optimization.AutoZygote) = (H, f, x) -> first(Zygote.hessian!(H, f, x))

abstract type AbstractMarginalizer end

"""
`LaplaceApprox([solver=LBFGS() [; adtype=Optimization.AutoForwardDiff(), opt_func_kwargs...]])
`LaplaceApprox([solver=LBFGS() [; adtype=AutoForwardDiff(), opt_func_kwargs...]])
Construct a `LaplaceApprox` marginalizer to integrate out marginal variables via
the Laplace approximation. This method will usually be faster than `Cubature`, especially
Expand All @@ -45,25 +41,26 @@ in high dimensions, though it may not be as accurate.
# Arguments
- `solver=LBFGS()` : Algorithm to use when performing the inner optimization to find the
mode of the marginalized variables. Can be any algorithm defined in Optim.jl.
- `adtype=Optimization.AutoForwardDiff()` : Automatic differentiation type to use for the
inner optimization. `AutoForwardDiff()` is robust and fast for small problems; for larger
ones `AutoReverseDiff()` or `AutoZygote()` are likely better.
- `opt_func_kwargs` : Optional keyword arguments passed on to `Optimization.OptimizationFunction`.
- `adtype=AutoForwardDiff()` : Automatic differentiation type to use for the inner
optimization, specified via the ADTypes.jl interface. `AutoForwardDiff()` is robust and
fast for small problems; for larger ones `AutoReverseDiff()` or `AutoZygote()` are likely
better.
- `opt_func_kwargs` : Optional keyword arguments passed on to
`Optimization.OptimizationFunction`.
"""
struct LaplaceApprox{TA, TT, TS} <: AbstractMarginalizer
# sparsehess::Bool
solver::TS
adtype::TA
opt_func_kwargs::TT
end

function LaplaceApprox(solver=LBFGS(); adtype=Optimization.AutoForwardDiff(),
function LaplaceApprox(solver=LBFGS(); adtype=AutoForwardDiff(),
opt_func_kwargs...)
return LaplaceApprox(solver, adtype, opt_func_kwargs)
end

"""
Cubature([; solver=LBFGS(), adtype=Optimization.AutoForwardDiff(),
Cubature([; solver=LBFGS(), adtype=AutoForwardDiff(),
upper=nothing, lower=nothing, nσ=6; opt_func_kwargs...])
Construct a `Cubature` marginalizer to integrate out marginal variables via
Expand All @@ -79,14 +76,16 @@ The integration is performed using `hcubature` from Cubature.jl.
# Arguments
- `solver=LBFGS()` : Algorithm to use when performing the inner optimization to find the
mode of the marginalized variables. Can be any algorithm defined in Optim.jl.
- `adtype=Optimization.AutoForwardDiff()` : Automatic differentiation type to use for the
- `adtype=AutoForwardDiff()` : Automatic differentiation type to use for the
inner optimization. `AutoForwardDiff()` is robust and fast for small problems; for larger
ones `AutoReverseDiff()` or `AutoZygote()` are likely better.
- `upper`, `lower` : Optional upper and lower bounds for the numerical integration. If supplied,
they must be numeric vectors the same length as the marginal variables.
- `nσ=6.0` : If `upper` and `lower` are not supplied, integrate this many standard deviations
away from the mode based on a Laplace approximation to the curvature there.
- `opt_func_kwargs` : Optional keyword arguments passed on to `Optimization.OptimizationFunction`.
- `upper`, `lower` : Optional upper and lower bounds for the numerical integration. If
supplied, they must be numeric vectors the same length as the marginal variables.
- `nσ=6.0` : If `upper` and `lower` are not supplied, integrate this many standard
deviations away from the mode based on a Laplace approximation to the curvature at that
point.
- `opt_func_kwargs` : Optional keyword arguments passed on to
`Optimization.OptimizationFunction`.
"""
struct Cubature{TA, TT, TS, TV, T} <: AbstractMarginalizer
solver::TS
Expand All @@ -97,17 +96,24 @@ struct Cubature{TA, TT, TS, TV, T} <: AbstractMarginalizer
::T
end

function Cubature(; solver=LBFGS(), adtype=Optimization.AutoForwardDiff(),
function Cubature(; solver=LBFGS(), adtype=AutoForwardDiff(),
upper=nothing, lower=nothing, nσ=6.0, opt_func_kwargs...)
return Cubature(solver, adtype, opt_func_kwargs, promote(upper, lower)..., nσ)
end


"""
`MarginalLogDensity(logdensity, u, iw, data, [method=LaplaceApprox()])`
`MarginalLogDensity(logdensity, u, iw, data, [method=LaplaceApprox();
[hess_adtype=nothing, sparsity_detector=DenseSparsityDetector(method.adtype, atol=cbrt(eps())),
coloring_algorithm=GreedyColoringAlgorithm()]])`
Construct a callable object which wraps the function `logdensity` and
integrates over a subset of its arguments.
The resulting `MarginalLogDensity` object `mld` can then be called like a function
as `mld(v, data)`, where `v` is the subset of the full parameter vector `u` which is
*not* indexed by `iw`. If `length(u) == n` and `length(iw) == m`, then `length(v) == n-m`.
# Arguments
- `logdensity` : function with signature `(u, data)` returning a positive
log-probability (e.g. a log-pdf, log-likelihood, or log-posterior). In this
Expand All @@ -118,15 +124,19 @@ NamedTuple, or whatever) that contains data and/or fixed parameters.
- `data=()` : Optional argument
- `method` : How to perform the marginalization. Defaults to `LaplaceApprox()`; `Cubature()`
is also available.
- `hess_autosparse=:none` : Specifies how to detect sparsity in the Hessian matrix of
`logdensity`. Can be `:none`, `:finitediff`` `:forwarddiff`, or `:sparsitydetection`.
If `:none` (the default), the Hessian is assumed dense and calculated using `ForwardDiff`.
Detecting sparsity takes some time and may not be worth it for small problems, but for
larger problems it can be extremely worth it.
- `hess_adtype = nothing` : Specifies how to calculate the Hessian of the marginalized
variables. If not specified, defaults to a sparse second-order method using finite
differences over the AD type given in the `method` (`AutoForwardDiff()` is the default).
Other backends can be set by loading the appropriate AD package and using the ADTypes.jl
interface.
- `sparsity_detector = DenseSparsityDetector(method.adtype, atol=cbrt(eps))` : How to
perform the sparsity detection. Detecting sparsity takes some time and may not be worth it
for small problems, but for larger problems it can be extremely worth it. The default
`DenseSparsityDetector` is most robust, but if it's too slow, or if you're running out of
memory on a larger problem, try the tracing-based dectectors from SparseConnectivityTracer.jl.
- `coloring_algorithm = GreedyColoringAlgorithm()` : How to determine the matrix "colors"
to compress the sparse Hessian.
The resulting `MarginalLogDensity` object `mld` can then be called like a function
as `mld(v, data)`, where `v` is the subset of the full parameter vector `u` which is
*not* indexed by `iw`. If `length(u) == n` and `length(iw) == m`, then `length(v) == n-m`.
# Examples
```julia-repl
Expand All @@ -150,62 +160,59 @@ julia> mld(rand(2), data)
```
"""
struct MarginalLogDensity{TF, TU<:AbstractVector, TD, TV<:AbstractVector, TW<:AbstractVector,
TF1<:OptimizationFunction, TM<:AbstractMarginalizer}
struct MarginalLogDensity{
TF,
TU<:AbstractVector,
TD,
TV<:AbstractVector,
TW<:AbstractVector,
TM<:AbstractMarginalizer,
TF1<:OptimizationFunction,
TP<:OptimizationProblem,
TC<:OptimizationCache,
TH<:AbstractMatrix,
TB<:ADTypes.AbstractADType,
TE<:DifferentiationInterface.HessianExtras
}
logdensity::TF
u::TU
data::TD
iv::TV
iw::TW
F::TF1
method::TM
f_opt::TF1
prob::TP
cache::TC
H::TH
hess_adtype::TB
hess_extras::TE
end

function get_hessian_prototype(f, w, p2, autosparsity)
f2(w) = f(w, p2)
if autosparsity == :finitediff
H = FiniteDiff.finite_difference_hessian(f2, w)
hess_prototype = sparse(H)
elseif autosparsity == :forwarddiff
H = ForwardDiff.hessian(f2, w)
hess_prototype = sparse(H)
elseif autosparsity == :reversediff
H = ReverseDiff.hessian(f2, w)
hess_prototype = sparse(H)
elseif autosparsity == :zygote
H = ReverseDiff.hessian(f2, w)
hess_prototype = sparse(H)
# elseif autosparsity == :sparsitydetection
# hess_prototype = SparsityDetection.hessian_sparsity(w -> f(w, p2), w) .* one(eltype(w))
# elseif autosparsity == :symbolics
# ...
elseif autosparsity == :none
hess_prototype = ones(eltype(w), length(w), length(w))
else
error("Unsupported method for hessian sparsity detection: $(autosparsity)")
end
return hess_prototype
end

function MarginalLogDensity(logdensity, u, iw, data=(), method=LaplaceApprox(); hess_autosparse=:none)
function MarginalLogDensity(logdensity, u, iw, data=(), method=LaplaceApprox();
hess_adtype=nothing, sparsity_detector=DenseSparsityDetector(method.adtype, atol=sqrt(eps())),
coloring_algorithm=GreedyColoringAlgorithm())
n = length(u)
iv = setdiff(1:n, iw)
w = u[iw]
v = u[iv]
p2 = (p=data, v=v)
f(w, p2) = -logdensity(merge_parameters(p2.v, w, iv, iw), p2.p)
hess_prototype = get_hessian_prototype(f, w, p2, hess_autosparse)
if hess_autosparse != :none
hess_colorvec = matrix_colors(hess_prototype)
hess = (H, w, p2) -> numauto_color_hessian!(H, w -> f(w, p2), w, hess_colorvec, hess_prototype)
F = OptimizationFunction(f, method.adtype; hess_prototype=hess_prototype, hess_colorvec=hess_colorvec,
hess = hess, method.opt_func_kwargs...)
else
hess = (H, w, p2) -> ForwardDiff.hessian!(H, w -> f(w, p2), w) #auto_ad_hess(method.adtype)(H, u, p)
F = OptimizationFunction(f, method.adtype; hess_prototype=hess_prototype,
hess = hess, method.opt_func_kwargs...)
f_opt = OptimizationFunction(f, method.adtype; method.opt_func_kwargs...)
prob = OptimizationProblem(f_opt, w, p2)
cache = init(prob, method.solver)

if isnothing(hess_adtype)
hess_adtype = AutoSparse(
SecondOrder(AutoFiniteDiff(), method.adtype),
sparsity_detector,
coloring_algorithm
)
end
return MarginalLogDensity(logdensity, u, data, iv, iw, F, method)
extras = prepare_hessian(w -> f(w, p2), hess_adtype, w)
H = hessian(w -> f(w, p2), hess_adtype, w, extras)
return MarginalLogDensity(logdensity, u, data, iv, iw, method, f_opt, prob, cache,
H, hess_adtype, extras)
end

function Base.show(io::IO, mld::MarginalLogDensity)
Expand Down Expand Up @@ -234,7 +241,7 @@ nmarginal(mld::MarginalLogDensity) = length(mld.iw)
njoint(mld::MarginalLogDensity) = length(mld.iv)

"""Get the value of the cached Hessian matrix."""
cached_hessian(mld::MarginalLogDensity) = mld.F.hess_prototype
cached_hessian(mld::MarginalLogDensity) = mld.H

"""
Splice together the estimated (fixed) parameters `v` and marginalized (random) parameters
Expand Down Expand Up @@ -269,23 +276,28 @@ split_parameters(u, iv, iw) = (u[iv], u[iw])

function optimize_marginal!(mld, p2)
w0 = mld.u[mld.iw]
prob = OptimizationProblem(mld.F, w0, p2)
sol = solve(prob, mld.method.solver)
wopt = sol.u
mld.u[mld.iw] = wopt
return sol
reinit!(mld.cache, u0=w0, p=p2)
sol = solve!(mld.cache)
wopt = sol.u::typeof(w0)
objective = sol.objective::eltype(w0)
mld.u[mld.iw] .= wopt
return wopt, objective
end

function modal_hessian!(mld::MarginalLogDensity, w, p2)
hessian!(w -> mld.f_opt(w, p2), mld.H, mld.hess_adtype, w, mld.hess_extras)
return mld.H
end

function _marginalize(mld, v, data, method::LaplaceApprox, verbose)
p2 = (; p=data, v)
verbose && println("Finding mode...")
sol = optimize_marginal!(mld, p2)
wopt, objective = optimize_marginal!(mld, p2)
verbose && println("Calculating hessian...")
# H = -ForwardDiff.hessian(w -> mld.F(w, p2), sol.u)
H = mld.F.hess(mld.F.hess_prototype, sol.u, p2)
modal_hessian!(mld, wopt, p2)
verbose && println("Integrating...")
nw = length(mld.iw)
integral = -sol.objective + (nw/2)* log(2π) - 0.5logabsdet(H)[1]
integral = -objective + (0.5nw) * log(2π) - 0.5logabsdet(mld.H)[1]
verbose && println("Done!")
return integral#, sol
end
Expand All @@ -303,24 +315,26 @@ end
function _marginalize(mld, v, data, method::Cubature, verbose)
p2 = (; p=data, v)
if method.lower == nothing || method.upper == nothing
sol = optimize_marginal!(mld, p2)
wopt = sol.u
h = hessdiag(w -> mld.F(w, p2), wopt)
wopt, _ = optimize_marginal!(mld, p2)
println(wopt)
h = hessdiag(w -> mld.f_opt(w, p2), wopt)
se = 1 ./ sqrt.(h)
upper = wopt .+ method.* se
lower = wopt .- method.* se
else
lower = method.lower
upper = method.upper
end
println(upper)
println(lower)
integral, err = hcubature(w -> exp(-mld.F(w, p2)), lower, upper)
if verbose
println(upper)
println(lower)
end
integral, err = hcubature(w -> exp(-mld.f_opt(w, p2)), lower, upper)
return log(integral)
end

function Optim.optimize(mld::MarginalLogDensity, init_v, data=(), args...; kwargs...)
return optimize(v -> -mld(v, data), init_v, args...; kwargs...)
end
# function Optim.optimize(mld::MarginalLogDensity, init_v, data=(), args...; kwargs...)
# return optimize(v -> -mld(v, data), init_v, args...; kwargs...)
# end

end # module
Loading

0 comments on commit c2c2c74

Please sign in to comment.