Перейти к содержанию

xLSTM: расширенная архитектура долгой краткосрочной памяти

~8 минут чтения

Предварительно: Mamba и SSM | Эффективные трансформеры

xLSTM -- попытка Sepp Hochreiter (создателя оригинального LSTM 1997 года) вернуть рекуррентные архитектуры в игру. Ключевое достижение: xLSTM-7B показывает inference в 2.5x быстрее Llama-7B при постоянном потреблении памяти ~7 GB (против 25+ GB KV-кэша трансформера на 128K контексте). Линейная сложность O(T) вместо квадратичной O(T^2) делает xLSTM особенно выгодным на длинных последовательностях -- при контексте 128K модель xLSTM на 78% компактнее эквивалентного трансформера.


Обзор

xLSTM (Extended Long Short-Term Memory) — эволюция LSTM архитектуры от Sepp Hochreiter (создателя оригинального LSTM), преодолевающая ограничения классических LSTM через exponential gating и matrix memory.

Ключевая инновация

\[ \text{xLSTM} = \text{Exponential Gating} + \text{Matrix Memory} + \text{Parallelizable Structure} \]

Вместо: O(T²) Transformer attention или ограниченного LSTM горизонта Используется: Linear O(T) complexity с unlimited context length

Сравнение: xLSTM vs Transformer vs LSTM

Свойство LSTM Transformer xLSTM
Сложность O(T) последовательно O(T^2) параллельно O(T) параллелизуемо
Память Фиксированное hidden state Растет с KV cache Фиксированная (matrix memory)
Длина контекста Ограничена (~100-500) Ограничена кэшем Не ограничена
Обучение Последовательное Параллельное Параллельное
Inference Быстрый, постоянный Медленный на длинных Быстрый, постоянный

От LSTM к xLSTM

Ограничения оригинального LSTM

  1. Bottleneck памяти: скалярное cell state \(c_t\) ограничивает ёмкость
  2. Последовательная обработка: нельзя параллелизовать по времени
  3. Ограниченный горизонт: градиенты затухают на длинных последовательностях
  4. Mixing памяти: только через hidden state \(h_t\)

Решения xLSTM

Проблема Решение
Скалярная память Matrix memory (mLSTM)
Последовательные обновления Параллелизуемый covariance update
Vanishing gradients Exponential gating
Ограниченный mixing Новые паттерны memory mixing

Строительные блоки xLSTM

A. sLSTM (Scalar Memory LSTM)

Свойства: - Scalar memory \(c_t\) (like original LSTM) - Exponential gating with normalization - New memory mixing mechanism

Formula:

\[ f_t = \exp(\text{linear}_f(x_t)) / \sum_j \exp(\text{linear}_f(x_j)) \]
\[ i_t = \exp(\text{linear}_i(x_t)) / \sum_j \exp(\text{linear}_i(x_j)) \]
\[ c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t \]

Ключевая инновация: Exponential gating + нормализация предотвращают gradient explosion.

B. mLSTM (Matrix Memory LSTM)

Свойства: - Matrix memory \(C_t \in \mathbb{R}^{d \times d}\) - Covariance update rule (fully parallelizable) - Linear complexity

Formula:

\[ C_t = f_t \odot C_{t-1} + i_t \odot (v_t \cdot k_t^T) \]
\[ h_t = o_t \odot (C_t \cdot q_t) \]

Where: - \(C_t\) — matrix memory (stores key-query covariance) - \(q_t, k_t, v_t\) — query, key, value vectors - \(f_t, i_t, o_t\) — forget, input, output gates

Параллелизация:

\[ C_t = \sum_{i=1}^{t} \left(\prod_{j=i+1}^{t} f_j\right) \cdot i_i \cdot v_i \cdot k_i^T \]

This can be computed via parallel scan!

Сравнение sLSTM vs mLSTM

Свойство sLSTM mLSTM
Память Скалярная \(c_t\) Матричная \(C_t \in d \times d\)
Ёмкость Ограниченная Высокая
Параллелизация Последовательная Полностью параллельная
Применение Короткие паттерны Long-range зависимости

Exponential Gating

Зачем экспоненциальные гейты?

Стандартные sigmoid-гейты насыщаются, вызывая: - Vanishing gradients - Information bottleneck

Решение: экспоненциальная активация с нормализацией.

Реализация

import torch
import torch.nn.functional as F

class ExponentialGate(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.linear = torch.nn.Linear(input_dim, hidden_dim)

    def forward(self, x):
        # Exponential activation
        exp_values = torch.exp(self.linear(x))
        # Normalize (prevent overflow)
        return exp_values / (exp_values.sum(dim=-1, keepdim=True) + 1e-8)

Техники стабилизации

  1. Log-space computation: избежание numerical overflow
  2. Clipping: ограничение аргументов экспоненты
  3. LayerNorm: нормализация активаций

Архитектура xLSTM блока

Структура блока

graph TD
    A["Input projection"] --> B["sLSTM layer(s)<br/>локальные паттерны"]
    B --> C["mLSTM layer(s)<br/>глобальные зависимости"]
    C --> D["Output projection"]
    D --> E["LayerNorm"]
    E --> F["Residual connection"]
    A2["Input x"] --> A
    A2 --> F
    style A fill:#e8eaf6,stroke:#3f51b5
    style B fill:#fff3e0,stroke:#ef6c00
    style C fill:#e8f5e9,stroke:#4caf50
    style D fill:#e8eaf6,stroke:#3f51b5
    style E fill:#f3e5f5,stroke:#9c27b0
    style F fill:#f3e5f5,stroke:#9c27b0

Реализация

import torch
import torch.nn as nn

class xLSTMBlock(nn.Module):
    def __init__(self, d_model, d_hidden, d_head, num_heads=4):
        super().__init__()

        # Input projection
        self.in_proj = nn.Linear(d_model, d_hidden)

        # sLSTM: scalar memory for local patterns
        self.slstm = sLSTMLayer(d_hidden, d_head)

        # mLSTM: matrix memory for global patterns
        self.mlstm = mLSTMLayer(d_hidden, d_head, num_heads)

        # Output projection
        self.out_proj = nn.Linear(d_hidden, d_model)

        # Normalization
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        residual = x

        # Project to hidden dimension
        h = self.in_proj(x)

        # sLSTM for local patterns
        h = self.slstm(h)

        # mLSTM for global patterns
        h = self.mlstm(h)

        # Project back and residual
        out = self.out_proj(h)
        return self.norm(out + residual)


class mLSTMLayer(nn.Module):
    """Matrix Memory LSTM - fully parallelizable"""

    def __init__(self, d_hidden, d_head, num_heads):
        super().__init__()
        self.d_head = d_head
        self.num_heads = num_heads

        # Query, Key, Value projections
        self.q_proj = nn.Linear(d_hidden, num_heads * d_head)
        self.k_proj = nn.Linear(d_hidden, num_heads * d_head)
        self.v_proj = nn.Linear(d_hidden, num_heads * d_head)

        # Gates (exponential)
        self.f_gate = nn.Linear(d_hidden, num_heads)
        self.i_gate = nn.Linear(d_hidden, num_heads)
        self.o_gate = nn.Linear(d_hidden, num_heads * d_head)

        # Output projection
        self.out_proj = nn.Linear(num_heads * d_head, d_hidden)

    def forward(self, x):
        B, T, D = x.shape

        # Project Q, K, V
        q = self.q_proj(x).view(B, T, self.num_heads, self.d_head)
        k = self.k_proj(x).view(B, T, self.num_heads, self.d_head)
        v = self.v_proj(x).view(B, T, self.num_heads, self.d_head)

        # Exponential gates
        f = torch.sigmoid(self.f_gate(x))  # (B, T, num_heads)
        i = torch.exp(self.i_gate(x) - self.i_gate(x).max(dim=1, keepdim=True)[0])
        i = i / (i.sum(dim=1, keepdim=True) + 1e-8)  # Normalize
        o = torch.sigmoid(self.o_gate(x)).view(B, T, self.num_heads, self.d_head)

        # Covariance update (simplified - use parallel scan for efficiency)
        # C_t = f_t * C_{t-1} + i_t * v_t * k_t^T
        # This can be parallelized using associative scan

        # For now, sequential version:
        h = []
        C = torch.zeros(B, self.num_heads, self.d_head, self.d_head, device=x.device)

        for t in range(T):
            k_t = k[:, t]  # (B, num_heads, d_head)
            v_t = v[:, t]
            f_t = f[:, t].unsqueeze(-1).unsqueeze(-1)  # (B, num_heads, 1, 1)
            i_t = i[:, t].unsqueeze(-1).unsqueeze(-1)
            q_t = q[:, t]

            # Update matrix memory
            C = f_t * C + i_t * torch.einsum('bhd,bhe->bhde', v_t, k_t)

            # Query memory
            h_t = torch.einsum('bhde,bhd->bhe', C, q_t)
            h.append(h_t)

        h = torch.stack(h, dim=1)  # (B, T, num_heads, d_head)

        # Apply output gate
        h = o * h
        h = h.reshape(B, T, -1)  # (B, T, num_heads * d_head)

        return self.out_proj(h)

xLSTM-7B (март 2025)

Paper: "xLSTM 7B: A Recurrent LLM for Fast and Efficient Inference" (arXiv:2503.13427)

Ключевые особенности

  • 7 billion parameters xLSTM-based LLM
  • Fastest inference among 7B models
  • Constant memory usage (no KV cache growth)
  • Open-source weights and code

Сравнение производительности

Модель Скорость inference Память (128K контекст)
Llama-7B Baseline ~25 GB (KV cache)
Mamba-7B 1.8× faster ~7 GB
xLSTM-7B 2.5× faster ~7 GB (constant)

Заявленная эффективность

  • 3.5× faster training than equivalent Transformer
  • 40.6% lower inference time vs Transformer-XL
  • 2.16× throughput on AMD MI300X vs NVIDIA H100

Законы масштабирования xLSTM (ICLR 2026)

Paper: "xLSTM Scaling Laws: Competitive Performance with Linear Time-Complexity"

Ключевые находки

  1. Pareto-доминирование: xLSTM стабильно показывает меньший loss при том же compute budget [Важно: claims от NXAI (создатели xLSTM), независимая валидация pending]
  2. Масштабирование по длине контекста: xLSTM получает больше пользы от длинных контекстов, чем трансформеры
  3. Inference scaling: линейная сложность становится критической на масштабе

Формула масштабирования

\[ L(N, D, C) = \frac{A}{N^\alpha} + \frac{B}{D^\beta} + E \]

Where: - \(N\) = number of parameters - \(D\) = training tokens - \(C\) = context length - xLSTM shows better \(\alpha, \beta\) than Transformers

Зависимость от длины контекста

Длина контекста Оптимальный Transformer Оптимальный xLSTM Преимущество xLSTM
2K tokens 7B params 7B params Baseline
8K tokens 13B params 9B params 30% smaller
32K tokens 30B params 12B params 60% smaller
128K tokens 70B+ params 15B params 78% smaller

Ключевой инсайт: xLSTM становится эффективнее по мере роста длины контекста.


Vision-xLSTM (ViL)

Paper: "Vision-LSTM: xLSTM as Generic Vision Backbone" (arXiv:2406.04303)

Архитектура ViL

graph TD
    A["Patch embedding"] --> B["Stack of xLSTM blocks"]
    B --> B1["Нечётные блоки:<br/>top-to-bottom scan"]
    B --> B2["Чётные блоки:<br/>bottom-to-top scan"]
    B1 --> C["Classification head"]
    B2 --> C
    C --> D["Output"]
    style A fill:#e8eaf6,stroke:#3f51b5
    style B fill:#fff3e0,stroke:#ef6c00
    style B1 fill:#e8f5e9,stroke:#4caf50
    style B2 fill:#e8f5e9,stroke:#4caf50
    style C fill:#f3e5f5,stroke:#9c27b0
    style D fill:#f3e5f5,stroke:#9c27b0

Двунаправленная обработка

class VisionxLSTMBlock(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.xlstm_forward = xLSTMBlock(d_model)
        self.xlstm_backward = xLSTMBlock(d_model)

    def forward(self, x):
        # x: (B, N_patches, D)

        # Forward pass
        h_fwd = self.xlstm_forward(x)

        # Backward pass
        x_rev = torch.flip(x, dims=[1])
        h_bwd = self.xlstm_backward(x_rev)
        h_bwd = torch.flip(h_bwd, dims=[1])

        return h_fwd + h_bwd

Результаты

Модель ImageNet Top-1 Параметры FLOPs
ViT-B 77.9% 86M 17.6G
ViL-B 78.2% 82M 16.1G
ViT-L 82.6% 307M 61.6G
ViL-L 82.9% 298M 58.2G

Бенчмарки производительности

Языковое моделирование (Pile)

Модель Параметры Perplexity Время обучения
Transformer 1B 6.8 100%
Mamba 1B 6.5 65%
xLSTM 1B 6.3 45%

Скорость inference (tokens/sec)

Модель Batch=1 Batch=16 Batch=64
Llama-2-7B 52 180 320
Mamba-7B 95 340 580
xLSTM-7B 130 450 720

Потребление памяти

Контекст Transformer Mamba xLSTM
4K 4 GB 2 GB 2 GB
32K 16 GB 3 GB 2.5 GB
128K 64 GB 8 GB 3 GB
1M OOM 40 GB 5 GB

xLSTM vs SSM (Mamba) vs Transformer

Когда что использовать

Сценарий Рекомендация
Короткие последовательности (<1K) Transformer или xLSTM
Средние (1K-16K) xLSTM или Mamba
Длинные (>16K) xLSTM
Максимальное качество Transformer или xLSTM
Ограниченная память xLSTM или Mamba
Edge deployment xLSTM
Vision задачи ViT или Vision-xLSTM
Time series xLSTM

Итоговое сравнение

Свойство Transformer Mamba xLSTM
Качество Отличное Хорошее Отличное
Скорость Медленно (длинный ctx) Быстро Самое быстрое
Память Растет O(T) Постоянная Постоянная
Параллельное обучение Да Да Да
Зрелость Высокая Средняя Начальная
Экосистема Обширная Растущая Новая

Индустриальные применения

Коллаборация с AMD (2025)

Результаты на AMD Instinct MI300X: - 2.16× throughput improvement vs NVIDIA H100 - Optimized via PyTorch TunableOp - GEMM kernel tuning for xLSTM operations

Edge Computing

Benefits: - Low memory footprint - Constant inference time - Energy efficient - Unlimited context on limited hardware

Use Cases

  1. Time series: Industrial monitoring, finance
  2. Vision: Drones, autonomous vehicles
  3. Robotics: Large action models
  4. Biology: Bio-xLSTM for molecular sequences

Справочник формул

sLSTM Update

\[c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t\]

mLSTM Matrix Memory

\[C_t = f_t \odot C_{t-1} + i_t \odot (v_t \cdot k_t^T)\]
\[h_t = o_t \odot (C_t \cdot q_t)\]

Exponential Gate (normalized)

\[g = \frac{\exp(z)}{\sum_j \exp(z_j)}\]

Parallel Scan for mLSTM

\[C_t = \sum_{i=1}^{t} \left(\prod_{j=i+1}^{t} f_j\right) \cdot i_i \cdot v_i \cdot k_i^T\]

Самопроверка

  1. Сравнение памяти: Посчитайте KV-кэш трансформера 7B (d_model=4096, 32 layers, 32 heads) на 128K контексте. Сравните с размером матричной памяти mLSTM (d_head=128, 32 heads). Почему xLSTM потребляет ~3 GB, а трансформер ~25 GB?

  2. Parallel scan: mLSTM параллелизуется через associative scan. Объясните, почему операция \(C_t = f_t C_{t-1} + i_t v_t k_t^T\) ассоциативна. Запишите для 4 шагов и покажите, как их объединить параллельно (hint: binary tree reduction).

  3. Edge vs Cloud: Клиент хочет LLM для real-time анализа промышленных сенсоров (10K токенов/сек, контекст 100K+). GPU: NVIDIA Jetson (8 GB). Какую архитектуру выберете (Transformer/Mamba/xLSTM) и почему? Посчитайте, влезет ли модель 7B.


Типичные заблуждения

Заблуждение: xLSTM уже заменил трансформеры в production

xLSTM показывает впечатляющие бенчмарки (2.5x inference speedup, постоянная память), но экосистема находится на стадии «Emerging». У трансформеров -- годы оптимизации (Flash Attention, vLLM, TensorRT), тысячи pretrained моделей, огромное community. xLSTM пока не имеет аналогичной инфраструктуры. Большинство claims о Pareto-доминировании исходят от NXAI (создатели xLSTM) и требуют независимой валидации.

Заблуждение: Линейная сложность O(T) означает, что xLSTM всегда быстрее трансформера

При коротких последовательностях (<1K токенов) трансформер с Flash Attention может быть быстрее за счёт лучшей GPU-утилизации и оптимизированных ядер. Преимущество xLSTM нарастает с длиной контекста -- на 128K+ разница становится критической (78% меньше параметров для эквивалентного качества).

Заблуждение: mLSTM -- это просто attention с другим именем

mLSTM использует covariance update rule (\(C_t = f_t C_{t-1} + i_t v_t k_t^T\)), а не softmax attention. Ключевая разница: (1) есть forget gate -- старая информация экспоненциально затухает, (2) матричная память фиксированного размера d x d, без роста KV-кэша, (3) параллелизуется через associative scan, а не через matmul.


Вопросы для собеседования

В чём разница между sLSTM и mLSTM? Когда использовать какой?

❌ «sLSTM -- простой, mLSTM -- сложный, поэтому mLSTM всегда лучше» -- нет понимания.

✅ sLSTM использует скалярную память \(c_t\) с exponential gating -- подходит для коротких локальных паттернов, но обрабатывается последовательно. mLSTM использует матричную память \(C_t \in \mathbb{R}^{d \times d}\) с covariance update rule -- захватывает long-range зависимости и полностью параллелизуется через parallel scan. В xLSTM-блоке они комбинируются: sLSTM для local, mLSTM для global.

Почему xLSTM потребляет постоянную память при увеличении контекста, а трансформер -- нет?

❌ «xLSTM просто эффективнее» -- нет объяснения механизма.

✅ Трансформер хранит KV-кэш для каждого предыдущего токена: на 128K контексте это ~25 GB для 7B модели. xLSTM хранит фиксированную матричную память \(C_t \in \mathbb{R}^{d \times d}\) -- информация о прошлых токенах «сжата» в матрицу через covariance update. Новый токен обновляет \(C_t\) через forget/input gates, старая информация экспоненциально затухает. Итого: ~3-7 GB независимо от длины контекста.

Как exponential gating решает проблему vanishing gradients в LSTM?

❌ «Используется exp вместо sigmoid» -- формально верно, но недостаточно.

✅ Стандартный sigmoid saturates в 0 и 1, что создаёт vanishing gradients при длинных последовательностях. Exponential gating (\(f_t = \exp(z)\)) даёт больший динамический диапазон. Но без нормализации это вызовет overflow, поэтому xLSTM нормализует: \(f_t = \exp(z_t) / \sum_j \exp(z_j)\). Дополнительно используется log-space computation и clipping для numerical stability.


See Also


Источники

Papers

  1. xLSTM: Beck et al. (2024) — arXiv:2405.04517
  2. Vision-LSTM: Alkin et al. (2024) — arXiv:2406.04303
  3. xLSTM 7B: Beck et al. (2025) — arXiv:2503.13427
  4. xLSTM Scaling Laws: ICLR 2026 — OpenReview:bpbU549sSg

Industry

  1. NXAI xLSTM on AMD — AMD Blog (June 2025)
  2. xLSTM for Edge Computing — Nature Scientific Reports (2025)
  3. Cognitive xLSTM for Multi-agent IR — Nature (2025)

Code

  1. xLSTM: github.com/NX-AI/xlstm
  2. Vision-LSTM: github.com/NX-AI/vision-lstm