Inferring the Distribution of Gravitational Waves

Here we will get the distribution of gravitational wave events in our universe by performing Hierarchical Bayesian Analysis. Here we will reproduce a simpler version of the population analysis done by the LIGO/VIRGO Collaboration in this paper

Hierarchical Likelihood Inference

[100]:
import numpy as np
import pandas as pd
import json
import jax.numpy as jnp

with open(f"./event_posterior_samples.json", "r") as f:
    event_samples = {k:pd.DataFrame(v) for k,v in json.load(f).items()}

eventnames = []
posteriors = []
for event, posterior in event_samples.items():
    eventnames.append(event)
    posteriors.append(posterior)
[63]:
with open(f"./selection_function_samples.json", "r") as f:
    selection_samples = pd.DataFrame(json.load(f))

indices = np.random.randint(0, high=8000, size=2000)

selections_jaxed = {col: jnp.array(selection_samples[col][indices]) for col in selection_samples.columns if col not in ['waveform_name', 'name']}
[64]:
posteriors_jaxed = [{col: jnp.array(post[col][indices]) for col in post.columns if col not in ['waveform_name']} for post in posteriors]

Population Model Recap

We use individual event likelihoods to infer the hierarchical model.

\[\begin{equation} p(\Lambda \mid d) \propto p(\Lambda)\prod_i^N \mathcal{L}\left(d_i \mid \Lambda\right) \end{equation}\]

where

\[\begin{equation} \mathcal{L}\left(d_i \mid \Lambda\right) \propto \int \frac{p\left(\theta \mid d_i\right) p(\theta \mid \Lambda)}{p(\theta)} d \theta \approx \left\langle \frac{p(\theta \mid \Lambda)}{p(\theta)} \right\rangle_{\sim p\left(\theta \mid d_i\right)} \end{equation}\]

We will “recycle” these posteriors and get the result for the following population model:

[66]:
## Mass model
def power_law(x, L, a, b):
    normalization = (L + 1) / (jnp.power(b, L+1) - jnp.power(a, L+1))
    return jnp.power(x,L) * normalization

def trunc_normal_pdf(x, mu, sig, a, b):
    a_std, b_std = (a - mu) / sig, (b - mu) / sig
    return jax.scipy.stats.truncnorm.pdf(x, a_std, b_std, loc=mu, scale=sig)
[67]:
class MassModel:
    def __init__(self, m_min, m_max):
        self.m_min = m_min
        self.m_max = m_max

    def pdf(self, data, params):
        p_m_power_law = power_law(data['mass_1_source'], params['lambda'], self.m_min, self.m_max)
        p_m_peak = trunc_normal_pdf(data['mass_1_source'], params['mu_m'], params['sigma_m'], self.m_min, self.m_max)
        p_q = power_law(data['mass_ratio'], params['gamma'], self.m_min/data['mass_1_source'], 1.0)
        p_m = params['fp'] * p_m_power_law + (1 - params['fp']) * p_m_peak
        return p_m*p_q

    def __call__(self, data, params):
        return self.pdf(data, params)
[68]:
class Redshift:
    def __init__(self, z_max):
        from astropy.cosmology import Planck15
        self.z = jnp.linspace(0,z_max,300)
        self.y = Planck15.differential_comoving_volume(self.z).value * 4 * np.pi
        self.kappas = jnp.linspace(-10,10,3000)
        self._kappa_norms = None

    def dVdz(self, z):
        return jnp.interp(z, self.z, self.y)

    def normalization_func(self, z, kappa):
        return (1 + z)**(kappa - 1) * self.dVdz(z)

    @property
    def kappa_norms(self):
        if self._kappa_norms is None:
            self._normalize()
        return self._kappa_norms

    def _normalize(self):
        self._kappa_norms = jnp.array([jax.scipy.integrate.trapezoid(self.normalization_func(self.z, float(self.kappas[i])), self.z) for i in range(len(self.kappas))])

    def normalization(self, kappa):
        return jnp.interp(kappa, self.kappas, self.kappa_norms)

    def __call__(self, data, params):
        un_normalized = self.dVdz(data['redshift']) * ((1 + data['redshift']) ** (params["kappa"]-1))
        return un_normalized#/self.normalization(params['kappa'])
[69]:
R = Redshift(3.0)
M = MassModel(2.0, 100.0)
[96]:
class CustomLikelihood:
    def __init__(self, all_posteriors, selections, domain_changer=None):
        self.all_posteriors = all_posteriors
        self.posts = {key: jnp.stack([post[key] for post in self.all_posteriors]) for key in ['mass_1_source', 'mass_ratio', 'redshift', 'prior']};
        self.selections = selections
        self.domain_changer = domain_changer
        self.N_events = len(self.all_posteriors)

    def logpdf(self, x):
        event_likelihoods = jnp.sum( jax.scipy.special.logsumexp( (jnp.log(M(self.posts,x)) + jnp.log(R(self.posts,x)) - jnp.log(self.posts["prior"])) , axis=1) )
        selection_effects = -self.N_events * jax.scipy.special.logsumexp( jnp.log(M(self.selections, x)) + jnp.log(R(self.selections, x)) - jnp.log(self.selections["prior"]) )
        return event_likelihoods + selection_effects
[97]:
from bayesian_inference import NUTS, MHSampler
import jax

limits = {'lambda' : [-5,2], "gamma": [-3,3], 'fp':[0,1], 'mu_m':[10,50], 'sigma_m':[3,10], 'kappa':[-10,10]}
initial_point = {'lambda' : -2.35, "gamma": 1.1, 'fp':0.1, 'mu_m':33.0, 'sigma_m':4.0, 'kappa':2.9}
[98]:
CL = CustomLikelihood(posteriors_jaxed, selections_jaxed)
[74]:
N = NUTS(CL, initial_point, limits=limits)
[75]:
result = N.run(2000)
Running the inference for 2000 samples
[76]:
df = pd.DataFrame(result)
[79]:
import corner

figure = corner.corner(
    df.values,
    labels=[
        r"$\lambda$",
        r"$\gamma$",
        r"$f_p$",
        r"$\mu_m$",
        r"$\sigma_m$",
        r"$\kappa$",
    ],
    quantiles=[0.16, 0.5, 0.84],
    show_titles=True,
    title_kwargs={"fontsize": 12},
)
_images/BlackHolePopulationAnalysis_16_0.png