Attention is the engine behind every modern LLM, yet most explanations either drown you in notation or skip the intuition entirely. Let’s fix that.
What problem does attention solve?
Before transformers, sequence models like LSTMs compressed the entire input into a fixed-size vector — a bottleneck that shredded long-range context. Attention throws that constraint away: every output token can directly attend to every input token.
Attention is just a differentiable lookup table. You have queries (what you’re looking for), keys (what’s available), and values (what you actually retrieve).
The math, minimally
Given queries Q, keys K, and values V:
Attention(Q, K, V) = softmax(QKᵀ / √d_k) · VThe √d_k scaling prevents the dot products from growing so large that softmax saturates into near-zero gradients.
A working implementation
import torchimport torch.nn.functional as F
def scaled_dot_product_attention( q: torch.Tensor, # (batch, heads, seq, d_k) k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor | None = None,) -> tuple[torch.Tensor, torch.Tensor]: d_k = q.size(-1)
# (batch, heads, seq_q, seq_k) scores = torch.matmul(q, k.transpose(-2, -1)) / (d_k ** 0.5)
if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf'))
weights = F.softmax(scores, dim=-1) output = torch.matmul(weights, v) return output, weightsPyTorch 2.0+ ships F.scaled_dot_product_attention with FlashAttention support built in. Use that in production — the implementation above is for learning.
Multi-head attention
Running multiple attention heads in parallel lets the model attend to different representation subspaces simultaneously.
import torch.nn as nn
class MultiHeadAttention(nn.Module): def __init__(self, d_model: int, n_heads: int): super().__init__() assert d_model % n_heads == 0 self.d_k = d_model // n_heads self.n_heads = n_heads
self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None): batch = q.size(0)
def split_heads(x): return x.view(batch, -1, self.n_heads, self.d_k).transpose(1, 2)
q = split_heads(self.W_q(q)) k = split_heads(self.W_k(k)) v = split_heads(self.W_v(v))
attn_out, _ = scaled_dot_product_attention(q, k, v, mask)
# Recombine heads attn_out = attn_out.transpose(1, 2).contiguous().view(batch, -1, self.n_heads * self.d_k) return self.W_o(attn_out)Visualising attention weights
One of the best debugging tools is plotting which tokens attend to which:
import matplotlib.pyplot as plt
def plot_attention(weights: torch.Tensor, tokens: list[str], head: int = 0): """weights: (batch, heads, seq, seq)""" fig, ax = plt.subplots(figsize=(8, 6)) im = ax.imshow(weights[0, head].detach().cpu(), cmap='Blues') ax.set_xticks(range(len(tokens))) ax.set_yticks(range(len(tokens))) ax.set_xticklabels(tokens, rotation=45, ha='right') ax.set_yticklabels(tokens) plt.colorbar(im, ax=ax) plt.title(f'Attention weights — head {head}') plt.tight_layout() plt.show()What’s next
- Positional encodings — how transformers learn token order
- KV-cache — why inference is O(n²) and how to tame it
- Flash Attention — rewriting attention to live in SRAM
The attention mechanism is deceptively simple once you strip away the jargon. Everything downstream — BERT, GPT, Llama — is just variations on this theme.