diff --git a/src/jnotype/__init__.py b/src/jnotype/__init__.py index 7ae9bfc..098de06 100644 --- a/src/jnotype/__init__.py +++ b/src/jnotype/__init__.py @@ -1,4 +1,5 @@ """Exploratory analysis of binary data.""" + import jnotype.bmm as bmm import jnotype.datasets as datasets import jnotype.sampling as sampling diff --git a/src/jnotype/_csp.py b/src/jnotype/_csp.py index c446758..ada045f 100644 --- a/src/jnotype/_csp.py +++ b/src/jnotype/_csp.py @@ -1,4 +1,5 @@ """Cumulative shrinkage prior.""" + import jax import jax.numpy as jnp import jax.scipy as jsp diff --git a/src/jnotype/_factor_analysis/_gibbs_backend.py b/src/jnotype/_factor_analysis/_gibbs_backend.py index 8ed08c5..afe7b52 100644 --- a/src/jnotype/_factor_analysis/_gibbs_backend.py +++ b/src/jnotype/_factor_analysis/_gibbs_backend.py @@ -1,5 +1,6 @@ """Sampling steps for all variables, apart from variances attributed to latent traits, which are sampled with CSP module.""" + from typing import Callable import jax diff --git a/src/jnotype/_factor_analysis/_inference.py b/src/jnotype/_factor_analysis/_inference.py index 01a2f6c..e34102c 100644 --- a/src/jnotype/_factor_analysis/_inference.py +++ b/src/jnotype/_factor_analysis/_inference.py @@ -1,4 +1,5 @@ """Sampling from the posterior distribution.""" + from typing import Callable from jaxtyping import Float, Array diff --git a/src/jnotype/_factor_analysis/_simulate.py b/src/jnotype/_factor_analysis/_simulate.py index 7cfd6ff..70fd1a7 100644 --- a/src/jnotype/_factor_analysis/_simulate.py +++ b/src/jnotype/_factor_analysis/_simulate.py @@ -1,4 +1,5 @@ """Simulate data sets.""" + from jaxtyping import Float, Array import jax.numpy as jnp diff --git a/src/jnotype/_utils.py b/src/jnotype/_utils.py index d2f3173..9aa4416 100644 --- a/src/jnotype/_utils.py +++ b/src/jnotype/_utils.py @@ -3,6 +3,7 @@ This file should be as small as possible. Appearing themes should be refactored and placed into separate modules.""" + import jax diff --git a/src/jnotype/_variance.py b/src/jnotype/_variance.py index e3abbf8..5bbe5ef 100644 --- a/src/jnotype/_variance.py +++ b/src/jnotype/_variance.py @@ -1,4 +1,5 @@ """Utilities for sampling variances.""" + from jax import random import jax import jax.numpy as jnp diff --git a/src/jnotype/bmm/__init__.py b/src/jnotype/bmm/__init__.py index 19f4573..8691fd1 100644 --- a/src/jnotype/bmm/__init__.py +++ b/src/jnotype/bmm/__init__.py @@ -1,4 +1,5 @@ """Bernoulli Mixture Model.""" + from jnotype.bmm._em import expectation_maximization from jnotype.bmm._gibbs import ( sample_mixing, diff --git a/src/jnotype/bmm/_em.py b/src/jnotype/bmm/_em.py index a9db452..8db696e 100644 --- a/src/jnotype/bmm/_em.py +++ b/src/jnotype/bmm/_em.py @@ -1,4 +1,5 @@ """The Expectation-Maximization algorithm for Bernoulli Mixture Model.""" + import dataclasses import time from typing import Optional @@ -68,7 +69,11 @@ def em_step( observed: Int[Array, "N K"], mixing: Float[Array, "K B"], proportions: Float[Array, " B"], -) -> tuple[Float[Array, "N B"], Float[Array, "K B"], Float[Array, " B"],]: +) -> tuple[ + Float[Array, "N B"], + Float[Array, "K B"], + Float[Array, " B"], +]: """The E and M step combined, for better JIT compiler optimisation. Args: diff --git a/src/jnotype/bmm/_gibbs.py b/src/jnotype/bmm/_gibbs.py index 0d6c0f4..80f537e 100644 --- a/src/jnotype/bmm/_gibbs.py +++ b/src/jnotype/bmm/_gibbs.py @@ -1,4 +1,5 @@ """Sampling cluster labels and proportions.""" + from typing import Optional, Sequence import jax diff --git a/src/jnotype/checks/_histograms.py b/src/jnotype/checks/_histograms.py index 35ba21c..6577570 100644 --- a/src/jnotype/checks/_histograms.py +++ b/src/jnotype/checks/_histograms.py @@ -1,4 +1,5 @@ """Plotting histograms of data.""" + from typing import Sequence, Union, Literal import matplotlib.pyplot as plt diff --git a/src/jnotype/datasets/__init__.py b/src/jnotype/datasets/__init__.py index dd99224..64d6e62 100644 --- a/src/jnotype/datasets/__init__.py +++ b/src/jnotype/datasets/__init__.py @@ -1,4 +1,5 @@ """Data sets.""" + from jnotype.datasets._simulation import BlockImagesSampler __all__ = [ diff --git a/src/jnotype/datasets/_simulation/__init__.py b/src/jnotype/datasets/_simulation/__init__.py index cf48708..46ee3aa 100644 --- a/src/jnotype/datasets/_simulation/__init__.py +++ b/src/jnotype/datasets/_simulation/__init__.py @@ -1,4 +1,5 @@ """Simulated data sets.""" + from jnotype.datasets._simulation._block_images import BlockImagesSampler __all__ = [ diff --git a/src/jnotype/datasets/_simulation/_block_images.py b/src/jnotype/datasets/_simulation/_block_images.py index 2f3b881..76bafc3 100644 --- a/src/jnotype/datasets/_simulation/_block_images.py +++ b/src/jnotype/datasets/_simulation/_block_images.py @@ -1,4 +1,5 @@ """Simulation of binary images using Bernoulli mixture model.""" + from typing import Optional from jaxtyping import Array, Float, Int diff --git a/src/jnotype/logistic/_binary_latent.py b/src/jnotype/logistic/_binary_latent.py index 6334688..0cbd67b 100644 --- a/src/jnotype/logistic/_binary_latent.py +++ b/src/jnotype/logistic/_binary_latent.py @@ -1,4 +1,5 @@ """Sample binary latent variables.""" + from functools import partial import jax diff --git a/src/jnotype/logistic/_polyagamma.py b/src/jnotype/logistic/_polyagamma.py index 4fb9144..fbfe5a4 100644 --- a/src/jnotype/logistic/_polyagamma.py +++ b/src/jnotype/logistic/_polyagamma.py @@ -1,4 +1,5 @@ """Logistic regression sampling utilities using PĆ³lya-Gamma augmentation.""" + from jax import random import jax import jax.numpy as jnp @@ -53,9 +54,9 @@ def _sample_coefficients( precision_matrices: Float[Array, "features covariates covariates"] = jax.vmap( jnp.diag )(jnp.reciprocal(prior_variance)) - posterior_covariances: Float[ - Array, "features covariates covariates" - ] = jnp.linalg.inv(x_omega_x + precision_matrices) + posterior_covariances: Float[Array, "features covariates covariates"] = ( + jnp.linalg.inv(x_omega_x + precision_matrices) + ) kappa: Float[Array, "points features"] = jnp.asarray(observed, dtype=float) - 0.5 diff --git a/src/jnotype/logistic/_structure.py b/src/jnotype/logistic/_structure.py index a28470b..92d9f12 100644 --- a/src/jnotype/logistic/_structure.py +++ b/src/jnotype/logistic/_structure.py @@ -1,4 +1,5 @@ """Sample structure (spike/slab distinction) variables.""" + from typing import Union import jax diff --git a/src/jnotype/logistic/logreg.py b/src/jnotype/logistic/logreg.py index 791b9e1..6d140aa 100644 --- a/src/jnotype/logistic/logreg.py +++ b/src/jnotype/logistic/logreg.py @@ -1,4 +1,5 @@ """Logistic regression utilities.""" + import jax import jax.numpy as jnp from jaxtyping import Int, Float, Array diff --git a/src/jnotype/pyramids/_sampler_csp.py b/src/jnotype/pyramids/_sampler_csp.py index da9cdc1..9895675 100644 --- a/src/jnotype/pyramids/_sampler_csp.py +++ b/src/jnotype/pyramids/_sampler_csp.py @@ -1,6 +1,7 @@ """Sampler for two-layer Bayesian pyramids with cumulative shrinkage process (CSP) prior on latent binary codes.""" + from typing import Optional, Sequence, Union, NewType import jax diff --git a/src/jnotype/pyramids/_sampler_fixed.py b/src/jnotype/pyramids/_sampler_fixed.py index 61c17bd..9951563 100644 --- a/src/jnotype/pyramids/_sampler_fixed.py +++ b/src/jnotype/pyramids/_sampler_fixed.py @@ -1,5 +1,6 @@ """Sampler for two-layer Bayesian pyramids with fixed number of latent binary codes.""" + from typing import Optional, Sequence, Union, NewType import jax diff --git a/src/jnotype/sampling/__init__.py b/src/jnotype/sampling/__init__.py index ddef699..04af43d 100644 --- a/src/jnotype/sampling/__init__.py +++ b/src/jnotype/sampling/__init__.py @@ -1,4 +1,5 @@ """Generic utilities for sampling.""" + from jnotype.sampling._chunker import ( DatasetInterface, ListDataset, diff --git a/src/jnotype/sampling/_chunker.py b/src/jnotype/sampling/_chunker.py index 119d1cc..ad22982 100644 --- a/src/jnotype/sampling/_chunker.py +++ b/src/jnotype/sampling/_chunker.py @@ -1,4 +1,5 @@ """Utilities for saving samples in chunks, to limit RAM usage.""" + import abc from datetime import datetime diff --git a/src/jnotype/sampling/_sampler.py b/src/jnotype/sampling/_sampler.py index 7191c15..5783533 100644 --- a/src/jnotype/sampling/_sampler.py +++ b/src/jnotype/sampling/_sampler.py @@ -1,4 +1,5 @@ """Generic Gibbs sampler.""" + import abc import logging import time