Phase 05 - Lesson 10
Mecanismo de Atenção — O Avanço Decisivo
O decoder para de apertar os olhos para um resumo comprimido e passa a olhar para toda a fonte. Tudo depois disso é atenção mais engenharia.
Type: Build Languages: Python Prerequisites: Phase 5 · 09 (Sequence-to-Sequence Models) Time: ~45 minutes
O Problema
A lição 09 terminou com uma falha calculada. Um encoder-decoder GRU treinado em uma tarefa de cópia de brinquedo vai de 89% de acurácia no comprimento 5 para próximo do acaso no comprimento 80. A razão é estrutural, não um bug de treinamento: cada bit de informação que o encoder coletou precisa caber em um único estado oculto de tamanho fixo, e o decoder nunca vê mais nada.
Bahdanau, Cho e Bengio publicaram uma correção de três linhas em 2014. Em vez de dar ao decoder apenas o estado final do encoder, mantenha todos os estados do encoder. A cada passo do decoder, calcule uma média ponderada dos estados do encoder, onde os pesos dizem "quanto o decoder precisa olhar para a posição i do encoder agora?" Essa média ponderada é o contexto, e ela muda a cada passo do decoder.
Essa é a ideia inteira. Os Transformers a estenderam. A self-attention a aplicou a uma única sequência. A multi-head attention a executou em paralelo. Mas a versão de 2014 já tinha quebrado o gargalo, e uma vez que você a tem, a virada para os transformers é engenharia, não conceito.
O Conceito
A cada passo t do decoder:
- Use o estado oculto anterior do decoder
s_{t-1}como uma query. - Pontue-o contra cada estado oculto do encoder
h_1, ..., h_T. Um escalar por posição do encoder. - Aplique softmax sobre as pontuações para obter os pesos de atenção
α_{t,1}, ..., α_{t,T}que somam 1. - Vetor de contexto
c_t = Σ α_{t,i} * h_i. Média ponderada dos estados do encoder. - O decoder pega
c_tmais o token de saída anterior e produz o próximo token.
A média ponderada é o ponto central. Quando o decoder precisa traduzir "Je" para "I", ele atribui peso alto ao estado do encoder sobre "Je" e peso baixo aos demais. Quando precisa de "not", atribui peso alto a "pas". O vetor de contexto se remodela a cada passo.
Shapes (a coisa que pega todo mundo)
É aqui que toda implementação de atenção dá errado na primeira vez. Leia devagar.
| Coisa | Shape | Notas |
|---|---|---|
Estados ocultos do encoder H |
(T_enc, d_h) |
Se for BiLSTM, d_h = 2 * d_hidden |
Estado oculto do decoder s_{t-1} |
(d_s,) |
Um vetor |
Pontuação de atenção e_{t,i} |
escalar | Uma por posição do encoder |
Peso de atenção α_{t,i} |
escalar | Após o softmax sobre todos os i |
Vetor de contexto c_t |
(d_h,) |
Mesmo shape de um estado do encoder |
Pontuação de Bahdanau (aditiva). e_{t,i} = v_α^T * tanh(W_a * s_{t-1} + U_a * h_i).
s_{t-1}tem shape(d_s,),h_item shape(d_h,).W_atem shape(d_attn, d_s).U_atem shape(d_attn, d_h).- A soma deles dentro do tanh tem shape
(d_attn,). v_αtem shape(d_attn,). O produto interno comv_αcolapsa para um escalar. É isso quev_αfaz. Não é mágica. É a projeção que transforma um vetor de dimensão de atenção em uma pontuação escalar.
Pontuação de Luong (multiplicativa). Três variantes:
dot:e_{t,i} = s_t^T * h_i. Exiged_s == d_h. Restrição rígida. Pule se o seu encoder for bidirecional.general:e_{t,i} = s_t^T * W * h_icomWde shape(d_s, d_h). Remove a restrição de dimensões iguais.concat: essencialmente a forma de Bahdanau. Raramente usada, já que as duas primeiras são mais baratas.
Uma pegadinha de Bahdanau / Luong que vale nomear. Bahdanau usa s_{t-1} (o estado do decoder antes de gerar a palavra atual). Luong usa s_t (o estado depois). Confundi-los produz gradientes sutilmente errados que são extremamente difíceis de depurar. Escolha um artigo e siga a convenção dele.
Construa
Passo 1: atenção aditiva (Bahdanau)
import numpy as np
def additive_attention(decoder_state, encoder_states, W_a, U_a, v_a):
projected_dec = W_a @ decoder_state
projected_enc = encoder_states @ U_a.T
combined = np.tanh(projected_enc + projected_dec)
scores = combined @ v_a
weights = softmax(scores)
context = weights @ encoder_states
return context, weights
def softmax(x):
x = x - np.max(x)
e = np.exp(x)
return e / e.sum()
Confira seus shapes contra a tabela acima. encoder_states tem shape (T_enc, d_h). projected_enc tem shape (T_enc, d_attn). projected_dec tem shape (d_attn,) e é propagado por broadcasting. combined tem shape (T_enc, d_attn). scores tem shape (T_enc,). weights tem shape (T_enc,). context tem shape (d_h,). Pode mandar.
Passo 2: Luong dot e general
def dot_attention(decoder_state, encoder_states):
scores = encoder_states @ decoder_state
weights = softmax(scores)
return weights @ encoder_states, weights
def general_attention(decoder_state, encoder_states, W):
projected = W.T @ decoder_state
scores = encoder_states @ projected
weights = softmax(scores)
return weights @ encoder_states, weights
Três linhas cada. É por isso que o artigo de Luong pegou. Mesma acurácia na maioria das tarefas, muito menos código.
Passo 3: um exemplo numérico resolvido
Dados três estados de encoder (mais ou menos "cat", "sat", "mat") e um estado de decoder que se alinha mais com o primeiro, a distribuição de atenção se concentra na posição 0. Se o estado do decoder mudar para se alinhar com o último, a atenção se move para a posição 2. O vetor de contexto acompanha.
H = np.array([
[1.0, 0.0, 0.2],
[0.5, 0.5, 0.1],
[0.1, 0.9, 0.3],
])
s_close_to_cat = np.array([0.9, 0.1, 0.2])
ctx, w = dot_attention(s_close_to_cat, H)
print("weights:", w.round(3))
weights: [0.464 0.305 0.231]
A primeira linha vence. Depois mova o estado do decoder para mais perto do terceiro estado do encoder e veja os pesos mudarem. É isso. Atenção é alinhamento explícito.
Passo 4: por que isto é a ponte para os transformers
Traduza a linguagem acima para Q/K/V:
- Query = estado do decoder
s_{t-1} - Key = estados do encoder (contra o que pontuamos)
- Value = estados do encoder (o que ponderamos e somamos)
Na atenção clássica, keys e values são a mesma coisa. A self-attention as separa: você pode consultar uma sequência contra ela mesma, com projeções aprendidas diferentes para K e V. A multi-head attention a executa em paralelo com projeções aprendidas diferentes. Os transformers empilham o estágio inteiro muitas vezes e abandonam as RNNs.
A matemática é a mesma. Os shapes são os mesmos. O salto pedagógico da atenção de Bahdanau para a atenção de produto escalar escalonado (scaled dot-product) é, na maior parte, notação.
Use
PyTorch e TensorFlow trazem atenção pronta.
import torch
import torch.nn as nn
mha = nn.MultiheadAttention(embed_dim=128, num_heads=8, batch_first=True)
query = torch.randn(2, 5, 128)
key = torch.randn(2, 10, 128)
value = torch.randn(2, 10, 128)
output, weights = mha(query, key, value)
print(output.shape, weights.shape)
torch.Size([2, 5, 128]) torch.Size([2, 5, 10])
Essa é uma camada de atenção de transformer. Batch de query com 5 posições, batch de key/value com 10 posições, 128 dimensões cada, 8 heads. output são as novas queries aumentadas com contexto. weights é a matriz de alinhamento 5x10 que você pode visualizar.
Quando a atenção clássica ainda importa
- Pedagogia. A versão de uma única head, uma única camada, baseada em RNN torna cada conceito visível.
- Tarefas de sequência on-device onde os transformers não cabem.
- Qualquer artigo de 2014-2017. Você vai interpretá-lo errado sem conhecer a convenção de Bahdanau.
- Análise de alinhamento fino em MT (tradução automática). Os pesos de atenção brutos são uma ferramenta de interpretabilidade mesmo em modelos transformer, e lê-los exige saber o que eles são.
A armadilha do peso-de-atenção-como-explicação
Os pesos de atenção parecem interpretáveis. São pesos que somam um entre as posições; você pode plotá-los; alto significa "olhou para isto". Os revisores os adoram.
Eles não são tão interpretáveis quanto parecem. Jain e Wallace (2019) mostraram que as distribuições de atenção podem ser permutadas e substituídas por alternativas arbitrárias sem mudar as predições do modelo para algumas tarefas. Nunca reporte pesos de atenção como evidência de raciocínio sem uma checagem de ablação ou contrafactual.
Entregue
Salve como outputs/prompt-attention-shapes.md:
---
name: attention-shapes
description: Debug shape bugs in attention implementations.
phase: 5
lesson: 10
---
Given a broken attention implementation, you identify the shape mismatch. Output:
1. Which matrix has the wrong shape. Name the tensor.
2. What its shape should be, derived from (d_s, d_h, d_attn, T_enc, T_dec, batch_size).
3. One-line fix. Transpose, reshape, or project.
4. A test to catch regressions. Typically: assert `output.shape == (batch, T_dec, d_h)` and `weights.shape == (batch, T_dec, T_enc)` and `weights.sum(dim=-1) close to 1`.
Refuse to recommend fixes that silently broadcast. Broadcast-hiding bugs surface later as silent accuracy degradation, the worst kind of attention bug.
For Bahdanau confusion, insist the decoder input is `s_{t-1}` (pre-step state). For Luong, `s_t` (post-step state). For dot-product, flag dimension mismatch between query and key as the most common first-time error.
Exercícios
- Fácil. Implemente o mascaramento do
softmaxpara que os tokens de padding no encoder recebam peso de atenção zero. Teste em um batch com sequências de comprimento variável. - Médio. Adicione multi-head attention à forma
generalde Luong. Dividad_hemn_headsgrupos, execute a atenção por head e concatene. Verifique se o caso de uma única head corresponde à sua implementação anterior. - Difícil. Treine um encoder-decoder GRU com atenção de Bahdanau na tarefa de cópia de brinquedo da lição 09. Plote a acurácia versus o comprimento da sequência. Compare com o baseline sem atenção. Você deve ver a diferença aumentar conforme o comprimento cresce, confirmando que a atenção alivia o gargalo.
Termos-Chave
| Termo | O que as pessoas dizem | O que realmente significa |
|---|---|---|
| Atenção | Olhar para as coisas | Média ponderada de uma sequência de values, com pesos calculados a partir de uma similaridade query-key. |
| Query, Key, Value | QKV | Três projeções: Q pergunta, K é o que casar, V é o que retornar. |
| Atenção aditiva | Bahdanau | Pontuação feed-forward: v^T tanh(W q + U k). |
| Atenção multiplicativa | Luong dot / general | A pontuação é q^T k ou q^T W k. Mais barata, mesma acurácia na maioria das tarefas. |
| Matriz de alinhamento | A figura bonita | Pesos de atenção como uma grade (T_dec, T_enc). Leia-a para ver para onde o modelo prestou atenção. |
Leitura Adicional
- Bahdanau, Cho, Bengio (2014). Neural Machine Translation by Jointly Learning to Align and Translate — o artigo.
- Luong, Pham, Manning (2015). Effective Approaches to Attention-based Neural Machine Translation — as três variantes de pontuação e a comparação entre elas.
- Jain and Wallace (2019). Attention is not Explanation — a ressalva sobre interpretabilidade.
- Dive into Deep Learning — Bahdanau Attention — tutorial executável com PyTorch.