Перейти к содержанию

Реализация внимания с нуля

~7 минут чтения

Предварительно: Формулы линейной алгебры, базовый PyTorch


Зачем это нужно

Attention -- единственный механизм, который делает трансформеры трансформерами. Без него GPT-4, Claude, Gemini -- невозможны.

Аналогия: представь библиотеку. Ты приходишь с запросом (Query) -- "книга про Python". Библиотекарь сравнивает твой запрос с каталожными карточками (Keys) всех книг. Чем лучше совпадение -- тем больший вес получает содержимое (Value) этой книги в ответе. Attention делает ровно то же самое: каждый токен "ищет" релевантную информацию среди всех остальных токенов.

Ключевой инсайт: формула внимания поразительно проста -- это всего лишь взвешенная сумма значений, где веса определяются похожестью запроса и ключей.

Эволюция внимания

Тип Год Ключевая идея
Bahdanau Attention 2014 Аддитивное внимание (seq2seq)
Luong Attention 2015 Мультипликативное внимание
Self-Attention 2017 Q, K, V из одного входа
Multi-Head 2017 Параллельные головы внимания
Flash Attention 2022 IO-aware оптимизация памяти
Flash Attention 2 2023 Улучшенная параллелизация
Flash Attention 3 2024 Оптимизация под Hopper GPU

Формула внимания

Scaled Dot-Product Attention

Core Attention Formula

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]

Где:

- $Q \in \mathbb{R}^{n \times d_k}$ -- матрица запросов (Queries)
- $K \in \mathbb{R}^{m \times d_k}$ -- матрица ключей (Keys)
- $V \in \mathbb{R}^{m \times d_v}$ -- матрица значений (Values)
- $\sqrt{d_k}$ -- масштабирующий коэффициент (предотвращает насыщение softmax)

Пошаговое вычисление

graph TD
    INPUT["Input: Q, K, V<br/>[batch, seq_len, d_k]"]
    INPUT --> S1

    S1["Step 1: Attention Scores<br/>scores = Q @ K^T<br/>[batch, seq_len, seq_len]"]
    S1 --> S2

    S2["Step 2: Scale<br/>scores = scores / sqrt(d_k)<br/>prevents softmax saturation"]
    S2 --> S3

    S3["Step 3: Mask (causal)<br/>scores.masked_fill(mask==0, -inf)<br/>future positions -> 0"]
    S3 --> S4

    S4["Step 4: Softmax<br/>attn_weights = softmax(scores, dim=-1)<br/>each row sums to 1"]
    S4 --> S5

    S5["Step 5: Weighted Sum<br/>output = attn_weights @ V<br/>[batch, seq_len, d_v]"]

    style INPUT fill:#e3f2fd,stroke:#1565c0
    style S1 fill:#fff3e0,stroke:#ef6c00
    style S2 fill:#fff3e0,stroke:#ef6c00
    style S3 fill:#fce4ec,stroke:#c62828
    style S4 fill:#e8f5e9,stroke:#2e7d32
    style S5 fill:#e8eaf6,stroke:#3f51b5

Реализация на PyTorch

Базовый Self-Attention

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def scaled_dot_product_attention(
    q: torch.Tensor,  # [batch, seq_len, d_k]
    k: torch.Tensor,  # [batch, seq_len, d_k]
    v: torch.Tensor,  # [batch, seq_len, d_v]
    mask: torch.Tensor = None,  # [batch, seq_len, seq_len]
    dropout: nn.Dropout = None,
) -> torch.Tensor:
    """Scaled dot-product attention.

    Args:
        q: Query tensor
        k: Key tensor
        v: Value tensor
        mask: Optional mask (True = attend, False = ignore)
        dropout: Optional dropout layer

    Returns:
        Attention output [batch, seq_len, d_v]
    """
    d_k = q.size(-1)

    # Step 1 & 2: Compute scaled attention scores
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (1)!

    # Step 3: Apply mask (if provided)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))  # (2)!

    # Step 4: Softmax
    attn_weights = F.softmax(scores, dim=-1)  # (3)!

    # Optional dropout
    if dropout is not None:
        attn_weights = dropout(attn_weights)

    # Step 5: Weighted sum
    output = torch.matmul(attn_weights, v)

    return output, attn_weights
  1. Q @ K^T дает матрицу [seq_len, seq_len] -- "похожесть" каждого токена с каждым. Деление на sqrt(d_k) предотвращает насыщение softmax при больших d_k.
  2. Causal mask: -inf в позициях будущих токенов -> после softmax эти позиции станут 0. Так модель "не видит будущее".
  3. dim=-1 -- softmax по последней оси, каждая строка суммируется в 1. Строка i -- это веса внимания токена i ко всем остальным.

Multi-Head Attention

class MultiHeadAttention(nn.Module):
    """Multi-Head Attention mechanism.

    Allows the model to jointly attend to information from different
    representation subspaces at different positions.
    """

    def __init__(
        self,
        d_model: int,
        n_heads: int,
        dropout: float = 0.1,
    ):
        super().__init__()

        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"

        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads  # Dimension per head

        # Linear projections for Q, K, V
        self.w_q = nn.Linear(d_model, d_model, bias=False)  # (1)!
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)

        # Output projection
        self.w_o = nn.Linear(d_model, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        mask: torch.Tensor = None,
    ) -> torch.Tensor:
        batch_size = q.size(0)

        # Linear projections
        q = self.w_q(q)  # [batch, seq_len, d_model]
        k = self.w_k(k)
        v = self.w_v(v)

        # Reshape for multi-head: [batch, n_heads, seq_len, d_k]
        q = q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)  # (2)!
        k = k.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        v = v.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)

        # Apply attention
        attn_output, attn_weights = scaled_dot_product_attention(
            q, k, v, mask, self.dropout
        )

        # Reshape back: [batch, seq_len, d_model]
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, -1, self.d_model
        )  # (3)!

        # Output projection
        output = self.w_o(attn_output)  # (4)!

        return output, attn_weights
  1. Три отдельных линейных проекции W_Q, W_K, W_V -- каждая [d_model, d_model]. На практике часто объединяют в один nn.Linear(d_model, 3*d_model) и разрезают.
  2. view + transpose: [batch, seq, d_model] -> [batch, seq, n_heads, d_k] -> [batch, n_heads, seq, d_k]. Каждая "голова" получает свой срез вектора размерности d_k.
  3. .contiguous() нужен после transpose -- данные в памяти перестают быть непрерывными, а view требует непрерывности.
  4. Output projection W_O -- обучаемая линейная трансформация, которая "смешивает" выходы всех голов обратно в d_model.

Полный слой Self-Attention

class SelfAttentionLayer(nn.Module):
    """Complete self-attention layer with normalization and residual."""

    def __init__(
        self,
        d_model: int,
        n_heads: int,
        dropout: float = 0.1,
    ):
        super().__init__()

        self.attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.norm = nn.RMSNorm(d_model)  # or nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        mask: torch.Tensor = None,
    ) -> torch.Tensor:
        # Pre-norm architecture (used in modern LLMs)
        residual = x
        x = self.norm(x)

        # Self-attention: Q, K, V all come from x
        attn_output, _ = self.attention(x, x, x, mask)

        # Residual connection
        x = residual + self.dropout(attn_output)

        return x

Каузальное (авторегрессивное) внимание

Каузальная маска

def create_causal_mask(seq_len: int, device: str = 'cuda') -> torch.Tensor:
    """Create causal mask for decoder self-attention.

    Mask[b, i, j] = 0 if j > i (can't attend to future)
    Mask[b, i, j] = 1 if j <= i (can attend to past)
    """
    mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
    return mask.unsqueeze(0)  # [1, seq_len, seq_len]

# Alternative: boolean mask for F.scaled_dot_product_attention
def create_causal_mask_bool(seq_len: int, device: str = 'cuda') -> torch.Tensor:
    """Boolean mask where True = masked (ignore)."""
    mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
    return mask.bool()

Decoder-Only блок внимания

class DecoderAttention(nn.Module):
    """Decoder self-attention with causal masking."""

    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.norm = nn.RMSNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        seq_len = x.size(1)

        # Create causal mask
        causal_mask = create_causal_mask(seq_len, x.device)

        # Pre-norm + self-attention + residual
        residual = x
        x = self.norm(x)
        attn_out, _ = self.attention(x, x, x, mask=causal_mask)
        x = residual + self.dropout(attn_out)

        return x

Causal mask: True или False?

PyTorch F.scaled_dot_product_attention использует is_causal=True, а кастомная маска -- инвертированная логика. В одних реализациях True = "attend", в других True = "mask out". Проверяй документацию конкретного API. Путаница здесь -- один из самых частых багов при реализации трансформеров.

Эффективные варианты внимания

Flash Attention (концептуально)

# Note: Real Flash Attention requires custom CUDA kernels
# This is a conceptual wrapper

def flash_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    dropout_p: float = 0.0,
    causal: bool = False,
) -> torch.Tensor:
    """Flash Attention - memory-efficient attention.

    Key optimizations:
    1. Tiling: Process attention in blocks to fit in SRAM
    2. Recomputation: Don't store full attention matrix
    3. IO-aware: Minimize HBM reads/writes

    Memory: O(N) instead of O(N²) for standard attention
    """
    # In PyTorch 2.0+, use:
    if hasattr(F, 'scaled_dot_product_attention'):
        with torch.backends.cuda.sdp_kernel(enable_flash=True):
            return F.scaled_dot_product_attention(
                q, k, v,
                attn_mask=None,
                dropout_p=dropout_p,
                is_causal=causal,
            )
    else:
        # Fallback to standard attention
        return scaled_dot_product_attention(q, k, v)[0]

Сравнение по памяти

Attention Type Memory (seq_len=N) Speed
Standard \(O(N^2 \cdot d)\) Baseline
Flash Attention \(O(N \cdot d)\) Same or faster
Flash Attention 2 \(O(N \cdot d)\) 2x faster
Flash Attention 3 \(O(N \cdot d)\) 3x faster (H100)

Кросс-внимание (Encoder-Decoder)

Слой кросс-внимания

class CrossAttention(nn.Module):
    """Cross-attention for encoder-decoder architectures.

    Q comes from decoder, K/V come from encoder.
    """

    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.norm = nn.RMSNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        decoder_hidden: torch.Tensor,  # Q - from decoder
        encoder_hidden: torch.Tensor,  # K, V - from encoder
    ) -> torch.Tensor:
        residual = decoder_hidden
        x = self.norm(decoder_hidden)

        # Q from decoder, K/V from encoder
        attn_out, _ = self.attention(x, encoder_hidden, encoder_hidden)

        return residual + self.dropout(attn_out)

Паттерны внимания

Основные паттерны

Каждый токен видит все остальные. Используется в encoder-моделях.

Сложность: \(O(N^2)\) | Где: BERT, RoBERTa, encoder-задачи

Mask: [[1,1,1,1],
       [1,1,1,1],
       [1,1,1,1],
       [1,1,1,1]]

Каждый токен видит только предыдущие (и себя). Фундамент авторегрессивных LLM.

Сложность: \(O(N^2)\) треугольная | Где: GPT-4, LLaMA, Claude

Mask: [[1,0,0,0],
       [1,1,0,0],
       [1,1,1,0],
       [1,1,1,1]]

Каждый токен видит только локальное окно из \(W\) соседей. Линейная сложность.

Сложность: \(O(N \times W)\) | Где: Mistral, Longformer

Window=2: [[1,1,0,0],
           [1,1,1,0],
           [0,1,1,1],
           [0,0,1,1]]

Локальное окно + глобальные токены (CLS, разделители) видят всех.

Сложность: \(O(N \times (W + G))\) | Где: BigBird, Longformer

LSH-хеширование или обучаемая маршрутизация определяют паттерн разреженности.

Сложность: \(O(N \log N)\) или \(O(N \sqrt{N})\) | Где: Reformer, Routing Transformer

Сравнение паттернов

Паттерн Сложность Макс. длина Модели
Full \(O(N^2)\) ~512-2K BERT, T5 encoder
Causal \(O(N^2)\) ~2K-128K GPT-4, LLaMA, Claude
Sliding Window \(O(N \cdot W)\) ~32K-128K Mistral
Global + Local \(O(N \cdot (W+G))\) ~4K-16K BigBird, Longformer
Sparse \(O(N \log N)\) ~64K+ Reformer

Зачем делить на \(\sqrt{d_k}\)?

Интуиция

Dot product двух случайных векторов растет с размерностью. Это проблема, потому что softmax экспоненциально усиливает разницу: большие значения "съедают" все вероятности, а маленькие получают \(\approx 0\).

Аналогия: softmax -- как голосование, где каждый голос экспоненциально зависит от оценки. Если оценки слишком разные (100 vs 1 vs 0.5), один кандидат получит 99.99% голосов. Scaling нормализует оценки, чтобы голосование было осмысленным.

Математически

Для \(q, k \in \mathbb{R}^d\) с элементами \(\sim \mathcal{N}(0, 1)\):

\[q \cdot k = \sum_{i=1}^{d} q_i \cdot k_i\]
  • \(\mathbb{E}[q \cdot k] = 0\) (среднее -- ноль)
  • \(\text{Var}[q \cdot k] = d\) (\(d\) независимых слагаемых, каждое с дисперсией 1)
  • \(\text{Std}[q \cdot k] = \sqrt{d}\)

Числовой пример: \(d_k = 64\) vs \(d_k = 512\)

Без scaling (\(d_k = 512\)): dot products \(\sim \mathcal{N}(0, 512)\), типичные значения \(\pm 22.6\)

scores = [21.3, -18.7, 5.2, -22.1]
softmax = [0.9999, 0.0000, 0.0000, 0.0000]  -- одна позиция "съела" все

С scaling: \(\text{scores} / \sqrt{512} \approx \text{scores} / 22.6\)

scaled  = [0.94, -0.83, 0.23, -0.98]
softmax = [0.36, 0.06, 0.18, 0.05]  -- осмысленное распределение

Результат: без scaling softmax вырождается в argmax, градиенты исчезают, модель не учится.

Scaling и размерность головы, а не модели

Делим на \(\sqrt{d_k}\) где \(d_k = d_{model} / n_{heads}\), не на \(\sqrt{d_{model}}\). Для модели с \(d_{model}=768\) и 12 головами: \(d_k = 64\), scaling = \(\sqrt{64} = 8\). Частая ошибка -- путать \(d_{model}\) и \(d_k\).


Шпаргалка для интервью

Ключевые формулы

Концепция Формула
Attention \(\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\)
Multi-Head \(\text{Concat}(h_1, \ldots, h_h)W^O\), где \(h_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)\)
Causal Mask \(M_{ij} = 0\) если \(j > i\), иначе \(1\)
Память (стандарт) \(O(N^2 \cdot d)\)
Память (flash) \(O(N \cdot d)\)

Сложность по операциям

Операция Время Память
\(QK^T\) \(O(N^2 \cdot d)\) \(O(N^2)\)
Softmax \(O(N^2)\) \(O(N^2)\)
\(\text{Attn} \cdot V\) \(O(N^2 \cdot d)\) \(O(N \cdot d)\)
Итого \(O(N^2 \cdot d)\) \(O(N^2)\)

Вопросы на собеседовании

1. "Зачем делить на \(\sqrt{d_k}\)?"

  • ❌ Red flag: "Так в статье написано" / "Для нормализации"
  • ✅ Strong: "Dot product растет с размерностью -- дисперсия \(= d_k\). Без scaling softmax вырождается в argmax, градиенты исчезают. Деление приводит дисперсию к 1."

2. "Зачем multi-head, а не одна большая голова?"

  • ❌ Red flag: "Больше параметров -- лучше"
  • ✅ Strong: "Каждая голова учит свой тип зависимости -- синтаксис, семантику, позиционные связи. Одна голова вынуждена смешивать все в одном subspace. Multi-head = параллельная декомпозиция по подпространствам."
  • Типичный follow-up: "Сколько голов нужно? Все ли головы полезны?" -- см. pruning голов (Voita et al., 2019).

3. "Bottleneck attention по памяти?"

  • ✅ Strong: "Матрица \(N \times N\) -- для sequence length 128K это \(128K^2 = 16\) млрд элементов в FP16 = 32 GB. Flash Attention решает через тайлинг: матрица никогда не материализуется целиком, \(O(N)\) памяти."

4. "Pre-norm vs post-norm?"

  • ✅ Strong: "Post-norm (оригинальный трансформер) -- градиенты проходят через LayerNorm, что может их уменьшать. Pre-norm (GPT-2+, LLaMA) -- residual path чистый, градиенты текут напрямую. Pre-norm стабильнее, но post-norm может давать лучшее качество при правильном warmup."

5. "Как Q, K, V получаются в self-attention?"

  • ❌ Red flag: "Q это запрос пользователя, K это ключи в базе данных"
  • ✅ Strong: "В self-attention все три -- линейные проекции одного входа: \(Q = XW^Q\), \(K = XW^K\), \(V = XW^V\). Три разные матрицы позволяют разделить роли: 'что я ищу' vs 'что я предлагаю как ключ' vs 'что я даю как содержимое'."

Задание для самопроверки

Реализуйте attention с нуля без подглядывания. Проверьте: (1) output shape совпадает с input shape, (2) attention weights суммируются в 1 по последней оси, (3) causal mask не позволяет позиции \(i\) видеть позицию \(j > i\).


Источники

  1. Vaswani et al. -- "Attention Is All You Need" (2017), arXiv:1706.03762
  2. Jay Alammar -- The Illustrated Transformer
  3. Andrej Karpathy -- "Let's build GPT" (YouTube)
  4. Hugging Face -- реализация attention в Transformers
  5. Dao et al. -- "FlashAttention: Fast and Memory-Efficient Exact Attention" (2022)
  6. PyTorch -- F.scaled_dot_product_attention
  7. Lilian Weng -- "Attention? Attention!" (blog)
  8. Stanford CS224N -- лекция по Attention
  9. Harvard NLP -- "The Annotated Transformer"

See Also