Phase 04 - Lesson 05

Transfer Learning y Fine-Tuning

Alguien más invirtió un millón de horas de GPU enseñándole a una red cómo lucen los bordes, las texturas y las partes de los objetos. Deberías tomar prestadas esas features antes de entrenar las tuyas.

Tipo: Build Lenguajes: Python Prerrequisitos: Fase 4 Lección 03 (CNNs), Fase 4 Lección 04 (Clasificación de Imágenes) Tiempo: ~75 minutos

Objetivos de Aprendizaje

  • Distinguir la extracción de features del fine-tuning y elegir la opción correcta según el tamaño del dataset, la distancia de dominio y el presupuesto de cómputo
  • Cargar un backbone preentrenado, reemplazar su cabeza clasificadora y entrenar solo la cabeza hasta un baseline funcional en menos de 20 líneas
  • Descongelar capas progresivamente con learning rates discriminativos, de modo que las features genéricas iniciales reciban actualizaciones menores que las específicas de la tarea al final
  • Diagnosticar las tres fallas comunes: feature drift por un LR demasiado alto en bloques descongelados, colapso de las estadísticas de BN en datasets diminutos y olvido catastrófico

El Problema

Entrenar una ResNet-50 en ImageNet cuesta alrededor de 2.000 horas de GPU. Muy pocos equipos tienen ese presupuesto para cada tarea que llevan a producción. Lo que casi todos los equipos efectivamente ponen en producción es un backbone preentrenado con una nueva cabeza entrenada con unos cientos o unos miles de imágenes específicas de la tarea.

Esto no es un atajo. El primer bloque conv de cualquier CNN entrenada en ImageNet aprende bordes y filtros tipo Gabor. Los siguientes bloques aprenden texturas y motivos simples. Los bloques intermedios aprenden partes de objetos. Los bloques finales aprenden combinaciones que empiezan a parecerse a las 1.000 categorías de ImageNet. El primer 90% de esa jerarquía transfiere casi sin cambios a imágenes médicas, inspección industrial, datos satelitales y cualquier otra tarea de visión — porque la naturaleza tiene un vocabulario limitado de bordes y texturas. El último 10% es lo que de verdad entrenas.

Hacer bien la transferencia tiene tres bugs esperándote: destruir features preentrenadas con un learning rate demasiado alto, privar al modelo de información congelando demasiado y dejar que las estadísticas en ejecución de BatchNorm deriven hacia un dataset diminuto del que el resto de la red nunca aprendió. Esta lección recorre cada uno de ellos a propósito.

El Concepto

Extracción de features vs fine-tuning

Dos regímenes, elegidos según cuánto confíes en las features preentrenadas y cuántos datos tengas.

flowchart TB
    subgraph FE["Extracción de features — backbone congelado"]
        FE1["Backbone preentrenado<br/>(sin gradiente)"] --> FE2["Nueva cabeza<br/>(entrenada)"]
    end
    subgraph FT["Fine-tuning — de punta a punta"]
        FT1["Backbone preentrenado<br/>(LR diminuto)"] --> FT2["Nueva cabeza<br/>(LR normal)"]
    end

    style FE1 fill:#e5e7eb,stroke:#6b7280
    style FE2 fill:#dcfce7,stroke:#16a34a
    style FT1 fill:#fef3c7,stroke:#d97706
    style FT2 fill:#dcfce7,stroke:#16a34a

Reglas prácticas:

Tamaño del dataset Distancia de dominio Receta
< 1k imágenes cercano a ImageNet Congela el backbone, entrena solo la cabeza
1k-10k cercano Congela los primeros 2-3 estadios, haz fine-tune al resto
10k-100k cualquiera Fine-tune de punta a punta con LR discriminativo
100k+ lejano Haz fine-tune a todo; considera entrenar desde cero si el dominio es lo bastante lejano

"Cercano a ImageNet" significa, a grandes rasgos, fotos RGB naturales con contenido tipo objeto. Las tomografías médicas, las imágenes satelitales cenitales y la microscopía son dominios lejanos — las features aún ayudan, pero tendrás que dejar que más capas se adapten.

Por qué congelar funciona, después de todo

Las features de ImageNet que aprende una CNN no están especializadas en las 1.000 categorías. Están especializadas en las estadísticas de las imágenes naturales: bordes en orientaciones específicas, texturas, patrones de contraste, primitivas de forma. Esas estadísticas son estables en casi todo dominio visual que un humano pueda nombrar. Por eso un modelo entrenado en ImageNet y evaluado zero-shot en CIFAR-10 con apenas una nueva cabeza lineal (sin fine-tuning del backbone) alcanza 80%+ de exactitud. La cabeza está aprendiendo cuáles de las features ya aprendidas debe ponderar para esta tarea.

Learning rates discriminativos

Cuando efectivamente descongelas, las capas iniciales deben entrenar más lento que las finales. Las capas iniciales codifican features genéricas que quieres preservar; las capas finales codifican estructura específica de la tarea que necesitas mover bastante.

Receta típica:

  stage 0 (stem + primer grupo): lr = base_lr / 100    (casi fijo)
  stage 1:                       lr = base_lr / 10
  stage 2:                       lr = base_lr / 3
  stage 3 (último grupo backbone): lr = base_lr
  head:                          lr = base_lr  (o un poco mayor)

En PyTorch esto es simplemente una lista de grupos de parámetros pasada al optimizador. Un modelo, cinco learning rates, cero código extra.

El problema de BatchNorm

Las capas BN guardan buffers running_mean y running_var que fueron calculados en ImageNet. Si tu tarea tiene una distribución de píxeles distinta — distinta iluminación, distinto sensor, distinto espacio de color — esos buffers están equivocados. Tres opciones en orden de preferencia:

  1. Fine-tune con BN en modo train. Deja que BN actualice sus estadísticas en ejecución junto con todo lo demás. Opción por defecto cuando el dataset de la tarea es de tamaño medio (>= 5k ejemplos).
  2. Congela BN en modo eval. Conserva las estadísticas de ImageNet y entrena solo los pesos. Correcto cuando tu dataset es lo bastante pequeño como para que el promedio móvil de BN sea ruidoso.
  3. Reemplaza BN por GroupNorm. Elimina por completo el problema del promedio móvil. Se usa en backbones de detección y segmentación donde el batch size por GPU es diminuto.

Equivocarse en esto baja silenciosamente la exactitud en un 5-15%.

Diseño de la cabeza

La cabeza clasificadora son 1-3 capas lineales más un dropout opcional. Todo backbone de torchvision trae una cabeza por defecto que reemplazas:

backbone.fc = nn.Linear(backbone.fc.in_features, num_classes)          # ResNet
backbone.classifier[1] = nn.Linear(..., num_classes)                    # EfficientNet, MobileNet
backbone.heads.head = nn.Linear(..., num_classes)                       # torchvision ViT

Para datasets pequeños, una sola capa lineal suele bastar. Agregar una capa oculta (Linear -> ReLU -> Dropout -> Linear) ayuda cuando la distribución de la tarea está más lejos de la distribución de entrenamiento del backbone.

Decaimiento de LR por capa

Una versión más suave del LR discriminativo, usada en fine-tuning moderno (BEiT, DINOv2, fine-tunes de ViT-B). En lugar de agrupar capas en estadios, dale a cada capa un LR ligeramente menor que el de la capa de encima:

lr_layer_k = base_lr * decay^(L - k)

Con decay = 0.75 y L = 12 bloques de transformer, el primer bloque entrena a 0.75^11 ≈ 0.04x del LR de la cabeza. Importa más para los fine-tunes de transformer que para las CNNs, donde los LR agrupados por estadio suelen bastar.

Qué evaluar

Las corridas de transfer learning necesitan dos números que no rastrearías en una corrida desde cero:

  • Exactitud solo preentrenada — la exactitud de la cabeza con el backbone congelado. Este es tu piso.
  • Exactitud con fine-tune — el mismo modelo tras el entrenamiento de punta a punta. Este es tu techo.

Si la exactitud con fine-tune es menor que la solo preentrenada, tienes un bug de learning rate o de BN. Imprime siempre ambas.

Constrúyelo

Paso 1: Cargar un backbone preentrenado e inspeccionarlo

import torch
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights

backbone = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
print(backbone)
print()
print("classifier head:", backbone.fc)
print("feature dim:", backbone.fc.in_features)

La ResNet18 tiene cuatro estadios (layer1..layer4) más un stem y una cabeza fc. Todo backbone de clasificación de torchvision tiene una estructura análoga.

Paso 2: Extracción de features — congela todo, reemplaza la cabeza

def make_feature_extractor(num_classes=10):
    model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    for p in model.parameters():
        p.requires_grad = False
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

model = make_feature_extractor(num_classes=10)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
frozen = sum(p.numel() for p in model.parameters() if not p.requires_grad)
print(f"trainable: {trainable:>10,}")
print(f"frozen:    {frozen:>10,}")

Solo model.fc es entrenable. El backbone es un extractor de features congelado.

Paso 3: Fine-tuning discriminativo

Una utilidad que construye grupos de parámetros con learning rates específicos por estadio.

def discriminative_param_groups(model, base_lr=1e-3, decay=0.3):
    stages = [
        ["conv1", "bn1"],
        ["layer1"],
        ["layer2"],
        ["layer3"],
        ["layer4"],
        ["fc"],
    ]
    groups = []
    for i, names in enumerate(stages):
        lr = base_lr * (decay ** (len(stages) - 1 - i))
        params = [p for n, p in model.named_parameters()
                  if any(n.startswith(k) for k in names)]
        if params:
            groups.append({"params": params, "lr": lr, "name": "_".join(names)})
    return groups

model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, 10)
for p in model.parameters():
    p.requires_grad = True

groups = discriminative_param_groups(model)
for g in groups:
    print(f"{g['name']:>10s}  lr={g['lr']:.2e}  params={sum(p.numel() for p in g['params']):>8,}")

decay=0.3 significa que cada estadio entrena al 30% de la tasa del siguiente. fc recibe base_lr, layer4 recibe 0.3 * base_lr, conv1 recibe 0.3^5 * base_lr ≈ 0.00243 * base_lr. Suena extremo; empíricamente funciona.

Paso 4: Manejo de BatchNorm

Helper para congelar las estadísticas en ejecución de BN sin congelar sus pesos.

def freeze_bn_stats(model):
    for m in model.modules():
        if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
            m.eval()
            for p in m.parameters():
                p.requires_grad = False
    return model

Llámalo después de definir model.train() al inicio de cada epoch. model.train() pone todo en modo de entrenamiento; esto lo revierte solo para las capas BN.

Paso 5: Un loop mínimo de fine-tuning de punta a punta

from torch.optim import SGD
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn.functional as F

def fine_tune(model, train_loader, val_loader, device, epochs=5, base_lr=1e-3, freeze_bn=False):
    model = model.to(device)
    groups = discriminative_param_groups(model, base_lr=base_lr)
    optimizer = SGD(groups, momentum=0.9, weight_decay=1e-4, nesterov=True)
    scheduler = CosineAnnealingLR(optimizer, T_max=epochs)

    for epoch in range(epochs):
        model.train()
        if freeze_bn:
            freeze_bn_stats(model)
        tr_loss, tr_correct, tr_total = 0.0, 0, 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = F.cross_entropy(logits, y, label_smoothing=0.1)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            tr_loss += loss.item() * x.size(0)
            tr_total += x.size(0)
            tr_correct += (logits.argmax(-1) == y).sum().item()
        scheduler.step()

        model.eval()
        va_total, va_correct = 0, 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                pred = model(x).argmax(-1)
                va_total += x.size(0)
                va_correct += (pred == y).sum().item()
        print(f"epoch {epoch}  train {tr_loss/tr_total:.3f}/{tr_correct/tr_total:.3f}  "
              f"val {va_correct/va_total:.3f}")
    return model

Cinco epochs con la receta anterior en CIFAR-10 llevan a la ResNet18-IMAGENET1K_V1 de ~70% de exactitud zero-shot en linear-probe a ~93% de exactitud con fine-tune. La cabeza por sí sola se estancaría alrededor del 86% sin tocar nunca el backbone.

Paso 6: Descongelamiento progresivo

Un cronograma que descongela un estadio por epoch, del final hacia el comienzo. Mitiga el feature drift a costa de algunos epochs extra.

def progressive_unfreeze_schedule(model):
    stages = ["layer4", "layer3", "layer2", "layer1"]
    yielded = set()

    def start():
        for p in model.parameters():
            p.requires_grad = False
        for p in model.fc.parameters():
            p.requires_grad = True

    def unfreeze(epoch):
        if epoch < len(stages):
            name = stages[epoch]
            yielded.add(name)
            for n, p in model.named_parameters():
                if n.startswith(name):
                    p.requires_grad = True
            return name
        return None

    return start, unfreeze

Llama a start() una vez antes del primer epoch. Llama a unfreeze(epoch) al inicio de cada epoch. Reconstruye el optimizador cada vez que cambie el conjunto de parámetros entrenables; de lo contrario, los params congelados aún conservan momentos en caché que lo confunden.

Úsalo

Para la mayoría de las tareas reales, torchvision.models + tres líneas basta. La maquinaria más pesada de arriba importa cuando te topas con los problemas que los defaults de la biblioteca no pueden resolver.

from torchvision.models import resnet50, ResNet50_Weights

model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
model.fc = nn.Linear(model.fc.in_features, num_classes)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

Otros dos defaults de nivel de producción:

  • timm trae ~800 backbones de visión preentrenados con una API consistente (timm.create_model("resnet50", pretrained=True, num_classes=10)). Para cualquier fine-tune más allá del zoo de torchvision, es el estándar.
  • Para transformers, transformers.AutoModelForImageClassification.from_pretrained(name, num_labels=N) te da ViT / BEiT / DeiT con la misma semántica de carga que los modelos de texto.

Entrégalo

Esta lección produce:

  • outputs/prompt-fine-tune-planner.md — un prompt que elige entre extracción de features, fine-tuning progresivo o de punta a punta según el tamaño del dataset, la distancia de dominio y el presupuesto de cómputo.
  • outputs/skill-freeze-inspector.md — una skill que, dado un modelo PyTorch, reporta qué parámetros son entrenables, qué capas BatchNorm están en modo eval y si el optimizador efectivamente está recibiendo los parámetros entrenables.

Ejercicios

  1. (Fácil) Entrena una ResNet18 como linear probe (backbone congelado) y como full fine-tune sobre el mismo dataset CIFAR sintético. Reporta ambas exactitudes lado a lado. Explica qué brecha te dice que las features transfieren bien y cuál te dice que no.
  2. (Medio) Introduce un bug a propósito: define base_lr = 1e-1 en el estadio del backbone en lugar de la cabeza. Muestra cómo la loss de entrenamiento explota, luego recupérala aplicando el helper discriminative_param_groups. Registra el LR en el que cada estadio empieza a divergir.
  3. (Difícil) Toma un dataset de imágenes médicas (p. ej. CheXpert-small, PatchCamelyon o HAM10000) y compara tres regímenes: (a) backbone preentrenado en ImageNet congelado + cabeza lineal; (b) fine-tune de punta a punta del preentrenado en ImageNet; (c) entrenamiento desde cero. Reporta la exactitud y el costo de cómputo de cada uno. ¿A partir de qué tamaño de dataset el entrenamiento desde cero se vuelve competitivo?

Términos Clave

Término Lo que dice la gente Lo que realmente significa
Extracción de features "Congelar y entrenar la cabeza" Parámetros del backbone congelados, solo la nueva cabeza clasificadora recibe gradiente
Fine-tuning "Reentrenar de punta a punta" Todos los parámetros entrenables, normalmente con un LR mucho menor que en el entrenamiento desde cero
LR discriminativo "LR menor para las capas iniciales" Grupos de parámetros del optimizador donde el LR de los estadios iniciales es una fracción del LR de los estadios finales
Decaimiento de LR por capa "Gradiente suave de LR" LR por capa multiplicado por decay^(L - k); común en los fine-tunes de transformer
Olvido catastrófico "El modelo perdió ImageNet" Un LR demasiado alto sobrescribe las features preentrenadas antes de que se aprenda la señal de la nueva tarea
Drift de las estadísticas de BN "La running mean está mal" running_mean/var de BatchNorm calculados en una distribución distinta de la tarea actual, perjudicando silenciosamente la exactitud
Linear probe "Backbone congelado + cabeza lineal" Evaluación de las features preentrenadas — exactitud del mejor clasificador lineal sobre la representación congelada
Colapso catastrófico "Todo predice una sola clase" Ocurre al hacer fine-tune con un LR lo bastante alto como para destruir las features antes de que los gradientes de la cabeza logren estabilizarse

Lecturas Adicionales

0 lifetime access. Curriculum based on AI Engineering from Scratch by Rohit Ghumare (MIT, used under attribution).