Phase 03 - Lesson 12
Introdução ao JAX
O PyTorch muta tensores. O TensorFlow constrói grafos. O JAX compila funções puras. Esse último ponto muda a forma como você pensa sobre deep learning.
Tipo: Build Linguagens: Python Pré-requisitos: Fase 03 Lições 01-10, NumPy básico Tempo: ~90 minutos
Objetivos de Aprendizagem
- Escrever código de rede neural com funções puras usando a API funcional do JAX (jax.numpy, jax.grad, jax.jit, jax.vmap)
- Explicar a diferença fundamental de design entre a mutação eager do PyTorch e o modelo de compilação funcional do JAX
- Aplicar compilação jit e vetorização vmap para acelerar loops de treinamento em comparação com Python ingênuo
- Treinar uma rede simples em JAX e contrastar o gerenciamento explícito de estado com a abordagem orientada a objetos do PyTorch
O Problema
Você sabe construir redes neurais em PyTorch. Você define um nn.Module, chama .backward(), dá um passo no otimizador. Funciona. Milhões de pessoas usam.
Mas o PyTorch tem uma restrição embutida no seu DNA: ele rastreia operações de forma eager, uma de cada vez, em Python. Cada tensor + tensor é um lançamento de kernel separado. Cada passo de treinamento reinterpreta o mesmo código Python. Isso funciona bem até você precisar treinar um modelo de 540 bilhões de parâmetros em 2.048 TPUs. Aí o overhead te mata.
O Google DeepMind treina o Gemini em JAX. A Anthropic treinou o Claude em JAX. Essas não são operações pequenas -- são os maiores runs de treinamento de redes neurais da Terra. Elas escolheram o JAX porque ele trata seu loop de treinamento como um programa compilável, não como uma sequência de chamadas Python.
O JAX é NumPy com três superpoderes: diferenciação automática, compilação JIT para XLA e vetorização automática. Você escreve uma função que processa um exemplo. O JAX te dá uma função que processa um batch, calcula gradientes, compila para código de máquina e roda em múltiplos dispositivos. Tudo sem mudar a função original.
O Conceito
A Filosofia do JAX
O JAX é um framework funcional. Sem classes, sem estado mutável, sem método .backward(). Em vez disso:
| PyTorch | JAX |
|---|---|
Classe nn.Module com estado |
Função pura: f(params, x) -> y |
loss.backward() |
jax.grad(loss_fn)(params, x, y) |
| Execução eager | Compilação JIT via XLA |
Loop manual for x in batch: |
jax.vmap(f) auto-vetorização |
DataParallel / FSDP |
jax.pmap(f) auto-paralelismo |
model.parameters() mutável |
Pytree imutável de arrays |
Isso não é uma preferência de estilo. É uma restrição do compilador. A compilação JIT exige funções puras -- mesmas entradas sempre produzem as mesmas saídas, sem efeitos colaterais. É essa restrição que torna possíveis speedups de 100x.
jax.numpy: A Superfície Familiar
O JAX reimplementa a API do NumPy em aceleradores:
import jax.numpy as jnp
a = jnp.array([1.0, 2.0, 3.0])
b = jnp.array([4.0, 5.0, 6.0])
c = jnp.dot(a, b)
Mesmos nomes de funções. Mesmas regras de broadcasting. Mesma semântica de slicing. Mas os arrays vivem na GPU/TPU, e cada operação é rastreável pelo compilador.
Uma diferença crítica: arrays JAX são imutáveis. Nada de a[0] = 5. Em vez disso: a = a.at[0].set(5). Isso parece estranho por uma semana, depois faz clique -- a imutabilidade é o que torna transformações como grad, jit e vmap componíveis.
jax.grad: Autodiff Funcional
O PyTorch anexa gradientes a tensores (.grad). O JAX anexa gradientes a funções.
import jax
def f(x):
return x ** 2
df = jax.grad(f)
df(3.0)
jax.grad recebe uma função e retorna uma nova função que calcula o gradiente. Sem chamada .backward(). Sem grafo de computação armazenado nos tensores. O gradiente é apenas outra função que você pode chamar, compor ou compilar com JIT.
Isso se compõe arbitrariamente:
d2f = jax.grad(jax.grad(f))
d2f(3.0)
Segundas derivadas. Terceiras derivadas. Jacobianas. Hessianas. Tudo compondo grad. O PyTorch também consegue fazer isso (torch.autograd.functional.hessian), mas é algo adicionado por cima. No JAX, é a base.
A restrição: grad só funciona em funções puras. Sem statements de print dentro (eles rodam durante o tracing, não na execução). Sem mutação de estado externo. Sem geração de números aleatórios sem gerenciamento explícito de keys.
jit: Compilar para XLA
@jax.jit
def train_step(params, x, y):
loss = loss_fn(params, x, y)
return loss
fast_step = jax.jit(train_step)
Na primeira chamada, o JAX rastreia a função -- ele registra quais operações acontecem, sem executá-las. Depois entrega esse trace para o XLA (Accelerated Linear Algebra), o compilador do Google para TPUs e GPUs. O XLA funde operações, elimina cópias de memória redundantes e gera código de máquina otimizado.
Chamadas subsequentes pulam o Python por completo. O código compilado roda no acelerador na velocidade do C++.
Quando o JIT ajuda:
- Passos de treinamento (mesma computação repetida milhares de vezes)
- Inferência (mesmo modelo, entradas diferentes)
- Qualquer função chamada mais de uma vez com entradas de formato similar
Quando o JIT atrapalha:
- Funções com controle de fluxo em Python que depende de valores (
if x > 0onde x é um array rastreado) - Computações de uma única vez (o overhead de compilação excede o tempo de execução)
- Debugging (o tracing esconde a execução real)
A restrição de controle de fluxo é real. jax.lax.cond substitui if/else. jax.lax.scan substitui loops for. Eles não são opcionais -- são o preço da compilação.
vmap: Vetorização Automática
Você escreve uma função que processa um exemplo:
def predict(params, x):
return jnp.dot(params['w'], x) + params['b']
vmap a eleva para processar um batch:
batch_predict = jax.vmap(predict, in_axes=(None, 0))
in_axes=(None, 0) significa: não fazer batch sobre params (compartilhado), fazer batch sobre o eixo 0 de x. Sem loop for manual. Sem reshaping. Sem ter que conduzir a dimensão de batch manualmente. O JAX descobre a dimensão de batch e vetoriza a computação inteira.
Isso não é açúcar sintático. vmap gera código vetorizado fundido que roda de 10 a 100x mais rápido que um loop Python. E se compõe com jit e grad:
per_example_grads = jax.vmap(jax.grad(loss_fn), in_axes=(None, 0, 0))
Gradientes por exemplo. Uma linha. Isso é quase impossível em PyTorch sem gambiarras.
pmap: Paralelismo de Dados Entre Dispositivos
parallel_step = jax.pmap(train_step, axis_name='devices')
pmap replica a função em todos os dispositivos disponíveis (GPUs/TPUs) e divide o batch. Dentro da função, jax.lax.pmean e jax.lax.psum sincronizam os gradientes entre os dispositivos.
O Google treina o Gemini em milhares de chips TPU v5e usando pmap (e seu sucessor shard_map). O modelo de programação: escreva a versão para um único dispositivo, envolva com pmap, pronto.
Pytrees: A Estrutura de Dados Universal
O JAX opera sobre "pytrees" -- combinações aninhadas de listas, tuplas, dicts e arrays. Os parâmetros do seu modelo são uma pytree:
params = {
'layer1': {'w': jnp.zeros((784, 256)), 'b': jnp.zeros(256)},
'layer2': {'w': jnp.zeros((256, 128)), 'b': jnp.zeros(128)},
'layer3': {'w': jnp.zeros((128, 10)), 'b': jnp.zeros(10)},
}
Toda transformação do JAX -- grad, jit, vmap -- sabe como percorrer pytrees. jax.tree.map(f, tree) aplica f a cada folha. É assim que otimizadores atualizam todos os parâmetros de uma vez:
params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
Sem método .parameters(). Sem registro de parâmetros. A estrutura da árvore é o modelo.
Funcional vs Orientado a Objetos
O PyTorch armazena estado dentro de objetos:
class Model(nn.Module):
def __init__(self):
self.linear = nn.Linear(784, 10)
def forward(self, x):
return self.linear(x)
O JAX usa funções puras com estado explícito:
def predict(params, x):
return jnp.dot(x, params['w']) + params['b']
Os params são passados como argumento. Nada é armazenado. Nada é mutado. Isso torna cada função testável, componível e compilável. Também significa que você gerencia os params você mesmo -- ou usa uma biblioteca como Flax ou Equinox.
O Ecossistema do JAX
O JAX te dá primitivas. As bibliotecas te dão ergonomia:
| Biblioteca | Papel | Estilo |
|---|---|---|
| Flax (Google) | Camadas de rede neural | nn.Module com estado explícito |
| Equinox (Patrick Kidger) | Camadas de rede neural | Baseado em pytree, Pythônico |
| Optax (DeepMind) | Otimizadores + schedules de LR | Transforms de gradiente componíveis |
| Orbax (Google) | Checkpointing | Salvar/restaurar pytrees |
| CLU (Google) | Métricas + logging | Utilitários para loop de treinamento |
O Optax é a biblioteca padrão de otimizadores. Ele separa a transformação do gradiente (Adam, SGD, clipping) da atualização dos parâmetros, tornando trivial compor:
optimizer = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adam(learning_rate=1e-3),
)
Quando Usar JAX vs PyTorch
| Fator | JAX | PyTorch |
|---|---|---|
| Suporte a TPU | First-class (o Google construiu ambos) | Mantido pela comunidade (torch_xla) |
| Suporte a GPU | Bom (CUDA via XLA) | O melhor da categoria (CUDA nativo) |
| Debugging | Difícil (tracing + compilação) | Fácil (eager, linha por linha) |
| Ecossistema | Focado em pesquisa (Flax, Equinox) | Massivo (HuggingFace, torchvision, etc.) |
| Contratação | Nicho (Google/DeepMind/Anthropic) | Mainstream (em todo lugar) |
| Treinamento em larga escala | Superior (XLA, pmap, mesh) | Bom (FSDP, DeepSpeed) |
| Velocidade de prototipagem | Mais lenta (overhead funcional) | Mais rápida (muta e segue) |
| Inferência em produção | TensorFlow Serving, Vertex AI | TorchServe, Triton, ONNX |
| Quem usa | DeepMind (Gemini), Anthropic (Claude) | Meta (Llama), OpenAI (GPT), Stability AI |
A resposta honesta: use PyTorch a menos que você tenha um motivo específico para usar JAX. Esses motivos são -- acesso a TPU, necessidade de gradientes por exemplo, treinamento multi-dispositivo em escala massiva, ou trabalhar no Google/DeepMind/Anthropic.
Números Aleatórios no JAX
O JAX não tem estado aleatório global. Toda operação aleatória exige uma key PRNG explícita:
key = jax.random.PRNGKey(42)
key1, key2 = jax.random.split(key)
w = jax.random.normal(key1, shape=(784, 256))
Isso é chato no começo. Mas garante reprodutibilidade entre dispositivos e compilações -- uma propriedade que o torch.manual_seed do PyTorch não consegue garantir em cenários multi-GPU.
Construa
Passo 1: Setup e Dados
Vamos treinar um MLP de 3 camadas no MNIST usando JAX e Optax. 784 entradas, duas camadas ocultas de 256 e 128 neurônios, 10 classes de saída.
import jax
import jax.numpy as jnp
from jax import random
import optax
def get_mnist_data():
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
X = mnist.data.astype('float32') / 255.0
y = mnist.target.astype('int')
X_train, X_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]
return X_train, y_train, X_test, y_test
Passo 2: Inicializar Parâmetros
Sem classe. Apenas uma função que retorna uma pytree:
def init_params(key):
k1, k2, k3 = random.split(key, 3)
scale1 = jnp.sqrt(2.0 / 784)
scale2 = jnp.sqrt(2.0 / 256)
scale3 = jnp.sqrt(2.0 / 128)
params = {
'layer1': {
'w': scale1 * random.normal(k1, (784, 256)),
'b': jnp.zeros(256),
},
'layer2': {
'w': scale2 * random.normal(k2, (256, 128)),
'b': jnp.zeros(128),
},
'layer3': {
'w': scale3 * random.normal(k3, (128, 10)),
'b': jnp.zeros(10),
},
}
return params
Inicialização de He, feita manualmente. Três keys PRNG derivadas de uma seed. Cada peso é um array imutável dentro de um dict aninhado.
Passo 3: Forward Pass
def forward(params, x):
x = jnp.dot(x, params['layer1']['w']) + params['layer1']['b']
x = jax.nn.relu(x)
x = jnp.dot(x, params['layer2']['w']) + params['layer2']['b']
x = jax.nn.relu(x)
x = jnp.dot(x, params['layer3']['w']) + params['layer3']['b']
return x
def loss_fn(params, x, y):
logits = forward(params, x)
one_hot = jax.nn.one_hot(y, 10)
return -jnp.mean(jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1))
Funções puras. Params entram, predição sai. Sem self, sem estado armazenado. loss_fn calcula a cross-entropy do zero -- softmax, log, média negativa.
Passo 4: Passo de Treinamento Compilado com JIT
@jax.jit
def train_step(params, opt_state, x, y):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss
@jax.jit
def accuracy(params, x, y):
logits = forward(params, x)
preds = jnp.argmax(logits, axis=-1)
return jnp.mean(preds == y)
jax.value_and_grad retorna tanto o valor da loss quanto os gradientes em uma única passada. O decorator @jax.jit compila ambas as funções para XLA. Após a primeira chamada, cada passo de treinamento roda sem tocar no Python.
Passo 5: Loop de Treinamento
optimizer = optax.adam(learning_rate=1e-3)
X_train, y_train, X_test, y_test = get_mnist_data()
X_train, X_test = jnp.array(X_train), jnp.array(X_test)
y_train, y_test = jnp.array(y_train), jnp.array(y_test)
key = random.PRNGKey(0)
params = init_params(key)
opt_state = optimizer.init(params)
batch_size = 128
n_epochs = 10
for epoch in range(n_epochs):
key, subkey = random.split(key)
perm = random.permutation(subkey, len(X_train))
X_shuffled = X_train[perm]
y_shuffled = y_train[perm]
epoch_loss = 0.0
n_batches = len(X_train) // batch_size
for i in range(n_batches):
start = i * batch_size
xb = X_shuffled[start:start + batch_size]
yb = y_shuffled[start:start + batch_size]
params, opt_state, loss = train_step(params, opt_state, xb, yb)
epoch_loss += loss
train_acc = accuracy(params, X_train[:5000], y_train[:5000])
test_acc = accuracy(params, X_test, y_test)
print(f"Epoch {epoch + 1:2d} | Loss: {epoch_loss / n_batches:.4f} | "
f"Train Acc: {train_acc:.4f} | Test Acc: {test_acc:.4f}")
10 épocas. ~97% de acurácia no teste. A primeira época é lenta (compilação JIT). As épocas 2-10 são rápidas.
Repare no que falta: sem .zero_grad(), sem .backward(), sem .step(). A atualização inteira é uma única chamada de função composta. Os gradientes são calculados, transformados pelo Adam e aplicados aos parâmetros -- tudo dentro de train_step.
Use
Flax: O Padrão do Google
O Flax é a biblioteca de rede neural JAX mais comum. Ele traz o nn.Module de volta, mas com gerenciamento de estado explícito:
import flax.linen as nn
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(256)(x)
x = nn.relu(x)
x = nn.Dense(128)(x)
x = nn.relu(x)
x = nn.Dense(10)(x)
return x
model = MLP()
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 784)))
logits = model.apply(params, x_batch)
Mesma estrutura do PyTorch, mas params é separado do modelo. model.init() cria os params. model.apply(params, x) roda o forward pass. O objeto modelo não tem estado.
Equinox: A Alternativa Pythônica
O Equinox (de Patrick Kidger) representa modelos como pytrees:
import equinox as eqx
model = eqx.nn.MLP(
in_size=784, out_size=10, width_size=256, depth=2,
activation=jax.nn.relu, key=jax.random.PRNGKey(0)
)
logits = model(x)
O próprio modelo é uma pytree. Não precisa de .apply(). Os parâmetros são apenas as folhas do modelo. Isso é mais próximo de como o JAX pensa.
Optax: Otimizadores Componíveis
O Optax desacopla a transformação do gradiente da atualização:
schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0, peak_value=1e-3,
warmup_steps=1000, decay_steps=50000
)
optimizer = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adamw(learning_rate=schedule, weight_decay=0.01),
)
Gradient clipping, warmup de learning rate, weight decay -- tudo composto como uma cadeia de transforms. Cada transform vê os gradientes, os modifica e os passa para o próximo. Sem classe de otimizador monolítica.
Coloque em Produção
Instalação:
pip install jax jaxlib optax flax
Para suporte a GPU:
pip install jax[cuda12]
Para TPU (Google Cloud):
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Pegadinhas de performance:
- A primeira chamada JIT é lenta (compilação). Faça um warm up antes de fazer benchmark.
- Evite loops Python sobre arrays JAX dentro do JIT. Use
jax.lax.scanoujax.lax.fori_loop. jax.debug.print()funciona dentro do JIT. Oprint()comum não.- Faça profiling com
jax.profilerou TensorBoard. A compilação XLA pode esconder gargalos. - O JAX pré-aloca 75% da memória da GPU por padrão. Defina
XLA_PYTHON_CLIENT_PREALLOCATE=falsepara desabilitar.
Checkpointing:
import orbax.checkpoint as ocp
checkpointer = ocp.PyTreeCheckpointer()
checkpointer.save('/tmp/model', params)
restored = checkpointer.restore('/tmp/model')
Esta lição produz:
outputs/prompt-jax-optimizer.md-- um prompt para escolher a configuração certa de otimizador JAXoutputs/skill-jax-patterns.md-- uma skill cobrindo padrões funcionais em JAX
Exercícios
Adicione dropout ao MLP. No JAX, o dropout exige uma key PRNG -- conduza uma key através do forward pass e a divida para cada camada de dropout. Compare a acurácia no teste com e sem.
Use
jax.vmappara calcular gradientes por exemplo para um batch de 32 imagens do MNIST. Calcule a norma do gradiente para cada exemplo. Quais exemplos têm os maiores gradientes, e por quê?Substitua a função forward manual por uma
mlp_forward(params, x)genérica que funciona para qualquer número de camadas. Usejax.tree.leavespara determinar a profundidade automaticamente.Faça benchmark do passo de treinamento com e sem
@jax.jit. Cronometre 100 passos de cada. Quão grande é o speedup no seu hardware? Qual é o overhead de compilação na primeira chamada?Implemente gradient clipping compondo
optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-3)). Treine com e sem clipping. Plote a norma do gradiente ao longo do treinamento para ver o efeito.
Termos-Chave
| Termo | O que as pessoas dizem | O que realmente significa |
|---|---|---|
| XLA | "A coisa que deixa o JAX rápido" | Accelerated Linear Algebra -- um compilador que funde operações e gera kernels de GPU/TPU otimizados a partir de um grafo de computação |
| JIT | "Compilação just-in-time" | O JAX rastreia a função na primeira chamada, compila para XLA e então roda a versão compilada nas chamadas seguintes |
| Função pura | "Sem efeitos colaterais" | Uma função cuja saída depende apenas das entradas -- sem estado global, sem mutação, sem aleatoriedade sem keys explícitas |
| vmap | "Auto-batching" | Transforma uma função que processa um exemplo em uma que processa um batch, sem reescrever |
| pmap | "Auto-paralelismo" | Replica uma função em múltiplos dispositivos e divide o batch de entrada |
| Pytree | "Dict aninhado de arrays" | Qualquer estrutura aninhada de listas, tuplas, dicts e arrays que o JAX consegue percorrer e transformar |
| Tracing | "Gravar a computação" | O JAX executa a função com valores abstratos para construir um grafo de computação, sem calcular resultados reais |
| Autodiff funcional | "grad de uma função" | Calcular derivadas transformando funções, não anexando armazenamento de gradiente aos tensores |
| Optax | "A biblioteca de otimizadores do JAX" | Uma biblioteca componível de transformações de gradiente -- Adam, SGD, clipping, scheduling -- que se encadeiam |
| Flax | "O nn.Module do JAX" | A biblioteca de rede neural do Google para JAX, que adiciona abstrações de camadas mantendo o estado explícito |
Leitura Adicional
- Documentação do JAX: https://jax.readthedocs.io/ -- os docs oficiais, com excelentes tutoriais sobre grad, jit e vmap
- "JAX: composable transformations of Python+NumPy programs" (Bradbury et al., 2018) -- o paper original explicando a filosofia de design
- Documentação do Flax: https://flax.readthedocs.io/ -- a biblioteca de rede neural do Google para JAX
- Patrick Kidger, "Equinox: neural networks in JAX via callable PyTrees and filtered transformations" (2021) -- a alternativa Pythônica ao Flax
- DeepMind, "Optax: composable gradient transformation and optimisation" -- a biblioteca padrão de otimizadores
- "You Don't Know JAX" (Colin Raffel, 2020) -- um guia prático sobre pegadinhas e padrões do JAX, de um dos autores do T5