Phase 10 - Lesson 34
Checkpoint de Gradiente e Recomputação de Ativação
This lesson includes a graded coding exercise that runs in your browser, unlocked with lifetime access.
O backprop mantém todas as ativações intermediárias. Com 70B de parâmetros e 128K de contexto, isso representa 3 TB de ativações por rank. O checkpoint troca FLOPs por memória: recomputar em vez de salvar. A questão é quais segmentos descartar, e a resposta não é "todos eles".
Tipo: Build Idiomas: Python (com numpy, torch opcional) Pré-requisitos: Fase 10 Lição 04 (Pre-Training Mini-GPT), Fase 10 Lição 05 (Scaling & Distributed) Tempo: ~70 minutos
O Problema
O treinamento de um transformer armazena, para cada camada, as entradas de cada operação que é diferenciada no backward: as entradas de atenção, as projeções Q/K/V, a saída do softmax, as entradas do FFN, as saídas da normalização (norm) e a conexão residual (residual stream). Para uma camada com tamanho oculto (hidden size) d, comprimento de sequência L e lote (batch) B, isso está na ordem de 12 * B * L * d floats por camada.
Para d=8192, L=8192, B=1, isso representa 800 MB/camada em BF16. Um modelo de 64 camadas resulta em 51 GB de ativações — e isso antes de multiplicar pelo tamanho do micro-lote (microbatch size), antes de adicionar os intermediários de atenção-softmax (L^2 por cabeça) e antes de considerar as cópias parciais de paralelismo de tensor (tensor-parallel).
A conta tem dois lados: os pesos BF16 somados ao estado do otimizador podem caber em 80GB, mas as ativações ultrapassam esse limite. O checkpoint de gradiente (também conhecido como recomputação de ativação) é a solução padrão. Descarte a maioria das ativações; refaça o forward durante o backward para recuperá-las. Custo: FLOPs extras. Benefício: a memória diminui pela razão entre os segmentos de checkpoint e o total de camadas.
Feito de forma ingênua (naive), o checkpointing custa cerca de 33% a mais de FLOPs de forward por passo. Feito de forma inteligente — checkpointing seletivo conforme a "seleção inteligente" de Korthikanti et al. — você economiza 5x more memória por menos de 5% de overhead de FLOPs. E com multiplicações de matriz (matmuls) em FP8, offload de FSDP e MoE com paralelismo de especialistas (expert-parallel), isso realmente importa: você não pode se dar ao luxo de desperdiçar nem memória nem processamento.
O Conceito
O que o Backward Realmente Precisa
output = layer(input). O backward precisa de grad_input e grad_params. Para computá-los, ele precisa de:
input(para computargrad_params = input.T @ grad_outputpara camadas lineares)- alguns intermediários de derivada de ativação (a derivada de ReLU/GELU/softmax depende do valor da ativação)
O forward pass armazena esses valores automaticamente no grafo do autograd. Cada tensor.retain_grad() e cada operação que precisa de sua entrada retêm uma referência.
Checkpointing Completo Ingênuo (Naive Full Checkpointing)
Divida a rede em N segmentos. Durante o forward, armazene apenas a entrada de cada segmento. Quando o backward precisar dos intermediários, execute novamente o forward do segmento para materializá-los e, em seguida, faça a diferenciação.
Exemplo: transformer de 32 camadas dividido em 32 segmentos de 1 camada cada.
- Memória: 32 entradas de camada (pequeno) vs 32 * (volume de ativação por camada) (gigante).
- Computação extra: 1 forward extra por segmento, ou seja, ~33% a mais de FLOPs de forward no total (como o backward é 2x o forward, o passo completo se torna 1 + 1 + 2 = 4 unidades em vez de 1 + 2 = 3).
Esta é a receita original de Chen et al. 2016: um checkpoint a cada sqrt(L) camadas para equilibrar memória e computação. Para L=64, são 8 checkpoints.
Checkpointing Seletivo (Korthikanti 2022)
Nem todas as ativações custam o mesmo. A saída do softmax de atenção é B*L*L*heads e cresce quadraticamente com o comprimento da sequência. A ativação oculta do FFN é B*L*4d e cresce linearmente. Para sequências longas, o softmax domina.
O checkpoint seletivo mantém as ativações baratas de armazenar (projeções lineares, conexões residuais) e recomputa apenas as caras (atenção). Você paga o mínimo de FLOPs para recomputar, mas economiza a memória O(L^2).
O Megatron-Core implementa isso como recomputação "seletiva" de ativação. Usado na maioria dos treinamentos de modelos de fronteira de 2024 em diante.
Offload
Alternativa à recomputação: transferir ativações para a RAM da CPU entre o forward e o backward. Requer largura de banda PCIe; é benéfico quando a largura de banda ociosa excede o custo de rematerialização. Estratégias mistas são comuns: aplicar checkpoint em algumas camadas e fazer offload em outras.
O FSDP2 oferece offload como uma opção nativa de primeira classe. O offload se destaca quando a GPU está com gargalo de memória, mas a transferência CPU-GPU tem margem de manobra (headroom).
Modelo de Custo de Recomputação
FLOPs por passo com checkpointing ingênuo (naive) a cada k camadas de um total de L:
flops_fwd_normal = L * f_layer
flops_bwd_normal = 2 * L * f_layer
flops_total_normal = 3 * L * f_layer
flops_fwd_ckpt = L * f_layer
flops_recompute = L * f_layer # one extra forward per layer in the segment
flops_bwd_ckpt = 2 * L * f_layer
flops_total_ckpt = 4 * L * f_layer
overhead = 4 / 3 - 1 = 0.33 = 33%
Com o checkpointing seletivo, você recomputa apenas o kernel de atenção, não a camada inteira:
flops_recompute_selective = L * f_attention ~= L * f_layer * 0.15
overhead_selective = (3 + 0.15) / 3 - 1 = 0.05 = 5%
Modelo de Economia de Memória
Volume de ativação por camada: A. Para L camadas, memória de ativação total: L * A.
Checkpoint completo (tamanho do segmento igual a 1): armazena apenas L * input_volume (~`L * 1/10 A para um transformer padrão). Economiza ~9 * L * A * 1/10`.
Checkpoint a cada k camadas: armazena L/k * A mais o equivalente a k-1 camadas dentro do segmento ativo.
Em k = sqrt(L), a memória e o custo de recomputação escalam com sqrt(L) — o tradeoff ideal para camadas com custo uniforme.
Quando Não Usar Checkpoint
- As camadas mais internas de um estágio de pipeline que já está em execução (in-flight). Elas precisam terminar de qualquer forma.
- As primeiras e últimas camadas, caso dominem a computação do estágio (raro em transformers).
- Kernels de atenção que já usam FlashAttention — o Flash já recomputa o softmax rapidamente, de modo que um checkpoint extra no nível de camada adiciona pouco benefício.
Padrões de Implementação
- Function wrapper: envolve um segmento em
torch.utils.checkpoint.checkpoint(fn, input). O PyTorch armazena apenasinpute recomputa todo o resto no backward. - Decorator-based: rotula camadas como aptas para checkpoint; o trainer decide no momento da configuração quais segmentos serão envolvidos.
- Manual explicit recompute: você mesmo escreve o backward pass, chamando um
recompute_forwardcustomizado que duplica o forward com a entrada armazenada.
Todos os três fornecem o mesmo resultado funcional. Wrappers são o idioma padrão.
Interação com TP / PP / FP8
- Tensor parallel: as entradas do checkpoint devem ser reunidas (gather) ou redistribuídas (scatter) na recomputação; lide com o custo de comunicação.
- Pipeline parallel: o padrão típico é aplicar checkpoint no forward de cada estágio do pipeline para que os micro-lotes em ordem inversa possam reutilizar a memória de ativação.
- FP8 recompute: os históricos de amax atualizados durante a recomputação devem coincidir com os do forward original, caso contrário a escala FP8 sofrerá deriva (drift). A maioria dos frameworks tira um snapshot da escala.
Construa
Passo 1: Um Modelo de Brinquedo Com Segmentos
import numpy as np
def linear_forward(x, w, b):
return x @ w + b
def relu(x):
return np.maximum(x, 0)
def layer_forward(x, w1, b1, w2, b2):
h = relu(linear_forward(x, w1, b1))
return linear_forward(h, w2, b2)
def model_forward(x, params):
activations = [x]
h = x
for w1, b1, w2, b2 in params:
h = layer_forward(h, w1, b1, w2, b2)
activations.append(h)
return h, activations
Passo 2: Backward Ingênuo Necessitando de Todas as Ativações
def model_backward(grad_output, activations, params):
grads = [None] * len(params)
g = grad_output
for i in range(len(params) - 1, -1, -1):
w1, b1, w2, b2 = params[i]
x_in = activations[i]
h_pre = linear_forward(x_in, w1, b1)
h = relu(h_pre)
gh = g @ w2.T
gw2 = h.T @ g
gb2 = g.sum(axis=0)
g_pre = gh * (h_pre > 0)
gx = g_pre @ w1.T
gw1 = x_in.T @ g_pre
gb1 = g_pre.sum(axis=0)
grads[i] = (gw1, gb1, gw2, gb2)
g = gx
return g, grads
Passo 3: Memória para Checkpoint a Cada k
def model_forward_checkpointed(x, params, k=4):
saved_inputs = [x]
h = x
for i, (w1, b1, w2, b2) in enumerate(params):
h = layer_forward(h, w1, b1, w2, b2)
if (i + 1) % k == 0:
saved_inputs.append(h)
return h, saved_inputs
def model_backward_checkpointed(grad_output, saved_inputs, params, k=4):
grads = [None] * len(params)
g = grad_output
segments = [(j * k, min((j + 1) * k, len(params))) for j in range(len(saved_inputs))]
for seg_idx in range(len(saved_inputs) - 1, -1, -1):
start, end = segments[seg_idx]
if start >= end:
continue
x_in = saved_inputs[seg_idx]
_, seg_acts = model_forward(x_in, params[start:end])
g, seg_grads = model_backward(g, seg_acts, params[start:end])
for j, gr in enumerate(seg_grads):
grads[start + j] = gr
return g, grads
Passo 4: Modelo de Custo
def checkpoint_cost(n_layers, segment_size, flops_per_layer=1.0):
fwd = n_layers * flops_per_layer
recompute = n_layers * flops_per_layer
bwd = 2 * n_layers * flops_per_layer
return {
"fwd": fwd,
"recompute": recompute,
"bwd": bwd,
"total": fwd + recompute + bwd,
"overhead_vs_no_ckpt": (fwd + recompute + bwd) / (fwd + bwd) - 1.0,
}
def selective_checkpoint_cost(n_layers, attention_fraction=0.15,
flops_per_layer=1.0):
fwd = n_layers * flops_per_layer
recompute = n_layers * attention_fraction * flops_per_layer
bwd = 2 * n_layers * flops_per_layer
return {
"fwd": fwd,
"recompute": recompute,
"bwd": bwd,
"total": fwd + recompute + bwd,
"overhead_vs_no_ckpt": (fwd + recompute + bwd) / (fwd + bwd) - 1.0,
}
Passo 5: Estimador de Memória
def activation_memory_mb(n_layers, hidden=8192, seq=8192,
batch=1, bytes_per_value=2):
per_layer = 12 * batch * seq * hidden * bytes_per_value
return n_layers * per_layer / 1e6
def memory_after_checkpoint(n_layers, segment_size, hidden=8192,
seq=8192, batch=1, bytes_per_value=2):
n_seg = max(1, n_layers // segment_size)
saved = (n_seg + segment_size) * 1 * batch * seq * hidden * bytes_per_value
return saved / 1e6
Passo 6: Tamanho Ideal de Segmento
def optimal_segment(n_layers):
return int(round(np.sqrt(n_layers)))
Passo 7: Decisão de Checkpoint Seletivo
def should_recompute(layer_type, activation_bytes, recompute_flops_ratio):
if layer_type == "attention" and activation_bytes > 100 * 1e6:
return True
if layer_type == "ffn" and activation_bytes > 500 * 1e6:
return recompute_flops_ratio < 0.1
return False
Uso
- torch.utils.checkpoint:
from torch.utils.checkpoint import checkpoint— o wrapper canônico no PyTorch. Envolve uma função; armazena apenas as entradas e recomputa no backward. - Megatron-Core activation recomputation: suporta os modos
selective,fulleblock. Padrão em treinamentos de fronteira a partir de 2024. - FSDP2 offload:
module.to_empty(device="cpu")comoffload_policyno FSDP2 transfere (shards) as ativações para a CPU em vez de recomputá-las. - DeepSpeed ZeRO-Offload: offload para CPU de estados do otimizador e ativações, complementando o checkpointing.
Produção
Esta lição produz outputs/prompt-activation-recompute-policy.md — um prompt que recebe as configurações do seu modelo (camadas, tamanho oculto, comprimento de sequência, lote) e a memória de GPU disponível, e emite uma política de recomputação por camada (none / selective / full / offload).
Exercícios
Verifique a corretude. Execute
model_forward+model_backward(ativações completas) vsmodel_forward_checkpointed+model_backward_checkpointed(segmentos). Os gradientes dos parâmetros devem ser idênticos até a precisão da máquina.Varra o tamanho do segmento
kde 1 aL. Plote o overhead de FLOP e a memória. Encontre o "joelho" da curva.Implemente o checkpointing seletivo: armazene a entrada do módulo de atenção, mas não seus intermediários. Meça o overhead de FLOP versus o checkpointing de camada completa para um modelo de 32 camadas com seq=8192.
Adicione offload. Salve as entradas dos segmentos em um "buffer de CPU" simulado (uma lista separada). Meça a "largura de banda PCIe" como bytes/tempo e encontre o ponto de equilíbrio entre offload e recomputação.
Realize o benchmark de um transformer real em PyTorch com e sem
torch.utils.checkpoint. Meça a memória (viatorch.cuda.max_memory_allocated) e o tempo de passo (step time).
Termos-Chave
| Termo | O que dizem | O que realmente significa |
|---|---|---|
| Gradient checkpointing | "Economize memória refazendo o forward" | Armazena apenas as entradas dos segmentos; recomputa os intermediários durante o backward para obter os tensores de suporte ao gradiente |
| Activation recomputation | "O mesmo que checkpointing" | O nome com estilo de computação de alto desempenho (HPC) para a mesma técnica |
| Segment size (k) | "Quantas camadas por checkpoint" | Número de camadas cujos intermediários são descartados e rematerializados juntos |
| Selective checkpointing | "O truque de Korthikanti" | Recomputa apenas ativações caras de armazenar (atendimento/attention softmax); mantém as baratas |
| Full checkpointing | "A versão ingênua (naive)" | Recomputa os intermediários de todas as camadas em todos os segmentos |
| Block checkpointing | "Grão grosso" | Aplica checkpoint a blocos inteiros de transformer; maior granularidade |
| FLOP overhead | "O imposto de computação" | FLOPs extras por passo = (FLOPs de recomputação) / (FLOPs fwd + bwd); 33% ingênuo, 5% seletivo |
| Activation offload | "Envie para a CPU" | Move ativações para a RAM da CPU durante o intervalo forward->backward; alternativa à recomputação |
| Regra sqrt-L | "O ótimo clássico" | Para camadas de custo uniforme, o espaçamento ideal de checkpoint é de sqrt(L) camadas |
| Volume de attention-softmax | "O problema O(L^2)" | L^2 * cabeças * lote floats; domina a memória de ativação em contextos longos |
Leitura Adicional
- Chen et al., 2016 -- "Training Deep Nets with Sublinear Memory Cost" -- o artigo original que formalizou o gradient checkpointing
- Korthikanti et al., 2022 -- "Reducing Activation Recomputation in Large Transformer Models" -- recomputação de ativação seletiva e a análise formal de custo
- Pudipeddi et al., 2020 -- "Training Large Neural Networks with Constant Memory using a New Execution Algorithm" -- abordagem alternativa de memória constante por meio de rematerialização em modo reverso
- Ren et al., 2021 -- "ZeRO-Offload: Democratizing Billion-Scale Model Training" -- offload de ativação em escala
- PyTorch torch.utils.checkpoint docs -- a API padrão
- Megatron-Core activation recomputation documentation -- modos seletivo, completo e em bloco