深度学习基础——循环神经网络的结构及参数更新方式

devtools/2024/9/23 4:49:49/

深度学习基础——循环神经网络的结构及参数更新方式

深度学习领域的一大重要分支是循环神经网络(Recurrent Neural Networks,简称RNN),它是一种用于处理序列数据的神经网络结构。与传统的前馈神经网络不同,循环神经网络能够利用序列中的时间信息,从而更好地建模序列数据的依赖关系。

1. 概述

循环神经网络是一种具有循环连接的神经网络结构,用于处理序列数据,如文本、时间序列等。其主要特点是可以将过去的信息传递到当前时间步,从而在处理序列数据时具有记忆性。

循环神经网络的基本结构如下图所示:

在这里插入图片描述

其中, x t x_t xt表示时间步 t t t的输入数据, h t h_t ht 表示时间步 t t t 的隐藏状态,用于存储过去的信息, y t y_t yt表示时间步 t t t的输出数据。 U U U表示输入层到隐藏层的权重矩阵, W W W表示上一时间步隐藏状态到当前时间步隐藏状态的权重矩阵, V V V表示隐藏层到输出层的权重矩阵。

2. 公式介绍及详细推导

基本结构

循环神经网络的基本结构如下所示:

h t = σ ( U x t + W h t − 1 ) h_t = \sigma(Ux_t + Wh_{t-1}) ht=σ(Uxt+Wht1)
y t = V h t y_t = Vh_t yt=Vht

其中, σ \sigma σ表示激活函数,通常为Sigmoid、Tanh等函数。

参数更新

循环神经网络的参数更新采用反向传播算法,目标是最小化损失函数。具体来说,假设损失函数为 L L L,则参数更新的公式为:

θ t + 1 = θ t − α ∂ L ∂ θ \theta_{t+1} = \theta_t - \alpha \frac{\partial L}{\partial \theta} θt+1=θtαθL

其中, θ \theta θ 表示模型的参数,包括 U , W , V U, W, V U,W,V等权重矩阵, α \alpha α 表示学习率,控制参数更新的步长, ∂ L ∂ θ \frac{\partial L}{\partial \theta} θL表示损失函数对参数的梯度。

3. 用Python实现示例代码

下面是一个使用PyTorch实现简单循环神经网络的示例代码,并进行了参数更新和结果可视化。

import torch
import torch.nn as nn
import matplotlib.pyplot as plt# 定义循环神经网络模型
class SimpleRNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(SimpleRNN, self).__init__()self.hidden_size = hidden_sizeself.rnn = nn.RNN(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):out, _ = self.rnn(x)out = self.fc(out[:, -1, :])return out# 生成示例数据
input_size = 1
hidden_size = 32
output_size = 1
sequence_length = 100
x = torch.linspace(0, 10, sequence_length).reshape(-1, sequence_length, input_size)
y = torch.sin(x)# 定义模型、损失函数和优化器
model = SimpleRNN(input_size, hidden_size, output_size)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)# 训练模型
epochs = 1000
losses = []
for epoch in range(epochs):optimizer.zero_grad()output = model(x)loss = criterion(output, y)loss.backward()optimizer.step()losses.append(loss.item())# 可视化损失曲线
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Curve')
plt.show()

在这里插入图片描述

4. 总结

本文介绍了循环神经网络的基本结构及其参数更新方式,并通过Python示例代码进行了实现和演示。循环神经网络在处理序列数据时具有很好的效果,可以应用于文本生成、时间序列预测等任务中。深入理解循环神经网络的结构和参数更新方式对于学习和应用深度学习模型具有重要意义。


http://www.ppmy.cn/devtools/8321.html

相关文章

Git 远程仓库多人协作

文章目录 前言一、操作远程仓库1、克隆远程仓库2、向远程仓库推送3、拉取远程仓库4、删除远程库 二、多人协作 前言 Git是分布式版本控制系统,同一个Git仓库,可以分布到不同的机器上。那么该怎么分布呢?首先肯定得有一台机器充当“原始库”&…

Stable Diffusion UI 从安装到实现文字图片融合(光影字,错觉图)图片制作详细教程

前言 最近在实践大模型本地部署,前几天在本地部署了一个ChatGLM大模型,刚好环境搭好了,也支持跑Stable Diffusion,所以就安装了再尝试一下。 原因是之前在B站上有大佬做了一个Windows电脑能一键运行的Stable Diffusion的安装包&…

LeetCode第797题: 所有可能的路径

目录 1.问题描述 2.问题分析 1.问题描述 给你一个有 n 个节点的有向无环图(DAG),请你找出所有从节点 0 到节点 n-1 的路径并输出(不要求按特定顺序)。 graph[i] 是一个从节点 i 可以访问的所有节点的列表&#xff08…

C语言修炼——什么是流?什么是文件?什么是文件操作?

目录 一、为什么使用文件?二、什么是文件?2.1 程序文件2.2 数据文件2.3 文件名 三、二进制文件和文本文件四、文件的打开和关闭4.1 流和标准流4.1.1 流4.1.2 标准流 4.2 文件指针4.3 文件的打开和关闭 五、文件的顺序读写5.1 顺序读写函数介绍a. fgetcb.…

服务器Linux上杀死特定进程的命令:kill

1、查看用户XXX正在运行的进程 top -u xxx2、查看想要杀死的进程对应的PID 先找到此进程对应的命令 取其中的main-a3c.py即可 ps -aux | grep main-a3c.py可以看到对应的PID是1325390使用kill杀死对应PID的进程 kill -9 1325390成功,gpustat可以看到之前一直占…

Java最短路径问题知识点(含面试大厂题和源码)

最短路径问题是图论中的一个经典问题,它寻找图中两点之间的最短路径。这个问题在现实世界中有广泛的应用,比如导航系统中的路线规划、网络中的信息传输等。解决最短路径问题有多种算法,其中最著名的包括: 贝尔曼-福特算法&#xf…

对称加密与非对称加密有什么区别?

本文转自 公众号 ByteByteGo,如有侵权,请联系,立即删除 对称加密与非对称加密有什么区别? 对称加密与非对称加密有什么区别? 对称加密和非对称加密是用于确保数据和通信安全的两种加密技术,但它们在加密和…

木马——文件上传

目录 1、WebShell 2.一句话木马 靶场训练 3.蚁剑 虚拟终端 文件管理 ​编辑 数据操作 4.404.php 5.文件上传漏洞 客户端JS检测 右键查看元素,删除检测代码 BP拦截JPG修改为php 服务端检测 1.MIME类型检测 2.文件幻数检测 3.后缀名检测 1、WebShell W…