Regularization

# HIDDEN
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 10})
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
import scipy.stats as sstats 

import ipywidgets as wd

A Typical Linear Problem

Consider the linear problem

$$ q = Ax, \ x\in\mathbb{R}^n, \ q\in \mathbb{R}^p, \ A\in\mathbb{R}^{p\times n} $$

where we assume that $p\leq n$ and $A$ is rank $p$.

Statistical Bayesian connection to Tikhonov regularization

Assuming we observe datum $\tilde{q}$ and that we are using a Gaussian prior ($N(\bar{x},C_x)$) and Gaussian noise model ($N(0,C_q)$), the statistical Bayesian posterior is given by

$$ \pi^{\text{post}} \propto \exp\left(-\frac{1}{2}\left( \underbrace{\left|\left|C_q^{-1/2}(q-\tilde{q})\right|\right|_2^2}_{\text{Data mismatch}} + \underbrace{\left|\left|C_x^{-1/2}(x-\bar{x})\right|\right|_2^2}_{\text{Tikhonov regularization}} \right)\right) $$

where we have made explicit the connection of the MAP (maximum a posteriori) point of the posterior density with the Tikhonov regularized solution to a deterministic optimization problem.

See https://en.wikipedia.org/wiki/Tikhonov_regularization for more information.

Take-aways

  • The model defines the data mismatch and the prior defines the regularization.
  • The regularization impacts all directions of the posterior since we effectively balance the data mismatch with our prior beliefs. This implies that the "solution" defined by a MAP point is not necessarily a point that produces the observed datum.

Example

Consider the linear problem where $$ A = [2 \ -1],\ \bar{x}=\left[\begin{array}[c] 00.2 \\ 0.2 \end{array}\right], \\ \ C_x = \text{diag}(0.5, 0.25), \\ \ \tilde{q} = [0.1], \ C_q = [0.25]. $$

Things to play with

  • Try changing the x_prior in the code to something other than $[0.2 \ 0.2]^\top$ to make the prior guess either better or worse. What happens?

  • Try playing with the C_x covariance to give the prior guess either more confidence (reduce the components) or less confidence (increase the components). What happens?

# HIDDEN
def solve(n=101, data_cov_const=0.25, prior_x1=0.2, prior_x2=0.2, sigma_x1=0.5, sigma_x2=0.5, tinker=False):
    # Discretize a portion of the input space R^2
    # copied troy's code into here. 
    # Setup example and prior guess, prior prediction, and actual datum
    a = 2
    b = -1
    
    A = np.array([[a, b]]) #map
    x_prior = np.array([prior_x1, prior_x2]).reshape(-1,1) #prior guess of mean
    q_obs = np.array([0.1]) #actual datum # leave fixed.
    q_prior = np.dot(A,x_prior) #predicted datum using prior
    print('Prior Mean (x1,x2) =', *x_prior, 'maps to q =', *q_prior[0])    
    if not tinker:
        def data_misfit(x):
            C_q_inv = np.linalg.inv(C_q)
            q = np.dot(A,x)
            WSSE = np.vdot(np.dot(C_q_inv,q-q_obs),q-q_obs) #weighted sum-squared error
            res = q-q_obs
            WSSE = res@C_q_inv@res
            return WSSE

        def Tikhonov_reg(x):
            C_x_inv = np.linalg.inv(C_x)
            WSSE = np.vdot(np.dot(C_x_inv,x-x_prior),x-x_prior) #weighted sum-squared error
            return WSSE

        def unregularize(x):
            C_A_inv = np.linalg.inv(C_A)
            q = np.dot(A,x)
            WSSE = np.vdot(np.dot(C_A_inv,q-q_prior),q-q_prior) #weighted sum-squared error
            return WSSE

        # Setup all the covariances
        prior_cov = [sigma_x1, sigma_x2]
        C_x = np.diag(prior_cov) #prior covariance
        C_q = np.diag([data_cov_const]) #data covariance
        C_A = np.dot(np.dot(A,C_x),A.transpose()) #the "covariance of the map"

        x1 = np.linspace(-0.5, 0.5,n)
        x2 = x1
        x1,x2 = np.meshgrid(x1,x2)
        # Compute all the WSSE terms

        WSSE = np.zeros((n,n))
        TSSE = np.zeros((n,n))
        USSE = np.zeros((n,n))
        for i in range(n):
            for j in range(n):
                WSSE[j,i] = data_misfit(np.array([[x1[j,i],x2[j,i]]]).transpose())
                TSSE[j,i] = Tikhonov_reg(np.array([[x1[j,i],x2[j,i]]]).transpose())
                USSE[j,i] = unregularize(np.array([[x1[j,i],x2[j,i]]]).transpose())

        x_reg_ind = np.argmin(WSSE+TSSE)
        x_unreg_ind = np.argmin(WSSE+TSSE-USSE)


        print('Absolute error in prediction through Tikonov: ', 
          np.abs(0.1 - np.dot(A,[x1.flatten()[x_reg_ind],x2.flatten()[x_reg_ind]])[0]))

        print('Absolute error in prediction through CB: ', 
          np.abs(0.1 - np.dot(A,[x1.flatten()[x_unreg_ind],x2.flatten()[x_unreg_ind]])[0]))


        f, axarr = plt.subplots(2, 3, figsize=[20,20])
        i, j = 0,0
        ax = axarr[i,j]
        ax.set_xlabel('$x_1$')
        ax.set_ylabel('$x_2$')
        ax.set_title('Prior')
        ax.set_aspect('equal')
        Z = sstats.multivariate_normal.pdf(np.concatenate([x1.reshape(-1,1), x2.reshape(-1,1)], axis=1), mean=x_prior.flatten(), cov=np.diag(prior_cov))
        ax.pcolormesh(x1, x2, Z.reshape(n,n), cmap=cm.hot, linewidth=0, antialiased=False)
        ax.scatter([x_prior[0]], [x_prior[1]], s=250,facecolor='red')
        Mv = ['WSSE', 'TSSE']
        for M in [WSSE, TSSE]:
            ax = axarr[i,j+1]
            ax.set_aspect('equal')
            ax.set_xlabel('$x_1$')
            ax.set_ylabel('$x_2$')
            ax.set_title(Mv[j])          
            ax.pcolormesh(x1, x2, M, cmap=cm.hot, linewidth=0, antialiased=False)
            if Mv[j] == 'WSSE':
                plt.plot(x1,x1, '-', color='blue', zorder=10)
            j +=1

        i, j = 1, 0
        Mv = ['WSSE + TSSE - USSE', 'TSSE - USSE', 'WSSE + TSSE', ]
        for M in [WSSE + TSSE - USSE, TSSE-USSE, WSSE + TSSE]:

            ax = axarr[i,j]
            ax.set_aspect('equal')
            ax.pcolormesh(x1, x2, M, cmap=cm.hot, linewidth=0, antialiased=False)
            if j == 0:
                ax.scatter([x1.flatten()[x_unreg_ind]], [x2.flatten()[x_unreg_ind]], s=250, facecolor='yellow')
    #             ax.plot([-0.2, 0.2], [-0.4, 0.4])
                ax.scatter([x_prior[0]], [x_prior[1]], s=250,facecolor='red')
            if j == 2:
                ax.scatter([x1.flatten()[x_reg_ind]], [x2.flatten()[x_reg_ind]], s=250, facecolor='blue')
                ax.scatter([x_prior[0]], [x_prior[1]], s=250,facecolor='red')
    #             ax.plot([-0.2, 0.2], [-0.4, 0.4])
            ax.set_xlabel('$x_1$')
            ax.set_ylabel('$x_2$')
            ax.set_title(Mv[j])

            j +=1

        plt.show()
# HIDDEN
A = wd.interactive(solve, n = wd.fixed(101),
                   data_cov_const=wd.FloatSlider(value=0.05, min=0.001, max=2, step=0.001, continuous_update=False),
                   sigma_x1=wd.FloatSlider(value=0.5, min=0.025, max=2.5, step=0.025, continuous_update=False), 
                   sigma_x2=wd.FloatSlider(value=0.25, min=0.025, max=2.5, step=0.025,  continuous_update=False), 
                   prior_x1=wd.FloatSlider(value=0.2, min=-0.5, max=0.5, continuous_update=False), 
                   prior_x2=wd.FloatSlider(value=0.2, min=-0.5, max=0.5, continuous_update=False),
                   tinker=wd.ToggleButton(value=True, continuous_update=False, description='tinker mode')
                  )
# HIDDEN
def cback1(a):
    A.children[0].value = a['new']

A.children[3].observe(cback1, 'value')
A.children[4].observe(cback1, 'value')
A

Archive

Previous