Phase 07 - Lesson 02
Self-Attention desde Cero
This lesson includes a graded coding exercise that runs in your browser, unlocked with lifetime access.
La atención es una tabla de búsqueda donde cada palabra pregunta "¿quién me importa?" - y aprende la respuesta.
Tipo: Build Lenguajes: Python Prerrequisitos: Fase 3 (Núcleo de Deep Learning), Fase 5 Lección 10 (Secuencia-a-Secuencia) Tiempo: ~90 minutos
Objetivos de Aprendizaje
- Implementar self-attention con producto punto escalado desde cero usando solo NumPy, incluyendo las proyecciones de query/key/value y la suma ponderada por softmax
- Construir una capa de multi-head attention que divide las cabezas, calcula la atención en paralelo y concatena los resultados
- Rastrear cómo la matriz de atención captura las relaciones entre tokens y explicar por qué escalar por sqrt(d_k) evita la saturación del softmax
- Aplicar enmascaramiento causal para convertir la atención bidireccional en atención autorregresiva (estilo decoder)
El Problema
Las RNN procesan secuencias un token a la vez. Para cuando llegas al token 50, la información del token 1 ya pasó por 50 pasos de compresión. Las dependencias de largo alcance terminan aplastadas en un estado oculto de tamaño fijo - un cuello de botella que ninguna cantidad de gating de LSTM resuelve por completo.
El paper de atención de Bahdanau de 2014 mostró la solución: dejar que el decoder mire hacia atrás a cada posición del encoder y decida cuáles importan para el paso actual. Pero seguía estando atornillado a una RNN. El paper de 2017 "Attention Is All You Need" planteó una pregunta más aguda: ¿qué pasa si la atención es el único mecanismo? Sin recurrencia. Sin convolución. Solo atención.
La self-attention permite que cada posición de una secuencia atienda a todas las demás posiciones en un solo paso paralelo. Eso es lo que hace que los transformers sean rápidos, escalables y dominantes.
El Concepto
La Analogía de la Búsqueda en Base de Datos
Piensa en la atención como una búsqueda suave en una base de datos:
Traditional database:
Query: "capital of France" --> exact match --> "Paris"
Attention:
Query: "capital of France" --> similarity to ALL keys --> weighted blend of ALL values
Cada token genera tres vectores:
- Query (Q): "¿Qué estoy buscando?"
- Key (K): "¿Qué contengo?"
- Value (V): "¿Qué información proporciono si soy seleccionado?"
El producto punto entre una query y todas las keys produce los puntajes de atención. Un puntaje alto significa "esta key coincide con mi query". Esos puntajes ponderan los values. La salida es una suma ponderada de los values.
Cálculo de Q, K, V
Cada embedding de token se proyecta a través de tres matrices de pesos aprendidas:
Input embeddings (sequence of n tokens, each d-dimensional):
X = [x1, x2, x3, ..., xn] shape: (n, d)
Three weight matrices:
Wq shape: (d, dk)
Wk shape: (d, dk)
Wv shape: (d, dv)
Projections:
Q = X @ Wq shape: (n, dk) each token's query
K = X @ Wk shape: (n, dk) each token's key
V = X @ Wv shape: (n, dv) each token's value
Visualmente, para un token:
Wq
x_i ------[*]------> q_i "What am I looking for?"
|
| Wk
+----[*]------> k_i "What do I contain?"
|
| Wv
+----[*]------> v_i "What do I offer?"
La Matriz de Atención
Una vez que tienes Q, K, V para todos los tokens, los puntajes de atención forman una matriz:
Scores = Q @ K^T shape: (n, n)
k1 k2 k3 k4 k5
+-----+-----+-----+-----+-----+
q1 | 2.1 | 0.3 | 0.1 | 0.8 | 0.2 | <- how much q1 attends to each key
+-----+-----+-----+-----+-----+
q2 | 0.4 | 1.9 | 0.7 | 0.1 | 0.3 |
+-----+-----+-----+-----+-----+
q3 | 0.2 | 0.6 | 2.3 | 0.5 | 0.1 |
+-----+-----+-----+-----+-----+
q4 | 0.9 | 0.1 | 0.4 | 1.7 | 0.6 |
+-----+-----+-----+-----+-----+
q5 | 0.1 | 0.3 | 0.2 | 0.5 | 2.0 |
+-----+-----+-----+-----+-----+
Each row: one token's attention over the entire sequence
¿Por Qué Escalar?
Los productos punto crecen con la dimensión dk. Si dk = 64, los productos punto pueden estar en el rango de las decenas, empujando al softmax hacia regiones donde los gradientes se desvanecen. La solución: dividir por sqrt(dk).
Scaled scores = (Q @ K^T) / sqrt(dk)
Esto mantiene los valores en un rango donde el softmax produce gradientes útiles.
El Softmax Convierte Puntajes en Pesos
El softmax convierte los puntajes brutos en una distribución de probabilidad a lo largo de cada fila:
Raw scores for q1: [2.1, 0.3, 0.1, 0.8, 0.2]
|
softmax
|
Attention weights: [0.52, 0.09, 0.07, 0.14, 0.08] (sums to ~1.0)
Ahora cada token tiene un conjunto de pesos que indican cuánto atender a cada otro token.
Suma Ponderada de los Values
La salida final para cada token es una suma ponderada de todos los vectores de value:
output_i = sum( attention_weight[i][j] * v_j for all j )
For token 1:
output_1 = 0.52 * v1 + 0.09 * v2 + 0.07 * v3 + 0.14 * v4 + 0.08 * v5
Pipeline Completo
+-------+
X (input) ----->| @ Wq |-----> Q
+-------+
+-------+
X (input) ----->| @ Wk |-----> K
+-------+ +----------+
+-------+ | |
X (input) ----->| @ Wv |-----> V ---------->| weighted |----> output
+-------+ ^ | sum |
| +----------+
+--------+--------+
| softmax |
+---------+-------+
^
+---------+-------+
| Q @ K^T / sqrt |
+-----------------+
Fórmula en una línea:
Attention(Q, K, V) = softmax( Q @ K^T / sqrt(dk) ) @ V
Constrúyelo
Paso 1: Softmax desde cero
El softmax convierte los logits brutos en probabilidades. Resta el máximo para lograr estabilidad numérica.
import numpy as np
def softmax(x):
shifted = x - np.max(x, axis=-1, keepdims=True)
exp_x = np.exp(shifted)
return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
logits = np.array([2.0, 1.0, 0.1])
print(f"logits: {logits}")
print(f"softmax: {softmax(logits)}")
print(f"sum: {softmax(logits).sum():.4f}")
Paso 2: Atención con producto punto escalado
La función central. Recibe las matrices Q, K, V y devuelve la salida de la atención más la matriz de pesos.
def scaled_dot_product_attention(Q, K, V):
dk = Q.shape[-1]
scores = Q @ K.T / np.sqrt(dk)
weights = softmax(scores)
output = weights @ V
return output, weights
Paso 3: Clase de self-attention con proyecciones aprendidas
Un módulo completo de self-attention con matrices de pesos Wq, Wk, Wv inicializadas con escalado estilo Xavier.
class SelfAttention:
def __init__(self, d_model, dk, dv, seed=42):
rng = np.random.default_rng(seed)
scale = np.sqrt(2.0 / (d_model + dk))
self.Wq = rng.normal(0, scale, (d_model, dk))
self.Wk = rng.normal(0, scale, (d_model, dk))
scale_v = np.sqrt(2.0 / (d_model + dv))
self.Wv = rng.normal(0, scale_v, (d_model, dv))
self.dk = dk
def forward(self, X):
Q = X @ self.Wq
K = X @ self.Wk
V = X @ self.Wv
output, weights = scaled_dot_product_attention(Q, K, V)
return output, weights
Paso 4: Ejecútalo sobre una oración
Crea embeddings ficticios para una oración y observa los pesos de atención.
sentence = ["The", "cat", "sat", "on", "the", "mat"]
n_tokens = len(sentence)
d_model = 8
dk = 4
dv = 4
rng = np.random.default_rng(42)
X = rng.normal(0, 1, (n_tokens, d_model))
attn = SelfAttention(d_model, dk, dv, seed=42)
output, weights = attn.forward(X)
print("Attention weights (each row: where that token looks):\n")
print(f"{'':>6}", end="")
for token in sentence:
print(f"{token:>6}", end="")
print()
for i, token in enumerate(sentence):
print(f"{token:>6}", end="")
for j in range(n_tokens):
w = weights[i][j]
print(f"{w:6.3f}", end="")
print()
Paso 5: Visualiza la atención con un heatmap ASCII
Mapea los pesos de atención a caracteres para obtener una visualización rápida.
def ascii_heatmap(weights, tokens, chars=" ░▒▓█"):
n = len(tokens)
print(f"\n{'':>6}", end="")
for t in tokens:
print(f"{t:>6}", end="")
print()
for i in range(n):
print(f"{tokens[i]:>6}", end="")
for j in range(n):
level = int(weights[i][j] * (len(chars) - 1) / weights.max())
level = min(level, len(chars) - 1)
print(f"{' ' + chars[level] + ' '}", end="")
print()
ascii_heatmap(weights, sentence)
Úsalo
El nn.MultiheadAttention de PyTorch hace exactamente lo que construimos, más la división en múltiples cabezas y la proyección de salida:
import torch
import torch.nn as nn
d_model = 8
n_heads = 2
seq_len = 6
mha = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, batch_first=True)
X_torch = torch.randn(1, seq_len, d_model)
output, attn_weights = mha(X_torch, X_torch, X_torch)
print(f"Input shape: {X_torch.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weight shape: {attn_weights.shape}")
print(f"\nAttn weights (averaged over heads):")
print(attn_weights[0].detach().numpy().round(3))
La diferencia clave: la multi-head attention ejecuta múltiples funciones de atención en paralelo, cada una con sus propias proyecciones Q, K, V de tamaño dk = d_model / n_heads, y luego concatena los resultados. Esto permite que el modelo atienda a distintos tipos de relación simultáneamente.
Entrégalo
Esta lección produce:
outputs/prompt-attention-explainer.md- un prompt para explicar la atención a través de la analogía de la búsqueda en base de datos
Ejercicios
- Modifica
scaled_dot_product_attentionpara que acepte una matriz de máscara opcional que establezca ciertas posiciones en infinito negativo antes del softmax (así es como funciona el enmascaramiento causal/de decoder) - Implementa la multi-head attention desde cero: divide Q, K, V en
n_headsbloques, ejecuta la atención en cada uno, concatena y proyecta a través de una matriz de pesos final Wo - Toma dos oraciones diferentes de la misma longitud, pásalas por la misma instancia de SelfAttention y compara sus patrones de atención. ¿Qué cambia? ¿Qué permanece igual?
Términos Clave
| Término | Lo que la gente dice | Lo que realmente significa |
|---|---|---|
| Query (Q) | "El vector de pregunta" | Una proyección aprendida de la entrada que representa qué información está buscando este token |
| Key (K) | "El vector de etiqueta" | Una proyección aprendida que representa qué información contiene este token, comparada contra las queries |
| Value (V) | "El vector de contenido" | Una proyección aprendida que lleva la información real que se agrega con base en los puntajes de atención |
| Atención con producto punto escalado | "La fórmula de la atención" | softmax(QK^T / sqrt(dk)) @ V - el escalado evita la saturación del softmax en dimensiones altas |
| Self-attention | "El token se mira a sí mismo y a los demás" | Atención donde Q, K, V provienen todos de la misma secuencia, permitiendo que cada posición atienda a todas las demás posiciones |
| Pesos de atención | "Cuánto foco" | Una distribución de probabilidad sobre las posiciones, producida por el softmax sobre los productos punto escalados |
| Multi-head attention | "Atención en paralelo" | Ejecutar múltiples funciones de atención con proyecciones diferentes y luego concatenar los resultados para obtener representaciones más ricas |
Lecturas Adicionales
- Attention Is All You Need (Vaswani et al., 2017) - el paper original del transformer
- The Illustrated Transformer (Jay Alammar) - el mejor recorrido visual de la arquitectura completa
- The Annotated Transformer (Harvard NLP) - implementación línea por línea en PyTorch con explicaciones