MoH:将多头注意力(Multi-Head Attention)作为头注意力混合(Mixture-of-Head Attention)
摘要
https://arxiv.org/pdf/2410.11842?
在本文中,我们对Transformer模型的核心——多头注意力机制进行了升级,旨在提高效率的同时保持或超越先前的准确度水平。我们表明,多头注意力可以表示为求和形式。鉴于并非所有注意力头都具有同等重要性这一见解,我们提出了混合头注意力(MoH),这是一种将注意力头视为混合专家(MoE)机制中专家的新架构。MoH具有两大显著优势:首先,MoH使每个令牌能够选择适当的注意力头,从而在不影响准确度或增加参数数量的情况下提高推理效率。其次,MoH用加权求和替换了多头注意力中的标准求和,为注意力机制引入了灵活性,并释放了额外的性能潜力。在视觉Transformer(ViT)、扩散Transformer(DiT)和大型语言模型(LLMs)上的大量实验表明,MoH仅使用 50 % ∼ 90 % 50\% \sim 90\% 50%∼90%的注意力头便超越了多头注意力。此外,我们还证明了诸如LLaMA3-8B等预训练的多头注意力模型可以进一步微调成我们的MoH模型。值得注意的是,MoH-LLaMA3-8B在14个基准测试中实现了 64.0 % 64.0\% 64.0%的平均准确度,仅使用 75 % 75\% 75%的注意力头便比LLaMA3-8B高出 2.4 % 2.4\% 2.4%。我们认为,所提出的MoH是多头注意力的一个有前途的替代方案,并为开发先进且高效的基于注意力的模型奠定了坚实基础。
1 引言
自注意力机制被引入并成为Transformer(Vaswani等,2017)的基本组件以来,多头注意力一直是自然语言处理(Kenton&Toutanova,2019)和计算机视觉任务(Dosovitskiy等,2021)的标准架构。众所周知,使用多个头可以提高模型准确度。然而,并非所有注意力头都具有同等重要性。一些研究表明,可以在不影响准确度的情况下剪除许多注意力头。例如,Voita等(2019)提出了一种量化每个注意力头有用性的方法,并剪除了冗余的注意力头。同样,Michel等(2019)通过考察各种设置下大量剪枝的影响,对多个头的必要性提出了质疑。这些发现表明,传统的多头注意力包含冗余的注意力头。
此外,在多头注意力中,每个注意力头并行运行,最终输出是所有注意力头的和(请参阅第3.1节)。鉴于这些注意力头独立运行且部分可能冗余,我们认为构建动态注意力头路由机制是可能的。这种机制将使每个令牌能够自适应地选择适当的注意力头,从而在不影响准确度的情况下提高推理效率。
为此,我们引入了混合头注意力( M o H \mathrm{MoH} MoH),这是一种将多头注意力与混合专家(MoE)机制(Jacobs等,1991;Jin等,2024b)相结合的新架构。具体来说,我们提议在MoE框架中将注意力头视为专家。与 M o E \mathrm{MoE} MoE类似, M o H \mathrm{MoH} MoH由多个注意力头和一个为每个令牌激活前K个头的路由器组成。此外,我们用加权求和替换了多头注意力中的标准求和。这种设计带来了两大显著优势:首先,MoH允许每个令牌选择最相关的注意力头,从而在不牺牲准确度或增加参数的情况下提高推理效率。其次,通过将多头注意力中的标准求和替换为加权求和,MoH增强了注意力机制的灵活性,并提高了性能潜力。此外,为了有效捕获不同上下文中的通用知识,我们将部分注意力头指定为始终激活的共享头。
我们在各种流行的模型框架中评估了我们提出的MoH,包括用于图像分类的视觉Transformer(ViT)(Dosovitskiy等,2021)、用于类条件图像生成的带有Transformer的扩散模型(DiT)(Peebles&Xie,2023)以及用于语言任务的大型语言模型(LLMs)(Brown等,2020;OpenAI,2022;Ouyang等,2022)。我们表明,MoH仅使用 50 % ∼ 90 % 50\% \sim 90\% 50%∼90%的注意力头便实现了具有竞争力的性能,甚至超越了多头注意力。例如,MoH-ViT-B在ImageNet-1K(Deng等,2009)分类基准上实现了 84.9 % 84.9\% 84.9%和 84.7 % 84.7\% 84.7%的Top-1准确度,仅使用 75 % 75\% 75%和 50 % 50\% 50%的注意力头便超越了精心调整的多头注意力基线。
此外,我们还证明了诸如LLaMA3-8B(Dubey等,2024)等预训练的多头注意力模型可以进一步微调成我们的MoH模型。具体来说,仅使用原始LLaMA3预训练数据的约 3 % 3\% 3%(4000亿个令牌)进行微调,MoH-LLaMA3-8B在14个基准测试中实现了 64.0 % 64.0\% 64.0%的平均准确度,仅使用 75 % 75\% 75%的注意力头便比LLaMA3-8B高出 2.4 % 2.4\% 2.4%。这些结果表明,MoH是传统多头注意力的一个有前途的替代方案,为开发先进且高效的基于注意力的模型奠定了坚实基础。主要贡献总结如下:
- 我们提出了一种动态注意力头路由机制,使每个令牌能够自适应地选择适当的注意力头,从而在不增加参数数量的情况下提高模型性能和推理效率。
- 除了从头开始训练外,我们还证明了诸如LLaMA3-8B等预训练的多头注意力模型可以进一步微调成我们的MoH模型,极大地提高了所提出MoH方法的适用性。
- 在包括ViT、DiT和LLMs在内的各种流行模型框架中进行的广泛实验证实,MoH是传统多头注意力的一个有前途的替代方案,为开发先进且高效的基于注意力的模型奠定了坚实基础。
2 相关工作
多头注意力。Transformer(Vaswani等,2017)在自然语言处理和计算机视觉领域都获得了极大的关注和成功。Transformer的成功长久以来被归因于多头注意力机制(Cordonnier等,2020)。Vaswani等(2017)提出了多头注意力机制,通过允许多个注意力头对输入的不同低维投影进行操作,从而增强注意力层的表示能力。这些头的输出随后被拼接起来形成最终结果。另外,通过对输出投影矩阵按行分解,多头注意力可以表示为求和形式。在求和形式中,每个头并行操作,最终输出是所有头的和。受此观察启发,我们提出了MoH,一种动态注意力头路由机制,允许每个标记自适应地选择适当的头。
专家混合模型。专家混合(MoE)方法(Du等,2022;Lewis等,2021;Rajbhandari等,2022;Roller等,2021;Zhou等,2022;Jin等,2024b)被引入以在不增加计算成本的情况下扩展深度神经网络的容量。在这种方法中,对于每个输入,只有一部分参数(称为专家)被激活。Shazeer等(2017)首先在LSTM层之间引入了一个MoE层。Switch Transformer(Fedus等,2022)通过仅为每个标记选择Top-1专家进一步简化了门控机制。Gshard(Lepikhin等,2021)改进了Top-2专家路由策略。与强调在保持可控计算成本的同时实现高效参数扩展的MoE不同,所提出的MoH侧重于在不增加参数数量的情况下减少冗余注意力头的激活。
3 方法论
在这项工作中,我们的目标是减少冗余注意力头的激活,同时不增加参数数量。图1展示了标准多头注意力和我们提出的混合头注意力(MoH)之间的高级比较。
3.1 多头注意力
我们首先回顾由Vaswani等(2017)引入的标准多头注意力机制。多头注意力机制基于缩放点积注意力。具体而言,对于每个维度为 d i n d_{in} din的 T T T个标记 X ∈ R T × d i n \boldsymbol{X} \in \mathbb{R}^{T \times d_{in}} X∈RT×din和每个维度为 d i n d_{in} din的 T ′ T' T′个标记 X ′ ∈ R T ′ × d i n \boldsymbol{X}^{\prime} \in \mathbb{R}^{T^{\prime} \times d_{in}} X′∈RT′×din,缩放点积注意力计算如下:
Attention ( Q , K , V ) = Softmax ( Q K ⊤ d k ) V Q = X W Q , K = X ′ W K , V = X ′ W V \begin{array}{r} \text { Attention }(\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V})=\operatorname{Softmax}\left(\frac{\boldsymbol{Q} \boldsymbol{K}^{\top}}{\sqrt{d_{k}}}\right) \boldsymbol{V} \\ \boldsymbol{Q}=\boldsymbol{X} \boldsymbol{W}_{Q}, \boldsymbol{K}=\boldsymbol{X}^{\prime} \boldsymbol{W}_{K}, \boldsymbol{V}=\boldsymbol{X}^{\prime} \boldsymbol{W}_{V} \end{array} Attention (Q,K,V)=Softmax(dkQK⊤)VQ=XWQ,K=X′WK,V=X′WV
其中, W Q ∈ R d i n × d k \boldsymbol{W}_{Q} \in \mathbb{R}^{d_{in} \times d_{k}} WQ∈Rdin×dk, W K ∈ R d i n × d k \boldsymbol{W}_{K} \in \mathbb{R}^{d_{in} \times d_{k}} WK∈Rdin×dk,和 W V ∈ R d i n × d v \boldsymbol{W}_{V} \in \mathbb{R}^{d_{in} \times d_{v}} WV∈Rdin×dv分别表示查询、键和值的投影矩阵。在自注意力中,输入标记是相同的,即 X ′ = X \boldsymbol{X}^{\prime}=\boldsymbol{X} X′=X,并且键和值维度通常相等,即 d v = d k d_{v}=d_{k} dv=dk。
拼接形式。为了增强注意力层的表示能力,Vaswani等(2017)提出允许多个注意力头对输入标记的不同低维投影进行操作。具体而言,多头注意力机制计算 ( Q , K , V ) (\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V}) (Q,K,V)的 h h h个不同的低维投影,对每个头执行缩放点积注意力,将结果拼接起来,并对拼接后的输出应用最终投影。多头注意力的拼接形式可以表示为:
MultiHead ( X , X ′ ) = Concat ( H 1 , H 2 , … , H h ) W O , H i = Attention ( X W Q i , X ′ W K i , X ′ W V i ) \begin{array}{c} \operatorname{MultiHead}\left(\boldsymbol{X}, \boldsymbol{X}^{\prime}\right)=\operatorname{Concat}\left(\boldsymbol{H}^{1}, \boldsymbol{H}^{2}, \ldots, \boldsymbol{H}^{h}\right) \boldsymbol{W}_{O}, \\ \boldsymbol{H}^{i}=\operatorname{Attention}\left(\boldsymbol{X} \boldsymbol{W}_{Q}^{i}, \boldsymbol{X}^{\prime} \boldsymbol{W}_{K}^{i}, \boldsymbol{X}^{\prime} \boldsymbol{W}_{V}^{i}\right) \end{array} MultiHead(X,X′)=Concat(H1,H2,…,Hh)WO,Hi=Attention(XWQi,X′WKi,X′WVi)
其中, W Q i ∈ R d i n × d k / h \boldsymbol{W}_{Q}^{i} \in \mathbb{R}^{d_{in} \times d_{k} / h} WQi∈Rdin×dk/h, W K i ∈ R d i n × d k / h \boldsymbol{W}_{K}^{i} \in \mathbb{R}^{d_{in} \times d_{k} / h} WKi∈Rdin×dk/h,和 W V i ∈ R d i n × d v / h \boldsymbol{W}_{V}^{i} \in \mathbb{R}^{d_{in} \times d_{v} / h} WVi∈Rdin×dv/h分别表示第 i i i个查询、键和值的投影矩阵。 W O ∈ R d v × d o u t \boldsymbol{W}_{O} \in \mathbb{R}^{d_{v} \times d_{out}} WO∈Rdv×dout是最终投影矩阵。
求和形式。多头注意力机制通常以拼接形式表示。然而,从另一个角度来看,如果我们按行分解 W O ∈ R d v × d o u t \boldsymbol{W}_{O} \in \mathbb{R}^{d_{v} \times d_{out}} WO∈Rdv×dout,我们可以将多头注意力表示为求和形式。具体而言, W O \boldsymbol{W}_{O} WO可以按行分为 h h h个矩阵,即 [ W O 1 , W O 2 , … , W O h ] = W O \left[\boldsymbol{W}_{O}^{1}, \boldsymbol{W}_{O}^{2}, \ldots, \boldsymbol{W}_{O}^{h}\right]=\boldsymbol{W}_{O} [WO1,WO2,…,WOh]=WO,其中 W O i ∈ R d v / h × d o u t \boldsymbol{W}_{O}^{i} \in \mathbb{R}^{d_{v} / h \times d_{out}} WOi∈Rdv/h×dout。最后,多头注意力的求和形式可以表示为:
MultiHead ( X , X ′ ) = ∑ i = 1 h H i W O i \operatorname{MultiHead}\left(\boldsymbol{X}, \boldsymbol{X}^{\prime}\right)=\sum_{i=1}^{h} \boldsymbol{H}^{i} \boldsymbol{W}_{O}^{i} MultiHead(X,X′)=i=1∑hHiWOi
拼接形式可以看作是求和形式的一种变体,其中所有注意力头的维度之和正好等于隐藏大小。如等式3所示,在标准多头注意力中,每个注意力头并行操作,最终输出是所有注意力头的和。由于这些注意力头独立工作,我们可以构建一个动态注意力头路由机制,允许每个标记自适应地选择最相关的注意力头,从而在不影响准确性的情况下提高推理效率。
3.2 多头混合注意力
最近,专家混合(Mixture-of-Experts, MoE)方法已成为扩展大型语言模型参数的一种流行方法(Jiang等人,2024)。典型的MoE层由多个专家网络和一个激活Top-K专家的路由器组成。通常,激活的专家数量 K K K远小于专家总数,以确保推理效率。
将头作为专家。受MoE巨大成功的启发,我们提出了多头混合注意力(Mixture-of-Head attention, MoH),将注意力头视为专家。具体来说,MoH包含 h h h个注意力头 H = { H 1 , H 2 , … , H h } \boldsymbol{H}=\left\{H^{1}, H^{2}, \ldots, H^{h}\right\} H={H1,H2,…,Hh}和一个激活Top-K头的路由器。形式上,给定输入标记 X \boldsymbol{X} X和 X ′ \boldsymbol{X}^{\prime} X′,MoH的输出是 K K K个选定头输出的加权和:
MoH ( X , X ′ ) = ∑ i = 1 h g i H i W O i \operatorname{MoH}\left(\boldsymbol{X}, \boldsymbol{X}^{\prime}\right)=\sum_{i=1}^{h} g_{i} \boldsymbol{H}^{i} \boldsymbol{W}_{O}^{i} MoH(X,X′)=i=1∑hgiHiWOi
其中, g i g_{i} gi表示路由分数。只有当第 i i i个注意力头被激活时, g i g_{i} gi才非零。这种设计提供了两个主要优势:一方面,MoH使每个标记能够选择最相关的注意力头,从而在保持准确性的同时提高推理效率。另一方面,与多头注意力中的标准求和相比,MoH中的加权求和增强了注意力机制的灵活性,并释放了性能潜力。
共享头。在注意力机制中,一些注意力头可能在不同上下文中捕获共同知识,如语言中的语法规则。受Dai等人(2024)的启发,我们指定一个头的子集作为始终激活的共享头。通过整合共享头中的共同知识,我们减少了其他动态路由头之间的冗余。
两阶段路由。此外,为了动态平衡共享头和路由头之间的权重,我们提出了一种两阶段路由策略。在这种路由策略中,路由分数由每个单独头的分数和与头类型相关的分数共同决定。具体来说,给定 X ∈ R T × d i n \boldsymbol{X} \in \mathbb{R}^{T \times d_{i n}} X∈RT×din中的第 t t t个输入标记 x t ∈ R d i n \boldsymbol{x}_{t} \in \mathbb{R}^{d_{i n}} xt∈Rdin,路由分数 g i g_{i} gi定义为:
g i = { α 1 Softmax ( W s x t ) i , if 1 ≤ i ≤ h s α 2 Softmax ( W r x t ) i , if ( W r x t ) i ∈ Top − K ( { ( W r x t ) i ∣ h s + 1 ≤ i ≤ h } ) 0 , otherwise g_{i}=\left\{\begin{array}{ll} \alpha_{1} \operatorname{Softmax}\left(\boldsymbol{W}_{s} \boldsymbol{x}_{t}\right)_{i}, & \text { if } 1 \leq i \leq h_{s} \\ \alpha_{2} \operatorname{Softmax}\left(\boldsymbol{W}_{r} \boldsymbol{x}_{t}\right)_{i}, & \text { if }\left(\boldsymbol{W}_{r} \boldsymbol{x}_{t}\right)_{i} \in \operatorname{Top}-\mathrm{K}\left(\left\{\left(\boldsymbol{W}_{r} \boldsymbol{x}_{t}\right)_{i} \mid h_{s}+1 \leq i \leq h\right\}\right) \\ 0, & \text { otherwise } \end{array}\right. gi=⎩ ⎨ ⎧α1Softmax(Wsxt)i,α2Softmax(Wrxt)i,0, if 1≤i≤hs if (Wrxt)i∈Top−K({(Wrxt)i∣hs+1≤i≤h}) otherwise
其中, h s h_{s} hs表示共享头的数量。 W s ∈ R h s × d i n \boldsymbol{W}_{s} \in \mathbb{R}^{h_{s} \times d_{i n}} Ws∈Rhs×din和 W r ∈ R ( h − h s ) × d in \boldsymbol{W}_{r} \in \mathbb{R}^{\left(h-h_{s}\right) \times d_{\text {in }}} Wr∈R(h−hs)×din 分别表示共享头和路由头的投影矩阵。系数 α 1 \alpha_{1} α1和 α 2 \alpha_{2} α2平衡共享头和路由头的贡献,定义为:
[ α 1 , α 2 ] = Softmax ( W h x t ) \left[\alpha_{1}, \alpha_{2}\right]=\operatorname{Softmax}\left(\boldsymbol{W}_{h} \boldsymbol{x}_{t}\right) [α1,α2]=Softmax(Whxt)
其中, W h ∈ R 2 × d i n \boldsymbol{W}_{h} \in \mathbb{R}^{2 \times d_{i n}} Wh∈R2×din是可训练的投影矩阵, d i n d_{i n} din是 x t \boldsymbol{x}_{t} xt的隐藏大小。
负载均衡损失。 直接训练MoE层通常会导致大多数标记被路由到少数专家,从而使剩余专家训练不足(Shazeer等人,2017)。为了避免所提出的MoH中的负载不平衡,我们遵循之前的MoE方法(Lepikhin等人,2021;Wei等人,2024),应用了一个负载均衡损失。具体来说,对于 X ∈ R T × d i n \boldsymbol{X} \in \mathbb{R}^{T \times d_{i n}} X∈RT×din中的第 t t t个输入标记 x t ∈ R d i n \boldsymbol{x}_{t} \in \mathbb{R}^{d_{i n}} xt∈Rdin,负载均衡损失 L b \mathcal{L}_{b} Lb定义为:
L b = ∑ i = h s + 1 h f i P i , f i = 1 T ∑ t = 1 T 1 ( Token x t selects Head i ) , P i = 1 T ∑ t = 1 T Softmax ( W r x t ) i \mathcal{L}_{b}=\sum_{i=h_{s}+1}^{h} f_{i} P_{i}, f_{i}=\frac{1}{T} \sum_{t=1}^{T} \mathbb{1}\left(\text { Token } \boldsymbol{x}_{t} \text { selects Head } i\right), P_{i}=\frac{1}{T} \sum_{t=1}^{T} \operatorname{Softmax}\left(\boldsymbol{W}_{r} \boldsymbol{x}_{t}\right)_{i} Lb=i=hs+1∑hfiPi,fi=T1t=1∑T1( Token xt selects Head i),Pi=T1t=1∑TSoftmax(Wrxt)i
其中, T T T表示标记的数量。 1 ( ∗ ) \mathbb{1}(*) 1(∗)表示指示函数。
总训练目标。值得注意的是,MoH是一个通用框架。因此,我们在各种流行的模型框架中评估了我们提出的MoH,包括视觉Transformer(Vision Transformers, ViT)、具有Transformer的扩散模型(Diffusion models with Transformers, DiT)和大型语言模型(Large Language Models, LLMs)。根据具体任务,我们需要任务特定的损失。最后,总训练损失是任务特定损失 L task \mathcal{L}_{\text {task }} Ltask 和负载均衡损失 L b \mathcal{L}_{b} Lb的加权和:
L = L task + β L b \mathcal{L}=\mathcal{L}_{\text {task }}+\beta \mathcal{L}_{b} L=Ltask +βLb
其中, β \beta β是用于减轻路由崩溃风险的折衷超参数。默认情况下,对于所有任务,负载均衡损失的权重 β \beta β设置为0.01。
4 实验
4.1 Vit在图像分类中的应用
模型设置。对于视觉Transformer(ViT)(Dosovitskiy等,2021),我们的MoH-ViT模型基于TransNeXt(Shi,2024)框架实现,并在ImageNet-1K数据集(Deng等,2009)上从头开始训练,该数据集包含1000个类别中的超过120万张图像。为确保公平比较,我们仅将标准多头注意力替换为提出的MoH,同时保持所有其他训练参数与TransNeXt相同。
训练细节。我们的MoH-ViT模型使用8个GPU上的自动混合精度训练了300个周期。我们遵循TransNeXt的训练策略,包括各种数据增强技术,如随机增强(Cubuk等,2020)、Mixup(Zhang,2017)、CutMix(Yun等,2019)和随机擦除(Zhong等,2020)。我们还应用了标签平滑(Szegedy等,2016)和DropPath(Huang等,2016)来正则化我们的模型。我们使用AdamW优化器(Loshchilov和Hutter,2017)优化模型,梯度裁剪范数为1.0,权重衰减为0.05。初始学习率设置为 1 e − 3 1 \mathrm{e}-3 1e−3,从 1 e − 6 1 \mathrm{e}-6 1e−6开始有5个周期的预热。采用余弦学习率调度器(Loshchilov和Hutter,2016)来衰减学习率。在训练过程中,图像被随机裁剪为 224 × 224 224 \times 224 224×224的大小。值得注意的是,我们没有使用指数移动平均(EMA)权重。
结果。如表1所示,尽管仅激活了注意力头的一个子集,但MoH-ViT与当前最先进的方法相比仍取得了极具竞争力的性能。例如,MoH-ViT-B在ImageNet-1K分类基准上仅以75%的注意力头实现了84.9%的Top-1准确率。相比之下,成熟的ViT基线TransNeXt虽然需要激活100%的注意力头,但其准确率却略低,为84.8%。表1表明, M o H − V i T \mathrm{MoH}-\mathrm{ViT} MoH−ViT在激活更少的注意力头的情况下优于其他模型。这表明,MoH是视觉模型设计中标准多头注意力的一个有前途的替代方案,通过更高效的注意力头使用提供了具有竞争力的性能潜力。
4.2 DiT用于类条件图像生成
模型设置。对于带有Transformer的扩散模型(DiT)(Peebles和Xie,2023),我们仅在 M o H − D i T \mathrm{MoH}-\mathrm{DiT} MoH−DiT模型中将标准多头注意力替换为我们的MoH,同时保持所有其他训练参数与DiT相同。我们使用ImageNet-1K数据集(Deng等,2009)以 256 × 256 256 \times 256 256×256的分辨率进行类条件图像生成。
训练细节。遵循DiT,最终线性层初始化为零,所有其他层遵循标准的ViT权重初始化。我们使用AdamW优化器(Loshchilov和Hutter,2017)训练所有模型,采用恒定的学习率 1 e − 4 1 \mathrm{e}-4 1e−4,无权重衰减,批量大小为256,并应用水平翻转进行数据增强。与DiT一致,我们在训练期间以0.9999的衰减率使用MoH-DiT权重的指数移动平均(EMA),并使用EMA模型生成所有图像。我们使用来自Stable Diffusion(Rombach等,2022)的现成的预训练变分自编码器(Kingma,2013)模型。与TransNeXt一致,我们的注意力头激活预算在各层之间分布不均,浅层激活的注意力头较少,深层激活的注意力头较多。
评估基准。为评估生成性能,我们使用弗雷谢特感知距离(FID)(Heusel等,2017)评估样本的整体质量,使用精确度和召回率(Kynkäänniemi等,2019)分别测量保真度和多样性,使用sFID(Nash等,2021)作为比FID更能捕捉空间关系的指标。此外,我们还使用感知得分(IS)(Salimans等,2016)作为另一个保真度指标。
结果。为对我们提出的MoH-DiT模型与标准DiT模型进行比较评估,我们从Small模型开始,扩展到XLarge模型。如表2所示,在激活90%的注意力头的情况下,MoHDiT模型始终优于标准DiT模型。然而,当仅激活75%的注意力头时,MoH-DiT模型的性能不如激活100%注意力头的DiT模型。这可能是因为图像生成任务是密集预测任务,需要注意力机制捕获像素级的细粒度关系,与图像分类任务相比,注意力头中的冗余更少。此外,我们将MoH-DiT-XL/2的训练预算扩展到7000K训练步骤,使其与DiT-XL/2保持一致。如表3所示,尽管仅激活了90%的注意力头,但MoH-DiT-XL/2与当前最先进的方法相比仍取得了极具竞争力的性能。这些结果表明,MoH是扩散模型中多头注意力的一个有前途的替代方案。
4.3 从头开始训练大型语言模型(LLMs)
模型设置。为从头开始训练LLMs,我们使用Megatron(Shoeybi等,2019)作为训练框架,这是一个开源训练代码。有关各种MoH-LLMs的详细超参数设置(表A),请参阅附录。所有模型均使用AdamW优化器(Loshchilov和Hutter,2017)进行训练,批量大小为400万个令牌,序列长度为2048。最终学习率设置为最大值的10%。在训练期间,应用0.1的权重衰减和1.0的梯度裁剪。对于LLM-S和MoH-LLM-S,最大学习率设置为3e-4。对于LLM-B和MoH-LLM-B,最大学习率设置为5e-4。
训练细节。我们仅使用公共数据集进行训练,确保学术研究的可访问性。具体而言,我们根据不同的采样概率从RedPajama(Computer,2023)、Dolma(Soldaini等,2024)和Pile(Gao等,2020)数据集中采样。有关详细的样本比例(表B),请参阅附录。遵循以前的工作,我们使用来自LLaMA2(Touvron等,2023)的分词器,其中包含65,536个词汇令牌。
评估基准。我们使用Eleuther AI语言模型评估框架(Gao等,2024)在多个基准上进行评估,该框架是测试生成式语言模型的统一框架。由于最小模型的参数仅为0.2B,我们选择6个简单基准作为评估指标。具体而言,我们报告了SciQ(Welbl等,2017)、PIQA(Bisk等,2020)、WinoGrande(Sakaguchi等,2021)、OpenbookQA(Mihaylov等,2018)、LogiQA(Liu等,2020)和TruthfulQA(Lin等,2022)上的零样本准确率。
结果。如表4所示,尽管仅激活了注意力头的一个子集,但多头混合(MoH)大型语言模型(LLMs)与我们的基线模型相比仍实现了极具竞争力的性能。例如,当仅激活50%的注意力头时,MoH-LLM的平均准确率达到了45.4%;相比之下,基线模型在激活全部(100%)注意力头的情况下,准确率略低,为43.9%。这些结果表明,对于从零开始训练LLMs而言,MoH有望成为标准多头注意力机制的一个有前景的替代方案。令人惊讶的是,我们发现对于MoH-LLM-S,仅激活50%的注意力头的表现优于激活75%的情况。我们认为,这可能是因为当模型和数据集都较小时,激活较少的注意力头可以有效地对模型进行正则化。然而,随着数据量的增加,激活更多的注意力头提供了更高的性能潜力。
4.4 继续调优LLAMA3-8B
模型设置。为了显著提升所提出的MoH方法的适用性,我们还尝试将预训练的多头注意力模型(如LLaMA3-8B)进一步继续调优为MoH模型。然而,这带来了三个挑战。(i)确定共享的注意力头:我们简单地选择每一层的前16个注意力头作为共享头。(ii)添加头路由器:将随机初始化的路由器集成到预训练模型中,同时又不损害其原始性能,需要谨慎的训练技术。为此,我们提出了一种无需参数的路由器,该路由器使用每个注意力头查询的$ \ell_{2} 范数来确定路由分数。( i i i )加权注意力头:我们观察到,对注意力头的输出进行加权会显著改变注意力层输出的分布,这需要大量的训练数据来恢复原始性能。为了解决这个问题,我们对路由分数进行量化,并使用直通估计器( B e n g i o 等人, 2013 ; L i u 等人, 2022 )通过稀疏函数反向传播梯度。具体来说,给定输入标记 范数来确定路由分数。(iii)加权注意力头:我们观察到,对注意力头的输出进行加权会显著改变注意力层输出的分布,这需要大量的训练数据来恢复原始性能。为了解决这个问题,我们对路由分数进行量化,并使用直通估计器(Bengio等人,2013;Liu等人,2022)通过稀疏函数反向传播梯度。具体来说,给定输入标记 范数来确定路由分数。(iii)加权注意力头:我们观察到,对注意力头的输出进行加权会显著改变注意力层输出的分布,这需要大量的训练数据来恢复原始性能。为了解决这个问题,我们对路由分数进行量化,并使用直通估计器(Bengio等人,2013;Liu等人,2022)通过稀疏函数反向传播梯度。具体来说,给定输入标记 \boldsymbol{x} $,我们为激活路由分数使用一个量化器,其前向传播公式为:
g i q = 1 ( Token x selects Head i ) g_{i}^{q}=\mathbb{1}(\text { Token } \boldsymbol{x} \text { selects Head } i) giq=1( Token x selects Head i),
其中,$ \mathbb{1}(*) 表示指示函数, 表示指示函数, 表示指示函数,g_{i}^{q}$表示量化的路由分数。然后,我们采用直通估计器,该估计器将传入的梯度分配给阈值操作作为传出的梯度,公式为:
∂ L ∂ g i q = ∂ L ∂ g i , \frac{\partial \mathcal{L}}{\partial g_{i}^{q}}=\frac{\partial \mathcal{L}}{\partial g_{i}}, ∂giq∂L=∂gi∂L,
其中, g i g_{i} gi表示实值路由分数。这个简单的近似函数显著缓解了梯度消失的问题(Wang等人,2024)。与从零开始训练LLMs类似,我们也使用Megatron(Shoeybi等人,2019)这一开源训练代码作为训练框架。
训练细节。我们发现,如果继续训练的数据与模型原始训练数据的分布存在差异,那么模型在训练初期的性能可能会剧烈波动。由于我们无法获取LLaMA3的原始训练数据,因此我们通过将训练过程分为两个阶段来解决这些潜在的性能波动问题。在第一阶段,我们使用3000亿个标记继续调优原始的LLaMA3-8B模型,以使模型适应我们的数据集。在第二阶段,我们使用1000亿个标记将这个经过适应的模型继续调优为我们提出的MoH模型。在第一阶段,最大学习率设置为 6 e − 5 6 \mathrm{e}-5 6e−5,最终学习率设置为 6 e − 6 6 \mathrm{e}-6 6e−6。在第二阶段,最大学习率设置为 2 e − 5 2 \mathrm{e}-5 2e−5,最终学习率设置为 1 e − 6 1 \mathrm{e}-6 1e−6。对于两个阶段,我们都使用了AdamW优化器(Loshchilov和Hutter,2017),批量大小为1600万个标记,序列长度为8192。在训练过程中,我们使用了0.1的权重衰减和1.0的梯度裁剪。
评估基准。我们使用Eleuther AI语言模型评估框架(Gao等人,2024)在多个关键基准上评估模型。具体来说,我们利用lm-evaluation-harness软件包来评估一系列下游任务的性能:(i)遵循Pythia(Biderman等人,2023),我们报告了LAMBADA(Paperno等人,2016)、LogiQA(Liu等人,2020)、PIQA(Bisk等人,2020)、SciQ(Welbl等人,2017)和WinoGrande(Sakaguchi等人,2021)的零样本准确率。(ii)我们报告了中文任务的准确率,包括5样本CEVAL(Huang等人,2023)和5样本CMMLU(Li等人,2023a)。(iii)我们报告了来自Open LLM排行榜(Beeching等人,2023)的任务的准确率,包括10样本HellaSwag(Zellers等人,2019)、25样本ARC挑战赛(ARC-C)(Clark等人,2018)和5样本MMLU(Hendrycks等人,2021)。(iv)我们报告了32样本Natural Questions(NQ)(Kwiatkowski等人,2019)的精确匹配分数和32样本BoolQ(Clark等人,2019)的准确率。(v)我们报告了8样本GSM8K(Cobbe等人,2021)的精确匹配分数,以评估数学能力。(vi)此外,我们还报告了TruthfulQA(Lin等人,2022)的零样本准确率,以评估生成真实答案的能力。
结果。如图2所示,在1000亿标记的训练预算内,MoH-LLaMA3-8B迅速恢复到原始模型95%以上的性能。在继续调优1000亿标记后,如表5所示,MoH-LLaMA3-8B在14个基准上的平均准确率为64.0%,仅使用75%的注意力头就比LLaMA3-8B高出2.4%。这些结果表明,预训练的多头注意力模型可以进一步继续调优为我们的MoH模型,从而显著提高MoH方法的适用性。
4.5 消融分析
所提MoH中各组件的影响。为了探索我们MoH方法中每个组件的影响,我们在表6中提供了消融结果。“共享头”指的是始终激活的注意力头子集。“两阶段路由”表示如等式5和等式6所述,在路由分数上平衡共享头和路由头权重的动态系数。如表6所示,共享头通过有效捕获通用知识,显著提高了模型性能,从而使路由头能够更专注于领域特定信息。此外,两阶段路由通过动态平衡共享头和路由头之间的权重,进一步提高了模型性能。我们的完整模型实现了最佳性能,这表明两个组件都对注意力机制有显著益处。
激活头中共享头的比例影响。在表7中,我们提供了关于激活头中共享头比例的消融研究。我们发现,模型性能在广泛的共享头比例范围内(从13.9%到74.0%)保持相对稳定。这些结果表明,只要共享头比例不过于极端,模型的性能就是稳定的。从另一个角度来看,共享头可以被视为Soft MoE(Puigcerver等人,2024)的一种形式。根据Soft MoE论文(Puigcerver等人,2024)的发现,我们建议在激活头中使用较高的共享头比例(大于40%)。
5 讨论
注意力头负载分布的可视化。如图3所示,我们观察到不同类别和任务主题下的注意力头分配存在显著差异,这表明多注意力头(MoH)模型通过采用不同的头分配模式来适应不同的任务。MoH的这一特性允许不同的注意力头专注于不同类型的任务,使得参数利用比多头注意力更高效。有关MoH-LLaMA3-8B的更多可视化展示和注意力头负载分布的详细分析,请参阅附录D。
MoH与MoA的区别。我们从以下三个方面阐明了MoH与MoA(Zhang等人,2022)之间的区别。首先,在动机方面,MoH的目标是在不增加参数数量的情况下提高注意力机制的效率和性能。相比之下,MoA与混合专家(MoE)的动机相同,即在保持推理成本较低的同时扩展模型参数。因此,MoH的模型设置比MoA更为严格。其次,在方法论方面,我们的MoH引入了共享头和两阶段路由来增强标准的MoE方法。更重要的是,我们表明,预训练的多头注意力模型可以进一步继续调整为我们的MoH模型,极大地提高了所提出的MoH方法的适用性。相比之下,MoA直接将多头注意力与MoE相结合。由于采用了共享的键和值,MoA必须从零开始训练,这限制了其适用性。最后,在模型框架方面,我们的MoH在各种流行的模型框架和任务中得到了验证,包括ViT、DiT和仅解码器的LLM,而MoA仅在用于语言任务的编码器-解码器架构上得到了验证。
6 结论
在本文中,我们介绍了MoH,它是多头注意力的一种有前景的替代方案。MoH使每个令牌能够自适应地选择适当的注意力头,从而在不增加参数数量的情况下提高模型性能和推理效率。在包括ViT、DiT和LLM在内的各种流行模型框架上进行的广泛实验表明,即使仅使用50%~90%的注意力头,MoH的性能也优于多头注意力。更令人鼓舞的是,我们表明,诸如LLaMA3-8B之类的预训练多头注意力模型可以进一步继续调整为我们的MoH模型,显著提高了所提出的MoH方法的适用性。这项工作朝着先进且高效的基于注意力的模型迈出了有希望的一步,可能对研究和工业界都具有重要意义并有所帮助。
代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as npclass Attention(nn.Module):LOAD_BALANCING_LOSSES = []def __init__(self, dim, input_resolution, num_heads=8, qkv_bias=True, attn_drop=0.,proj_drop=0., shared_head=0, routed_head=0):super().__init__()assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."self.dim = dimself.num_heads = num_headsself.head_dim = dim // num_headsself.temperature = nn.Parameter(torch.log((torch.ones(num_heads, 1, 1) / 0.24).exp() - 1)) # Initialize softplus(temperature) to 1/0.24.# Generate sequnce length scaleself.register_buffer("seq_length_scale", torch.as_tensor(np.log(input_resolution[0] * input_resolution[1])),persistent=False)self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.query_embedding = nn.Parameter(nn.init.trunc_normal_(torch.empty(self.num_heads, 1, self.head_dim), mean=0, std=0.02))self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)# mlp to generate continuous relative position biasself.cpb_fc1 = nn.Linear(2, 512, bias=True)self.cpb_act = nn.ReLU(inplace=True)self.cpb_fc2 = nn.Linear(512, num_heads, bias=True)self.shared_head = shared_headself.routed_head = routed_headif self.routed_head > 0:self.wg = torch.nn.Linear(dim, num_heads - shared_head, bias=False)if self.shared_head > 0:self.wg_0 = torch.nn.Linear(dim, 2, bias=False)if self.shared_head > 1:self.wg_1 = torch.nn.Linear(dim, shared_head, bias=False)def forward(self, x, H, W, relative_pos_index, relative_coords_table):B, N, C = x.shape_x = x.reshape(B * N, C) if self.routed_head > 0:logits = self.wg(_x)gates = F.softmax(logits, dim=1)num_tokens, num_experts = gates.shape_, indices = torch.topk(gates, k=self.routed_head, dim=1)mask = F.one_hot(indices, num_classes=num_experts).sum(dim=1)if self.training:me = gates.mean(dim=0)ce = mask.float().mean(dim=0)l_aux = torch.mean(me * ce) * num_experts * num_expertsAttention.LOAD_BALANCING_LOSSES.append(l_aux)routed_head_gates = gates * maskdenom_s = torch.sum(routed_head_gates, dim=1, keepdim=True)denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)routed_head_gates /= denom_srouted_head_gates = routed_head_gates.reshape(B, N, -1) * self.routed_headqkv = self.qkv(x).reshape(B, -1, 3 * self.num_heads, self.head_dim).permute(0, 2, 1, 3)q, k, v = qkv.chunk(3, dim=1)# Use MLP to generate continuous relative positional biasrel_bias = self.cpb_fc2(self.cpb_act(self.cpb_fc1(relative_coords_table))).transpose(0, 1)[:,relative_pos_index.view(-1)].view(-1, N, N)# Calculate attention map using sequence length scaled cosine attention and query embeddingattn = ((F.normalize(q, dim=-1) + self.query_embedding) * F.softplus(self.temperature) * self.seq_length_scale) @ F.normalize(k, dim=-1).transpose(-2, -1) + rel_biasattn = attn.softmax(dim=-1)attn = self.attn_drop(attn)if self.routed_head > 0:x = (attn @ v).transpose(1, 2) # B, N, head, dimif self.shared_head > 1:shared_head_weight = self.wg_1(_x)shared_head_gates = F.softmax(shared_head_weight, dim=1).reshape(B, N, -1) * self.shared_headelse:shared_head_gates = torch.ones((B, N, self.shared_head)).to(_x.device).to(_x.dtype) * self.shared_head if self.shared_head == 0:masked_gates = routed_head_gateselse:weight_0 = self.wg_0(_x)weight_0 = F.softmax(weight_0, dim=1).reshape(B, N, 2) * 2 shared_head_gates = torch.einsum("bn,bne->bne", weight_0[:,:,0], shared_head_gates)routed_head_gates = torch.einsum("bn,bne->bne", weight_0[:,:,1], routed_head_gates)masked_gates = torch.cat([shared_head_gates, routed_head_gates], dim=2)x = torch.einsum("bne,bned->bned", masked_gates, x)x = x.reshape(B, N, C)else:shared_head_weight = self.wg_1(_x)masked_gates = F.softmax(shared_head_weight, dim=1).reshape(B, N, -1) * self.shared_headx = (attn @ v).transpose(1, 2) # B, N, head, dimx = torch.einsum("bne,bned->bned", masked_gates, x)x = x.reshape(B, N, C)x = self.proj(x)x = self.proj_drop(x)return x
类定义和初始化
class Attention(nn.Module):LOAD_BALANCING_LOSSES = []
Attention
类继承自nn.Module
,是PyTorch中所有神经网络模块的基类。LOAD_BALANCING_LOSSES
是一个类变量,用于存储负载均衡损失,但在这个实现中,它看起来并没有在类的其他部分被有效利用。
def __init__(self, dim, input_resolution, num_heads=8, qkv_bias=True, attn_drop=0.,proj_drop=0., shared_head=0, routed_head=0):
- 类的初始化方法,接受多个参数:
dim
:输入特征的维度。input_resolution
:输入数据的分辨率,通常是一个包含高度和宽度的元组。num_heads
:注意力头的数量,默认为8。qkv_bias
:qkv线性变换是否使用偏置项,默认为True。attn_drop
:注意力得分上的dropout率,默认为0。proj_drop
:最终投影后的dropout率,默认为0。shared_head
:共享注意力头的数量,默认为0。routed_head
:路由注意力头的数量,默认为0。
super().__init__()
- 调用父类的初始化方法。
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
- 确保
dim
可以被num_heads
整除,以便每个头获得相同数量的维度。
self.dim = dimself.num_heads = num_headsself.head_dim = dim // num_heads
- 初始化实例变量。
self.temperature = nn.Parameter(torch.log((torch.ones(num_heads, 1, 1) / 0.24).exp() - 1))
- 初始化温度参数,用于缩放注意力得分,通过softplus函数初始化为大约
1/0.24
。
self.register_buffer("seq_length_scale", torch.as_tensor(np.log(input_resolution[0] * input_resolution[1])),persistent=False)
- 注册一个不参与梯度计算的buffer,用于存储序列长度的对数尺度。
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- 初始化qkv线性变换。
self.query_embedding = nn.Parameter(nn.init.trunc_normal_(torch.empty(self.num_heads, 1, self.head_dim), mean=0, std=0.02))
- 初始化查询嵌入,使用截断正态分布初始化。
self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)
- 初始化dropout和投影层。
self.cpb_fc1 = nn.Linear(2, 512, bias=True)self.cpb_act = nn.ReLU(inplace=True)self.cpb_fc2 = nn.Linear(512, num_heads, bias=True)
- 初始化用于生成连续相对位置偏置的MLP。
self.shared_head = shared_headself.routed_head = routed_head
- 初始化共享和路由注意力头的数量。
if self.routed_head > 0:self.wg = torch.nn.Linear(dim, num_heads - shared_head, bias=False)if self.shared_head > 0:self.wg_0 = torch.nn.Linear(dim, 2, bias=False)
- 如果存在路由头,初始化相应的线性层。
if self.shared_head > 1:self.wg_1 = torch.nn.Linear(dim, shared_head, bias=False)
- 如果存在共享头,初始化相应的线性层。
前向传播
def forward(self, x, H, W, relative_pos_index, relative_coords_table):
- 前向传播方法,接受输入特征
x
、高度H
、宽度W
、相对位置索引relative_pos_index
和相对坐标表relative_coords_table
。
B, N, C = x.shape_x = x.reshape(B * N, C)
- 获取输入的形状,并将
x
重塑为二维张量,以便进行线性变换。
if self.routed_head > 0:# 路由头的相关计算...
- 如果存在路由头,执行一系列计算来路由注意力头。
qkv = self.qkv(x).reshape(B, -1, 3 * self.num_heads, self.head_dim).permute(0, 2, 1, 3)q, k, v = qkv.chunk(3, dim=1)
- 计算qkv,并分割成q、k、v。
rel_bias = self.cpb_fc2(self.cpb_act(self.cpb_fc1(relative_coords_table))).transpose(0, 1)[:,relative_pos_index.view(-1)].view(-1, N, N)
- 使用MLP计算相对位置偏置。
attn = ((F.normalize(q, dim=-1) + self.query_embedding) * F.softplus(self.temperature) * self.seq_length_scale) @ F.normalize(k, dim=-1).transpose(-2, -1) + rel_bias
- 计算注意力得分,包括查询嵌入、温度缩放、序列长度缩放和相对位置偏置。
attn = attn.softmax(dim=-1)attn = self.attn_drop(attn)
- 对注意力得分应用softmax和dropout。
if self.routed_head > 0:# 路由头的后续处理...else:# 共享头的处理...
- 根据是否存在路由头,执行不同的处理逻辑。
x = self.proj(x)x = self.proj_drop(x)return x
- 应用最终投影和dropout,返回输出。
这个类实现了一个复杂的注意力机制,包括路由注意力头、共享注意力头和连续相对位置偏置等特性。这些特性使得这个注意力层能够处理更复杂的输入和输出关系,适用于各种深度学习模型。