中文内容
在这篇文章中,我们将深入探讨现代 AI 中最关键的工作负载之一:Flash Attention,你将了解:
- 如何使用 NVIDIA cuTile 实现 Flash Attention。逐步讲解一个可用于生产环境的完整实现代码。
- “陷阱与补救”的优化历程。这个案例研究展示了朴素的优化(例如仅仅增大 tile 大小)如何适得其反,以及如何修复这些问题。
- 用于实现最高性能的高级技术,例如 FMA 模式、快速数学、循环拆分和自适应 tiling。
环境要求:
- CUDA 13.1 或更高版本
- GPU 架构:计算能力 8.X、10.X、11.X、12.X(NVIDIA Ampere、NVIDIA Ada、NVIDIA Blackwell)
- Python:3.10 或更高版本
有关安装 cuTile Python 的更多信息,请参阅快速入门文档。
什么是 attention?
attention 机制是 transformer 模型的计算核心。给定一个 token 序列,attention 使每个 token 能够“查看”其他所有 token,并决定对它们的贡献赋予多大的权重。从数学上看,对于输入矩阵 Query(\(Q\))、Key(\(K\))和 Value(\(V\)),输出为:
正文:\(O = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V\)
其中:
- \(Q \text{ 的形状为 } (N,d),\ N \text{ 个查询 token,每个维度为 } d.\)
- \(K \text{ 的形状为 } (N,d),\ N \text{ 个键 token。}\)
- \(V \text{ 的形状为 } (N,d),\ N \text{ 个值 token。}\)
- \(\text{中间的 } QK^{T} \text{ 矩阵形状为 } (N,N),\text{这是一个问题。}\)
内存带宽问题
对于序列长度 \(N = 16,384\)(在现代 LLM 中很常见),注意力矩阵 \(QK^{T}\) 包含 \(N^2 = 268\) 百万个元素。在 FP16 中,这相当于每个注意力头、每个批次项需要 512 MB 的中间存储。
标准注意力实现:
- 计算完整的 \(N \times N\) 注意力矩阵并将其写入全局内存(速度较慢)
- 逐行应用 softmax
- 读回该矩阵并与 \(V\) 相乘
这种方法受内存带宽限制,因为 GPU 大部分时间都在等待数据在 HBM 和计算单元之间移动,而不是进行计算。
Flash Attention 如何解决内存带宽问题
Flash Attention(由 Dao 等人于 2022 年提出)是一种 IO 感知算法,它从不实例化完整的 \(N \times N\) 矩阵。相反,它:
- 对计算进行分块:以适合快速片上 SMEM 的小块处理 \(Q, K, V\)
- 使用在线 softmax:以增量方式计算 softmax,而不需要完整的行
- 融合操作:将矩阵乘法和 softmax 合并为单次内核执行
其结果是实现 2-4 倍加速并显著节省内存,从而支持更长的上下文长度。

理解在线 softmax
Flash Attention 的关键算法洞见是在线 softmax 技巧。数值稳定的安全 softmax 需要在计算前知道整行的最大值:
正文:\(\text{softmax}(x_i) = \frac{e^{x_i – \max(x)}}{\sum_j e^{x_j – \max(x)}}\)
但如果我们正在处理分块,就无法访问完整的一行。在线 softmax 通过维护可增量更新的运行统计量来解决这个问题。
在线 softmax 算法
我们为每一行维护两个运行中的值:
- \(m_i\):到目前为止看到的最大值(用于数值稳定性)
- \(l_i\):到目前为止看到的指数和(softmax 分母)
当我们处理一个包含值 \(x_{new}\) 的新分块时:
- 更新最大值:\(m_{new} = \max(m_i, \max(x_{new}))\)
- 计算校正因子:\(\alpha = e^{m_i – m_{new}}\)(重新缩放先前的计算结果)
- 更新求和项:\(l_i = l_i \cdot \alpha + \sum e^{x_{new} – m_{new}}\)
- 更新累加器:\(acc = acc \cdot \alpha + P_{new} \cdot V_{tile}\)
\(P_{new}\) 是注意力权重矩阵,\(V_{tile}\) 是值矩阵分块,对应于当前迭代中的 Key 分块。最后,我们进行归一化:\(O = acc / l_i\)
这使我们能够计算精确的 softmax,而无需存储完整的行。
因果注意力和分组查询注意力
在深入实现之前,我们先了解现代 LLM 中使用的两种重要注意力变体:
因果注意力
在 GPT、LLaMA 和 Claude 等自回归语言模型中,每个 token 只能关注序列中位于其之前的 token,而不能关注未来的 token。这可以防止训练过程中的“作弊”,即模型通过提前查看后续内容来预测下一个词。
从数学上看,我们会对注意力分数应用一个三角掩码:
\(\text{mask}_{ij} = \begin{cases} 0 & \text{如果 } i \geq j \text{(查询位置 ≥ 键位置)} \ -\infty & \text{如果 } i < j \text{(未来 token)} \end{cases}\)
掩码注意力变为:
正文:\(O = \text{softmax}\left(\frac{QK^T}{\sqrt{d}} + \text{mask}\right)V\)
向未来位置添加 \(-\infty\) 可确保它们在经过 softmax 后变为零,从而有效阻断来自未来 token 的信息流。

使用因果掩码时,大约一半的注意力矩阵会被掩蔽(上三角部分)。我们可以完全跳过这些被掩蔽图块的计算,从而带来 2 倍的算法加速。这对于 K 循环拆分优化至关重要。
分组查询注意力
标准多头注意力为每个注意力头分别配备 \(K,V\) 矩阵,导致内存占用较高:
- 多头注意力(MHA):32 个查询头 → 32 个 K/V 头(1:1 比例)
- 分组查询注意力(GQA):32 个查询头 → 4 个 K/V 头(8:1 比例)
- 多查询注意力(MQA):32 个查询头 → 1 个 K/V 头(32:1 比例)
在 GQA 中,多个查询头共享相同的 K/V 头。例如,使用 32 个查询头和 4 个 K/V 头时:
- 查询头 0-7 使用 K/V 头 0
- 查询头 8-15 使用 K/V 头 1
- 查询头 16-23 使用 K/V 头 2
- 查询头 24-31 使用 K/V 头 3
这在推理过程中将 K/V 缓存大小减少了 8 倍,对于服务长上下文模型至关重要。LlamA 2、Llama 3、Mistral 和 Qwen 等现代 LLM 广泛使用 GQA。
在 Flash Attention 中实现时,每个 CUDA 块为一个查询头计算注意力,但会加载相应的共享 K/V 头:
head_idx = bid_y % num_heads # Which query head (0-31) kv_head_idx = head_idx // query_group_size # Which K/V head (0-3)
当查询组大小为 8 时,查询头 0-7 都映射到 kv_head_idx = 0,共享内存中相同的 K/V tile。
第 1 部分:CUDA Tile 中的 flash attention kernel
让我们一步步实现 Flash Attention。我们的基线采用较小的 64×64 tile 和直观的代码——正确但尚未优化。
1. 定义内核接口
在 cuTile 中,@ct.kernel 装饰器将一个 Python 函数标记为 GPU 内核。我们使用 ct.Constant[T] 类型注解来传递编译时常量:
import math
import cuda.tile as ct
# Type aliases for compile-time constants
ConstInt = ct.Constant[int]
ConstBool = ct.Constant[bool]
# Conversion factor: we use exp2 instead of exp for efficiency
INV_LOG_2 = 1.0 / math.log(2)
@ct.kernel()
def fmha_kernel(
Q, K, V, Out, # Input/output tensors
qk_scale: float, # Scale factor (1/sqrt(d))
input_pos: int, # Position offset for causal masking
TILE_D: ConstInt, # Head dimension (for example, 128)
H: ConstInt, # Number of attention heads
TILE_M: ConstInt, # Tile size for Q dimension (for example, 64)
TILE_N: ConstInt, # Tile size for K/V dimension (for example, 64)
QUERY_GROUP_SIZE: ConstInt,# For Grouped Query Attention (GQA)
CAUSAL: ConstBool, # Whether to apply causal mask
EVEN_K: ConstBool, # Whether K length is divisible by TILE_N
):
2. 块 ID 映射
每个 CUDA 块计算输出的一个分块。使用 ct.bid,我们将二维网格映射到 batch/head 索引:
# Get block indices
bid_x = ct.bid(0) # Which tile along the sequence dimension
bid_y = ct.bid(1) # Which batch-head combination
# Decode batch and head from flattened index
batch_idx = bid_y // H
head_idx = bid_y % H
# For Grouped Query Attention: multiple Q heads share one K/V head
off_kv_h = head_idx // QUERY_GROUP_SIZE
3. 初始化累加器
在主循环之前,我们初始化在线 softmax 状态和输出累加器:
# Convert scale for base-2 exponential (faster than natural exp)
qk_scale = qk_scale * INV_LOG_2
# Create position indices for this tile
offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
offs_m += input_pos
offs_m = offs_m[:, None] # Shape: [TILE_M, 1]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
offs_n_tile = offs_n_tile[None, :] # Shape: [1, TILE_N]
# Online softmax state (float32 for numerical stability)
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32) # Running max
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32) # Running sum
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32) # Output accumulator
我们对累加器使用 float32,即使输入为 float16,也要在迭代式 softmax 计算过程中保持数值精度。
4. 加载查询分块
查询图块只加载一次,并在所有 K/V 迭代中重复使用:
# Load Q tile: shape [1, 1, TILE_M, TILE_D] -> [TILE_M, TILE_D]
q = ct.load(
Q,
index=(batch_idx, head_idx, bid_x, 0),
shape=(1, 1, TILE_M, TILE_D)
).reshape((TILE_M, TILE_D))
当图块超出张量边界时,ct.load 函数会自动处理边界条件。
5. K/V 图块上的主循环
这是 Flash Attention 的核心。我们遍历 K/V 图块:
# Calculate loop bounds
m_end = input_pos + (bid_x + 1) * TILE_M
k_seqlen = K.shape[2]
if CAUSAL:
# For causal attention, stop early (future tokens are masked)
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
for j in range(0, Tc):
# --- Step A: Load Key tile and compute QK^T ---
k = ct.load(
K,
index=(batch_idx, off_kv_h, 0, j),
shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2), # Transpose for correct layout
latency=2 # Hint for memory prefetching
).reshape((TILE_D, TILE_N))
# Matrix multiply: Q @ K^T
qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
qk = ct.mma(q, k, qk) # Uses Tensor Cores automatically
参数中的 order=(0,1,3,2) 告诉 cuTile 加载操作使用转置后的 K,而 latency=2 提示我们可以容忍一定延迟(从而实现更好的流水线化)。然后我们使用 ct.mma=(q, k, k,qk) 执行 cuTile 矩阵乘加运算。
6. 应用因果掩码
对于自回归模型(GPT、Llama 等),每个 token 只能关注之前的 token:
# --- Step B: Apply causal masking ---
if CAUSAL or not EVEN_K:
offs_n = j * TILE_N + offs_n_tile
mask = ct.full((TILE_M, TILE_N), True, dtype=ct.bool_)
# Boundary mask (for non-divisible sequence lengths)
if not EVEN_K:
mask = mask & (offs_n < k_seqlen)
# Causal mask: query position >= key position
if CAUSAL:
mask = mask & (offs_m >= offs_n)
# Convert to additive mask: True->0, False->-inf
mask = ct.where(mask, 0.0, -math.inf)
qk += mask
向被掩码的位置添加 -inf 可确保它们在 softmax 后变为零。
7. 在线 softmax 更新
现在我们更新运行中的 softmax 统计量:
# --- Step C: Online softmax ---
# Find max in current tile
qk_max = ct.max(qk, axis=-1, keepdims=True)
qk_max_scaled = qk_max * qk_scale
# Update running maximum
m_ij = max(m_i, qk_max_scaled)
# Scale QK scores
qk = qk * qk_scale
qk = qk - m_ij
# Compute attention weights (using exp2 for speed)
p = ct.exp2(qk)
# Update running sum
l_ij = ct.sum(p, axis=-1, keepdims=True)
alpha = ct.exp2(m_i - m_ij) # Correction factor
l_i = l_i * alpha
l_i = l_i + l_ij
# Rescale previous accumulator
acc = acc * alpha
8. 累加输出
最后,我们加载 Value tile 并进行累加:
# --- Step D: Load V and accumulate ---
v = ct.load(
V,
index=(batch_idx, off_kv_h, j, 0),
shape=(1, 1, TILE_N, TILE_D),
latency=4
).reshape((TILE_N, TILE_D))
# Cast attention weights back to input dtype for Tensor Core MMA
p = p.astype(Q.dtype)
# Accumulate: acc += P @ V
acc = ct.mma(p, v, acc)
# Update max for next iteration
m_i = m_ij
9. 最终归一化并存储
处理完所有图块后,我们按总和进行归一化并写入结果:
# --- Final: Normalize and store ---
acc = ct.truediv(acc, l_i)
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
启动内核:主机端代码
现在让我们看看启动内核的主机端代码:
import torch
from math import ceil
def tile_fmha(q, k, v, sm_scale=None, is_causal=True):
"""
Launch the Flash Attention kernel.
Args:
q: Query tensor, shape [batch, heads, seq_len, head_dim]
k: Key tensor, shape [batch, kv_heads, seq_len, head_dim]
v: Value tensor, shape [batch, kv_heads, seq_len, head_dim]
sm_scale: Softmax scale (default: 1/sqrt(head_dim))
is_causal: Whether to apply causal masking
Returns:
Output tensor, same shape as q
"""
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(q.size(-1))
batch_size, num_heads, seq_len, head_dim = q.shape
_, num_kv_heads, _, _ = k.shape
# Calculate query group size for GQA
query_group_size = num_heads // num_kv_heads
# Ensure contiguous memory layout
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
# Allocate output
o = torch.empty_like(q)
# Choose tile sizes (we'll optimize this later!)
TILE_M, TILE_N = 64, 64
# Calculate grid dimensions
grid_x = ceil(seq_len / TILE_M) # Number of tiles along sequence
grid_y = batch_size * num_heads # One block per batch-head pair
grid = (grid_x, grid_y, 1)
# Check if K length is evenly divisible
EVEN_K = (k.shape[2] % TILE_N) == 0
# Launch kernel
ct.launch(
torch.cuda.current_stream(),
grid,
fmha_kernel,
(q, k, v, o, sm_scale, 0, head_dim, num_heads,
TILE_M, TILE_N, query_group_size, is_causal, EVEN_K)
)
return o
这个采用 64×64 分块的基线版本可以正确运行。但我们能让它更快吗?让我们来看看。
第 2 部分:“陷阱与救援”的优化之旅
我们在以下配置上进行基准测试:
- 硬件:NVIDIA B200
- 批量大小:4,头数:32,头维度:128
- 注意力:因果,数据类型:FP16
- 序列长度:1024、2048、4096、8192、16384
为解释每一步,我们使用 Nsight Compute,并采用最小化的分析区段集合:
- 正文:LaunchStats
- 正文:Occupancy
- 正文:SpeedOfLight
- 正文:ComputeWorkloadAnalysis
- 内存工作负载分析
基线性能
这是我们使用 64×64 分块且未进行优化的起点。
NCU 分析(SeqLen=1024,B200):
- 每线程寄存器数:128
- 理论/实际占用率:25% / 19.8%
- 计算(SM)吞吐量:37.8%
- 内存吞吐量:19.7%
- 网格大小:2,048
1. 更大分块的陷阱
GPU 编程中一个常见的直觉是“分块越大 = 性能越好”。更大的分块:
- 分摊内存访问开销。
- 提高 L2 缓存利用率。
- 降低每个元素的内核启动开销。
因此,让我们将 tile 大小从 64×64 增加到 256×128:
TILE_M, TILE_N = 256, 128 # Was 64, 64
预期是更好的内存带宽利用率 → 更快的性能。然而,以 TFLOPS 计的结果是:
在所有序列长度下,性能下降了 18-43%。这就是陷阱:大的 tile 会让性能变得更差。
为什么会发生这种情况?
- 计算瓶颈:随着每个分块中的元素增多,低效操作(单独的乘法/加法、精确数学运算)会成为瓶颈。
- 指令开销:每个分块的工作量增加,意味着在下一次内存操作之前需要执行更多指令。
经验:分块大小与计算效率相互依赖。只有当计算足够高效、能够跟上时,大分块才有帮助。
NCU 洞察(SeqLen=1,024,NVIDIA B200):
- 每线程寄存器数跃升至 168(+31%),理论占用率降至 18.75%
- 实际占用率降至 16.5%
- 计算吞吐量暴跌至 17.4%(陷阱)
- 内存吞吐量降至 7.4%
- 网格大小缩小到 512(由于 tile 更大,block 数量更少)
2. 使用快速数学运算进行挽救
瓶颈之一是特殊函数:exp2(指数)和 truediv(除法)。默认情况下,这些函数遵循 IEEE-754 精度——准确度很高,但速度较慢。
对于深度学习,我们可以用极小的精度损失换取巨大的速度提升:
之前(精确运算):
p = ct.exp2(qk) alpha = ct.exp2(m_i - m_ij) acc = ct.truediv(acc, l_i)
之后(快速数学):
p = ct.exp2(qk, flush_to_zero=True) alpha = ct.exp2(m_i - m_ij, flush_to_zero=True) acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
这些标志的作用:
- flush_to_zero=True:非规格化数(接近零的极小值)会变为精确的零。这可以避免 GPU 上缓慢的微代码路径。
- rounding_mode=RMd.APPROX:在初始硬件近似之后跳过迭代细化。
借助 fast math,我们“挽救”了大尺寸 tile,结果以 TFLOPS 表示如下:
我们现在达到或超过了小尺寸 tile 的基线水平,并且在较长序列上获得了 10-20% 的提升。
NCU 洞察(SeqLen=1,024,NVIDIA B200):
- 每线程寄存器数:168(未变化)
- 理论/实际占用率:18.75% / 16.6%(未变化)
- 计算吞吐量回升至 24.0%
- 内存吞吐量提升至 12.9%
3. K 循环拆分
对于因果注意力,我们应用三角掩码:每个 query 只能关注更早位置的 key。在我们的基线中,我们在每次循环迭代时都会检查 if CAUSAL: mask …。
但想一想:对于位置 1000 的 query tile,大多数 key tile(0-900)根本不需要任何掩码。只有靠近对角线的 tile 才需要掩码。而 query 位置之后的 tile 会被完全掩码(我们可以直接跳过它们)。

该优化将循环拆分为多个阶段:
# Calculate where masking starts being necessary
mask_start = (input_pos + bid_x * TILE_M) // TILE_N
mask_start = min(mask_start, k_seqlen // TILE_N)
# Calculate where to stop (for causal, we exit early)
if CAUSAL:
Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
for j in range(0, Tc):
# Load K and compute QK...
# ONLY apply masking when necessary
if (CAUSAL or not EVEN_K) and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
mask = ct.full((TILE_M, TILE_N), True, dtype=ct.bool_)
if not EVEN_K:
mask = mask & (offs_n < k_seqlen)
if CAUSAL:
mask = mask & (offs_m >= offs_n)
mask = ct.where(mask, 0.0, -math.inf)
qk += mask
# Continue with softmax and accumulation...
这为何重要:对于一个包含 16K 序列、256-token 图块的情况:
- 约 50% 的图块完全未被掩蔽(无分支,无需计算掩码)
- 每行约有 1 个 tile 被部分掩蔽(完整逻辑)
- 其余部分被完全跳过(提前退出)
结果(TFLOPS):
这是最大的一项单独优化——在所有序列长度上最高可实现 32% 的加速。
NCU 洞察(SeqLen=1,024,B200):
- 寄存器/线程:168(未变化)
- 理论/实际占用率:18.75% / 16.6%(未变化)
- 内存吞吐率提升至 14.5%(浪费的工作更少)
- 计算吞吐率保持在 24.0%(工作更有用,但每个周期不一定更快)
4. ProgramId 重映射
一个微妙的优化是针对因果注意力反转块顺序。当我们按反向(从右下到左上)处理图块时,由于因果掩码,后启动的块工作量更少。这改善了负载均衡并减少了尾部效应。
之前(标准顺序):
bid_x = ct.bid(0) # Process tiles 0, 1, 2, ...
之后(为因果性而反转):
if CAUSAL:
bid_x = NUM_M_BLOCKS - 1 - ct.bid(0) # Process tiles N, N-1, N-2, ...
else:
bid_x = ct.bid(0)
这一小改动改善了波次调度,因为各个块在 GPU 上完成得更加均匀。
结果(TFLOPS):
有一个幅度不大但稳定的 1–3% 提升,尤其在较长序列中更为明显,因为此时尾部效应影响最大。
5. 自动调优
我们已经优化了大分块,但有一个问题:短序列仍然更适合小分块。
为什么?对于一个 1,024 个 token 的序列和 256 个 token 的分块,我们只有 4 个分块。这不足以充分利用 B200 上的所有 SM。更小的分块(64×64)会产生 16 个分块,从而更好地填满 GPU。
与其手动选择阈值,我们可以让 cuTile 的自动调优器对多种配置进行基准测试,并为每种输入形状缓存最佳配置。
自动调优器方法:
def _fmha_autotune_configs():
"""Search space for autotuning.
The autotuner will benchmark these configurations and cache the best one
per input shape (sequence length, batch size, etc.).
"""
gpu_capability = torch.cuda.get_device_capability()
if gpu_capability in [(12, 0), (12, 1)]:
# RTX 50 series (sm120, sm121)
yield SimpleNamespace(TILE_M=64, TILE_N=64, num_ctas=1, occupancy=2)
else:
# B200/GB200 (sm100) - Try multiple tile sizes
# Autotuner will discover:
# - 64x64 is best for short sequences (1024-2048)
# - 128x128 may be best for medium sequences (4096)
# - 256x128 is best for long sequences (8192+)
yield SimpleNamespace(TILE_M=64, TILE_N=64, num_ctas=1, occupancy=2)
yield SimpleNamespace(TILE_M=128, TILE_N=128, num_ctas=1, occupancy=2)
yield SimpleNamespace(TILE_M=256, TILE_N=128, num_ctas=1, occupancy=1)
如何使用自动调优启动:
不要直接调用 ct.launch,而是使用 ct_experimental.autotune_launch:
import cuda.tile_experimental as ct_experimental
def autotune_launch_fmha(
stream, q, k, v, o, sm_scale, input_pos,
hidden_size, num_heads, query_group_size, is_causal
):
batch_size, _, q_len, _ = q.shape
def _grid_fn(cfg):
return (math.ceil(q_len / cfg.TILE_M), batch_size * num_heads, 1)
def _args_fn(cfg):
num_m_blocks = math.ceil(q_len / cfg.TILE_M)
even_k = (k.shape[2] % cfg.TILE_N) == 0
return (
q, k, v, o, sm_scale, input_pos,
hidden_size, num_heads, cfg.TILE_M, cfg.TILE_N,
query_group_size, is_causal, even_k, num_m_blocks,
)
ct_experimental.autotune_launch(
stream,
grid_fn=_grid_fn,
kernel=fmha_kernel,
args_fn=_args_fn,
hints_fn=lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy},
search_space=_fmha_autotune_configs,
)
注意:自动调优器 API 可能会发生变化。
自动调优器会智能地工作:
- 首次以 seq_len=1024 调用:对所有 3 个配置进行基准测试,并缓存最佳配置
- 首次调用 seq_len=2048:对所有 3 种配置进行基准测试,并缓存最佳配置
- 后续调用:使用已缓存的配置(零开销)
缓存键包含张量形状,因此不同的序列长度会自动获得不同的最佳配置。
以 TFLOPS 表示的结果:
自动调优器发现,对于序列长度 ≤2,048,64×64 分块效果最佳;随后在更长序列上切换到更大的分块。与固定的大分块相比,这在短序列上带来了额外 45% 的性能提升,同时在长序列上保持峰值性能。
自动调优器的选择(在 B200 上):
- SeqLen 1,024:64×64 分块(高并行度)
- SeqLen 2,048:64×64 或 128×128 分块(均衡)
- SeqLen 4,096+:128×128 或 256×128 分块(内存效率)
我们现在无需手动调优,即可在所有序列长度上实现最佳性能。
总结:优化栈
最终加速比:在所有序列长度上达到 1.60x-1.66x。
入门
编写高性能内核很少是为了找到某个“神奇”的设置。正如我们在“陷阱与补救”中看到的那样:
- 优化是相互依赖的:在我们修正数学运算之前,大分块反而更慢。不能孤立地评估分块大小。
- 数学很重要:flush_to_zero 和 APPROX 等标志对于释放 Tensor Core 吞吐量至关重要。对于深度学习而言,精确数学通常是过度要求。
- 算法层面的收益会叠加:K 循环拆分通过避免不必要的工作,带来了最大的单项改进(最高达 32%)。
- 自动调优优于手工启发式方法:cuTile 的自动调优器会针对每种序列长度发现最优分块大小(短序列为 64×64,长序列为 256×128),相比固定配置可带来 10-45% 的提升。
- 累积效果是乘法式的:完整的优化栈在所有序列长度上实现了 1.60x-1.66x 的加速——远高于任何单项优化单独带来的提升。
cuTile 使开发者能够用清晰、易读的 Python 代码表达这些优化——分块、快速数学控制、循环拆分、自动调优——同时为 NVIDIA GPU 生成高度优化的 PTX。
你可以在 TileGym 仓库中找到完全优化的内核。祝你尽情探索。
标签




















