Phase 07 - Lesson 12

KV Cache, Flash Attention & Otimização de Inferência

This lesson includes a graded coding exercise that runs in your browser, unlocked with lifetime access.

O treinamento é paralelo e limitado por FLOPs (FLOP-bound). A inferência é serial e limitada por memória (memory-bound). Gargalos diferentes, truques diferentes.

Tipo: Build Linguagens: Python Pré-requisitos: Fase 7 · 02 (Self-Attention), Fase 7 · 05 (Full Transformer), Fase 7 · 07 (GPT) Tempo: ~75 minutos

O Problema

Um decodificador autorregressivo ingênuo realiza um trabalho de O(N²) para gerar N tokens: a cada passo, ele recalcula a atenção sobre todo o prefixo. Para uma resposta de 4K tokens, isso representa 16M de operações de atenção, a maioria delas redundante. Cada estado oculto (hidden state) de um token de prefixo é determinístico uma vez computado — você só precisa executar a query do novo token contra as keys e values armazenados em cache de tudo o que veio antes.

Além disso, a atenção em si move muitos dados. A atenção padrão materializa uma matriz de pontuação (score matrix) N×N, uma saída softmax N×d e uma saída final N×d — muitas leituras e escritas na HBM. Para N≥2K, atenção se torna limitada por memória (memory-bound) antes de se tornar limitada por computação (FLOP-bound). Os kernels de atenção clássicos subutilizam as GPUs modernas em 4 a 10 vezes.

Duas otimizações, ambas de Dao et al., levaram a inferência de modelos de fronteira de "lenta" para "rápida":

  1. KV cache. Armazena os vetores K e V de cada token do prefixo. A atenção de cada novo token é apenas uma query contra as keys armazenadas em cache. A inferência é reduzida de O(N²) para O(N) por passo de geração.
  2. Flash Attention. Divide a computação da atenção em blocos (tiling) para que a matriz N×N completa nunca acesse a HBM. Toda a operação de softmax + matmul acontece na SRAM. Aceleração de 2 a 4 vezes em tempo de relógio (wall-clock) na A100; 5 a 10 vezes na H100 com FP8.

Em 2026, ambas são universais. Toda stack de inferência em produção (vLLM, TensorRT-LLM, SGLang, llama.cpp) as pressupõe. Cada modelo de fronteira é lançado com o Flash Attention ativado.

O Conceito

KV cache growth and Flash Attention tiling

Matemática do KV cache

Por camada do decodificador, por token, por cabeça:

bytes_per_token_per_layer = 2 * d_head * dtype_size
                          ^
                          K and V

Para um modelo de 7B parâmetros com 32 camadas, 32 cabeças, d_head=128, fp16:

per token per layer = 2 * 128 * 2 = 512 bytes
per token (32 layers) = 16 KB
per 32K context = 512 MB

Para o Llama 3 70B (80 camadas, d_head=128, GQA com 8 cabeças KV):

per token per layer = 2 * 8 * 128 * 2 = 4096 bytes (4 KB)
per 32K context = 10.4 GB

Esses 10 GB são o motivo pelo qual o Llama 3 70B com contexto de 128K precisa de quase toda uma GPU A100 de 40 GB apenas para o KV cache com tamanho de lote (batch size) igual a 1.

O GQA é a grande vitória para o KV cache. O MHA com 64 cabeças exigiria 32 GB. O MLA comprime ainda mais.

Flash Attention — o truque do particionamento em blocos (tiling)

Standard attention:

S = Q @ K^T          (HBM read, N×N, HBM write)
P = softmax(S)       (HBM read, HBM write)
O = P @ V            (HBM read, HBM write)

Três viagens de ida e volta à HBM. Na H100, a largura de banda da HBM é de 3 TB/s; a da SRAM é de 30 TB/s. Cada viagem à HBM representa uma lentidão de 10 vezes em comparação a manter tudo no chip.

Flash Attention:

for each block of Q (tile size ~128 × 128):
    load Q_tile into SRAM
    for each block of K, V:
        load K_tile, V_tile into SRAM
        compute S_tile = Q_tile @ K_tile^T     (SRAM)
        running softmax aggregation             (SRAM)
        accumulate into O_tile                  (SRAM)
    write O_tile to HBM

Uma viagem à HBM por bloco (tile). A pegada total de memória cai de O(N²) para O(N). O passo de retropropagação (backward pass) recalcula alguns valores do passo de propagação (forward pass) em vez de armazená-los — outra vitória de memória.

Truque numérico. O softmax dinâmico (running softmax) mantém (max, sum) ao longo dos blocos para que a normalização final seja exata. Não é uma aproximação — o Flash Attention calcula uma saída bit a bit idêntica à atenção padrão (módulo a não associatividade de fp16).

Evolução das versões:

Versão Ano Mudança principal Aceleração no hardware de referência
Flash 1 2022 Kernel de SRAM dividido em blocos (tiled) 2× na A100
Flash 2 2023 Melhor paralelismo, ordenação causal-first 3× na A100
Flash 3 2024 Assincronia do Hopper, FP8 1.5–2× na H100 (~740 TFLOPs FP16)
Flash 4 2026 Pipeline de 5 estágios do Blackwell, exp2 por software Foco inicial em inferência (inicialmente apenas forward pass)

O Flash 4 suporta apenas o forward-pass no lançamento. O treinamento ainda usa o Flash 3. O suporte a GQA e varlen para o Flash 4 está pendente (meados de 2026).

Decodificação especulativa — a outra vitória na latência

Um modelo mais barato propõe N tokens. O modelo principal (grande) verifica todos os N em paralelo. Se a verificação aceitar k tokens, você pagou por 1 forward pass do modelo principal para k gerações. O valor típico de k é de 3 a 5 em código e prosa.

Padrões em 2026:

  • EAGLE 2 / Medusa. Cabeças de rascunho (draft heads) integradas que compartilham os estados ocultos do verificador. Aceleração de 2 a 3 vezes sem perda de qualidade.
  • Decodificação especulativa com modelo de rascunho (draft model). Aceleração de 2 a 4 vezes em hardware voltado para o consumidor final.
  • Lookahead decoding. Iteração de Jacobi; nenhum modelo de rascunho é necessário. De nicho, mas gratuito.

Continuous batching (Loteamento contínuo)

Inferência em lote clássica: espera que a sequência mais lenta termine para iniciar um novo lote. Desperdiça GPU quando respostas curtas terminam antes.

Continuous batching (lançado primeiro no Orca, agora presente no vLLM, TensorRT-LLM, SGLang): insere novas requisições no lote assim que as antigas terminam. Ganho de throughput de 5 a 10 vezes para cargas de trabalho de chat típicas.

PagedAttention — KV cache como memória virtual

O recurso principal do vLLM. O KV cache é alocado em blocos de 16 tokens; uma tabela de páginas mapeia posições lógicas para blocos físicos. Permite compartilhar o KV entre amostras paralelas (beam search, amostragem paralela), realizar o hot-swap de prefixos para cache de prompts e desfragmentar a memória. Melhoria de throughput de 4 vezes em relação à alocação contígua ingênua.

Construa

Consulte code/main.py. Nós implementamos:

  1. Um decodificador incremental ingênuo O(N²).
  2. Um decodificador com KV-cache O(N).
  3. Um softmax dividido em blocos (tiled) que simula o algoritmo de running-max do Flash Attention.

Passo 1: KV cache

class KVCache:
    def __init__(self, n_layers, n_heads, d_head):
        self.K = [[[] for _ in range(n_heads)] for _ in range(n_layers)]
        self.V = [[[] for _ in range(n_heads)] for _ in range(n_layers)]

    def append(self, layer, head, k, v):
        self.K[layer][head].append(k)
        self.V[layer][head].append(v)

    def read(self, layer, head):
        return self.K[layer][head], self.V[layer][head]

Simples: continua expandindo os vetores K, V de cada token em listas por camada e por cabeça.

Passo 2: softmax em blocos (tiled softmax)

def tiled_softmax_dot(q, K, V, tile=4):
    """Flash-attention-style softmax(qK^T)V with running max/sum."""
    m = float("-inf")
    s = 0.0
    out = [0.0] * len(V[0])
    for start in range(0, len(K), tile):
        k_block = K[start:start + tile]
        v_block = V[start:start + tile]
        scores = [sum(qi * ki for qi, ki in zip(q, k)) for k in k_block]
        new_m = max(m, *scores)
        exp_old = math.exp(m - new_m) if m != float("-inf") else 0.0
        exp_new = [math.exp(sc - new_m) for sc in scores]
        s = s * exp_old + sum(exp_new)
        for j in range(len(out)):
            out[j] = out[j] * exp_old + sum(e * v[j] for e, v in zip(exp_new, v_block))
        m = new_m
    return [o / s for o in out]

Saída bit a bit idêntica a softmax(qK) V gerada de uma só vez, mas, a qualquer momento, o conjunto de trabalho é um bloco tile × d_head, e não o N × d_head completo.

Passo 3: comparar a decodificação ingênua vs com cache em uma geração de 100 tokens

Contagem de operações de atenção. Ingênua: O(N²) = 5050. Com cache: O(N) = 100. O código imprime ambas.

Use-o

# HuggingFace transformers auto-enables KV cache on decoder-only generate().
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-3B",
    attn_implementation="flash_attention_2",  # use FA3 if Hopper
    torch_dtype="bfloat16",
)
# generate() uses KV cache automatically

Produção com vLLM:

pip install vllm
vllm serve meta-llama/Llama-3.1-70B-Instruct \
    --tensor-parallel-size 4 \
    --max-model-len 32768 \
    --enable-prefix-caching \
    --kv-cache-dtype fp8

O cache de prefixo (prefix caching) entre requisições é uma grande vitória em 2026 — o mesmo prompt do sistema, exemplos few-shot ou documento de contexto longo reutilizam o KV entre chamadas. Para cargas de trabalho de agentes com prompts de ferramentas repetidos, o cache de prefixo gera rotineiramente um ganho de 5 vezes no throughput.

Envie para Produção

Consulte outputs/skill-inference-optimizer.md. A skill escolhe a implementação de atenção, a estratégia de KV cache, a quantização e a decodificação especulativa para um novo deploy de inferência.

Exercícios

  1. Fácil. Execute code/main.py. Confirme que os decodificadores ingênuo e com cache produzem a mesma saída; observe a diferença na contagem de operações.
  2. Médio. Implemente o cache de prefixo (prefix caching): dados um prompt P e várias conclusões (completions), execute uma etapa de forward pass sobre P para preencher o KV cache e depois ramifique por conclusão. Meça a aceleração versus codificar novamente o prompt P para cada uma.
  3. Difícil. Implemente um PagedAttention de brinquedo: KV cache em blocos fixos de 16 tokens com uma lista de blocos livres (free-list). Quando uma sequência terminar, retorne seus blocos ao pool. Simule 1.000 conclusões de chat com comprimentos variados. Compare a fragmentação de memória com a alocação contígua.

Key Terms

Termo O que as pessoas dizem O que realmente significa
KV cache "O truque que torna a decodificação rápida" Vetores K e V armazenados de cada token de prefixo; novas queries aplicam atenção a eles em vez de recalculá-los.
HBM "Memória principal da GPU" High Bandwidth Memory (Memória de Alta Largura de Banda); 80 GB na H100, 192 GB na B200. Largura de banda de ~3 TB/s.
SRAM "Memória interna (no chip)" Memória rápida por SM, ~256 KB por SM na H100. Largura de banda de ~30 TB/s.
Flash Attention "Kernel de atenção dividido em blocos" Calcula a atenção sem materializar a matriz N×N na HBM.
Continuous batching "Loteamento sem espera" Substitui sequências concluídas por novas sem precisar esvaziar o lote.
PagedAttention "O principal recurso do vLLM" KV cache alocado em blocos fixos com uma tabela de páginas; elimina a fragmentação.
Prefix caching "Reutilizar prompts longos" Armazena o KV em cache para um prefixo compartilhado entre requisições; redução significativa de custo para agentes.
Speculative decoding "Rascunho + verificação" Um modelo de rascunho barato sugere tokens; o modelo grande verifica k deles de uma só vez.

Leituras Adicionais

0 lifetime access. Curriculum based on AI Engineering from Scratch by Rohit Ghumare (MIT, used under attribution).