当前位置: 首页 > news >正文

【深度学习中的注意力机制10】11种主流注意力机制112个创新研究paper+代码——交叉注意力(Cross-Attention)

【深度学习中的注意力机制10】11种主流注意力机制112个创新研究paper+代码——交叉注意力(Cross-Attention)

【深度学习中的注意力机制10】11种主流注意力机制112个创新研究paper+代码——交叉注意力(Cross-Attention)


文章目录

  • 【深度学习中的注意力机制10】11种主流注意力机制112个创新研究paper+代码——交叉注意力(Cross-Attention)
  • 1. 交叉注意力的起源与提出
  • 2. 交叉注意力的原理
  • 3. 交叉注意力的数学表示
  • 4. 交叉注意力的应用场景与发展
  • 5. 代码实现
  • 6. 代码解释
  • 7. 总结


欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz

1. 交叉注意力的起源与提出

交叉注意力(Cross-Attention)是在深度学习中提出的一种重要注意力机制,用于在多个输入之间建立关联,主要用于多模态任务中(如图像和文本、视频和音频的联合处理)。

与常规的自注意力机制不同,交叉注意力专注于从两个不同的输入特征空间中提取和结合关键信息。这种机制最初在自然语言处理和计算机视觉的融合任务中得到应用,例如在多模态Transformer、机器翻译和图像-文本任务(如CLIP、DALL·E、VQA等)中。

  • 提出背景:交叉注意力通常用于处理两种不同类型的数据,通过这种机制,一个输入可以对另一个输入进行查询,捕捉和增强跨模态之间的关联。相比自注意力(仅在同一个输入中找到相关性),交叉注意力能够有效地捕捉多模态数据的交互信息。

2. 交叉注意力的原理

交叉注意力的核心思想是将一个输入(例如图像)作为查询(Query),另一个输入(例如文本)作为键(Key)和值(Value),通过注意力机制让查询能够从键和值中选择和关注相关信息。

交叉注意力的步骤:

  • 查询、键、值的生成: 假设有两个不同的输入数据 X1 和 X2,分别生成对应的 Query、Key 和 Value 矩阵。对于 X1,我们可以生成 Query 矩阵,而对于 X2,则可以生成 Key 和 Value 矩阵。
  • 注意力计算: 与自注意力类似,交叉注意力通过计算 Query 和 Key 的相似性来获得注意力权重:
    在这里插入图片描述
    其中 Q 来自 X1,而 K 和 V 来自 X2 。通过这种计算,Query 可以从X2 中提取与其最相关的信息,这种机制实现了两个输入数据之间的特征融合和信息传递。
  • 权重与输出: 计算出的注意力权重应用到 X2的 Value 矩阵上,得到 X1在
    X2上的相关信息。这种机制实现了两个输入数据之间的特征融合和信息传递。

3. 交叉注意力的数学表示

假设有两个输入特征 X 1 ∈ R T 1 × d X_1∈R^{T_1×d} X1RT1×d X 2 ∈ R T 2 × d X_2∈R^{T_2×d} X2RT2×d,其中 T 1 T_1 T1 T 2 T_2 T2分别表示两个输入的长度(如序列长度或特征维度), d d d 表示特征维度。

Query、Key 和 Value 的生成:

  • 对于 X 1 X_1 X1:生成查询矩阵 Q = W q X 1 Q=W_qX_1 Q=WqX1
  • 对于 X 2 X_2 X2:生成键矩阵 K = W k X 2 K=W_kX_2 K=WkX2和值矩阵 V = W v X 2 V=W_vX_2 V=WvX2

注意力计算:
在这里插入图片描述
其中, W q W_q Wq W k W_k Wk W v W_v Wv ∈ R d × d ∈R^{d×d} Rd×d是线性变换矩阵, d d d 是键的维度。

结果输出: 注意力权重应用于 V V V 后的结果,即:
在这里插入图片描述

4. 交叉注意力的应用场景与发展

交叉注意力在以下场景中得到广泛应用:

  • 多模态学习:交叉注意力在视觉和语言任务中的多模态联合建模中尤为常见,如图像与文本的对齐(CLIP)、视觉问答(VQA)和跨模态生成任务(如DALL·E)。
  • 机器翻译:交叉注意力在Transformer中的"解码器"部分用于让生成的序列(目标语言)参考源语言的表示,这大大提高了翻译质量。
  • Transformer架构的扩展:在诸如BERT、GPT等基于Transformer的模型中,交叉注意力也被用于各种任务,例如文本生成、序列到序列任务等。

发展过程中,交叉注意力机制已经被改进和扩展。例如,层次化交叉注意力(Hierarchical Cross-Attention)通过在不同层次上融合多模态信息,进一步提升了模型在多模态任务上的性能。

5. 代码实现

下面是一个基于PyTorch的交叉注意力机制的简单实现,用于展示如何在两个不同的输入(例如图像和文本)之间计算交叉注意力。

import torch
import torch.nn as nnclass CrossAttention(nn.Module):def __init__(self, dim, num_heads=8, dropout=0.1):super(CrossAttention, self).__init__()self.num_heads = num_headsself.dim = dimself.head_dim = dim // num_headsassert self.head_dim * num_heads == dim, "dim must be divisible by num_heads"# 线性变换,用于生成 Q, K, V 矩阵self.q_proj = nn.Linear(dim, dim)self.k_proj = nn.Linear(dim, dim)self.v_proj = nn.Linear(dim, dim)# 输出的线性变换self.out_proj = nn.Linear(dim, dim)self.dropout = nn.Dropout(dropout)self.softmax = nn.Softmax(dim=-1)def forward(self, x1, x2):# x1 是 Query,x2 是 Key 和 ValueB, T1, C = x1.shape  # x1 的形状: [batch_size, seq_len1, dim]_, T2, _ = x2.shape  # x2 的形状: [batch_size, seq_len2, dim]# 生成 Q, K, V 矩阵Q = self.q_proj(x1).view(B, T1, self.num_heads, self.head_dim).transpose(1, 2)K = self.k_proj(x2).view(B, T2, self.num_heads, self.head_dim).transpose(1, 2)V = self.v_proj(x2).view(B, T2, self.num_heads, self.head_dim).transpose(1, 2)# 计算注意力得分attn_scores = (Q @ K.transpose(-2, -1)) / (self.head_dim ** 0.5)attn_weights = self.softmax(attn_scores)  # 注意力权重attn_weights = self.dropout(attn_weights)  # dropout 防止过拟合# 使用注意力权重加权值矩阵attn_output = attn_weights @ Vattn_output = attn_output.transpose(1, 2).contiguous().view(B, T1, C)# 输出线性变换output = self.out_proj(attn_output)return output# 测试交叉注意力机制
if __name__ == "__main__":B, T1, T2, C = 2, 10, 20, 64  # batch_size, seq_len1, seq_len2, channelsx1 = torch.randn(B, T1, C)  # Query 输入x2 = torch.randn(B, T2, C)  # Key 和 Value 输入cross_attn = CrossAttention(dim=C, num_heads=4)output = cross_attn(x1, x2)print("输出形状:", output.shape)  # 输出应该为 [batch_size, seq_len1, channels]

6. 代码解释

CrossAttention 类:该类实现了交叉注意力机制,允许将两个不同的输入(x1x2)进行交叉信息融合。

  • q_proj, k_proj, v_proj:三个线性层,用于将输入分别映射到 Query、Key 和 Value 空间。
  • num_headshead_dim:定义了多头注意力机制的头数和每个头的维度。

forward 函数:实现前向传播过程。

  • Q, K, V:分别从 x1x2 中生成 Query、Key 和 Value 矩阵,形状为 [batch_size, num_heads, seq_len, head_dim]
  • attn_scores:计算 Query 和 Key 的点积,得到注意力得分。
  • attn_weights:通过 softmax 对得分进行归一化,得到注意力权重。
  • attn_output:利用注意力权重对 Value 矩阵进行加权求和,得到最终的注意力输出。

测试部分:随机生成两个输入张量 x1x2,并测试交叉注意力的输出形状,确保与预期一致。

7. 总结

交叉注意力在多模态学习中起到了至关重要的作用,能够有效融合不同类型的数据,使得模型可以同时处理图像、文本等多种信息。通过捕捉模态之间的相关性,交叉注意力为多模态任务中的特征融合提供了强大的工具。

欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz


http://www.mrgr.cn/news/58131.html

相关文章:

  • 无需多言,简单粗暴上手MybatisPlus
  • springboot RedisTemplate支持多个序列化方式
  • 华为eNSP:端口安全
  • RAG技术
  • 接口自动化-框架搭建(Python+request+pytest+allure)
  • 3GPP协议解读_NTN系列(一)_38.811_非地面网络(NTN)的背景、应用场景和信道建模
  • 附录章节:SQL标准与方言对比
  • 【已解决】【hadoop】如何解决Hive连接MySQL元数据库的依赖问题
  • 【C++】位图
  • ruoyi域名跳转缓存冲突问题(解决办法修改:session名修改session的JSESSIONID名称)
  • C/C++(六)多态
  • OpenCV KeyPoint与描述子编解码
  • rtsp的2种收流模式
  • Qt 智能指针QScopedPoint用法
  • 【已解决】【hadoop】【hive】启动不成功 报错 无法与MySQL服务器建立连接 Hive连接到MetaStore失败 无法进入交互式执行环境
  • Golang | Leetcode Golang题解之第507题完美数
  • 将二维图像映射到三维场景使用NeRF在AMD GPU上
  • <自用> python 更新库命令
  • Codeforces Round 981 div3 个人题解(A~G)
  • AI学习指南深度学习篇-自注意力机制(Self-Attention Mechanism)
  • 基于 Python 的自然语言处理系列(43):Question Answering
  • 【C++差分数组】P10903 商品库存管理
  • 003:无人机概述
  • 【MySQL】数据库约束和多表查询
  • Hugging Face HUGS 加快了基于开放模型的AI应用的开发
  • 前端方案:播放的视频加水印或者文字最佳实践