Phase 10 - Lesson 18
Multi-Token Prediction (MTP)
This lesson includes a graded coding exercise that runs in your browser, unlocked with lifetime access.
Todo LLM autorregresivo do
GPT-2aoLlama 3treina com uma perda por posição: prever o próximo token. O DeepSeek-V3 adicionou uma segunda perda por posição: prever o token seguinte a esse. Os14Bextras de parâmetros (em um modelo de671B) foram destilados de volta para o modelo principal por meio do fluxo de gradiente, e as cabeças de MTP treinadas foram reaproveitadas na inferência como rascunhos de decodificação especulativa com mais de80%+de aceitação. Um ganho de1.8×na taxa de geração veio de graça. Esta lição constrói o módulo MTP sequencial a partir do relatório técnico do DeepSeek, calcula a perda e o layout de parâmetros de cabeça compartilhada, e explica por que o MTP mantém a cadeia causal enquanto o MTP paralelo original de Gloeckle et al. a quebrava.
Tipo: Build Linguagens: Python (stdlib) Pré-requisitos: Fase 10 · 04 (pré-treino de um mini GPT), Fase 10 · 15 (decodificação especulativa) Tempo: ~60 minutos
Objetivos de Aprendizado
- Definir o objetivo de treinamento do MTP e derivar a perda conjunta ao longo das profundidades de predição.
- Explicar a diferença entre as cabeças MTP paralelas de Gloeckle et al. (2024) e os módulos MTP sequenciais do DeepSeek-V3 e por que o design sequencial preserva a cadeia causal.
- Calcular o overhead de parâmetros e memória ao adicionar módulos MTP a uma execução de pré-treinamento.
- Implementar um módulo MTP do zero: o embedding compartilhado, o bloco transformer por profundidade, a projeção e a cabeça de saída compartilhada.
O Problema
A predição do próximo token é o objetivo padrão de treinamento de LLMs. Cada estado oculto é supervisionado para prever exatamente uma coisa: o token imediatamente seguinte. Esse é um sinal surpreendentemente fraco. A maior parte das informações em uma sequência estende-se para além de um único token — estrutura, coerência, factualidade, fluxo aritmético. O modelo precisa aprender tudo isso acumulando muitos sinais de um único token ao longo de trilhões de tokens.
O MTP pergunta: e se cada estado oculto fosse supervisionado para prever múltiplos tokens futuros de uma vez? Gloeckle et al. (Meta, 2024) mostraram que isso ajuda. A implementação deles colocava várias cabeças de saída independentes no topo do backbone, cada uma prevendo um deslocamento diferente. Paralelo, simples, mas as cabeças viam o mesmo estado oculto sem qualquer refinamento hierárquico — e as predições não se encadeavam causalmente, de modo que não podiam ser usadas para decodificação especulativa.
O DeepSeek-V3 (dezembro de 2024) reformulou o MTP como módulos sequenciais que mantêm a cadeia causal em cada profundidade de predição. O modelo prevê t+1 a partir de h_i^(0), depois prevê t+2 a partir de um novo estado oculto h_i^(1) que combina h_i^(0) com o embedding E(t+1), e assim por diante. Cada profundidade é seu próprio pequeno bloco transformer. O embedding compartilhado e a cabeça de saída compartilhada mantêm o overhead de parâmetros modesto. Na escala do DeepSeek-V3, são 14B de parâmetros extras nos módulos MTP além dos pesos de 671B do modelo principal. Esse overhead de 2% trouxe sinais de treinamento mais densos E um rascunho de decodificação especulativa pronto para uso na inferência.
Esta lição constrói um único módulo MTP e a perda de profundidade D do zero. A matemática é organizada. A implementação tem 150 linhas.
O Conceito
A receita do MTP sequencial
O DeepSeek-V3 adiciona D módulos MTP no topo do modelo principal. Cada módulo k (para k = 1..D) prevê o token na profundidade k — ou seja, t_{i+k} dado um prefixo até a posição i.
O módulo k consiste em:
- Um bloco transformer
T_kcom sua própria atenção e MLP. - Uma matriz de projeção
M_kque combina o estado oculto da profundidade anterior com o embedding do token real (ground-truth) da próxima profundidade. - O embedding compartilhado
E(o mesmo do modelo principal). - A cabeça de saída compartilhada
Out(a mesma do modelo principal).
No treinamento, para um prefixo até a posição i, o estado oculto por profundidade é:
h_i^(0) = main model backbone at position i
h_i^(k) = T_k( M_k * concat(RMSNorm(h_i^(k-1)), RMSNorm(E(t_{i+k}))) ) for k >= 1
A predição por profundidade é:
logits_{i+k} = Out(h_i^(k-1)) for k = 1..D
A perda por profundidade é a entropia cruzada contra o token real t_{i+k}:
L_k = CE(logits_{i+k}, t_{i+k})
A perda conjunta ao longo das profundidades:
L_MTP = (lambda / D) * sum_{k=1..D} L_k
lambda é um pequeno fator de ponderação — o DeepSeek-V3 usa 0,3 para os primeiros 10% do treinamento e 0,1 depois disso. A perda total de treinamento é L_main + L_MTP.
Por que sequencial, e não paralelo
O MTP paralelo original de Gloeckle tinha D cabeças de saída, cada uma aplicada diretamente a h_i^(0). Cada cabeça prevê t_{i+k} a partir do mesmo estado oculto do backbone. Isso treina bem, mas as predições não são condicionadas umas nas outras. Você não pode usar a saída da head_1 para ajudar a head_2 — as cabeças disparam em paralelo.
O design sequencial do DeepSeek-V3 constrói h_i^(k) a partir de h_i^(k-1) mais o embedding do próximo token real E(t_{i+k}). Isso preserva a cadeia causal: para prever t_{i+k+1}, o módulo na profundidade k+1 vê o que estava em t_{i+k}. Isso é estruturalmente idêntico a como um decodificador autorregressivo consome sua própria saída — tornando os módulos MTP diretamente utilizáveis como rascunhos para decodificação especulativa.
Na inferência: alimente h_i^(k-1) e o token rascunhado t_{i+k} no módulo k+1, obtendo uma predição para t_{i+k+1}. Repita. Isso é exatamente um rascunho no estilo EAGLE, usando o módulo MTP treinado como a rede de rascunho (draft network). O DeepSeek-V3 relata mais de 80% de aceitação no primeiro módulo MTP e um ganho de velocidade de cerca de 1.8×.
Contabilidade de parâmetros
Para um modelo com tamanho oculto h e vocabulário V:
- Modelo principal: bilhões de parâmetros, mais uma cabeça de saída de tamanho
V * h. - Cabeça de saída compartilhada: reutiliza a cabeça do modelo principal. Sem parâmetros extras.
- Embedding compartilhado: reutiliza o embedding do modelo principal. Sem parâmetros extras.
- Por módulo MTP:
- Projeção
M_k:(2h) * h = 2h^2. - Bloco transformer
T_k: atenção (4h^2para MHA) mais MLP (geralmente8h^2para SwiGLU com razão 8/3). Cerca de12h^2por bloco.
- Projeção
Extra total por módulo: ~14h^2. Para o h = 7168 do DeepSeek-V3, D = 1 módulo: ~14 * 7168^2 = ~720M parâmetros no papel. O DeepSeek-V3 relata 14B — a diferença se deve principalmente ao fato de as camadas de especialistas também serem MoE no módulo MTP.
O retorno da decodificação especulativa
Durante o pré-treinamento, os módulos MTP deixam o treinamento cerca de 10% mais lento (mais computação forward, perda extra). O retorno é duplo:
Sinal de treinamento mais denso. Cada estado oculto vê D+1 alvos de supervisão. Efeito medido no MMLU, GSM8K, MATH e HumanEval: melhorias consistentes de alguns pontos percentuais nas ablações do DeepSeek-V3.
Rascunho gratuito para decodificação especulativa na inferência. O módulo MTP já está treinado para prever os próximos tokens. Reaproveitado como uma rede de rascunho, ele entrega taxas de aceitação de mais de 80%. Nesse nível, decodificação especulativa com N=3 ou N=5 fornece
1.8×mais throughput. O custo de 10% no tempo de treinamento se paga na primeira vez que você executa a inferência.
Relação com o EAGLE
O EAGLE treina um pequeno modelo de rascunho SEPARADAMENTE após o pré-treinamento. O MTP incorpora o rascunho diretamente no pré-treinamento. As duas abordagens convergem para taxas de aceitação semelhantes, mas por pipelines diferentes:
| Dimensão | EAGLE-3 | MTP (DeepSeek-V3) |
|---|---|---|
| Quando treinado | Pós-pré-treinamento | Durante o pré-treinamento |
| Retrocompatível com pesos existentes | Sim | Não (requer retreinamento) |
| Parâmetros do rascunho | 1-2 camadas transformer | 1 bloco transformer + projeção |
| Taxa de aceitação | 0.88-0.92 | 0.80+ na profundidade 1 |
| Benefício além do ganho de velocidade | Apenas decodificação especulativa | Sinal de treinamento mais denso + ganho de velocidade |
Construa
code/main.py constrói um único módulo MTP de ponta a ponta: embedding compartilhado, projeção, bloco transformer, cabeça de saída compartilhada. Em seguida, calcula a perda de entropia cruzada por profundidade em uma sequência sintética curta e imprime a contagem de parâmetros por componente. Um vocabulário de brinquedo de 32 tokens mantém os números legíveis.
Passo 1: tabela de embedding compartilhada
Uma única tabela de vocab_size x hidden é usada pelo modelo principal E por cada módulo MTP em todas as profundidades. Não é uma segunda cópia — literalmente o mesmo tensor.
Passo 2: a combinação por profundidade
def combine(prev_hidden, next_token_embed, M_k):
# concat along feature dim, then project down to hidden
concat = rms_norm(prev_hidden) + rms_norm(next_token_embed) # vector addition stand-in
projected = matvec(M_k, concat)
return projected
O DeepSeek-V3 real concatena os dois vetores com RMSNorm aplicada para obter [2h] e projeta com uma matriz de h x 2h. O brinquedo usa adição de vetores por questões de brevidade na biblioteca padrão.
Passo 3: o bloco transformer na profundidade k
Atenção própria mais MLP. No brinquedo, um bloco de atenção linear de uma camada e um MLP SwiGLU mantêm a estrutura visível sem numpy.
Passo 4: a cabeça de saída compartilhada
Reutiliza a projeção de saída do modelo principal. Logits sobre o vocabulário.
Passo 5: perda por profundidade
Entropia cruzada de softmax(logits) contra o token real no deslocamento k. Agregue ao longo das profundidades com o fator de escala lambda / D.
Passo 6: contabilidade de parâmetros
Imprima a contagem total de parâmetros, a contagem compartilhada (embedding, cabeça) e a contagem extra por módulo. Mostre a razão de parâmetros extras do MTP em relação ao tamanho do modelo principal.
Use
O MTP está integrado ao DeepSeek-V3 (dezembro de 2024) e à série DeepSeek-R1. Na inferência:
- A própria stack de serviço da DeepSeek consome módulos MTP como decodificadores especulativos nativamente.
- O vLLM e o SGLang têm caminhos de integração para o MTP do DeepSeek-V3 desde abril de 2026.
- O tutorial de ROCm SGLang da AMD mostra uma configuração específica de decodificação especulativa MTP com ganho medido de
1.8×no checkpoint V3.
Quando usar MTP em um novo pré-treinamento:
- Você controla todo o pipeline de pré-treinamento e deseja acumular um sinal de treinamento mais denso.
- Você sabe que servirá o modelo em escala e deseja decodificação especulativa gratuitamente.
- Seu tamanho oculto é de pelo menos 4096. Na escala de 1B, o overhead prejudica mais do que o ganho ajuda.
Quando não usar:
- Ajuste fino (fine-tuning) de um modelo denso pré-treinado existente. O módulo MTP não está treinado.
- Modelos de pesquisa onde você deseja uma linha de base limpa para comparação. O MTP altera a arquitetura.
Envie
Esta lição produz outputs/skill-mtp-planner.md. Dada uma especificação de execução de pré-treinamento (tamanho do modelo, dados, computação), ela retorna um plano para integrar o MTP: número de profundidades D, cronograma de lambda, overhead de memória e conexões da decodificação especulativa em tempo de inferência.
Exercícios
Execute
code/main.py. Mostre que a perda por profundidade diminui monotonicamente à medida que o sinal sintético se fortalece. Modifique o sintético para usar um padrão fixo e verifique se as perdas de profundidade-1 e profundidade-2 convergem.Calcule o overhead de parâmetros para um modelo denso de 70B (oculto 8192, 80 layers) com D=1 módulo MTP. Compare com o overhead de 14B relatado pelo DeepSeek-V3. Explique por que o número da DeepSeek é maior: o bloco transformer MTP herda a mesma estrutura MoE, inflando a contagem de parâmetros por módulo.
Implemente D=2 no modelo de brinquedo: adicione um segundo módulo MTP que recebe h^(1) e prevê
t_{i+2}. Verifique se a perda conjunta e a contabilidade de parâmetros correspondem às equações 19-21 do artigo da DeepSeek.Mude o brinquedo para MTP paralelo (estilo Gloeckle): adicione D cabeças de saída no topo do estado oculto principal, cada uma prevendo um deslocamento diferente. Meça como as perdas por profundidade se comparam à versão sequencial no mesmo sinal sintético. A versão sequencial deve produzir uma perda menor na profundidade-k para k > 1 porque ela se condiciona nas predições intermediárias.
Use o módulo MTP treinado como um rascunho estilo EAGLE: chame o módulo k para propor
t_{i+k}na inferência. Meça a taxa de aceitação desses tokens de rascunho em relação às predições do modelo principal em uma sequência de teste reservada. Se você atingir mais de 50% no modelo de brinquedo, terá reproduzido a propriedade empírica do MTP como rascunho.
Termos-Chave
| Termo | O que dizem | O que realmente significa |
|---|---|---|
| Módulo MTP | "Bloco de perda extra" | Um pequeno bloco transformer mais projeção que prevê um token k posições à frente do modelo principal |
| Profundidade de predição | "Qual deslocamento" | O inteiro k tal que o módulo k prevê t_{i+k} a partir do prefixo até a posição i |
| MTP paralelo | "Estilo Gloeckle" | D cabeças independentes no mesmo estado oculto do backbone, sem cadeia condicional |
| MTP sequencial | "Estilo DeepSeek-V3" | Cada módulo se condiciona no estado oculto da profundidade anterior mais o embedding do próximo token; preserva a cadeia causal |
| Cabeça de saída compartilhada | "Reutilizar a cabeça principal" | Os módulos MTP chamam a cabeça LM do modelo principal, não uma projeção de saída separada |
| Embedding compartilhado | "Reutilizar a tabela principal" | A mesma tabela de embeddings do vocabulário é usada em todos os lugares; sem parâmetros duplicados |
| Matriz de projeção M_k | "Combinar oculto + próximo token" | Uma camada linear h x 2h que mescla o estado oculto anterior e o embedding do token alvo na entrada da próxima profundidade |
| Perda conjunta L_MTP | "Média das perdas extras" | Média aritmética das perdas de entropia cruzada por profundidade, ponderada por lambda |
| Taxa de aceitação na profundidade 1 | "Frequência de acerto do rascunho MTP" | A taxa em que a predição top-1 do módulo MTP com D=1 é igual à predição top-1 do modelo principal; 80%+ no DeepSeek-V3 |
| Ponderação Lambda | "Importância da perda extra" | Fator de escala por profundidade; 0,3 no início do treinamento, 0,1 depois no DeepSeek-V3 |
Leituras Adicionais
- DeepSeek-AI — DeepSeek-V3 Technical Report (arXiv:2412.19437) — a descrição completa do MTP sequencial (Seção 2.2), incluindo as equações de perda conjunta e o ganho de velocidade de 1.8× na inferência
- Gloeckle et al. — Better & Faster Large Language Models via Multi-token Prediction (arXiv:2404.19737) — a linha de base de MTP paralelo sobre a qual o design do DeepSeek melhora
- DeepSeek-V3 model card no Hugging Face — 685B total (671B principal + 14B MTP), notas de implantação
- Leviathan et al. — Fast Inference from Transformers via Speculative Decoding (arXiv:2211.17192) — o framework de decodificação especulativa no qual o MTP se encaixa
- Li et al. — EAGLE-3 (arXiv:2503.01840) — a arquitetura de rascunho de 2025 do EAGLE, contraparte com a qual o MTP compete