Phase 07 - Lesson 03

Atenção Multi-Cabeça

This lesson includes a graded coding exercise that runs in your browser, unlocked with lifetime access.

Uma cabeça de atenção aprende uma relação por vez. Oito cabeças aprendem oito. Cabeças são de graça. Pegue mais delas.

Tipo: Build Linguagens: Python Pré-requisitos: Fase 7 · 02 (Self-Attention do Zero) Tempo: ~75 minutos

O Problema

Uma única cabeça de self-attention calcula uma matriz de atenção. Essa matriz captura um tipo de relação — geralmente aquela que minimiza a perda em qualquer que seja o sinal de treino. Se seus dados têm concordância sujeito-verbo, co-referência, discurso de longo alcance e segmentação sintática tudo emaranhado, uma única cabeça os mistura em uma só distribuição soft-max e perde metade do sinal.

A correção do artigo de Vaswani de 2017: rodar várias funções de atenção em paralelo, cada uma com suas próprias projeções Q, K, V, e concatenar as saídas. Cada cabeça opera em um subespaço menor de dimensão d_model / n_heads. O total de parâmetros permanece o mesmo. O poder expressivo aumenta.

A atenção multi-cabeça é o padrão com o qual todo transformer em 2026 vem de fábrica. A única discussão é sobre quantas cabeças e se chaves e valores compartilham projeções (Grouped-Query Attention, Multi-Query Attention, Multi-head Latent Attention).

O Conceito

Atenção multi-cabeça separa, atende e concatena

Separar. Pegue X de formato (N, d_model). Projete para Q, K, V cada um de formato (N, d_model). Faça reshape para (N, n_heads, d_head) onde d_head = d_model / n_heads. Transponha para (n_heads, N, d_head).

Atender em paralelo. Rode a atenção por produto escalar escalado dentro de cada cabeça. Cada cabeça produz (N, d_head). As cabeças operam em subespaços diferentes do embedding e nunca conversam durante o próprio cálculo da atenção.

Concatenar e projetar. Empilhe as cabeças de volta para (N, d_model) e multiplique por uma matriz de saída aprendida W_o de formato (d_model, d_model). W_o é onde as cabeças se misturam.

Por que funciona. Cada cabeça pode se especializar sem competir com as outras por orçamento de representação. Estudos de sondagem de 2019–2024 mostram papéis distintos por cabeça: cabeças posicionais, cabeça que atende ao token anterior, cabeças de cópia, cabeças de entidade nomeada, cabeças de indução (que estão por trás do aprendizado em contexto).

A linhagem de variações de 2026:

Variante Cabeças Q Cabeças K/V Usada por
Multi-cabeça (MHA) N N GPT-2, BERT, T5
Multi-query (MQA) N 1 PaLM, Falcon
Grouped-query (GQA) N G (ex.: N/8) Llama 2 70B, Llama 3+, Qwen 2+, Mistral
Multi-head latent (MLA) N comprimida para baixo posto DeepSeek-V2, V3

A GQA é o padrão moderno porque corta a memória do cache KV por um fator de N/G mantendo qualidade quase total. A MLA vai além ao comprimir K/V em um espaço latente e então projetar de volta no momento da computação — custa FLOPs, economiza muito mais memória.

Construa

Passo 1: separar cabeças a partir da atenção de cabeça única que já temos

Pegue o SelfAttention da Lição 02 e envolva-o com um par separar/concatenar. Veja code/main.py para uma implementação em numpy; a lógica é:

def split_heads(X, n_heads):
    n, d = X.shape
    d_head = d // n_heads
    return X.reshape(n, n_heads, d_head).transpose(1, 0, 2)  # (heads, n, d_head)

def combine_heads(H):
    h, n, d_head = H.shape
    return H.transpose(1, 0, 2).reshape(n, h * d_head)

Um reshape e um transpose. Sem loop. É exatamente o que o PyTorch faz por baixo de nn.MultiheadAttention.

Passo 2: rodar atenção por produto escalar escalado por cabeça

Cada cabeça recebe sua própria fatia de Q, K, V. A atenção vira um matmul em lote:

def mha_forward(X, W_q, W_k, W_v, W_o, n_heads):
    Q = X @ W_q
    K = X @ W_k
    V = X @ W_v
    Qh = split_heads(Q, n_heads)         # (heads, n, d_head)
    Kh = split_heads(K, n_heads)
    Vh = split_heads(V, n_heads)
    scores = Qh @ Kh.transpose(0, 2, 1) / np.sqrt(Qh.shape[-1])
    weights = softmax(scores, axis=-1)
    out = weights @ Vh                    # (heads, n, d_head)
    concat = combine_heads(out)
    return concat @ W_o, weights

Em hardware real Qh @ Kh.transpose(...) é um único bmm. A GPU vê um único matmul em lote de formato (heads, N, d_head) × (heads, d_head, N) -> (heads, N, N). Adicionar cabeças é de graça.

Passo 3: variante Grouped-Query Attention

Apenas as projeções de chave e valor mudam. Q recebe n_heads grupos; K e V recebem n_kv_heads < n_heads grupos e são repetidos para casar:

def gqa_project(X, W, n_kv_heads, n_heads):
    kv = split_heads(X @ W, n_kv_heads)       # (kv_heads, n, d_head)
    repeat = n_heads // n_kv_heads
    return np.repeat(kv, repeat, axis=0)      # (n_heads, n, d_head)

Na inferência isso economiza memória porque apenas n_kv_heads cópias vivem no cache KV, não n_heads. O Llama 3 70B usa 64 cabeças de query com 8 cabeças KV — uma redução de 8× no cache.

Passo 4: sondar o que cada cabeça aprendeu

Rode a MHA em uma frase curta com 4 cabeças. Para cada cabeça, imprima a matriz de atenção (N, N). Você verá cabeças diferentes destacando estruturas diferentes mesmo com inicialização aleatória — isso é em parte sinal, em parte simetria rotacional nos subespaços.

Use

No PyTorch, a versão de uma linha:

import torch.nn as nn

mha = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)

GQA a partir do PyTorch 2.5+:

from torch.nn.functional import scaled_dot_product_attention

# scaled_dot_product_attention auto-dispatches Flash Attention on CUDA.
# For GQA, pass Q of shape (B, n_heads, N, d_head) and K,V of shape
# (B, n_kv_heads, N, d_head). PyTorch handles the repeat.
out = scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=True)

Quantas cabeças? Regras práticas de modelos em produção em 2026:

Tamanho do modelo d_model n_heads d_head
Pequeno (~125M) 768 12 64
Base (~350M) 1024 16 64
Grande (~1B) 2048 16 128
Fronteira (~70B) 8192 64 128

d_head quase sempre cai em 64 ou 128. É a unidade de quanto uma cabeça consegue "enxergar". Caia abaixo de 32 e as cabeças começam a brigar com o fator de escala sqrt(d_head); passe de 256 e você perde o benefício dos "muitos pequenos especialistas".

Entregue

Veja outputs/skill-mha-configurator.md. A skill recomenda contagem de cabeças, contagem de cabeças kv e estratégia de projeção para um novo transformer dado orçamento de parâmetros, comprimento de sequência e alvo de implantação.

Exercícios

  1. Fácil. Pegue a MHA de code/main.py e mude n_heads de 1 para 16 com d_model=64 fixo. Plote a perda de um modelo minúsculo de uma camada em uma tarefa sintética de cópia. Mais cabeças ajudam, estabilizam ou prejudicam?
  2. Médio. Implemente MQA (uma cabeça KV compartilhada entre todas as cabeças de query). Meça quanto cai a contagem de parâmetros vs a MHA completa. Calcule quanto encolhe o tamanho do cache KV na inferência para N=2048.
  3. Difícil. Implemente uma versão minúscula de Multi-head Latent Attention: comprima K,V para um latente de posto r, armazene o latente no cache KV, descomprima no momento da atenção. Em qual r a memória do cache cruza abaixo de 1/8 da MHA completa enquanto a qualidade permanece dentro de 1 bit da ppl de validação?

Termos-Chave

Termo O que as pessoas dizem O que de fato significa
Cabeça "Um único circuito de atenção" Uma projeção Q/K/V de dimensão d_head = d_model / n_heads com sua própria matriz de atenção.
d_head "Dimensão da cabeça" Largura oculta por cabeça; quase sempre 64 ou 128 em produção.
Separar / combinar "Truques de reshape" reshape+transpose (N, d_model) ↔ (n_heads, N, d_head) em torno da atenção.
W_o "Projeção de saída" Matriz (d_model, d_model) aplicada após concatenar as cabeças; onde as cabeças se misturam.
MQA "Uma cabeça KV" Multi-Query Attention: uma única projeção K/V compartilhada. Menor cache KV, alguma perda de qualidade.
GQA "O padrão desde o Llama 2" Grouped-Query Attention com n_kv_heads < n_heads; repete para casar com Q.
MLA "O truque da DeepSeek" Multi-head Latent Attention: K,V comprimidos para um latente de baixo posto, descomprimidos no momento de atender.
Cabeça de indução "O circuito por trás do aprendizado em contexto" Um par de cabeças que detecta ocorrências anteriores e copia o que veio depois delas.

Leitura Adicional

0 lifetime access. Curriculum based on AI Engineering from Scratch by Rohit Ghumare (MIT, used under attribution).