9--RNN

news/2024/11/28 15:42:23/

有隐藏状态的循环神经网络

        假设在时间步t有小批量输入\mathbf{X}_t \in \mathbb{R}^{n \times d},即对于n个序列样本的小批量,\mathbf{X}_t的每一行对应于来自该序列的时间步t处的一个样本,用\mathbf{H}_t \in \mathbb{R}^{n \times h}表示时间步t的隐藏变量。与MLP不同的是, 我们在这里保存了前一个时间步的隐藏变量\mathbf{H}_{t-1},并引入了一个新的权重参数\mathbf{W}_{hh} \in \mathbb{R}^{h \times h}。当前时间步隐藏变量由当前时间步的输入与前一个时间步的隐藏变量一起计算得出:

\mathbf{H}_t = \phi(\mathbf{X}_t \mathbf{W}_{xh} + \mathbf{H}_{t-1} \mathbf{W}_{hh} + \mathbf{b}_h).

        从相邻时间步的隐藏变量\mathbf{H}_t\mathbf{H}_{t-1}之间的关系可知, 这些变量捕获并保留了序列直到其当前时间步的历史信息, 就如当前时间步下神经网络的状态或记忆, 因此这样的隐藏变量被称为隐状态(hidden state)。对于时间步t,输出层的输出类似于多层感知机中的计算:

\mathbf{O}_t = \mathbf{H}_t \mathbf{W}_{hq} + \mathbf{b}_q.

        其实循环神经网络与MLP不同的地方就在于,中间隐藏层的更新会依赖于上一时间步的隐藏层。(下图中蓝色的点为隐藏层)

基于循环神经网络的字符级语言模型 

        根据过去的词与当前的词来对下一个词进行预测,可以将词的原始序列位移一个词源作为一个标签。考虑使用神经网络来进行语言建模,设小批量大小为1,批量中的那个文本序列为“machine”。这里考虑字符级语言模型,下图展示了如何通过之前以及当前字符预测下一个字符。

        在训练过程中,对每个时间步的输出都进行一个softmax操作,并利用交叉熵损失计算模型输出和标签之间的误差。

困惑度(Perplexity)

        对于语言模型预测的结果,通过计算序列的似然概率来度量模型的质量。 一个更好的语言模型应该能更准确地预测下一个词元。因此,它在压缩序列时花费更少的比特。所以可以通过一个序列中所有的n个词元的交叉熵损失的平均值来衡量:

\frac{1}{n} \sum_{t=1}^n -\log P(x_t \mid x_{t-1}, \ldots, x_1),

        其中P由语言模型给出, xt是在时间步t从该序列中观察到的实际词元,上式的指数则称为困惑度,即下一个词元的实际选择数的调和平均数

\exp\left(-\frac{1}{n} \sum_{t=1}^n \log P(x_t \mid x_{t-1}, \ldots, x_1)\right). 

        在最好的情况下,模型总是完美地估计标签词元的概率为1(即预测结果为一个词元), 在这种情况下,模型的困惑度为1。 在最坏的情况下,模型总是预测标签词元的概率为0,在这种情况下,困惑度是正无穷大。在基线上,该模型的预测是词表的所有可用词元上的均匀分布,困惑度等于词表中唯一词元的数量。

实例

        基于时光机器数据集来训练模型,具体代码如下:

!pip install git+https://github.com/d2l-ai/d2l-zh@release  # installing d2l
!pip install matplotlib_inline
!pip install matplotlib==3.0.0import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lbatch_size , num_steps = 32,35
train_iter,vocab = d2l.load_data_time_machine(batch_size , num_steps)#构造一个具有256个隐藏单元的单隐藏层的循环神经网络层
num_hiddens = 256
rnn_layer = nn.RNN(len(vocab),num_hiddens,1)class RNNModel(nn.Module):def __init__(self,rnn_layer,vocab_size,**kwargs):super(RNNModel,self).__init__(**kwargs)self.rnn = rnn_layerself.vocab_size = vocab_sizeself.num_hiddens = self.rnn.hidden_sizeif not self.rnn.bidirectional:self.num_directions=1self.linear = nn.Linear(self.num_hiddens,self.vocab_size)else:self.num_directions=2self.linear = nn.Linear(self.num_hiddens*2,self.vocab_size)def forward(self,inputs,state):X = F.one_hot(inputs.T.long(),self.vocab_size)X = X.to(torch.float32)Y, state = self.rnn(X,state)output = self.linear(Y.reshape(-1,Y.shape[-1]))return output,state#初始化隐状态为0 形状是(隐藏层数,批量大小,隐藏单元数)def begin_state(self,device,batch_size=1):if not isinstance(self.rnn,nn.LSTM):return  torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens),device=device)else:return (torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device),torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device))
device = d2l.try_gpu()
net = RNNModel(rnn_layer,vocab_size=len(vocab))
num_epochs,lr = 500,1
d2l.train_ch8(net,train_iter,vocab,lr,num_epochs,device)

        运行结果如下,500个epoch后困惑度达到了1.3。

        另外,这里分别使用训练前和训练后的模型对“time traveller”后续词元进行续写,可以看出模型训练前完全是随机性的预测字符串,虽然训练后的模型预测结果语义上不太通顺,但预测出来的单词大部分是正确的(该模型的词元是字符)。

 

 


http://www.ppmy.cn/news/494707.html

相关文章

红米note9pro和华为p40参数对比哪个值得入手

红米note9pro:采 用了一块6.67英寸的屏幕有着2400x1080像素的分辨率并且是支持90Hz的刷新率的 红米手机爆降600这活动太给力了机会不容错过 https://www.xiaomi.com.cn 华为p40更多使用感受和评价:https://www.huawei.com/p40 华为p40:采用了…

iphone12和华为mate40 的区别 哪个好

iPhone12的机身背部是AG工艺磨砂玻璃,很好地避免沾染指纹,航空铝中框回归了传统的硬朗设计风格,有点致敬乔布斯iPhone4时代的味道,拥有黑、白、蓝、黄、金、橙六种颜色可以选择,满足不同审美人群的需要。 华为mate40更…

编译原理笔记10:语言与文法,正规式转CFG,正规式和CFG,文法、语言与自动机

目录 正规式,和 CFG正规式到 CFG 的转换:正规式和 CFG 的关系为毛不用 CFG 描述词法规则贯穿词法、语法分析始终的思想 上下文有关文法 CSG文法、语言与自动机0型文法:1型文法:2型文法:3型文法:为什么&…

Android蓝牙协议知识汇总

蓝牙协议下载 蓝牙技术联盟网址:https://www.bluetooth.com/ 在这个网址搜索,比如: 在搜索结果中找到蓝牙协议规范: 点击上面网址: 蓝牙手册里包含了部分核心协议,比如L2CAP、SDP、ATT、GATT&#x…

吉林大学 计算机网络常见的名词解释

吉林大学 计算机网络常见的名词解释 1.应用层2.传输层3. 网络层4.链路层5. 无线网络和移动网络6.计算机网络中的安全 1.应用层 API (Application Programming Interface)应用程序编程接口HTTP (Hyper Text Transfer Protocol) 超…

车载以太网 - 传输层 - TCP/IP

目录 一、传输层基础介绍 传输层主要包括两种协议 传输层端口号 二、UDP通信 UDP协议介绍 UDP 通信特点: UDP Segment结构 UDP通信过程 三、TCP通信 TCP通信特点: TCP Segment结构 一、传输层基础介绍 传输层的寻址方式:端口号 包括传输层的寻址方式&…

小米max2 android p,小米Max2完全曝光:6.4寸巨屏 性价比爆棚

小米今年上半年准备的新机颇多,除了本月发布的小米6外,大屏手机小米Max也要迎来换代(小米Max已经在小米商城下架)。 现在有网友给出的最新消息,小米将在5月份推出小米Max二代,不过它搭载的处理器并非之前传闻的骁龙660&#xff0c…

小米 max android,小米Max原生安卓8.0刷机包放出:仅供尝鲜体验

IT之家9月4日消息 自谷歌发布Android 8.0正式版以来,来自XDA论坛的第三方开发者已经先后为小米4、国际版红米Note4、全网通版红米Note3、一加1等机型进行了适配,日前,有IT之家网友反映称,现XDA论坛上已有第三方开发者放出了小米Ma…