Phase 10 - Lesson 34

Checkpoint de Gradiente e Recomputação de Ativação

This lesson includes a graded coding exercise that runs in your browser, unlocked with lifetime access.

O backprop mantém todas as ativações intermediárias. Com 70B de parâmetros e 128K de contexto, isso representa 3 TB de ativações por rank. O checkpoint troca FLOPs por memória: recomputar em vez de salvar. A questão é quais segmentos descartar, e a resposta não é "todos eles".

Tipo: Build Idiomas: Python (com numpy, torch opcional) Pré-requisitos: Fase 10 Lição 04 (Pre-Training Mini-GPT), Fase 10 Lição 05 (Scaling & Distributed) Tempo: ~70 minutos

O Problema

O treinamento de um transformer armazena, para cada camada, as entradas de cada operação que é diferenciada no backward: as entradas de atenção, as projeções Q/K/V, a saída do softmax, as entradas do FFN, as saídas da normalização (norm) e a conexão residual (residual stream). Para uma camada com tamanho oculto (hidden size) d, comprimento de sequência L e lote (batch) B, isso está na ordem de 12 * B * L * d floats por camada.

Para d=8192, L=8192, B=1, isso representa 800 MB/camada em BF16. Um modelo de 64 camadas resulta em 51 GB de ativações — e isso antes de multiplicar pelo tamanho do micro-lote (microbatch size), antes de adicionar os intermediários de atenção-softmax (L^2 por cabeça) e antes de considerar as cópias parciais de paralelismo de tensor (tensor-parallel).

A conta tem dois lados: os pesos BF16 somados ao estado do otimizador podem caber em 80GB, mas as ativações ultrapassam esse limite. O checkpoint de gradiente (também conhecido como recomputação de ativação) é a solução padrão. Descarte a maioria das ativações; refaça o forward durante o backward para recuperá-las. Custo: FLOPs extras. Benefício: a memória diminui pela razão entre os segmentos de checkpoint e o total de camadas.

Feito de forma ingênua (naive), o checkpointing custa cerca de 33% a mais de FLOPs de forward por passo. Feito de forma inteligente — checkpointing seletivo conforme a "seleção inteligente" de Korthikanti et al. — você economiza 5x more memória por menos de 5% de overhead de FLOPs. E com multiplicações de matriz (matmuls) em FP8, offload de FSDP e MoE com paralelismo de especialistas (expert-parallel), isso realmente importa: você não pode se dar ao luxo de desperdiçar nem memória nem processamento.

O Conceito

O que o Backward Realmente Precisa

output = layer(input). O backward precisa de grad_input e grad_params. Para computá-los, ele precisa de:

  • input (para computar grad_params = input.T @ grad_output para camadas lineares)
  • alguns intermediários de derivada de ativação (a derivada de ReLU/GELU/softmax depende do valor da ativação)

O forward pass armazena esses valores automaticamente no grafo do autograd. Cada tensor.retain_grad() e cada operação que precisa de sua entrada retêm uma referência.

Checkpointing Completo Ingênuo (Naive Full Checkpointing)

Divida a rede em N segmentos. Durante o forward, armazene apenas a entrada de cada segmento. Quando o backward precisar dos intermediários, execute novamente o forward do segmento para materializá-los e, em seguida, faça a diferenciação.

Exemplo: transformer de 32 camadas dividido em 32 segmentos de 1 camada cada.

  • Memória: 32 entradas de camada (pequeno) vs 32 * (volume de ativação por camada) (gigante).
  • Computação extra: 1 forward extra por segmento, ou seja, ~33% a mais de FLOPs de forward no total (como o backward é 2x o forward, o passo completo se torna 1 + 1 + 2 = 4 unidades em vez de 1 + 2 = 3).

Esta é a receita original de Chen et al. 2016: um checkpoint a cada sqrt(L) camadas para equilibrar memória e computação. Para L=64, são 8 checkpoints.

Checkpointing Seletivo (Korthikanti 2022)

Nem todas as ativações custam o mesmo. A saída do softmax de atenção é B*L*L*heads e cresce quadraticamente com o comprimento da sequência. A ativação oculta do FFN é B*L*4d e cresce linearmente. Para sequências longas, o softmax domina.

O checkpoint seletivo mantém as ativações baratas de armazenar (projeções lineares, conexões residuais) e recomputa apenas as caras (atenção). Você paga o mínimo de FLOPs para recomputar, mas economiza a memória O(L^2).

O Megatron-Core implementa isso como recomputação "seletiva" de ativação. Usado na maioria dos treinamentos de modelos de fronteira de 2024 em diante.

Offload

Alternativa à recomputação: transferir ativações para a RAM da CPU entre o forward e o backward. Requer largura de banda PCIe; é benéfico quando a largura de banda ociosa excede o custo de rematerialização. Estratégias mistas são comuns: aplicar checkpoint em algumas camadas e fazer offload em outras.

O FSDP2 oferece offload como uma opção nativa de primeira classe. O offload se destaca quando a GPU está com gargalo de memória, mas a transferência CPU-GPU tem margem de manobra (headroom).

Modelo de Custo de Recomputação

FLOPs por passo com checkpointing ingênuo (naive) a cada k camadas de um total de L:

flops_fwd_normal = L * f_layer
flops_bwd_normal = 2 * L * f_layer
flops_total_normal = 3 * L * f_layer

flops_fwd_ckpt = L * f_layer
flops_recompute = L * f_layer  # one extra forward per layer in the segment
flops_bwd_ckpt = 2 * L * f_layer
flops_total_ckpt = 4 * L * f_layer
overhead = 4 / 3 - 1 = 0.33 = 33%

Com o checkpointing seletivo, você recomputa apenas o kernel de atenção, não a camada inteira:

flops_recompute_selective = L * f_attention ~= L * f_layer * 0.15
overhead_selective = (3 + 0.15) / 3 - 1 = 0.05 = 5%

Modelo de Economia de Memória

Volume de ativação por camada: A. Para L camadas, memória de ativação total: L * A.

Checkpoint completo (tamanho do segmento igual a 1): armazena apenas L * input_volume (~`L * 1/10 A para um transformer padrão). Economiza ~9 * L * A * 1/10`.

Checkpoint a cada k camadas: armazena L/k * A mais o equivalente a k-1 camadas dentro do segmento ativo.

Em k = sqrt(L), a memória e o custo de recomputação escalam com sqrt(L) — o tradeoff ideal para camadas com custo uniforme.

Quando Não Usar Checkpoint

  • As camadas mais internas de um estágio de pipeline que já está em execução (in-flight). Elas precisam terminar de qualquer forma.
  • As primeiras e últimas camadas, caso dominem a computação do estágio (raro em transformers).
  • Kernels de atenção que já usam FlashAttention — o Flash já recomputa o softmax rapidamente, de modo que um checkpoint extra no nível de camada adiciona pouco benefício.

Padrões de Implementação

  1. Function wrapper: envolve um segmento em torch.utils.checkpoint.checkpoint(fn, input). O PyTorch armazena apenas input e recomputa todo o resto no backward.
  2. Decorator-based: rotula camadas como aptas para checkpoint; o trainer decide no momento da configuração quais segmentos serão envolvidos.
  3. Manual explicit recompute: você mesmo escreve o backward pass, chamando um recompute_forward customizado que duplica o forward com a entrada armazenada.

Todos os três fornecem o mesmo resultado funcional. Wrappers são o idioma padrão.

Interação com TP / PP / FP8

  • Tensor parallel: as entradas do checkpoint devem ser reunidas (gather) ou redistribuídas (scatter) na recomputação; lide com o custo de comunicação.
  • Pipeline parallel: o padrão típico é aplicar checkpoint no forward de cada estágio do pipeline para que os micro-lotes em ordem inversa possam reutilizar a memória de ativação.
  • FP8 recompute: os históricos de amax atualizados durante a recomputação devem coincidir com os do forward original, caso contrário a escala FP8 sofrerá deriva (drift). A maioria dos frameworks tira um snapshot da escala.

Construa

Passo 1: Um Modelo de Brinquedo Com Segmentos

import numpy as np


def linear_forward(x, w, b):
    return x @ w + b


def relu(x):
    return np.maximum(x, 0)


def layer_forward(x, w1, b1, w2, b2):
    h = relu(linear_forward(x, w1, b1))
    return linear_forward(h, w2, b2)


def model_forward(x, params):
    activations = [x]
    h = x
    for w1, b1, w2, b2 in params:
        h = layer_forward(h, w1, b1, w2, b2)
        activations.append(h)
    return h, activations

Passo 2: Backward Ingênuo Necessitando de Todas as Ativações

def model_backward(grad_output, activations, params):
    grads = [None] * len(params)
    g = grad_output
    for i in range(len(params) - 1, -1, -1):
        w1, b1, w2, b2 = params[i]
        x_in = activations[i]
        h_pre = linear_forward(x_in, w1, b1)
        h = relu(h_pre)
        gh = g @ w2.T
        gw2 = h.T @ g
        gb2 = g.sum(axis=0)
        g_pre = gh * (h_pre > 0)
        gx = g_pre @ w1.T
        gw1 = x_in.T @ g_pre
        gb1 = g_pre.sum(axis=0)
        grads[i] = (gw1, gb1, gw2, gb2)
        g = gx
    return g, grads

Passo 3: Memória para Checkpoint a Cada k

def model_forward_checkpointed(x, params, k=4):
    saved_inputs = [x]
    h = x
    for i, (w1, b1, w2, b2) in enumerate(params):
        h = layer_forward(h, w1, b1, w2, b2)
        if (i + 1) % k == 0:
            saved_inputs.append(h)
    return h, saved_inputs


def model_backward_checkpointed(grad_output, saved_inputs, params, k=4):
    grads = [None] * len(params)
    g = grad_output
    segments = [(j * k, min((j + 1) * k, len(params))) for j in range(len(saved_inputs))]
    for seg_idx in range(len(saved_inputs) - 1, -1, -1):
        start, end = segments[seg_idx]
        if start >= end:
            continue
        x_in = saved_inputs[seg_idx]
        _, seg_acts = model_forward(x_in, params[start:end])
        g, seg_grads = model_backward(g, seg_acts, params[start:end])
        for j, gr in enumerate(seg_grads):
            grads[start + j] = gr
    return g, grads

Passo 4: Modelo de Custo

def checkpoint_cost(n_layers, segment_size, flops_per_layer=1.0):
    fwd = n_layers * flops_per_layer
    recompute = n_layers * flops_per_layer
    bwd = 2 * n_layers * flops_per_layer
    return {
        "fwd": fwd,
        "recompute": recompute,
        "bwd": bwd,
        "total": fwd + recompute + bwd,
        "overhead_vs_no_ckpt": (fwd + recompute + bwd) / (fwd + bwd) - 1.0,
    }


def selective_checkpoint_cost(n_layers, attention_fraction=0.15,
                              flops_per_layer=1.0):
    fwd = n_layers * flops_per_layer
    recompute = n_layers * attention_fraction * flops_per_layer
    bwd = 2 * n_layers * flops_per_layer
    return {
        "fwd": fwd,
        "recompute": recompute,
        "bwd": bwd,
        "total": fwd + recompute + bwd,
        "overhead_vs_no_ckpt": (fwd + recompute + bwd) / (fwd + bwd) - 1.0,
    }

Passo 5: Estimador de Memória

def activation_memory_mb(n_layers, hidden=8192, seq=8192,
                        batch=1, bytes_per_value=2):
    per_layer = 12 * batch * seq * hidden * bytes_per_value
    return n_layers * per_layer / 1e6


def memory_after_checkpoint(n_layers, segment_size, hidden=8192,
                           seq=8192, batch=1, bytes_per_value=2):
    n_seg = max(1, n_layers // segment_size)
    saved = (n_seg + segment_size) * 1 * batch * seq * hidden * bytes_per_value
    return saved / 1e6

Passo 6: Tamanho Ideal de Segmento

def optimal_segment(n_layers):
    return int(round(np.sqrt(n_layers)))

Passo 7: Decisão de Checkpoint Seletivo

def should_recompute(layer_type, activation_bytes, recompute_flops_ratio):
    if layer_type == "attention" and activation_bytes > 100 * 1e6:
        return True
    if layer_type == "ffn" and activation_bytes > 500 * 1e6:
        return recompute_flops_ratio < 0.1
    return False

Uso

  • torch.utils.checkpoint: from torch.utils.checkpoint import checkpoint — o wrapper canônico no PyTorch. Envolve uma função; armazena apenas as entradas e recomputa no backward.
  • Megatron-Core activation recomputation: suporta os modos selective, full e block. Padrão em treinamentos de fronteira a partir de 2024.
  • FSDP2 offload: module.to_empty(device="cpu") com offload_policy no FSDP2 transfere (shards) as ativações para a CPU em vez de recomputá-las.
  • DeepSpeed ZeRO-Offload: offload para CPU de estados do otimizador e ativações, complementando o checkpointing.

Produção

Esta lição produz outputs/prompt-activation-recompute-policy.md — um prompt que recebe as configurações do seu modelo (camadas, tamanho oculto, comprimento de sequência, lote) e a memória de GPU disponível, e emite uma política de recomputação por camada (none / selective / full / offload).

Exercícios

  1. Verifique a corretude. Execute model_forward + model_backward (ativações completas) vs model_forward_checkpointed + model_backward_checkpointed (segmentos). Os gradientes dos parâmetros devem ser idênticos até a precisão da máquina.

  2. Varra o tamanho do segmento k de 1 a L. Plote o overhead de FLOP e a memória. Encontre o "joelho" da curva.

  3. Implemente o checkpointing seletivo: armazene a entrada do módulo de atenção, mas não seus intermediários. Meça o overhead de FLOP versus o checkpointing de camada completa para um modelo de 32 camadas com seq=8192.

  4. Adicione offload. Salve as entradas dos segmentos em um "buffer de CPU" simulado (uma lista separada). Meça a "largura de banda PCIe" como bytes/tempo e encontre o ponto de equilíbrio entre offload e recomputação.

  5. Realize o benchmark de um transformer real em PyTorch com e sem torch.utils.checkpoint. Meça a memória (via torch.cuda.max_memory_allocated) e o tempo de passo (step time).

Termos-Chave

Termo O que dizem O que realmente significa
Gradient checkpointing "Economize memória refazendo o forward" Armazena apenas as entradas dos segmentos; recomputa os intermediários durante o backward para obter os tensores de suporte ao gradiente
Activation recomputation "O mesmo que checkpointing" O nome com estilo de computação de alto desempenho (HPC) para a mesma técnica
Segment size (k) "Quantas camadas por checkpoint" Número de camadas cujos intermediários são descartados e rematerializados juntos
Selective checkpointing "O truque de Korthikanti" Recomputa apenas ativações caras de armazenar (atendimento/attention softmax); mantém as baratas
Full checkpointing "A versão ingênua (naive)" Recomputa os intermediários de todas as camadas em todos os segmentos
Block checkpointing "Grão grosso" Aplica checkpoint a blocos inteiros de transformer; maior granularidade
FLOP overhead "O imposto de computação" FLOPs extras por passo = (FLOPs de recomputação) / (FLOPs fwd + bwd); 33% ingênuo, 5% seletivo
Activation offload "Envie para a CPU" Move ativações para a RAM da CPU durante o intervalo forward->backward; alternativa à recomputação
Regra sqrt-L "O ótimo clássico" Para camadas de custo uniforme, o espaçamento ideal de checkpoint é de sqrt(L) camadas
Volume de attention-softmax "O problema O(L^2)" L^2 * cabeças * lote floats; domina a memória de ativação em contextos longos

Leitura Adicional

0 lifetime access. Curriculum based on AI Engineering from Scratch by Rohit Ghumare (MIT, used under attribution).