Phase 12 - Lesson 04

Flamingo e Gated Cross-Attention para VLMs Few-Shot

This lesson includes a graded coding exercise that runs in your browser, unlocked with lifetime access.

O Flamingo (2022) da DeepMind fez duas coisas antes de qualquer outro. Ele mostrou que um único modelo poderia processar sequências arbitrariamente intercaladas de imagens, vídeos e texto. E mostrou que os VLMs podiam aprender no contexto (in-context learning) — forneça um prompt few-shot com três pares de exemplo (imagem, legenda) e o modelo legenda uma nova imagem sem nenhum passo de gradiente. O mecanismo: camadas de gated cross-attention, inseridas entre as camadas existentes do LLM congelado, com um gate tanh aprendido que começa em zero para que a capacidade de texto do LLM seja preservada na inicialização. Esta lição detalha a arquitetura do Perceiver resampler e do gated cross-attention do Flamingo — o ancestral das entradas intercaladas do Gemini e dos tokens visuais do Idefics2.

Tipo: Aprender Linguagens: Python (stdlib, gated cross-attention + Perceiver resampler demo) Pré-requisitos: Fase 12 · 03 (BLIP-2 Q-Former) Tempo: ~120 minutos

Objetivos de Aprendizado

  • Explicar como o gated cross-attention preserva a capacidade de texto de um LLM congelado na inicialização via tanh(gate) = 0.
  • Explicar o funcionamento de um Perceiver resampler: N patches de imagem → K consultas "latentes" fixas via cross-attention.
  • Descrever como o Flamingo lida com sequências intercaladas de imagem-texto com mascaramento causal que respeita o posicionamento da imagem.
  • Reproduzir uma estrutura de prompt multimodal few-shot (3 exemplos de imagem-legenda seguidos de uma imagem de consulta).

O Problema

O BLIP-2 alimenta 32 tokens visuais na camada de entrada de um LLM congelado. Isso funciona para uma imagem por prompt. Mas e se você quiser alimentar muitas imagens intercaladas com texto, como em "aqui está a imagem A, legende-a; aqui está a imagem B, legende-a; agora aqui está a imagem C, legende-a"? A self-attention do LLM precisaria lidar com tokens de imagem e tokens de texto em um único fluxo, e a questão de quais posições podem prestar atenção a quais imagens se torna complicada.

O resposta do Flamingo: não altere em nada o fluxo de entrada do LLM. Insira camadas extras de cross-attention entre os blocos LLM existentes. Os tokens de texto ainda fluem pela self-attention causal do LLM como sempre. A cada poucos blocos de LLM, os tokens de texto também realizam cross-attention com os recursos de imagem por meio de uma nova camada com gate. O gate (inicializado em zero) significa que, no passo zero, as novas camadas são no-ops (não realizam operações) — o modelo se comporta exatamente como o LLM pré-treinado. À medida que o treinamento avança, o gate se abre e as informações visuais começam a fluir.

A segunda pergunta que o Flamingo respondeu: como lidar com um número variável de imagens (0, 1 ou muitas) por prompt? Um Perceiver resampler — um pequeno módulo de cross-attention que recebe qualquer quantidade de patches que você tiver e produz um número fixo de tokens latentes visuais. A camada de cross-attention do LLM vê o mesmo formato, independentemente de quantas imagens estejam no prompt.

O Conceito

O LLM congelado

O Flamingo começa com um LLM Chinchilla 70B congelado. Todos os pesos de 70B permanecem intocados. A self-attention de texto e o FFN existentes operam normalmente.

Perceiver resampler

Para cada imagem no prompt, o ViT produz N tokens de patch. O Perceiver resampler tem K latentes fixos e treináveis (o Flamingo usa K=64). Cada bloco do resampler consiste em duas subetapas:

  1. Cross-attention: os K latentes atendem sobre os N tokens de patch (Q dos latentes, K/V dos patches).
  2. Self-attention + FFN dentro dos latentes.

Após 6 blocos do resampler, a saída é de K=64 tokens visuais de dimensão 1024, independentemente de quantos patches o ViT produziu. Uma imagem de 224x224 (196 patches) e uma imagem de 480x480 (900 patches) saem ambas como 64 tokens do resampler.

Para vídeo, o resampler é aplicado temporalmente: os patches de cada quadro produzem 64 latentes, e uma codificação posicional temporal permite que o modelo distinga t=0 de t=N. O vídeo completo se torna T * 64 tokens visuais.

Gated cross-attention

Entre cada M camadas do LLM congelado (o Flamingo usa M=4), insira um novo bloco de gated cross-attention:

x_after_llm_block = llm_block(x_before)
cross = cross_attn(x_after, resampler_output)
gated = tanh(alpha) * cross + x_after
x_before_next_block = gated
  • alpha é um escalar treinável inicializado em zero.
  • tanh(0) = 0, portanto, na inicialização, a ramificação com gate contribui com zero.
  • À medida que alpha se afasta de zero, a contribuição do cross-attention cresce suavemente.
  • A conexão residual significa que mesmo um gate totalmente aberto não substitui a representação de texto do LLM; apenas adiciona informações visuais por cima.

Esta é a escolha de design mais importante no Flamingo: o condicionamento visual é aditivo, com gate e zero na inicialização. Um Flamingo no passo 0 é um Chinchilla 70B perfeito em entradas apenas de texto.

Masked cross-attention para entradas intercaladas

Em um prompt como " legenda A legenda B ?", cada token de texto deve ver apenas as imagens que vieram antes dele na sequência. A máscara de cross-attention garante que: o token de texto na posição t atenda apenas aos tokens do resampler de imagem cujo índice de imagem i < i_t, onde i_t é a imagem mais recente antes da posição t. "Ver apenas a última imagem anterior" ou "ver todas as imagens anteriores" são escolhas válidas; o Flamingo escolheu a primeira.

Aprendizado de poucos exemplos no contexto (In-context few-shot learning)

Um prompt do Flamingo se parece com:

<image1> A photo of a cat. <image2> A photo of a dog. <image3> A photo of a

O modelo vê o padrão de conclusão e gera "bird" (ou o que quer que a imagem3 mostre). Sem passos de gradiente. A capacidade de aprendizado em contexto do LLM congelado se propaga através do gated cross-attention — esse é o ponto alto do artigo e a razão de sua importância.

Dados de treinamento

O Flamingo foi treinado em três conjuntos de dados (datasets):

  1. MultiModal MassiveWeb (M3W): 43 milhões de páginas da web com imagens e texto intercalados, reconstruindo a ordem de leitura.
  2. Pares de Imagem-Texto (ALIGN + LTIP): 4,4 bilhões de pares.
  3. Pares de Vídeo-Texto (VTP): 27 milhões de clipes curtos de vídeo.

O OBELICS (2023) é uma reprodução aberta do corpus da web intercalado, no qual o Idefics, Idefics2 e a maioria dos modelos abertos baseados no Flamingo são treinados.

OpenFlamingo e Otter

O OpenFlamingo (2023) é a reprodução aberta. A arquitetura é idêntica (Perceiver resampler + gated cross-attention em LLaMA ou MPT congelados). Checkpoints em 3B, 4B, 9B. A qualidade é inferior à do Flamingo devido a um LLM base menor e menos dados.

O Otter (2023) baseia-se no OpenFlamingo com ajuste de instruções (instruction tuning) no MIMIC-IT (um conjunto de dados de instruções multimodais), mostrando que o gated cross-attention também funciona para o seguimento de instruções.

Os descendentes

  • Idefics / Idefics2 / Idefics3: a linhagem de gated cross-attention da Hugging Face, progressivamente mais simples (o Idefics2 removeu o resampler em favor de tokens de patch diretos com pooling adaptativo).
  • Transição Flamingo-para-Chameleon: por volta de 2024, muitas equipes migraram para a fusão precoce (early-fusion) (Lição 12.11); o gated cross-attention no estilo Flamingo permanece em produção onde o congelamento do backbone é necessário.
  • Entrada intercalada do Gemini: herda conceitualmente a flexibilidade do formato intercalado do Flamingo, embora o mecanismo exato seja proprietário.

Comparação com o BLIP-2

BLIP-2 Flamingo
Ponte visual Q-Former apenas na entrada Gated cross-attention a cada M camadas
Tokens visuais 32 por imagem 64 por imagem por camada de cross-attn
LLM congelado Sim Sim
Poucos disparos no contexto (Few-shot) Fraco Forte — a peça central do artigo
Entradas intercaladas Sem suporte nativo Sim, o objetivo de design
Dados de treinamento 130M pares 1,3B pares + 43M páginas intercaladas
Contagem de parâmetros 188M treinados ~10B treinados (camadas de cross-attn)
Computação Dias em 8 A100s Semanas em milhares de TPUv4

Escolha o BLIP-2 para VQA de imagem única com orçamento limitado. Escolha o Flamingo/Idefics2 para raciocínio intercalado, few-shot ou multi-imagem.

Use

O arquivo code/main.py demonstra:

  1. Um Perceiver resampler em 36 tokens de patch falsos com 8 latentes treináveis (cross-attention em Python puro).
  2. Uma etapa de gated cross-attention com alpha = 0 → a saída é igual à entrada (LLM inalterado), depois alpha = 2.0 → contribuição visual misturada.
  3. Um construtor de máscara intercalada que produz a máscara de atenção 2D para uma sequência "(imagem 1) (texto 1) (imagem 2) (texto 2)".

Entregue

Esta lição produz outputs/skill-gated-bridge-diagnostic.md. Dada a configuração de um VLM aberto (resampler Sim/Não, frequência de cross-attn, esquema de gate), ela identifica os elementos da linhagem Flamingo e explica a estratégia de congelamento. Útil para depurar por que um ajuste fino degradou o desempenho do texto (resposta: o gate se abriu rápido demais).

Exercícios

  1. Calcule a contagem de parâmetros visuais do Flamingo-9B: LLM de 9B + 1,4B de camadas de gated cross-attention + 64M de resampler. Que fração do total de parâmetros é treinada?

  2. Implemente o residual com gate y = tanh(alpha) * cross + x no PyTorch. Mostre experimentalmente que, com alpha=0, y==x exatamente na inicialização.

  3. Leia a Seção 3.2 do OpenFlamingo (arXiv:2308.01390) sobre como eles lidam com várias imagens em um lote (batch) quando cada prompt tem uma contagem de imagens diferente. Descreva a estratégia de preenchimento (padding).

  4. Por que a máscara de cross-attention do Flamingo permite que um token de texto atenda apenas à imagem anterior mais recente, em vez de a todas as imagens anteriores? Leia o artigo do Flamingo, Seção 2.4, e explique essa compensação (tradeoff).

  5. Few-shot no contexto: construa um prompt com 4 exemplos de "imagem → cor do objeto principal" para uma nova variante do Flamingo. Descreva o padrão de precisão esperado conforme você varia o número de exemplos de 0 a 8.

Termos-Chave

Termo O que as pessoas dizem O que realmente significa
Perceiver resampler "Cross-attention de latente fixa" Módulo que produz K tokens fixos a partir de um número variável de patches de entrada
Gated cross-attention "Ponte com gate Tanh" Camada residual y = tanh(alpha)*cross + x, alpha treinável, init 0
Entrada intercalada (Interleaved input) "Sequência mista" Formato de prompt com imagens e texto misturados livremente na ordem de leitura
LLM congelado "Sem gradientes no LLM" Os pesos do LLM de texto não são atualizados; apenas as camadas do resampler + cross-attn são treinadas
Poucos exemplos (Few-shot) "Exemplos no contexto" Fornece alguns pares (imagem, resposta) no prompt; o modelo generaliza sem ajuste fino
OBELICS "Corpus da web intercalado" Dataset aberto de 141M de páginas da web com imagens e texto na ordem de leitura
Chinchilla "Base congelada de 70B" O LLM de texto congelado do Flamingo, do artigo Chinchilla da DeepMind
Cronograma do gate (Gate schedule) "Como o alfa se move" A taxa na qual o gate de cross-attention se abre durante o treinamento
Frequência de cross-attn "A cada M camadas" A frequência com que um bloco de gated cross-attention é inserido; o Flamingo usa M=4
OpenFlamingo "Reprodução aberta" Checkpoint aberto da MosaicML/LAION de 3-9B; arquitetura idêntica à do Flamingo

Leitura Adicional

0 lifetime access. Curriculum based on AI Engineering from Scratch by Rohit Ghumare (MIT, used under attribution).