pytorch实现RNN网络

ops/2024/9/22 17:56:50/

目录

1.导包

2. 加载本地文本数据

 3.构建循环神经网络层

4.初始化隐藏状态state

5.创建随机的数据,检测一下代码是否能正常运行

6. 构建一个完整的循环神经网络¶ 

7.模型训练 

8.个人知识点理解


 

1.导包

import torch
from torch import nn
from torch.nn import functional as F
import dltools

2. 加载本地文本数据

#声明变量:批次大小(一批所取的数据量)、子序列的长度
batch_size, num_steps =32, 35
#获取训练数据的迭代器, 词汇表
train_iter, vocab = dltools.load_data_time_machine(batch_size=batch_size, num_steps=num_steps)

 3.构建循环神经网络层

#声明变量:隐藏层的神经元数量(每个神经元都会有一个输出)
num_hiddens = 256
#构建一个具有256个隐藏单元的单隐藏层的循环神经网络
#num_layers=1默认值:一层神经网络
rnn_layer = nn.RNN(input_size=len(vocab), hidden_size=num_hiddens, num_layers=1)

4.初始化隐藏状态state

# 括号中的1:因为num_layers=1默认值:一层神经网络
state = torch.zeros((1, batch_size, num_hiddens))
state.shape
torch.Size([1, 32, 256])

5.创建随机的数据,检测一下代码是否能正常运行

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
#传入X和初始化时的state,获取Y和state_new
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape#有输出表示代码正常运行!!!

 (torch.Size([35, 32, 256]), torch.Size([1, 32, 256])) 

6. 构建一个完整的循环神经网络¶ 

.long() 方法‌:这是PyTorch张量的一个方法,用于将张量的数据类型转换为torch.long。torch.long是一种整数数据类型,通常用于索引或存储不需要浮点数精度的整数数据。 

class RNNModel(nn.Module):   #继承nn.Module#初始化(需要用到的)参数,  **kwargs表示继承的其他参数(不一一写明的意思)#vocab_size = len(vocab)def __init__(self, rnn_layer, vocab_size, **kwargs):#继承父类的属性和方法super().__init__(**kwargs)self.rnn_layer = rnn_layer#词汇表的长度self.vocab_size =vocab_sizeself.num_hiddens = self.rnn_layer.hidden_size#判断是否为双向循环if not self.rnn_layer.bidirectional:self.num_directions = 1#nn.Linear用于定义线性层的类,一般用于全连接层self.linear = nn.Linear(in_features=self.num_hiddens, out_features=self.vocab_size)else:self.num_directions = 2self.linear = nn.Linear(self.num_hiddens*2, self.vocab_size)#定义了数据在模型中的前向传播过程。(串联每一件事件的逻辑顺序)def forward(self, inputs, state):#one_hot编码,处理输入的X数据,此时的X.shape=(batch_size, num_steps)#。T转置之后,X.shape=(num_steps,batch_size)#one_hot编码之后, X.shape=(num_steps,batch_size, len(vocab)X = F.one_hot(inputs.T.long(), self.vocab_size)#将数据转化为tensorX = X.to(torch.float32)Y, state = self.rnn_layer(X, state)#此时,Y.shape = torch.Size(num_steps, batch_size, num_hiddens)#输出层:Y.shape必须是一个二维的, -1表示合并Y.shape中的num_steps与batch_size,outputs = self.linear(Y.reshape(-1, Y.shape[-1]))return outputs, state# 初始化隐藏状态def begin_state(self, device, batch_size=1):return torch.zeros((self.num_directions * self.rnn_layer.num_layers, batch_size, self.num_hiddens), device=device)
#在训练之前,基于随机初始化的权重进行预测,测试模型
device = dltools.try_gpu()
rnn_net = RNNModel(rnn_layer, vocab_size=len(vocab))
rnn_net = rnn_net.to(device)
dltools.predict_ch8(prefix='time traveller',num_preds=10, net=rnn_net, vocab=vocab, device=device)
'time travellergghhhhhhhh'

7.模型训练 

#声明变量
#模型训练时,可以先让学习率的值稍大一些,让梯度下降的快一些,然后
#梯度下降到一定程度再改成较小的值
num_epochs, lr = 500, 0.1
dltools.train_ch8(net=rnn_net, train_iter=train_iter, vocab=vocab, lr=lr, num_epochs=num_epochs, device=device)

 

8.个人知识点理解

 

 

 


http://www.ppmy.cn/ops/114356.html

相关文章

Spring Boot-热部署问题

Spring Boot 热部署问题分析与解决方案 热部署(Hot Deployment)是指在应用程序运行过程中,无需停止应用就可以动态加载新代码、配置或资源,从而提升开发效率。在 Spring Boot 开发中,热部署是一项非常实用的功能&…

创建一个带有 F6 快捷键的自动点击器

创建一个带有 F6 快捷键的自动点击器 在许多情况下,自动化点击任务可以帮助我们节省大量时间和精力。本文将介绍如何使用 Python 和 Tkinter 创建一个简单的自动点击器,并通过 F6 键作为快捷键来控制点击器的开始和停止,即使应用程序在后台也…

P9235 [蓝桥杯 2023 省 A] 网络稳定性

*原题链接* 最小瓶颈生成树题,和货车运输完全一样。 先简化题意, 次询问,每次给出 ,问 到 的所有路径集合中,最小边权的最大值。 对于这种题可以用kruskal生成树来做,也可以用倍增来写,但不…

JSON.parseArray 内存溢出

实际上我的JSON如下: 如果用以下代码:JVM的内存直接飙到内存溢出,报错OutOfMemoryError: Java heap space Object oo JSON.parseArray(json, TestVo.class) 如果我换成了这样,就没事: Object oo JSON.parseObject(…

MISC - 第一天(stegsolve图片隐写解析器、QR research、binwalk,dd文件分离,十六进制文件编辑器)

前言 各位师傅大家好,我是qmx_07,最近更新Buuctf在线测评中的MISC杂项内容 介绍 BUUCTF:https://buuoj.cn/ ,整合了各大 CTF 赛事题目,类型丰富,涵盖了Web 安全、密码学、系统安全、逆向工程、代码审计等多个领域 …

VUE面试题(单页应用及其首屏加载速度慢的问题)

目录 一、单页应用 1.概念 2.单页面应用的优缺点 二、多页面应用: 1.概念 2.区别 三、SPA的实现 1.原理 2.方式: 3.Hash与History模式有什么区别 四、首屏加载速度慢如何优化 1.什么是首屏加载? 2.首屏加载慢的原因 3.如何解决…

Qt 学习第十天:小项目:QListWidget的使用

一、页面布局 二、命名按钮 双击按钮可以修改显示中的文字(例如:改成“全选”),objectName是要改成程序员所熟悉的名字(英文,符合代码规范)方便修改和书写代码,一看就能看懂的 三、…

使用IDA Pro动态调试Android APP

版权归作者所有,如有转发,请注明文章出处:https://cyrus-studio.github.io/blog/ 关于 android_server android_server 是 IDA Pro 在 Android 设备上运行的一个调试服务器。 通过在 Android 设备上运行android_server,IDA Pro …