is a relatively new sub-field in AI, focused on understanding how neural networks function by reverse-engineering their internal mechanisms and representations, aiming to translate them into human-understandable algorithms and concepts. This is in contrast to and further than traditional explainability techniques like SHAP and LIME.

SHAP stands for SHapley Additive exPlanations. It computes the contribution of each feature to the prediction of the model, locally and globally, that is for a single example as well as across the whole dataset. This allows SHAP to be used to determine feature importance in general for the use case. LIME, meanwhile, works on a single example-prediction pair where it perturbs the example input and uses the perturbations and its outputs to approximate a simpler substitute of the black-box model. As such, both of these work at a feature level and give us some explanation and heuristic to gauge how each input into the model affects its prediction or output.

On the other hand, mechanistic interpretation understands things at a more granular level in that it is capable of providing a pathway of how the said feature is learnt by different neurons in different layers in the neural network, and how that learning evolves over the layers in the network. This makes it adept at tracing paths inside the network for a particular feature and also seeing how that feature affects the outcome. 

SHAP and LIME, then, answer the question “which feature contributes the most to the outcome?” whereas mechanistic interpretation answers the question “which neurons activate for which feature, and how does that feature evolve and affect the outcome of the network?

Since explainability in general is a problem with deeper networks, this sub-field majorly works with deeper models like the transformers. There are a few places where mechanistic interpretability looks at transformers differently than the traditional way, one of which is multi-head attention. As we will see, this difference is in reframing the multiplication and concatenation operations as defined in the “Attention is All You Need” paper as addition operations which opens a whole range of new possibilities.

But first, a recap of the Transformer architecture.

Transformer Architecture

Image by Author: Transformer Architecture

These are the sizes we work with:

  • batch_size B =1;
  • sequence length S = 20;
  • vocab_size V = 50,000;
  • hidden_dims D = 512;
  • heads H = 8

This means that the number of dimensions in the Q, K, V vectors is 512/8 (L) = 64. (In case you don’t remember, an analogy for understanding query, key and value: The idea is that for a token at a given position (K), based on its context (Q) we want to get alignment (reweighing) to the positions it’s relevant to (V).)

These are the steps upto the attention computation in a transformer. (The shape of tensors is assumed as an example for better understanding. Numbers in italic represent the dimension along which the matrix is multiplied.)

Step Operation Input 1 Dims (Shape) Input 2 Dims (Shape) Output Dims (Shape)
1 N/A B x S x V
(1 x 20 x 50,000)
N/A B x S x V
(1 x 20 x 50,000)
2 Get embeddings B x S x V
(1 x 20 x 50,000)
V x D
(50,000 x 512)
B x S x D
(1 x 20 x 512)
3 Add positional embeddings B x S x D
(1 x 20 x 512)
N/A B x S x D
(1 x 20 x 512)
4 Copy embeddings to Q, K, V B x S x D
(1 x 20 x 512)
N/A B x S x D
(1 x 20 x 512)
5 Linear transform for each head H=8 B x S x D
(1 x 20 x 512)
D x L
(512 x 64)
BxHxSxL
(1 x 1 x 20 x 64)
6 Scaled Dot Product (Q@K’) in each head BxHxSxL
(1 x 1 x 20 x 64)
(LxSxHxB)
(64 x 20 x 1 x 1)
BxHxSxS
(1 x 1 x 20 x 20) 
7 Scaled Dot Product (Attention calculation) Q@K’V in each head BxHxSxS
(1 x 1 x 20 x 20)
BxHxSxL
(1 x 1 x 20 x 64)
BxHxSxL
(1 x 1 x 20 x 64)
8 Concat across all heads H=8 BxHxSxL
(1 x 1 x 20 x 64)
N/A B x S x D
(1 x 20 x 512)
9 Linear projection B x S x D
(1 x 20 x 512)
D x D
(512 x 512)
B x S x D
(1 x 20 x 512)
Tabular view of shape transformations towards attention computation in the Transformer

The table explained in detail:

  1. We start with one input sentence of a sequence length of 20 that is one-hot encoded to represent words in the vocabulary present in the sequence. Shape (B x S x V): (1 x 20 x 50,000)
  2. We multiply this input with the learnable embedding matrix Wₑ of shape (V x D) to get the embeddings. Shape (B x S x D): (1 x 20 x 512)
  3. Next a learnable positional encoding matrix of the same shape is added to the embeddings
  4. The resultant embeddings are then copied to the matrices Q, K and V. Q, K and V each are split and reshaped on the D dimension. Shape (B x S x D): (1 x 20 x 512)
  5. The matrices for Q, K and V are each fed to a linear transformation layer that multiplies them with learnable weight matrices each of shape (D x L) Wq, Wₖ and Wᵥ, respectively (one copy for each of the H=8 heads). Shape (B x H x S x L): (1 x 1 x 20 x 64) where H=1, as this is the resultant shape for each head.
  6. Next, we compute attention with Scaled Dot Product attention where Q and K (transposed) are multiplied first in each head. Shape (B x H x S x L) x (L x S x H x B) → (B x H x S x S): (1 x 1 x 20 x 20). 
  7. There is a scaling and masking step next that I have skipped as that is not important in understanding what is the different way of looking at MHA. So, next we multiply QK with V for each head. Shape (B x H x S x S) x (B x H x S x L) → (B x H x S x L): (1 x 1 x 20 x 64)
  8. Concat: Here, we concatenate the results of attention from all the heads at the L dimension to get back a shape of (B x S x D) → (1 x 20 x 512)
  9. This output is once more linearly projected using yet another learnable weight matrix Wₒ of shape (D x D). Final shape we end with (B x S x D): (1 x 20 x 512)

Reimagining Multi-Head Attention

Image by Author: Reimagining Multi-head attention

Now, let’s see how the field of mechanistic interpretation looks at this, and we will also see why it is mathematically equivalent. On the right in the image above, you see the module that reimagines multi-head attention. 

Instead of concatenating the attention output, we proceed with the multiplication “inside” the heads itself where now the shape of Wₒ is (L x D) and multiply with QK’V of shape (B x H x S x L) to get the result of shape (B x S x H x D): (1 x 20 x 1 x 512). Then, we sum over the H dimension to again end with the shape (B x S x D): (1 x 20 x 512).

From the table above, the last two steps are what changes:

Step Operation Input 1 Dims (Shape) Input 2 Dims (Shape) Output Dims (Shape)
8 Matrix multiplication in each head H=8 BxHxSxL
(1 x 1 x 20 x 64)
L x D
(64 x 512)
BxSxHxD
(1 x 20 x 1 x 512)
9 Sum over heads (H dimension) BxSxHxD
(1 x 20 x 1 x 512)
N/A B x S x D
(1 x 20 x 512)

Side note: This “summing over” is reminiscent of how summing over different channels happens in CNNs. In CNNs, each filter operates on the input, and then we sum the outputs across channels. Same here — each head can be seen as a channel, and the model learns a weight matrix to map each head’s contribution into the final output space.

But why is the project + sum mathematically equivalent to concat + project? In short, because the projection weights in the mechanistic perspective are just sliced versions of the weights in the traditional view (sliced across the D dimension and split to match each head).

Image by Author: Why the re-imagining works

Let’s focus on the H and D dimensions before the multiplication with Wₒ. From image above, each head now has a vector of size 64 that is being multiplied with the weight matrix of shape (64 x 512). Let’s denote the result by R and head by h.

To get R₁₁, we have this equation: 

R₁,₁ = h₁,₁ x Wₒ₁,₁ + h₁,₂ x Wₒ₂,₁ + …. + h₁ₓ₆₄ x Wₒ₆₄,₁

Now let’s say we had a concatenated the heads to get an attention output shape of (1 x 512) and the weight matrix of shape (512, 512) then the equation would have been:

R₁,₁ = h₁,₁ x Wₒ₁,₁ + h₁,₂ x Wₒ₂,₁ + …. + h₁ₓ₅₁₂ x Wₒ₅₁₂,₁

So, the part h₁ₓ₆x Wₒ₆,₁ + … + h₁ₓ₅₁₂ x Wₒ₅₁₂,₁ would have been added. But this part being added is the part that is present in each of the other heads in modulo 64 fashion. Said another way, if there is no concatenation, Wₒ₆,₁ is the value behind Wₒ₁,₁ in the second head, Wₒ₁₂₉,₁ is the value behind Wₒ₁,₁ in the third head and so on if we imagine that the values for each head sit behind one another. Hence, even without concatenation, the “summing over the heads” operation results in the same values being added.

In conclusion, this insight lays the foundation of looking at transformers as purely additive models in that all the operations in a transformer take the initial embedding and add to it. This view opens up new possibilities like tracing features as they are learnt via additions through the layers (called circuit tracing) which is what mechanistic interpretability is about as I will show in my next articles.


We have shown that this view is mathematically equivalent to the vastly different view that multi-head attention, by splitting Q,K,V parallelizes and optimizes computation of attention. Read more about this in this blog here and the actual paper that introduces these points is here.

Share.

Comments are closed.