元鉴
返回中文阅读流

NVIDIA Developer Blog

使用 NVIDIA Warp 为 AI 构建加速的可微分计算物理代码

计算机辅助工程 (CAE) 正从人工驱动的工作流转向 AI 驱动的工作流,包括跨...泛化的物理基础模型。

中文内容

已翻译official company source英文原文2026-05-26

计算机辅助工程(CAE)正从人工驱动的工作流向 AI 驱动的工作流转变,其中涵盖了能够跨几何形状和工况泛化的物理基础模型。与大语言模型(LLM)不同,这些模型依赖于大量高保真且符合物理规律的数据。

近期关于计算流体动力学(CFD)代理模型的缩放定律研究表明,仿真生成的训练数据往往是实际应用中制约成本的关键因素。这给模拟器提出了更高要求:它必须是 GPU 原生的、运行快速的,并且能够直接接入机器学习工作流。

NVIDIA Warp 是一个用于加速仿真、数据生成和空间计算的框架,它桥接了 CUDA 与 Python。Warp 使开发者能够将高性能内核编写为常规的 Python 函数,并通过即时编译(JIT)转化为在 GPU 上高效执行的代码。与基于张量的框架(开发者需将计算表述为对整个 N 维数组的操作)不同,在 Warp 框架中,开发者编写的是灵活的内核,这些内核可在计算网格的所有元素上同时执行。

仿真内核通常基于计算网格进行表达,并依赖于数据相关的控制流,例如条件判断、提前退出以及针对每个元素的选择性更新。在张量框架中,处理这些模式需要构建布尔掩码,这很快就会变得难以管理,并且可能在无关元素上浪费计算资源。而在 Warp 内核中,每个线程可以独立地进行分支、跳过或退出,从而自然地表达此类逻辑,无需借助掩码变通方案。

此外,正如本文将展示的,借助 Warp 对自动微分的原生支持,用 Warp 编写的求解器可以轻松地实现可微性。它们能够便捷地集成到优化或训练工作流中,同时保持与 PyTorch、JAX 和 NumPy 等框架的互操作性,适用于涵盖仿真、机器人、感知和几何处理等多种应用场景。

本文将逐步指导您如何完全使用 Warp 构建一个二维 Navier-Stokes 求解器。文中阐述了 Warp 编程模型如何映射至偏微分方程求解器。随后,通过对仿真过程进行微分,端到端地求解了最优扰动问题。最后,通过工业案例研究展示了 Warp 在实际生产工作流中的赋能潜力。欲了解更多信息,请参阅 NVIDIA/warp GitHub 仓库中的二维 Navier-Stokes 求解器示例和二维 Navier-Stokes 最优扰动示例。

如何使用 Warp 编写二维 Navier-Stokes 求解器

为了将重点放在 Warp 而非数值方法上,此处采用了一个经典的二维衰减湍流教科书示例,该示例由不可压缩 Navier-Stokes 方程的涡量-流函数形式描述。涡量 \(\omega\) 的演化遵循以下输运方程:

正文:\(\frac{\partial \omega}{\partial t} + \frac{\partial \psi}{\partial y}\frac{\partial \omega}{\partial x} – \frac{\partial \psi}{\partial x}\frac{\partial \omega}{\partial y} = \frac{1}{\text{Re}}\nabla^2 \omega \tag{1}\)

而流函数 \(\psi\) 则通过泊松方程由涡量求解得出:

正文:\(\nabla^2 \psi = -\omega \tag{2}\)

在周期性边界条件下,上述方程在傅里叶空间中简化为代数方程,从而无需使用迭代求解器:

正文:\(\hat{\psi}_{m,n} = \frac{\hat{\omega}_{m,n}}{k_x^2 + k_y^2} \tag{3}\)

其中 \((k_x, k_y)\) 为傅里叶空间中的波数对。该求解器利用快速傅里叶变换(FFT)算法,高效地将 \(\omega\) 和 \(\psi\) 转换至傅里叶空间,反之亦然。

每个时间步包含两个子步骤(图1)。首先,涡量输运方程在 \(L \times L\) 方形区域内的 \(N \times N\) 网格上进行离散。采用三阶强稳定性保持Runge-Kutta(RK3)格式将解沿时间方向推进 \(\Delta t\),从而得到 \(\omega(t+\Delta t)\)。其次,在傅里叶空间中求解泊松方程,以获取更新后的 \(\psi(t+\Delta t)\)。

Flowchart of one solver timestep: starting with $\omega(t)$ and $\psi(t)$, discretization/time marching computes $\omega(t+\Delta t)$, then a Fourier Poisson solver computes $\psi(t+\Delta t)$, which feeds back to the next timestep.Flowchart of one solver timestep: starting with $\omega(t)$ and $\psi(t)$, discretization/time marching computes $\omega(t+\Delta t)$, then a Fourier Poisson solver computes $\psi(t+\Delta t)$, which feeds back to the next timestep.
图1. 求解器单时间步循环示意图

因此,正向求解器包含两个基础模块,将在后续章节中予以详述:

  • 用于离散化与时间推进的 Warp 核函数
  • 基于 FFT 的泊松求解器

基础模块 1:有限差分离散化与时间推进

涡量输运方程中的对流项与扩散项采用如图2所示的二阶中心有限差分格式进行近似。虽然也可采用高阶离散化方法,但出于简化考虑,此处选用了中心二阶格式。

Finite difference stencils for $latex \omega$Finite difference stencils for $latex \omega$
Finite difference stencils for $latex \psi$Finite difference stencils for $latex \psi$
图2. \(\omega\) 与 \(\psi\) 的有限差分模板

rk3_update() 核函数用于计算扩散项和对流项,并执行单次 RK3 子步更新。step() 函数在每个时间步调用该核函数三次,分别对应 RK3 的三个阶段,每个阶段使用不同的系数(coeff0、coeff1、coeff2)。

@wp.kernel
def rk3_update(
    n: int, h: float, re: float, dt: float,
    coeff0: float, coeff1: float, coeff2: float,
    omega_0: wp.array2d(dtype=float),
    omega_1: wp.array2d(dtype=float),
    psi: wp.array2d(dtype=float),
    omega_out: wp.array2d(dtype=float)
): 

   """Perform a single substep of SSP-RK3."""

    i, j = wp.tid()

    left = cyclic_index(i - 1, n)
    right = cyclic_index(i + 1, n)
    top = cyclic_index(j + 1, n)
    down = cyclic_index(j - 1, n)

    inv_h2 = 1.0 / (h * h)
    laplacian = (
        omega_1[right, j] + omega_1[left, j] + omega_1[i, top] + omega_1[i, down] - 4.0 * omega_1[i,j]
    ) * inv_h2

    inv_2h = 1.0 / (2.0 * h)
    j1 = ((omega_1[right, j] - omega_1[left, j]) * inv_2h) * ((psi[i, top] - psi[i, down]) * inv_2h)
    j2 = ((omega_1[i, top] - omega_1[i, down]) * inv_2h) * ((psi[right, j] - psi[left, j]) * inv_2h)

    rhs = (1.0 / re) * laplacian + j2 - j1

    omega_out[i, j] = coeff0 * omega_0[i, j] + coeff1 * omega_1[i, j] + coeff2 * dt * rhs

rk3_update() 核函数遵循单指令多线程(SIMT)范式,其中每个线程映射到计算域上的一个网格点,所有 \(N \times N\) 个点均通过单次 wp.launch() 调用实现同步更新。

wp.launch(rk3_update,
          dim=(self.n, self.n), # one thread per grid point
          inputs=[self.n, self.h, self.re, self.dt,
                  stage_coeff[0], stage_coeff[1], stage_coeff[2],
                  self.omega_0, 
                  self.omega_1, 
                  self.psi,
                ],
	        outputs=[self.omega_tmp]
         )
SIMT update on an $N\times N$ grid: $N^2$ threads run in parallel, one per cell $(i,j)$; each thread reads the five-point stencil values from timestep $n-1$ and writes the updated $\omega_{i,j}^n$ for timestep $n$.SIMT update on an $N\times N$ grid: $N^2$ threads run in parallel, one per cell $(i,j)$; each thread reads the five-point stencil values from timestep $n-1$ and writes the updated $\omega_{i,j}^n$ for timestep $n$.
图3. 二维网格上 \(\omega\) 的 SIMT 更新。线程 (i, j) 利用当前时间步模板中相邻网格点的值,将网格单元 (i, j) 更新至下一时间步。

构建模块 2:FFT 泊松求解器

基于 Warp 图块的原语支持在傅里叶空间中求解泊松方程。关键操作是 wp.tile_fft() 和 wp.tile_ifft(),它们分别对加载到图块中的单行数据执行正向和反向 FFT。针对 \(N \times N\) 数组的完整二维 FFT 随后被分解为三个步骤:逐行 FFT -> 转置 -> 逐行 FFT。图 4 的示意图解释了 fft_tiled() 和 ifft_tiled() 在底层是如何计算正向和反向 FFT 的。

@wp.kernel
def fft_tiled(x: wp.array2d(dtype=wp.vec2f), y: wp.array2d(dtype=wp.vec2f)):
    """Row-wise FFT using tile primitives."""
    i, _, _ = wp.tid()
    a = wp.tile_load(x, shape=(1, N_GRID), offset=(i, 0))
    wp.tile_fft(a)
    wp.tile_store(y, a, offset=(i, 0))
@wp.kernel
def ifft_tiled(x: wp.array2d(dtype=wp.vec2f), y: wp.array2d(dtype=wp.vec2f)):
    """Row-wise inverse FFT using tile primitives."""
    i, _, _ = wp.tid()
    a = wp.tile_load(x, shape=(1, N_GRID), offset=(i, 0))
    wp.tile_ifft(a)
    wp.tile_store(y, a, offset=(i, 0))
Row-wise GPU on an $N\times N$ grid: one thread block per row loads the row into a register tile, performs an in-place FFT cooperatively, and stores the result to a new array in the frequency domain.Row-wise GPU on an $N\times N$ grid: one thread block per row loads the row into a register tile, performs an in-place FFT cooperatively, and stores the result to a new array in the frequency domain.
图 4. 在 NxN 网格上执行逐行 tile_fft。每个线程块将一行数据加载至寄存器图块中,协作完成 FFT 计算,并将结果存回全局内存。

二维 FFT 在逐行计算之间还需要一次转置操作。该操作可采用 SIMT 或图块范式(通过 wp.tile_transpose 实现)。为简洁起见,下方展示的是 SIMT 版本:

@wp.kernel
def transpose(x: wp.array2d(dtype=wp.vec2f), y: wp.array2d(dtype=wp.vec2f)):
    i, j = wp.tid()
    y[i, j] = x[j, i]

将这三个内核组合起来,即 fft_tiled -> transpose -> fft_tiled,即可得到完整的二维正向 FFT。反向 FFT 遵循相同的模式,只需使用 ifft_tiled 即可。

组合基础模块

示例中的 step() 函数依赖于其他几个辅助内核,此处不作详细讨论。有关这些内核的定义,请参阅 NVIDIA/warp GitHub 仓库中的二维 Navier–Stokes 求解器示例。在所有基础模块就绪后,单次调用 step() 即可将模拟推进一个时间步。出于模块化设计考虑,示例代码中的 self._solve_poisson() 方法对 \(\omega(t+\Delta t) \xrightarrow{\text{FFT}} \hat{\omega} \xrightarrow{\text{Eq.\,3}} \hat{\psi} \xrightarrow{\text{IFFT}} \psi(t+\Delta t)\) 流程进行了抽象封装。

 def step(self) -> None:
        """Advance simulation by one timestep using SSP-RK3."""
        for stage_coeff in self.rk3_coeffs:
            wp.launch(
                rk3_update,
                dim=(self.n, self.n),
                inputs=[
                    self.n, self.h, self.re, self.dt,
                    stage_coeff[0], stage_coeff[1], stage_coeff[2],
                    self.omega_0,
                    self.omega_1,
                    self.psi,
                ],
               outputs=[self.omega_tmp],
            )
            # Swap buffers for next RK3 substep
            self.omega_1, self.omega_tmp = self.omega_tmp, self.omega_1

            # Update streamfunction for next timestep
            self._solve_poisson()
        
        # Copy updated vorticity to self.omega_0 for the next timestep
        wp.copy(self.omega_0, self.omega_1)

运行该求解器可生成图 5 所示的衰减湍流场。在 GPU 上,step() 函数通过 wp.ScopedCapture 被捕获为 CUDA Graph,并在后续所有帧中使用 wp.capture_launch() 进行重放,从而消除了每次启动的开销。

Pseudocolor GIF of two-dimensional decaying turbulence at $\mathrm{Re}=1000$, showing intertwined vortical filaments and eddy structures across the domain.Pseudocolor GIF of two-dimensional decaying turbulence at $\mathrm{Re}=1000$, showing intertwined vortical filaments and eddy structures across the domain.
图 5. Re = 1,000 时的二维衰减湍流

对求解器进行微分

既然可运行的求解器已构建完毕,接下来的问题是如何使其具备可微性。

自动微分(AD)通过对计算图中的每个基本运算应用链式法则,计算程序的精确导数。与有限差分法不同,AD 避免了步长调整的难题,并能生成达到机器精度的梯度。AD 应用于 PDE 求解器的核心优势在于规模扩展性:在大型网格上进行复杂模拟时,单次正向求解的成本已经很高,因此有限差分等方法需要执行 \(O(n)\) 次完整求解才能计算出针对 \(n\) 个输入的梯度。

反向模式 AD 仅需约一次前向计算和一次反向计算即可求出所有 \(\partial \mathcal{L}/\partial x_i\),从而使得在生产级分辨率下进行基于梯度的优化切实可行。这一思想与神经网络中的反向传播完全一致,也正是深度学习和大规模物理优化能够处理数百万自由度的根本原因。

Warp 自动微分系统在编译期间会为可微模拟生成程序的两个版本:

  • 前向版本:该版本代码负责接收物理输入(如初始条件、离散化控制方程等),计算模拟输出(如物理场、衍生量),并生成伴随版本所需的中间数组。
  • 伴随版本:正向仿真的自动生成对应版本,用于计算所选感兴趣量关于仿真输出的灵敏度,并将其一直反向传播至输入端。该反向传播过程复用正向执行过程中的中间数组,在整个求解器上应用微分链式法则,从而在不构建大型符号表达式的情况下得出仿真伴随。

开发者编写正向物理逻辑,Warp 负责处理梯度计算。任何需要可微的 wp.array 在分配时均需设置 requires_grad=True ,该参数会指示 Warp 分配一个用于存储伴随变量的配套数组。生成的伴随量既可独立使用(如本例所示),也可与 PyTorch 或 JAX 交互对接,用于端到端优化,包括训练机器学习模型。目前,Warp 仅支持反向模式自动微分。

为作说明,此处针对《Prediction and Control of Two-Dimensional Decaying Turbulence Using Generative Adversarial Networks》中阐述的最优扰动问题展开求解。在湍流中,初始条件的微小扰动会随时间不断放大,并显著改变流动轨迹。识别增长最快的扰动是实现流动控制以及理解流场中哪些结构具有动力学重要性的关键一步。具体而言,此处旨在寻找初始涡度扰动 \(\Delta\omega\),使其在提前时间 \(\tau\) 处最大化受扰动与未受扰动轨迹之间的偏离程度。

设 \(F^{\tau}\) 表示运行 \(\tau\) 个时间单位的正向求解器。未受扰动轨迹为 \(Y^{*} = F^{\tau}(\omega_0)\),受扰动轨迹为 \(\tilde{Y} = F^{\tau}(\omega_0 + \Delta\omega)\)。均方误差(MSE)

正文:\(\mathrm{MSE} = -\frac{1}{N^2}\left\| Y^* – \tilde{Y} \right\|_2^2 \tag{4}\)

被最小化,其中负号将轨迹发散的最大化问题转化为最小化问题。为约束优化过程,需满足 \(\mathrm{rms}(\Delta\omega) \leq 0.2 \times \mathrm{rms}(\omega_0)\),即扰动的均方根(RMS)不得超过初始涡度场 \(\omega_0\) 的均方根的 20%。

更多详情请参见 NVIDIA/warp GitHub 仓库中的二维 Navier-Stokes 最优扰动示例。接下来的部分将重点介绍使正向求解器具备可微性的三个关键改动。

无原地修改

wp.Tape() 会在正向传播过程中记录内核的启动,并在反向传播时逆序重放以计算梯度。这仅在反向传播所需的中间值仍然可用时才有效,因此数组不能被随意原地覆盖。这是与不可微求解器的关键区别。在仅正向运行的版本中,每个时间步结束时可以交换两个数组 omega_0 和 omega_1:

wp.copy(omega_0, omega_1)

对于可微求解器,右端项(RHS)计算和 RK3 更新需要拆分为写入不同数组的独立内核。因此,单次 RK3 更新将变为如下形式。请注意,此时不能再像以前那样在每个时间步结束时将 omega_1 的值复制给 omega_0。

omega_out[i, j] = coeff0 * omega_0[i, j] + coeff1 * omega_in[i, j] + coeff2 * dt * rhs[i, j]

在 Warp 中,所有中间数组都需要由用户显式定义。这要求在每个时间步的每个 RK 子步中预先分配独立的数组,这通常是任何可微求解器中最主要的 GPU 内存开销。

self.omega_timestep = [wp.zeros((n, n), dtype=wp.float32, requires_grad=True) for _ in range(T + 1)]

# Intermediate arrays for each RK3 substep for each timestep
self.omega_stage = []
self.psi_stage = []
self.rhs_stage = []
self.fft_arrays = []

for _ in range(T):
    s_omega, s_psi, s_rhs, s_fft = [], [], [], []
    for _ in range(3):
        s_omega.append(wp.zeros((n, n), dtype=wp.float32, requires_grad=True))
        s_psi.append(wp.zeros((n, n), dtype=wp.float32, requires_grad=True))
        s_rhs.append(wp.zeros((n, n), dtype=wp.float32, requires_grad=True))
        s_fft.append({"omega_complex": wp.zeros((n, n), dtype=wp.vec2f, requires_grad=True),
                      # ... plus 4 FFT scratch arrays, each (n, n) vec2f
                    })
    self.omega_stage.append(s_omega)
    self.psi_stage.append(s_psi)
    self.rhs_stage.append(s_rhs)
    self.fft_arrays.append(s_fft)

为每个中间状态存储 Warp 数组的开销与时间步数呈线性增长,在长时间运行中会变得难以承受。一种常见的方法是梯度检查点,仅保存选定的状态,然后在反向传播期间使用前向求解器重新计算缺失的部分。该方法以额外的前向计算为代价,换取显著降低的内存占用。如需了解如何在 Warp 中实现梯度检查点的示例,请参阅 NVIDIA/warp GitHub 仓库中的 fluid checkpoint 示例。

使用 wp.Tape() 记录梯度

在预分配好数组后,记录前向传播并对其进行求导的过程非常直接:

with wp.Tape() as tape:
    forward()  # wp.launch calls that take omega from t0 to t0 + lead t and calculate MSE 
tape.backward(loss) # Automatic differentiation to get derivatives of loss w.r.t Warp arrays

wp.Tape() 上下文会将每次 wp.launch() 调用记录到计算图中。tape.backward(loss) 会反向遍历该计算图,计算损失相对于 Warp 数组的导数。此处的重点是损失相对于 \(\Delta{\omega}\) 的梯度,可通过 delta_omega.grad 获取。

优化循环

以下代码块展示了一个优化步骤。`forward()` 函数作用于受扰动的初始涡度,以生成最终场和损失(即与未受扰动运行结果的均方误差)。`tape` 负责记录该次前向传递过程中的内核启动。接着,`tape.backward(loss)` 沿记录的计算图执行反向传播,以计算相对于扰动的梯度;`optimizer.step()` 则据此更新扰动以降低损失。最后,在下一轮迭代开始前,`tape.zero()` 会清除已累积的梯度。

with wp.Tape() as tape:
    forward() # Loss is computed inside forward() function

tape.backward(loss)
optimizer.step([delta_omega.grad.flatten()])
tape.zero()

经过1000次迭代后,优化器发现了一种能够放大轨迹发散的结构化扰动 \(\Delta\omega\),使均方误差(MSE)从接近零上升至约250。通过求解器内循环优化获得的扰动场,在定性上与《Prediction and Control of Two-Dimensional Decaying Turbulence Using Generative Adversarial Networks》一文中报告的结果相似。

Optimization GIF for 1,000 iterations. $\mathrm{MSE}(Y^*,\tilde{Y}$ decreases over iterations (left), with field snapshots showing the baseline $\omega_0$ (top center), learned perturbation $\Delta\omega$ (top right), target $Y^*$ (bottom cOptimization GIF for 1,000 iterations. $\mathrm{MSE}(Y^*,\tilde{Y}$ decreases over iterations (left), with field snapshots showing the baseline $\omega_0$ (top center), learned perturbation $\Delta\omega$ (top right), target $Y^*$ (bottom c
图6. 经过1000次迭代的优化过程,右上角为所发现的扰动

如需了解更多内容,NVIDIA/warp 的 GitHub 仓库中提供了除计算流体力学(CFD)以外的更多可微求解器示例。另请参阅一份不断增长的、利用 Warp 的研究出版物列表。

Warp 实战:AI 驱动的工业工作流案例研究

在实际的 AI 工作流中,仿真与几何计算通常嵌套在更大的系统中(如代理模型、强化学习、设计优化等)。PyTorch 和 JAX 负责处理训练与张量运算,但仿真层还需引入阶段性时间步进、模板更新及大规模空间查询。Warp 正是针对这一内核密集型(kernel-heavy)层:用户可自主控制执行流程,通过内核融合来降低内存访问与启动开销,并利用 CUDA Graphs 减少重复调度。此外,它还能与 PyTorch 和 JAX 张量实现零拷贝互操作。

正文:Autodesk XLB

Autodesk Research 开发了 XLB,这是一款基于 Python 的可微格子玻尔兹曼(Lattice Boltzmann)求解器,同时支持 Warp 与 JAX 后端,从而可在相同的数学模型与硬件上进行直接对比。在约 1.34 亿网格单元的顶盖驱动腔体基准测试中,在单块 40 GB NVIDIA A100 Tensor Core GPU 上,Warp 的运行速度约为 JAX 的 8 倍,其吞吐量大致相当于 JAX 需调用 8 块 NVIDIA A100 Tensor Core GPU 才能达到的水平。在更大规模场景下,Warp 的内存占用降低了约 2.5 至 3 倍,并成功完成了最大规模的测试案例,而在同一 GPU 上 JAX 则出现了显存溢出。

Two bar charts comparing Warp and JAX on NVIDIA A100. Left: throughput in MLUPS: Warp single-GPU (8879) versus JAX single-GPU (1139) and JAX 8-GPU (8397). Right: memory usage at 128^3, 256^3, and 512^3 domain sizes — JAX OOMs at 512^3 whileTwo bar charts comparing Warp and JAX on NVIDIA A100. Left: throughput in MLUPS: Warp single-GPU (8879) versus JAX single-GPU (1139) and JAX 8-GPU (8397). Right: memory usage at 128^3, 256^3, and 512^3 domain sizes — JAX OOMs at 512^3 while
图 7. Warp 与 JAX 的吞吐量及内存使用对比

欲了解更多信息,请参阅《Autodesk Research Brings Warp Speed to Computational Fluid Dynamics on NVIDIA GH200》。

正文:Google DeepMind MuJoCo

Google DeepMind 近期发布了 MuJoCo Warp(MJWarp),这是一个基于 Warp 的大规模多体动力学后端。在同等硬件条件下,该 Warp 后端相较于 JAX 实现了高达 252 倍(移动)和 475 倍(操作)的加速。MJWarp 通过利用稀疏矩阵运算和推测执行来更精确地调度计算,从而实现这一性能,同时保持与 JAX 训练的无缝兼容。

MJWarp physics step throughput versus MuJoCo MJX on LEAP benchmarks.MJWarp physics step throughput versus MuJoCo MJX on LEAP benchmarks.
MJWarp physics step throughput versus MuJoCo MJX on Apptronik benchmarks.MJWarp physics step throughput versus MuJoCo MJX on Apptronik benchmarks.
图 8. 在 LEAP 机械手操作与 Apptronik 移动基准测试中,MJWarp 与 MuJoCo MJX 的物理步吞吐量对比

欲了解更多信息,请参阅 MuJoCo Warp 发布公告。

C-Infinity AutoAssembler 

原文标题

Build Accelerated, Differentiable Computational Physics Code for AI with NVIDIA Warp