Phase 03 - Lesson 13
Depuración de Redes Neuronales
Tu red compiló. Se ejecutó. Produjo un número. El número está mal y nada falló. Bienvenido al tipo más difícil de depuración: aquel en el que no hay ningún mensaje de error.
Tipo: Práctica Lenguajes: Python, PyTorch Requisitos previos: Fase 03 Lecciones 01-10 (especialmente retropropagación, funciones de pérdida, optimizadores) Tiempo: ~90 minutos
Objetivos de Aprendizaje
- Diagnosticar fallas comunes de redes neuronales (pérdida NaN, curva de pérdida plana, sobreajuste, oscilación) usando estrategias sistemáticas de depuración
- Aplicar la técnica de "sobreajustar un batch" para verificar que la arquitectura de tu modelo y el bucle de entrenamiento sean correctos
- Inspeccionar magnitudes de gradiente, distribuciones de activación y normas de pesos para identificar problemas de gradientes que se desvanecen/explotan
- Construir una lista de verificación de depuración que cubra problemas en el pipeline de datos, la arquitectura del modelo, la función de pérdida, el optimizador y la tasa de aprendizaje
El Problema
El software tradicional falla cuando está roto. Un puntero nulo lanza una excepción. Una incompatibilidad de tipos falla en tiempo de compilación. Un error de "off-by-one" produce una salida claramente incorrecta.
Las redes neuronales no te dan ese lujo.
Una red neuronal rota se ejecuta hasta el final, imprime un valor de pérdida y produce predicciones. La pérdida hasta podría disminuir. Las predicciones hasta podrían parecer plausibles. Pero el modelo está silenciosamente equivocado: aprendiendo atajos, memorizando ruido o convergiendo a un mínimo local inútil. Investigadores de Google estimaron que entre el 60 y el 70% del tiempo de depuración de ML se gasta en bugs "silenciosos" que no producen errores pero degradan la calidad del modelo.
La diferencia entre un modelo que funciona y uno roto suele ser una sola línea mal ubicada: un zero_grad() faltante, una dimensión transpuesta, una tasa de aprendizaje equivocada por un factor de 10x. El canónico "Recipe for Training Neural Networks" (2019) abre con esto: "Los errores más comunes en redes neuronales son bugs que no fallan."
Esta lección te enseña a encontrar esos bugs.
El Concepto
La Mentalidad de Depuración
Olvídate de la depuración de imprimir-y-rezar. Depurar redes neuronales exige un enfoque sistemático porque el ciclo de retroalimentación es lento (de minutos a horas por ejecución de entrenamiento) y los síntomas son ambiguos (una pérdida mala podría significar 20 cosas distintas).
La regla de oro: comienza simple, agrega complejidad una pieza a la vez y verifica cada pieza de forma independiente.
flowchart TD
A["La pérdida no disminuye"] --> B{"Verificar tasa de aprendizaje"}
B -->|"Muy alta"| C["La pérdida oscila o explota"]
B -->|"Muy baja"| D["La pérdida apenas se mueve"]
B -->|"Razonable"| E{"Verificar gradientes"}
E -->|"Todos cero"| F["ReLUs muertas o gradientes que se desvanecen"]
E -->|"NaN/Inf"| G["Gradientes que explotan"]
E -->|"Normal"| H{"Verificar pipeline de datos"}
H -->|"Etiquetas mezcladas"| I["Precisión de azar"]
H -->|"Bug de preprocesamiento"| J["El modelo aprende ruido"]
H -->|"Los datos están bien"| K{"Verificar arquitectura"}
K -->|"Demasiado pequeña"| L["Subajuste"]
K -->|"Demasiado profunda"| M["Dificultad de optimización"]
Síntoma 1: La Pérdida No Disminuye
Esta es la queja más común. El bucle de entrenamiento se ejecuta, las épocas pasan, y la pérdida se mantiene plana u oscila descontroladamente.
Tasa de aprendizaje equivocada. Muy alta: la pérdida oscila o salta a NaN. Muy baja: la pérdida disminuye tan lentamente que parece plana. Para Adam, comienza en 1e-3. Para SGD, comienza en 1e-1 o 1e-2. Siempre prueba 3 tasas de aprendizaje que abarquen 10x cada una (por ejemplo, 1e-2, 1e-3, 1e-4) antes de concluir que algo más está mal.
ReLUs muertas. Si una neurona ReLU recibe una entrada negativa grande, produce 0 y su gradiente es 0. Nunca vuelve a activarse. Si mueren suficientes neuronas, la red no puede aprender. Verifica: imprime la fracción de activaciones que son exactamente 0 después de cada capa ReLU. Si >50% están muertas, cambia a LeakyReLU o reduce la tasa de aprendizaje.
Gradientes que se desvanecen. En redes profundas con activaciones sigmoid o tanh, los gradientes se encogen exponencialmente a medida que se propagan hacia atrás. Para cuando llegan a la primera capa, están en ~0. Las primeras capas dejan de aprender. Solución: usa ReLU/GELU, agrega conexiones residuales o usa normalización por batch.
Gradientes que explotan. El problema opuesto: los gradientes crecen exponencialmente. Común en RNN y redes muy profundas. La pérdida salta a NaN. Solución: recorte de gradiente (torch.nn.utils.clip_grad_norm_), reduce la tasa de aprendizaje o agrega normalización.
Síntoma 2: La Pérdida Disminuye Pero el Modelo es Malo
La pérdida baja. La precisión de entrenamiento llega al 99%. Pero la precisión de prueba es 55%. O el modelo produce salidas sin sentido con datos reales.
Sobreajuste. El modelo memoriza los datos de entrenamiento en lugar de aprender patrones. La brecha entre las pérdidas de entrenamiento y validación crece con el tiempo. Solución: más datos, dropout, decaimiento de pesos, detención temprana, aumento de datos.
Fuga de datos. Datos de prueba se filtraron al entrenamiento. La precisión es sospechosamente alta. Causas comunes: mezclar antes de dividir, preprocesar con estadísticas del conjunto de datos completo, muestras duplicadas entre las divisiones. Solución: divide primero, preprocesa después, verifica duplicados.
Errores de etiqueta. El 5-10% de las etiquetas en la mayoría de los conjuntos de datos reales están mal (Northcutt et al., 2021 -- "Pervasive Label Errors in Test Sets"). El modelo aprende el ruido. Solución: usa aprendizaje confiado (confident learning) para encontrar y corregir ejemplos mal etiquetados, o usa truncamiento de pérdida para ignorar muestras de pérdida alta.
Síntoma 3: NaN o Inf en la Pérdida
El valor de la pérdida se convierte en nan o inf. El entrenamiento está muerto.
Tasa de aprendizaje demasiado alta. Las actualizaciones de gradiente se exceden tanto que los pesos explotan. Solución: reduce por 10x.
log(0) o log(negativo). La pérdida de entropía cruzada calcula log(p). Si tu modelo produce exactamente 0 o una probabilidad negativa, el log explota. Solución: limita las predicciones a [eps, 1-eps] donde eps=1e-7.
División por cero. La normalización por batch divide por la desviación estándar. Un batch con valores constantes tiene std=0. Solución: agrega epsilon al denominador (PyTorch lo hace por defecto, pero las implementaciones personalizadas podrían no hacerlo).
Desbordamiento numérico. Activaciones grandes ingresadas en exp() producen Inf. El softmax es especialmente propenso. Solución: resta el máximo antes de exponenciar (el truco del log-sum-exp).
Técnica 1: Verificación de Gradiente
Compara tus gradientes analíticos (de la retropropagación) con los gradientes numéricos (de diferencias finitas). Si difieren, tu paso hacia atrás (backward pass) tiene un bug.
Gradiente numérico para el parámetro w:
grad_numerical = (loss(w + eps) - loss(w - eps)) / (2 * eps)
Métrica de concordancia (diferencia relativa):
rel_diff = |grad_analytical - grad_numerical| / max(|grad_analytical|, |grad_numerical|, 1e-8)
Si rel_diff < 1e-5: correcto. Si rel_diff > 1e-3: casi con certeza un bug.
flowchart LR
A["Parámetro w"] --> B["w + eps"]
A --> C["w - eps"]
B --> D["Paso hacia adelante"]
C --> E["Paso hacia adelante"]
D --> F["loss+"]
E --> G["loss-"]
F --> H["(loss+ - loss-) / 2eps"]
G --> H
H --> I["Comparar con gradiente de retropropagación"]
Técnica 2: Estadísticas de Activación
Monitorea la media y la desviación estándar de las activaciones después de cada capa durante el entrenamiento. Las redes saludables mantienen activaciones con media cercana a 0 y std cercano a 1 (después de la normalización) o al menos acotadas.
| Indicador de salud | Media | Std | Diagnóstico |
|---|---|---|---|
| Saludable | ~0 | ~1 | La red está aprendiendo normalmente |
| Saturada | >>0 o <<0 | ~0 | Activaciones atascadas en valores extremos |
| Muerta | 0 | 0 | Las neuronas están muertas (todas en cero) |
| Explotando | >>10 | >>10 | Activaciones creciendo sin límite |
Técnica 3: Visualización del Flujo de Gradiente
Grafica la magnitud promedio del gradiente para cada capa. En una red saludable, las magnitudes de gradiente deberían ser aproximadamente similares entre las capas. Si las primeras capas tienen gradientes 1000x menores que las capas posteriores, tienes gradientes que se desvanecen.
graph LR
subgraph "Flujo de Gradiente Saludable"
L1["Capa 1<br/>grad: 0.05"] --- L2["Capa 2<br/>grad: 0.04"] --- L3["Capa 3<br/>grad: 0.06"] --- L4["Capa 4<br/>grad: 0.05"]
end
graph LR
subgraph "Flujo de Gradiente que se Desvanece"
V1["Capa 1<br/>grad: 0.0001"] --- V2["Capa 2<br/>grad: 0.003"] --- V3["Capa 3<br/>grad: 0.02"] --- V4["Capa 4<br/>grad: 0.08"]
end
Técnica 4: La Prueba de Sobreajustar un Batch
La técnica de depuración más importante en deep learning.
Toma un batch pequeño (8-32 muestras). Entrena con él durante más de 100 iteraciones. La pérdida debería llegar a casi cero y la precisión de entrenamiento debería alcanzar el 100%. Si no lo hace, tu modelo o bucle de entrenamiento tiene un bug fundamental: no procedas al entrenamiento completo.
Esta prueba detecta:
- Funciones de pérdida rotas
- Pasos hacia atrás (backward passes) rotos
- Arquitectura demasiado pequeña para representar los datos
- Optimizador no conectado a los parámetros del modelo
- Datos y etiquetas desalineados
Esto toma 30 segundos en ejecutarse y ahorra horas de depuración de ejecuciones completas de entrenamiento.
Técnica 5: Buscador de Tasa de Aprendizaje
Leslie Smith (2017) propuso barrer la tasa de aprendizaje desde muy pequeña (1e-7) hasta muy grande (10) a lo largo de una época mientras se registra la pérdida. Grafica la pérdida frente a la tasa de aprendizaje. La tasa de aprendizaje óptima es aproximadamente 10x menor que la tasa en la que la pérdida comienza a disminuir más rápido.
graph TD
subgraph "Gráfico del Buscador de LR"
direction LR
A["1e-7: loss=2.3"] --> B["1e-5: loss=2.3"]
B --> C["1e-3: loss=1.8"]
C --> D["1e-2: loss=0.9 -- más pronunciada"]
D --> E["1e-1: loss=0.5"]
E --> F["1.0: loss=NaN -- demasiado alta"]
end
Mejor LR en este ejemplo: ~1e-3 (un orden de magnitud antes del punto más pronunciado).
Bugs Comunes de PyTorch
Estos son los bugs que desperdician más horas colectivas en la comunidad de PyTorch:
| Bug | Síntoma | Solución |
|---|---|---|
Olvidar optimizer.zero_grad() |
Los gradientes se acumulan entre batches, la pérdida oscila | Agrega optimizer.zero_grad() antes de loss.backward() |
Olvidar model.eval() en el momento de prueba |
Dropout y batch norm se comportan diferente, la precisión de prueba varía entre ejecuciones | Agrega model.eval() y torch.no_grad() |
| Formas de tensor incorrectas | El broadcasting silencioso produce resultados erróneos, sin error | Imprime las formas después de cada operación durante la depuración |
| Incompatibilidad CPU/GPU | RuntimeError: expected CUDA tensor |
Usa .to(device) en el modelo Y en los datos |
| No desvincular tensores | El grafo de cómputo crece para siempre, OOM | Usa .detach() o with torch.no_grad() |
| Operaciones in-place que rompen autograd | RuntimeError: modified by in-place operation |
Reemplaza x += 1 por x = x + 1 |
| Datos no normalizados | La pérdida atascada en el nivel de azar | Normaliza las entradas a media=0, std=1 |
| Etiquetas con dtype incorrecto | La entropía cruzada espera Long, recibió Float |
Convierte las etiquetas: labels.long() |
La Tabla Maestra de Depuración
| Síntoma | Causa probable | Lo primero que probar |
|---|---|---|
| Pérdida atascada en -log(1/num_classes) | El modelo predice una distribución uniforme | Verifica el pipeline de datos, confirma que las etiquetas coincidan con las entradas |
| Pérdida NaN después de unos pasos | Tasa de aprendizaje demasiado alta | Reduce la LR por 10x |
| Pérdida NaN inmediatamente | log(0) o división por cero | Agrega epsilon a las operaciones de log/división |
| Pérdida oscilando descontroladamente | LR demasiado alta o batch size demasiado pequeño | Reduce la LR, aumenta el batch size |
| Pérdida disminuye y luego se estanca | LR demasiado alta para la fase de ajuste fino | Agrega un programador de LR (decaimiento coseno o por pasos) |
| Precisión de entrenamiento alta, precisión de prueba baja | Sobreajuste | Agrega dropout, decaimiento de pesos, más datos |
| Precisión de entrenamiento = precisión de prueba = azar | El modelo no está aprendiendo nada | Ejecuta la prueba de sobreajustar un batch |
| Precisión de entrenamiento = precisión de prueba pero ambas bajas | Subajuste | Modelo más grande, más capas, más features |
| Gradientes todos cero | ReLUs muertas o grafo de cómputo desvinculado | Cambia a LeakyReLU, verifica .requires_grad |
| Sin memoria durante el entrenamiento | Batch demasiado grande o grafo no liberado | Reduce el batch size, usa torch.no_grad() para la evaluación |
Constrúyelo
Un kit de diagnóstico que monitorea activaciones, gradientes y curvas de pérdida. Vas a romper deliberadamente una red y usar el kit para diagnosticar cada problema.
Paso 1: La Clase NetworkDebugger
Se engancha (hooks) a un modelo PyTorch para registrar estadísticas de activación y gradiente por capa.
import torch
import torch.nn as nn
import math
class NetworkDebugger:
def __init__(self, model):
self.model = model
self.activation_stats = {}
self.gradient_stats = {}
self.loss_history = []
self.lr_losses = []
self.hooks = []
self._register_hooks()
def _register_hooks(self):
for name, module in self.model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ReLU, nn.LeakyReLU)):
hook = module.register_forward_hook(self._make_activation_hook(name))
self.hooks.append(hook)
hook = module.register_full_backward_hook(self._make_gradient_hook(name))
self.hooks.append(hook)
def _make_activation_hook(self, name):
def hook(module, input, output):
with torch.no_grad():
out = output.detach().float()
self.activation_stats[name] = {
"mean": out.mean().item(),
"std": out.std().item(),
"fraction_zero": (out == 0).float().mean().item(),
"min": out.min().item(),
"max": out.max().item(),
}
return hook
def _make_gradient_hook(self, name):
def hook(module, grad_input, grad_output):
if grad_output[0] is not None:
with torch.no_grad():
grad = grad_output[0].detach().float()
self.gradient_stats[name] = {
"mean": grad.mean().item(),
"std": grad.std().item(),
"abs_mean": grad.abs().mean().item(),
"max": grad.abs().max().item(),
}
return hook
def record_loss(self, loss_value):
self.loss_history.append(loss_value)
def check_loss_health(self):
if len(self.loss_history) < 2:
return "NOT_ENOUGH_DATA"
recent = self.loss_history[-10:]
if any(math.isnan(v) or math.isinf(v) for v in recent):
return "NAN_OR_INF"
if len(self.loss_history) >= 20:
first_half = sum(self.loss_history[:10]) / 10
second_half = sum(self.loss_history[-10:]) / 10
if second_half >= first_half * 0.99:
return "NOT_DECREASING"
if len(recent) >= 5:
diffs = [recent[i+1] - recent[i] for i in range(len(recent)-1)]
if max(diffs) - min(diffs) > 2 * abs(sum(diffs) / len(diffs)):
return "OSCILLATING"
return "HEALTHY"
def check_activations(self):
issues = []
for name, stats in self.activation_stats.items():
if stats["fraction_zero"] > 0.5:
issues.append(f"DEAD_NEURONS: {name} has {stats['fraction_zero']:.0%} zero activations")
if abs(stats["mean"]) > 10:
issues.append(f"EXPLODING_ACTIVATIONS: {name} mean={stats['mean']:.2f}")
if stats["std"] < 1e-6:
issues.append(f"COLLAPSED_ACTIVATIONS: {name} std={stats['std']:.2e}")
return issues if issues else ["HEALTHY"]
def check_gradients(self):
issues = []
grad_magnitudes = []
for name, stats in self.gradient_stats.items():
grad_magnitudes.append((name, stats["abs_mean"]))
if stats["abs_mean"] < 1e-7:
issues.append(f"VANISHING_GRADIENT: {name} abs_mean={stats['abs_mean']:.2e}")
if stats["abs_mean"] > 100:
issues.append(f"EXPLODING_GRADIENT: {name} abs_mean={stats['abs_mean']:.2e}")
if len(grad_magnitudes) >= 2:
first_mag = grad_magnitudes[0][1]
last_mag = grad_magnitudes[-1][1]
if last_mag > 0 and first_mag / last_mag > 100:
issues.append(f"GRADIENT_RATIO: first/last = {first_mag/last_mag:.0f}x (vanishing)")
return issues if issues else ["HEALTHY"]
def print_report(self):
print("\n=== NETWORK DEBUGGER REPORT ===")
print(f"\nLoss health: {self.check_loss_health()}")
if self.loss_history:
print(f" Last 5 losses: {[f'{v:.4f}' for v in self.loss_history[-5:]]}")
print("\nActivation diagnostics:")
for item in self.check_activations():
print(f" {item}")
print("\nGradient diagnostics:")
for item in self.check_gradients():
print(f" {item}")
print("\nPer-layer activation stats:")
for name, stats in self.activation_stats.items():
print(f" {name}: mean={stats['mean']:.4f} std={stats['std']:.4f} zero={stats['fraction_zero']:.1%}")
print("\nPer-layer gradient stats:")
for name, stats in self.gradient_stats.items():
print(f" {name}: abs_mean={stats['abs_mean']:.2e} max={stats['max']:.2e}")
def remove_hooks(self):
for hook in self.hooks:
hook.remove()
self.hooks.clear()
Paso 2: La Prueba de Sobreajustar un Batch
def overfit_one_batch(model, x_batch, y_batch, criterion, lr=0.01, steps=200):
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
model.train()
print("\n=== OVERFIT ONE BATCH TEST ===")
print(f"Batch size: {x_batch.shape[0]}, Steps: {steps}")
for step in range(steps):
optimizer.zero_grad()
output = model(x_batch)
loss = criterion(output, y_batch)
loss.backward()
optimizer.step()
if step % 50 == 0 or step == steps - 1:
with torch.no_grad():
preds = (output > 0).float() if output.shape[-1] == 1 else output.argmax(dim=1)
targets = y_batch if y_batch.dim() == 1 else y_batch.squeeze()
acc = (preds.squeeze() == targets).float().mean().item()
print(f" Step {step:3d} | Loss: {loss.item():.6f} | Accuracy: {acc:.1%}")
final_loss = loss.item()
if final_loss > 0.1:
print(f"\n FAIL: Loss did not converge ({final_loss:.4f}). Model or training loop is broken.")
return False
print(f"\n PASS: Loss converged to {final_loss:.6f}")
return True
Paso 3: Buscador de Tasa de Aprendizaje
def find_learning_rate(model, x_data, y_data, criterion, start_lr=1e-7, end_lr=10, steps=100):
import copy
original_state = copy.deepcopy(model.state_dict())
optimizer = torch.optim.SGD(model.parameters(), lr=start_lr)
lr_mult = (end_lr / start_lr) ** (1 / steps)
model.train()
results = []
best_loss = float("inf")
current_lr = start_lr
print("\n=== LEARNING RATE FINDER ===")
for step in range(steps):
optimizer.zero_grad()
output = model(x_data)
loss = criterion(output, y_data)
if math.isnan(loss.item()) or loss.item() > best_loss * 10:
break
best_loss = min(best_loss, loss.item())
results.append((current_lr, loss.item()))
loss.backward()
optimizer.step()
current_lr *= lr_mult
for param_group in optimizer.param_groups:
param_group["lr"] = current_lr
model.load_state_dict(original_state)
if len(results) < 10:
print(" Could not complete LR sweep -- loss diverged too quickly")
return results
min_loss_idx = min(range(len(results)), key=lambda i: results[i][1])
suggested_lr = results[max(0, min_loss_idx - 10)][0]
print(f" Swept {len(results)} steps from {start_lr:.0e} to {results[-1][0]:.0e}")
print(f" Minimum loss {results[min_loss_idx][1]:.4f} at lr={results[min_loss_idx][0]:.2e}")
print(f" Suggested learning rate: {suggested_lr:.2e}")
return results
Paso 4: Verificador de Gradiente
def _flat_to_multi_index(flat_idx, shape):
multi_idx = []
remaining = flat_idx
for dim in reversed(shape):
multi_idx.insert(0, remaining % dim)
remaining //= dim
return tuple(multi_idx)
def gradient_check(model, x, y, criterion, eps=1e-4):
model.train()
x_double = x.double()
y_double = y.double()
model_double = model.double()
print("\n=== GRADIENT CHECK ===")
overall_max_diff = 0
checked = 0
for name, param in model_double.named_parameters():
if not param.requires_grad:
continue
layer_max_diff = 0
model_double.zero_grad()
output = model_double(x_double)
loss = criterion(output, y_double)
loss.backward()
analytical_grad = param.grad.clone()
num_checks = min(5, param.numel())
for i in range(num_checks):
idx = _flat_to_multi_index(i, param.shape)
original = param.data[idx].item()
param.data[idx] = original + eps
with torch.no_grad():
loss_plus = criterion(model_double(x_double), y_double).item()
param.data[idx] = original - eps
with torch.no_grad():
loss_minus = criterion(model_double(x_double), y_double).item()
param.data[idx] = original
numerical = (loss_plus - loss_minus) / (2 * eps)
analytical = analytical_grad[idx].item()
denom = max(abs(numerical), abs(analytical), 1e-8)
rel_diff = abs(numerical - analytical) / denom
layer_max_diff = max(layer_max_diff, rel_diff)
checked += 1
overall_max_diff = max(overall_max_diff, layer_max_diff)
status = "OK" if layer_max_diff < 1e-5 else "MISMATCH"
print(f" {name}: max_rel_diff={layer_max_diff:.2e} [{status}]")
model.float()
print(f"\n Checked {checked} parameters")
if overall_max_diff < 1e-5:
print(" PASS: Gradients match (rel_diff < 1e-5)")
elif overall_max_diff < 1e-3:
print(" WARN: Small differences (1e-5 < rel_diff < 1e-3)")
else:
print(" FAIL: Gradient mismatch detected (rel_diff > 1e-3)")
return overall_max_diff
Paso 5: Redes Deliberadamente Rotas
Ahora aplica el kit a redes rotas y diagnostica cada una.
def demo_broken_networks():
torch.manual_seed(42)
x = torch.randn(64, 10)
y = (x[:, 0] > 0).long()
print("\n" + "=" * 60)
print("BUG 1: Learning rate too high (lr=10)")
print("=" * 60)
model1 = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 2))
debugger1 = NetworkDebugger(model1)
optimizer1 = torch.optim.SGD(model1.parameters(), lr=10.0)
criterion = nn.CrossEntropyLoss()
for step in range(20):
optimizer1.zero_grad()
out = model1(x)
loss = criterion(out, y)
debugger1.record_loss(loss.item())
loss.backward()
optimizer1.step()
debugger1.print_report()
debugger1.remove_hooks()
print("\n" + "=" * 60)
print("BUG 2: Dead ReLUs from bad initialization")
print("=" * 60)
model2 = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 2))
with torch.no_grad():
for m in model2.modules():
if isinstance(m, nn.Linear):
m.weight.fill_(-1.0)
m.bias.fill_(-5.0)
debugger2 = NetworkDebugger(model2)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=1e-3)
for step in range(50):
optimizer2.zero_grad()
out = model2(x)
loss = criterion(out, y)
debugger2.record_loss(loss.item())
loss.backward()
optimizer2.step()
debugger2.print_report()
debugger2.remove_hooks()
print("\n" + "=" * 60)
print("BUG 3: Missing zero_grad (gradients accumulate)")
print("=" * 60)
model3 = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 2))
debugger3 = NetworkDebugger(model3)
optimizer3 = torch.optim.SGD(model3.parameters(), lr=0.01)
for step in range(50):
out = model3(x)
loss = criterion(out, y)
debugger3.record_loss(loss.item())
loss.backward()
optimizer3.step()
debugger3.print_report()
debugger3.remove_hooks()
print("\n" + "=" * 60)
print("HEALTHY NETWORK: Correct setup for comparison")
print("=" * 60)
model_good = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 2))
debugger_good = NetworkDebugger(model_good)
optimizer_good = torch.optim.Adam(model_good.parameters(), lr=1e-3)
for step in range(50):
optimizer_good.zero_grad()
out = model_good(x)
loss = criterion(out, y)
debugger_good.record_loss(loss.item())
loss.backward()
optimizer_good.step()
debugger_good.print_report()
debugger_good.remove_hooks()
print("\n" + "=" * 60)
print("OVERFIT-ONE-BATCH TEST (healthy model)")
print("=" * 60)
model_test = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 2))
overfit_one_batch(model_test, x[:8], y[:8], criterion)
print("\n" + "=" * 60)
print("LEARNING RATE FINDER")
print("=" * 60)
model_lr = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 2))
find_learning_rate(model_lr, x, y, criterion)
print("\n" + "=" * 60)
print("GRADIENT CHECK")
print("=" * 60)
model_grad = nn.Sequential(nn.Linear(10, 8), nn.ReLU(), nn.Linear(8, 2))
gradient_check(model_grad, x[:4], y[:4], criterion)
Úsalo
Herramientas Nativas de PyTorch
import torch
import torch.nn as nn
model = nn.Sequential(
nn.Linear(768, 256),
nn.ReLU(),
nn.Linear(256, 10),
)
with torch.autograd.detect_anomaly():
output = model(input_tensor)
loss = criterion(output, target)
loss.backward()
for name, param in model.named_parameters():
if param.grad is not None:
print(f"{name}: grad_mean={param.grad.abs().mean():.2e}")
Integración con Weights & Biases
import wandb
wandb.init(project="debug-training")
for epoch in range(100):
loss = train_one_epoch()
wandb.log({
"loss": loss,
"lr": optimizer.param_groups[0]["lr"],
"grad_norm": torch.nn.utils.clip_grad_norm_(model.parameters(), float("inf")),
})
for name, param in model.named_parameters():
if param.grad is not None:
wandb.log({f"grad/{name}": wandb.Histogram(param.grad.cpu().numpy())})
TensorBoard
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("runs/debug_experiment")
for epoch in range(100):
loss = train_one_epoch()
writer.add_scalar("Loss/train", loss, epoch)
for name, param in model.named_parameters():
writer.add_histogram(f"weights/{name}", param, epoch)
if param.grad is not None:
writer.add_histogram(f"gradients/{name}", param.grad, epoch)
La Lista de Verificación de Depuración (Antes del Entrenamiento Completo)
- Ejecuta la prueba de sobreajustar un batch. Si falla, detente.
- Imprime el resumen del modelo -- verifica que el conteo de parámetros sea razonable.
- Ejecuta un único paso hacia adelante con datos aleatorios -- verifica la forma de la salida.
- Entrena durante 5 épocas -- verifica que la pérdida disminuya.
- Verifica las estadísticas de activación -- sin capas muertas, sin explosiones.
- Verifica el flujo de gradiente -- sin gradientes que se desvanecen, sin explosiones.
- Verifica el pipeline de datos -- imprime 5 muestras aleatorias con etiquetas.
Entrégalo
Esta lección produce:
outputs/prompt-nn-debugger.md-- un prompt para diagnosticar fallas de entrenamiento de redes neuronalesoutputs/skill-debug-checklist.md-- una lista de verificación en árbol de decisión para depurar problemas de entrenamiento
Patrones clave de despliegue para depuración:
- Agrega hooks de monitoreo a los scripts de entrenamiento en producción
- Registra estadísticas de activación y gradiente en W&B o TensorBoard cada N pasos
- Implementa alertas automáticas para pérdida NaN, neuronas muertas (>80% cero) o explosión de gradiente
- Siempre ejecuta la prueba de sobreajustar un batch al cambiar arquitecturas o pipelines de datos
Ejercicios
Agrega un detector de gradiente que explota. Modifica el
NetworkDebuggerpara detectar cuándo los gradientes exceden un umbral y sugerir automáticamente un valor de recorte de gradiente. Pruébalo en una red de 20 capas sin normalización.Construye un resucitador de neuronas muertas. Escribe una función que identifique neuronas ReLU muertas (que siempre producen 0) y reinicialice sus pesos de entrada con la inicialización de Kaiming. Demuestra que esto recupera una red donde >70% de las neuronas están muertas.
Implementa el buscador de tasa de aprendizaje con graficación. Extiende
find_learning_ratepara guardar los resultados como un CSV y escribe un script separado que lea el CSV y muestre la curva de LR frente a pérdida usando matplotlib. Identifica la LR óptima para ResNet-18 en CIFAR-10.Crea un validador de pipeline de datos. Escribe una función que verifique: muestras duplicadas entre las divisiones de entrenamiento/prueba, desbalance en la distribución de etiquetas (proporción >10:1), normalización de las entradas (media cercana a 0, std cercano a 1) y valores NaN/Inf en los datos. Ejecútala en un conjunto de datos deliberadamente corrompido.
Depura una falla real. Toma el mini-framework de la Lección 10, introduce un bug sutil (por ejemplo, transponer la matriz de pesos en el backward) y usa la verificación de gradiente para localizar exactamente qué parámetro tiene gradientes incorrectos. Documenta el proceso de depuración.
Términos Clave
| Término | Lo que la gente dice | Lo que realmente significa |
|---|---|---|
| Bug silencioso | "Se ejecuta pero da malos resultados" | Un bug que no produce error pero degrada la calidad del modelo -- el modo de falla dominante en ML |
| ReLU muerta | "Las neuronas murieron" | Una neurona ReLU cuya entrada siempre es negativa, así que produce 0 y recibe gradiente 0 permanentemente |
| Gradientes que se desvanecen | "Las primeras capas dejan de aprender" | Los gradientes se encogen exponencialmente a través de las capas, dejando los pesos de las primeras capas efectivamente congelados |
| Gradientes que explotan | "La pérdida se fue a NaN" | Los gradientes crecen exponencialmente a través de las capas, causando actualizaciones de pesos tan grandes que se desbordan (overflow) |
| Verificación de gradiente | "Verificar que la retropropagación sea correcta" | Comparar gradientes analíticos de la retropropagación con gradientes numéricos de diferencias finitas |
| Sobreajustar un batch | "La prueba de depuración más importante" | Entrenar con un único batch pequeño para verificar que el modelo PUEDE aprender -- si no puede, algo está fundamentalmente roto |
| Buscador de LR | "Barrer para encontrar la tasa de aprendizaje correcta" | Aumentar exponencialmente la tasa de aprendizaje a lo largo de una época y elegir la tasa justo antes de que la pérdida diverja |
| Fuga de datos | "Datos de prueba se filtraron al entrenamiento" | Cuando información del conjunto de prueba contamina el entrenamiento, produciendo una precisión artificialmente alta |
| Estadísticas de activación | "Monitorear la salud de la capa" | Rastrear la media, el std y la fracción de ceros de la salida de cada capa para detectar neuronas muertas, saturadas o que explotan |
| Recorte de gradiente | "Limitar la magnitud del gradiente" | Reducir la escala de los gradientes cuando su norma excede un umbral, previniendo actualizaciones por gradientes que explotan |
Lecturas Adicionales
- Smith, "Cyclical Learning Rates for Training Neural Networks" (2017) -- el artículo que introduce la prueba de rango de tasa de aprendizaje (buscador de LR)
- Northcutt et al., "Pervasive Label Errors in Test Sets Destabilize Machine Learning Benchmarks" (2021) -- demuestra que el 3-6% de las etiquetas en ImageNet, CIFAR-10 y otros benchmarks importantes están mal
- Zhang et al., "Understanding Deep Learning Requires Rethinking Generalization" (2017) -- el artículo que muestra que las redes neuronales pueden memorizar etiquetas aleatorias, que es la razón por la que funciona la prueba de sobreajustar un batch
- Documentación de PyTorch sobre
torch.autograd.detect_anomalyytorch.autograd.set_detect_anomalypara detección nativa de NaN/Inf