8 Computational Bayesian Inference: Probabilistic Programming with NumPyro
In the Bayesian Inference chapter, we calculated (by hand) a posterior distribution using a single data point on whether the one tested Chili’s store increased sales. In this original Chili’s story, we limited ourselves to just two possible models of success:
- The Pessimist Model: \(\theta = 20\)%
- The Optimist Model: \(\theta = 70\)%
We represented the Chili’s narrative using the generative DAG shown in Figure 8.1 and got a posterior distribution, by hand, with just two candidate models, the pessimist model and the optimist model.
We now expand our analysis to consider infinite possible values of success probability \(\theta\) where \(0 \leq \theta \leq 1\). This expansion can be expressed in an updated graphical model plus an updated statistical model. The updated graphical model is in Figure 8.2.
The updated statistical model is below.
\[ \begin{aligned} \Theta \equiv& \textrm{ Store Success Probability: } \\ X \equiv& \textrm{ Sales Increase: } \\ & \textrm{ If sales increase more than 5} \% \textrm{, then }X=1 \textrm{ otherwise, }X=0.\\ \theta \sim & \textrm{ Uniform}(0,1) \\ x \sim & \textrm{ Bernoulli}(\theta) \end{aligned} \]
In the above, \(\theta \sim \textrm{ Uniform}(0,1)\) becomes the new representation of our prior probability. This prior spreads equal plausibility among all potential values for probability \(\theta\). We are asserting that all probability values are equally likely. In the previous chapter, for any given \(\theta\) value, we were able to calculate prior and likelihood by hand, but now there are infinite \(\theta\) values.
To manually use Bayes rule to calculate the posterior distribution for the infinite possible values of \(\theta\) is impossible. That would simply take too long. Instead, we will rely on state-of-the-art computational techniques and software to deliver us a representative sample from the posterior distribution. In fact, the software we will use allows us tremendous flexibility; seemingly accommodating any probabilistic graphical model consisting of the most commonly used probability distributions. In other words, we as analysts supply a our model, prior, and some data, and then, the computer spits out the posterior - using Bayes rule behind the scenes to properly update our prior assumptions to place more plausibility on model parameters that are consistent with the observed data.
8.0.1 Bayesian Updating With NumPyro
NumPyro is an offshoot of Pyro, a universal probabilistic programming language (PPL) written in Python. We use NumPyro which uses Jax to gain a 100x speedup over Pyro (which used PyTorch) for the types of problems we will work on. You can read more about the origins of Pyro at https://www.uber.com/blog/pyro/.
To get the computer to do our work, we need to translate our data and generative DAG (i.e. graphical model + statistical model that has enough detail to simulate data) into a computational model (i.e. a computer-world model that calculates posterior distributions). We will learn to do this is using the fantastic numpyro
package.
The following probability distributions (and more) are classes in the distributions
sub-module of numpyro
: Uniform, Normal, LogNormal, Bernoulli, Binomial, BetaBinomial, NegativeBinomial, Poisson, Gamma, InverseGamma, Weibull,Exponential, Pareto, StudentT, Laplace, Beta, Cauchy, Chi2, Logistic, MultivariateNormal, MultivariateStudentT,Multinomial, Categorical, Dirichlet, DirichletMultinomial, LKJ.
Many of these will be introduced in subsequent chapters and all others can be found on Wikipedia. Do not let the fancy names scare you, they are just ways of compactly representing uncertainty.
Step 1 - Create a Python function that tells numpyro
how to simulate all the nodes of your graphical model.
When reproducing a graphical and statistical model in numpyro
code, the entire model gets defined as a function where the sole argument to the function (for now) is the observed data. Shown below is the coded representation of Figure 8.2:
import numpy as np
import numpyro
import numpyro.distributions as dist
## define the graphical/statistical model as a Python function
## pass data and the cardinality of plates as inputs
## N represents the number of chilis stores for which data is observed
def chilisModel(x):
# numpyro.sample is a "primitive", i.e. basic building block of model
= numpyro.sample('theta', dist.Uniform(low = 0, high = 1))
theta # numpyro.plate is another primitive
with numpyro.plate('N', len(x)):
= numpyro.sample('x', dist.Bernoulli(probs = theta), obs=x) x
Two functions, called primitives, serve as building blocks in building the above function definition. They are:
numpyro.sample
: This function takes two positional arguments and optionally one keyword argument. The first argument is simply the name of the node, I will use the mathematical shorthand name shown in the graphical model (e.g. theta or x). The second argument is the probability distribution that serves as either prior (for unobserved nodes) or likelihood (for observed nodes). The keyword argumentobs
is used to supply observed data when applicable, e.g. the sales outcomes of Chilis’ stores.numpyro.plate
: This function, typically used in the context of awith
block takes two positional arguments. The first argument is simply the name of the plate, I will use the name of the plate shown in the graphical model (e.g. observation). The second argument is the cardinality of the plate which serves as a counter of how many individual realizations of each node are signified by their existence on the plate. In the above, the Sales Increase node is repeated once for every store that we observe data at.
When replicating a graphical and statistical model in code, all nodes that exist on a plate should be sampled within the context of that plate as signaled by indentation (just like is done in for loops). Additionally, all observed data should be passed into the function as arguments; above we pass in x
as the lone argument.
Step 2 - Prepare Your Data.
Data should be passed in as either a numpy
array or a jax.numpy
array. In the above, we make up some data (1,1,0
) to represent the first two stores successfully increasing sales and the third observed store failing to increase sales.
## define the required inputs for chilismodel
## assume first two stores are a success and
## the third store is not (i.e. make up some data)
= np.array([1,1,0]) salesIncData
Step 3 - Use NumPyro to Get A Representative Sample of The Posterior Distribution.
For now, I request that when you modify the below code for other problems, only modify the model name, i.e. chilisModel
, and the arguments passed to mcmc.run
. We will learn more details about this code in subsequent chapters. For now, let’s just digest the code at a high level of understanding where its purpose is to get a representative sample of unobserved parameters from the posterior distribution.
from jax import random
from numpyro.infer import MCMC, NUTS
## computationally get posterior distribution
## in the below line, only change "chilisModel" for different problems
= MCMC(NUTS(chilisModel), num_warmup=500, num_samples=4000)
mcmc = random.PRNGKey(seed = 111) ## so you and I get same results
rng_key ## supply the run method with a random key followed by arguments to model
=salesIncData) ## get representative sample of posterior mcmc.run(rng_key, x
The essence of the above code is that it creates an object called mcmc
. A fancy algorithm called NUTS (No U-Turn Sampler) is set-up to create a representative sample of the posterior when passed some data. Data gets massed through the objects .run()
method. The first argument to .run()
is always a random key, and any additional arguments should pass data to the arguments of the NumPyro model.
After running mcmc.run(rng_key, x=salesIncData)
, the mcmc
object has successfully created a representative sample from the posterior. This sample is often referred to as the posterior distribution. To access the posterior distribution, we will use functionality from the arviz
package to extract an xarray
dataset containing the information we need as shown below:
import arviz as az
## get samples into xarray
= az.from_numpyro(mcmc).posterior
drawsDS drawsDS
<xarray.Dataset> Dimensions: (chain: 1, draw: 4000) Coordinates: * chain (chain) int32 0 * draw (draw) int32 0 1 2 3 4 5 6 7 ... 3993 3994 3995 3996 3997 3998 3999 Data variables: theta (chain, draw) float32 0.1519 0.1514 0.7819 ... 0.4454 0.5435 0.4545 Attributes: created_at: 2023-05-02T16:28:05.101292 arviz_version: 0.13.0 inference_library: numpyro inference_library_version: 0.10.0
Step 4 - Use the posterior distribution for insight and for making probabilistic statements.
A quick and easy way to visualize the posterior distribution for theta
is to use the arviz.plot_dist()
function:
az.plot_dist(drawsDS.theta)
However, if you want to do any customization of an arviz
plot, we must revert back to our typical matplotlib
workflow and use arviz
to modify an Axes
object:
As can be seen from Figure 8.4, there is more plausibility to the right of 0.5. Hence, one might wonder “what is the probability that more than half the stores receiving a remodel will see a successful increase in sales?” To answer this, just like any probabilistic query, we use an indicator function and the fundamental bridge to find out:
## use indicator function to make probabilistic statements
## for example, find P(theta > 0.5)
(
drawsDS= drawsDS.theta > 0.50)
.assign(thetaOver50 # about 60% probability theta is over 50%
.mean() ).to_pandas()
theta 0.599126
thetaOver50 0.699000
dtype: float64
Hence, we can say \(P(\theta>50\%) \approx 69\%\). You might also notice another result, namely \(\mathbb{E}[\Theta] \approx 60\%\). We will dive deeper into extracting insights in subsequent chapters. For now, the big takeaway is that data can be used to update our knowledge. Initially, we had complete uncertainty as to plausible values of \(\Theta\), but after observing just three stores, we refine our beliefs to shift plausibility away from extremely low-values which are inconsistent with the observed successes in store sales.
8.1 Using a beta Prior in NumPyro
Let’s now use a \(\textrm{beta}\) distribution prior to reenforce how NumPyro works, but more importantly to demonstrate how priors and data observations come together to make a posterior distribution.
The updated statistical model is below.
\[
\begin{aligned}
\Theta \equiv& \textrm{ Store Success Probability: } \\
X \equiv& \textrm{ Sales Increase: } \\
& \textrm{ If sales increase more than 5} \% \textrm{, then }X=1 \textrm{ otherwise, }X=0.\\
\theta \sim & \textrm{ Beta}(2,2) \\
x \sim & \textrm{ Bernoulli}(\theta)
\end{aligned}
\] Figure 8.5 is a DAG declaring a \(\textrm{beta}(2,2)\) prior as representative of our uncertainty in \(\theta\). We aren’t saying too much with this prior. This is a weak prior because. as we will see, it will be easily overwhelmed by data; its just saying that two successes and two failures have been seen. Let’s then imagine processing new data where we observe 20 successes and only two failures. This data is highly consistent with a very large value for theta. We can intelligently combine prior and data using numpyro
to get our posterior:
import numpy as np
import numpyro
import numpyro.distributions as dist
import arviz as az
from jax import random
from numpyro.infer import MCMC, NUTS
## define the data - 20 successes and 2 failures
= np.concatenate((np.repeat(1, 20), np.repeat(0,2)))
successData
## define the graphical/statistical model as a Python function
def betaBernoulliModel(x):
# concentration1: 1st concentration parameter (alpha) for the Beta dist.
# think the higher alpha, the more concentrated theta values towards 1
# concentration0: 2nd concentration parameter (beta) for the Beta dist.
# think the higher beta, the more concentrated theta values towards 0
= numpyro.sample('theta', dist.Beta(concentration1=2, #alpha
theta =2)) #beta
concentration0with numpyro.plate('observation', len(x)):
= numpyro.sample('x', dist.Bernoulli(probs = theta), obs=x)
x
## computationally get posterior distribution
= MCMC(NUTS(betaBernoulliModel), num_warmup=500, num_samples=4000)
mcmc = random.PRNGKey(seed = 111) ## so you and I get same results
rng_key =successData) ## get representative sample of posterior mcmc.run(rng_key, x
= az.from_numpyro(mcmc).posterior ## get samples into xarray drawsDS
And once the posterior sample (drawsDS
) is created, we can query the results and compare them to the \(\textrm{beta}(2,2)\) prior:
from matplotlib.lines import Line2D
= plt.subplots(figsize=(6, 3.5),
fig, ax ='constrained')
layout
# plot density estimate, i.e. estimate of f(x)
= ax, color = "darkorchid",
az.plot_dist(drawsDS.theta, ax = {"zorder": 1, "linewidth": 4, "alpha": 0.5},
plot_kwargs ={"alpha": 0.5},
fill_kwargs=[.10, .50, .90])
quantiles
# plot prior from rep sample
= default_rng(seed=111).beta(2,2,50000)
beta2_2_repSample = ax, color = "cadetblue",
az.plot_dist(beta2_2_repSample, ax = {"zorder": 1, "linewidth": 4, "alpha": 0.5},
plot_kwargs ={"alpha": 0.5},
fill_kwargs=[.10, .50, .90])
quantiles
0,.25,.5,.75,1])
ax.set_xticks(['Plausibility Measure: ' + r'$f_\Theta(\theta)$')
ax.set_ylabel(r'$\theta$')
ax.set_xlabel(
= [Line2D([0], [0], color = "darkorchid", lw=4, alpha = 0.5),
custom_lines 0], [0], color = "cadetblue", lw=4, alpha = 0.5)]
Line2D(['Posterior', 'Beta(2,2) Prior'], loc='upper left')
ax.legend(custom_lines, [
plt.show()
Figure 8.6 shows a dramatic shift from prior to posterior distribution. The weak prior suggest all values had plausibility, but once observing 20 successes out of 22 trials, the higher values for \(\theta\) are much more plausible.
If we want to change the prior to something stronger, say a \(\textrm{beta}(50,50)\), then we can rerun the numpyro
code just changing the one line for the prior:
Figure 8.7 shows a posterior distribution that is only mildly shifted from its prior. This is a direct result of a strong prior due to the larger \(\alpha\) and \(\beta\) parameters. In general, we will not use strong priors and seek weakly informative priors that yield plausible prior generating processes, yet are flexible enough to let the data inform the posterior generating process. There is a bit of an art to this and we will learn more in subsequent chapters.
8.2 Getting Help
See the “Getting Started” section of the numpyro documentation for more details about coding in numpyro
. A link to that section is here: https://num.pyro.ai/en/stable/getting_started.html.
8.3 Questions to Learn From
See CANVAS.