Phase 07 - Lesson 12
KV Cache, Flash Attention & Otimização de Inferência
This lesson includes a graded coding exercise that runs in your browser, unlocked with lifetime access.
O treinamento é paralelo e limitado por FLOPs (FLOP-bound). A inferência é serial e limitada por memória (memory-bound). Gargalos diferentes, truques diferentes.
Tipo: Build Linguagens: Python Pré-requisitos: Fase 7 · 02 (Self-Attention), Fase 7 · 05 (Full Transformer), Fase 7 · 07 (GPT) Tempo: ~75 minutos
O Problema
Um decodificador autorregressivo ingênuo realiza um trabalho de O(N²) para gerar N tokens: a cada passo, ele recalcula a atenção sobre todo o prefixo. Para uma resposta de 4K tokens, isso representa 16M de operações de atenção, a maioria delas redundante. Cada estado oculto (hidden state) de um token de prefixo é determinístico uma vez computado — você só precisa executar a query do novo token contra as keys e values armazenados em cache de tudo o que veio antes.
Além disso, a atenção em si move muitos dados. A atenção padrão materializa uma matriz de pontuação (score matrix) N×N, uma saída softmax N×d e uma saída final N×d — muitas leituras e escritas na HBM. Para N≥2K, atenção se torna limitada por memória (memory-bound) antes de se tornar limitada por computação (FLOP-bound). Os kernels de atenção clássicos subutilizam as GPUs modernas em 4 a 10 vezes.
Duas otimizações, ambas de Dao et al., levaram a inferência de modelos de fronteira de "lenta" para "rápida":
- KV cache. Armazena os vetores K e V de cada token do prefixo. A atenção de cada novo token é apenas uma query contra as keys armazenadas em cache. A inferência é reduzida de
O(N²)paraO(N)por passo de geração. - Flash Attention. Divide a computação da atenção em blocos (tiling) para que a matriz N×N completa nunca acesse a HBM. Toda a operação de softmax + matmul acontece na SRAM. Aceleração de 2 a 4 vezes em tempo de relógio (wall-clock) na A100; 5 a 10 vezes na H100 com FP8.
Em 2026, ambas são universais. Toda stack de inferência em produção (vLLM, TensorRT-LLM, SGLang, llama.cpp) as pressupõe. Cada modelo de fronteira é lançado com o Flash Attention ativado.
O Conceito
Matemática do KV cache
Por camada do decodificador, por token, por cabeça:
bytes_per_token_per_layer = 2 * d_head * dtype_size
^
K and V
Para um modelo de 7B parâmetros com 32 camadas, 32 cabeças, d_head=128, fp16:
per token per layer = 2 * 128 * 2 = 512 bytes
per token (32 layers) = 16 KB
per 32K context = 512 MB
Para o Llama 3 70B (80 camadas, d_head=128, GQA com 8 cabeças KV):
per token per layer = 2 * 8 * 128 * 2 = 4096 bytes (4 KB)
per 32K context = 10.4 GB
Esses 10 GB são o motivo pelo qual o Llama 3 70B com contexto de 128K precisa de quase toda uma GPU A100 de 40 GB apenas para o KV cache com tamanho de lote (batch size) igual a 1.
O GQA é a grande vitória para o KV cache. O MHA com 64 cabeças exigiria 32 GB. O MLA comprime ainda mais.
Flash Attention — o truque do particionamento em blocos (tiling)
Standard attention:
S = Q @ K^T (HBM read, N×N, HBM write)
P = softmax(S) (HBM read, HBM write)
O = P @ V (HBM read, HBM write)
Três viagens de ida e volta à HBM. Na H100, a largura de banda da HBM é de 3 TB/s; a da SRAM é de 30 TB/s. Cada viagem à HBM representa uma lentidão de 10 vezes em comparação a manter tudo no chip.
Flash Attention:
for each block of Q (tile size ~128 × 128):
load Q_tile into SRAM
for each block of K, V:
load K_tile, V_tile into SRAM
compute S_tile = Q_tile @ K_tile^T (SRAM)
running softmax aggregation (SRAM)
accumulate into O_tile (SRAM)
write O_tile to HBM
Uma viagem à HBM por bloco (tile). A pegada total de memória cai de O(N²) para O(N). O passo de retropropagação (backward pass) recalcula alguns valores do passo de propagação (forward pass) em vez de armazená-los — outra vitória de memória.
Truque numérico. O softmax dinâmico (running softmax) mantém (max, sum) ao longo dos blocos para que a normalização final seja exata. Não é uma aproximação — o Flash Attention calcula uma saída bit a bit idêntica à atenção padrão (módulo a não associatividade de fp16).
Evolução das versões:
| Versão | Ano | Mudança principal | Aceleração no hardware de referência |
|---|---|---|---|
| Flash 1 | 2022 | Kernel de SRAM dividido em blocos (tiled) | 2× na A100 |
| Flash 2 | 2023 | Melhor paralelismo, ordenação causal-first | 3× na A100 |
| Flash 3 | 2024 | Assincronia do Hopper, FP8 | 1.5–2× na H100 (~740 TFLOPs FP16) |
| Flash 4 | 2026 | Pipeline de 5 estágios do Blackwell, exp2 por software | Foco inicial em inferência (inicialmente apenas forward pass) |
O Flash 4 suporta apenas o forward-pass no lançamento. O treinamento ainda usa o Flash 3. O suporte a GQA e varlen para o Flash 4 está pendente (meados de 2026).
Decodificação especulativa — a outra vitória na latência
Um modelo mais barato propõe N tokens. O modelo principal (grande) verifica todos os N em paralelo. Se a verificação aceitar k tokens, você pagou por 1 forward pass do modelo principal para k gerações. O valor típico de k é de 3 a 5 em código e prosa.
Padrões em 2026:
- EAGLE 2 / Medusa. Cabeças de rascunho (draft heads) integradas que compartilham os estados ocultos do verificador. Aceleração de 2 a 3 vezes sem perda de qualidade.
- Decodificação especulativa com modelo de rascunho (draft model). Aceleração de 2 a 4 vezes em hardware voltado para o consumidor final.
- Lookahead decoding. Iteração de Jacobi; nenhum modelo de rascunho é necessário. De nicho, mas gratuito.
Continuous batching (Loteamento contínuo)
Inferência em lote clássica: espera que a sequência mais lenta termine para iniciar um novo lote. Desperdiça GPU quando respostas curtas terminam antes.
Continuous batching (lançado primeiro no Orca, agora presente no vLLM, TensorRT-LLM, SGLang): insere novas requisições no lote assim que as antigas terminam. Ganho de throughput de 5 a 10 vezes para cargas de trabalho de chat típicas.
PagedAttention — KV cache como memória virtual
O recurso principal do vLLM. O KV cache é alocado em blocos de 16 tokens; uma tabela de páginas mapeia posições lógicas para blocos físicos. Permite compartilhar o KV entre amostras paralelas (beam search, amostragem paralela), realizar o hot-swap de prefixos para cache de prompts e desfragmentar a memória. Melhoria de throughput de 4 vezes em relação à alocação contígua ingênua.
Construa
Consulte code/main.py. Nós implementamos:
- Um decodificador incremental ingênuo
O(N²). - Um decodificador com KV-cache
O(N). - Um softmax dividido em blocos (tiled) que simula o algoritmo de running-max do Flash Attention.
Passo 1: KV cache
class KVCache:
def __init__(self, n_layers, n_heads, d_head):
self.K = [[[] for _ in range(n_heads)] for _ in range(n_layers)]
self.V = [[[] for _ in range(n_heads)] for _ in range(n_layers)]
def append(self, layer, head, k, v):
self.K[layer][head].append(k)
self.V[layer][head].append(v)
def read(self, layer, head):
return self.K[layer][head], self.V[layer][head]
Simples: continua expandindo os vetores K, V de cada token em listas por camada e por cabeça.
Passo 2: softmax em blocos (tiled softmax)
def tiled_softmax_dot(q, K, V, tile=4):
"""Flash-attention-style softmax(qK^T)V with running max/sum."""
m = float("-inf")
s = 0.0
out = [0.0] * len(V[0])
for start in range(0, len(K), tile):
k_block = K[start:start + tile]
v_block = V[start:start + tile]
scores = [sum(qi * ki for qi, ki in zip(q, k)) for k in k_block]
new_m = max(m, *scores)
exp_old = math.exp(m - new_m) if m != float("-inf") else 0.0
exp_new = [math.exp(sc - new_m) for sc in scores]
s = s * exp_old + sum(exp_new)
for j in range(len(out)):
out[j] = out[j] * exp_old + sum(e * v[j] for e, v in zip(exp_new, v_block))
m = new_m
return [o / s for o in out]
Saída bit a bit idêntica a softmax(qK) V gerada de uma só vez, mas, a qualquer momento, o conjunto de trabalho é um bloco tile × d_head, e não o N × d_head completo.
Passo 3: comparar a decodificação ingênua vs com cache em uma geração de 100 tokens
Contagem de operações de atenção. Ingênua: O(N²) = 5050. Com cache: O(N) = 100. O código imprime ambas.
Use-o
# HuggingFace transformers auto-enables KV cache on decoder-only generate().
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-3B",
attn_implementation="flash_attention_2", # use FA3 if Hopper
torch_dtype="bfloat16",
)
# generate() uses KV cache automatically
Produção com vLLM:
pip install vllm
vllm serve meta-llama/Llama-3.1-70B-Instruct \
--tensor-parallel-size 4 \
--max-model-len 32768 \
--enable-prefix-caching \
--kv-cache-dtype fp8
O cache de prefixo (prefix caching) entre requisições é uma grande vitória em 2026 — o mesmo prompt do sistema, exemplos few-shot ou documento de contexto longo reutilizam o KV entre chamadas. Para cargas de trabalho de agentes com prompts de ferramentas repetidos, o cache de prefixo gera rotineiramente um ganho de 5 vezes no throughput.
Envie para Produção
Consulte outputs/skill-inference-optimizer.md. A skill escolhe a implementação de atenção, a estratégia de KV cache, a quantização e a decodificação especulativa para um novo deploy de inferência.
Exercícios
- Fácil. Execute
code/main.py. Confirme que os decodificadores ingênuo e com cache produzem a mesma saída; observe a diferença na contagem de operações. - Médio. Implemente o cache de prefixo (prefix caching): dados um prompt P e várias conclusões (completions), execute uma etapa de forward pass sobre P para preencher o KV cache e depois ramifique por conclusão. Meça a aceleração versus codificar novamente o prompt P para cada uma.
- Difícil. Implemente um PagedAttention de brinquedo: KV cache em blocos fixos de 16 tokens com uma lista de blocos livres (free-list). Quando uma sequência terminar, retorne seus blocos ao pool. Simule 1.000 conclusões de chat com comprimentos variados. Compare a fragmentação de memória com a alocação contígua.
Key Terms
| Termo | O que as pessoas dizem | O que realmente significa |
|---|---|---|
| KV cache | "O truque que torna a decodificação rápida" | Vetores K e V armazenados de cada token de prefixo; novas queries aplicam atenção a eles em vez de recalculá-los. |
| HBM | "Memória principal da GPU" | High Bandwidth Memory (Memória de Alta Largura de Banda); 80 GB na H100, 192 GB na B200. Largura de banda de ~3 TB/s. |
| SRAM | "Memória interna (no chip)" | Memória rápida por SM, ~256 KB por SM na H100. Largura de banda de ~30 TB/s. |
| Flash Attention | "Kernel de atenção dividido em blocos" | Calcula a atenção sem materializar a matriz N×N na HBM. |
| Continuous batching | "Loteamento sem espera" | Substitui sequências concluídas por novas sem precisar esvaziar o lote. |
| PagedAttention | "O principal recurso do vLLM" | KV cache alocado em blocos fixos com uma tabela de páginas; elimina a fragmentação. |
| Prefix caching | "Reutilizar prompts longos" | Armazena o KV em cache para um prefixo compartilhado entre requisições; redução significativa de custo para agentes. |
| Speculative decoding | "Rascunho + verificação" | Um modelo de rascunho barato sugere tokens; o modelo grande verifica k deles de uma só vez. |
Leituras Adicionais
- Dao et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness — Flash 1.
- Dao (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning — Flash 2.
- Shah et al. (2024). FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision — Flash 3.
- FlashAttention-4 release notes (Dao-AILab, 2026) — pipeline de 5 estágios do Blackwell e o truque do exp2 por software; leia o README do repositório para entender as ressalvas do lançamento apenas para o passo de propagação (forward-only) que esta lição menciona.
- Kwon et al. (2023). Efficient Memory Management for Large Language Model Serving with PagedAttention — artigo do vLLM.
- Leviathan et al. (2023). Fast Inference from Transformers via Speculative Decoding — decodificação especulativa.
- Li et al. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty — artigo do EAGLE-1/2 para a abordagem de rascunho integrado citada na lição.
- Cai et al. (2024). Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads — a abordagem do Medusa referenciada junto com o EAGLE.
- vLLM docs — PagedAttention — o mergulho profundo canônico no design de bloco de 16 tokens e tabela de páginas.