当前位置: 首页 > news >正文

【TPAMI 2024】如何让模型在任何环境下都能胜出?领域泛化学习从单一到多元!

Out-of-Domain Generalization From a Single Source: An Uncertainty Quantification Approach

题目:单一源域的域外泛化:一种不确定性量化方法

作者:Xi Peng; Fengchun Qiao; Long Zhao
关注公众号:AI前沿速递,获取更多优质资源!


摘要

  • 我们关注的是在模型泛化中的最坏情况,即模型旨在在许多未见过的领域上表现良好,而训练时只有一个单一的领域可用。我们提出了基于元学习(Meta-Learning)的对抗性领域增强(Adversarial Domain Augmentation)来解决这一领域外泛化问题。关键思想是利用对抗性训练创建“虚构的”但“具有挑战性的”群体,模型可以从中学习泛化,并具有理论保证。为了促进快速且理想的领域增强,我们将模型训练纳入元学习方案,并使用Wasserstein自编码器(Wasserstein Auto-Encoder)来放宽广泛使用的最坏情况约束。我们通过整合不确定性量化进一步改进我们的方法,以实现高效的领域泛化。在多个基准数据集上的广泛实验表明,其在解决单一领域泛化方面的性能优于现有技术。

关键词

  • 领域泛化
  • 对抗性训练
  • 元学习
  • 不确定性量化

1. 引言

近年来,机器学习模型在广泛的应用中得到了迅速部署。其显著成功的一个关键假设是训练和测试数据通常遵循相似的统计规律。否则,即使是强大的模型(例如深度神经网络)也可能在未见过的或领域外(Out-of-Domain, OOD)测试样本上失败。通过整合来自多个训练领域的数据可以在某种程度上缓解这一问题,但由于数据获取预算有限或隐私问题,这并不总是可行的。于是,一个有趣但很少被研究的问题出现了:一个模型能否从一个源领域泛化到许多未见过的的目标领域?换句话说,当训练时只有一个领域可用时,如何最大化模型的泛化能力?

源和目标领域之间的差异,也被称为领域或协变量变化,在领域适应和领域泛化中已经得到了深入研究。尽管这些方法在解决普通领域差异问题上取得了各种成功,但我们认为现有的方法很难在上述的单一领域泛化问题上取得成功。如图1所示,前者通常期望目标领域数据的可用性(无论是标记的还是未标记的);而后者则总是假设在训练时有多个(而不是一个)领域可用。这一事实强调了为单一领域泛化开发新的学习范式的必要性。

在本文中,我们提出了对抗性领域增强来解决这一具有挑战性的任务。受到对抗性训练近期成功的启发,我们将单一领域泛化问题纳入最坏情况的公式中。目标是使用单一源领域生成“虚构的”但“具有挑战性的”群体,模型可以从中学习泛化,并具有理论保证。然而,在应用对抗性训练进行领域增强时存在技术障碍。一方面,由于最坏情况公式中的语义一致性约束,很难创建与源领域大不相同的“虚构”领域。另一方面,我们希望探索许多“虚构”领域以保证足够的覆盖范围,这可能导致显著的计算开销。为了绕过这些障碍,我们提出通过Wasserstein自编码器放宽最坏情况约束,以鼓励在输入空间中进行大范围的领域转换。此外,我们不是学习一系列集成模型,而是通过元学习组织对抗性领域增强,从而得到一个高效的模型,提高了单一领域泛化的性能。

模型泛化能力通常以准确率来衡量。然而,在关键任务中,不确定性量化起着至关重要的作用。例如,在未知环境中部署自动驾驶汽车时,了解预测不确定性对于风险评估至关重要。现有的工作通常忽视了在解决领域外泛化时利用增强数据的潜在风险,引发了严重的安全和安全问题。为此,我们通过整合不确定性量化改进我们的方法,以实现广泛和安全的领域泛化。总之,我们的贡献是多方面的:

  • 这项工作的主要贡献是一个基于元学习的方案,它使得单一领域泛化成为可能,这是一个重要但很少研究的问题。我们通过提出对抗性领域增强,并同时放宽广泛使用的最坏情况约束来实现这一目标。
  • 据我们所知,我们是第一个从单一源量化泛化不确定性的。我们利用不确定性评估来增加输入和标签空间的容量。
  • 广泛的实验表明,我们的方法在包括Digits、CIFAR-10-C和SYTHIA在内的基准数据集上的单一领域泛化方面,略微优于现有技术。

这项工作对我们会议版本进行了大量扩展。我们提出联合扰动潜在特征和真实标签,以鼓励显著的领域增强,整合不确定性量化以实现更可靠的领域外泛化。

3 单一领域泛化

我们的目标是解决单一领域泛化问题:一个模型仅在一个源领域 S S S上训练,但期望在未知领域分布 f T = { T 1 , T 2 , . . . } ∼ p ( T ) f_T = \{T_1, T_2, ...\} \sim p(T) fT={T1,T2,...}p(T)上泛化。这个问题比领域适应(假设 p ( T ) p(T) p(T)已知)和领域泛化(假设有多个源领域 { S 1 , S 2 , . . . } \{S_1, S_2, ...\} {S1,S2,...}可用)更具挑战性。受到最近在对抗性训练中的许多成就的启发,一个有前景的解决方案是利用对抗性训练来学习一个对分布外扰动具有抵抗力的鲁棒模型。更具体地说,我们可以通过解决最坏情况问题来学习模型:

min ⁡ θ sup ⁡ T : D ( S , T ) ≤ ρ E L task ( θ , T ) \min_{\theta} \sup_{T: D(S, T) \leq \rho} \mathbb{E}_{L_{\text{task}}(\theta, T)} θminT:D(S,T)ρsupELtask(θ,T)

其中 D D D是度量源领域和目标领域之间距离的相似性度量, ρ \rho ρ表示 S S S T T T之间最大的领域差异。 θ \theta θ是模型参数,根据特定任务的目标函数 L task L_{\text{task}} Ltask进行优化。在这里,我们专注于使用交叉熵损失的分类问题:

L task ( y , y ^ ) = − ∑ i y i log ⁡ ( y ^ i ) L_{\text{task}}(y, \hat{y}) = -\sum_{i} y_i \log(\hat{y}_i) Ltask(y,y^)=iyilog(y^i)

其中 y ^ \hat{y} y^是模型的softmax输出, y y y是表示真实类别的独热向量; y i y_i yi y ^ i \hat{y}_i y^i分别代表 y y y y ^ \hat{y} y^的第 i i i维。

根据最坏情况的公式(1),我们提出了一种新的方法,基于元学习的对抗性领域增强,用于单一领域泛化。图2展示了我们方法的概览。我们在3.1节中通过对抗性训练创建“虚构的”但“具有挑战性的”领域来增强源领域。任务模型在3.2节中帮助下从领域增强中学习,其中Wasserstein自编码器(WAE)放宽了最坏情况约束。我们在3.3节中描述了任务模型和WAE的联合训练,以及领域增强过程,这些都在一个元学习框架中组织。

3.1 对抗性领域增强

我们的目标是从未源领域创建多个增强领域。增强领域需要在分布上与源领域不同,以便模仿未见过的领域。此外,为了避免增强领域的发散,最坏情况保证定义在等式(1)中也应该被满足。

为了实现这一目标,我们提出了对抗性领域增强。我们的模型由任务模型和WAE组成,如图2所示。在图2中,任务模型由特征提取器 F : X → Z F: X \rightarrow Z F:XZ组成,将图像从输入空间映射到嵌入空间,以及分类器 C : Z → Y C: Z \rightarrow Y C:ZY用于从嵌入空间预测标签。设 z z z表示 x x x的潜在表示,由 z = F ( x ) z = F(x) z=F(x)获得。整体损失函数如下公式所示:

L ADA = L task ( θ , x ) + α L const ( θ , z ) + β L relax ( c , x ) L_{\text{ADA}} = L_{\text{task}}(\theta, x) + \alpha L_{\text{const}}(\theta, z) + \beta L_{\text{relax}}(c, x) LADA=Ltask(θ,x)+αLconst(θ,z)+βLrelax(c,x)

其中 L task L_{\text{task}} Ltask是等式(2)中定义的分类损失, L const L_{\text{const}} Lconst是等式(1)中定义的最坏情况保证, L relax L_{\text{relax}} Lrelax保证了大范围的领域传输,定义在等式(7)中。 c c c是WAE的参数。 α \alpha α β \beta β是平衡 L const L_{\text{const}} Lconst L relax L_{\text{relax}} Lrelax的两个超参数。

给定目标函数 L ADA L_{\text{ADA}} LADA,我们采用迭代方式在增强领域 S + S^+ S+中生成对抗样本 x + x^+ x+

x t + 1 + = x t + + γ ∇ x t + L ADA ( θ , c , x t + , z t + ) x^{+}_{t+1} = x^+_t + \gamma \nabla_{x^+_t} L_{\text{ADA}}(\theta, c, x^+_t, z^+_t) xt+1+=xt++γxt+LADA(θ,c,xt+,zt+)

其中 γ \gamma γ是梯度上升的学习率。需要少量迭代就可以产生足够的扰动并创建理想的对抗样本。

L const L_{\text{const}} Lconst对对抗样本施加语义一致性约束,以便 S + S^+ S+满足 D ( S , S + ) ≤ ρ D(S, S^+) \leq \rho D(S,S+)ρ。具体来说,我们遵循[70]在嵌入空间测量 S + S^+ S+ S S S之间的Wasserstein距离:

L const = 1 2 ∥ z − z + ∥ 2 + 1 { y ≠ y + } L_{\text{const}} = \frac{1}{2} \|z - z^+\|^2 + \mathbb{1}_{\{y \neq y^+\}} Lconst=21zz+2+1{y=y+}

其中 1 { ⋅ } \mathbb{1}_{\{\cdot\}} 1{}是0-1指示函数,如果 x + x^+ x+的类别标签与 x x x不同, L const L_{\text{const}} Lconst将为1。我们假设 y + = y y^+ = y y+=y总是真的,只要步长足够小,这简化了实现,而不影响性能。直观地说, L const L_{\text{const}} Lconst通过Wasserstein距离控制了泛化到源领域之外的能力。在传统的对抗性训练设置中,最坏情况问题由 L task L_{\text{task}} Ltask L const L_{\text{const}} Lconst处理。然而, L const L_{\text{const}} Lconst由于严格限制了样本及其扰动之间的语义距离,因此在领域传输上产生了有限的影响。因此,我们提出 L relax L_{\text{relax}} Lrelax来放宽语义一致性约束并创建大范围的领域传输。 L relax L_{\text{relax}} Lrelax的实现在第3.2节中讨论。

3.2 Wasserstein距离约束的放宽

直观地说,我们期望增强领域 S + S^+ S+与源领域 S S S大不相同。换句话说,我们希望最大化 S + S^+ S+ S S S之间的领域差异。然而,语义一致性约束 L const L_{\text{const}} Lconst会严重限制从 S S S S + S^+ S+的领域传输,为生成理想的 S + S^+ S+带来了新的挑战。为了解决这个问题,我们提出 L relax L_{\text{relax}} Lrelax来鼓励领域外的增强。我们在图3中说明了这个想法。

具体来说,我们采用Wasserstein自编码器(WAEs)来实现 L relax L_{\text{relax}} Lrelax。设 V V V表示由 c c c参数化的WAE,它由编码器 Q ( e ∣ x ) Q(e|x) Q(ex)和解码器 G ( x ∣ e ) G(x|e) G(xe)组成,其中 x x x e e e分别表示输入和瓶颈嵌入。此外,我们使用距离度量 D e D_e De来衡量 Q ( x ) Q(x) Q(x)和先验分布 P ( e ) P(e) P(e)之间的差异,这可以作为最大均值差异(MMD)或GANs实现。我们可以通过优化来学习 V V V

min ⁡ c [ 1 2 ∥ G ( Q ( x ) ) − x ∥ 2 + λ D e ( Q ( x ) , P ( e ) ) ] \min_c \left[ \frac{1}{2} \| G(Q(x)) - x \|^2 + \lambda D_e(Q(x), P(e)) \right] cmin[21G(Q(x))x2+λDe(Q(x),P(e))]

其中 λ \lambda λ是一个超参数。在源领域 S S S上预先训练 V V V后,我们保持它不变,并最大化领域增强的重建误差 L relax L_{\text{relax}} Lrelax

L relax = ∥ x + − V ( x + ) ∥ 2 L_{\text{relax}} = \| x^+ - V(x^+) \|^2 Lrelax=x+V(x+)2

与传统的或变分自编码器不同,WAEs使用Wasserstein度量来衡量输入和重建之间的分布距离。因此,预先训练的 V V V可以更好地捕获源领域的分布,最大化 L relax L_{\text{relax}} Lrelax创建大范围的领域传输。

在这项工作中, V V V 充当一类鉴别器,用以区分增强样本是否在源领域之外,这与传统的生成对抗网络(GANs)[14]中的鉴别器有显著的不同。它也不同于领域分类器,通常在领域适应任务中用于区分不同领域的样本。在领域适应中,领域分类器的目标是最大化源领域和目标领域之间的区分度,而我们的 V V V则是为了最小化这种区分度,从而允许在输入空间内实现大范围的领域传输。 L relax L_{\text{relax}} Lrelax L const L_{\text{const}} Lconst联合使用,旨在在输入空间内“推开”增强领域 S + S^+ S+,同时在嵌入空间内“拉回” S + S^+ S+。在第5节中,我们将展示 L relax L_{\text{relax}} Lrelax L const L_{\text{const}} Lconst分别是在输入空间和嵌入空间定义的两个Wasserstein距离度量的推导。

3.3 元学习单一领域泛化

为了高效地组织在源领域 S S S和增强领域 S + S^+ S+上的模型训练,我们利用元学习方案来训练单一模型。为了模拟源领域 S S S和目标领域 T T T之间的实际领域偏移,在每次学习迭代中,我们在源领域 S S S上执行元训练,并在所有增强领域 S + S^+ S+上执行元测试。因此,经过多次迭代后,期望模型在评估期间对最终目标领域 T T T实现良好的泛化。

正式地,所提出的基于元学习的对抗性领域增强方法在训练过程中的每次迭代由三部分组成:元训练、元测试和元更新。在元训练中, L task L_{\text{task}} Ltask是在来自源领域 S S S的样本上计算的,并且模型参数 θ \theta θ通过一个或多个梯度步骤以学习率 η \eta η进行更新:

θ ← θ − η ∇ θ L task ( θ , S ) \theta \leftarrow \theta - \eta \nabla_{\theta} L_{\text{task}}(\theta, S) θθηθLtask(θ,S)

然后我们在元测试中计算每个增强领域 S k + S^+_k Sk+上的 L task ( θ ′ , S k + ) L_{\text{task}}(\theta', S^+_k) Ltask(θ,Sk+)。最后,在元更新中,我们根据联合损失的梯度更新 θ \theta θ,其中元训练和元测试同时被优化:

θ ← θ − η [ 1 2 ∇ θ L task ( θ , S ) + ∑ k = 1 K L task ( θ ′ , S k + ) ] \theta \leftarrow \theta - \eta \left[ \frac{1}{2} \nabla_{\theta} L_{\text{task}}(\theta, S) + \sum_{k=1}^{K} L_{\text{task}}(\theta', S^+_k) \right] θθη[21θLtask(θ,S)+k=1KLtask(θ,Sk+)]

其中 K K K是增强领域的数量。注意,除了增强领域 S + S^+ S+外,我们还最小化源领域 S S S上的损失,以避免当 S + S^+ S+远离 S S S时性能下降。

单一领域泛化(SDG)的全部训练流程在算法1中进行了总结。与之前学习一系列集成模型的工作[70]不同,我们的方法实现了单一模型的高效性。更重要的是,元学习方案为学习模型准备了快速适应:一到几个梯度步骤将在新目标领域上产生改进的行为。这使我们的方法能够适用于少样本领域适应。更多细节请参见第7.4节。

4 不确定的单一领域泛化

尽管我们的方法在提高准确性方面的泛化能力有所提升,但它忽略了利用增强数据解决领域外泛化问题的潜在风险。为此,我们提出了不确定的单一领域泛化(Uncertain SDG),通过整合不确定性量化来进一步改进,以实现高效和安全的领域泛化。为此,我们引入了辅助模型 c = f ( p , f m ) c = f(p, f_m) c=f(p,fm) 来明确地模拟与 u u u 相关的不确定性,并逐渐增加不确定性,以在输入和标签空间中创造更具挑战性的 S+。在输入空间,我们引入 f p f_p fp 通过添加从 N ( m m , s s ) N(mm, ss) N(mm,ss) 采样的扰动 e e e 来创建特征增强 h + h^+ h+。在标签空间,我们将编码在 ( m m , s s ) (mm, ss) (mm,ss) 中的相同不确定性整合到 f m f_m fm 中,并提出了可学习的 mixup 方法,通过三个变量 ( a , b , t ) (a, b, t) (a,b,t) 生成 y + y^+ y+(与 h + h^+ h+ 一起),在输入和输出空间中实现一致的增强。

输入增强

目标是评估 S+ 的不确定性,并用它来提供输入和输出空间中一致的增强。为了实现这一目标,我们不是直接增强原始数据,而是引入一个轻量级的辅助网络 f p f_p fp 来通过增加与 u u u 相关的不确定性来创建具有大领域传输的特征增强 h + h^+ h+。我们提出学习逐层特征扰动 e e e,将潜在特征 h h h 传输到 h + h^+ h+ 以实现高效的领域增强 S → S+。与之前工作中广泛使用的直接生成 e = f p ( x , h ) e = f_p(x, h) e=fp(x,h) 不同,我们假设 e e e 遵循多元高斯分布 N ( m m , s s ) N(mm, ss) N(mm,ss),这可以在在未见领域部署模型时轻松访问不确定性。更具体地说,高斯参数是通过变分推断学习的 ( m m , s s ) = f p ( S , u ) (mm, ss) = f_p(S, u) (mm,ss)=fp(S,u),使得

h + = h + Softplus ( e ) , e ∼ N ( m m , s s ) h^+ = h + \text{Softplus}(e), \quad e \sim N(mm, ss) h+=h+Softplus(e),eN(mm,ss)

其中 Softplus \text{Softplus} Softplus 应用于稳定训练。

标签增强

特征扰动不仅增强了输入,还产生了标签的不确定性。为了明确地模拟标签的不确定性,我们利用 ( m m , s s ) (mm, ss) (mm,ss) 中编码的输入不确定性,通过 f m f_m fm 推断出编码在 ( a , b , t ) (a, b, t) (a,b,t) 中的标签不确定性(如图4)。受到 mixup 的启发,它执行成对样本 ( x i , x j ) (x_i, x_j) (xi,xj) 及其标签 ( y i , y j ) (y_i, y_j) (yi,yj) 的凸插值,我们通过将其置于特别为单一源泛化量身定制的可学习框架中来改进 mixup。首先,我们不是混合成对样本,而是混合 S 和 S+ 以实现领域间的插值。其次,我们利用 ( m m , s s ) (mm, ss) (mm,ss) 中编码的不确定性来预测可学习的参数 ( a , b ) (a, b) (a,b),它控制领域插值的方向和强度:

h + = α h + ( 1 − α ) h + , y + = α y + ( 1 − α ) y ~ h^+ = \alpha h + (1 - \alpha) h^+, \quad y^+ = \alpha y + (1 - \alpha) \tilde{y} h+=αh+(1α)h+,y+=αy+(1α)y~

其中 α ∼ Beta ( a , b ) \alpha \sim \text{Beta}(a, b) αBeta(a,b) 并且 y ~ \tilde{y} y~ 表示 y y y 的标签平滑版本。具体来说,我们通过机会 t t t 执行标签平滑,这样我们就为真实类别分配 r ∈ ( 0 , 1 ) r \in (0, 1) r(0,1) 并将 1 − r 1 - r 1r 平等地分配给其他类别,其中 c c c 统计类别。Beta 分布 ( a , b ) (a, b) (a,b) 和抽签 t t t ( a , b , t ) = f m ( m m , s s ) (a, b, t) = f_m(mm, ss) (a,b,t)=fm(mm,ss) 联合推断,以整合领域增强的不确定性。在第 7.2 节中,我们通过实验证明编码在 m m mm mm s s ss ss 中的不确定性可以鼓励更平滑的标签并显著增加标签空间的容量。整个训练流程在算法 2 中进行了总结。


5 理论理解

我们对提出的对抗性领域增强进行了详细的理论分析。具体来说,我们展示了等式(3)中定义的整体损失函数是放松的最坏情况问题的直接派生。

假设存在一个“成本”函数 c : Z × Z → R + ∪ { 1 } c: Z \times Z \rightarrow \mathbb{R}^+ \cup \{1\} c:Z×ZR+{1},用于量化嵌入空间中 z z z z + z^+ z+ 的对抗性扰动。同样,存在一个“成本”函数 d : X × X → R + ∪ { 1 } d: X \times X \rightarrow \mathbb{R}^+ \cup \{1\} d:X×XR+{1},用于量化输入空间中 x x x x + x^+ x+ 的对抗性扰动。 S S S S + S^+ S+ 之间的Wasserstein距离可以表示为:

W c ( S , S + ) : = inf ⁡ M z ∈ P ( S , S + ) E M z [ c ( z , z + ) ] W_c(S, S^+) := \inf_{M_z \in \mathcal{P}(S, S^+)} \mathbb{E}_{M_z} \left[ c(z, z^+) \right] Wc(S,S+):=MzP(S,S+)infEMz[c(z,z+)]

W d ( S , S + ) : = inf ⁡ M x ∈ P ( S , S + ) E M x [ d ( x , x + ) ] W_d(S, S^+) := \inf_{M_x \in \mathcal{P}(S, S^+)} \mathbb{E}_{M_x} \left[ d(x, x^+) \right] Wd(S,S+):=MxP(S,S+)infEMx[d(x,x+)]

其中, M z M_z Mz M x M_x Mx 分别是嵌入空间和输入空间上的测度; P ( S , S + ) \mathcal{P}(S, S^+) P(S,S+) S S S S + S^+ S+ 的联合分布。然后,放松的最坏情况问题可以表述为:

θ ∗ = min ⁡ θ sup ⁡ S + ∈ D E L task ( θ ; S + ) \theta^* = \min_\theta \sup_{S^+ \in \mathcal{D}} \mathbb{E}_{L_{\text{task}}(\theta; S^+)} θ=θminS+DsupELtask(θ;S+)

其中,

D : = { S + : W c ( S , S + ) ≤ ρ , W d ( S , S + ) ≤ ϵ } \mathcal{D} := \{ S^+ : W_c(S, S^+) \leq \rho, W_d(S, S^+) \leq \epsilon \} D:={S+:Wc(S,S+)ρ,Wd(S,S+)ϵ}

D \mathcal{D} D 涵盖了在嵌入空间中 S S S ρ \rho ρ 距离内,以及在输入空间中 S S S ϵ \epsilon ϵ 距离外的稳健区域,这里 W c W_c Wc W d W_d Wd 分别是嵌入空间和输入空间下的Wasserstein距离度量。

对于深度神经网络,当 ρ \rho ρ ϵ \epsilon ϵ 任意取值时,等式(12)是不可行的。因此,我们考虑其拉格朗日松弛形式,固定惩罚参数 α ≥ 0 \alpha \geq 0 α0 β ≥ 0 \beta \geq 0 β0

min ⁡ θ sup ⁡ S + E [ L task ( θ ; x + ) − λ W c , d ( S , S + ) ] \min_\theta \sup_{S^+} \mathbb{E} \left[ L_{\text{task}}(\theta; x^+) - \lambda W_{c, d}(S, S^+) \right] θminS+supE[Ltask(θ;x+)λWc,d(S,S+)]

其中,

W c , d ( S , S + ) : = α W c ( S , S + ) + β W d ( S , S + ) W_{c, d}(S, S^+) := \alpha W_c(S, S^+) + \beta W_d(S, S^+) Wc,d(S,S+):=αWc(S,S+)+βWd(S,S+)

f α , β ( θ ; c , x ) f_{\alpha, \beta}(\theta; c, x) fα,β(θ;c,x) 是等式(12)中的稳健替代形式。根据[57],当 α \alpha α 足够大且Lipschitz平滑性假设成立时, f a f_a fa 关于 θ \theta θ 是平滑的。由于 c c c θ \theta θ 相互独立, f α , β f_{\alpha, \beta} fα,β 关于 θ \theta θ 也是平滑的。梯度可以计算为:

∇ θ f α , β ( θ ; c , x ) = ∇ θ L task ( θ ; x ∗ ) \nabla_\theta f_{\alpha, \beta}(\theta; c, x) = \nabla_\theta L_{\text{task}}(\theta; x^*) θfα,β(θ;c,x)=θLtask(θ;x)

其中,

x ∗ ( x ; θ , c ) = arg ⁡ max ⁡ x + [ L task ( θ ; x + ) − L c , d ] = arg ⁡ max ⁡ x + L ADA ( θ , c , x + , z + ) x^*(x; \theta, c) = \arg\max_{x^+} \left[ L_{\text{task}}(\theta; x^+) - L_{c, d} \right] = \arg\max_{x^+} L_{\text{ADA}}(\theta, c, x^+, z^+) x(x;θ,c)=argx+max[Ltask(θ;x+)Lc,d]=argx+maxLADA(θ,c,x+,z+)

正是等式(3)中定义的对抗性扰动。

6 实现细节

任务模型。我们根据各自的特点为三个数据集设计了特定的任务模型,并采用了不同的训练策略。

在Digits数据集中,模型架构是conv-pool-conv-pool-fc-fc-softmax。包含两个5x5的卷积层,分别有64和128个通道。每个卷积层后面跟着一个2x2的最大池化层。两个全连接(FC)层的大小分别是1024,softmax层的大小是10。

在CIFAR-10-C数据集中,我们使用了宽度为4的16层Wide Residual Network (WRN)。第一层是一个3x3的卷积层,它将原始的3通道图像转换为16通道的特征图。然后特征通过三组3x3卷积层,每组包含两个块,每个块由两个相同通道数的卷积层组成,通道数分别是{64, 128, 256}。每个卷积层后面跟着批量归一化(BN)层。在第三组的输出后附加了一个8x8的平均池化层。最后,一个大小为10的softmax层预测类别分布。

在SYTHIA数据集中,我们使用了带有ResNet-50骨干的FCN-32s。模型从ResNet-50开始。附加了一个1x1的卷积层,用14个通道预测每个粗糙输出位置的每个类别的分数。跟随一个反卷积层,通过双线性插值将粗糙输出上采样到原始大小。

Wasserstein自动编码器。我们遵循[66]实现WAEs,但根据三个数据集的特点稍微修改了架构。

在Digits数据集中,编码器和解码器由FC层构建。编码器由两个FC层组成,大小分别为400和20。相应地,解码器由两个FC层组成,大小分别为400和3072。鉴别器由两个FC层组成,大小分别为128和1。架构如图5(a)所示。

在CIFAR-10-C数据集中,编码器从四个卷积层开始,通道为{16, 32, 32, 32}。后面跟着两个FC层,大小为1024和512。相应地,解码器从两个FC层开始,大小分别为512和1024。后面跟着四个转置卷积层,通道为{32, 32, 16, 3}。除解码器的最后一层外,每层后面都跟着BN。鉴别器由两个FC层组成,大小分别为128和1。架构如图5(b)所示。

在SYTHIA数据集中,编码器从三个卷积层开始,通道为{32, 64, 128}。后面跟着两个FC层,大小为{3840, 512}。相应地,解码器从两个FC层开始,大小分别为{512, 3840}。后面跟着三个转置卷积层,通道为{64, 32, 3}。除解码器的最后一层外,每层后面都跟着BN。鉴别器由三个FC层组成,大小为{512, 512, 1}。架构如图5©所示。

我们使用Adam优化器训练WAEs。学习率为Digits为0.001,CIFAR-10-C和SYTHIA为0.0001。训练周期为Digits为20,CIFAR-10-C为100,SYTHIA为200。

7 实验

我们首先在7.1节中介绍实验设置。在7.2节中,我们进行了详细的消融研究,以验证所提出的放松、元学习方案的效率、关键超参数的选择和权衡,以及所提出的不确定性量化的有效性。在7.3节中,我们在基准数据集上与现有技术进行了比较。在7.4节中,我们进一步评估了我们的方法在少样本领域适应性方面的表现。

7.1 数据集和设置

数据集和设置。(1) Digits由五个子数据集组成:MNIST、MNIST-M、SVHN、SYN和USPS,每个数据集可以看作是一个不同的领域。这些数据集中的每张图像都包含一个不同风格的单个数字。该数据集主要用于消融研究。我们使用MNIST训练集的前10,000个样本进行训练,并在所有其他领域评估模型。(2) CIFAR-10-C是一个鲁棒性基准,包含19种类型的腐败,这些腐败的五个严重程度应用于CIFAR-10的测试集。腐败来自四个主要类别:噪声、模糊、天气和数字。每种腐败都有五个级别的严重程度,“5”表示最腐败的一个。所有模型都在CIFAR-10上训练,并在CIFAR-10-C上评估。(3) SYNTHIA是一个合成数据集,用于驾驶场景中的语义分割。该数据集由相同的交通情况组成,但在不同的位置(选择了高速公路、类似纽约的城市和古老的欧洲小镇)以及不同的天气/照明/季节条件下(选择了黎明、雾、夜晚、春季和冬季)。按照[70]中的协议,我们只使用左前摄像头的图像,并从每个源领域随机采样900张图像。

评估指标。对于Digits和CIFAR-10-C,我们计算每个未见领域的平均准确率。对于CIFAR-10-C,准确率可能不足以全面评估模型的性能,而没有衡量相对于基线模型(ERM)和在干净数据集上评估的相对误差,即CIFAR-10的测试集没有任何腐败。受到[20]中提出的鲁棒性指标的启发,我们制定了两个指标来评估领域泛化背景下对图像腐败的鲁棒性:平均腐败误差(mCE)和相对mCE(RmCE)。它们定义为:

mCE = 1 N ∑ i = 1 N E f i − E ERM i , RmCE = 1 N ∑ i = 1 N ( E f i − E f clean E ERM i − E ERM clean ) \text{mCE} = \frac{1}{N} \sum_{i=1}^{N} E_{f_i} - E_{\text{ERM}_i}, \quad \text{RmCE} = \frac{1}{N} \sum_{i=1}^{N} \left(\frac{E_{f_i} - E_{f_{\text{clean}}}}{E_{\text{ERM}_i} - E_{\text{ERM clean}}}\right) mCE=N1i=1NEfiEERMi,RmCE=N1i=1N(EERMiEERM cleanEfiEfclean)

其中N是腐败的数量。mCE用于评估分类器f与ERM的鲁棒性比较。RmCE衡量与干净数据的相对鲁棒性。对于SYTHIA,我们计算每个未见领域的标准平均交并比(mIoU)。

7.2 消融研究

在本节中,我们进行实验以评估所提出的放松项Lrelax(公式3)、元学习训练方案(第3.3节)、三个关键超参数(K、a和b)以及所提出的不确定性量化(第4节)的效果。

Lrelax的验证。为了直观了解Lrelax如何影响增强领域S+的分布,我们使用t-SNE[37]在嵌入空间中可视化有无Lrelax的S+。结果分别显示在图6(b)和6©中。我们观察到,使用Lrelax的S+的凸包覆盖的区域比不使用Lrelax的S+的凸包大。这表明S+包含更多的分布变化,并且与未见领域更好地重叠。此外,我们计算Wasserstein距离来定量测量S和S+之间的差异。S和使用Lrelax的S+之间的距离为0.078,如果不使用Lrelax,则距离降低到0.032,表明引入Lrelax改进了58.9%。这些结果表明,Lrelax能够将S+从S推开,保证了输入空间中显著的领域传输。

元学习方案的验证。我们的方法有无元学习(ML)方案的比较结果列在表1和4中。我们观察到,借助这种元学习方案,Digits和CIFAR-10-C的平均准确率分别提高了0.94%和1.37%。特别是,两种未见腐败的结果在图7中显示。如图所示,元学习方案可以显著降低方差,并在所有严重程度级别上获得更好的性能。实验结果证明了元学习方案在进行具有挑战性的对抗性领域增强时提高训练稳定性和分类准确率的关键作用。


K、a和b的超参数调整。我们研究了三个重要超参数的影响:增强领域数量(K)、源和增强领域在嵌入空间中的距离(a)以及源和增强领域之间的偏差(b)。我们在图8中绘制了不同K、a和b下的准确率曲线。在图8(左)中,我们发现当K=3时准确率达到峰值,随着K的增加而持续下降。这是因为超过某个阈值的过多的对抗样本将增加不稳定性并降低模型的鲁棒性。由于增强和源领域之间的距离随着K的增加而增加,大的K可能会破坏语义一致性的约束,导致模型训练效果变差。在图8(中)中,我们发现当a=1.0时准确率达到峰值,随着a的增加而持续下降。这是因为大的a将使源和增强领域在嵌入空间过于接近,导致领域传输有限。在图8(右)中,我们观察到当 b = 2.0 ∗ 1 0 − 3 b=2.0*10^-3 b=2.0103时准确率达到峰值,并且随着b的增加略有下降。这是因为大的b将产生距离源S过远的领域,甚至在嵌入空间的流形之外。

不确定性量化的验证。我们在MNIST[28]上可视化了不同训练迭代T时的特征扰动 ∣ e ∣ = ∣ h + − h ∣ |e| = |h^+ -h| e=h+h 和域的嵌入。我们使用t-SNE[37]在嵌入空间中可视化了有无不确定性评估的源和增强域。结果在图9中显示。在没有不确定性的模型中(左侧),特征扰动e是从N(0, I)中采样的,没有可学习的参数。在具有不确定性的模型中(右侧),我们观察到大多数扰动位于背景区域,这在保持类别不变的情况下增加了S+的变化。因此,具有不确定性的模型可以在课程学习方案中创建大领域传输,在看不见的领域实现安全增强和提高准确性。我们在图10中可视化了 y + y^+ y+的密度。由此可见,具有不确定性的模型可以显著增加标签空间。


7.3 单一领域泛化评估

我们将我们的方法与以下五种最先进的方法进行了比较。(1) 经验风险最小化(ERM)[26],[67],是在没有辅助损失和数据增强方案的情况下,用交叉熵损失训练的模型。(2) CCSA[43]使用语义对齐来正则化学习到的特征子空间,以实现领域泛化。(3) d-SNE[72]最小化同类样本之间的最大距离,并最大化不同类样本之间的最小距离。(4) GUD[70]提出了一种用于单一领域泛化的对抗性数据增强方法,这是我们方法的相关研究。(5) JiGen[5]同时学习分类和预测打乱图像块的顺序,以实现领域泛化。

在Digits上的比较。我们在MNIST上训练所有模型,并在未见领域,即MNIST-M、SVHN、SYN和USPS上测试它们。我们在表1中报告了结果。我们观察到,我们的方法在SVHN、MNIST-M和SYN上以较大优势超过了GUD。在USPS上的改进不如其他领域显著,主要是因为它与MNIST非常相似。相反,CCSA和d-SNE在USPS上取得了较大的改进,但在其他领域表现较差。Uncertain SDG在SYN和平均准确率上分别超过了SDG 8.1%和1.8%。我们还比较了Uncertain SDG和SDG在模型大小和训练时间方面的性能。Digits上的结果在表2中显示。正如观察到的,Uncertain SDG可以减少约25%的参数和约30%的训练时间。这些强有力的结果证明了所提出的不确定单一领域泛化的效率。

在CIFAR-10-C上的比较。我们在干净数据上训练所有模型,即CIFAR-10,并在腐败数据上测试它们,即CIFAR-10-C。在这种情况下,总共有19个未见过的测试领域。表3显示了CIFAR-10-C在五个腐败严重程度级别上的结果。可以看到,随着严重程度级别的增加,GUD和我们方法之间的差距越来越大,我们的方法可以在所有级别上显著降低标准差。此外,我们在表4中展示了每种腐败最严重情况下的结果。我们观察到,我们的方法在大多数腐败情况下明显优于其他方法。特别是在Snow(雪)、Glass blur(玻璃模糊)、Pixelate(像素化)和与噪声相关的腐败情况下,我们的方法比ERM[26]提高了10%以上。更重要的是,我们的方法在mCE和RmCE上具有最低的值,表明其对图像腐败的强大鲁棒性。


在SYTHIA上的比较。在这个实验中,高速公路是源领域,而类似纽约的城市和古老的欧洲小镇是未见过的目标领域。我们在表5中报告了语义分割结果,并在图11中展示了一些示例。未见过的领域来自不同的地点和其他条件。我们观察到,我们的方法在三个源领域上的平均mIoUs上优于ERM[26]和GUD[70],表明其应对地点、天气和时间变化的能力。与另外两个数据集相比,与ERM[26]和GUD[70]的改进并不显著,这主要是因为训练图像数量有限和对未见领域的高依赖性。在大多数未见环境中,不确定SDG优于SDG。结果表明,不确定性量化可以进一步改善对未见领域的泛化。


7.4 少样本领域适应性评估

尽管我们的方法旨在解决单一领域泛化问题,如第3.3节所述,我们也展示了我们的方法可以轻松应用于少样本领域适应性[42]。

设置。在少样本学习中,模型通常首先在源领域S上预训练,然后在目标领域T上微调。更具体地说,我们首先使用S上的所有训练图像训练我们的模型。然后我们随机选取T中每个类别的7或10张图像。这些图像用于使用学习率为0.0001和批量大小为16的少量样本微调预训练模型。

讨论。我们的方法与少样本领域适应性的最新方法进行了比较。我们还报告了一些无监督方法的结果,这些方法在训练中使用了目标领域中的图像。在MNIST、USPS和SVHN上的结果分别列在表6中。我们观察到,我们的方法与FADA[42]和CCSA[43]相比获得了有竞争力的结果。我们的方法还在一些利用目标领域未标记图像的无监督方法中表现更好。Uncertain SDG在三个任务的平均值上取得了最佳性能。在最困难的任务(M!S)上的结果甚至与SBADA[51]相当。此外,值得注意的是,FADA[42]和CCSA[43]的训练方式是S和T中的样本强耦合的。这意味着当目标领域发生变化时,需要训练一个全新的模型。另一方面,对于新的目标领域,我们的方法只需要用少量样本在少量迭代内微调预训练模型。这证明了我们方法的高灵活性。

8 结论

在本文中,我们提出了基于元学习的对抗性领域增强来解决单一领域泛化问题。核心思想是使用基于元学习的方案有效地组织训练增强的“虚构”领域,这些领域来自源领域之外,并通过对抗性训练创建。我们通过整合不确定性量化进一步改进了我们的方法,以实现广泛和安全的领域泛化。除了通过这些实验取得的优异表现外,一系列消融研究进一步验证了我们方法中关键组件的有效性。在未来,我们期望将我们的工作扩展到半监督学习或多模态学习中的知识转移。

关注公众号:AI前沿速递,获取更多优质资源!


http://www.mrgr.cn/news/30860.html

相关文章:

  • 系统启动时将自动加载环境变量,并后台启动 MinIO、Nacos 和 Redis 服务
  • 「IDE」集成开发环境专栏目录大纲
  • C++初阶——list
  • 安卓aab包的安装教程,附带adb环境的配置
  • 【Pikachu】SQL-Inject实战
  • ReactPress:功能全面的开源发布平台
  • 24:RTC实时时钟
  • 【学术会议:中国杭州,机器学习和计算机应用面临的新的挑战问题和研究方向】第五届机器学习与计算机应用国际学术会议(ICMLCA 2024)
  • 第十九节:学习WebFlux与前端响应式-非阻塞-流式通讯(自学Spring boot 3.x的第四天)
  • 平价头戴式蓝牙耳机有哪些?四款公认平价性能超强品牌机型推荐
  • 第六天旅游线路预览——从景区门口到天山天池
  • JavaScript可视化
  • 【Unity踩坑】UI Image的fillAmount不起作用
  • 创新的护盾:知识产权、商标与软件著作权的全方位解读
  • 【QGIS】(六)对图层添加属性并赋值行号(可作为导入数据的主键使用)
  • 大厂常问的MySQL事务隔离到底怎么回答
  • LabVIEW闪退
  • AutoX.js向后端传输二进制数据
  • js 深入理解类-class
  • Python数据处理入门教程!
  • 低侧单向电流、单电源检测电路
  • Redis系列---Redission分布式锁
  • 深度学习激活函数
  • 力扣560 和为k的子数组 Java版本
  • CCRC-CDO首席数据官:未成年人首次上网年龄持续降低
  • vmware官网下载