Нормализация в LLM: глубокий разбор¶
~6 минут чтения
Предварительно: Реализация внимания с нуля
Зачем нормализация¶
Нейронная сеть -- оркестр, где каждый нейрон "играет" с разной громкостью. Без нормализации один слой выдает значения порядка тысяч, а следующий -- порядка сотых. Нормализация приводит все "голоса" к одному масштабу, чтобы сеть стабильно обучалась.
Ключевой инсайт: RMSNorm стал стандартом в современных LLM (LLaMA, Mistral, Qwen), потому что убирает ненужное mean-centering из LayerNorm. Эмбеддинги LLM уже имеют распределение, близкое к нулевому среднему -- повторное центрирование избыточно. Результат: та же точность на 15-30% быстрее.
Распространенность (2026)¶
| Метод | Доля | Используется в | Ключевое преимущество |
|---|---|---|---|
| RMSNorm | ~80% | LLaMA, Mistral, Qwen, PaLM | 15% faster, simpler |
| LayerNorm | ~15% | BERT, GPT-2, T5 | Original standard |
| GroupNorm | ~5% | Vision Transformers | Batch-independent |
| BatchNorm | <1% | Rarely in LLMs | Batch-dependent (bad for variable length) |
Таксономия нормализаций¶
Эволюция¶
graph TD
BN["2015: BatchNorm<br/>По batch dimension<br/>Зависим от batch size"] --> LN["2016: LayerNorm<br/>По feature dimension<br/>Batch-independent"]
LN --> GN["2018: GroupNorm<br/>По группам каналов<br/>Vision, малые batch"]
LN --> RMS["2019: RMSNorm<br/>Без mean-centering<br/>~15% быстрее LayerNorm"]
style BN fill:#fce4ec,stroke:#c62828
style LN fill:#fff3e0,stroke:#ef6c00
style GN fill:#f3e5f5,stroke:#9c27b0
style RMS fill:#e8f5e9,stroke:#4caf50
Формулы¶
LayerNorm (Original Transformer)¶
LayerNorm
Где:
- $\mu = \frac{1}{d}\sum_{i=1}^{d} x_i$ -- среднее
- $\sigma = \sqrt{\frac{1}{d}\sum_{i=1}^{d}(x_i - \mu)^2 + \epsilon}$ -- стандартное отклонение
- $\gamma, \beta$ -- обучаемые параметры (scale и shift)
Свойства:
- 2 обучаемых параметра на измерение (\(\gamma\), \(\beta\))
- Центрирование по среднему + нормализация по дисперсии
- 2 прохода по данным (среднее, затем дисперсия)
RMSNorm (Modern LLMs)¶
RMSNorm -- стандарт 2024-2026
Где:
- $\text{RMS}(x) = \sqrt{\frac{1}{d}\sum_{i=1}^{d} x_i^2 + \epsilon}$
- $\gamma$ -- обучаемый параметр масштабирования
- **Нет** bias параметра $\beta$
- **Нет** вычитания среднего
**Числовой пример:** для вектора $x = [1, 2, 3, 4]$:
- LayerNorm: $\mu=2.5$, $\sigma=1.12$, нормализованный $= [-1.34, -0.45, 0.45, 1.34]$
- RMSNorm: $\text{RMS}=\sqrt{(1+4+9+16)/4}=2.74$, нормализованный $= [0.37, 0.73, 1.10, 1.46]$
Видно: RMSNorm сохраняет **знак и направление**, LayerNorm центрирует вокруг нуля.
Свойства:
- 1 обучаемый параметр на измерение (только \(\gamma\))
- Без центрирования по среднему
- 1 проход по данным (только RMS)
- На 15-30% быстрее LayerNorm
GroupNorm¶
Где \(\mu_g\), \(\sigma_g\) вычисляются по группам каналов.
Свойства:
- Группы делят feature dimension
- Не зависит от batch size
- Используется в vision моделях
Почему RMSNorm побеждает¶
-
Эмбеддинги уже центрированы. Word embeddings инициализируются ~\(\mathcal{N}(0, \sigma^2)\), после тренировки среднее остается близким к 0. Mean-centering в LayerNorm -- лишнее вычисление.
-
Re-centering разрушает направление. Направление эмбеддинга кодирует семантику. Вычитание среднего (LayerNorm) сдвигает направление вектора. RMSNorm сохраняет исходное направление, масштабируя только амплитуду.
-
Вычислительная эффективность. LayerNorm:
mean(x) -> (x-mu) -> var(x-mu) -> norm(2 редукции). RMSNorm:rms(x) -> x/rms(x)(1 редукция). Экономия 15-30%. -
Стабильность градиентов. У RMSNorm проще формула градиента, меньше вероятность gradient issues, стабильнее тренировка на практике.
RMSNorm -- это не 'LayerNorm без bias'
Распространенная ошибка: думать, что RMSNorm = LayerNorm с \(\beta=0\). На самом деле RMSNorm убирает и mean-centering, и bias параметр. Формула принципиально другая: вместо \((x - \mu) / \sigma\) используется \(x / \text{RMS}(x)\). Нельзя получить RMSNorm просто убрав bias из LayerNorm.
Не меняй позицию нормализации при fine-tuning
Модель обучена с Pre-Norm? Не переключай на Post-Norm (и наоборот). Residual connections настроились на конкретную позицию нормализации. Смена позиции = разрушение обученных весов.
Сравнение производительности¶
| Метрика | LayerNorm | RMSNorm | Разница |
|---|---|---|---|
| Forward pass | 2 редукции | 1 редукция | ~15% быстрее |
| Параметры | \(\gamma + \beta\) | только \(\gamma\) | На 50% меньше |
| Память | 2 буфера | 1 буфер | Меньше памяти |
| Стабильность | Хорошая | Хорошая | Одинаково |
| Perplexity | Baseline | Такой же или лучше | Без потери качества |
Реализация на PyTorch¶
RMSNorm¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization.
Used by LLaMA, Mistral, Qwen, PaLM.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Compute RMS
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
# Normalize and scale
return (x / rms) * self.weight
class LayerNorm(nn.Module):
"""Standard Layer Normalization.
Used by BERT, GPT-2, T5, original Transformer.
"""
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.bias = nn.Parameter(torch.zeros(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
return self.weight * (x - mean) / torch.sqrt(var + self.eps) + self.bias
class GroupNorm(nn.Module):
"""Group Normalization.
Used by Vision Transformers, detection models.
"""
def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5):
super().__init__()
self.num_groups = num_groups
self.eps = eps
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
Fused RMSNorm (for efficiency)¶
# Using PyTorch's native LayerNorm as reference
# For production, use fused kernels from apex or flash-attention
try:
from apex.normalization import FusedRMSNorm
RMSNorm = FusedRMSNorm # Use optimized version if available
except ImportError:
pass # Fall back to pure PyTorch
Когда что использовать¶
Гайд по выбору¶
| Архитектура | Рекомендация | Причина |
|---|---|---|
| Decoder-only LLM | RMSNorm | Стандарт, эффективен, проверен |
| Encoder (BERT-style) | LayerNorm | Оригинальный дизайн, pretrained веса |
| Vision Transformer | LayerNorm или GroupNorm | Image patches, пространственная структура |
| Multimodal | Зависит от задачи | Выбирать по доминирующей модальности |
| Fine-tuning | Оставить оригинал | Не менять нормализацию |
Pre-Norm vs Post-Norm¶
graph TD
subgraph post["Post-Norm (Original Transformer)"]
X1["x"] --> ATT1["Attention"] --> ADD1["+ residual"] --> NORM1["Norm"] --> FFN1["FFN"] --> ADD2["+ residual"] --> NORM2["Norm"] --> OUT1["out"]
end
subgraph pre["Pre-Norm (LLaMA, GPT-2+, Mistral)"]
X2["x"] --> NORM3["Norm"] --> ATT2["Attention"] --> ADD3["+ residual"] --> NORM4["Norm"] --> FFN2["FFN"] --> ADD4["+ residual"] --> OUT2["out"]
end
style post fill:#fce4ec,stroke:#c62828
style pre fill:#e8f5e9,stroke:#4caf50
Post-Norm: градиенты проходят через norm layers -- менее стабильно. Pre-Norm: чистые residual пути для градиентов -- стандарт для всех современных LLM.
Полное сравнение¶
Матрица сравнения¶
| Свойство | BatchNorm | LayerNorm | GroupNorm | RMSNorm |
|---|---|---|---|---|
| Зависит от batch | Да | Нет | Нет | Нет |
| Mean-centering | Да | Да | Да | Нет |
| Обучаемые параметры | γ, β | γ, β | γ, β | только γ |
| Compute cost | Низкий | Средний | Средний | Низкий |
| Доля в LLM | <1% | 15% | 5% | 80% |
| Стабильность | Зависит от batch | Хорошая | Хорошая | Хорошая |
| Inference | Зависит от batch | Консистентный | Консистентный | Консистентный |
Использование в моделях (2026)¶
| Модель | Нормализация | Архитектура |
|---|---|---|
| LLaMA 3 | RMSNorm | Pre-norm |
| Mistral | RMSNorm | Pre-norm |
| Qwen 2 | RMSNorm | Pre-norm |
| GPT-4 | LayerNorm (speculated) | Pre-norm |
| BERT | LayerNorm | Post-norm |
| T5 | LayerNorm | Pre-norm |
| ViT | LayerNorm | Pre-norm |
Ключевые числа¶
Сравнение вычислений¶
| Операция | LayerNorm | RMSNorm | Ускорение |
|---|---|---|---|
| Mean | 1 редукция | — | — |
| Variance | 1 редукция | — | — |
| RMS | — | 1 редукция | — |
| Всего редукций | 2 | 1 | ~15% |
| Буферы памяти | 2 (mean, var) | 1 (rms) | Меньше |
Количество параметров¶
| Размер модели | Параметры LayerNorm | Параметры RMSNorm | Экономия |
|---|---|---|---|
| 7B | ~1.3M | ~0.65M | 0.65M |
| 70B | ~13M | ~6.5M | 6.5M |
Скорость обучения¶
| Модель | LayerNorm | RMSNorm | Ускорение |
|---|---|---|---|
| LLaMA-7B | — | 1x (baseline) | — |
| С LayerNorm | ~1.15x медленнее | — | -15% |
Метрики стабильности¶
| Метрика | LayerNorm | RMSNorm |
|---|---|---|
| Дисперсия нормы градиентов | Baseline | Такая же или ниже |
| Скачки loss | Иногда | Реже |
| Финальный perplexity | Baseline | Такой же или лучше |
Interview Questions¶
1. Почему современные LLM перешли с LayerNorm на RMSNorm?¶
Red flag: "RMSNorm быстрее, потому что проще формула"
Strong answer: "RMSNorm убирает mean-centering из LayerNorm. Это обосновано: word embeddings инициализируются с нулевым средним, и после обучения среднее остается близким к нулю -- mean-centering избыточен. Кроме того, вычитание среднего сдвигает направление эмбеддинга, которое кодирует семантику. RMSNorm делает одну редукцию вместо двух, экономя 15-30% compute при том же качестве."
2. Pre-Norm vs Post-Norm: что лучше и почему?¶
Red flag: "Pre-Norm лучше, потому что все современные модели его используют"
Strong answer: "Pre-Norm дает чистые residual paths для градиентов -- нормализация применяется до attention/FFN, поэтому градиенты текут через residual connections без искажений. Post-Norm пропускает градиенты через norm layers, что менее стабильно при глубоких сетях. Все современные LLM (LLaMA, Mistral, GPT-3+) используют Pre-Norm. Post-Norm встречается в BERT и оригинальном Transformer."
3. Можно ли заменить RMSNorm на LayerNorm в pretrained модели?¶
Red flag: "Да, они похожи, можно просто поменять"
Strong answer: "Нет. Веса модели обучались с конкретным типом нормализации. RMSNorm масштабирует по амплитуде без сдвига, LayerNorm центрирует и масштабирует. Замена изменит распределение активаций, сломав обученные паттерны. Если нужно конвертировать -- потребуется дообучение (минимум LoRA fine-tuning)."
Самопроверка
- Вычислите RMSNorm для вектора \(x = [2, -1, 3, 0]\) с \(\gamma = [1, 1, 1, 1]\). Сравните с результатом LayerNorm.
- В модели 32 трансформер-блока. Каждый блок использует Pre-Norm с RMSNorm (dim=4096). Сколько обучаемых параметров приходится на нормализацию? А если заменить на LayerNorm?
- Объясните, почему BatchNorm не подходит для LLM (подсказка: подумайте о переменной длине последовательности и inference с batch_size=1).
Sources¶
- arXiv — "Root Mean Square Layer Normalization" (Zhang & Sennrich, 2019)
- Papers with Code — RMSNorm benchmarks and implementations
- Hugging Face — Transformers normalization documentation
- Medium — "Why RMSNorm is the Default for Modern LLMs" (2025)
- Reddit r/MachineLearning — "RMSNorm vs LayerNorm in practice" (2025)
- LLaMA Paper — "LLaMA: Open and Efficient Foundation Language Models"
- Mistral Technical Report — Architecture details
- Qwen Technical Report — RMSNorm implementation
- PyTorch Documentation — Normalization layers
- NVIDIA Apex — FusedRMSNorm implementation
See Also¶
- Activation Functions in LLMs -- SwiGLU + RMSNorm = стандартная архитектурная пара
- Normalization Comparison -- сводная таблица BatchNorm vs LayerNorm vs RMSNorm vs GroupNorm
- PyTorch Cheatsheet -- nn.LayerNorm, nn.RMSNorm, model.train()/eval()
- Efficient Transformers -- архитектурные решения, включая выбор нормализации
- Flash Attention 3 -- kernel fusion для attention, аналогично fused RMSNorm