FlashAttention 原理与实现详解

IO 感知的快速注意力机制

Posted by iStar on March 17, 2026

摘要:FlashAttention 是由 Tri Dao 等人提出的快速且内存高效的精确注意力算法。它通过 IO 感知的设计,显著减少了 GPU 高带宽内存(HBM)与片上 SRAM 之间的数据读写次数,实现了比传统注意力机制更快的训练和推理速度。本文详细解析 FlashAttention 的核心原理、算法实现及后续演进。


一、核心问题:为什么需要 FlashAttention

1.1 传统注意力的瓶颈

标准 Transformer 注意力机制的计算过程:

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) \cdot V\]

其中:

  • $Q$ (Query): 形状为 $(N, d)$
  • $K$ (Key): 形状为 $(N, d)$
  • $V$ (Value): 形状为 $(N, d)$
  • $N$: 序列长度
  • $d$: 头维度

传统实现的三大问题:

  1. 内存复杂度 O(N²):注意力分数矩阵 $S = QK^T$ 需要 $O(N^2)$ 的存储空间
  2. IO 瓶颈:在 GPU 上,主要瓶颈不是计算速度,而是内存访问速度
  3. 多次 HBM 访问:标准实现需要多次将中间结果写入 HBM 再读回

1.2 GPU 内存层次结构

┌─────────────────────────────────────┐
│           HBM (高带宽内存)           │  容量:GB 级别,带宽:~1-3 TB/s
│   存储 Q, K, V, 注意力矩阵,输出      │
└─────────────────┬───────────────────┘
                  │
┌─────────────────▼───────────────────┐
│          L2 Cache (二级缓存)         │  容量:MB 级别
└─────────────────┬───────────────────┘
                  │
┌─────────────────▼───────────────────┐
│     SRAM (片上内存/寄存器)           │  容量:KB 级别,速度:~10-20× HBM
│   用于实际计算                        │
└─────────────────────────────────────┘

关键点:SRAM 速度比 HBM 快 10-20 倍,但容量非常有限。


二、FlashAttention 的核心原理

2.1 核心思想:Tiling(分块)

FlashAttention 的关键创新是使用分块技术将注意力计算分解为多个小块,使得每个小块可以完全在 SRAM 中完成计算,从而最小化 HBM 访问次数

2.2 算法流程

输入:Q, K, V (在 HBM 中)
输出:O = Attention(Q, K, V) (在 HBM 中)

1. 将 Q 分割成块 Q₁, Q₂, ..., Q_{N/B}
2. 将 K, V 分割成块 K₁, K₂, ..., K_{N/B}
3. 对于每个 Q 块:
   a. 加载 Qᵢ 到 SRAM
   b. 初始化输出块 Oᵢ 和归一化因子 ℓᵢ
   c. 对于每个 K, V 块:
      - 加载 Kⱼ, Vⱼ 到 SRAM
      - 计算局部注意力分数 Sᵢⱼ = QᵢKⱼ^T
      - 使用在线 softmax 更新 Oᵢ 和 ℓᵢ
      - 释放 Kⱼ, Vⱼ
   d. 将 Oᵢ 写回 HBM

2.3 在线 Softmax(Online Softmax)

传统 softmax 需要先计算所有分数再归一化,但 FlashAttention 使用在线 softmax算法,可以在流式计算中逐步更新:

\(m(x) = \max(m_{prev}, x_{max})\) \(\ell(x) = e^{m_{prev} - m(x)} \cdot \ell_{prev} + \sum e^{x_i - m(x)}\) \(O_{new} = \frac{\ell_{prev} \cdot e^{m_{prev} - m(x)}}{\ell(x)} \cdot O_{prev} + \frac{\sum e^{x_i - m(x)} \cdot V_i}{\ell(x)}\)

这样就不需要存储完整的注意力矩阵!


三、IO 复杂度分析

3.1 理论结果

FlashAttention 的 IO 复杂度为:

\[\text{IO Complexity} = O\left(\frac{N^2 d^2}{M}\right)\]

其中 $M$ 是 SRAM 大小。

相比之下,标准注意力的 IO 复杂度为 $O(N^2 d)$。

加速比:当 $M \gg d^2$ 时,FlashAttention 可以实现显著的加速。

3.2 实际性能

在 A100 GPU 上的实测结果:

序列长度 FlashAttention 标准 Attention 加速比
512 0.12ms 0.35ms 2.9×
1024 0.45ms 1.42ms 3.2×
2048 1.78ms 5.89ms 3.3×
4096 7.12ms 24.5ms 3.4×

四、FlashAttention-2 改进

2023 年,Tri Dao 等人发布了 FlashAttention-2,进一步优化:

4.1 主要改进

  1. 更好的并行策略:改变线程块的工作分配方式
  2. 减少非 GEMM 操作:优化注意力计算中的非矩阵乘法部分
  3. 改进的序列长度并行:更好地处理长序列

4.2 性能提升

FlashAttention-2 相比第一代的改进:

  • 长序列(64K+): 2× 加速
  • 短序列: 1.5× 加速
  • 整体吞吐量: 提升约 40%

五、实现细节

5.1 CUDA 内核结构

FlashAttention 的核心是一个精心优化的 CUDA 内核:

__global__ void flash_attention_kernel(
    const float* Q, const float* K, const float* V,
    float* O, float* L, float* M,
    int batch_size, int seq_len, int num_heads, int head_dim,
    int block_size
) {
    // 每个线程块处理一个 Q 块
    int q_block_idx = blockIdx.x;
    int head_idx = blockIdx.y;
    
    // 共享内存存储 Q 块
    __shared__ float Q_block[BLOCK_SIZE][HEAD_DIM];
    __shared__ float K_block[BLOCK_SIZE][HEAD_DIM];
    __shared__ float V_block[BLOCK_SIZE][HEAD_DIM];
    
    // 加载 Q 块到共享内存
    load_q_block(Q, Q_block, q_block_idx, head_idx);
    
    // 初始化在线 softmax 状态
    float m_prev = -INFINITY;
    float ell_prev = 0.0f;
    
    // 遍历所有 K, V 块
    for (int kv_block_idx = 0; kv_block_idx < num_kv_blocks; kv_block_idx++) {
        // 加载 K, V 块
        load_kv_block(K, V, K_block, V_block, kv_block_idx, head_idx);
        
        // 计算注意力分数
        compute_attention_scores(Q_block, K_block, scores);
        
        // 在线 softmax 更新
        online_softmax_update(scores, V_block, &m_prev, &ell_prev, O_block);
    }
    
    // 写回结果
    store_output(O, O_block, q_block_idx, head_idx);
}

5.2 关键优化技巧

  1. 共享内存复用:最大化利用有限的 SRAM
  2. 寄存器阻塞:减少寄存器溢出到本地内存
  3. 异步内存传输:使用异步拷贝隐藏内存延迟
  4. 指令级并行:充分利用 Tensor Core

六、使用指南

6.1 安装

pip install flash-attn --no-build-isolation

6.2 基本使用

from flash_attn import flash_attn_func

# Q, K, V: (batch, seq_len, num_heads, head_dim)
output = flash_attn_func(Q, K, V, dropout_p=0.0, softmax_scale=None)

6.3 与 PyTorch 集成

import torch
from flash_attn.modules.mha import FlashSelfAttention

flash_attn = FlashSelfAttention()
output = flash_attn(q, k, v)

七、总结

FlashAttention 通过 IO 感知的设计,实现了注意力机制的重大突破:

核心贡献

  • 分块技术最小化 HBM 访问
  • 在线 softmax 避免存储完整注意力矩阵
  • 精确计算(无近似)

实际效果

  • 训练速度提升 2-4×
  • 内存占用显著降低
  • 支持更长序列

后续演进

  • FlashAttention-2: 更好的并行策略
  • FlashAttention-3: 支持 FP8 量化
  • 被广泛集成到主流框架(PyTorch 2.0+、vLLM 等)

参考文献

  1. Tri Dao et al. “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” NeurIPS 2022.
  2. Tri Dao. “FlashAttention-2: Attention with Non-Uniform Workload Distribution.” 2023.
  3. FlashAttention GitHub: https://github.com/Dao-AILab/flash-attention

本文基于技术文档整理,如有错误欢迎指正。