From ecf8bc624da65524bae7a0d57cd858d9eef010ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Tue, 19 Nov 2024 10:24:22 +0100 Subject: [PATCH] WIP: Refactor notebook up to fitting the model --- examples/frequentist_notebook_jax.py | 126 +++++++++++++++------------ 1 file changed, 70 insertions(+), 56 deletions(-) diff --git a/examples/frequentist_notebook_jax.py b/examples/frequentist_notebook_jax.py index be24789..8c2d66f 100644 --- a/examples/frequentist_notebook_jax.py +++ b/examples/frequentist_notebook_jax.py @@ -13,41 +13,36 @@ # name: python3 # --- +# # Quasilikelihood data analysis notebook +# +# This notebook shows how to estimate growth advantages by fiting the model within the quasimultinomial framework. + # + -import jax import jax.numpy as jnp - -import pandas as pd - -import numpy as np - -import matplotlib.ticker as ticker import matplotlib.pyplot as plt -import seaborn as sns - -from scipy.special import expit -from scipy.stats import norm - +import matplotlib.ticker as ticker +import pandas as pd import yaml -import covvfit._preprocess_abundances as prec -import covvfit.plotting._timeseries as plot_ts - +from covvfit import plot, preprocess from covvfit import quasimultinomial as qm -import numpyro - +plot_ts = plot.timeseries # - -# # Load and preprocess data +# ## Load and preprocess data +# +# We start by loading the data: # + -DATA_PATH = "../../LolliPop/lollipop_covvfit/deconvolved.csv" -VAR_DATES_PATH = "../../LolliPop/lollipop_covvfit/var_dates.yaml" - -DATA_PATH = "../new_data/deconvolved.csv" -VAR_DATES_PATH = "../new_data/var_dates.yaml" +_dir_switch = False # Change this to 0 or 1, depending on the laptop you are on +if _dir_switch: + DATA_PATH = "../../LolliPop/lollipop_covvfit/deconvolved.csv" + VAR_DATES_PATH = "../../LolliPop/lollipop_covvfit/var_dates.yaml" +else: + DATA_PATH = "../new_data/deconvolved.csv" + VAR_DATES_PATH = "../new_data/var_dates.yaml" data = pd.read_csv(DATA_PATH, sep="\t") @@ -60,21 +55,27 @@ # Access the var_dates data var_dates = var_dates_data["var_dates"] -# - + data_wide = data.pivot_table( index=["date", "location"], columns="variant", values="proportion", fill_value=0 ).reset_index() data_wide = data_wide.rename(columns={"date": "time", "location": "city"}) -data_wide.head() -# + +# Define the list with cities: +cities = list(data_wide["city"].unique()) + ## Set limit times for modeling max_date = pd.to_datetime(data_wide["time"]).max() delta_time = pd.Timedelta(days=240) start_date = max_date - delta_time +# Print the data frame +data_wide.head() +# - + +# Now we look at the variants in the data and define the variants of interest: # + # Convert the keys to datetime objects for comparison @@ -90,72 +91,85 @@ def match_date(start_date): return closest_date, var_dates_parsed[closest_date] -variants_full = match_date(start_date + delta_time)[1] - -variants = ["KP.2", "KP.3", "XEC"] +variants_full = match_date(start_date + delta_time)[1] # All the variants in this range -variants_other = [i for i in variants_full if i not in variants] +variants_of_interest = ["KP.2", "KP.3", "XEC"] # Variants of interest +variants_other = [ + i for i in variants_full if i not in variants_of_interest +] # Variants not of interest # - -cities = list(data_wide["city"].unique()) +# Apart from the variants of interest, we define the "other" variant, which artificially merges all the other variants into one. This allows us to model the data as a compositional time series, i.e., the sum of abundances of all "variants" is normalized to one. -variants2 = ["other"] + variants -data2 = prec.preprocess_df( +# + +variants_effective = ["other"] + variants_of_interest +data_full = preprocess.preprocess_df( data_wide, cities, variants_full, date_min=start_date, zero_date=start_date ) -# + -data2["other"] = data2[variants_other].sum(axis=1) -data2[variants2] = data2[variants2].div(data2[variants2].sum(axis=1), axis=0) - -ts_lst, ys_lst = prec.make_data_list(data2, cities, variants2) -ts_lst, ys_lst2 = prec.make_data_list(data2, cities, variants) +data_full["other"] = data_full[variants_other].sum(axis=1) +data_full[variants_effective] = data_full[variants_effective].div( + data_full[variants_effective].sum(axis=1), axis=0 +) -t_max = max([x.max() for x in ts_lst]) -t_min = min([x.min() for x in ts_lst]) +# + +_, ys_effective = preprocess.make_data_list(data_full, cities, variants_effective) +ts_lst, ys_of_interest = preprocess.make_data_list( + data_full, cities, variants_of_interest +) -ts_lst_scaled = [(x - t_min) / (t_max - t_min) for x in ts_lst] +# Scale the time for numerical stability +t_scaler = preprocess.TimeScaler() +ts_lst_scaled = t_scaler.fit_transform(ts_lst) # - -# # fit in jax +# ## Fit the quasimultinomial model +# +# Now we fit the quasimultinomial model, which allows us to find the maximum quasilikelihood estimate of the parameters: # + # %%time -# Recall that the input should be (n_timepoints, n_variants) -# TODO(Pawel, David): Resolve Issue https://github.com/cbg-ethz/covvfit/issues/24 -observed_data = [y.T for y in ys_lst] - - # no priors loss = qm.construct_total_loss( - ys=observed_data, + ys=ys_effective, ts=ts_lst_scaled, average_loss=False, # Do not average the loss over the data points, so that the covariance matrix shrinks with more and more data added ) +n_variants_effective = len(variants_effective) + # initial parameters -theta0 = qm.construct_theta0(n_cities=len(cities), n_variants=len(variants2)) +theta0 = qm.construct_theta0(n_cities=len(cities), n_variants=n_variants_effective) # Run the optimization routine solution = qm.jax_multistart_minimize(loss, theta0, n_starts=10) + +theta_star = solution.x # The maximum quasilikelihood estimate + +print( + f"Relative growth rates: \n", + qm.get_relative_growths(theta_star, n_variants=n_variants_effective), +) # - # ## Make fitted values and confidence intervals -# + ## compute fitted values fitted_values = qm.fitted_values( - ts_lst_scaled, theta=solution.x, cities=cities, n_variants=len(variants2) + ts_lst_scaled, theta=theta_star, cities=cities, n_variants=n_variants_effective ) + +# + +# TODO(Pawel): Refactor this out!!!!! # ... and because of https://github.com/cbg-ethz/covvfit/issues/24 # we need to transpose again y_fit_lst = [y.T[1:] for y in fitted_values] ## compute covariance matrix -covariance = qm.get_covariance(loss, solution.x) +covariance = qm.get_covariance(loss, theta_star) overdispersion_tuple = qm.compute_overdispersion( observed=observed_data, @@ -171,7 +185,8 @@ def match_date(start_date): ## compute standard errors and confidence intervals of the estimates standard_errors_estimates = qm.get_standard_errors(covariance_scaled) -confints_estimates = qm.get_confidence_intervals(solution.x, standard_errors_estimates) +confints_estimates = qm.get_confidence_intervals(theta_star, standard_errors_estimates) + ## compute confidence intervals of the fitted values on the logit scale and back transform y_fit_lst_confint = qm.get_confidence_bands_logit( @@ -207,7 +222,7 @@ def match_date(start_date): # ## Plot # + -colors_covsp = plot_ts.colors_covsp +colors_covsp = plot_ts.COLORS_COVSPECTRUM colors = [colors_covsp[var] for var in variants] fig, axes_tot = plt.subplots(4, 2, figsize=(15, 10)) axes_flat = axes_tot.flatten() @@ -257,4 +272,3 @@ def format_date(x, pos): fig.tight_layout() fig.show() -# -