Phase 04 - Lesson 05
Transfer Learning y Fine-Tuning
Alguien más invirtió un millón de horas de GPU enseñándole a una red cómo lucen los bordes, las texturas y las partes de los objetos. Deberías tomar prestadas esas features antes de entrenar las tuyas.
Tipo: Build Lenguajes: Python Prerrequisitos: Fase 4 Lección 03 (CNNs), Fase 4 Lección 04 (Clasificación de Imágenes) Tiempo: ~75 minutos
Objetivos de Aprendizaje
- Distinguir la extracción de features del fine-tuning y elegir la opción correcta según el tamaño del dataset, la distancia de dominio y el presupuesto de cómputo
- Cargar un backbone preentrenado, reemplazar su cabeza clasificadora y entrenar solo la cabeza hasta un baseline funcional en menos de 20 líneas
- Descongelar capas progresivamente con learning rates discriminativos, de modo que las features genéricas iniciales reciban actualizaciones menores que las específicas de la tarea al final
- Diagnosticar las tres fallas comunes: feature drift por un LR demasiado alto en bloques descongelados, colapso de las estadísticas de BN en datasets diminutos y olvido catastrófico
El Problema
Entrenar una ResNet-50 en ImageNet cuesta alrededor de 2.000 horas de GPU. Muy pocos equipos tienen ese presupuesto para cada tarea que llevan a producción. Lo que casi todos los equipos efectivamente ponen en producción es un backbone preentrenado con una nueva cabeza entrenada con unos cientos o unos miles de imágenes específicas de la tarea.
Esto no es un atajo. El primer bloque conv de cualquier CNN entrenada en ImageNet aprende bordes y filtros tipo Gabor. Los siguientes bloques aprenden texturas y motivos simples. Los bloques intermedios aprenden partes de objetos. Los bloques finales aprenden combinaciones que empiezan a parecerse a las 1.000 categorías de ImageNet. El primer 90% de esa jerarquía transfiere casi sin cambios a imágenes médicas, inspección industrial, datos satelitales y cualquier otra tarea de visión — porque la naturaleza tiene un vocabulario limitado de bordes y texturas. El último 10% es lo que de verdad entrenas.
Hacer bien la transferencia tiene tres bugs esperándote: destruir features preentrenadas con un learning rate demasiado alto, privar al modelo de información congelando demasiado y dejar que las estadísticas en ejecución de BatchNorm deriven hacia un dataset diminuto del que el resto de la red nunca aprendió. Esta lección recorre cada uno de ellos a propósito.
El Concepto
Extracción de features vs fine-tuning
Dos regímenes, elegidos según cuánto confíes en las features preentrenadas y cuántos datos tengas.
flowchart TB
subgraph FE["Extracción de features — backbone congelado"]
FE1["Backbone preentrenado<br/>(sin gradiente)"] --> FE2["Nueva cabeza<br/>(entrenada)"]
end
subgraph FT["Fine-tuning — de punta a punta"]
FT1["Backbone preentrenado<br/>(LR diminuto)"] --> FT2["Nueva cabeza<br/>(LR normal)"]
end
style FE1 fill:#e5e7eb,stroke:#6b7280
style FE2 fill:#dcfce7,stroke:#16a34a
style FT1 fill:#fef3c7,stroke:#d97706
style FT2 fill:#dcfce7,stroke:#16a34a
Reglas prácticas:
| Tamaño del dataset | Distancia de dominio | Receta |
|---|---|---|
| < 1k imágenes | cercano a ImageNet | Congela el backbone, entrena solo la cabeza |
| 1k-10k | cercano | Congela los primeros 2-3 estadios, haz fine-tune al resto |
| 10k-100k | cualquiera | Fine-tune de punta a punta con LR discriminativo |
| 100k+ | lejano | Haz fine-tune a todo; considera entrenar desde cero si el dominio es lo bastante lejano |
"Cercano a ImageNet" significa, a grandes rasgos, fotos RGB naturales con contenido tipo objeto. Las tomografías médicas, las imágenes satelitales cenitales y la microscopía son dominios lejanos — las features aún ayudan, pero tendrás que dejar que más capas se adapten.
Por qué congelar funciona, después de todo
Las features de ImageNet que aprende una CNN no están especializadas en las 1.000 categorías. Están especializadas en las estadísticas de las imágenes naturales: bordes en orientaciones específicas, texturas, patrones de contraste, primitivas de forma. Esas estadísticas son estables en casi todo dominio visual que un humano pueda nombrar. Por eso un modelo entrenado en ImageNet y evaluado zero-shot en CIFAR-10 con apenas una nueva cabeza lineal (sin fine-tuning del backbone) alcanza 80%+ de exactitud. La cabeza está aprendiendo cuáles de las features ya aprendidas debe ponderar para esta tarea.
Learning rates discriminativos
Cuando efectivamente descongelas, las capas iniciales deben entrenar más lento que las finales. Las capas iniciales codifican features genéricas que quieres preservar; las capas finales codifican estructura específica de la tarea que necesitas mover bastante.
Receta típica:
stage 0 (stem + primer grupo): lr = base_lr / 100 (casi fijo)
stage 1: lr = base_lr / 10
stage 2: lr = base_lr / 3
stage 3 (último grupo backbone): lr = base_lr
head: lr = base_lr (o un poco mayor)
En PyTorch esto es simplemente una lista de grupos de parámetros pasada al optimizador. Un modelo, cinco learning rates, cero código extra.
El problema de BatchNorm
Las capas BN guardan buffers running_mean y running_var que fueron calculados en ImageNet. Si tu tarea tiene una distribución de píxeles distinta — distinta iluminación, distinto sensor, distinto espacio de color — esos buffers están equivocados. Tres opciones en orden de preferencia:
- Fine-tune con BN en modo train. Deja que BN actualice sus estadísticas en ejecución junto con todo lo demás. Opción por defecto cuando el dataset de la tarea es de tamaño medio (>= 5k ejemplos).
- Congela BN en modo eval. Conserva las estadísticas de ImageNet y entrena solo los pesos. Correcto cuando tu dataset es lo bastante pequeño como para que el promedio móvil de BN sea ruidoso.
- Reemplaza BN por GroupNorm. Elimina por completo el problema del promedio móvil. Se usa en backbones de detección y segmentación donde el batch size por GPU es diminuto.
Equivocarse en esto baja silenciosamente la exactitud en un 5-15%.
Diseño de la cabeza
La cabeza clasificadora son 1-3 capas lineales más un dropout opcional. Todo backbone de torchvision trae una cabeza por defecto que reemplazas:
backbone.fc = nn.Linear(backbone.fc.in_features, num_classes) # ResNet
backbone.classifier[1] = nn.Linear(..., num_classes) # EfficientNet, MobileNet
backbone.heads.head = nn.Linear(..., num_classes) # torchvision ViT
Para datasets pequeños, una sola capa lineal suele bastar. Agregar una capa oculta (Linear -> ReLU -> Dropout -> Linear) ayuda cuando la distribución de la tarea está más lejos de la distribución de entrenamiento del backbone.
Decaimiento de LR por capa
Una versión más suave del LR discriminativo, usada en fine-tuning moderno (BEiT, DINOv2, fine-tunes de ViT-B). En lugar de agrupar capas en estadios, dale a cada capa un LR ligeramente menor que el de la capa de encima:
lr_layer_k = base_lr * decay^(L - k)
Con decay = 0.75 y L = 12 bloques de transformer, el primer bloque entrena a 0.75^11 ≈ 0.04x del LR de la cabeza. Importa más para los fine-tunes de transformer que para las CNNs, donde los LR agrupados por estadio suelen bastar.
Qué evaluar
Las corridas de transfer learning necesitan dos números que no rastrearías en una corrida desde cero:
- Exactitud solo preentrenada — la exactitud de la cabeza con el backbone congelado. Este es tu piso.
- Exactitud con fine-tune — el mismo modelo tras el entrenamiento de punta a punta. Este es tu techo.
Si la exactitud con fine-tune es menor que la solo preentrenada, tienes un bug de learning rate o de BN. Imprime siempre ambas.
Constrúyelo
Paso 1: Cargar un backbone preentrenado e inspeccionarlo
import torch
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights
backbone = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
print(backbone)
print()
print("classifier head:", backbone.fc)
print("feature dim:", backbone.fc.in_features)
La ResNet18 tiene cuatro estadios (layer1..layer4) más un stem y una cabeza fc. Todo backbone de clasificación de torchvision tiene una estructura análoga.
Paso 2: Extracción de features — congela todo, reemplaza la cabeza
def make_feature_extractor(num_classes=10):
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
for p in model.parameters():
p.requires_grad = False
model.fc = nn.Linear(model.fc.in_features, num_classes)
return model
model = make_feature_extractor(num_classes=10)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
frozen = sum(p.numel() for p in model.parameters() if not p.requires_grad)
print(f"trainable: {trainable:>10,}")
print(f"frozen: {frozen:>10,}")
Solo model.fc es entrenable. El backbone es un extractor de features congelado.
Paso 3: Fine-tuning discriminativo
Una utilidad que construye grupos de parámetros con learning rates específicos por estadio.
def discriminative_param_groups(model, base_lr=1e-3, decay=0.3):
stages = [
["conv1", "bn1"],
["layer1"],
["layer2"],
["layer3"],
["layer4"],
["fc"],
]
groups = []
for i, names in enumerate(stages):
lr = base_lr * (decay ** (len(stages) - 1 - i))
params = [p for n, p in model.named_parameters()
if any(n.startswith(k) for k in names)]
if params:
groups.append({"params": params, "lr": lr, "name": "_".join(names)})
return groups
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, 10)
for p in model.parameters():
p.requires_grad = True
groups = discriminative_param_groups(model)
for g in groups:
print(f"{g['name']:>10s} lr={g['lr']:.2e} params={sum(p.numel() for p in g['params']):>8,}")
decay=0.3 significa que cada estadio entrena al 30% de la tasa del siguiente. fc recibe base_lr, layer4 recibe 0.3 * base_lr, conv1 recibe 0.3^5 * base_lr ≈ 0.00243 * base_lr. Suena extremo; empíricamente funciona.
Paso 4: Manejo de BatchNorm
Helper para congelar las estadísticas en ejecución de BN sin congelar sus pesos.
def freeze_bn_stats(model):
for m in model.modules():
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
m.eval()
for p in m.parameters():
p.requires_grad = False
return model
Llámalo después de definir model.train() al inicio de cada epoch. model.train() pone todo en modo de entrenamiento; esto lo revierte solo para las capas BN.
Paso 5: Un loop mínimo de fine-tuning de punta a punta
from torch.optim import SGD
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn.functional as F
def fine_tune(model, train_loader, val_loader, device, epochs=5, base_lr=1e-3, freeze_bn=False):
model = model.to(device)
groups = discriminative_param_groups(model, base_lr=base_lr)
optimizer = SGD(groups, momentum=0.9, weight_decay=1e-4, nesterov=True)
scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
for epoch in range(epochs):
model.train()
if freeze_bn:
freeze_bn_stats(model)
tr_loss, tr_correct, tr_total = 0.0, 0, 0
for x, y in train_loader:
x, y = x.to(device), y.to(device)
logits = model(x)
loss = F.cross_entropy(logits, y, label_smoothing=0.1)
optimizer.zero_grad()
loss.backward()
optimizer.step()
tr_loss += loss.item() * x.size(0)
tr_total += x.size(0)
tr_correct += (logits.argmax(-1) == y).sum().item()
scheduler.step()
model.eval()
va_total, va_correct = 0, 0
with torch.no_grad():
for x, y in val_loader:
x, y = x.to(device), y.to(device)
pred = model(x).argmax(-1)
va_total += x.size(0)
va_correct += (pred == y).sum().item()
print(f"epoch {epoch} train {tr_loss/tr_total:.3f}/{tr_correct/tr_total:.3f} "
f"val {va_correct/va_total:.3f}")
return model
Cinco epochs con la receta anterior en CIFAR-10 llevan a la ResNet18-IMAGENET1K_V1 de ~70% de exactitud zero-shot en linear-probe a ~93% de exactitud con fine-tune. La cabeza por sí sola se estancaría alrededor del 86% sin tocar nunca el backbone.
Paso 6: Descongelamiento progresivo
Un cronograma que descongela un estadio por epoch, del final hacia el comienzo. Mitiga el feature drift a costa de algunos epochs extra.
def progressive_unfreeze_schedule(model):
stages = ["layer4", "layer3", "layer2", "layer1"]
yielded = set()
def start():
for p in model.parameters():
p.requires_grad = False
for p in model.fc.parameters():
p.requires_grad = True
def unfreeze(epoch):
if epoch < len(stages):
name = stages[epoch]
yielded.add(name)
for n, p in model.named_parameters():
if n.startswith(name):
p.requires_grad = True
return name
return None
return start, unfreeze
Llama a start() una vez antes del primer epoch. Llama a unfreeze(epoch) al inicio de cada epoch. Reconstruye el optimizador cada vez que cambie el conjunto de parámetros entrenables; de lo contrario, los params congelados aún conservan momentos en caché que lo confunden.
Úsalo
Para la mayoría de las tareas reales, torchvision.models + tres líneas basta. La maquinaria más pesada de arriba importa cuando te topas con los problemas que los defaults de la biblioteca no pueden resolver.
from torchvision.models import resnet50, ResNet50_Weights
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
model.fc = nn.Linear(model.fc.in_features, num_classes)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
Otros dos defaults de nivel de producción:
timmtrae ~800 backbones de visión preentrenados con una API consistente (timm.create_model("resnet50", pretrained=True, num_classes=10)). Para cualquier fine-tune más allá del zoo de torchvision, es el estándar.- Para transformers,
transformers.AutoModelForImageClassification.from_pretrained(name, num_labels=N)te da ViT / BEiT / DeiT con la misma semántica de carga que los modelos de texto.
Entrégalo
Esta lección produce:
outputs/prompt-fine-tune-planner.md— un prompt que elige entre extracción de features, fine-tuning progresivo o de punta a punta según el tamaño del dataset, la distancia de dominio y el presupuesto de cómputo.outputs/skill-freeze-inspector.md— una skill que, dado un modelo PyTorch, reporta qué parámetros son entrenables, qué capas BatchNorm están en modo eval y si el optimizador efectivamente está recibiendo los parámetros entrenables.
Ejercicios
- (Fácil) Entrena una
ResNet18como linear probe (backbone congelado) y como full fine-tune sobre el mismo dataset CIFAR sintético. Reporta ambas exactitudes lado a lado. Explica qué brecha te dice que las features transfieren bien y cuál te dice que no. - (Medio) Introduce un bug a propósito: define
base_lr = 1e-1en el estadio del backbone en lugar de la cabeza. Muestra cómo la loss de entrenamiento explota, luego recupérala aplicando el helperdiscriminative_param_groups. Registra el LR en el que cada estadio empieza a divergir. - (Difícil) Toma un dataset de imágenes médicas (p. ej. CheXpert-small, PatchCamelyon o HAM10000) y compara tres regímenes: (a) backbone preentrenado en ImageNet congelado + cabeza lineal; (b) fine-tune de punta a punta del preentrenado en ImageNet; (c) entrenamiento desde cero. Reporta la exactitud y el costo de cómputo de cada uno. ¿A partir de qué tamaño de dataset el entrenamiento desde cero se vuelve competitivo?
Términos Clave
| Término | Lo que dice la gente | Lo que realmente significa |
|---|---|---|
| Extracción de features | "Congelar y entrenar la cabeza" | Parámetros del backbone congelados, solo la nueva cabeza clasificadora recibe gradiente |
| Fine-tuning | "Reentrenar de punta a punta" | Todos los parámetros entrenables, normalmente con un LR mucho menor que en el entrenamiento desde cero |
| LR discriminativo | "LR menor para las capas iniciales" | Grupos de parámetros del optimizador donde el LR de los estadios iniciales es una fracción del LR de los estadios finales |
| Decaimiento de LR por capa | "Gradiente suave de LR" | LR por capa multiplicado por decay^(L - k); común en los fine-tunes de transformer |
| Olvido catastrófico | "El modelo perdió ImageNet" | Un LR demasiado alto sobrescribe las features preentrenadas antes de que se aprenda la señal de la nueva tarea |
| Drift de las estadísticas de BN | "La running mean está mal" | running_mean/var de BatchNorm calculados en una distribución distinta de la tarea actual, perjudicando silenciosamente la exactitud |
| Linear probe | "Backbone congelado + cabeza lineal" | Evaluación de las features preentrenadas — exactitud del mejor clasificador lineal sobre la representación congelada |
| Colapso catastrófico | "Todo predice una sola clase" | Ocurre al hacer fine-tune con un LR lo bastante alto como para destruir las features antes de que los gradientes de la cabeza logren estabilizarse |
Lecturas Adicionales
- How transferable are features in deep neural networks? (Yosinski et al., 2014) — el paper que cuantificó la transferibilidad de las features entre capas
- Universal Language Model Fine-tuning (ULMFiT, Howard & Ruder, 2018) — la receta original de LR discriminativo / descongelamiento progresivo; las ideas transfieren directamente a visión
- timm documentation — la referencia para los backbones de visión modernos y los defaults exactos de fine-tune con los que fueron entrenados
- A Simple Framework for Linear-Probe Evaluation (Kornblith et al., 2019) — por qué importa la exactitud de linear-probe y cómo reportarla correctamente