{
"cells": [
{
"cell_type": "markdown",
"id": "9b6b6e0f-02a5-4057-bc0e-4d279d5f6a6b",
"metadata": {},
"source": [
"# Gravpop tutorial\n",
"\n",
"[](https://colab.research.google.com/github/potatoasad/gravpop/blob/main/docs/Examples/gravpop_tutorial.ipynb)\n",
"\n",
"This is a library that allows you to perform a population analysis, ala [Thrane et. al](https://arxiv.org/abs/1809.02293), but using a trick described in [Hussain et. al](...) that allows one to be able to probe population features even when they get very narrow, and get close to the edges of a bounded domain. \n",
"\n",
"The trick essentially relies on dividing the parameter space into a sector (which we call the __analytic__ sector $\\theta^a$) where our population model is made out of some weighted sum of multivariate truncated normals - where we can analytically compute the population likelihood, and another where the model is general and we can compute it using the monte-carlo estimate of the population likelihood (we call this sector the __sampled__ sector $\\theta^s$). \n",
"\n",
"This trick involves representing the posterior samples as a truncated gaussian mixture model (TGMM). See [truncatedgaussianmixtures](https://github.com/Potatoasad/truncatedgaussianmixtures) for a package that can fit a dataset to a mixture of truncated gaussians. \n",
"\n",
"We can install all the packages we need:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e49a0b0-0a74-4f3d-997a-d1880aab19b1",
"metadata": {},
"outputs": [],
"source": [
"#!pip install numpy jax jaxlib numpyro astropy matplotlib scipy tqdm pandas\n",
"#!pip install gravpop"
]
},
{
"cell_type": "markdown",
"id": "63e8507b-9c12-4ccb-bdea-2b602889b18e",
"metadata": {},
"source": [
"# Data\n",
"For this trick to work, we expect the data to be in a particular format. Given a dataset of $E$ events each fitted to a TGMM using $K$ components, we need the following to be able to do the analytic estimates of the likelihood integral. \n",
"\n",
"- For parameters that are in the __sampled sector__ (e.g. mass, redshift) we desire $N$ samples for every component to be able to do the monte-carlo estimates of the likelihood integral. For each parameter $x$ we desire \n",
" - $E\\times K \\times N$ array called `'x'`, representing the value of parameter $x$ in the sample\n",
" - $E\\times K \\times N$ array called `'prior'`, representing the prior evaluated on each of these samples\n",
" \n",
"- For parameters that are in the __analytic sector__ (e.g. spin orientation, spin magnitude) for each parameter $x$ we desire \n",
" - $E\\times K$ array called `'x_mu_kernel'`, representing the location parameter of each TGMM component\n",
" - $E\\times K$ array called `'x_sigma_kernel'`, representing the scale parameter of each TGMM component\n",
" - $E\\times K$ array called `'x_rho_kernel'`, representing the corrleation parameter of each TGMM component with some other parameter (refer to the documentation of the generated data to infer which other coordinate this correlation correponds to).\n",
" - $E\\times K$ array called `'weights'`, representing the weight of each TGMM component\n",
"\n",
"Here is an example of the form of the data that gwpop uses:\n",
"\n",
"```python\n",
"{'mass_1_source' : [...], # E x K x N Array\n",
" 'prior' : [...], # E x K x N Array\n",
" 'chi_1_mu_kernel' : [...], # E x K Array\n",
" 'chi_1_sigma_kernel' : [...], # E x K Array\n",
" 'chi_2_mu_kernel' : [...], # E x K Array\n",
" 'chi_2_sigma_kernel' : [...], # E x K Array\n",
" 'chi_1_rho_kernel' : [...], # E x K Array\n",
" 'weights' : [...] # E x K Array\n",
"}\n",
"```\n",
"\n",
"\n",
"One can also load from a saved HDF5 dataproduct with the\n",
"following format: \n",
"\n",
"```python\n",
"{'GW150914': {'mass_1_source' : [...], # K x N Array\n",
" 'prior' : [...], # K x N Array\n",
" 'chi_1_mu_kernel' : [...], # K Array\n",
" 'chi_1_sigma_kernel' : [...], # K Array\n",
" 'chi_2_mu_kernel' : [...], # K Array\n",
" 'chi_2_sigma_kernel' : [...], # K Array\n",
" 'chi_1_rho_kernel' : [...], # K Array\n",
" 'weights' : [...] # K Array\n",
" }\n",
"'GW190517' : ...\n",
"}\n",
"```\n",
"\n",
"which is then internally converted upon loading\n",
"\n",
"\n",
"## Creating Fits to Data\n",
"A quick example way to perform this fitting using the `truncatedgaussianmixtures` library, given we have posterior samples (with precomputed priors), is the following:\n",
"\n",
"Lets first pull the event data. The `load_hdf5_to_jax_dict` utility will pull in an hdf5 file containing datasets in some nested structure, and provide a dictionary with all datasets in the form of a dictionary of jax arrays. "
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3400f274",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n", " | chi_1 | \n", "chi_2 | \n", "chirp_mass | \n", "cos_tilt_1 | \n", "cos_tilt_2 | \n", "mass_1_source | \n", "mass_ratio | \n", "prior | \n", "redshift | \n", "
---|---|---|---|---|---|---|---|---|---|
0 | \n", "0.001930 | \n", "0.097727 | \n", "30.385015 | \n", "0.648332 | \n", "-0.889278 | \n", "34.860180 | \n", "0.846616 | \n", "0.002376 | \n", "0.088913 | \n", "
1 | \n", "0.028095 | \n", "0.041392 | \n", "30.547218 | \n", "-0.478594 | \n", "-0.951041 | \n", "34.373547 | \n", "0.860291 | \n", "0.003205 | \n", "0.101225 | \n", "
2 | \n", "0.225305 | \n", "0.275129 | \n", "31.643066 | \n", "0.137109 | \n", "0.303796 | \n", "35.392544 | \n", "0.852512 | \n", "0.004326 | \n", "0.113009 | \n", "
3 | \n", "0.000255 | \n", "0.453062 | \n", "30.345112 | \n", "0.380074 | \n", "-0.241549 | \n", "35.997993 | \n", "0.765623 | \n", "0.003989 | \n", "0.108616 | \n", "
4 | \n", "0.055366 | \n", "0.084168 | \n", "31.065317 | \n", "-0.569629 | \n", "-0.102006 | \n", "34.082230 | \n", "0.920346 | \n", "0.002493 | \n", "0.091574 | \n", "
5 | \n", "0.618177 | \n", "0.570168 | \n", "30.133745 | \n", "-0.140515 | \n", "-0.041809 | \n", "36.690060 | \n", "0.734442 | \n", "0.003609 | \n", "0.103474 | \n", "
6 | \n", "0.721023 | \n", "0.114889 | \n", "30.238930 | \n", "-0.177807 | \n", "-0.596842 | \n", "40.232189 | \n", "0.652075 | \n", "0.001779 | \n", "0.074039 | \n", "
7 | \n", "0.002967 | \n", "0.082386 | \n", "31.107443 | \n", "-0.801190 | \n", "0.647356 | \n", "36.535999 | \n", "0.781811 | \n", "0.003973 | \n", "0.107783 | \n", "
8 | \n", "0.026698 | \n", "0.268805 | \n", "30.639584 | \n", "-0.344281 | \n", "-0.528770 | \n", "33.594978 | \n", "0.893774 | \n", "0.003713 | \n", "0.108505 | \n", "
9 | \n", "0.483900 | \n", "0.414391 | \n", "30.279470 | \n", "0.298207 | \n", "-0.782193 | \n", "33.710155 | \n", "0.855433 | \n", "0.004422 | \n", "0.116259 | \n", "
\n", " | eta | \n", "lamb | \n", "mu_1 | \n", "sigma_1 | \n", "sigma_2 | \n", "
---|---|---|---|---|---|
0 | \n", "0.476112 | \n", "0.071545 | \n", "0.090596 | \n", "0.105491 | \n", "0.077657 | \n", "
1 | \n", "0.460751 | \n", "0.095909 | \n", "0.048131 | \n", "0.113161 | \n", "0.130583 | \n", "
2 | \n", "0.592883 | \n", "0.289967 | \n", "0.315391 | \n", "0.284772 | \n", "0.199418 | \n", "
3 | \n", "0.978045 | \n", "0.010133 | \n", "0.082620 | \n", "0.111385 | \n", "0.073198 | \n", "
4 | \n", "0.975931 | \n", "0.428702 | \n", "0.117799 | \n", "0.270596 | \n", "0.155701 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
995 | \n", "0.790344 | \n", "0.017701 | \n", "0.017430 | \n", "0.074497 | \n", "0.001218 | \n", "
996 | \n", "0.641068 | \n", "0.332459 | \n", "0.011991 | \n", "0.334153 | \n", "0.004681 | \n", "
997 | \n", "0.863573 | \n", "0.144214 | \n", "0.000969 | \n", "0.113832 | \n", "0.059052 | \n", "
998 | \n", "0.765053 | \n", "0.170739 | \n", "0.008454 | \n", "0.110767 | \n", "0.134668 | \n", "
999 | \n", "0.837497 | \n", "0.067079 | \n", "0.225776 | \n", "0.166493 | \n", "0.009897 | \n", "
1000 rows × 5 columns
\n", "