Phase 04 - Lesson 05

Transfer Learning e Fine-Tuning

Outra pessoa gastou um milhão de horas de GPU ensinando uma rede a reconhecer como são bordas, texturas e partes de objetos. Você deveria pegar essas features emprestadas antes de treinar as suas próprias.

Tipo: Build Linguagens: Python Pré-requisitos: Fase 4 Lição 03 (CNNs), Fase 4 Lição 04 (Classificação de Imagens) Tempo: ~75 minutos

Objetivos de Aprendizagem

  • Distinguir extração de features de fine-tuning e escolher a opção certa com base no tamanho do dataset, na distância de domínio e no orçamento de computação
  • Carregar um backbone pré-treinado, substituir sua cabeça de classificação e treinar apenas a cabeça até um baseline funcional em menos de 20 linhas
  • Descongelar camadas progressivamente com learning rates discriminativos, de modo que features genéricas iniciais recebam atualizações menores do que as específicas da tarefa no final
  • Diagnosticar as três falhas comuns: feature drift por LR alto demais em blocos descongelados, colapso das estatísticas de BN em datasets minúsculos e esquecimento catastrófico

O Problema

Treinar uma ResNet-50 no ImageNet custa cerca de 2.000 horas de GPU. Pouquíssimas equipes têm esse orçamento para cada tarefa que colocam em produção. O que quase toda equipe de fato coloca em produção é um backbone pré-treinado com uma nova cabeça treinada em algumas centenas ou alguns milhares de imagens específicas da tarefa.

Isso não é um atalho. O primeiro bloco conv de qualquer CNN treinada no ImageNet aprende bordas e filtros do tipo Gabor. Os próximos blocos aprendem texturas e motivos simples. Os blocos intermediários aprendem partes de objetos. Os blocos finais aprendem combinações que começam a se parecer com as 1.000 categorias do ImageNet. Os primeiros 90% dessa hierarquia transferem quase inalterados para imagens médicas, inspeção industrial, dados de satélite e qualquer outra tarefa de visão — porque a natureza tem um vocabulário limitado de bordas e texturas. Os últimos 10% são o que você de fato treina.

Acertar a transferência tem três bugs te esperando: destruir features pré-treinadas com um learning rate alto demais, privar o modelo de informação congelando coisas demais e deixar as estatísticas em execução do BatchNorm derivarem em direção a um dataset minúsculo do qual o resto da rede nunca aprendeu. Esta lição percorre cada um deles de propósito.

O Conceito

Extração de features vs fine-tuning

Dois regimes, escolhidos conforme o quanto você confia nas features pré-treinadas e quantos dados você tem.

flowchart TB
    subgraph FE["Extração de features — backbone congelado"]
        FE1["Backbone pré-treinado<br/>(sem gradiente)"] --> FE2["Nova cabeça<br/>(treinada)"]
    end
    subgraph FT["Fine-tuning — ponta a ponta"]
        FT1["Backbone pré-treinado<br/>(LR minúsculo)"] --> FT2["Nova cabeça<br/>(LR normal)"]
    end

    style FE1 fill:#e5e7eb,stroke:#6b7280
    style FE2 fill:#dcfce7,stroke:#16a34a
    style FT1 fill:#fef3c7,stroke:#d97706
    style FT2 fill:#dcfce7,stroke:#16a34a

Regras práticas:

Tamanho do dataset Distância de domínio Receita
< 1k imagens próximo do ImageNet Congele o backbone, treine só a cabeça
1k-10k próximo Congele os 2-3 primeiros estágios, faça fine-tune no resto
10k-100k qualquer Fine-tune ponta a ponta com LR discriminativo
100k+ distante Faça fine-tune em tudo; considere treinar do zero se o domínio for distante o suficiente

"Próximo do ImageNet" significa, grosso modo, fotos RGB naturais com conteúdo do tipo objeto. Tomografias médicas, imagens de satélite vistas de cima e microscopia são domínios distantes — as features ainda ajudam, mas você precisará deixar mais camadas se adaptarem.

Por que congelar funciona, afinal

As features do ImageNet que uma CNN aprende não são especializadas nas 1.000 categorias. Elas são especializadas nas estatísticas de imagens naturais: bordas em orientações específicas, texturas, padrões de contraste, primitivas de forma. Essas estatísticas são estáveis em quase todo domínio visual que um humano consiga nomear. É por isso que um modelo treinado no ImageNet e avaliado zero-shot no CIFAR-10 com apenas uma nova cabeça linear (sem fine-tuning do backbone) chega a 80%+ de acurácia. A cabeça está aprendendo quais das features já aprendidas devem ser ponderadas para esta tarefa.

Learning rates discriminativos

Quando você de fato descongela, as camadas iniciais devem treinar mais devagar do que as finais. As camadas iniciais codificam features genéricas que você quer preservar; as camadas finais codificam estrutura específica da tarefa que você precisa mover bastante.

Receita típica:

  stage 0 (stem + primeiro grupo): lr = base_lr / 100    (praticamente fixo)
  stage 1:                         lr = base_lr / 10
  stage 2:                         lr = base_lr / 3
  stage 3 (último grupo backbone): lr = base_lr
  head:                            lr = base_lr  (ou um pouco maior)

No PyTorch isso é apenas uma lista de grupos de parâmetros passada ao otimizador. Um modelo, cinco learning rates, zero código extra.

O problema do BatchNorm

As camadas BN guardam buffers running_mean e running_var que foram calculados no ImageNet. Se a sua tarefa tem uma distribuição de pixels diferente — iluminação diferente, sensor diferente, espaço de cor diferente — esses buffers estão errados. Três opções em ordem de preferência:

  1. Fine-tune com o BN em modo train. Deixe o BN atualizar suas estatísticas em execução junto com o resto. Escolha padrão quando o dataset da tarefa é de tamanho médio (>= 5k exemplos).
  2. Congele o BN em modo eval. Mantenha as estatísticas do ImageNet e treine apenas os pesos. Correto quando o seu dataset é pequeno o suficiente para que a média móvel do BN fique ruidosa.
  3. Substitua o BN por GroupNorm. Elimina o problema da média móvel por completo. Usado em backbones de detecção e segmentação onde o batch size por GPU é minúsculo.

Errar isso derruba silenciosamente a acurácia em 5-15%.

Design da cabeça

A cabeça de classificação são 1-3 camadas lineares mais um dropout opcional. Todo backbone do torchvision vem com uma cabeça padrão que você substitui:

backbone.fc = nn.Linear(backbone.fc.in_features, num_classes)          # ResNet
backbone.classifier[1] = nn.Linear(..., num_classes)                    # EfficientNet, MobileNet
backbone.heads.head = nn.Linear(..., num_classes)                       # torchvision ViT

Para datasets pequenos, uma única camada linear geralmente basta. Adicionar uma camada oculta (Linear -> ReLU -> Dropout -> Linear) ajuda quando a distribuição da tarefa está mais distante da distribuição de treino do backbone.

Decaimento de LR por camada

Uma versão mais suave do LR discriminativo, usada em fine-tuning moderno (BEiT, DINOv2, fine-tunes de ViT-B). Em vez de agrupar camadas em estágios, dê a cada camada um LR ligeiramente menor do que a camada acima dela:

lr_layer_k = base_lr * decay^(L - k)

Com decay = 0.75 e L = 12 blocos de transformer, o primeiro bloco treina a 0.75^11 ≈ 0.04x do LR da cabeça. Importa mais para fine-tunes de transformer do que para CNNs, onde LRs agrupados por estágio costumam bastar.

O que avaliar

Execuções de transfer learning precisam de dois números que você não acompanharia em uma execução do zero:

  • Acurácia somente pré-treinado — a acurácia da cabeça com o backbone congelado. Este é o seu piso.
  • Acurácia com fine-tune — o mesmo modelo após o treino ponta a ponta. Este é o seu teto.

Se a acurácia com fine-tune for menor do que a somente pré-treinado, você tem um bug de learning rate ou de BN. Sempre imprima as duas.

Construa

Passo 1: Carregar um backbone pré-treinado e inspecioná-lo

import torch
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights

backbone = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
print(backbone)
print()
print("classifier head:", backbone.fc)
print("feature dim:", backbone.fc.in_features)

A ResNet18 tem quatro estágios (layer1..layer4) mais um stem e uma cabeça fc. Todo backbone de classificação do torchvision tem uma estrutura análoga.

Passo 2: Extração de features — congele tudo, substitua a cabeça

def make_feature_extractor(num_classes=10):
    model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    for p in model.parameters():
        p.requires_grad = False
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

model = make_feature_extractor(num_classes=10)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
frozen = sum(p.numel() for p in model.parameters() if not p.requires_grad)
print(f"trainable: {trainable:>10,}")
print(f"frozen:    {frozen:>10,}")

Apenas model.fc é treinável. O backbone é um extrator de features congelado.

Passo 3: Fine-tuning discriminativo

Um utilitário que constrói grupos de parâmetros com learning rates específicos por estágio.

def discriminative_param_groups(model, base_lr=1e-3, decay=0.3):
    stages = [
        ["conv1", "bn1"],
        ["layer1"],
        ["layer2"],
        ["layer3"],
        ["layer4"],
        ["fc"],
    ]
    groups = []
    for i, names in enumerate(stages):
        lr = base_lr * (decay ** (len(stages) - 1 - i))
        params = [p for n, p in model.named_parameters()
                  if any(n.startswith(k) for k in names)]
        if params:
            groups.append({"params": params, "lr": lr, "name": "_".join(names)})
    return groups

model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, 10)
for p in model.parameters():
    p.requires_grad = True

groups = discriminative_param_groups(model)
for g in groups:
    print(f"{g['name']:>10s}  lr={g['lr']:.2e}  params={sum(p.numel() for p in g['params']):>8,}")

decay=0.3 significa que cada estágio treina a 30% da taxa do estágio seguinte. fc recebe base_lr, layer4 recebe 0.3 * base_lr, conv1 recebe 0.3^5 * base_lr ≈ 0.00243 * base_lr. Parece extremo; empiricamente funciona.

Passo 4: Tratamento do BatchNorm

Helper para congelar as estatísticas em execução do BN sem congelar seus pesos.

def freeze_bn_stats(model):
    for m in model.modules():
        if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
            m.eval()
            for p in m.parameters():
                p.requires_grad = False
    return model

Chame-o depois de definir model.train() no início de cada epoch. model.train() coloca tudo em modo de treino; isto reverte apenas as camadas BN.

Passo 5: Um loop mínimo de fine-tuning ponta a ponta

from torch.optim import SGD
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn.functional as F

def fine_tune(model, train_loader, val_loader, device, epochs=5, base_lr=1e-3, freeze_bn=False):
    model = model.to(device)
    groups = discriminative_param_groups(model, base_lr=base_lr)
    optimizer = SGD(groups, momentum=0.9, weight_decay=1e-4, nesterov=True)
    scheduler = CosineAnnealingLR(optimizer, T_max=epochs)

    for epoch in range(epochs):
        model.train()
        if freeze_bn:
            freeze_bn_stats(model)
        tr_loss, tr_correct, tr_total = 0.0, 0, 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = F.cross_entropy(logits, y, label_smoothing=0.1)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            tr_loss += loss.item() * x.size(0)
            tr_total += x.size(0)
            tr_correct += (logits.argmax(-1) == y).sum().item()
        scheduler.step()

        model.eval()
        va_total, va_correct = 0, 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                pred = model(x).argmax(-1)
                va_total += x.size(0)
                va_correct += (pred == y).sum().item()
        print(f"epoch {epoch}  train {tr_loss/tr_total:.3f}/{tr_correct/tr_total:.3f}  "
              f"val {va_correct/va_total:.3f}")
    return model

Cinco epochs com a receita acima no CIFAR-10 levam a ResNet18-IMAGENET1K_V1 de ~70% de acurácia zero-shot em linear-probe para ~93% de acurácia com fine-tune. A cabeça sozinha estabilizaria em torno de 86% sem nunca tocar no backbone.

Passo 6: Descongelamento progressivo

Um cronograma que descongela um estágio por epoch, do fim em direção ao começo. Mitiga o feature drift ao custo de alguns epochs extras.

def progressive_unfreeze_schedule(model):
    stages = ["layer4", "layer3", "layer2", "layer1"]
    yielded = set()

    def start():
        for p in model.parameters():
            p.requires_grad = False
        for p in model.fc.parameters():
            p.requires_grad = True

    def unfreeze(epoch):
        if epoch < len(stages):
            name = stages[epoch]
            yielded.add(name)
            for n, p in model.named_parameters():
                if n.startswith(name):
                    p.requires_grad = True
            return name
        return None

    return start, unfreeze

Chame start() uma vez antes do primeiro epoch. Chame unfreeze(epoch) no início de cada epoch. Reconstrua o otimizador sempre que o conjunto de parâmetros treináveis mudar, caso contrário os params congelados ainda guardam momentos em cache que o confundem.

Use

Para a maioria das tarefas reais, torchvision.models + três linhas basta. O maquinário mais pesado acima importa quando você esbarra nos problemas que os defaults da biblioteca não conseguem resolver.

from torchvision.models import resnet50, ResNet50_Weights

model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
model.fc = nn.Linear(model.fc.in_features, num_classes)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

Dois outros defaults de nível de produção:

  • O timm traz ~800 backbones de visão pré-treinados com uma API consistente (timm.create_model("resnet50", pretrained=True, num_classes=10)). Para qualquer fine-tune além do zoo do torchvision, é o padrão.
  • Para transformers, transformers.AutoModelForImageClassification.from_pretrained(name, num_labels=N) te dá ViT / BEiT / DeiT com a mesma semântica de carregamento dos modelos de texto.

Entregue

Esta lição produz:

  • outputs/prompt-fine-tune-planner.md — um prompt que escolhe entre extração de features, fine-tuning progressivo ou ponta a ponta com base no tamanho do dataset, na distância de domínio e no orçamento de computação.
  • outputs/skill-freeze-inspector.md — uma skill que, dado um modelo PyTorch, reporta quais parâmetros são treináveis, quais camadas BatchNorm estão em modo eval e se o otimizador está de fato recebendo os parâmetros treináveis.

Exercícios

  1. (Fácil) Treine uma ResNet18 como linear probe (backbone congelado) e como full fine-tune no mesmo dataset CIFAR sintético. Reporte as duas acurácias lado a lado. Explique qual gap te diz que as features transferem bem e qual te diz que não transferem.
  2. (Médio) Introduza um bug de propósito: defina base_lr = 1e-1 no estágio do backbone em vez da cabeça. Mostre a loss de treino explodir, depois recupere aplicando o helper discriminative_param_groups. Registre o LR no qual cada estágio começa a divergir.
  3. (Difícil) Pegue um dataset de imagens médicas (ex.: CheXpert-small, PatchCamelyon ou HAM10000) e compare três regimes: (a) backbone pré-treinado no ImageNet congelado + cabeça linear; (b) fine-tune ponta a ponta do pré-treinado no ImageNet; (c) treino do zero. Reporte a acurácia e o custo de computação de cada um. A partir de que tamanho de dataset o treino do zero se torna competitivo?

Termos-chave

Termo O que as pessoas dizem O que de fato significa
Extração de features "Congelar e treinar a cabeça" Parâmetros do backbone congelados, apenas a nova cabeça de classificação recebe gradiente
Fine-tuning "Retreinar ponta a ponta" Todos os parâmetros treináveis, geralmente com LR muito menor do que no treino do zero
LR discriminativo "LR menor para camadas iniciais" Grupos de parâmetros do otimizador onde o LR dos estágios iniciais é uma fração do LR dos estágios finais
Decaimento de LR por camada "Gradiente suave de LR" LR por camada multiplicado por decay^(L - k); comum em fine-tunes de transformer
Esquecimento catastrófico "O modelo perdeu o ImageNet" Um LR alto demais sobrescreve features pré-treinadas antes que o sinal da nova tarefa seja aprendido
Drift das estatísticas de BN "A running mean está errada" running_mean/var do BatchNorm calculados em uma distribuição diferente da tarefa atual, prejudicando a acurácia silenciosamente
Linear probe "Backbone congelado + cabeça linear" Avaliação de features pré-treinadas — acurácia do melhor classificador linear sobre a representação congelada
Colapso catastrófico "Tudo prevê uma única classe" Acontece ao fazer fine-tune com um LR alto o suficiente para destruir features antes que os gradientes da cabeça consigam se estabilizar

Leitura Adicional

0 lifetime access. Curriculum based on AI Engineering from Scratch by Rohit Ghumare (MIT, used under attribution).