Phase 04 - Lesson 23

Diffusion Transformers y Rectified Flow

La U-Net no es el secreto de la difusión. Reemplázala por un transformer, cambia el cronograma de ruido por un flujo en línea recta y, de repente, tienes SD3, FLUX y todos los modelos text-to-image de 2026.

Tipo: Aprender + Construir Lenguajes: Python Requisitos previos: Fase 4 Lección 10 (Diffusion DDPM), Fase 4 Lección 14 (ViT), Fase 7 Lección 02 (Self-Attention) Tiempo: ~75 minutos

Objetivos de aprendizaje

  • Trazar la evolución desde la U-Net DDPM (Lección 10) hasta el Diffusion Transformer (DiT), el MMDiT (SD3) y el DiT single+double-stream (FLUX)
  • Explicar el rectified flow: por qué una trayectoria en línea recta entre ruido y datos permite que los modelos muestreen en 20 pasos en lugar de 1000
  • Implementar un pequeño bloque DiT y un bucle de entrenamiento de rectified flow, ambos con menos de 100 líneas
  • Distinguir variantes de modelo (SD3, FLUX.1-dev, FLUX.1-schnell, Z-Image, Qwen-Image) por arquitectura, cantidad de parámetros y licenciamiento

El problema

La Lección 10 construyó un DDPM con un denoiser U-Net. Esa receta dominó de 2020 a 2023: U-Net + cronograma beta + pérdida de predicción de ruido. Produjo Stable Diffusion 1.5 y 2.1 y DALL-E 2.

Todo modelo text-to-image de última generación de 2026 ha superado esa abordaje. Stable Diffusion 3, FLUX, SD4, Z-Image, Qwen-Image, Hunyuan-Image — ninguno usa una U-Net. Usan Diffusion Transformers (DiT). SD3 y FLUX también cambian el cronograma de ruido del DDPM por el rectified flow, que rectifica el camino del ruido a los datos y habilita inferencia de 1 a 4 pasos con variantes de consistencia o destiladas.

El cambio importa porque es la razón por la cual la generación de imágenes basada en difusión se volvió controlable, fiel al prompt (SD3/SD4 resolvieron el renderizado de texto) y lo bastante rápida para producción. Entender DiT + rectified flow es entender el stack de generación de imágenes de 2026.

El concepto

De la U-Net al transformer

flowchart LR
    subgraph UNET["DDPM U-Net (2020)"]
        U1["Encoder convolucional"] --> U2["Cuello de botella convolucional"] --> U3["Decoder convolucional"]
    end
    subgraph DIT["DiT (2023)"]
        D1["Patch embed"] --> D2["Bloques transformer"] --> D3["Unpatchify"]
    end
    subgraph MMDIT["MMDiT (SD3, 2024)"]
        M1["Flujo de texto"] --> M3["Atención conjunta<br/>(pesos separados por modalidad)"]
        M2["Flujo de imagen"] --> M3
    end
    subgraph FLUX["FLUX (2024)"]
        F1["Bloques double-stream<br/>(texto + imagen separados)"] --> F2["Bloques single-stream<br/>(concat + pesos compartidos)"]
    end

    style UNET fill:#e5e7eb,stroke:#6b7280
    style DIT fill:#dbeafe,stroke:#2563eb
    style MMDIT fill:#fef3c7,stroke:#d97706
    style FLUX fill:#dcfce7,stroke:#16a34a
  • DiT (Peebles & Xie, 2023) — reemplaza la U-Net por un transformer estilo ViT sobre patches latentes. Condicionamiento vía adaptive layer norm (AdaLN).
  • MMDiT (SD3, Esser et al., 2024) — dos flujos con pesos separados para tokens de texto e imagen que comparten una atención conjunta.
  • FLUX (Black Forest Labs, 2024) — los primeros N bloques son double-stream como en SD3, y los bloques posteriores concatenan y comparten pesos (single-stream) para mayor eficiencia a más profundidad.
  • Z-Image (2025) — un DiT single-stream eficiente con 6B de parámetros que desafía el "escalar a cualquier costo".

Rectified flow en un párrafo

El DDPM define el proceso forward como una SDE ruidosa en la que x_t se corrompe progresivamente. El reverso aprendido es una segunda SDE, resuelta con 1000 pasos pequeños.

El rectified flow define una interpolación en línea recta entre datos limpios y ruido puro:

x_t = (1 - t) * x_0 + t * epsilon,     t in [0, 1]

Entrena una red para predecir la velocidad v_theta(x_t, t) = epsilon - x_0 — la dirección forward a lo largo del camino en línea recta desde los datos limpios hasta el ruido (dx_t/dt). Durante el muestreo, integras esta velocidad en sentido reverso para avanzar del ruido hacia los datos. La ODE resultante queda mucho más cerca de una línea recta, por lo que se necesitan muchos menos pasos de integración para muestrear.

SD3 llama a esto Rectified Flow Matching. FLUX, Z-Image y la mayoría de los modelos de 2026 usan el mismo objetivo. Inferencia típica: 20-30 pasos de Euler (determinista) frente a 50+ pasos de DDIM en el antiguo régimen DDPM. Las variantes destiladas / turbo / schnell / LCM lo reducen a 1-4 pasos.

Condicionamiento AdaLN

Los DiTs condicionan en el timestep y en la clase/texto vía adaptive layer norm: predicen scale y shift a partir del vector de condicionamiento y los aplican después del LayerNorm. Mucho más limpio que la modulación estilo FiLM en las U-Nets y es el estándar en todo DiT moderno.

cond -> MLP -> (scale, shift, gate)
norm(x) * (1 + scale) + shift, then residual add * gate

Encoders de texto en SD3 y en FLUX

  • SD3 usa tres encoders de texto: dos modelos CLIP + T5-XXL. Las embeddings se concatenan y se alimentan al flujo de imagen como condicionamiento de texto.
  • FLUX usa un CLIP-L + T5-XXL.
  • Las variantes Qwen-Image / Z-Image usan sus propios encoders de texto internos, alineados con sus LLMs base.

El encoder de texto es gran parte de la razón por la cual SD3/FLUX razonan tanto mejor sobre prompts que SD1.5. T5-XXL por sí solo tiene 4,7B de parámetros.

El classifier-free guidance sigue vigente

El rectified flow cambia el sampler, no el condicionamiento. El classifier-free guidance (descartar el texto con 10% de probabilidad durante el entrenamiento, mezclar predicciones condicionales e incondicionales en la inferencia) funciona de forma idéntica con rectified flow. La mayoría de los modelos de 2026 usan guidance scale 3.5-5 — menor que el 7.5 de SD1.5, porque los modelos de rectified flow siguen los prompts de forma más fiel por defecto.

Consistency, Turbo, Schnell, LCM

Cuatro nombres para la misma idea: destilar un modelo lento de muchos pasos en un modelo rápido de pocos pasos.

  • LCM (Latent Consistency Model) — entrena un estudiante que predice el x_0 final a partir de cualquier x_t intermedio en un solo paso.
  • SDXL Turbo / FLUX schnell — modelos de 1-4 pasos entrenados con adversarial diffusion distillation.
  • SD Turbo — Consistency Models al estilo OpenAI adaptados a la difusión latente.

El servicio en producción de cualquier modelo nuevo entrega tanto un checkpoint de "calidad total" como una variante "turbo / schnell". Schnell ("rápido" en alemán, convención de Black Forest Labs) corre en 1-4 pasos y encaja en pipelines de tiempo real.

El panorama de modelos en 2026

Modelo Tamaño Arquitectura Términos
Stable Diffusion 3 Medium 2B MMDiT SAI Community
Stable Diffusion 3.5 Large 8B MMDiT SAI Community
FLUX.1-dev 12B DiT Double + Single Stream no comercial
FLUX.1-schnell 12B la misma, destilada Apache 2.0
FLUX.2 FLUX.1 iterado mixta
Z-Image 6B S3-DiT (Scalable Single-Stream) permisiva
Qwen-Image ~20B DiT + torre de texto Qwen Apache 2.0
Hunyuan-Image-3.0 ~80B DiT investigación
SD4 Turbo 3B DiT + destilación SAI Commercial

FLUX.1-schnell es el estándar open-source de 2026. Z-Image es el líder en eficiencia. FLUX.2 y SD4 son las referencias actuales de calidad.

Por qué importa este cambio de fase

DDPM + U-Net funcionaba. DiT + rectified flow funciona mejor, más rápido y escala de forma más limpia. La transición se asemeja a la de los RNNs a los transformers en PLN: ambas arquitecturas resolvían el mismo problema, pero los transformers escalaron y ahora dominan. Todo artículo de 2026 sobre generación de imagen, video o 3D usa un denoiser con forma de DiT y, por lo general, un objetivo de rectified flow. El DDPM con U-Net hoy es principalmente pedagógico (Lección 10).

Constrúyelo

Paso 1: Un bloque DiT con AdaLN

import torch
import torch.nn as nn


class AdaLNZero(nn.Module):
    """
    Adaptive LayerNorm with a gate. Predicts (scale, shift, gate) from the conditioning.
    Init such that the whole block starts as identity ("zero init").
    """

    def __init__(self, dim, cond_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim, elementwise_affine=False)
        self.mlp = nn.Linear(cond_dim, dim * 3)
        nn.init.zeros_(self.mlp.weight)
        nn.init.zeros_(self.mlp.bias)

    def forward(self, x, cond):
        scale, shift, gate = self.mlp(cond).chunk(3, dim=-1)
        h = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
        return h, gate.unsqueeze(1)


class DiTBlock(nn.Module):
    def __init__(self, dim=192, heads=3, mlp_ratio=4, cond_dim=192):
        super().__init__()
        self.adaln1 = AdaLNZero(dim, cond_dim)
        self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
        self.adaln2 = AdaLNZero(dim, cond_dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * mlp_ratio),
            nn.GELU(),
            nn.Linear(dim * mlp_ratio, dim),
        )

    def forward(self, x, cond):
        h, gate1 = self.adaln1(x, cond)
        a, _ = self.attn(h, h, h, need_weights=False)
        x = x + gate1 * a
        h, gate2 = self.adaln2(x, cond)
        x = x + gate2 * self.mlp(h)
        return x

AdaLNZero comienza como un mapeo identidad porque los pesos de su MLP se inicializan en cero. El entrenamiento aleja el bloque de la identidad; esto estabiliza dramáticamente los modelos de difusión transformer profundos.

Paso 2: Un DiT diminuto

def timestep_embedding(t, dim):
    import math
    half = dim // 2
    freqs = torch.exp(-math.log(10000) * torch.arange(half, device=t.device) / half)
    args = t[:, None].float() * freqs[None]
    return torch.cat([args.sin(), args.cos()], dim=-1)


class TinyDiT(nn.Module):
    def __init__(self, image_size=16, patch_size=2, in_channels=3, dim=96, depth=4, heads=3):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.patch = nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size)
        self.pos = nn.Parameter(torch.zeros(1, self.num_patches, dim))
        self.time_mlp = nn.Sequential(
            nn.Linear(dim, dim * 2),
            nn.SiLU(),
            nn.Linear(dim * 2, dim),
        )
        self.blocks = nn.ModuleList([DiTBlock(dim, heads, cond_dim=dim) for _ in range(depth)])
        self.norm_out = nn.LayerNorm(dim, elementwise_affine=False)
        self.head = nn.Linear(dim, patch_size * patch_size * in_channels)

    def forward(self, x, t):
        n = x.size(0)
        x = self.patch(x)
        x = x.flatten(2).transpose(1, 2) + self.pos
        t_emb = self.time_mlp(timestep_embedding(t, self.pos.size(-1)))
        for blk in self.blocks:
            x = blk(x, t_emb)
        x = self.norm_out(x)
        x = self.head(x)
        return self._unpatchify(x, n)

    def _unpatchify(self, x, n):
        p = self.patch_size
        h = w = int(self.num_patches ** 0.5)
        x = x.view(n, h, w, p, p, -1).permute(0, 5, 1, 3, 2, 4).reshape(n, -1, h * p, w * p)
        return x

Paso 3: Entrenamiento con rectified flow

import torch.nn.functional as F

def rectified_flow_train_step(model, x0, optimizer, device):
    model.train()
    x0 = x0.to(device)
    n = x0.size(0)
    t = torch.rand(n, device=device)
    epsilon = torch.randn_like(x0)
    x_t = (1 - t[:, None, None, None]) * x0 + t[:, None, None, None] * epsilon

    target_velocity = epsilon - x0
    pred_velocity = model(x_t, t)

    loss = F.mse_loss(pred_velocity, target_velocity)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

Compara con la pérdida de predicción de ruido del DDPM (Lección 10): la misma estructura, distinto objetivo. En lugar de predecir el ruido epsilon, predecimos la velocidad epsilon - x_0, que apunta desde los datos hacia el ruido a lo largo de la interpolación en línea recta.

Paso 4: Sampler de Euler

El rectified flow es una ODE. El método de Euler es el más simple y, para un modelo de rectified flow bien entrenado, casi tan preciso como solvers de orden superior con 20+ pasos.

@torch.no_grad()
def rectified_flow_sample(model, shape, steps=20, device="cpu"):
    model.eval()
    x = torch.randn(shape, device=device)
    dt = 1.0 / steps
    t = torch.ones(shape[0], device=device)
    for _ in range(steps):
        v = model(x, t)
        x = x - dt * v
        t = t - dt
    return x

20 pasos. En un modelo entrenado, esto produce muestras comparables a un DDPM de 1000 pasos.

Paso 5: Smoke test de extremo a extremo

import numpy as np

def synthetic_blobs(num=200, size=16, seed=0):
    rng = np.random.default_rng(seed)
    out = np.zeros((num, 3, size, size), dtype=np.float32)
    yy, xx = np.meshgrid(np.arange(size), np.arange(size), indexing="ij")
    for i in range(num):
        cx, cy = rng.uniform(4, size - 4, size=2)
        r = rng.uniform(2, 4)
        mask = (xx - cx) ** 2 + (yy - cy) ** 2 < r ** 2
        colour = rng.uniform(-1, 1, size=3)
        for c in range(3):
            out[i, c][mask] = colour[c]
    return torch.from_numpy(out)

Entrena un TinyDiT con esto usando rectified flow. Después de 500 pasos, las salidas muestreadas deberían parecer manchas tenues de color.

Úsalo

Para la generación real de imágenes con FLUX / SD3 / Z-Image, diffusers entrega cada uno de ellos con una API unificada:

from diffusers import FluxPipeline, StableDiffusion3Pipeline
import torch

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell",
    torch_dtype=torch.bfloat16,
).to("cuda")

out = pipe(
    prompt="a golden retriever surfing a tsunami, hyperrealistic, studio lighting",
    guidance_scale=0.0,           # schnell was trained without CFG
    num_inference_steps=4,
    max_sequence_length=256,
).images[0]
out.save("surf.png")

Tres líneas. FLUX.1-schnell en cuatro pasos. Cambia el id del modelo por black-forest-labs/FLUX.1-dev para mayor calidad en 20-30 pasos con CFG.

Para SD3:

pipe = StableDiffusion3Pipeline.from_pretrained(
    "stabilityai/stable-diffusion-3.5-large",
    torch_dtype=torch.bfloat16,
).to("cuda")
out = pipe(prompt, guidance_scale=3.5, num_inference_steps=28).images[0]

Entrégalo

Esta lección produce:

  • outputs/prompt-dit-model-picker.md — elige entre SD3, FLUX.1-dev, FLUX.1-schnell, Z-Image, SD4 Turbo dados la calidad, la latencia y los términos de uso.
  • outputs/skill-rectified-flow-trainer.md — escribe un bucle de entrenamiento completo para rectified flow con DiT AdaLN y muestreo de Euler.

Ejercicios

  1. (Fácil) Entrena el TinyDiT anterior en el dataset sintético de blobs durante 500 pasos. Compara las muestras producidas con 10, 20 y 50 pasos de Euler.
  2. (Medio) Agrega condicionamiento de texto concatenando una embedding de clase aprendida a la embedding de tiempo (10 "clases" de blob por color). Muestrea con las clases 0, 5 y 9 y verifica que los colores coincidan.
  3. (Difícil) Calcula la distancia de Fréchet (proxy de FID) entre muestras generadas por las versiones de rectified-flow y de DDPM de la misma red del mismo tamaño, entrenadas con los mismos datos durante el mismo número de pasos. Reporta cuál converge más rápido.

Términos clave

Término Lo que la gente dice Lo que realmente significa
DiT "Diffusion transformer" Transformer que reemplaza a la U-Net como denoiser de difusión; opera sobre latentes en patches
AdaLN "Adaptive layer norm" Condicionamiento de timestep/texto vía scale, shift y gate aprendidos, aplicados después del LayerNorm; estándar en todo DiT moderno
MMDiT "DiT multimodal (SD3)" Flujos de pesos separados para tokens de texto e imagen que comparten una self-attention conjunta
Single-stream / double-stream "Truco de FLUX" Los primeros N bloques son double-stream (pesos separados por modalidad), los bloques posteriores son single-stream (concat + pesos compartidos) para mayor eficiencia
Rectified flow "Línea recta de ruido a datos" Interpolación lineal entre datos y ruido; la red predice la velocidad; menos pasos de ODE necesarios en la inferencia
Objetivo de velocidad "epsilon - x_0" El objetivo de regresión en el rectified flow; apunta desde los datos limpios hacia el ruido
CFG guidance "classifier-free guidance" Mezcla predicciones condicionales e incondicionales; aún se usa en modelos de rectified flow
Schnell / turbo / LCM "Destilación de 1-4 pasos" Variantes de pocos pasos destiladas de modelos de calidad total; tiempo real en producción

Lecturas adicionales

0 lifetime access. Curriculum based on AI Engineering from Scratch by Rohit Ghumare (MIT, used under attribution).