李沐56_门控循环单元——自学笔记

embedded/2024/9/23 10:19:30/

关注每一个序列

1.不是每个观察值都是同等重要

2.想只记住的观察需要:能关注的机制(更新门 update gate)、能遗忘的机制(重置门 reset gate)

python">!pip install --upgrade d2l==0.17.5  #d2l需要更新
python">import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
Downloading ../data/timemachine.txt from http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt...

下一步是初始化模型参数。 我们从标准差为0.01的高斯分布中提取权重, 并将偏置项设为0,超参数num_hiddens定义隐藏单元的数量, 实例化与更新门、重置门、候选隐状态和输出层相关的所有权重和偏置。

python">def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xz, W_hz, b_z = three()  # 更新门参数W_xr, W_hr, b_r = three()  # 重置门参数W_xh, W_hh, b_h = three()  # 候选隐状态参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params

将定义隐状态的初始化函数init_gru_state。此函数返回一个形状为(批量大小,隐藏单元个数)的张量,张量的值全部为零。

python">def init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )

准备定义门控循环单元模型, 模型的架构与基本的循环神经网络单元是相同的, 只是权重更新公式更为复杂。

python">def gru(inputs, state, params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)

训练结束后,我们分别打印输出训练集的困惑度, 以及前缀“time traveler”和“traveler”的预测序列上的困惑度。

python">vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.1, 31831.9 tokens/sec on cuda:0
time traveller for so it will be convenient to speak of himwas e
travelleryou can show black is white by argument said filby

在这里插入图片描述

简洁实现

python">num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.0, 255484.2 tokens/sec on cuda:0
time traveller for so it will be convenient to speak of himwas e
traveller with a slight accession ofcheerfulness really thi

在这里插入图片描述


http://www.ppmy.cn/embedded/24714.html

相关文章

9节点牛拉法matlab

潮流计算程序matlab 牛拉法 采用matlab对9节点进行潮流计算,采用牛拉法,程序运行可靠。

LinkedList

一.模拟实现 public class MyLinkedList {static class ListNode {private int val;private ListNode prev;//前驱private ListNode next;//后继public ListNode(int val) {this.val val;}}public ListNode head;//双向链表的头节点public ListNode last;//双向链表的尾巴//得…

K8s容器部署maven项目

最近在整一整套devops自动化持续集成的东西,一开始就做好了踩坑的准备。 failed to verify certificate: x509: certificate signed by unknown authority 今天在执行kubectl get nodes的时候报的证书验证问题,看了一圈首次搭建k8s的都是高频出现的问题…

《资本之王》全球私募之王黑石集团成长史 - 三余书屋 3ysw.net

资本之王:全球私募之王黑石集团成长史 大家好,今天我们要解读的书叫做《资本之王》,它讲述了全球私募股权之王——黑石公司的精彩成长史。这本书为我们揭秘了私募股权这个看似神秘的行业,并让我们更深入地了解了华尔街的金融发展…

基于Unity+Vue通信交互的WebGL项目实践

unity-webgl 是无法直接向vue项目进行通信的,需要一个中间者 jslib 文件 jslib当作中间者,unity与它通信,前端也与它通信,在此基础上三者之间进行了通信对接 看过很多例子:介绍的都不是很详细,不如自己写, 注意看箭头走向 共同点:unity 打包项目都放 在 public 里面…

『大模型笔记』AI 智能体(Agent)在推理(Reasoning)、规划(Planning)与工具调度(Tool Calling)方面的研究:综合调查!

AI 智能体(Agent)在推理(Reasoning)、规划(Planning)与工具调度(Tool Calling)方面的研究:综合调查! 文章目录 o. 摘要一. Introduction1.1. Taxonomy(分类学)二. 关键考虑因素以实现有效的智能体2.1. 概述2.2. 推理和规划的重要性2.3. 有效工具调用的重要性三. 单智能体架…

模拟LinkedList实现的双向链表

1. 前言 前文我们用java语言实现了无哨兵的单向链表.稍作修改即可实现有哨兵的单向链表.有哨兵的单向链表相较与无哨兵的而言,其对链表的头结点的增删操作更为方便.而在此我们实现了带有头节点和尾节点的双向链表(该头节点和尾节点都不存储有效的数据). 2. 带有头…

想复制这个侧视图的一半,咋复制不了呢?

问:老师我想复制这个侧视图的一半,咋复制不了呢 答:修改器列表里面选择对称就可以了