Multi-Query и Grouped-Query Attention (MQA/GQA)¶
~6 минут чтения
Предварительно: Attention с нуля | KV Cache
Проблема: KV cache как bottleneck¶
При авторегрессивном инференсе LLM каждый новый токен требует attention ко всем предыдущим. Чтобы не пересчитывать Key и Value для старых токенов, их кэшируют -- это KV cache.
Размер KV cache для стандартного Multi-Head Attention (MHA):
Для LLaMA 2 70B (\(n_\text{layers}=80\), \(n_\text{heads}=64\), \(d_\text{head}=128\), fp16):
При batch size 32 это уже 166 GB -- больше чем сами веса модели. KV cache становится главным ограничителем throughput: GPU memory заканчивается раньше, чем вычислительные ресурсы.
Ключевой инсайт: можно ли уменьшить KV cache не трогая Query? Да -- именно это делают MQA и GQA.
Три механизма внимания¶
Multi-Head Attention (MHA)¶
Стандартный подход (Vaswani et al., 2017). Каждый head имеет свои проекции Q, K, V:
\(H\) наборов KV -- максимальная выразительность, но максимальный KV cache.
Multi-Query Attention (MQA)¶
Shazeer (2019): все Query heads делят один набор K, V:
KV cache сокращается в \(H\) раз. Для LLaMA 70B: с 5.2 GB до 81 MB на запрос. Разные Query heads задают разные "вопросы" к одним и тем же Key/Value -- как несколько человек ищут в одной базе данных.
Цена: потеря выразительности. Все heads вынуждены работать с одним представлением ключей и значений. На практике качество падает на 0.5-1% на бенчмарках.
Grouped-Query Attention (GQA)¶
Ainslie et al. (2023): компромисс. Query heads делятся на \(G\) групп, каждая группа делит свой K, V:
Group 0: Q0, Q1 ──── K0, V0
Group 1: Q2, Q3 ──── K1, V1
Group 2: Q4, Q5 ──── K2, V2
Group 3: Q6, Q7 ──── K3, V3
При \(H=8\) heads и \(G=4\) групп: KV cache в \(H/G = 2\) раза меньше MHA, но в \(G = 4\) раза больше MQA. Качество близко к MHA.
Сравнительная таблица¶
| MHA | GQA | MQA | |
|---|---|---|---|
| KV heads | \(H\) | \(G\) (1 < G < H) | 1 |
| KV cache | \(H \times d_h \times n\) | \(G \times d_h \times n\) | \(1 \times d_h \times n\) |
| Экономия памяти | 1x (baseline) | \(H/G\) x | \(H\) x |
| Качество модели | Best | ~MHA (< 0.5% degrade) | -0.5-1% vs MHA |
| Throughput (inference) | Baseline | 1.5-2x MHA | 2-4x MHA |
| Latency (TTFT) | Baseline | Сопоставима | Сопоставима |
| Используется в | GPT-2, BERT, GPT-3 | LLaMA ⅔, Mistral, Gemma | PaLM, Falcon, StarCoder |
GQA != ухудшенный MHA
Распространенное заблуждение: GQA -- это компромисс, значит хуже. На практике GQA с \(G=8\) (при \(H=32\)) дает практически идентичное качество MHA, при 4x экономии KV cache. Причина: соседние attention heads часто учат очень похожие паттерны Key/Value -- шаринг не теряет информацию.
Формулы¶
KV Cache Memory
$$\text{KV}_\text{MQA} = 2 \cdot L \cdot 1 \cdot d_h \cdot S \cdot b$$
Где: \(L\) -- layers, \(H\) -- heads, \(G\) -- groups, \(d_h\) -- head dim, \(S\) -- seq length, \(b\) -- bytes per element.
**Числовой пример** (LLaMA 3 70B, $L=80$, $H=64$, $G=8$, $d_h=128$, $S=8192$, fp16):
- MHA: $2 \times 80 \times 64 \times 128 \times 8192 \times 2 = $ **10.7 GB**
- GQA ($G=8$): $2 \times 80 \times 8 \times 128 \times 8192 \times 2 = $ **1.3 GB** (8x экономия)
- MQA: $2 \times 80 \times 1 \times 128 \times 8192 \times 2 = $ **0.17 GB** (64x экономия)
Реализация GQA на Python¶
GQA Attention (PyTorch)
import torch
import torch.nn as nn
import torch.nn.functional as F
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, n_kv_heads: int):
"""
Args:
d_model: размерность модели (например 4096)
n_heads: количество Query heads (например 32)
n_kv_heads: количество KV heads / групп (например 8)
"""
super().__init__()
assert n_heads % n_kv_heads == 0
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.n_rep = n_heads // n_kv_heads # сколько Q heads на 1 KV head
self.head_dim = d_model // n_heads
self.wq = nn.Linear(d_model, n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(n_heads * self.head_dim, d_model, bias=False)
def forward(self, x: torch.Tensor, mask=None):
B, S, _ = x.shape
# Q: [B, S, n_heads, head_dim]
q = self.wq(x).view(B, S, self.n_heads, self.head_dim)
# K, V: [B, S, n_kv_heads, head_dim]
k = self.wk(x).view(B, S, self.n_kv_heads, self.head_dim)
v = self.wv(x).view(B, S, self.n_kv_heads, self.head_dim)
# Repeat KV heads to match Q heads
# [B, S, n_kv_heads, head_dim] -> [B, S, n_heads, head_dim]
k = k.repeat_interleave(self.n_rep, dim=2)
v = v.repeat_interleave(self.n_rep, dim=2)
# Standard attention
q = q.transpose(1, 2) # [B, n_heads, S, head_dim]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, v) # [B, n_heads, S, head_dim]
out = out.transpose(1, 2).contiguous().view(B, S, -1)
return self.wo(out)
# Пример: LLaMA 3 8B конфигурация
gqa = GroupedQueryAttention(d_model=4096, n_heads=32, n_kv_heads=8)
x = torch.randn(2, 128, 4096)
out = gqa(x)
print(f"Input: {x.shape}, Output: {out.shape}")
# Input: torch.Size([2, 128, 4096]), Output: torch.Size([2, 128, 4096])
# Параметры:
# Q projection: 4096 * 4096 = 16M params
# K projection: 4096 * 1024 = 4M params (8 KV heads * 128)
# V projection: 4096 * 1024 = 4M params
# Total KV params: 8M vs 32M для MHA -- 4x экономия
Конвертация MHA в GQA¶
Как конвертировать уже обученную MHA модель в GQA без обучения с нуля? (Ainslie et al., 2023):
- Mean pooling: усреднить Key/Value проекции внутри каждой группы
- Continued pre-training: дообучить ~5% от исходного бюджета (несколько сотен шагов)
Результат: качество в пределах 0.5% от MHA при значительной экономии памяти. Именно так LLaMA 2 70B была конвертирована в GQA.
Что используют модели¶
| Модель | Mechanism | Q heads | KV heads | Ratio |
|---|---|---|---|---|
| LLaMA 3 8B | GQA | 32 | 8 | 4:1 |
| LLaMA 3 70B | GQA | 64 | 8 | 8:1 |
| Mistral 7B | GQA | 32 | 8 | 4:1 |
| Mixtral 8x7B | GQA | 32 | 8 | 4:1 |
| Gemma 2 9B | GQA | 16 | 8 | 2:1 |
| Qwen 2.5 72B | GQA | 64 | 8 | 8:1 |
| PaLM | MQA | 16 | 1 | 16:1 |
| Falcon 40B | MQA | 64 | 1 | 64:1 |
| StarCoder | MQA | 48 | 1 | 48:1 |
| GPT-4 | Неизвестно | -- | -- | -- |
Тренд 2024-2025: GQA с ratio 4:1 или 8:1 стал стандартом. MQA слишком агрессивно; MHA слишком расточительно.
Связь с другими оптимизациями¶
GQA/MQA -- это оптимизация памяти KV cache. Она отлично сочетается с:
- FlashAttention -- оптимизация вычислений (IO-aware, тайлинг). Ортогонально: GQA уменьшает что хранить, FlashAttention ускоряет как считать.
- PagedAttention (vLLM) -- оптимизация аллокации памяти (пагинация блоков). GQA делает блоки меньше, PagedAttention распределяет их эффективнее.
- Quantization -- уменьшение precision (FP16 -> INT8/INT4). Меньше KV heads + ниже precision = максимальная экономия.
- KV cache eviction -- удаление старых KV пар (H2O, StreamingLLM). Меньше heads = меньше данных для eviction.
Комбинация: GQA + FlashAttention + PagedAttention + INT8 KV cache = стандартный продакшен стек 2025.
Interview вопросы¶
Conceptual¶
Q: Объясните разницу между MHA, MQA и GQA.
Strong: "MHA -- каждый head имеет свои KV, максимум выразительности но O(H) memory. MQA -- один KV на все heads, O(1) memory но потеря качества. GQA -- G групп по H/G heads, O(G) memory, sweet spot. На практике GQA с G=8 дает качество как MHA при 4-8x экономии, потому что соседние heads часто учат похожие KV паттерны."
Red flag: "GQA -- это упрощенная версия attention" (нет, это архитектурная оптимизация KV cache).
Q: Почему GQA не теряет качество при шаринге KV?
Эмпирическое наблюдение: Key/Value проекции соседних heads высоко коррелированы. Исследования показывают cosine similarity > 0.9 для соседних heads. Шаринг KV внутри группы -- форма structured weight sharing, аналогичная weight tying в embedding/output layers.
Design¶
Q: У вас модель с 32 heads. Как выбрать количество KV groups?
Decision framework:
- G=32 (MHA): нужно максимальное качество, memory не проблема (маленькая модель, короткие контексты)
- G=8: стандартный выбор, хороший баланс (LLaMA 3, Mistral)
- G=4: нужна агрессивная экономия, длинные контексты
- G=1 (MQA): экстремальная экономия, допустима потеря качества (кодовые модели типа StarCoder)
Задача для самопроверки
Рассчитайте KV cache для модели: 32 layers, 32 Q heads, d_model=4096, seq_len=16384, fp16. Сравните MHA vs GQA (G=8) vs MQA. При каком batch size разница становится критичной (> 80 GB A100)?
See Also¶
- KV Cache Optimization -- формулы KV cache, PagedAttention, eviction strategies
- FlashAttention 3 -- complementary: FlashAttention + GQA = maximum throughput
- Attention с нуля -- vanilla MHA implementation для сравнения
- Сравнение нормализаций -- RMSNorm + GQA = modern LLM standard
- Efficient Transformers -- обзор всех architectural optimizations