Phase 07 - Lesson 03
Atenção Multi-Cabeça
This lesson includes a graded coding exercise that runs in your browser, unlocked with lifetime access.
Uma cabeça de atenção aprende uma relação por vez. Oito cabeças aprendem oito. Cabeças são de graça. Pegue mais delas.
Tipo: Build Linguagens: Python Pré-requisitos: Fase 7 · 02 (Self-Attention do Zero) Tempo: ~75 minutos
O Problema
Uma única cabeça de self-attention calcula uma matriz de atenção. Essa matriz captura um tipo de relação — geralmente aquela que minimiza a perda em qualquer que seja o sinal de treino. Se seus dados têm concordância sujeito-verbo, co-referência, discurso de longo alcance e segmentação sintática tudo emaranhado, uma única cabeça os mistura em uma só distribuição soft-max e perde metade do sinal.
A correção do artigo de Vaswani de 2017: rodar várias funções de atenção em paralelo, cada uma com suas próprias projeções Q, K, V, e concatenar as saídas. Cada cabeça opera em um subespaço menor de dimensão d_model / n_heads. O total de parâmetros permanece o mesmo. O poder expressivo aumenta.
A atenção multi-cabeça é o padrão com o qual todo transformer em 2026 vem de fábrica. A única discussão é sobre quantas cabeças e se chaves e valores compartilham projeções (Grouped-Query Attention, Multi-Query Attention, Multi-head Latent Attention).
O Conceito
Separar. Pegue X de formato (N, d_model). Projete para Q, K, V cada um de formato (N, d_model). Faça reshape para (N, n_heads, d_head) onde d_head = d_model / n_heads. Transponha para (n_heads, N, d_head).
Atender em paralelo. Rode a atenção por produto escalar escalado dentro de cada cabeça. Cada cabeça produz (N, d_head). As cabeças operam em subespaços diferentes do embedding e nunca conversam durante o próprio cálculo da atenção.
Concatenar e projetar. Empilhe as cabeças de volta para (N, d_model) e multiplique por uma matriz de saída aprendida W_o de formato (d_model, d_model). W_o é onde as cabeças se misturam.
Por que funciona. Cada cabeça pode se especializar sem competir com as outras por orçamento de representação. Estudos de sondagem de 2019–2024 mostram papéis distintos por cabeça: cabeças posicionais, cabeça que atende ao token anterior, cabeças de cópia, cabeças de entidade nomeada, cabeças de indução (que estão por trás do aprendizado em contexto).
A linhagem de variações de 2026:
| Variante | Cabeças Q | Cabeças K/V | Usada por |
|---|---|---|---|
| Multi-cabeça (MHA) | N | N | GPT-2, BERT, T5 |
| Multi-query (MQA) | N | 1 | PaLM, Falcon |
| Grouped-query (GQA) | N | G (ex.: N/8) | Llama 2 70B, Llama 3+, Qwen 2+, Mistral |
| Multi-head latent (MLA) | N | comprimida para baixo posto | DeepSeek-V2, V3 |
A GQA é o padrão moderno porque corta a memória do cache KV por um fator de N/G mantendo qualidade quase total. A MLA vai além ao comprimir K/V em um espaço latente e então projetar de volta no momento da computação — custa FLOPs, economiza muito mais memória.
Construa
Passo 1: separar cabeças a partir da atenção de cabeça única que já temos
Pegue o SelfAttention da Lição 02 e envolva-o com um par separar/concatenar. Veja code/main.py para uma implementação em numpy; a lógica é:
def split_heads(X, n_heads):
n, d = X.shape
d_head = d // n_heads
return X.reshape(n, n_heads, d_head).transpose(1, 0, 2) # (heads, n, d_head)
def combine_heads(H):
h, n, d_head = H.shape
return H.transpose(1, 0, 2).reshape(n, h * d_head)
Um reshape e um transpose. Sem loop. É exatamente o que o PyTorch faz por baixo de nn.MultiheadAttention.
Passo 2: rodar atenção por produto escalar escalado por cabeça
Cada cabeça recebe sua própria fatia de Q, K, V. A atenção vira um matmul em lote:
def mha_forward(X, W_q, W_k, W_v, W_o, n_heads):
Q = X @ W_q
K = X @ W_k
V = X @ W_v
Qh = split_heads(Q, n_heads) # (heads, n, d_head)
Kh = split_heads(K, n_heads)
Vh = split_heads(V, n_heads)
scores = Qh @ Kh.transpose(0, 2, 1) / np.sqrt(Qh.shape[-1])
weights = softmax(scores, axis=-1)
out = weights @ Vh # (heads, n, d_head)
concat = combine_heads(out)
return concat @ W_o, weights
Em hardware real Qh @ Kh.transpose(...) é um único bmm. A GPU vê um único matmul em lote de formato (heads, N, d_head) × (heads, d_head, N) -> (heads, N, N). Adicionar cabeças é de graça.
Passo 3: variante Grouped-Query Attention
Apenas as projeções de chave e valor mudam. Q recebe n_heads grupos; K e V recebem n_kv_heads < n_heads grupos e são repetidos para casar:
def gqa_project(X, W, n_kv_heads, n_heads):
kv = split_heads(X @ W, n_kv_heads) # (kv_heads, n, d_head)
repeat = n_heads // n_kv_heads
return np.repeat(kv, repeat, axis=0) # (n_heads, n, d_head)
Na inferência isso economiza memória porque apenas n_kv_heads cópias vivem no cache KV, não n_heads. O Llama 3 70B usa 64 cabeças de query com 8 cabeças KV — uma redução de 8× no cache.
Passo 4: sondar o que cada cabeça aprendeu
Rode a MHA em uma frase curta com 4 cabeças. Para cada cabeça, imprima a matriz de atenção (N, N). Você verá cabeças diferentes destacando estruturas diferentes mesmo com inicialização aleatória — isso é em parte sinal, em parte simetria rotacional nos subespaços.
Use
No PyTorch, a versão de uma linha:
import torch.nn as nn
mha = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)
GQA a partir do PyTorch 2.5+:
from torch.nn.functional import scaled_dot_product_attention
# scaled_dot_product_attention auto-dispatches Flash Attention on CUDA.
# For GQA, pass Q of shape (B, n_heads, N, d_head) and K,V of shape
# (B, n_kv_heads, N, d_head). PyTorch handles the repeat.
out = scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=True)
Quantas cabeças? Regras práticas de modelos em produção em 2026:
| Tamanho do modelo | d_model | n_heads | d_head |
|---|---|---|---|
| Pequeno (~125M) | 768 | 12 | 64 |
| Base (~350M) | 1024 | 16 | 64 |
| Grande (~1B) | 2048 | 16 | 128 |
| Fronteira (~70B) | 8192 | 64 | 128 |
d_head quase sempre cai em 64 ou 128. É a unidade de quanto uma cabeça consegue "enxergar". Caia abaixo de 32 e as cabeças começam a brigar com o fator de escala sqrt(d_head); passe de 256 e você perde o benefício dos "muitos pequenos especialistas".
Entregue
Veja outputs/skill-mha-configurator.md. A skill recomenda contagem de cabeças, contagem de cabeças kv e estratégia de projeção para um novo transformer dado orçamento de parâmetros, comprimento de sequência e alvo de implantação.
Exercícios
- Fácil. Pegue a MHA de
code/main.pye muden_headsde 1 para 16 comd_model=64fixo. Plote a perda de um modelo minúsculo de uma camada em uma tarefa sintética de cópia. Mais cabeças ajudam, estabilizam ou prejudicam? - Médio. Implemente MQA (uma cabeça KV compartilhada entre todas as cabeças de query). Meça quanto cai a contagem de parâmetros vs a MHA completa. Calcule quanto encolhe o tamanho do cache KV na inferência para N=2048.
- Difícil. Implemente uma versão minúscula de Multi-head Latent Attention: comprima K,V para um latente de posto
r, armazene o latente no cache KV, descomprima no momento da atenção. Em qualra memória do cache cruza abaixo de 1/8 da MHA completa enquanto a qualidade permanece dentro de 1 bit da ppl de validação?
Termos-Chave
| Termo | O que as pessoas dizem | O que de fato significa |
|---|---|---|
| Cabeça | "Um único circuito de atenção" | Uma projeção Q/K/V de dimensão d_head = d_model / n_heads com sua própria matriz de atenção. |
| d_head | "Dimensão da cabeça" | Largura oculta por cabeça; quase sempre 64 ou 128 em produção. |
| Separar / combinar | "Truques de reshape" | reshape+transpose (N, d_model) ↔ (n_heads, N, d_head) em torno da atenção. |
| W_o | "Projeção de saída" | Matriz (d_model, d_model) aplicada após concatenar as cabeças; onde as cabeças se misturam. |
| MQA | "Uma cabeça KV" | Multi-Query Attention: uma única projeção K/V compartilhada. Menor cache KV, alguma perda de qualidade. |
| GQA | "O padrão desde o Llama 2" | Grouped-Query Attention com n_kv_heads < n_heads; repete para casar com Q. |
| MLA | "O truque da DeepSeek" | Multi-head Latent Attention: K,V comprimidos para um latente de baixo posto, descomprimidos no momento de atender. |
| Cabeça de indução | "O circuito por trás do aprendizado em contexto" | Um par de cabeças que detecta ocorrências anteriores e copia o que veio depois delas. |
Leitura Adicional
- Vaswani et al. (2017). Attention Is All You Need §3.2.2 — a especificação original da multi-cabeça.
- Shazeer (2019). Fast Transformer Decoding: One Write-Head is All You Need — o artigo da MQA.
- Ainslie et al. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints — como converter MHA para GQA após o treino.
- DeepSeek-AI (2024). DeepSeek-V2 Technical Report — MLA e por que ela vence MHA/GQA em memória de cache.
- Olsson et al. (2022). In-context Learning and Induction Heads — um olhar mecanicista sobre o que as cabeças realmente fazem.