Skip to content

Commit

Permalink
small changes: reorganize validation logic, check types with isinstan…
Browse files Browse the repository at this point in the history
…ce(), rename stepsizes to step_sizes, rename stepcounts to step_counts, change list type to sequence type.
  • Loading branch information
gil2rok authored Oct 26, 2023
1 parent 1943914 commit 453af20
Showing 1 changed file with 91 additions and 97 deletions.
188 changes: 91 additions & 97 deletions bayes_kit/drghmc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Iterator, Optional, Tuple
from collections.abc import Sequence

import numpy as np
from numpy.typing import ArrayLike
Expand All @@ -23,21 +24,22 @@ class DrGhmcDiag:
This efficiently samples from multiscale distributions because of probabilistic
delayed rejection and partial momentum refresh. With non-increasing leapfrog
stepsizes, the former encourages large stepsizes in wide, flat regions and smaller
stepsizes in narrow, high-curvature regions. The latter suppresses random-walk
behavior by only partially updating the auxiliary momentum variable rho.
step sizes, the former encourages large step sizes in wide, flat regions and
smaller step sizes in narrow, high-curvature regions. The latter suppresses
random-walk behavior by only partially updating the auxiliary momentum variable
rho.
This implementation is based on Modi C, Barnett A, Carpenter B. Delayed rejection
Hamiltonian Monte Carlo for sampling multiscale distributions. Bayesian Analysis.
Hamiltonian Monte Carlo for sampling multiscale distributions. *Bayesian Analysis*.
2023. https://doi.org/10.1214/23-BA1360.
"""

def __init__(
self,
model: GradModel,
max_proposals: int,
leapfrog_stepsizes: list[float],
leapfrog_stepcounts: list[int],
leapfrog_step_sizes: Sequence[float],
leapfrog_step_counts: Sequence[int],
damping: float,
metric_diag: Optional[VectorType] = None,
init: Optional[VectorType] = None,
Expand All @@ -49,8 +51,8 @@ def __init__(
Args:
model: probabilistic model with log density and gradient
max_proposals: maximum number of proposal attempts
leapfrog_stepsizes: list of leapfrog stepsizes
leapfrog_stepcounts: list of number of leapfrog steps
leapfrog_step_sizes: list of leapfrog step_sizes
leapfrog_step_counts: list of number of leapfrog steps
damping: generalized HMC momentum damping factor in (0, 1]
metric_diag: diagonal of a diagonal metric. Defaults to identity metric.
init: parameter vector to initialize position variable theta. Defaults to
Expand All @@ -61,12 +63,10 @@ def __init__(
"""
self._model = model
self._dim = self._model.dims()
self._max_proposals = self._validate_propoals(max_proposals)
self._leapfrog_stepsizes = self._validate_leapfrog_stepsizes(leapfrog_stepsizes)
self._leapfrog_stepcounts = self._validate_leapfrog_stepcounts(
leapfrog_stepcounts
)
self._damping = self._validate_damping(damping)
self._max_proposals = max_proposals
self._leapfrog_step_sizes = leapfrog_step_sizes
self._leapfrog_step_counts = leapfrog_step_counts
self._damping = damping
self._metric = metric_diag or np.ones(self._dim)
self._rng = np.random.default_rng(seed)
self._theta = (
Expand All @@ -80,9 +80,19 @@ def __init__(
# use stack to avoid redundant computation within a single draw (when
# recursively computing the log acceptance probability) and across draws
self._log_density_gradient_cache: list[Tuple[float, VectorType]] = []
self._validate_arguments()

def _validate_arguments(self) -> None:
""" Raise error if constructor recieves invalid maximum number of proposals,
leapfrog step sizes, leapfrog step counts, or damping factor.
"""
self._validate_propoals(self._max_proposals)
self._validate_leapfrog_step_sizes(self._leapfrog_step_sizes)
self._validate_leapfrog_step_counts(self._leapfrog_step_counts)
self._validate_damping(self._damping)

def _validate_propoals(self, max_proposals: int) -> int:
"""Check that the maximum number of proposals is an integer greater than or
def _validate_propoals(self, max_proposals: int) -> None:
"""Raise error if maximum number of proposals is not an integer greater than or
equal to one.
Args:
Expand All @@ -91,124 +101,108 @@ def _validate_propoals(self, max_proposals: int) -> int:
Raises:
TypeError: max_proposals is not an int
ValueError: max_proposals is less than one
Returns:
validated number of proposals
"""
if not (type(max_proposals) is int):
if not isinstance(max_proposals, int):
raise TypeError(f"max_proposals must be an int, not {type(max_proposals)}")
if not (max_proposals >= 1):
raise ValueError(
f"max_proposals must be greater than or equal to 1, not {max_proposals}"
)
return max_proposals

def _validate_leapfrog_stepsizes(
self, leapfrog_stepsizes: list[float]
) -> list[float]:
"""Check that leapfrog stepsizes is a list with positive, float stepsizes and a
length equal to the maximum number of proposals.
def _validate_leapfrog_step_sizes(
self, leapfrog_step_sizes: Sequence[float]
) -> None:
"""Raise error if leapfrog step sizes is not a list with positive, float step
sizes and a length equal to the maximum number of proposals.
Args:
leapfrog_stepsizes: list of leapfrog stepsizes
leapfrog_step_sizes: list of leapfrog step sizes
Raises:
TypeError: leapfrog_stepsizes is not a list
ValueError: leapfrog_stepsizes is of incorrect length
TypeError: leapfrog_stepsizes contains non-float stepsizes
ValueError: leapfrog_stepsizes contains non-positive stepsizes
Returns:
list of validated leapfrog stepsizes
TypeError: leapfrog_step_sizes is not a list
ValueError: leapfrog_step_sizes is of incorrect length
TypeError: leapfrog_step_sizes contains non-float step sizes
ValueError: leapfrog_step_sizes contains non-positive step sizes
"""
if not (type(leapfrog_stepsizes) is list):
if not isinstance(leapfrog_step_sizes, list):
raise TypeError(
f"leapfrog_stepsizes must be of type list, but found type "
f"{type(leapfrog_stepsizes)}"
f"leapfrog_step_sizes must be of type list, but found type "
f"{type(leapfrog_step_sizes)}"
)
if len(leapfrog_stepsizes) != self._max_proposals:
if len(leapfrog_step_sizes) != self._max_proposals:
raise ValueError(
f"leapfrog_stepsizes must be a list of length {self._max_proposals} "
f"so that each proposal has its own specfied leapfrog stepsize, but "
f"instead found length {len(leapfrog_stepsizes)}"
f"leapfrog_step_sizes must be a list of length {self._max_proposals} "
f"so that each proposal has its own specfied leapfrog step size, but "
f"instead found length {len(leapfrog_step_sizes)}"
)
for idx, stepsize in enumerate(leapfrog_stepsizes):
if not (type(stepsize) is float):
for idx, step_size in enumerate(leapfrog_step_sizes):
if not isinstance(step_size, float):
raise TypeError(
f"each stepsize in leapfrog_stepsizes must be of type float, but "
f"found stepsize of type {type(stepsize)} at index {idx}"
f"each step size in leapfrog_step_sizes must be of type float, but "
f"found step size of type {type(step_size)} at index {idx}"
)
if not stepsize > 0:
if not step_size > 0:
raise ValueError(
f"each stepsize in leapfrog_stepsizes must be positive, but found "
f"stepsize of {stepsize} at index {idx}"
f"each step size in leapfrog_step_sizes must be positive, but found "
f"step size of {step_size} at index {idx}"
)
return leapfrog_stepsizes

def _validate_leapfrog_stepcounts(
self, leapfrog_stepcounts: list[int]
) -> list[int]:
"""Check that leapfrog stepcounts is a list with positive, integer stepcounts
and a length equal to the maximum number of proposals.
def _validate_leapfrog_step_counts(
self, leapfrog_step_counts: Sequence[int]
) -> None:
"""Raise error if leapfrog step counts is not a list with positive, integer
step counts and a length equal to the maximum number of proposals.
Args:
leapfrog_stepcounts: list of leapfrog stepcounts
leapfrog_step_counts: list of leapfrog step counts
Raises:
TypeError: leapfrog_stepcounts is not a list
ValueError: leapfrog_stepcounts is of incorrect length
TypeError: leapfrog_stepcounts contains non-integer steps
ValueError: leapfrog_stepcounts contains non-positive steps
Returns:
list of validated leapfrog stepcounts
TypeError: leapfrog_step_counts is not a list
ValueError: leapfrog_step_counts is of incorrect length
TypeError: leapfrog_step_counts contains non-integer steps
ValueError: leapfrog_step_counts contains non-positive steps
"""
if not (type(leapfrog_stepcounts) is list):
if not isinstance(leapfrog_step_counts, list):
raise TypeError(
f"leapfrog_stepcounts must be of type list, but found type "
f"{type(leapfrog_stepcounts)}"
f"leapfrog_step_counts must be of type list, but found type "
f"{type(leapfrog_step_counts)}"
)
if len(leapfrog_stepcounts) != self._max_proposals:
if len(leapfrog_step_counts) != self._max_proposals:
raise ValueError(
f"leapfrog_stepcounts must be a list of length {self._max_proposals}, "
f"leapfrog_step_counts must be a list of length {self._max_proposals}, "
f"so that each proposal has its own specified number of leapfrog "
f"steps, but instead found length {len(leapfrog_stepcounts)}"
f"steps, but instead found length {len(leapfrog_step_counts)}"
)
for idx, stepcount in enumerate(leapfrog_stepcounts):
if not (type(stepcount) is int):
for idx, step_count in enumerate(leapfrog_step_counts):
if not isinstance(step_count, int):
raise TypeError(
f"each stepcount in leapfrog_stepcounts must be of type int, but "
f"found stepcount of type {type(stepcount)} at index {idx}"
f"each step count in leapfrog_step_counts must be of type int, but "
f"found step count of type {type(step_count)} at index {idx}"
)
if not stepcount > 0:
if not step_count > 0:
raise ValueError(
f"each stepcount in leapfrog_stepcounts must be positive, but "
f"found stepcount of {stepcount} at index {idx}"
f"each step count in leapfrog_step_counts must be positive, but "
f"found step count of {step_count} at index {idx}"
)
return leapfrog_stepcounts

def _validate_damping(self, damping: float) -> float:
"""Check that the damping factor is a float in (0, 1].
def _validate_damping(self, damping: float) -> None:
"""Raise error if the damping factor is not a float in (0, 1].
Args:
damping: generalized HMC momentum damping factor in (0, 1]
Raises:
TypeError: damping is not a float
ValueError: damping is not in (0, 1]
Returns:
validated damping factor
"""
if not (type(damping) is float):
if not isinstance(damping, float):
raise TypeError(
f"damping must be of type float, but found type {type(damping)}"
)
if not 0 < damping <= 1:
raise ValueError(
f"damping must be within (0, 1], but found damping of {damping}"
)
return damping

def __iter__(self) -> Iterator[DrawAndLogP]:
"""Return the iterator for draws from this sampler.
Expand Down Expand Up @@ -258,8 +252,8 @@ def leapfrog(
self,
theta: VectorType,
rho: VectorType,
stepsize: float,
stepcount: int,
step_size: float,
step_count: int,
) -> tuple[VectorType, VectorType]:
"""Return the result of running the leapfrog integrator for Hamiltonian
dynamics starting from the current draw (theta, rho) with the specified step
Expand All @@ -268,8 +262,8 @@ def leapfrog(
Args:
theta: position
rho: momentum
stepsize: stepsize in each leapfrog step
stepcount: number of leapfrog steps
step_size: step size in each leapfrog step
step_count: number of leapfrog steps
Returns:
Approximate solution to Hamiltonian dynamics via leapfrog integration
Expand All @@ -278,16 +272,16 @@ def leapfrog(
grad: ArrayLike # mypy infers too strict a type when reading from cache

logp, grad = self._log_density_gradient_cache[-1]
rho_mid = rho + 0.5 * stepsize * np.multiply(self._metric, grad).squeeze()
theta += stepsize * rho_mid
rho_mid = rho + 0.5 * step_size * np.multiply(self._metric, grad).squeeze()
theta += step_size * rho_mid

for _ in range(stepcount - 1):
for _ in range(step_count - 1):
logp, grad = self._model.log_density_gradient(theta)
rho_mid += stepsize * np.multiply(self._metric, grad).squeeze()
theta += stepsize * rho_mid
rho_mid += step_size * np.multiply(self._metric, grad).squeeze()
theta += step_size * rho_mid

logp, grad = self._model.log_density_gradient(theta)
rho = rho_mid + 0.5 * stepsize * np.multiply(self._metric, grad).squeeze()
rho = rho_mid + 0.5 * step_size * np.multiply(self._metric, grad).squeeze()

self._log_density_gradient_cache.append((logp, np.asanyarray(grad)))
return (theta, rho)
Expand Down Expand Up @@ -336,13 +330,13 @@ def proposal_map(
Args:
theta: position
rho: momentum
k: proposal number (for leapfrog stepsize and stepcount)
k: proposal number (for leapfrog step size and step count)
Returns:
proposed draw (theta_prop, rho_prop)
"""
stepsize, stepcount = self._leapfrog_stepsizes[k], self._leapfrog_stepcounts[k]
theta_prop, rho_prop = self.leapfrog(theta, rho, stepsize, stepcount)
step_size, step_count = self._leapfrog_step_sizes[k], self._leapfrog_step_counts[k]
theta_prop, rho_prop = self.leapfrog(theta, rho, step_size, step_count)
rho_prop = -rho_prop
return (theta_prop, rho_prop)

Expand Down Expand Up @@ -412,7 +406,7 @@ def accept(
Args:
theta_prop: proposed position
rho_prop: proposed momentum
k: proposal number (for leapfrog stepsize and stepcount)
k: proposal number (for leapfrog step size and step count)
cur_hastings: log probability of rejecting all previous proposals
cur_logp: log probability of current draw
Expand Down

0 comments on commit 453af20

Please sign in to comment.