当前位置: 首页 > news >正文

【NLP】GRU基本结构原理,代码实现

LSTM变种GRU

GRU是LSTM改进的门控循环神经网络,将输入门,遗忘门,输出门变成更新门和重置门。

将细胞状态和隐藏状态合并,只有当前时刻候选状态和当前时刻隐藏状态。

【NLP】LSTM结构,原理,代码实现,序列池化-CSDN博客

模型结构

在这里插入图片描述

内部结构
在这里插入图片描述

相较于LSTM,GRU的结构更加简洁,参数更少,计算效率更高

可以类比LSTM理解GRU,同样都是门控机制

重置门

在这里插入图片描述

决定了保留多上一个时间步的信息和当前的信息合并输入

候选门

在这里插入图片描述

最终隐藏状态

在这里插入图片描述

代码实现

原生代码实现

import numpy as npclass GRU():def __init__(self,input_size,hidden_size):self.input_size = input_sizeself.hidden_size = hidden_size# 初始化权重参数# 跟新门self.W_z = np.random.randn(hidden_size,hidden_size+input_size)self.b_z = np.zeros(hidden_size)# 重置门self.W_r = np.random.randn(hidden_size,hidden_size+input_size)self.b_r = np.zeros(hidden_size)# 候选隐藏状态self.W_h = np.random.randn(hidden_size,hidden_size+input_size)self.b_h = np.zeros(hidden_size)def tanh(self,x):return np.tanh(x)def sigmoid(self,x):return 1/(1+np.exp(-x))def forward(self,x):# 初始化隐藏状态h_prev = np.zeros((self.hidden_size,))concat_input = np.concatenate([x,h_prev],axis=0)z_t = self.sigmoid(np.dot(self.W_z,concat_input)+self.b_z)r_t = self.sigmoid(np.dot(self.W_r,concat_input)+self.b_r)concat_reset_input = np.concatenate([x,r_t*h_prev],axis=0)h_hat_t = self.tanh(np.dot(self.W_h,concat_reset_input)+self.b_h)h_t = (1-z_t)*h_prev + z_t*h_hat_treturn h_t# 测试数据
input_size = 3
hidden_size = 2
seq_len = 4
x = np.random.randn(seq_len,input_size)gru = GRU(input_size, hidden_size)
all_h = []
for t in range(seq_len):h_t = gru.forward(x[t,:])all_h.append(h_t)print(h_t.shape)all_h = np.array(all_h)
print(all_h.shape)

基于PyTorch的GURcell

import torch
import torch.nn as nn
import numpy as npclass GRUcell(nn.Module):def __init__(self,input_size, hidden_size):super().__init__()self.input_size = input_sizeself.hidden_size = hidden_sizeself.gru_cell = nn.GRUCell(input_size,hidden_size)def forward(self,x):h_t = self.gru_cell(x)return h_tinput_size = 3
hidden_size = 2
seq_len = 2x = torch.randn(seq_len,input_size)
grucell = GRUcell(input_size, hidden_size)
for t in range(seq_len):out = grucell(x[t])print(out)

基于PyTorch的GRUapi实现

import torch
import torch.nn as nnclass GRU(nn.Module):def __init__(self,input_size, hidden_size):super().__init__()self.input_size = input_sizeself.hidden_size = hidden_sizeself.gru = nn.GRU(input_size,hidden_size)def forward(self,x):out,_ = self.gru(x)return outinput_size = 3
hidden_size = 2
seq_len = 4
bach_size = 5x = torch.randn(seq_len,bach_size,input_size)gru = GRU(input_size,hidden_size)
out = gru(x)
print(out)
print(out.shape)

http://www.mrgr.cn/news/45523.html

相关文章:

  • Flutter + Three.js (WebView)实现桌面端3d模型展示和交互
  • 【Vue】Vue2(2)
  • 【LeetCode刷题】:双指针篇(移动零、复写零)
  • SEO(搜索引擎优化)指南
  • 红队老子养成记2 - 不想渗透pc?我们来远控安卓!(全网最详细)
  • 要实现无限极评论
  • 计算机毕业设计-自主完成指南
  • MySql复习知识及扩展内容
  • C语言从头学65—学习头文件 <stdio.h>(一)
  • 碧桂园服务携手安徽砀山,以购代捐助力乡村振兴
  • scaling 的作用
  • Python Kivy 完整应用开发:待办事项列表
  • 【RTCP】Interarrival Jitter: 到达间隔抖动的举例说明
  • 【Transformer 模型中的投影层,lora_projection是否需要?】
  • 点餐小程序实战教程17角色管理
  • OpenHarmony(鸿蒙南向开发)——轻量系统内核(LiteOS-M)【内存调测】
  • Ngx+Lua+Redis 快速存储POST数据
  • 如何使用PSTools工具集中的PSExec修改注册表信息,解决某些注册表项无法删除的问题
  • 以下是一些数据看板的常见使用场景:
  • 招个测试员,我又面试了100+人,未果…