Phase 03 - Lesson 12

Introdução ao JAX

O PyTorch muta tensores. O TensorFlow constrói grafos. O JAX compila funções puras. Esse último ponto muda a forma como você pensa sobre deep learning.

Tipo: Build Linguagens: Python Pré-requisitos: Fase 03 Lições 01-10, NumPy básico Tempo: ~90 minutos

Objetivos de Aprendizagem

  • Escrever código de rede neural com funções puras usando a API funcional do JAX (jax.numpy, jax.grad, jax.jit, jax.vmap)
  • Explicar a diferença fundamental de design entre a mutação eager do PyTorch e o modelo de compilação funcional do JAX
  • Aplicar compilação jit e vetorização vmap para acelerar loops de treinamento em comparação com Python ingênuo
  • Treinar uma rede simples em JAX e contrastar o gerenciamento explícito de estado com a abordagem orientada a objetos do PyTorch

O Problema

Você sabe construir redes neurais em PyTorch. Você define um nn.Module, chama .backward(), dá um passo no otimizador. Funciona. Milhões de pessoas usam.

Mas o PyTorch tem uma restrição embutida no seu DNA: ele rastreia operações de forma eager, uma de cada vez, em Python. Cada tensor + tensor é um lançamento de kernel separado. Cada passo de treinamento reinterpreta o mesmo código Python. Isso funciona bem até você precisar treinar um modelo de 540 bilhões de parâmetros em 2.048 TPUs. Aí o overhead te mata.

O Google DeepMind treina o Gemini em JAX. A Anthropic treinou o Claude em JAX. Essas não são operações pequenas -- são os maiores runs de treinamento de redes neurais da Terra. Elas escolheram o JAX porque ele trata seu loop de treinamento como um programa compilável, não como uma sequência de chamadas Python.

O JAX é NumPy com três superpoderes: diferenciação automática, compilação JIT para XLA e vetorização automática. Você escreve uma função que processa um exemplo. O JAX te dá uma função que processa um batch, calcula gradientes, compila para código de máquina e roda em múltiplos dispositivos. Tudo sem mudar a função original.

O Conceito

A Filosofia do JAX

O JAX é um framework funcional. Sem classes, sem estado mutável, sem método .backward(). Em vez disso:

PyTorch JAX
Classe nn.Module com estado Função pura: f(params, x) -> y
loss.backward() jax.grad(loss_fn)(params, x, y)
Execução eager Compilação JIT via XLA
Loop manual for x in batch: jax.vmap(f) auto-vetorização
DataParallel / FSDP jax.pmap(f) auto-paralelismo
model.parameters() mutável Pytree imutável de arrays

Isso não é uma preferência de estilo. É uma restrição do compilador. A compilação JIT exige funções puras -- mesmas entradas sempre produzem as mesmas saídas, sem efeitos colaterais. É essa restrição que torna possíveis speedups de 100x.

jax.numpy: A Superfície Familiar

O JAX reimplementa a API do NumPy em aceleradores:

import jax.numpy as jnp

a = jnp.array([1.0, 2.0, 3.0])
b = jnp.array([4.0, 5.0, 6.0])
c = jnp.dot(a, b)

Mesmos nomes de funções. Mesmas regras de broadcasting. Mesma semântica de slicing. Mas os arrays vivem na GPU/TPU, e cada operação é rastreável pelo compilador.

Uma diferença crítica: arrays JAX são imutáveis. Nada de a[0] = 5. Em vez disso: a = a.at[0].set(5). Isso parece estranho por uma semana, depois faz clique -- a imutabilidade é o que torna transformações como grad, jit e vmap componíveis.

jax.grad: Autodiff Funcional

O PyTorch anexa gradientes a tensores (.grad). O JAX anexa gradientes a funções.

import jax

def f(x):
    return x ** 2

df = jax.grad(f)
df(3.0)

jax.grad recebe uma função e retorna uma nova função que calcula o gradiente. Sem chamada .backward(). Sem grafo de computação armazenado nos tensores. O gradiente é apenas outra função que você pode chamar, compor ou compilar com JIT.

Isso se compõe arbitrariamente:

d2f = jax.grad(jax.grad(f))
d2f(3.0)

Segundas derivadas. Terceiras derivadas. Jacobianas. Hessianas. Tudo compondo grad. O PyTorch também consegue fazer isso (torch.autograd.functional.hessian), mas é algo adicionado por cima. No JAX, é a base.

A restrição: grad só funciona em funções puras. Sem statements de print dentro (eles rodam durante o tracing, não na execução). Sem mutação de estado externo. Sem geração de números aleatórios sem gerenciamento explícito de keys.

jit: Compilar para XLA

@jax.jit
def train_step(params, x, y):
    loss = loss_fn(params, x, y)
    return loss

fast_step = jax.jit(train_step)

Na primeira chamada, o JAX rastreia a função -- ele registra quais operações acontecem, sem executá-las. Depois entrega esse trace para o XLA (Accelerated Linear Algebra), o compilador do Google para TPUs e GPUs. O XLA funde operações, elimina cópias de memória redundantes e gera código de máquina otimizado.

Chamadas subsequentes pulam o Python por completo. O código compilado roda no acelerador na velocidade do C++.

Quando o JIT ajuda:

  • Passos de treinamento (mesma computação repetida milhares de vezes)
  • Inferência (mesmo modelo, entradas diferentes)
  • Qualquer função chamada mais de uma vez com entradas de formato similar

Quando o JIT atrapalha:

  • Funções com controle de fluxo em Python que depende de valores (if x > 0 onde x é um array rastreado)
  • Computações de uma única vez (o overhead de compilação excede o tempo de execução)
  • Debugging (o tracing esconde a execução real)

A restrição de controle de fluxo é real. jax.lax.cond substitui if/else. jax.lax.scan substitui loops for. Eles não são opcionais -- são o preço da compilação.

vmap: Vetorização Automática

Você escreve uma função que processa um exemplo:

def predict(params, x):
    return jnp.dot(params['w'], x) + params['b']

vmap a eleva para processar um batch:

batch_predict = jax.vmap(predict, in_axes=(None, 0))

in_axes=(None, 0) significa: não fazer batch sobre params (compartilhado), fazer batch sobre o eixo 0 de x. Sem loop for manual. Sem reshaping. Sem ter que conduzir a dimensão de batch manualmente. O JAX descobre a dimensão de batch e vetoriza a computação inteira.

Isso não é açúcar sintático. vmap gera código vetorizado fundido que roda de 10 a 100x mais rápido que um loop Python. E se compõe com jit e grad:

per_example_grads = jax.vmap(jax.grad(loss_fn), in_axes=(None, 0, 0))

Gradientes por exemplo. Uma linha. Isso é quase impossível em PyTorch sem gambiarras.

pmap: Paralelismo de Dados Entre Dispositivos

parallel_step = jax.pmap(train_step, axis_name='devices')

pmap replica a função em todos os dispositivos disponíveis (GPUs/TPUs) e divide o batch. Dentro da função, jax.lax.pmean e jax.lax.psum sincronizam os gradientes entre os dispositivos.

O Google treina o Gemini em milhares de chips TPU v5e usando pmap (e seu sucessor shard_map). O modelo de programação: escreva a versão para um único dispositivo, envolva com pmap, pronto.

Pytrees: A Estrutura de Dados Universal

O JAX opera sobre "pytrees" -- combinações aninhadas de listas, tuplas, dicts e arrays. Os parâmetros do seu modelo são uma pytree:

params = {
    'layer1': {'w': jnp.zeros((784, 256)), 'b': jnp.zeros(256)},
    'layer2': {'w': jnp.zeros((256, 128)), 'b': jnp.zeros(128)},
    'layer3': {'w': jnp.zeros((128, 10)),  'b': jnp.zeros(10)},
}

Toda transformação do JAX -- grad, jit, vmap -- sabe como percorrer pytrees. jax.tree.map(f, tree) aplica f a cada folha. É assim que otimizadores atualizam todos os parâmetros de uma vez:

params = jax.tree.map(lambda p, g: p - lr * g, params, grads)

Sem método .parameters(). Sem registro de parâmetros. A estrutura da árvore é o modelo.

Funcional vs Orientado a Objetos

O PyTorch armazena estado dentro de objetos:

class Model(nn.Module):
    def __init__(self):
        self.linear = nn.Linear(784, 10)

    def forward(self, x):
        return self.linear(x)

O JAX usa funções puras com estado explícito:

def predict(params, x):
    return jnp.dot(x, params['w']) + params['b']

Os params são passados como argumento. Nada é armazenado. Nada é mutado. Isso torna cada função testável, componível e compilável. Também significa que você gerencia os params você mesmo -- ou usa uma biblioteca como Flax ou Equinox.

O Ecossistema do JAX

O JAX te dá primitivas. As bibliotecas te dão ergonomia:

Biblioteca Papel Estilo
Flax (Google) Camadas de rede neural nn.Module com estado explícito
Equinox (Patrick Kidger) Camadas de rede neural Baseado em pytree, Pythônico
Optax (DeepMind) Otimizadores + schedules de LR Transforms de gradiente componíveis
Orbax (Google) Checkpointing Salvar/restaurar pytrees
CLU (Google) Métricas + logging Utilitários para loop de treinamento

O Optax é a biblioteca padrão de otimizadores. Ele separa a transformação do gradiente (Adam, SGD, clipping) da atualização dos parâmetros, tornando trivial compor:

optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adam(learning_rate=1e-3),
)

Quando Usar JAX vs PyTorch

Fator JAX PyTorch
Suporte a TPU First-class (o Google construiu ambos) Mantido pela comunidade (torch_xla)
Suporte a GPU Bom (CUDA via XLA) O melhor da categoria (CUDA nativo)
Debugging Difícil (tracing + compilação) Fácil (eager, linha por linha)
Ecossistema Focado em pesquisa (Flax, Equinox) Massivo (HuggingFace, torchvision, etc.)
Contratação Nicho (Google/DeepMind/Anthropic) Mainstream (em todo lugar)
Treinamento em larga escala Superior (XLA, pmap, mesh) Bom (FSDP, DeepSpeed)
Velocidade de prototipagem Mais lenta (overhead funcional) Mais rápida (muta e segue)
Inferência em produção TensorFlow Serving, Vertex AI TorchServe, Triton, ONNX
Quem usa DeepMind (Gemini), Anthropic (Claude) Meta (Llama), OpenAI (GPT), Stability AI

A resposta honesta: use PyTorch a menos que você tenha um motivo específico para usar JAX. Esses motivos são -- acesso a TPU, necessidade de gradientes por exemplo, treinamento multi-dispositivo em escala massiva, ou trabalhar no Google/DeepMind/Anthropic.

Números Aleatórios no JAX

O JAX não tem estado aleatório global. Toda operação aleatória exige uma key PRNG explícita:

key = jax.random.PRNGKey(42)
key1, key2 = jax.random.split(key)
w = jax.random.normal(key1, shape=(784, 256))

Isso é chato no começo. Mas garante reprodutibilidade entre dispositivos e compilações -- uma propriedade que o torch.manual_seed do PyTorch não consegue garantir em cenários multi-GPU.

Construa

Passo 1: Setup e Dados

Vamos treinar um MLP de 3 camadas no MNIST usando JAX e Optax. 784 entradas, duas camadas ocultas de 256 e 128 neurônios, 10 classes de saída.

import jax
import jax.numpy as jnp
from jax import random
import optax

def get_mnist_data():
    from sklearn.datasets import fetch_openml
    mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
    X = mnist.data.astype('float32') / 255.0
    y = mnist.target.astype('int')
    X_train, X_test = X[:60000], X[60000:]
    y_train, y_test = y[:60000], y[60000:]
    return X_train, y_train, X_test, y_test

Passo 2: Inicializar Parâmetros

Sem classe. Apenas uma função que retorna uma pytree:

def init_params(key):
    k1, k2, k3 = random.split(key, 3)
    scale1 = jnp.sqrt(2.0 / 784)
    scale2 = jnp.sqrt(2.0 / 256)
    scale3 = jnp.sqrt(2.0 / 128)
    params = {
        'layer1': {
            'w': scale1 * random.normal(k1, (784, 256)),
            'b': jnp.zeros(256),
        },
        'layer2': {
            'w': scale2 * random.normal(k2, (256, 128)),
            'b': jnp.zeros(128),
        },
        'layer3': {
            'w': scale3 * random.normal(k3, (128, 10)),
            'b': jnp.zeros(10),
        },
    }
    return params

Inicialização de He, feita manualmente. Três keys PRNG derivadas de uma seed. Cada peso é um array imutável dentro de um dict aninhado.

Passo 3: Forward Pass

def forward(params, x):
    x = jnp.dot(x, params['layer1']['w']) + params['layer1']['b']
    x = jax.nn.relu(x)
    x = jnp.dot(x, params['layer2']['w']) + params['layer2']['b']
    x = jax.nn.relu(x)
    x = jnp.dot(x, params['layer3']['w']) + params['layer3']['b']
    return x

def loss_fn(params, x, y):
    logits = forward(params, x)
    one_hot = jax.nn.one_hot(y, 10)
    return -jnp.mean(jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1))

Funções puras. Params entram, predição sai. Sem self, sem estado armazenado. loss_fn calcula a cross-entropy do zero -- softmax, log, média negativa.

Passo 4: Passo de Treinamento Compilado com JIT

@jax.jit
def train_step(params, opt_state, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

@jax.jit
def accuracy(params, x, y):
    logits = forward(params, x)
    preds = jnp.argmax(logits, axis=-1)
    return jnp.mean(preds == y)

jax.value_and_grad retorna tanto o valor da loss quanto os gradientes em uma única passada. O decorator @jax.jit compila ambas as funções para XLA. Após a primeira chamada, cada passo de treinamento roda sem tocar no Python.

Passo 5: Loop de Treinamento

optimizer = optax.adam(learning_rate=1e-3)

X_train, y_train, X_test, y_test = get_mnist_data()
X_train, X_test = jnp.array(X_train), jnp.array(X_test)
y_train, y_test = jnp.array(y_train), jnp.array(y_test)

key = random.PRNGKey(0)
params = init_params(key)
opt_state = optimizer.init(params)

batch_size = 128
n_epochs = 10

for epoch in range(n_epochs):
    key, subkey = random.split(key)
    perm = random.permutation(subkey, len(X_train))
    X_shuffled = X_train[perm]
    y_shuffled = y_train[perm]

    epoch_loss = 0.0
    n_batches = len(X_train) // batch_size
    for i in range(n_batches):
        start = i * batch_size
        xb = X_shuffled[start:start + batch_size]
        yb = y_shuffled[start:start + batch_size]
        params, opt_state, loss = train_step(params, opt_state, xb, yb)
        epoch_loss += loss

    train_acc = accuracy(params, X_train[:5000], y_train[:5000])
    test_acc = accuracy(params, X_test, y_test)
    print(f"Epoch {epoch + 1:2d} | Loss: {epoch_loss / n_batches:.4f} | "
          f"Train Acc: {train_acc:.4f} | Test Acc: {test_acc:.4f}")

10 épocas. ~97% de acurácia no teste. A primeira época é lenta (compilação JIT). As épocas 2-10 são rápidas.

Repare no que falta: sem .zero_grad(), sem .backward(), sem .step(). A atualização inteira é uma única chamada de função composta. Os gradientes são calculados, transformados pelo Adam e aplicados aos parâmetros -- tudo dentro de train_step.

Use

Flax: O Padrão do Google

O Flax é a biblioteca de rede neural JAX mais comum. Ele traz o nn.Module de volta, mas com gerenciamento de estado explícito:

import flax.linen as nn

class MLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        x = nn.Dense(128)(x)
        x = nn.relu(x)
        x = nn.Dense(10)(x)
        return x

model = MLP()
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 784)))
logits = model.apply(params, x_batch)

Mesma estrutura do PyTorch, mas params é separado do modelo. model.init() cria os params. model.apply(params, x) roda o forward pass. O objeto modelo não tem estado.

Equinox: A Alternativa Pythônica

O Equinox (de Patrick Kidger) representa modelos como pytrees:

import equinox as eqx

model = eqx.nn.MLP(
    in_size=784, out_size=10, width_size=256, depth=2,
    activation=jax.nn.relu, key=jax.random.PRNGKey(0)
)
logits = model(x)

O próprio modelo é uma pytree. Não precisa de .apply(). Os parâmetros são apenas as folhas do modelo. Isso é mais próximo de como o JAX pensa.

Optax: Otimizadores Componíveis

O Optax desacopla a transformação do gradiente da atualização:

schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0, peak_value=1e-3,
    warmup_steps=1000, decay_steps=50000
)

optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adamw(learning_rate=schedule, weight_decay=0.01),
)

Gradient clipping, warmup de learning rate, weight decay -- tudo composto como uma cadeia de transforms. Cada transform vê os gradientes, os modifica e os passa para o próximo. Sem classe de otimizador monolítica.

Coloque em Produção

Instalação:

pip install jax jaxlib optax flax

Para suporte a GPU:

pip install jax[cuda12]

Para TPU (Google Cloud):

pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Pegadinhas de performance:

  • A primeira chamada JIT é lenta (compilação). Faça um warm up antes de fazer benchmark.
  • Evite loops Python sobre arrays JAX dentro do JIT. Use jax.lax.scan ou jax.lax.fori_loop.
  • jax.debug.print() funciona dentro do JIT. O print() comum não.
  • Faça profiling com jax.profiler ou TensorBoard. A compilação XLA pode esconder gargalos.
  • O JAX pré-aloca 75% da memória da GPU por padrão. Defina XLA_PYTHON_CLIENT_PREALLOCATE=false para desabilitar.

Checkpointing:

import orbax.checkpoint as ocp
checkpointer = ocp.PyTreeCheckpointer()
checkpointer.save('/tmp/model', params)
restored = checkpointer.restore('/tmp/model')

Esta lição produz:

  • outputs/prompt-jax-optimizer.md -- um prompt para escolher a configuração certa de otimizador JAX
  • outputs/skill-jax-patterns.md -- uma skill cobrindo padrões funcionais em JAX

Exercícios

  1. Adicione dropout ao MLP. No JAX, o dropout exige uma key PRNG -- conduza uma key através do forward pass e a divida para cada camada de dropout. Compare a acurácia no teste com e sem.

  2. Use jax.vmap para calcular gradientes por exemplo para um batch de 32 imagens do MNIST. Calcule a norma do gradiente para cada exemplo. Quais exemplos têm os maiores gradientes, e por quê?

  3. Substitua a função forward manual por uma mlp_forward(params, x) genérica que funciona para qualquer número de camadas. Use jax.tree.leaves para determinar a profundidade automaticamente.

  4. Faça benchmark do passo de treinamento com e sem @jax.jit. Cronometre 100 passos de cada. Quão grande é o speedup no seu hardware? Qual é o overhead de compilação na primeira chamada?

  5. Implemente gradient clipping compondo optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-3)). Treine com e sem clipping. Plote a norma do gradiente ao longo do treinamento para ver o efeito.

Termos-Chave

Termo O que as pessoas dizem O que realmente significa
XLA "A coisa que deixa o JAX rápido" Accelerated Linear Algebra -- um compilador que funde operações e gera kernels de GPU/TPU otimizados a partir de um grafo de computação
JIT "Compilação just-in-time" O JAX rastreia a função na primeira chamada, compila para XLA e então roda a versão compilada nas chamadas seguintes
Função pura "Sem efeitos colaterais" Uma função cuja saída depende apenas das entradas -- sem estado global, sem mutação, sem aleatoriedade sem keys explícitas
vmap "Auto-batching" Transforma uma função que processa um exemplo em uma que processa um batch, sem reescrever
pmap "Auto-paralelismo" Replica uma função em múltiplos dispositivos e divide o batch de entrada
Pytree "Dict aninhado de arrays" Qualquer estrutura aninhada de listas, tuplas, dicts e arrays que o JAX consegue percorrer e transformar
Tracing "Gravar a computação" O JAX executa a função com valores abstratos para construir um grafo de computação, sem calcular resultados reais
Autodiff funcional "grad de uma função" Calcular derivadas transformando funções, não anexando armazenamento de gradiente aos tensores
Optax "A biblioteca de otimizadores do JAX" Uma biblioteca componível de transformações de gradiente -- Adam, SGD, clipping, scheduling -- que se encadeiam
Flax "O nn.Module do JAX" A biblioteca de rede neural do Google para JAX, que adiciona abstrações de camadas mantendo o estado explícito

Leitura Adicional

  • Documentação do JAX: https://jax.readthedocs.io/ -- os docs oficiais, com excelentes tutoriais sobre grad, jit e vmap
  • "JAX: composable transformations of Python+NumPy programs" (Bradbury et al., 2018) -- o paper original explicando a filosofia de design
  • Documentação do Flax: https://flax.readthedocs.io/ -- a biblioteca de rede neural do Google para JAX
  • Patrick Kidger, "Equinox: neural networks in JAX via callable PyTrees and filtered transformations" (2021) -- a alternativa Pythônica ao Flax
  • DeepMind, "Optax: composable gradient transformation and optimisation" -- a biblioteca padrão de otimizadores
  • "You Don't Know JAX" (Colin Raffel, 2020) -- um guia prático sobre pegadinhas e padrões do JAX, de um dos autores do T5
0 lifetime access. Curriculum based on AI Engineering from Scratch by Rohit Ghumare (MIT, used under attribution).