Simple Bayesian Inference Examples

import jax.numpy as jnp
from jax import random
from jax.scipy.special import logsumexp
from jax import grad
import pandas as pd

import numpy as np
import matplotlib.pyplot as plt

Linear Regression

Lets set up a linear regression problem and get samples from the likelihood function. First we produce the data

key = random.PRNGKey(0)

m = 0.1
c = 0
X = jnp.linspace(0,10,100)
noise_amount = 0.1
eps = noise_amount*random.normal(key, X.shape)

Y = m*X + c + eps

plt.title("Y = mX + c + noise")

Linear Regressor Model

Here we create the model. All we need is a class that has a logpdf function that takes in a dictionary of data and outputs the logpdf

class LinearRegressor:
    def __init__(self, X, Y, noise=1):
        self.X = X
        self.Y = Y
        self.noise = noise

    def logpdf(self, x):
        Z = (self.Y - x['m']*self.X - x['c'])/self.noise
        return jnp.sum(-(Z**2))/2

Instantiate the model

LR = LinearRegressor(X,Y,noise=noise_amount)  # Instantiate the model

Set starting point of the NUTS sampler to some point specified as a dictionary. Then initialize the NUTS sampler

from quicksampler import NUTS
import jax

initial_position = {'m':0.0, 'c':1.0} # starting point of the NUTS sampler
problem = NUTS(LR, initial_position)

Run the results

result =
Running the inference for 1000 samples
df = pd.DataFrame(result)
m c
0 0.103061 -0.013709
1 0.103323 -0.011630
2 0.105348 -0.018875
3 0.105896 -0.039735
4 0.108621 -0.031321
... ... ...
995 0.104327 -0.025296
996 0.105400 -0.005010
997 0.105913 -0.027444
998 0.099664 0.005911
999 0.100274 0.014101

1000 rows × 2 columns

The chain will be returned as a dataframe. We can now create a traceplot and a histogram of the parameters


Inferring coin flip bias with the Metropolis Hastings Sampler & numpy

Same thing again, lets generate some data. We’ll do 100 experiments of 10 flips each using a biased coin with parameter \(p\)

import jax
import jax.numpy as jnp
import scipy

N_experiments = 100
N_flips = 10
p = 0.7

rv = scipy.stats.binom(N_flips, p)
head_counts = rv.rvs(N_experiments)

Now lets define our coin flip model:

class CoinFlip:
    def __init__(self, N_flips, number_of_heads):
        self.head_counts = number_of_heads
        self.N_flips = N_flips

    def logpdf(self, x):
        p = x['p']
        return jnp.sum( jax.scipy.stats.binom.logpmf(self.head_counts, self.N_flips, p) )
from quicksampler import NUTS, MHSampler

CF = CoinFlip(10, head_counts)

initial_position = {'p':0.1} # starting point of the NUTS sampler
problem2 = MHSampler(CF, initial_position, limits={'p': [0,1]}, step_size=0.1, backend='numpy');
result2 =;
Getting 1000 using Metropolis Hastings
100%|██████████████████████████████████████| 1000/1000 [00:03<00:00, 293.00it/s]
Sampling finished with an acceptance rate of 60.79

import pandas as pd
import matplotlib.pyplot as plt
result = pd.DataFrame(result2)

plt.axvline(p, c='r')

Inferring coin flip bias with the Metropolis Hastings sampler and JAX

CF = CoinFlip(10, head_counts)

initial_position = {'p':0.1} # starting point of the NUTS sampler
problem2 = MHSampler(CF, initial_position, limits={'p': [0,1]}, step_size=0.1, backend='JAX');
result2 =;
Getting 1000 using Metropolis Hastings
100%|██████████████████████████████████████| 1000/1000 [00:03<00:00, 285.28it/s]
Sampling finished with an acceptance rate of 72.15
import pandas as pd
import matplotlib.pyplot as plt
result = pd.DataFrame(result2)

plt.axvline(p, c='r')

Inferring coin flip bias with the NUTS sampler

CF = CoinFlip(10, head_counts)

initial_position = {'p':0.1} # starting point of the NUTS sampler
problem2 = NUTS(CF, initial_position, limits={'p': [0,1], 'mu': [0.0, 1.0]}, step_size=0.1);
result2 =;
Running the inference for 1000 samples
import pandas as pd
import matplotlib.pyplot as plt
result = pd.DataFrame(result2)

plt.axvline(p, c='r')