Wesley.dev

Understanding Transformer Attention

Demystifying self-attention with Python demos — from the math to a working implementation in under 100 lines.

Wesley Sum · · 3 min read

Attention is the engine behind every modern LLM, yet most explanations either drown you in notation or skip the intuition entirely. Let’s fix that.

What problem does attention solve?

Before transformers, sequence models like LSTMs compressed the entire input into a fixed-size vector — a bottleneck that shredded long-range context. Attention throws that constraint away: every output token can directly attend to every input token.

Key insight

Attention is just a differentiable lookup table. You have queries (what you’re looking for), keys (what’s available), and values (what you actually retrieve).

The math, minimally

Given queries Q, keys K, and values V:

Attention(Q, K, V) = softmax(QKᵀ / √d_k) · V

The √d_k scaling prevents the dot products from growing so large that softmax saturates into near-zero gradients.

A working implementation

import torch
import torch.nn.functional as F
def scaled_dot_product_attention(
q: torch.Tensor, # (batch, heads, seq, d_k)
k: torch.Tensor,
v: torch.Tensor,
mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
d_k = q.size(-1)
# (batch, heads, seq_q, seq_k)
scores = torch.matmul(q, k.transpose(-2, -1)) / (d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
weights = F.softmax(scores, dim=-1)
output = torch.matmul(weights, v)
return output, weights
Info

PyTorch 2.0+ ships F.scaled_dot_product_attention with FlashAttention support built in. Use that in production — the implementation above is for learning.

Multi-head attention

Running multiple attention heads in parallel lets the model attend to different representation subspaces simultaneously.

import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int):
super().__init__()
assert d_model % n_heads == 0
self.d_k = d_model // n_heads
self.n_heads = n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
batch = q.size(0)
def split_heads(x):
return x.view(batch, -1, self.n_heads, self.d_k).transpose(1, 2)
q = split_heads(self.W_q(q))
k = split_heads(self.W_k(k))
v = split_heads(self.W_v(v))
attn_out, _ = scaled_dot_product_attention(q, k, v, mask)
# Recombine heads
attn_out = attn_out.transpose(1, 2).contiguous().view(batch, -1, self.n_heads * self.d_k)
return self.W_o(attn_out)

Visualising attention weights

One of the best debugging tools is plotting which tokens attend to which:

import matplotlib.pyplot as plt
def plot_attention(weights: torch.Tensor, tokens: list[str], head: int = 0):
"""weights: (batch, heads, seq, seq)"""
fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(weights[0, head].detach().cpu(), cmap='Blues')
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens, rotation=45, ha='right')
ax.set_yticklabels(tokens)
plt.colorbar(im, ax=ax)
plt.title(f'Attention weights — head {head}')
plt.tight_layout()
plt.show()

What’s next

  • Positional encodings — how transformers learn token order
  • KV-cache — why inference is O(n²) and how to tame it
  • Flash Attention — rewriting attention to live in SRAM

The attention mechanism is deceptively simple once you strip away the jargon. Everything downstream — BERT, GPT, Llama — is just variations on this theme.