Self-Attention Visual Diagram

How Self-Attention Works — Visually Explained

Self-attention is the Transformer’s secret: each token decides which other tokens matter, and by how much. In this gentle guide, we build your intuition first, then show the math, shapes, and a tiny numerical walk-through. We’ll cover scaled dot-product attention, masking, multi-head attention, positional encoding, encoder vs decoder vs cross-attention, complexity, and practical PyTorch/NumPy code you can run.

ELI5: Imagine a group discussion where every word can “look at” the other words and decide who to listen to. The louder a word speaks (higher attention), the more it influences the final meaning of the sentence.

1) Intuition: why attention?


In language, each word’s meaning depends on context. In “The bank will not loan money,” the word “bank” relates to finance; in “The boat reached the bank,” it relates to a river. Self-attention lets each token ask the others, “Are you relevant to me?” and then blend information accordingly.

Contextual relevance arrows between words

Unlike RNNs (sequential) or CNNs (local), attention compares every token with every other token in parallel, which is powerful—but also quadratic in sequence length.


2) Queries, Keys, Values (Q/K/V)


Start with token embeddings \(X\in\mathbb{R}^{n\times d_\text{model}}\) (n = tokens). We learn three projections:

$$ Q = XW_Q,\quad K = XW_K,\quad V = XW_V, $$

with \(W_Q,W_K,W_V\in\mathbb{R}^{d_\text{model}\times d_k}\) (often \(d_k=d_v=d_\text{model}/h\) for h heads). Intuition:

Shapes — If \(X\) is [n, d_model], and we use one head with \(d_k\), then \(Q,K,V\) are [n, d_k].


3) Scaled Dot-Product Attention


Scores are dot products between queries and keys. We scale by \(\sqrt{d_k}\) to keep gradients stable and apply softmax row-wise:

$$ \text{Attn}(Q,K,V)=\underbrace{\text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)}_{\text{attention weights }A}\;V. $$

The matrix \(A\in\mathbb{R}^{n\times n}\) holds how strongly each token attends to each other token. Each row sums to 1.


4) Tiny numeric example (3 tokens, 2-D head)


Let \(Q,K,V\in\mathbb{R}^{3\times 2}\). Suppose \[ Q=\begin{bmatrix}1&0\\0&1\\1&1\end{bmatrix},\; K=\begin{bmatrix}1&1\\1&0\\0&1\end{bmatrix},\; V=\begin{bmatrix}2&0\\0&2\\1&1\end{bmatrix}. \] Compute \(S=QK^\top\) (3×3): \[ S=\begin{bmatrix} 1\cdot1+0\cdot1 & 1\cdot1+0\cdot0 & 1\cdot0+0\cdot1\\ 0\cdot1+1\cdot1 & 0\cdot1+1\cdot0 & 0\cdot0+1\cdot1\\ 1\cdot1+1\cdot1 & 1\cdot1+1\cdot0 & 1\cdot0+1\cdot1 \end{bmatrix} = \begin{bmatrix} 1&1&0\\ 1&0&1\\ 2&1&1 \end{bmatrix}. \] Scale by \(\sqrt{d_k}=\sqrt{2}\) and softmax each row to get \(A\). For row 1: softmax\((\tfrac{1}{\sqrt2},\tfrac{1}{\sqrt2},0)\) ≈ (0.401,0.401,0.198). Then output = \(A V\) → each token becomes a weighted mix of the rows of \(V\).

This small arithmetic is worth doing once by hand; it makes the abstraction “click.”


5) Masking: causal (decoder) & padding (batching)


Causal mask (used in decoders / autoregressive models like GPT): token \(t\) cannot look ahead to \(t+1,\dots\). We add \(-\infty\) to forbidden logits before softmax.

$$ A=\text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}+\text{mask}\right), \quad \text{mask}_{ij}=\begin{cases}0 & j\le i\\ -\infty & j>i.\end{cases} $$

Padding mask: When batching sequences of different lengths, pad with zeros and mask those positions so they receive and send no attention.


6) Multi-Head Attention (why more than one head?)


One head learns one notion of “relevance.” Multiple heads let the model attend to different patterns (syntax, coreference, long-range dependencies) simultaneously. Each head has its own \(W_Q,W_K,W_V\), smaller per-head dimension \(d_k\), and outputs are concatenated then projected:

$$ \text{MHA}(X) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)\,W_O. $$
Multi-head attention: split, attend, concat, project

Shapes — With h heads and model dim d_model, we often set d_k = d_model / h. Each head: [n, d_k] → concat: [n, h*d_k]=[n, d_model].


7) Positional Encoding (order matters!)


Self-attention has no sense of order by itself. We add positional encodings to token embeddings so the model knows token positions.

$$ \text{PE}_{(pos,2i)} = \sin\!\left(\frac{pos}{10000^{2i/d_\text{model}}}\right),\quad \text{PE}_{(pos,2i+1)} = \cos\!\left(\frac{pos}{10000^{2i/d_\text{model}}}\right). $$

These sinusoids let the model infer relative positions. Many variants exist (learned, rotary, ALiBi, etc.), but sinusoidal remains a clear starting point.

Sinusoidal positional encoding stripes

8) Encoder, Decoder & Cross-Attention (where self-attention lives)


Encoder self-attention, decoder masked self-attention, and cross-attention

9) Complexity & efficiency


Computing \(QK^\top\) is \(O(n^2 d_k)\) time and \(O(n^2)\) memory for the attention matrix \(A\). That’s why long contexts are expensive. Practical tricks include:

For beginners: start with the standard implementation; move to efficient variants only when you hit memory/time limits.


10) NumPy: self-attention in ~30 lines (with optional mask)


import numpy as np

def softmax(x, axis=-1):
    x = x - x.max(axis=axis, keepdims=True)
    e = np.exp(x)
    return e / e.sum(axis=axis, keepdims=True)

def self_attention_numpy(X, Wq, Wk, Wv, mask=None):
    """
    X:  [n, d_model]
    W*: [d_model, d_k]
    mask: [n, n] with 0 for allowed, -inf for blocked (added before softmax)
    """
    Q = X @ Wq   # [n, d_k]
    K = X @ Wk   # [n, d_k]
    V = X @ Wv   # [n, d_k] (or d_v)

    scores = (Q @ K.T) / np.sqrt(K.shape[1])  # [n, n]
    if mask is not None:
        scores = scores + mask                # add -inf to illegal connections

    A = softmax(scores, axis=-1)              # attention weights
    Y = A @ V                                 # [n, d_k]
    return Y, A

# Example
np.random.seed(0)
n, d_model, d_k = 5, 16, 8
X  = np.random.randn(n, d_model)
Wq = np.random.randn(d_model, d_k)
Wk = np.random.randn(d_model, d_k)
Wv = np.random.randn(d_model, d_k)

# causal mask (upper-triangular set to -inf)
mask = np.triu(np.ones((n,n)) * -1e9, k=1)
Y, A = self_attention_numpy(X, Wq, Wk, Wv, mask=mask)
print("Output shape:", Y.shape, "  Attn shape:", A.shape)

Notice the mask shape matches the \(n\times n\) attention score matrix.


11) PyTorch: from scratch and with MultiheadAttention


import torch, torch.nn as nn, math
torch.manual_seed(0)

def scaled_dot_product_attention(Q, K, V, attn_mask=None):
    # Q,K,V: [B, h, n, d_k]
    scores = Q @ K.transpose(-2, -1) / math.sqrt(Q.size(-1))  # [B,h,n,n]
    if attn_mask is not None:
        scores = scores + attn_mask                           # add -inf where blocked
    A = torch.softmax(scores, dim=-1)
    Y = A @ V                                                 # [B,h,n,d_k]
    return Y, A

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model=128, num_heads=4):
        super().__init__()
        assert d_model % num_heads == 0
        self.h = num_heads
        self.d_k = d_model // num_heads
        self.Wq = nn.Linear(d_model, d_model, bias=False)
        self.Wk = nn.Linear(d_model, d_model, bias=False)
        self.Wv = nn.Linear(d_model, d_model, bias=False)
        self.Wo = nn.Linear(d_model, d_model, bias=False)

    def forward(self, X, attn_mask=None):
        B, n, d_model = X.shape
        Q = self.Wq(X).view(B, n, self.h, self.d_k).transpose(1, 2)  # [B,h,n,d_k]
        K = self.Wk(X).view(B, n, self.h, self.d_k).transpose(1, 2)
        V = self.Wv(X).view(B, n, self.h, self.d_k).transpose(1, 2)

        Y, A = scaled_dot_product_attention(Q, K, V, attn_mask=attn_mask)  # [B,h,n,d_k],[B,h,n,n]
        Y = Y.transpose(1, 2).contiguous().view(B, n, self.h*self.d_k)     # concat heads
        out = self.Wo(Y)                                                   # [B,n,d_model]
        return out, A

# Example usage
B, n, d_model, h = 2, 6, 128, 4
X = torch.randn(B, n, d_model)

# causal mask: [1,1,n,n] broadcastable to [B,h,n,n]
mask = torch.triu(torch.ones(n, n)*-1e9, diagonal=1).unsqueeze(0).unsqueeze(0)
mha = MultiHeadSelfAttention(d_model=d_model, num_heads=h)
Y, A = mha(X, attn_mask=mask)
print(Y.shape, A.shape)  # torch.Size([2, 6, 128]) torch.Size([2, 4, 6, 6])

Built-in layer. PyTorch also offers nn.MultiheadAttention (note its default input shape is [seq_len, batch, embed_dim]):

mha = nn.MultiheadAttention(embed_dim=d_model, num_heads=h, batch_first=True)
# For self-attention, Q=K=V=X
causal_mask = torch.triu(torch.ones(n, n)*float('-inf'), diagonal=1)  # [n,n]
Y, A = mha(X, X, X, attn_mask=causal_mask)  # Y: [B,n,d_model], A: [B,h,n,n] (as of recent versions)
Decoder blocks: use the causal mask. Encoder blocks: no causal mask; optionally use a padding mask to hide padded tokens.

12) Visualizing attention (what heads learn)


Plot \(A\) as a heatmap: rows = query positions, columns = key positions. Different heads often specialize:

Attention heatmap visualization

Not every head is interpretable; that’s normal. Use these plots for debugging intuition, not as a definitive explanation.


13) Mini-Glossary



14) References & Further Reading


  1. Vaswani, A. et al. (2017). Attention Is All You Need. NIPS. (Transformer overview)
  2. Jay Alammar. The Illustrated Transformer. (Great visual blog)
  3. Bengio, Goodfellow, Courville. Deep Learning — sequence models & attention.
  4. Wikipedia: Attention (ML), Softmax, LayerNorm, Positional encoding.

15) Video: Karpathy — “Let’s build GPT from scratch”




16) FAQ


Q1. Why the \(\sqrt{d_k}\) scaling?

Dot products grow with \(d_k\). Scaling keeps logits in a good range so softmax isn’t overly peaky, stabilizing gradients.

Q2. What’s the difference between encoder self-attention and decoder self-attention?

Encoders attend bidirectionally (see all tokens). Decoders use a causal mask to prevent peeking at future tokens during generation.

Q3. Do I need multi-head attention?

Almost always yes. Multiple heads let the model capture different relationships in parallel; single-head is a useful teaching simplification.

Q4. How do positional encodings get added?

Typically by addition: \(X_\text{input}=X_\text{embed}+\text{PE}\). Learned or sinusoidal both work; sinusoidal gives nice extrapolation properties.

Q5. Why are long contexts expensive?

The attention matrix is \(n\times n\). Both compute and memory grow quadratically with sequence length.