Skip to content

Commit

Permalink
Interpolated mass distribtuion (#77)
Browse files Browse the repository at this point in the history
* FEATURE: allow downstream packages to automatically update backend

* TEST: fix backend test

* FORMAT: precommit fixes

* Fix posterior resampling for jax backend

* FEATURE: add interpolated primary mass distribution

* BUGFIX: interpolated mass bugfixes and testing

* TEST: improve mass test coverage

* DOC: add reference

* FORMAT: pre-commit fixes
  • Loading branch information
ColmTalbot authored Dec 20, 2023
1 parent b1a4cf5 commit 2d43f7c
Show file tree
Hide file tree
Showing 6 changed files with 299 additions and 158 deletions.
28 changes: 21 additions & 7 deletions gwpopulation/hyperpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,14 +338,28 @@ def posterior_predictive_resample(self, samples, return_weights=False):
weights = (weights.T / xp.sum(weights, axis=-1)).T
new_idxs = xp.empty_like(weights, dtype=int)
for ii in range(self.n_posteriors):
new_idxs[ii] = xp.asarray(
np.random.choice(
range(self.samples_per_posterior),
size=self.samples_per_posterior,
replace=True,
p=to_numpy(weights[ii]),
if "jax" in xp.__name__:
from jax import random

rng_key = random.PRNGKey(np.random.randint(10000000))
new_idxs = new_idxs.at[ii].set(
random.choice(
rng_key,
xp.arange(self.samples_per_posterior),
shape=(self.samples_per_posterior,),
replace=True,
p=weights[ii],
)
)
else:
new_idxs[ii] = xp.asarray(
np.random.choice(
range(self.samples_per_posterior),
size=self.samples_per_posterior,
replace=True,
p=to_numpy(weights[ii]),
)
)
)
new_samples = {
key: xp.vstack(
[self.data[key][ii, new_idxs[ii]] for ii in range(self.n_posteriors)]
Expand Down
32 changes: 26 additions & 6 deletions gwpopulation/models/interped.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,39 @@ def _setup_interpolant(nodes, values, kind="cubic", backend=xp):
return interpolant


class InterpolatedNoBaseModelIdentical(object):
class InterpolatedNoBaseModelIdentical:
"""
Base class for the Interpolated classes with no base model
Parameters
==========
parameters: list
List of parameters to interpolate over, e.g., :code:`["a_1", "a_2"]`
minimum: float
Minimum value to normalize the spline over
maximum: float
Maximum value to normalize the spline over
nodes: int
Number of nodes to use in the spline, default=10
kind: str
The interpolation order of the spline, default="cubic"
log_nodes: bool
Whether to use log-spaced nodes, default=False
"""

def __init__(self, parameters, minimum, maximum, nodes=10, kind="cubic"):
def __init__(
self, parameters, minimum, maximum, nodes=10, kind="cubic", log_nodes=False
):
""" """
self.nodes = nodes
self.norm_selector = None
self.spline_selector = None
self._norm_spline = None
self._data_spline = dict()
self.kind = kind
self._xs = xp.linspace(minimum, maximum, 10 * self.nodes)
self.parameters = parameters
self.min = minimum
self.max = maximum
self.log_nodes = log_nodes

self.base = self.parameters[0].strip("_1")
self.xkeys = [f"{self.base}{ii}" for ii in range(self.nodes)]
Expand All @@ -53,10 +69,14 @@ def variable_names(self):
return keys

def setup_interpolant(self, nodes, values):
if self.log_nodes:
func = xp.log
else:
func = xp.array
kwargs = dict(kind=self.kind, backend=xp)
self._norm_spline = _setup_interpolant(nodes, self._xs, **kwargs)
self._norm_spline = _setup_interpolant(func(nodes), func(self._xs), **kwargs)
self._data_spline = {
param: _setup_interpolant(nodes, values[param], **kwargs)
param: _setup_interpolant(func(nodes), func(values[param]), **kwargs)
for param in self.parameters
}

Expand Down
138 changes: 134 additions & 4 deletions gwpopulation/models/mass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import scipy.special as scs

from ..utils import powerlaw, truncnorm
from .interped import InterpolatedNoBaseModelIdentical

xp = np

Expand Down Expand Up @@ -249,6 +250,27 @@ def power_law_primary_secondary_identical(dataset, alpha, mmin, mmax):
)


def power_law_mass(mass, alpha, mmin, mmax):
r"""
Power law model for one-dimensional mass distribution.
.. math::
p(m) &\propto m^{-\alpha} : m_\min \leq m < m_\max
Parameters
----------
mass: array-like
Array of mass values (:math:`m`).
alpha: float
Negative power law exponent for the black hole distribution (:math:`\alpha`).
mmin: float
Minimum black hole mass (:math:`m_\min`).
mmax: float
Maximum black hole mass (:math:`m_\max`).
"""
return powerlaw(mass, alpha=-alpha, high=mmax, low=mmin)


def two_component_single(
mass, alpha, mmin, mmax, lam, mpp, sigpp, gaussian_mass_maximum=100
):
Expand Down Expand Up @@ -480,7 +502,7 @@ def two_component_primary_secondary_identical(
)


class BaseSmoothedMassDistribution(object):
class BaseSmoothedMassDistribution:
"""
Generic smoothed mass distribution base class.
Expand Down Expand Up @@ -558,8 +580,8 @@ def norm_p_m1(self, delta_m, **kwargs):
p_m = self.__class__.primary_model(self.m1s, **kwargs)
p_m *= self.smoothing(self.m1s, mmin=mmin, mmax=self.mmax, delta_m=delta_m)

norm = xp.trapz(p_m, self.m1s)
return norm ** (delta_m > 0)
norm = xp.where(delta_m > 0, xp.trapz(p_m, self.m1s), 1)
return norm

def p_q(self, dataset, beta, mmin, delta_m):
p_q = powerlaw(dataset["mass_ratio"], beta, 1, mmin / dataset["mass_1"])
Expand All @@ -583,7 +605,11 @@ def norm_p_q(self, beta, mmin, delta_m):
p_q *= self.smoothing(
self.m1s_grid * self.qs_grid, mmin=mmin, mmax=self.m1s_grid, delta_m=delta_m
)
norms = xp.nan_to_num(xp.trapz(p_q, self.qs, axis=0)) ** (delta_m > 0)
norms = xp.where(
delta_m > 0,
xp.nan_to_num(xp.trapz(p_q, self.qs, axis=0)),
xp.ones(self.m1s.shape),
)

return self._q_interpolant(norms)

Expand Down Expand Up @@ -781,3 +807,107 @@ class BrokenPowerLawPeakSmoothedMassDistribution(BaseSmoothedMassDistribution):
@property
def kwargs(self):
return dict(gaussian_mass_maximum=self.mmax)


class InterpolatedPowerlaw(
BaseSmoothedMassDistribution, InterpolatedNoBaseModelIdentical
):
"""
Interpolated powerlaw primary mass distribution with powerlaw mass ratio distribution.
See https://arxiv.org/abs/2109.06137 for details.
Parameters
----------
dataset: dict
Dictionary of numpy arrays for 'mass_1' and 'mass_ratio'.
alpha: float
Powerlaw exponent for more massive black hole.
beta: float
Power law exponent of the mass ratio distribution.
mmin: float
Minimum black hole mass.
mmax: float
Maximum mass in the powerlaw distributed component.
delta_m: float
Rise length of the low end of the mass distribution.
mass{ii}: float
The locations of the spline nodes for the primary mass distribution.
fmass{ii}: float
The values of the spline nodes for the primary mass distribution.
"""

primary_model = power_law_mass

def __init__(
self, nodes=10, kind="cubic", mmin=2, mmax=100, normalization_shape=(1000, 500)
):
"""
Parameters
==========
nodes: int
Number of spline nodes to use for interpolation, default=10.
kind: str
Order of the spline to use for interpolation, default="cubic".
mmin: float
The minimum mass considered for numerical normalization, default=2.
mmax: float
The maximum mass considered for numerical normalization, default=100.
normalization_shape: tuple
Shape of the grid used for numerical normalization, default=(1000, 500).
"""
BaseSmoothedMassDistribution.__init__(
self,
mmin=mmin,
mmax=mmax,
normalization_shape=normalization_shape,
)
InterpolatedNoBaseModelIdentical.__init__(
self,
minimum=mmin,
maximum=mmax,
parameters=["mass_1"],
nodes=nodes,
kind=kind,
log_nodes=True,
)
self._xs = self.m1s

@property
def variable_names(self):
variable_names = super().variable_names.union(
InterpolatedNoBaseModelIdentical.variable_names.fget(self)
)
return variable_names

def p_m1(self, dataset, **kwargs):

f_splines = xp.array([kwargs[key] for key in self.fkeys])
m_splines = xp.array([kwargs[key] for key in self.xkeys])

mmin = kwargs.get("mmin", self.mmin)
delta_m = kwargs.pop("delta_m", 0)
p_m = self.__class__.primary_model(
dataset["mass_1"], **{key: kwargs[key] for key in ["alpha", "mmin", "mmax"]}
)
p_m *= self.smoothing(
dataset["mass_1"], mmin=mmin, mmax=self.mmax, delta_m=delta_m
)
p_m *= self.p_x_unnormed(dataset, "mass_1", m_splines, f_splines, **kwargs)

norm = self.norm_p_m1(delta_m=delta_m, f_splines=f_splines, **kwargs)
return p_m / norm

def norm_p_m1(self, delta_m, f_splines=None, **kwargs):
mmin = kwargs.get("mmin", self.mmin)
p_m = self.__class__.primary_model(
self.m1s, **{key: kwargs[key] for key in ["alpha", "mmin", "mmax"]}
)
p_m = xp.where(
delta_m > 0,
p_m * self.smoothing(self.m1s, mmin=mmin, mmax=self.mmax, delta_m=delta_m),
p_m,
)
p_m *= xp.exp(self._norm_spline(y=f_splines))
norm = xp.trapz(p_m, self.m1s)
return norm
81 changes: 18 additions & 63 deletions test/example_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,29 @@
import gwpopulation

from . import TEST_BACKENDS
from .jax_utils import (
JittedLikelihood,
NonCachingModel,
generic_bilby_likelihood_function,
)
from .jax_utils import JittedLikelihood, NonCachingModel


@pytest.mark.parametrize("backend", TEST_BACKENDS)
def test_likelihood_evaluation(backend):
def _template_likelihod_evaluation(backend, jit):
gwpopulation.set_backend(backend)
xp = gwpopulation.models.mass.xp
bilby.core.utils.random.seed(10)
rng = bilby.core.utils.random.rng

model = bilby.hyper.model.Model(
if jit:
model_cls = NonCachingModel
else:
model_cls = Model

model = model_cls(
[
gwpopulation.models.mass.SinglePeakSmoothedMassDistribution(),
gwpopulation.models.spin.independent_spin_magnitude_beta,
gwpopulation.models.spin.independent_spin_orientation_gaussian_isotropic,
gwpopulation.models.redshift.PowerLawRedshift(),
]
)
vt_model = bilby.hyper.model.Model(
vt_model = model_cls(
[
gwpopulation.models.mass.SinglePeakSmoothedMassDistribution(),
gwpopulation.models.spin.independent_spin_magnitude_beta,
Expand Down Expand Up @@ -67,70 +67,25 @@ def test_likelihood_evaluation(backend):
hyper_prior=model,
posteriors=posteriors,
selection_function=selection,
cupy=backend == "cupy",
cupy=False,
)
if jit:
likelihood = JittedLikelihood(likelihood)

priors = bilby.core.prior.PriorDict("priors/bbh_population.prior")

likelihood.parameters.update(priors.sample())
assert abs(likelihood.log_likelihood_ratio() + 1.810695) < 0.01
likelihood.posterior_predictive_resample(pd.DataFrame(priors.sample(5)))


def test_jit_likelihood():
gwpopulation.set_backend("jax")
xp = gwpopulation.models.mass.xp
bilby.core.utils.random.seed(10)
rng = bilby.core.utils.random.rng

model = NonCachingModel(
[
gwpopulation.models.mass.SinglePeakSmoothedMassDistribution(),
gwpopulation.models.spin.independent_spin_magnitude_beta,
gwpopulation.models.spin.independent_spin_orientation_gaussian_isotropic,
gwpopulation.models.redshift.PowerLawRedshift(),
]
)
vt_model = NonCachingModel(
[
gwpopulation.models.mass.SinglePeakSmoothedMassDistribution(),
gwpopulation.models.spin.independent_spin_magnitude_beta,
gwpopulation.models.spin.independent_spin_orientation_gaussian_isotropic,
gwpopulation.models.redshift.PowerLawRedshift(),
]
)

bounds = dict(
mass_1=(20, 25),
mass_ratio=(0.9, 1),
a_1=(0, 1),
a_2=(0, 1),
cos_tilt_1=(-1, -1),
cos_tilt_2=(-1, -1),
redshift=(0, 2),
prior=(1, 1),
)
posteriors = [
pd.DataFrame({key: rng.uniform(*bound, 100) for key, bound in bounds.items()})
for _ in range(10)
]
vt_data = {
key: xp.asarray(rng.uniform(*bound, 10000)) for key, bound in bounds.items()
}

selection = gwpopulation.vt.ResamplingVT(vt_model, vt_data, len(posteriors))

likelihood = gwpopulation.hyperpe.HyperparameterLikelihood(
hyper_prior=model,
posteriors=posteriors,
selection_function=selection,
cupy=False,
)
likelihood = JittedLikelihood(likelihood)
@pytest.mark.parametrize("backend", TEST_BACKENDS)
def test_likelihood_evaluation(backend):
_template_likelihod_evaluation(backend, False)

priors = bilby.core.prior.PriorDict("priors/bbh_population.prior")

likelihood.parameters.update(priors.sample())
assert abs(likelihood.log_likelihood_ratio() + 1.810695) < 0.01
def test_jit_likelihood():
_template_likelihod_evaluation("jax", True)


def test_prior_files_load():
Expand Down
Loading

0 comments on commit 2d43f7c

Please sign in to comment.