Перейти к содержанию

Нормализация в LLM: глубокий разбор

~6 минут чтения

Предварительно: Реализация внимания с нуля


Зачем нормализация

Нейронная сеть -- оркестр, где каждый нейрон "играет" с разной громкостью. Без нормализации один слой выдает значения порядка тысяч, а следующий -- порядка сотых. Нормализация приводит все "голоса" к одному масштабу, чтобы сеть стабильно обучалась.

Ключевой инсайт: RMSNorm стал стандартом в современных LLM (LLaMA, Mistral, Qwen), потому что убирает ненужное mean-centering из LayerNorm. Эмбеддинги LLM уже имеют распределение, близкое к нулевому среднему -- повторное центрирование избыточно. Результат: та же точность на 15-30% быстрее.

Распространенность (2026)

Метод Доля Используется в Ключевое преимущество
RMSNorm ~80% LLaMA, Mistral, Qwen, PaLM 15% faster, simpler
LayerNorm ~15% BERT, GPT-2, T5 Original standard
GroupNorm ~5% Vision Transformers Batch-independent
BatchNorm <1% Rarely in LLMs Batch-dependent (bad for variable length)

Таксономия нормализаций

Эволюция

graph TD
    BN["2015: BatchNorm<br/>По batch dimension<br/>Зависим от batch size"] --> LN["2016: LayerNorm<br/>По feature dimension<br/>Batch-independent"]
    LN --> GN["2018: GroupNorm<br/>По группам каналов<br/>Vision, малые batch"]
    LN --> RMS["2019: RMSNorm<br/>Без mean-centering<br/>~15% быстрее LayerNorm"]

    style BN fill:#fce4ec,stroke:#c62828
    style LN fill:#fff3e0,stroke:#ef6c00
    style GN fill:#f3e5f5,stroke:#9c27b0
    style RMS fill:#e8f5e9,stroke:#4caf50

Формулы

LayerNorm (Original Transformer)

LayerNorm

\[\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sigma} + \beta\]

Где:

- $\mu = \frac{1}{d}\sum_{i=1}^{d} x_i$ -- среднее
- $\sigma = \sqrt{\frac{1}{d}\sum_{i=1}^{d}(x_i - \mu)^2 + \epsilon}$ -- стандартное отклонение
- $\gamma, \beta$ -- обучаемые параметры (scale и shift)

Свойства:

  • 2 обучаемых параметра на измерение (\(\gamma\), \(\beta\))
  • Центрирование по среднему + нормализация по дисперсии
  • 2 прохода по данным (среднее, затем дисперсия)

RMSNorm (Modern LLMs)

RMSNorm -- стандарт 2024-2026

\[\text{RMSNorm}(x) = \gamma \cdot \frac{x}{\text{RMS}(x)}\]

Где:

- $\text{RMS}(x) = \sqrt{\frac{1}{d}\sum_{i=1}^{d} x_i^2 + \epsilon}$
- $\gamma$ -- обучаемый параметр масштабирования
- **Нет** bias параметра $\beta$
- **Нет** вычитания среднего

**Числовой пример:** для вектора $x = [1, 2, 3, 4]$:

- LayerNorm: $\mu=2.5$, $\sigma=1.12$, нормализованный $= [-1.34, -0.45, 0.45, 1.34]$
- RMSNorm: $\text{RMS}=\sqrt{(1+4+9+16)/4}=2.74$, нормализованный $= [0.37, 0.73, 1.10, 1.46]$

Видно: RMSNorm сохраняет **знак и направление**, LayerNorm центрирует вокруг нуля.

Свойства:

  • 1 обучаемый параметр на измерение (только \(\gamma\))
  • Без центрирования по среднему
  • 1 проход по данным (только RMS)
  • На 15-30% быстрее LayerNorm

GroupNorm

\[\text{GroupNorm}(x) = \gamma \cdot \frac{x - \mu_g}{\sigma_g} + \beta\]

Где \(\mu_g\), \(\sigma_g\) вычисляются по группам каналов.

Свойства:

  • Группы делят feature dimension
  • Не зависит от batch size
  • Используется в vision моделях

Почему RMSNorm побеждает

  1. Эмбеддинги уже центрированы. Word embeddings инициализируются ~\(\mathcal{N}(0, \sigma^2)\), после тренировки среднее остается близким к 0. Mean-centering в LayerNorm -- лишнее вычисление.

  2. Re-centering разрушает направление. Направление эмбеддинга кодирует семантику. Вычитание среднего (LayerNorm) сдвигает направление вектора. RMSNorm сохраняет исходное направление, масштабируя только амплитуду.

  3. Вычислительная эффективность. LayerNorm: mean(x) -> (x-mu) -> var(x-mu) -> norm (2 редукции). RMSNorm: rms(x) -> x/rms(x) (1 редукция). Экономия 15-30%.

  4. Стабильность градиентов. У RMSNorm проще формула градиента, меньше вероятность gradient issues, стабильнее тренировка на практике.

RMSNorm -- это не 'LayerNorm без bias'

Распространенная ошибка: думать, что RMSNorm = LayerNorm с \(\beta=0\). На самом деле RMSNorm убирает и mean-centering, и bias параметр. Формула принципиально другая: вместо \((x - \mu) / \sigma\) используется \(x / \text{RMS}(x)\). Нельзя получить RMSNorm просто убрав bias из LayerNorm.

Не меняй позицию нормализации при fine-tuning

Модель обучена с Pre-Norm? Не переключай на Post-Norm (и наоборот). Residual connections настроились на конкретную позицию нормализации. Смена позиции = разрушение обученных весов.

Сравнение производительности

Метрика LayerNorm RMSNorm Разница
Forward pass 2 редукции 1 редукция ~15% быстрее
Параметры \(\gamma + \beta\) только \(\gamma\) На 50% меньше
Память 2 буфера 1 буфер Меньше памяти
Стабильность Хорошая Хорошая Одинаково
Perplexity Baseline Такой же или лучше Без потери качества

Реализация на PyTorch

RMSNorm

import torch
import torch.nn as nn
import torch.nn.functional as F

class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization.

    Used by LLaMA, Mistral, Qwen, PaLM.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Compute RMS
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        # Normalize and scale
        return (x / rms) * self.weight


class LayerNorm(nn.Module):
    """Standard Layer Normalization.

    Used by BERT, GPT-2, T5, original Transformer.
    """

    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
        self.bias = nn.Parameter(torch.zeros(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        return self.weight * (x - mean) / torch.sqrt(var + self.eps) + self.bias


class GroupNorm(nn.Module):
    """Group Normalization.

    Used by Vision Transformers, detection models.
    """

    def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5):
        super().__init__()
        self.num_groups = num_groups
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(num_channels))
        self.bias = nn.Parameter(torch.zeros(num_channels))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)

Fused RMSNorm (for efficiency)

# Using PyTorch's native LayerNorm as reference
# For production, use fused kernels from apex or flash-attention

try:
    from apex.normalization import FusedRMSNorm
    RMSNorm = FusedRMSNorm  # Use optimized version if available
except ImportError:
    pass  # Fall back to pure PyTorch

Когда что использовать

Гайд по выбору

Архитектура Рекомендация Причина
Decoder-only LLM RMSNorm Стандарт, эффективен, проверен
Encoder (BERT-style) LayerNorm Оригинальный дизайн, pretrained веса
Vision Transformer LayerNorm или GroupNorm Image patches, пространственная структура
Multimodal Зависит от задачи Выбирать по доминирующей модальности
Fine-tuning Оставить оригинал Не менять нормализацию

Pre-Norm vs Post-Norm

graph TD
    subgraph post["Post-Norm (Original Transformer)"]
        X1["x"] --> ATT1["Attention"] --> ADD1["+ residual"] --> NORM1["Norm"] --> FFN1["FFN"] --> ADD2["+ residual"] --> NORM2["Norm"] --> OUT1["out"]
    end
    subgraph pre["Pre-Norm (LLaMA, GPT-2+, Mistral)"]
        X2["x"] --> NORM3["Norm"] --> ATT2["Attention"] --> ADD3["+ residual"] --> NORM4["Norm"] --> FFN2["FFN"] --> ADD4["+ residual"] --> OUT2["out"]
    end

    style post fill:#fce4ec,stroke:#c62828
    style pre fill:#e8f5e9,stroke:#4caf50

Post-Norm: градиенты проходят через norm layers -- менее стабильно. Pre-Norm: чистые residual пути для градиентов -- стандарт для всех современных LLM.


Полное сравнение

Матрица сравнения

Свойство BatchNorm LayerNorm GroupNorm RMSNorm
Зависит от batch Да Нет Нет Нет
Mean-centering Да Да Да Нет
Обучаемые параметры γ, β γ, β γ, β только γ
Compute cost Низкий Средний Средний Низкий
Доля в LLM <1% 15% 5% 80%
Стабильность Зависит от batch Хорошая Хорошая Хорошая
Inference Зависит от batch Консистентный Консистентный Консистентный

Использование в моделях (2026)

Модель Нормализация Архитектура
LLaMA 3 RMSNorm Pre-norm
Mistral RMSNorm Pre-norm
Qwen 2 RMSNorm Pre-norm
GPT-4 LayerNorm (speculated) Pre-norm
BERT LayerNorm Post-norm
T5 LayerNorm Pre-norm
ViT LayerNorm Pre-norm

Ключевые числа

Сравнение вычислений

Операция LayerNorm RMSNorm Ускорение
Mean 1 редукция
Variance 1 редукция
RMS 1 редукция
Всего редукций 2 1 ~15%
Буферы памяти 2 (mean, var) 1 (rms) Меньше

Количество параметров

Размер модели Параметры LayerNorm Параметры RMSNorm Экономия
7B ~1.3M ~0.65M 0.65M
70B ~13M ~6.5M 6.5M

Скорость обучения

Модель LayerNorm RMSNorm Ускорение
LLaMA-7B 1x (baseline)
С LayerNorm ~1.15x медленнее -15%

Метрики стабильности

Метрика LayerNorm RMSNorm
Дисперсия нормы градиентов Baseline Такая же или ниже
Скачки loss Иногда Реже
Финальный perplexity Baseline Такой же или лучше

Interview Questions

1. Почему современные LLM перешли с LayerNorm на RMSNorm?

❌ Red flag: "RMSNorm быстрее, потому что проще формула"

✅ Strong answer: "RMSNorm убирает mean-centering из LayerNorm. Это обосновано: word embeddings инициализируются с нулевым средним, и после обучения среднее остается близким к нулю -- mean-centering избыточен. Кроме того, вычитание среднего сдвигает направление эмбеддинга, которое кодирует семантику. RMSNorm делает одну редукцию вместо двух, экономя 15-30% compute при том же качестве."

2. Pre-Norm vs Post-Norm: что лучше и почему?

❌ Red flag: "Pre-Norm лучше, потому что все современные модели его используют"

✅ Strong answer: "Pre-Norm дает чистые residual paths для градиентов -- нормализация применяется до attention/FFN, поэтому градиенты текут через residual connections без искажений. Post-Norm пропускает градиенты через norm layers, что менее стабильно при глубоких сетях. Все современные LLM (LLaMA, Mistral, GPT-3+) используют Pre-Norm. Post-Norm встречается в BERT и оригинальном Transformer."

3. Можно ли заменить RMSNorm на LayerNorm в pretrained модели?

❌ Red flag: "Да, они похожи, можно просто поменять"

✅ Strong answer: "Нет. Веса модели обучались с конкретным типом нормализации. RMSNorm масштабирует по амплитуде без сдвига, LayerNorm центрирует и масштабирует. Замена изменит распределение активаций, сломав обученные паттерны. Если нужно конвертировать -- потребуется дообучение (минимум LoRA fine-tuning)."


Самопроверка

  1. Вычислите RMSNorm для вектора \(x = [2, -1, 3, 0]\) с \(\gamma = [1, 1, 1, 1]\). Сравните с результатом LayerNorm.
  2. В модели 32 трансформер-блока. Каждый блок использует Pre-Norm с RMSNorm (dim=4096). Сколько обучаемых параметров приходится на нормализацию? А если заменить на LayerNorm?
  3. Объясните, почему BatchNorm не подходит для LLM (подсказка: подумайте о переменной длине последовательности и inference с batch_size=1).

Sources

  1. arXiv — "Root Mean Square Layer Normalization" (Zhang & Sennrich, 2019)
  2. Papers with Code — RMSNorm benchmarks and implementations
  3. Hugging Face — Transformers normalization documentation
  4. Medium — "Why RMSNorm is the Default for Modern LLMs" (2025)
  5. Reddit r/MachineLearning — "RMSNorm vs LayerNorm in practice" (2025)
  6. LLaMA Paper — "LLaMA: Open and Efficient Foundation Language Models"
  7. Mistral Technical Report — Architecture details
  8. Qwen Technical Report — RMSNorm implementation
  9. PyTorch Documentation — Normalization layers
  10. NVIDIA Apex — FusedRMSNorm implementation

See Also