Phase 03 - Lesson 12

Introducción a JAX

PyTorch muta tensores. TensorFlow construye grafos. JAX compila funciones puras. Ese último punto cambia la forma en que piensas sobre deep learning.

Tipo: Build Lenguajes: Python Prerrequisitos: Fase 03 Lecciones 01-10, NumPy básico Tiempo: ~90 minutos

Objetivos de Aprendizaje

  • Escribir código de red neuronal con funciones puras usando la API funcional de JAX (jax.numpy, jax.grad, jax.jit, jax.vmap)
  • Explicar la diferencia fundamental de diseño entre la mutación eager de PyTorch y el modelo de compilación funcional de JAX
  • Aplicar compilación jit y vectorización vmap para acelerar bucles de entrenamiento en comparación con Python ingenuo
  • Entrenar una red simple en JAX y contrastar la gestión explícita de estado con el enfoque orientado a objetos de PyTorch

El Problema

Sabes construir redes neuronales en PyTorch. Defines un nn.Module, llamas .backward(), das un paso en el optimizador. Funciona. Millones de personas lo usan.

Pero PyTorch tiene una restricción incrustada en su ADN: traza operaciones de forma eager, una a la vez, en Python. Cada tensor + tensor es un lanzamiento de kernel separado. Cada paso de entrenamiento reinterpreta el mismo código Python. Esto funciona bien hasta que necesitas entrenar un modelo de 540 mil millones de parámetros en 2.048 TPUs. Entonces el overhead te mata.

Google DeepMind entrena Gemini en JAX. Anthropic entrenó Claude en JAX. Estas no son operaciones pequeñas -- son los entrenamientos de redes neuronales más grandes de la Tierra. Eligieron JAX porque trata tu bucle de entrenamiento como un programa compilable, no como una secuencia de llamadas de Python.

JAX es NumPy con tres superpoderes: diferenciación automática, compilación JIT a XLA y vectorización automática. Escribes una función que procesa un ejemplo. JAX te da una función que procesa un batch, calcula gradientes, compila a código de máquina y corre en múltiples dispositivos. Todo sin cambiar la función original.

El Concepto

La Filosofía de JAX

JAX es un framework funcional. Sin clases, sin estado mutable, sin método .backward(). En su lugar:

PyTorch JAX
Clase nn.Module con estado Función pura: f(params, x) -> y
loss.backward() jax.grad(loss_fn)(params, x, y)
Ejecución eager Compilación JIT vía XLA
Bucle manual for x in batch: jax.vmap(f) auto-vectorización
DataParallel / FSDP jax.pmap(f) auto-paralelismo
model.parameters() mutable Pytree inmutable de arrays

Esto no es una preferencia de estilo. Es una restricción del compilador. La compilación JIT requiere funciones puras -- las mismas entradas siempre producen las mismas salidas, sin efectos secundarios. Esa restricción es lo que hace posibles los speedups de 100x.

jax.numpy: La Superficie Familiar

JAX reimplementa la API de NumPy en aceleradores:

import jax.numpy as jnp

a = jnp.array([1.0, 2.0, 3.0])
b = jnp.array([4.0, 5.0, 6.0])
c = jnp.dot(a, b)

Mismos nombres de funciones. Mismas reglas de broadcasting. Misma semántica de slicing. Pero los arrays viven en la GPU/TPU, y cada operación es rastreable por el compilador.

Una diferencia crítica: los arrays de JAX son inmutables. Nada de a[0] = 5. En su lugar: a = a.at[0].set(5). Esto se siente incómodo por una semana, luego encaja -- la inmutabilidad es lo que hace que transformaciones como grad, jit y vmap sean componibles.

jax.grad: Autodiff Funcional

PyTorch adjunta gradientes a los tensores (.grad). JAX adjunta gradientes a las funciones.

import jax

def f(x):
    return x ** 2

df = jax.grad(f)
df(3.0)

jax.grad recibe una función y devuelve una nueva función que calcula el gradiente. Sin llamada .backward(). Sin grafo de computación almacenado en los tensores. El gradiente es solo otra función que puedes llamar, componer o compilar con JIT.

Esto se compone arbitrariamente:

d2f = jax.grad(jax.grad(f))
d2f(3.0)

Segundas derivadas. Terceras derivadas. Jacobianos. Hessianos. Todo componiendo grad. PyTorch también puede hacer esto (torch.autograd.functional.hessian), pero está añadido por encima. En JAX, es la base.

La restricción: grad solo funciona en funciones puras. Sin sentencias print dentro (corren durante el tracing, no en la ejecución). Sin mutación de estado externo. Sin generación de números aleatorios sin gestión explícita de keys.

jit: Compilar a XLA

@jax.jit
def train_step(params, x, y):
    loss = loss_fn(params, x, y)
    return loss

fast_step = jax.jit(train_step)

En la primera llamada, JAX traza la función -- registra qué operaciones ocurren, sin ejecutarlas. Luego entrega ese trace a XLA (Accelerated Linear Algebra), el compilador de Google para TPUs y GPUs. XLA fusiona operaciones, elimina copias de memoria redundantes y genera código de máquina optimizado.

Las llamadas siguientes saltan Python por completo. El código compilado corre en el acelerador a velocidad de C++.

Cuándo ayuda JIT:

  • Pasos de entrenamiento (la misma computación repetida miles de veces)
  • Inferencia (mismo modelo, entradas diferentes)
  • Cualquier función llamada más de una vez con entradas de forma similar

Cuándo perjudica JIT:

  • Funciones con control de flujo en Python que depende de valores (if x > 0 donde x es un array rastreado)
  • Computaciones de una sola vez (el overhead de compilación supera el tiempo de ejecución)
  • Debugging (el tracing oculta la ejecución real)

La restricción de control de flujo es real. jax.lax.cond reemplaza if/else. jax.lax.scan reemplaza los bucles for. No son opcionales -- son el precio de la compilación.

vmap: Vectorización Automática

Escribes una función que procesa un ejemplo:

def predict(params, x):
    return jnp.dot(params['w'], x) + params['b']

vmap la eleva para procesar un batch:

batch_predict = jax.vmap(predict, in_axes=(None, 0))

in_axes=(None, 0) significa: no hacer batch sobre params (compartido), hacer batch sobre el eje 0 de x. Sin bucle for manual. Sin reshaping. Sin tener que pasar la dimensión de batch manualmente. JAX descubre la dimensión de batch y vectoriza toda la computación.

Esto no es azúcar sintáctico. vmap genera código vectorizado fusionado que corre de 10 a 100x más rápido que un bucle de Python. Y se compone con jit y grad:

per_example_grads = jax.vmap(jax.grad(loss_fn), in_axes=(None, 0, 0))

Gradientes por ejemplo. Una línea. Esto es casi imposible en PyTorch sin trucos.

pmap: Paralelismo de Datos Entre Dispositivos

parallel_step = jax.pmap(train_step, axis_name='devices')

pmap replica la función en todos los dispositivos disponibles (GPUs/TPUs) y divide el batch. Dentro de la función, jax.lax.pmean y jax.lax.psum sincronizan los gradientes entre los dispositivos.

Google entrena Gemini en miles de chips TPU v5e usando pmap (y su sucesor shard_map). El modelo de programación: escribe la versión de un solo dispositivo, envuélvela con pmap, listo.

Pytrees: La Estructura de Datos Universal

JAX opera sobre "pytrees" -- combinaciones anidadas de listas, tuplas, dicts y arrays. Los parámetros de tu modelo son un pytree:

params = {
    'layer1': {'w': jnp.zeros((784, 256)), 'b': jnp.zeros(256)},
    'layer2': {'w': jnp.zeros((256, 128)), 'b': jnp.zeros(128)},
    'layer3': {'w': jnp.zeros((128, 10)),  'b': jnp.zeros(10)},
}

Cada transformación de JAX -- grad, jit, vmap -- sabe cómo recorrer pytrees. jax.tree.map(f, tree) aplica f a cada hoja. Así es como los optimizadores actualizan todos los parámetros a la vez:

params = jax.tree.map(lambda p, g: p - lr * g, params, grads)

Sin método .parameters(). Sin registro de parámetros. La estructura del árbol es el modelo.

Funcional vs Orientado a Objetos

PyTorch almacena estado dentro de objetos:

class Model(nn.Module):
    def __init__(self):
        self.linear = nn.Linear(784, 10)

    def forward(self, x):
        return self.linear(x)

JAX usa funciones puras con estado explícito:

def predict(params, x):
    return jnp.dot(x, params['w']) + params['b']

Los params se pasan como argumento. Nada se almacena. Nada se muta. Esto hace que cada función sea testeable, componible y compilable. También significa que gestionas los params tú mismo -- o usas una biblioteca como Flax o Equinox.

El Ecosistema de JAX

JAX te da primitivas. Las bibliotecas te dan ergonomía:

Biblioteca Rol Estilo
Flax (Google) Capas de red neuronal nn.Module con estado explícito
Equinox (Patrick Kidger) Capas de red neuronal Basado en pytree, Pythónico
Optax (DeepMind) Optimizadores + schedules de LR Transforms de gradiente componibles
Orbax (Google) Checkpointing Guardar/restaurar pytrees
CLU (Google) Métricas + logging Utilidades para bucle de entrenamiento

Optax es la biblioteca estándar de optimizadores. Separa la transformación del gradiente (Adam, SGD, clipping) de la actualización de los parámetros, haciendo trivial componer:

optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adam(learning_rate=1e-3),
)

Cuándo Usar JAX vs PyTorch

Factor JAX PyTorch
Soporte de TPU First-class (Google construyó ambos) Mantenido por la comunidad (torch_xla)
Soporte de GPU Bueno (CUDA vía XLA) El mejor de su clase (CUDA nativo)
Debugging Difícil (tracing + compilación) Fácil (eager, línea por línea)
Ecosistema Enfocado en investigación (Flax, Equinox) Masivo (HuggingFace, torchvision, etc.)
Contratación Nicho (Google/DeepMind/Anthropic) Mainstream (en todas partes)
Entrenamiento a gran escala Superior (XLA, pmap, mesh) Bueno (FSDP, DeepSpeed)
Velocidad de prototipado Más lenta (overhead funcional) Más rápida (muta y avanza)
Inferencia en producción TensorFlow Serving, Vertex AI TorchServe, Triton, ONNX
Quién lo usa DeepMind (Gemini), Anthropic (Claude) Meta (Llama), OpenAI (GPT), Stability AI

La respuesta honesta: usa PyTorch a menos que tengas una razón específica para usar JAX. Esas razones son -- acceso a TPU, necesidad de gradientes por ejemplo, entrenamiento multi-dispositivo a escala masiva, o trabajar en Google/DeepMind/Anthropic.

Números Aleatorios en JAX

JAX no tiene estado aleatorio global. Cada operación aleatoria requiere una key PRNG explícita:

key = jax.random.PRNGKey(42)
key1, key2 = jax.random.split(key)
w = jax.random.normal(key1, shape=(784, 256))

Esto es molesto al principio. Pero garantiza reproducibilidad entre dispositivos y compilaciones -- una propiedad que el torch.manual_seed de PyTorch no puede garantizar en escenarios multi-GPU.

Constrúyelo

Paso 1: Setup y Datos

Vamos a entrenar un MLP de 3 capas en MNIST usando JAX y Optax. 784 entradas, dos capas ocultas de 256 y 128 neuronas, 10 clases de salida.

import jax
import jax.numpy as jnp
from jax import random
import optax

def get_mnist_data():
    from sklearn.datasets import fetch_openml
    mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
    X = mnist.data.astype('float32') / 255.0
    y = mnist.target.astype('int')
    X_train, X_test = X[:60000], X[60000:]
    y_train, y_test = y[:60000], y[60000:]
    return X_train, y_train, X_test, y_test

Paso 2: Inicializar Parámetros

Sin clase. Solo una función que devuelve un pytree:

def init_params(key):
    k1, k2, k3 = random.split(key, 3)
    scale1 = jnp.sqrt(2.0 / 784)
    scale2 = jnp.sqrt(2.0 / 256)
    scale3 = jnp.sqrt(2.0 / 128)
    params = {
        'layer1': {
            'w': scale1 * random.normal(k1, (784, 256)),
            'b': jnp.zeros(256),
        },
        'layer2': {
            'w': scale2 * random.normal(k2, (256, 128)),
            'b': jnp.zeros(128),
        },
        'layer3': {
            'w': scale3 * random.normal(k3, (128, 10)),
            'b': jnp.zeros(10),
        },
    }
    return params

Inicialización de He, hecha manualmente. Tres keys PRNG derivadas de una seed. Cada peso es un array inmutable dentro de un dict anidado.

Paso 3: Forward Pass

def forward(params, x):
    x = jnp.dot(x, params['layer1']['w']) + params['layer1']['b']
    x = jax.nn.relu(x)
    x = jnp.dot(x, params['layer2']['w']) + params['layer2']['b']
    x = jax.nn.relu(x)
    x = jnp.dot(x, params['layer3']['w']) + params['layer3']['b']
    return x

def loss_fn(params, x, y):
    logits = forward(params, x)
    one_hot = jax.nn.one_hot(y, 10)
    return -jnp.mean(jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1))

Funciones puras. Params entran, predicción sale. Sin self, sin estado almacenado. loss_fn calcula la cross-entropy desde cero -- softmax, log, media negativa.

Paso 4: Paso de Entrenamiento Compilado con JIT

@jax.jit
def train_step(params, opt_state, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

@jax.jit
def accuracy(params, x, y):
    logits = forward(params, x)
    preds = jnp.argmax(logits, axis=-1)
    return jnp.mean(preds == y)

jax.value_and_grad devuelve tanto el valor de la loss como los gradientes en una sola pasada. El decorador @jax.jit compila ambas funciones a XLA. Después de la primera llamada, cada paso de entrenamiento corre sin tocar Python.

Paso 5: Bucle de Entrenamiento

optimizer = optax.adam(learning_rate=1e-3)

X_train, y_train, X_test, y_test = get_mnist_data()
X_train, X_test = jnp.array(X_train), jnp.array(X_test)
y_train, y_test = jnp.array(y_train), jnp.array(y_test)

key = random.PRNGKey(0)
params = init_params(key)
opt_state = optimizer.init(params)

batch_size = 128
n_epochs = 10

for epoch in range(n_epochs):
    key, subkey = random.split(key)
    perm = random.permutation(subkey, len(X_train))
    X_shuffled = X_train[perm]
    y_shuffled = y_train[perm]

    epoch_loss = 0.0
    n_batches = len(X_train) // batch_size
    for i in range(n_batches):
        start = i * batch_size
        xb = X_shuffled[start:start + batch_size]
        yb = y_shuffled[start:start + batch_size]
        params, opt_state, loss = train_step(params, opt_state, xb, yb)
        epoch_loss += loss

    train_acc = accuracy(params, X_train[:5000], y_train[:5000])
    test_acc = accuracy(params, X_test, y_test)
    print(f"Epoch {epoch + 1:2d} | Loss: {epoch_loss / n_batches:.4f} | "
          f"Train Acc: {train_acc:.4f} | Test Acc: {test_acc:.4f}")

10 épocas. ~97% de precisión en el test. La primera época es lenta (compilación JIT). Las épocas 2-10 son rápidas.

Fíjate en lo que falta: sin .zero_grad(), sin .backward(), sin .step(). Toda la actualización es una sola llamada de función compuesta. Los gradientes se calculan, se transforman con Adam y se aplican a los parámetros -- todo dentro de train_step.

Úsalo

Flax: El Estándar de Google

Flax es la biblioteca de red neuronal JAX más común. Reintroduce el nn.Module, pero con gestión de estado explícita:

import flax.linen as nn

class MLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        x = nn.Dense(128)(x)
        x = nn.relu(x)
        x = nn.Dense(10)(x)
        return x

model = MLP()
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 784)))
logits = model.apply(params, x_batch)

Misma estructura que PyTorch, pero params está separado del modelo. model.init() crea los params. model.apply(params, x) corre el forward pass. El objeto modelo no tiene estado.

Equinox: La Alternativa Pythónica

Equinox (de Patrick Kidger) representa modelos como pytrees:

import equinox as eqx

model = eqx.nn.MLP(
    in_size=784, out_size=10, width_size=256, depth=2,
    activation=jax.nn.relu, key=jax.random.PRNGKey(0)
)
logits = model(x)

El propio modelo es un pytree. No necesita .apply(). Los parámetros son simplemente las hojas del modelo. Esto es más cercano a cómo piensa JAX.

Optax: Optimizadores Componibles

Optax desacopla la transformación del gradiente de la actualización:

schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0, peak_value=1e-3,
    warmup_steps=1000, decay_steps=50000
)

optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adamw(learning_rate=schedule, weight_decay=0.01),
)

Gradient clipping, warmup de learning rate, weight decay -- todo compuesto como una cadena de transforms. Cada transform ve los gradientes, los modifica y los pasa al siguiente. Sin clase de optimizador monolítica.

Llévalo a Producción

Instalación:

pip install jax jaxlib optax flax

Para soporte de GPU:

pip install jax[cuda12]

Para TPU (Google Cloud):

pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Trampas de rendimiento:

  • La primera llamada JIT es lenta (compilación). Haz un warm up antes de hacer benchmark.
  • Evita bucles de Python sobre arrays JAX dentro del JIT. Usa jax.lax.scan o jax.lax.fori_loop.
  • jax.debug.print() funciona dentro del JIT. El print() normal no.
  • Haz profiling con jax.profiler o TensorBoard. La compilación XLA puede ocultar cuellos de botella.
  • JAX preasigna el 75% de la memoria de la GPU por defecto. Define XLA_PYTHON_CLIENT_PREALLOCATE=false para deshabilitarlo.

Checkpointing:

import orbax.checkpoint as ocp
checkpointer = ocp.PyTreeCheckpointer()
checkpointer.save('/tmp/model', params)
restored = checkpointer.restore('/tmp/model')

Esta lección produce:

  • outputs/prompt-jax-optimizer.md -- un prompt para elegir la configuración correcta de optimizador JAX
  • outputs/skill-jax-patterns.md -- una skill que cubre patrones funcionales en JAX

Ejercicios

  1. Agrega dropout al MLP. En JAX, el dropout requiere una key PRNG -- pasa una key a través del forward pass y divídela para cada capa de dropout. Compara la precisión en el test con y sin.

  2. Usa jax.vmap para calcular gradientes por ejemplo para un batch de 32 imágenes de MNIST. Calcula la norma del gradiente para cada ejemplo. ¿Qué ejemplos tienen los gradientes más grandes, y por qué?

  3. Reemplaza la función forward manual por una mlp_forward(params, x) genérica que funcione para cualquier número de capas. Usa jax.tree.leaves para determinar la profundidad automáticamente.

  4. Haz benchmark del paso de entrenamiento con y sin @jax.jit. Cronometra 100 pasos de cada uno. ¿Qué tan grande es el speedup en tu hardware? ¿Cuál es el overhead de compilación en la primera llamada?

  5. Implementa gradient clipping componiendo optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-3)). Entrena con y sin clipping. Grafica la norma del gradiente a lo largo del entrenamiento para ver el efecto.

Términos Clave

Término Lo que la gente dice Lo que realmente significa
XLA "La cosa que hace rápido a JAX" Accelerated Linear Algebra -- un compilador que fusiona operaciones y genera kernels de GPU/TPU optimizados a partir de un grafo de computación
JIT "Compilación just-in-time" JAX traza la función en la primera llamada, compila a XLA y luego corre la versión compilada en las llamadas siguientes
Función pura "Sin efectos secundarios" Una función cuya salida depende solo de las entradas -- sin estado global, sin mutación, sin aleatoriedad sin keys explícitas
vmap "Auto-batching" Transforma una función que procesa un ejemplo en una que procesa un batch, sin reescribir
pmap "Auto-paralelismo" Replica una función en múltiples dispositivos y divide el batch de entrada
Pytree "Dict anidado de arrays" Cualquier estructura anidada de listas, tuplas, dicts y arrays que JAX pueda recorrer y transformar
Tracing "Registrar la computación" JAX ejecuta la función con valores abstractos para construir un grafo de computación, sin calcular resultados reales
Autodiff funcional "grad de una función" Calcular derivadas transformando funciones, no adjuntando almacenamiento de gradiente a los tensores
Optax "La biblioteca de optimizadores de JAX" Una biblioteca componible de transformaciones de gradiente -- Adam, SGD, clipping, scheduling -- que se encadenan
Flax "El nn.Module de JAX" La biblioteca de red neuronal de Google para JAX, que agrega abstracciones de capas manteniendo el estado explícito

Lecturas Adicionales

  • Documentación de JAX: https://jax.readthedocs.io/ -- los docs oficiales, con excelentes tutoriales sobre grad, jit y vmap
  • "JAX: composable transformations of Python+NumPy programs" (Bradbury et al., 2018) -- el paper original que explica la filosofía de diseño
  • Documentación de Flax: https://flax.readthedocs.io/ -- la biblioteca de red neuronal de Google para JAX
  • Patrick Kidger, "Equinox: neural networks in JAX via callable PyTrees and filtered transformations" (2021) -- la alternativa Pythónica a Flax
  • DeepMind, "Optax: composable gradient transformation and optimisation" -- la biblioteca estándar de optimizadores
  • "You Don't Know JAX" (Colin Raffel, 2020) -- una guía práctica sobre trampas y patrones de JAX, de uno de los autores de T5
0 lifetime access. Curriculum based on AI Engineering from Scratch by Rohit Ghumare (MIT, used under attribution).