Phase 07 - Lesson 03

Atención Multi-Cabeza

This lesson includes a graded coding exercise that runs in your browser, unlocked with lifetime access.

Una cabeza de atención aprende una relación a la vez. Ocho cabezas aprenden ocho. Las cabezas son gratis. Usa más de ellas.

Tipo: Build Lenguajes: Python Prerrequisitos: Fase 7 · 02 (Self-Attention desde Cero) Tiempo: ~75 minutos

El Problema

Una sola cabeza de self-attention calcula una matriz de atención. Esa matriz captura un tipo de relación — generalmente la que minimiza la pérdida en cualquiera que sea la señal de entrenamiento. Si tus datos tienen concordancia sujeto-verbo, co-referencia, discurso de largo alcance y segmentación sintáctica todo enredado, una sola cabeza los mezcla en una única distribución soft-max y pierde la mitad de la señal.

La solución del artículo de Vaswani de 2017: ejecutar varias funciones de atención en paralelo, cada una con sus propias proyecciones Q, K, V, y concatenar las salidas. Cada cabeza opera en un subespacio más pequeño de dimensión d_model / n_heads. El total de parámetros se mantiene igual. El poder expresivo aumenta.

La atención multi-cabeza es el estándar con el que viene todo transformer en 2026. La única discusión es sobre cuántas cabezas y si las claves y los valores comparten proyecciones (Grouped-Query Attention, Multi-Query Attention, Multi-head Latent Attention).

El Concepto

La atención multi-cabeza separa, atiende y concatena

Separar. Toma X de forma (N, d_model). Proyecta a Q, K, V cada uno de forma (N, d_model). Haz reshape a (N, n_heads, d_head) donde d_head = d_model / n_heads. Transpone a (n_heads, N, d_head).

Atender en paralelo. Ejecuta la atención por producto punto escalado dentro de cada cabeza. Cada cabeza produce (N, d_head). Las cabezas operan en subespacios diferentes del embedding y nunca se comunican durante el propio cálculo de la atención.

Concatenar y proyectar. Apila las cabezas de vuelta a (N, d_model) y multiplica por una matriz de salida aprendida W_o de forma (d_model, d_model). W_o es donde las cabezas se mezclan.

Por qué funciona. Cada cabeza puede especializarse sin competir con las demás por presupuesto de representación. Estudios de sondeo de 2019–2024 muestran roles distintos por cabeza: cabezas posicionales, cabeza que atiende al token anterior, cabezas de copia, cabezas de entidad nombrada, cabezas de inducción (que están detrás del aprendizaje en contexto).

El linaje de variaciones de 2026:

Variante Cabezas Q Cabezas K/V Usada por
Multi-cabeza (MHA) N N GPT-2, BERT, T5
Multi-query (MQA) N 1 PaLM, Falcon
Grouped-query (GQA) N G (ej.: N/8) Llama 2 70B, Llama 3+, Qwen 2+, Mistral
Multi-head latent (MLA) N comprimida a bajo rango DeepSeek-V2, V3

GQA es el estándar moderno porque recorta la memoria del caché KV por un factor de N/G manteniendo una calidad casi total. MLA va más allá al comprimir K/V en un espacio latente y luego proyectar de vuelta en el momento del cómputo — cuesta FLOPs, ahorra mucha más memoria.

Constrúyelo

Paso 1: separar cabezas a partir de la atención de cabeza única que ya tenemos

Toma el SelfAttention de la Lección 02 y envuélvelo con un par separar/concatenar. Ve code/main.py para una implementación en numpy; la lógica es:

def split_heads(X, n_heads):
    n, d = X.shape
    d_head = d // n_heads
    return X.reshape(n, n_heads, d_head).transpose(1, 0, 2)  # (heads, n, d_head)

def combine_heads(H):
    h, n, d_head = H.shape
    return H.transpose(1, 0, 2).reshape(n, h * d_head)

Un reshape y un transpose. Sin bucle. Es exactamente lo que hace PyTorch por debajo de nn.MultiheadAttention.

Paso 2: ejecutar atención por producto punto escalado por cabeza

Cada cabeza recibe su propia porción de Q, K, V. La atención se convierte en un matmul por lotes:

def mha_forward(X, W_q, W_k, W_v, W_o, n_heads):
    Q = X @ W_q
    K = X @ W_k
    V = X @ W_v
    Qh = split_heads(Q, n_heads)         # (heads, n, d_head)
    Kh = split_heads(K, n_heads)
    Vh = split_heads(V, n_heads)
    scores = Qh @ Kh.transpose(0, 2, 1) / np.sqrt(Qh.shape[-1])
    weights = softmax(scores, axis=-1)
    out = weights @ Vh                    # (heads, n, d_head)
    concat = combine_heads(out)
    return concat @ W_o, weights

En hardware real Qh @ Kh.transpose(...) es un único bmm. La GPU ve un único matmul por lotes de forma (heads, N, d_head) × (heads, d_head, N) -> (heads, N, N). Agregar cabezas es gratis.

Paso 3: variante Grouped-Query Attention

Solo cambian las proyecciones de clave y valor. Q recibe n_heads grupos; K y V reciben n_kv_heads < n_heads grupos y se repiten para coincidir:

def gqa_project(X, W, n_kv_heads, n_heads):
    kv = split_heads(X @ W, n_kv_heads)       # (kv_heads, n, d_head)
    repeat = n_heads // n_kv_heads
    return np.repeat(kv, repeat, axis=0)      # (n_heads, n, d_head)

En la inferencia esto ahorra memoria porque solo n_kv_heads copias viven en el caché KV, no n_heads. Llama 3 70B usa 64 cabezas de query con 8 cabezas KV — una reducción de 8× en el caché.

Paso 4: sondear lo que aprendió cada cabeza

Ejecuta la MHA en una oración corta con 4 cabezas. Para cada cabeza, imprime la matriz de atención (N, N). Verás cabezas diferentes destacando estructuras diferentes incluso con inicialización aleatoria — eso es en parte señal, en parte simetría rotacional en los subespacios.

Úsalo

En PyTorch, la versión de una línea:

import torch.nn as nn

mha = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)

GQA a partir de PyTorch 2.5+:

from torch.nn.functional import scaled_dot_product_attention

# scaled_dot_product_attention auto-dispatches Flash Attention on CUDA.
# For GQA, pass Q of shape (B, n_heads, N, d_head) and K,V of shape
# (B, n_kv_heads, N, d_head). PyTorch handles the repeat.
out = scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=True)

¿Cuántas cabezas? Reglas prácticas de modelos en producción en 2026:

Tamaño del modelo d_model n_heads d_head
Pequeño (~125M) 768 12 64
Base (~350M) 1024 16 64
Grande (~1B) 2048 16 128
Frontera (~70B) 8192 64 128

d_head casi siempre cae en 64 o 128. Es la unidad de cuánto puede "ver" una cabeza. Baja de 32 y las cabezas empiezan a pelear con el factor de escala sqrt(d_head); pasa de 256 y pierdes el beneficio de los "muchos pequeños especialistas".

Entrégalo

Ve outputs/skill-mha-configurator.md. La skill recomienda cantidad de cabezas, cantidad de cabezas kv y estrategia de proyección para un nuevo transformer dado un presupuesto de parámetros, longitud de secuencia y objetivo de despliegue.

Ejercicios

  1. Fácil. Toma la MHA de code/main.py y cambia n_heads de 1 a 16 con d_model=64 fijo. Grafica la pérdida de un modelo minúsculo de una capa en una tarea sintética de copia. ¿Más cabezas ayudan, se estabilizan o perjudican?
  2. Medio. Implementa MQA (una cabeza KV compartida entre todas las cabezas de query). Mide cuánto baja la cantidad de parámetros vs la MHA completa. Calcula cuánto encoge el tamaño del caché KV en la inferencia para N=2048.
  3. Difícil. Implementa una versión minúscula de Multi-head Latent Attention: comprime K,V a un latente de rango r, almacena el latente en el caché KV, descomprime en el momento de la atención. ¿En qué r la memoria del caché cruza por debajo de 1/8 de la MHA completa mientras la calidad se mantiene dentro de 1 bit de la ppl de validación?

Términos Clave

Término Lo que dice la gente Lo que en realidad significa
Cabeza "Un único circuito de atención" Una proyección Q/K/V de dimensión d_head = d_model / n_heads con su propia matriz de atención.
d_head "Dimensión de la cabeza" Ancho oculto por cabeza; casi siempre 64 o 128 en producción.
Separar / combinar "Trucos de reshape" reshape+transpose (N, d_model) ↔ (n_heads, N, d_head) alrededor de la atención.
W_o "Proyección de salida" Matriz (d_model, d_model) aplicada tras concatenar las cabezas; donde las cabezas se mezclan.
MQA "Una cabeza KV" Multi-Query Attention: una única proyección K/V compartida. El caché KV más pequeño, algo de pérdida de calidad.
GQA "El estándar desde Llama 2" Grouped-Query Attention con n_kv_heads < n_heads; repite para coincidir con Q.
MLA "El truco de DeepSeek" Multi-head Latent Attention: K,V comprimidos a un latente de bajo rango, descomprimidos en el momento de atender.
Cabeza de inducción "El circuito detrás del aprendizaje en contexto" Un par de cabezas que detecta ocurrencias anteriores y copia lo que vino después de ellas.

Lectura Adicional

0 lifetime access. Curriculum based on AI Engineering from Scratch by Rohit Ghumare (MIT, used under attribution).