Phase 12 - Lesson 04

Flamingo y Gated Cross-Attention para VLMs Few-Shot

This lesson includes a graded coding exercise that runs in your browser, unlocked with lifetime access.

Flamingo (2022) de DeepMind hizo dos cosas antes que nadie. Demostró que un solo modelo podía procesar secuencias arbitrariamente intercaladas de imágenes, videos y texto. Y demostró que los VLMs podían aprender en contexto (in-context learning): al proporcionar un prompt few-shot con tres pares de ejemplos (imagen, leyenda), el modelo genera la leyenda de una nueva imagen sin ningún paso de gradiente. El mecanismo: capas de gated cross-attention, insertadas entre las capas existentes del LLM congelado, con un gate tanh aprendido que comienza en cero para que la capacidad de texto del LLM se preserve en la inicialización. Esta lección recorre la arquitectura de Perceiver resampler y gated cross-attention de Flamingo, el ancestro de las entradas intercaladas de Gemini y los tokens visuales de Idefics2.

Tipo: Aprender Lenguajes: Python (stdlib, gated cross-attention + Perceiver resampler demo) Prerrequisitos: Fase 12 · 03 (BLIP-2 Q-Former) Tiempo: ~120 minutos

Objetivos de Aprendizaje

  • Explicar cómo el gated cross-attention preserva la capacidad de texto de un LLM congelado en la inicialización a través de tanh(gate) = 0.
  • Comprender el funcionamiento de un Perceiver resampler: N patches de imagen → K consultas "latentes" fijas mediante cross-attention.
  • Describir cómo Flamingo maneja secuencias intercaladas de imagen y texto con enmascaramiento causal que respeta la ubicación de la imagen.
  • Reproducir una estructura de prompt multimodal few-shot (3 ejemplos de imagen-leyenda y luego una imagen de consulta).

El Problema

BLIP-2 introduce 32 tokens visuales en la capa de entrada de un LLM congelado. Esto funciona para una imagen por prompt. Pero, ¿y si deseas introducir muchas imágenes intercaladas con texto, como en "aquí está la imagen A, ponle una leyenda; aquí está la imagen B, ponle una leyenda; ahora aquí está la imagen C, ponle una leyenda"? La self-attention del LLM necesitaría procesar tokens de imagen y tokens de texto en un solo flujo, y la cuestión de qué posiciones pueden atender a qué imágenes se vuelve complicada.

La respuesta de Flamingo: no cambiar el flujo de entrada del LLM en absoluto. Insertar capas adicionales de cross-attention entre los bloques de LLM existentes. Los tokens de texto siguen fluyendo a través de la self-attention causal del LLM como siempre. Cada ciertos bloques del LLM, los tokens de texto también realizan cross-attention con las características de la imagen mediante una nueva capa con gate. El gate (inicializado en cero) significa que en el paso cero las nuevas capas no realizan ninguna operación (no-ops): el modelo se comporta exactamente igual que el LLM preentrenado. A medida que avanza el entrenamiento, el gate se abre y la información visual comienza a fluir.

La segunda pregunta que Flamingo respondió: ¿cómo se maneja un número variable de imágenes (0, 1 o muchas) por prompt? Un Perceiver resampler: un pequeño módulo de cross-attention que toma cualquier cantidad de patches que tengas y produce un número fijo de tokens latentes visuales. La capa de cross-attention del LLM ve la misma forma sin importar cuántas imágenes haya en el prompt.

El Concepto

El LLM congelado

Flamingo comienza con un LLM Chinchilla 70B congelado. Todos los pesos de 70B se mantienen intactos. La self-attention de texto y FFN existentes funcionan con normalidad.

Perceiver resampler

Para cada imagen en el prompt, el ViT produce N tokens de patch. El Perceiver resampler tiene K latentes fijos entrenables (Flamingo utiliza K=64). Cada bloque de resampler consta de dos subpasos:

  1. Cross-attention: los K latentes atienden sobre los N tokens de patch (Q de los latentes, K/V de los patches).
  2. Self-attention + FFN dentro de los latentes.

Después de 6 bloques de resampler, la salida es K=64 tokens visuales de dimensión 1024, independientemente de cuántos patches haya producido el ViT. Tanto una imagen de 224x224 (196 patches) como una de 480x480 (900 patches) salen como 64 tokens de resampler.

Para video, el resampler se aplica de forma temporal: los patches de cada fotograma producen 64 latentes, y una codificación posicional temporal permite al modelo distinguir t=0 de t=N. El video completo se convierte en T * 64 tokens visuales.

Gated cross-attention

Entre cada M capas del LLM congelado (Flamingo utiliza M=4), se inserta un nuevo bloque de gated cross-attention:

x_after_llm_block = llm_block(x_before)
cross = cross_attn(x_after, resampler_output)
gated = tanh(alpha) * cross + x_after
x_before_next_block = gated
  • alpha es un escalar entrenable inicializado en cero.
  • tanh(0) = 0, por lo que en la inicialización la rama con gate contribuye con cero.
  • A medida que alpha se aleja de cero, la contribución del cross-attention crece suavemente.
  • La conexión residual significa que incluso un gate completamente abierto no sobrescribe la representación de texto del LLM; solo añade información visual por encima.

Esta es la decisión de diseño más importante en Flamingo: el condicionamiento visual es aditivo, con gate y cero en la inicialización. Un Flamingo en el paso 0 es un Chinchilla 70B perfecto en entradas únicamente de texto.

Masked cross-attention para entradas intercaladas

En un prompt como " leyenda A leyenda B ?", cada token de texto solo debería ver las imágenes que aparecieron antes de él en la secuencia. La máscara de cross-attention impone que: el token de texto en la posición t atienda solo a los tokens del resampler de la imagen cuyo índice de imagen i < i_t, donde i_t es la imagen más reciente antes de la posición t. "Ver solo la última imagen anterior" o "ver todas las imágenes anteriores" son opciones válidas; Flamingo eligió la primera.

Aprendizaje few-shot en contexto (In-context few-shot learning)

Un prompt de Flamingo tiene el siguiente aspecto:

<image1> A photo of a cat. <image2> A photo of a dog. <image3> A photo of a

El modelo ve el patrón de completado y genera "bird" (or lo que sea que muestre la imagen3). Sin pasos de gradiente. La capacidad de aprendizaje en contexto del LLM congelado se transmite a través del gated cross-attention; este es el punto clave del artículo y por qué es importante.

Datos de entrenamiento

Flamingo se entrenó en tres conjuntos de datos (datasets):

  1. MultiModal MassiveWeb (M3W): 43 millones de páginas web con imágenes y texto intercalados, reconstruyendo el orden de lectura.
  2. Pares de imagen y texto (ALIGN + LTIP): 4400 millones de pares.
  3. Pares de video y texto (VTP): 27 millones de videoclips cortos.

OBELICS (2023) es una reproducción abierta del corpus web intercalado en el que se entrenan Idefics, Idefics2 y la mayoría de los modelos abiertos "estilo Flamingo".

OpenFlamingo y Otter

OpenFlamingo (2023) es la reproducción abierta. Su arquitectura es idéntica (Perceiver resampler + gated cross-attention en LLaMA o MPT congelados). Checkpoints en 3B, 4B, 9B. Su calidad es inferior a la de Flamingo debido a un LLM base más pequeño y menos datos.

Otter (2023) se basa en OpenFlamingo con ajuste de instrucciones (instruction tuning) en MIMIC-IT (un conjunto de datos de instrucciones multimodales), lo que demuestra que el gated cross-attention también funciona para seguir instrucciones.

Los descendientes

  • Idefics / Idefics2 / Idefics3: el linaje de gated cross-attention de Hugging Face, progresivamente más simple (Idefics2 eliminó el resampler a favor de tokens de patch directos con pooling adaptativo).
  • Transición de Flamingo a Chameleon: para 2024, muchos equipos migraron a la fusión temprana (early-fusion) (Lección 12.11); el gated cross-attention al estilo Flamingo sigue utilizándose en producción cuando se requiere congelar el backbone.
  • Entrada intercalada de Gemini: herda conceptualmente la flexibilidad del formato intercalado de Flamingo, aunque el mecanismo exacto es propietario.

Comparación con BLIP-2

BLIP-2 Flamingo
Puente visual Q-Former una vez en la entrada Gated cross-attention cada M capas
Tokens visuales 32 por imagen 64 por imagen por capa de cross-attn
LLM congelado
Few-shot en contexto Débil Fuerte — la pieza central del artículo
Entradas intercaladas Sin soporte nativo Sí, el objetivo de diseño
Datos de entrenamiento 130M de pares 1.3B de pares + 43M de páginas intercaladas
Cantidad de parámetros 188M entrenados ~10B entrenados (capas de cross-attn)
Cómputo Días en 8 A100s Semanas en miles de TPUv4

Elige BLIP-2 para VQA de una sola imagen con presupuesto limitado. Elige Flamingo/Idefics2 para razonamiento intercalado, few-shot o de múltiples imágenes.

Uso

El archivo code/main.py demuestra:

  1. Un Perceiver resampler en 36 tokens de patch falsos con 8 latentes entrenables (cross-attention en Python puro).
  2. Un paso de gated cross-attention con alpha = 0 → la salida es igual a la entrada (LLM sin cambios), luego alpha = 2.0 → la contribución visual se mezcla.
  3. Un constructor de máscara intercalada que Unicode produce la máscara de atención 2D para una secuencia "(imagen 1) (texto 1) (imagen 2) (texto 2)".

Entrega

Esta lección produce outputs/skill-gated-bridge-diagnostic.md. Dada la configuración de un VLM abierto (resampler Sí/No, frecuencia de cross-attn, esquema de gate), identifica los elementos del linaje Flamingo y explica la estrategia de congelación. Es útil para depurar por que un ajuste fino degradó el rendimiento del texto (respuesta: el gate se abrió demasiado rápido).

Ejercicios

  1. Calcula la cantidad de parámetros visuales de Flamingo-9B: LLM de 9B + 1.4B de capas de gated cross-attention + 64M de resampler. ¿Qué fracción del total de parámetros está entrenada?

  2. Implementa el residual con gate y = tanh(alpha) * cross + x en PyTorch. Demuestra experimentalmente que con alpha=0, y==x exactamente en la inicialización.

  3. Lee la Sección 3.2 de OpenFlamingo (arXiv:2308.01390) sobre cómo manejan múltiples imágenes en un lote (batch) cuando cada prompt tiene una cantidad diferente de imágenes. Describe la estrategia de relleno (padding).

  4. ¿Por qué la máscara de cross-attention de Flamingo permite que un token de texto atienda solo a la imagen anterior más reciente en lugar de a todas las imágenes anteriores? Lee el artículo de Flamingo, Sección 2.4, y explica la compensación (tradeoff).

  5. Few-shot en contexto: construye un prompt con 4 ejemplos de "imagen → color del objeto principal" para una nueva variante de Flamingo. Describe el patrón de precisión esperado a medida que varías la cantidad de ejemplos de 0 a 8.

Términos Clave

Término Lo que la gente dice Lo que realmente significa
Perceiver resampler "Cross-attention de latente fija" Módulo que produce K tokens fijos a partir de un número variable de patches de entrada
Gated cross-attention "Puente con gate Tanh" Capa residual y = tanh(alpha)*cross + x, alpha entrenable, init 0
Entrada intercalada (Interleaved input) "Secuencia mixta" Formato de prompt con imágenes y texto mezclados libremente en el orden de lectura
LLM congelado "Sin gradientes en el LLM" Los pesos del LLM de texto no se actualizan; solo se entrenan las capas de resampler + cross-attn
Few-shot "Ejemplos en contexto" Proporciona unos pocos pares (imagen, respuesta) en el prompt; el modelo generaliza sin ajuste fino
OBELICS "Corpus web intercalado" Dataset abierto de 141M de páginas web con imágenes y texto en el orden de lectura
Chinchilla "Base congelada de 70B" El LLM de texto congelado de Flamingo, del artículo Chinchilla de DeepMind
Cronograma del gate (Gate schedule) "Cómo se mueve alfa" La tasa a la que se abre el gate de cross-attention durante el entrenamiento
Frecuencia de cross-attn "Cada M capas" Qué tan a menudo se inserta un bloque de gated cross-attention; Flamingo usa M=4
OpenFlamingo "Reproducción abierta" Checkpoint abierto de MosaicML/LAION de 3-9B; arquitectura idéntica a Flamingo

Lectura Adicional

0 lifetime access. Curriculum based on AI Engineering from Scratch by Rohit Ghumare (MIT, used under attribution).