Notebooks
http://bactra.org/notebooks
Cosma's NotebooksenPropagation of Error to the Rescue? A Naive Baseline for Quantifying the Uncertainty of Complex Predictive Models Fit by Smooth Minimization
http://bactra.org/notebooks/2023/06/23#prop-error-uncertainty-quant
<blockquote>I wrote this on 30 September 2019, and haven't done anything with
it since, except for making it the basis for some problem sets. I'm making it
into a notebook now in case it might be useful to anyone else. (Also,
sometimes doing things like this makes me do actual research.) See also: <a href="uncertainty-for-neural-networks.html">Uncertainty for Neural Networks, and Other Large Complicated Models</a>.</blockquote>
\[
\DeclareMathOperator*{\argmin}{argmin}
\newcommand{\Var}[1]{\mathbb{V}\left[ #1 \right]}
\]
<P>Let me sketch a very, very naive baseline for quantifying uncertainty in the
predictions of even very complex models. Generically, I will write the model
predictions take the form of \( m(x;\theta) \), where \( x \) is a (high-dimensional)
vector of predictor variables, \( \theta \) is a (high-dimensional) vector of
parameters, and the function class \( m \) is implicitly defined by the
architecture of some <a href="neural-nets.html">multi-tentacled computational horror</a>. I will
casually vectorize this, so \( m(\mathbf{x};\theta) \) is the \( n \)-dimensional
vector of predictions we get on the data set \( \mathbf{x} \), using the same
parameter vector \( \theta \) for all data points. The
actual predictands in the training data were \( \mathbf{y} \).
<P>Notation: \( \nabla f \) will always refer to gradient of \( f \) with respect to
\( \theta \), even if \( f \) also has other arguments. \( \nabla\nabla f \) will
always be the Hessian of \( f \).
<blockquote><strong>Assumption I</strong>: We get the estimate \( \hat{\theta} \) by minimizing some
loss function on the training data, possibly with regularization.
Specifically,
\[
\hat{\theta} = \argmin_{\theta}{L_n(\mathbf{y}, m(\mathbf{x};\theta))}
\]
where \( L_n \) is the \( n \)-point loss function, possibly including regularization.</blockquote>
<blockquote><strong>Assumption II</strong>: There's an <a href="ergodic-theory.html">ergodic property</a> in play, so
that
\[
L_n(\mathbf{Y}, m(\mathbf{X};\theta)) \rightarrow \ell(\theta)
\]
for each \( \theta \).</blockquote>
<P>(A sufficient condition for Assumption II would be is \( L_n \) is an average over
per-data-point losses, and the data points are IID, but there are also
more complex possibilities.)
<blockquote><strong>Assumption III</strong>: For each \( n \), \( L_n \) has a unique, interior minimum in \( \theta \),
with a positive-definite Hessian.</blockquote>
In words: \( L_n \) has a nice minimum.
<blockquote><strong>Assumption IV</strong>: The limiting objective function \( \ell \) also has a nice minimum,
at \( \theta^* \).</blockquote>
<P>Taylor-expand \( \nabla L_n(\hat{\theta}) \) around \( \theta^* \), using the fact that
\( \hat{\theta} \) is an interior minimum:
\begin{eqnarray}
0 & = & \nabla L_n(\hat{\theta}\\
& = & \nabla L_n(\theta^*) +\nabla\nabla L_n(\theta^*) (\hat{\theta}-\theta^*) \\
& = & \left(\nabla\nabla L_n(\theta^*)\right)^{-1} \nabla L_n(\theta^*) + \hat{\theta} -\theta^*\\
\hat{\theta} & = & \theta^* - \left(\nabla\nabla L_n(\theta^*)\right)^{-1} \nabla L_n(\theta^*)
\end{eqnarray}
<P>It is convenient to abbreviate \( \nabla L_(\theta^*) \) by \( U_n \) and
\( \nabla\nabla L_n(\theta^*) \) by \( \mathbf{H}_n \). (I write these with capital letters as
reminders that they're random quantities.) Under suitable regularity
assumptions, which let us exchange taking derivatives and limits,
\[
U_n \rightarrow \nabla \ell(\theta^*) = 0
\]
and
\[
H_n \rightarrow \nabla\nabla \ell(\theta^*) \equiv \mathbf{h}
\]
<P>The utility of these expansions is that they give the asymptotic variance
of \( \hat{\theta} \):
\begin{eqnarray}
\Var{\hat{\theta}} & = & \Var{\mathbf{H}_n^{-1} U_n}\\
& = & \mathbf{H}_n^{-1} \Var{U_n} \mathbf{H}_n^{-1}\\
& = & \mathbf{H}_n^{-1} \mathbf{J}_n \mathbf{H}_n^{-1}
\end{eqnarray}
Typically, when \( L_n \) is an average over independent or weakly-dependent
terms, the variance of the gradient will approach a limit at rate \( 1/n \),
\( n \mathbf{J}_n \rightarrow \mathbf{j} \). This will give
\[
\Var{\hat{\theta}} \rightarrow n^{-1} \mathbf{h}^{-1} \mathbf{j} \mathbf{h}^{-1}
\]
However, that's not <em>essential</em> to what's of interest here; that's the
expression
\[
\Var{\hat{\theta}} = \mathbf{H}_n^{-1} \mathbf{J}_n \mathbf{H}_n^{-1}
\]
which is the <strong>sandwich variance</strong> (or <strong>sandwich covariance</strong>) matrix for \( \hat{\theta} \).
<P>We now make
<blockquote><strong>Assumption V</strong>: The predictions \( m(x;\theta) \) are differentiable in \( \theta \).</blockquote>
<P>Since we want to make a prediction at \( x \), we Taylor-expand around \( \theta^* \):
\begin{eqnarray}
m(x;\hat{\theta}) & \approx & m(x;\theta^*) + (\hat{\theta} - \theta^*) \cdot \nabla m(x;\theta^*)\\
\Var{m(x;\hat{\theta})} & \approx & (\nabla m(x, \theta^*)) \cdot \Var{\hat{\theta}} \nabla m(x, \theta^*)\\
& \approx & (\nabla m(x, \hat{\theta})) \cdot \Var{\hat{\theta}} \nabla m(x, \hat{\theta})
\end{eqnarray}
(The transition from the 1st to the 2nd line doesn't introduce any new
approximation; the transition from the 2nd to the 3rd is justified by the
conviction that \( \hat{\theta} \rightarrow \theta^* \).) Let's write \( g \) (for "gradient") for \( \nabla m(x, \hat{\theta}) \).
<P>To sum up, we get an approximate variance for the prediction at \( x \) as
\[
\Var{m(x;\hat{\theta})} \approx g \cdot \mathbf{H}_n^{-1} \mathbf{J}_n \mathbf{H}_n g
\]
In other words: combine the sandwich variance matrix with propagation of error.
<P>This gives us an approximate variance, but not an approximate distribution.
The easiest case would be when \( \hat{\theta} \) has a Gaussian distribution,
since then that will propagate through to \( m(x;\hat{\theta}) \), at
least to the extent the first-order Taylor expansion of \( m \) is justifiable.
Note that if the loss function is indeed an average over many terms, each with
small effect on the total and only weakly dependent, we'll get Gaussian
fluctuations in limit of many terms even when each term is very non-Gaussian.
However, I won't insist on this, and would be very interested in
exploring non-Gaussian limits.
<P><strong>Estimating the terms of the approximate variance</strong>: \( g=\nabla
m(x;\hat{\theta}) \), which is fairly straightforward. (Indeed we often get it
for the training values of \( x \) as part of the fitting process.) The Hessian
\( \mathbf{H}_n \) is also straightforward in principle (and, again, is often
computed as part of the training). The variance \( \mathbf{J}_n \) of the loss
gradient \( U \) is the tricky part. (It's the only thing which is a functional of
the ensemble, rather than of the realized loss function on the data or the
predictive architecture.) If \( L_n \) is an average of point-by-point losses,
say \( l_i \), and the data are IID, we can use \( n^{-1}\Var{\nabla l_i} \) to estimate \( \Var{\nabla L_n} \). The econometricians have related tricks
for non-IID data ("heteroskedastic-autocorrelated robust standard errors"),
hereby incorporated by reference.