From 8bd849901748c68152805c2335c4c6e490d2eed6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Wed, 14 Aug 2024 15:56:41 +0200 Subject: [PATCH] Gaussian structure learning (#30) --- examples/.gitignore | 2 + examples/gaussian_graphical_learning.qmd | 487 +++++++++++++++++++++++ src/jnotype/_csp.py | 4 +- src/jnotype/_variance.py | 10 +- src/jnotype/gaussian/__init__.py | 12 + src/jnotype/gaussian/_horseshoe.py | 377 ++++++++++++++++++ src/jnotype/gaussian/_numeric.py | 175 ++++++++ src/jnotype/gaussian/_spike_and_slab.py | 376 +++++++++++++++++ src/jnotype/sampling/_sampler.py | 3 +- tests/test_variance.py | 4 +- 10 files changed, 1440 insertions(+), 10 deletions(-) create mode 100644 examples/.gitignore create mode 100644 examples/gaussian_graphical_learning.qmd create mode 100644 src/jnotype/gaussian/__init__.py create mode 100644 src/jnotype/gaussian/_horseshoe.py create mode 100644 src/jnotype/gaussian/_numeric.py create mode 100644 src/jnotype/gaussian/_spike_and_slab.py diff --git a/examples/.gitignore b/examples/.gitignore new file mode 100644 index 0000000..0a60a06 --- /dev/null +++ b/examples/.gitignore @@ -0,0 +1,2 @@ +*.html +*_files/ diff --git a/examples/gaussian_graphical_learning.qmd b/examples/gaussian_graphical_learning.qmd new file mode 100644 index 0000000..633d2a5 --- /dev/null +++ b/examples/gaussian_graphical_learning.qmd @@ -0,0 +1,487 @@ +--- +title: Learning sparse Gaussian graphical models +format: + html: + code-fold: true +jupyter: python3 +toc: true +number-sections: true +--- + +In this tutorial we will learn a sparse Gaussian graphical models by using appropriate priors on the precision matrices. + +We assume that there is a sparse symmetric positive definite matrix $Q\in \mathbb{R}^{G\times G}$, which describes a Gaussian graphical model as following: + +$$ + Y_n\mid Q\sim \mathcal N\left(\mathbf{0}, Q^{-1}\right) +$$ +for $n=1, \dotsc, N$. +In other words, matrix $Q$ is the *precision matrix* (with $Q^{-1}$ being the covariance matrix), which we assume to be sparse. +We expect that this matrix is rather sparse: $Q_{12} = 0$ corresponds to the conditional independence of $Y_{n1}$ and $Y_{n2}$ given variables $Y_{n3}, \dotsc, Y_{nG}$ (for any particular $n$). + +## Generating the data + +Let's start by generating a sparse positive-definite matrix $Q$. We will sample it from a graphical spike-and-slab prior proposed by [Hao Wang, *Scaling it up: Stochastic search structure learning in graphical models* (2015)](https://arxiv.org/abs/1505.01687), employing the Gibbs sampler he introduced: + +```{python} + +import jnotype as jt +import jnotype.gaussian as gauss + +import jax +import jax.numpy as jnp +import jax.random as jrandom + +import seaborn as sns +import matplotlib.pyplot as plt + +key = jrandom.PRNGKey(32) + +G = 10 # Dimensionality +N = 150 # The number of samples to be generated + +aux_dataset = jt.sampling.ListDataset( + thinning=1, + dimensions=gauss.PrecisionMatrixSpikeAndSlabSampler.dimensions(), +) + +gibbs_sampler = gauss.PrecisionMatrixSpikeAndSlabSampler( + datasets=[aux_dataset], + scatter_matrix=jnp.zeros((G, G)), + n_points=0, + warmup=1000, + steps=1, + verbose=True, + seed=121, + pi=0.2, + std0=0.1, + std1=2, + lambd=0.8, + deterministic_init=False, +) + +gibbs_sampler.run() + +prec_true = aux_dataset.samples[-1]["precision"] + +fig, ax = plt.subplots() + +sns.heatmap(prec_true, ax=ax, cmap="bwr", center=0) +ax.set_title("Generated precision matrix") +ax.set_xticks([]) +ax.set_yticks([]) +``` + +Now let's generate the $N\times G$ matrix representing the observed samples $Y_n$: + +```{python} + +cov_true = jnp.linalg.inv(prec_true + 1e-6) + +Y = jrandom.multivariate_normal(key, jnp.zeros(G), cov_true, shape=(N,)) +``` + +Let's estimate the covariance from the sample and then invert it to get a precision estimate: + +```{python} +fig, axs = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(4*1.1, 4)) + +cov_empirical = jnp.cov(Y.T) +prec_empirical = jnp.linalg.inv(cov_empirical + 1e-6 * jnp.eye(G)) + +vmin_cov = min(cov_empirical.min(), cov_true.min()) +vmax_cov = max(cov_empirical.max(), cov_true.max()) +cmap_cov = "PiYG" + +vmin_prec = min(prec_true.min(), prec_true.max()) +vmax_prec = max(prec_true.max(), prec_true.max()) +cmap_prec = "bwr" + +ax = axs[0, 0] +sns.heatmap( + prec_true, + cmap=cmap_prec, + center=0, + ax=ax, + vmin=vmin_prec, + vmax=vmax_prec, +) +ax.set_title("True precision") + +ax = axs[0, 1] +sns.heatmap( + cov_true, + cmap=cmap_cov, + center=0, + ax=ax, + vmin=vmin_cov, vmax=vmax_cov, +) +ax.set_title("True covariance") + +ax = axs[1, 0] +sns.heatmap( + prec_empirical, + center=0, + ax=ax, + cmap=cmap_prec, + vmin=vmin_prec, + vmax=vmax_prec, +) +ax.set_title("Estimated precision") + +ax = axs[1, 1] +sns.heatmap( + cov_empirical, + center=0, + ax=ax, + cmap=cmap_cov, + vmin=vmin_cov, + vmax=vmax_cov) +ax.set_title("Estimated covariance") + +for ax in axs.ravel(): + ax.set_xticks([]) + ax.set_yticks([]) + +fig.tight_layout() + +``` + +Even though the sample covariance somehow approximates the ground-truth covariance, matrix inverse introduced a lot of additional entries to the precision matrix estimate... + +Let's understand how much these matrices differ by plotting separately the diagonal and the off-diagonal entries: + +```{python} +fig, axs = plt.subplots(1, 2) + +ax = axs[0] +ax.set_title("Diagonal") +true_diagonal = jnp.diagonal(prec_true) +diagonal_ordering = jnp.argsort(true_diagonal)[::-1] + +color_true = "k" +color_estimate = "orange" + +ax.plot( + true_diagonal[diagonal_ordering], + c=color_true, + linestyle="-", + label="True" +) +ax.plot( + jnp.diagonal(prec_empirical)[diagonal_ordering], + c=color_estimate, + linestyle="--", + label="Estimate" +) + +ax = axs[1] +ax.set_title("Off-diagonal") +o1, o2 = jnp.triu_indices(G, k=1) +true_offdiagonal = prec_true[o1, o2] +offdiagonal_ordering = jnp.argsort(true_offdiagonal)[::-1] + +def matrix_to_offdiagonal(m): + return m[o1, o2][offdiagonal_ordering] + +ax.plot( + matrix_to_offdiagonal(prec_true), + c=color_true, + linestyle="-" +) +ax.plot( + matrix_to_offdiagonal(prec_empirical), + c=color_estimate, + linestyle="--" +) + + +for ax in axs: + ax.spines[["top", "right"]].set_visible(False) + +fig.legend(frameon=False) +``` + +## Spike-and-slab prior + +Let's use a model-based estimator, where we find $Q$ by explicitly using the assumed model $Y_n\sim \mathcal N\left(\mathbf{0}, Q^{-1}\right)$. +To calculate the likelihood, one can use the *scatter matrix*, rather than the original data: +$$ + S_{ij} = \sum_{n=1}^N Y_{ni}Y_{nj} = Y^TY. +$$ + +We can calculate it as following: + +```{python} +# The scatter matrix +scatter = gauss.construct_scatter_matrix(Y) +``` + +Now we can use it to estimate $Q$ by using a Gibbs sampler and assuming a spike-and-slab prior. +We will use essentially the same sampler as before, but this time we provide the scatter matrix and number of samples obtained from the data: + +```{python} +dataset = jt.sampling.ListDataset( + thinning=2, + dimensions=gauss.PrecisionMatrixSpikeAndSlabSampler.dimensions(), +) + +sampler = gauss.PrecisionMatrixSpikeAndSlabSampler( + datasets=[dataset], + scatter_matrix=scatter, + n_points=N, + warmup=1000, + steps=2000, + verbose=True, + seed=0, + std0=0.1, +) + +sampler.run() +precs = jnp.array(dataset.dataset["precision"]) +``` + +In principle, we should run multiple chains and see whether there are any convergence issues. +In this tutorial we will however rely on a single one. + +Let's investigate the obtained estimate. +Note that a linear combination of two symmetric positive definite matrices is still symmetric positive definite, provided that the coefficients are positive (geometrically speaking, the set of symmetric positive definite matrices forms a [convex cone](https://en.wikipedia.org/wiki/Convex_cone)). +Hence, the posterior mean is also a symmetric positive definite matrix. Let's plot is as a point estimate: + +```{python} +fig, axs = plt.subplots(1, 3, figsize=(7, 2.5), sharex=True, sharey=True) + +ax = axs[0] +ax.set_title("True precision") +sns.heatmap( + prec_true, + ax=ax, + cmap=cmap_prec, + center=0, + vmin=vmin_prec, + vmax=vmax_prec, +) + +ax = axs[1] +ax.set_title("Posterior mean") +sns.heatmap(precs.mean(axis=0), ax=ax, cmap="bwr", center=0, vmin=vmin_prec, vmax=vmax_prec) + +ax = axs[2] +ax.set_title("Entrywise\nstandard deviation") +sns.heatmap( + precs.std(axis=0), + ax=ax, + cmap="Greys", + vmin=0, +) + + +for ax in axs.ravel(): + ax.set_xticks([]) + ax.set_yticks([]) + +fig.tight_layout() +``` + +On the right hand side we see the standard deviation, quantifying how uncertain we are. +Similarly as before, we can also plot the diagonal and off-diagonal entries separately. +This time, however, we have a measure of uncertainty. As we are plotting individual entries, we will plot the median and a 80%-credible interval ranging between the 10% and 90% quantiles. + +```{python} +fig, axs = plt.subplots(1, 2) + +ax = axs[0] +ax.set_title("Diagonal") + +x_ax = jnp.arange(len(true_diagonal)) + +ax.plot( + x_ax, + true_diagonal[diagonal_ordering], + c=color_true, + linestyle="-", +) + +aux = [] +for prec in precs: + aux.append(jnp.diagonal(prec)[diagonal_ordering]) + +aux = jnp.asarray(aux) +median = jnp.quantile(aux, axis=0, q=0.5) +low = jnp.quantile(aux, axis=0, q=0.1) +high = jnp.quantile(aux, axis=0, q=0.9) + +ax.plot( + x_ax, + median, + c=color_estimate, + linestyle="--", + alpha=1.0, +) +ax.fill_between(x_ax, low, high, alpha=0.1, color=color_estimate) + +ax = axs[1] +ax.set_title("Off-diagonal") + +x_ax = jnp.arange(len(matrix_to_offdiagonal(prec_true))) + +ax.plot( + x_ax, + matrix_to_offdiagonal(prec_true), + c=color_true, + linestyle="-" +) + +aux = [] +for prec in precs: + aux.append(matrix_to_offdiagonal(prec)) + +aux = jnp.asarray(aux) +median = jnp.quantile(aux, axis=0, q=0.5) +low = jnp.quantile(aux, axis=0, q=0.1) +high = jnp.quantile(aux, axis=0, q=0.9) + +ax.plot( + x_ax, + median, + c=color_estimate, + linestyle="--", + alpha=1.0, +) +ax.fill_between(x_ax, low, high, alpha=0.1, color=color_estimate) + +for ax in axs: + ax.spines[["top", "right"]].set_visible(False) +``` + +## Horseshoe prior + +An alternative prior is the graphical horseshoe prior, proposed by [Y. Li, B.A. Craig and A. Bhadra, *The graphical horseshoe estimator for inverse covariance matrices* (2019)](https://arxiv.org/abs/1707.06661). + +We can sample from it in an analogous fashion: + +```{python} +dataset = jt.sampling.ListDataset( + thinning=2, + dimensions=gauss.PrecisionMatrixHorseshoeSampler.dimensions(), +) + +sampler = gauss.PrecisionMatrixHorseshoeSampler( + datasets=[dataset], + scatter_matrix=scatter, + n_points=N, + warmup=1000, + steps=2000, + verbose=True, + seed=0, +) + +sampler.run() +precs = jnp.array(dataset.dataset["precision"]) +``` + +Let's plot the mean estimate: + +```{python} +fig, axs = plt.subplots(1, 3, figsize=(7, 2.5), sharex=True, sharey=True) + +ax = axs[0] +ax.set_title("True precision") +sns.heatmap( + prec_true, + ax=ax, + cmap=cmap_prec, + center=0, + vmin=vmin_prec, + vmax=vmax_prec, +) + +ax = axs[1] +ax.set_title("Posterior mean") +sns.heatmap(precs.mean(axis=0), ax=ax, cmap="bwr", center=0, vmin=vmin_prec, vmax=vmax_prec) + +ax = axs[2] +ax.set_title("Entrywise\nstandard deviation") +sns.heatmap( + precs.std(axis=0), + ax=ax, + cmap="Greys", + vmin=0, +) + + +for ax in axs.ravel(): + ax.set_xticks([]) + ax.set_yticks([]) + +fig.tight_layout() +``` + +Finally, we can separately visualise the diagonal and the off-diagonal entries: + +```{python} +fig, axs = plt.subplots(1, 2) + +ax = axs[0] +ax.set_title("Diagonal") + +x_ax = jnp.arange(len(true_diagonal)) + +ax.plot( + x_ax, + true_diagonal[diagonal_ordering], + c=color_true, + linestyle="-", +) + +aux = [] +for prec in precs: + aux.append(jnp.diagonal(prec)[diagonal_ordering]) + +aux = jnp.asarray(aux) +median = jnp.quantile(aux, axis=0, q=0.5) +low = jnp.quantile(aux, axis=0, q=0.1) +high = jnp.quantile(aux, axis=0, q=0.9) + +ax.plot( + x_ax, + median, + c=color_estimate, + linestyle="--", + alpha=1.0, +) +ax.fill_between(x_ax, low, high, alpha=0.1, color=color_estimate) + +ax = axs[1] +ax.set_title("Off-diagonal") + +x_ax = jnp.arange(len(matrix_to_offdiagonal(prec_true))) + +ax.plot( + x_ax, + matrix_to_offdiagonal(prec_true), + c=color_true, + linestyle="-" +) + +aux = [] +for prec in precs: + aux.append(matrix_to_offdiagonal(prec)) + +aux = jnp.asarray(aux) +median = jnp.quantile(aux, axis=0, q=0.5) +low = jnp.quantile(aux, axis=0, q=0.1) +high = jnp.quantile(aux, axis=0, q=0.9) + +ax.plot( + x_ax, + median, + c=color_estimate, + linestyle="--", + alpha=1.0, +) +ax.fill_between(x_ax, low, high, alpha=0.1, color=color_estimate) + +for ax in axs: + ax.spines[["top", "right"]].set_visible(False) +``` diff --git a/src/jnotype/_csp.py b/src/jnotype/_csp.py index ada045f..bff070f 100644 --- a/src/jnotype/_csp.py +++ b/src/jnotype/_csp.py @@ -299,8 +299,8 @@ def sample_csp_prior( variances_active = sample_inverse_gamma( key=key_var, n_points=k, - a=prior_shape, - b=prior_scale, + shape=prior_shape, + scale=prior_scale, ) variance = _select_variances_active( diff --git a/src/jnotype/_variance.py b/src/jnotype/_variance.py index a74b95a..c41360f 100644 --- a/src/jnotype/_variance.py +++ b/src/jnotype/_variance.py @@ -95,15 +95,15 @@ def sample_variances( def sample_inverse_gamma( key, n_points: int, - a: float, - b: float, + shape: float, + scale: float, ) -> Float[Array, " n_points"]: """Samples from the inverse gamma distribution. Args: key: JAX random key - a: shape parameter of the inverse gamma distribution - b: scale parameter of the inverse gamma distribution + shape: shape parameter of the inverse gamma distribution + scale: scale parameter of the inverse gamma distribution n_points: number of points to sample Note that: @@ -111,5 +111,5 @@ def sample_inverse_gamma( is equivalent to 1/X ~ Gamma(shape=a, rate=b) """ - one_over_x = random.gamma(key, a, shape=(n_points,)) / b + one_over_x = random.gamma(key, shape, shape=(n_points,)) / scale return jnp.reciprocal(one_over_x) diff --git a/src/jnotype/gaussian/__init__.py b/src/jnotype/gaussian/__init__.py new file mode 100644 index 0000000..33df28c --- /dev/null +++ b/src/jnotype/gaussian/__init__.py @@ -0,0 +1,12 @@ +"""Samplers for Gaussian graphical models with sparse prior on precision matrices.""" + +from jnotype.gaussian._horseshoe import PrecisionMatrixHorseshoeSampler +from jnotype.gaussian._spike_and_slab import PrecisionMatrixSpikeAndSlabSampler +from jnotype.gaussian._numeric import construct_scatter_matrix + + +__all__ = [ + "PrecisionMatrixHorseshoeSampler", + "PrecisionMatrixSpikeAndSlabSampler", + "construct_scatter_matrix", +] diff --git a/src/jnotype/gaussian/_horseshoe.py b/src/jnotype/gaussian/_horseshoe.py new file mode 100644 index 0000000..c014d75 --- /dev/null +++ b/src/jnotype/gaussian/_horseshoe.py @@ -0,0 +1,377 @@ +"""Implements the sampler employing graphical horseshoe +prior as proposed by +Y. Li, B.A. Craig and A. Bhadra, +The graphical horseshoe estimator for inverse covariance matrices (2019) +""" + +from typing import NamedTuple, NewType, Optional, Sequence + +from jaxtyping import Float, Array + +import jax +import jax.random as jrandom +import jax.numpy as jnp + +from jnotype.sampling import AbstractGibbsSampler, DatasetInterface +from jnotype._utils import JAXRNG + +import jnotype.gaussian._numeric as num + + +class _HorseshoeRowSample(NamedTuple): + """Internal sample of the last row of all three local arrays""" + + precision: Float[Array, " G"] + lambda2: Float[Array, " G-1"] + nu: Float[Array, " G-1"] + + +def _sample_inverse_gamma( + key, + shape: float, + scale: Float[Array, " N"], +) -> Float[Array, " N"]: + """Samples from the inverse gamma distribution. + + Args: + key: JAX random key + shape: shape parameter of the inverse gamma distribution + scale: scale parameter of the inverse gamma distribution + + Note that: + X ~ InvGamma(shape=a, scale=b) + is equivalent to + 1/X ~ Gamma(shape=a, rate=b) + """ + samples_gamma = jrandom.gamma(key, shape, shape=scale.shape) + return scale * jnp.reciprocal(samples_gamma) + + +def _sample_horseshoe_row( + key, + *, + scatter_row: Float[Array, " G"], + precision: Float[Array, "G G"], + lambda2_row: Float[Array, " G-1"], + nu_row: Float[Array, " G-1"], + n_points: int, + tau2: float, + _jitter: float, +) -> _HorseshoeRowSample: + """Samples the last row. + + Args: + _jitter: a small numerical jitter to make the matrix inversion more stable + """ + s12: Float[Array, " G-1"] = scatter_row[:-1] + s22: float = scatter_row[-1] + Gm1 = s12.shape[0] # G - 1 + + inv_omega11 = jnp.linalg.inv(precision[:-1, :-1] + _jitter * jnp.eye(Gm1)) + v12 = lambda2_row * tau2 + inv_C = s22 * inv_omega11 + jnp.diag(jnp.reciprocal(v12)) + _jitter * jnp.eye(Gm1) + rate = 0.5 * s22 + + key_omega, key_lambda, key_nu = jrandom.split(key, 3) + + omega12: Float[Array, " G"] = num.sample_precision_column( + key=key_omega, + inv_omega11=inv_omega11, + inv_C=inv_C, + scatter12=s12, + n_samples=n_points, + rate=rate, + ) + + lambda2: Float[Array, " G-1"] = _sample_inverse_gamma( + key_lambda, + shape=1, + scale=jnp.reciprocal(nu_row) + 0.5 * jnp.square(omega12[:-1]) / tau2, + ) + + nu: Float[Array, " G-1"] = _sample_inverse_gamma( + key_nu, shape=1.0, scale=1 + jnp.reciprocal(lambda2) + ) + + return _HorseshoeRowSample( + precision=omega12, + lambda2=lambda2, + nu=nu, + ) + + +class _HorseshoeMatricesSample(NamedTuple): + """Internal object representing the matrices.""" + + precision: Float[Array, "G G"] + lambda2: Float[Array, "G G"] + nu: Float[Array, "G G"] + + @property + def dim(self) -> int: + return self.precision.shape[0] + + +def _sample_precision_matrix_column_by_column( + key, + *, + n_samples: int, + scatter: Float[Array, "G G"], + sample: _HorseshoeMatricesSample, + tau2: float, + _jitter: float, +) -> _HorseshoeMatricesSample: + """Samples the precision matrix by sampling + columns one after the other. + + Args: + key: JAX random key + precision: precision matrix + scatter: the scatter matrix + variances: variances matrix + (obtained using the latent indicators) + lambd: penalisation on the diagonal entries. + The larger `lambd`, the more shrinkage to 0 is encouraged. + n_samples: number of data points + _jitter: a small numerical jitter to make the matrix inversion more stable + + Returns: + A precision matrix. + """ + + def update_column(carry: tuple, k: int) -> tuple: + """Function sampling the `k`th column (row) and updating it. + + Args: + carry: tuple (key, HorseshoeMatricesSample) + k: the index of the column (row) to be updated + """ + key = carry[0] + matrices: _HorseshoeMatricesSample = carry[1] + + # Reorder the variables, + # so that the updated column is the last one + scatter_ = num.swap_with_last(scatter, k) + precision = num.swap_with_last(matrices.precision, k) + lambda2 = num.swap_with_last(matrices.lambda2, k) + nu = num.swap_with_last(matrices.nu, k) + + # Sample the new last row/column + key, subkey = jrandom.split(key) + new_cols: _HorseshoeRowSample = _sample_horseshoe_row( + key=subkey, + precision=precision, # Full precision matrix + scatter_row=scatter_[:, -1], + lambda2_row=lambda2[:-1, -1], + nu_row=nu[:-1, -1], + n_points=n_samples, + tau2=tau2, + _jitter=_jitter, + ) + + # Update both the row and the column + _LAST = -1 + precision = precision.at[:, _LAST].set(new_cols.precision) + precision = precision.at[_LAST, :].set(new_cols.precision) + + lambda2 = lambda2.at[:-1, _LAST].set(new_cols.lambda2) + lambda2 = lambda2.at[_LAST, :-1].set(new_cols.lambda2) + + nu = nu.at[:-1, _LAST].set(new_cols.nu) + nu = nu.at[_LAST, :-1].set(new_cols.nu) + + # Reorder the variables to the original order + precision = num.swap_with_last(precision, k) + lambda2 = num.swap_with_last(lambda2, k) + nu = num.swap_with_last(lambda2, k) + + new_matrices = _HorseshoeMatricesSample( + precision=precision, + lambda2=lambda2, + nu=nu, + ) + + return (key, new_matrices), None + + carry, _ = jax.lax.scan( + update_column, + (key, sample), + jnp.arange(sample.dim), + ) + _, matrices = carry + return matrices + + +class HorseshoeSample(NamedTuple): + """Represents a full sample.""" + + precision: Float[Array, "G G"] + lambda2: Float[Array, "G G"] + nu: Float[Array, "G G"] + tau2: Float[Array, ""] + xi: Float[Array, ""] + + @property + def dim(self) -> int: + return self.precision.shape[0] + + +@jax.jit +def sample_horseshoe( + key, + scatter: Float[Array, "G G"], + n_samples: int, + sample: HorseshoeSample, + _jitter: float = 1e-6, +) -> HorseshoeSample: + key_matrices, key_tau, key_xi = jrandom.split(key, 3) + + new_matrices = _sample_precision_matrix_column_by_column( + key=key_matrices, + n_samples=n_samples, + scatter=scatter, + sample=_HorseshoeMatricesSample( + precision=sample.precision, + lambda2=sample.lambda2, + nu=sample.nu, + ), + tau2=sample.tau2, + _jitter=_jitter, + ) + + G = sample.dim + Gover2 = 0.5 * G * (G - 1) + + offset = 0.5 * jnp.sum( + jnp.square(num.utzd_to_vector(new_matrices.precision)) + / num.utzd_to_vector(new_matrices.lambda2) + ) + + tau2 = _sample_inverse_gamma( + key_tau, + shape=0.5 * (1 + Gover2), + scale=jnp.reciprocal(sample.xi) + offset, + ) + + xi = _sample_inverse_gamma( + key_xi, + shape=1, + scale=1.0 + jnp.reciprocal(tau2), + ) + + return HorseshoeSample( + precision=new_matrices.precision, + lambda2=new_matrices.lambda2, + nu=new_matrices.nu, + tau2=tau2, + xi=xi, + ) + + +Sample = NewType("Sample", dict) + + +class PrecisionMatrixHorseshoeSampler(AbstractGibbsSampler): + """A Gibbs sampler to learn a precision matrix from centered (zero-mean) + normally distributed data. + """ + + def __init__( + self, + datasets: Sequence[DatasetInterface], + *, + data: Optional[Float[Array, "points features"]] = None, + scatter_matrix: Optional[Float[Array, "features features"]] = None, + n_points: Optional[int] = None, + # Gibbs sampling + warmup: int = 5_000, + steps: int = 10_000, + verbose: bool = False, + seed: int = 195, + _jitter: float = 1e-6, + # Initialisation + deterministic_init: bool = False, + ) -> None: + super().__init__(datasets, warmup=warmup, steps=steps, verbose=verbose) + + self._deterministic_init = deterministic_init + + # Initialize a random number generator + self._jax_rng = JAXRNG(jax.random.PRNGKey(seed)) + + scatter, n_points = num.prepare_data( + data=data, + scatter=scatter_matrix, + n_points=n_points, + ) + + if _jitter < 0: + raise ValueError( + f"The _jitter argument has to be non-negative but is {_jitter}." + ) + self._jitter = _jitter + + self._scatter_matrix: Float[Array, "features features"] = scatter + self._n_points: int = n_points + self._n_features: int = self._scatter_matrix.shape[0] + + @classmethod + def dimensions(cls) -> Sample: + """The sites in each sample with annotated dimensions.""" + return { + "precision": ["features_dim0", "features_dim1"], + "tau2": [], + "tau2_aux": [], + "lambda2": ["features_dim0", "features_dim1"], + "lambda2_aux": ["features_dim0", "features_dim1"], + } + + def new_sample(self, sample: Sample) -> Sample: + """A new sample.""" + x: HorseshoeSample = sample_horseshoe( + key=self._jax_rng.key, + scatter=self._scatter_matrix, + n_samples=self._n_points, + sample=HorseshoeSample( + precision=sample["precision"], + lambda2=sample["lambda2"], + nu=sample["lambda2_aux"], + tau2=sample["tau2"], + xi=sample["tau2_aux"], + ), + _jitter=self._jitter, + ) + + return { + "precision": x.precision, + "tau2": x.tau2, + "tau2_aux": x.xi, + "lambda2": x.lambda2, + "lambda2_aux": x.nu, + } + + def initialise(self) -> Sample: + """Initialises the sample.""" + if self._deterministic_init: + return { + "precision": jnp.eye(self._n_features), + "tau2": jnp.array(1.0), + "tau2_aux": jnp.array(1.0), + "lambda2": jnp.ones((self._n_features, self._n_features)), + "lambda2_aux": jnp.ones((self._n_features, self._n_features)), + } + else: + return { + "precision": jnp.eye(self._n_features) + * (0.5 + jrandom.gamma(self._jax_rng.key, 1.0) / 0.5), + "tau2": 0.5 + jrandom.gamma(self._jax_rng.key, 1.0), + "tau2_aux": 0.5 + jrandom.gamma(self._jax_rng.key, 1.0), + "lambda2": 0.5 + + jrandom.gamma( + self._jax_rng.key, 1.0, shape=(self._n_features, self._n_features) + ), + "lambda2_aux": 0.5 + + jrandom.gamma( + self._jax_rng.key, 1.0, shape=(self._n_features, self._n_features) + ), + } diff --git a/src/jnotype/gaussian/_numeric.py b/src/jnotype/gaussian/_numeric.py new file mode 100644 index 0000000..c93d620 --- /dev/null +++ b/src/jnotype/gaussian/_numeric.py @@ -0,0 +1,175 @@ +"""Common numeric utilities. + +Notation: an UTZD stands for an "upper triangular with zero diagonal" +matrix. An UTZD matrix of shape `(G, G)` has therefore G*(G-1)/2 +free parameters. +""" + +from typing import Optional + +import jax.numpy as jnp +import jax.random as jrandom + +from jaxtyping import Float, Array, Num + + +def construct_scatter_matrix(y: Float[Array, "N G"]) -> Float[Array, "G G"]: + """Constructs the scatter matrix of a data set, i.e., + + $$S_{ij} = \\sum_{n=1^N} y_{ni}y_{nj}$$ + + for $i, j=1, \\dotsc, G$. + """ + return jnp.einsum("ng,nh->gh", y, y) + + +def utzd_to_vector(matrix: Num[Array, "G G"]) -> Num[Array, " G*(G-1)/2"]: + """Stores the free parameters of the UTZD matrix in a vector. + + See Also: + `vector_to_utzd` for the (one-sided) inverse. + """ + # Get the indices for the upper-triangular part (excluding the diagonal) + m = matrix.shape[0] + upper_tri_indices = jnp.triu_indices(m, k=1) + + # Extract the upper-triangular elements and flatten them into a vector + vector = matrix[upper_tri_indices] + return vector + + +def vector_to_utzd(vector: Num[Array, " m*(m-1)/2"], m: int) -> Num[Array, "m m"]: + """Stores a vector `vector` as + an upper triangular matrix with zero diagonal. + + See Also: + `utzd_to_vector` for the (one-sided) inverse + """ + # Create an empty m x m matrix of zeros + matrix = jnp.zeros((m, m), dtype=vector.dtype) + + # Get the indices for the upper-triangular part (excluding the diagonal) + upper_tri_indices = jnp.triu_indices(m, k=1) + + # Assign the vector values to the upper-triangular positions + matrix = matrix.at[upper_tri_indices].set(vector) + return matrix + + +def symmetrize_utzd(a: Num[Array, "G G"]) -> Num[Array, "G G"]: + """Symmetrizes a UTZD matrix, by copying the entries + to the lower diagonal. + + Note: + Do not use this function for a general matrix as e.g., it may + behave counterintuitively with respect to th diagonal. + """ + return a + a.T + + +def swap_with_last(A: Float[Array, "G G"], k: int) -> Float[Array, "G G"]: + """For a symmetric matrix `A` swaps the `k`th column with the last one.""" + m = -1 # We swap with the last column + A = A.at[[k, m], :].set(A[[m, k], :]) # Swap rows + A = A.at[:, [k, m]].set(A[:, [m, k]]) # Swap columns + return A + + +def sample_precision_column( + key, + inv_omega11: Float[Array, "G-1 G-1"], + inv_C: Float[Array, "G-1 G-1"], + scatter12: Float[Array, " G-1"], + n_samples: int, + rate: float, +) -> Float[Array, " G"]: + """Samples the last column (row) using the factorization: + Normal(first G-1 entries) x Gamma(last entry) + + Args: + key: JAX random key + inv_omega11: inverse of the (G-1) x (G-1) block + of the precision matrix + inv_C: inverse of the `C` matrix, + i.e., the precision matrix of the first `G-1` entries + scatter: column of the scatter matrix, shape (G,) + n_samples: number of samples, which controls the + shape parameter of the Gamma distribution + rate: the rate parameter of the Gamma distribution + + Returns: + Sampled column of length `G`. + """ + # Invert `inv_C` to obtain the variance + C = jnp.linalg.inv(inv_C) + + key_u, key_v = jrandom.split(key) + + u = jrandom.multivariate_normal(key_u, -C @ scatter12, C) + + shape = 1 + 0.5 * n_samples + v = jrandom.gamma(key_v, shape) / rate + + new_omega22 = v + jnp.einsum("g,gh,h->", u, inv_omega11, u) + + return jnp.append(u, new_omega22) + + +def prepare_data( + data: Optional[Float[Array, "points features"]], + scatter: Optional[Float[Array, "features features"]], + n_points: Optional[int], +) -> tuple[Float[Array, "features features"], int]: + """Generates the scatter matrix and the number of points. + + Args: + data: optional data matrix + scatter: optional scatter matrix. + Provide *either* `data` or `scatter` + n_points: optional number of points. + Has to be provided whenever `scatter` is provided + + Raises: + ValueError, if the data do not align properly + or if both `data` and `scatter` are provided + """ + if data is None and (scatter is None or n_points is None): + raise ValueError( + "Not enough arguments provided. " + "Provide *either* data (n_points, n_features) " + "or both the scatter matrix (n_features, n_features " + "and the number of points." + ) + if data is not None and (scatter is not None or n_points is not None): + raise ValueError( + "Too many arguments provided. " + "Provide *either* data (n_points, n_features) " + "or both the scatter matrix (n_features, n_features) " + "and the number of points." + ) + + # Case 1: We have the data + if data is not None: + if len(data.shape) != 2: + raise ValueError( + f"Data has to have shape (n_points, n_features), " + f"but has {data.shape}." + ) + + n_points = data.shape[0] + scatter = construct_scatter_matrix(data) + return scatter, n_points + + # Case 2: We have the scatter matrix and the number of points + assert n_points is not None, "Number of points not provided." + assert scatter is not None, "Scatter matrix not provided." + + if len(scatter.shape) != 2 or scatter.shape[0] != scatter.shape[1]: + raise ValueError( + f"The scatter matrix has to be a square matrix, " + f"but has shape {scatter.shape}" + ) + + if n_points < 0: + raise ValueError("Number of points cannot be negative.") + return scatter, n_points diff --git a/src/jnotype/gaussian/_spike_and_slab.py b/src/jnotype/gaussian/_spike_and_slab.py new file mode 100644 index 0000000..75e4190 --- /dev/null +++ b/src/jnotype/gaussian/_spike_and_slab.py @@ -0,0 +1,376 @@ +"""This module implements the Gibbs sampler from + +Hao Wang, "Scaling it up: Stochastic search structure +learning in graphical models", Bayesian Analysis (2015) +""" + +from typing import NewType, Optional, Sequence + +import jax +import jax.numpy as jnp +import jax.random as jrandom + +import numpyro.distributions as dist + +from jaxtyping import Float, Array, Int + +import jnotype.gaussian._numeric as num +from jnotype.logistic._structure import _softmax_p1 +from jnotype.sampling import AbstractGibbsSampler, DatasetInterface +from jnotype._utils import JAXRNG + + +def _normal_logp(x: float, std: float) -> float: + """Evaluates log-PDF of `N(0, std^2)`$ at `x`""" + return dist.Normal(0.0, scale=std).log_prob(x) + + +def _sample_indicators( + key, + precision: Float[Array, "G G"], + pi: float, + std0: float, + std1: float, +) -> Int[Array, "G G"]: + """Samples the indicator matrix: + + Args: + key: JAX random key + precision: precision matrix of shape `(G, G)` + pi: value between 0 and 1 controlling the sparsity + (lower `pi` should result in sparser matrices) + std0: standard deviation of the spike prior component + std1: standard deviation of the slab prior component + + Returns: + an indicator matrix of shape (G, G). + Note that it is a *symmetric* matrix with zero diagonal. + """ + G = precision.shape[0] + prec = num.utzd_to_vector(precision) + + logp_slab = _normal_logp(prec, std1) + jnp.log(pi) + logp_spike = _normal_logp(prec, std0) + jnp.log1p(-pi) + + p_slab = _softmax_p1(log_p0=logp_spike, log_p1=logp_slab) + indicators = jnp.asarray(jrandom.bernoulli(key, p=p_slab), dtype=int) + + a = num.vector_to_utzd(indicators, G) + return num.symmetrize_utzd(a) + + +def _generate_variance_matrix( + indicators: Int[Array, "G G"], + std0: float, + std1: float, +) -> Float[Array, "G G"]: + """Auxiliary function creating the variance matrix. + + Args: + indicators: symmetric indicator matrix with zero diagonal + std0: standard deviation of the spike prior component + std1: standard deviation of the slab prior component + """ + a = jnp.triu( + indicators * jnp.square(std1) + (1 - indicators) * jnp.square(std0), k=1 + ) + return num.symmetrize_utzd(a) + + +def _sample_last_precision_column( + key, + precision: Float[Array, "G G"], + scatter: Float[Array, "G G"], + variances: Float[Array, "G G"], + lambd: float, + n: int, +) -> Float[Array, " G"]: + """Samples the last column. + + Args: + key: JAX random key + precision: precision matrix + scatter: the scatter matrix + variances: variances matrix + (obtained using the latent indicators) + lambd: penalisation on the diagonal entries. + The larger `lambd`, the more shrinkage to 0 is encouraged. + n: number of data points + + Returns: + A sample from the conditional distribution of the last column (row) + of the precision matrix. + """ + inv_omega11 = jnp.linalg.inv(precision[:-1, :-1]) # (G-1, G-1) + + v12 = variances[-1, :-1] # (G-1,) + s12 = scatter[-1, :-1] # (G-1,) + s22: float = scatter[-1, -1] + + inv_C = (s22 + lambd) * inv_omega11 + jnp.diag(jnp.reciprocal(v12)) + rate = 0.5 * (s22 + lambd) + + return num.sample_precision_column( + key, + inv_omega11=inv_omega11, + inv_C=inv_C, + scatter12=s12, + n_samples=n, + rate=rate, + ) + + +def _sample_precision_matrix_column_by_column( + key, + precision: Float[Array, "G G"], + scatter: Float[Array, "G G"], + variances: Float[Array, "G G"], + lambd: float, + n: int, +) -> Float[Array, "G G"]: + """Samples the precision matrix by sampling + columns one after the other. + + Args: + key: JAX random key + precision: precision matrix + scatter: the scatter matrix + variances: variances matrix + (obtained using the latent indicators) + lambd: penalisation on the diagonal entries. + The larger `lambd`, the more shrinkage to 0 is encouraged. + n: number of data points + + Returns: + A precision matrix. + """ + + def update_column(carry: tuple, k: int) -> tuple: + """Function sampling the `k`th column (row) and updating it. + + Args: + carry: tuple (key, precision) + k: the index of the column (row) to be updated + """ + key, precision = carry + + # Reorder the variables + precision = num.swap_with_last(precision, k) + scatter_ = num.swap_with_last(scatter, k) + variances_ = num.swap_with_last(variances, k) + + # Sample the new last row/column + key, subkey = jrandom.split(key) + new_col = _sample_last_precision_column( + key=subkey, + precision=precision, + scatter=scatter_, + variances=variances_, + lambd=lambd, + n=n, + ) + # Update both the row and the column + precision = precision.at[:, -1].set(new_col) + precision = precision.at[-1, :].set(new_col) + + # Reorder the variables to the original order + precision = num.swap_with_last(precision, k) + + return (key, precision), None + + carry, _ = jax.lax.scan( + update_column, + (key, precision), + jnp.arange(precision.shape[0]), + ) + _, precision = carry + return precision + + +@jax.jit +def sample_indicators_and_precision( + key, + indicators: Int[Array, "G G"], + precision: Float[Array, "G G"], + scatter: Float[Array, "G G"], + lambd: float, + n_samples: int, + pi: float, + std0: float, + std1: float, +) -> tuple[Int[Array, "G G"], Float[Array, "G G"]]: + """Jointly samples indicator variables and precision matrix. + + Args: + key: JAX random key + indicators: current indicator matrix + precision: current precision matrix + scatter: the scatter matrix + lambd: penalisation on the diagonal entries. + The larger `lambd`, the more shrinkage to 0 is encouraged. + n_samples: number of data points + pi: value between 0 and 1 controlling the sparsity + (lower `pi` should result in sparser matrices) + std0: standard deviation of the spike prior component + std1: standard deviation of the slab prior component + + Returns: + indicators: symmetric 0-1 matrix of shape (G, G) + precision: symmetric real matrix of shape (G, G) + """ + subkey_indicators, subkey_precision = jrandom.split(key) + + indicators = _sample_indicators( + key=subkey_indicators, + precision=precision, + pi=pi, + std0=std0, + std1=std1, + ) + + precision = _sample_precision_matrix_column_by_column( + key=subkey_precision, + precision=precision, + scatter=scatter, + variances=_generate_variance_matrix( + indicators=indicators, std0=std0, std1=std1 + ), + lambd=lambd, + n=n_samples, + ) + + return indicators, precision + + +Sample = NewType("Sample", dict) + + +class PrecisionMatrixSpikeAndSlabSampler(AbstractGibbsSampler): + """A Gibbs sampler to learn a precision matrix from centered (zero-mean) + normally distributed data. + """ + + def __init__( + self, + datasets: Sequence[DatasetInterface], + *, + # Data + data: Optional[Float[Array, "points features"]] = None, + scatter_matrix: Optional[Float[Array, "features features"]] = None, + n_points: Optional[int] = None, + # Gibbs sampling + warmup: int = 5_000, + steps: int = 10_000, + verbose: bool = False, + seed: int = 195, + # Initialisation + deterministic_init: bool = False, + # Prior hyperparameters + lambd: float = 1.0, + pi: Optional[float] = None, + std0: float = 0.1, + std1: float = 1.0, + ) -> None: + """ + + Args: + datasets: data sets in which the samples are stored + data: data, assumed to be multivariate normal with zero mean. + Shape (n_points, n_features) + scatter_matrix: scatter matrix (can be provided instead of data). + Shape (n_features, n_features) + n_points: number of points in the data set + (use only when the scatter matrix is provided) + warmup: number of warmup steps in Gibbs sampling + steps: number of Gibbs steps after the warmup + verbose: whether the sampler should print out the sampling status + seed: random seed + deterministic_init: whether to use deterministic initialisation + lambd: prior parameter which regularises the diagonal entries + pi: prior parameter between 0 and 1 controlling the sparsity + (lower `pi` should result in sparser matrices). + By default (None) is set to `2 / (n_features - 1)`. + std0: standard deviation of the spike prior component + std1: standard deviation of the slab prior component + """ + super().__init__(datasets, warmup=warmup, steps=steps, verbose=verbose) + + self._deterministic_init = deterministic_init + + # Initialize a random number generator + self._jax_rng = JAXRNG(jax.random.PRNGKey(seed)) + + scatter, n_points = num.prepare_data( + data=data, + scatter=scatter_matrix, + n_points=n_points, + ) + + self._scatter_matrix: Float[Array, "features features"] = scatter + self._n_points: int = n_points + self._n_features: int = self._scatter_matrix.shape[0] + + if lambd <= 0: + raise ValueError(f"The lambd value has to be positive, but is {lambd}.") + self._lambd = lambd + + if pi is None: + pi = 2 / (self._n_features - 1) + + if pi <= 0 or pi >= 1: + raise ValueError( + f"The pi value has to be from the open inverval (0, 1), but is {pi}." + ) + self._pi = pi + + if std0 <= 0: + raise ValueError(f"The std0 has to be positive, but is {std0}.") + if std1 <= 0: + raise ValueError(f"The std1 has to be positive, but is {std1}.") + self._std0 = std0 + self._std1 = std1 + + @classmethod + def dimensions(cls) -> Sample: + """The sites in each sample with annotated dimensions.""" + return { + "precision": ["features_dim0", "features_dim1"], + "indicators": ["features_dim0", "features_dim1"], + } + + def new_sample(self, sample: Sample) -> Sample: + """A new sample.""" + indicators, precision = sample_indicators_and_precision( + key=self._jax_rng.key, + indicators=sample["indicators"], + precision=sample["precision"], + scatter=self._scatter_matrix, + n_samples=self._n_points, + lambd=self._lambd, + pi=self._pi, + std0=self._std0, + std1=self._std1, + ) + return { + "precision": precision, + "indicators": indicators, + } + + def initialise(self) -> Sample: + """Initialises the sample.""" + G = self._n_features + + if self._deterministic_init: + return { + "precision": jnp.eye(G, dtype=float), + "indicators": jnp.zeros((G, G), dtype=int), + } + else: + coinflips = jrandom.bernoulli( + self._jax_rng.key, p=self._pi, shape=(G * (G - 1) // 2,) + ) + return { + "precision": jnp.eye(self._n_features) + * (0.5 + jrandom.gamma(self._jax_rng.key, 1.0) / 0.5), + "indicators": num.symmetrize_utzd(num.vector_to_utzd(coinflips, G)), + } diff --git a/src/jnotype/sampling/_sampler.py b/src/jnotype/sampling/_sampler.py index 5783533..7ec1f9f 100644 --- a/src/jnotype/sampling/_sampler.py +++ b/src/jnotype/sampling/_sampler.py @@ -38,7 +38,8 @@ def __init__( self.steps = steps self.verbose = verbose - @abc.abstractclassmethod + @classmethod + @abc.abstractmethod def dimensions(cls) -> dict: """Returns dictionary describing the dimensions, e.g.,: diff --git a/tests/test_variance.py b/tests/test_variance.py index f88d623..0b17cbd 100644 --- a/tests/test_variance.py +++ b/tests/test_variance.py @@ -87,8 +87,8 @@ def test_sample_inverse_gamma(n_samples: int, shape: float, scale: float) -> Non samples = _var.sample_inverse_gamma( key=key, n_points=n_samples, - a=shape, - b=scale, + shape=shape, + scale=scale, ) assert jnp.mean(samples) == pytest.approx(scale / (shape - 1), rel=0.01)