Phase 07 - Lesson 15

Variantes de Atenção — Janela Deslizante, Esparsa, Diferencial

This lesson includes a graded coding exercise that runs in your browser, unlocked with lifetime access.

A atenção total é um círculo. Cada token vê cada token, e a memória paga o preço. Quatro variantes dobram a forma do círculo e recuperam metade do custo.

Tipo: Build Linguagens: Python Pré-requisitos: Phase 7 · 02 (Self-Attention), Phase 7 · 03 (Multi-Head), Phase 7 · 12 (KV Cache / Flash Attention) Tempo: ~60 minutos

O Problema

A atenção total custa O(N²) em memória e O(N²) em computação em relação ao comprimento da sequência. Para um Llama 3 70B com contexto de 128K, isso significa 16 bilhões de entradas de atenção por camada, multiplicado por 80 camadas. O Flash Attention (Lição 12) oculta a memória de ativação O(N²), mas não altera o custo aritmético — cada token ainda atende a todos os outros tokens.

Três classes de variantes alteram a própria topologia da matriz de atenção:

  1. Atenção por janela deslizante (Sliding Window Attention - SWA). Cada token atende a uma janela fixa de vizinhos, não a todo o prefixo. A memória e a computação caem para O(N · W), onde W é o tamanho da janela. Gemma 2/3, primeiras camadas do Mistral 7B, Phi-3-Long.
  2. Atenção esparsa / em blocos (Sparse / block attention). Apenas pares selecionados (i, j) recebem pontuação; o restante é forçado a peso zero. Longformer, BigBird, OpenAI sparse transformer.
  3. Atenção diferencial (Differential attention). Computa dois mapas de atenção com projeções Q/K separadas e subtrai um do outro. Elimina o "sumidouro de atenção" (attention sink) que drena peso para os primeiros tokens. DIFF Transformer da Microsoft (2024).

Essas variantes coexistem. Um modelo de fronteira em 2026 frequentemente as mistura: a maioria das camadas é SWA-1024, cada quinta camada é de atenção total global, e um punhado delas são cabeças diferenciais que limpam a recuperação de informações. A proporção de 5:1 de SWA para global do Gemma 3 é o padrão atual dos livros-texto.

O Conceito

Atenção por Janela Deslizante (SWA)

Cada query na posição i atende apenas às posições em [i - W, i] (SWA causal) ou [i - W/2, i + W/2] (bidirecional). Tokens fora da janela recebem -inf na matriz de pontuação (score matrix).

full causal:           sliding window (W=4):
positions 0-7          positions 0-7, W=4
    0 1 2 3 4 5 6 7        0 1 2 3 4 5 6 7
0 | x                0 |  x
1 | x x              1 |  x x
2 | x x x            2 |  x x x
3 | x x x x          3 |  x x x x
4 | x x x x x        4 |    x x x x
5 | x x x x x x      5 |      x x x x
6 | x x x x x x x    6 |        x x x x
7 | x x x x x x x x  7 |          x x x x

Para N = 8192 e W = 1024, a matriz de pontuação tem em média 1024 × 8192 linhas não nulas — uma redução de 8×.

O KV cache encolhe com SWA. Apenas os últimos W tokens de K e V precisam ser mantidos por camada. Para uma configuração semelhante ao Gemma-3 (janela de 1024, contexto de 128K), o KV cache cai 128×.

Custo de qualidade. Modelos puramente baseados em SWA têm dificuldade com a recuperação de informações de longo alcance. A solução: intercalar camadas de SWA com camadas de atenção total (global). O Gemma 3 usa uma proporção SWA:global de 5:1. O Mistral 7B usava uma pilha de SWA causal onde a informação "flui para a frente" através de janelas sobrepostas — cada camada estende o campo receptivo efetivo em W e, após L camadas, o modelo pode atender a L × W tokens passados.

Atenção Esparsa / em Blocos

Escolhe-se um padrão de esparsidade de tamanho N × N previamente. Três formatos clássicos:

  • Local + strided (OpenAI sparse transformer). Atende aos últimos W tokens mais cada stride-ésimo token anterior. Captura tanto o contexto local quanto o de longo alcance com complexidade de computação O(N · sqrt(N)).
  • Longformer / BigBird. Janela local + um pequeno conjunto de tokens globais (ex: [CLS]) que atendem a todos e recebem atenção de todos + conexões esparsas aleatórias. Permite empiricamente o dobro do contexto com qualidade equivalente.
  • Native Sparse Attention (DeepSeek, 2025). Aprende quais blocos de (Q, K) são importantes; ignora os blocos de zeros no nível do kernel. Compatível com FlashAttention.

A atenção esparsa é uma história de engenharia de kernel. A matemática é simples (mascarar a matriz de pontuação); a vantagem real vem de nunca carregar as entradas zeradas para a SRAM. O FlashAttention-3 e a API FlexAttention de 2026 tornam os padrões esparsos personalizados cidadãos de primeira classe no PyTorch.

Atenção Diferencial (DIFF Transformer, 2024)

A atenção regular sofre com o problema do "sumidouro de atenção" (attention sink): o softmax força cada linha a somar 1, então queries que não querem atender a nada em particular acabam concentrando peso no primeiro token (ou nos primeiros). Isso consome capacidade de representação que deveria ir para o conteúdo real.

A atenção diferencial resolve isso computando dois mapas de atenção e subtraindo-os:

A1 = softmax(Q1 K1^T / √d)
A2 = softmax(Q2 K2^T / √d)
DiffAttn = (A1 - λ · A2) V

onde λ é um escalar aprendido (tipicamente 0.5–0.8). A1 captura os pesos do conteúdo real; A2 captura o sumidouro. A subtração cancela o sumidouro e realoca o peso para os tokens relevantes.

Resultados reportados (Microsoft 2024): perplexidade 5–10% menor, contexto efetivo 1.5–2× mais longo para o mesmo comprimento de treino, e recuperação do tipo "agulha no palheiro" (needle-in-a-haystack) muito mais nítida.

Comparativo de Variantes

Variante Computação KV cache Qualidade vs. total Uso em produção
Full attention O(N²) O(N) por camada baseline camada padrão de todo modelo
SWA (janela 1024) O(N·W) O(W) por camada -0.1 ppl, bom com camadas globais Gemma 2/3, Phi-3-Long
Local + strided esparsa O(N·√N) misto similar a SWA OpenAI sparse transformer, Longformer
BigBird (local + global + aleatória) O(N) aprox. misto iguala a total com 2× contexto BERT inicial de contexto longo
Native Sparse (DeepSeek-V3.2) O(N · fração ativa) O(N) dentro de 0.05 ppl DeepSeek-V3.2, 2025
Diferencial O(2·N²) O(2N) -5 a -10% ppl DIFF Transformer, modelos do início de 2026

Implemente

Consulte code/main.py. Implementamos um comparador de máscara causal que exibe as atenções total, SWA, local+strided e diferencial lado a lado em uma sequência de brinquedo.

Passo 1: máscara causal completa (baseline)

def causal_mask(n):
    return [[0.0 if j <= i else float("-inf") for j in range(n)] for i in range(n)]

Baseline da Lição 07. Triangular inferior; peso zero acima da diagonal.

Passo 2: máscara causal de janela deslizante

def swa_mask(n, window):
    M = [[float("-inf")] * n for _ in range(n)]
    for i in range(n):
        lo = max(0, i - window + 1)
        for j in range(lo, i + 1):
            M[i][j] = 0.0
    return M

Um parâmetro — window. Para window >= n, você recupera a atenção causal completa. Para window = 1, cada token atende apenas a si mesmo.

Passo 3: máscara esparsa local + strided

def strided_mask(n, window, stride):
    M = [[float("-inf")] * n for _ in range(n)]
    for i in range(n):
        lo = max(0, i - window + 1)
        for j in range(lo, i + 1):
            M[i][j] = 0.0
        for j in range(0, i + 1, stride):
            M[i][j] = 0.0
    return M

Janela local densa mais cada stride-ésimo token de volta ao início da sequência. O campo receptivo cresce em passos logarítmicos com camadas adicionais.

Passo 4: atenção diferencial

def diff_attention(Q1, K1, Q2, K2, V, lam):
    A1 = softmax_causal(Q1 @ K1.T / sqrt_d)
    A2 = softmax_causal(Q2 @ K2.T / sqrt_d)
    return (A1 - lam * A2) @ V

Duas passagens de atenção, subtraídas com um coeficiente de mistura aprendido. No código, comparamos o mapa de calor do sumidouro de atenção de uma única atenção contra a diferencial e observamos o sumidouro colapsar.

Passo 5: tamanhos do KV cache

Imprima o tamanho do cache por camada com N = 131072 para cada variante. As variantes SWA e esparsas caem em 10–100×. A diferencial dobra. Pague sua conta de memória conscientemente.

Use

Padrões de produção de 2026:

from transformers import AutoModelForCausalLM
# Gemma 3 mistura SWA (window=1024) e camadas globais na proporção 5:1.
model = AutoModelForCausalLM.from_pretrained("google/gemma-3-27b-it")
# print(model.config.sliding_window, model.config.layer_types)

O FlexAttention no PyTorch 2.5+ aceita uma função de máscara:

from torch.nn.attention.flex_attention import flex_attention, create_block_mask

def swa_pattern(b, h, q_idx, kv_idx):
    return (q_idx - kv_idx < 1024) & (q_idx >= kv_idx)

mask = create_block_mask(swa_pattern, B=batch, H=heads, Q_LEN=n, KV_LEN=n)
out = flex_attention(q, k, v, block_mask=mask)

Isso compila para um kernel Triton personalizado. Fica a menos de 10% da velocidade do FlashAttention-3 para padrões comuns, e a função de máscara é um invocável Python.

Quando escolher cada uma:

  • Atenção total pura — todas as camadas até um contexto de aproximadamente 16K, ou quando a qualidade de recuperação é primordial.
  • Mix SWA + global — contextos longos (>32K), treino e inferência limitados por memória. O padrão de 2026 acima de 32K.
  • Atenção esparsa em blocos — kernel personalizado, padrão personalizado. Reservado para cargas de trabalho especializadas (recuperação, áudio).
  • Atenção diferencial — qualquer carga de trabalho onde a contaminação por sumidouros de atenção prejudique (RAG de contexto longo, needle-in-a-haystack).

Envie

Consulte outputs/skill-attention-variant-picker.md. A skill escolhe uma topologia de atenção para um novo modelo considerando a extensão de contexto alvo, demandas de recuperação de informação e o perfil de computação de treino/inferência.

Exercícios

  1. Fácil. Execute code/main.py. Verifique se a SWA com window=4 zera tudo o que estiver fora dos últimos 4 tokens por linha. Verifique se window=n reproduz a atenção causal completa de forma idêntica bit a bit.
  2. Médio. Implemente SWA causal com window=1024 no topo do projeto final da Lição 07. Treine por 1.000 passos no tinyshakespeare. Quanto a perda de validação piora em relação à atenção completa? Quanto a memória de pico diminui?
  3. Difícil. Implemente um mix de camadas no estilo Gemma-3 de 5:1 (5 SWA, 1 global) no modelo do projeto final. Compare perda, memória e qualidade de geração com baselines puramente baseadas em SWA e puramente globais com parâmetros equivalentes.
  4. Difícil. Implemente atenção diferencial com um λ aprendido por cabeça. Treine em uma tarefa sintética de recuperação (uma agulha, 2.000 distratores). Meça a precisão da recuperação em comparação a um baseline de atenção única com parâmetros equivalentes.

Termos-Chave

Termo O que dizem O que realmente significa
Atenção por janela deslizante (SWA) "Atenção local" Cada query atende aos seus últimos W tokens; o KV cache encolhe para O(W).
Campo receptivo efetivo "Até onde o modelo enxerga no passado" Em uma pilha SWA de L camadas com janela W, até L × W tokens.
Longformer / BigBird "Local + global + aleatória" Padrões esparsos com alguns tokens globais que sempre atendem e recebem atenção; abordagem antiga para contextos longos.
Native Sparse Attention "O truque de kernel da DeepSeek" Aprende a esparsidade no nível de bloco; ignora blocos zerados no nível de kernel, mantendo a qualidade.
Atenção diferencial "Dois mapas, um subtrai" DIFF Transformer: subtrai um mapa de atenção secundário multiplicado por um λ aprendido do mapa primário para cancelar os sumidouros de atenção.
Sumidouro de atenção (Attention sink) "O peso vaza para o token 0" A normalização do softmax força as linhas a somarem 1; queries não informativas depositam peso na posição 0.
FlexAttention "Máscara em Python" API do PyTorch 2.5+ que compila funções de máscara arbitrárias em kernels com o formato do FlashAttention.
Mix de tipos de camada "5:1 SWA para global" Intercala camadas de atenção esparsa e total em uma pilha para manter a qualidade usando menos memória.

Leitura Adicional

0 lifetime access. Curriculum based on AI Engineering from Scratch by Rohit Ghumare (MIT, used under attribution).