λmem.ac
imagesmodepost cover

Attention Is All You Need (and So Are You)

Introduction

Transformers replaced recurrence with a single, surprisingly simple idea: let every token decide, for itself, which other tokens matter. That decision is attention — and once you see it as a weighted average, the rest of the architecture quietly falls into place.

In this post we build scaled dot-product attention from the ground up, add multiple heads, and poke at why the dk\sqrt{d_k} term keeps everything numerically calm.

infoNoteexpand_more

Attention was introduced for translation long before Transformers — the 2017 paper’s trick was to throw away recurrence entirely and keep only attention.

Scaled dot-product attention

Given queries QQ, keys KK and values VV, attention scores each query against every key, normalises with a softmax, and uses the result to mix the values:

Attention(Q,K,V)=softmax ⁣(QKdk)V\mathrm{Attention}(Q,K,V)=\mathrm{softmax}\!\left(\frac{QK^{\top}}{\sqrt{d_k}}\right)V

Here dkd_k is the dimension of the keys. The raw dot products grow with dkd_k, so dividing by dk\sqrt{d_k} keeps the logits in a sane range before the softmax saturates.

lightbulbTipexpand_more

In code, scale before the softmax, not after — scaling the probabilities instead of the logits silently breaks the gradient.

attention as a soft lookup table
attention as a soft lookup table

Multi-head attention

One attention map can only emphasise one kind of relationship. Multi-head attention runs several in parallel, each with its own learned projections, then concatenates them:

MultiHead(Q,K,V)=Concat(head1,,headh)WO\mathrm{MultiHead}(Q,K,V)=\mathrm{Concat}(\mathrm{head}_1,\dots,\mathrm{head}_h)\,W^{O}

where each headi=Attention(QWiQ,KWiK,VWiV)\mathrm{head}_i=\mathrm{Attention}(QW_i^{Q},\,KW_i^{K},\,VW_i^{V}) learns to ask a different question of the same sentence.

Think of each head as a different lens: one tracks syntax, another long-range topic, another simple position. Concatenation lets the next layer use them all.

Why does it work?

a right-floated figure

Self-attention is O(n2d)O(n^2 d) in sequence length nn — every token looks at every other token. That quadratic cost buys something precious: a constant path length between any two positions, so gradients flow without the long detours an RNN forces.

Because the graph is fully connected, information from the first token can reach the last in a single layer. Depth then stops being about reach and becomes about refinement: each block re-weights the same global context a little more sharply.

scores  = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k)
weights = scores.softmax(dim=-1)
out     = weights @ V          # the soft lookup
warningWarningexpand_more

That memory is real: a 4k-token sequence needs a 16M-entry attention matrix per head. Watch your VRAM.

dangerousDangerexpand_more

Never feed un-masked future tokens to a decoder at train time — the model will “cheat” by attending ahead and collapse at inference.

A note on numerical stability

Softmax is shift-invariant, so subtracting the row maximum before exponentiating changes nothing mathematically but everything numerically: softmax(x)i=softmax(xmaxjxj)i\mathrm{softmax}(x)_i=\mathrm{softmax}(x-\max_j x_j)_i. Every sane implementation does this for free.

SymbolMeaningShape
QQqueriesn×dkn \times d_k
KKkeysn×dkn \times d_k
VVvaluesn×dvn \times d_v

Conclusion

Attention is just a learned, differentiable lookup. Everything else — heads, scaling, positional encodings — is bookkeeping around that one idea. Next time we’ll add the positions back in. ✦

menu_book

References

4
  1. [1]Vaswani et al. — Attention Is All You Need (2017)linkhttps://arxiv.org/abs/1706.03762
  2. [2]Alammar — The Illustrated Transformerlinkhttps://jalammar.github.io/illustrated-transformer/
  3. [3]Bahdanau et al. — Neural Machine Translation by Jointly Learning to Align and Translatelinkhttps://arxiv.org/abs/1409.0473
  4. [4]PyTorch docs — torch.nn.MultiheadAttentionlinkhttps://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html