Notebooks

Symmetries of Neural Networks

14 Aug 2023 13:41

(First draft 2 April 2023)

If we think about a multi-layer neural network as a statistical model, or even just an input-output mapping, it's specified by giving (1) the "squashing" function used by each unit, (2) the weight each unit gives to the activation of each other unit, and (3) the "bias" term for each unit. (More clearly, in symbols, \( A_i = \sigma(b_i + \sum_{j}{w_{ij} A_j}) \), using \( A_i \) for the activation of neuron \( i \).) We can get rid of (3) by adding one more unit whose activation is fixed to 1 all the time, and treat (1) as fixed (e.g., tanh or rectified-linear), so the model is specified by the weight matrix \( \mathbf{w} \).

A little thought should make it clear that this is a very redundant way of specifying models; it's "over-parametrized" or, as statisticians would say, un-identified. Consider just a three-layer feed-forward architecture, so we have input vector \( X \), middle layer activation vector \( Z \), and outputs \( Y \). We can swap any two units in the middle layer without changing the mapping from \( X \) to \( Y \) at all. Said differently, it doesn't matter what order we list the middle-layer units in. So there input-output mapping is invariant under permutations of the weight matrix; it's permutation-symmetric.

But there should also be more continuous symmetries. In fact, I think there should be something like a version of the "rotation problem" from factor models. Imagine for a moment that we had a linear three-layer network, so \( Z = \mathbf{u} X \) and \( Y = \mathbf{v} Z \). Pick any invertible matrix \( \mathbf{r} \) and consider \( \mathbf{u}^{\prime} = \mathbf{r}\mathbf{u} \) and \( \mathbf{v}^{\prime} = \mathbf{v}\mathbf{r}^{-1} \). Clearly \( \mathbf{v}^{\prime} \mathbf{u}^{\prime} X = \mathbf{v}\mathbf{u} X \), so these new weights would leave the input-output mapping unchanged. If we restricted \( \mathbf{r} \) to being an orthogonal matrix, we'd even leave (lots) of norms on the internal layer activations unchanged, and we'd be really close to the rotation problem. Now clearly nonlinearities will complicate things, it's really \( Z = \sigma(\mathbf{u} X) \) and \( Y = \sigma(\mathbf{v} Z) \), but this feels like a place to start...

Partial identification: Clearly, some weight matrices do lead to different input-output mappings, so \( \mathbf{w} \) isn't totally un-identified, just partially identified. Say that two weight matrices are equivalent iff they produce exactly the same input-output mapping, \( \mathbf{w} \sim \mathbf{w}^{\prime} \). Then (in the usual partial-identification jargon) the maximal identifiable parameter is the equivalence class \( [\mathbf{w}] \equiv \left\{ \mathbf{v} : \mathbf{w} \sim \mathbf{v} \right\} \). A very natural question is then to geometrically describe these equivalence classes. Permutation symmetry means that there will be disconnected components (in the original geometry of the weight matrices), but if there's anything like the rotation problem there will also be continuous sets of equivalent weight matrices. (Continuous latent space models for networks lead to a similar structure of equivalence classes.) Can we say anything, in general, about the dimension of these continuous sets? Can we describe the maximially-identified parameter in a more direct way? (One possibility would be to pick out some "canonical" member of each equivalence class, in such a way that if \( \mathbf{w} \not \sim \mathbf{v} \) but \( \mathbf{w} \) and \( \mathbf{v} \) are close, their canonical versions are also close. [Shades of locality-sensitive hashing...])

In fact what I'd expect is many copies (by permutation [and reflection?]) of one continuous set of equivalent weights. (This is a hunch, not a result.) This would of course help with constructing a canonical member of the equivalence class.

Interpretation: If neural networks do have the rotation problem (or something like it), I think that's a good reason to not interpret specific units as representing specific features of the data, even if they are highly activated by inputs with those features. I realize this is common now-a-days, but as an old coot who learned about neural networks in the early '90s, I was taught they give us distributed representations, so interpreting individual units has always seemed grandmother-cell-ish to me.

Or, at the very least: you could say that in this neural net, unit 37 represents Nonna Anna's face, but also that there are completely equivalent networks where no one unit does so, and the fact that we got that feature represented so simply is just luck-of-the-draw, or rather luck-of-the-optimizer-initialization.

Optimization: One puzzling aspect of neural networks is why comparatively simple optimization approaches find pretty good weights. Back-propagation, e.g., is just a (clever!) way of doing the book-keeping for gradient descent; we know that gradient descent works poorly for general non-convex objective functions; and the objective function for a neural network, as a function of its weights, is highly non-convex. So one thought here is that the landscape being searched over isn't as bad as it first seems, because there's actually a lot of symmetry. The partial-identification ideas might actually be useful here: confine the search to the subset of "canonical" weight matrices, without any loss.

Ignorance

None of the above is particularly deep, so I am sure people have much more fully-worked-out thoughts about this in the literature, which I need to find and absorb. Shortly after starting to think about this, I chanced across Ainsworth, Hayase and Srinivasa (2022), which shows the importance of permutation symmetry for realistic neural network training, and contains pointers to papers from the 1990s which I need to follow up. (In particular the papers by Hecht-Nielsen and collaborators.)

Update, 24 April 2023: Since writing the above I have had a chance to read Chen et al. (1993), and Kurkova and Kainen (1994). The former considers groups of symmetries composed from exchanging interior nodes, and sign flips on weights. (That is, you could switch the weights on an interior node so it outputs \( -1 \times \) what it did before, provided you also multiply all the weights on that node by \( -1 \).) They work out a pretty complete theory of this group of discrete symmetries, though they allow some narrow room for continuous symmetries that they don't really say much about. Kurkova and Kainen claim to give the complete symmetry group, which is built up out of node exchanges and a generalization of sign flips which I think allows for something more rotation-like. I need to re-read the latter paper, to really understand it, and all of these refer to some work by Sussman which I need to track down, or at least record here. (12 July 2023: I now have the paper by Sussman, and some related work by Albertini and Sontag, references below, but I've not read them yet.) Both papers make the point that these symmetries imply large neural networks have immense numbers of copies of the optimal weight configuration, which can help explain how learning from different starting points converges on networks with nearly-equal behavior.


Notebooks: