Phase 03 - Lesson 12
Introducción a JAX
PyTorch muta tensores. TensorFlow construye grafos. JAX compila funciones puras. Ese último punto cambia la forma en que piensas sobre deep learning.
Tipo: Build Lenguajes: Python Prerrequisitos: Fase 03 Lecciones 01-10, NumPy básico Tiempo: ~90 minutos
Objetivos de Aprendizaje
- Escribir código de red neuronal con funciones puras usando la API funcional de JAX (jax.numpy, jax.grad, jax.jit, jax.vmap)
- Explicar la diferencia fundamental de diseño entre la mutación eager de PyTorch y el modelo de compilación funcional de JAX
- Aplicar compilación jit y vectorización vmap para acelerar bucles de entrenamiento en comparación con Python ingenuo
- Entrenar una red simple en JAX y contrastar la gestión explícita de estado con el enfoque orientado a objetos de PyTorch
El Problema
Sabes construir redes neuronales en PyTorch. Defines un nn.Module, llamas .backward(), das un paso en el optimizador. Funciona. Millones de personas lo usan.
Pero PyTorch tiene una restricción incrustada en su ADN: traza operaciones de forma eager, una a la vez, en Python. Cada tensor + tensor es un lanzamiento de kernel separado. Cada paso de entrenamiento reinterpreta el mismo código Python. Esto funciona bien hasta que necesitas entrenar un modelo de 540 mil millones de parámetros en 2.048 TPUs. Entonces el overhead te mata.
Google DeepMind entrena Gemini en JAX. Anthropic entrenó Claude en JAX. Estas no son operaciones pequeñas -- son los entrenamientos de redes neuronales más grandes de la Tierra. Eligieron JAX porque trata tu bucle de entrenamiento como un programa compilable, no como una secuencia de llamadas de Python.
JAX es NumPy con tres superpoderes: diferenciación automática, compilación JIT a XLA y vectorización automática. Escribes una función que procesa un ejemplo. JAX te da una función que procesa un batch, calcula gradientes, compila a código de máquina y corre en múltiples dispositivos. Todo sin cambiar la función original.
El Concepto
La Filosofía de JAX
JAX es un framework funcional. Sin clases, sin estado mutable, sin método .backward(). En su lugar:
| PyTorch | JAX |
|---|---|
Clase nn.Module con estado |
Función pura: f(params, x) -> y |
loss.backward() |
jax.grad(loss_fn)(params, x, y) |
| Ejecución eager | Compilación JIT vía XLA |
Bucle manual for x in batch: |
jax.vmap(f) auto-vectorización |
DataParallel / FSDP |
jax.pmap(f) auto-paralelismo |
model.parameters() mutable |
Pytree inmutable de arrays |
Esto no es una preferencia de estilo. Es una restricción del compilador. La compilación JIT requiere funciones puras -- las mismas entradas siempre producen las mismas salidas, sin efectos secundarios. Esa restricción es lo que hace posibles los speedups de 100x.
jax.numpy: La Superficie Familiar
JAX reimplementa la API de NumPy en aceleradores:
import jax.numpy as jnp
a = jnp.array([1.0, 2.0, 3.0])
b = jnp.array([4.0, 5.0, 6.0])
c = jnp.dot(a, b)
Mismos nombres de funciones. Mismas reglas de broadcasting. Misma semántica de slicing. Pero los arrays viven en la GPU/TPU, y cada operación es rastreable por el compilador.
Una diferencia crítica: los arrays de JAX son inmutables. Nada de a[0] = 5. En su lugar: a = a.at[0].set(5). Esto se siente incómodo por una semana, luego encaja -- la inmutabilidad es lo que hace que transformaciones como grad, jit y vmap sean componibles.
jax.grad: Autodiff Funcional
PyTorch adjunta gradientes a los tensores (.grad). JAX adjunta gradientes a las funciones.
import jax
def f(x):
return x ** 2
df = jax.grad(f)
df(3.0)
jax.grad recibe una función y devuelve una nueva función que calcula el gradiente. Sin llamada .backward(). Sin grafo de computación almacenado en los tensores. El gradiente es solo otra función que puedes llamar, componer o compilar con JIT.
Esto se compone arbitrariamente:
d2f = jax.grad(jax.grad(f))
d2f(3.0)
Segundas derivadas. Terceras derivadas. Jacobianos. Hessianos. Todo componiendo grad. PyTorch también puede hacer esto (torch.autograd.functional.hessian), pero está añadido por encima. En JAX, es la base.
La restricción: grad solo funciona en funciones puras. Sin sentencias print dentro (corren durante el tracing, no en la ejecución). Sin mutación de estado externo. Sin generación de números aleatorios sin gestión explícita de keys.
jit: Compilar a XLA
@jax.jit
def train_step(params, x, y):
loss = loss_fn(params, x, y)
return loss
fast_step = jax.jit(train_step)
En la primera llamada, JAX traza la función -- registra qué operaciones ocurren, sin ejecutarlas. Luego entrega ese trace a XLA (Accelerated Linear Algebra), el compilador de Google para TPUs y GPUs. XLA fusiona operaciones, elimina copias de memoria redundantes y genera código de máquina optimizado.
Las llamadas siguientes saltan Python por completo. El código compilado corre en el acelerador a velocidad de C++.
Cuándo ayuda JIT:
- Pasos de entrenamiento (la misma computación repetida miles de veces)
- Inferencia (mismo modelo, entradas diferentes)
- Cualquier función llamada más de una vez con entradas de forma similar
Cuándo perjudica JIT:
- Funciones con control de flujo en Python que depende de valores (
if x > 0donde x es un array rastreado) - Computaciones de una sola vez (el overhead de compilación supera el tiempo de ejecución)
- Debugging (el tracing oculta la ejecución real)
La restricción de control de flujo es real. jax.lax.cond reemplaza if/else. jax.lax.scan reemplaza los bucles for. No son opcionales -- son el precio de la compilación.
vmap: Vectorización Automática
Escribes una función que procesa un ejemplo:
def predict(params, x):
return jnp.dot(params['w'], x) + params['b']
vmap la eleva para procesar un batch:
batch_predict = jax.vmap(predict, in_axes=(None, 0))
in_axes=(None, 0) significa: no hacer batch sobre params (compartido), hacer batch sobre el eje 0 de x. Sin bucle for manual. Sin reshaping. Sin tener que pasar la dimensión de batch manualmente. JAX descubre la dimensión de batch y vectoriza toda la computación.
Esto no es azúcar sintáctico. vmap genera código vectorizado fusionado que corre de 10 a 100x más rápido que un bucle de Python. Y se compone con jit y grad:
per_example_grads = jax.vmap(jax.grad(loss_fn), in_axes=(None, 0, 0))
Gradientes por ejemplo. Una línea. Esto es casi imposible en PyTorch sin trucos.
pmap: Paralelismo de Datos Entre Dispositivos
parallel_step = jax.pmap(train_step, axis_name='devices')
pmap replica la función en todos los dispositivos disponibles (GPUs/TPUs) y divide el batch. Dentro de la función, jax.lax.pmean y jax.lax.psum sincronizan los gradientes entre los dispositivos.
Google entrena Gemini en miles de chips TPU v5e usando pmap (y su sucesor shard_map). El modelo de programación: escribe la versión de un solo dispositivo, envuélvela con pmap, listo.
Pytrees: La Estructura de Datos Universal
JAX opera sobre "pytrees" -- combinaciones anidadas de listas, tuplas, dicts y arrays. Los parámetros de tu modelo son un pytree:
params = {
'layer1': {'w': jnp.zeros((784, 256)), 'b': jnp.zeros(256)},
'layer2': {'w': jnp.zeros((256, 128)), 'b': jnp.zeros(128)},
'layer3': {'w': jnp.zeros((128, 10)), 'b': jnp.zeros(10)},
}
Cada transformación de JAX -- grad, jit, vmap -- sabe cómo recorrer pytrees. jax.tree.map(f, tree) aplica f a cada hoja. Así es como los optimizadores actualizan todos los parámetros a la vez:
params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
Sin método .parameters(). Sin registro de parámetros. La estructura del árbol es el modelo.
Funcional vs Orientado a Objetos
PyTorch almacena estado dentro de objetos:
class Model(nn.Module):
def __init__(self):
self.linear = nn.Linear(784, 10)
def forward(self, x):
return self.linear(x)
JAX usa funciones puras con estado explícito:
def predict(params, x):
return jnp.dot(x, params['w']) + params['b']
Los params se pasan como argumento. Nada se almacena. Nada se muta. Esto hace que cada función sea testeable, componible y compilable. También significa que gestionas los params tú mismo -- o usas una biblioteca como Flax o Equinox.
El Ecosistema de JAX
JAX te da primitivas. Las bibliotecas te dan ergonomía:
| Biblioteca | Rol | Estilo |
|---|---|---|
| Flax (Google) | Capas de red neuronal | nn.Module con estado explícito |
| Equinox (Patrick Kidger) | Capas de red neuronal | Basado en pytree, Pythónico |
| Optax (DeepMind) | Optimizadores + schedules de LR | Transforms de gradiente componibles |
| Orbax (Google) | Checkpointing | Guardar/restaurar pytrees |
| CLU (Google) | Métricas + logging | Utilidades para bucle de entrenamiento |
Optax es la biblioteca estándar de optimizadores. Separa la transformación del gradiente (Adam, SGD, clipping) de la actualización de los parámetros, haciendo trivial componer:
optimizer = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adam(learning_rate=1e-3),
)
Cuándo Usar JAX vs PyTorch
| Factor | JAX | PyTorch |
|---|---|---|
| Soporte de TPU | First-class (Google construyó ambos) | Mantenido por la comunidad (torch_xla) |
| Soporte de GPU | Bueno (CUDA vía XLA) | El mejor de su clase (CUDA nativo) |
| Debugging | Difícil (tracing + compilación) | Fácil (eager, línea por línea) |
| Ecosistema | Enfocado en investigación (Flax, Equinox) | Masivo (HuggingFace, torchvision, etc.) |
| Contratación | Nicho (Google/DeepMind/Anthropic) | Mainstream (en todas partes) |
| Entrenamiento a gran escala | Superior (XLA, pmap, mesh) | Bueno (FSDP, DeepSpeed) |
| Velocidad de prototipado | Más lenta (overhead funcional) | Más rápida (muta y avanza) |
| Inferencia en producción | TensorFlow Serving, Vertex AI | TorchServe, Triton, ONNX |
| Quién lo usa | DeepMind (Gemini), Anthropic (Claude) | Meta (Llama), OpenAI (GPT), Stability AI |
La respuesta honesta: usa PyTorch a menos que tengas una razón específica para usar JAX. Esas razones son -- acceso a TPU, necesidad de gradientes por ejemplo, entrenamiento multi-dispositivo a escala masiva, o trabajar en Google/DeepMind/Anthropic.
Números Aleatorios en JAX
JAX no tiene estado aleatorio global. Cada operación aleatoria requiere una key PRNG explícita:
key = jax.random.PRNGKey(42)
key1, key2 = jax.random.split(key)
w = jax.random.normal(key1, shape=(784, 256))
Esto es molesto al principio. Pero garantiza reproducibilidad entre dispositivos y compilaciones -- una propiedad que el torch.manual_seed de PyTorch no puede garantizar en escenarios multi-GPU.
Constrúyelo
Paso 1: Setup y Datos
Vamos a entrenar un MLP de 3 capas en MNIST usando JAX y Optax. 784 entradas, dos capas ocultas de 256 y 128 neuronas, 10 clases de salida.
import jax
import jax.numpy as jnp
from jax import random
import optax
def get_mnist_data():
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
X = mnist.data.astype('float32') / 255.0
y = mnist.target.astype('int')
X_train, X_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]
return X_train, y_train, X_test, y_test
Paso 2: Inicializar Parámetros
Sin clase. Solo una función que devuelve un pytree:
def init_params(key):
k1, k2, k3 = random.split(key, 3)
scale1 = jnp.sqrt(2.0 / 784)
scale2 = jnp.sqrt(2.0 / 256)
scale3 = jnp.sqrt(2.0 / 128)
params = {
'layer1': {
'w': scale1 * random.normal(k1, (784, 256)),
'b': jnp.zeros(256),
},
'layer2': {
'w': scale2 * random.normal(k2, (256, 128)),
'b': jnp.zeros(128),
},
'layer3': {
'w': scale3 * random.normal(k3, (128, 10)),
'b': jnp.zeros(10),
},
}
return params
Inicialización de He, hecha manualmente. Tres keys PRNG derivadas de una seed. Cada peso es un array inmutable dentro de un dict anidado.
Paso 3: Forward Pass
def forward(params, x):
x = jnp.dot(x, params['layer1']['w']) + params['layer1']['b']
x = jax.nn.relu(x)
x = jnp.dot(x, params['layer2']['w']) + params['layer2']['b']
x = jax.nn.relu(x)
x = jnp.dot(x, params['layer3']['w']) + params['layer3']['b']
return x
def loss_fn(params, x, y):
logits = forward(params, x)
one_hot = jax.nn.one_hot(y, 10)
return -jnp.mean(jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1))
Funciones puras. Params entran, predicción sale. Sin self, sin estado almacenado. loss_fn calcula la cross-entropy desde cero -- softmax, log, media negativa.
Paso 4: Paso de Entrenamiento Compilado con JIT
@jax.jit
def train_step(params, opt_state, x, y):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss
@jax.jit
def accuracy(params, x, y):
logits = forward(params, x)
preds = jnp.argmax(logits, axis=-1)
return jnp.mean(preds == y)
jax.value_and_grad devuelve tanto el valor de la loss como los gradientes en una sola pasada. El decorador @jax.jit compila ambas funciones a XLA. Después de la primera llamada, cada paso de entrenamiento corre sin tocar Python.
Paso 5: Bucle de Entrenamiento
optimizer = optax.adam(learning_rate=1e-3)
X_train, y_train, X_test, y_test = get_mnist_data()
X_train, X_test = jnp.array(X_train), jnp.array(X_test)
y_train, y_test = jnp.array(y_train), jnp.array(y_test)
key = random.PRNGKey(0)
params = init_params(key)
opt_state = optimizer.init(params)
batch_size = 128
n_epochs = 10
for epoch in range(n_epochs):
key, subkey = random.split(key)
perm = random.permutation(subkey, len(X_train))
X_shuffled = X_train[perm]
y_shuffled = y_train[perm]
epoch_loss = 0.0
n_batches = len(X_train) // batch_size
for i in range(n_batches):
start = i * batch_size
xb = X_shuffled[start:start + batch_size]
yb = y_shuffled[start:start + batch_size]
params, opt_state, loss = train_step(params, opt_state, xb, yb)
epoch_loss += loss
train_acc = accuracy(params, X_train[:5000], y_train[:5000])
test_acc = accuracy(params, X_test, y_test)
print(f"Epoch {epoch + 1:2d} | Loss: {epoch_loss / n_batches:.4f} | "
f"Train Acc: {train_acc:.4f} | Test Acc: {test_acc:.4f}")
10 épocas. ~97% de precisión en el test. La primera época es lenta (compilación JIT). Las épocas 2-10 son rápidas.
Fíjate en lo que falta: sin .zero_grad(), sin .backward(), sin .step(). Toda la actualización es una sola llamada de función compuesta. Los gradientes se calculan, se transforman con Adam y se aplican a los parámetros -- todo dentro de train_step.
Úsalo
Flax: El Estándar de Google
Flax es la biblioteca de red neuronal JAX más común. Reintroduce el nn.Module, pero con gestión de estado explícita:
import flax.linen as nn
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(256)(x)
x = nn.relu(x)
x = nn.Dense(128)(x)
x = nn.relu(x)
x = nn.Dense(10)(x)
return x
model = MLP()
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 784)))
logits = model.apply(params, x_batch)
Misma estructura que PyTorch, pero params está separado del modelo. model.init() crea los params. model.apply(params, x) corre el forward pass. El objeto modelo no tiene estado.
Equinox: La Alternativa Pythónica
Equinox (de Patrick Kidger) representa modelos como pytrees:
import equinox as eqx
model = eqx.nn.MLP(
in_size=784, out_size=10, width_size=256, depth=2,
activation=jax.nn.relu, key=jax.random.PRNGKey(0)
)
logits = model(x)
El propio modelo es un pytree. No necesita .apply(). Los parámetros son simplemente las hojas del modelo. Esto es más cercano a cómo piensa JAX.
Optax: Optimizadores Componibles
Optax desacopla la transformación del gradiente de la actualización:
schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0, peak_value=1e-3,
warmup_steps=1000, decay_steps=50000
)
optimizer = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adamw(learning_rate=schedule, weight_decay=0.01),
)
Gradient clipping, warmup de learning rate, weight decay -- todo compuesto como una cadena de transforms. Cada transform ve los gradientes, los modifica y los pasa al siguiente. Sin clase de optimizador monolítica.
Llévalo a Producción
Instalación:
pip install jax jaxlib optax flax
Para soporte de GPU:
pip install jax[cuda12]
Para TPU (Google Cloud):
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Trampas de rendimiento:
- La primera llamada JIT es lenta (compilación). Haz un warm up antes de hacer benchmark.
- Evita bucles de Python sobre arrays JAX dentro del JIT. Usa
jax.lax.scanojax.lax.fori_loop. jax.debug.print()funciona dentro del JIT. Elprint()normal no.- Haz profiling con
jax.profilero TensorBoard. La compilación XLA puede ocultar cuellos de botella. - JAX preasigna el 75% de la memoria de la GPU por defecto. Define
XLA_PYTHON_CLIENT_PREALLOCATE=falsepara deshabilitarlo.
Checkpointing:
import orbax.checkpoint as ocp
checkpointer = ocp.PyTreeCheckpointer()
checkpointer.save('/tmp/model', params)
restored = checkpointer.restore('/tmp/model')
Esta lección produce:
outputs/prompt-jax-optimizer.md-- un prompt para elegir la configuración correcta de optimizador JAXoutputs/skill-jax-patterns.md-- una skill que cubre patrones funcionales en JAX
Ejercicios
Agrega dropout al MLP. En JAX, el dropout requiere una key PRNG -- pasa una key a través del forward pass y divídela para cada capa de dropout. Compara la precisión en el test con y sin.
Usa
jax.vmappara calcular gradientes por ejemplo para un batch de 32 imágenes de MNIST. Calcula la norma del gradiente para cada ejemplo. ¿Qué ejemplos tienen los gradientes más grandes, y por qué?Reemplaza la función forward manual por una
mlp_forward(params, x)genérica que funcione para cualquier número de capas. Usajax.tree.leavespara determinar la profundidad automáticamente.Haz benchmark del paso de entrenamiento con y sin
@jax.jit. Cronometra 100 pasos de cada uno. ¿Qué tan grande es el speedup en tu hardware? ¿Cuál es el overhead de compilación en la primera llamada?Implementa gradient clipping componiendo
optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-3)). Entrena con y sin clipping. Grafica la norma del gradiente a lo largo del entrenamiento para ver el efecto.
Términos Clave
| Término | Lo que la gente dice | Lo que realmente significa |
|---|---|---|
| XLA | "La cosa que hace rápido a JAX" | Accelerated Linear Algebra -- un compilador que fusiona operaciones y genera kernels de GPU/TPU optimizados a partir de un grafo de computación |
| JIT | "Compilación just-in-time" | JAX traza la función en la primera llamada, compila a XLA y luego corre la versión compilada en las llamadas siguientes |
| Función pura | "Sin efectos secundarios" | Una función cuya salida depende solo de las entradas -- sin estado global, sin mutación, sin aleatoriedad sin keys explícitas |
| vmap | "Auto-batching" | Transforma una función que procesa un ejemplo en una que procesa un batch, sin reescribir |
| pmap | "Auto-paralelismo" | Replica una función en múltiples dispositivos y divide el batch de entrada |
| Pytree | "Dict anidado de arrays" | Cualquier estructura anidada de listas, tuplas, dicts y arrays que JAX pueda recorrer y transformar |
| Tracing | "Registrar la computación" | JAX ejecuta la función con valores abstractos para construir un grafo de computación, sin calcular resultados reales |
| Autodiff funcional | "grad de una función" | Calcular derivadas transformando funciones, no adjuntando almacenamiento de gradiente a los tensores |
| Optax | "La biblioteca de optimizadores de JAX" | Una biblioteca componible de transformaciones de gradiente -- Adam, SGD, clipping, scheduling -- que se encadenan |
| Flax | "El nn.Module de JAX" | La biblioteca de red neuronal de Google para JAX, que agrega abstracciones de capas manteniendo el estado explícito |
Lecturas Adicionales
- Documentación de JAX: https://jax.readthedocs.io/ -- los docs oficiales, con excelentes tutoriales sobre grad, jit y vmap
- "JAX: composable transformations of Python+NumPy programs" (Bradbury et al., 2018) -- el paper original que explica la filosofía de diseño
- Documentación de Flax: https://flax.readthedocs.io/ -- la biblioteca de red neuronal de Google para JAX
- Patrick Kidger, "Equinox: neural networks in JAX via callable PyTrees and filtered transformations" (2021) -- la alternativa Pythónica a Flax
- DeepMind, "Optax: composable gradient transformation and optimisation" -- la biblioteca estándar de optimizadores
- "You Don't Know JAX" (Colin Raffel, 2020) -- una guía práctica sobre trampas y patrones de JAX, de uno de los autores de T5