# HIDDEN
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 10})
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
import scipy.stats as sstats
import ipywidgets as wd
A Typical Linear Problem¶
Consider the linear problem
$$ q = Ax, \ x\in\mathbb{R}^n, \ q\in \mathbb{R}^p, \ A\in\mathbb{R}^{p\times n} $$
where we assume that $p\leq n$ and $A$ is rank $p$.
Statistical Bayesian connection to Tikhonov regularization¶
Assuming we observe datum $\tilde{q}$ and that we are using a Gaussian prior ($N(\bar{x},C_x)$) and Gaussian noise model ($N(0,C_q)$), the statistical Bayesian posterior is given by
$$ \pi^{\text{post}} \propto \exp\left(-\frac{1}{2}\left( \underbrace{\left|\left|C_q^{-1/2}(q-\tilde{q})\right|\right|_2^2}_{\text{Data mismatch}} + \underbrace{\left|\left|C_x^{-1/2}(x-\bar{x})\right|\right|_2^2}_{\text{Tikhonov regularization}} \right)\right) $$
where we have made explicit the connection of the MAP (maximum a posteriori) point of the posterior density with the Tikhonov regularized solution to a deterministic optimization problem.
See https://en.wikipedia.org/wiki/Tikhonov_regularization for more information.
Take-aways¶
- The model defines the data mismatch and the prior defines the regularization.
- The regularization impacts all directions of the posterior since we effectively balance the data mismatch with our prior beliefs. This implies that the "solution" defined by a MAP point is not necessarily a point that produces the observed datum.
Example¶
Consider the linear problem where $$ A = [2 \ -1],\ \bar{x}=\left[\begin{array}[c] 00.2 \\ 0.2 \end{array}\right], \\ \ C_x = \text{diag}(0.5, 0.25), \\ \ \tilde{q} = [0.1], \ C_q = [0.25]. $$
Things to play with¶
Try changing the
x_prior
in the code to something other than $[0.2 \ 0.2]^\top$ to make the prior guess either better or worse. What happens?Try playing with the
C_x
covariance to give the prior guess either more confidence (reduce the components) or less confidence (increase the components). What happens?
# HIDDEN
def solve(n=101, data_cov_const=0.25, prior_x1=0.2, prior_x2=0.2, sigma_x1=0.5, sigma_x2=0.5, tinker=False):
# Discretize a portion of the input space R^2
# copied troy's code into here.
# Setup example and prior guess, prior prediction, and actual datum
a = 2
b = -1
A = np.array([[a, b]]) #map
x_prior = np.array([prior_x1, prior_x2]).reshape(-1,1) #prior guess of mean
q_obs = np.array([0.1]) #actual datum # leave fixed.
q_prior = np.dot(A,x_prior) #predicted datum using prior
print('Prior Mean (x1,x2) =', *x_prior, 'maps to q =', *q_prior[0])
if not tinker:
def data_misfit(x):
C_q_inv = np.linalg.inv(C_q)
q = np.dot(A,x)
WSSE = np.vdot(np.dot(C_q_inv,q-q_obs),q-q_obs) #weighted sum-squared error
res = q-q_obs
WSSE = res@C_q_inv@res
return WSSE
def Tikhonov_reg(x):
C_x_inv = np.linalg.inv(C_x)
WSSE = np.vdot(np.dot(C_x_inv,x-x_prior),x-x_prior) #weighted sum-squared error
return WSSE
def unregularize(x):
C_A_inv = np.linalg.inv(C_A)
q = np.dot(A,x)
WSSE = np.vdot(np.dot(C_A_inv,q-q_prior),q-q_prior) #weighted sum-squared error
return WSSE
# Setup all the covariances
prior_cov = [sigma_x1, sigma_x2]
C_x = np.diag(prior_cov) #prior covariance
C_q = np.diag([data_cov_const]) #data covariance
C_A = np.dot(np.dot(A,C_x),A.transpose()) #the "covariance of the map"
x1 = np.linspace(-0.5, 0.5,n)
x2 = x1
x1,x2 = np.meshgrid(x1,x2)
# Compute all the WSSE terms
WSSE = np.zeros((n,n))
TSSE = np.zeros((n,n))
USSE = np.zeros((n,n))
for i in range(n):
for j in range(n):
WSSE[j,i] = data_misfit(np.array([[x1[j,i],x2[j,i]]]).transpose())
TSSE[j,i] = Tikhonov_reg(np.array([[x1[j,i],x2[j,i]]]).transpose())
USSE[j,i] = unregularize(np.array([[x1[j,i],x2[j,i]]]).transpose())
x_reg_ind = np.argmin(WSSE+TSSE)
x_unreg_ind = np.argmin(WSSE+TSSE-USSE)
print('Absolute error in prediction through Tikonov: ',
np.abs(0.1 - np.dot(A,[x1.flatten()[x_reg_ind],x2.flatten()[x_reg_ind]])[0]))
print('Absolute error in prediction through CB: ',
np.abs(0.1 - np.dot(A,[x1.flatten()[x_unreg_ind],x2.flatten()[x_unreg_ind]])[0]))
f, axarr = plt.subplots(2, 3, figsize=[20,20])
i, j = 0,0
ax = axarr[i,j]
ax.set_xlabel('$x_1$')
ax.set_ylabel('$x_2$')
ax.set_title('Prior')
ax.set_aspect('equal')
Z = sstats.multivariate_normal.pdf(np.concatenate([x1.reshape(-1,1), x2.reshape(-1,1)], axis=1), mean=x_prior.flatten(), cov=np.diag(prior_cov))
ax.pcolormesh(x1, x2, Z.reshape(n,n), cmap=cm.hot, linewidth=0, antialiased=False)
ax.scatter([x_prior[0]], [x_prior[1]], s=250,facecolor='red')
Mv = ['WSSE', 'TSSE']
for M in [WSSE, TSSE]:
ax = axarr[i,j+1]
ax.set_aspect('equal')
ax.set_xlabel('$x_1$')
ax.set_ylabel('$x_2$')
ax.set_title(Mv[j])
ax.pcolormesh(x1, x2, M, cmap=cm.hot, linewidth=0, antialiased=False)
if Mv[j] == 'WSSE':
plt.plot(x1,x1, '-', color='blue', zorder=10)
j +=1
i, j = 1, 0
Mv = ['WSSE + TSSE - USSE', 'TSSE - USSE', 'WSSE + TSSE', ]
for M in [WSSE + TSSE - USSE, TSSE-USSE, WSSE + TSSE]:
ax = axarr[i,j]
ax.set_aspect('equal')
ax.pcolormesh(x1, x2, M, cmap=cm.hot, linewidth=0, antialiased=False)
if j == 0:
ax.scatter([x1.flatten()[x_unreg_ind]], [x2.flatten()[x_unreg_ind]], s=250, facecolor='yellow')
# ax.plot([-0.2, 0.2], [-0.4, 0.4])
ax.scatter([x_prior[0]], [x_prior[1]], s=250,facecolor='red')
if j == 2:
ax.scatter([x1.flatten()[x_reg_ind]], [x2.flatten()[x_reg_ind]], s=250, facecolor='blue')
ax.scatter([x_prior[0]], [x_prior[1]], s=250,facecolor='red')
# ax.plot([-0.2, 0.2], [-0.4, 0.4])
ax.set_xlabel('$x_1$')
ax.set_ylabel('$x_2$')
ax.set_title(Mv[j])
j +=1
plt.show()
# HIDDEN
A = wd.interactive(solve, n = wd.fixed(101),
data_cov_const=wd.FloatSlider(value=0.05, min=0.001, max=2, step=0.001, continuous_update=False),
sigma_x1=wd.FloatSlider(value=0.5, min=0.025, max=2.5, step=0.025, continuous_update=False),
sigma_x2=wd.FloatSlider(value=0.25, min=0.025, max=2.5, step=0.025, continuous_update=False),
prior_x1=wd.FloatSlider(value=0.2, min=-0.5, max=0.5, continuous_update=False),
prior_x2=wd.FloatSlider(value=0.2, min=-0.5, max=0.5, continuous_update=False),
tinker=wd.ToggleButton(value=True, continuous_update=False, description='tinker mode')
)
# HIDDEN
def cback1(a):
A.children[0].value = a['new']
A.children[3].observe(cback1, 'value')
A.children[4].observe(cback1, 'value')
A