Phase 10 - Lesson 34
Checkpoint de Gradiente y Recomputación de Activación
This lesson includes a graded coding exercise that runs in your browser, unlocked with lifetime access.
El backprop conserva cada activación intermedia. Con 70B de parámetros y un contexto de 128K, eso representa 3 TB de activaciones por rank. El checkpoint intercambia FLOPs por memoria: recomputar en lugar de guardar. La pregunta es qué segmentos descartar, y la respuesta no es "todos ellos".
Tipo: Build Idiomas: Python (con numpy, torch opcional) Prerrequisitos: Fase 10 Lección 04 (Pre-Training Mini-GPT), Fase 10 Lección 05 (Scaling & Distributed) Tiempo: ~70 minutos
El Problema
El entrenamiento de un transformer almacena, para cada capa, las entradas de cada operación que se diferencia en el backward: las entradas de atención, las proyecciones Q/K/V, la salida de softmax, las entradas de FFN, las salidas de normalización (norm) y el flujo residual (residual stream). Para una capa con tamaño oculto (hidden size) d, longitud de secuencia L y lote (batch) B, esto está en el orden de 12 * B * L * d floats por capa.
Para d=8192, L=8192, B=1, eso representa 800 MB/capa en BF16. Un modelo de 64 capas equivale a 51 GB de activaciones — y eso antes de multiplicar por el tamaño del micro lote (microbatch size), antes de sumar los intermediarios de atención-softmax (L^2 por cabezal) y antes de considerar las copias parciales de paralelismo de tensor (tensor-parallel).
La factura tiene dos lados: los pesos en BF16 sumados al estado del optimizador pueden caber en 80GB, pero las activaciones exceden ese límite. El checkpoint de gradiente (también conocido como recomputación de activación) es la solución estándar. Descarte la mayoría de las activaciones; vuelva a realizar el forward durante el backward para recuperarlas. Costo: FLOPs adicionales. Beneficio: la memoria disminuye por la proporción entre los segmentos de checkpoint y el total de capas.
Hecho de forma ingenua (naive), el checkpointing cuesta aproximadamente un 33% más de FLOPs de forward por paso. Hecho de forma inteligente — checkpointing selectivo según la "selección inteligente" de Korthikanti et al. — se ahorra 5 veces más memoria por menos del 5% de overhead de FLOPs. Y con multiplicaciones de matrices (matmuls) en FP8, offload de FSDP y MoE con paralelismo de expertos (expert-parallel), esto realmente importa: no puede permitirse desperdiciar ni memoria ni procesamiento.
El Concepto
Lo que el Backward Realmente Necesita
output = layer(input). El backward necesita grad_input y grad_params. Para calcularlos, se requiere:
input(para calculargrad_params = input.T @ grad_outputpara capas lineales)- algunos intermediarios de derivada de activación (la derivada de ReLU/GELU/softmax depende del valor de la activación)
El forward pass almacena estos valores automáticamente en el grafo de autograd. Cada tensor.retain_grad() y cada operación que necesita su entrada retienen una referencia.
Checkpointing Completo Ingenuo (Naive Full Checkpointing)
Divida la red en N segmentos. Durante el forward, almacene solo la entrada de cada segmento. Cuando el backward necesite los intermediarios, vuelva a ejecutar el forward del segmento para materializarlos y, luego, realice la diferenciación.
Ejemplo: transformer de 32 capas dividido en 32 segmentos de 1 capa cada.
- Memoria: 32 entradas de capa (pequeño) vs 32 * (volumen de activación por capa) (gigante).
- Cómputo extra: 1 forward adicional por segmento, es decir, ~33% más de FLOPs de forward en total (como el backward es 2x el forward, el paso completo se convierte en 1 + 1 + 2 = 4 unidades en lugar de 1 + 2 = 3).
Esta es la receta original de Chen et al. 2016: un checkpoint cada sqrt(L) capas para equilibrar memoria y cómputo. Para L=64, son 8 checkpoints.
Checkpointing Selectivo (Korthikanti 2022)
No todas las activaciones cuestan lo mismo. La salida de softmax de atención es B*L*L*heads y crece cuadráticamente con la longitud de la secuencia. La activación oculta de FFN es B*L*4d y crece linealmente. Para secuencias largas, el softmax domina.
El checkpoint selectivo conserva las activaciones baratas de almacenar (proyecciones lineales, conexiones residuales) y recomputa solo las costosas (atención). Se paga una cantidad mínima de FLOPs para recomputar, pero se ahorra la memoria O(L^2).
Megatron-Core implementa esto como recomputación de activación "selectiva". Utilizado en la mayoría de los entrenamientos de modelos de frontera a partir de 2024.
Offload
Alternativa a la recomputación: transferir activaciones a la memoria RAM de la CPU entre el forward y el backward. Requiere ancho de banda PCIe; es beneficioso cuando el ancho de banda inactivo supera el costo de la rematerialización. Las estrategias mixtas son comunes: aplicar checkpoint en algunas capas y realizar offload en otras.
FSDP2 ofrece offload como una opción nativa de primera clase. El offload se destaca cuando la GPU tiene un cuello de botella de memoria, pero la transferencia CPU-GPU tiene margen de maniobra (headroom).
Modelo de Costo de Recomputación
FLOPs por paso con checkpointing ingenuo (naive) cada k capas de un total de L:
flops_fwd_normal = L * f_layer
flops_bwd_normal = 2 * L * f_layer
flops_total_normal = 3 * L * f_layer
flops_fwd_ckpt = L * f_layer
flops_recompute = L * f_layer # one extra forward per layer in the segment
flops_bwd_ckpt = 2 * L * f_layer
flops_total_ckpt = 4 * L * f_layer
overhead = 4 / 3 - 1 = 0.33 = 33%
Con el checkpointing selectivo, se recomputa solo el kernel de atención, no la capa completa:
flops_recompute_selective = L * f_attention ~= L * f_layer * 0.15
overhead_selective = (3 + 0.15) / 3 - 1 = 0.05 = 5%
Modelo de Ahorro de Memoria
Volumen de activación por capa: A. Para L capas, memoria de activación total: L * A.
Checkpoint completo (tamanho do segmento al que se le aplica checkpoint igual a 1): almacena solo L * input_volume (~`L * 1/10 A para un transformer estándar). Ahorra ~9 * L * A * 1/10`.
Checkpoint cada k capas: almacena L/k * A más el equivalente a k-1 capas dentro del segmento activo.
En k = sqrt(L), la memoria y el costo de recomputación escalan con sqrt(L) — el tradeoff ideal para capas con costo uniforme.
Cuándo No Usar Checkpoint
- Las capas más internas de una etapa de pipeline que ya está en ejecución (in-flight). Tienen que terminar de todos modos.
- Las primeras y últimas capas si dominan el cómputo de la etapa (raro en transformers).
- Kernels de atención que ya usan FlashAttention — Flash ya recomputa el softmax rápidamente, por lo que un checkpoint adicional a nivel de capa aporta poco beneficio.
Patrones de Implementación
- Function wrapper: envuelve un segmento en
torch.utils.checkpoint.checkpoint(fn, input). PyTorch almacena soloinputy recomputa todo lo demás en el backward. - Decorator-based: etiqueta las capas como aptas para checkpoint; el trainer decide en el momento de la configuración qué segmentos se envuelven.
- Manual explicit recompute: usted mismo escribe el backward pass, llamando a un
recompute_forwardpersonalizado que duplica el forward con la entrada almacenada.
Los tres proporcionan el mismo resultado funcional. Los wrappers son el idioma estándar.
Interacción con TP / PP / FP8
- Tensor parallel: las entradas del checkpoint deben reunirse (gather) o redistribuirse (scatter) en la recomputación; gestione el costo de comunicación.
- Pipeline parallel: el patrón típico es aplicar checkpoint en el forward de cada etapa del pipeline para que los micro lotes en orden inverso puedan reutilizar la memoria de activación.
- FP8 recompute: los historiales de amax actualizados durante la recomputación deben coincidir con los del forward original, de lo contrario la escala FP8 sufrirá deriva (drift). La mayoría de los frameworks toman un snapshot de la escala.
Constrúyalo
Paso 1: Un Modelo de Juguete Con Segmentos
import numpy as np
def linear_forward(x, w, b):
return x @ w + b
def relu(x):
return np.maximum(x, 0)
def layer_forward(x, w1, b1, w2, b2):
h = relu(linear_forward(x, w1, b1))
return linear_forward(h, w2, b2)
def model_forward(x, params):
activations = [x]
h = x
for w1, b1, w2, b2 in params:
h = layer_forward(h, w1, b1, w2, b2)
activations.append(h)
return h, activations
Paso 2: Backward Ingenuo que Necesita Todas las Activaciones
def model_backward(grad_output, activations, params):
grads = [None] * len(params)
g = grad_output
for i in range(len(params) - 1, -1, -1):
w1, b1, w2, b2 = params[i]
x_in = activations[i]
h_pre = linear_forward(x_in, w1, b1)
h = relu(h_pre)
gh = g @ w2.T
gw2 = h.T @ g
gb2 = g.sum(axis=0)
g_pre = gh * (h_pre > 0)
gx = g_pre @ w1.T
gw1 = x_in.T @ g_pre
gb1 = g_pre.sum(axis=0)
grads[i] = (gw1, gb1, gw2, gb2)
g = gx
return g, grads
Paso 3: Memoria para Checkpoint Cada k
def model_forward_checkpointed(x, params, k=4):
saved_inputs = [x]
h = x
for i, (w1, b1, w2, b2) in enumerate(params):
h = layer_forward(h, w1, b1, w2, b2)
if (i + 1) % k == 0:
saved_inputs.append(h)
return h, saved_inputs
def model_backward_checkpointed(grad_output, saved_inputs, params, k=4):
grads = [None] * len(params)
g = grad_output
segments = [(j * k, min((j + 1) * k, len(params))) for j in range(len(saved_inputs))]
for seg_idx in range(len(saved_inputs) - 1, -1, -1):
start, end = segments[seg_idx]
if start >= end:
continue
x_in = saved_inputs[seg_idx]
_, seg_acts = model_forward(x_in, params[start:end])
g, seg_grads = model_backward(g, seg_acts, params[start:end])
for j, gr in enumerate(seg_grads):
grads[start + j] = gr
return g, grads
Paso 4: Modelo de Costo
def checkpoint_cost(n_layers, segment_size, flops_per_layer=1.0):
fwd = n_layers * flops_per_layer
recompute = n_layers * flops_per_layer
bwd = 2 * n_layers * flops_per_layer
return {
"fwd": fwd,
"recompute": recompute,
"bwd": bwd,
"total": fwd + recompute + bwd,
"overhead_vs_no_ckpt": (fwd + recompute + bwd) / (fwd + bwd) - 1.0,
}
def selective_checkpoint_cost(n_layers, attention_fraction=0.15,
flops_per_layer=1.0):
fwd = n_layers * flops_per_layer
recompute = n_layers * attention_fraction * flops_per_layer
bwd = 2 * n_layers * flops_per_layer
return {
"fwd": fwd,
"recompute": recompute,
"bwd": bwd,
"total": fwd + recompute + bwd,
"overhead_vs_no_ckpt": (fwd + recompute + bwd) / (fwd + bwd) - 1.0,
}
Paso 5: Estimador de Memoria
def activation_memory_mb(n_layers, hidden=8192, seq=8192,
batch=1, bytes_per_value=2):
per_layer = 12 * batch * seq * hidden * bytes_per_value
return n_layers * per_layer / 1e6
def memory_after_checkpoint(n_layers, segment_size, hidden=8192,
seq=8192, batch=1, bytes_per_value=2):
n_seg = max(1, n_layers // segment_size)
saved = (n_seg + segment_size) * 1 * batch * seq * hidden * bytes_per_value
return saved / 1e6
Paso 6: Tamaño Óptimo de Segmento
def optimal_segment(n_layers):
return int(round(np.sqrt(n_layers)))
Paso 7: Decisión de Checkpoint Selectivo
def should_recompute(layer_type, activation_bytes, recompute_flops_ratio):
if layer_type == "attention" and activation_bytes > 100 * 1e6:
return True
if layer_type == "ffn" and activation_bytes > 500 * 1e6:
return recompute_flops_ratio < 0.1
return False
Uso
- torch.utils.checkpoint:
from torch.utils.checkpoint import checkpoint— el wrapper canónico en PyTorch. Envuelve una función; almacena solo las entradas y recomputa en el backward. - Megatron-Core activation recomputation: admite los modos
selective,fullyblock. Estándar en el entrenamiento de frontera a partir de 2024. - FSDP2 offload:
module.to_empty(device="cpu")conoffload_policyen FSDP2 divide (shards) las activaciones hacia la CPU en lugar de recomputarlas. - DeepSpeed ZeRO-Offload: offload a CPU para los estados del optimizador y las activaciones, complementando el checkpointing.
Producción
Esta lección produce outputs/prompt-activation-recompute-policy.md — un prompt que recibe la configuración de su modelo (capas, tamaño oculto, longitud de secuencia, lote) y la memoria GPU disponible, y emite una política de recomputación por capa (none / selective / full / offload).
Ejercicios
Verifique la corrección. Ejecute
model_forward+model_backward(activaciones completas) vsmodel_forward_checkpointed+model_backward_checkpointed(segmentos). Los gradientes de los parámetros deben ser idénticos hasta la precisión de la máquina.Realice un barrido del tamaño del segmento
kde 1 aL. Grafique el overhead de FLOP y la memoria. Encuentre el "codo" de la curva.Implemente el checkpointing selectivo: almacene la entrada del módulo de atención pero no sus intermediarios. Mida el overhead de FLOP frente al checkpointing de capa completa para un modelo de 32 capas con seq=8192.
Agregue offload. Guarde las entradas del segmento en un "búfer de CPU" simulado (una lista separada). Mida el "ancho de banda PCIe" como bytes/tiempo y encuentre el punto de equilibrio entre offload y recomputación.
Realice un benchmark de un transformer real en PyTorch con y sin
torch.utils.checkpoint. Mida la memoria (víatorch.cuda.max_memory_allocated) y el tiempo de paso (step time).
Términos Clave
| Término | Lo que la gente dice | Lo que realmente significa |
|---|---|---|
| Gradient checkpointing | "Ahorre memoria rehaciendo el forward" | Almacena solo las entradas de los segmentos; recomputa los intermediarios durante el backward para obtener los tensores de soporte de gradiente |
| Activation recomputation | "Lo mismo que checkpointing" | El nombre con estilo de computación de alto rendimiento (HPC) para la misma técnica |
| Segment size (k) | "¿Cuántas capas por checkpoint?" | Número de capas cuyos intermediarios se descartan y rematerializan juntos |
| Selective checkpointing | "El truco de Korthikanti" | Recomputa solo activaciones costosas de almacenar (atención/attention softmax); conserva las baratas |
| Full checkpointing | "La versión ingenua (naive)" | Recomputa los intermediarios de cada capa en cada segmento |
| Block checkpointing | "Granularidad gruesa" | Aplica checkpoint a bloques completos de transformer; mayor granularidad |
| FLOP overhead | "El impuesto de cómputo" | FLOPs adicionales por paso = (FLOPs de recomputación) / (FLOPs fwd + bwd); 33% ingenuo, 5% selectivo |
| Activation offload | "Enviar a la CPU" | Mueve activaciones a la memoria RAM de la CPU durante el intervalo forward->backward; alternativa a la recomputación |
| Regla sqrt-L | "El óptimo clásico" | Para capas de costo uniforme, el espaciamiento óptimo de checkpoint es de sqrt(L) capas |
| Volumen de la attention-softmax | "El problema O(L^2)" | L^2 * cabezales * lote floats; domina la memoria de activación en contextos largos |
Lecturas Recomendadas
- Chen et al., 2016 -- "Training Deep Nets with Sublinear Memory Cost" -- el artículo original que formalizó el gradient checkpointing
- Korthikanti et al., 2022 -- "Reducing Activation Recomputation in Large Transformer Models" -- recomputación de activación selectiva y el análisis de costo formal
- Pudipeddi et al., 2020 -- "Training Large Neural Networks with Constant Memory using a New Execution Algorithm" -- enfoque alternativo de memoria constante a través de la rematerialización en modo inverso
- Ren et al., 2021 -- "ZeRO-Offload: Democratizing Billion-Scale Model Training" -- offload de activación a escala
- PyTorch torch.utils.checkpoint docs -- la API estándar
- Megatron-Core activation recomputation documentation -- modos selectivo, completo y por bloques