当前位置: 首页 > news >正文

SamOut 推理空间不变模型解析

项目地址

SamOutV2 0.18B模型

  • 采取 em参数共享在参数量减半的情况下将维度从1024 拉升到了1536
  • sft 单论对话 loss 保持1.8
  • 如果未来匹配state 推理代码性能不变的同时推理任意长度使用资源空间保持不变
import torchclass MaxState(torch.nn.Module):def __init__(self, hidden_dim, heads):super(MaxState, self).__init__()assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."self.head_size = hidden_dim // headsself.head0 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)self.head1 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)self.head2 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)self.head_num = headsself.hidden = hidden_dimdef forward(self, input_data, state=None):b, s, k, h = input_data.shape[0], input_data.shape[1], self.head_num, self.head_sizeout = self.head0(input_data)out1 = self.head1(input_data)out2 = self.head2(input_data)out = out.reshape([b, s, k, h]).permute([0, 2, 1, 3])out1 = out1.reshape([b, s, k, h]).permute([0, 2, 1, 3])if state is None:out = torch.cummax((out + out1) / h ** 0.5, 2)[0]else:out = torch.cummax(torch.concat([state, (out + out1)/ h ** 0.5], 2) , 2)[0]state1 = out[:, :, -1:]out = out.permute([0, 2, 1, 3])out1 = out1.permute([0, 2, 1, 3])if state is None:out = out.reshape([b, s, -1])out1 = out1.reshape([b, s, -1])else:out = out[:, -1:].reshape([b, 1, -1])out1 = out1[:, -1:].reshape([b, 1, -1])out = (out + out2) * out + out1return out, state1class FeedForward(torch.nn.Module):def __init__(self, hidden_size):super(FeedForward, self).__init__()self.ffn1 = torch.nn.Linear(hidden_size, hidden_size // 2)self.ffn2 = torch.nn.Linear(hidden_size // 2, hidden_size)self.gate = torch.nn.Linear(hidden_size, hidden_size // 2)self.relu = torch.nn.ReLU()self.gr = torch.nn.Dropout(0.1)def forward(self, x):x1 = self.ffn1(x)x2 = self.relu(self.gate(x))xx = x1 * x2x = self.gr(self.ffn2(xx))return xclass DecoderLayer(torch.nn.Module):def __init__(self, hidden_size, num_heads):super(DecoderLayer, self).__init__()self.self_attention = MaxState(hidden_size, num_heads)self.ffn = FeedForward(hidden_size)self.layer_norm = torch.nn.LayerNorm(hidden_size)self.alpha = torch.nn.Parameter(torch.tensor(0.5))def forward(self, x, state=None, ):x1, state = self.self_attention(x, state)x = self.layer_norm(self.alpha * self.ffn(x1) + (1 - self.alpha) * x)return x, stateclass SamOut(torch.nn.Module):def __init__(self, voc_size, hidden_size, num_heads, num_layers):super(SamOut, self).__init__()self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3)self.decoder_layers = torch.nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])self.head = FeedForward(hidden_size)def state_forward(self, state, x):if state is None:state = [None] * len(self.decoder_layers)i = 0for ii, decoder_layer in enumerate(self.decoder_layers):x1, state[i] = decoder_layer(x, state[i])x = x1 + xi += 1return x, statedef forward(self, x, state=None):x = self.em(x)x, state = self.state_forward(state, x)em = self.head(self.em.weight) / x.shape[-1]return x @ em.permute([1, 0]), stateif __name__ == '__main__':net = SamOut(235, 256, 16, 4)net(torch.randint(0, 200, [2, 8 * 13]))

解析该代码为什么推理空间不变,需要深入理解其设计和实现细节。首先,让我们逐步分析这段代码中的各个组件及其工作原理。

MaxState 类

MaxState 类定义了一个多头机制(multi-head mechanism),它可能是为了模仿Transformer架构中的多头注意力机制。这个类中定义了三个线性变换层 (self.head0, self.head1, 和 self.head2) 来处理输入数据,并通过累积最大值操作 (torch.cummax) 来更新状态。这种设计确保了在每次前向传播时,模型能够保留到目前为止遇到的最大值,从而保持了一种记忆效应。这有助于维持推理过程中的上下文信息而不改变空间维度。

class MaxState(torch.nn.Module):def __init__(self, hidden_dim, heads):...def forward(self, input_data, state=None):...if state is None:out = torch.cummax((out + out1) / h ** 0.5, 2)[0]else:out = torch.cummax(torch.concat([state, (out + out1)/ h ** 0.5], 2) , 2)[0]...

FeedForward 类

FeedForward 类实现了简单的前馈神经网络(FFN),它包括两个线性层以及门控机制来控制信息流。此模块负责执行空间变换,但它的输出形状与输入相同,因此不会改变特征图的空间尺寸。

class FeedForward(torch.nn.Module):def __init__(self, hidden_size):...def forward(self, x):...return x

DecoderLayer 类

DecoderLayer 类结合了自注意力机制(由 MaxState 实现)和前馈网络(FeedForward)。这里引入了一个可学习参数 alpha 来调整来自这两个子模块的信息混合比例。重要的是,层归一化(LayerNorm)被应用于最终输出之前,以稳定训练并帮助梯度流动。此外,由于 MaxState 的设计保证了状态的持续更新而没有改变序列长度或隐藏维度,所以整个解码器层也不会改变推理空间。

class DecoderLayer(torch.nn.Module):def __init__(self, hidden_size, num_heads):...def forward(self, x, state=None, ):...return x, state

SamOut 类

SamOut 类是整个模型的核心,它包含了嵌入层、多个解码器层组成的列表(ModuleList),以及一个最终的前馈网络用于生成预测结果。关键点在于:

  • Embedding 层:将词汇索引转换为密集向量表示。
  • 解码器层堆叠:通过循环调用每个解码器层来进行多次迭代处理。每层都接收当前状态作为输入,并返回更新后的状态。
  • 最终投影:使用嵌入矩阵转置乘以前一层输出的方式计算logits。这种方式确保了输出维度与词汇表大小相匹配,同时保持了输入序列的时间步数不变。
class SamOut(torch.nn.Module):def __init__(self, voc_size, hidden_size, num_heads, num_layers):...def state_forward(self, state, x):...return x, statedef forward(self, x, state=None):...return x @ em.permute([1, 0]), state

推理空间不变的原因

综上所述,SamOut 模型之所以能够在推理过程中保持空间不变,是因为从输入到输出的所有操作都被精心设计以保持时间序列的长度一致。无论是通过自定义的 MaxState 进行状态更新还是通过标准的 FeedForward 进行非线性变换,这些操作都不会减少或增加时间步的数量。即使在应用了多层解码器之后,序列的长度依然保持不变,只是特征表示得到了增强或转换。最后,在生成预测时,模型采用了基于嵌入矩阵的操作,这也保证了输出的概率分布对应于原始输入序列的每一个位置,因此推理空间在整个过程中是稳定的。

这种特性对于诸如语言模型等应用场景非常重要,因为它们通常需要根据给定的历史信息逐个生成新词,而不能随意改变文本序列的结构。通过这种方式,模型可以在不破坏原有顺序的情况下有效地捕捉长期依赖关系,并且在生成过程中逐步构建合理的句子或段落。


http://www.mrgr.cn/news/80748.html

相关文章:

  • leetcode 面试经典 150 题:无重复字符的最长子串
  • 全面详解 g++ 编译器命令行选项
  • 网络编程 03:端口的定义、分类,端口映射,通过 Java 实现了 IP 和端口的信息获取
  • CSS|12 display属性
  • 放弃机器学习框架,如何用Python做物体检测?
  • Java中的垃圾收集器
  • [SZ901]程序固化工具速度对比
  • 【Maven】基础(一)
  • 排序算法深度好文(图解 + 代码解析 + 误区 QA )——学排序看这一篇就够了!!!
  • 洛谷P3879 [TJOI2010] 阅读理解(c嘎嘎)
  • 【CSS in Depth 2 精译_085】14.2:CSS 蒙版的用法
  • 无刷电机的概念
  • Linux:进程通信、管道通信
  • PYQT5程序框架
  • Go-FastDFS文件服务器一镜到底使用Docker安装
  • 【AI图像生成网站Golang】项目架构
  • 基础数据结构---栈
  • linux_x64 下的一般汇编函数与syscall调用约定
  • 安卓换源资源记录
  • 修改ubuntu apt 源及apt 使用
  • HW机试题库(个人总结)
  • Fast-Planner项目复现(Ubuntu 20.04 ROS Noetic)
  • 设计模式2
  • harbor离线安装 配置https 全程记录
  • Flutter环境搭建
  • vue复习