Phase 04 - Lesson 23

Diffusion Transformers e Rectified Flow

A U-Net não é o segredo da difusão. Troque-a por um transformer, substitua o cronograma de ruído por um fluxo em linha reta e, de repente, você tem SD3, FLUX e todos os modelos text-to-image de 2026.

Tipo: Aprender + Construir Linguagens: Python Pré-requisitos: Fase 4 Aula 10 (Diffusion DDPM), Fase 4 Aula 14 (ViT), Fase 7 Aula 02 (Self-Attention) Tempo: ~75 minutos

Objetivos de aprendizagem

  • Traçar a evolução desde o U-Net DDPM (Aula 10) até o Diffusion Transformer (DiT), o MMDiT (SD3) e o DiT single+double-stream (FLUX)
  • Explicar o rectified flow: por que uma trajetória em linha reta entre ruído e dados permite que os modelos amostrem em 20 passos em vez de 1000
  • Implementar um pequeno bloco DiT e um loop de treinamento de rectified flow, ambos com menos de 100 linhas
  • Distinguir variantes de modelo (SD3, FLUX.1-dev, FLUX.1-schnell, Z-Image, Qwen-Image) por arquitetura, contagem de parâmetros e licenciamento

O problema

A Aula 10 construiu um DDPM com um denoiser U-Net. Essa receita dominou de 2020 a 2023: U-Net + cronograma beta + perda de predição de ruído. Ela produziu o Stable Diffusion 1.5 e 2.1 e o DALL-E 2.

Todo modelo text-to-image estado-da-arte de 2026 superou essa abordagem. Stable Diffusion 3, FLUX, SD4, Z-Image, Qwen-Image, Hunyuan-Image — nenhum usa uma U-Net. Eles usam Diffusion Transformers (DiT). SD3 e FLUX também trocam o cronograma de ruído do DDPM pelo rectified flow, que retifica o caminho do ruído até os dados e habilita inferência de 1 a 4 passos com variantes de consistência ou destiladas.

A mudança importa porque é a razão pela qual a geração de imagens baseada em difusão se tornou controlável, fiel ao prompt (SD3/SD4 resolveram a renderização de texto) e rápida o suficiente para produção. Entender DiT + rectified flow é entender a stack de geração de imagens de 2026.

O conceito

Do U-Net ao transformer

flowchart LR
    subgraph UNET["DDPM U-Net (2020)"]
        U1["Encoder convolucional"] --> U2["Gargalo convolucional"] --> U3["Decoder convolucional"]
    end
    subgraph DIT["DiT (2023)"]
        D1["Patch embed"] --> D2["Blocos transformer"] --> D3["Unpatchify"]
    end
    subgraph MMDIT["MMDiT (SD3, 2024)"]
        M1["Fluxo de texto"] --> M3["Atenção conjunta<br/>(pesos separados por modalidade)"]
        M2["Fluxo de imagem"] --> M3
    end
    subgraph FLUX["FLUX (2024)"]
        F1["Blocos double-stream<br/>(texto + imagem separados)"] --> F2["Blocos single-stream<br/>(concat + pesos compartilhados)"]
    end

    style UNET fill:#e5e7eb,stroke:#6b7280
    style DIT fill:#dbeafe,stroke:#2563eb
    style MMDIT fill:#fef3c7,stroke:#d97706
    style FLUX fill:#dcfce7,stroke:#16a34a
  • DiT (Peebles & Xie, 2023) — substitui a U-Net por um transformer estilo ViT sobre patches latentes. Condicionamento via adaptive layer norm (AdaLN).
  • MMDiT (SD3, Esser et al., 2024) — dois fluxos com pesos separados para tokens de texto e imagem que compartilham uma atenção conjunta.
  • FLUX (Black Forest Labs, 2024) — os primeiros N blocos são double-stream como no SD3, e os blocos posteriores concatenam e compartilham pesos (single-stream) para eficiência em maior profundidade.
  • Z-Image (2025) — um DiT single-stream eficiente com 6B de parâmetros que desafia o "escalar a todo custo".

Rectified flow em um parágrafo

O DDPM define o processo forward como uma SDE ruidosa em que x_t é progressivamente corrompido. O reverso aprendido é uma segunda SDE, resolvida por 1000 pequenos passos.

O rectified flow define uma interpolação em linha reta entre dados limpos e ruído puro:

x_t = (1 - t) * x_0 + t * epsilon,     t in [0, 1]

Treine uma rede para prever a velocidade v_theta(x_t, t) = epsilon - x_0 — a direção forward ao longo do caminho em linha reta dos dados limpos até o ruído (dx_t/dt). Durante a amostragem, você integra essa velocidade no sentido reverso para avançar do ruído em direção aos dados. A ODE resultante fica muito mais próxima de uma linha reta, então são necessários muito menos passos de integração para amostrar.

O SD3 chama isso de Rectified Flow Matching. FLUX, Z-Image e a maioria dos modelos de 2026 usam o mesmo objetivo. Inferência típica: 20-30 passos de Euler (determinística) versus 50+ passos de DDIM no antigo regime DDPM. Variantes destiladas / turbo / schnell / LCM reduzem isso para 1-4 passos.

Condicionamento AdaLN

DiTs condicionam no timestep e na classe/texto via adaptive layer norm: preveem scale e shift a partir do vetor de condicionamento e os aplicam após o LayerNorm. Muito mais limpo do que a modulação estilo FiLM nas U-Nets e é o padrão em todo DiT moderno.

cond -> MLP -> (scale, shift, gate)
norm(x) * (1 + scale) + shift, then residual add * gate

Encoders de texto no SD3 e no FLUX

  • SD3 usa três encoders de texto: dois modelos CLIP + T5-XXL. As embeddings são concatenadas e alimentadas no fluxo de imagem como condicionamento de texto.
  • FLUX usa um CLIP-L + T5-XXL.
  • As variantes Qwen-Image / Z-Image usam seus próprios encoders de texto internos, alinhados com seus LLMs base.

O encoder de texto é grande parte do motivo pelo qual SD3/FLUX raciocinam tão melhor sobre prompts do que o SD1.5. O T5-XXL sozinho tem 4,7B de parâmetros.

Classifier-free guidance continua valendo

O rectified flow muda o sampler, não o condicionamento. O classifier-free guidance (descartar o texto com 10% de probabilidade durante o treinamento, misturar predições condicionais e incondicionais na inferência) funciona de forma idêntica com rectified flow. A maioria dos modelos de 2026 usa guidance scale 3.5-5 — menor do que o 7.5 do SD1.5, porque os modelos de rectified flow seguem os prompts de forma mais fiel por padrão.

Consistency, Turbo, Schnell, LCM

Quatro nomes para a mesma ideia: destilar um modelo lento de muitos passos em um modelo rápido de poucos passos.

  • LCM (Latent Consistency Model) — treina um aluno que prevê o x_0 final a partir de qualquer x_t intermediário em um único passo.
  • SDXL Turbo / FLUX schnell — modelos de 1-4 passos treinados com adversarial diffusion distillation.
  • SD Turbo — Consistency Models no estilo OpenAI adaptados à difusão latente.

O serviço em produção de qualquer modelo novo entrega tanto um checkpoint de "qualidade total" quanto uma variante "turbo / schnell". O Schnell ("rápido" em alemão, convenção da Black Forest Labs) roda em 1-4 passos e se encaixa em pipelines de tempo real.

O panorama de modelos em 2026

Modelo Tamanho Arquitetura Termos
Stable Diffusion 3 Medium 2B MMDiT SAI Community
Stable Diffusion 3.5 Large 8B MMDiT SAI Community
FLUX.1-dev 12B DiT Double + Single Stream não comercial
FLUX.1-schnell 12B mesma, destilada Apache 2.0
FLUX.2 FLUX.1 iterado mista
Z-Image 6B S3-DiT (Scalable Single-Stream) permissiva
Qwen-Image ~20B DiT + torre de texto Qwen Apache 2.0
Hunyuan-Image-3.0 ~80B DiT pesquisa
SD4 Turbo 3B DiT + destilação SAI Commercial

O FLUX.1-schnell é o padrão open-source de 2026. O Z-Image é o líder em eficiência. FLUX.2 e SD4 são as referências atuais de qualidade.

Por que essa mudança de fase importa

DDPM + U-Net funcionava. DiT + rectified flow funciona melhor, mais rápido e escala de forma mais limpa. A transição se assemelha à dos RNNs para os transformers em PLN: ambas as arquiteturas resolviam o mesmo problema, mas os transformers escalaram e agora dominam. Todo artigo de 2026 sobre geração de imagem, vídeo ou 3D usa um denoiser em formato de DiT e, em geral, um objetivo de rectified flow. O DDPM com U-Net hoje é principalmente pedagógico (Aula 10).

Construa

Passo 1: Um bloco DiT com AdaLN

import torch
import torch.nn as nn


class AdaLNZero(nn.Module):
    """
    Adaptive LayerNorm with a gate. Predicts (scale, shift, gate) from the conditioning.
    Init such that the whole block starts as identity ("zero init").
    """

    def __init__(self, dim, cond_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim, elementwise_affine=False)
        self.mlp = nn.Linear(cond_dim, dim * 3)
        nn.init.zeros_(self.mlp.weight)
        nn.init.zeros_(self.mlp.bias)

    def forward(self, x, cond):
        scale, shift, gate = self.mlp(cond).chunk(3, dim=-1)
        h = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
        return h, gate.unsqueeze(1)


class DiTBlock(nn.Module):
    def __init__(self, dim=192, heads=3, mlp_ratio=4, cond_dim=192):
        super().__init__()
        self.adaln1 = AdaLNZero(dim, cond_dim)
        self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
        self.adaln2 = AdaLNZero(dim, cond_dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * mlp_ratio),
            nn.GELU(),
            nn.Linear(dim * mlp_ratio, dim),
        )

    def forward(self, x, cond):
        h, gate1 = self.adaln1(x, cond)
        a, _ = self.attn(h, h, h, need_weights=False)
        x = x + gate1 * a
        h, gate2 = self.adaln2(x, cond)
        x = x + gate2 * self.mlp(h)
        return x

O AdaLNZero começa como um mapeamento identidade porque os pesos de seu MLP são inicializados em zero. O treinamento afasta o bloco da identidade; isso estabiliza dramaticamente os modelos de difusão transformer profundos.

Passo 2: Um DiT minúsculo

def timestep_embedding(t, dim):
    import math
    half = dim // 2
    freqs = torch.exp(-math.log(10000) * torch.arange(half, device=t.device) / half)
    args = t[:, None].float() * freqs[None]
    return torch.cat([args.sin(), args.cos()], dim=-1)


class TinyDiT(nn.Module):
    def __init__(self, image_size=16, patch_size=2, in_channels=3, dim=96, depth=4, heads=3):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.patch = nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size)
        self.pos = nn.Parameter(torch.zeros(1, self.num_patches, dim))
        self.time_mlp = nn.Sequential(
            nn.Linear(dim, dim * 2),
            nn.SiLU(),
            nn.Linear(dim * 2, dim),
        )
        self.blocks = nn.ModuleList([DiTBlock(dim, heads, cond_dim=dim) for _ in range(depth)])
        self.norm_out = nn.LayerNorm(dim, elementwise_affine=False)
        self.head = nn.Linear(dim, patch_size * patch_size * in_channels)

    def forward(self, x, t):
        n = x.size(0)
        x = self.patch(x)
        x = x.flatten(2).transpose(1, 2) + self.pos
        t_emb = self.time_mlp(timestep_embedding(t, self.pos.size(-1)))
        for blk in self.blocks:
            x = blk(x, t_emb)
        x = self.norm_out(x)
        x = self.head(x)
        return self._unpatchify(x, n)

    def _unpatchify(self, x, n):
        p = self.patch_size
        h = w = int(self.num_patches ** 0.5)
        x = x.view(n, h, w, p, p, -1).permute(0, 5, 1, 3, 2, 4).reshape(n, -1, h * p, w * p)
        return x

Passo 3: Treinamento com rectified flow

import torch.nn.functional as F

def rectified_flow_train_step(model, x0, optimizer, device):
    model.train()
    x0 = x0.to(device)
    n = x0.size(0)
    t = torch.rand(n, device=device)
    epsilon = torch.randn_like(x0)
    x_t = (1 - t[:, None, None, None]) * x0 + t[:, None, None, None] * epsilon

    target_velocity = epsilon - x0
    pred_velocity = model(x_t, t)

    loss = F.mse_loss(pred_velocity, target_velocity)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

Compare com a perda de predição de ruído do DDPM (Aula 10): mesma estrutura, alvo diferente. Em vez de prever o ruído epsilon, prevemos a velocidade epsilon - x_0, que aponta dos dados para o ruído ao longo da interpolação em linha reta.

Passo 4: Sampler de Euler

O rectified flow é uma ODE. O método de Euler é o mais simples e, para um modelo de rectified flow bem treinado, quase tão preciso quanto solvers de ordem superior em 20+ passos.

@torch.no_grad()
def rectified_flow_sample(model, shape, steps=20, device="cpu"):
    model.eval()
    x = torch.randn(shape, device=device)
    dt = 1.0 / steps
    t = torch.ones(shape[0], device=device)
    for _ in range(steps):
        v = model(x, t)
        x = x - dt * v
        t = t - dt
    return x

20 passos. Em um modelo treinado, isso produz amostras comparáveis a um DDPM de 1000 passos.

Passo 5: Smoke test ponta a ponta

import numpy as np

def synthetic_blobs(num=200, size=16, seed=0):
    rng = np.random.default_rng(seed)
    out = np.zeros((num, 3, size, size), dtype=np.float32)
    yy, xx = np.meshgrid(np.arange(size), np.arange(size), indexing="ij")
    for i in range(num):
        cx, cy = rng.uniform(4, size - 4, size=2)
        r = rng.uniform(2, 4)
        mask = (xx - cx) ** 2 + (yy - cy) ** 2 < r ** 2
        colour = rng.uniform(-1, 1, size=3)
        for c in range(3):
            out[i, c][mask] = colour[c]
    return torch.from_numpy(out)

Treine um TinyDiT nisso com rectified flow. Após 500 passos, as saídas amostradas devem parecer manchas tênues de cor.

Use

Para geração real de imagens com FLUX / SD3 / Z-Image, o diffusers entrega cada um deles com uma API unificada:

from diffusers import FluxPipeline, StableDiffusion3Pipeline
import torch

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell",
    torch_dtype=torch.bfloat16,
).to("cuda")

out = pipe(
    prompt="a golden retriever surfing a tsunami, hyperrealistic, studio lighting",
    guidance_scale=0.0,           # schnell was trained without CFG
    num_inference_steps=4,
    max_sequence_length=256,
).images[0]
out.save("surf.png")

Três linhas. FLUX.1-schnell em quatro passos. Troque o id do modelo por black-forest-labs/FLUX.1-dev para maior qualidade em 20-30 passos com CFG.

Para o SD3:

pipe = StableDiffusion3Pipeline.from_pretrained(
    "stabilityai/stable-diffusion-3.5-large",
    torch_dtype=torch.bfloat16,
).to("cuda")
out = pipe(prompt, guidance_scale=3.5, num_inference_steps=28).images[0]

Entregue

Esta aula produz:

  • outputs/prompt-dit-model-picker.md — escolhe entre SD3, FLUX.1-dev, FLUX.1-schnell, Z-Image, SD4 Turbo dados qualidade, latência e termos de uso.
  • outputs/skill-rectified-flow-trainer.md — escreve um loop de treinamento completo para rectified flow com DiT AdaLN e amostragem de Euler.

Exercícios

  1. (Fácil) Treine o TinyDiT acima no dataset sintético de blobs por 500 passos. Compare as amostras produzidas com 10, 20 e 50 passos de Euler.
  2. (Médio) Adicione condicionamento de texto concatenando uma embedding de classe aprendida à embedding de tempo (10 "classes" de blob por cor). Amostre com as classes 0, 5 e 9 e verifique se as cores correspondem.
  3. (Difícil) Compute a distância de Fréchet (proxy de FID) entre amostras geradas pelas versões de rectified-flow e de DDPM da mesma rede de mesmo tamanho, treinadas com os mesmos dados pelo mesmo número de passos. Reporte qual converge mais rápido.

Termos-chave

Termo O que as pessoas dizem O que realmente significa
DiT "Diffusion transformer" Transformer que substitui a U-Net como denoiser de difusão; opera sobre latentes em patches
AdaLN "Adaptive layer norm" Condicionamento de timestep/texto via scale, shift e gate aprendidos, aplicados após o LayerNorm; padrão em todo DiT moderno
MMDiT "DiT multimodal (SD3)" Fluxos de pesos separados para tokens de texto e imagem que compartilham uma self-attention conjunta
Single-stream / double-stream "Truque do FLUX" Os primeiros N blocos são double-stream (pesos separados por modalidade), os blocos posteriores são single-stream (concat + pesos compartilhados) para eficiência
Rectified flow "Linha reta de ruído a dados" Interpolação linear entre dados e ruído; a rede prevê a velocidade; menos passos de ODE necessários na inferência
Alvo de velocidade "epsilon - x_0" O alvo de regressão no rectified flow; aponta dos dados limpos para o ruído
CFG guidance "classifier-free guidance" Mistura predições condicionais e incondicionais; ainda usado em modelos de rectified flow
Schnell / turbo / LCM "Destilação de 1-4 passos" Variantes de poucos passos destiladas de modelos de qualidade total; tempo real em produção

Leitura adicional

0 lifetime access. Curriculum based on AI Engineering from Scratch by Rohit Ghumare (MIT, used under attribution).