Source code for quicksampler.samplers.MHSampler

import numpy as np
from scipy.stats import norm
from tqdm import trange
from .DomainChanger import DomainChanger

import jax
import jax.numpy as jnp

[docs] def convert_to_jax_array(dictionary): r""" Convert selected values in a dictionary to JAX arrays. Parameters ---------- dictionary : dict The input dictionary containing values to be converted. Returns ------- dict A new dictionary with selected values converted to JAX arrays. """ for key, value in dictionary.items(): if isinstance(value, (float, int)): # Check if the value is a float or int dictionary[key] = jnp.array([value]) # Replace with a JAX array of size 1 return dictionary
[docs] def create_rng_key(backend): r""" Create a random number generator key based on the backend. Parameters ---------- backend : str The backend for computation ('numpy' or 'JAX'). Returns ------- array A random number generator key. """ if backend == 'numpy': return np.random.seed(np.random.randint(2**32)) elif backend == 'JAX': return jax.random.key(np.random.randint(2**32)) else: raise ValueError(f"Do not recognize the {backend} backend")
[docs] class AbstractProposal: r""" Abstract proposal for generating new states based on the current state. This class serves as a base for specific proposal distributions. It defines the common functionality and interface for proposing new values for variables in the state dictionary. Parameters ---------- step_size : float The step size for the proposal. backend : str, optional The backend for computation ('numpy' or 'JAX'), defaults to 'numpy'. rng_key : array, optional The random number generator key, defaults to None. Methods ------- __call__(x) Generate a new state based on the current state 'x'. """ def __init__(self, step_size, backend='numpy', rng_key=None): r""" Initialize an abstract proposal. Parameters ---------- step_size : float The step size for the proposal. backend : str, optional The backend for computation ('numpy' or 'JAX'), defaults to 'numpy'. rng_key : array, optional The random number generator key, defaults to None. """ self.step_size = step_size self.keys = None self.backend = backend self.rng_key = rng_key if rng_key is None: self.rng_key = create_rng_key(backend)
[docs] def __call__(self, x): if self.keys is None: self.keys = list(x.keys()) return {key: (value + self.next_step(key, value)) for key,value in x.items()}
[docs] class SphericalGaussianProposal(AbstractProposal): r""" Spherical Gaussian proposal for generating new states. This class defines a proposal distribution where the next step is sampled from a spherical Gaussian distribution with a specified step size. Parameters ---------- step_size : float The step size for the proposal. backend : str, optional The backend for computation ('numpy' or 'JAX'), defaults to 'numpy'. rng_key : array, optional The random number generator key, defaults to None. """ def __init__(self, step_size, backend='numpy', rng_key=None): super().__init__(step_size, backend=backend, rng_key=rng_key)
[docs] def next_step(self, key, value): r""" Compute the next proposed value for any variable inside the state dictionary x. Within the state dictionary x (e.g., x = {'mu': [0.1, 0.2], 'kappa': 0.3}), each state variable (the key) has a value that parameterizes its current state. This function will take in a value and compute the change in position needed to go to the next step in the proposal. e.g. >>> x = {'mu': np.array([0.1, 0.2]), 'kappa': 0.3} >>> SphericalGaussianProposal(step_size=0.1) >>> SphericalGaussianProposal.next_step('mu', x['mu']) # random normal proposal with the right shape. Parameters ---------- key : str The variable key. value : float, array-like The current value of the variable. Returns ------- array The next proposed value for the variable. """ if self.backend == 'numpy': if type(value) == np.ndarray: # Add a random gaussian of std step size to every entry of the matrix return np.random.randn(*value.shape)*self.step_size elif (type(value) == type(2.0)) or (type(value) == np.float64): # Add a random gaussian of std step size to the float value return np.random.randn()*self.step_size else: raise ValueError(f"Got value of type {type(value)} having value {value} no idea what to do with this.") elif self.backend == 'JAX': new_key, self.rng_key = jax.random.split(self.rng_key) if type(value) == type(jnp.array([0.0, 1.0])): # Add a random gaussian of std step size to every entry of the matrix return jax.lax.stop_gradient(jax.random.normal(new_key , value.shape))*self.step_size elif (type(value) == type(2.0)) or (type(value) == type(jnp.array([0.0, 10.0])[0])): # Add a random gaussian of std step size to the float value return jax.lax.stop_gradient(jax.random.normal(new_key))*self.step_size else: raise ValueError(f"Got value of type {type(value)} having value {value} no idea what to do with this. It isnt {type(jax.array([0.0, 10.0])[0])}") else: raise ValueError(f"Do not recognize the {self.backend} backend")
MAX_REJECTS_DEFAULT = 10000
[docs] class MHSampler: r""" Metropolis-Hastings sampler for generating samples from a distribution. This class implements the Metropolis-Hastings algorithm for generating samples from a target distribution specified by a likelihood function. It works with both 'numpy' and 'JAX' backends. Samplers will accept an object with a logpdf method like the following: Parameters ---------- likelihood : object The likelihood object representing the target distribution. init_position : dict The initial position in the state space. step_size : float, optional The step size for the Metropolis-Hastings proposal, defaults to 1. limits : dict, optional The limits for variables in the state space, defaults to None. rng_key : array, optional The random number generator key, defaults to None. backend : str, optional The backend for computation ('numpy' or 'JAX'), defaults to 'numpy'. """ def __init__(self, likelihood, init_position, step_size=1, limits=None, rng_key=None, backend='numpy'): r""" Initialize the Metropolis-Hastings sampler. Parameters ---------- likelihood : object The likelihood object representing the target distribution. init_position : dict The initial position in the state space. step_size : float, optional The step size for the Metropolis-Hastings proposal, defaults to 1. limits : dict, optional The limits for variables in the state space, defaults to None. rng_key : array, optional The random number generator key, defaults to None. backend : str, optional The backend for computation ('numpy' or 'JAX'), defaults to 'numpy'. """ self.likelihood = likelihood self.backend = backend self.rng_key = rng_key if rng_key is None: self.rng_key = create_rng_key(backend) if limits is None: self.domain_changer = DomainChanger({key : 'infinite' for key in init_position.keys()}, backend=self.backend) else: limit_dict = {} for key in init_position: if key in limits: limit_dict[key] = limits[key] else: limit_dict[key] = 'infinite' self.domain_changer = DomainChanger(limit_dict, backend=self.backend) self.init_position = init_position if self.backend == 'JAX': self.init_position = jax.lax.stop_gradient(convert_to_jax_array(self.init_position)) self.init_position_transformed = self.domain_changer.transform(self.init_position) self.likelihood_func = self.domain_changer.logprob_wrapped(self.likelihood.logpdf) self.step_size = step_size self.proposal = SphericalGaussianProposal(self.step_size, rng_key=self.rng_key, backend=self.backend) if rng_key is None: if self.backend == 'numpy': rng_key = np.random.seed(np.random.randint(2**32)) elif self.backend == 'JAX': rng_key = jax.random.key(np.random.randint(2**32)) else: raise ValueError(f"Do not recognize the {self.backend} backend") self.rng_key = rng_key self.history = [] self.max_rejects_default = 10000 self.running_acceptances = 0 self.total_steps = 0 self.num_samples = None
[docs] def accept_reject(self, p_accept): r""" Accept or reject a proposed state based on the acceptance probability. Parameters ---------- p_accept : float The probability of accepting the proposed state. Returns ------- bool True if the proposed state is accepted, False otherwise. """ if self.backend == 'numpy': return (np.random.rand() < p_accept) elif self.backend == 'JAX': return bool(jax.random.uniform(self.rng_key) < p_accept)
[docs] def step(self, x, max_rejects = MAX_REJECTS_DEFAULT): r""" Perform a single Metropolis-Hastings step to generate a new state. Parameters ---------- x : dict The current state. max_rejects : int, optional The maximum number of rejections before raising an error, defaults to 10000. Returns ------- dict The new proposed state. """ accept = False rejects = 0 while ((rejects < max_rejects) and (not accept)): x_proposal = self.proposal(x) log_likelihood_current = self.likelihood_func(x) log_likelihood_proposal = self.likelihood_func(x_proposal) p_current = log_likelihood_current #+ log_prior_current p_proposal = log_likelihood_proposal #+ log_prior_proposal if self.backend == 'JAX': p_accept = jnp.exp(p_proposal - p_current) elif self.backend == 'numpy': p_accept = np.exp(p_proposal - p_current) else: raise ValueError(f"Do not recognize the {self.backend} backend") accept = self.accept_reject(p_accept) self.total_steps += 1 if accept: self.running_acceptances += 1 return x_proposal else: rejects += 1 if rejects == max_rejects: raise ValueError("The next proposal has been rejected {max_rejects} times! try changing something")
[docs] def run(self, n_steps = 1000, max_rejects=MAX_REJECTS_DEFAULT): r""" Run the Metropolis-Hastings sampler for a specified number of steps. Parameters ---------- n_steps : int, optional The number of Metropolis-Hastings steps to run, defaults to 1000. max_rejects : int, optional The maximum number of rejections before raising an error, defaults to MAX_REJECTS_DEFAULT. Returns ------- dict The result of the sampler in a dictionary format. """ self.num_samples = n_steps y = self.init_position_transformed self.history.append(self.domain_changer.inverse_transform(self.init_position_transformed)) self.running_acceptances = 0 self.total_steps = 0 print(f"Getting {n_steps} using Metropolis Hastings") for t in trange(n_steps): y = self.step(y, max_rejects=max_rejects) self.history.append(self.domain_changer.inverse_transform(y)) acceptance_rate = self.running_acceptances/self.total_steps print(f"Sampling finished with an acceptance rate of {np.round(acceptance_rate*100, decimals=2)}") return self.result
@property def result(self): r""" Get the result of the sampler in a dictionary format. Returns ------- dict The result of the sampler in a dictionary format. """ if self.num_samples is None: return None def replace_array_with_float(val): if getattr(val, 'shape', None) == (1,): return float(val[0]) elif type(val) == float: return val else: return np.array(val) return {key: ([replace_array_with_float(self.history[i][key]) for i in range(self.num_samples)]) for key in self.history[0].keys()}