元鉴
返回中文阅读流

NVIDIA Developer Blog

在 NVIDIA CUDA Tile 中调整 Flash Attention 以实现峰值性能

在本文中,我们将深入探讨现代 AI 中最关键的工作负载之一:Flash Attention,您将学习:如何使用 NVIDIA...实现 Flash Attention。

中文内容

已翻译official company source英文原文2026-05-26

在这篇文章中,我们将深入探讨现代 AI 中最关键的工作负载之一:Flash Attention,你将了解:

  • 如何使用 NVIDIA cuTile 实现 Flash Attention。逐步讲解一个可用于生产环境的完整实现代码。
  • “陷阱与补救”的优化历程。这个案例研究展示了朴素的优化(例如仅仅增大 tile 大小)如何适得其反,以及如何修复这些问题。
  • 用于实现最高性能的高级技术,例如 FMA 模式、快速数学、循环拆分和自适应 tiling。

环境要求:

  • CUDA 13.1 或更高版本
  • GPU 架构:计算能力 8.X、10.X、11.X、12.X(NVIDIA Ampere、NVIDIA Ada、NVIDIA Blackwell)
  • Python:3.10 或更高版本

有关安装 cuTile Python 的更多信息,请参阅快速入门文档。

什么是 attention?

attention 机制是 transformer 模型的计算核心。给定一个 token 序列,attention 使每个 token 能够“查看”其他所有 token,并决定对它们的贡献赋予多大的权重。从数学上看,对于输入矩阵 Query(\(Q\))、Key(\(K\))和 Value(\(V\)),输出为:

正文:\(O = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V\)

其中:

  • \(Q \text{ 的形状为 } (N,d),\ N \text{ 个查询 token,每个维度为 } d.\)
  • \(K \text{ 的形状为 } (N,d),\ N \text{ 个键 token。}\)
  • \(V \text{ 的形状为 } (N,d),\ N \text{ 个值 token。}\)
  • \(\text{中间的 } QK^{T} \text{ 矩阵形状为 } (N,N),\text{这是一个问题。}\)

内存带宽问题

对于序列长度 \(N = 16,384\)(在现代 LLM 中很常见),注意力矩阵 \(QK^{T}\) 包含 \(N^2 = 268\) 百万个元素。在 FP16 中,这相当于每个注意力头、每个批次项需要 512 MB 的中间存储。

标准注意力实现:

  1. 计算完整的 \(N \times N\) 注意力矩阵并将其写入全局内存(速度较慢)
  2. 逐行应用 softmax
  3. 读回该矩阵并与 \(V\) 相乘

这种方法受内存带宽限制,因为 GPU 大部分时间都在等待数据在 HBM 和计算单元之间移动,而不是进行计算。

Flash Attention 如何解决内存带宽问题

Flash Attention(由 Dao 等人于 2022 年提出)是一种 IO 感知算法,它从不实例化完整的 \(N \times N\) 矩阵。相反,它:

  1. 对计算进行分块:以适合快速片上 SMEM 的小块处理 \(Q, K, V\)
  2. 使用在线 softmax:以增量方式计算 softmax,而不需要完整的行
  3. 融合操作:将矩阵乘法和 softmax 合并为单次内核执行

其结果是实现 2-4 倍加速并显著节省内存,从而支持更长的上下文长度。

A tiled flash attention figure showing Q, K^T, V and O in HBM, being accumulated to Q, K, V, and O in SMEM.A tiled flash attention figure showing Q, K^T, V and O in HBM, being accumulated to Q, K, V, and O in SMEM.
图 1. 分块式 Flash Attention 计算

理解在线 softmax

Flash Attention 的关键算法洞见是在线 softmax 技巧。数值稳定的安全 softmax 需要在计算前知道整行的最大值:

正文:\(\text{softmax}(x_i) = \frac{e^{x_i – \max(x)}}{\sum_j e^{x_j – \max(x)}}\)

但如果我们正在处理分块,就无法访问完整的一行。在线 softmax 通过维护可增量更新的运行统计量来解决这个问题。

在线 softmax 算法

我们为每一行维护两个运行中的值:

  • \(m_i\):到目前为止看到的最大值(用于数值稳定性)
  • \(l_i\):到目前为止看到的指数和(softmax 分母)

当我们处理一个包含值 \(x_{new}\) 的新分块时:

  1. 更新最大值:\(m_{new} = \max(m_i, \max(x_{new}))\)
  2. 计算校正因子:\(\alpha = e^{m_i – m_{new}}\)(重新缩放先前的计算结果)
  3. 更新求和项:\(l_i = l_i \cdot \alpha + \sum e^{x_{new} – m_{new}}\)
  4. 更新累加器:\(acc = acc \cdot \alpha + P_{new} \cdot V_{tile}\)

\(P_{new}\) 是注意力权重矩阵,\(V_{tile}\) 是值矩阵分块,对应于当前迭代中的 Key 分块。最后,我们进行归一化:\(O = acc / l_i\)

这使我们能够计算精确的 softmax,而无需存储完整的行。

因果注意力和分组查询注意力

在深入实现之前,我们先了解现代 LLM 中使用的两种重要注意力变体:

因果注意力

在 GPT、LLaMA 和 Claude 等自回归语言模型中,每个 token 只能关注序列中位于其之前的 token,而不能关注未来的 token。这可以防止训练过程中的“作弊”,即模型通过提前查看后续内容来预测下一个词。

从数学上看,我们会对注意力分数应用一个三角掩码:

\(\text{mask}_{ij} = \begin{cases} 0 & \text{如果 } i \geq j \text{(查询位置 ≥ 键位置)} \ -\infty & \text{如果 } i < j \text{(未来 token)} \end{cases}\)

掩码注意力变为:

正文:\(O = \text{softmax}\left(\frac{QK^T}{\sqrt{d}} + \text{mask}\right)V\)

向未来位置添加 \(-\infty\) 可确保它们在经过 softmax 后变为零,从而有效阻断来自未来 token 的信息流。

Causal attention mask matrix for 4 tokens showing how the upper triangle of the matrix is masked to 0, meaning that those values are not used in the computation.Causal attention mask matrix for 4 tokens showing how the upper triangle of the matrix is masked to 0, meaning that those values are not used in the computation.
图 2. 四个 token 的因果注意力掩码

使用因果掩码时,大约一半的注意力矩阵会被掩蔽(上三角部分)。我们可以完全跳过这些被掩蔽图块的计算,从而带来 2 倍的算法加速。这对于 K 循环拆分优化至关重要。

分组查询注意力

标准多头注意力为每个注意力头分别配备 \(K,V\) 矩阵,导致内存占用较高:

  • 多头注意力(MHA):32 个查询头 → 32 个 K/V 头(1:1 比例)
  • 分组查询注意力(GQA):32 个查询头 → 4 个 K/V 头(8:1 比例)
  • 多查询注意力(MQA):32 个查询头 → 1 个 K/V 头(32:1 比例)

在 GQA 中,多个查询头共享相同的 K/V 头。例如,使用 32 个查询头和 4 个 K/V 头时:

  • 查询头 0-7 使用 K/V 头 0
  • 查询头 8-15 使用 K/V 头 1
  • 查询头 16-23 使用 K/V 头 2
  • 查询头 24-31 使用 K/V 头 3

这在推理过程中将 K/V 缓存大小减少了 8 倍,对于服务长上下文模型至关重要。LlamA 2、Llama 3、Mistral 和 Qwen 等现代 LLM 广泛使用 GQA。

在 Flash Attention 中实现时,每个 CUDA 块为一个查询头计算注意力,但会加载相应的共享 K/V 头:

head_idx = bid_y % num_heads              # Which query head (0-31)
kv_head_idx = head_idx // query_group_size # Which K/V head (0-3)

当查询组大小为 8 时,查询头 0-7 都映射到 kv_head_idx = 0,共享内存中相同的 K/V tile。

第 1 部分:CUDA Tile 中的 flash attention kernel

让我们一步步实现 Flash Attention。我们的基线采用较小的 64×64 tile 和直观的代码——正确但尚未优化。

1. 定义内核接口

在 cuTile 中,@ct.kernel 装饰器将一个 Python 函数标记为 GPU 内核。我们使用 ct.Constant[T] 类型注解来传递编译时常量:

import math
import cuda.tile as ct

# Type aliases for compile-time constants
ConstInt = ct.Constant[int]
ConstBool = ct.Constant[bool]

# Conversion factor: we use exp2 instead of exp for efficiency
INV_LOG_2 = 1.0 / math.log(2)

@ct.kernel()
def fmha_kernel(
    Q, K, V, Out,              # Input/output tensors
    qk_scale: float,           # Scale factor (1/sqrt(d))
    input_pos: int,            # Position offset for causal masking
    TILE_D: ConstInt,          # Head dimension (for example, 128)
    H: ConstInt,               # Number of attention heads
    TILE_M: ConstInt,          # Tile size for Q dimension (for example, 64)
    TILE_N: ConstInt,          # Tile size for K/V dimension (for example, 64)
    QUERY_GROUP_SIZE: ConstInt,# For Grouped Query Attention (GQA)
    CAUSAL: ConstBool,         # Whether to apply causal mask
    EVEN_K: ConstBool,         # Whether K length is divisible by TILE_N
):

2. 块 ID 映射

每个 CUDA 块计算输出的一个分块。使用 ct.bid,我们将二维网格映射到 batch/head 索引:

# Get block indices
    bid_x = ct.bid(0)  # Which tile along the sequence dimension
    bid_y = ct.bid(1)  # Which batch-head combination
    
    # Decode batch and head from flattened index
    batch_idx = bid_y // H
    head_idx = bid_y % H
    
    # For Grouped Query Attention: multiple Q heads share one K/V head
    off_kv_h = head_idx // QUERY_GROUP_SIZE

3. 初始化累加器

在主循环之前,我们初始化在线 softmax 状态和输出累加器:

# Convert scale for base-2 exponential (faster than natural exp)
    qk_scale = qk_scale * INV_LOG_2
    
    # Create position indices for this tile
    offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
    offs_m += input_pos
    offs_m = offs_m[:, None]  # Shape: [TILE_M, 1]
    
    offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
    offs_n_tile = offs_n_tile[None, :]  # Shape: [1, TILE_N]
    
    # Online softmax state (float32 for numerical stability)
    m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)  # Running max
    l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)        # Running sum
    acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)   # Output accumulator

我们对累加器使用 float32,即使输入为 float16,也要在迭代式 softmax 计算过程中保持数值精度。

4. 加载查询分块

查询图块只加载一次,并在所有 K/V 迭代中重复使用:

    # Load Q tile: shape [1, 1, TILE_M, TILE_D] -> [TILE_M, TILE_D]
    q = ct.load(
        Q, 
        index=(batch_idx, head_idx, bid_x, 0), 
        shape=(1, 1, TILE_M, TILE_D)
    ).reshape((TILE_M, TILE_D))

当图块超出张量边界时,ct.load 函数会自动处理边界条件。

5. K/V 图块上的主循环

这是 Flash Attention 的核心。我们遍历 K/V 图块:

   # Calculate loop bounds
    m_end = input_pos + (bid_x + 1) * TILE_M
    k_seqlen = K.shape[2]
    
    if CAUSAL:
        # For causal attention, stop early (future tokens are masked)
        Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
    else:
        Tc = ct.cdiv(k_seqlen, TILE_N)
    
    for j in range(0, Tc):
        # --- Step A: Load Key tile and compute QK^T ---
        k = ct.load(
            K,
            index=(batch_idx, off_kv_h, 0, j),
            shape=(1, 1, TILE_D, TILE_N),
            order=(0, 1, 3, 2),  # Transpose for correct layout
            latency=2            # Hint for memory prefetching
        ).reshape((TILE_D, TILE_N))
        
        # Matrix multiply: Q @ K^T
        qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
        qk = ct.mma(q, k, qk)  # Uses Tensor Cores automatically

参数中的 order=(0,1,3,2) 告诉 cuTile 加载操作使用转置后的 K,而 latency=2 提示我们可以容忍一定延迟(从而实现更好的流水线化)。然后我们使用 ct.mma=(q, k, k,qk) 执行 cuTile 矩阵乘加运算。

6. 应用因果掩码

对于自回归模型(GPT、Llama 等),每个 token 只能关注之前的 token:

# --- Step B: Apply causal masking ---
        if CAUSAL or not EVEN_K:
            offs_n = j * TILE_N + offs_n_tile
            mask = ct.full((TILE_M, TILE_N), True, dtype=ct.bool_)
            
            # Boundary mask (for non-divisible sequence lengths)
            if not EVEN_K:
                mask = mask & (offs_n < k_seqlen)
            
            # Causal mask: query position >= key position
            if CAUSAL:
                mask = mask & (offs_m >= offs_n)
            
            # Convert to additive mask: True->0, False->-inf
            mask = ct.where(mask, 0.0, -math.inf)
            qk += mask

向被掩码的位置添加 -inf 可确保它们在 softmax 后变为零。

7. 在线 softmax 更新

现在我们更新运行中的 softmax 统计量:

   # --- Step C: Online softmax ---
        # Find max in current tile
        qk_max = ct.max(qk, axis=-1, keepdims=True)
        qk_max_scaled = qk_max * qk_scale
        
        # Update running maximum
        m_ij = max(m_i, qk_max_scaled)
        
        # Scale QK scores
        qk = qk * qk_scale
        qk = qk - m_ij
        
        # Compute attention weights (using exp2 for speed)
        p = ct.exp2(qk)
        
        # Update running sum
        l_ij = ct.sum(p, axis=-1, keepdims=True)
        alpha = ct.exp2(m_i - m_ij)  # Correction factor
        l_i = l_i * alpha
        l_i = l_i + l_ij
        
        # Rescale previous accumulator
        acc = acc * alpha

8. 累加输出

最后,我们加载 Value tile 并进行累加:

# --- Step D: Load V and accumulate ---
        v = ct.load(
            V,
            index=(batch_idx, off_kv_h, j, 0),
            shape=(1, 1, TILE_N, TILE_D),
            latency=4
        ).reshape((TILE_N, TILE_D))
        
        # Cast attention weights back to input dtype for Tensor Core MMA
        p = p.astype(Q.dtype)
        
        # Accumulate: acc += P @ V
        acc = ct.mma(p, v, acc)
        
        # Update max for next iteration
        m_i = m_ij

9. 最终归一化并存储

处理完所有图块后,我们按总和进行归一化并写入结果:

   # --- Final: Normalize and store ---
    acc = ct.truediv(acc, l_i)
    acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
    ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)

启动内核:主机端代码

现在让我们看看启动内核的主机端代码:

import torch
from math import ceil

def tile_fmha(q, k, v, sm_scale=None, is_causal=True):
    """
    Launch the Flash Attention kernel.
    
    Args:
        q: Query tensor, shape [batch, heads, seq_len, head_dim]
        k: Key tensor, shape [batch, kv_heads, seq_len, head_dim]
        v: Value tensor, shape [batch, kv_heads, seq_len, head_dim]
        sm_scale: Softmax scale (default: 1/sqrt(head_dim))
        is_causal: Whether to apply causal masking
    
    Returns:
        Output tensor, same shape as q
    """
    if sm_scale is None:
        sm_scale = 1.0 / math.sqrt(q.size(-1))
    
    batch_size, num_heads, seq_len, head_dim = q.shape
    _, num_kv_heads, _, _ = k.shape
    
    # Calculate query group size for GQA
    query_group_size = num_heads // num_kv_heads
    
    # Ensure contiguous memory layout
    q = q.contiguous()
    k = k.contiguous()
    v = v.contiguous()
    
    # Allocate output
    o = torch.empty_like(q)
    
    # Choose tile sizes (we'll optimize this later!)
    TILE_M, TILE_N = 64, 64
    
    # Calculate grid dimensions
    grid_x = ceil(seq_len / TILE_M)  # Number of tiles along sequence
    grid_y = batch_size * num_heads  # One block per batch-head pair
    grid = (grid_x, grid_y, 1)
    
    # Check if K length is evenly divisible
    EVEN_K = (k.shape[2] % TILE_N) == 0
    
    # Launch kernel
    ct.launch(
        torch.cuda.current_stream(),
        grid,
        fmha_kernel,
        (q, k, v, o, sm_scale, 0, head_dim, num_heads,
         TILE_M, TILE_N, query_group_size, is_causal, EVEN_K)
    )
    
    return o

这个采用 64×64 分块的基线版本可以正确运行。但我们能让它更快吗?让我们来看看。

第 2 部分:“陷阱与救援”的优化之旅

我们在以下配置上进行基准测试:

  • 硬件:NVIDIA B200
  • 批量大小:4,头数:32,头维度:128
  • 注意力:因果,数据类型:FP16
  • 序列长度:1024、2048、4096、8192、16384

为解释每一步,我们使用 Nsight Compute,并采用最小化的分析区段集合:

  • 正文:LaunchStats
  • 正文:Occupancy
  • 正文:SpeedOfLight
  • 正文:ComputeWorkloadAnalysis
  • 内存工作负载分析

基线性能

SeqLenThroughput (TFLOPS)1,0243302,0484414,0965118,19254616,384566
表 1. 未进行任何特定优化的基线性能

这是我们使用 64×64 分块且未进行优化的起点。

NCU 分析(SeqLen=1024,B200):

  • 每线程寄存器数:128
  • 理论/实际占用率:25% / 19.8%
  • 计算(SM)吞吐量:37.8%
  • 内存吞吐量:19.7%
  • 网格大小:2,048

1. 更大分块的陷阱

GPU 编程中一个常见的直觉是“分块越大 = 性能越好”。更大的分块:

  • 分摊内存访问开销。
  • 提高 L2 缓存利用率。
  • 降低每个元素的内核启动开销。

因此,让我们将 tile 大小从 64×64 增加到 256×128:

TILE_M, TILE_N = 256, 128  # Was 64, 64

预期是更好的内存带宽利用率 → 更快的性能。然而,以 TFLOPS 计的结果是:

SeqLenBaseline (64×64)Larger tiles (256×128)Performance Degradation1,024330187-43%2,048441268-39%4,096511347-32%8,192546415-24%16,384566463-18%
表 2. 基线性能与使用更大 tile 尺寸时的性能对比,显示使用更大 tile 尺寸时性能下降

在所有序列长度下,性能下降了 18-43%。这就是陷阱:大的 tile 会让性能变得更差。

为什么会发生这种情况?

  1. 计算瓶颈:随着每个分块中的元素增多,低效操作(单独的乘法/加法、精确数学运算)会成为瓶颈。
  2. 指令开销:每个分块的工作量增加,意味着在下一次内存操作之前需要执行更多指令。

经验:分块大小与计算效率相互依赖。只有当计算足够高效、能够跟上时,大分块才有帮助。

NCU 洞察(SeqLen=1,024,NVIDIA B200):

  • 每线程寄存器数跃升至 168(+31%),理论占用率降至 18.75%
  • 实际占用率降至 16.5%
  • 计算吞吐量暴跌至 17.4%(陷阱)
  • 内存吞吐量降至 7.4%
  • 网格大小缩小到 512(由于 tile 更大,block 数量更少)

2. 使用快速数学运算进行挽救

瓶颈之一是特殊函数:exp2(指数)和 truediv(除法)。默认情况下,这些函数遵循 IEEE-754 精度——准确度很高,但速度较慢。

对于深度学习,我们可以用极小的精度损失换取巨大的速度提升:

之前(精确运算):

p = ct.exp2(qk)
alpha = ct.exp2(m_i - m_ij)
acc = ct.truediv(acc, l_i)

之后(快速数学):

p = ct.exp2(qk, flush_to_zero=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)

这些标志的作用:

  • flush_to_zero=True:非规格化数(接近零的极小值)会变为精确的零。这可以避免 GPU 上缓慢的微代码路径。
  • rounding_mode=RMd.APPROX:在初始硬件近似之后跳过迭代细化。

借助 fast math,我们“挽救”了大尺寸 tile,结果以 TFLOPS 表示如下:

SeqLenLarger tiles (trap)Fast math (rescue)Improvement1,024187322+72%2,048268436+63%4,096347524+51%8,192415585+41%16,384463620+34%
表 3. 使用两项 fast math 优化时的性能提升

我们现在达到或超过了小尺寸 tile 的基线水平,并且在较长序列上获得了 10-20% 的提升。

NCU 洞察(SeqLen=1,024,NVIDIA B200):

  • 每线程寄存器数:168(未变化)
  • 理论/实际占用率:18.75% / 16.6%(未变化)
  • 计算吞吐量回升至 24.0%
  • 内存吞吐量提升至 12.9%

3. K 循环拆分

对于因果注意力,我们应用三角掩码:每个 query 只能关注更早位置的 key。在我们的基线中,我们在每次循环迭代时都会检查 if CAUSAL: mask …。

但想一想:对于位置 1000 的 query tile,大多数 key tile(0-900)根本不需要任何掩码。只有靠近对角线的 tile 才需要掩码。而 query 位置之后的 tile 会被完全掩码(我们可以直接跳过它们)。

Q by K tiled causal attention matrix showing 8 tiles per side and showing how the lower triangle is computed. The diagonal is partially computed, and the upper triangle is skipped.Q by K tiled causal attention matrix showing 8 tiles per side and showing how the lower triangle is computed. The diagonal is partially computed, and the upper triangle is skipped.
图 3. 平铺因果注意力矩阵(每边 8 个图块)

该优化将循环拆分为多个阶段:

# Calculate where masking starts being necessary
mask_start = (input_pos + bid_x * TILE_M) // TILE_N
mask_start = min(mask_start, k_seqlen // TILE_N)

# Calculate where to stop (for causal, we exit early)
if CAUSAL:
    Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
else:
    Tc = ct.cdiv(k_seqlen, TILE_N)

for j in range(0, Tc):
    # Load K and compute QK...
    
    # ONLY apply masking when necessary
    if (CAUSAL or not EVEN_K) and j >= mask_start:
        offs_n = j * TILE_N + offs_n_tile
        mask = ct.full((TILE_M, TILE_N), True, dtype=ct.bool_)
        if not EVEN_K:
            mask = mask & (offs_n < k_seqlen)
        if CAUSAL:
            mask = mask & (offs_m >= offs_n)
        mask = ct.where(mask, 0.0, -math.inf)
        qk += mask
    
    # Continue with softmax and accumulation...

这为何重要:对于一个包含 16K 序列、256-token 图块的情况:

  • 约 50% 的图块完全未被掩蔽(无分支,无需计算掩码)
  • 每行约有 1 个 tile 被部分掩蔽(完整逻辑)
  • 其余部分被完全跳过(提前退出)

结果(TFLOPS):

SeqLenFast mathLoop splitImprovement1,024322373+16%2,048436552+27%4,096524684+31%8,192585770+32%16,384620813+31%
表 4. 使用 K 循环拆分优化时的性能提升

这是最大的一项单独优化——在所有序列长度上最高可实现 32% 的加速。

NCU 洞察(SeqLen=1,024,B200):

  • 寄存器/线程:168(未变化)
  • 理论/实际占用率:18.75% / 16.6%(未变化)
  • 内存吞吐率提升至 14.5%(浪费的工作更少)
  • 计算吞吐率保持在 24.0%(工作更有用,但每个周期不一定更快)

4. ProgramId 重映射

一个微妙的优化是针对因果注意力反转块顺序。当我们按反向(从右下到左上)处理图块时,由于因果掩码,后启动的块工作量更少。这改善了负载均衡并减少了尾部效应。

之前(标准顺序):

bid_x = ct.bid(0)  # Process tiles 0, 1, 2, ...

之后(为因果性而反转):

if CAUSAL:
    bid_x = NUM_M_BLOCKS - 1 - ct.bid(0)  # Process tiles N, N-1, N-2, ...
else:
    bid_x = ct.bid(0)

这一小改动改善了波次调度,因为各个块在 GPU 上完成得更加均匀。

结果(TFLOPS):

SeqLenLoop splitRemappingImprovement1,024373377+1%2,048552560+1.5%4,096684696+1.8%8,192770781+1.5%16,384813835+2.6%
表 5. 重新映射分块的块顺序后的性能提升

有一个幅度不大但稳定的 1–3% 提升,尤其在较长序列中更为明显,因为此时尾部效应影响最大。

5. 自动调优

我们已经优化了大分块,但有一个问题:短序列仍然更适合小分块。

为什么?对于一个 1,024 个 token 的序列和 256 个 token 的分块,我们只有 4 个分块。这不足以充分利用 B200 上的所有 SM。更小的分块(64×64)会产生 16 个分块,从而更好地填满 GPU。

与其手动选择阈值,我们可以让 cuTile 的自动调优器对多种配置进行基准测试,并为每种输入形状缓存最佳配置。

自动调优器方法:

def _fmha_autotune_configs():
    """Search space for autotuning.

    The autotuner will benchmark these configurations and cache the best one
    per input shape (sequence length, batch size, etc.).
    """
    gpu_capability = torch.cuda.get_device_capability()

    if gpu_capability in [(12, 0), (12, 1)]:
        # RTX 50 series (sm120, sm121)
        yield SimpleNamespace(TILE_M=64, TILE_N=64, num_ctas=1, occupancy=2)
    else:
        # B200/GB200 (sm100) - Try multiple tile sizes
        # Autotuner will discover:
        # - 64x64 is best for short sequences (1024-2048)
        # - 128x128 may be best for medium sequences (4096)
        # - 256x128 is best for long sequences (8192+)
        yield SimpleNamespace(TILE_M=64, TILE_N=64, num_ctas=1, occupancy=2)
        yield SimpleNamespace(TILE_M=128, TILE_N=128, num_ctas=1, occupancy=2)
        yield SimpleNamespace(TILE_M=256, TILE_N=128, num_ctas=1, occupancy=1)

如何使用自动调优启动:

不要直接调用 ct.launch,而是使用 ct_experimental.autotune_launch:

import cuda.tile_experimental as ct_experimental

def autotune_launch_fmha(
    stream, q, k, v, o, sm_scale, input_pos,
    hidden_size, num_heads, query_group_size, is_causal
):
    batch_size, _, q_len, _ = q.shape

    def _grid_fn(cfg):
        return (math.ceil(q_len / cfg.TILE_M), batch_size * num_heads, 1)

    def _args_fn(cfg):
        num_m_blocks = math.ceil(q_len / cfg.TILE_M)
        even_k = (k.shape[2] % cfg.TILE_N) == 0
        return (
            q, k, v, o, sm_scale, input_pos,
            hidden_size, num_heads, cfg.TILE_M, cfg.TILE_N,
            query_group_size, is_causal, even_k, num_m_blocks,
        )

    ct_experimental.autotune_launch(
        stream,
        grid_fn=_grid_fn,
        kernel=fmha_kernel,
        args_fn=_args_fn,
        hints_fn=lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy},
        search_space=_fmha_autotune_configs,
    )

注意:自动调优器 API 可能会发生变化。

自动调优器会智能地工作:

  1. 首次以 seq_len=1024 调用:对所有 3 个配置进行基准测试,并缓存最佳配置
  2. 首次调用 seq_len=2048:对所有 3 种配置进行基准测试,并缓存最佳配置
  3. 后续调用:使用已缓存的配置(零开销)

缓存键包含张量形状,因此不同的序列长度会自动获得不同的最佳配置。

以 TFLOPS 表示的结果:

SeqLenBaselineRemappingAutotuneSpeedup vs baseline1,0243303775481.66x2,0484415607081.61x4,0965116968171.60x8,1925467818871.62x16,3845668359181.62x
表 6. 原始基线与步骤 5 以及步骤 6 自动调优结果的比较

自动调优器发现,对于序列长度 ≤2,048,64×64 分块效果最佳;随后在更长序列上切换到更大的分块。与固定的大分块相比,这在短序列上带来了额外 45% 的性能提升,同时在长序列上保持峰值性能。

自动调优器的选择(在 B200 上):

  • SeqLen 1,024:64×64 分块(高并行度)
  • SeqLen 2,048:64×64 或 128×128 分块(均衡)
  • SeqLen 4,096+:128×128 或 256×128 分块(内存效率)

我们现在无需手动调优,即可在所有序列长度上实现最佳性能。

总结:优化栈

OptimizationKey insightImpactBaseline (64×64)Correct but unoptimizedBaselineLarge tiles (256×128)TRAP: 18-43% slower!-18% to -43%+ Fast math (FTZ, APPROX)RESCUE: Large tiles now pay off+34% to +72% from trap+ K-loop splitBiggest single optimization+16% to +32%+ ProgramId remappingBetter load balancing+1% to +3%+ AutotuningOptimal tiles per sequence+10% to +45%
表 7. 逐步优化结果及每一步对性能的影响

最终加速比:在所有序列长度上达到 1.60x-1.66x。

入门

编写高性能内核很少是为了找到某个“神奇”的设置。正如我们在“陷阱与补救”中看到的那样:

  1. 优化是相互依赖的:在我们修正数学运算之前,大分块反而更慢。不能孤立地评估分块大小。
  2. 数学很重要:flush_to_zero 和 APPROX 等标志对于释放 Tensor Core 吞吐量至关重要。对于深度学习而言,精确数学通常是过度要求。
  3. 算法层面的收益会叠加:K 循环拆分通过避免不必要的工作,带来了最大的单项改进(最高达 32%)。
  4. 自动调优优于手工启发式方法:cuTile 的自动调优器会针对每种序列长度发现最优分块大小(短序列为 64×64,长序列为 256×128),相比固定配置可带来 10-45% 的提升。
  5. 累积效果是乘法式的:完整的优化栈在所有序列长度上实现了 1.60x-1.66x 的加速——远高于任何单项优化单独带来的提升。

cuTile 使开发者能够用清晰、易读的 Python 代码表达这些优化——分块、快速数学控制、循环拆分、自动调优——同时为 NVIDIA GPU 生成高度优化的 PTX。

你可以在 TileGym 仓库中找到完全优化的内核。祝你尽情探索。

Like

标签

原文标题

Tuning Flash Attention for Peak Performance in NVIDIA CUDA Tile