元鉴
返回中文阅读流

NVIDIA Developer Blog

使用 NVIDIA FLARE 实现无重构开销的联邦学习

联邦学习已从研究探索转向现实需求:高价值数据往往最难移动。

中文内容

已翻译official company source英文原文2026-05-26

联邦学习(FL)不再只是研究中的新奇概念,而是对一项硬约束的现实回应:最有价值的数据往往最难移动。监管边界、数据主权规则和组织风险承受能力,常常阻止数据集中汇聚。与此同时,强大的数据引力使得即便获准传输,在大规模场景下也会变得缓慢、昂贵且脆弱。

最新版 NVIDIA FLARE 正是针对这一现实,提供了一个联邦计算运行时:将训练逻辑移动到数据所在处,而原始数据保持不动。在高风险环境中,集中汇聚数据往往不可行或不切实际,因此现代联邦平台必须将数据隔离、合规性和隐私增强技术作为一等需求。

历史上阻碍采用的并不是 FL 的概念,而是开发者体验。如果从“我的本地脚本可以训练”到“我的任务可以跨联邦站点运行”的路径需要深度重构、新的类层级或脆弱的配置,许多项目就会在试点之后停滞。

FLARE API 的演进正是针对这一点:通过将工作拆分为两个具体步骤,消除重构开销,并且这两个步骤能够清晰对应团队实际构建和交付 ML 系统的方式:

  • 步骤 1(客户端 API):用约 5–6 行代码将现有本地训练脚本转变为联邦客户端,而无需改变训练循环结构。
  • 步骤 2(任务配方):选择 FL 工作流并将其绑定到客户端训练脚本,然后仅通过更换执行环境,在仿真、PoC 和生产环境中运行同一个任务。

“不复制数据”作为系统需求

在受监管或高敏感度场景中,“直接集中数据集”越来越不可行。一个实用的联邦计算平台需要支持:

  • 不复制数据:数据保留在本地,仅移动模型更新(或等价信号)。
  • 合规态势:支持主权和审计要求的部署与治理控制。
  • 隐私增强技术:多层防护机制(示例包括同态加密、差分隐私和机密计算)。
Figure shows a before-and-after comparison of centralized versus federated computing. On the left (“before”), three separate data silos send their data into one centralized database where a model is trained. On the right (“after”), data remFigure shows a before-and-after comparison of centralized versus federated computing. On the left (“before”), three separate data silos send their data into one centralized database where a model is trained. On the right (“after”), data rem
图 1. 联邦计算让数据保持在原处,通过模型更新实现协作,同时支持合规性和隐私增强保护。

重构悬崖:为什么 FL 项目会停滞

团队通常会在试点之后遇到两类悬崖之一:

  • 代码悬崖:将可运行的 PyTorch/TensorFlow/Lightning 训练转换为 FL,可能需要侵入式重构——新的抽象、消息传递胶水代码以及特定框架的脚手架。
  • 生命周期悬崖:即使仿真可用,进入 PoC 和生产阶段时,也会因任务重新定义、重新配置和面向环境的分支而触发重写。

FLARE 通过将工作流标准化为两个步骤来抹平这两类悬崖:

  1. 让脚本具备联邦能力(客户端 API)
  2. 将其作为可移植任务执行(任务配方)

其预期体验明确是将二者结合起来,让你能够快速从零开始创建可运行的联邦任务。

步骤 1:将本地训练脚本转换为联邦客户端(客户端 API)

适用对象:已有训练代码、希望尽可能减少差异的实践者和 ML 工程师。

其心智模型被有意设计得很简单:

  1. 初始化客户端运行时
  2. 在任务运行期间循环
  3. 接收当前全局模型
  4. 本地训练(你的代码)
  5. 发送更新后的权重和指标

FLARE 的客户端 API 旨在实现最小代码变更,并避免迫使你采用繁重的“Executor/Learner”继承方式——可以使用 FLModel 结构或简单的数据交换与运行时通信。

示例 1a:将 PyTorch 转换为 FLARE

下面是一个可应用于许多脚本的具体模式。关键接触点包括:flare.init()、flare.receive()、加载模型权重,以及使用更新后的权重和指标调用 flare.send()。

我们在左侧展示本地训练代码,在右侧展示联邦版本,并突出显示:import、flare.init()、receive()、send()。

train.py
# train.py

import torch
import torchvision
import torchvision.transforms as transforms

from model import Net

batch_size = 4
epochs = 1
lr = 0.01
model = Net()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
transform = transforms.Compose(
   [
       transforms.ToTensor(),
       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
   ]
)

train_dataset = torchvision.datasets.CIFAR10(
   root="/tmp/data/cifar10", transform=transform, download=True, train=True
)

trainloader = torch.utils.data.DataLoader(
   train_dataset, batch_size=batch_size, shuffle=True
)

model.to(device)

for epoch in range(epochs):
   running_loss = 0.0

   for i, batch in enumerate(trainloader):
       images, labels = batch[0].to(device), batch[1].to(device)

       optimizer.zero_grad()

       predictions = model(images)
       cost = loss(predictions, labels)
       cost.backward()
       optimizer.step()

       running_loss += cost.cpu().detach().numpy() / batch_size

       if i % 3000 == 2999:
           print(
               f"Epoch: {epoch + 1}/{epochs}, batch: {i + 1}, Loss: {running_loss / 3000}"
           )
           running_loss = 0.0

   print(
       f"Epoch: {epoch + 1}/{epochs}, batch: {i + 1}, Loss: {running_loss / (i + 1)}"
   )

print("Finished Training")

torch.save(model.state_dict(), "./cifar_net.pth")
client.py
# client.py

# 1. Import client API
import nvflare.client as flare
import torch
import torchvision
import torchvision.transforms as transforms

from model import Net

batch_size = 4
epochs = 1
lr = 0.01
model = Net()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
transform = transforms.Compose(
   [
       transforms.ToTensor(),
       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
   ]
)

train_dataset = torchvision.datasets.CIFAR10(
   root="/tmp/data/cifar10", transform=transform, download=True, train=True
)

trainloader = torch.utils.data.DataLoader(
   train_dataset, batch_size=batch_size, shuffle=True
)

# 2. Initialize FLARE
flare.init()

# At each round while FLARE is running
while flare.is_running():
   # 3. Receive the global model
   input_model = flare.receive()

   # 4. Load global model
   model.load_state_dict(input_model.params)
   model.to(device)

   for epoch in range(epochs):
       running_loss = 0.0

       for i, batch in enumerate(trainloader):
           images, labels = batch[0].to(device), batch[1].to(device)

           optimizer.zero_grad()

           predictions = model(images)
           cost = loss(predictions, labels)
           cost.backward()
           optimizer.step()

           running_loss += cost.cpu().detach().numpy() / batch_size

           if i % 3000 == 2999:
               print(
                   f"Epoch: {epoch + 1}/{epochs}, batch: {i + 1}, Loss: {running_loss / 3000}"
               )
               running_loss = 0.0

       print(
           f"Epoch: {epoch + 1}/{epochs}, batch: {i + 1}, Loss: {running_loss / (i + 1)}"
       )

   print("Finished Training")

   torch.save(model.state_dict(), "./cifar_net.pth")

   # 5. Send back the updated model
   output_model = flare.FLModel(
       params=model.cpu().state_dict(),
       meta={"NUM_STEPS_CURRENT_ROUND": len(trainloader) * epochs},
   )
   flare.send(output_model)

示例 1b:PyTorch Lightning 客户端 Lightning 集成保持相同

Lightning 集成保持相同意图——接收全局模型、训练、发送更新——但以更适合 Lightning 的方式暴露:导入 Lightning 客户端适配器并 patch Trainer。典型流程为:导入、patch、(可选)验证、照常训练。

# lightning_client.py
import pytorch_lightning as pl
from pytorch_lightning import Trainer

import nvflare.client.lightning as flare  # Lightning Client API  

from model import LitNet
from data import CIFAR10DataModule
def main():
   model = LitNet()
   dm = CIFAR10DataModule()

   trainer = Trainer(max_epochs=1, accelerator="gpu", devices=1)

   # Patch trainer to participate in FL
   flare.patch(trainer)

   while flare.is_running():
       # Optional: validate current global model (useful for server-side selection flows)
       trainer.validate(model, datamodule=dm)

       # Train starting from received global model (handled internally after patch)
       trainer.fit(model, datamodule=dm)


if __name__ == "__main__":
   main()

关键在于:Lightning 用户不必进入自定义联邦消息传递流程——他们保留 Trainer 抽象,同时仍能正确参与 FL 轮次。

步骤 2:在任意位置打包并执行联邦任务(任务配方)

适用对象:希望以代码优先方式定义任务,并使其在不同环境中保持稳定的数据科学家和应用团队。

完成步骤 1 后,你就有了一个联邦客户端脚本。步骤 2 会将其变成可重复运行、并能在生命周期中顺畅迁移的联邦任务。

任务配方旨在用基于 Python 的任务定义取代基于 JSON 的任务配置:

  • 代码优先:用 Python 定义完整 FL 任务,而不是复杂配置文件
  • 一次编写,随处运行:同一个配方可在模拟器、PoC 或生产环境中运行
  • 部署提速:无需改变代码结构,即可从实验推进到部署

示例 2a:在仿真中执行 FedAvg 配方

关键关联在于,你的配方会引用在步骤 1 中创建的客户端训练脚本(例如 train_script="client.py"),然后在某个环境中执行它。

# job.py
from nvflare.app_common.workflows.job import FedAvgRecipe
from nvflare.job_config import SimEnv  # exact import path can vary by NVFlare version

from model import SimpleNetwork

def main():
   n_clients = 3
   num_rounds = 5
   batch_size = 32

   recipe = FedAvgRecipe(
       name="hello-pt",
       min_clients=n_clients,
       num_rounds=num_rounds,
       model=SimpleNetwork(),
       train_script="client.py",  # <-- Step A script
       train_args=f"--batch_size {batch_size} --epochs 1",
   )

   env = SimEnv(num_clients=n_clients, num_threads=n_clients)
   recipe.execute(env=env)

if __name__ == "__main__":
   main()

这就是“一次编写”的实践:一旦配方正确引用了客户端脚本,其余部分就变成执行层面的事项。

示例 2b:通过环境切换从仿真迁移到真实场景。

任务配方通过切换执行环境,将渐进式工作流形式化:

  1. SimEnv(仿真):便于开发、快速调试
  2. PocEnv(概念验证):本地运行时、多进程、贴近真实的测试
  3. ProdEnv(生产):在安全、可扩展基础设施上的分布式部署
Alt text: Figure shows a three-stage JobRecipe pipeline flowing into three execution environments. A box labeled “JobRecipe” at the top splits into three arrows pointing to side-by-side panels: SimEnv (Simulation) for easy development and rAlt text: Figure shows a three-stage JobRecipe pipeline flowing into three execution environments. A box labeled “JobRecipe” at the top splits into three arrows pointing to side-by-side panels: SimEnv (Simulation) for easy development and r
图 2. 一个 JobRecipe,多个执行环境:在 SimEnv 中调试,在 PocEnv 中验证,并在 ProdEnv 中部署,无需重写任务定义

开始使用

  • 从你已经信任的脚本开始。
  • 步骤 1:添加客户端 API 握手(或 patch 你的 Lightning Trainer)。
  • 步骤 2:将其封装到任务配方中,先在仿真中执行,然后通过切换环境进入 PoC,再进入生产。

新闻中的 FLARE

FLARE 正出现在真实部署中——从 Eli Lilly TuneLab 的联邦学习平台(由 Rhino Federated Computing 使用 NVFlare 构建),到台湾 MOHW 的国家医疗联邦学习计划,以及 Tri-labs(Sandia/LANL/LLNL)在敏感数据集上的联邦 AI 试点。

进一步探索

从你已经信任的脚本开始。添加最小化的 FLARE 客户端握手(接收 → 训练 → 发送)。然后在准备好时,从单节点仿真扩展到多站点部署。

  • 从这里开始:Hello World 示例(最快完成首次联邦运行的路径)——NVFlare Hello World
  • 观看演示:查看简化 API 栈的实际运行——网络研讨会录像
  • 客户端 API 文档
  • JobRecipe 文档
  • GitHub 上的 NVFlare

Like

标签

原文标题

Federated Learning Without the Refactoring Overhead Using NVIDIA FLARE