Skip to content

Commit

Permalink
Merge branch 'main' into move_overdispersion
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed Nov 15, 2024
2 parents db261d7 + 3acfe08 commit e7ec915
Show file tree
Hide file tree
Showing 4 changed files with 417 additions and 49 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
[![build](https://github.com/cbg-ethz/covvfit/actions/workflows/test.yml/badge.svg?branch=main)](https://github.com/cbg-ethz/covvfit/actions/workflows/test.yml)
[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v2.json)](https://github.com/charliermarsh/ruff)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![PyPI Latest Release](https://img.shields.io/pypi/v/covvfit.svg)](https://pypi.org/project/covvfit/)


# covvfit

Fitness estimates of SARS-CoV-2 variants.

**Note: this package is in currently in an alpha stage. The API is expected to be refined and the documentation is being currently developed.**


- **Documentation:** [https://cbg-ethz.github.io/covvfit](https://cbg-ethz.github.io/covvfit)
- **Source code:** [https://github.com/cbg-ethz/covvfit](https://github.com/cbg-ethz/covvfit)
- **Bug reports:** [https://github.com/cbg-ethz/covvfit/issues](https://github.com/cbg-ethz/covvfit/issues)
Expand Down
74 changes: 74 additions & 0 deletions src/covvfit/_padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Sequence, TypeVar

import jax
import jax.numpy as jnp

T = TypeVar("T")


def _is_scalar(value) -> bool:
return not hasattr(value, "__len__")


def create_padded_array(
values: T | Sequence[T] | Sequence[jax.Array] | Sequence[Sequence[T]],
lengths: list[int],
padding_length: int,
padding_value: T,
_out_dtype=float,
) -> jax.Array:
"""Parsing utility, which pads `values`
into a two-dimensional array describing multiple time series.
Args:
values: provided values for each group.
It can be the following a scalar value
(constant for all time series) or a sequence of values
describing the observations of each time series.
If it is a sequence, then each entry can be either
a single value (constant for the time series) or
an array specifying values for all the time points in the
particular time series
lengths: lengths of the timeseries, one per time series
padding_length: padding length, must be larger than all entries
in `lengths`
Returns:
JAX array of shape (n_timeseries, padding_length)
"""
n_cities = len(lengths)
if n_cities < 1:
raise ValueError("There has to be at least one city.")
if max(lengths) > padding_length:
raise ValueError(
f"Maximum length is {max(lengths)}, which is greater than the padding {padding_length}."
)

out_array = jnp.full(
shape=(n_cities, padding_length), fill_value=padding_value, dtype=_out_dtype
)

# First case: `values` argument is a single number (not an iterable)
if _is_scalar(values):
for i, length in enumerate(lengths):
out_array = out_array.at[i, :length].set(values)
return out_array

# Second case: `values` argument is not a scalar, but rather an iterable:
if len(values) != n_cities:
raise ValueError(
f"Provided list has length {len(values)} rather than {n_cities}."
)

for i, (value, exp_len) in enumerate(zip(values, lengths)):
if _is_scalar(value): # For this city we have constant value provided
out_array = out_array.at[i, :exp_len].set(value)
else: # We have a vector of values provided
if len(value) != exp_len:
raise ValueError(
f"For {i}th component the provided array has length {len(value)} rather than {exp_len}."
)
vals = jnp.asarray(value, dtype=out_array.dtype)
out_array = out_array.at[i, :exp_len].set(vals)

return out_array
123 changes: 74 additions & 49 deletions src/covvfit/_quasimultinomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import numpy as np
import numpyro
import numpyro.distributions as distrib
from jaxtyping import Array, Float
from jaxtyping import Array, Bool, Float
from scipy import optimize

from covvfit._padding import create_padded_array


def calculate_linear(
ts: Float[Array, " *batch"],
Expand Down Expand Up @@ -420,24 +422,29 @@ class _ProblemData(NamedTuple):
for timepoints where there is no measurement for a particular city
mask: array of shape (cities, timepoints) with 0 when there is
no measurement for a particular city and 1 otherwise
n_quasimul: quasimultinomial number of trials for each city
overdispersion: overdispersion factor for each city
n_quasimul: quasimultinomial number of trials for each city and timepoint
overdispersion: overdispersion factor for each city and timepoint
"""

n_cities: int
n_variants: int
ts: Float[Array, "cities timepoints"]
ys: Float[Array, "cities timepoints variants"]
mask: Float[Array, "cities timepoints"]
n_quasimul: Float[Array, " cities"]
overdispersion: Float[Array, " cities"]
mask: Bool[Array, "cities timepoints"]
n_quasimul: Float[Array, "cities timepoints"]
overdispersion: Float[Array, "cities timepoints"]


_OverDispersionType = (
float | list[float] | list[jax.Array] | list[list[float]] | Float[Array, " cities"]
)


def _validate_and_pad(
ys: list[jax.Array],
ts: list[jax.Array],
ns_quasimul: Float[Array, " cities"] | list[float] | float = 1.0,
overdispersion: Float[Array, " cities"] | list[float] | float = 1.0,
ns_quasimul: _OverDispersionType,
overdispersion: _OverDispersionType,
) -> _ProblemData:
"""Validation function, parsing the input provided in
the format convenient for the user to the internal
Expand All @@ -446,21 +453,8 @@ def _validate_and_pad(
n_cities = len(ys)
if len(ts) != n_cities:
raise ValueError(f"Number of cities not consistent: {len(ys)} != {len(ts)}.")

# Create arrays representing `n` and `overdispersion`
if hasattr(ns_quasimul, "__len__"):
if len(ns_quasimul) != n_cities:
raise ValueError(
f"Provided `ns_quasimul` has length {len(ns_quasimul)} rather than {n_cities}."
)
if hasattr(overdispersion, "__len__"):
if len(overdispersion) != n_cities:
raise ValueError(
f"Provided `overdispersion` has length {len(overdispersion)} rather than {n_cities}."
)

out_n = jnp.asarray(ns_quasimul) * jnp.ones(n_cities, dtype=float)
out_overdispersion = jnp.asarray(overdispersion) * jnp.ones_like(out_n)
if n_cities < 1:
raise ValueError("There has to be at least one city.")

# Get the number of variants
n_variants = ys[0].shape[-1]
Expand All @@ -472,7 +466,7 @@ def _validate_and_pad(
f"City {i} has {y.shape[-1]} variants rather than {n_variants}."
)

# Ensure that the number of timepoints is consistent
# Ensure that the number of timepoints is consistent for t and y
max_timepoints = 0
for i, (t, y) in enumerate(zip(ts, ys)):
if t.ndim != 1:
Expand All @@ -484,21 +478,45 @@ def _validate_and_pad(
f"City {i} has timepoints mismatch: {t.shape[0]} != {y.shape[0]}."
)

max_timepoints = t.shape[0]
max_timepoints = max(max_timepoints, t.shape[0])

_lengths = [t.shape[0] for t in ts]
out_n = create_padded_array(
values=ns_quasimul,
lengths=_lengths,
padding_length=max_timepoints,
padding_value=0.0,
)
out_overdispersion = create_padded_array(
values=overdispersion,
lengths=_lengths,
padding_length=max_timepoints,
padding_value=1.0, # Use 1.0 as we divide by it and want to avoid NaNs
)

# Now create the arrays representing the data
out_ts = jnp.zeros((n_cities, max_timepoints)) # Pad with zeros
out_mask = jnp.zeros((n_cities, max_timepoints)) # Pad with zeros
out_ts = create_padded_array(
values=ts,
lengths=_lengths,
padding_length=max_timepoints,
padding_value=0.0,
)
out_mask = create_padded_array(
values=1,
lengths=_lengths,
padding_length=max_timepoints,
padding_value=0,
_out_dtype=bool,
)

# Create the array with variant proportions, padded with constant vectors
out_ys = jnp.full(
shape=(n_cities, max_timepoints, n_variants), fill_value=1.0 / n_variants
) # Pad with constant vectors

for i, (t, y) in enumerate(zip(ts, ys)):
n_timepoints = t.shape[0]
)

out_ts = out_ts.at[i, :n_timepoints].set(t)
for i, y in enumerate(ys):
n_timepoints = y.shape[0]
out_ys = out_ys.at[i, :n_timepoints, :].set(y)
out_mask = out_mask.at[i, :n_timepoints].set(1)

return _ProblemData(
n_cities=n_cities,
Expand All @@ -517,16 +535,19 @@ def _quasiloglikelihood_single_city(
ts: Float[Array, " timepoints"],
ys: Float[Array, "timepoints variants"],
mask: Float[Array, " timepoints"],
n_quasimul: float,
overdispersion: float,
n_quasimul: Float[Array, " timepoints"],
overdispersion: Float[Array, " timepoints"],
) -> float:
weight = n_quasimul / overdispersion
logps = calculate_logps(
ts=ts,
midpoints=_add_first_variant(relative_offsets),
growths=_add_first_variant(relative_growths),
)
return jnp.sum(mask[:, None] * weight * ys * logps)
# Ensure compatible shapes:
mask = jnp.asarray(mask, dtype=float)[:, None]
weight = (n_quasimul / overdispersion)[:, None]

return jnp.sum(mask * weight * ys * logps)


_RelativeGrowthsAndOffsetsFunction = Callable[
Expand Down Expand Up @@ -571,24 +592,26 @@ def quasiloglikelihood(
def construct_model(
ys: list[jax.Array],
ts: list[jax.Array],
ns: Float[Array, " cities"] | list[float] | float = 1.0,
overdispersion: Float[Array, " cities"] | list[float] | float = 1.0,
ns: _OverDispersionType = 1.0,
overdispersion: _OverDispersionType = 1.0,
sigma_growth: float = 10.0,
sigma_offset: float = 1000.0,
) -> Callable:
"""Builds a NumPyro model suitable for sampling from the quasiposterior.
Args:
ys: list of variant proportions for each city.
ys: list of variant proportions array for each city.
The ith entry should be an array
of shape (n_timepoints[i], n_variants)
ts: list of timepoints. The ith entry should be an array
ts: list of timepoint arrays. The ith entry should be an array
of shape (n_timepoints[i],)
Note: `ts` should be appropriately normalized
ns: controls the overdispersion of each city by means of
quasimultinomial sample size
ns: controls the quasimultinomial sample size of each city. It can be:
- a single float (sample size is constant across all cities and timepoints)
- a sequence of floats, describing one sample size for each city
- a list of arrays, with the `i`th entry having length `n_timepoints[i]`
overdispersion: controls the overdispersion factor as in the
quasilikelihood approach
quasilikelihood approach. The shape restrictions are the same as in `ns`.
sigma_growth: controls the standard deviation of the prior
on the relative growths
sigma_offset: controls the standard deviation of the prior
Expand Down Expand Up @@ -638,8 +661,8 @@ def model():
def construct_total_loss(
ys: list[jax.Array],
ts: list[jax.Array],
ns: list[float] | float = 1.0,
overdispersion: list[float] | float = 1.0,
ns: _OverDispersionType = 1.0,
overdispersion: _OverDispersionType = 1.0,
accept_theta: bool = True,
average_loss: bool = False,
) -> Callable[[_ThetaType], _Float] | _RelativeGrowthsAndOffsetsFunction:
Expand All @@ -652,10 +675,12 @@ def construct_total_loss(
ts: list of timepoints. The ith entry should be an array
of shape (n_timepoints[i],)
Note: `ts` should be appropriately normalized
ns: controls the overdispersion of each city by means of
quasimultinomial sample size
ns: controls the quasimultinomial sample size of each city. It can be:
- a single float (sample size is constant across all cities and timepoints)
- a sequence of floats, describing one sample size for each city
- a list of arrays, with the `i`th entry having length `n_timepoints[i]`
overdispersion: controls the overdispersion factor as in the
quasilikelihood approach
quasilikelihood approach. The shape restrictions are the same as in `ns`.
accept_theta: whether the returned loss function should accept the
`theta` vector (suitable for optimization)
or should be parameterized by the relative growths
Expand Down
Loading

0 comments on commit e7ec915

Please sign in to comment.