Phase 07 - Lesson 03
Atención Multi-Cabeza
This lesson includes a graded coding exercise that runs in your browser, unlocked with lifetime access.
Una cabeza de atención aprende una relación a la vez. Ocho cabezas aprenden ocho. Las cabezas son gratis. Usa más de ellas.
Tipo: Build Lenguajes: Python Prerrequisitos: Fase 7 · 02 (Self-Attention desde Cero) Tiempo: ~75 minutos
El Problema
Una sola cabeza de self-attention calcula una matriz de atención. Esa matriz captura un tipo de relación — generalmente la que minimiza la pérdida en cualquiera que sea la señal de entrenamiento. Si tus datos tienen concordancia sujeto-verbo, co-referencia, discurso de largo alcance y segmentación sintáctica todo enredado, una sola cabeza los mezcla en una única distribución soft-max y pierde la mitad de la señal.
La solución del artículo de Vaswani de 2017: ejecutar varias funciones de atención en paralelo, cada una con sus propias proyecciones Q, K, V, y concatenar las salidas. Cada cabeza opera en un subespacio más pequeño de dimensión d_model / n_heads. El total de parámetros se mantiene igual. El poder expresivo aumenta.
La atención multi-cabeza es el estándar con el que viene todo transformer en 2026. La única discusión es sobre cuántas cabezas y si las claves y los valores comparten proyecciones (Grouped-Query Attention, Multi-Query Attention, Multi-head Latent Attention).
El Concepto
Separar. Toma X de forma (N, d_model). Proyecta a Q, K, V cada uno de forma (N, d_model). Haz reshape a (N, n_heads, d_head) donde d_head = d_model / n_heads. Transpone a (n_heads, N, d_head).
Atender en paralelo. Ejecuta la atención por producto punto escalado dentro de cada cabeza. Cada cabeza produce (N, d_head). Las cabezas operan en subespacios diferentes del embedding y nunca se comunican durante el propio cálculo de la atención.
Concatenar y proyectar. Apila las cabezas de vuelta a (N, d_model) y multiplica por una matriz de salida aprendida W_o de forma (d_model, d_model). W_o es donde las cabezas se mezclan.
Por qué funciona. Cada cabeza puede especializarse sin competir con las demás por presupuesto de representación. Estudios de sondeo de 2019–2024 muestran roles distintos por cabeza: cabezas posicionales, cabeza que atiende al token anterior, cabezas de copia, cabezas de entidad nombrada, cabezas de inducción (que están detrás del aprendizaje en contexto).
El linaje de variaciones de 2026:
| Variante | Cabezas Q | Cabezas K/V | Usada por |
|---|---|---|---|
| Multi-cabeza (MHA) | N | N | GPT-2, BERT, T5 |
| Multi-query (MQA) | N | 1 | PaLM, Falcon |
| Grouped-query (GQA) | N | G (ej.: N/8) | Llama 2 70B, Llama 3+, Qwen 2+, Mistral |
| Multi-head latent (MLA) | N | comprimida a bajo rango | DeepSeek-V2, V3 |
GQA es el estándar moderno porque recorta la memoria del caché KV por un factor de N/G manteniendo una calidad casi total. MLA va más allá al comprimir K/V en un espacio latente y luego proyectar de vuelta en el momento del cómputo — cuesta FLOPs, ahorra mucha más memoria.
Constrúyelo
Paso 1: separar cabezas a partir de la atención de cabeza única que ya tenemos
Toma el SelfAttention de la Lección 02 y envuélvelo con un par separar/concatenar. Ve code/main.py para una implementación en numpy; la lógica es:
def split_heads(X, n_heads):
n, d = X.shape
d_head = d // n_heads
return X.reshape(n, n_heads, d_head).transpose(1, 0, 2) # (heads, n, d_head)
def combine_heads(H):
h, n, d_head = H.shape
return H.transpose(1, 0, 2).reshape(n, h * d_head)
Un reshape y un transpose. Sin bucle. Es exactamente lo que hace PyTorch por debajo de nn.MultiheadAttention.
Paso 2: ejecutar atención por producto punto escalado por cabeza
Cada cabeza recibe su propia porción de Q, K, V. La atención se convierte en un matmul por lotes:
def mha_forward(X, W_q, W_k, W_v, W_o, n_heads):
Q = X @ W_q
K = X @ W_k
V = X @ W_v
Qh = split_heads(Q, n_heads) # (heads, n, d_head)
Kh = split_heads(K, n_heads)
Vh = split_heads(V, n_heads)
scores = Qh @ Kh.transpose(0, 2, 1) / np.sqrt(Qh.shape[-1])
weights = softmax(scores, axis=-1)
out = weights @ Vh # (heads, n, d_head)
concat = combine_heads(out)
return concat @ W_o, weights
En hardware real Qh @ Kh.transpose(...) es un único bmm. La GPU ve un único matmul por lotes de forma (heads, N, d_head) × (heads, d_head, N) -> (heads, N, N). Agregar cabezas es gratis.
Paso 3: variante Grouped-Query Attention
Solo cambian las proyecciones de clave y valor. Q recibe n_heads grupos; K y V reciben n_kv_heads < n_heads grupos y se repiten para coincidir:
def gqa_project(X, W, n_kv_heads, n_heads):
kv = split_heads(X @ W, n_kv_heads) # (kv_heads, n, d_head)
repeat = n_heads // n_kv_heads
return np.repeat(kv, repeat, axis=0) # (n_heads, n, d_head)
En la inferencia esto ahorra memoria porque solo n_kv_heads copias viven en el caché KV, no n_heads. Llama 3 70B usa 64 cabezas de query con 8 cabezas KV — una reducción de 8× en el caché.
Paso 4: sondear lo que aprendió cada cabeza
Ejecuta la MHA en una oración corta con 4 cabezas. Para cada cabeza, imprime la matriz de atención (N, N). Verás cabezas diferentes destacando estructuras diferentes incluso con inicialización aleatoria — eso es en parte señal, en parte simetría rotacional en los subespacios.
Úsalo
En PyTorch, la versión de una línea:
import torch.nn as nn
mha = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)
GQA a partir de PyTorch 2.5+:
from torch.nn.functional import scaled_dot_product_attention
# scaled_dot_product_attention auto-dispatches Flash Attention on CUDA.
# For GQA, pass Q of shape (B, n_heads, N, d_head) and K,V of shape
# (B, n_kv_heads, N, d_head). PyTorch handles the repeat.
out = scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=True)
¿Cuántas cabezas? Reglas prácticas de modelos en producción en 2026:
| Tamaño del modelo | d_model | n_heads | d_head |
|---|---|---|---|
| Pequeño (~125M) | 768 | 12 | 64 |
| Base (~350M) | 1024 | 16 | 64 |
| Grande (~1B) | 2048 | 16 | 128 |
| Frontera (~70B) | 8192 | 64 | 128 |
d_head casi siempre cae en 64 o 128. Es la unidad de cuánto puede "ver" una cabeza. Baja de 32 y las cabezas empiezan a pelear con el factor de escala sqrt(d_head); pasa de 256 y pierdes el beneficio de los "muchos pequeños especialistas".
Entrégalo
Ve outputs/skill-mha-configurator.md. La skill recomienda cantidad de cabezas, cantidad de cabezas kv y estrategia de proyección para un nuevo transformer dado un presupuesto de parámetros, longitud de secuencia y objetivo de despliegue.
Ejercicios
- Fácil. Toma la MHA de
code/main.pyy cambian_headsde 1 a 16 cond_model=64fijo. Grafica la pérdida de un modelo minúsculo de una capa en una tarea sintética de copia. ¿Más cabezas ayudan, se estabilizan o perjudican? - Medio. Implementa MQA (una cabeza KV compartida entre todas las cabezas de query). Mide cuánto baja la cantidad de parámetros vs la MHA completa. Calcula cuánto encoge el tamaño del caché KV en la inferencia para N=2048.
- Difícil. Implementa una versión minúscula de Multi-head Latent Attention: comprime K,V a un latente de rango
r, almacena el latente en el caché KV, descomprime en el momento de la atención. ¿En quérla memoria del caché cruza por debajo de 1/8 de la MHA completa mientras la calidad se mantiene dentro de 1 bit de la ppl de validación?
Términos Clave
| Término | Lo que dice la gente | Lo que en realidad significa |
|---|---|---|
| Cabeza | "Un único circuito de atención" | Una proyección Q/K/V de dimensión d_head = d_model / n_heads con su propia matriz de atención. |
| d_head | "Dimensión de la cabeza" | Ancho oculto por cabeza; casi siempre 64 o 128 en producción. |
| Separar / combinar | "Trucos de reshape" | reshape+transpose (N, d_model) ↔ (n_heads, N, d_head) alrededor de la atención. |
| W_o | "Proyección de salida" | Matriz (d_model, d_model) aplicada tras concatenar las cabezas; donde las cabezas se mezclan. |
| MQA | "Una cabeza KV" | Multi-Query Attention: una única proyección K/V compartida. El caché KV más pequeño, algo de pérdida de calidad. |
| GQA | "El estándar desde Llama 2" | Grouped-Query Attention con n_kv_heads < n_heads; repite para coincidir con Q. |
| MLA | "El truco de DeepSeek" | Multi-head Latent Attention: K,V comprimidos a un latente de bajo rango, descomprimidos en el momento de atender. |
| Cabeza de inducción | "El circuito detrás del aprendizaje en contexto" | Un par de cabezas que detecta ocurrencias anteriores y copia lo que vino después de ellas. |
Lectura Adicional
- Vaswani et al. (2017). Attention Is All You Need §3.2.2 — la especificación original de la multi-cabeza.
- Shazeer (2019). Fast Transformer Decoding: One Write-Head is All You Need — el artículo de MQA.
- Ainslie et al. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints — cómo convertir MHA a GQA tras el entrenamiento.
- DeepSeek-AI (2024). DeepSeek-V2 Technical Report — MLA y por qué supera a MHA/GQA en memoria de caché.
- Olsson et al. (2022). In-context Learning and Induction Heads — una mirada mecanicista a lo que realmente hacen las cabezas.