Phase 10 - Lesson 18

Multi-Token Prediction (MTP)

This lesson includes a graded coding exercise that runs in your browser, unlocked with lifetime access.

Todo LLM autorregresivo do GPT-2 ao Llama 3 treina com uma perda por posição: prever o próximo token. O DeepSeek-V3 adicionou uma segunda perda por posição: prever o token seguinte a esse. Os 14B extras de parâmetros (em um modelo de 671B) foram destilados de volta para o modelo principal por meio do fluxo de gradiente, e as cabeças de MTP treinadas foram reaproveitadas na inferência como rascunhos de decodificação especulativa com mais de 80%+ de aceitação. Um ganho de 1.8× na taxa de geração veio de graça. Esta lição constrói o módulo MTP sequencial a partir do relatório técnico do DeepSeek, calcula a perda e o layout de parâmetros de cabeça compartilhada, e explica por que o MTP mantém a cadeia causal enquanto o MTP paralelo original de Gloeckle et al. a quebrava.

Tipo: Build Linguagens: Python (stdlib) Pré-requisitos: Fase 10 · 04 (pré-treino de um mini GPT), Fase 10 · 15 (decodificação especulativa) Tempo: ~60 minutos

Objetivos de Aprendizado

  • Definir o objetivo de treinamento do MTP e derivar a perda conjunta ao longo das profundidades de predição.
  • Explicar a diferença entre as cabeças MTP paralelas de Gloeckle et al. (2024) e os módulos MTP sequenciais do DeepSeek-V3 e por que o design sequencial preserva a cadeia causal.
  • Calcular o overhead de parâmetros e memória ao adicionar módulos MTP a uma execução de pré-treinamento.
  • Implementar um módulo MTP do zero: o embedding compartilhado, o bloco transformer por profundidade, a projeção e a cabeça de saída compartilhada.

O Problema

A predição do próximo token é o objetivo padrão de treinamento de LLMs. Cada estado oculto é supervisionado para prever exatamente uma coisa: o token imediatamente seguinte. Esse é um sinal surpreendentemente fraco. A maior parte das informações em uma sequência estende-se para além de um único token — estrutura, coerência, factualidade, fluxo aritmético. O modelo precisa aprender tudo isso acumulando muitos sinais de um único token ao longo de trilhões de tokens.

O MTP pergunta: e se cada estado oculto fosse supervisionado para prever múltiplos tokens futuros de uma vez? Gloeckle et al. (Meta, 2024) mostraram que isso ajuda. A implementação deles colocava várias cabeças de saída independentes no topo do backbone, cada uma prevendo um deslocamento diferente. Paralelo, simples, mas as cabeças viam o mesmo estado oculto sem qualquer refinamento hierárquico — e as predições não se encadeavam causalmente, de modo que não podiam ser usadas para decodificação especulativa.

O DeepSeek-V3 (dezembro de 2024) reformulou o MTP como módulos sequenciais que mantêm a cadeia causal em cada profundidade de predição. O modelo prevê t+1 a partir de h_i^(0), depois prevê t+2 a partir de um novo estado oculto h_i^(1) que combina h_i^(0) com o embedding E(t+1), e assim por diante. Cada profundidade é seu próprio pequeno bloco transformer. O embedding compartilhado e a cabeça de saída compartilhada mantêm o overhead de parâmetros modesto. Na escala do DeepSeek-V3, são 14B de parâmetros extras nos módulos MTP além dos pesos de 671B do modelo principal. Esse overhead de 2% trouxe sinais de treinamento mais densos E um rascunho de decodificação especulativa pronto para uso na inferência.

Esta lição constrói um único módulo MTP e a perda de profundidade D do zero. A matemática é organizada. A implementação tem 150 linhas.

O Conceito

A receita do MTP sequencial

O DeepSeek-V3 adiciona D módulos MTP no topo do modelo principal. Cada módulo k (para k = 1..D) prevê o token na profundidade k — ou seja, t_{i+k} dado um prefixo até a posição i.

O módulo k consiste em:

  • Um bloco transformer T_k com sua própria atenção e MLP.
  • Uma matriz de projeção M_k que combina o estado oculto da profundidade anterior com o embedding do token real (ground-truth) da próxima profundidade.
  • O embedding compartilhado E (o mesmo do modelo principal).
  • A cabeça de saída compartilhada Out (a mesma do modelo principal).

No treinamento, para um prefixo até a posição i, o estado oculto por profundidade é:

h_i^(0) = main model backbone at position i
h_i^(k) = T_k( M_k * concat(RMSNorm(h_i^(k-1)), RMSNorm(E(t_{i+k}))) )   for k >= 1

A predição por profundidade é:

logits_{i+k} = Out(h_i^(k-1))   for k = 1..D

A perda por profundidade é a entropia cruzada contra o token real t_{i+k}:

L_k = CE(logits_{i+k}, t_{i+k})

A perda conjunta ao longo das profundidades:

L_MTP = (lambda / D) * sum_{k=1..D} L_k

lambda é um pequeno fator de ponderação — o DeepSeek-V3 usa 0,3 para os primeiros 10% do treinamento e 0,1 depois disso. A perda total de treinamento é L_main + L_MTP.

Por que sequencial, e não paralelo

O MTP paralelo original de Gloeckle tinha D cabeças de saída, cada uma aplicada diretamente a h_i^(0). Cada cabeça prevê t_{i+k} a partir do mesmo estado oculto do backbone. Isso treina bem, mas as predições não são condicionadas umas nas outras. Você não pode usar a saída da head_1 para ajudar a head_2 — as cabeças disparam em paralelo.

O design sequencial do DeepSeek-V3 constrói h_i^(k) a partir de h_i^(k-1) mais o embedding do próximo token real E(t_{i+k}). Isso preserva a cadeia causal: para prever t_{i+k+1}, o módulo na profundidade k+1 vê o que estava em t_{i+k}. Isso é estruturalmente idêntico a como um decodificador autorregressivo consome sua própria saída — tornando os módulos MTP diretamente utilizáveis como rascunhos para decodificação especulativa.

Na inferência: alimente h_i^(k-1) e o token rascunhado t_{i+k} no módulo k+1, obtendo uma predição para t_{i+k+1}. Repita. Isso é exatamente um rascunho no estilo EAGLE, usando o módulo MTP treinado como a rede de rascunho (draft network). O DeepSeek-V3 relata mais de 80% de aceitação no primeiro módulo MTP e um ganho de velocidade de cerca de 1.8×.

Contabilidade de parâmetros

Para um modelo com tamanho oculto h e vocabulário V:

  • Modelo principal: bilhões de parâmetros, mais uma cabeça de saída de tamanho V * h.
  • Cabeça de saída compartilhada: reutiliza a cabeça do modelo principal. Sem parâmetros extras.
  • Embedding compartilhado: reutiliza o embedding do modelo principal. Sem parâmetros extras.
  • Por módulo MTP:
    • Projeção M_k: (2h) * h = 2h^2.
    • Bloco transformer T_k: atenção (4h^2 para MHA) mais MLP (geralmente 8h^2 para SwiGLU com razão 8/3). Cerca de 12h^2 por bloco.

Extra total por módulo: ~14h^2. Para o h = 7168 do DeepSeek-V3, D = 1 módulo: ~14 * 7168^2 = ~720M parâmetros no papel. O DeepSeek-V3 relata 14B — a diferença se deve principalmente ao fato de as camadas de especialistas também serem MoE no módulo MTP.

O retorno da decodificação especulativa

Durante o pré-treinamento, os módulos MTP deixam o treinamento cerca de 10% mais lento (mais computação forward, perda extra). O retorno é duplo:

  1. Sinal de treinamento mais denso. Cada estado oculto vê D+1 alvos de supervisão. Efeito medido no MMLU, GSM8K, MATH e HumanEval: melhorias consistentes de alguns pontos percentuais nas ablações do DeepSeek-V3.

  2. Rascunho gratuito para decodificação especulativa na inferência. O módulo MTP já está treinado para prever os próximos tokens. Reaproveitado como uma rede de rascunho, ele entrega taxas de aceitação de mais de 80%. Nesse nível, decodificação especulativa com N=3 ou N=5 fornece 1.8× mais throughput. O custo de 10% no tempo de treinamento se paga na primeira vez que você executa a inferência.

Relação com o EAGLE

O EAGLE treina um pequeno modelo de rascunho SEPARADAMENTE após o pré-treinamento. O MTP incorpora o rascunho diretamente no pré-treinamento. As duas abordagens convergem para taxas de aceitação semelhantes, mas por pipelines diferentes:

Dimensão EAGLE-3 MTP (DeepSeek-V3)
Quando treinado Pós-pré-treinamento Durante o pré-treinamento
Retrocompatível com pesos existentes Sim Não (requer retreinamento)
Parâmetros do rascunho 1-2 camadas transformer 1 bloco transformer + projeção
Taxa de aceitação 0.88-0.92 0.80+ na profundidade 1
Benefício além do ganho de velocidade Apenas decodificação especulativa Sinal de treinamento mais denso + ganho de velocidade

Construa

code/main.py constrói um único módulo MTP de ponta a ponta: embedding compartilhado, projeção, bloco transformer, cabeça de saída compartilhada. Em seguida, calcula a perda de entropia cruzada por profundidade em uma sequência sintética curta e imprime a contagem de parâmetros por componente. Um vocabulário de brinquedo de 32 tokens mantém os números legíveis.

Passo 1: tabela de embedding compartilhada

Uma única tabela de vocab_size x hidden é usada pelo modelo principal E por cada módulo MTP em todas as profundidades. Não é uma segunda cópia — literalmente o mesmo tensor.

Passo 2: a combinação por profundidade

def combine(prev_hidden, next_token_embed, M_k):
    # concat along feature dim, then project down to hidden
    concat = rms_norm(prev_hidden) + rms_norm(next_token_embed)  # vector addition stand-in
    projected = matvec(M_k, concat)
    return projected

O DeepSeek-V3 real concatena os dois vetores com RMSNorm aplicada para obter [2h] e projeta com uma matriz de h x 2h. O brinquedo usa adição de vetores por questões de brevidade na biblioteca padrão.

Passo 3: o bloco transformer na profundidade k

Atenção própria mais MLP. No brinquedo, um bloco de atenção linear de uma camada e um MLP SwiGLU mantêm a estrutura visível sem numpy.

Passo 4: a cabeça de saída compartilhada

Reutiliza a projeção de saída do modelo principal. Logits sobre o vocabulário.

Passo 5: perda por profundidade

Entropia cruzada de softmax(logits) contra o token real no deslocamento k. Agregue ao longo das profundidades com o fator de escala lambda / D.

Passo 6: contabilidade de parâmetros

Imprima a contagem total de parâmetros, a contagem compartilhada (embedding, cabeça) e a contagem extra por módulo. Mostre a razão de parâmetros extras do MTP em relação ao tamanho do modelo principal.

Use

O MTP está integrado ao DeepSeek-V3 (dezembro de 2024) e à série DeepSeek-R1. Na inferência:

  • A própria stack de serviço da DeepSeek consome módulos MTP como decodificadores especulativos nativamente.
  • O vLLM e o SGLang têm caminhos de integração para o MTP do DeepSeek-V3 desde abril de 2026.
  • O tutorial de ROCm SGLang da AMD mostra uma configuração específica de decodificação especulativa MTP com ganho medido de 1.8× no checkpoint V3.

Quando usar MTP em um novo pré-treinamento:

  • Você controla todo o pipeline de pré-treinamento e deseja acumular um sinal de treinamento mais denso.
  • Você sabe que servirá o modelo em escala e deseja decodificação especulativa gratuitamente.
  • Seu tamanho oculto é de pelo menos 4096. Na escala de 1B, o overhead prejudica mais do que o ganho ajuda.

Quando não usar:

  • Ajuste fino (fine-tuning) de um modelo denso pré-treinado existente. O módulo MTP não está treinado.
  • Modelos de pesquisa onde você deseja uma linha de base limpa para comparação. O MTP altera a arquitetura.

Envie

Esta lição produz outputs/skill-mtp-planner.md. Dada uma especificação de execução de pré-treinamento (tamanho do modelo, dados, computação), ela retorna um plano para integrar o MTP: número de profundidades D, cronograma de lambda, overhead de memória e conexões da decodificação especulativa em tempo de inferência.

Exercícios

  1. Execute code/main.py. Mostre que a perda por profundidade diminui monotonicamente à medida que o sinal sintético se fortalece. Modifique o sintético para usar um padrão fixo e verifique se as perdas de profundidade-1 e profundidade-2 convergem.

  2. Calcule o overhead de parâmetros para um modelo denso de 70B (oculto 8192, 80 layers) com D=1 módulo MTP. Compare com o overhead de 14B relatado pelo DeepSeek-V3. Explique por que o número da DeepSeek é maior: o bloco transformer MTP herda a mesma estrutura MoE, inflando a contagem de parâmetros por módulo.

  3. Implemente D=2 no modelo de brinquedo: adicione um segundo módulo MTP que recebe h^(1) e prevê t_{i+2}. Verifique se a perda conjunta e a contabilidade de parâmetros correspondem às equações 19-21 do artigo da DeepSeek.

  4. Mude o brinquedo para MTP paralelo (estilo Gloeckle): adicione D cabeças de saída no topo do estado oculto principal, cada uma prevendo um deslocamento diferente. Meça como as perdas por profundidade se comparam à versão sequencial no mesmo sinal sintético. A versão sequencial deve produzir uma perda menor na profundidade-k para k > 1 porque ela se condiciona nas predições intermediárias.

  5. Use o módulo MTP treinado como um rascunho estilo EAGLE: chame o módulo k para propor t_{i+k} na inferência. Meça a taxa de aceitação desses tokens de rascunho em relação às predições do modelo principal em uma sequência de teste reservada. Se você atingir mais de 50% no modelo de brinquedo, terá reproduzido a propriedade empírica do MTP como rascunho.

Termos-Chave

Termo O que dizem O que realmente significa
Módulo MTP "Bloco de perda extra" Um pequeno bloco transformer mais projeção que prevê um token k posições à frente do modelo principal
Profundidade de predição "Qual deslocamento" O inteiro k tal que o módulo k prevê t_{i+k} a partir do prefixo até a posição i
MTP paralelo "Estilo Gloeckle" D cabeças independentes no mesmo estado oculto do backbone, sem cadeia condicional
MTP sequencial "Estilo DeepSeek-V3" Cada módulo se condiciona no estado oculto da profundidade anterior mais o embedding do próximo token; preserva a cadeia causal
Cabeça de saída compartilhada "Reutilizar a cabeça principal" Os módulos MTP chamam a cabeça LM do modelo principal, não uma projeção de saída separada
Embedding compartilhado "Reutilizar a tabela principal" A mesma tabela de embeddings do vocabulário é usada em todos os lugares; sem parâmetros duplicados
Matriz de projeção M_k "Combinar oculto + próximo token" Uma camada linear h x 2h que mescla o estado oculto anterior e o embedding do token alvo na entrada da próxima profundidade
Perda conjunta L_MTP "Média das perdas extras" Média aritmética das perdas de entropia cruzada por profundidade, ponderada por lambda
Taxa de aceitação na profundidade 1 "Frequência de acerto do rascunho MTP" A taxa em que a predição top-1 do módulo MTP com D=1 é igual à predição top-1 do modelo principal; 80%+ no DeepSeek-V3
Ponderação Lambda "Importância da perda extra" Fator de escala por profundidade; 0,3 no início do treinamento, 0,1 depois no DeepSeek-V3

Leituras Adicionais

0 lifetime access. Curriculum based on AI Engineering from Scratch by Rohit Ghumare (MIT, used under attribution).