From ced040ce580035604e84d3ba010e0bc1d275028e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Wed, 5 Jun 2024 11:38:42 +0200 Subject: [PATCH] Frequentist infererence utilities in JAX (#8) * Some prototypes * Add parameterization. * Add total loss * Silence ruff false positive errors * Add functions for calculating standard errors and predictions * Add multistart optimization * Minor changes * Loss normalization * going back in the sewer * Pin PyMC version * Trying to configure ruff * Add a prototype of finding MAP using JAX * Add NumPyro * Add MCMC prototype * Address Ruff --------- Co-authored-by: dr-david --- examples/frequentist_fitting-JAX.py | 606 ++++++++++++++++++++++++++++ examples/frequentist_fitting.py | 18 +- examples/splines_demonstration.py | 3 +- pyproject.toml | 6 +- src/covvfit/__init__.py | 5 +- src/covvfit/_frequentist.py | 4 +- src/covvfit/_frequentist_jax.py | 273 +++++++++++++ src/covvfit/plotting/__init__.py | 2 +- src/covvfit/plotting/_timeseries.py | 4 +- src/covvfit/simulation/_sde.py | 4 +- tests/simulation/test_sde.py | 2 - tests/test_frequentist_jax.py | 48 +++ 12 files changed, 946 insertions(+), 29 deletions(-) create mode 100644 examples/frequentist_fitting-JAX.py create mode 100644 src/covvfit/_frequentist_jax.py create mode 100644 tests/test_frequentist_jax.py diff --git a/examples/frequentist_fitting-JAX.py b/examples/frequentist_fitting-JAX.py new file mode 100644 index 0000000..09a8891 --- /dev/null +++ b/examples/frequentist_fitting-JAX.py @@ -0,0 +1,606 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.16.1 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# %% +import covvfit._frequentist as freq +import covvfit._frequentist_jax as fj +import covvfit._preprocess_abundances as prec +import covvfit.plotting._timeseries as plot_ts +import jax +import jax.nn as nn +import jax.numpy as jnp +import matplotlib.pyplot as plt +import matplotlib.ticker as ticker +import numpy as np +import numpyro +import numpyro.distributions as dist +import pandas as pd +import pymc as pm +from numpyro.infer import MCMC, NUTS +from scipy.special import expit + +variants_full = [ + "B.1.1.7", + "B.1.351", + "P.1", + "B.1.617.2", + "BA.1", + "BA.2", + "BA.4", + "BA.5", + "BA.2.75", + "BQ.1.1", + "XBB.1.5", + "XBB.1.9", + "XBB.1.16", + "XBB.2.3", + "EG.5", + "BA.2.86", + "JN.1", +] + +variants = ["XBB.1.5", "XBB.1.9", "XBB.1.16", "XBB.2.3", "EG.5", "BA.2.86", "JN.1"] + +variants_other = [i for i in variants_full if i not in variants] + + +cities = [ + "Lugano (TI)", + "Zürich (ZH)", + "Chur (GR)", + "Altenrhein (SG)", + "Laupen (BE)", + "Genève (GE)", + "Basel (BS)", + "Porrentruy (JU)", + "Lausanne (VD)", + "Bern (BE)", + "Luzern (LU)", + "Solothurn (SO)", + "Neuchâtel (NE)", + "Schwyz (SZ)", +] + +colors_covsp = { + "B.1.1.7": "#D16666", + "B.1.351": "#FF6666", + "P.1": "#FFB3B3", + "B.1.617.1": "#66C265", + "B.1.617.2": "#66A366", + "BA.1": "#A366A3", + "BA.2": "#CFAFCF", + "BA.4": "#8a66ff", + "BA.5": "#585eff", + "BA.2.75": "#008fe0", + "BQ.1.1": "#ac00e0", + "XBB.1.9": "#bb6a33", + "XBB.1.5": "#ff5656", + "XBB.1.16": "#e99b30", + "XBB.2.3": "#f5e424", + "EG.5": "#b4e80b", + "BA.2.86": "#FF20E0", + "JN.1": "#00e9ff", # improv + "undetermined": "#969696", +} + + +# %% +DATA_PATH = "../data/robust_deconv2_noisy14.csv" + +data = prec.load_data(DATA_PATH) +variants2 = ["other"] + variants +data2 = prec.preprocess_df(data, cities, variants_full, date_min="2023-04-01") +data2["other"] = data2[variants_other].sum(axis=1) +data2[variants2] = data2[variants2].div(data2[variants2].sum(axis=1), axis=0) + +ts_lst, ys_lst = prec.make_data_list(data2, cities, variants2) +ts_lst, ys_lst2 = prec.make_data_list(data2, cities, variants) # TODO: To be removed? + +data = [fj.CityData(ts=ts, ys=ys.T, n=1) for ts, ys in zip(ts_lst, ys_lst)] + +n_cities = len(data) +n_variants = len(variants2) + + +# %% +loss_loglike = fj.construct_total_loss(data, average_loss=False) + + +def loss_prior(x: jnp.ndarray, mu: float = 0.15, sigma: float = 0.1) -> float: + # Shape and rate of gamma distribution + alpha = jnp.square(mu / sigma) + beta = mu / jnp.square(sigma) + + # Return -log_prob(growths) + g = fj.get_relative_growths(x, n_variants=n_variants) + return -jnp.sum(dist.Gamma(alpha, beta).log_prob(g)) + + +def optim_to_param(y): + midpoints = fj.get_relative_midpoints(y, n_variants=n_variants) + unconstrained_rates = fj.get_relative_growths(y, n_variants=n_variants) + rates = nn.softplus(unconstrained_rates) + return fj.construct_theta(relative_growths=rates, relative_midpoints=midpoints) + + +def loss_total(y): + x = optim_to_param(y) + return loss_loglike(x) + loss_prior(x) + + +theta0 = fj.construct_theta0(n_cities=n_cities, n_variants=n_variants) + +solution = fj.jax_multistart_minimize( + loss_total, + theta0, + random_seed=1, + n_starts=20, +) + +print(fj.get_relative_growths(optim_to_param(solution.best.x), n_variants=n_variants)) +print(solution.best.fun) + + +# %% +def model(mu: float = 0.15, sigma: float = 0.1): + midpoints = numpyro.sample( + "midpoints", + dist.Normal( + jnp.zeros_like(fj.get_relative_midpoints(theta0, n_variants=n_variants)), + 100, + ), + ) + + alpha = jnp.square(mu / sigma) + beta = mu / jnp.square(sigma) + growths = numpyro.sample( + "growths", + dist.Gamma( + alpha + * jnp.ones_like(fj.get_relative_growths(theta0, n_variants=n_variants)), + beta, + ), + ) + + # growths = numpyro.sample("growths", dist.TruncatedNormal(mu * jnp.ones_like(fj.get_relative_growths(theta0, n_variants=n_variants)), sigma, low=0.01, high=1.0)) + + x = fj.construct_theta(relative_growths=growths, relative_midpoints=midpoints) + numpyro.factor("loglikelihood", -loss_loglike(x)) + + +mcmc = MCMC(NUTS(model), num_warmup=1_000, num_samples=1_000, num_chains=4) +mcmc.run(jax.random.PRNGKey(0)) + +# %% +mcmc.print_summary() + +# %% + +rng = np.random.default_rng(42) + +mu = 0.15 +sigma = 0.1 + +alpha = jnp.square(mu / sigma) +beta = mu / jnp.square(sigma) + +samples = rng.gamma(alpha, 1 / beta, size=10_000) + +print(np.mean(samples)) +print(np.std(samples)) + +plt.hist(samples, bins=np.linspace(0, 0.5, 50), color="salmon", alpha=0.4, density=True) +plt.hist( + mcmc.get_samples()["growths"][:, -1], + bins=np.linspace(0, 0.5, 50), + color="darkblue", + alpha=0.4, + density=True, +) + + +np.mean(samples < 0.01) + +# %% +## This model takes into account the complement of the variants to be monitored, and sets its fitness to zero +## However, due to the pm.math.concatenate operation, we cannot use it for finding the hessian + + +def create_model_fixed2( + ts_lst, + ys_lst, + n=1.0, + coords={ + "cities": [], + "variants": [], + }, + n_pred=60, +): + """function to create a fixed effect model with varying intercepts and one rate vector""" + with pm.Model(coords=coords) as model: + midpoint_var = pm.Normal( + "midpoint", mu=0.0, sigma=300.0, dims=["cities", "variants"] + ) + rate_var = pm.Gamma("rate", mu=0.15, sigma=0.1, dims="variants") + + # Kaan's trick to avoid overflows + def softmax(x, rates, midpoints): + E = rates[:, None] * x + midpoints[:, None] + E_max = E.max(axis=0) + un_norm = pm.math.exp(E - E_max) + return un_norm / (pm.math.sum(un_norm, axis=0)) + + ys_smooth = [ + softmax( + ts_lst[i], + pm.math.concatenate([[0], rate_var]), + pm.math.concatenate([[0], midpoint_var[i, :]]), + ) + for i, city in enumerate(coords["cities"]) + ] + + # make Multinom/n likelihood + def log_likelihood(y, p, n): + # return n*pm.math.sum(y * pm.math.log(p), axis=0) + n*(1-pm.math.sum(y, axis=0))*pm.math.log(1-pm.math.sum(p, axis=0)) + return n * pm.math.sum(y * pm.math.log(p), axis=0) + + [ + pm.DensityDist( + f"ys_noisy_{city}", + ys_smooth[i], + n, + logp=log_likelihood, + observed=ys_lst[i], + ) + for i, city in enumerate(coords["cities"]) + ] + + return model + + +# %% +with create_model_fixed2( + ts_lst, + ys_lst, + coords={ + "cities": cities, + "variants": variants, + }, +): + model_map_fixed = pm.find_MAP(maxeval=50000, seed=12313) + + +# %% +print(model_map_fixed["rate"]) + +# %% +## This model takes into account the complement of the variants to be monitored, and sets its fitness to zero +## It has some numerical instabilities that make it not suitable for finding the MAP or MLE, but I use it for the Hessian + + +def create_model_fixed3( + ts_lst, + ys_lst, + n=1.0, + coords={ + "cities": [], + "variants": [], + }, + n_pred=60, +): + """function to create a fixed effect model with varying intercepts and one rate vector""" + with pm.Model(coords=coords) as model: + midpoint_var = pm.Normal( + "midpoint", mu=0.0, sigma=1500.0, dims=["cities", "variants"] + ) + rate_var = pm.Gamma("rate", mu=0.15, sigma=0.1, dims="variants") + + # Kaan's trick to avoid overflows + def softmax_1(x, rates, midpoints): + E = rates[:, None] * x + midpoints[:, None] + E_max = E.max(axis=0) + un_norm = pm.math.exp(E - E_max) + return un_norm / (pm.math.exp(-E_max) + pm.math.sum(un_norm, axis=0)) + + ys_smooth = [ + softmax_1(ts_lst[i], rate_var, midpoint_var[i, :]) + for i, city in enumerate(coords["cities"]) + ] + + # make Multinom/n likelihood + def log_likelihood(y, p, n): + return n * pm.math.sum(y * pm.math.log(p), axis=0) + n * ( + 1 - pm.math.sum(y, axis=0) + ) * pm.math.log(1 - pm.math.sum(p, axis=0)) + + # return n*pm.math.sum(y * pm.math.log(p), axis=0) + + [ + pm.DensityDist( + f"ys_noisy_{city}", + ys_smooth[i], + n, + logp=log_likelihood, + observed=ys_lst[i], + ) + for i, city in enumerate(coords["cities"]) + ] + + return model + + +# %% +with create_model_fixed3( + ts_lst, + ys_lst2, + coords={ + "cities": cities, + "variants": variants, + }, +): + model_hessian_fixed = pm.find_hessian(model_map_fixed) + +# %% +y_fit_lst = freq.fitted_values(ts_lst, model_map_fixed, cities) +ts_pred_lst, y_pred_lst = freq.pred_values( + [i.max() - 1 for i in ts_lst], model_map_fixed, cities, horizon=60 +) +pearson_r_lst, overdisp_list, overdisp_fixed = freq.compute_overdispersion( + ys_lst2, y_fit_lst, cities +) +( + fitness_diff, + fitness_diff_se, + fitness_diff_lower, + fitness_diff_upper, +) = freq.make_fitness_confints( + model_map_fixed["rate"], model_hessian_fixed, overdisp_fixed, g=7.0 +) + +# %% [markdown] +# ## Plot + +# %% +fig, axes_tot = plt.subplots(5, 3, figsize=(22, 15)) +# colors = default_cmap = plt.cm.get_cmap('tab10').colors +colors = [colors_covsp[var] for var in variants] +# axes=[axes_tot] +axes = axes_tot.flatten() +p_variants = len(variants) +p_params = model_hessian_fixed.shape[0] +model_hessian_inv = np.linalg.inv(model_hessian_fixed) + +for k, city in enumerate(cities): + ax = axes[k + 1] + y_fit = y_fit_lst[k] + ts = ts_lst[k] + ts_pred = ts_pred_lst[k] + y_pred = y_pred_lst[k] + ys = ys_lst2[k] + hessian_indices = np.concatenate( + [ + np.arange(p_variants) + k * p_variants, + np.arange(model_hessian_fixed.shape[0] - p_variants, p_params), + ] + ) + tmp_hessian = model_hessian_inv[hessian_indices, :][:, hessian_indices] + y_fit_logit = np.log(y_fit) - np.log(1 - y_fit) + logit_se = np.array( + [ + freq.project_se( + model_map_fixed["rate"], + model_map_fixed["midpoint"][k, :], + t, + tmp_hessian, + overdisp_list[k], + ) + for t in ts + ] + ).T + y_pred_logit = np.log(y_pred) - np.log(1 - y_pred) + logit_se_pred = np.array( + [ + freq.project_se( + model_map_fixed["rate"], + model_map_fixed["midpoint"][k, :], + t, + tmp_hessian, + overdisp_list[k], + ) + for t in ts_pred + ] + ).T + + for i, variant in enumerate(variants): + # grid + ax.vlines( + x=( + pd.date_range(start="2021-11-01", end="2024-02-01", freq="MS") + - pd.to_datetime("2023-01-01") + ).days, + ymin=-0.05, + ymax=1.05, + color="grey", + alpha=0.02, + ) + ax.hlines( + y=[0, 0.25, 0.5, 0.75, 1], + xmin=(pd.to_datetime("2021-10-10") - pd.to_datetime("2023-01-01")).days, + xmax=(pd.to_datetime("2024-02-20") - pd.to_datetime("2023-01-01")).days, + color="grey", + alpha=0.02, + ) + ax.fill_between(x=ts_pred, y1=0, y2=1, color="grey", alpha=0.01) + + # plot fitted + sorted_indices = np.argsort(ts) + ax.plot( + ts[sorted_indices], + y_fit[i, :][sorted_indices], + color=colors[i], + label="fit", + ) + # plot pred + ax.plot(ts_pred, y_pred[i, :], color=colors[i], linestyle="--", label="predict") + # plot confints + ax.fill_between( + ts[sorted_indices], + expit( + y_fit_logit[i, :][sorted_indices] + - 1.96 * logit_se[i, :][sorted_indices] + ), + expit( + y_fit_logit[i, :][sorted_indices] + + 1.96 * logit_se[i, :][sorted_indices] + ), + color=colors[i], + alpha=0.2, + label="Confidence band", + ) + ax.fill_between( + ts_pred, + expit(y_pred_logit[i, :] - 1.96 * logit_se_pred[i, :]), + expit(y_pred_logit[i, :] + 1.96 * logit_se_pred[i, :]), + color=colors[i], + alpha=0.2, + label="Confidence band", + ) + + # plot empirical + ax.scatter(ts, ys[i, :], label="observed", alpha=0.5, color=colors[i], s=4) + + ax.set_ylim((-0.05, 1.05)) + ax.set_xticks( + ( + pd.date_range(start="2021-11-01", end="2023-12-01", freq="MS") + - pd.to_datetime("2023-01-01") + ).days + ) + date_formatter = ticker.FuncFormatter(plot_ts.num_to_date) + ax.xaxis.set_major_formatter(date_formatter) + tick_positions = [0, 0.5, 1] + tick_labels = ["0%", "50%", "100%"] + ax.set_yticks(tick_positions) + ax.set_yticklabels(tick_labels) + ax.set_ylabel("relative abundances") + ax.set_xlim( + ( + pd.to_datetime(["2023-03-15", "2024-01-05"]) + - pd.to_datetime("2023-01-01") + ).days + ) + ax.set_title(f"{city}") + +## Plot estimates + +ax = axes[0] + +( + fitness_diff, + fitness_diff_se, + fitness_diff_lower, + fitness_diff_upper, +) = freq.make_fitness_confints( + model_map_fixed["rate"], model_hessian_fixed, overdisp_fixed, g=7.0 +) + +fitness_diff = fitness_diff * 100 +fitness_diff_lower = fitness_diff_lower * 100 +fitness_diff_upper = fitness_diff_upper * 100 + +# Get the indices for the upper triangle, starting at the diagonal (k=0) +upper_triangle_indices = np.triu_indices_from(fitness_diff, k=0) + +# Assign np.nan to the upper triangle including the diagonal +fitness_diff[upper_triangle_indices] = np.nan +fitness_diff_lower[upper_triangle_indices] = np.nan +fitness_diff_upper[upper_triangle_indices] = np.nan + +fitness_diff[:-2, :] = np.nan +fitness_diff_lower[:-2, :] = np.nan +fitness_diff_upper[:-2, :] = np.nan + +# Calculate the error (distance from the point to the error bar limit) +error = np.array( + [ + fitness_diff - fitness_diff_lower, # Lower error + fitness_diff_upper - fitness_diff, # Upper error + ] +) + +# Define the width of the offset +offset_width = 0.1 +num_sets = fitness_diff.shape[0] +# num_sets = 2 +mid = (num_sets - 1) / 2 + +# grid +ax.vlines( + x=np.arange(len(variants) - 1), + ymin=np.nanmin(fitness_diff_lower), + ymax=np.nanmax(fitness_diff_upper), + color="grey", + alpha=0.2, +) +ax.hlines( + y=np.arange(-25, 126, step=25), + xmin=-0.5, + xmax=len(variants) - 2 + 0.5, + color="grey", + alpha=0.2, +) + +# Plot each set of points with error bars +for i, y_vals in enumerate(fitness_diff): + # Calculate offset for each set + offset = (i - mid) * offset_width + # Create an array of x positions for this set + # x_positions = np.arange(len(variants)) + offset + x_positions = np.arange(len(variants)) + offset - 0.25 + # We need to transpose the error array to match the shape of y_vals + ax.errorbar( + x_positions, + y_vals, + yerr=error[:, i, :], + fmt="o", + label=variants[i], + color=colors_covsp[variants[i]], + ) + +# Set the x-ticks to be at the middle of the groups of points +ax.set_xticks(np.arange(len(variants) - 1)) +ax.set_xticklabels(variants[:-1]) + +# Add some labels and a legend +ax.set_xlabel("Variants") +ax.set_ylabel("% weekly growth advantage") +ax.set_title("growth advantages") + + +fig.tight_layout() +fig.legend( + handles=plot_ts.make_legend(colors, variants), + loc="lower center", + ncol=9, + bbox_to_anchor=(0.5, -0.04), + frameon=False, +) + + +plt.savefig("growth_rates20231108.pdf", bbox_inches="tight") + +plt.show() + + +# %% diff --git a/examples/frequentist_fitting.py b/examples/frequentist_fitting.py index 02373cb..3a6a865 100644 --- a/examples/frequentist_fitting.py +++ b/examples/frequentist_fitting.py @@ -13,21 +13,15 @@ # --- # %% -import pandas as pd -import pymc as pm - -import numpy as np - -import matplotlib.ticker as ticker -import matplotlib.pyplot as plt - -from scipy.special import expit - - import covvfit._frequentist as freq import covvfit._preprocess_abundances as prec import covvfit.plotting._timeseries as plot_ts - +import matplotlib.pyplot as plt +import matplotlib.ticker as ticker +import numpy as np +import pandas as pd +import pymc as pm +from scipy.special import expit variants_full = [ "B.1.1.7", diff --git a/examples/splines_demonstration.py b/examples/splines_demonstration.py index 57b1b5b..23d9f84 100644 --- a/examples/splines_demonstration.py +++ b/examples/splines_demonstration.py @@ -1,10 +1,9 @@ """Bayesian regression using B-splines.""" +import covvfit as cv import matplotlib.pyplot as plt import numpy as np import pymc as pm -import covvfit as cv - def create_model( xs: np.ndarray, diff --git a/pyproject.toml b/pyproject.toml index 4ec3983..f2b6b85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ numpy = "==1.24.3" pytensor = "==2.11.1" pymc = "==5.3.0" seaborn = "^0.13.2" +numpyro = "^0.14.0" [tool.poetry.group.dev] optional = true @@ -36,11 +37,14 @@ jupyterlab = "^4.1.6" [tool.ruff] exclude = [".venv"] -ignore = ["E501"] +select = ["E", "F", "I001"] +ignore = ["E721", "E731", "F722", "E501"] +# ignore-init-module-imports = true [tool.jupytext] formats = "ipynb,py:percent" + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/src/covvfit/__init__.py b/src/covvfit/__init__.py index 1a9545c..d395c2a 100644 --- a/src/covvfit/__init__.py +++ b/src/covvfit/__init__.py @@ -1,8 +1,7 @@ -from covvfit._splines import create_spline_matrix -from covvfit._preprocess_abundances import make_data_list, preprocess_df, load_data - import covvfit._frequentist as freq import covvfit.plotting as plot +from covvfit._preprocess_abundances import load_data, make_data_list, preprocess_df +from covvfit._splines import create_spline_matrix VERSION = "0.1.0" diff --git a/src/covvfit/_frequentist.py b/src/covvfit/_frequentist.py index c647fcd..3e1bc1c 100644 --- a/src/covvfit/_frequentist.py +++ b/src/covvfit/_frequentist.py @@ -1,9 +1,7 @@ """utilities to fit frequentist models""" -import pymc as pm - import numpy as np - +import pymc as pm __all__ = [ "create_model_fixed", diff --git a/src/covvfit/_frequentist_jax.py b/src/covvfit/_frequentist_jax.py new file mode 100644 index 0000000..c97b41a --- /dev/null +++ b/src/covvfit/_frequentist_jax.py @@ -0,0 +1,273 @@ +"""Frequentist fitting functions powered by JAX.""" +import dataclasses +from typing import Callable, NamedTuple, Sequence + +import jax +import jax.numpy as jnp +import numpy as np +from jaxtyping import Array, Float +from scipy import optimize + + +def calculate_linear( + ts: Float[Array, " *batch"], + midpoints: Float[Array, " variants"], + growths: Float[Array, " variants"], +) -> Float[Array, "*batch variants"]: + shape = (1,) * ts.ndim + (-1,) + m = midpoints.reshape(shape) + g = growths.reshape(shape) + + return (ts[..., None] - m) * g + + +_Float = float | Float[Array, " "] + + +def calculate_logps( + ts: Float[Array, " *batch"], + midpoints: Float[Array, " variants"], + growths: Float[Array, " variants"], +) -> Float[Array, "*batch variants"]: + linears = calculate_linear( + ts=ts, + midpoints=midpoints, + growths=growths, + ) + return jax.nn.log_softmax(linears, axis=-1) + + +def loss( + y: Float[Array, "*batch variants"], + logp: Float[Array, "*batch variants"], + n: _Float, +) -> Float[Array, " *batch"]: + # Note: we want loss (lower is better), rather than + # total loglikelihood (higher is better), + # so we add the negative sign. + return -jnp.sum(n * y * logp, axis=-1) + + +class CityData(NamedTuple): + ts: Float[Array, " timepoints"] + ys: Float[Array, "timepoints variants"] + n: _Float + + +_ThetaType = Float[Array, "(cities+1)*(variants-1)"] + + +def add_first_variant(vec: Float[Array, " variants-1"]) -> Float[Array, " variants"]: + return jnp.concatenate([jnp.zeros_like(vec)[0:1], vec]) + + +def construct_total_loss( + cities: Sequence[CityData], + average_loss: bool = False, +) -> Callable[[_ThetaType], _Float]: + cities = tuple(cities) + n_variants = cities[0].ys.shape[-1] + for city in cities: + assert ( + city.ys.shape[-1] == n_variants + ), "All cities must have the same number of variants" + + if average_loss: + n_points_total = 1.0 * sum(city.ts.shape[0] for city in cities) + else: + n_points_total = 1.0 + + def total_loss(theta: _ThetaType) -> _Float: + rel_growths = get_relative_growths(theta, n_variants=n_variants) + rel_midpoints = get_relative_midpoints(theta, n_variants=n_variants) + + growths = add_first_variant(rel_growths) + return ( + jnp.sum( + jnp.asarray( + [ + loss( + y=city.ys, + n=city.n, + logp=calculate_logps( + ts=city.ts, + midpoints=add_first_variant(midp), + growths=growths, + ), + ).sum() + for midp, city in zip(rel_midpoints, cities) + ] + ) + ) + / n_points_total + ) + + return total_loss + + +def construct_theta( + relative_growths: Float[Array, " variants-1"], + relative_midpoints: Float[Array, "cities variants-1"], +) -> _ThetaType: + flattened_midpoints = relative_midpoints.flatten() + theta = jnp.concatenate([relative_growths, flattened_midpoints]) + return theta + + +def get_relative_growths( + theta: _ThetaType, + n_variants: int, +) -> Float[Array, " variants-1"]: + return theta[: n_variants - 1] + + +def get_relative_midpoints( + theta: _ThetaType, + n_variants: int, +) -> Float[Array, "cities variants-1"]: + n_cities = theta.shape[0] // (n_variants - 1) - 1 + return theta[n_variants - 1 :].reshape(n_cities, n_variants - 1) + + +class StandardErrorsMultipliers(NamedTuple): + CI95: float = 1.96 + + @staticmethod + def convert(confidence: float) -> float: + """Calculates the multiplier for a given confidence level. + + Example: + StandardErrorsMultipliers.convert(0.95) # 1.9599 + """ + return float(jax.scipy.stats.norm.ppf((1 + confidence) / 2)) + + +def get_standard_errors( + jacobian: Float[Array, "*output_shape n_inputs"], + covariance: Float[Array, "n_inputs n_inputs"], +) -> Float[Array, " *output_shape"]: + """Delta method to calculate standard errors of a function + from `n_inputs` to `output_shape`. + + Args: + jacobian: Jacobian of the function to be fitted, shape (output_shape, n_inputs) + covariance: Covariance matrix of the inputs, shape (n_inputs, n_inputs) + + Returns: + Standard errors of the fitted parameters, shape (output_shape,) + + Note: + `output_shape` can be a vector, in which case the output is a vector + of standard errors or a tensor of any other shape, + in which case the output is a tensor of standard errors for each output + coordinate. + """ + return jnp.sqrt(jnp.einsum("...L,KL,...K -> ...", jacobian, covariance, jacobian)) + + +def triangular_mask(n_variants, valid_value: float = 0, masked_value: float = jnp.nan): + """Creates a triangular mask. Helpful for masking out redundant parameters + in anti-symmetric matrices.""" + a = jnp.arange(n_variants) + nan_mask = jnp.where(a[:, None] < a[None, :], valid_value, masked_value) + return nan_mask + + +def get_relative_advantages(theta, n_variants: int): + # Shape (n_variants-1,) describing relative advantages + # over the 0th variant + rel_growths = get_relative_growths(theta, n_variants=n_variants) + + growths = jnp.concatenate((jnp.zeros(1, dtype=rel_growths.dtype), rel_growths)) + diffs = growths[None, :] - growths[:, None] + return diffs + + +def get_softmax_predictions( + theta: _ThetaType, n_variants: int, city_index: int, ts: Float[Array, " timepoints"] +) -> Float[Array, "timepoints variants"]: + rel_growths = get_relative_growths(theta, n_variants=n_variants) + growths = add_first_variant(rel_growths) + + rel_midpoints = get_relative_midpoints(theta, n_variants=n_variants) + midpoints = add_first_variant(rel_midpoints[city_index]) + + y_linear = calculate_linear( + ts=ts, + midpoints=midpoints, + growths=growths, + ) + + y_softmax = jax.nn.softmax(y_linear, axis=-1) + return y_softmax + + +def get_logit_predictions( + theta: _ThetaType, + n_variants: int, + city_index: int, + ts: Float[Array, " timepoints"], +) -> Float[Array, "timepoints variants"]: + return jax.scipy.special.logit( + get_softmax_predictions( + theta=theta, + n_variants=n_variants, + city_index=city_index, + ts=ts, + ) + ) + + +@dataclasses.dataclass +class OptimizeMultiResult: + x: np.ndarray + fun: float + best: optimize.OptimizeResult + runs: list[optimize.OptimizeResult] + + +def construct_theta0( + n_cities: int, + n_variants: int, +) -> _ThetaType: + return np.zeros((n_cities * (n_variants - 1) + n_variants - 1,), dtype=float) + + +def jax_multistart_minimize( + loss_fn, + theta0: np.ndarray, + n_starts: int = 10, + random_seed: int = 42, + maxiter: int = 10_000, +) -> OptimizeMultiResult: + # Create loss function and its gradient + _loss_grad_fun = jax.jit(jax.value_and_grad(loss_fn)) + + def loss_grad_fun(theta): + loss, grad = _loss_grad_fun(theta) + return np.asarray(loss), np.asarray(grad) + + solutions: list[optimize.OptimizeResult] = [] + rng = np.random.default_rng(random_seed) + + for i in range(1, n_starts + 1): + starting_point = theta0 + (i / n_starts) * rng.normal(size=theta0.shape) + sol = optimize.minimize( + loss_grad_fun, jac=True, x0=starting_point, options={"maxiter": maxiter} + ) + solutions.append(sol) + + # Find the optimal solution + optimal_index = None + optimal_value = np.inf + for i, sol in enumerate(solutions): + if sol.fun < optimal_value: + optimal_index = i + optimal_value = sol.fun + + return OptimizeMultiResult( + best=solutions[optimal_index], + x=solutions[optimal_index].x, + fun=solutions[optimal_index].fun, + runs=solutions, + ) diff --git a/src/covvfit/plotting/__init__.py b/src/covvfit/plotting/__init__.py index 0a01bc9..2940664 100644 --- a/src/covvfit/plotting/__init__.py +++ b/src/covvfit/plotting/__init__.py @@ -1,7 +1,7 @@ """Plotting functionalities.""" from covvfit.plotting._grid import plot_grid, set_axis_off from covvfit.plotting._simplex import plot_on_simplex -from covvfit.plotting._timeseries import make_legend, num_to_date, colors_covsp +from covvfit.plotting._timeseries import colors_covsp, make_legend, num_to_date __all__ = [ "plot_on_simplex", diff --git a/src/covvfit/plotting/_timeseries.py b/src/covvfit/plotting/_timeseries.py index b0130f1..6add266 100644 --- a/src/covvfit/plotting/_timeseries.py +++ b/src/covvfit/plotting/_timeseries.py @@ -1,9 +1,7 @@ """utilities to plot""" -import pandas as pd - import matplotlib.lines as mlines import matplotlib.patches as mpatches - +import pandas as pd colors_covsp = { "B.1.1.7": "#D16666", diff --git a/src/covvfit/simulation/_sde.py b/src/covvfit/simulation/_sde.py index 0364d91..87b2327 100644 --- a/src/covvfit/simulation/_sde.py +++ b/src/covvfit/simulation/_sde.py @@ -1,8 +1,7 @@ import jax -import jax.random as jrandom import jax.numpy as jnp +import jax.random as jrandom from diffrax import ( - diffeqsolve, ControlTerm, Euler, MultiTerm, @@ -10,6 +9,7 @@ SaveAt, Solution, VirtualBrownianTree, + diffeqsolve, ) diff --git a/tests/simulation/test_sde.py b/tests/simulation/test_sde.py index 33749f2..8e3ef7b 100644 --- a/tests/simulation/test_sde.py +++ b/tests/simulation/test_sde.py @@ -1,8 +1,6 @@ import jax import jax.numpy as jnp - import numpy.testing as npt - from covvfit.simulation._sde import ( simplex_complete, solve_stochastic_replicator_dynamics, diff --git a/tests/test_frequentist_jax.py b/tests/test_frequentist_jax.py new file mode 100644 index 0000000..04c0ac3 --- /dev/null +++ b/tests/test_frequentist_jax.py @@ -0,0 +1,48 @@ +import covvfit._frequentist_jax as fj +import jax +import numpy.testing as npt +import pytest + + +@pytest.mark.parametrize("n_cities", [1, 5, 12]) +@pytest.mark.parametrize("n_variants", [2, 3, 8]) +@pytest.mark.parametrize("seed", [0, 1, 2]) +def test_parameter_conversions_1(seed: int, n_cities: int, n_variants: int) -> None: + key = jax.random.PRNGKey(seed) + key1, key2 = jax.random.split(key) + growth_rel = jax.random.uniform(key1, shape=(n_variants - 1,)) + midpoint_rel = jax.random.uniform(key2, shape=(n_cities, n_variants - 1)) + + theta = fj.construct_theta( + relative_growths=growth_rel, + relative_midpoints=midpoint_rel, + ) + + npt.assert_allclose( + fj.get_relative_growths(theta, n_variants=n_variants), + growth_rel, + ) + npt.assert_allclose( + fj.get_relative_midpoints(theta, n_variants=n_variants), + midpoint_rel, + ) + + +@pytest.mark.parametrize("n_cities", [1, 5, 12]) +@pytest.mark.parametrize("n_variants", [2, 3, 8]) +@pytest.mark.parametrize("seed", [0, 1, 2]) +def test_parameter_conversions_2(seed: int, n_cities: int, n_variants: int) -> None: + theta = jax.random.uniform( + jax.random.PRNGKey(seed), shape=(n_cities * (n_variants - 1) + n_variants - 1,) + ) + + growth_rel = fj.get_relative_growths(theta, n_variants=n_variants) + midpoint_rel = fj.get_relative_midpoints(theta, n_variants=n_variants) + + npt.assert_allclose( + fj.construct_theta( + relative_growths=growth_rel, + relative_midpoints=midpoint_rel, + ), + theta, + )