李沐56_门控循环单元——自学笔记

server/2024/10/22 18:27:30/

关注每一个序列

1.不是每个观察值都是同等重要

2.想只记住的观察需要:能关注的机制(更新门 update gate)、能遗忘的机制(重置门 reset gate)

python">!pip install --upgrade d2l==0.17.5  #d2l需要更新
python">import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
Downloading ../data/timemachine.txt from http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt...

下一步是初始化模型参数。 我们从标准差为0.01的高斯分布中提取权重, 并将偏置项设为0,超参数num_hiddens定义隐藏单元的数量, 实例化与更新门、重置门、候选隐状态和输出层相关的所有权重和偏置。

python">def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xz, W_hz, b_z = three()  # 更新门参数W_xr, W_hr, b_r = three()  # 重置门参数W_xh, W_hh, b_h = three()  # 候选隐状态参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params

将定义隐状态的初始化函数init_gru_state。此函数返回一个形状为(批量大小,隐藏单元个数)的张量,张量的值全部为零。

python">def init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )

准备定义门控循环单元模型, 模型的架构与基本的循环神经网络单元是相同的, 只是权重更新公式更为复杂。

python">def gru(inputs, state, params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)

训练结束后,我们分别打印输出训练集的困惑度, 以及前缀“time traveler”和“traveler”的预测序列上的困惑度。

python">vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.1, 31831.9 tokens/sec on cuda:0
time traveller for so it will be convenient to speak of himwas e
travelleryou can show black is white by argument said filby

在这里插入图片描述

简洁实现

python">num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.0, 255484.2 tokens/sec on cuda:0
time traveller for so it will be convenient to speak of himwas e
traveller with a slight accession ofcheerfulness really thi

在这里插入图片描述


http://www.ppmy.cn/server/19164.html

相关文章

GaussianEditor:快速可控的3D编辑与高斯飞溅

GaussianEditor: Swift and Controllable 3D Editing with Gaussian Splatting GaussianEditor:快速可控的3D编辑与高斯飞溅 Yiwen Chen*​1,2   Zilong Chen*​3,5   Chi Zhang2   Feng Wang3   Xiaofeng Yang2 陈怡雯 *​1,2 陈子龙 *​3,5 张驰 2 王峰 3 杨晓…

DB索引B+树SQL优化

数据库的索引就像一本书的目录,查数据快人一步,快速定位,精准打击! 什么是数据库的索引? 官方介绍索引是帮助MySQL高效获取数据的数据结构。更通俗的说,数据库索引好比是一本书前面的目录,能加…

微信小程序开发:2.小程序组件

常用的视图容器类组件 View 普通的视图区域类似于div常用来进行布局效果 scroll-view 可以滚动的视图&#xff0c;常用来进行滚动列表区域 swiper and swiper-item 轮播图的容器组件和轮播图的item项目组件 View组件的基本使用 案例1 <view class"container"&…

RabbitMQ在Java中的完美实现:从入门到精通

哈喽&#xff0c;大家好&#xff0c;我是木头左&#xff01; 一、RabbitMQ简介 RabbitMQ是一个开源的AMQP实现&#xff0c;服务器端用Erlang语言编写&#xff0c;支持多种客户端&#xff0c;如&#xff1a;Python、Ruby、.NET、Java、JMS、C、PHP、ActionScript、XMPP、STOMP等…

安卓ComponentName简介及使用

目录 一、ComponentName是什么&#xff1f;二、类分析2.1 构造方法2.2 重点方法 三、ComponentName的使用 一、ComponentName是什么&#xff1f; ComponentName&#xff0c;顾名思义&#xff0c;就是组件名称&#xff0c;用于表示Android应用程序组件的名称。在Android开发中&…

Redis高级篇详细讲解

0.今日菜单 Redis持久化【理解】 Redis主从 Redis哨兵 Redis分片集群【运维】 单点Redis的问题 数据丢失问题&#xff1a;Redis是内存存储&#xff0c;服务重启可能会丢失数据 并发能力问题&#xff1a;单节点Redis并发能力虽然不错&#xff0c;但也无法满足如618这样的高…

神经网络与深度学习(三)——卷积神经网络基础

卷积神经网络基础 1.为什么要学习神经网络1.1全连接网络问题1.2深度学习平台简介1.3PyTorch简介1.4简单示例 2.卷积神经网络基础2.1进化史2.2特征提取2.3基本结构 3.学习算法3.1前向传播3.2误差反向传播3.2.1经典BP算法3.2.2卷积NN的BP算法 4.LeNet-5网络4.1网络介绍4.2网络结构…

RabbitMQ(高级)笔记

一、生产者可靠性 &#xff08;1&#xff09;生产者重连&#xff08;不建议使用&#xff09; logging:pattern:dateformat: MM-dd HH:mm:ss:SSSspring:rabbitmq:virtual-host: /hamllport: 5672host: 192.168.92.136username: hmallpassword: 123listener:simple:prefetch: 1c…