Skip to content

Commit

Permalink
FEAT: backend agnostic cosmology (#92)
Browse files Browse the repository at this point in the history
* TST: update tests for cosmology change

* MAINT: move cosmo redshift classes out of experimental

* REFACTOR: make jax import not top level

* CI: fix installation

* CI: update versions

* MAINT: flake fixes

* TEST: try fixing pytest

* TST: fix importorskip

* BGUFIX: fix cupy mass normalization

* TST: test fixes for cupy

* BLD: remove 3.12 tests for black issues

* FMT: black fixes

* MAINT: remove deprecated code

* BUG: fix cosmo model instantiation

* BUG: fix references to redshift

* BUG: fix returned source frame samples

* BLD: require wcosmo from pypi
  • Loading branch information
ColmTalbot authored May 28, 2024
1 parent 261f2b8 commit ba0f4c5
Show file tree
Hide file tree
Showing 14 changed files with 109 additions and 357 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/pages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ jobs:

- name: Install dependencies
run: |
conda install --file requirements.txt
pip install -r requirements.txt
# conda install --file requirements.txt
conda install --file pages_requirements.txt
conda install --file test_requirements.txt
python -m pip install .
Expand Down
6 changes: 4 additions & 2 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11"]

steps:
- uses: actions/checkout@v4
Expand All @@ -41,7 +41,9 @@ jobs:
- name: Install dependencies
run: |
conda install pip setuptools
conda install --file requirements.txt
# install via pip to install from source repo
pip install -r requirements.txt
# conda install --file requirements.txt
conda install --file test_requirements.txt
pre-commit install
- name: Install gwpopulation
Expand Down
266 changes: 43 additions & 223 deletions gwpopulation/experimental/cosmo_models.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,30 @@
import numpy as np
from astropy import constants
from astropy import units as u
from astropy.cosmology import FlatLambdaCDM, FlatwCDM
from bilby.hyper.model import Model
from scipy.interpolate import splev, splrep
import numpy as xp
from wcosmo import FlatwCDM, available, z_at_value

from gwpopulation.utils import to_numpy

xp = np
speed_of_light = constants.c.to("km/s").value
from .jax import NonCachingModel


class CosmoMixin:
def __init__(self, cosmo_model="Planck15"):

self.cosmo_model = cosmo_model
if self.cosmo_model == "FlatwCDM":
self.cosmology_names = ["H0", "Om0", "w0"]
elif self.cosmo_model == "FlatLambdaCDM":
self.cosmology_names = ["H0", "Om0"]
else:
self.cosmology_names = []
self._cosmo = available[cosmo_model]

def cosmology_variables(self, parameters):
return {key: parameters[key] for key in self.cosmology_names}

def cosmology(self, parameters):
if isinstance(self._cosmo, FlatwCDM):
return self._cosmo
else:
return self._cosmo(**self.cosmology_variables(parameters))

def detector_frame_to_source_frame(self, data, **parameters):
"""
Convert detector frame samples to sourece frame samples given cosmological parameters. Calculate the corresponding d_detector/d_source Jacobian term.
Expand All @@ -24,46 +37,39 @@ def detector_frame_to_source_frame(self, data, **parameters):
The cosmological parameters for relevant cosmology model.
"""

cosmo = self.redshift_model.cosmology_model(**parameters)
jac = 1
samples = dict()
if "luminosity_distance" in data.keys():
zs = self.redshift_model.zs_
dl = cosmo.luminosity_distance(zs).value
interp_dl_to_z = splrep(dl, zs, s=0)

data["redshift"] = xp.nan_to_num(
xp.asarray(
splev(to_numpy(data["luminosity_distance"]), interp_dl_to_z, ext=0)
)
cosmo = self.cosmology(self.parameters)
samples["redshift"] = z_at_value(
cosmo.luminosity_distance,
data["luminosity_distance"],
)
jac *= data["luminosity_distance"] / (
1 + data["redshift"]
) + speed_of_light * (1 + data["redshift"]) / xp.array(
cosmo.H(to_numpy(data["redshift"])).value
) # luminosity_distance_to_redshift_jacobian, dL_by_dz
jacobian = cosmo.dDLdz(samples["redshift"])
elif "redshift" not in data:
raise ValueError(
f"Either luminosity distance or redshift provided in detector frame to source frame samples conversion"
)
else:
jacobian = xp.ones(data["redshift"].shape)

for key in list(data.keys()):
if key.endswith("_detector"):
data[key[:-9]] = data[key] / (1 + data["redshift"])
jac *= 1 + data["redshift"]
samples[key[:-9]] = data[key] / (1 + samples["redshift"])
jacobian *= 1 + samples["redshift"]
elif key != "luminosity_distance":
samples[key] = data[key]

return data, jac
return samples, jacobian


class CosmoModel(Model, CosmoMixin):
class CosmoModel(NonCachingModel, CosmoMixin):
"""
Modified version of bilby.hyper.model.Model that disables caching for jax.
"""

def __init__(self, model_functions=None):
super(CosmoModel, self).__init__(model_functions=model_functions)
for model in self.models:
if isinstance(model, _BaseRedshift):
self.redshift_model = model
def __init__(self, model_functions=None, cosmo_model="Planck15"):
NonCachingModel.__init__(self, model_functions=model_functions)
CosmoMixin.__init__(self, cosmo_model=cosmo_model)

def prob(self, data, **kwargs):
"""
Expand All @@ -82,194 +88,8 @@ def prob(self, data, **kwargs):
model.
"""

data, jac = self.detector_frame_to_source_frame(
data,
**self._get_function_parameters(self.redshift_model),
) # convert samples to source frame and calculate the Jacobian term.
probability = 1.0 # prob in source frame
for function in self.models:
new_probability = function(data, **self._get_function_parameters(function))
probability *= new_probability
probability /= jac # prob in detector frame
data, jacobian = self.detector_frame_to_source_frame(data)
probability = super().prob(data, **kwargs)
probability /= jacobian

return probability


class _BaseRedshift:
"""
Base class for models which include a term like dVc/dz / (1 + z)
"""

base_variable_names = None

@property
def variable_names(self):
if self.cosmo_model == None:
vars = []
if self.cosmo_model == FlatwCDM:
vars = ["H0", "Om0", "w0"]
elif self.cosmo_model == FlatLambdaCDM:
vars = ["H0", "Om0"]
else:
raise ValueError(f"Model {cosmo_model} not found.")
vars += self.base_variable_names
return vars

def __init__(self, z_max=2.3, cosmo_model=None):

self.z_max = z_max
self.zs_ = np.linspace(1e-3, z_max, 1000)
self.zs = xp.asarray(self.zs_)
if cosmo_model == None:
self.cosmo_model = None
from astropy.cosmology import Planck15

self.dvc_dz_ = (
Planck15.differential_comoving_volume(self.zs_).value * 4 * np.pi
)
self.dvc_dz = xp.asarray(self.dvc_dz_)
self.cached_dvc_dz = None
else:
self.cosmo_model = cosmo_model

def __call__(self, dataset, **kwargs):
return self.probability(dataset=dataset, **kwargs)

def _cache_dvc_dz(self, redshifts):
self.cached_dvc_dz = xp.asarray(
np.interp(to_numpy(redshifts), self.zs_, self.dvc_dz_, left=0, right=0)
)

def normalisation(self, parameters):
r"""
Compute the normalization or differential spacetime volume.
.. math::
\mathcal{V} = \int dz \frac{1}{1+z} \frac{dVc}{dz} \psi(z|\Lambda)
Parameters
----------
parameters: dict
Dictionary of parameters
Returns
-------
(float, array-like): Total spacetime volume
"""
psi_of_z = self.psi_of_z(redshift=self.zs, **parameters)
norm = xp.trapz(psi_of_z * self.dvc_dz / (1 + self.zs), self.zs)
return norm

def probability(self, dataset, **parameters):
if self.cosmo_model is not None:
self.update_dvc_dz(**parameters)
normalisation = self.normalisation(parameters=parameters)
differential_volume = self.differential_spacetime_volume(
dataset=dataset, **parameters
)
in_bounds = dataset["redshift"] <= self.z_max
return differential_volume / normalisation * in_bounds

def psi_of_z(self, redshift, **parameters):
raise NotImplementedError

def differential_spacetime_volume(self, dataset, **parameters):
r"""
Compute the differential spacetime volume.
.. math::
d\mathcal{V} = \frac{1}{1+z} \frac{dVc}{dz} \psi(z|\Lambda)
Parameters
----------
dataset: dict
Dictionary containing entry "redshift"
parameters: dict
Dictionary of parameters
Returns
-------
differential_volume: (float, array-like)
Differential spacetime volume
"""
psi_of_z = self.psi_of_z(redshift=dataset["redshift"], **parameters)
differential_volume = psi_of_z / (1 + dataset["redshift"])
try:
differential_volume *= self.cached_dvc_dz
except (TypeError, ValueError):
self._cache_dvc_dz(dataset["redshift"])
differential_volume *= self.cached_dvc_dz
return differential_volume

def cosmology_model(self, **parameters):
if self.cosmo_model == FlatwCDM:
return self.cosmo_model(
H0=parameters["H0"], Om0=parameters["Om0"], w0=parameters["w0"]
)
elif self.cosmo_model == FlatLambdaCDM:
return self.cosmo_model(H0=parameters["H0"], Om0=parameters["Om0"])
else:
raise ValueError(f"Model {cosmo_model} not found.")

def update_dvc_dz(self, **parameters):
self.dvc_dz_ = (
self.cosmology_model(**parameters)
.differential_comoving_volume(self.zs_)
.value
* 4
* xp.pi
)
self.dvc_dz = xp.asarray(self.dvc_dz_)
self.cached_dvc_dz = None


class CosmoPowerLawRedshift(_BaseRedshift):
r"""
Redshift model from Fishbach+ https://arxiv.org/abs/1805.10270 and Cosmo model FlatLambdaCDM
.. math::
p(z|\gamma, \kappa, z_p) &\propto \frac{1}{1 + z}\frac{dV_c}{dz} \psi(z|\gamma, \kappa, z_p)
\psi(z|\gamma, \kappa, z_p) &= (1 + z)^\lambda
Parameters
----------
lamb: float
The spectral index.
"""
base_variable_names = ["lamb"]

def psi_of_z(self, redshift, **parameters):
return (1 + redshift) ** parameters["lamb"]


class CosmoMadauDickinsonRedshift(_BaseRedshift):
r"""
Redshift model from Fishbach+ https://arxiv.org/abs/1805.10270 (33)
See https://arxiv.org/abs/2003.12152 (2) for the normalisation
The parameterisation differs a little from there, we use
.. math::
p(z|\gamma, \kappa, z_p) &\propto \frac{1}{1 + z}\frac{dV_c}{dz} \psi(z|\gamma, \kappa, z_p)
\psi(z|\gamma, \kappa, z_p) &= \frac{(1 + z)^\gamma}{1 + (\frac{1 + z}{1 + z_p})^\kappa}
Parameters
----------
gamma: float
Slope of the distribution at low redshift
kappa: float
Slope of the distribution at high redshift
z_peak: float
Redshift at which the distribution peaks.
z_max: float, optional
The maximum redshift allowed.
"""
base_variable_names = ["gamma", "kappa", "z_peak"]

def psi_of_z(self, redshift, **parameters):
gamma = parameters["gamma"]
kappa = parameters["kappa"]
z_peak = parameters["z_peak"]
psi_of_z = (1 + redshift) ** gamma / (
1 + ((1 + redshift) / (1 + z_peak)) ** kappa
)
psi_of_z *= 1 + (1 + z_peak) ** (-kappa)
return psi_of_z
3 changes: 2 additions & 1 deletion gwpopulation/experimental/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy as np
from bilby.core.likelihood import Likelihood
from bilby.hyper.model import Model
from jax import jit


def generic_bilby_likelihood_function(likelihood, parameters, use_ratio=True):
Expand Down Expand Up @@ -46,6 +45,8 @@ class JittedLikelihood(Likelihood):
def __init__(
self, likelihood, likelihood_func=generic_bilby_likelihood_function, kwargs=None
):
from jax import jit

if kwargs is None:
kwargs = dict()
self.kwargs = kwargs
Expand Down
17 changes: 8 additions & 9 deletions gwpopulation/models/mass.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,9 @@ def norm_p_m1(self, delta_m, **kwargs):
p_m = self.__class__.primary_model(self.m1s, **kwargs)
p_m *= self.smoothing(self.m1s, mmin=mmin, mmax=self.mmax, delta_m=delta_m)

norm = xp.where(xp.array(delta_m) > 0, xp.trapz(p_m, self.m1s), 1)
norm = xp.nan_to_num(xp.trapz(p_m, self.m1s)) * (delta_m != 0) + 1 * (
delta_m == 0
)
return norm

def p_q(self, dataset, beta, mmin, delta_m):
Expand Down Expand Up @@ -611,10 +613,9 @@ def norm_p_q(self, beta, mmin, delta_m):
p_q *= self.smoothing(
self.m1s_grid * self.qs_grid, mmin=mmin, mmax=self.m1s_grid, delta_m=delta_m
)
norms = xp.where(
xp.array(delta_m) > 0,
xp.nan_to_num(xp.trapz(p_q, self.qs, axis=0)),
xp.ones(self.m1s.shape),

norms = xp.nan_to_num(xp.trapz(p_q, self.qs, axis=0)) * (delta_m != 0) + 1 * (
delta_m == 0
)

return self._q_interpolant(norms)
Expand Down Expand Up @@ -918,10 +919,8 @@ def norm_p_m1(self, delta_m, f_splines=None, **kwargs):
p_m = self.__class__.primary_model(
self.m1s, **{key: kwargs[key] for key in ["alpha", "mmin", "mmax"]}
)
p_m = xp.where(
delta_m > 0,
p_m * self.smoothing(self.m1s, mmin=mmin, mmax=self.mmax, delta_m=delta_m),
p_m,
p_m *= self.smoothing(self.m1s, mmin=mmin, mmax=self.mmax, delta_m=delta_m) ** (
delta_m > 0
)
p_m *= xp.exp(self._norm_spline(y=f_splines))
norm = xp.trapz(p_m, self.m1s)
Expand Down
Loading

0 comments on commit ba0f4c5

Please sign in to comment.