Phase 10 - Lesson 12

Optimización de Inferencia

This lesson includes a graded coding exercise that runs in your browser, unlocked with lifetime access.

Dos fases definen la inferencia de LLM. Prefill procesa su prompt en paralelo -- limitado por computación (compute-bound). Decode genera tokens de uno en uno -- limitado por memoria (memory-bound). Cada optimización se enfoca en una o ambas.

Tipo: Build Idiomas: Python Prerrequisitos: Fase 10, Lecciones 01-08 (arquitectura Transformer, atención) Tiempo: ~120 minutos

Objetivos de Aprendizaje

  • Implementar KV-cache para eliminar el cómputo redundante durante la generación autorregresiva de tokens
  • Explicar las fases de prefill vs decode en la inferencia de LLM y por que cada una tiene diferentes cuellos de botella (limitado por computación vs limitado por memoria)
  • Implementar los conceptos de continuous batching (loteo continuo) y PagedAttention para maximizar la utilización de la GPU bajo peticiones concurrentes
  • Comparar técnicas de optimización de inferencia (KV-cache, decodificación especulativa, flash attention) y sus compromisos (tradeoffs) entre rendimiento (throughput) y latencia

El Problema

Despliegas Llama 3 70B en 4 GPUs A100. Un solo usuario obtiene ~50 tokens por segundo. Se siente rápido. Luego, 100 usuarios acceden al endpoint simultáneamente. El rendimiento cae a 3 tokens/segundo/usuario. Tu factura de GPU de

5,000/mes está sirviendo respuestas más lento de lo que escribe un humano.

El modelo en sí no cambia entre 1 usuario y 100 usuarios. Mismos pesos, misma arquitectura, misma matemática. Lo que cambia es cómo programas el trabajo. La inferencia ingenua desperdicia más del 90% del cómputo disponible de la GPU. Un usuario que espera por el token 47 mantiene abierto un slot de lote completo mientras el bus de memoria de la GPU permanece inactivo entre multiplicaciones de matrices (matmuls). Mientras tanto, el prompt de 2,000 tokens de un nuevo usuario podría llenar ese tiempo muerto con cómputo útil.

Esto no es un problema de escalado. Es un problema de programación (scheduling). Las técnicas de esta lección -- KV caching, continuous batching, PagedAttention, decodificación especulativa, caché de prefijo -- son las que diferencian una factura de inferencia de 5,000/mes de una de $5,000/mes sirviendo el mismo tráfico.

vLLM sirviendo Llama 3 70B en 4xA100-80GB logra ~50 tokens/segundo/usuario con baja concurrencia, y sostiene de 15 a 25 TPS/usuario con 100 peticiones concurrentes a través de continuous batching y PagedAttention. Sin estas optimizaciones, el mismo hardware sirve 5 TPS/usuario con esa concurrencia. Mismas GPUs, mismo modelo, 4 veces más rendimiento.

El Concepto

Prefill vs Decode

Cada petición de inferencia de LLM tiene dos fases distintas.

Prefill procesa todo el prompt de entrada. Todos los tokens son conocidos, por lo que la atención se puede calcular en paralelo a lo largo de toda la secuencia. Esta es una gran multiplicación de matrices -- los núcleos de la GPU se mantienen ocupados. El cuello de botella es el cómputo: cuántos FLOPS puede entregar tu hardware por segundo. Una A100 hace 312 TFLOPS (BF16). El prefill para un prompt de 4,096 tokens en un modelo de 70B toma ~400ms en una sola A100.

Decode genera tokens de salida uno a la vez. Cada nuevo token atiende a todos los tokens anteriores, pero solo se produce un token por cada paso hacia adelante (forward pass). Las matrices de pesos son del mismo tamaño que durante el prefill, pero las estás multiplicando por un solo vector en lugar de una matriz. Los núcleos de la GPU terminan en microsegundos, luego esperan a que llegue el siguiente lote de pesos desde la memoria. El cuello de botella es el ancho de banda de la memoria: qué tan rápido puedes transmitir (stream) los pesos del modelo desde HBM a las unidades de cómputo. Una A100 tiene un ancho de banda de 2 TB/s. Un modelo de 70B en FP16 pesa 140 GB. Leer el modelo completo una vez toma 70ms -- ese es tu piso para un solo paso de decode.

graph LR
    subgraph "Prefill (limitado por computación)"
        P1["Todos los tokens del prompt"] --> P2["Atención paralela"]
        P2 --> P3["Utilización total de matmul"]
    end

    subgraph "Decode (limitado por memoria)"
        D1["Un token a la vez"] --> D2["Generación secuencial"]
        D2 --> D3["Esperando lecturas de memoria"]
    end

    P3 --> D1

La relación ops:byte (también llamada intensidad aritmética) captura este compromiso (tradeoff). Mide cuántas operaciones realizas por cada byte cargado desde la memoria.

ops:byte ratio = FLOPs per token / bytes read from memory

Durante el prefill con un lote de 4,096 tokens, realizas ~4,096 operaciones de multiplicar-acumular por cada peso cargado. La relación es alta -- estás limitado por computación (compute-bound). Durante el decode con tamaño de lote 1, realizas ~1 operación por cada peso cargado. La relación es baja -- estás limitado por memoria (memory-bound).

La idea fundamental: decode está limitado por memoria porque lees todo el modelo para producir un solo token. Cada optimización a continuación reduce lo que lees, aumenta el lote de tokens procesados por lectura o evita las lecturas por completo.

KV Cache

Durante la atención, la consulta (query) de cada token atiende a los vectores de clave (key) y valor (value) de todos los tokens anteriores. Sin caché, generar el token N requiere recalcular las proyecciones de clave y valor para todos los N-1 tokens precedentes. El token 1 se proyecta al generar el token 2, luego otra vez para el token 3, y luego otra vez para el token 4. Para el token 1,000, has proyectado el token 1 un total de 999 veces.

El KV cache almacena las proyecciones de clave y valor de todos los tokens anteriores. Al generar el token N, solo calculas la clave y el valor para el token N, y luego los concatenas con los K/V almacenados en caché de los tokens 1 a N-1.

graph TD
    subgraph "Sin KV Cache"
        A1["Token 5: recomputar K,V para tokens 1-4"]
        A2["Token 6: recomputar K,V para tokens 1-5"]
        A3["Token 7: recomputar K,V para tokens 1-6"]
    end

    subgraph "Con KV Cache"
        B1["Token 5: computar K5,V5, leer K1-4,V1-4 del cache"]
        B2["Token 6: computar K6,V6, leer K1-5,V1-5 del cache"]
        B3["Token 7: computar K7,V7, leer K1-6,V1-6 del cache"]
    end

Fórmula de memoria para KV cache:

KV cache size = 2 * num_layers * num_kv_heads * head_dim * seq_len * bytes_per_param

Para Llama 3 70B (80 capas, 8 cabezas de KV con GQA, head_dim=128, BF16):

per token: 2 * 80 * 8 * 128 * 2 bytes = 327,680 bytes = 320 KB
at 4,096 tokens: 320 KB * 4,096 = 1.28 GB
at 128K tokens: 320 KB * 131,072 = 40 GB

Una sola conversación de contexto 128K para Llama 3 70B consume 40 GB de KV cache -- la mitad de la memoria de una A100. Con 100 usuarios concurrentes a 4K tokens cada uno, KV cache por sí solo requiere 128 GB. Es por eso que la gestión de KV cache es el desafío central de la optimización de inferencia.

Continuous Batching

El loteo estático (static batching) espera hasta que llega un lote de N peticiones, las procesa juntas y espera hasta que todas terminen antes de aceptar nuevas peticiones. Si una petición necesita 500 tokens y otra necesita 10, la petición corta permanece inactiva durante 490 pasos de decode después de terminar.

El loteo continuo (continuous batching, también llamado iteration-level batching) inserta nuevas peticiones en el lote tan pronto como se completa cualquiera de ellas. El lote se reevalúa en cada paso de decode. Una petición que termina después de 10 tokens es reemplazada inmediatamente por una petición en espera.

sequenceDiagram
    participant GPU
    participant R1 as Petición 1 (50 tokens)
    participant R2 as Petición 2 (10 tokens)
    participant R3 as Petición 3 (30 tokens)
    participant R4 as Petición 4 (en espera)

    Note over GPU: Lote estático
    GPU->>R1: Procesar lote [R1, R2, R3]
    Note over R2: R2 terminada en el paso 10
    Note over R2: Desperdiciando 40 pasos...
    Note over R3: R3 terminada en el paso 30
    Note over R3: Desperdiciando 20 pasos...
    GPU->>R4: Finalmente iniciar R4 en el paso 50

    Note over GPU: Lote continuo
    GPU->>R1: Procesar lote [R1, R2, R3]
    Note over R2: R2 terminada en el paso 10
    GPU->>R4: Insertar R4 en el paso 11
    Note over R3: R3 terminada en el paso 30

La mejora del rendimiento depende de cuánto varíen las longitudes de las salidas. Con longitudes uniformes, continuous batching iguala al loteo estático. Con longitudes variables (el caso común), continuous batching puede ofrecer un rendimiento de 2 a 5 veces mayor porque los slots de la GPU nunca se quedan vacíos.

PagedAttention

El KV cache para cada petición es un bloque de memoria contiguo. A medida que las peticiones llegan y se van, la memoria se fragmenta -- exactamente igual que la fragmentación de la RAM en los sistemas operativos. Una petición de 4K tokens necesita 1.28 GB contiguos. Incluso si tienes 2 GB libres en total, es posible que no tengas 1.28 GB contiguos. Terminas desperdiciando memoria o rechazando la petición.

PagedAttention (de vLLM) aplica memoria virtual al estilo de los sistemas operativos al KV cache. En lugar de asignar un bloque contiguo por petición, asigna "páginas" de tamaño fijo (típicamente de 16 tokens cada una). Las páginas pueden estar en cualquier lugar de la memoria física de la GPU. Una tabla de páginas mapea las posiciones de la secuencia lógica de cada petición a las ubicaciones físicas de las páginas.

graph TD
    subgraph "Asignación contigua"
        C1["Petición A: bloque de 2GB"]
        C2["[libre: 0.5GB]"]
        C3["Petición B: bloque de 1GB"]
        C4["[libre: 1.5GB -- pero fragmentado]"]
    end

    subgraph "PagedAttention"
        P1["Pool de páginas: 256 páginas de 16 tokens cada una"]
        P2["Petición A: páginas 3,7,12,45,88..."]
        P3["Petición B: páginas 1,4,9,22,67..."]
        P4["Sin fragmentación, sin desperdicio"]
    end

PagedAttention también permite copy-on-write (copia en escritura) para prefijos compartidos. Si 50 peticiones comparten el mismo prompt del sistema, las páginas del KV cache para ese prompt del sistema se almacenan una sola vez y son referenciadas por las 50 peticiones. Solo cuando una petición diverge (mensajes de usuario diferentes), obtiene sus propias páginas. Esto reduce drásticamente el uso de la memoria en aplicaciones con prompts del sistema compartidos.

vLLM reporta un desperdicio de memoria casi nulo (~4% frente a ~60-80% en la asignación ingenua) gracias a PagedAttention.

Decodificación Especulativa

Decode es lento porque es secuencial -- generas un token, lo retroalimentas, generas el siguiente. Pero, ¿qué pasaría si pudieras adivinar los siguientes 5 tokens de forma económica y luego verificarlos todos a la vez?

La decodificación especulativa utiliza un modelo de borrador (draft model) pequeño y rápido para generar K tokens candidatos. El modelo objetivo (target model) grande luego procesa los K candidatos en un solo paso hacia adelante (lo que se parece a un prefill -- paralelo, limitado por computación, eficiente). Si el modelo objetivo está de acuerdo con las predicciones del modelo de borrador, aceptas los K tokens en el tiempo de un solo paso hacia adelante del objetivo. Si no está de acuerdo en la posición j, aceptas los tokens del 1 al j-1 y descartas el resto.

graph LR
    D["Draft model (1B)"] -->|"Generar 5 tokens<br/>~5ms"| C["Candidatos: the cat sat on the"]
    C --> T["Target model (70B)"]
    T -->|"Verificar los 5 en una pasada<br/>~70ms"| V{"¿Match?"}
    V -->|"4 de 5 coinciden"| A["Aceptar 4 tokens en 75ms<br/>vs 280ms secuencial"]
    V -->|"Incoherencia en pos 5"| R["Rechazar token 5<br/>Re-muestrear del objetivo"]

La aceleración depende de la tasa de aceptación -- qué tan a menudo las predicciones del modelo de borrador coinciden con las del objetivo. Para un Llama 3 8B haciendo el borrador para Llama 3 70B, las tasas de aceptación del 70-85% son típicas en lenguaje natural. Esto se traduce en una aceleración del decode de 2 a 3 veces.

Tres enfoques para la decodificación especulativa:

Método Origen del borrador Tasa de aceptación Sobrecarga (Overhead)
Draft-target (Leviathan et al.) Modelo pequeño independiente 70-85% Memoria del modelo borrador
EAGLE (Li et al.) Cabezal ligero en el objetivo 75-90% ~1% de parámetros adicionales
N-gram lookup Tabla de n-gramas de tokens 40-60% Despreciable

EAGLE entrena un pequeño cabezal autorregresivo sobre los estados ocultos (hidden states) del modelo objetivo. Predice el embedding del siguiente token utilizando las características de la penúltima capa del modelo objetivo. Debido a que opera sobre las propias representaciones del modelo objetivo (not las de un modelo independiente), logra tasas de aceptación más altas con una memoria adicional mínima. EAGLE-2 agrega un árbol de borrador dinámico que ajusta la cantidad de candidatos en función del contexto.

La decodificación especulativa N-gram mantiene una tabla de continuaciones de n-gramas a partir del contexto actual o de un corpus preconstruido. Si el borrador coincide con lo que apareció antes en la misma conversación (patrones repetitivos, código, salida estructurada), se ejecuta con cero sobrecarga de red neuronal. Las tasas de aceptación son más bajas en promedio, pero el costo por especulación es esencialmente gratuito.

La decodificación especulativa es matemáticamente exacta -- la distribución de salida es idéntica a la distribución del modelo objetivo. No es una aproximación. El paso de verificación garantiza que cada token aceptado tenga exactamente la probabilidad que el modelo objetivo le habría asignado.

Caché de Prefijos

Muchas peticiones comparten el mismo prefijo. El prompt del sistema de un chatbot. Un bloque de contexto de RAG. Un conjunto de ejemplos few-shot. Sin la caché de prefijos, cada petición recalcula el KV cache para estos tokens compartidos desde cero.

La caché de prefijos almacena el KV cache para prefijos comunes y lo reutiliza en todas las peticiones. Cuando llega una nueva petición con un prefijo conocido, el sistema copia (o hace referencia a) las entradas de KV almacenadas en caché y solo calcula el KV para el sufijo único.

Para un prompt del sistema de 2,000 tokens compartido entre todas las peticiones, la caché de prefijos elimina ~400ms de prefill por petición. Con 100 peticiones/segundo, eso ahorra 40 segundos de cómputo de GPU por segundo -- más del trabajo de una GPU completa.

RadixAttention de SGLang implementa la caché de prefijos con un árbol de prefijos (trie) que indexa los prefijos por su contenido de tokens. Cualquier petición que coincida con un prefijo almacenado obtiene su KV cache de forma gratuita. El árbol permite coincidencias parciales de prefijo -- si compartes 1,500 de 2,000 tokens de prefijo con una entrada almacenada en caché, reutilizas esos 1,500 y calculas solo 500.

Inference Engines

Tres motores dominan el servicio de LLM en producción:

Motor (Engine) Inovación clave Mejor para
vLLM PagedAttention, continuous batching Servicio de propósito general, máxima compatibilidad
SGLang RadixAttention (caché de prefijos), generación estructurada Chatbots de múltiples turnos, decodificación restringrada
TensorRT-LLM Fusión de kernel de NVIDIA, cuantización FP8 Máximo rendimiento en una sola GPU en hardware NVIDIA

vLLM es el punto de partida predeterminado. Admite la gama más amplia de modelos, se ejecuta en cualquier proveedor de GPU (NVIDIA, AMD, Intel) y logra un gran rendimiento a través de PagedAttention + continuous batching. La API compatible con OpenAI significa que puedes soltarlo como reemplazo para cualquier llamada de API de OpenAI.

SGLang se basa en los mismos cimientos que vLLM pero agrega RadixAttention para la caché de prefijos y un lenguaje específico de dominio para programas de LLM estructurados. Si tu carga de trabajo involucra conversaciones de varios turnos, uso de herramientas o decodificación restringida (salida JSON, generación guiada por regex), SGLang a menudo supera a vLLM de 2 a 5 veces gracias a la reutilización de prefijos.

TensorRT-LLM compila modelos en kernels optimizados para GPU NVIDIA. Fusiona operaciones (atención + lineal + activación en un solo kernel), utiliza FP8 en GPUs H100 e integra con NVIDIA Triton Inference Server para el despliegue en producción. Logra el mayor rendimiento en una sola GPU en hardware NVIDIA, pero requiere más configuración y solo funciona en GPUs NVIDIA.

Números del mundo real para Llama 3 70B (4xA100-80GB, BF16):

Métrica vLLM SGLang TensorRT-LLM
Rendimiento (1 usuario) ~50 TPS ~55 TPS ~65 TPS
Rendimiento (100 usuarios) ~2,500 TPS totales ~3,200 TPS totales ~3,000 TPS totales
Tiempo al primer token (TTFT) ~400ms ~300ms (prefix hit) ~350ms
Contexto máximo 128K 128K 128K

El Framework Ops:Byte

No puedes optimizar lo que no mides. La relación ops:byte te dice si estás limitado por computación o por memoria, lo que determina qué optimizaciones son importantes.

Compute roof: peak FLOPS of the GPU
Memory roof:  peak bandwidth * ops:byte ratio

Cuando ops:byte es bajo (decode, lotes pequeños), golpeas el límite del ancho de banda de la memoria. Agregar más cómputo (mayor reloj, más núcleos) no ayuda. Necesitas reducir las lecturas de memoria (cuantización, compresión de KV cache) o aumentar el tamaño del lote para amortizar las lecturas a lo largo de un trabajo más útil.

Cuando ops:byte es alto (prefill, lotes grandes), golpeas el límite de computación. La optimización del ancho de banda de memoria no ayuda. Necesitas GPUs más rápidas, fusión de kernels o precisión reducida para exprimir más FLOPS.

Escenario ops:byte Limitación (Bound) Optimizar con
Prefill, lote=1 ~4,096 Computación Fusión de kernel, FP8
Decode, lote=1 ~1 Memoria Cuantización, compresión de KV
Decode, lote=32 ~32 Memoria Lote más grande, continuous batching
Decode, lote=256 ~256 Transición Ambos importan
Decode, lote=1024 ~1,024 Computación Fusión de kernel, paralelismo de tensores

El punto de cruce (crossover) en la A100 es de alrededor de ops:byte = 156 (312 TFLOPS / 2 TB/s). Por debajo de 156, estás limitado por memoria. Por encima de 156, estás limitado por computación. Continuous batching empuja al decode hacia este cruce al empaquetar más tokens por iteración.

Implementación (Build It)

Paso 1: KV Cache desde Cero

Construimos un KV cache multi-cabezal que almacena proyecciones de clave y valor por capa y por cabezal, y demuestra el patrón de crecimiento de la memoria.

import numpy as np

class KVCache:
    def __init__(self, num_layers, num_heads, head_dim, max_seq_len, dtype=np.float16):
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.max_seq_len = max_seq_len
        self.dtype = dtype

        self.k_cache = np.zeros(
            (num_layers, num_heads, max_seq_len, head_dim), dtype=dtype
        )
        self.v_cache = np.zeros(
            (num_layers, num_heads, max_seq_len, head_dim), dtype=dtype
        )
        self.seq_len = 0

    def update(self, layer_idx, new_keys, new_values):
        num_new = new_keys.shape[1]
        end = self.seq_len + num_new
        self.k_cache[layer_idx, :, self.seq_len:end, :] = new_keys
        self.v_cache[layer_idx, :, self.seq_len:end, :] = new_values
        return (
            self.k_cache[layer_idx, :, :end, :],
            self.v_cache[layer_idx, :, :end, :]
        )

    def advance(self, num_tokens):
        self.seq_len += num_tokens

    def memory_bytes(self):
        return self.k_cache.nbytes + self.v_cache.nbytes

    def used_bytes(self):
        per_token = 2 * self.num_layers * self.num_heads * self.head_dim * np.dtype(self.dtype).itemsize
        return per_token * self.seq_len

Paso 2: Atención con KV Cache

Una atención multi-cabezal simplificada que utiliza el KV cache para los pasos de decode.

def scaled_dot_product_attention(query, keys, values):
    head_dim = query.shape[-1]
    scores = np.matmul(query, keys.transpose(0, 1, 3, 2)) / np.sqrt(head_dim)
    seq_len_q = scores.shape[-2]
    seq_len_k = scores.shape[-1]
    if seq_len_q > 1:
        mask = np.triu(np.ones((seq_len_q, seq_len_k), dtype=np.float32), k=seq_len_k - seq_len_q + 1)
        scores = scores + mask * (-1e9)
    max_scores = np.max(scores, axis=-1, keepdims=True)
    exp_scores = np.exp(scores - max_scores)
    attn_weights = exp_scores / np.sum(exp_scores, axis=-1, keepdims=True)
    return np.matmul(attn_weights, values)


class MultiHeadAttention:
    def __init__(self, d_model, num_heads):
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        scale = np.sqrt(2.0 / d_model)
        self.W_q = np.random.randn(d_model, d_model).astype(np.float32) * scale
        self.W_k = np.random.randn(d_model, d_model).astype(np.float32) * scale
        self.W_v = np.random.randn(d_model, d_model).astype(np.float32) * scale
        self.W_o = np.random.randn(d_model, d_model).astype(np.float32) * scale

    def forward(self, x, kv_cache=None, layer_idx=0):
        batch, seq_len, d_model = x.shape
        Q = np.matmul(x, self.W_q).reshape(batch, seq_len, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
        K = np.matmul(x, self.W_k).reshape(batch, seq_len, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
        V = np.matmul(x, self.W_v).reshape(batch, seq_len, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)

        if kv_cache is not None:
            K_full, V_full = kv_cache.update(layer_idx, K[0], V[0])
            K = K_full[np.newaxis, :, :, :]
            V = V_full[np.newaxis, :, :, :]
            if seq_len == 1:
                kv_cache.advance(1)

        attn_out = scaled_dot_product_attention(Q, K, V)
        attn_out = attn_out.transpose(0, 2, 1, 3).reshape(batch, -1, d_model)
        return np.matmul(attn_out, self.W_o)

Paso 3: Simulador de Continuous Batching

Esto simula la diferencia de programación entre el loteo estático y continuo.

import heapq

class Request:
    def __init__(self, request_id, prompt_tokens, output_tokens, arrival_step):
        self.request_id = request_id
        self.prompt_tokens = prompt_tokens
        self.output_tokens = output_tokens
        self.arrival_step = arrival_step
        self.tokens_generated = 0
        self.start_step = None
        self.end_step = None

    def is_done(self):
        return self.tokens_generated >= self.output_tokens


def simulate_static_batching(requests, batch_size):
    step = 0
    completed = []
    queue = list(requests)
    queue.sort(key=lambda r: r.arrival_step)

    while queue:
        batch = []
        while queue and len(batch) < batch_size:
            r = queue.pop(0)
            r.start_step = max(step, r.arrival_step)
            batch.append(r)

        if batch:
            step = max(step, max(r.start_step for r in batch))
            max_output = max(r.output_tokens for r in batch)
            for r in batch:
                r.tokens_generated = r.output_tokens
                r.end_step = step + max_output
            step += max_output
            completed.extend(batch)

    return completed


def simulate_continuous_batching(requests, batch_size):
    step = 0
    completed = []
    queue = sorted(requests, key=lambda r: r.arrival_step)
    queue_idx = 0
    active = []
    waiting = []

    while queue_idx < len(queue) or active or waiting:
        while queue_idx < len(queue) and queue[queue_idx].arrival_step <= step:
            waiting.append(queue[queue_idx])
            queue_idx += 1

        while waiting and len(active) < batch_size:
            r = waiting.pop(0)
            r.start_step = step
            active.append(r)

        if not active:
            if waiting:
                step += 1
                continue
            elif queue_idx < len(queue):
                step = queue[queue_idx].arrival_step
                continue
            else:
                break

        for r in active:
            r.tokens_generated += 1

        done = [r for r in active if r.is_done()]
        for r in done:
            r.end_step = step + 1
            completed.append(r)
        active = [r for r in active if not r.is_done()]

        step += 1

    return completed


def batching_stats(completed):
    latencies = [r.end_step - r.arrival_step for r in completed]
    total_time = max(r.end_step for r in completed) - min(r.arrival_step for r in completed)
    total_tokens = sum(r.output_tokens for r in completed)
    return {
        "avg_latency": np.mean(latencies),
        "p50_latency": np.median(latencies),
        "p99_latency": np.percentile(latencies, 99),
        "total_time": total_time,
        "throughput": total_tokens / total_time if total_time > 0 else 0,
    }

Paso 4: Caché de Prefijos

Una caché de prefijos basada en trie que almacena entradas de KV para prefijos compartidos.

class TrieNode:
    def __init__(self):
        self.children = {}
        self.kv_data = None
        self.hit_count = 0


class PrefixCache:
    def __init__(self, max_entries=1000):
        self.root = TrieNode()
        self.max_entries = max_entries
        self.total_entries = 0
        self.hits = 0
        self.misses = 0

    def _walk(self, token_ids):
        node = self.root
        depth = 0
        for tid in token_ids:
            if tid not in node.children:
                break
            node = node.children[tid]
            depth += 1
        return node, depth

    def lookup(self, token_ids):
        node, depth = self._walk(token_ids)
        if depth > 0:
            self.hits += 1
            current = self.root
            for tid in token_ids[:depth]:
                current = current.children[tid]
                current.hit_count += 1
            kv_entries = []
            current = self.root
            for tid in token_ids[:depth]:
                current = current.children[tid]
                if current.kv_data is not None:
                    kv_entries.append(current.kv_data)
            return depth, kv_entries
        self.misses += 1
        return 0, []

    def insert(self, token_ids, kv_per_token):
        node = self.root
        for i, tid in enumerate(token_ids):
            if tid not in node.children:
                if self.total_entries >= self.max_entries:
                    return i
                node.children[tid] = TrieNode()
                self.total_entries += 1
            node = node.children[tid]
            if i < len(kv_per_token):
                node.kv_data = kv_per_token[i]
        return len(token_ids)

    def hit_rate(self):
        total = self.hits + self.misses
        return self.hits / total if total > 0 else 0.0

Paso 5: Simulador de Decodificación Especulativa

Simulamos la decodificación especulativa draft-target con tasas de aceptación configurables.

class DraftModel:
    def __init__(self, vocab_size, acceptance_rate=0.8):
        self.vocab_size = vocab_size
        self.acceptance_rate = acceptance_rate

    def generate(self, context, num_tokens):
        tokens = np.random.randint(0, self.vocab_size, size=num_tokens)
        return tokens

    def get_probs(self, context, token):
        probs = np.random.dirichlet(np.ones(self.vocab_size))
        return probs


class TargetModel:
    def __init__(self, vocab_size):
        self.vocab_size = vocab_size

    def get_probs(self, context, tokens=None):
        if tokens is not None:
            return [np.random.dirichlet(np.ones(self.vocab_size)) for _ in tokens]
        return np.random.dirichlet(np.ones(self.vocab_size))


def speculative_decode(draft_model, target_model, context, num_speculative=5,
                       draft_cost=1.0, target_cost=10.0, verify_cost=12.0):
    total_tokens = 0
    total_cost = 0.0
    accepted_counts = []
    context = list(context)

    max_tokens = 100

    while total_tokens < max_tokens:
        draft_tokens = draft_model.generate(context, num_speculative)
        total_cost += draft_cost * num_speculative

        target_probs = target_model.get_probs(context, draft_tokens)
        total_cost += verify_cost

        accepted = 0
        for i, token in enumerate(draft_tokens):
            draft_p = draft_model.get_probs(context + list(draft_tokens[:i]), token)
            target_p = target_probs[i]

            r = np.random.random()
            acceptance_prob = min(1.0, target_p[token] / (draft_p[token] + 1e-10))

            if r < draft_model.acceptance_rate:
                accepted += 1
                context.append(token)
                total_tokens += 1
            else:
                new_token = np.random.choice(draft_model.vocab_size, p=target_p)
                context.append(new_token)
                total_tokens += 1
                break

        accepted_counts.append(accepted)

        if accepted == num_speculative:
            bonus_probs = target_model.get_probs(context)
            bonus_token = np.random.choice(draft_model.vocab_size, p=bonus_probs)
            context.append(bonus_token)
            total_tokens += 1

    sequential_cost = total_tokens * target_cost
    return {
        "total_tokens": total_tokens,
        "speculative_cost": total_cost,
        "sequential_cost": sequential_cost,
        "speedup": sequential_cost / total_cost if total_cost > 0 else 1.0,
        "avg_accepted": np.mean(accepted_counts),
        "acceptance_rate": np.mean(accepted_counts) / num_speculative,
    }


def compare_speculation_strategies(vocab_size=1000, num_trials=20):
    results = {}

    for name, acceptance_rate, spec_tokens in [
        ("Draft-target (8B->70B)", 0.78, 5),
        ("EAGLE", 0.85, 6),
        ("N-gram", 0.50, 4),
        ("No speculation", 0.0, 0),
    ]:
        if spec_tokens == 0:
            results[name] = {
                "speedup": 1.0,
                "acceptance_rate": 0.0,
                "avg_accepted": 0.0,
            }
            continue

        trial_results = []
        for _ in range(num_trials):
            draft = DraftModel(vocab_size, acceptance_rate=acceptance_rate)
            target = TargetModel(vocab_size)
            context = list(np.random.randint(0, vocab_size, size=10))
            result = speculative_decode(draft, target, context, num_speculative=spec_tokens)
            trial_results.append(result)

        results[name] = {
            "speedup": np.mean([r["speedup"] for r in trial_results]),
            "acceptance_rate": np.mean([r["acceptance_rate"] for r in trial_results]),
            "avg_accepted": np.mean([r["avg_accepted"] for r in trial_results]),
        }

    return results

Paso 6: Perfilador de Memoria de KV Cache

Calcular los requisitos de memoria de KV cache para configuraciones de modelos reales.

MODEL_CONFIGS = {
    "Llama-3-8B": {
        "num_layers": 32, "num_kv_heads": 8, "head_dim": 128,
        "model_params_b": 8, "gqa": True,
    },
    "Llama-3-70B": {
        "num_layers": 80, "num_kv_heads": 8, "head_dim": 128,
        "model_params_b": 70, "gqa": True,
    },
    "Llama-3-405B": {
        "num_layers": 126, "num_kv_heads": 8, "head_dim": 128,
        "model_params_b": 405, "gqa": True,
    },
    "Mistral-7B": {
        "num_layers": 32, "num_kv_heads": 8, "head_dim": 128,
        "model_params_b": 7, "gqa": True,
    },
    "GPT-4-est": {
        "num_layers": 120, "num_kv_heads": 96, "head_dim": 128,
        "model_params_b": 1800, "gqa": False,
    },
}


def kv_cache_memory(config, seq_len, dtype_bytes=2):
    per_token = 2 * config["num_layers"] * config["num_kv_heads"] * config["head_dim"] * dtype_bytes
    total = per_token * seq_len
    return {
        "per_token_bytes": per_token,
        "per_token_kb": per_token / 1024,
        "total_bytes": total,
        "total_mb": total / (1024 ** 2),
        "total_gb": total / (1024 ** 3),
    }


def memory_budget(config, gpu_memory_gb, model_dtype_bytes=2, kv_dtype_bytes=2):
    model_memory_gb = config["model_params_b"] * 1e9 * model_dtype_bytes / (1024 ** 3)
    overhead_gb = gpu_memory_gb * 0.1
    available_for_kv = gpu_memory_gb - model_memory_gb - overhead_gb

    if available_for_kv <= 0:
        return {"error": "Model does not fit in GPU memory", "model_memory_gb": model_memory_gb}

    per_token = 2 * config["num_layers"] * config["num_kv_heads"] * config["head_dim"] * kv_dtype_bytes
    max_tokens = int(available_for_kv * (1024 ** 3) / per_token)

    return {
        "gpu_memory_gb": gpu_memory_gb,
        "model_memory_gb": round(model_memory_gb, 1),
        "overhead_gb": round(overhead_gb, 1),
        "available_for_kv_gb": round(available_for_kv, 1),
        "max_total_tokens": max_tokens,
        "max_users_at_2k": max_tokens // 2048,
        "max_users_at_4k": max_tokens // 4096,
        "max_users_at_32k": max_tokens // 32768,
    }

Utilización (Use It)

Con vLLM:

from vllm import LLM, SamplingParams

llm = LLM(
    model="meta-llama/Llama-3-70B-Instruct",
    tensor_parallel_size=4,
    enable_prefix_caching=True,
    max_model_len=8192,
    gpu_memory_utilization=0.9,
)

params = SamplingParams(temperature=0.7, max_tokens=256)
outputs = llm.generate(["Explain inference optimization in one paragraph."], params)

Con SGLang para caché de prefijos + salida estructurada:

import sglang as sgl

@sgl.function
def classify(s, text):
    s += sgl.system("You are a classifier. Output JSON only.")
    s += sgl.user(f"Classify this text: {text}")
    s += sgl.assistant(sgl.gen("result", regex=r'\{"label": "(positive|negative|neutral)"\}'))

runtime = sgl.Runtime(model_path="meta-llama/Llama-3-70B-Instruct", tp_size=4)
sgl.set_default_backend(runtime)

results = classify.run_batch([
    {"text": "This product is amazing!"},
    {"text": "Terrible experience."},
    {"text": "It was okay I guess."},
])

Con TensorRT-LLM:

import tensorrt_llm
from tensorrt_llm.runtime import ModelRunner

runner = ModelRunner.from_dir("./llama-70b-trt-engine/", rank=0)

outputs = runner.generate(
    batch_input_ids=[tokenizer.encode("Explain KV caching.")],
    max_new_tokens=256,
    temperature=0.7,
)

Entrega (Ship It)

Esta lección produce:

Ejercicios

  1. Modifica el perfilador de KV cache para comparar la cuantización de KV cache en FP16 vs FP8 vs INT4. Para Llama 3 70B con un contexto de 4K, calcula el número máximo de usuarios concurrentes para cada uno en 4xA100-80GB. La cuantización de KV a INT4 debería multiplicar aproximadamente por 4 la capacidad de usuarios.

  2. Extiende el simulador de continuous batching para realizar el seguimiento de la utilización de la GPU (fracción de slots de lote llenos por paso). Grafica la utilización a lo largo del tiempo tanto para el loteo estático como continuo con 50 peticiones cuyas longitudes de salida sigan una distribución de Pareto (forma=1.5, escala=20). Continuous batching debería mantener una utilización >80%.

  3. Implementa una versión de atención de consulta agrupada (grouped-query attention - GQA) del KV cache donde num_kv_heads < num_query_heads. Llama 3 70B utiliza 64 cabezales de consulta pero solo 8 cabezales de KV. Calcula el ahorro de memoria frente a la atención multi-cabezal completa (reducción de 8 veces en el tamaño del KV cache).

  4. Construye una caché de prefijos que utilice la política de desalojo LRU. Establece max_entries en 500 y genera 1,000 peticiones donde el 60% comparta uno de 5 prefijos comunes. Mide la tasa de aciertos (hit rate) y compárala con una caché ilimitada. Con un buen desalojo, la tasa de aciertos debería mantenerse por encima de 55%.

  5. Extiende el simulador de decodificación especulativa para implementar la especulación basada en árboles (estilo EAGLE-2). En lugar de una sola cadena de K tokens de borrador, genera un árbol de candidatos (por ejemplo, 2 ramas en cada uno de los 3 niveles = 8 candidatos finales). Compara los tokens totales aceptados por ronda de verificación frente a la especulación lineal.

Términos Clave

Término Lo que la gente dice Lo que realmente significa
Prefill "Procesar el prompt" Calcular la atención sobre todos los tokens de entrada en paralelo -- limitado por computación porque la multiplicación de matrices completa mantiene ocupados los núcleos de la GPU
Decode "Generar tokens" Producir un token por cada paso hacia adelante, leyendo los pesos completos del modelo cada vez -- limitado por memoria porque el cómputo termina antes de que lleguen los siguientes pesos
KV cache "Almacenar en caché los estados de atención" Almacenar las proyecciones de clave y valor para todos los tokens anteriores para que no se vuelvan a calcular en cada paso de decode -- intercambia memoria por computación
Continuous batching "Loteo dinámico" Insertar nuevas peticiones en el lote en ejecución tan pronto como finalice cualquier petición, evaluado en cada iteración de decode en lugar de esperar a todo el lote
PagedAttention "Memoria virtual para KV cache" Asignar KV cache en páginas de tamaño fijo en lugar de bloques contiguos, eliminando la fragmentación de memoria y permitiendo copy-on-write para prefijos compartidos
Decodificación especulativa "Borrador y verificación" Utilizar un modelo rápido de borrador para proponer múltiples tokens, y luego verificarlos todos en un solo paso hacia adelante del modelo objetivo -- matemáticamente exacto, aceleración de 2 a 3 veces
EAGLE "Decodificación auto-especulativa" Una variante de decodificación especulativa que entrena un cabezal ligero sobre los propios estados ocultos del modelo objetivo, logrando tasas de aceptación más altas que un modelo de borrador independiente
Caché de prefijos "Reutilizar el KV del prompt del sistema" Almacenar las entradas de KV cache calculadas para prefijos comunes (prompts de sistema, ejemplos few-shot) y reutilizarlas en todas las peticiones para omitir prefill redundante
Relación ops:byte "Intensidad aritmética" La relación entre las operaciones de cómputo y los bytes de memoria leídos -- determina si una carga de trabajo está limitada por computación (relación alta) o por memoria (relación baja)
Tiempo al primer token "TTFT" Latencia desde la recepción de una petición hasta la producción del primer token de salida -- dominada por el tiempo de prefill para prompts largos

Lecturas Adicionales