Flash Attention 3¶
~8 минут чтения
Предварительно: Реализация внимания с нуля, KV-кэш оптимизация
Зачем Flash Attention¶
Стандартный attention масштабируется как \(O(N^2)\) по памяти -- для sequence length 16K это \(16384^2 \approx 268\) млн элементов, которые нужно записать в HBM (High Bandwidth Memory) GPU. Даже при быстрых вычислениях bottleneck -- память, не арифметика. GPU считает быстрее, чем успевает читать/писать данные.
Аналогия: представь повара (Tensor Core) и холодильник (HBM). Повар режет овощи за секунду, но каждый раз ходит к холодильнику за 10 секунд. Flash Attention -- это когда повар достает все ингредиенты на разделочный стол (SRAM) заранее и работает без остановок. Flash Attention 3 идет дальше: пока повар режет, ассистент уже несет следующую порцию из холодильника (асинхронность Hopper GPU).
Ключевой инсайт: Flash Attention не меняет что вычисляется (результат математически идентичен standard attention), а меняет как -- через тайлинг, online softmax и IO-aware алгоритм. FA3 добавляет аппаратную асинхронность Hopper GPU для перекрытия compute и memory operations.
Эволюция Flash Attention¶
| Version | GPU Target | Key Innovation | Speedup |
|---|---|---|---|
| FA1 | Ampere (A100) | Tiling, recomputation | 3-4x vs baseline |
| FA2 | Ampere | Parallelism, work partition | 2x vs FA1 |
| FA3 | Hopper (H100) | Async, low-precision | 1.5-2x vs FA2 |
| FA4 | Blackwell | Selective rescaling, TileLang | 1 PFLOP/s barrier |
Recap: как работает Flash Attention¶
Ключевая идея¶
Стандартный Attention: \(\text{Attention}(Q,K,V) = \text{softmax}(QK^T/\sqrt{d}) \times V\) -- материализует матрицу \(N \times N\) в HBM, \(O(N^2)\) памяти.
Flash Attention (тайлинг):
- Разбивает Q, K, V на блоки размером SRAM (64x64)
- Вычисляет softmax онлайн -- поблочно обновляя статистики (\(m\), \(l\))
- Никогда не записывает полную матрицу attention в HBM
- Память: \(O(N)\) вместо \(O(N^2)\)
Online Softmax -- сердце Flash Attention
При поблочной обработке softmax нужно корректировать уже вычисленные значения при поступлении новых блоков:
Где $m$ -- текущий максимум (для численной стабильности), $l$ -- знаменатель softmax. Это позволяет вычислять **точный** softmax без хранения всей строки.
Иерархия памяти GPU¶
graph TD
HBM["HBM<br/>80 GB (H100)<br/>3.35 TB/s<br/>High latency"]
HBM --> L2["L2 Cache<br/>50 MB<br/>Medium latency"]
L2 --> SRAM["SRAM (Shared Memory)<br/>228 KB per SM<br/>~20 TB/s<br/>Flash Attention works here"]
style HBM fill:#ffcdd2,stroke:#c62828
style L2 fill:#fff3e0,stroke:#ef6c00
style SRAM fill:#c8e6c9,stroke:#2e7d32
Инновации FlashAttention-3¶
Возможности архитектуры Hopper¶
| Feature | Description | Impact |
|---|---|---|
Async Tensor Core (wgmma.mma_async) |
Tensor Core runs asynchronously, overlaps with CUDA Core | Producer-consumer paradigm |
| TMA (Tensor Memory Accelerator) | Hardware-accelerated async DMA HBM->SRAM | Thread group free during transfer |
| FP8 Tensor Core | Native FP8, 2x throughput vs FP16/BF16 | ~1.2-1.3 PFLOPs/s peak |
| Thread Block Cluster | Multiple SMs coordinate, shared memory | Distributed attention |
Три ключевые техники¶
1. Асинхронность¶
graph LR
subgraph Producer["Producer (Warp 1)"]
LOAD["Load Q, K tiles<br/>(TMA async)"]
end
subgraph Consumer["Consumer (Warp 2)"]
GEMM["Tensor Core GEMM<br/>(wgmma async)"]
end
LOAD -->|"pipeline"| GEMM
style Producer fill:#e3f2fd,stroke:#1565c0
style Consumer fill:#fff3e0,stroke:#ef6c00
Суть: Producer загружает тайлы, пока Consumer считает -- compute и memory перекрываются во времени.
2. Конвейеризация (Pipelining)¶
| Stage | Operation | Overlaps With |
|---|---|---|
| 1 | Load Q block | Load K block |
| 2 | Compute QK^T | Load V block |
| 3 | Compute softmax | Store partial results |
| 4 | Compute OV^T | Load next Q block |
3. Низкая точность (FP8)¶
| Format | Dynamic Range | Precision | Use Case |
|---|---|---|---|
| E4M3 | ±448 | High | Activations |
| E5M2 | ±57344 | Lower | Gradients |
FP8 не для всех задач
FP8 E4M3 имеет динамический диапазон всего до +-448. Для attention scores, которые могут быть большими (особенно до scaling), это может привести к overflow. FA3 решает это через per-block scaling, но при fine-tuning на задачах с длинным контекстом FP8 может давать заметную деградацию (> 0.5%). Всегда валидируйте FP8 на вашем конкретном workload.
Бенчмарки производительности¶
Сравнение пропускной способности¶
| GPU | FA2 (BF16) | FA3 (BF16) | FA3 (FP8) |
|---|---|---|---|
| A100 | 312 TFLOPs/s | 312 TFLOPs/s | N/A |
| H100 | 400 TFLOPs/s | 740 TFLOPs/s | 1.2 PFLOPs/s |
Коэффициенты ускорения¶
| Scenario | FA2 → FA3 Speedup |
|---|---|
| Short sequences (< 1K) | 1.3x |
| Medium sequences (1K-4K) | 1.5x |
| Long sequences (4K-16K) | 1.8x |
| Very long (16K+) | 2.0x |
Эффективность памяти¶
| Метрика | FA2 | FA3 |
|---|---|---|
| Чтений из HBM | 1x | 0.7x |
| Использование SRAM | 70% | 90% |
| Память на токен | O(1) | O(1) |
Сравнение с другими методами¶
Методы attention¶
| Method | Memory | Speed | Long Context | Hardware |
|---|---|---|---|---|
| Standard | O(N²) | Slow | No | Any |
| Flash Attention 2 | O(N) | Fast | Good | A100+ |
| Flash Attention 3 | O(N) | Very Fast | Excellent | H100+ |
| Ring Attention | O(N) | Medium | Excellent | Multi-GPU |
| Sparse Attention | O(N) | Medium | Good | Any |
Когда использовать FA3¶
| Use Case | Recommended |
|---|---|
| H100/H200 training | FA3 (required) |
| A100 training | FA2 |
| Inference only | FA2 or FA3 |
| Long context (100K+) | Ring + FA3 |
| Multi-GPU | Ring Attention |
Интеграция в код¶
Использование в PyTorch¶
import torch
from flash_attn import flash_attn_func
# Flash Attention 3 (Hopper only)
q = torch.randn(1, 8, 4096, 128, dtype=torch.bfloat16, device='cuda')
k = torch.randn(1, 8, 4096, 128, dtype=torch.bfloat16, device='cuda')
v = torch.randn(1, 8, 4096, 128, dtype=torch.bfloat16, device='cuda')
# Automatic dispatch to FA3 on H100
output = flash_attn_func(q, k, v, causal=True)
# FP8 version (FA3 only)
q_fp8 = q.to(torch.float8_e4m3fn)
k_fp8 = k.to(torch.float8_e4m3fn)
v_fp8 = v.to(torch.float8_e4m3fn)
output_fp8 = flash_attn_func(q_fp8, k_fp8, v_fp8, causal=True)
Чеклист интеграции¶
| Step | Description |
|---|---|
| 1 | Install flash-attn >= 2.6.0 |
| 2 | Verify GPU (H100/H200 for FA3) |
| 3 | Check CUDA >= 12.0 |
| 4 | Replace F.scaled_dot_product_attention |
| 5 | Verify outputs match (FP8 has small diffs) |
Ring Attention + FA3¶
Кольцевое внимание для длинного контекста¶
graph LR
G0["GPU 0<br/>Q0 + FA3"] <-->|"K,V rotation"| G1["GPU 1<br/>Q1 + FA3"]
G1 <-->|"K,V rotation"| G2["GPU 2<br/>Q2 + FA3"]
G2 <-->|"K,V rotation"| G3["GPU 3<br/>Q3 + FA3"]
G3 <-->|"K,V rotation"| G0
style G0 fill:#e8eaf6,stroke:#3f51b5
style G1 fill:#fff3e0,stroke:#ef6c00
style G2 fill:#e8f5e9,stroke:#4caf50
style G3 fill:#fce4ec,stroke:#e91e63
Каждый GPU использует FA3 для локального attention. Блоки K/V ротируются по кольцевой топологии. Масштабируется до миллионов токенов.
FlashAttention-4 (Hot Chips 2025)¶
Ключевые анонсы¶
- Поддержка Blackwell GPU -- оптимизация под архитектуру NVIDIA Blackwell
- Барьер петафлопса -- первое ядро attention, преодолевшее 1 PFLOP/s
- Selective rescaling -- пересчет масштабирования только когда изменение максимума влияет на численную корректность
- Абстракция TileLang -- ~80 строк вместо 500+ на CUDA, поддержка AMD MI300X через ROCm
- Reverse-engineering от Modal -- доступен детальный технический анализ
FA3 vs FA4¶
| Аспект | FA3 | FA4 |
|---|---|---|
| GPU | Hopper (H100) | Blackwell |
| Peak | 740 TFLOP/s (BF16) | 1+ PFLOP/s |
| Kernel code | 500+ lines CUDA | ~80 lines TileLang |
| Vendor support | NVIDIA only | NVIDIA + AMD (ROCm) |
Пробелы¶
- Полная статья FA4 ещё не опубликована (только анонс)
- Полные бенчмарки FA4 vs FA3 на Blackwell ещё не публичны
Ключевые числа для интервью¶
Пиковая производительность¶
| GPU | Peak TFLOPs (BF16) | FA3 Achieved | Utilization |
|---|---|---|---|
| A100 | 312 | 200 (FA2) | 64% |
| H100 | 989 | 740 | 75% |
Пропускная способность памяти¶
| GPU | Пропускная способность HBM | Эффективная (FA3) |
|---|---|---|
| A100 | 2.0 TB/s | 1.8 TB/s |
| H100 | 3.35 TB/s | 3.0 TB/s |
Масштабирование по длине последовательности¶
| Sequence | Standard | FA2 | FA3 |
|---|---|---|---|
| 1K | 50ms | 5ms | 3ms |
| 4K | 800ms | 25ms | 15ms |
| 16K | OOM | 120ms | 65ms |
| 64K | OOM | 600ms | 320ms |
FP8 vs BF16¶
| Метрика | BF16 | FP8 |
|---|---|---|
| Скорость | 1x | 1.6x |
| Память | 1x | 0.5x |
| Точность | Baseline | -0.1% loss |
Вопросы на собеседовании¶
Q: Объясните как Flash Attention достигает O(N) памяти вместо O(N^2)?
Red flag: "Используется разреженное внимание / аппроксимация"
Strong: "Flash Attention вычисляет точный attention, но через тайлинг: Q, K, V разбиваются на блоки, каждый блок помещается в SRAM. Матрица attention \(N \times N\) никогда не материализуется целиком в HBM. Online softmax позволяет поблочно обновлять статистики (\(m\), \(l\)) и получать точный результат. Backward pass использует recomputation вместо хранения."
Q: Чем FA3 отличается от FA2?
Strong: "FA3 оптимизирован под Hopper GPU (H100). Три ключевых отличия: (1) асинхронные Tensor Core через
wgmma.mma_async-- compute перекрывается с memory ops, (2) конвейеризация -- пока считается QK^T, загружается V, (3) нативная поддержка FP8 -- 1.6x speedup при потере <0.1% accuracy. Результат: 740 TFLOP/s BF16 (75% utilization) vs 400 TFLOP/s для FA2 на H100."
Q: Когда НЕ стоит использовать Flash Attention?
Strong: "На очень коротких последовательностях (< 128 tokens) overhead от тайлинга может быть больше экономии. Также FA3 требует Hopper GPU -- на A100 используйте FA2. При использовании sparse attention patterns (Longformer, BigBird) Flash Attention может быть несовместим с кастомными масками. Для cross-attention с сильно различающимися длинами Q и KV нужна специальная конфигурация тайлов."
Q: Как Ring Attention связан с Flash Attention?
Strong: "Ring Attention -- для распределения attention по нескольким GPU. Каждый GPU хранит свою часть Q, а блоки K/V ротируются по кольцевой топологии между GPU. На каждом GPU локально используется Flash Attention. Это ортогональные оптимизации: FA оптимизирует внутри GPU (SRAM vs HBM), Ring Attention -- между GPU (communication overlap). Вместе масштабируются до миллионов токенов."
Задание для самопроверки
- Рассчитайте, сколько памяти занимает матрица attention для sequence length 32768 в FP16. Почему стандартный attention не помещается на A100 (80 GB)?
- Объясните, почему online softmax дает математически идентичный результат стандартному softmax. Подсказка: запишите softmax через \(m\) (max) и \(l\) (sum of exps).
- Для тренировки модели на 128K контексте: какую комбинацию FA3 + Ring Attention + GQA вы бы использовали на кластере из 8 H100? Обоснуйте.
Источники¶
- arXiv — "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision" (2407.08608)
- NVIDIA GTC 2025 — "FlashAttention-3: Fast and Accurate Attention With Asynchrony..."
- GitHub — Dao-AILab/flash-attention
- Colfax Research — "FlashAttention-3 Technical Analysis"
- Medium — "FlashAttention-3: The Engine Powering Next-Gen LLMs"
- Nebius Blog — "Kvax: Fast Flash Attention for JAX"
- HPCWire — "Liger-Kernel: Triton Kernels for Efficient LLM Training"
- CrowdStrike Blog -- "How We Train GenAI Models with Distributed Computing"
- Modal -- "Reverse-engineering FlashAttention-4" (2026)
- AMD ROCm Blog -- "ROCm TileLang Kernel" (2026)
See Also¶
- Attention Implementation -- реализация vanilla attention с нуля
- KV Cache Optimization -- PagedAttention, prefix caching
- MQA/GQA Attention -- Multi-Query/Grouped-Query для уменьшения KV cache
- Efficient Transformers -- обзор всех техник оптимизации
- Inference Engines -- vLLM, SGLang используют FlashAttention