Phase 04 - Lesson 14

Vision Transformers (ViT)

This lesson includes a graded coding exercise that runs in your browser, unlocked with lifetime access.

Corta la imagen en patches, trata cada patch como una palabra, ejecuta un transformer estandar. No mires atras.

Tipo: Build Lenguajes: Python Prerrequisitos: Fase 7 Leccion 02 (Self-Attention), Fase 4 Leccion 04 (Clasificacion de Imagenes) Tiempo: ~45 minutos

Objetivos de Aprendizaje

  • Implementar patch embedding, positional embedding aprendido, class token y bloques de encoder transformer desde cero para construir un ViT minimo
  • Explicar por que se creia que el ViT necesitaba un preentrenamiento masivo de datos hasta que DeiT y MAE demostraron lo contrario
  • Comparar ViT, Swin y ConvNeXt en sus priors arquitectonicos (ninguno, atencion en ventana local, backbone convolucional)
  • Hacer fine-tune de un ViT preentrenado en un dataset pequeno usando timm y la receta estandar de linear-probe / fine-tune

El Problema

Durante una decada, convolucion fue sinonimo de vision por computadora. Las CNN tenian fuertes sesgos inductivos — localidad, equivarianza a la traslacion — que nadie creia que se pudieran reemplazar. Entonces Dosovitskiy et al. (2020) mostraron que un transformer puro aplicado a patches de imagen aplanados, sin ninguna maquinaria convolucional, podia igualar o superar a las mejores CNN a escala.

El detalle estaba en "a escala". El ViT en ImageNet-1k perdia ante el ResNet. El ViT preentrenado en ImageNet-21k o JFT-300M y luego fine-tuneado en ImageNet-1k lo superaba. La conclusion fue que los transformers carecian de priors utiles, pero podian aprenderlos con suficientes datos. Trabajos posteriores (DeiT, MAE, DINO) mostraron que con las recetas de entrenamiento correctas — augmentation fuerte, preentrenamiento auto-supervisado, destilacion — los ViT tambien entrenan bien con datos pequenos.

Para 2026, las CNN puras siguen siendo competitivas en dispositivos de borde (ConvNeXt es la mas fuerte), pero los transformers dominan todo lo demas: segmentacion (Mask2Former, SegFormer), deteccion (DETR, RT-DETR), multimodal (CLIP, SigLIP), video (VideoMAE, VJEPA). La estructura del bloque ViT es la que conviene conocer.

El Concepto

El pipeline

flowchart LR
    IMG["Imagen<br/>(3, 224, 224)"] --> PATCH["Patch embedding<br/>conv 16x16 s=16<br/>-> (768, 14, 14)"]
    PATCH --> FLAT["Aplana a<br/>(196, 768) tokens"]
    FLAT --> CAT["Antepone<br/>token [CLS]"]
    CAT --> POS["Suma positional<br/>embed aprendido"]
    POS --> ENC["N bloques de<br/>encoder transformer"]
    ENC --> CLS["Toma la salida del<br/>token [CLS]"]
    CLS --> HEAD["Clasificador MLP"]

    style PATCH fill:#dbeafe,stroke:#2563eb
    style ENC fill:#fef3c7,stroke:#d97706
    style HEAD fill:#dcfce7,stroke:#16a34a

Siete pasos. Patches -> tokens -> atencion -> clasificador. Cada variante (DeiT, Swin, ConvNeXt, preentrenamiento MAE) cambia uno o dos de los siete y deja el resto intacto.

Patch embedding

La primera conv es el secreto. Tamano de kernel 16, stride 16, asi que una imagen 224x224 se convierte en una grilla 14x14 de patches 16x16, cada uno proyectado a un embedding de 768 dimensiones. Esa unica conv tanto patchifica como proyecta linealmente.

Input:  (3, 224, 224)
Conv (3 -> 768, k=16, s=16, no padding):
Output: (768, 14, 14)
Flatten spatial: (196, 768)

196 patches = 196 tokens. La dimension de feature de cada token es 768 (ViT-B), 1024 (ViT-L) o 1280 (ViT-H).

Class token

Un unico vector aprendido antepuesto a la secuencia:

tokens = [CLS; patch_1; patch_2; ...; patch_196]   shape (197, 768)

Despues de N bloques transformer, la salida del [CLS] es la representacion global de la imagen. La cabeza de clasificacion lee solo ese unico vector.

Positional embedding

Los transformers no tienen una nocion incorporada de posicion espacial. Suma un vector aprendido a cada token:

tokens = tokens + learned_pos_embedding   (also shape (197, 768))

El embedding es un parametro del modelo; el entrenamiento basado en gradiente lo adapta a la estructura 2D de la imagen. Existen alternativas sinusoidales 2D, pero rara vez se usan en la practica.

Bloque de encoder transformer

Estandar. Multi-head self-attention, MLP, conexiones residuales, pre-LayerNorm.

x = x + MSA(LN(x))
x = x + MLP(LN(x))

MLP is two-layer with GELU: Linear(d -> 4d) -> GELU -> Linear(4d -> d)

El ViT-B/16 apila 12 de estos bloques, cada uno con 12 cabezas de atencion, sumando 86M de parametros.

Por que pre-LN

Los primeros transformers usaban post-LN (x = LN(x + sublayer(x))) y les costaba entrenar mas alla de 6-8 capas sin warmup. Pre-LN (x = x + sublayer(LN(x))) entrena redes mas profundas de forma estable sin warmup. Todo ViT y todo LLM moderno usan pre-LN.

Trade-off del tamano de patch

  • Patches 16x16 -> 196 tokens, estandar.
  • Patches 32x32 -> 49 tokens, mas rapido pero menor resolucion.
  • Patches 8x8 -> 784 tokens, mas fino pero el costo O(n^2) de la atencion escala mal.

Patches mas grandes = menos tokens = mas rapido pero menos detalle espacial. El SwinV2 usa patches 4x4 en ventanas jerarquicas.

La receta de DeiT para entrenar ViT en ImageNet-1k

El ViT original necesitaba JFT-300M para superar a las CNN. DeiT (Touvron et al., 2020) entreno el ViT-B hasta 81,8% top-1 en ImageNet-1k solo con cuatro cambios:

  1. Augmentation pesada: RandAugment, Mixup, CutMix, Random Erasing.
  2. Stochastic depth (descartar bloques enteros al azar durante el entrenamiento).
  3. Repeated augmentation (misma imagen muestreada 3 veces por batch).
  4. Destilacion de un profesor CNN (opcional, eleva aun mas la precision).

Toda receta moderna de entrenamiento de ViT desciende de DeiT.

Swin vs ConvNeXt

  • Swin (Liu et al., 2021) — atencion basada en ventana. Cada bloque atiende dentro de una ventana local; los bloques alternados desplazan la ventana para mezclar informacion entre ventanas. Devuelve un prior de localidad al estilo CNN manteniendo el operador de atencion.
  • ConvNeXt (Liu et al., 2022) — una CNN rediseñada que iguala las decisiones arquitectonicas de Swin (convs depthwise, LayerNorm, GELU, bottleneck invertido). Mostro que la diferencia no es "atencion vs convolucion" sino "receta de entrenamiento moderna + arquitectura".

En 2026, ConvNeXt-V2 y Swin-V2 son ambos de nivel de produccion; la eleccion correcta depende de tu stack de inferencia (ConvNeXt compila mejor para borde) y del corpus de preentrenamiento.

Preentrenamiento MAE

Masked Autoencoder (He et al., 2022): enmascara el 75% de los patches al azar, entrena el encoder para procesar solo el 25% visible, entrena un pequeno decoder para reconstruir los patches enmascarados a partir de la salida del encoder. Tras el preentrenamiento, descarta el decoder y haz fine-tune del encoder.

El MAE hace que el ViT sea entrenable en ImageNet-1k solo, alcanza SOTA y es la receta auto-supervisada estandar actual.

Construyelo

Paso 1: Patch embedding

import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=16, dim=192, image_size=64):
        super().__init__()
        assert image_size % patch_size == 0
        self.proj = nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size)
        num_patches = (image_size // patch_size) ** 2
        self.num_patches = num_patches

    def forward(self, x):
        x = self.proj(x)
        return x.flatten(2).transpose(1, 2)

Una conv, un flatten, un transpose. Ese es todo el paso de imagen-a-tokens.

Paso 2: Bloque transformer

Pre-LN, multi-head self-attention, MLP con GELU, conexiones residuales.

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4, dropout=0.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * mlp_ratio),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mlp_ratio, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        a, _ = self.attn(self.ln1(x), self.ln1(x), self.ln1(x), need_weights=False)
        x = x + a
        x = x + self.mlp(self.ln2(x))
        return x

El nn.MultiheadAttention maneja la division en cabezas, el producto punto escalado y la proyeccion de salida. batch_first=True para que los shapes sean (N, seq, dim).

Paso 3: El ViT

class ViT(nn.Module):
    def __init__(self, image_size=64, patch_size=16, in_channels=3,
                 num_classes=10, dim=192, depth=6, num_heads=3, mlp_ratio=4):
        super().__init__()
        self.patch = PatchEmbedding(in_channels, patch_size, dim, image_size)
        num_patches = self.patch.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, dim))
        self.blocks = nn.ModuleList([
            Block(dim, num_heads, mlp_ratio) for _ in range(depth)
        ])
        self.ln = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, num_classes)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):
        x = self.patch(x)
        cls = self.cls_token.expand(x.size(0), -1, -1)
        x = torch.cat([cls, x], dim=1)
        x = x + self.pos_embed
        for blk in self.blocks:
            x = blk(x)
        x = self.ln(x[:, 0])
        return self.head(x)

vit = ViT(image_size=64, patch_size=16, num_classes=10, dim=192, depth=6, num_heads=3)
x = torch.randn(2, 3, 64, 64)
print(f"output: {vit(x).shape}")
print(f"params: {sum(p.numel() for p in vit.parameters()):,}")

Cerca de 2,8M de parametros — un ViT diminuto manejable en CPU. El ViT-B real tiene 86M; la misma definicion de clase con dim=768, depth=12, num_heads=12.

Paso 4: Sanity check — inferencia de imagen unica

logits = vit(torch.randn(1, 3, 64, 64))
print(f"logits: {logits}")
print(f"probs:  {logits.softmax(-1)}")

Debe ejecutarse sin error. Las probabilidades suman 1.

Usalo

El timm trae cada variante de ViT con pesos preentrenados en ImageNet. Una linea:

import timm

model = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=10)

El timm es el estandar de produccion para vision transformers en 2026. Soporta ViT, DeiT, Swin, Swin-V2, ConvNeXt, ConvNeXt-V2, MaxViT, MViT, EfficientFormer y decenas de otros bajo la misma API.

Para trabajo multimodal (imagen + texto), el transformers trae CLIP, SigLIP, BLIP-2, LLaVA. El encoder de imagen en todos ellos es una variante de ViT.

Entregalo

Esta leccion produce:

  • outputs/prompt-vit-vs-cnn-picker.md — un prompt que elige entre un ViT, un ConvNeXt o un Swin segun el tamano del dataset, el compute y la stack de inferencia.
  • outputs/skill-vit-patch-and-pos-embed-inspector.md — una skill que verifica que los shapes de patch embedding y positional embedding de un ViT coincidan con la longitud de secuencia esperada por el modelo, capturando los bugs de portabilidad mas comunes.

Ejercicios

  1. (Facil) Imprime los shapes de cada tensor intermedio para un forward pass por el ViT diminuto de arriba. Confirma: input (N, 3, 64, 64) -> patches (N, 16, 192) -> con CLS (N, 17, 192) -> input del clasificador (N, 192) -> output (N, num_classes).
  2. (Medio) Haz fine-tune de un ViT-S/16 preentrenado de timm en el dataset CIFAR sintetico de la Leccion 4. Compara con el fine-tune de un ResNet-18 sobre los mismos datos. Reporta el tiempo de entrenamiento y la precision final.
  3. (Dificil) Implementa el preentrenamiento MAE para el ViT diminuto: enmascara el 75% de los patches, entrena el encoder + un pequeno decoder para reconstruir los patches enmascarados. Evalua la precision de linear-probe en los datos sinteticos antes y despues del preentrenamiento.

Terminos Clave

Termino Lo que dice la gente Lo que realmente significa
Patch embedding "La primera conv" Una conv con tamano de kernel = stride = tamano del patch; convierte la imagen en una grilla de token embeddings
Class token "[CLS]" Un vector aprendido antepuesto a la secuencia de tokens; su salida final es la representacion global de la imagen
Positional embedding "Pos aprendido" Un vector aprendido sumado a cada token para que el transformer sepa de donde vino cada patch
Pre-LN "LayerNorm antes del sublayer" La variante estable del transformer: x + sublayer(LN(x)) en vez de LN(x + sublayer(x))
Multi-head attention "Atencion paralela" La atencion estandar del transformer dividida en num_heads subespacios independientes, concatenados despues
ViT-B/16 "Base, patch 16" El tamano canonico: dim=768, depth=12, heads=12, patch_size=16, image=224; ~86M params
DeiT "ViT eficiente en datos" ViT entrenado solo en ImageNet-1k con augmentation fuerte; demostro que los datasets grandes de preentrenamiento no son estrictamente necesarios
MAE "Masked autoencoder" Preentrenamiento auto-supervisado: enmascara el 75% de los patches, reconstruye; la receta dominante de preentrenamiento de ViT

Lecturas Adicionales

0 lifetime access. Curriculum based on AI Engineering from Scratch by Rohit Ghumare (MIT, used under attribution).