SelMatch:最新数据集蒸馏,仅用5%训练数据也是可以的 | ICML‘24
数据集蒸馏旨在从大型数据集中合成每类(
IPC
)少量图像,以在最小性能损失的情况下近似完整数据集训练。尽管在非常小的IPC
范围内有效,但随着IPC
增加,许多蒸馏方法变得不太有效甚至性能不如随机样本选择。论文对各种IPC
范围下的最先进的基于轨迹匹配的蒸馏方法进行了研究,发现这些方法在增加IPC
的情况下很难将更难样本的复杂、罕见特征纳入合成数据集中,导致了容易和难的测试样本之间持续存在的覆盖差距。受到这些观察的启发,论文提出了SelMatch
,一种能够有效随IPC
扩展的新型蒸馏方法。SelMatch
使用基于选择的初始化和通过轨迹匹配进行部分更新来管理合成数据集,以适应针对IPC
范围定制的期望难度级别。在对CIFAR-10
/100
和TinyImageNet
的测试中,SelMatch
在5%
到30%
的子集比率上始终优于主流的仅选择和仅蒸馏方法。来源:晓飞的算法工程笔记 公众号,转载请注明出处
论文: SelMatch: Effectively Scaling Up Dataset Distillation via Selection-Based Initialization and Partial Updates by Trajectory Matching
- 论文地址:https://arxiv.org/abs/2406.18561
- 论文代码:https://github.com/Yongalls/SelMatch
Introduction
数据集缩减对于数据高效学习至关重要,它涉及从大型数据集中合成或选择较少数量的样本,同时确保在这个缩减后的数据集上训练的模型性能与在完整数据集上训练的相比保持可比性或性能降低最小化。这种方法解决了在大型数据集上训练神经网络时所面临的挑战,如高计算成本和内存需求。
在这一领域中一种重要的技术是数据集蒸馏,也被称为数据集凝聚。这种方法将大型数据集提炼为一个更小的合成数据集。与核心集选择方法相比,数据蒸馏在图像分类任务中表现出显著的性能,特别是在极小规模上。例如,匹配训练轨迹(MTT
)算法仅使用CIFAR-10
数据集的1%
,在简单的ConvNet
上实现了71.6%
的准确率,接近完整数据集的84.8%
准确率。这种显著的效率来自于优化过程,在这个过程中,合成样本在连续空间中被最优地学习,而不是直接从原始数据集中选择。
然而,最近的研究表明,随着合成数据集的规模或每类图像(IPC
)的增加,许多数据集蒸馏方法失去了效果,甚至表现不如随机样本选择。这一现象令人费解,考虑到蒸馏相对于离散样本选择提供的更大优化自由度。具体来说,DATM
通过分析最先进的MTT
方法的训练轨迹来调查这一现象,指出了在合成数据集过程中方法所关注的训练轨迹阶段如何显著影响蒸馏数据集的有效性。特别是,在早期轨迹中学习到的简单模式和在后期阶段学习到的困难模式明显影响了MTT
在不同IPC
情况下的性能。
论文进一步通过比较在不同IPC
情况下,MTT
方法涵盖合成数据集中简单和困难真实样本的情况,发现随着IPC
增加,蒸馏方法未能充分将困难样本的稀有特征纳入合成数据集中,这导致了简单样本与困难样本之间的一致覆盖差距。在更高IPC
范围内,数据集蒸馏方法效果降低的部分原因是它们倾向于专注于数据集中更简单、更具代表性的特征。相反,随着IPC
的增加,涵盖更难、更稀有的特征对于在缩减数据集上训练的模型的泛化能力变得更加关键,这点在数据选择研究中得到了实证和理论上的验证。
受到这些观察的启发,论文提出了一种新颖的方法,名为SelMatch
,作为有效扩展数据集蒸馏方法的解决方案。随着IPC
的增加,合成数据集应该涵盖真实数据集更复杂和多样化的特征,具有适当的难度水平。通过基于选择的初始化和通过轨迹匹配的部分更新,管理合成数据集的期望难度级别。
- 基于选择的初始化:为克服传统轨迹匹配方法过度集中于简单模式的局限性,即使
IPC
增加,使用针对每个IPC
进行优化的适当难度级别的真实图像来初始化合成数据集。传统的轨迹匹配方法通常使用随机选择的样本或接近类中心的简单或代表性样本来初始化合成数据集,以提高蒸馏的收敛速度。论文的方法则使用精心选择的子集来初始化合成数据集,该子集包含适合合成数据集大小的样本,其难度级别恰到好处。这种方法确保了随后的蒸馏过程以针对特定IPC
范围优化难度级别的样本开始。实验结果显示,基于选择的初始化在性能表现中扮演重要角色。 - 部分更新:在传统的数据集蒸馏方法中,合成数据集中的每个样本都在蒸馏迭代过程中进行更新。然而,随着蒸馏迭代次数的增加,该过程会不断降低合成数据集中样本的多样性,因为蒸馏提供的信号偏向于全数据集中的简单模式。因此,为了保持困难样本的稀有和复杂特征(这些特征对于模型在较大
IPC
范围内的泛化能力至关重要),论文引入了对合成数据集的部分更新。主要思想是保持合成数据集中的固定部分不变,同时通过蒸馏信号更新其余部分,而未更改部分的比例根据IPC
进行调整。实验结果显示,这样的部分更新对于有效扩展数据集蒸馏起到了重要作用。
在CIFAR-10
/100
和TinyImageNet
上评估了SelMatch
,并展示了在从5%
到30%
的子集比例设置中,与最先进的仅选择和仅蒸馏方法相比的优越性。值得注意的是,在CIFAR-100
中,当每类有50
张图像(10%
比例)的情况下,与领先方法相比,SelMatch
的测试准确率提高了3.5
%。
Related Works
数据集减少中的两种主要方法:样本选择和数据集蒸馏。
在样本选择中,主要有两种方法:基于优化和基于评分的选择。
基于优化的选择旨在识别一个小的核心集,有效地代表完整数据集的各种特征。例如,Herding
和K-center
选择一个近似于完整数据集分布的核心集。Craig
和GradMatch
寻求一个核心集,在神经网络训练中,它能够最小化与完整数据集的平均梯度差异。尽管在小到中等IPC
范围内有效,但是与基于评分的选择相比,这些方法在可伸缩性和性能方面常常面临问题,特别是随着IPC
的增加。
基于评分的选择能够根据神经网络训练中每个实例的难度或影响分配值。例如,Forgetting
通过计算先前被正确分类但在之后的多个时期被误分类的次数来评估实例的学习难度。C-score
将困难性评估为从训练集中删除样本时误分类的概率。这些方法优先考虑困难样本,捕捉罕见和复杂的特征,并在较大的IPC
规模下优于基于优化的选择方法。这些研究表明,随着IPC
的增加,引入更难或更稀有的特征对于模型的泛化能力的提高变得越来越重要。
数据集蒸馏旨在创建一个小的合成集 S \mathcal{S} S ,以便在 S \mathcal{S} S 上训练的模型 θ S \theta^\mathcal{S} θS 能够实现良好的泛化性能,在完整数据集 T \mathcal{T} T 上表现良好:
S ∗ = arg min S L T ( θ S ) , with θ S = arg min θ L S ( θ ) \mathcal{S^*} = \underset{\mathcal{S}}{\text{arg min}} \mathcal{L}^\mathcal{T}(\theta^\mathcal{S}), \text{ with } \theta^\mathcal{S} = \underset{\theta}{\text{arg min}} \mathcal{L}^\mathcal{S}(\theta) S∗=Sarg minLT(θS), with θS=θarg minLS(θ)
这里, L T \mathcal{L}^\mathcal{T} LT 和 L S \mathcal{L}^\mathcal{S} LS 分别是 T \mathcal{T} T 和 S \mathcal{S} S 上的损失。为了应对双层优化的计算复杂性和内存需求,现有的工作采用了两种方法:基于替代的匹配和基于核的方法。基于替代的匹配将复杂的原始目标替换为更简单的代理任务。例如,DC
、DSA
和MTT
旨在通过匹配梯度或轨迹,使在 S \mathcal{S} S 上训练的模型 θ S \theta^\mathcal{S} θS 的轨迹与完整数据集 T \mathcal{T} T 的轨迹保持一致。DM
确保 S \mathcal{S} S 和 T \mathcal{T} T 在特征空间中具有相似的分布。另外,基于核的方法利用核方法近似神经网络对 θ S \theta^\mathcal{S} θS 的训练,并为内部优化推导出闭式解。例如,KIP
使用神经切线核(NTK
)进行核岭回归,FrePo
通过仅专注于最后一个可学习层的回归来减少训练成本。然而,随着IPC
的增加,基于替代的匹配和基于核的方法在可扩展性或性能方面都难以有效扩展。DC-BENCH
指出,与高IPC
情况下的随机样本选择相比,这些方法性能不佳。
近期的研究致力于解决最先进的MTT
方法的可扩展性问题,主要关注计算方面,通过降低内存需求,或性能方面,通过在后续时期利用完整数据集的训练轨迹。具体而言,DATM
发现与早期训练轨迹保持一致可增强在低IPC
制度下的性能,而与后期轨迹保持一致对于高IPC
制度下更有益。基于这一观察,DATM
根据IPC
优化了轨迹匹配范围,从而自适应地将专家轨迹中更容易或更困难的模式纳入,从而提高了MTT
的可扩展性。虽然DATM
可有效地确定轨迹匹配范围的下限和上限,但在这些范围之外的匹配损失变化趋势上,明确量化或搜寻所需的训练轨迹困难水平仍然是一个具有挑战性的任务。相比之下,论文的SelMatch
利用基于选择的初始化和通过轨迹匹配进行部分更新,以纳入适合每个IPC
的难样本的复杂特征。尤其是,论文的方法引入了一种新颖的策略,即针对每个IPC
范围为合成样本初始化定制的困难水平,这是在以往的数据集蒸馏文献中尚未探讨的。此外,与专门设计用于增强MTT
的DATM
不同,SelMatch
的主要组成部分,即基于选择的初始化和部分更新,在各种蒸馏方法中具有更广泛的适用性。
Motivation
Preliminary
最先进的数据集蒸馏方法MTT
将作为基准,用于分析传统数据集蒸馏方法在大IPC
范围内的局限性。MTT
的目标是通过匹配真实数据集 D real \mathcal{D}_\textrm{real} Dreal 和合成数据集 D syn \mathcal{D}_\textrm{syn} Dsyn 之间的训练轨迹来生成合成数据集。在每个蒸馏迭代中,合成数据集会被更新,以最小化匹配损失,该损失以真实数据集 D real \mathcal{D}_\textrm{real} Dreal 的训练轨迹 { θ t ∗ } \{\theta_t^*\} {θt∗} 和合成数据集 D syn \mathcal{D}_\textrm{syn} Dsyn 的训练轨迹 { θ ^ t } \{\hat{\theta}_t\} {θ^t} 为定义。
L ( D syn , D real ) = ∥ θ ^ t + N − θ t + M ∗ ∥ 2 2 ∥ θ t ∗ − θ t + M ∗ ∥ 2 2 , \begin{equation} \mathcal{L}(\mathcal{D}_\textrm{syn}, \mathcal{D}_\textrm{real})= \frac{\|\hat{\theta}_{t+N} - \theta^*_{t+M}\|^2_2}{\|\theta^*_{t} - \theta^*_{t+M}\|^2_2}, \end{equation} L(Dsyn,Dreal)=∥θt∗−θt+M∗∥22∥θ^t+N−θt+M∗∥22,
其中, θ t ∗ \theta_t^* θt∗ 是在第 t t t 步上在 D real \mathcal{D}_\textrm{real} Dreal 上训练的模型参数。从 θ ^ t = θ t ∗ \hat{\theta}_{t}=\theta_t^* θ^t=θt∗ 开始, θ ^ t + N \hat{\theta}_{t+N} θ^t+N 是通过在 D syn \mathcal{D}_\textrm{syn} Dsyn 上训练 N N N 步后获得的模型参数,而 θ t + M ∗ {\theta}^*_{t+M} θt+M∗ 是在 D real \mathcal{D}_\textrm{real} Dreal 上训练 M M M 步后获得的参数。
Limitations of Traditional Methods in Larger IPC
首先分析MTT
生成的合成数据的模式如何随着每类图像(IPC
)的增加而演变。要使数据集蒸馏方法在更大的合成数据集中保持有效,蒸馏过程在每类图像增加时应继续向合成样本提供真实数据集的新颖和复杂模式。轨迹匹配方法在低IPC
水平上虽然处于最先进地位,但在实现这一目标方面还存在不足。
论文通过检查真实(测试)数据集的“覆盖率”来展示这一点。“覆盖率”被定义为在特征空间内与合成样本距离小于一定半径( r r r )的真实样本的比例,半径 r r r 被设置为特征空间内真实训练样本的平均最近邻距离。较高的覆盖率表明合成数据集捕获了真实样本的多样特征,使得在合成数据集上训练的模型能够学习到真实数据集中不仅是简单,还有复杂模式。
图1a
(左)展示了随着CIFAR-10
数据集的每类图像数量(IPC
)增加,覆盖率如何变化。此外,在图1a
(右)中,针对两组样本进行了分析。“简单”50%
和“困难”50%
(根据遗忘分数对真实样本进行的难度衡量)。
观察结果显示,使用MTT
的覆盖率并没有有效地随IPC
扩展,始终低于随机选择的覆盖率。此外,困难样本组的覆盖率远远低于简单样本组的覆盖率。这表明,即使IPC
增加,MTT
也无法有效地将困难和复杂的数据模式嵌入到合成样本中,这可能是MTT
性能不佳的缩放原因。而论文的方法SelMatch
展示了更优越的总体覆盖率,特别是在IPC
增加时,困难组覆盖率明显提升。
另一个重要发现是,随着蒸馏迭代次数的增多,MTT
的覆盖率在减少,如图1b
所示。这一观察进一步表明,传统的蒸馏方法主要在多次迭代过程中捕获“简单”模式,使得合成数据集随着蒸馏迭代次数的增加变得缺乏多样性。相比之下,即使迭代次数增加,使用SelMatch
的覆盖率仍然保持稳定。如图1c
所示,覆盖率也影响测试准确性。简单测试样本组和困难测试样本组之间覆盖率显著差异导致两组之间测试准确性存在显著差距。SelMatch
提高了两组的覆盖率,从而提高了总体测试准确性,特别是在IPC
增加时,对困难组的测试准确性有所提升。
Main Method: SelMatch
图2
展示了SelMatch
的核心思想,该方法将基于选择的初始化与通过轨迹匹配进行部分更新相结合。传统的轨迹匹配方法通常使用随机选择的真实数据集 D real \mathcal{D}_\textrm{real} Dreal 的子集对合成数据集 D syn \mathcal{D}_\textrm{syn} Dsyn 进行初始化,没有任何特定的选择标准。在每次蒸馏迭代过程中,整个 D syn \mathcal{D}_\textrm{syn} Dsyn 都会被更新,以最小化定义在公式1
中的匹配损失 L ( D syn , D real ) \mathcal{L}(\mathcal{D}_\textrm{syn}, \mathcal{D}_\textrm{real}) L(Dsyn,Dreal) 。
相比之下,SelMatch
首先使用精心选择的子集 D initial \mathcal{D}_\textrm{initial} Dinitial 对 D syn \mathcal{D}_\textrm{syn} Dsyn 进行初始化,该子集包含量身定制的适合于合成数据集规模的样本,具有适当的困难级别。然后,在每次蒸馏迭代中,SelMatch
仅更新 D syn \mathcal{D}_\textrm{syn} Dsyn 的一个特定部分,表示为 α ∈ [ 0 , 1 ] \alpha\in[0,1] α∈[0,1] ,(称为 D distill \mathcal{D}_\textrm{distill} Ddistill ),而数据集的剩余部分(称为 D select \mathcal{D}_\textrm{select} Dselect )保持不变。这个过程旨在最小化公式1
中的相同匹配损失 L ( D syn , D real ) \mathcal{L}(\mathcal{D}_\textrm{syn}, \mathcal{D}_\textrm{real}) L(Dsyn,Dreal) ,但现在 D syn \mathcal{D}_\textrm{syn} Dsyn 是 D distill \mathcal{D}_\textrm{distill} Ddistill 和 D select \mathcal{D}_\textrm{select} Dselect 的组合。
图1
中的一个重要观察是,传统的轨迹匹配方法倾向于关注完整数据集中简单和具代表性的模式,而不是复杂的数据模式,导致在更大的IPC
设置中扩展性较差。为了克服这一问题,论文提出了使用一个经过精心选择的难度级别对合成数据集 D syn \mathcal{D}_\textrm{syn} Dsyn 进行初始化,该难度级别在IPC
增加时包括来自真实数据集更复杂的模式。因此,挑战在于如何选择真实数据集 D real \mathcal{D}_\textrm{real} Dreal 的一个子集,其复杂度水平适当,同时考虑 D syn \mathcal{D}_\textrm{syn} Dsyn 的规模。
为了解决这个问题,论文设计了一个滑动窗口算法。根据预先计算的困难度分数(在CIFAR-10
/100
上利用预先计算的C-score
,而在Tiny Imagenet
上使用Forgetting score
作为困难分数。),按照困难程度的降序(从最困难到最容易)排列训练样本。然后,通过在不同起始点上的每个窗口子集训练模型来比较测试准确度,评估这些样本的窗口子集。对于给定阈值 β ∈ [ 0 , 100 ] % \beta\in[0,100]\% β∈[0,100]% ,在排除最困难的 β \beta β %样本后,窗口子集包括来自 [ β , β + r ] [\beta, \beta+r] [β,β+r] %范围内的样本,其中 r = ( ∣ D syn ∣ / ∣ D real ∣ ) × 100 % r=(|\mathcal{D}_\textrm{syn}|/|\mathcal{D}_\textrm{real}|)\times 100\% r=(∣Dsyn∣/∣Dreal∣)×100% , ∣ D syn ∣ |\mathcal{D}_\textrm{syn}| ∣Dsyn∣ 等于IPC
乘以类别数。在这里,确保每个窗口子集包含相同数量的来自每个类别的样本。
如图3
所示,窗口的起始点对应于困难程度的级别,显著影响模型的泛化能力(通过测试准确度来衡量)。特别是对于较小的窗口(5-10%
范围),测试准确度根据窗口起始位置的不同可以出现高达40%
的偏差。此外,表现最好的窗口子集,即实现最高测试准确度的子集,倾向于在子集大小增加时包含更困难的样本(较小的 β \beta β )。这符合这样一种直觉,即随着IPC
的增加,将来自真实数据集的复杂模式纳入模型可以增强其泛化能力。
基于这一观察,将 D syn \mathcal{D}_\textrm{syn} Dsyn 的初始化设置为 D initial \mathcal{D}_\textrm{initial} Dinitial ,其中 D initial \mathcal{D}_\textrm{initial} Dinitial 是由滑动窗口算法为给定 D syn \mathcal{D}_\textrm{syn} Dsyn 大小确定的表现最佳的窗口子集。这种方法确保了随后的提取过程从特定IPC
制度下经过优化的难度级别的图像开始。
在用滑动窗口算法选择的最佳窗口子集 D initial \mathcal{D}_\textrm{initial} Dinitial 对合成数据集 D syn \mathcal{D}_\textrm{syn} Dsyn 进行初始化后,下一个目标是通过数据集蒸馏来更新 D syn \mathcal{D}_\textrm{syn} Dsyn ,以便有效地将来自整个真实数据集 D real \mathcal{D}_\textrm{real} Dreal 的信息嵌入其中。传统上,匹配训练轨迹(MTT
)算法通过对 N N N 个模型更新进行反向传播,以最小化匹配损失公式1
,从而更新 D syn \mathcal{D}_\textrm{syn} Dsyn 中的所有样本。然而,如图1b
所示,这种方法倾向于数据集中更简单的模式,导致在连续提取迭代中覆盖范围的减少。因此,为了解决这个问题并保持一些真实样本的独特和复杂特征(对于模型在更大IPC
范围内的泛化能力至关重要),论文引入了对 D syn \mathcal{D}_\textrm{syn} Dsyn 的部分更新。
根据每个样本的难度分数,将初始合成数据集 D syn = D initial \mathcal{D}_\textrm{syn}=\mathcal{D}_\textrm{initial} Dsyn=Dinitial 划分为两个子集 D select \mathcal{D}_\textrm{select} Dselect 和 D distill \mathcal{D}_\textrm{distill} Ddistill 。子集 D select \mathcal{D}_\textrm{select} Dselect 包含 ( 1 − α ) × ∣ D syn ∣ (1-\alpha) \times |\mathcal{D}_\textrm{syn}| (1−α)×∣Dsyn∣ 个难度较高的样本,剩下的 α \alpha α 部分样本分配到 D distill \mathcal{D}_\textrm{distill} Ddistill ,其中 α ∈ [ 0 , 1 ] \alpha\in[0,1] α∈[0,1] 是根据IPC
调整的超参数。
在提取迭代过程中,保持 D select \mathcal{D}_\textrm{select} Dselect 不变,只更新 D distill \mathcal{D}_\textrm{distill} Ddistill 子集。更新的目标是最小化整个 D syn = D select ∪ D distill \mathcal{D}_\textrm{syn}=\mathcal{D}_\textrm{select}\cup \mathcal{D}_\textrm{distill} Dsyn=Dselect∪Ddistill 和 D real \mathcal{D}_\textrm{real} Dreal 之间的匹配损失,即:
KaTeX parse error: Undefined control sequence: \label at position 19: …egin{equation} \̲l̲a̲b̲e̲l̲{eq:partial_upd…
与最小化 L ( D distill , D real ) \mathcal{L}(\mathcal{D}_\textrm{distill}, \mathcal{D}_\textrm{real}) L(Ddistill,Dreal) 不同,仅更新部分 D syn \mathcal{D}_\textrm{syn} Dsyn 的损失策略鼓励 D distill \mathcal{D}_\textrm{distill} Ddistill 浓缩在 D select \mathcal{D}_\textrm{select} Dselect 中不存在的知识,从而丰富 D syn \mathcal{D}_\textrm{syn} Dsyn 中的整体信息。
在创建合成数据集 D syn \mathcal{D}_\textrm{syn} Dsyn 后,通过使用这个数据集训练一个随机初始化的神经网络来评估其有效性。通常情况下,先前的蒸馏方法采用了Dif- ferentiable Siamese Augmentation
(DSA
)来评估合成数据集。这种方法涉及比用于真实数据集常见的简单方法(如随机裁剪和水平翻转)更复杂的增强技术,在合成数据方面取得了更好的结果。这种提升的性能可能是因为合成数据集主要捕获了更简单的模式,使它们更适合于通过DSA
进行更强的增强。
然而,在整个合成数据集 D syn \mathcal{D}_\textrm{syn} Dsyn 上应用DSA
可能并非理想,特别是考虑到包含难以处理样本的子集 D select \mathcal{D}_\textrm{select} Dselect 的存在。为了解决这个问题,论文提出了一种专门针对论文的合成数据集定制的综合增强策略。具体而言,将DSA
应用于精炼部分 D distill \mathcal{D}_\textrm{distill} Ddistill ,并对选择的、更复杂的子集 D select \mathcal{D}_\textrm{select} Dselect 使用更简单、更传统的增强技术。这种综合方法旨在利用两种增强方法的优势,以提高合成数据集的整体性能。
将所有内容整合起来,SelMatch
在算法1
中进行了总结。
Experimental Results
如果本文对你有帮助,麻烦点个赞或在看呗~
更多内容请关注 微信公众号【晓飞的算法工程笔记】