当前位置: 首页 > news >正文

人工智能安全与隐私——联邦遗忘学习(Federated Unlearning)

前言

在联邦学习(Federated Learning, FL)中,尽管用户不需要共享数据,但全局模型本身可以隐式地记住用户的本地数据。因此,有必要将目标用户的数据从FL的全局模型中有效去除,以降低隐私泄露的风险,并满足GDPR要求的“被遗忘权”(The right to be forgotten)。

本文介绍一篇最新的AAAI-2025的文章:Federated Unlearning with Gradient Descent and Conflict Mitigation。这篇文章令人印象深刻。它所提出的方法(FedOSD)解决了传统的遗忘学习梯度上升法带来的弊端,并且不单适用于联邦学习,还适用于传统集中式的机器学习和深度学习中的unlearning,以及目前热门的大语言模型。因此具有一定的启发意义。

本文不去简单地翻译一下论文,要是简单翻译一下,就没有分享的必要了,而是围绕这个联邦遗忘学习的话题,融合自己的一些看法记录分享一下。遇到不专业的地方,欢迎大家来指正!

背景——什么时候需要“遗忘”?

随着人工智能(AI)的蓬勃发展,AI安全是一个越来越重要的话题。AI安全是一个很大的话题,本文就先往小的说,只探讨其中的一小部分:Federated Unlearning以及Machine unlearning。在集中式学习、联邦学习的模型训练过程中,可能会混杂有错误的、有害的数据,抑或是隐私数据。如果事后被发现了,再重新训练就成本很高了。所以这个时候“遗忘学习”就发挥作用了。

根据需要遗忘的内容,联邦遗忘可以划分为:sample unlearning、client unlearning。前者的目标是让FL global model把用户训练集中的某些特定sample给遗忘掉;后者的目标是遗忘掉整个目标用户的训练数据。不过,对于一些传统的FL unlearning算法,例如FedEraser, FedKdu, FedRecovery,等等,因为它们的遗忘机制是依赖于在先前的FL训练中提前把要遗忘的内容对应的梯度等信息存储起来,后续执行遗忘操作的时候,用它来进行遗忘。因此,这类的算法并不适用于sample unlearning。因为在正常的FL训练时,我们并不能提前判断哪些数据是需要后续被遗忘的。而对于其他类型的、不需要提前存储梯度信息的unlearning算法,我们可以用技术手段把sample unlearning规约为client unlearning。比如某个client需要遗忘部分数据,那么我们可以把这部分数据抽取出来,看作是一个虚拟的独立的用户来unlearn。

遗忘学习的基本步骤

综合前面几篇文章来看,联邦遗忘学习一般分为两个阶段:

  • Unlearning stage:使用遗忘算法,消除已训练好的模型对目标用户的训练数据的记忆。

  • Post-training stage:目标用户退出联邦学习,剩余用户继续训练若干轮,以恢复模型可用性(模型在剩余用户上的性能)。

post-training stage的目的是恢复被unlearning破坏的模型性能。在以往的时候一般不纳入联邦遗忘学习算法的主要设计范围内,一般是直接调用跟正常联邦学习相同的方法来进行,比如FedAvg。而这篇文章就把它考虑在整个算法设计的一环里面,下文再细说。

联邦遗忘学习的主要挑战

文章分析了联邦遗忘学习的三个关键挑战:

  • 梯度爆炸问题;

  • 更新冲突导致的模型可用性降低问题;

  • 模型“回退”问题。

下面逐个来看看这是怎么回事,以及文章的解决方法:

1. 梯度爆炸问题

梯度上升法是一种简单有效的unlearning方法,它通过反转梯度,使得训练的时候不再是追求“loss下降”,而变成了“让loss上升”,从而降低模型在需要遗忘的数据上的精度,达到遗忘的效果。

以CrossEntropy loss为例,其公式为:

 

其中C表示类别数; 是binary indicator(数值为0或1),其实就是将label转成one-hot编码后的第 个元素的值。 表示模型对应的预测概率。举个例子(稍微来回顾一下CE loss的计算):假设某个sample的label为2,那么其one-hot编码为 ,即 。假设模型的output经过softmax运算后得到的结果为 ,那么,所以求得CE loss为0.223。

然而,采用梯度上升法来unlearn的时候,会让 趋于0,因此不可避免地会出现梯度爆炸问题,如下图(a)所示。因为它没有upper bound的。前人用gradient clipping等方法解决此问题,但会引入额外的超参数。而在真实场景中无法预知最佳超参数,并且不恰当的超参数甚至会影响收敛性。

为此,文章通过对CrossEntropy loss中的概率 进行反转,解决了这一问题。如下图(b)所示。

 

上面的子图(b)是文章所提出的Unlearning Cross Entropy loss(UCE)。公式如下:

 

它为什么在 进行反转后还要除以2?这是因为对于FL而言,一般来说unlearning client并不能知道当前模型对于其他client而言的gradient的模长是多少,为了避免刚开始unlearn的时候的梯度爆炸,所以加了个除以2的操作来确保unlearning client的gradient不会爆炸。

2. 更新冲突问题

在unlearning的时候,unlearning模型所用的model update direction与不需要unlearn的用户(记为remaining clients)的梯度很容易发生冲突,即模型更新方向并不是remaining clients的梯度下降方向,会直接导致模型的在remaining clients的性能遭受破坏,甚至 导致灾难性遗忘。如下图所示,如果采用 作为更新方向,那么它就会让模型在剩余用户(用户1和用户2)上的精度变差,破坏了模型的可用性。

 

为此,文章设计了一套名为Orthogonal Steepest Descent(正交最速下降)的方法。具体地,根据梯度下降的理论,假如模型更新方向垂直于剩余用户的梯度,那么模型精度的损失就会降到最低。但存在无数条这样的正交梯度,因此,文章通过求解一个距离 最近的向量作为模型的更新方向( 表示unlearning client的梯度),从而既降低了模型可用性的损失,也加速了unlearning。其实从另一个角度来看,加速unlearning在某种意义上也能减少对模型可用性的破坏。文章通过求解下述问题来得到更新方向 。

 

最终得到 的公式如下:

 

其中 是由所有remaining clients的梯度拼成的一个矩阵; 。而 是对 SVD分解的结果。 表示 的伪逆(这个很好计算,就把 中的非0元素取倒数就OK了)。

这样计算出来的第 轮的更新方向 就会距离 最近,同时与剩余用户的梯度垂直。

3. Post-training的模型回退问题

文章开创性地指出post-training阶段会容易出现“模型回退问题”,换言之,就是说在剩余用户上继续训练的时候,模型不受控制地往回朝着unlearn之前的初始模型走,以至于恢复出部分已经unlearn掉的内容,白白做unlearning了。如下图所示:

 

这个现象有点令人震惊。这意味着很多前人的unlearning方法所展现出来的在post-training阶段模型所恢复出跟unlearning之前差不多accuracy的结果,实际上是大有问题的。文章还描绘了模型与初始模型(unlearning前的模型)的距离变化图,展现了这一“回退”现象,如下图:

 

实验中的ASR(Attack Successful Rate)用来衡量unlearning效果,它越低越好;R-Acc表示模型在remaining clients上的平均Accuracy,它用来衡量模型的可用性,数值越高越好。

从图中可见,尽管一些算法在post-training阶段里,模型的可用性恢复到了接近unlearning前的水平,但明显地模型回退到了接近unlearning前的初始模型的状态了,相当于白做unlearn了。

为此,文章提出了一种梯度投影策略,在post-trainining阶段,首先计算一个向量 ,表示当前模型到初始模型的0.5倍的L-2距离的梯度。然后,如果某个剩余用户 的梯度 与 的夹角小于90°,则将 投影到 的法平面上,从而避免了模型回退靠近初始模型,转去搜索其他的local optimum区域。

后记

总结来说,整个算法(FedOSD,Federated Unlearning with Orthogonal Steepest Descent)的流程图如下所示:

 

Federated Learning其中图(a)描绘的是unlearning之前的正常的联邦学习训练。然后某天用户 提出要遗忘、退出,则进入图(b)所示的unlearning阶段。unlearning结束后,进入图(c)的post-training阶段。

文章还做了很多其他实验,并详细分析了收敛性证明。这里就不展开叙述了。

值得一提的是,我之前确实也试过基于梯度上升的FL unlearning方法,实话实说,确实非常非常依赖调参,例如调节learning rate,以及调节gradient clipping的参数,否则没几轮,梯度就爆炸了,模型近乎变成一个随机初始化的模型,完全没法实用SOS)。另外,文章所描绘的方法,实际上也适用于machine unlearning,以及LLM的unlearning(毕竟LLM里面一般也用CrossEntropy loss了)。我后面也试图将它用在LLM上,做LLM unlearning,用来探索大模型的隐私安全问题,但那个就更复杂了,后面再慢慢探索。希望这篇文章在帮助自己记录学习点滴之余,也能帮助大家!


http://www.mrgr.cn/news/82289.html

相关文章:

  • Rosbag常见使用汇总
  • vue数据请求通用方案:axios的options都有哪些值
  • 蓝桥杯备赛:C++基础,顺序表和vector(STL)
  • RabbitMq的Java项目实践
  • SAP SD销售模块常见BAPI函数
  • 医学图像分析工具01:FreeSurfer || Recon -all 全流程MRI皮质表面重建
  • 51c视觉~合集40
  • 硬件设计-关于ADS54J60的校准问题
  • 多种方式访问mysql的对比分析
  • Pygame Zero(pgzrun)详解(简介、使用方法、坐标系、目录结构、语法参数、安装、实例解释)
  • NLP中的神经网络基础
  • SELECT的使用
  • GRAPE——RLAIF微调VLA模型:通过偏好对齐提升机器人策略的泛化能力(含24年具身模型汇总)
  • 矩阵的因子分解1-奇异值分解
  • 本地LLM部署--llama.cpp
  • Go 语言:Jank 简客博客系统
  • 我在广州学 Mysql 系列——插入、更新与删除数据详解以及实例
  • 数据结构与算法Python版 拓扑排序与强连通分支
  • chatwoot 开源客服系统搭建
  • 我的 2024 年终总结
  • 【信号滤波 (中)】采样条件及多种滤波算法对比(滑动平均/陷波滤波)
  • 【机器学习】工业 4.0 下机器学习如何驱动智能制造升级
  • JSON 系列之3:导入JSON标准示例数据
  • Redis - 1 ( 11000 字 Redis 入门级教程 )
  • SpringBoot整合Canal+RabbitMQ监听数据变更
  • gitlab的搭建及使用