Phase 05 - Lesson 10
Mecanismo de Atención — El Gran Avance
El decodificador deja de entrecerrar los ojos ante un resumen comprimido y empieza a mirar toda la fuente. Todo lo que viene después es atención más ingeniería.
Type: Build Languages: Python Prerequisites: Phase 5 · 09 (Sequence-to-Sequence Models) Time: ~45 minutes
El Problema
La lección 09 terminó con un fracaso calculado. Un codificador-decodificador GRU entrenado en una tarea de copia de juguete pasa de un 89% de exactitud con longitud 5 a casi azar con longitud 80. La razón es estructural, no un error de entrenamiento: cada bit de información que el codificador recopiló tiene que caber en un único estado oculto de tamaño fijo, y el decodificador nunca ve nada más.
Bahdanau, Cho y Bengio publicaron una corrección de tres líneas en 2014. En lugar de dar al decodificador solo el estado final del codificador, conserva todos los estados del codificador. En cada paso del decodificador, calcula un promedio ponderado de los estados del codificador donde los pesos dicen "¿cuánto necesita el decodificador mirar la posición i del codificador ahora mismo?" Ese promedio ponderado es el contexto, y cambia en cada paso del decodificador.
Esa es toda la idea. Los Transformers la extendieron. La self-attention la aplicó a una sola secuencia. La multi-head attention la ejecutó en paralelo. Pero la versión de 2014 ya había roto el cuello de botella, y una vez que la tienes, el giro hacia los transformers es ingeniería, no concepto.
El Concepto
En cada paso t del decodificador:
- Usa el estado oculto anterior del decodificador
s_{t-1}como una query. - Puntúalo contra cada estado oculto del codificador
h_1, ..., h_T. Un escalar por posición del codificador. - Aplica softmax a las puntuaciones para obtener los pesos de atención
α_{t,1}, ..., α_{t,T}que suman 1. - Vector de contexto
c_t = Σ α_{t,i} * h_i. Promedio ponderado de los estados del codificador. - El decodificador toma
c_tmás el token de salida anterior y produce el siguiente token.
El promedio ponderado es lo esencial. Cuando el decodificador necesita traducir "Je" a "I", pondera alto el estado del codificador sobre "Je" y bajo los demás. Cuando necesita "not", pondera alto "pas". El vector de contexto se reconfigura en cada paso.
Shapes (lo que hace tropezar a todo el mundo)
Aquí es donde toda implementación de atención sale mal la primera vez. Lee despacio.
| Cosa | Shape | Notas |
|---|---|---|
Estados ocultos del codificador H |
(T_enc, d_h) |
Si es BiLSTM, d_h = 2 * d_hidden |
Estado oculto del decodificador s_{t-1} |
(d_s,) |
Un vector |
Puntuación de atención e_{t,i} |
escalar | Una por posición del codificador |
Peso de atención α_{t,i} |
escalar | Tras el softmax sobre todos los i |
Vector de contexto c_t |
(d_h,) |
Mismo shape que un estado del codificador |
Puntuación de Bahdanau (aditiva). e_{t,i} = v_α^T * tanh(W_a * s_{t-1} + U_a * h_i).
s_{t-1}tiene shape(d_s,),h_itiene shape(d_h,).W_atiene shape(d_attn, d_s).U_atiene shape(d_attn, d_h).- Su suma dentro del tanh tiene shape
(d_attn,). v_αtiene shape(d_attn,). El producto interno conv_αcolapsa a un escalar. Eso es lo que hacev_α. No es magia. Es la proyección que convierte un vector de dimensión de atención en una puntuación escalar.
Puntuación de Luong (multiplicativa). Tres variantes:
dot:e_{t,i} = s_t^T * h_i. Requiered_s == d_h. Restricción dura. Sáltala si tu codificador es bidireccional.general:e_{t,i} = s_t^T * W * h_iconWde shape(d_s, d_h). Elimina la restricción de dimensiones iguales.concat: esencialmente la forma de Bahdanau. Rara vez se usa, ya que las dos primeras son más baratas.
Un detalle de Bahdanau / Luong que vale la pena nombrar. Bahdanau usa s_{t-1} (el estado del decodificador antes de generar la palabra actual). Luong usa s_t (el estado después). Confundirlos produce gradientes sutilmente erróneos que son extremadamente difíciles de depurar. Elige un artículo y apégate a su convención.
Constrúyelo
Paso 1: atención aditiva (Bahdanau)
import numpy as np
def additive_attention(decoder_state, encoder_states, W_a, U_a, v_a):
projected_dec = W_a @ decoder_state
projected_enc = encoder_states @ U_a.T
combined = np.tanh(projected_enc + projected_dec)
scores = combined @ v_a
weights = softmax(scores)
context = weights @ encoder_states
return context, weights
def softmax(x):
x = x - np.max(x)
e = np.exp(x)
return e / e.sum()
Verifica tus shapes contra la tabla de arriba. encoder_states tiene shape (T_enc, d_h). projected_enc tiene shape (T_enc, d_attn). projected_dec tiene shape (d_attn,) y se propaga por broadcasting. combined tiene shape (T_enc, d_attn). scores tiene shape (T_enc,). weights tiene shape (T_enc,). context tiene shape (d_h,). Listo para enviar.
Paso 2: Luong dot y general
def dot_attention(decoder_state, encoder_states):
scores = encoder_states @ decoder_state
weights = softmax(scores)
return weights @ encoder_states, weights
def general_attention(decoder_state, encoder_states, W):
projected = W.T @ decoder_state
scores = encoder_states @ projected
weights = softmax(scores)
return weights @ encoder_states, weights
Tres líneas cada una. Por eso el artículo de Luong tuvo éxito. La misma exactitud en la mayoría de las tareas, mucho menos código.
Paso 3: un ejemplo numérico resuelto
Dados tres estados de codificador (más o menos "cat", "sat", "mat") y un estado de decodificador que se alinea más con el primero, la distribución de atención se concentra en la posición 0. Si el estado del decodificador cambia para alinearse con el último, la atención se mueve a la posición 2. El vector de contexto lo sigue.
H = np.array([
[1.0, 0.0, 0.2],
[0.5, 0.5, 0.1],
[0.1, 0.9, 0.3],
])
s_close_to_cat = np.array([0.9, 0.1, 0.2])
ctx, w = dot_attention(s_close_to_cat, H)
print("weights:", w.round(3))
weights: [0.464 0.305 0.231]
Gana la primera fila. Luego mueve el estado del decodificador más cerca del tercer estado del codificador y observa cómo cambian los pesos. Eso es todo. La atención es alineación explícita.
Paso 4: por qué esto es el puente hacia los transformers
Traduce el lenguaje de arriba a Q/K/V:
- Query = estado del decodificador
s_{t-1} - Key = estados del codificador (contra lo que puntuamos)
- Value = estados del codificador (lo que ponderamos y sumamos)
En la atención clásica, las keys y los values son lo mismo. La self-attention los separa: puedes consultar una secuencia contra sí misma, con proyecciones aprendidas distintas para K y V. La multi-head attention la ejecuta en paralelo con proyecciones aprendidas distintas. Los transformers apilan toda la etapa muchas veces y descartan las RNN.
La matemática es la misma. Los shapes son los mismos. El salto pedagógico de la atención de Bahdanau a la atención de producto escalar escalado (scaled dot-product) es, en su mayor parte, notación.
Úsalo
PyTorch y TensorFlow traen la atención lista para usar.
import torch
import torch.nn as nn
mha = nn.MultiheadAttention(embed_dim=128, num_heads=8, batch_first=True)
query = torch.randn(2, 5, 128)
key = torch.randn(2, 10, 128)
value = torch.randn(2, 10, 128)
output, weights = mha(query, key, value)
print(output.shape, weights.shape)
torch.Size([2, 5, 128]) torch.Size([2, 5, 10])
Esa es una capa de atención de transformer. Batch de query con 5 posiciones, batch de key/value con 10 posiciones, 128 dimensiones cada una, 8 heads. output son las nuevas queries aumentadas con contexto. weights es la matriz de alineación 5x10 que puedes visualizar.
Cuándo la atención clásica todavía importa
- Pedagogía. La versión de una sola head, una sola capa, basada en RNN hace visible cada concepto.
- Tareas de secuencia on-device donde los transformers no caben.
- Cualquier artículo de 2014-2017. Lo interpretarás mal sin conocer la convención de Bahdanau.
- Análisis de alineación fina en MT (traducción automática). Los pesos de atención crudos son una herramienta de interpretabilidad incluso en modelos transformer, y leerlos requiere saber qué son.
La trampa del peso-de-atención-como-explicación
Los pesos de atención parecen interpretables. Son pesos que suman uno entre las posiciones; puedes graficarlos; alto significa "miró esto". A los revisores les encantan.
No son tan interpretables como parecen. Jain y Wallace (2019) mostraron que las distribuciones de atención pueden permutarse y reemplazarse por alternativas arbitrarias sin cambiar las predicciones del modelo en algunas tareas. Nunca reportes los pesos de atención como evidencia de razonamiento sin una comprobación de ablación o contrafactual.
Entrégalo
Guarda como outputs/prompt-attention-shapes.md:
---
name: attention-shapes
description: Debug shape bugs in attention implementations.
phase: 5
lesson: 10
---
Given a broken attention implementation, you identify the shape mismatch. Output:
1. Which matrix has the wrong shape. Name the tensor.
2. What its shape should be, derived from (d_s, d_h, d_attn, T_enc, T_dec, batch_size).
3. One-line fix. Transpose, reshape, or project.
4. A test to catch regressions. Typically: assert `output.shape == (batch, T_dec, d_h)` and `weights.shape == (batch, T_dec, T_enc)` and `weights.sum(dim=-1) close to 1`.
Refuse to recommend fixes that silently broadcast. Broadcast-hiding bugs surface later as silent accuracy degradation, the worst kind of attention bug.
For Bahdanau confusion, insist the decoder input is `s_{t-1}` (pre-step state). For Luong, `s_t` (post-step state). For dot-product, flag dimension mismatch between query and key as the most common first-time error.
Ejercicios
- Fácil. Implementa el enmascaramiento del
softmaxpara que los tokens de padding en el codificador reciban peso de atención cero. Pruébalo en un batch con secuencias de longitud variable. - Medio. Agrega multi-head attention a la forma
generalde Luong. Divided_henn_headsgrupos, ejecuta la atención por head y concaténalas. Verifica que el caso de una sola head coincida con tu implementación anterior. - Difícil. Entrena un codificador-decodificador GRU con atención de Bahdanau en la tarea de copia de juguete de la lección 09. Grafica la exactitud frente a la longitud de la secuencia. Compárala con el baseline sin atención. Deberías ver que la brecha se amplía a medida que crece la longitud, confirmando que la atención alivia el cuello de botella.
Términos Clave
| Término | Lo que dice la gente | Lo que realmente significa |
|---|---|---|
| Atención | Mirar las cosas | Promedio ponderado de una secuencia de values, con pesos calculados a partir de una similitud query-key. |
| Query, Key, Value | QKV | Tres proyecciones: Q pregunta, K es lo que hay que emparejar, V es lo que hay que devolver. |
| Atención aditiva | Bahdanau | Puntuación feed-forward: v^T tanh(W q + U k). |
| Atención multiplicativa | Luong dot / general | La puntuación es q^T k o q^T W k. Más barata, misma exactitud en la mayoría de las tareas. |
| Matriz de alineación | La imagen bonita | Pesos de atención como una cuadrícula (T_dec, T_enc). Léela para ver a qué prestó atención el modelo. |
Lecturas Adicionales
- Bahdanau, Cho, Bengio (2014). Neural Machine Translation by Jointly Learning to Align and Translate — el artículo.
- Luong, Pham, Manning (2015). Effective Approaches to Attention-based Neural Machine Translation — las tres variantes de puntuación y su comparación.
- Jain and Wallace (2019). Attention is not Explanation — la advertencia sobre interpretabilidad.
- Dive into Deep Learning — Bahdanau Attention — recorrido ejecutable con PyTorch.