循环神经网络-简洁实现

news/2025/3/19 7:05:37/

参考:
https://zh-v2.d2l.ai/chapter_recurrent-neural-networks/rnn-concise.html
https://pytorch.org/docs/stable/generated/torch.nn.RNN.html?highlight=rnn#torch.nn.RNN

RNN

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lbatch_size, num_steps = 32, 35  # num_steps: sequence length
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps) #  vocab:Vocab 26# 1 定义模型
# 构造一个具有256个隐藏层的循环神经网络 rnn_layer
# 此处先仅设计一层循环神经网络,以后讨论多层神经网络
num_hiddens = 256
rnn_layer = nn.RNN(len(vocab),num_hiddens) # RNN(28,256)
"""input_size – The number of expected features in the input x
hidden_size – The number of features in the hidden state h
num_layers – Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two RNNs together to form a stacked RNN, with the second RNN taking in outputs of the first RNN and computing the final results. Default: 1
nonlinearity – The non-linearity to use. Can be either 'tanh' or 'relu'. Default: 'tanh'
bias – If False, then the layer does not use bias weights b_ih and b_hh. Default: True
batch_first – If True, then the input and output tensors are provided as (batch, seq, feature) instead of (seq, batch, feature). Note that this does not apply to hidden or cell states. See the Inputs/Outputs sections below for details. Default: False
dropout – If non-zero, introduces a Dropout layer on the outputs of each RNN layer except the last layer, with dropout probability equal to dropout. Default: 0
bidirectional – If True, becomes a bidirectional RNN. Default: False
"""
# 2.我们使用张量来初始化隐状态,它的形状是(隐藏层数,批量大小,隐藏单元数)
state = torch.zeros((1,batch_size,num_hiddens))
print(state.shape)  #(torch.size([1,32,256]))#3. 通过一个隐状态和一个输入,我们就可以用更新后的隐状态计算输出。
# 需要强调的是,rnn_layer的“输出”(Y)不涉及输出层的计算: 它是指每个时间步的隐状态,这些隐状态可以用作后续输出层的输入。
X=torch.rand(size=(num_steps,batch_size,len(vocab)))  #torch.Size([35, 32, 28])   # (L,N,H(in)) L:sequence length  N batch size Hin: input_size
Y,state_new = rnn_layer(X,state)
print(Y.shape,state_new.shape) #torch.Size([35, 32, 256]) torch.Size([1, 32, 256])class RNNModel(nn.Module):"""循环神经网络"""def __init__(self,rnn_layer,vocab_size,**kwargs):super(RNNModel,self).__init__(**kwargs)self.rnn = rnn_layerself.vocab_size = vocab_sizeself.num_hiddens = self.rnn.hidden_size# 如果RNN是双向的,num_directions 应该是2,否则应该是1if not self.rnn.bidirectional:self.num_directions = 1self.linear = nn.Linear(self.num_hiddens,self.vocab_size)else:self.num_directions = 2self.linear = nn.Linear(self.num_hiddens*2,self.vocab_size)def forward(self,inputs,state):X = F.one_hot(inputs.T.long(),self.vocab_size)X = X.to(torch.float32)Y,state = self.rnn(X,state)# 全连接首层将Y的形状改为(时间步数*批量大小,隐藏单元数)output = self.linear(Y.reshape((-1,Y.shape[-1])))return output,statedef begin_state(self, device, batch_size=1):if not isinstance(self.rnn, nn.LSTM):# nn.GRU以张量作为隐状态return  torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens),device=device)else:# nn.LSTM以元组作为隐状态return (torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device),torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device))# 训练
device = d2l.try_gpu()
net = RNNModel(rnn_layer,vocab_size=len(vocab))
net = net.to(device)
num_epochs ,lr = 500,1
d2l.train_ch8(net,train_iter,vocab,lr,num_epochs,device)

http://www.ppmy.cn/news/1118622.html

相关文章

zookeeper源码(01)集群启动

本文介绍一下zookeeper-3.5.7集群安装。 解压安装 tar zxf apache-zookeeper-3.5.7-bin.tar.gz创建数据、日志目录: mv apache-zookeeper-3.5.7-bin /app/zookeeper-3.5.7 cd /app/zookeeper-3.5.7mkdir data mkdir logs编辑配置文件 zoo.cfg文件 cp conf/zoo_…

calibre和cpolar搭建一个私有的网络书库

Kindle中国电子书店停运不要慌,十分钟搭建自己的在线书库随时随地看小说! 文章目录 Kindle中国电子书店停运不要慌,十分钟搭建自己的在线书库随时随地看小说!1.网络书库软件下载安装2.网络书库服务器设置3.内网穿透工具设置4.公网…

VS2019创建GIt仓库时剔除文件或目录

假设本地有解决方案“SomeSolution” 1、首先”团队资源管理器“-“创建Git存储库”,选择“仅限本地”、“创建” VS会在解决方案目录下自动生成.gitattributes、.gitignore 2、编辑gitignore,直接拖到VS里或者用记事本打开。添加要剔除的文件或文件夹…

【LeetCode热题100】--560.和为K的子数组

560.和为K的子数组 示例2的结果: 输入:nums [1,2,3] ,k3的时候 连续子数组有[1,2],[3],一共有2个 利用枚举法: 枚举[0,…i]里所有的下标j来判断是否符合条件 class Solution {public int subarraySum(int[] nums, int k) {i…

决胜绝地求生:玩家们常见最关心的吃鸡要领和细节,全方位指南

吃鸡(绝地求生)作为一款风靡全球的大逃杀游戏,在玩家中拥有庞大的粉丝群体。为了帮助那些正在或准备进入吃鸡世界的玩家们,以下是一些常见的关心事项和要点,希望能为大家带来帮助。 1. 选择合适的降落点:在…

ChatGPT Prompting开发实战(八)

一. 什么是归纳总结式的prompt开发 有时候需要对一段文本进行归纳总结,那么可以采取以下的方案: -按照给定单词、句子或者字符的数量限制来让模型裁剪文本,使内容更精炼 -基于聚焦的主题进行总结 -只根据需求抽取相关的文本信…

【Java】注解 之 处理注解

处理注解 Java的注解本身对代码逻辑没有任何影响。根据Retention的配置: SOURCE类型的注解在编译期就被丢掉了;CLASS类型的注解仅保存在class文件中,它们不会被加载进JVM;RUNTIME类型的注解会被加载进JVM,并且在运行…

vuex实现简易购物车加购效果

目录 一、加购效果动图二、前提条件三、开始操作四、解决vuex刷新数据丢失问题五、最终效果 一、加购效果动图 二、前提条件 创建了vue项目,安装了vuex 三、开始操作 目录结构如下: main.js文件中引入store: import Vue from vue import App from ./…