## Learning Statistical Models Which Predict Well Across Many Rare Categories

*05 Mar 2024 21:03*

This is me trying out some ideas that might eventually become a paper, if someone else hasn't done it already. Pointers would be more than usually appreciated; also, it's especially likely to contain errors.\[ \newcommand{\Expect}[1]{\mathbb{E}\left[ #1 \right]} \newcommand{\Prob}[1]{\mathbb{P}\left[ #1 \right]} \newcommand{\Risk}{r} \]

Start with your basic statistical-decision-theoretic problem: there's an unknown state of the world \( Y \), we want to take an action \( a \in A \), when we do so we incur a loss \( \ell(Y,a) \), we have available information \( X \). ("Action" here includes predictions about \( Y \).) We want rules which are functions of \( X \), so \( a = m(x, \psi) \) for some parameter \( \psi \). Our goal is to minimize the risk, i.e., to pick \( \psi \) so as to make small \( \Risk(\psi) \equiv \Expect{\ell(Y, m(X, \psi))} \). (I am going to assume there's a joint distribution for \( X \) and \( Y \), that we don't care about dependence across data points, etc.)

So far so standard. Now let's add a little more structure: instances of the
decision problem break up into discrete categories, which I'll denote by \( Z
\). There are a very large number of categories, with two properties: a few
categories contain *most* of the probability mass, but the large tail of
small categories is non-negligible. If you want to think in terms
of Pareto: 20% of the categories have 80% of the
probability, but conversely the other 80% of the categories add up to 20% of
the probability, which isn't something you can totally ignore unless you want
your system to just be hosed one day out of each work-week. Let's say that \(
p_k \equiv \Prob{Z=k} \propto k^{-\alpha} \) for some \( \alpha > 0 \).
(Actually nothing I've done yet really needs this assumption about the specific
functional form...)

The categories become relevant because the right decision rule to apply to go from \( Y \) to \( X \) depends on the category. I'm not interested (here!) in the difficulties of inferring latent categories, so let's say that we provide the machine with \( Z \) (or that it's easily inferred from \( X \) to start with). So the over-all risk is an average of the category-conditional risks: \[ \Risk(\psi) = \sum_{k}{p_k \Expect{\ell(Y, m(X, \psi))|Z=k}} \]

#### Parameterization corresponding to the categories

Now \( \psi \) is generally a big multi-dimensional object, so let's start with the situation where \( \psi \) is parameterized so that each of its coordinates says what to do for a particular category. That is \( \psi_k \) says what to do in the context of category \( k \), and \( \psi_j \) is irrelevant to the conditional risk when \( Z=k \) (unless of course \( j = k \) ). I will pretend we only need one parameter coordinate for each category; giving it a bunch shouldn't (I hope) do anything except complicate notation.

The first-order condition for risk minimization, in this parameterization,
is of course
\[
\frac{\partial \Risk}{\partial \psi_k} = p_k \frac{\partial}{\partial \psi_k}\Expect{\ell(Y, m(X, \psi))|Z=k} = 0
\]
for all \( k \). Now, even if we had access to the true risk function, if we
can only optimize numerically, we'll have to accept a situation where the
first-order condition is only *approximately* met. If we demand that
\( |\partial \Risk/\partial \psi_k| \leq \epsilon \), and
\( p_k \propto k^{-\alpha} \), then for large \( k \), far out in the tail of rare categories,
we're accepting much larger gradients, whose magnitude scales like \( k^{\alpha} \). That in turn means that we're tolerating being much further from the
optimum of th conditional risk in rare categories. This conclusion just relies on our having to accept some numerical slop in our optimization, not imperfect
data. Naturally data will be even more imperfect for rare categories.

I'd like to push this further, in two directions:

- Not enough parameters to go around
- Parameters cutting across categories

#### Not enough parameters

Suppose there's a limited number of parameters available to the model, and there are more categories than that. We will not then be able to give each category its own parameter, and so even with infinite data we'd have some issues. If the parameter settings for each category are very different, it seems intuitive to me that the right way to cope is to create a residual category for all the rare cases we can't distinguish among, rather than degrading performance on the most common categories by expanding them to include oddballs. I don't yet see how to show this, however.#### Cross-cutting parameters

Suppose the parameterization we've chosen*doesn't*neatly align with the categories. That is, suppose we're working with a different parameterization \( \theta \); it's inter-translatable with \( \psi \) but doesn't coincide with it. In fact let's say that the two are just linear transformations of each other, so \( \theta = \mathbf{w} \psi \) and \( \psi = \mathbf{w}^{-1} \theta \). Now the first-order condition reads \[ \begin{eqnarray} \frac{\partial \Risk}{\partial \theta_i} & = & \sum_{j}{\frac{\partial \theta_i}{\partial \psi_j} \frac{\partial \Risk}{\partial \psi_j}}\\ & = & \sum_{j}{w_{ij} \frac{\partial \Risk}{\partial \psi_j}}\\ & = & \sum_{k}{w_{ik} p_k \frac{\partial}{\partial \psi_k}\Expect{\ell(Y, m(X, \psi))|Z=k}} = 0 \end{eqnarray} \]

That is, the components of the gradient in the new, \( \theta \) coordinates are inner products of the old, \( \psi \) gradient vector and the vectors \( w_{i\cdot} \) of the transformation. Some components of the \( \theta \) gradient will be near zero just because \( w_{i\cdot} \) is (nearly) perpendicular to the \( \psi \) gradient.

## A Maybe-Simpler Way to Approach the Whole Problem

(At this point the off-line notes I'm mostly transcribing change notation, because I did this on a different day, and I'm not feeling up to totally harmonizing them.)

We have a decision rule/statistical model \( m(x, \theta) \), with \( \theta \) being the parameter. There's also a loss function, so with information \( X=x \) and parameter \( \theta \) we experience the random loss \( \ell(Y, m(x, \theta) \). Depending on the category \( Z \), there's some joint distribution of \( X \) and \( Y \), and so a category-conditional risk \( \Risk_z(\theta) \equiv \Expect{\ell(Y, m(x, \theta)) | Z=z} \). Let's say that for each category, there is a unique value of \( \theta \) that minimizes \( \Risk_z(\theta) \), say \( \psi_z \). Now the over-all risk will naturally be \[ \begin{eqnarray} \Risk(\theta) & = & \sum_{z}{\Risk_z(\theta) \Prob{Z=z}}\\ & = & \sum_{z}{\Risk_z(\theta) p_z} \end{eqnarray} \] taking the last line to define \( p_z \). If we assume there's a nice minimum, so the first order conditions are satisfied at \( \theta=\theta^* \), \[ \begin{eqnarray} 0 & = & \nabla_{\theta} \Risk(\theta^*)\\ & = & \sum_{z}{p_z \nabla_{\theta} \Risk_z(\theta^*)} \end{eqnarray} \]

Now it could be that all the parameters which optimize the conditional risks
for each class are all the same, \( \psi_z = \theta^* \) for all \( z \), but
that seems like a lot to ask. More plausibly every category has a
*different* optimal parameter. We will not then set the over-all risk
gradient to zero by minimizing the risk for any one category. Rather, \(
\theta^* \) will be a compromise which isn't optimal for any category. But
this compromise will inevitably be more weighted towards the larger categories.

Say there are only two categories, but \( p_1 = 7/8, p_2 = 1/8 \). Then
\[
\nabla_{\theta} \Risk_1(\theta^*) = - \frac{1}{7} \nabla_\theta \Risk_2(\theta^*)
\]
That is, the over-all optimum is one where things *could* be made better
for the minority (class 2), but doing so would impose much smaller (seven times
smaller) costs in the majority (class 1). On the way to the optimum, changes
which benefit the majority at the minority's expense are easier to make than
vice versa (that 7:1 ratio again).

With more than two groups, and picking out group \( k \) as our focus, \[ \nabla_{\theta} \Risk_k(\theta^*) = - \frac{\sum_{z \neq k}{p_z \nabla \Risk_z(\theta^*)}}{p_k} \]

Even if the numerator on the right-hand is near zero, if \( p_k \) in the denominator is small then this can still end up implying a large gradient for the conditional risk of category \( k \), i.e., that group being far from its optimum. Of course if the sum in the numerator is not small this implies an even bigger gradient.

Now it should be possible to say something (crudely) about how big the numerator is. Suppose that we could rule out category \( k \), setting \( p_k = 0 \), and the resulting optimal parameter value was not \( \theta^* \) but \( \theta^\prime \). Then we'd have \[ \sum_{z\neq k}{\frac{p_z}{1-p_k} \nabla \Risk_z(\theta^\prime)} = 0 \] (since \( \Prob{Z=z|Z\neq=k} = \frac{p_z}{1-p_k} \) ). If \( p_k \) is small, then the difference between \( \theta^* \) and \( \theta^\prime \) should itself be of order \( p_k \); presumably I just need to make some Taylor expansions here.

#### Small-group risk as a penalty on large-group risk

Think back to the two-group case for simplicity. The problem we're posing is \[ \min_{\theta}{p_1 \Risk_1(\theta) + p_2 \Risk_2(\theta)} \] or equivalently \[ \min_{\theta}{\Risk_1(\theta) + \frac{p_2}{p_1}\Risk_2(\theta)} \] Now this exactly the same form as a penalized optimization problem where the main goal is to minimize \( \Risk_1(\theta) \), with a penalty term proportional to \( \Risk_2(\theta) \); the penalty strength factor is \( p_2/p_1 \). Ordinarily, if we're doing penalized optimization, we'd use something like cross-validation to set the strength of the penalty. Alternately, we might think of the penalty term as coming from using a Lagrange multiplier to enforce a constraint, which'd here be of the form \( \Risk_2(\theta) \leq c \), and the strength of the penalty would be set by th need to make the constraint hold. Here however the strength of the penalty is fixed by the relative size of the two groups; we can trace the Lagrange-multiplier reasoning backwards and work out the implied constraint on the conditional risk for group 2. (The implied \( c \) is a decreasing function of \( p_2/p_1 \).)To sum up, the two-group problem is equivalent to "minimize the
risk for the majority, subject to the constraint that the minority risk not
be *too* big", with that constraint getting weaker and weaker as the
minority gets smaller and smaller. If there are not two groups but many, we
can recast the problem as minimizing the risk for the largest group,
with constraints on the risk of all the smaller groups, those constraints
weakening as the groups get smaller.

(Of course we could make a small group "group 1" and treat its conditional risk as the real objective function, but then it'd be subject to tight constraints on all the other conditional risks.)

References to come; there is a definite set of applications I have in mind here.

- To read:
- John Duchi, Tatsunori Hashimoto, Hongseok Namkoong, "Distributionally Robust Losses for Latent Covariate Mixtures", arxiv:2007.13982 [Doesn't seem to be quite what I have in mind, fortunately for me]
- Frederik Kunstner, Robin Yadav, Alan Milligan, Mark Schmidt, Alberto Bietti, "Heavy-Tailed Class Imbalance and Why Adam Outperforms Gradient Descent on Language Models", arxiv:2402.19449