## Before running this install:
## pip install jax
## pip install jaxlib
## pip install blackjax
import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import numpy as np
import pandas as pd
import blackjax
from .DomainChanger import DomainChanger
[docs]
def convert_to_jax_array(dictionary):
r"""
Convert selected values in a dictionary to JAX arrays.
Parameters
----------
dictionary : dict
The input dictionary containing values to be converted.
Returns
-------
dict
A new dictionary with selected values converted to JAX arrays.
"""
for key, value in dictionary.items():
if isinstance(value, (float, int)): # Check if the value is a float or int
dictionary[key] = jnp.array([value]) # Replace with a JAX array of size 1
return dictionary
[docs]
class NUTS:
r"""
Solves a problem specified by a likelihood object using the NUTS sampler.
"""
def __init__(self, likelihood, init_position, limits=None,
step_size=1e-3, inverse_mass_matrix=None, rng_key=None,
warmup_steps=100):
r"""
Initialize the NUTS sampler.
Parameters
----------
likelihood : object
The likelihood object representing the problem to be solved.
init_position : dict
The initial positions for the sampler.
limits : dict, optional
The limits for each variable, default is None (unbounded).
step_size : float, optional
The step size for the sampler, defaults to 1e-3.
inverse_mass_matrix : array, optional
The inverse mass matrix for the sampler, defaults to None.
rng_key : array, optional
The random number generator key, defaults to None.
warmup_steps : int, optional
The number of warmup steps for the sampler, defaults to 100.
"""
if rng_key is None:
rng_key = jax.random.key(np.random.randint(2**32))
if limits is None:
self.domain_changer = DomainChanger({key : 'infinite' for key in init_position.keys()}, backend='JAX')
else:
limit_dict = {}
for key in init_position:
if key in limits:
limit_dict[key] = limits[key]
else:
limit_dict[key] = 'infinite'
self.domain_changer = DomainChanger(limit_dict, backend='JAX')
self.likelihood = likelihood # likelihood object
self.step_size = step_size # stepsize (at the moment it doesn't matter)
self.rng_key = rng_key # key
self.likelihood_func = self.domain_changer.logprob_wrapped(self.likelihood.logpdf)#lambda x: likelihood_func(x) # likelihood function
#my_init = init_position.copy()
my_init = self.domain_changer.transform(convert_to_jax_array(init_position))
self.num_samples = None
#print(init_position, my_init)
self.init_position = my_init
self.warmup_steps = warmup_steps
## Set up the warmup for the HMC sampler
warmup = blackjax.window_adaptation(blackjax.nuts, self.likelihood_func)
rng_key, warmup_key, sample_key = jax.random.split(rng_key, 3)
(self._state_init, self.parameters), _ = warmup.run(warmup_key, self.init_position, num_steps=warmup_steps)
# the kernel performs one step
self.kernel = blackjax.nuts(self.likelihood_func, **self.parameters).step
self.states = []
self.positions = None
[docs]
def step(self, rng_key, x):
r"""
Perform a single step of the NUTS sampler.
Parameters
----------
rng_key : array
The random number generator key.
x : dict
The current position in the parameter space.
Returns
-------
state
The updated state after a single step.
"""
return self.kernel(rng_key, x)
[docs]
def inference_loop(self, num_samples, sample_key=None):
r"""
Perform the inference loop to obtain samples from the NUTS sampler.
Parameters
----------
num_samples : int
The number of samples to generate.
sample_key : array, optional
The random number generator key for sampling, defaults to None.
Returns
-------
states
A list of states generated by the sampler.
"""
if sample_key is None:
self.rng_key, sample_key = jax.random.split(self.rng_key)
print(f"Running the inference for {num_samples} samples")
@jax.jit
def one_step(state, rng_key):
state, _ = self.kernel(rng_key, state)
return state, state
self.keys = jax.random.split(sample_key, num_samples)
_, self.states = jax.lax.scan(one_step, self._state_init, self.keys)
return self.states
[docs]
def run(self, num_samples=100):
r"""
Run the NUTS sampler to obtain samples.
Parameters
----------
num_samples : int, optional
The number of samples to generate, defaults to 100.
Returns
-------
result : dict
A dictionary containing the sampled positions.
"""
self.num_samples = num_samples
self.rng_key = jax.random.key(np.random.randint(2**16))
self.rng_key, sample_key = jax.random.split(self.rng_key)
self.inference_loop(num_samples, sample_key=sample_key)
positions = self.domain_changer.inverse_transform(self.states.position)
self.positions = positions
return self.result
@property
def result(self):
r"""
Get the result of the sampler.
Returns
-------
result : dict
A dictionary containing the sampled positions.
"""
if self.num_samples is None:
return None
def replace_array_with_float(val):
if val.shape == (1,):
return float(val[0])
else:
return val
return {key: [replace_array_with_float(value[i]) for i in range(self.num_samples)] for (key, value) in self.positions.items()}