Phase 07 - Lesson 12
KV Cache, Flash Attention & Otimização de Inferência
This lesson includes a graded coding exercise that runs in your browser, unlocked with lifetime access.
El entrenamiento es paralelo y limitado por FLOPs (FLOP-bound). La inferencia es serial y limitada por memoria (memory-bound). Diferentes cuellos de botella, diferentes trucos.
Tipo: Build Lenguajes: Python Prerrequisitos: Fase 7 · 02 (Self-Attention), Fase 7 · 05 (Full Transformer), Fase 7 · 07 (GPT) Tiempo: ~75 minutos
El Problema
Un decodificador autorregresivo ingenuo realiza un trabajo de O(N²) para generar N tokens: en cada paso, vuelve a calcular la atención sobre todo el prefijo. Para una respuesta de 4K tokens, eso equivale a 16 millones de operaciones de atención, la mayoría de ellas redundantes. Cada estado oculto (hidden state) de un token de prefijo es determinista una vez calculado; solo necesitas ejecutar la query del nuevo token contra las keys y values en caché de todo lo anterior.
Además de eso, la atención en sí mueve una gran cantidad de datos. La atención estándar materializa una matriz de puntuación (score matrix) N×N, una salida de softmax N×d y una salida final N×d; demasiadas lecturas y escritas en la HBM. Para N≥2K, la atención se vuelve limitada por memoria (memory-bound) antes de limitarse por computación (FLOP-bound). Los kernels de atención clásicos subutilizan las GPU modernas entre 4 y 10 veces.
Dos optimizaciones, ambas de Dao et al., llevaron la inferencia de modelos de frontera de "lenta" a "rápida":
- KV cache. Almacena los vectores K y V de cada token del prefijo. La atención de cada nuevo token es una query contra las keys en caché. La inferencia se reduce de
O(N²)aO(N)por paso de generación. - Flash Attention. Divide la computación de la atención en bloques (tiling) para que la matriz N×N completa nunca llegue a la HBM. Todo el proceso de softmax + matmul ocurre en SRAM. Aceleración de 2 a 4 veces en tiempo real (wall-clock) en A100; de 5 a 10 veces en H100 con FP8.
Para 2026, ambas son universales. Cada stack de inferencia en producción (vLLM, TensorRT-LLM, SGLang, llama.cpp) las da por sentado. Cada modelo de frontera se distribuye con Flash Attention activado.
El Concepto
Matemática de KV cache
Por capa del decodificador, por token, por cabezal:
bytes_per_token_per_layer = 2 * d_head * dtype_size
^
K and V
Para un modelo 7B con 32 capas, 32 cabezales, d_head=128, fp16:
per token per layer = 2 * 128 * 2 = 512 bytes
per token (32 layers) = 16 KB
per 32K context = 512 MB
Para Llama 3 70B (80 capas, d_head=128, GQA con 8 cabezales KV):
per token per layer = 2 * 8 * 128 * 2 = 4096 bytes (4 KB)
per 32K context = 10.4 GB
Esos 10 GB son la razón por la cual Llama 3 70B con un contexto de 128K necesita la mayor parte de una A100 de 40 GB solo para la KV cache con un tamaño de lote (batch size) de 1.
GQA es la gran victoria para la KV cache. MHA con 64 cabezales requeriría 32 GB. MLA comprime aún más.
Flash Attention — el truco de la división en bloques (tiling)
Standard attention:
S = Q @ K^T (HBM read, N×N, HBM write)
P = softmax(S) (HBM read, HBM write)
O = P @ V (HBM read, HBM write)
Tres viajes de ida y vuelta a la HBM. En la H100, el ancho de banda de la HBM es de 3 TB/s; el de SRAM es de 30 TB/s. Cada viaje a la HBM representa una ralentización de 10 veces en comparación con mantener todo en el chip.
Flash Attention:
for each block of Q (tile size ~128 × 128):
load Q_tile into SRAM
for each block of K, V:
load K_tile, V_tile into SRAM
compute S_tile = Q_tile @ K_tile^T (SRAM)
running softmax aggregation (SRAM)
accumulate into O_tile (SRAM)
write O_tile to HBM
Un viaje a la HBM por bloque (tile). La huella de memoria total cae de O(N²) a O(N). El paso de retropropagación (backward pass) recalcula algunos valores del paso de propagación (forward pass) en lugar de almacenarlos; otra victoria de memoria.
Truco numérico. El softmax dinámico (running softmax) mantiene (max, sum) a lo largo de los bloques para que la normalización final sea exacta. No es una aproximación: Flash Attention calcula una salida idéntica bit a bit a la atención estándar (módulo la no asociatividad de fp16).
Evolución de versiones:
| Versión | Año | Cambio clave | Aceleración en hardware de referencia |
|---|---|---|---|
| Flash 1 | 2022 | Kernel de SRAM dividido en bloques (tiled) | 2× en A100 |
| Flash 2 | 2023 | Mejor paralelismo, ordenación causal-first | 3× en A100 |
| Flash 3 | 2024 | Asincronía de Hopper, FP8 | 1.5–2× en H100 (~740 TFLOPs FP16) |
| Flash 4 | 2026 | Pipeline de 5 etapas de Blackwell, exp2 por software | Enfoque inicial en inferencia (inicialmente solo forward pass) |
Flash 4 admite solo el forward-pass en su lanzamiento. El entrenamiento aún utiliza Flash 3. El soporte para GQA y varlen para Flash 4 está pendiente (mediados de 2026).
Decodificación especulativa — la otra victoria en latencia
Un modelo económico propone N tokens. El modelo grande verifica los N en paralelo. Si la verificación acepta k tokens, pagaste 1 forward pass del modelo grande por k generaciones. Un k típico es de 3 a 5 en código y prosa.
Valores predeterminados en 2026:
- EAGLE 2 / Medusa. Cabezales de borrador (draft heads) integrados que comparten los estados ocultos del verificador. Aceleración de 2 a 3 veces sin pérdida de calidad.
- Decodificación especulativa con modelo de borrador (draft model). Aceleración de 2 a 4 veces en hardware de consumo.
- Lookahead decoding. Iteración de Jacobi; no se necesita modelo de borrador. De nicho pero gratuito.
Continuous batching (Loteado continuo)
Inferencia por lotes clásica: espera a que termine la secuencia más lenta para iniciar un nuevo lote. Desperdicia GPU cuando las respuestas cortas terminan antes.
Continuous batching (introducido primero en Orca, ahora en vLLM, TensorRT-LLM, SGLang): intercambia nuevas solicitudes en el lote tan pronto como terminan las anteriores. Incremento de 5 a 10 veces en el rendimiento (throughput) para cargas de trabajo de chat típicas.
PagedAttention — KV cache como memoria virtual
La característica estrella de vLLM. La KV cache se asigna en bloques de 16 tokens; una tabla de páginas mapeia posiciones lógicas a bloques físicos. Permite compartir KV entre muestras paralelas (beam search, muestreo paralelo), realizar hot-swap de prefijos para caché de prompts y desfragmentar la memoria. Mejora del rendimiento (throughput) de 4 veces en comparación con la asignación contigua ingenua.
Constrúyelo
Consulta code/main.py. Implementamos:
- Un decodificador incremental ingenuo
O(N²). - Un decodificador con KV cache
O(N). - Un softmax dividido en bloques (tiled) que simula el algoritmo de running-max de Flash Attention.
Paso 1: KV cache
class KVCache:
def __init__(self, n_layers, n_heads, d_head):
self.K = [[[] for _ in range(n_heads)] for _ in range(n_layers)]
self.V = [[[] for _ in range(n_heads)] for _ in range(n_layers)]
def append(self, layer, head, k, v):
self.K[layer][head].append(k)
self.V[layer][head].append(v)
def read(self, layer, head):
return self.K[layer][head], self.V[layer][head]
Simple: continúa expandiendo los vectores K, V de cada token en listas por capa y por cabezal.
Paso 2: softmax dividido en bloques (tiled softmax)
def tiled_softmax_dot(q, K, V, tile=4):
"""Flash-attention-style softmax(qK^T)V with running max/sum."""
m = float("-inf")
s = 0.0
out = [0.0] * len(V[0])
for start in range(0, len(K), tile):
k_block = K[start:start + tile]
v_block = V[start:start + tile]
scores = [sum(qi * ki for qi, ki in zip(q, k)) for k in k_block]
new_m = max(m, *scores)
exp_old = math.exp(m - new_m) if m != float("-inf") else 0.0
exp_new = [math.exp(sc - new_m) for sc in scores]
s = s * exp_old + sum(exp_new)
for j in range(len(out)):
out[j] = out[j] * exp_old + sum(e * v[j] for e, v in zip(exp_new, v_block))
m = new_m
return [o / s for o in out]
Salida idéntica bit a bit a softmax(qK) V generada de una sola vez, pero, en cualquier momento, el conjunto de trabajo es un bloque tile × d_head, no el N × d_head completo.
Paso 3: comparar la decodificación ingenua frente a la decodificación con caché en una generación de 100 tokens
Conteo de operaciones de atención. Ingenua: O(N²) = 5050. Con caché: O(N) = 100. El código imprime ambas.
Úsalo
# HuggingFace transformers auto-enables KV cache on decoder-only generate().
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-3B",
attn_implementation="flash_attention_2", # use FA3 if Hopper
torch_dtype="bfloat16",
)
# generate() uses KV cache automatically
vLLM en producción:
pip install vllm
vllm serve meta-llama/Llama-3.1-70B-Instruct \
--tensor-parallel-size 4 \
--max-model-len 32768 \
--enable-prefix-caching \
--kv-cache-dtype fp8
El almacenamiento en caché de prefijos (prefix caching) entre solicitudes es una gran victoria en 2026: el mismo prompt del sistema, ejemplos few-shot o documentos de contexto largo reutilizan KV entre llamadas. Para cargas de trabajo de agentes con prompts de herramientas repetidos, el almacenamiento en caché de prefijos genera rutinariamente una ganancia de rendimiento (throughput) de 5 veces.
Envía a Producción
Consulta outputs/skill-inference-optimizer.md. La skill elige la implementación de la atención, la estrategia de KV cache, la cuantificación y la decodificación especulativa para un nuevo despliegue de inferencia.
Ejercicios
- Fácil. Ejecuta
code/main.py. Confirma que los decodificadores ingenuo y con caché producen la misma salida; observa la diferencia en el conteo de operaciones. - Medio. Implementa el almacenamiento en caché de prefijos (prefix caching): dado un prompt P y varias finalizaciones (completions), ejecuta un forward pass sobre P para llenar la KV cache, luego bifurca por finalización. Mide la aceleración frente a volver a codificar P para cada una.
- Difícil. Implementa un PagedAttention de juguete: KV cache en bloques fijos de 16 tokens con una lista libre (free-list). Cuando termine una secuencia, devuelve sus bloques al pool. Simule 1,000 finalizaciones de chat con longitudes variables. Compara la fragmentación de la memoria frente a la asignación contigua.
Key Terms
| Término | Lo que dice la gente | Lo que realmente significa |
|---|---|---|
| KV cache | "El truco que hace rápida la decodificación" | Vectores K y V almacenados de cada token de prefijo; las nuevas queries aplican atención a ellos en lugar de volver a calcularlos. |
| HBM | "Memoria principal de la GPU" | High Bandwidth Memory (Memoria de Alto Ancho de Banda); 80 GB en H100, 192 GB en B200. Ancho de banda de ~3 TB/s. |
| SRAM | "Memoria integrada (en el chip)" | Memoria rápida por SM, ~256 KB por SM en H100. Ancho de banda de ~30 TB/s. |
| Flash Attention | "Kernel de atención dividido en bloques" | Calcula la atención sin materializar la matriz N×N en la HBM. |
| Continuous batching | "Loteado sin espera" | Intercambia secuencias terminadas por nuevas sin necesidad de vaciar el lote. |
| PagedAttention | "La característica principal de vLLM" | KV cache asignada en bloques fijos con una tabla de páginas; elimina la fragmentación. |
| Prefix caching | "Reutilizar prompts largos" | Almacena en caché la KV para un prefijo compartido entre solicitudes; reducción significativa de costos para agentes. |
| Speculative decoding | "Borrador + verificación" | Un modelo de borrador económico propone tokens; el modelo grande verifica k en una sola pasada. |
Lecturas Adicionales
- Dao et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness — Flash 1.
- Dao (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning — Flash 2.
- Shah et al. (2024). FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision — Flash 3.
- FlashAttention-4 release notes (Dao-AILab, 2026) — pipeline de 5 etapas de Blackwell y el truco de exp2 por software; lee el README del repositorio para comprender las salvedades del lanzamiento solo para el paso de propagación (forward-only) que esta lección menciona.
- Kwon et al. (2023). Efficient Memory Management for Large Language Model Serving with PagedAttention — artículo de vLLM.
- Leviathan et al. (2023). Fast Inference from Transformers via Speculative Decoding — decodificación especulativa.
- Li et al. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty — artículo de EAGLE-1/2 para el enfoque de borrador integrado citado en la lección.
- Cai et al. (2024). Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads — el enfoque de Medusa referenciado junto con EAGLE.
- vLLM docs — PagedAttention — el análisis profundo canónico sobre el diseño de bloques de 16 tokens y tablas de páginas.