pytorch实现RNN网络

server/2024/9/22 21:56:39/

目录

1.导包

2. 加载本地文本数据

 3.构建循环神经网络层

4.初始化隐藏状态state

5.创建随机的数据,检测一下代码是否能正常运行

6. 构建一个完整的循环神经网络¶ 

7.模型训练 

8.个人知识点理解


 

1.导包

import torch
from torch import nn
from torch.nn import functional as F
import dltools

2. 加载本地文本数据

#声明变量:批次大小(一批所取的数据量)、子序列的长度
batch_size, num_steps =32, 35
#获取训练数据的迭代器, 词汇表
train_iter, vocab = dltools.load_data_time_machine(batch_size=batch_size, num_steps=num_steps)

 3.构建循环神经网络层

#声明变量:隐藏层的神经元数量(每个神经元都会有一个输出)
num_hiddens = 256
#构建一个具有256个隐藏单元的单隐藏层的循环神经网络
#num_layers=1默认值:一层神经网络
rnn_layer = nn.RNN(input_size=len(vocab), hidden_size=num_hiddens, num_layers=1)

4.初始化隐藏状态state

# 括号中的1:因为num_layers=1默认值:一层神经网络
state = torch.zeros((1, batch_size, num_hiddens))
state.shape
torch.Size([1, 32, 256])

5.创建随机的数据,检测一下代码是否能正常运行

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
#传入X和初始化时的state,获取Y和state_new
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape#有输出表示代码正常运行!!!

 (torch.Size([35, 32, 256]), torch.Size([1, 32, 256])) 

6. 构建一个完整的循环神经网络¶ 

.long() 方法‌:这是PyTorch张量的一个方法,用于将张量的数据类型转换为torch.long。torch.long是一种整数数据类型,通常用于索引或存储不需要浮点数精度的整数数据。 

class RNNModel(nn.Module):   #继承nn.Module#初始化(需要用到的)参数,  **kwargs表示继承的其他参数(不一一写明的意思)#vocab_size = len(vocab)def __init__(self, rnn_layer, vocab_size, **kwargs):#继承父类的属性和方法super().__init__(**kwargs)self.rnn_layer = rnn_layer#词汇表的长度self.vocab_size =vocab_sizeself.num_hiddens = self.rnn_layer.hidden_size#判断是否为双向循环if not self.rnn_layer.bidirectional:self.num_directions = 1#nn.Linear用于定义线性层的类,一般用于全连接层self.linear = nn.Linear(in_features=self.num_hiddens, out_features=self.vocab_size)else:self.num_directions = 2self.linear = nn.Linear(self.num_hiddens*2, self.vocab_size)#定义了数据在模型中的前向传播过程。(串联每一件事件的逻辑顺序)def forward(self, inputs, state):#one_hot编码,处理输入的X数据,此时的X.shape=(batch_size, num_steps)#。T转置之后,X.shape=(num_steps,batch_size)#one_hot编码之后, X.shape=(num_steps,batch_size, len(vocab)X = F.one_hot(inputs.T.long(), self.vocab_size)#将数据转化为tensorX = X.to(torch.float32)Y, state = self.rnn_layer(X, state)#此时,Y.shape = torch.Size(num_steps, batch_size, num_hiddens)#输出层:Y.shape必须是一个二维的, -1表示合并Y.shape中的num_steps与batch_size,outputs = self.linear(Y.reshape(-1, Y.shape[-1]))return outputs, state# 初始化隐藏状态def begin_state(self, device, batch_size=1):return torch.zeros((self.num_directions * self.rnn_layer.num_layers, batch_size, self.num_hiddens), device=device)
#在训练之前,基于随机初始化的权重进行预测,测试模型
device = dltools.try_gpu()
rnn_net = RNNModel(rnn_layer, vocab_size=len(vocab))
rnn_net = rnn_net.to(device)
dltools.predict_ch8(prefix='time traveller',num_preds=10, net=rnn_net, vocab=vocab, device=device)
'time travellergghhhhhhhh'

7.模型训练 

#声明变量
#模型训练时,可以先让学习率的值稍大一些,让梯度下降的快一些,然后
#梯度下降到一定程度再改成较小的值
num_epochs, lr = 500, 0.1
dltools.train_ch8(net=rnn_net, train_iter=train_iter, vocab=vocab, lr=lr, num_epochs=num_epochs, device=device)

 

8.个人知识点理解

 

 

 


http://www.ppmy.cn/server/120480.html

相关文章

如何将生物序列tokenization为token?

文章目录 原理讲解:代码实现:利用k-mer技术把核苷酸或氨基酸序列tokenization成token输入文件和结果文件: 原理讲解: tokenization是自然语言处理领域非常成熟的一项技术,tokenization就是把我们研究的语言转换成计算…

macOS平台(intel)编译MAVSDK安卓平台SO库

1.下载MAVSDK: git clone https://github.com/mavlink/MAVSDK.git --recursive 2.编译liblzma 修改CMakeLists.txt文件增加C与CXX指令-fPIC set(CMAKE_C_FLAGS "-fPIC ${CMAKE_C_FLAGS}") set(CMAKE_CXX_FLAGS "-fPIC ${CMAKE_CXX_FLAGS}") 修改如下:…

算法题目复习(0909-0917)

1. 连续子序列和 pdd的算法题&#xff0c;根本不记得怎么做 给一个数组&#xff0c;有正数和负数&#xff0c;算出连续子序列的和最大为多少 int maxSubArraySum(vector<int>& nums) {int maxSoFar nums[0];int maxEndingHere nums[0];for (size_t i 1; i <…

Ubuntu 22.04.5 LTS 发布下载 - 现代化的企业与开源 Linux

Ubuntu 22.04.5 LTS (Jammy Jellyfish) - 现代化的企业与开源 Linux Ubuntu 22.04.5 发布&#xff0c;配备 Linux 内核 6.8 请访问原文链接&#xff1a;https://sysin.org/blog/ubuntu-2204/&#xff0c;查看最新版。原创作品&#xff0c;转载请保留出处。 作者主页&#xf…

【鸿蒙OH-v5.0源码分析之 Linux Kernel 部分】007 - 一号内核线程 kernel_init线程 工作流程分析

【鸿蒙OH-v5.0源码分析之 Linux Kernel 部分】007 - 一号内核线程 kernel_init线程 工作流程分析 系列文章汇总:《鸿蒙OH-v5.0源码分析之 Uboot+Kernel 部分】000 - 文章链接汇总》 本文链接:《【鸿蒙OH-v5.0源码分析之 Linux Kernel 部分】007 - 一号内核线程 kernel_init线…

nginx_单机平滑升级

#!/bin/bash# 定义要下载的 Nginx 源码包的 URL 和保存路径 nginx_tar"http://nginx.org/download/nginx-1.19.0.tar.gz" nginx_tar_file"/tmp/nginx-1.19.0.tar.gz" nginx_version"nginx-1.19.0" nginx_path$(which nginx) # 获取 Nginx 的路径…

【Kubernetes】常见面试题汇总(二十四)

目录 71.假设一家公司想要修改它的部署方法&#xff0c;并希望建立一个更具可扩展性和响应性的平台。您如何看待这家公司能够实现这一目标以满足客户需求&#xff1f; 72.考虑一家拥有非常分散的系统的跨国公司&#xff0c;期待解决整体代码库问题。您认为公司如何解决他们的问…

前端环境搭建

配置国内镜像 nvm node_mirror https://npmmirror.com/mirrors/node/ nvm npm_mirror https://npmmirror.com/mirrors/npm/ 安装指定版本node并切换&#xff08;以16.17.0为例&#xff09;&#xff0c;过程中会弹出两次授权&#xff0c;确认即可 nvm install 16.17.0 nvm us…