Skip to content

Commit

Permalink
WIP: Fixed the notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed Nov 19, 2024
1 parent 19be8fa commit e04f79b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
19 changes: 12 additions & 7 deletions examples/frequentist_notebook_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ def match_date(start_date):
)

# Scale the time for numerical stability
t_scaler = preprocess.TimeScaler()
ts_lst_scaled = t_scaler.fit_transform(ts_lst)
time_scaler = preprocess.TimeScaler()
ts_lst_scaled = time_scaler.fit_transform(ts_lst)
# -


Expand Down Expand Up @@ -210,13 +210,17 @@ def match_date(start_date):

## compute confidence intervals of the fitted values on the logit scale and back transform
y_fit_lst_confint = qm.get_confidence_bands_logit(
theta_star, len(variants_effective), ts_lst_scaled, covariance_scaled
theta_star,
n_variants=n_variants_effective,
ts=ts_lst_scaled,
covariance=covariance_scaled,
)


## compute predicted values and confidence bands
horizon = 60
ts_pred_lst = [jnp.arange(horizon + 1) + tt.max() for tt in ts_lst]
ts_pred_lst_scaled = t_scaler.transform(ts_pred_lst)
ts_pred_lst_scaled = time_scaler.transform(ts_pred_lst)

y_pred_lst = qm.fitted_values(
ts_pred_lst_scaled, theta=theta_star, cities=cities, n_variants=n_variants_effective
Expand All @@ -226,14 +230,15 @@ def match_date(start_date):
y_pred_lst = [y.T[1:] for y in y_pred_lst]

y_pred_lst_confint = qm.get_confidence_bands_logit(
solution.x, n_variants_effective, ts_pred_lst_scaled, covariance_scaled
theta_star,
n_variants=n_variants_effective,
ts=ts_pred_lst_scaled,
covariance=covariance_scaled,
)


# -

confints_estimates

y_pred_lst[0].shape

# ## Plot
Expand Down
4 changes: 2 additions & 2 deletions src/covvfit/_quasimultinomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def get_confidence_bands_logit(
*,
n_variants: int,
ts: list[Float[Array, " timepoints"]],
covariance_scaled: Float[Array, "n_params n_params"],
covariance: Float[Array, "n_params n_params"],
confidence_level: float = 0.95,
) -> list[tuple]:
"""Computes confidence intervals for logit predictions using the Delta method,
Expand Down Expand Up @@ -284,7 +284,7 @@ def get_confidence_bands_logit(
# Compute the Jacobian of the transformation and project standard errors
jacobian = jax.jacobian(_create_logit_predictions_fn(n_variants, i, ts))(theta)
standard_errors = get_standard_errors(
jacobian=jacobian, covariance=covariance_scaled
jacobian=jacobian, covariance=covariance
).T
y_fit_lst_logit_se.append(standard_errors)

Expand Down

0 comments on commit e04f79b

Please sign in to comment.