Перейти к содержанию

Multi-Query и Grouped-Query Attention (MQA/GQA)

~6 минут чтения

Предварительно: Attention с нуля | KV Cache

Проблема: KV cache как bottleneck

При авторегрессивном инференсе LLM каждый новый токен требует attention ко всем предыдущим. Чтобы не пересчитывать Key и Value для старых токенов, их кэшируют -- это KV cache.

Размер KV cache для стандартного Multi-Head Attention (MHA):

\[\text{KV cache} = 2 \times n_\text{layers} \times n_\text{heads} \times d_\text{head} \times n_\text{seq} \times \text{bytes}\]

Для LLaMA 2 70B (\(n_\text{layers}=80\), \(n_\text{heads}=64\), \(d_\text{head}=128\), fp16):

\[2 \times 80 \times 64 \times 128 \times 4096 \times 2 = \textbf{5.2 GB на один запрос}\]

При batch size 32 это уже 166 GB -- больше чем сами веса модели. KV cache становится главным ограничителем throughput: GPU memory заканчивается раньше, чем вычислительные ресурсы.

Ключевой инсайт: можно ли уменьшить KV cache не трогая Query? Да -- именно это делают MQA и GQA.


Три механизма внимания

Multi-Head Attention (MHA)

Стандартный подход (Vaswani et al., 2017). Каждый head имеет свои проекции Q, K, V:

Head 0: Q0, K0, V0
Head 1: Q1, K1, V1
...
Head H: QH, KH, VH

\(H\) наборов KV -- максимальная выразительность, но максимальный KV cache.

Multi-Query Attention (MQA)

Shazeer (2019): все Query heads делят один набор K, V:

Head 0: Q0 ─┐
Head 1: Q1 ─┼─ K_shared, V_shared
...         │
Head H: QH ─┘

KV cache сокращается в \(H\) раз. Для LLaMA 70B: с 5.2 GB до 81 MB на запрос. Разные Query heads задают разные "вопросы" к одним и тем же Key/Value -- как несколько человек ищут в одной базе данных.

Цена: потеря выразительности. Все heads вынуждены работать с одним представлением ключей и значений. На практике качество падает на 0.5-1% на бенчмарках.

Grouped-Query Attention (GQA)

Ainslie et al. (2023): компромисс. Query heads делятся на \(G\) групп, каждая группа делит свой K, V:

Group 0: Q0, Q1 ──── K0, V0
Group 1: Q2, Q3 ──── K1, V1
Group 2: Q4, Q5 ──── K2, V2
Group 3: Q6, Q7 ──── K3, V3

При \(H=8\) heads и \(G=4\) групп: KV cache в \(H/G = 2\) раза меньше MHA, но в \(G = 4\) раза больше MQA. Качество близко к MHA.


Сравнительная таблица

MHA GQA MQA
KV heads \(H\) \(G\) (1 < G < H) 1
KV cache \(H \times d_h \times n\) \(G \times d_h \times n\) \(1 \times d_h \times n\)
Экономия памяти 1x (baseline) \(H/G\) x \(H\) x
Качество модели Best ~MHA (< 0.5% degrade) -0.5-1% vs MHA
Throughput (inference) Baseline 1.5-2x MHA 2-4x MHA
Latency (TTFT) Baseline Сопоставима Сопоставима
Используется в GPT-2, BERT, GPT-3 LLaMA ⅔, Mistral, Gemma PaLM, Falcon, StarCoder

GQA != ухудшенный MHA

Распространенное заблуждение: GQA -- это компромисс, значит хуже. На практике GQA с \(G=8\) (при \(H=32\)) дает практически идентичное качество MHA, при 4x экономии KV cache. Причина: соседние attention heads часто учат очень похожие паттерны Key/Value -- шаринг не теряет информацию.


Формулы

KV Cache Memory

\[\text{KV}_\text{MHA} = 2 \cdot L \cdot H \cdot d_h \cdot S \cdot b\]
\[\text{KV}_\text{GQA} = 2 \cdot L \cdot G \cdot d_h \cdot S \cdot b\]
$$\text{KV}_\text{MQA} = 2 \cdot L \cdot 1 \cdot d_h \cdot S \cdot b$$

Где: \(L\) -- layers, \(H\) -- heads, \(G\) -- groups, \(d_h\) -- head dim, \(S\) -- seq length, \(b\) -- bytes per element.

**Числовой пример** (LLaMA 3 70B, $L=80$, $H=64$, $G=8$, $d_h=128$, $S=8192$, fp16):

- MHA: $2 \times 80 \times 64 \times 128 \times 8192 \times 2 = $ **10.7 GB**
- GQA ($G=8$): $2 \times 80 \times 8 \times 128 \times 8192 \times 2 = $ **1.3 GB** (8x экономия)
- MQA: $2 \times 80 \times 1 \times 128 \times 8192 \times 2 = $ **0.17 GB** (64x экономия)

Реализация GQA на Python

GQA Attention (PyTorch)
import torch
import torch.nn as nn
import torch.nn.functional as F

class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, n_kv_heads: int):
        """
        Args:
            d_model: размерность модели (например 4096)
            n_heads: количество Query heads (например 32)
            n_kv_heads: количество KV heads / групп (например 8)
        """
        super().__init__()
        assert n_heads % n_kv_heads == 0
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.n_rep = n_heads // n_kv_heads  # сколько Q heads на 1 KV head
        self.head_dim = d_model // n_heads

        self.wq = nn.Linear(d_model, n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(n_heads * self.head_dim, d_model, bias=False)

    def forward(self, x: torch.Tensor, mask=None):
        B, S, _ = x.shape

        # Q: [B, S, n_heads, head_dim]
        q = self.wq(x).view(B, S, self.n_heads, self.head_dim)
        # K, V: [B, S, n_kv_heads, head_dim]
        k = self.wk(x).view(B, S, self.n_kv_heads, self.head_dim)
        v = self.wv(x).view(B, S, self.n_kv_heads, self.head_dim)

        # Repeat KV heads to match Q heads
        # [B, S, n_kv_heads, head_dim] -> [B, S, n_heads, head_dim]
        k = k.repeat_interleave(self.n_rep, dim=2)
        v = v.repeat_interleave(self.n_rep, dim=2)

        # Standard attention
        q = q.transpose(1, 2)  # [B, n_heads, S, head_dim]
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)  # [B, n_heads, S, head_dim]

        out = out.transpose(1, 2).contiguous().view(B, S, -1)
        return self.wo(out)

# Пример: LLaMA 3 8B конфигурация
gqa = GroupedQueryAttention(d_model=4096, n_heads=32, n_kv_heads=8)
x = torch.randn(2, 128, 4096)
out = gqa(x)
print(f"Input: {x.shape}, Output: {out.shape}")
# Input: torch.Size([2, 128, 4096]), Output: torch.Size([2, 128, 4096])

# Параметры:
# Q projection: 4096 * 4096 = 16M params
# K projection: 4096 * 1024 = 4M params (8 KV heads * 128)
# V projection: 4096 * 1024 = 4M params
# Total KV params: 8M vs 32M для MHA -- 4x экономия

Конвертация MHA в GQA

Как конвертировать уже обученную MHA модель в GQA без обучения с нуля? (Ainslie et al., 2023):

  1. Mean pooling: усреднить Key/Value проекции внутри каждой группы
  2. Continued pre-training: дообучить ~5% от исходного бюджета (несколько сотен шагов)

Результат: качество в пределах 0.5% от MHA при значительной экономии памяти. Именно так LLaMA 2 70B была конвертирована в GQA.


Что используют модели

Модель Mechanism Q heads KV heads Ratio
LLaMA 3 8B GQA 32 8 4:1
LLaMA 3 70B GQA 64 8 8:1
Mistral 7B GQA 32 8 4:1
Mixtral 8x7B GQA 32 8 4:1
Gemma 2 9B GQA 16 8 2:1
Qwen 2.5 72B GQA 64 8 8:1
PaLM MQA 16 1 16:1
Falcon 40B MQA 64 1 64:1
StarCoder MQA 48 1 48:1
GPT-4 Неизвестно -- -- --

Тренд 2024-2025: GQA с ratio 4:1 или 8:1 стал стандартом. MQA слишком агрессивно; MHA слишком расточительно.


Связь с другими оптимизациями

GQA/MQA -- это оптимизация памяти KV cache. Она отлично сочетается с:

  • FlashAttention -- оптимизация вычислений (IO-aware, тайлинг). Ортогонально: GQA уменьшает что хранить, FlashAttention ускоряет как считать.
  • PagedAttention (vLLM) -- оптимизация аллокации памяти (пагинация блоков). GQA делает блоки меньше, PagedAttention распределяет их эффективнее.
  • Quantization -- уменьшение precision (FP16 -> INT8/INT4). Меньше KV heads + ниже precision = максимальная экономия.
  • KV cache eviction -- удаление старых KV пар (H2O, StreamingLLM). Меньше heads = меньше данных для eviction.

Комбинация: GQA + FlashAttention + PagedAttention + INT8 KV cache = стандартный продакшен стек 2025.


Interview вопросы

Conceptual

Q: Объясните разницу между MHA, MQA и GQA.

Strong: "MHA -- каждый head имеет свои KV, максимум выразительности но O(H) memory. MQA -- один KV на все heads, O(1) memory но потеря качества. GQA -- G групп по H/G heads, O(G) memory, sweet spot. На практике GQA с G=8 дает качество как MHA при 4-8x экономии, потому что соседние heads часто учат похожие KV паттерны."

Red flag: "GQA -- это упрощенная версия attention" (нет, это архитектурная оптимизация KV cache).

Q: Почему GQA не теряет качество при шаринге KV?

Эмпирическое наблюдение: Key/Value проекции соседних heads высоко коррелированы. Исследования показывают cosine similarity > 0.9 для соседних heads. Шаринг KV внутри группы -- форма structured weight sharing, аналогичная weight tying в embedding/output layers.

Design

Q: У вас модель с 32 heads. Как выбрать количество KV groups?

Decision framework:

  • G=32 (MHA): нужно максимальное качество, memory не проблема (маленькая модель, короткие контексты)
  • G=8: стандартный выбор, хороший баланс (LLaMA 3, Mistral)
  • G=4: нужна агрессивная экономия, длинные контексты
  • G=1 (MQA): экстремальная экономия, допустима потеря качества (кодовые модели типа StarCoder)

Задача для самопроверки

Рассчитайте KV cache для модели: 32 layers, 32 Q heads, d_model=4096, seq_len=16384, fp16. Сравните MHA vs GQA (G=8) vs MQA. При каком batch size разница становится критичной (> 80 GB A100)?


See Also

Sources

  1. Multi-Query Attention (Shazeer, 2019)
  2. GQA: Training Generalized Multi-Query Transformer Models (Ainslie et al., 2023)
  3. LLaMA 2 (Touvron et al., 2023)
  4. NVIDIA TensorRT-LLM -- GQA