李沐56_门控循环单元——自学笔记

news/2024/9/23 6:35:59/

关注每一个序列

1.不是每个观察值都是同等重要

2.想只记住的观察需要:能关注的机制(更新门 update gate)、能遗忘的机制(重置门 reset gate)

python">!pip install --upgrade d2l==0.17.5  #d2l需要更新
python">import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
Downloading ../data/timemachine.txt from http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt...

下一步是初始化模型参数。 我们从标准差为0.01的高斯分布中提取权重, 并将偏置项设为0,超参数num_hiddens定义隐藏单元的数量, 实例化与更新门、重置门、候选隐状态和输出层相关的所有权重和偏置。

python">def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xz, W_hz, b_z = three()  # 更新门参数W_xr, W_hr, b_r = three()  # 重置门参数W_xh, W_hh, b_h = three()  # 候选隐状态参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params

将定义隐状态的初始化函数init_gru_state。此函数返回一个形状为(批量大小,隐藏单元个数)的张量,张量的值全部为零。

python">def init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )

准备定义门控循环单元模型, 模型的架构与基本的循环神经网络单元是相同的, 只是权重更新公式更为复杂。

python">def gru(inputs, state, params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)

训练结束后,我们分别打印输出训练集的困惑度, 以及前缀“time traveler”和“traveler”的预测序列上的困惑度。

python">vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.1, 31831.9 tokens/sec on cuda:0
time traveller for so it will be convenient to speak of himwas e
travelleryou can show black is white by argument said filby

在这里插入图片描述

简洁实现

python">num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.0, 255484.2 tokens/sec on cuda:0
time traveller for so it will be convenient to speak of himwas e
traveller with a slight accession ofcheerfulness really thi

在这里插入图片描述


http://www.ppmy.cn/news/1433201.html

相关文章

[Android]Jetpack Compose加载图标和图片

一、加载本地矢量图标 在 Android 开发中使用本地矢量图标是一种常见的做法,因为矢量图标(通常保存为 SVG 或 Android 的 XML vector format)具有可缩放性和较小的文件大小。 在 Jetpack Compose 中加载本地矢量图标可以使用内置的支持&…

parallels desktop19.3最新版本软件新功能详细介绍

Parallels Desktop是一款运行在Mac电脑上的虚拟机软件,它允许用户在Mac系统上同时运行多个操作系统,比如Windows、Linux等。通过这款软件,Mac用户可以轻松地在同一台电脑上体验不同操作系统的功能和应用程序,而无需额外的硬件设备…

【C++】:函数重载,引用,内联函数,auto关键字,基于范围的for循环,nullptr关键字

目录 一,函数重载1.1 函数重载的定义1.1.1.形参的类型不同1.1.2参数的个数不同1.1.3.参数的顺序不同1.1.4.有一个是缺省参数构成重载。但是调用时存在歧义1.1.5.返回值不同,不构成重载。因为返回值可接收,可不接受,调用函数产生歧…

程序猿成长之路之数据挖掘篇——朴素贝叶斯

朴素贝叶斯是数据挖掘分类的基础,本篇文章将介绍一下朴素贝叶斯算法 情景再现 以挑选西瓜为例,西瓜的色泽、瓜蒂、敲响声音、触感、脐部等特征都会影响到西瓜的好坏。那么我们怎么样可以挑选出一个好的西瓜呢? 分析过程 既然挑选西瓜有多个…

Java基础入门day37

day37 js小案例 全选&#xff0c;全不选和反选 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Doc…

C#基础|OOP、类与对象的认识

哈喽&#xff0c;你好&#xff0c;我是雷工&#xff01; 所有的面向对象的编程语言&#xff0c;都是把我们要处理的“数据”和“行为”封装到类中。 以下为OOP的学习笔记。 01 什么是面向对象编程&#xff08;OOP&#xff09;&#xff1f; 设计类&#xff1a;就是根据需求设计…

java:基于guava ClassPath工具实现基于包名(package)的类扫描

google的guava库提供了一个类路径扫描的实用工具ClassPath(参见说明&#xff1a; https://github.com/google/guava/wiki/ReflectionExplained#classpath)工具&#xff0c;适用于非android的Java平台搜索类。基于它可以设计一个过滤包名的搜索工具。 导入依赖库 <dependen…

【matlab】【数值分析】针对特殊矩阵的追赶法的matlab实现

【matlab】【数值分析】针对特殊矩阵的追赶法的matlab实现 三对角循环Toeplitz三对角五对角Latex公式源码文件参考资料 原文链接&#xff1a; 点我&#xff01;这就是本人的博客喵&#xff0c;快来看喵&#xff01; 下面的追赶法算法原理不予介绍&#xff0c;在参考文献中有原…