Phase 10 - Lesson 12

Otimização de Inferência

This lesson includes a graded coding exercise that runs in your browser, unlocked with lifetime access.

Duas fases definem a inferência de LLM. O Prefill processa seu prompt em paralelo -- limitado por computação (compute-bound). O Decode gera tokens um de cada vez -- limitado por memória (memory-bound). Cada otimização visa uma ou ambas as fases.

Tipo: Build Idiomas: Python Pré-requisitos: Fase 10, Lições 01-08 (arquitetura Transformer, atenção) Tempo: ~120 minutos

Objetivos de Aprendizado

  • Implementar o KV-cache para eliminar a computação redundante durante a geração autorregressiva de tokens
  • Explicar as fases de prefill vs decode na inferência de LLM e por que cada uma tem gargalos diferentes (limitado por computação vs limitado por memória)
  • Implementar os conceitos de continuous batching (lote contínuo) e PagedAttention para maximizar a utilização de GPU sob requisições concorrentes
  • Comparar técnicas de otimização de inferência (KV-cache, decodificação especulativa, flash attention) e seus compromissos (tradeoffs) entre taxa de transferência (throughput) e latência

O Problema

Você implanta o Llama 3 70B em 4 GPUs A100. Um único usuário obtém ~50 tokens por segundo. Parece rápido. Então, 100 usuários acessam o endpoint simultaneamente. A taxa de transferência cai para 3 tokens/segundo/usuário. Sua fatura de GPU de

5.000/mês está servindo respostas mais devagar do que um ser humano digita.

O modelo em si não muda entre 1 usuário e 100 usuários. Mesmos pesos, mesma arquitetura, mesma matemática. O que muda é como você agenda o trabalho. A inferência ingênua desperdiça mais de 90% da computação disponível na GPU. Um usuário esperando pelo token 47 mantém um slot de lote inteiro aberto enquanto o barramento de memória da GPU fica ocioso entre multiplicações de matrizes (matmuls). Enquanto isso, o prompt de 2.000 tokens de um novo usuário poderia preencher esse tempo morto com computação útil.

Isso não é um problema de dimensionamento (scaling). É um problema de agendamento (scheduling). As técnicas desta lição -- KV caching, continuous batching, PagedAttention, decodificação especulativa, cache de prefixo -- são o que separam uma conta de inferência de 5.000/mês de uma de $5.000/mês servindo o mesmo tráfego.

O vLLM servindo o Llama 3 70B em 4xA100-80GB atinge ~50 tokens/segundo/usuário sob baixa concorrência, e sustenta de 15 a 25 TPS/usuário com 100 requisições concorrentes por meio de continuous batching e PagedAttention. Sem essas otimizações, o mesmo hardware serve 5 TPS/usuário com essa concorrência. Mesmas GPUs, mesmo modelo, 4x mais taxa de transferência.

O Conceito

Prefill vs Decode

Toda requisição de inferência de LLM possui duas fases distintas.

O Prefill processa todo o prompt de entrada. Todos os tokens são conhecidos, portanto, a atenção pode ser computada em paralelo ao longo de toda a sequência. Trata-se de uma grande multiplicação de matrizes -- os núcleos da GPU permanecem ocupados. O gargalo é a computação: quantos FLOPS seu hardware consegue entregar por segundo. Uma A100 entrega 312 TFLOPS (BF16). O prefill para um prompt de 4.096 tokens em um modelo de 70B leva ~400ms em uma única A100.

O Decode gera tokens de saída um por vez. Cada novo token atende a todos os tokens anteriores, mas apenas um token é gerado por etapa de propagação direta (forward pass). As matrizes de pesos têm o mesmo tamanho da fase de prefill, mas você as está multiplicando por um único vetor em vez de uma matriz. Os núcleos da GPU terminam em microssegundos e depois esperam pelo próximo lote de pesos chegar da memória. O gargalo é a largura de banda de memória: quão rápido você consegue transferir (stream) os pesos do modelo da HBM (High Bandwidth Memory) para as unidades de computação. Uma A100 tem 2 TB/s de largura de banda. Um modelo de 70B em FP16 tem 140 GB. Ler o modelo completo uma vez leva 70ms -- esse é o seu piso para uma única etapa de decode.

graph LR
    subgraph "Prefill (limitado por computação)"
        P1["Todos os tokens do prompt"] --> P2["Atenção paralela"]
        P2 --> P3["Utilização total da matmul"]
    end

    subgraph "Decode (limitado por memória)"
        D1["Um token de cada vez"] --> D2["Geração sequencial"]
        D2 --> D3["Aguardando leituras de memória"]
    end

    P3 --> D1

A razão ops:byte (também chamada de intensidade aritmética) captura esse compromisso. Ela mede quantas operações você realiza por byte carregado da memória.

ops:byte ratio = FLOPs per token / bytes read from memory

Durante o prefill com um lote de 4.096 tokens, você realiza ~4.096 operações de multiplicação-acumulação por peso carregado. A razão é alta -- você está limitado por computação (compute-bound). Durante o decode com tamanho de lote 1, você realiza ~1 operação por peso carregado. A razão é baixa -- você está limitado por memória (memory-bound).

A percepção fundamental: o decode é limitado por memória porque você lê o modelo inteiro para gerar um único token. Cada otimização abaixo reduz o que você lê, aumenta o lote de tokens processados por leitura, ou evita totalmente as leituras.

KV Cache

Durante a atenção, a consulta (query) de cada token atende aos vetores de chave (key) e valor (value) de todos os tokens anteriores. Sem cache, a geração do token N exige a recomputação das projeções de chave e valor para todos os N-1 tokens precedentes. O token 1 é projetado ao gerar o token 2, depois novamente para o token 3, depois novamente para o token 4. No token 1.000, você já projetou o token 1 um total de 999 vezes.

O KV cache armazena as projeções de chave e valor de todos os tokens anteriores. Ao gerar o token N, você apenas computa a chave e o valor para o token N, e depois os concatena com as K/V armazenadas em cache dos tokens 1 a N-1.

graph TD
    subgraph "Sem KV Cache"
        A1["Token 5: recomputar K,V para os tokens 1-4"]
        A2["Token 6: recomputar K,V para os tokens 1-5"]
        A3["Token 7: recomputar K,V para os tokens 1-6"]
    end

    subgraph "Com KV Cache"
        B1["Token 5: computar K5,V5, ler K1-4,V1-4 do cache"]
        B2["Token 6: computar K6,V6, ler K1-5,V1-5 do cache"]
        B3["Token 7: computar K7,V7, ler K1-6,V1-6 do cache"]
    end

Fórmula de memória para o KV cache:

KV cache size = 2 * num_layers * num_kv_heads * head_dim * seq_len * bytes_per_param

Para o Llama 3 70B (80 camadas, 8 cabeças de KV com GQA, head_dim=128, BF16):

per token: 2 * 80 * 8 * 128 * 2 bytes = 327,680 bytes = 320 KB
at 4,096 tokens: 320 KB * 4,096 = 1.28 GB
at 128K tokens: 320 KB * 131,072 = 40 GB

Uma única conversa com contexto de 128K para o Llama 3 70B consome 40 GB de KV cache -- metade da memória de uma A100. Com 100 usuários concorrentes a 4K tokens cada, o KV cache sozinho exige 128 GB. É por isso que o gerenciamento do KV cache é o desafio central da otimização de inferência.

Continuous Batching

O lote estático (static batching) espera até que um lote de N requisições chegue, processa-as juntas e espera até que todas terminem antes de aceitar novas requisições. Se uma requisição precisa de 500 tokens e outra precisa de 10, a requisição curta fica ociosa por 490 etapas de decode após ser concluída.

O lote contínuo (continuous batching, também chamado de iteration-level batching) insere novas requisições no lote assim que qualquer requisição é concluída. O lote é reavaliado a cada etapa de decode. Uma requisição que termina após 10 tokens é imediatamente substituída por uma requisição em espera.

sequenceDiagram
    participant GPU
    participant R1 as Requisição 1 (50 tokens)
    participant R2 as Requisição 2 (10 tokens)
    participant R3 as Requisição 3 (30 tokens)
    participant R4 as Requisição 4 (em espera)

    Note over GPU: Lote estático
    GPU->>R1: Processar lote [R1, R2, R3]
    Note over R2: R2 concluída no passo 10
    Note over R2: Desperdiçando 40 passos...
    Note over R3: R3 concluída no passo 30
    Note over R3: Desperdiçando 20 passos...
    GPU->>R4: Finalmente iniciar R4 no passo 50

    Note over GPU: Lote contínuo
    GPU->>R1: Processar lote [R1, R2, R3]
    Note over R2: R2 concluída no passo 10
    GPU->>R4: Inserir R4 no passo 11
    Note over R3: R3 concluída no passo 30

A melhoria na taxa de transferência depende de quanto o comprimento das saídas varia. Com comprimentos uniformes, o continuous batching se iguala ao lote estático. Com comprimentos variáveis (o caso mais comum), o continuous batching pode fornecer uma taxa de transferência de 2 a 5 vezes maior, porque os slots da GPU nunca ficam vazios.

PagedAttention

O KV cache para cada requisição é um bloco contíguo de memória. À medida que as requisições chegam e saem, a memória se fragmenta -- exatamente como a fragmentação de RAM em sistemas operacionais. Uma requisição de 4K tokens precisa de 1,28 GB contíguos. Mesmo se você tiver 2 GB livres no total, pode não ter 1,28 GB contíguos. Você acaba desperdiçando memória ou rejeitando a requisição.

O PagedAttention (do vLLM) aplica memória virtual do tipo SO ao KV cache. Em vez de alocar um bloco contíguo por requisição, ele aloca "páginas" de tamanho fixo (geralmente 16 tokens cada). As páginas podem estar em qualquer lugar da memória física da GPU. Uma tabela de páginas mapeia as posições da sequência lógica de cada requisição para os locais físicos das páginas.

graph TD
    subgraph "Alocação contígua"
        C1["Requisição A: bloco de 2GB"]
        C2["[livre: 0.5GB]"]
        C3["Requisição B: bloco de 1GB"]
        C4["[livre: 1.5GB -- mas fragmentado]"]
    end

    subgraph "PagedAttention"
        P1["Pool de páginas: 256 páginas de 16 tokens cada"]
        P2["Requisição A: páginas 3,7,12,45,88..."]
        P3["Requisição B: páginas 1,4,9,22,67..."]
        P4["Sem fragmentação, sem desperdício"]
    end

O PagedAttention também possibilita o recurso de copy-on-write (cópia em gravação) para prefixos compartilhados. Se 50 requisições compartilham o mesmo prompt do sistema, as páginas de KV cache para esse prompt do sistema são armazenadas apenas uma vez e referenciadas por todas as 50 requisições. Apenas quando uma requisição diverge (mensagens de usuário diferentes) ela ganha suas próprias páginas. Isso reduz drasticamente o uso de memória em aplicações com prompts do sistema compartilhados.

O vLLM relata desperdício de memória quase nulo (~4% contra ~60-80% na alocação ingênua) por meio do PagedAttention.

Decodificação Especulativa

O decode é lento por ser sequencial -- você gera um token, o realimenta na rede e gera o próximo. Mas e se você pudesse estimar os próximos 5 tokens a um custo baixo e depois verificar todos de uma só vez?

A decodificação especulativa usa um modelo de rascunho (draft model) pequeno e rápido para gerar K tokens candidatos. O modelo alvo (target model) principal processa os K candidatos em uma única etapa de propagação direta (que se parece com um prefill -- paralela, limitada por computação e eficiente). Se o modelo alvo concordar com as previsões do modelo de rascunho, você aceita todos os K tokens no tempo de uma única propagação direta do alvo. Se houver divergência na posição j, você aceita os tokens de 1 a j-1 e descarta os demais.

graph LR
    D["Draft model (1B)"] -->|"Gerar 5 tokens<br/>~5ms"| C["Candidatos: the cat sat on the"]
    C --> T["Target model (70B)"]
    T -->|"Verificar os 5 em uma passagem<br/>~70ms"| V{"Match?"}
    V -->|"4 de 5 coincidem"| A["Aceitar 4 tokens em 75ms<br/>vs 280ms sequencial"]
    V -->|"Incoerência na pos 5"| R["Rejeitar token 5<br/>Reamostrar do alvo"]

A aceleração obtida depende da taxa de aceitação -- a frequência com que as previsões do modelo de rascunho coincidem com as do alvo. Para um Llama 3 8B servindo de rascunho para um Llama 3 70B, taxas de aceitação de 70-85% são típicas em linguagem natural. Isso se traduz em uma aceleração de 2 a 3 vezes no decode.

Três abordagens para decodificação especulativa:

Método Origem do rascunho Taxa de aceitação Sobrecarga (Overhead)
Draft-target (Leviathan et al.) Modelo menor separado 70-85% Memória do modelo de rascunho
EAGLE (Li et al.) Cabeça leve sobre o alvo 75-90% ~1% de parâmetros extras
N-gram lookup Tabela de n-gramas de tokens 40-60% Desprezível

O EAGLE treina uma pequena cabeça autorregressiva sobre os estados ocultos (hidden states) do modelo alvo. Ele prevê o embedding do próximo token usando os recursos da penúltima camada do modelo alvo. Por operar nas próprias representações do modelo alvo (e não nas de um modelo separado), ele atinge taxas de aceitação mais altas com um acréscimo mínimo de memória. O EAGLE-2 adiciona uma árvore de rascunho dinâmica que ajusta a contagem de candidatos com base no contexto.

A decodificação especulativa N-gram mantém uma tabela de continuações de n-gramas a partir do contexto atual ou de um corpus pré-construído. Se o rascunho coincidir com o que apareceu antes na mesma conversa (padrões repetitivos, código, saídas estruturadas), ele é disparado com zero sobrecarga de rede neural. As taxas de aceitação são mais baixas em média, mas o custo por especulação é essencialmente gratuito.

A decodificação especulativa é matematicamente exata -- a distribuição de saída é idêntica à distribuição do modelo alvo. Não se trata de uma aproximação. A etapa de verificação garante que cada token aceito tenha exatamente a probabilidade que o modelo alvo teria atribuído a ele.

Cache de Prefixo

Muitas requisições compartilham o mesmo prefixo. Um prompt de sistema de chatbot. Um bloco de contexto RAG. Um conjunto de exemplos few-shot. Sem o cache de prefixo, cada requisição recomputa do zero o KV cache para esses tokens compartilhados.

O cache de prefixo armazena o KV cache para prefixos comuns e o reutiliza em todas as requisições. Quando uma nova requisição chega com um prefixo conhecido, o sistema copia (ou faz referência) às entradas de KV armazenadas em cache e apenas computa o KV para o sufixo exclusivo.

Para um prompt de sistema de 2.000 tokens compartilhado por todas as requisições, o cache de prefixo elimina ~400ms de prefill por requisição. A 100 requisições/segundo, isso economiza 40 segundos de computação da GPU por segundo -- mais do que o trabalho de uma GPU inteira.

O RadixAttention do SGLang implementa cache de prefixo com uma árvore radix (trie) que indexa prefixos pelo conteúdo de seus tokens. Qualquer requisição que coincida com um prefixo armazenado obtém seu KV cache gratuitamente. A árvore permite correspondências parciais de prefixo -- se você compartilhar 1.500 dos 2.000 tokens de prefixo com uma entrada em cache, você reutiliza esses 1.500 e recomputa apenas 500.

Inference Engines

Três motores dominam o serviço de LLM em produção:

Motor (Engine) Inovação chave Melhor para
vLLM PagedAttention, continuous batching Serviço de uso geral, maior compatibilidade
SGLang RadixAttention (cache de prefixo), geração estruturada Chatbots de múltiplos turnos, decodificação restrita
TensorRT-LLM Fusão de kernel da NVIDIA, quantização FP8 Taxa de transferência máxima em GPU única em hardware NVIDIA

vLLM é o ponto de partida padrão. Ele suporta a mais ampla variedade de modelos, roda em qualquer fabricante de GPU (NVIDIA, AMD, Intel) e atinge uma alta taxa de transferência através do PagedAttention + continuous batching. A API compatível com OpenAI significa que você pode usá-lo como substituto direto de qualquer chamada de API da OpenAI.

SGLang baseia-se nos mesmos fundamentos que o vLLM, mas adiciona RadixAttention para cache de prefixo e uma linguagem específica de domínio para programas de LLM estruturados. Se sua carga de trabalho envolve conversas de vários turnos, uso de ferramentas ou decodificação restrita (saída JSON, geração guiada por regex), o SGLang geralmente supera o vLLM em 2 a 5 vezes devido ao reuso de prefixo.

TensorRT-LLM compila modelos em kernels de GPU NVIDIA otimizados. Ele funde operações (atenção + linear + ativação em um único kernel), usa FP8 em GPUs H100 e se integra ao Servidor de Inferência NVIDIA Triton para implantação em produção. Ele atinge a maior taxa de transferência em GPU única em hardware NVIDIA, mas exige mais configuração e funciona apenas em GPUs NVIDIA.

Números do mundo real para o Llama 3 70B (4xA100-80GB, BF16):

Métrica vLLM SGLang TensorRT-LLM
Taxa de transferência (1 usuário) ~50 TPS ~55 TPS ~65 TPS
Taxa de transferência (100 usuários) ~2.500 TPS totais ~3.200 TPS totais ~3.000 TPS totais
Tempo até o primeiro token (TTFT) ~400ms ~300ms (hit de prefixo) ~350ms
Contexto máximo 128K 128K 128K

O Framework Ops:Byte

Você não pode otimizar o que não mede. A razão ops:byte diz se você está limitado por computação ou por memória, o que determina quais otimizações são importantes.

Compute roof: peak FLOPS of the GPU
Memory roof:  peak bandwidth * ops:byte ratio

Quando o ops:byte é baixo (decode, lotes pequenos), você atinge o teto da largura de banda da memória. Adicionar mais poder computacional (clock maior, mais núcleos) não ajuda. Você precisa reduzir as leituras de memória (quantização, compressão de KV cache) ou aumentar o tamanho do lote para diluir as leituras por mais trabalho útil.

Quando o ops:byte é alto (prefill, lotes grandes), você atinge o teto computacional. A otimização da largura de banda da memória não ajuda. Você precisa de GPUs mais rápidas, fusão de kernels ou precisão reduzida para espremer mais FLOPS.

Cenário ops:byte Limitado por (Bound) Otimizar com
Prefill, lote=1 ~4.096 Computação Fusão de kernel, FP8
Decode, lote=1 ~1 Memória Quantização, compressão de KV
Decode, lote=32 ~32 Memória Lote maior, continuous batching
Decode, lote=256 ~256 Transição Ambos importam
Decode, lote=1024 ~1.024 Computação Fusão de kernel, paralelismo de tensores

O ponto de transição (crossover) na A100 fica em torno de ops:byte = 156 (312 TFLOPS / 2 TB/s). Abaixo de 156, você está limitado por memória. Acima de 156, você está limitado por computação. O continuous batching empurra o decode em direção a essa transição, agrupando mais tokens por iteração.

Implementação (Build It)

Passo 1: KV Cache do Zero

Construímos um KV cache multi-cabeça que armazena projeções de chaves e valores por camada, por cabeça, e demonstra o padrão de crescimento de memória.

import numpy as np

class KVCache:
    def __init__(self, num_layers, num_heads, head_dim, max_seq_len, dtype=np.float16):
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.max_seq_len = max_seq_len
        self.dtype = dtype

        self.k_cache = np.zeros(
            (num_layers, num_heads, max_seq_len, head_dim), dtype=dtype
        )
        self.v_cache = np.zeros(
            (num_layers, num_heads, max_seq_len, head_dim), dtype=dtype
        )
        self.seq_len = 0

    def update(self, layer_idx, new_keys, new_values):
        num_new = new_keys.shape[1]
        end = self.seq_len + num_new
        self.k_cache[layer_idx, :, self.seq_len:end, :] = new_keys
        self.v_cache[layer_idx, :, self.seq_len:end, :] = new_values
        return (
            self.k_cache[layer_idx, :, :end, :],
            self.v_cache[layer_idx, :, :end, :]
        )

    def advance(self, num_tokens):
        self.seq_len += num_tokens

    def memory_bytes(self):
        return self.k_cache.nbytes + self.v_cache.nbytes

    def used_bytes(self):
        per_token = 2 * self.num_layers * self.num_heads * self.head_dim * np.dtype(self.dtype).itemsize
        return per_token * self.seq_len

Passo 2: Atenção com KV Cache

Uma atenção multi-cabeça simplificada que utiliza o KV cache para as etapas de decode.

def scaled_dot_product_attention(query, keys, values):
    head_dim = query.shape[-1]
    scores = np.matmul(query, keys.transpose(0, 1, 3, 2)) / np.sqrt(head_dim)
    seq_len_q = scores.shape[-2]
    seq_len_k = scores.shape[-1]
    if seq_len_q > 1:
        mask = np.triu(np.ones((seq_len_q, seq_len_k), dtype=np.float32), k=seq_len_k - seq_len_q + 1)
        scores = scores + mask * (-1e9)
    max_scores = np.max(scores, axis=-1, keepdims=True)
    exp_scores = np.exp(scores - max_scores)
    attn_weights = exp_scores / np.sum(exp_scores, axis=-1, keepdims=True)
    return np.matmul(attn_weights, values)


class MultiHeadAttention:
    def __init__(self, d_model, num_heads):
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        scale = np.sqrt(2.0 / d_model)
        self.W_q = np.random.randn(d_model, d_model).astype(np.float32) * scale
        self.W_k = np.random.randn(d_model, d_model).astype(np.float32) * scale
        self.W_v = np.random.randn(d_model, d_model).astype(np.float32) * scale
        self.W_o = np.random.randn(d_model, d_model).astype(np.float32) * scale

    def forward(self, x, kv_cache=None, layer_idx=0):
        batch, seq_len, d_model = x.shape
        Q = np.matmul(x, self.W_q).reshape(batch, seq_len, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
        K = np.matmul(x, self.W_k).reshape(batch, seq_len, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
        V = np.matmul(x, self.W_v).reshape(batch, seq_len, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)

        if kv_cache is not None:
            K_full, V_full = kv_cache.update(layer_idx, K[0], V[0])
            K = K_full[np.newaxis, :, :, :]
            V = V_full[np.newaxis, :, :, :]
            if seq_len == 1:
                kv_cache.advance(1)

        attn_out = scaled_dot_product_attention(Q, K, V)
        attn_out = attn_out.transpose(0, 2, 1, 3).reshape(batch, -1, d_model)
        return np.matmul(attn_out, self.W_o)

Passo 3: Simulador de Continuous Batching

Simulação da diferença de agendamento entre lote estático e contínuo.

import heapq

class Request:
    def __init__(self, request_id, prompt_tokens, output_tokens, arrival_step):
        self.request_id = request_id
        self.prompt_tokens = prompt_tokens
        self.output_tokens = output_tokens
        self.arrival_step = arrival_step
        self.tokens_generated = 0
        self.start_step = None
        self.end_step = None

    def is_done(self):
        return self.tokens_generated >= self.output_tokens


def simulate_static_batching(requests, batch_size):
    step = 0
    completed = []
    queue = list(requests)
    queue.sort(key=lambda r: r.arrival_step)

    while queue:
        batch = []
        while queue and len(batch) < batch_size:
            r = queue.pop(0)
            r.start_step = max(step, r.arrival_step)
            batch.append(r)

        if batch:
            step = max(step, max(r.start_step for r in batch))
            max_output = max(r.output_tokens for r in batch)
            for r in batch:
                r.tokens_generated = r.output_tokens
                r.end_step = step + max_output
            step += max_output
            completed.extend(batch)

    return completed


def simulate_continuous_batching(requests, batch_size):
    step = 0
    completed = []
    queue = sorted(requests, key=lambda r: r.arrival_step)
    queue_idx = 0
    active = []
    waiting = []

    while queue_idx < len(queue) or active or waiting:
        while queue_idx < len(queue) and queue[queue_idx].arrival_step <= step:
            waiting.append(queue[queue_idx])
            queue_idx += 1

        while waiting and len(active) < batch_size:
            r = waiting.pop(0)
            r.start_step = step
            active.append(r)

        if not active:
            if waiting:
                step += 1
                continue
            elif queue_idx < len(queue):
                step = queue[queue_idx].arrival_step
                continue
            else:
                break

        for r in active:
            r.tokens_generated += 1

        done = [r for r in active if r.is_done()]
        for r in done:
            r.end_step = step + 1
            completed.append(r)
        active = [r for r in active if not r.is_done()]

        step += 1

    return completed


def batching_stats(completed):
    latencies = [r.end_step - r.arrival_step for r in completed]
    total_time = max(r.end_step for r in completed) - min(r.arrival_step for r in completed)
    total_tokens = sum(r.output_tokens for r in completed)
    return {
        "avg_latency": np.mean(latencies),
        "p50_latency": np.median(latencies),
        "p99_latency": np.percentile(latencies, 99),
        "total_time": total_time,
        "throughput": total_tokens / total_time if total_time > 0 else 0,
    }

Passo 4: Cache de Prefixo

Um cache de prefixo baseado em trie que armazena entradas de KV para prefixos compartilhados.

class TrieNode:
    def __init__(self):
        self.children = {}
        self.kv_data = None
        self.hit_count = 0


class PrefixCache:
    def __init__(self, max_entries=1000):
        self.root = TrieNode()
        self.max_entries = max_entries
        self.total_entries = 0
        self.hits = 0
        self.misses = 0

    def _walk(self, token_ids):
        node = self.root
        depth = 0
        for tid in token_ids:
            if tid not in node.children:
                break
            node = node.children[tid]
            depth += 1
        return node, depth

    def lookup(self, token_ids):
        node, depth = self._walk(token_ids)
        if depth > 0:
            self.hits += 1
            current = self.root
            for tid in token_ids[:depth]:
                current = current.children[tid]
                current.hit_count += 1
            kv_entries = []
            current = self.root
            for tid in token_ids[:depth]:
                current = current.children[tid]
                if current.kv_data is not None:
                    kv_entries.append(current.kv_data)
            return depth, kv_entries
        self.misses += 1
        return 0, []

    def insert(self, token_ids, kv_per_token):
        node = self.root
        for i, tid in enumerate(token_ids):
            if tid not in node.children:
                if self.total_entries >= self.max_entries:
                    return i
                node.children[tid] = TrieNode()
                self.total_entries += 1
            node = node.children[tid]
            if i < len(kv_per_token):
                node.kv_data = kv_per_token[i]
        return len(token_ids)

    def hit_rate(self):
        total = self.hits + self.misses
        return self.hits / total if total > 0 else 0.0

Passo 5: Simulador de Decodificação Especulativa

Simulamos a decodificação especulativa draft-target com taxas de aceitação configuráveis.

class DraftModel:
    def __init__(self, vocab_size, acceptance_rate=0.8):
        self.vocab_size = vocab_size
        self.acceptance_rate = acceptance_rate

    def generate(self, context, num_tokens):
        tokens = np.random.randint(0, self.vocab_size, size=num_tokens)
        return tokens

    def get_probs(self, context, token):
        probs = np.random.dirichlet(np.ones(self.vocab_size))
        return probs


class TargetModel:
    def __init__(self, vocab_size):
        self.vocab_size = vocab_size

    def get_probs(self, context, tokens=None):
        if tokens is not None:
            return [np.random.dirichlet(np.ones(self.vocab_size)) for _ in tokens]
        return np.random.dirichlet(np.ones(self.vocab_size))


def speculative_decode(draft_model, target_model, context, num_speculative=5,
                       draft_cost=1.0, target_cost=10.0, verify_cost=12.0):
    total_tokens = 0
    total_cost = 0.0
    accepted_counts = []
    context = list(context)

    max_tokens = 100

    while total_tokens < max_tokens:
        draft_tokens = draft_model.generate(context, num_speculative)
        total_cost += draft_cost * num_speculative

        target_probs = target_model.get_probs(context, draft_tokens)
        total_cost += verify_cost

        accepted = 0
        for i, token in enumerate(draft_tokens):
            draft_p = draft_model.get_probs(context + list(draft_tokens[:i]), token)
            target_p = target_probs[i]

            r = np.random.random()
            acceptance_prob = min(1.0, target_p[token] / (draft_p[token] + 1e-10))

            if r < draft_model.acceptance_rate:
                accepted += 1
                context.append(token)
                total_tokens += 1
            else:
                new_token = np.random.choice(draft_model.vocab_size, p=target_p)
                context.append(new_token)
                total_tokens += 1
                break

        accepted_counts.append(accepted)

        if accepted == num_speculative:
            bonus_probs = target_model.get_probs(context)
            bonus_token = np.random.choice(draft_model.vocab_size, p=bonus_probs)
            context.append(bonus_token)
            total_tokens += 1

    sequential_cost = total_tokens * target_cost
    return {
        "total_tokens": total_tokens,
        "speculative_cost": total_cost,
        "sequential_cost": sequential_cost,
        "speedup": sequential_cost / total_cost if total_cost > 0 else 1.0,
        "avg_accepted": np.mean(accepted_counts),
        "acceptance_rate": np.mean(accepted_counts) / num_speculative,
    }


def compare_speculation_strategies(vocab_size=1000, num_trials=20):
    results = {}

    for name, acceptance_rate, spec_tokens in [
        ("Draft-target (8B->70B)", 0.78, 5),
        ("EAGLE", 0.85, 6),
        ("N-gram", 0.50, 4),
        ("No speculation", 0.0, 0),
    ]:
        if spec_tokens == 0:
            results[name] = {
                "speedup": 1.0,
                "acceptance_rate": 0.0,
                "avg_accepted": 0.0,
            }
            continue

        trial_results = []
        for _ in range(num_trials):
            draft = DraftModel(vocab_size, acceptance_rate=acceptance_rate)
            target = TargetModel(vocab_size)
            context = list(np.random.randint(0, vocab_size, size=10))
            result = speculative_decode(draft, target, context, num_speculative=spec_tokens)
            trial_results.append(result)

        results[name] = {
            "speedup": np.mean([r["speedup"] for r in trial_results]),
            "acceptance_rate": np.mean([r["acceptance_rate"] for r in trial_results]),
            "avg_accepted": np.mean([r["avg_accepted"] for r in trial_results]),
        }

    return results

Passo 6: Profiler de Memória do KV Cache

Computar os requisitos de memória de KV cache para configurações reais de modelos.

MODEL_CONFIGS = {
    "Llama-3-8B": {
        "num_layers": 32, "num_kv_heads": 8, "head_dim": 128,
        "model_params_b": 8, "gqa": True,
    },
    "Llama-3-70B": {
        "num_layers": 80, "num_kv_heads": 8, "head_dim": 128,
        "model_params_b": 70, "gqa": True,
    },
    "Llama-3-405B": {
        "num_layers": 126, "num_kv_heads": 8, "head_dim": 128,
        "model_params_b": 405, "gqa": True,
    },
    "Mistral-7B": {
        "num_layers": 32, "num_kv_heads": 8, "head_dim": 128,
        "model_params_b": 7, "gqa": True,
    },
    "GPT-4-est": {
        "num_layers": 120, "num_kv_heads": 96, "head_dim": 128,
        "model_params_b": 1800, "gqa": False,
    },
}


def kv_cache_memory(config, seq_len, dtype_bytes=2):
    per_token = 2 * config["num_layers"] * config["num_kv_heads"] * config["head_dim"] * dtype_bytes
    total = per_token * seq_len
    return {
        "per_token_bytes": per_token,
        "per_token_kb": per_token / 1024,
        "total_bytes": total,
        "total_mb": total / (1024 ** 2),
        "total_gb": total / (1024 ** 3),
    }


def memory_budget(config, gpu_memory_gb, model_dtype_bytes=2, kv_dtype_bytes=2):
    model_memory_gb = config["model_params_b"] * 1e9 * model_dtype_bytes / (1024 ** 3)
    overhead_gb = gpu_memory_gb * 0.1
    available_for_kv = gpu_memory_gb - model_memory_gb - overhead_gb

    if available_for_kv <= 0:
        return {"error": "Model does not fit in GPU memory", "model_memory_gb": model_memory_gb}

    per_token = 2 * config["num_layers"] * config["num_kv_heads"] * config["head_dim"] * kv_dtype_bytes
    max_tokens = int(available_for_kv * (1024 ** 3) / per_token)

    return {
        "gpu_memory_gb": gpu_memory_gb,
        "model_memory_gb": round(model_memory_gb, 1),
        "overhead_gb": round(overhead_gb, 1),
        "available_for_kv_gb": round(available_for_kv, 1),
        "max_total_tokens": max_tokens,
        "max_users_at_2k": max_tokens // 2048,
        "max_users_at_4k": max_tokens // 4096,
        "max_users_at_32k": max_tokens // 32768,
    }

Utilização (Use It)

Com o vLLM:

from vllm import LLM, SamplingParams

llm = LLM(
    model="meta-llama/Llama-3-70B-Instruct",
    tensor_parallel_size=4,
    enable_prefix_caching=True,
    max_model_len=8192,
    gpu_memory_utilization=0.9,
)

params = SamplingParams(temperature=0.7, max_tokens=256)
outputs = llm.generate(["Explain inference optimization in one paragraph."], params)

Com o SGLang para cache de prefixo + saída estruturada:

import sglang as sgl

@sgl.function
def classify(s, text):
    s += sgl.system("You are a classifier. Output JSON only.")
    s += sgl.user(f"Classify this text: {text}")
    s += sgl.assistant(sgl.gen("result", regex=r'\{"label": "(positive|negative|neutral)"\}'))

runtime = sgl.Runtime(model_path="meta-llama/Llama-3-70B-Instruct", tp_size=4)
sgl.set_default_backend(runtime)

results = classify.run_batch([
    {"text": "This product is amazing!"},
    {"text": "Terrible experience."},
    {"text": "It was okay I guess."},
])

Com o TensorRT-LLM:

import tensorrt_llm
from tensorrt_llm.runtime import ModelRunner

runner = ModelRunner.from_dir("./llama-70b-trt-engine/", rank=0)

outputs = runner.generate(
    batch_input_ids=[tokenizer.encode("Explain KV caching.")],
    max_new_tokens=256,
    temperature=0.7,
)

Entrega (Ship It)

Esta lição produz:

Exercícios

  1. Modifique o profiler de KV cache para comparar a quantização do KV cache em FP16 vs FP8 vs INT4. Para o Llama 3 70B em contexto de 4K, compute a capacidade máxima de usuários concorrentes para cada caso em 4xA100-80GB. A quantização de KV para INT4 deve aumentar em aproximadamente 4x a capacidade de usuários.

  2. Estenda o simulador de continuous batching para rastrear a utilização de GPU (fração de slots de lote preenchidos por passo). Plote a utilização ao longo do tempo para o lote estático e contínuo com 50 requisições cujos comprimentos de saída seguem uma distribuição de Pareto (forma=1,5, escala=20). O continuous batching deve manter uma utilização superior a 80%.

  3. Implemente uma versão de atenção de consulta agrupada (grouped-query attention - GQA) do KV cache onde num_kv_heads < num_query_heads. O Llama 3 70B usa 64 cabeças de consulta, mas apenas 8 cabeças de KV. Calcule a economia de memória em comparação com a atenção multi-cabeça completa (redução de 8x no tamanho do KV cache).

  4. Construa um cache de prefixo que use a política de despejo LRU (Least Recently Used). Defina max_entries como 500 e gere 1.000 requisições nas quais 60% compartilham um de 5 prefixos comuns. Meça a taxa de acerto (hit rate) e compare com o cache ilimitado. Com uma boa política de despejo, a taxa de acerto deve permanecer acima de 55%.

  5. Estenda o simulador de decodificação especulativa para implementar a especulação baseada em árvore (estilo EAGLE-2). Em vez de uma única cadeia de K tokens de rascunho, gere uma árvore de candidatos (por exemplo, 2 ramificações em cada um dos 3 níveis = 8 candidatos finais). Compare o total de tokens aceitos por rodada de verificação em comparação com a especulação linear.

Termos-Chave

Termo O que dizem O que realmente significa
Prefill "Processar o prompt" Computar a atenção de todos os tokens de entrada em paralelo -- limitado por computação porque a multiplicação completa de matrizes mantém os núcleos da GPU ocupados
Decode "Gerar tokens" Produzir um token por etapa de propagação direta, lendo todos os pesos do modelo a cada vez -- limitado por memória porque a computação termina antes que os próximos pesos cheguem
KV cache "Salvar estados de atenção em cache" Armazenar as projeções de chaves e valores de todos os tokens anteriores para evitar que sejam recomputados a cada passo de decode -- troca memória por computação
Continuous batching "Lote dinâmico" Inserir novas requisições no lote em execução assim que qualquer requisição termina, avaliado a cada iteração de decode em vez de esperar por todo o lote
PagedAttention "Memória virtual para KV cache" Alocar o KV cache em páginas de tamanho fixo em vez de blocos contíguos, eliminando a fragmentação de memória e permitindo o copy-on-write para prefixos compartilhados
Decodificação especulativa "Esboçar e verificar" Usar um modelo rápido de rascunho para propor múltiplos tokens e depois verificar todos em uma única passagem direta do modelo alvo -- matematicamente exato, com aceleração de 2 a 3 vezes
EAGLE "Decodificação auto-especulativa" Uma variante da decodificação especulativa que treina uma cabeça leve nos próprios estados ocultos do modelo alvo, alcançando taxas de aceitação mais altas do que um modelo de rascunho separado
Cache de prefixo "Reutilizar o KV do prompt do sistema" Armazenar as entradas de KV cache computadas para prefixos comuns (prompts de sistema, exemplos few-shot) e reutilizá-las em diferentes requisições para evitar prefills redundantes
Razão ops:byte "Intensidade aritmética" A razão entre operações de computação e bytes de memória lidos -- determina se uma carga de trabalho é limitada por computação (razão alta) ou por memória (razão baixa)
Time to first token "TTFT (tempo até o primeiro token)" Latência entre receber uma requisição e gerar o primeiro token de saída -- dominado pelo tempo de prefill para prompts longos

Leituras Adicionais