Phase 04 - Lesson 07

Segmentación Semántica — U-Net

La segmentación es clasificación en cada píxel. La U-Net lo hace funcionar al emparejar un codificador de reducción de resolución con un decodificador de aumento de resolución y conectar conexiones de salto entre ellos.

Tipo: Build Lenguajes: Python Prerrequisitos: Fase 4 Lección 03 (CNNs), Fase 4 Lección 04 (Clasificación de Imágenes) Tiempo: ~75 minutos

Objetivos de Aprendizaje

  • Distinguir segmentación semántica, de instancia y panóptica y elegir la tarea correcta para un problema dado
  • Construir una U-Net desde cero en PyTorch con bloques de codificador, un cuello de botella, un decodificador con convoluciones transpuestas y conexiones de salto
  • Implementar entropía cruzada por píxel, pérdida Dice y la pérdida combinada que es el estándar actual para segmentación médica e industrial
  • Leer métricas de IoU y Dice por clase y diagnosticar si una puntuación mala proviene de la exhaustividad (recall) de objetos pequeños, la precisión de bordes o el desbalanceo de clases

El Problema

La clasificación produce una etiqueta por imagen. La detección produce un puñado de cajas por imagen. La segmentación produce una etiqueta por píxel. Para una entrada de tamaño H x W, la salida es un tensor de forma H x W (semántica) o H x W x N_instances (instancia). Son millones de predicciones por imagen, no una.

La estructura de la segmentación es la razón por la que impulsa casi todo producto de visión de predicción densa: imagen médica (máscaras de tumor), conducción autónoma (carretera, carril, obstáculo), satélite (huellas de edificios, límites de cultivos), análisis de documentos (zonas de diseño), robótica (regiones agarrables). Ninguna de esas tareas puede resolverse poniendo una caja alrededor del objeto; necesitan la silueta exacta.

El problema arquitectónico es simple de enunciar y no tan simple de resolver: necesitas que la red vea el contexto global de una imagen (qué tipo de escena es esta) y el detalle local del píxel (exactamente qué píxel es carretera vs acera) simultáneamente. Una CNN estándar comprime espacialmente para ganar contexto, lo cual descarta el detalle. La U-Net fue el diseño que consiguió ambos.

El Concepto

Semántica vs instancia vs panóptica

flowchart LR
    IN["Imagen de entrada"] --> SEM["Semántica<br/>(píxel → clase)"]
    IN --> INS["Instancia<br/>(píxel → id de objeto,<br/>solo clases de primer plano)"]
    IN --> PAN["Panóptica<br/>(cada píxel → clase + id)"]

    style SEM fill:#dbeafe,stroke:#2563eb
    style INS fill:#fef3c7,stroke:#d97706
    style PAN fill:#dcfce7,stroke:#16a34a
  • Semántica dice "este píxel es carretera, ese píxel es auto." Dos autos uno al lado del otro colapsan en una sola mancha.
  • Instancia dice "este píxel es el auto #3, ese píxel es el auto #5." Ignora el material de fondo ("stuff" = cielo, carretera, pasto).
  • Panóptica unifica ambos: cada píxel recibe una etiqueta de clase, cada instancia recibe un id único, tanto material de fondo como objetos se segmentan.

Esta lección cubre la semántica. La próxima lección (Mask R-CNN) cubre la instancia.

La forma de la U-Net

flowchart LR
    subgraph ENC["Codificador (contrayendo)"]
        E1["64<br/>H x W"] --> E2["128<br/>H/2 x W/2"]
        E2 --> E3["256<br/>H/4 x W/4"]
        E3 --> E4["512<br/>H/8 x W/8"]
    end
    subgraph BOT["Cuello de botella"]
        B1["1024<br/>H/16 x W/16"]
    end
    subgraph DEC["Decodificador (expandiendo)"]
        D4["512<br/>H/8 x W/8"] --> D3["256<br/>H/4 x W/4"]
        D3 --> D2["128<br/>H/2 x W/2"]
        D2 --> D1["64<br/>H x W"]
    end
    E4 --> B1 --> D4
    E1 -. salto .-> D1
    E2 -. salto .-> D2
    E3 -. salto .-> D3
    E4 -. salto .-> D4
    D1 --> OUT["conv 1x1<br/>clases"]

    style ENC fill:#dbeafe,stroke:#2563eb
    style BOT fill:#fef3c7,stroke:#d97706
    style DEC fill:#dcfce7,stroke:#16a34a

El codificador reduce a la mitad la resolución espacial cuatro veces y duplica los canales. El decodificador revierte: duplica la resolución espacial cuatro veces y reduce a la mitad los canales. Las conexiones de salto concatenan características correspondientes del codificador con características del decodificador en cada resolución. La conv 1x1 final mapea 64 -> num_classes a resolución completa.

Por qué las conexiones de salto son necesarias: el decodificador solo ha visto mapas de características pequeños para cuando intenta producir predicciones a nivel de píxel. Sin los saltos no puede localizar bordes con precisión porque esa información fue comprimida en el codificador. Las conexiones de salto le entregan los mapas de características de alta resolución que el codificador computó en el camino de bajada.

Upsample transpuesto vs bilineal

El decodificador tiene que expandir las dimensiones espaciales. Dos opciones:

  • Convolución transpuesta (nn.ConvTranspose2d) — upsample aprendible. Estándar histórico de la U-Net. Puede producir artefactos de tablero de ajedrez si el paso y el tamaño del kernel no dividen uniformemente.
  • Upsample bilineal + conv 3x3 — upsample suave seguido de una conv. Menos artefactos, menos parámetros, ahora el estándar moderno.

Ambos aparecen en la práctica. Para una primera U-Net, bilineal es más seguro.

Entropía cruzada en una grilla de píxeles

Para segmentación semántica con C clases, la salida del modelo es (N, C, H, W). El objetivo es (N, H, W) con IDs de clase enteros. La entropía cruzada es idéntica al caso de clasificación, solo que aplicada en cada posición espacial:

Loss = mean over (n, h, w) of -log( softmax(logits[n, :, h, w])[target[n, h, w]] )

F.cross_entropy en PyTorch maneja esta forma de manera nativa. No se necesita reshape.

Pérdida Dice y por qué la necesitas

La entropía cruzada trata cada píxel por igual. Eso está mal cuando una clase domina el cuadro (imagen médica: 99% fondo, 1% tumor). La red puede alcanzar 99% de exactitud prediciendo fondo en todas partes y aun así ser inútil.

La pérdida Dice resuelve esto optimizando directamente la superposición entre la máscara predicha y la verdadera:

Dice(p, y) = 2 * sum(p * y) / (sum(p) + sum(y) + epsilon)
Dice_loss = 1 - Dice

donde p es el mapa de probabilidad sigmoide/softmax para una clase e y es la máscara binaria de referencia (ground-truth). La pérdida es cero solo cuando la superposición es perfecta. Como se basa en una razón, el desbalanceo de clases es irrelevante.

En la práctica, usa la pérdida combinada:

L = L_cross_entropy + lambda * L_dice       (lambda ~ 1)

La entropía cruzada da gradientes estables al inicio del entrenamiento; la Dice enfoca la cola del entrenamiento en realmente coincidir con la forma de la máscara. Esta combinación es el estándar de la imagen médica y difícil de superar en cualquier conjunto de datos con clases desbalanceadas.

Métricas de evaluación

  • Exactitud de píxel — porcentaje de píxeles predichos correctamente. Barata. Se rompe en datos desbalanceados por la misma razón que la exactitud en clasificación.
  • IoU por clase — intersección sobre unión para la máscara de cada clase; promedio entre clases = mIoU.
  • Dice (F1 sobre píxeles) — similar al IoU; Dice = 2 * IoU / (1 + IoU). La imagen médica prefiere Dice, la comunidad de conducción prefiere IoU; están relacionados monótonamente.
  • Boundary F1 — mide qué tan cerca están los bordes predichos de los bordes de referencia, penalizando incluso pequeños desplazamientos. Importante para tareas de alta precisión como la inspección de semiconductores.

Reporta IoU por clase, no solo mIoU. El IoU medio oculta una clase en 15% cuando otras nueve están en 85%.

Compromiso de resolución de entrada

El codificador de la U-Net reduce a la mitad la resolución cuatro veces, así que la entrada debe ser divisible por 16. Las imágenes médicas suelen ser 512x512 o 1024x1024. Los recortes de conducción autónoma son 2048x1024. El costo de memoria de la U-Net escala con H * W * C_max, y a 1024x1024 con 1024 canales de cuello de botella el forward pass ya usa gigabytes de VRAM.

Dos soluciones estándar:

  1. Dividir la entrada en mosaicos — procesar mosaicos de 256x256 con superposición y coserlos.
  2. Reemplazar el cuello de botella con convoluciones dilatadas que mantienen la resolución espacial más alta pero amplían el campo receptivo (la familia DeepLab).

Para un primer modelo, una entrada de 256x256 con una U-Net de base de 64 canales entrena cómodamente en 8 GB de VRAM.

Constrúyelo

Paso 1: Bloque de codificador

Dos convs 3x3 con batch norm y ReLU. La primera conv cambia el conteo de canales; la segunda lo mantiene.

import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)

Este bloque se reutiliza en todo el modelo. bias=False porque el beta de la BN maneja el sesgo.

Paso 2: Bloques de bajada y subida

class Down(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.net = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_c, out_c),
        )

    def forward(self, x):
        return self.net(x)


class Up(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
        self.conv = DoubleConv(in_c, out_c)

    def forward(self, x, skip):
        x = self.up(x)
        if x.shape[-2:] != skip.shape[-2:]:
            x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
        x = torch.cat([skip, x], dim=1)
        return self.conv(x)

La verificación de forma solo espacial (shape[-2:]) maneja entradas cuyas dimensiones no son divisibles por 16; un F.interpolate seguro alinea el tensor antes de la concatenación. Comparar la forma completa también se dispararía ante diferencias en el conteo de canales, lo cual debería ser un error fuerte y ruidoso, no una interpolación silenciosa.

Paso 3: La U-Net

class UNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=2, base=64):
        super().__init__()
        self.inc = DoubleConv(in_channels, base)
        self.d1 = Down(base, base * 2)
        self.d2 = Down(base * 2, base * 4)
        self.d3 = Down(base * 4, base * 8)
        self.d4 = Down(base * 8, base * 16)
        self.u1 = Up(base * 16 + base * 8, base * 8)
        self.u2 = Up(base * 8 + base * 4, base * 4)
        self.u3 = Up(base * 4 + base * 2, base * 2)
        self.u4 = Up(base * 2 + base, base)
        self.outc = nn.Conv2d(base, num_classes, kernel_size=1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.d1(x1)
        x3 = self.d2(x2)
        x4 = self.d3(x3)
        x5 = self.d4(x4)
        x = self.u1(x5, x4)
        x = self.u2(x, x3)
        x = self.u3(x, x2)
        x = self.u4(x, x1)
        return self.outc(x)

net = UNet(in_channels=3, num_classes=2, base=32)
x = torch.randn(1, 3, 256, 256)
print(f"output: {net(x).shape}")
print(f"params: {sum(p.numel() for p in net.parameters()):,}")

Forma de salida (1, 2, 256, 256) — mismo tamaño espacial que la entrada, num_classes canales. Cerca de 7,7M de parámetros con base=32.

Paso 4: Pérdidas

def dice_loss(logits, targets, num_classes, eps=1e-6):
    probs = F.softmax(logits, dim=1)
    targets_one_hot = F.one_hot(targets, num_classes).permute(0, 3, 1, 2).float()
    dims = (0, 2, 3)
    intersection = (probs * targets_one_hot).sum(dim=dims)
    denom = probs.sum(dim=dims) + targets_one_hot.sum(dim=dims)
    dice = (2 * intersection + eps) / (denom + eps)
    return 1 - dice.mean()


def combined_loss(logits, targets, num_classes, lam=1.0):
    ce = F.cross_entropy(logits, targets)
    dc = dice_loss(logits, targets, num_classes)
    return ce + lam * dc, {"ce": ce.item(), "dice": dc.item()}

La Dice se computa por clase y luego se promedia (Dice macro). El eps previene la división por cero en clases ausentes del lote.

Paso 5: Métrica de IoU

@torch.no_grad()
def iou_per_class(logits, targets, num_classes):
    preds = logits.argmax(dim=1)
    ious = torch.zeros(num_classes)
    for c in range(num_classes):
        pred_c = (preds == c)
        true_c = (targets == c)
        inter = (pred_c & true_c).sum().float()
        union = (pred_c | true_c).sum().float()
        ious[c] = (inter / union) if union > 0 else torch.tensor(float("nan"))
    return ious

Retorna un vector de longitud C. nan marca clases ausentes del lote — no promedies sobre ellas al computar mIoU.

Paso 6: Conjunto de datos sintético para verificación de extremo a extremo

Genera formas sobre fondos de colores para que la red tenga que aprender la forma, no el color del píxel.

import numpy as np
from torch.utils.data import Dataset, DataLoader

def synthetic_segmentation(num_samples=200, size=64, seed=0):
    rng = np.random.default_rng(seed)
    images = np.zeros((num_samples, size, size, 3), dtype=np.float32)
    masks = np.zeros((num_samples, size, size), dtype=np.int64)
    for i in range(num_samples):
        bg = rng.uniform(0, 1, (3,))
        images[i] = bg
        masks[i] = 0
        num_shapes = rng.integers(1, 4)
        for _ in range(num_shapes):
            cls = int(rng.integers(1, 3))
            color = rng.uniform(0, 1, (3,))
            cx, cy = rng.integers(10, size - 10, size=2)
            r = int(rng.integers(4, 12))
            yy, xx = np.meshgrid(np.arange(size), np.arange(size), indexing="ij")
            if cls == 1:
                mask = (xx - cx) ** 2 + (yy - cy) ** 2 < r ** 2
            else:
                mask = (np.abs(xx - cx) < r) & (np.abs(yy - cy) < r)
            images[i][mask] = color
            masks[i][mask] = cls
        images[i] += rng.normal(0, 0.02, images[i].shape)
        images[i] = np.clip(images[i], 0, 1)
    return images, masks


class SegDataset(Dataset):
    def __init__(self, images, masks):
        self.images = images
        self.masks = masks

    def __len__(self):
        return len(self.images)

    def __getitem__(self, i):
        img = torch.from_numpy(self.images[i]).permute(2, 0, 1).float()
        mask = torch.from_numpy(self.masks[i]).long()
        return img, mask

Tres clases: fondo (0), círculos (1), cuadrados (2). La red debe aprender a distinguir la forma.

Paso 7: Bucle de entrenamiento

def train_one_epoch(model, loader, optimizer, device, num_classes):
    model.train()
    loss_sum, total = 0.0, 0
    iou_sum = torch.zeros(num_classes)
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss, _ = combined_loss(logits, y, num_classes)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_sum += loss.item() * x.size(0)
        total += x.size(0)
        iou_sum += iou_per_class(logits, y, num_classes).nan_to_num(0)
    return loss_sum / total, iou_sum / len(loader)

Ejecuta esto por 10-30 épocas en el conjunto de datos sintético y observa la mIoU subir más allá de 0,9 para las clases de forma. Nota que el nan_to_num(0) trata las clases ausentes de un lote como cero; para una IoU por clase precisa, enmascara por presencia y usa torch.nanmean entre lotes en el momento de la evaluación en lugar de promediar aquí.

Úsalo

Para producción, segmentation_models_pytorch ("smp") envuelve toda arquitectura de segmentación estándar con cualquier backbone de torchvision o timm. Tres líneas:

import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=3,
    classes=3,
)

También vale la pena conocer para trabajo real:

  • DeepLabV3+ reemplaza la reducción de resolución basada en max-pool con convs dilatadas para que el cuello de botella mantenga la resolución; bordes más rápidos en datos de satélite y conducción.
  • SegFormer intercambia el codificador conv por un transformer jerárquico; SOTA actual en muchos benchmarks.
  • Mask2Former / OneFormer unifican la segmentación semántica, de instancia y panóptica en una sola arquitectura.

Las tres son reemplazos directos en smp o transformers con el mismo cargador de datos.

Entrégalo

Esta lección produce:

  • outputs/prompt-segmentation-task-picker.md — un prompt que elige entre segmentación semántica, de instancia y panóptica y nombra la arquitectura para una tarea dada.
  • outputs/skill-segmentation-mask-inspector.md — una skill que reporta la distribución de clases, estadísticas de la máscara predicha y las clases que están sub-predichas o con bordes borrosos.

Ejercicios

  1. (Fácil) Implementa bce_dice_loss para una tarea de segmentación binaria (primer plano vs fondo). Verifica en un conjunto de datos sintético de dos clases que la pérdida combinada converge más rápido que la BCE sola cuando el primer plano es el 5% de los píxeles.
  2. (Medio) Reemplaza el bloque de subida nn.Upsample + conv con un bloque de subida nn.ConvTranspose2d. Entrena ambos en el conjunto de datos sintético y compara la mIoU. Observa dónde aparecen los artefactos de tablero de ajedrez en la versión con conv transpuesta.
  3. (Difícil) Toma un conjunto de datos de segmentación real (Oxford-IIIT Pets, split mini de Cityscapes o un subconjunto médico) y entrena la U-Net hasta quedar a 2 puntos de IoU de la referencia smp.Unet. Reporta la IoU por clase e identifica qué clases se benefician más de agregar la Dice a la pérdida.

Términos Clave

Término Lo que la gente dice Lo que realmente significa
Segmentación semántica "Etiquetar cada píxel" Clasificación por píxel en C clases; instancias de la misma clase se fusionan
Segmentación de instancia "Etiquetar cada objeto" Separa instancias distintas de la misma clase; solo primer plano
Segmentación panóptica "Semántica + instancia" Cada píxel recibe una clase; cada instancia de objeto también recibe un id único
Conexión de salto "Puente de la U-Net" Concatenación de características del codificador en características del decodificador de resolución correspondiente; preserva el detalle de alta frecuencia
Conv transpuesta "Deconvolución" Upsampling aprendible; puede producir artefactos de tablero de ajedrez
Pérdida Dice "Pérdida de superposición" 1 - 2
mIoU "Intersección media sobre unión" IoU media entre clases; la métrica estándar de la comunidad para segmentación
Boundary F1 "Precisión de borde" Puntuación F1 computada solo sobre píxeles de borde; importa para tareas críticas en precisión

Lectura Adicional

0 lifetime access. Curriculum based on AI Engineering from Scratch by Rohit Ghumare (MIT, used under attribution).