Phase 10 - Lesson 18
Multi-Token Prediction (MTP)
This lesson includes a graded coding exercise that runs in your browser, unlocked with lifetime access.
Todo LLM autorregresivo desde
GPT-2hastaLlama 3se entrena con una pérdida por posición: predecir el siguiente token. DeepSeek-V3 añadió una segunda pérdida por posición: predecir el token posterior a ese. Los14Bextras de parámetros (en un modelo de671B) se destilaron de vuelta al modelo principal a través del flujo de gradientes, y las cabezas de MTP entrenadas se reutilizaron en la inferencia como borradores de decodificación especulativa con más del80%+de aceptación. Se obtuvo un rendimiento de generación de1.8×de forma gratuita. Esta lección construye el módulo MTP secuencial a partir del reporte técnico de DeepSeek, calcula la pérdida y la distribución de parámetros de cabeza compartida, y explica por qué MTP mantiene la cadena causal mientras que el MTP paralelo original de Gloeckle et al. la rompía.
Tipo: Build Idiomas: Python (stdlib) Prerrequisitos: Fase 10 · 04 (pre-entrenamiento de un mini GPT), Fase 10 · 15 (decodificación especulativa) Tiempo: ~60 minutos
Objetivos de Aprendizaje
- Definir el objetivo de entrenamiento de MTP y derivar la pérdida conjunta a través de las profundidades de predicción.
- Explicar la diferencia entre las cabezas MTP paralelas de Gloeckle et al. (2024) y los módulos MTP secuenciales de DeepSeek-V3, y por qué el diseño secuencial preserva la cadena causal.
- Calcular la sobrecarga de parámetros y memoria al añadir módulos MTP a una ejecución de pre-entrenamiento.
- Implementar un módulo MTP desde cero: el embedding compartido, el bloque transformer por profundidad, la proyección y la cabeza de salida compartida.
El Problema
La predicción del siguiente token es el objetivo de entrenamiento estándar de los LLM. Cada estado oculto es supervisado para predecir exactamente una cosa: el token inmediatamente siguiente. Esa es una señal sorprendentemente débil. La mayor parte de la información en una secuencia se extiende más allá de un solo token: estructura, coherencia, factualidad, flujo aritmético. El modelo tiene que aprender esto acumulando muchas señales de un solo token a lo largo de billones de tokens.
MTP se pregunta: ¿qué pasaría si cada estado oculto fuera supervisado para predecir múltiples tokens futuros a la vez? Gloeckle et al. (Meta, 2024) demostraron que esto ayuda. Su implementación colocaba varias cabezas de salida independientes sobre el backbone, cada una prediciendo un desplazamiento diferente. Paralelo, simple, pero las cabezas veían el mismo estado oculto sin ningún refinamiento jerárquico, y las predicciones no se encadenaban causalmente, por lo que no podían usarse para la decodificación especulativa.
DeepSeek-V3 (diciembre de 2024) rediseñó MTP como módulos secuenciales que mantienen la cadena causal en cada profundidad de predicción. El modelo predice t+1 a partir de h_i^(0), luego predice t+2 a partir de un nuevo estado oculto h_i^(1) que combina h_i^(0) con el embedding E(t+1), y así sucesivamente. Cada profundidad es su propio bloque transformer pequeño. El embedding compartido y la cabeza de salida compartida mantienen la sobrecarga de parámetros modesta. A la escala de DeepSeek-V3, se añaden 14B de parámetros adicionales en los módulos MTP sobre los pesos de 671B del modelo principal. Esa sobrecarga del 2% compró señales de entrenamiento más densas Y un borrador de decodificación especulativa listo para usar en la inferencia.
Esta lección construye un único módulo MTP y la pérdida de profundidad D desde cero. Las matemáticas son ordenadas. La implementación tiene 150 líneas.
El Concepto
La receta del MTP secuencial
DeepSeek-V3 añade D módulos MTP sobre el modelo principal. Cada módulo k (para k = 1..D) predice el token en la profundidad k, es decir, t_{i+k} dado un prefijo hasta la posición i.
El módulo k consta de:
- Un bloque transformer
T_kcon su propia atención y MLP. - Una matriz de proyección
M_kque combina el estado oculto de la profundidad anterior con el embedding del token real (ground-truth) de la siguiente profundidad. - El embedding compartido
E(el mismo del modelo principal). - La cabeza de salida compartilhada
Out(la misma del modelo principal).
En el entrenamiento, para un prefijo hasta la posición i, el estado oculto por profundidad es:
h_i^(0) = main model backbone at position i
h_i^(k) = T_k( M_k * concat(RMSNorm(h_i^(k-1)), RMSNorm(E(t_{i+k}))) ) for k >= 1
La predicción por profundidad es:
logits_{i+k} = Out(h_i^(k-1)) for k = 1..D
La pérdida por profundidad es la entropía cruzada contra el token real t_{i+k}:
L_k = CE(logits_{i+k}, t_{i+k})
La pérdida conjunta a través de las profundidades:
L_MTP = (lambda / D) * sum_{k=1..D} L_k
lambda es un pequeño factor de ponderación: DeepSeek-V3 usa 0.3 para el primer 10% del entrenamiento y 0.1 después. La pérdida total de entrenamiento es L_main + L_MTP.
Por qué secuencial, no paralelo
El MTP paralelo original de Gloeckle tenía D cabezas de salida, cada una aplicada directamente a h_i^(0). Cada cabeza predice t_{i+k} a partir del mismo estado oculto del backbone. Eso se entrena bien, pero las predicciones no están condicionadas entre sí. No se puede usar la salida de head_1 para ayudar a head_2: las cabezas se ejecutan en paralelo.
El diseño secuencial de DeepSeek-V3 construye h_i^(k) a partir de h_i^(k-1) más el embedding del siguiente token real E(t_{i+k}). Eso preserva la cadena causal: para predecir t_{i+k+1}, el módulo en la profundidad k+1 ve lo que había en t_{i+k}. Esto es estructuralmente idéntico a cómo un decodificador autorregresivo consume su propia salida, lo que hace que los módulos MTP sean directamente utilizables como borradores de decodificación especulativa.
En la inferencia: se alimenta h_i^(k-1) y el token borrador t_{i+k} en el módulo k+1, obteniendo una predicción para t_{i+k+1}. Se repite. Esto es exactamente un borrador de estilo EAGLE, utilizando el módulo MTP entrenado como la red de borrador (draft network). DeepSeek-V3 reporta una tasa de aceptación superior al 80% en el primer módulo MTP y una aceleración de aproximadamente 1.8×.
Contabilidad de parámetros
Para un modelo con tamaño oculto h y vocabulario V:
- Modelo principal: miles de millones de parámetros, más una cabeza de salida de tamaño
V * h. - Cabeza de salida compartida: reutiliza la cabeza del modelo principal. Sin parámetros adicionales.
- Embedding compartido: reutiliza el embedding del modelo principal. Sin parámetros adicionales.
- Por módulo MTP:
- Proyección
M_k:(2h) * h = 2h^2. - Bloco transformer
T_k: atención (4h^2para MHA) más MLP (típicamente8h^2para SwiGLU con relación 8/3). Alrededor de12h^2por bloque.
- Proyección
Extra total por módulo: ~14h^2. Para el h = 7168 de DeepSeek-V3, D = 1 módulo: ~14 * 7168^2 = ~720M parámetros en papel. DeepSeek-V3 reporta 14B; la diferencia es principalmente que las capas de expertos también son MoE en el módulo MTP.
El beneficio de la decodificación especulativa
Durante el pre-entrenamiento, los módulos MTP ralentizan el entrenamiento en aproximadamente un 10% (más cómputo forward, pérdida adicional). El beneficio es doble:
Señal de entrenamiento más densa. Cada estado oculto ve D+1 objetivos de supervisión. Efecto medido en MMLU, GSM8K, MATH y HumanEval: mejoras consistentes de unos pocos puntos porcentuales en las ablaciones de DeepSeek-V3.
Borrador gratuito de decodificación especulativa en la inferencia. El módulo MTP ya está entrenado para predecir los siguientes tokens. Reutilizado como una red de borrador, ofrece tasas de aceptación superiores al 80%. A ese nivel, la decodificación especulativa con N=3 o N=5 proporciona un rendimiento
1.8×mayor. El costo del 10% en el tiempo de entrenamiento se recupera la primera vez que se ejecuta la inferencia.
Relación con EAGLE
EAGLE entrena un pequeño modelo de borrador POR SEPARADO después del pre-entrenamiento. MTP integra el borrador directamente en el pre-entrenamiento. Ambos enfoques convergen en tasas de aceptación similares pero a través de flujos de trabajo diferentes:
| Dimensión | EAGLE-3 | MTP (DeepSeek-V3) |
|---|---|---|
| Cuándo se entrena | Post-pre-entrenamiento | Durante el pre-entrenamiento |
| Retrocompatible con pesos existentes | Sí | No (requiere re-entrenamiento) |
| Parâmetros del borrador | 1-2 capas transformer | 1 bloque transformer + proyección |
| Tasa de aceptación | 0.88-0.92 | 0.80+ en profundidad 1 |
| Beneficio más allá de la velocidad | Solo decodificación especulativa | Señal de entrenamiento más densa + velocidad |
Constrúyelo
code/main.py construye un módulo MTP de extremo a extremo: embedding compartido, proyección, bloque transformer, cabeza de salida compartida. Luego computa la pérdida de entropía cruzada por profundidad en una secuencia sintética corta e imprime el conteo de parámetros por componente. Un vocabulario de juguete de 32 tokens mantiene los números legibles.
Paso 1: tabla de embedding compartido
Se utiliza una única tabla vocab_size x hidden tanto en el modelo principal COMO en cada módulo MTP en todas las profundidades. No es una segunda copia: es literalmente el mismo tensor.
Paso 2: la combinación por profundidad
def combine(prev_hidden, next_token_embed, M_k):
# concat along feature dim, then project down to hidden
concat = rms_norm(prev_hidden) + rms_norm(next_token_embed) # vector addition stand-in
projected = matvec(M_k, concat)
return projected
El DeepSeek-V3 real concatena los dos vectores normalizados por RMSNorm a [2h] y proyecta con una matriz h x 2h. El de juguete utiliza la suma de vectores por brevedad en la biblioteca estándar.
Paso 3: el bloque transformer en la profundidad k
Atención propia más MLP. En el de juguete, un bloque de atención lineal de una sola capa y un MLP SwiGLU mantienen la estructura visible sin numpy.
Paso 4: la cabeza de salida compartida
Reutiliza la proyección de salida del modelo principal. Logits sobre el vocabulario.
Paso 5: pérdida por profundidad
Entropía cruzada de softmax(logits) contra el token real en el desplazamiento k. Se agrega a través de las profundidades con el factor de escala lambda / D.
Paso 6: contabilidad de parámetros
Imprime el conteo total de parámetros, el conteo compartido (embedding, cabeza) y el conteo adicional por módulo. Muestra la relación entre los parámetros adicionales de MTP y el tamaño del modelo principal.
Úsalo
MTP está integrado en DeepSeek-V3 (diciembre de 2024) y en la serie DeepSeek-R1. En la inferencia:
- El propio stack de servicio de DeepSeek consume los módulos MTP como decodificadores especulativos de forma nativa.
- vLLM y SGLang tienen rutas de integración para MTP de DeepSeek-V3 a partir de abril de 2026.
- El tutorial de AMD ROCm SGLang muestra una configuración específica de decodificación especulativa MTP con una aceleración medida de
1.8×en el checkpoint V3.
Cuándo usar MTP en una nueva ejecución de pre-entrenamiento:
- Controlas todo el flujo de trabajo de pre-entrenamiento y deseas acumular una señal de entrenamiento más densa.
- Sabes que servirás el modelo a gran escala y deseas decodificación especulativa gratis.
- Tu tamaño oculto es de al menos 4096. A escala de 1B, la sobrecarga perjudica más de lo que ayuda el beneficio.
Cuándo no usarlo:
- Ajuste fino (fine-tuning) de un modelo denso pre-entrenado existente. El módulo MTP no está entrenado.
- Modelos de investigación donde deseas una línea de base limpia para comparar. MTP cambia la arquitectura.
Envíalo
Esta lección produce outputs/skill-mtp-planner.md. Dada una especificación de pre-entrenamiento (tamanho de modelo, datos, cómputo), devuelve un plan para integrar MTP: número de profundidades D, cronograma de lambda, sobrecarga de memoria y conexiones de decodificación especulativa en tiempo de inferencia.
Ejercicios
Ejecuta
code/main.py. Muestra que la pérdida por profundidad disminuye de forma monótona a medida que se fortalece la señal sintética. Modifica la sintética para usar un patrón fijo y verifica que las pérdidas de profundidad-1 y profundidad-2 convergen.Calcula la sobrecarga de parámetros para un modelo denso de 70B (oculto 8192, 80 capas) con D=1 módulo MTP. Compara con la sobrecarga de 14B reportada por DeepSeek-V3. Explica por qué el número de DeepSeek es mayor: el bloque transformer de MTP herda la misma estructura MoE, inflando el conteo de parámetros por módulo.
Implementa D=2 en el modelo de juguete: añade un segundo módulo MTP que tome h^(1) y prediga
t_{i+2}. Verifica que la pérdida conjunta y la contabilidad de parámetros coincidan con las ecuaciones 19-21 del artículo de DeepSeek.Cambia el juguete a MTP paralelo (estilo Gloeckle): añade D cabezas de salida sobre el estado oculto principal, cada una prediciendo un desplazamiento diferente. Mide cómo se comparan las pérdidas por profundidad con la versión secuencial en la misma señal sintética. La versión secuencial debería producir una pérdida menor en la profundidad-k para k > 1 porque se condiciona en las predicciones intermedias.
Usa el módulo MTP entrenado como un borrador de estilo EAGLE: llama al módulo k para proponer
t_{i+k}en la inferencia. Mide la tasa de aceptación de estos tokens borradores contra las predicciones del modelo principal en una secuencia de prueba reservada. Si logras más del 50% en el juguete, habrás reproducido la propiedad empírica de MTP como borrador.
Términos Clave
| Término | Lo que la gente dice | Lo que realmente significa |
|---|---|---|
| Módulo MTP | "Bloque de pérdida extra" | Un pequeño bloque transformer más proyección que predice un token k posiciones por delante del modelo principal |
| Profundidad de predicción | "Qué desplazamiento" | El entero k tal que el módulo k predice t_{i+k} a partir del prefijo hasta la posición i |
| MTP paralelo | "Estilo Gloeckle" | D cabezas independientes sobre el mismo estado oculto del backbone, sin cadena condicional |
| MTP secuencial | "Estilo DeepSeek-V3" | Cada módulo se condiciona en el estado oculto de la profundidad anterior más el embedding del siguiente token; preserva la cadena causal |
| Cabeza de salida compartida | "Reutilizar la cabeza principal" | Los módulos MTP llaman a la cabeza LM del modelo principal, no a una proyección de salida separada |
| Embedding compartido | "Reutilizar la tabla principal" | La misma tabla de embeddings del vocabulario se usa en todas partes; sin parámetros duplicados |
| Matriz de proyección M_k | "Combinar oculto + siguiente token" | Una capa de proyección lineal h x 2h que fusiona el estado oculto anterior y el embedding del token objetivo en la entrada de la siguiente profundidad |
| Pérdida conjunta L_MTP | "Promedio de pérdidas extras" | Media aritmética de las pérdidas de entropía cruzada por profundidad, escalada por lambda |
| Tasa de aceptación en profundidad 1 | "Frecuencia de acierto del borrador MTP" | La tasa en la que la predicción top-1 del módulo MTP con D=1 es igual a la predicción top-1 del modelo principal; 80%+ en DeepSeek-V3 |
| Ponderación Lambda | "Importancia de la pérdida extra" | Factor de escala por profundidad; 0.3 al inicio del entrenamiento, 0.1 después en DeepSeek-V3 |
Lecturas Adicionales
- DeepSeek-AI — DeepSeek-V3 Technical Report (arXiv:2412.19437) — la descripción completa de MTP secuencial (Sección 2.2), incluyendo las ecuaciones de pérdida conjunta y la aceleración de 1.8× en la inferencia
- Gloeckle et al. — Better & Faster Large Language Models via Multi-token Prediction (arXiv:2404.19737) — la línea de base de MTP paralelo sobre la cual mejora el diseño de DeepSeek
- DeepSeek-V3 model card en Hugging Face — 685B total (671B principal + 14B MTP), notas de despliegue
- Leviathan et al. — Fast Inference from Transformers via Speculative Decoding (arXiv:2211.17192) — el marco de decodificación especulativa en el que encaja MTP
- Li et al. — EAGLE-3 (arXiv:2503.01840) — la arquitectura de borrador de 2025 de EAGLE, la contraparte con la que compite MTP