Перейти к содержанию

Flash Attention 3

~8 минут чтения

Предварительно: Реализация внимания с нуля, KV-кэш оптимизация


Зачем Flash Attention

Стандартный attention масштабируется как \(O(N^2)\) по памяти -- для sequence length 16K это \(16384^2 \approx 268\) млн элементов, которые нужно записать в HBM (High Bandwidth Memory) GPU. Даже при быстрых вычислениях bottleneck -- память, не арифметика. GPU считает быстрее, чем успевает читать/писать данные.

Аналогия: представь повара (Tensor Core) и холодильник (HBM). Повар режет овощи за секунду, но каждый раз ходит к холодильнику за 10 секунд. Flash Attention -- это когда повар достает все ингредиенты на разделочный стол (SRAM) заранее и работает без остановок. Flash Attention 3 идет дальше: пока повар режет, ассистент уже несет следующую порцию из холодильника (асинхронность Hopper GPU).

Ключевой инсайт: Flash Attention не меняет что вычисляется (результат математически идентичен standard attention), а меняет как -- через тайлинг, online softmax и IO-aware алгоритм. FA3 добавляет аппаратную асинхронность Hopper GPU для перекрытия compute и memory operations.

Эволюция Flash Attention

Version GPU Target Key Innovation Speedup
FA1 Ampere (A100) Tiling, recomputation 3-4x vs baseline
FA2 Ampere Parallelism, work partition 2x vs FA1
FA3 Hopper (H100) Async, low-precision 1.5-2x vs FA2
FA4 Blackwell Selective rescaling, TileLang 1 PFLOP/s barrier

Recap: как работает Flash Attention

Ключевая идея

Стандартный Attention: \(\text{Attention}(Q,K,V) = \text{softmax}(QK^T/\sqrt{d}) \times V\) -- материализует матрицу \(N \times N\) в HBM, \(O(N^2)\) памяти.

Flash Attention (тайлинг):

  1. Разбивает Q, K, V на блоки размером SRAM (64x64)
  2. Вычисляет softmax онлайн -- поблочно обновляя статистики (\(m\), \(l\))
  3. Никогда не записывает полную матрицу attention в HBM
  4. Память: \(O(N)\) вместо \(O(N^2)\)

Online Softmax -- сердце Flash Attention

При поблочной обработке softmax нужно корректировать уже вычисленные значения при поступлении новых блоков:

\[m_\text{new} = \max(m_\text{old}, m_\text{cur})\]
\[l_\text{new} = l_\text{old} \cdot e^{m_\text{old} - m_\text{new}} + e^{m_\text{cur} - m_\text{new}}\]
Где $m$ -- текущий максимум (для численной стабильности), $l$ -- знаменатель softmax. Это позволяет вычислять **точный** softmax без хранения всей строки.

Иерархия памяти GPU

graph TD
    HBM["HBM<br/>80 GB (H100)<br/>3.35 TB/s<br/>High latency"]
    HBM --> L2["L2 Cache<br/>50 MB<br/>Medium latency"]
    L2 --> SRAM["SRAM (Shared Memory)<br/>228 KB per SM<br/>~20 TB/s<br/>Flash Attention works here"]
    style HBM fill:#ffcdd2,stroke:#c62828
    style L2 fill:#fff3e0,stroke:#ef6c00
    style SRAM fill:#c8e6c9,stroke:#2e7d32

Инновации FlashAttention-3

Возможности архитектуры Hopper

Feature Description Impact
Async Tensor Core (wgmma.mma_async) Tensor Core runs asynchronously, overlaps with CUDA Core Producer-consumer paradigm
TMA (Tensor Memory Accelerator) Hardware-accelerated async DMA HBM->SRAM Thread group free during transfer
FP8 Tensor Core Native FP8, 2x throughput vs FP16/BF16 ~1.2-1.3 PFLOPs/s peak
Thread Block Cluster Multiple SMs coordinate, shared memory Distributed attention

Три ключевые техники

1. Асинхронность

graph LR
    subgraph Producer["Producer (Warp 1)"]
        LOAD["Load Q, K tiles<br/>(TMA async)"]
    end
    subgraph Consumer["Consumer (Warp 2)"]
        GEMM["Tensor Core GEMM<br/>(wgmma async)"]
    end
    LOAD -->|"pipeline"| GEMM

    style Producer fill:#e3f2fd,stroke:#1565c0
    style Consumer fill:#fff3e0,stroke:#ef6c00

Суть: Producer загружает тайлы, пока Consumer считает -- compute и memory перекрываются во времени.

2. Конвейеризация (Pipelining)

Stage Operation Overlaps With
1 Load Q block Load K block
2 Compute QK^T Load V block
3 Compute softmax Store partial results
4 Compute OV^T Load next Q block

3. Низкая точность (FP8)

\[\text{FP8 Range} = \text{E4M3 (no inf)} \text{ or } \text{E5M2 (has inf)}\]
Format Dynamic Range Precision Use Case
E4M3 ±448 High Activations
E5M2 ±57344 Lower Gradients

FP8 не для всех задач

FP8 E4M3 имеет динамический диапазон всего до +-448. Для attention scores, которые могут быть большими (особенно до scaling), это может привести к overflow. FA3 решает это через per-block scaling, но при fine-tuning на задачах с длинным контекстом FP8 может давать заметную деградацию (> 0.5%). Всегда валидируйте FP8 на вашем конкретном workload.


Бенчмарки производительности

Сравнение пропускной способности

GPU FA2 (BF16) FA3 (BF16) FA3 (FP8)
A100 312 TFLOPs/s 312 TFLOPs/s N/A
H100 400 TFLOPs/s 740 TFLOPs/s 1.2 PFLOPs/s

Коэффициенты ускорения

Scenario FA2 → FA3 Speedup
Short sequences (< 1K) 1.3x
Medium sequences (1K-4K) 1.5x
Long sequences (4K-16K) 1.8x
Very long (16K+) 2.0x

Эффективность памяти

Метрика FA2 FA3
Чтений из HBM 1x 0.7x
Использование SRAM 70% 90%
Память на токен O(1) O(1)

Сравнение с другими методами

Методы attention

Method Memory Speed Long Context Hardware
Standard O(N²) Slow No Any
Flash Attention 2 O(N) Fast Good A100+
Flash Attention 3 O(N) Very Fast Excellent H100+
Ring Attention O(N) Medium Excellent Multi-GPU
Sparse Attention O(N) Medium Good Any

Когда использовать FA3

Use Case Recommended
H100/H200 training FA3 (required)
A100 training FA2
Inference only FA2 or FA3
Long context (100K+) Ring + FA3
Multi-GPU Ring Attention

Интеграция в код

Использование в PyTorch

import torch
from flash_attn import flash_attn_func

# Flash Attention 3 (Hopper only)
q = torch.randn(1, 8, 4096, 128, dtype=torch.bfloat16, device='cuda')
k = torch.randn(1, 8, 4096, 128, dtype=torch.bfloat16, device='cuda')
v = torch.randn(1, 8, 4096, 128, dtype=torch.bfloat16, device='cuda')

# Automatic dispatch to FA3 on H100
output = flash_attn_func(q, k, v, causal=True)

# FP8 version (FA3 only)
q_fp8 = q.to(torch.float8_e4m3fn)
k_fp8 = k.to(torch.float8_e4m3fn)
v_fp8 = v.to(torch.float8_e4m3fn)
output_fp8 = flash_attn_func(q_fp8, k_fp8, v_fp8, causal=True)

Чеклист интеграции

Step Description
1 Install flash-attn >= 2.6.0
2 Verify GPU (H100/H200 for FA3)
3 Check CUDA >= 12.0
4 Replace F.scaled_dot_product_attention
5 Verify outputs match (FP8 has small diffs)

Ring Attention + FA3

Кольцевое внимание для длинного контекста

graph LR
    G0["GPU 0<br/>Q0 + FA3"] <-->|"K,V rotation"| G1["GPU 1<br/>Q1 + FA3"]
    G1 <-->|"K,V rotation"| G2["GPU 2<br/>Q2 + FA3"]
    G2 <-->|"K,V rotation"| G3["GPU 3<br/>Q3 + FA3"]
    G3 <-->|"K,V rotation"| G0
    style G0 fill:#e8eaf6,stroke:#3f51b5
    style G1 fill:#fff3e0,stroke:#ef6c00
    style G2 fill:#e8f5e9,stroke:#4caf50
    style G3 fill:#fce4ec,stroke:#e91e63

Каждый GPU использует FA3 для локального attention. Блоки K/V ротируются по кольцевой топологии. Масштабируется до миллионов токенов.


FlashAttention-4 (Hot Chips 2025)

Ключевые анонсы

  • Поддержка Blackwell GPU -- оптимизация под архитектуру NVIDIA Blackwell
  • Барьер петафлопса -- первое ядро attention, преодолевшее 1 PFLOP/s
  • Selective rescaling -- пересчет масштабирования только когда изменение максимума влияет на численную корректность
  • Абстракция TileLang -- ~80 строк вместо 500+ на CUDA, поддержка AMD MI300X через ROCm
  • Reverse-engineering от Modal -- доступен детальный технический анализ

FA3 vs FA4

Аспект FA3 FA4
GPU Hopper (H100) Blackwell
Peak 740 TFLOP/s (BF16) 1+ PFLOP/s
Kernel code 500+ lines CUDA ~80 lines TileLang
Vendor support NVIDIA only NVIDIA + AMD (ROCm)

Пробелы

  • Полная статья FA4 ещё не опубликована (только анонс)
  • Полные бенчмарки FA4 vs FA3 на Blackwell ещё не публичны

Ключевые числа для интервью

Пиковая производительность

GPU Peak TFLOPs (BF16) FA3 Achieved Utilization
A100 312 200 (FA2) 64%
H100 989 740 75%

Пропускная способность памяти

GPU Пропускная способность HBM Эффективная (FA3)
A100 2.0 TB/s 1.8 TB/s
H100 3.35 TB/s 3.0 TB/s

Масштабирование по длине последовательности

Sequence Standard FA2 FA3
1K 50ms 5ms 3ms
4K 800ms 25ms 15ms
16K OOM 120ms 65ms
64K OOM 600ms 320ms

FP8 vs BF16

Метрика BF16 FP8
Скорость 1x 1.6x
Память 1x 0.5x
Точность Baseline -0.1% loss

Вопросы на собеседовании

Q: Объясните как Flash Attention достигает O(N) памяти вместо O(N^2)?

  • ❌ Red flag: "Используется разреженное внимание / аппроксимация"
  • ✅ Strong: "Flash Attention вычисляет точный attention, но через тайлинг: Q, K, V разбиваются на блоки, каждый блок помещается в SRAM. Матрица attention \(N \times N\) никогда не материализуется целиком в HBM. Online softmax позволяет поблочно обновлять статистики (\(m\), \(l\)) и получать точный результат. Backward pass использует recomputation вместо хранения."

Q: Чем FA3 отличается от FA2?

  • ✅ Strong: "FA3 оптимизирован под Hopper GPU (H100). Три ключевых отличия: (1) асинхронные Tensor Core через wgmma.mma_async -- compute перекрывается с memory ops, (2) конвейеризация -- пока считается QK^T, загружается V, (3) нативная поддержка FP8 -- 1.6x speedup при потере <0.1% accuracy. Результат: 740 TFLOP/s BF16 (75% utilization) vs 400 TFLOP/s для FA2 на H100."

Q: Когда НЕ стоит использовать Flash Attention?

  • ✅ Strong: "На очень коротких последовательностях (< 128 tokens) overhead от тайлинга может быть больше экономии. Также FA3 требует Hopper GPU -- на A100 используйте FA2. При использовании sparse attention patterns (Longformer, BigBird) Flash Attention может быть несовместим с кастомными масками. Для cross-attention с сильно различающимися длинами Q и KV нужна специальная конфигурация тайлов."

Q: Как Ring Attention связан с Flash Attention?

  • ✅ Strong: "Ring Attention -- для распределения attention по нескольким GPU. Каждый GPU хранит свою часть Q, а блоки K/V ротируются по кольцевой топологии между GPU. На каждом GPU локально используется Flash Attention. Это ортогональные оптимизации: FA оптимизирует внутри GPU (SRAM vs HBM), Ring Attention -- между GPU (communication overlap). Вместе масштабируются до миллионов токенов."

Задание для самопроверки

  1. Рассчитайте, сколько памяти занимает матрица attention для sequence length 32768 в FP16. Почему стандартный attention не помещается на A100 (80 GB)?
  2. Объясните, почему online softmax дает математически идентичный результат стандартному softmax. Подсказка: запишите softmax через \(m\) (max) и \(l\) (sum of exps).
  3. Для тренировки модели на 128K контексте: какую комбинацию FA3 + Ring Attention + GQA вы бы использовали на кластере из 8 H100? Обоснуйте.

Источники

  1. arXiv — "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision" (2407.08608)
  2. NVIDIA GTC 2025 — "FlashAttention-3: Fast and Accurate Attention With Asynchrony..."
  3. GitHub — Dao-AILab/flash-attention
  4. Colfax Research — "FlashAttention-3 Technical Analysis"
  5. Medium — "FlashAttention-3: The Engine Powering Next-Gen LLMs"
  6. Nebius Blog — "Kvax: Fast Flash Attention for JAX"
  7. HPCWire — "Liger-Kernel: Triton Kernels for Efficient LLM Training"
  8. CrowdStrike Blog -- "How We Train GenAI Models with Distributed Computing"
  9. Modal -- "Reverse-engineering FlashAttention-4" (2026)
  10. AMD ROCm Blog -- "ROCm TileLang Kernel" (2026)

See Also