Phase 07 - Lesson 02

Self-Attention desde Cero

This lesson includes a graded coding exercise that runs in your browser, unlocked with lifetime access.

La atención es una tabla de búsqueda donde cada palabra pregunta "¿quién me importa?" - y aprende la respuesta.

Tipo: Build Lenguajes: Python Prerrequisitos: Fase 3 (Núcleo de Deep Learning), Fase 5 Lección 10 (Secuencia-a-Secuencia) Tiempo: ~90 minutos

Objetivos de Aprendizaje

  • Implementar self-attention con producto punto escalado desde cero usando solo NumPy, incluyendo las proyecciones de query/key/value y la suma ponderada por softmax
  • Construir una capa de multi-head attention que divide las cabezas, calcula la atención en paralelo y concatena los resultados
  • Rastrear cómo la matriz de atención captura las relaciones entre tokens y explicar por qué escalar por sqrt(d_k) evita la saturación del softmax
  • Aplicar enmascaramiento causal para convertir la atención bidireccional en atención autorregresiva (estilo decoder)

El Problema

Las RNN procesan secuencias un token a la vez. Para cuando llegas al token 50, la información del token 1 ya pasó por 50 pasos de compresión. Las dependencias de largo alcance terminan aplastadas en un estado oculto de tamaño fijo - un cuello de botella que ninguna cantidad de gating de LSTM resuelve por completo.

El paper de atención de Bahdanau de 2014 mostró la solución: dejar que el decoder mire hacia atrás a cada posición del encoder y decida cuáles importan para el paso actual. Pero seguía estando atornillado a una RNN. El paper de 2017 "Attention Is All You Need" planteó una pregunta más aguda: ¿qué pasa si la atención es el único mecanismo? Sin recurrencia. Sin convolución. Solo atención.

La self-attention permite que cada posición de una secuencia atienda a todas las demás posiciones en un solo paso paralelo. Eso es lo que hace que los transformers sean rápidos, escalables y dominantes.

El Concepto

La Analogía de la Búsqueda en Base de Datos

Piensa en la atención como una búsqueda suave en una base de datos:

Traditional database:
  Query: "capital of France"  -->  exact match  -->  "Paris"

Attention:
  Query: "capital of France"  -->  similarity to ALL keys  -->  weighted blend of ALL values

Cada token genera tres vectores:

  • Query (Q): "¿Qué estoy buscando?"
  • Key (K): "¿Qué contengo?"
  • Value (V): "¿Qué información proporciono si soy seleccionado?"

El producto punto entre una query y todas las keys produce los puntajes de atención. Un puntaje alto significa "esta key coincide con mi query". Esos puntajes ponderan los values. La salida es una suma ponderada de los values.

Cálculo de Q, K, V

Cada embedding de token se proyecta a través de tres matrices de pesos aprendidas:

Input embeddings (sequence of n tokens, each d-dimensional):

  X = [x1, x2, x3, ..., xn]       shape: (n, d)

Three weight matrices:

  Wq  shape: (d, dk)
  Wk  shape: (d, dk)
  Wv  shape: (d, dv)

Projections:

  Q = X @ Wq    shape: (n, dk)      each token's query
  K = X @ Wk    shape: (n, dk)      each token's key
  V = X @ Wv    shape: (n, dv)      each token's value

Visualmente, para un token:

             Wq
  x_i ------[*]------> q_i    "What am I looking for?"
       |
       |     Wk
       +----[*]------> k_i    "What do I contain?"
       |
       |     Wv
       +----[*]------> v_i    "What do I offer?"

La Matriz de Atención

Una vez que tienes Q, K, V para todos los tokens, los puntajes de atención forman una matriz:

Scores = Q @ K^T    shape: (n, n)

              k1    k2    k3    k4    k5
        +-----+-----+-----+-----+-----+
   q1   | 2.1 | 0.3 | 0.1 | 0.8 | 0.2 |   <- how much q1 attends to each key
        +-----+-----+-----+-----+-----+
   q2   | 0.4 | 1.9 | 0.7 | 0.1 | 0.3 |
        +-----+-----+-----+-----+-----+
   q3   | 0.2 | 0.6 | 2.3 | 0.5 | 0.1 |
        +-----+-----+-----+-----+-----+
   q4   | 0.9 | 0.1 | 0.4 | 1.7 | 0.6 |
        +-----+-----+-----+-----+-----+
   q5   | 0.1 | 0.3 | 0.2 | 0.5 | 2.0 |
        +-----+-----+-----+-----+-----+

Each row: one token's attention over the entire sequence

¿Por Qué Escalar?

Los productos punto crecen con la dimensión dk. Si dk = 64, los productos punto pueden estar en el rango de las decenas, empujando al softmax hacia regiones donde los gradientes se desvanecen. La solución: dividir por sqrt(dk).

Scaled scores = (Q @ K^T) / sqrt(dk)

Esto mantiene los valores en un rango donde el softmax produce gradientes útiles.

El Softmax Convierte Puntajes en Pesos

El softmax convierte los puntajes brutos en una distribución de probabilidad a lo largo de cada fila:

Raw scores for q1:   [2.1, 0.3, 0.1, 0.8, 0.2]
                            |
                         softmax
                            |
Attention weights:   [0.52, 0.09, 0.07, 0.14, 0.08]   (sums to ~1.0)

Ahora cada token tiene un conjunto de pesos que indican cuánto atender a cada otro token.

Suma Ponderada de los Values

La salida final para cada token es una suma ponderada de todos los vectores de value:

output_i = sum( attention_weight[i][j] * v_j  for all j )

For token 1:
  output_1 = 0.52 * v1 + 0.09 * v2 + 0.07 * v3 + 0.14 * v4 + 0.08 * v5

Pipeline Completo

                    +-------+
  X (input)  ----->|  @ Wq  |-----> Q
                    +-------+
                    +-------+
  X (input)  ----->|  @ Wk  |-----> K
                    +-------+                     +----------+
                    +-------+                     |          |
  X (input)  ----->|  @ Wv  |-----> V ---------->| weighted |----> output
                    +-------+          ^          |   sum    |
                                       |          +----------+
                              +--------+--------+
                              |    softmax      |
                              +---------+-------+
                                        ^
                              +---------+-------+
                              | Q @ K^T / sqrt  |
                              +-----------------+

Fórmula en una línea:

Attention(Q, K, V) = softmax( Q @ K^T / sqrt(dk) ) @ V

Constrúyelo

Paso 1: Softmax desde cero

El softmax convierte los logits brutos en probabilidades. Resta el máximo para lograr estabilidad numérica.

import numpy as np

def softmax(x):
    shifted = x - np.max(x, axis=-1, keepdims=True)
    exp_x = np.exp(shifted)
    return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

logits = np.array([2.0, 1.0, 0.1])
print(f"logits:  {logits}")
print(f"softmax: {softmax(logits)}")
print(f"sum:     {softmax(logits).sum():.4f}")

Paso 2: Atención con producto punto escalado

La función central. Recibe las matrices Q, K, V y devuelve la salida de la atención más la matriz de pesos.

def scaled_dot_product_attention(Q, K, V):
    dk = Q.shape[-1]
    scores = Q @ K.T / np.sqrt(dk)
    weights = softmax(scores)
    output = weights @ V
    return output, weights

Paso 3: Clase de self-attention con proyecciones aprendidas

Un módulo completo de self-attention con matrices de pesos Wq, Wk, Wv inicializadas con escalado estilo Xavier.

class SelfAttention:
    def __init__(self, d_model, dk, dv, seed=42):
        rng = np.random.default_rng(seed)
        scale = np.sqrt(2.0 / (d_model + dk))
        self.Wq = rng.normal(0, scale, (d_model, dk))
        self.Wk = rng.normal(0, scale, (d_model, dk))
        scale_v = np.sqrt(2.0 / (d_model + dv))
        self.Wv = rng.normal(0, scale_v, (d_model, dv))
        self.dk = dk

    def forward(self, X):
        Q = X @ self.Wq
        K = X @ self.Wk
        V = X @ self.Wv
        output, weights = scaled_dot_product_attention(Q, K, V)
        return output, weights

Paso 4: Ejecútalo sobre una oración

Crea embeddings ficticios para una oración y observa los pesos de atención.

sentence = ["The", "cat", "sat", "on", "the", "mat"]
n_tokens = len(sentence)
d_model = 8
dk = 4
dv = 4

rng = np.random.default_rng(42)
X = rng.normal(0, 1, (n_tokens, d_model))

attn = SelfAttention(d_model, dk, dv, seed=42)
output, weights = attn.forward(X)

print("Attention weights (each row: where that token looks):\n")
print(f"{'':>6}", end="")
for token in sentence:
    print(f"{token:>6}", end="")
print()

for i, token in enumerate(sentence):
    print(f"{token:>6}", end="")
    for j in range(n_tokens):
        w = weights[i][j]
        print(f"{w:6.3f}", end="")
    print()

Paso 5: Visualiza la atención con un heatmap ASCII

Mapea los pesos de atención a caracteres para obtener una visualización rápida.

def ascii_heatmap(weights, tokens, chars=" ░▒▓█"):
    n = len(tokens)
    print(f"\n{'':>6}", end="")
    for t in tokens:
        print(f"{t:>6}", end="")
    print()

    for i in range(n):
        print(f"{tokens[i]:>6}", end="")
        for j in range(n):
            level = int(weights[i][j] * (len(chars) - 1) / weights.max())
            level = min(level, len(chars) - 1)
            print(f"{'  ' + chars[level] + '   '}", end="")
        print()

ascii_heatmap(weights, sentence)

Úsalo

El nn.MultiheadAttention de PyTorch hace exactamente lo que construimos, más la división en múltiples cabezas y la proyección de salida:

import torch
import torch.nn as nn

d_model = 8
n_heads = 2
seq_len = 6

mha = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, batch_first=True)

X_torch = torch.randn(1, seq_len, d_model)

output, attn_weights = mha(X_torch, X_torch, X_torch)

print(f"Input shape:            {X_torch.shape}")
print(f"Output shape:           {output.shape}")
print(f"Attention weight shape: {attn_weights.shape}")
print(f"\nAttn weights (averaged over heads):")
print(attn_weights[0].detach().numpy().round(3))

La diferencia clave: la multi-head attention ejecuta múltiples funciones de atención en paralelo, cada una con sus propias proyecciones Q, K, V de tamaño dk = d_model / n_heads, y luego concatena los resultados. Esto permite que el modelo atienda a distintos tipos de relación simultáneamente.

Entrégalo

Esta lección produce:

  • outputs/prompt-attention-explainer.md - un prompt para explicar la atención a través de la analogía de la búsqueda en base de datos

Ejercicios

  1. Modifica scaled_dot_product_attention para que acepte una matriz de máscara opcional que establezca ciertas posiciones en infinito negativo antes del softmax (así es como funciona el enmascaramiento causal/de decoder)
  2. Implementa la multi-head attention desde cero: divide Q, K, V en n_heads bloques, ejecuta la atención en cada uno, concatena y proyecta a través de una matriz de pesos final Wo
  3. Toma dos oraciones diferentes de la misma longitud, pásalas por la misma instancia de SelfAttention y compara sus patrones de atención. ¿Qué cambia? ¿Qué permanece igual?

Términos Clave

Término Lo que la gente dice Lo que realmente significa
Query (Q) "El vector de pregunta" Una proyección aprendida de la entrada que representa qué información está buscando este token
Key (K) "El vector de etiqueta" Una proyección aprendida que representa qué información contiene este token, comparada contra las queries
Value (V) "El vector de contenido" Una proyección aprendida que lleva la información real que se agrega con base en los puntajes de atención
Atención con producto punto escalado "La fórmula de la atención" softmax(QK^T / sqrt(dk)) @ V - el escalado evita la saturación del softmax en dimensiones altas
Self-attention "El token se mira a sí mismo y a los demás" Atención donde Q, K, V provienen todos de la misma secuencia, permitiendo que cada posición atienda a todas las demás posiciones
Pesos de atención "Cuánto foco" Una distribución de probabilidad sobre las posiciones, producida por el softmax sobre los productos punto escalados
Multi-head attention "Atención en paralelo" Ejecutar múltiples funciones de atención con proyecciones diferentes y luego concatenar los resultados para obtener representaciones más ricas

Lecturas Adicionales

0 lifetime access. Curriculum based on AI Engineering from Scratch by Rohit Ghumare (MIT, used under attribution).