Phase 04 - Lesson 07
Segmentação Semântica — U-Net
Segmentação é classificação em cada pixel. A U-Net faz isso funcionar ao parear um codificador de redução de resolução com um decodificador de aumento de resolução e conectar conexões de salto entre eles.
Tipo: Build Linguagens: Python Pré-requisitos: Fase 4 Lição 03 (CNNs), Fase 4 Lição 04 (Classificação de Imagens) Tempo: ~75 minutos
Objetivos de Aprendizagem
- Distinguir segmentação semântica, de instância e panóptica e escolher a tarefa certa para um dado problema
- Construir uma U-Net do zero em PyTorch com blocos de codificador, um gargalo, um decodificador com convoluções transpostas e conexões de salto
- Implementar entropia cruzada por pixel, perda Dice e a perda combinada que é o padrão atual para segmentação médica e industrial
- Ler métricas de IoU e Dice por classe e diagnosticar se uma pontuação ruim vem de revocação de objetos pequenos, precisão de bordas ou desbalanceamento de classes
O Problema
A classificação produz um rótulo por imagem. A detecção produz um punhado de caixas por imagem. A segmentação produz um rótulo por pixel. Para uma entrada de tamanho H x W, a saída é um tensor de formato H x W (semântica) ou H x W x N_instances (instância). São milhões de previsões por imagem, não uma.
A estrutura da segmentação é o motivo de ela impulsionar quase todo produto de visão de predição densa: imagem médica (máscaras de tumor), direção autônoma (estrada, faixa, obstáculo), satélite (pegadas de edifícios, limites de plantações), análise de documentos (zonas de layout), robótica (regiões agarráveis). Nenhuma dessas tarefas pode ser resolvida colocando uma caixa ao redor do objeto; elas precisam da silhueta exata.
O problema arquitetural é simples de enunciar e não tão simples de resolver: você precisa que a rede veja o contexto global de uma imagem (que tipo de cena é esta) e o detalhe local de pixel (exatamente qual pixel é estrada vs calçada) simultaneamente. Uma CNN padrão comprime espacialmente para ganhar contexto, o que descarta o detalhe. A U-Net foi o projeto que conseguiu ambos.
O Conceito
Semântica vs instância vs panóptica
flowchart LR
IN["Imagem de entrada"] --> SEM["Semântica<br/>(pixel → classe)"]
IN --> INS["Instância<br/>(pixel → id do objeto,<br/>apenas classes de primeiro plano)"]
IN --> PAN["Panóptica<br/>(cada pixel → classe + id)"]
style SEM fill:#dbeafe,stroke:#2563eb
style INS fill:#fef3c7,stroke:#d97706
style PAN fill:#dcfce7,stroke:#16a34a
- Semântica diz "este pixel é estrada, aquele pixel é carro." Dois carros lado a lado colapsam em uma única mancha.
- Instância diz "este pixel é o carro #3, aquele pixel é o carro #5." Ignora o material de fundo ("stuff" = céu, estrada, grama).
- Panóptica unifica ambos: cada pixel recebe um rótulo de classe, cada instância recebe um id único, tanto material de fundo quanto objetos são segmentados.
Esta lição cobre semântica. A próxima lição (Mask R-CNN) cobre instância.
O formato da U-Net
flowchart LR
subgraph ENC["Codificador (contraindo)"]
E1["64<br/>H x W"] --> E2["128<br/>H/2 x W/2"]
E2 --> E3["256<br/>H/4 x W/4"]
E3 --> E4["512<br/>H/8 x W/8"]
end
subgraph BOT["Gargalo"]
B1["1024<br/>H/16 x W/16"]
end
subgraph DEC["Decodificador (expandindo)"]
D4["512<br/>H/8 x W/8"] --> D3["256<br/>H/4 x W/4"]
D3 --> D2["128<br/>H/2 x W/2"]
D2 --> D1["64<br/>H x W"]
end
E4 --> B1 --> D4
E1 -. salto .-> D1
E2 -. salto .-> D2
E3 -. salto .-> D3
E4 -. salto .-> D4
D1 --> OUT["conv 1x1<br/>classes"]
style ENC fill:#dbeafe,stroke:#2563eb
style BOT fill:#fef3c7,stroke:#d97706
style DEC fill:#dcfce7,stroke:#16a34a
O codificador reduz a resolução espacial pela metade quatro vezes e dobra os canais. O decodificador reverte: dobra a resolução espacial quatro vezes e reduz os canais pela metade. As conexões de salto concatenam características correspondentes do codificador com características do decodificador em cada resolução. A conv 1x1 final mapeia 64 -> num_classes em resolução total.
Por que as conexões de salto são necessárias: o decodificador viu apenas mapas de características pequenos no momento em que tenta produzir previsões em nível de pixel. Sem os saltos ele não consegue localizar bordas com precisão porque essa informação foi comprimida no codificador. As conexões de salto entregam a ele os mapas de características de alta resolução que o codificador computou no caminho de descida.
Upsample transposto vs bilinear
O decodificador tem que expandir as dimensões espaciais. Duas opções:
- Convolução transposta (
nn.ConvTranspose2d) — upsample aprendível. Padrão histórico da U-Net. Pode produzir artefatos de tabuleiro de xadrez se o passo e o tamanho do kernel não dividirem uniformemente. - Upsample bilinear + conv 3x3 — upsample suave seguido de uma conv. Menos artefatos, menos parâmetros, agora o padrão moderno.
Ambos aparecem na prática. Para uma primeira U-Net, bilinear é mais seguro.
Entropia cruzada em uma grade de pixels
Para segmentação semântica com C classes, a saída do modelo é (N, C, H, W). O alvo é (N, H, W) com IDs de classe inteiros. A entropia cruzada é idêntica ao caso de classificação, apenas aplicada em cada posição espacial:
Loss = mean over (n, h, w) of -log( softmax(logits[n, :, h, w])[target[n, h, w]] )
F.cross_entropy no PyTorch lida com esse formato nativamente. Nenhum reshape necessário.
Perda Dice e por que você precisa dela
A entropia cruzada trata cada pixel igualmente. Isso é errado quando uma classe domina o quadro (imagem médica: 99% fundo, 1% tumor). A rede pode atingir 99% de acurácia prevendo fundo em todo lugar e ainda ser inútil.
A perda Dice resolve isso otimizando diretamente a sobreposição entre a máscara prevista e a verdadeira:
Dice(p, y) = 2 * sum(p * y) / (sum(p) + sum(y) + epsilon)
Dice_loss = 1 - Dice
onde p é o mapa de probabilidade sigmoide/softmax para uma classe e y é a máscara binária de referência (ground-truth). A perda é zero apenas quando a sobreposição é perfeita. Por ser baseada em razão, o desbalanceamento de classes é irrelevante.
Na prática, use a perda combinada:
L = L_cross_entropy + lambda * L_dice (lambda ~ 1)
A entropia cruzada fornece gradientes estáveis no início do treinamento; a Dice foca a cauda do treinamento em realmente combinar o formato da máscara. Essa combinação é o padrão da imagem médica e difícil de superar em qualquer conjunto de dados com classes desbalanceadas.
Métricas de avaliação
- Acurácia de pixel — percentual de pixels previstos corretamente. Barata. Quebra em dados desbalanceados pelo mesmo motivo que a acurácia na classificação.
- IoU por classe — interseção sobre união para a máscara de cada classe; média entre classes = mIoU.
- Dice (F1 em pixels) — similar ao IoU;
Dice = 2 * IoU / (1 + IoU). A imagem médica prefere Dice, a comunidade de direção prefere IoU; eles são monotonicamente relacionados. - Boundary F1 — mede quão próximas as bordas previstas estão das bordas de referência, penalizando até pequenos deslocamentos. Importante para tarefas de alta precisão como inspeção de semicondutores.
Relate IoU por classe, não apenas mIoU. A IoU média esconde uma classe em 15% quando outras nove estão em 85%.
Compromisso de resolução de entrada
O codificador da U-Net reduz a resolução pela metade quatro vezes, então a entrada deve ser divisível por 16. Imagens médicas são frequentemente 512x512 ou 1024x1024. Recortes de direção autônoma são 2048x1024. O custo de memória da U-Net escala com H * W * C_max, e em 1024x1024 com 1024 canais de gargalo o forward pass já usa gigabytes de VRAM.
Duas soluções padrão:
- Dividir a entrada em ladrilhos — processar ladrilhos de 256x256 com sobreposição e costurar.
- Substituir o gargalo por convoluções dilatadas que mantêm a resolução espacial mais alta mas ampliam o campo receptivo (a família DeepLab).
Para um primeiro modelo, uma entrada de 256x256 com uma U-Net de base de 64 canais treina confortavelmente em 8 GB de VRAM.
Construa
Passo 1: Bloco de codificador
Duas convs 3x3 com batch norm e ReLU. A primeira conv muda a contagem de canais; a segunda a mantém.
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
def __init__(self, in_c, out_c):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True),
nn.Conv2d(out_c, out_c, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.net(x)
Esse bloco é reutilizado em todo o modelo. bias=False porque o beta da BN lida com o viés.
Passo 2: Blocos de descida e subida
class Down(nn.Module):
def __init__(self, in_c, out_c):
super().__init__()
self.net = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_c, out_c),
)
def forward(self, x):
return self.net(x)
class Up(nn.Module):
def __init__(self, in_c, out_c):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
self.conv = DoubleConv(in_c, out_c)
def forward(self, x, skip):
x = self.up(x)
if x.shape[-2:] != skip.shape[-2:]:
x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
x = torch.cat([skip, x], dim=1)
return self.conv(x)
A verificação de formato apenas espacial (shape[-2:]) lida com entradas cujas dimensões não são divisíveis por 16; um F.interpolate seguro alinha o tensor antes da concatenação. Comparar o formato completo também dispararia em diferenças de contagem de canais, o que deveria ser um erro alto e barulhento, não uma interpolação silenciosa.
Passo 3: A U-Net
class UNet(nn.Module):
def __init__(self, in_channels=3, num_classes=2, base=64):
super().__init__()
self.inc = DoubleConv(in_channels, base)
self.d1 = Down(base, base * 2)
self.d2 = Down(base * 2, base * 4)
self.d3 = Down(base * 4, base * 8)
self.d4 = Down(base * 8, base * 16)
self.u1 = Up(base * 16 + base * 8, base * 8)
self.u2 = Up(base * 8 + base * 4, base * 4)
self.u3 = Up(base * 4 + base * 2, base * 2)
self.u4 = Up(base * 2 + base, base)
self.outc = nn.Conv2d(base, num_classes, kernel_size=1)
def forward(self, x):
x1 = self.inc(x)
x2 = self.d1(x1)
x3 = self.d2(x2)
x4 = self.d3(x3)
x5 = self.d4(x4)
x = self.u1(x5, x4)
x = self.u2(x, x3)
x = self.u3(x, x2)
x = self.u4(x, x1)
return self.outc(x)
net = UNet(in_channels=3, num_classes=2, base=32)
x = torch.randn(1, 3, 256, 256)
print(f"output: {net(x).shape}")
print(f"params: {sum(p.numel() for p in net.parameters()):,}")
Formato de saída (1, 2, 256, 256) — mesmo tamanho espacial da entrada, num_classes canais. Cerca de 7,7M de parâmetros com base=32.
Passo 4: Perdas
def dice_loss(logits, targets, num_classes, eps=1e-6):
probs = F.softmax(logits, dim=1)
targets_one_hot = F.one_hot(targets, num_classes).permute(0, 3, 1, 2).float()
dims = (0, 2, 3)
intersection = (probs * targets_one_hot).sum(dim=dims)
denom = probs.sum(dim=dims) + targets_one_hot.sum(dim=dims)
dice = (2 * intersection + eps) / (denom + eps)
return 1 - dice.mean()
def combined_loss(logits, targets, num_classes, lam=1.0):
ce = F.cross_entropy(logits, targets)
dc = dice_loss(logits, targets, num_classes)
return ce + lam * dc, {"ce": ce.item(), "dice": dc.item()}
A Dice é computada por classe e então tirada a média (Dice macro). O eps previne a divisão por zero em classes ausentes do lote.
Passo 5: Métrica de IoU
@torch.no_grad()
def iou_per_class(logits, targets, num_classes):
preds = logits.argmax(dim=1)
ious = torch.zeros(num_classes)
for c in range(num_classes):
pred_c = (preds == c)
true_c = (targets == c)
inter = (pred_c & true_c).sum().float()
union = (pred_c | true_c).sum().float()
ious[c] = (inter / union) if union > 0 else torch.tensor(float("nan"))
return ious
Retorna um vetor de comprimento C. nan marca classes ausentes do lote — não tire a média sobre elas ao computar mIoU.
Passo 6: Conjunto de dados sintético para verificação ponta a ponta
Gere formas em fundos coloridos para que a rede tenha que aprender o formato, não a cor do pixel.
import numpy as np
from torch.utils.data import Dataset, DataLoader
def synthetic_segmentation(num_samples=200, size=64, seed=0):
rng = np.random.default_rng(seed)
images = np.zeros((num_samples, size, size, 3), dtype=np.float32)
masks = np.zeros((num_samples, size, size), dtype=np.int64)
for i in range(num_samples):
bg = rng.uniform(0, 1, (3,))
images[i] = bg
masks[i] = 0
num_shapes = rng.integers(1, 4)
for _ in range(num_shapes):
cls = int(rng.integers(1, 3))
color = rng.uniform(0, 1, (3,))
cx, cy = rng.integers(10, size - 10, size=2)
r = int(rng.integers(4, 12))
yy, xx = np.meshgrid(np.arange(size), np.arange(size), indexing="ij")
if cls == 1:
mask = (xx - cx) ** 2 + (yy - cy) ** 2 < r ** 2
else:
mask = (np.abs(xx - cx) < r) & (np.abs(yy - cy) < r)
images[i][mask] = color
masks[i][mask] = cls
images[i] += rng.normal(0, 0.02, images[i].shape)
images[i] = np.clip(images[i], 0, 1)
return images, masks
class SegDataset(Dataset):
def __init__(self, images, masks):
self.images = images
self.masks = masks
def __len__(self):
return len(self.images)
def __getitem__(self, i):
img = torch.from_numpy(self.images[i]).permute(2, 0, 1).float()
mask = torch.from_numpy(self.masks[i]).long()
return img, mask
Três classes: fundo (0), círculos (1), quadrados (2). A rede deve aprender a distinguir o formato.
Passo 7: Laço de treinamento
def train_one_epoch(model, loader, optimizer, device, num_classes):
model.train()
loss_sum, total = 0.0, 0
iou_sum = torch.zeros(num_classes)
for x, y in loader:
x, y = x.to(device), y.to(device)
logits = model(x)
loss, _ = combined_loss(logits, y, num_classes)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_sum += loss.item() * x.size(0)
total += x.size(0)
iou_sum += iou_per_class(logits, y, num_classes).nan_to_num(0)
return loss_sum / total, iou_sum / len(loader)
Rode isso por 10-30 épocas no conjunto de dados sintético e veja a mIoU subir além de 0,9 para as classes de formato. Note que o nan_to_num(0) trata classes ausentes de um lote como zero; para uma IoU por classe precisa, mascare por presença e use torch.nanmean entre lotes no momento da avaliação em vez de tirar a média aqui.
Use
Para produção, segmentation_models_pytorch ("smp") encapsula toda arquitetura de segmentação padrão com qualquer backbone do torchvision ou timm. Três linhas:
import segmentation_models_pytorch as smp
model = smp.Unet(
encoder_name="resnet34",
encoder_weights="imagenet",
in_channels=3,
classes=3,
)
Também vale conhecer para trabalho real:
- DeepLabV3+ substitui a redução de resolução baseada em max-pool por convs dilatadas para que o gargalo mantenha a resolução; bordas mais rápidas em dados de satélite e direção.
- SegFormer troca o codificador conv por um transformer hierárquico; SOTA atual em muitos benchmarks.
- Mask2Former / OneFormer unificam segmentação semântica, de instância e panóptica em uma única arquitetura.
Os três são substituições diretas no smp ou transformers com o mesmo carregador de dados.
Entregue
Esta lição produz:
outputs/prompt-segmentation-task-picker.md— um prompt que escolhe entre segmentação semântica, de instância e panóptica e nomeia a arquitetura para uma dada tarefa.outputs/skill-segmentation-mask-inspector.md— uma skill que reporta a distribuição de classes, estatísticas da máscara prevista e as classes que estão sub-previstas ou com bordas borradas.
Exercícios
- (Fácil) Implemente
bce_dice_losspara uma tarefa de segmentação binária (primeiro plano vs fundo). Verifique em um conjunto de dados sintético de duas classes que a perda combinada converge mais rápido que a BCE sozinha quando o primeiro plano é 5% dos pixels. - (Médio) Substitua o bloco de subida
nn.Upsample + convpor um bloco de subidann.ConvTranspose2d. Treine ambos no conjunto de dados sintético e compare a mIoU. Observe onde os artefatos de tabuleiro de xadrez aparecem na versão com conv transposta. - (Difícil) Pegue um conjunto de dados de segmentação real (Oxford-IIIT Pets, split mini do Cityscapes ou um subconjunto médico) e treine a U-Net até ficar a 2 pontos de IoU da referência
smp.Unet. Reporte a IoU por classe e identifique quais classes mais se beneficiam de adicionar a Dice à perda.
Termos-Chave
| Termo | O que as pessoas dizem | O que realmente significa |
|---|---|---|
| Segmentação semântica | "Rotular cada pixel" | Classificação por pixel em C classes; instâncias da mesma classe se fundem |
| Segmentação de instância | "Rotular cada objeto" | Separa instâncias distintas da mesma classe; apenas primeiro plano |
| Segmentação panóptica | "Semântica + instância" | Cada pixel recebe uma classe; cada instância de objeto também recebe um id único |
| Conexão de salto | "Ponte da U-Net" | Concatenação de características do codificador em características do decodificador de resolução correspondente; preserva o detalhe de alta frequência |
| Conv transposta | "Deconvolução" | Upsampling aprendível; pode produzir artefatos de tabuleiro de xadrez |
| Perda Dice | "Perda de sobreposição" | 1 - 2 |
| mIoU | "Interseção média sobre união" | IoU média entre classes; a métrica padrão da comunidade para segmentação |
| Boundary F1 | "Precisão de borda" | Pontuação F1 computada apenas em pixels de borda; importa para tarefas críticas em precisão |
Leitura Adicional
- U-Net: Convolutional Networks for Biomedical Image Segmentation (Ronneberger et al., 2015) — o artigo original; a figura que todo mundo copia está na página 2
- Fully Convolutional Networks (Long et al., 2015) — o artigo que primeiro tornou a segmentação um problema de conv ponta a ponta
- segmentation_models_pytorch — a referência para segmentação em produção; toda arquitetura padrão mais toda perda padrão
- Lessons learned from training SOTA segmentation (kaggle.com competitions) — um passo a passo de por que TTA, pseudo-rotulagem e pesos de classe importam em dados reais