Skip to content

Commit

Permalink
Add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed Nov 19, 2024
1 parent ecf8bc6 commit c92935d
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 18 deletions.
53 changes: 35 additions & 18 deletions src/covvfit/_quasimultinomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def loss(
return -jnp.sum(n * y * logp, axis=-1)


_ThetaType = Float[Array, "(cities+1)*(variants-1)"]
ModelParameters = Float[Array, "(cities+1)*(variants-1)"]


def _add_first_variant(vec: Float[Array, " variants-1"]) -> Float[Array, " variants"]:
Expand All @@ -78,21 +78,21 @@ def _add_first_variant(vec: Float[Array, " variants-1"]) -> Float[Array, " varia
def construct_theta(
relative_growths: Float[Array, " variants-1"],
relative_midpoints: Float[Array, "cities variants-1"],
) -> _ThetaType:
) -> ModelParameters:
flattened_midpoints = relative_midpoints.flatten()
theta = jnp.concatenate([relative_growths, flattened_midpoints])
return theta


def get_relative_growths(
theta: _ThetaType,
theta: ModelParameters,
n_variants: int,
) -> Float[Array, " variants-1"]:
return theta[: n_variants - 1]


def get_relative_midpoints(
theta: _ThetaType,
theta: ModelParameters,
n_variants: int,
) -> Float[Array, "cities variants-1"]:
n_cities = theta.shape[0] // (n_variants - 1) - 1
Expand All @@ -113,8 +113,8 @@ def convert(confidence: float) -> float:


def get_covariance(
loss_fn: Callable[[_ThetaType], _Float],
theta: _ThetaType,
loss_fn: Callable[[ModelParameters], _Float],
theta: ModelParameters,
) -> Float[Array, "n_params n_params"]:
"""Calculates the covariance matrix of the parameters.
Expand Down Expand Up @@ -196,7 +196,7 @@ def get_confidence_intervals(

def fitted_values(
times: list[Float[Array, " timepoints"]],
theta: _ThetaType,
theta: ModelParameters,
cities: list,
n_variants: int,
) -> list[Float[Array, "timepoints variants"]]:
Expand Down Expand Up @@ -239,7 +239,7 @@ def create_logit_predictions_fn(
"""

def logit_predictions_with_fixed_args(
theta: _ThetaType,
theta: ModelParameters,
):
return get_logit_predictions(
theta=theta, n_variants=n_variants, city_index=city_index, ts=ts
Expand All @@ -249,7 +249,7 @@ def logit_predictions_with_fixed_args(


def get_confidence_bands_logit(
solution_x: Float[Array, " (cities+1)*(variants-1)"],
theta: ModelParameters,
variants_count: int,
ts_lst_scaled: list[Float[Array, " timepoints"]],
covariance_scaled: Float[Array, "n_params n_params"],
Expand All @@ -270,17 +270,18 @@ def get_confidence_bands_logit(
for the confidence intervals on the linear scale.
"""
# TODO(Pawel): Potentially fix the signature of this function.
# Issue 24

y_fit_lst_logit = [
get_logit_predictions(solution_x, variants_count, i, ts).T[1:, :]
get_logit_predictions(theta, variants_count, i, ts).T[1:, :]
for i, ts in enumerate(ts_lst_scaled)
]

y_fit_lst_logit_se = []
for i, ts in enumerate(ts_lst_scaled):
# Compute the Jacobian of the transformation and project standard errors
jacobian = jax.jacobian(create_logit_predictions_fn(variants_count, i, ts))(
solution_x
theta
)
standard_errors = get_standard_errors(
jacobian=jacobian, covariance=covariance_scaled
Expand Down Expand Up @@ -310,7 +311,22 @@ def triangular_mask(n_variants, valid_value: float = 0, masked_value: float = jn
return nan_mask


def get_relative_advantages(theta, n_variants: int):
def get_relative_advantages(
theta: ModelParameters, n_variants: int
) -> Float[Array, "variants variants"]:
"""Returns a matrix of relative advantages, comparing every two variants.
Returns:
matrix of shape (n_variants, n_variants) with `A[reference, variant]`
representing the relative advantage of `variant` over `reference`.
Note:
From the model assumptions it follows that
`A[v1, v2] + A[v2, v3] = A[v1, v3]`
for every three variants. (I.e., the relative advantage
of `v3` over `v1` is the sum of advantages of `v3` over `v2`
and `v2` over `v1`)
"""
# Shape (n_variants-1,) describing relative advantages
# over the 0th variant
rel_growths = get_relative_growths(theta, n_variants=n_variants)
Expand All @@ -321,10 +337,11 @@ def get_relative_advantages(theta, n_variants: int):


def get_softmax_predictions(
theta: _ThetaType, n_variants: int, city_index: int, ts: Float[Array, " timepoints"]
theta: ModelParameters,
n_variants: int,
city_index: int,
ts: Float[Array, " timepoints"],
) -> Float[Array, "timepoints variants"]:
# TODO(Pawel): Potentially fix the signature of this function.

rel_growths = get_relative_growths(theta, n_variants=n_variants)
growths = _add_first_variant(rel_growths)

Expand All @@ -342,7 +359,7 @@ def get_softmax_predictions(


def get_logit_predictions(
theta: _ThetaType,
theta: ModelParameters,
n_variants: int,
city_index: int,
ts: Float[Array, " timepoints"],
Expand All @@ -368,7 +385,7 @@ class OptimizeMultiResult:
def construct_theta0(
n_cities: int,
n_variants: int,
) -> _ThetaType:
) -> ModelParameters:
return np.zeros((n_cities * (n_variants - 1) + n_variants - 1,), dtype=float)


Expand Down Expand Up @@ -668,7 +685,7 @@ def construct_total_loss(
overdispersion: _OverDispersionType = 1.0,
accept_theta: bool = True,
average_loss: bool = False,
) -> Callable[[_ThetaType], _Float] | _RelativeGrowthsAndOffsetsFunction:
) -> Callable[[ModelParameters], _Float] | _RelativeGrowthsAndOffsetsFunction:
"""Constructs the loss function, suitable e.g., for optimization.
Args:
Expand Down
46 changes: 46 additions & 0 deletions tests/test_quasimultinomial.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import covvfit._quasimultinomial as qm
import jax
import jax.numpy as jnp
import numpy.testing as npt
import pytest

Expand Down Expand Up @@ -46,3 +47,48 @@ def test_parameter_conversions_2(seed: int, n_cities: int, n_variants: int) -> N
),
theta,
)


def test_softmax_predictions(
n_cities: int = 2, n_variants: int = 3, n_timepoints: int = 50
) -> None:
theta0 = qm.construct_theta0(n_cities=n_cities, n_variants=n_variants)
theta = jax.random.normal(jax.random.PRNGKey(42), shape=theta0.shape)

ts = jnp.linspace(0, 1, n_timepoints)

for city in range(n_cities):
predictions = qm.get_softmax_predictions(
theta,
n_variants=n_variants,
city_index=city,
ts=ts,
)

assert predictions.shape == (n_timepoints, n_variants)

npt.assert_allclose(
predictions.sum(axis=-1),
jnp.ones(n_timepoints),
atol=1e-6,
)


def test_get_relative_advantages(n_cities: int = 1, n_variants: int = 5) -> None:
theta0 = qm.construct_theta0(n_cities=n_cities, n_variants=n_variants)
# The variants are ordered by increasing fitness
relative = jnp.arange(1, n_variants)
theta = qm.construct_theta(
relative_growths=relative,
relative_midpoints=qm.get_relative_midpoints(theta0, n_variants=n_variants),
)

A = qm.get_relative_advantages(theta, n_variants=n_variants)
for v2 in range(n_variants):
for v1 in range(n_variants):
assert pytest.approx(A[v1, v2]) == v2 - v1

for v1 in range(n_variants):
for v2 in range(n_variants):
for v3 in range(n_variants):
assert pytest.approx(A[v1, v3]) == A[v1, v2] + A[v2, v3]

0 comments on commit c92935d

Please sign in to comment.