中文内容
联邦学习(FL)不再只是研究中的新奇概念,而是对一项硬约束的现实回应:最有价值的数据往往最难移动。监管边界、数据主权规则和组织风险承受能力,常常阻止数据集中汇聚。与此同时,强大的数据引力使得即便获准传输,在大规模场景下也会变得缓慢、昂贵且脆弱。
最新版 NVIDIA FLARE 正是针对这一现实,提供了一个联邦计算运行时:将训练逻辑移动到数据所在处,而原始数据保持不动。在高风险环境中,集中汇聚数据往往不可行或不切实际,因此现代联邦平台必须将数据隔离、合规性和隐私增强技术作为一等需求。
历史上阻碍采用的并不是 FL 的概念,而是开发者体验。如果从“我的本地脚本可以训练”到“我的任务可以跨联邦站点运行”的路径需要深度重构、新的类层级或脆弱的配置,许多项目就会在试点之后停滞。
FLARE API 的演进正是针对这一点:通过将工作拆分为两个具体步骤,消除重构开销,并且这两个步骤能够清晰对应团队实际构建和交付 ML 系统的方式:
- 步骤 1(客户端 API):用约 5–6 行代码将现有本地训练脚本转变为联邦客户端,而无需改变训练循环结构。
- 步骤 2(任务配方):选择 FL 工作流并将其绑定到客户端训练脚本,然后仅通过更换执行环境,在仿真、PoC 和生产环境中运行同一个任务。
“不复制数据”作为系统需求
在受监管或高敏感度场景中,“直接集中数据集”越来越不可行。一个实用的联邦计算平台需要支持:
- 不复制数据:数据保留在本地,仅移动模型更新(或等价信号)。
- 合规态势:支持主权和审计要求的部署与治理控制。
- 隐私增强技术:多层防护机制(示例包括同态加密、差分隐私和机密计算)。

重构悬崖:为什么 FL 项目会停滞
团队通常会在试点之后遇到两类悬崖之一:
- 代码悬崖:将可运行的 PyTorch/TensorFlow/Lightning 训练转换为 FL,可能需要侵入式重构——新的抽象、消息传递胶水代码以及特定框架的脚手架。
- 生命周期悬崖:即使仿真可用,进入 PoC 和生产阶段时,也会因任务重新定义、重新配置和面向环境的分支而触发重写。
FLARE 通过将工作流标准化为两个步骤来抹平这两类悬崖:
- 让脚本具备联邦能力(客户端 API)
- 将其作为可移植任务执行(任务配方)
其预期体验明确是将二者结合起来,让你能够快速从零开始创建可运行的联邦任务。
步骤 1:将本地训练脚本转换为联邦客户端(客户端 API)
适用对象:已有训练代码、希望尽可能减少差异的实践者和 ML 工程师。
其心智模型被有意设计得很简单:
- 初始化客户端运行时
- 在任务运行期间循环
- 接收当前全局模型
- 本地训练(你的代码)
- 发送更新后的权重和指标
FLARE 的客户端 API 旨在实现最小代码变更,并避免迫使你采用繁重的“Executor/Learner”继承方式——可以使用 FLModel 结构或简单的数据交换与运行时通信。
示例 1a:将 PyTorch 转换为 FLARE
下面是一个可应用于许多脚本的具体模式。关键接触点包括:flare.init()、flare.receive()、加载模型权重,以及使用更新后的权重和指标调用 flare.send()。
我们在左侧展示本地训练代码,在右侧展示联邦版本,并突出显示:import、flare.init()、receive()、send()。
# train.py
import torch
import torchvision
import torchvision.transforms as transforms
from model import Net
batch_size = 4
epochs = 1
lr = 0.01
model = Net()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
train_dataset = torchvision.datasets.CIFAR10(
root="/tmp/data/cifar10", transform=transform, download=True, train=True
)
trainloader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True
)
model.to(device)
for epoch in range(epochs):
running_loss = 0.0
for i, batch in enumerate(trainloader):
images, labels = batch[0].to(device), batch[1].to(device)
optimizer.zero_grad()
predictions = model(images)
cost = loss(predictions, labels)
cost.backward()
optimizer.step()
running_loss += cost.cpu().detach().numpy() / batch_size
if i % 3000 == 2999:
print(
f"Epoch: {epoch + 1}/{epochs}, batch: {i + 1}, Loss: {running_loss / 3000}"
)
running_loss = 0.0
print(
f"Epoch: {epoch + 1}/{epochs}, batch: {i + 1}, Loss: {running_loss / (i + 1)}"
)
print("Finished Training")
torch.save(model.state_dict(), "./cifar_net.pth")
# client.py # 1. Import client API import nvflare.client as flare import torch import torchvision import torchvision.transforms as transforms from model import Net batch_size = 4 epochs = 1 lr = 0.01 model = Net() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") loss = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) train_dataset = torchvision.datasets.CIFAR10( root="/tmp/data/cifar10", transform=transform, download=True, train=True ) trainloader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True ) # 2. Initialize FLARE flare.init() # At each round while FLARE is running while flare.is_running(): # 3. Receive the global model input_model = flare.receive() # 4. Load global model model.load_state_dict(input_model.params) model.to(device) for epoch in range(epochs): running_loss = 0.0 for i, batch in enumerate(trainloader): images, labels = batch[0].to(device), batch[1].to(device) optimizer.zero_grad() predictions = model(images) cost = loss(predictions, labels) cost.backward() optimizer.step() running_loss += cost.cpu().detach().numpy() / batch_size if i % 3000 == 2999: print( f"Epoch: {epoch + 1}/{epochs}, batch: {i + 1}, Loss: {running_loss / 3000}" ) running_loss = 0.0 print( f"Epoch: {epoch + 1}/{epochs}, batch: {i + 1}, Loss: {running_loss / (i + 1)}" ) print("Finished Training") torch.save(model.state_dict(), "./cifar_net.pth") # 5. Send back the updated model output_model = flare.FLModel( params=model.cpu().state_dict(), meta={"NUM_STEPS_CURRENT_ROUND": len(trainloader) * epochs}, ) flare.send(output_model)
示例 1b:PyTorch Lightning 客户端 Lightning 集成保持相同
Lightning 集成保持相同意图——接收全局模型、训练、发送更新——但以更适合 Lightning 的方式暴露:导入 Lightning 客户端适配器并 patch Trainer。典型流程为:导入、patch、(可选)验证、照常训练。
# lightning_client.py
import pytorch_lightning as pl
from pytorch_lightning import Trainer
import nvflare.client.lightning as flare # Lightning Client API
from model import LitNet
from data import CIFAR10DataModule
def main():
model = LitNet()
dm = CIFAR10DataModule()
trainer = Trainer(max_epochs=1, accelerator="gpu", devices=1)
# Patch trainer to participate in FL
flare.patch(trainer)
while flare.is_running():
# Optional: validate current global model (useful for server-side selection flows)
trainer.validate(model, datamodule=dm)
# Train starting from received global model (handled internally after patch)
trainer.fit(model, datamodule=dm)
if __name__ == "__main__":
main()
关键在于:Lightning 用户不必进入自定义联邦消息传递流程——他们保留 Trainer 抽象,同时仍能正确参与 FL 轮次。
步骤 2:在任意位置打包并执行联邦任务(任务配方)
适用对象:希望以代码优先方式定义任务,并使其在不同环境中保持稳定的数据科学家和应用团队。
完成步骤 1 后,你就有了一个联邦客户端脚本。步骤 2 会将其变成可重复运行、并能在生命周期中顺畅迁移的联邦任务。
任务配方旨在用基于 Python 的任务定义取代基于 JSON 的任务配置:
- 代码优先:用 Python 定义完整 FL 任务,而不是复杂配置文件
- 一次编写,随处运行:同一个配方可在模拟器、PoC 或生产环境中运行
- 部署提速:无需改变代码结构,即可从实验推进到部署
示例 2a:在仿真中执行 FedAvg 配方
关键关联在于,你的配方会引用在步骤 1 中创建的客户端训练脚本(例如 train_script="client.py"),然后在某个环境中执行它。
# job.py
from nvflare.app_common.workflows.job import FedAvgRecipe
from nvflare.job_config import SimEnv # exact import path can vary by NVFlare version
from model import SimpleNetwork
def main():
n_clients = 3
num_rounds = 5
batch_size = 32
recipe = FedAvgRecipe(
name="hello-pt",
min_clients=n_clients,
num_rounds=num_rounds,
model=SimpleNetwork(),
train_script="client.py", # <-- Step A script
train_args=f"--batch_size {batch_size} --epochs 1",
)
env = SimEnv(num_clients=n_clients, num_threads=n_clients)
recipe.execute(env=env)
if __name__ == "__main__":
main()
这就是“一次编写”的实践:一旦配方正确引用了客户端脚本,其余部分就变成执行层面的事项。
示例 2b:通过环境切换从仿真迁移到真实场景。
任务配方通过切换执行环境,将渐进式工作流形式化:
- SimEnv(仿真):便于开发、快速调试
- PocEnv(概念验证):本地运行时、多进程、贴近真实的测试
- ProdEnv(生产):在安全、可扩展基础设施上的分布式部署

开始使用
- 从你已经信任的脚本开始。
- 步骤 1:添加客户端 API 握手(或 patch 你的 Lightning Trainer)。
- 步骤 2:将其封装到任务配方中,先在仿真中执行,然后通过切换环境进入 PoC,再进入生产。
新闻中的 FLARE
FLARE 正出现在真实部署中——从 Eli Lilly TuneLab 的联邦学习平台(由 Rhino Federated Computing 使用 NVFlare 构建),到台湾 MOHW 的国家医疗联邦学习计划,以及 Tri-labs(Sandia/LANL/LLNL)在敏感数据集上的联邦 AI 试点。
进一步探索
从你已经信任的脚本开始。添加最小化的 FLARE 客户端握手(接收 → 训练 → 发送)。然后在准备好时,从单节点仿真扩展到多站点部署。
- 从这里开始:Hello World 示例(最快完成首次联邦运行的路径)——NVFlare Hello World
- 观看演示:查看简化 API 栈的实际运行——网络研讨会录像
- 客户端 API 文档
- JobRecipe 文档
- GitHub 上的 NVFlare
标签

















