Phase 05 - Lesson 10

Mecanismo de Atenção — O Avanço Decisivo

O decoder para de apertar os olhos para um resumo comprimido e passa a olhar para toda a fonte. Tudo depois disso é atenção mais engenharia.

Type: Build Languages: Python Prerequisites: Phase 5 · 09 (Sequence-to-Sequence Models) Time: ~45 minutes

O Problema

A lição 09 terminou com uma falha calculada. Um encoder-decoder GRU treinado em uma tarefa de cópia de brinquedo vai de 89% de acurácia no comprimento 5 para próximo do acaso no comprimento 80. A razão é estrutural, não um bug de treinamento: cada bit de informação que o encoder coletou precisa caber em um único estado oculto de tamanho fixo, e o decoder nunca vê mais nada.

Bahdanau, Cho e Bengio publicaram uma correção de três linhas em 2014. Em vez de dar ao decoder apenas o estado final do encoder, mantenha todos os estados do encoder. A cada passo do decoder, calcule uma média ponderada dos estados do encoder, onde os pesos dizem "quanto o decoder precisa olhar para a posição i do encoder agora?" Essa média ponderada é o contexto, e ela muda a cada passo do decoder.

Essa é a ideia inteira. Os Transformers a estenderam. A self-attention a aplicou a uma única sequência. A multi-head attention a executou em paralelo. Mas a versão de 2014 já tinha quebrado o gargalo, e uma vez que você a tem, a virada para os transformers é engenharia, não conceito.

O Conceito

Atenção de Bahdanau: o decoder consulta todos os estados do encoder

A cada passo t do decoder:

  1. Use o estado oculto anterior do decoder s_{t-1} como uma query.
  2. Pontue-o contra cada estado oculto do encoder h_1, ..., h_T. Um escalar por posição do encoder.
  3. Aplique softmax sobre as pontuações para obter os pesos de atenção α_{t,1}, ..., α_{t,T} que somam 1.
  4. Vetor de contexto c_t = Σ α_{t,i} * h_i. Média ponderada dos estados do encoder.
  5. O decoder pega c_t mais o token de saída anterior e produz o próximo token.

A média ponderada é o ponto central. Quando o decoder precisa traduzir "Je" para "I", ele atribui peso alto ao estado do encoder sobre "Je" e peso baixo aos demais. Quando precisa de "not", atribui peso alto a "pas". O vetor de contexto se remodela a cada passo.

Shapes (a coisa que pega todo mundo)

É aqui que toda implementação de atenção dá errado na primeira vez. Leia devagar.

Coisa Shape Notas
Estados ocultos do encoder H (T_enc, d_h) Se for BiLSTM, d_h = 2 * d_hidden
Estado oculto do decoder s_{t-1} (d_s,) Um vetor
Pontuação de atenção e_{t,i} escalar Uma por posição do encoder
Peso de atenção α_{t,i} escalar Após o softmax sobre todos os i
Vetor de contexto c_t (d_h,) Mesmo shape de um estado do encoder

Pontuação de Bahdanau (aditiva). e_{t,i} = v_α^T * tanh(W_a * s_{t-1} + U_a * h_i).

  • s_{t-1} tem shape (d_s,), h_i tem shape (d_h,).
  • W_a tem shape (d_attn, d_s). U_a tem shape (d_attn, d_h).
  • A soma deles dentro do tanh tem shape (d_attn,).
  • v_α tem shape (d_attn,). O produto interno com v_α colapsa para um escalar. É isso que v_α faz. Não é mágica. É a projeção que transforma um vetor de dimensão de atenção em uma pontuação escalar.

Pontuação de Luong (multiplicativa). Três variantes:

  • dot: e_{t,i} = s_t^T * h_i. Exige d_s == d_h. Restrição rígida. Pule se o seu encoder for bidirecional.
  • general: e_{t,i} = s_t^T * W * h_i com W de shape (d_s, d_h). Remove a restrição de dimensões iguais.
  • concat: essencialmente a forma de Bahdanau. Raramente usada, já que as duas primeiras são mais baratas.

Uma pegadinha de Bahdanau / Luong que vale nomear. Bahdanau usa s_{t-1} (o estado do decoder antes de gerar a palavra atual). Luong usa s_t (o estado depois). Confundi-los produz gradientes sutilmente errados que são extremamente difíceis de depurar. Escolha um artigo e siga a convenção dele.

Construa

Passo 1: atenção aditiva (Bahdanau)

import numpy as np


def additive_attention(decoder_state, encoder_states, W_a, U_a, v_a):
    projected_dec = W_a @ decoder_state
    projected_enc = encoder_states @ U_a.T
    combined = np.tanh(projected_enc + projected_dec)
    scores = combined @ v_a
    weights = softmax(scores)
    context = weights @ encoder_states
    return context, weights


def softmax(x):
    x = x - np.max(x)
    e = np.exp(x)
    return e / e.sum()

Confira seus shapes contra a tabela acima. encoder_states tem shape (T_enc, d_h). projected_enc tem shape (T_enc, d_attn). projected_dec tem shape (d_attn,) e é propagado por broadcasting. combined tem shape (T_enc, d_attn). scores tem shape (T_enc,). weights tem shape (T_enc,). context tem shape (d_h,). Pode mandar.

Passo 2: Luong dot e general

def dot_attention(decoder_state, encoder_states):
    scores = encoder_states @ decoder_state
    weights = softmax(scores)
    return weights @ encoder_states, weights


def general_attention(decoder_state, encoder_states, W):
    projected = W.T @ decoder_state
    scores = encoder_states @ projected
    weights = softmax(scores)
    return weights @ encoder_states, weights

Três linhas cada. É por isso que o artigo de Luong pegou. Mesma acurácia na maioria das tarefas, muito menos código.

Passo 3: um exemplo numérico resolvido

Dados três estados de encoder (mais ou menos "cat", "sat", "mat") e um estado de decoder que se alinha mais com o primeiro, a distribuição de atenção se concentra na posição 0. Se o estado do decoder mudar para se alinhar com o último, a atenção se move para a posição 2. O vetor de contexto acompanha.

H = np.array([
    [1.0, 0.0, 0.2],
    [0.5, 0.5, 0.1],
    [0.1, 0.9, 0.3],
])

s_close_to_cat = np.array([0.9, 0.1, 0.2])
ctx, w = dot_attention(s_close_to_cat, H)
print("weights:", w.round(3))
weights: [0.464 0.305 0.231]

A primeira linha vence. Depois mova o estado do decoder para mais perto do terceiro estado do encoder e veja os pesos mudarem. É isso. Atenção é alinhamento explícito.

Passo 4: por que isto é a ponte para os transformers

Traduza a linguagem acima para Q/K/V:

  • Query = estado do decoder s_{t-1}
  • Key = estados do encoder (contra o que pontuamos)
  • Value = estados do encoder (o que ponderamos e somamos)

Na atenção clássica, keys e values são a mesma coisa. A self-attention as separa: você pode consultar uma sequência contra ela mesma, com projeções aprendidas diferentes para K e V. A multi-head attention a executa em paralelo com projeções aprendidas diferentes. Os transformers empilham o estágio inteiro muitas vezes e abandonam as RNNs.

A matemática é a mesma. Os shapes são os mesmos. O salto pedagógico da atenção de Bahdanau para a atenção de produto escalar escalonado (scaled dot-product) é, na maior parte, notação.

Use

PyTorch e TensorFlow trazem atenção pronta.

import torch
import torch.nn as nn

mha = nn.MultiheadAttention(embed_dim=128, num_heads=8, batch_first=True)
query = torch.randn(2, 5, 128)
key = torch.randn(2, 10, 128)
value = torch.randn(2, 10, 128)

output, weights = mha(query, key, value)
print(output.shape, weights.shape)
torch.Size([2, 5, 128]) torch.Size([2, 5, 10])

Essa é uma camada de atenção de transformer. Batch de query com 5 posições, batch de key/value com 10 posições, 128 dimensões cada, 8 heads. output são as novas queries aumentadas com contexto. weights é a matriz de alinhamento 5x10 que você pode visualizar.

Quando a atenção clássica ainda importa

  • Pedagogia. A versão de uma única head, uma única camada, baseada em RNN torna cada conceito visível.
  • Tarefas de sequência on-device onde os transformers não cabem.
  • Qualquer artigo de 2014-2017. Você vai interpretá-lo errado sem conhecer a convenção de Bahdanau.
  • Análise de alinhamento fino em MT (tradução automática). Os pesos de atenção brutos são uma ferramenta de interpretabilidade mesmo em modelos transformer, e lê-los exige saber o que eles são.

A armadilha do peso-de-atenção-como-explicação

Os pesos de atenção parecem interpretáveis. São pesos que somam um entre as posições; você pode plotá-los; alto significa "olhou para isto". Os revisores os adoram.

Eles não são tão interpretáveis quanto parecem. Jain e Wallace (2019) mostraram que as distribuições de atenção podem ser permutadas e substituídas por alternativas arbitrárias sem mudar as predições do modelo para algumas tarefas. Nunca reporte pesos de atenção como evidência de raciocínio sem uma checagem de ablação ou contrafactual.

Entregue

Salve como outputs/prompt-attention-shapes.md:

---
name: attention-shapes
description: Debug shape bugs in attention implementations.
phase: 5
lesson: 10
---

Given a broken attention implementation, you identify the shape mismatch. Output:

1. Which matrix has the wrong shape. Name the tensor.
2. What its shape should be, derived from (d_s, d_h, d_attn, T_enc, T_dec, batch_size).
3. One-line fix. Transpose, reshape, or project.
4. A test to catch regressions. Typically: assert `output.shape == (batch, T_dec, d_h)` and `weights.shape == (batch, T_dec, T_enc)` and `weights.sum(dim=-1) close to 1`.

Refuse to recommend fixes that silently broadcast. Broadcast-hiding bugs surface later as silent accuracy degradation, the worst kind of attention bug.

For Bahdanau confusion, insist the decoder input is `s_{t-1}` (pre-step state). For Luong, `s_t` (post-step state). For dot-product, flag dimension mismatch between query and key as the most common first-time error.

Exercícios

  1. Fácil. Implemente o mascaramento do softmax para que os tokens de padding no encoder recebam peso de atenção zero. Teste em um batch com sequências de comprimento variável.
  2. Médio. Adicione multi-head attention à forma general de Luong. Divida d_h em n_heads grupos, execute a atenção por head e concatene. Verifique se o caso de uma única head corresponde à sua implementação anterior.
  3. Difícil. Treine um encoder-decoder GRU com atenção de Bahdanau na tarefa de cópia de brinquedo da lição 09. Plote a acurácia versus o comprimento da sequência. Compare com o baseline sem atenção. Você deve ver a diferença aumentar conforme o comprimento cresce, confirmando que a atenção alivia o gargalo.

Termos-Chave

Termo O que as pessoas dizem O que realmente significa
Atenção Olhar para as coisas Média ponderada de uma sequência de values, com pesos calculados a partir de uma similaridade query-key.
Query, Key, Value QKV Três projeções: Q pergunta, K é o que casar, V é o que retornar.
Atenção aditiva Bahdanau Pontuação feed-forward: v^T tanh(W q + U k).
Atenção multiplicativa Luong dot / general A pontuação é q^T k ou q^T W k. Mais barata, mesma acurácia na maioria das tarefas.
Matriz de alinhamento A figura bonita Pesos de atenção como uma grade (T_dec, T_enc). Leia-a para ver para onde o modelo prestou atenção.

Leitura Adicional

0 lifetime access. Curriculum based on AI Engineering from Scratch by Rohit Ghumare (MIT, used under attribution).