Propagation of Error to the Rescue? A Naive Baseline for Quantifying the Uncertainty of Complex Predictive Models Fit by Smooth Minimization

23 Jun 2023 13:24

I wrote this on 30 September 2019, and haven't done anything with it since, except for making it the basis for some problem sets. I'm making it into a notebook now in case it might be useful to anyone else. (Also, sometimes doing things like this makes me do actual research.) See also: Uncertainty for Neural Networks, and Other Large Complicated Models.
\[ \DeclareMathOperator*{\argmin}{argmin} \newcommand{\Var}[1]{\mathbb{V}\left[ #1 \right]} \]

Let me sketch a very, very naive baseline for quantifying uncertainty in the predictions of even very complex models. Generically, I will write the model predictions take the form of \( m(x;\theta) \), where \( x \) is a (high-dimensional) vector of predictor variables, \( \theta \) is a (high-dimensional) vector of parameters, and the function class \( m \) is implicitly defined by the architecture of some multi-tentacled computational horror. I will casually vectorize this, so \( m(\mathbf{x};\theta) \) is the \( n \)-dimensional vector of predictions we get on the data set \( \mathbf{x} \), using the same parameter vector \( \theta \) for all data points. The actual predictands in the training data were \( \mathbf{y} \).

Notation: \( \nabla f \) will always refer to gradient of \( f \) with respect to \( \theta \), even if \( f \) also has other arguments. \( \nabla\nabla f \) will always be the Hessian of \( f \).

Assumption I: We get the estimate \( \hat{\theta} \) by minimizing some loss function on the training data, possibly with regularization. Specifically, \[ \hat{\theta} = \argmin_{\theta}{L_n(\mathbf{y}, m(\mathbf{x};\theta))} \] where \( L_n \) is the \( n \)-point loss function, possibly including regularization.
Assumption II: There's an ergodic property in play, so that \[ L_n(\mathbf{Y}, m(\mathbf{X};\theta)) \rightarrow \ell(\theta) \] for each \( \theta \).

(A sufficient condition for Assumption II would be is \( L_n \) is an average over per-data-point losses, and the data points are IID, but there are also more complex possibilities.)

Assumption III: For each \( n \), \( L_n \) has a unique, interior minimum in \( \theta \), with a positive-definite Hessian.
In words: \( L_n \) has a nice minimum.
Assumption IV: The limiting objective function \( \ell \) also has a nice minimum, at \( \theta^* \).

Taylor-expand \( \nabla L_n(\hat{\theta}) \) around \( \theta^* \), using the fact that \( \hat{\theta} \) is an interior minimum: \begin{eqnarray} 0 & = & \nabla L_n(\hat{\theta}\\ & = & \nabla L_n(\theta^*) +\nabla\nabla L_n(\theta^*) (\hat{\theta}-\theta^*) \\ & = & \left(\nabla\nabla L_n(\theta^*)\right)^{-1} \nabla L_n(\theta^*) + \hat{\theta} -\theta^*\\ \hat{\theta} & = & \theta^* - \left(\nabla\nabla L_n(\theta^*)\right)^{-1} \nabla L_n(\theta^*) \end{eqnarray}

It is convenient to abbreviate \( \nabla L_(\theta^*) \) by \( U_n \) and \( \nabla\nabla L_n(\theta^*) \) by \( \mathbf{H}_n \). (I write these with capital letters as reminders that they're random quantities.) Under suitable regularity assumptions, which let us exchange taking derivatives and limits, \[ U_n \rightarrow \nabla \ell(\theta^*) = 0 \] and \[ H_n \rightarrow \nabla\nabla \ell(\theta^*) \equiv \mathbf{h} \]

The utility of these expansions is that they give the asymptotic variance of \( \hat{\theta} \): \begin{eqnarray} \Var{\hat{\theta}} & = & \Var{\mathbf{H}_n^{-1} U_n}\\ & = & \mathbf{H}_n^{-1} \Var{U_n} \mathbf{H}_n^{-1}\\ & = & \mathbf{H}_n^{-1} \mathbf{J}_n \mathbf{H}_n^{-1} \end{eqnarray} Typically, when \( L_n \) is an average over independent or weakly-dependent terms, the variance of the gradient will approach a limit at rate \( 1/n \), \( n \mathbf{J}_n \rightarrow \mathbf{j} \). This will give \[ \Var{\hat{\theta}} \rightarrow n^{-1} \mathbf{h}^{-1} \mathbf{j} \mathbf{h}^{-1} \] However, that's not essential to what's of interest here; that's the expression \[ \Var{\hat{\theta}} = \mathbf{H}_n^{-1} \mathbf{J}_n \mathbf{H}_n^{-1} \] which is the sandwich variance (or sandwich covariance) matrix for \( \hat{\theta} \).

We now make

Assumption V: The predictions \( m(x;\theta) \) are differentiable in \( \theta \).

Since we want to make a prediction at \( x \), we Taylor-expand around \( \theta^* \): \begin{eqnarray} m(x;\hat{\theta}) & \approx & m(x;\theta^*) + (\hat{\theta} - \theta^*) \cdot \nabla m(x;\theta^*)\\ \Var{m(x;\hat{\theta})} & \approx & (\nabla m(x, \theta^*)) \cdot \Var{\hat{\theta}} \nabla m(x, \theta^*)\\ & \approx & (\nabla m(x, \hat{\theta})) \cdot \Var{\hat{\theta}} \nabla m(x, \hat{\theta}) \end{eqnarray} (The transition from the 1st to the 2nd line doesn't introduce any new approximation; the transition from the 2nd to the 3rd is justified by the conviction that \( \hat{\theta} \rightarrow \theta^* \).) Let's write \( g \) (for "gradient") for \( \nabla m(x, \hat{\theta}) \).

To sum up, we get an approximate variance for the prediction at \( x \) as \[ \Var{m(x;\hat{\theta})} \approx g \cdot \mathbf{H}_n^{-1} \mathbf{J}_n \mathbf{H}_n g \] In other words: combine the sandwich variance matrix with propagation of error.

This gives us an approximate variance, but not an approximate distribution. The easiest case would be when \( \hat{\theta} \) has a Gaussian distribution, since then that will propagate through to \( m(x;\hat{\theta}) \), at least to the extent the first-order Taylor expansion of \( m \) is justifiable. Note that if the loss function is indeed an average over many terms, each with small effect on the total and only weakly dependent, we'll get Gaussian fluctuations in limit of many terms even when each term is very non-Gaussian. However, I won't insist on this, and would be very interested in exploring non-Gaussian limits.

Estimating the terms of the approximate variance: \( g=\nabla m(x;\hat{\theta}) \), which is fairly straightforward. (Indeed we often get it for the training values of \( x \) as part of the fitting process.) The Hessian \( \mathbf{H}_n \) is also straightforward in principle (and, again, is often computed as part of the training). The variance \( \mathbf{J}_n \) of the loss gradient \( U \) is the tricky part. (It's the only thing which is a functional of the ensemble, rather than of the realized loss function on the data or the predictive architecture.) If \( L_n \) is an average of point-by-point losses, say \( l_i \), and the data are IID, we can use \( n^{-1}\Var{\nabla l_i} \) to estimate \( \Var{\nabla L_n} \). The econometricians have related tricks for non-IID data ("heteroskedastic-autocorrelated robust standard errors"), hereby incorporated by reference.