Реализация внимания с нуля¶
~7 минут чтения
Предварительно: Формулы линейной алгебры, базовый PyTorch
Зачем это нужно¶
Attention -- единственный механизм, который делает трансформеры трансформерами. Без него GPT-4, Claude, Gemini -- невозможны.
Аналогия: представь библиотеку. Ты приходишь с запросом (Query) -- "книга про Python". Библиотекарь сравнивает твой запрос с каталожными карточками (Keys) всех книг. Чем лучше совпадение -- тем больший вес получает содержимое (Value) этой книги в ответе. Attention делает ровно то же самое: каждый токен "ищет" релевантную информацию среди всех остальных токенов.
Ключевой инсайт: формула внимания поразительно проста -- это всего лишь взвешенная сумма значений, где веса определяются похожестью запроса и ключей.
Эволюция внимания¶
| Тип | Год | Ключевая идея |
|---|---|---|
| Bahdanau Attention | 2014 | Аддитивное внимание (seq2seq) |
| Luong Attention | 2015 | Мультипликативное внимание |
| Self-Attention | 2017 | Q, K, V из одного входа |
| Multi-Head | 2017 | Параллельные головы внимания |
| Flash Attention | 2022 | IO-aware оптимизация памяти |
| Flash Attention 2 | 2023 | Улучшенная параллелизация |
| Flash Attention 3 | 2024 | Оптимизация под Hopper GPU |
Формула внимания¶
Scaled Dot-Product Attention¶
Core Attention Formula
Где:
- $Q \in \mathbb{R}^{n \times d_k}$ -- матрица запросов (Queries)
- $K \in \mathbb{R}^{m \times d_k}$ -- матрица ключей (Keys)
- $V \in \mathbb{R}^{m \times d_v}$ -- матрица значений (Values)
- $\sqrt{d_k}$ -- масштабирующий коэффициент (предотвращает насыщение softmax)
Пошаговое вычисление¶
graph TD
INPUT["Input: Q, K, V<br/>[batch, seq_len, d_k]"]
INPUT --> S1
S1["Step 1: Attention Scores<br/>scores = Q @ K^T<br/>[batch, seq_len, seq_len]"]
S1 --> S2
S2["Step 2: Scale<br/>scores = scores / sqrt(d_k)<br/>prevents softmax saturation"]
S2 --> S3
S3["Step 3: Mask (causal)<br/>scores.masked_fill(mask==0, -inf)<br/>future positions -> 0"]
S3 --> S4
S4["Step 4: Softmax<br/>attn_weights = softmax(scores, dim=-1)<br/>each row sums to 1"]
S4 --> S5
S5["Step 5: Weighted Sum<br/>output = attn_weights @ V<br/>[batch, seq_len, d_v]"]
style INPUT fill:#e3f2fd,stroke:#1565c0
style S1 fill:#fff3e0,stroke:#ef6c00
style S2 fill:#fff3e0,stroke:#ef6c00
style S3 fill:#fce4ec,stroke:#c62828
style S4 fill:#e8f5e9,stroke:#2e7d32
style S5 fill:#e8eaf6,stroke:#3f51b5
Реализация на PyTorch¶
Базовый Self-Attention¶
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def scaled_dot_product_attention(
q: torch.Tensor, # [batch, seq_len, d_k]
k: torch.Tensor, # [batch, seq_len, d_k]
v: torch.Tensor, # [batch, seq_len, d_v]
mask: torch.Tensor = None, # [batch, seq_len, seq_len]
dropout: nn.Dropout = None,
) -> torch.Tensor:
"""Scaled dot-product attention.
Args:
q: Query tensor
k: Key tensor
v: Value tensor
mask: Optional mask (True = attend, False = ignore)
dropout: Optional dropout layer
Returns:
Attention output [batch, seq_len, d_v]
"""
d_k = q.size(-1)
# Step 1 & 2: Compute scaled attention scores
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) # (1)!
# Step 3: Apply mask (if provided)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf')) # (2)!
# Step 4: Softmax
attn_weights = F.softmax(scores, dim=-1) # (3)!
# Optional dropout
if dropout is not None:
attn_weights = dropout(attn_weights)
# Step 5: Weighted sum
output = torch.matmul(attn_weights, v)
return output, attn_weights
Q @ K^Tдает матрицу[seq_len, seq_len]-- "похожесть" каждого токена с каждым. Деление наsqrt(d_k)предотвращает насыщение softmax при большихd_k.- Causal mask:
-infв позициях будущих токенов -> после softmax эти позиции станут 0. Так модель "не видит будущее". dim=-1-- softmax по последней оси, каждая строка суммируется в 1. Строкаi-- это веса внимания токенаiко всем остальным.
Multi-Head Attention¶
class MultiHeadAttention(nn.Module):
"""Multi-Head Attention mechanism.
Allows the model to jointly attend to information from different
representation subspaces at different positions.
"""
def __init__(
self,
d_model: int,
n_heads: int,
dropout: float = 0.1,
):
super().__init__()
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads # Dimension per head
# Linear projections for Q, K, V
self.w_q = nn.Linear(d_model, d_model, bias=False) # (1)!
self.w_k = nn.Linear(d_model, d_model, bias=False)
self.w_v = nn.Linear(d_model, d_model, bias=False)
# Output projection
self.w_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: torch.Tensor = None,
) -> torch.Tensor:
batch_size = q.size(0)
# Linear projections
q = self.w_q(q) # [batch, seq_len, d_model]
k = self.w_k(k)
v = self.w_v(v)
# Reshape for multi-head: [batch, n_heads, seq_len, d_k]
q = q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # (2)!
k = k.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
v = v.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
# Apply attention
attn_output, attn_weights = scaled_dot_product_attention(
q, k, v, mask, self.dropout
)
# Reshape back: [batch, seq_len, d_model]
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, -1, self.d_model
) # (3)!
# Output projection
output = self.w_o(attn_output) # (4)!
return output, attn_weights
- Три отдельных линейных проекции
W_Q, W_K, W_V-- каждая[d_model, d_model]. На практике часто объединяют в одинnn.Linear(d_model, 3*d_model)и разрезают. - view + transpose:
[batch, seq, d_model]->[batch, seq, n_heads, d_k]->[batch, n_heads, seq, d_k]. Каждая "голова" получает свой срез вектора размерностиd_k. .contiguous()нужен после transpose -- данные в памяти перестают быть непрерывными, а view требует непрерывности.- Output projection
W_O-- обучаемая линейная трансформация, которая "смешивает" выходы всех голов обратно вd_model.
Полный слой Self-Attention¶
class SelfAttentionLayer(nn.Module):
"""Complete self-attention layer with normalization and residual."""
def __init__(
self,
d_model: int,
n_heads: int,
dropout: float = 0.1,
):
super().__init__()
self.attention = MultiHeadAttention(d_model, n_heads, dropout)
self.norm = nn.RMSNorm(d_model) # or nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor = None,
) -> torch.Tensor:
# Pre-norm architecture (used in modern LLMs)
residual = x
x = self.norm(x)
# Self-attention: Q, K, V all come from x
attn_output, _ = self.attention(x, x, x, mask)
# Residual connection
x = residual + self.dropout(attn_output)
return x
Каузальное (авторегрессивное) внимание¶
Каузальная маска¶
def create_causal_mask(seq_len: int, device: str = 'cuda') -> torch.Tensor:
"""Create causal mask for decoder self-attention.
Mask[b, i, j] = 0 if j > i (can't attend to future)
Mask[b, i, j] = 1 if j <= i (can attend to past)
"""
mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
return mask.unsqueeze(0) # [1, seq_len, seq_len]
# Alternative: boolean mask for F.scaled_dot_product_attention
def create_causal_mask_bool(seq_len: int, device: str = 'cuda') -> torch.Tensor:
"""Boolean mask where True = masked (ignore)."""
mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
return mask.bool()
Decoder-Only блок внимания¶
class DecoderAttention(nn.Module):
"""Decoder self-attention with causal masking."""
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
super().__init__()
self.attention = MultiHeadAttention(d_model, n_heads, dropout)
self.norm = nn.RMSNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
seq_len = x.size(1)
# Create causal mask
causal_mask = create_causal_mask(seq_len, x.device)
# Pre-norm + self-attention + residual
residual = x
x = self.norm(x)
attn_out, _ = self.attention(x, x, x, mask=causal_mask)
x = residual + self.dropout(attn_out)
return x
Causal mask: True или False?
PyTorch F.scaled_dot_product_attention использует is_causal=True, а кастомная маска -- инвертированная логика. В одних реализациях True = "attend", в других True = "mask out". Проверяй документацию конкретного API. Путаница здесь -- один из самых частых багов при реализации трансформеров.
Эффективные варианты внимания¶
Flash Attention (концептуально)¶
# Note: Real Flash Attention requires custom CUDA kernels
# This is a conceptual wrapper
def flash_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dropout_p: float = 0.0,
causal: bool = False,
) -> torch.Tensor:
"""Flash Attention - memory-efficient attention.
Key optimizations:
1. Tiling: Process attention in blocks to fit in SRAM
2. Recomputation: Don't store full attention matrix
3. IO-aware: Minimize HBM reads/writes
Memory: O(N) instead of O(N²) for standard attention
"""
# In PyTorch 2.0+, use:
if hasattr(F, 'scaled_dot_product_attention'):
with torch.backends.cuda.sdp_kernel(enable_flash=True):
return F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=dropout_p,
is_causal=causal,
)
else:
# Fallback to standard attention
return scaled_dot_product_attention(q, k, v)[0]
Сравнение по памяти¶
| Attention Type | Memory (seq_len=N) | Speed |
|---|---|---|
| Standard | \(O(N^2 \cdot d)\) | Baseline |
| Flash Attention | \(O(N \cdot d)\) | Same or faster |
| Flash Attention 2 | \(O(N \cdot d)\) | 2x faster |
| Flash Attention 3 | \(O(N \cdot d)\) | 3x faster (H100) |
Кросс-внимание (Encoder-Decoder)¶
Слой кросс-внимания¶
class CrossAttention(nn.Module):
"""Cross-attention for encoder-decoder architectures.
Q comes from decoder, K/V come from encoder.
"""
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
super().__init__()
self.attention = MultiHeadAttention(d_model, n_heads, dropout)
self.norm = nn.RMSNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(
self,
decoder_hidden: torch.Tensor, # Q - from decoder
encoder_hidden: torch.Tensor, # K, V - from encoder
) -> torch.Tensor:
residual = decoder_hidden
x = self.norm(decoder_hidden)
# Q from decoder, K/V from encoder
attn_out, _ = self.attention(x, encoder_hidden, encoder_hidden)
return residual + self.dropout(attn_out)
Паттерны внимания¶
Основные паттерны¶
Каждый токен видит все остальные. Используется в encoder-моделях.
Сложность: \(O(N^2)\) | Где: BERT, RoBERTa, encoder-задачи
Каждый токен видит только предыдущие (и себя). Фундамент авторегрессивных LLM.
Сложность: \(O(N^2)\) треугольная | Где: GPT-4, LLaMA, Claude
Каждый токен видит только локальное окно из \(W\) соседей. Линейная сложность.
Сложность: \(O(N \times W)\) | Где: Mistral, Longformer
Локальное окно + глобальные токены (CLS, разделители) видят всех.
Сложность: \(O(N \times (W + G))\) | Где: BigBird, Longformer
LSH-хеширование или обучаемая маршрутизация определяют паттерн разреженности.
Сложность: \(O(N \log N)\) или \(O(N \sqrt{N})\) | Где: Reformer, Routing Transformer
Сравнение паттернов¶
| Паттерн | Сложность | Макс. длина | Модели |
|---|---|---|---|
| Full | \(O(N^2)\) | ~512-2K | BERT, T5 encoder |
| Causal | \(O(N^2)\) | ~2K-128K | GPT-4, LLaMA, Claude |
| Sliding Window | \(O(N \cdot W)\) | ~32K-128K | Mistral |
| Global + Local | \(O(N \cdot (W+G))\) | ~4K-16K | BigBird, Longformer |
| Sparse | \(O(N \log N)\) | ~64K+ | Reformer |
Зачем делить на \(\sqrt{d_k}\)?¶
Интуиция¶
Dot product двух случайных векторов растет с размерностью. Это проблема, потому что softmax экспоненциально усиливает разницу: большие значения "съедают" все вероятности, а маленькие получают \(\approx 0\).
Аналогия: softmax -- как голосование, где каждый голос экспоненциально зависит от оценки. Если оценки слишком разные (100 vs 1 vs 0.5), один кандидат получит 99.99% голосов. Scaling нормализует оценки, чтобы голосование было осмысленным.
Математически¶
Для \(q, k \in \mathbb{R}^d\) с элементами \(\sim \mathcal{N}(0, 1)\):
- \(\mathbb{E}[q \cdot k] = 0\) (среднее -- ноль)
- \(\text{Var}[q \cdot k] = d\) (\(d\) независимых слагаемых, каждое с дисперсией 1)
- \(\text{Std}[q \cdot k] = \sqrt{d}\)
Числовой пример: \(d_k = 64\) vs \(d_k = 512\)
Без scaling (\(d_k = 512\)): dot products \(\sim \mathcal{N}(0, 512)\), типичные значения \(\pm 22.6\)
scores = [21.3, -18.7, 5.2, -22.1]
softmax = [0.9999, 0.0000, 0.0000, 0.0000] -- одна позиция "съела" все
С scaling: \(\text{scores} / \sqrt{512} \approx \text{scores} / 22.6\)
Результат: без scaling softmax вырождается в argmax, градиенты исчезают, модель не учится.
Scaling и размерность головы, а не модели
Делим на \(\sqrt{d_k}\) где \(d_k = d_{model} / n_{heads}\), не на \(\sqrt{d_{model}}\). Для модели с \(d_{model}=768\) и 12 головами: \(d_k = 64\), scaling = \(\sqrt{64} = 8\). Частая ошибка -- путать \(d_{model}\) и \(d_k\).
Шпаргалка для интервью¶
Ключевые формулы¶
| Концепция | Формула |
|---|---|
| Attention | \(\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\) |
| Multi-Head | \(\text{Concat}(h_1, \ldots, h_h)W^O\), где \(h_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)\) |
| Causal Mask | \(M_{ij} = 0\) если \(j > i\), иначе \(1\) |
| Память (стандарт) | \(O(N^2 \cdot d)\) |
| Память (flash) | \(O(N \cdot d)\) |
Сложность по операциям¶
| Операция | Время | Память |
|---|---|---|
| \(QK^T\) | \(O(N^2 \cdot d)\) | \(O(N^2)\) |
| Softmax | \(O(N^2)\) | \(O(N^2)\) |
| \(\text{Attn} \cdot V\) | \(O(N^2 \cdot d)\) | \(O(N \cdot d)\) |
| Итого | \(O(N^2 \cdot d)\) | \(O(N^2)\) |
Вопросы на собеседовании¶
1. "Зачем делить на \(\sqrt{d_k}\)?"
Red flag: "Так в статье написано" / "Для нормализации"
Strong: "Dot product растет с размерностью -- дисперсия \(= d_k\). Без scaling softmax вырождается в argmax, градиенты исчезают. Деление приводит дисперсию к 1."
2. "Зачем multi-head, а не одна большая голова?"
Red flag: "Больше параметров -- лучше"
Strong: "Каждая голова учит свой тип зависимости -- синтаксис, семантику, позиционные связи. Одна голова вынуждена смешивать все в одном subspace. Multi-head = параллельная декомпозиция по подпространствам."
- Типичный follow-up: "Сколько голов нужно? Все ли головы полезны?" -- см. pruning голов (Voita et al., 2019).
3. "Bottleneck attention по памяти?"
Strong: "Матрица \(N \times N\) -- для sequence length 128K это \(128K^2 = 16\) млрд элементов в FP16 = 32 GB. Flash Attention решает через тайлинг: матрица никогда не материализуется целиком, \(O(N)\) памяти."
4. "Pre-norm vs post-norm?"
Strong: "Post-norm (оригинальный трансформер) -- градиенты проходят через LayerNorm, что может их уменьшать. Pre-norm (GPT-2+, LLaMA) -- residual path чистый, градиенты текут напрямую. Pre-norm стабильнее, но post-norm может давать лучшее качество при правильном warmup."
5. "Как Q, K, V получаются в self-attention?"
Red flag: "Q это запрос пользователя, K это ключи в базе данных"
Strong: "В self-attention все три -- линейные проекции одного входа: \(Q = XW^Q\), \(K = XW^K\), \(V = XW^V\). Три разные матрицы позволяют разделить роли: 'что я ищу' vs 'что я предлагаю как ключ' vs 'что я даю как содержимое'."
Задание для самопроверки
Реализуйте attention с нуля без подглядывания. Проверьте: (1) output shape совпадает с input shape, (2) attention weights суммируются в 1 по последней оси, (3) causal mask не позволяет позиции \(i\) видеть позицию \(j > i\).
Источники¶
- Vaswani et al. -- "Attention Is All You Need" (2017), arXiv:1706.03762
- Jay Alammar -- The Illustrated Transformer
- Andrej Karpathy -- "Let's build GPT" (YouTube)
- Hugging Face -- реализация attention в Transformers
- Dao et al. -- "FlashAttention: Fast and Memory-Efficient Exact Attention" (2022)
- PyTorch --
F.scaled_dot_product_attention - Lilian Weng -- "Attention? Attention!" (blog)
- Stanford CS224N -- лекция по Attention
- Harvard NLP -- "The Annotated Transformer"
See Also
- Flash Attention 3 — IO-aware memory optimization
- MQA/GQA внимание — Grouped/Multi-Query Attention
- KV-кэш оптимизация — KV-cache compression
- Позиционное кодирование — RoPE, ALiBi
- Эффективные трансформеры — Linear attention, sparse attention