当前位置: 首页 > news >正文

24/11/12 算法笔记<强化学习> 自注意力机制

自注意力机制(Self-Attention Mechanism),也称为内部注意力机制,是一种在深度学习模型中,特别是在自然语言处理(NLP)和计算机视觉领域中广泛使用的机制。它允许模型在处理序列数据时,能够动态地聚焦于序列的不同部分,从而捕捉到序列内部的长距离依赖关系。

自注意力机制的核心思想是,序列中的每个元素都与其他所有元素相关,模型需要学习如何根据上下文信息来分配不同的注意力权重。这种机制最早在Transformer模型中被提出,并在随后的研究中被广泛应用于各种任务。

自注意力机制的工作原理

  1. 输入表示:模型首先将输入序列(如句子或图像)转换为一系列向量表示,这些向量通常通过嵌入层(Embedding Layer)得到。

  2. 查询(Query)、键(Key)和值(Value):对于序列中的每个元素,模型会生成三个向量:查询(Q)、键(K)和值(V)。在原始的Transformer模型中,这些向量是通过输入向量与三个不同的权重矩阵相乘得到的。

  3. 计算注意力分数:模型计算每个查询向量与所有键向量之间的相似度或匹配程度,得到一个注意力分数矩阵。这个分数矩阵通常通过点积(Dot Product)或缩放点积(Scaled Dot-Product)得到。

  4. 应用 softmax 函数:注意力分数矩阵通过softmax函数进行归一化,使得每一行的和为1。这样,每个查询向量都会得到一个概率分布,表示对其他元素的注意力权重。

  5. 加权和:每个查询向量根据学到的权重,对所有值向量进行加权求和,得到最终的输出向量。

  6. 输出:自注意力层的输出可以是序列中的每个元素对应的加权和向量,这些向量可以被用作后续任务的输入,如分类、翻译等。

多头自注意力

为了捕捉不同子空间中的信息,Transformer模型引入了多头自注意力机制。在多头自注意力中,模型并行地执行多次自注意力操作,每个“头”使用不同的权重矩阵来生成查询、键和值。最后,所有头的输出被拼接在一起,并通过一个线性层进行处理,以产生最终的输出。

自注意力机制的优势在于其能够处理序列数据中的长距离依赖,并且不受传统循环神经网络(RNN)中序列长度的限制。此外,由于其并行化的特性,自注意力模型通常比RNN模型训练得更快。

我们来看下它的代码

1.导入必要的库

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.nn模块用于构建神经网络层和初始化参数,torch.nn.functional包含了各种不带有权重的函数式接口。
2.定义自注意力类

class SelfAttention(nn.Module):def __init__(self, embed_size, heads):super(SelfAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // heads

embed_size表示嵌入向量的维度,heads表示注意力头的数量。head_dim是每个头的维度,它通过将embed_size除以heads得到。

3.检查嵌入维度是否能被头数整除

        assert (self.head_dim * heads == embed_size), "Embedding size needs to be divisible by heads"

这里使用断言语句来确保嵌入维度可以被头数整除,这是实现多头注意力机制的前提条件。

4.初始化线性层

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

这部分代码初始化了四个线性层(全连接层),分别用于计算值(values)、键(keys)、查询(queries)和输出。由于我们使用的是多头注意力机制,所以需要将输入的嵌入向量分割成多个头,每个头都有自己的线性层。最后一个线性层fc_out用于将多头的输出合并回原始的嵌入维度。

5.前向传播

    def forward(self, value, key, query):N = query.shape[0]value_len, key_len, query_len = value.shape[1], key.shape[1], query.shape[1]

forward方法中,我们首先获取输入张量的形状信息。N是批大小,value_lenkey_lenquery_len分别是值、键和查询序列的长度。

6.分割嵌入向量

        values = self.values(value).view(N, value_len, self.heads, self.head_dim)keys = self.keys(key).view(N, key_len, self.heads, self.head_dim)queries = self.queries(query).view(N, query_len, self.heads, self.head_dim)

这里我们使用线性层处理输入的值、键和查询,并将结果分割成多个头。view方法用于重新塑形张量,以适应多头注意力机制的需要。

  1. 通过.view()方法,将变换后的数据重塑为一个新的形状。这里的新形状是(N, value_len, self.heads, self.head_dim),其中:

    • N是批次大小(batch size)。
    • value_lenkey_lenquery_len分别是值、键和查询序列的长度。
    • self.heads是注意力头的数量。
    • self.head_dim是每个头的维度。

7.调整张量维度以适应多头注意力

        values = values.permute(0, 2, 1, 3)keys = keys.permute(0, 2, 1, 3)queries = queries.permute(0, 2, 1, 3)

通过permute方法,我们调整张量的维度

8.计算注意力分数

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

使用torch.einsum计算查询和键之间的点积,得到注意力分数。然后,我们对这些分数应用softmax函数,以获得每个头的注意力权重。除以embed_size的平方根是缩放点积的一种常见做法,有助于稳定训练过程

9.应用注意力权重

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).permute(0, 2, 1, 3)

这里我们再次使用torch.einsum将注意力权重应用到值上,得到加权的值。然后,我们调整张量的维度,以便于后续的合并操作。

10.合并多头注意力

        out = out.reshape(N, query_len, self.heads * self.head_dim)out = self.fc_out(out)

11.使用自注意力机制

embed_size = 256  # 嵌入向量的维度
heads = 8  # 注意力头的数量value = torch.rand(32, 10, embed_size)
key = torch.rand(32, 10, embed_size)
query = torch.rand(32, 10, embed_size)attention = SelfAttention(embed_size, heads)
output = attention(value, key, query)print(output.shape)  # 应该输出:torch.Size([32, 10, 256])

这部分代码展示了如何使用上面定义的SelfAttention类。我们创建了一个SelfAttention实例,并传入随机生成的值、键和查询张量。然后,我们打印输出张量的形状,以验证其正确性。


http://www.mrgr.cn/news/71324.html

相关文章:

  • P10打卡——pytorch实现车牌识别
  • 第 5 场 算法季度赛
  • 使用 Python 实现自动化办公(邮件、Excel)
  • 算法题(33):长度最小的子数组
  • Android实战经验篇-增加系统分区
  • http://noi.openjudge.cn/——3.9数据结构之C++STL——【3342:字符串操作】
  • 【vs主程序 链接 实时生成库的问题】
  • HTTP状态码详解
  • 接口自动化测试实战(全网唯一)
  • integer==与equals()结果不同
  • Node.js笔记
  • 卸载 Python
  • 微澜:用 OceanBase 搭建基于知识图谱的实时资讯流的应用实践
  • 内网穿透,打通远程和本地调试部署测试
  • 软件测试项目实战
  • 使用docker安装mysql8
  • 炼石亮相密码丰会,探索从密码合规到数据安全实战防护
  • qt QSerialPortInfo详解
  • 机器视觉和计算机视觉的区别
  • 阿兰图灵的人工智能艺术作品以 100 万美元的价格售出
  • 创意加速器3个AI工具,让创作速度超光速!
  • 【数字静态时序分析】复杂时钟树的时序约束SDC写法
  • 力扣 LeetCode 704. 二分查找(Day1:数组)
  • 移门缓冲支架的作用与优势
  • 继承的学习
  • 虚拟机的安装