Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve parser creating quasiloglikelihood function #26

Merged
merged 9 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
[![build](https://github.com/cbg-ethz/covvfit/actions/workflows/test.yml/badge.svg?branch=main)](https://github.com/cbg-ethz/covvfit/actions/workflows/test.yml)
[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v2.json)](https://github.com/charliermarsh/ruff)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![PyPI Latest Release](https://img.shields.io/pypi/v/covvfit.svg)](https://pypi.org/project/covvfit/)


# covvfit

Fitness estimates of SARS-CoV-2 variants.

**Note: this package is in currently in an alpha stage. The API is expected to be refined and the documentation is being currently developed.**


- **Documentation:** [https://cbg-ethz.github.io/covvfit](https://cbg-ethz.github.io/covvfit)
- **Source code:** [https://github.com/cbg-ethz/covvfit](https://github.com/cbg-ethz/covvfit)
- **Bug reports:** [https://github.com/cbg-ethz/covvfit/issues](https://github.com/cbg-ethz/covvfit/issues)
Expand Down
74 changes: 74 additions & 0 deletions src/covvfit/_padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Sequence, TypeVar

import jax
import jax.numpy as jnp

T = TypeVar("T")


def _is_scalar(value) -> bool:
return not hasattr(value, "__len__")


def create_padded_array(
values: T | Sequence[T] | Sequence[jax.Array] | Sequence[Sequence[T]],
lengths: list[int],
padding_length: int,
padding_value: T,
_out_dtype=float,
) -> jax.Array:
"""Parsing utility, which pads `values`
into a two-dimensional array describing multiple time series.

Args:
values: provided values for each group.
It can be the following a scalar value
(constant for all time series) or a sequence of values
describing the observations of each time series.
If it is a sequence, then each entry can be either
a single value (constant for the time series) or
an array specifying values for all the time points in the
particular time series
lengths: lengths of the timeseries, one per time series
padding_length: padding length, must be larger than all entries
in `lengths`

Returns:
JAX array of shape (n_timeseries, padding_length)
"""
n_cities = len(lengths)
if n_cities < 1:
raise ValueError("There has to be at least one city.")
if max(lengths) > padding_length:
raise ValueError(
f"Maximum length is {max(lengths)}, which is greater than the padding {padding_length}."
)

out_array = jnp.full(
shape=(n_cities, padding_length), fill_value=padding_value, dtype=_out_dtype
)

# First case: `values` argument is a single number (not an iterable)
if _is_scalar(values):
for i, length in enumerate(lengths):
out_array = out_array.at[i, :length].set(values)
return out_array

# Second case: `values` argument is not a scalar, but rather an iterable:
if len(values) != n_cities:
raise ValueError(
f"Provided list has length {len(values)} rather than {n_cities}."
)

for i, (value, exp_len) in enumerate(zip(values, lengths)):
if _is_scalar(value): # For this city we have constant value provided
out_array = out_array.at[i, :exp_len].set(value)
else: # We have a vector of values provided
if len(value) != exp_len:
raise ValueError(
f"For {i}th component the provided array has length {len(value)} rather than {exp_len}."
)
vals = jnp.asarray(value, dtype=out_array.dtype)
out_array = out_array.at[i, :exp_len].set(vals)

return out_array
123 changes: 74 additions & 49 deletions src/covvfit/_quasimultinomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import numpy as np
import numpyro
import numpyro.distributions as distrib
from jaxtyping import Array, Float
from jaxtyping import Array, Bool, Float
from scipy import optimize

from covvfit._padding import create_padded_array


def calculate_linear(
ts: Float[Array, " *batch"],
Expand Down Expand Up @@ -420,24 +422,29 @@ class _ProblemData(NamedTuple):
for timepoints where there is no measurement for a particular city
mask: array of shape (cities, timepoints) with 0 when there is
no measurement for a particular city and 1 otherwise
n_quasimul: quasimultinomial number of trials for each city
overdispersion: overdispersion factor for each city
n_quasimul: quasimultinomial number of trials for each city and timepoint
overdispersion: overdispersion factor for each city and timepoint
"""

n_cities: int
n_variants: int
ts: Float[Array, "cities timepoints"]
ys: Float[Array, "cities timepoints variants"]
mask: Float[Array, "cities timepoints"]
n_quasimul: Float[Array, " cities"]
overdispersion: Float[Array, " cities"]
mask: Bool[Array, "cities timepoints"]
n_quasimul: Float[Array, "cities timepoints"]
overdispersion: Float[Array, "cities timepoints"]


_OverDispersionType = (
float | list[float] | list[jax.Array] | list[list[float]] | Float[Array, " cities"]
)


def _validate_and_pad(
ys: list[jax.Array],
ts: list[jax.Array],
ns_quasimul: Float[Array, " cities"] | list[float] | float = 1.0,
overdispersion: Float[Array, " cities"] | list[float] | float = 1.0,
ns_quasimul: _OverDispersionType,
overdispersion: _OverDispersionType,
) -> _ProblemData:
"""Validation function, parsing the input provided in
the format convenient for the user to the internal
Expand All @@ -446,21 +453,8 @@ def _validate_and_pad(
n_cities = len(ys)
if len(ts) != n_cities:
raise ValueError(f"Number of cities not consistent: {len(ys)} != {len(ts)}.")

# Create arrays representing `n` and `overdispersion`
if hasattr(ns_quasimul, "__len__"):
if len(ns_quasimul) != n_cities:
raise ValueError(
f"Provided `ns_quasimul` has length {len(ns_quasimul)} rather than {n_cities}."
)
if hasattr(overdispersion, "__len__"):
if len(overdispersion) != n_cities:
raise ValueError(
f"Provided `overdispersion` has length {len(overdispersion)} rather than {n_cities}."
)

out_n = jnp.asarray(ns_quasimul) * jnp.ones(n_cities, dtype=float)
out_overdispersion = jnp.asarray(overdispersion) * jnp.ones_like(out_n)
if n_cities < 1:
raise ValueError("There has to be at least one city.")

# Get the number of variants
n_variants = ys[0].shape[-1]
Expand All @@ -472,7 +466,7 @@ def _validate_and_pad(
f"City {i} has {y.shape[-1]} variants rather than {n_variants}."
)

# Ensure that the number of timepoints is consistent
# Ensure that the number of timepoints is consistent for t and y
max_timepoints = 0
for i, (t, y) in enumerate(zip(ts, ys)):
if t.ndim != 1:
Expand All @@ -484,21 +478,45 @@ def _validate_and_pad(
f"City {i} has timepoints mismatch: {t.shape[0]} != {y.shape[0]}."
)

max_timepoints = t.shape[0]
max_timepoints = max(max_timepoints, t.shape[0])

_lengths = [t.shape[0] for t in ts]
out_n = create_padded_array(
values=ns_quasimul,
lengths=_lengths,
padding_length=max_timepoints,
padding_value=0.0,
)
out_overdispersion = create_padded_array(
values=overdispersion,
lengths=_lengths,
padding_length=max_timepoints,
padding_value=1.0, # Use 1.0 as we divide by it and want to avoid NaNs
)

# Now create the arrays representing the data
out_ts = jnp.zeros((n_cities, max_timepoints)) # Pad with zeros
out_mask = jnp.zeros((n_cities, max_timepoints)) # Pad with zeros
out_ts = create_padded_array(
values=ts,
lengths=_lengths,
padding_length=max_timepoints,
padding_value=0.0,
)
out_mask = create_padded_array(
values=1,
lengths=_lengths,
padding_length=max_timepoints,
padding_value=0,
_out_dtype=bool,
)

# Create the array with variant proportions, padded with constant vectors
out_ys = jnp.full(
shape=(n_cities, max_timepoints, n_variants), fill_value=1.0 / n_variants
) # Pad with constant vectors

for i, (t, y) in enumerate(zip(ts, ys)):
n_timepoints = t.shape[0]
)

out_ts = out_ts.at[i, :n_timepoints].set(t)
for i, y in enumerate(ys):
n_timepoints = y.shape[0]
out_ys = out_ys.at[i, :n_timepoints, :].set(y)
out_mask = out_mask.at[i, :n_timepoints].set(1)

return _ProblemData(
n_cities=n_cities,
Expand All @@ -517,16 +535,19 @@ def _quasiloglikelihood_single_city(
ts: Float[Array, " timepoints"],
ys: Float[Array, "timepoints variants"],
mask: Float[Array, " timepoints"],
n_quasimul: float,
overdispersion: float,
n_quasimul: Float[Array, " timepoints"],
overdispersion: Float[Array, " timepoints"],
) -> float:
weight = n_quasimul / overdispersion
logps = calculate_logps(
ts=ts,
midpoints=_add_first_variant(relative_offsets),
growths=_add_first_variant(relative_growths),
)
return jnp.sum(mask[:, None] * weight * ys * logps)
# Ensure compatible shapes:
mask = jnp.asarray(mask, dtype=float)[:, None]
weight = (n_quasimul / overdispersion)[:, None]

return jnp.sum(mask * weight * ys * logps)


_RelativeGrowthsAndOffsetsFunction = Callable[
Expand Down Expand Up @@ -571,24 +592,26 @@ def quasiloglikelihood(
def construct_model(
ys: list[jax.Array],
ts: list[jax.Array],
ns: Float[Array, " cities"] | list[float] | float = 1.0,
overdispersion: Float[Array, " cities"] | list[float] | float = 1.0,
ns: _OverDispersionType = 1.0,
overdispersion: _OverDispersionType = 1.0,
sigma_growth: float = 10.0,
sigma_offset: float = 1000.0,
) -> Callable:
"""Builds a NumPyro model suitable for sampling from the quasiposterior.

Args:
ys: list of variant proportions for each city.
ys: list of variant proportions array for each city.
The ith entry should be an array
of shape (n_timepoints[i], n_variants)
ts: list of timepoints. The ith entry should be an array
ts: list of timepoint arrays. The ith entry should be an array
of shape (n_timepoints[i],)
Note: `ts` should be appropriately normalized
ns: controls the overdispersion of each city by means of
quasimultinomial sample size
ns: controls the quasimultinomial sample size of each city. It can be:
- a single float (sample size is constant across all cities and timepoints)
- a sequence of floats, describing one sample size for each city
- a list of arrays, with the `i`th entry having length `n_timepoints[i]`
overdispersion: controls the overdispersion factor as in the
quasilikelihood approach
quasilikelihood approach. The shape restrictions are the same as in `ns`.
sigma_growth: controls the standard deviation of the prior
on the relative growths
sigma_offset: controls the standard deviation of the prior
Expand Down Expand Up @@ -638,8 +661,8 @@ def model():
def construct_total_loss(
ys: list[jax.Array],
ts: list[jax.Array],
ns: list[float] | float = 1.0,
overdispersion: list[float] | float = 1.0,
ns: _OverDispersionType = 1.0,
overdispersion: _OverDispersionType = 1.0,
accept_theta: bool = True,
average_loss: bool = False,
) -> Callable[[_ThetaType], _Float] | _RelativeGrowthsAndOffsetsFunction:
Expand All @@ -652,10 +675,12 @@ def construct_total_loss(
ts: list of timepoints. The ith entry should be an array
of shape (n_timepoints[i],)
Note: `ts` should be appropriately normalized
ns: controls the overdispersion of each city by means of
quasimultinomial sample size
ns: controls the quasimultinomial sample size of each city. It can be:
- a single float (sample size is constant across all cities and timepoints)
- a sequence of floats, describing one sample size for each city
- a list of arrays, with the `i`th entry having length `n_timepoints[i]`
overdispersion: controls the overdispersion factor as in the
quasilikelihood approach
quasilikelihood approach. The shape restrictions are the same as in `ns`.
accept_theta: whether the returned loss function should accept the
`theta` vector (suitable for optimization)
or should be parameterized by the relative growths
Expand Down
Loading
Loading