Skip to content

Commit

Permalink
WIP: Refactor notebook up to fitting the model
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed Nov 19, 2024
1 parent 109eff4 commit ecf8bc6
Showing 1 changed file with 70 additions and 56 deletions.
126 changes: 70 additions & 56 deletions examples/frequentist_notebook_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,41 +13,36 @@
# name: python3
# ---

# # Quasilikelihood data analysis notebook
#
# This notebook shows how to estimate growth advantages by fiting the model within the quasimultinomial framework.

# +
import jax
import jax.numpy as jnp

import pandas as pd

import numpy as np

import matplotlib.ticker as ticker
import matplotlib.pyplot as plt
import seaborn as sns

from scipy.special import expit
from scipy.stats import norm

import matplotlib.ticker as ticker
import pandas as pd
import yaml

import covvfit._preprocess_abundances as prec
import covvfit.plotting._timeseries as plot_ts

from covvfit import plot, preprocess
from covvfit import quasimultinomial as qm

import numpyro

plot_ts = plot.timeseries
# -


# # Load and preprocess data
# ## Load and preprocess data
#
# We start by loading the data:

# +
DATA_PATH = "../../LolliPop/lollipop_covvfit/deconvolved.csv"
VAR_DATES_PATH = "../../LolliPop/lollipop_covvfit/var_dates.yaml"

DATA_PATH = "../new_data/deconvolved.csv"
VAR_DATES_PATH = "../new_data/var_dates.yaml"
_dir_switch = False # Change this to 0 or 1, depending on the laptop you are on
if _dir_switch:
DATA_PATH = "../../LolliPop/lollipop_covvfit/deconvolved.csv"
VAR_DATES_PATH = "../../LolliPop/lollipop_covvfit/var_dates.yaml"
else:
DATA_PATH = "../new_data/deconvolved.csv"
VAR_DATES_PATH = "../new_data/var_dates.yaml"


data = pd.read_csv(DATA_PATH, sep="\t")
Expand All @@ -60,21 +55,27 @@

# Access the var_dates data
var_dates = var_dates_data["var_dates"]
# -


data_wide = data.pivot_table(
index=["date", "location"], columns="variant", values="proportion", fill_value=0
).reset_index()
data_wide = data_wide.rename(columns={"date": "time", "location": "city"})
data_wide.head()

# +
# Define the list with cities:
cities = list(data_wide["city"].unique())

## Set limit times for modeling

max_date = pd.to_datetime(data_wide["time"]).max()
delta_time = pd.Timedelta(days=240)
start_date = max_date - delta_time

# Print the data frame
data_wide.head()
# -

# Now we look at the variants in the data and define the variants of interest:

# +
# Convert the keys to datetime objects for comparison
Expand All @@ -90,72 +91,85 @@ def match_date(start_date):
return closest_date, var_dates_parsed[closest_date]


variants_full = match_date(start_date + delta_time)[1]

variants = ["KP.2", "KP.3", "XEC"]
variants_full = match_date(start_date + delta_time)[1] # All the variants in this range

variants_other = [i for i in variants_full if i not in variants]
variants_of_interest = ["KP.2", "KP.3", "XEC"] # Variants of interest
variants_other = [
i for i in variants_full if i not in variants_of_interest
] # Variants not of interest
# -

cities = list(data_wide["city"].unique())
# Apart from the variants of interest, we define the "other" variant, which artificially merges all the other variants into one. This allows us to model the data as a compositional time series, i.e., the sum of abundances of all "variants" is normalized to one.

variants2 = ["other"] + variants
data2 = prec.preprocess_df(
# +
variants_effective = ["other"] + variants_of_interest
data_full = preprocess.preprocess_df(
data_wide, cities, variants_full, date_min=start_date, zero_date=start_date
)

# +
data2["other"] = data2[variants_other].sum(axis=1)
data2[variants2] = data2[variants2].div(data2[variants2].sum(axis=1), axis=0)

ts_lst, ys_lst = prec.make_data_list(data2, cities, variants2)
ts_lst, ys_lst2 = prec.make_data_list(data2, cities, variants)
data_full["other"] = data_full[variants_other].sum(axis=1)
data_full[variants_effective] = data_full[variants_effective].div(
data_full[variants_effective].sum(axis=1), axis=0
)

t_max = max([x.max() for x in ts_lst])
t_min = min([x.min() for x in ts_lst])
# +
_, ys_effective = preprocess.make_data_list(data_full, cities, variants_effective)
ts_lst, ys_of_interest = preprocess.make_data_list(
data_full, cities, variants_of_interest
)

ts_lst_scaled = [(x - t_min) / (t_max - t_min) for x in ts_lst]
# Scale the time for numerical stability
t_scaler = preprocess.TimeScaler()
ts_lst_scaled = t_scaler.fit_transform(ts_lst)
# -


# # fit in jax
# ## Fit the quasimultinomial model
#
# Now we fit the quasimultinomial model, which allows us to find the maximum quasilikelihood estimate of the parameters:

# +
# %%time

# Recall that the input should be (n_timepoints, n_variants)
# TODO(Pawel, David): Resolve Issue https://github.com/cbg-ethz/covvfit/issues/24
observed_data = [y.T for y in ys_lst]


# no priors
loss = qm.construct_total_loss(
ys=observed_data,
ys=ys_effective,
ts=ts_lst_scaled,
average_loss=False, # Do not average the loss over the data points, so that the covariance matrix shrinks with more and more data added
)

n_variants_effective = len(variants_effective)

# initial parameters
theta0 = qm.construct_theta0(n_cities=len(cities), n_variants=len(variants2))
theta0 = qm.construct_theta0(n_cities=len(cities), n_variants=n_variants_effective)

# Run the optimization routine
solution = qm.jax_multistart_minimize(loss, theta0, n_starts=10)

theta_star = solution.x # The maximum quasilikelihood estimate

print(
f"Relative growth rates: \n",
qm.get_relative_growths(theta_star, n_variants=n_variants_effective),
)
# -

# ## Make fitted values and confidence intervals

# +
## compute fitted values
fitted_values = qm.fitted_values(
ts_lst_scaled, theta=solution.x, cities=cities, n_variants=len(variants2)
ts_lst_scaled, theta=theta_star, cities=cities, n_variants=n_variants_effective
)


# +
# TODO(Pawel): Refactor this out!!!!!
# ... and because of https://github.com/cbg-ethz/covvfit/issues/24
# we need to transpose again
y_fit_lst = [y.T[1:] for y in fitted_values]

## compute covariance matrix
covariance = qm.get_covariance(loss, solution.x)
covariance = qm.get_covariance(loss, theta_star)

overdispersion_tuple = qm.compute_overdispersion(
observed=observed_data,
Expand All @@ -171,7 +185,8 @@ def match_date(start_date):

## compute standard errors and confidence intervals of the estimates
standard_errors_estimates = qm.get_standard_errors(covariance_scaled)
confints_estimates = qm.get_confidence_intervals(solution.x, standard_errors_estimates)
confints_estimates = qm.get_confidence_intervals(theta_star, standard_errors_estimates)


## compute confidence intervals of the fitted values on the logit scale and back transform
y_fit_lst_confint = qm.get_confidence_bands_logit(
Expand Down Expand Up @@ -207,7 +222,7 @@ def match_date(start_date):
# ## Plot

# +
colors_covsp = plot_ts.colors_covsp
colors_covsp = plot_ts.COLORS_COVSPECTRUM
colors = [colors_covsp[var] for var in variants]
fig, axes_tot = plt.subplots(4, 2, figsize=(15, 10))
axes_flat = axes_tot.flatten()
Expand Down Expand Up @@ -257,4 +272,3 @@ def format_date(x, pos):

fig.tight_layout()
fig.show()
# -

0 comments on commit ecf8bc6

Please sign in to comment.