【PyTorch】循环神经网络

server/2024/10/9 6:53:59/

循环神经网络是什么

Recurrent Neural Networks
RNN:循环神经网络

  • 处理不定长输入的模型
  • 常用于NLP及时间序列任务(输入数据具有前后关系

RNN网络结构

参考资料
Recurrent Neural Networks Tutorial, Part 1 – Introduction to RNNs
Understanding LSTM Networks
在这里插入图片描述

RNN实现人名分类

问题定义:输入任意长度姓名(字符串),输出姓名来自哪一个国家(18类分类任务)
数据: https://download.pytorch.org/tutorial/data.zip
Jackie Chan —— 成龙
Jay Chou —— 周杰伦
Tingsong Yue —— 余霆嵩

RNN如何处理不定长输入

思考:计算机如何实现不定长字符串分类向量的映射?
Chou(字符串)→ RNN →Chinese(分类类别)

  1. 单词字符 → 数字
  2. 数字 → model
  3. 下一个字符 → 数字 → model
  4. 最后一个字符 → 数字 → model → 分类向量
# 伪代码
# Chou(字符串)→ RNN →Chinese(分类类别)
for string in [C, h, o, u]:1. one-hot:string → [0,0, ...., 1, ..., 0]	# 首先把每个字母转换成编码2. y, h = model([0,0, ...., 1, ..., 0], h)		# h就是隐藏层的状态信息

xt:时刻t的输入,shape = (1, 57)
st:时刻t的状态值,shape=(1, 128)
ot:时刻t的输出值,shape=(1, 18)
U:linear层的权重参数, shape = (128, 57)
W:linear层的权重参数, shape = (128, 128)
V:linear层的权重参数, shape = (18, 128)

代码如下:

# -*- coding: utf-8 -*-
"""
# @file name  : rnn_demo.py
# @author     : TingsongYu https://github.com/TingsongYu
# @date       : 2019-12-09
# @brief      : rnn人名分类
"""
from io import open
import glob
import unicodedata
import string
import math
import os
import time
import torch.nn as nn
import torch
import random
import matplotlib.pyplot as plt
import torch.utils.data
import sys
# 获取路径
hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__)+os.path.sep+".."+os.path.sep+"..")
sys.path.append(hello_pytorch_DIR)from tools.common_tools import set_seedset_seed(1)  # 设置随机种子
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# 选择运行设备
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")# Read a file and split into lines
def readLines(filename):lines = open(filename, encoding='utf-8').read().strip().split('\n')return [unicodeToAscii(line) for line in lines]def unicodeToAscii(s):return ''.join(c for c in unicodedata.normalize('NFD', s)if unicodedata.category(c) != 'Mn'and c in all_letters)# Find letter index from all_letters, e.g. "a" = 0
def letterToIndex(letter):return all_letters.find(letter)# Just for demonstration, turn a letter into a <1 x n_letters> Tensor
def letterToTensor(letter):tensor = torch.zeros(1, n_letters)tensor[0][letterToIndex(letter)] = 1return tensor# Turn a line into a <line_length x 1 x n_letters>,
# or an array of one-hot letter vectors
def lineToTensor(line):tensor = torch.zeros(len(line), 1, n_letters)for li, letter in enumerate(line):tensor[li][0][letterToIndex(letter)] = 1return tensordef categoryFromOutput(output):top_n, top_i = output.topk(1)category_i = top_i[0].item()return all_categories[category_i], category_idef randomChoice(l):return l[random.randint(0, len(l) - 1)]def randomTrainingExample():category = randomChoice(all_categories)                 # 选类别line = randomChoice(category_lines[category])           # 选一个样本category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.long)line_tensor = lineToTensor(line)    # str to one-hotreturn category, line, category_tensor, line_tensordef timeSince(since):now = time.time()s = now - sincem = math.floor(s / 60)s -= m * 60return '%dm %ds' % (m, s)# Just return an output given a line
def evaluate(line_tensor):hidden = rnn.initHidden()for i in range(line_tensor.size()[0]):output, hidden = rnn(line_tensor[i], hidden)return outputdef predict(input_line, n_predictions=3):print('\n> %s' % input_line)with torch.no_grad():output = evaluate(lineToTensor(input_line))# Get top N categoriestopv, topi = output.topk(n_predictions, 1, True)for i in range(n_predictions):value = topv[0][i].item()category_index = topi[0][i].item()print('(%.2f) %s' % (value, all_categories[category_index]))def get_lr(iter, learning_rate):lr_iter = learning_rate if iter < n_iters else learning_rate*0.1return lr_iter# 定义网络结构
class RNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(RNN, self).__init__()self.hidden_size = hidden_sizeself.u = nn.Linear(input_size, hidden_size)self.w = nn.Linear(hidden_size, hidden_size)self.v = nn.Linear(hidden_size, output_size)self.tanh = nn.Tanh()self.softmax = nn.LogSoftmax(dim=1)def forward(self, inputs, hidden):u_x = self.u(inputs)hidden = self.w(hidden)hidden = self.tanh(hidden + u_x)output = self.softmax(self.v(hidden))return output, hiddendef initHidden(self):return torch.zeros(1, self.hidden_size)def train(category_tensor, line_tensor):hidden = rnn.initHidden()rnn.zero_grad()line_tensor = line_tensor.to(device)hidden = hidden.to(device)category_tensor = category_tensor.to(device)for i in range(line_tensor.size()[0]):output, hidden = rnn(line_tensor[i], hidden)loss = criterion(output, category_tensor)loss.backward()# Add parameters' gradients to their values, multiplied by learning ratefor p in rnn.parameters():# p.data.add_(-learning_rate, p.grad.data) # 该方法已经被弃用p.data.add_(p.grad.data, alpha=-learning_rate)return output, loss.item()if __name__ == "__main__":print(device)# configdata_dir = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "data", "rnn_data", "names"))if not os.path.exists(data_dir):raise Exception("\n{} 不存在,请下载 08-05-数据-20200724.zip  放到\n{}  下,并解压即可".format(data_dir, os.path.dirname(data_dir)))path_txt = os.path.join(data_dir, "*.txt")all_letters = string.ascii_letters + " .,;'"n_letters = len(all_letters)    # 52 + 5 字符总数print_every = 5000plot_every = 5000learning_rate = 0.005n_iters = 200000# step 1 data# Build the category_lines dictionary, a list of names per languagecategory_lines = {}all_categories = []for filename in glob.glob(path_txt):category = os.path.splitext(os.path.basename(filename))[0]all_categories.append(category)lines = readLines(filename)category_lines[category] = linesn_categories = len(all_categories)# step 2 modeln_hidden = 128# rnn = RNN(n_letters, n_hidden, n_categories)rnn = RNN(n_letters, n_hidden, n_categories)rnn.to(device)# step 3 losscriterion = nn.NLLLoss()# step 4 optimize by hand# step 5 iterationcurrent_loss = 0all_losses = []start = time.time()for iter in range(1, n_iters + 1):# samplecategory, line, category_tensor, line_tensor = randomTrainingExample()# trainingoutput, loss = train(category_tensor, line_tensor)current_loss += loss# Print iter number, loss, name and guessif iter % print_every == 0:guess, guess_i = categoryFromOutput(output)correct = '✓' if guess == category else '✗ (%s)' % categoryprint('Iter: {:<7} time: {:>8s} loss: {:.4f} name: {:>10s}  pred: {:>8s} label: {:>8s}'.format(iter, timeSince(start), loss, line, guess, correct))# Add current loss avg to list of lossesif iter % plot_every == 0:all_losses.append(current_loss / plot_every)current_loss = 0path_model = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "data", "rnn_state_dict.pkl"))
if not os.path.exists(path_model):raise Exception("\n{} 不存在,请下载 08-05-数据-20200724.zip  放到\n{}  下,并解压即可".format(path_model, os.path.dirname(path_model)))
torch.save(rnn.state_dict(), path_model)
plt.plot(all_losses)
plt.show()predict('Yue Tingsong')
predict('Yue tingsong')
predict('yutingsong')predict('test your name')

http://www.ppmy.cn/server/129154.html

相关文章

Java类设计模式

1、单例模式 核心&#xff1a;保证一个类只有一个对象&#xff0c;并且提供一个访问该实例的全局访问点 五种单例模式&#xff1a;主要&#xff1a;饿汉式&#xff1a;线程安全&#xff0c;调用效率高&#xff0c;不能延时加载懒汉式&#xff1a;线程安全&#xff0c;调用效率…

【人工智能】AI人工智能的重要组成部分,深入解析CNN与RNN两种神经网络的异同与应用场景和区别

文章目录 一、卷积神经网络&#xff08;CNN&#xff09;详解1. 特征与结构CNN的基本结构 2. 应用场景3. 代码示例 二、循环神经网络&#xff08;RNN&#xff09;详解1. 网络结构与特点RNN的基本结构 2. 应用场景3. 代码示例 三、CNN与RNN的异同点1. 相同点2. 不同点 四、CNN与R…

js进阶——深入解析JavaScript中的URLSearchParams

深入解析 JavaScript 中的 URLSearchParams 在现代Web开发中&#xff0c;我们经常需要处理URL中的查询参数&#xff0c;尤其是在构建动态Web应用时。这些查询参数&#xff08;query parameters&#xff09;通常以 ?keyvalue&key2value2 的形式存在。JavaScript 提供了一个…

二分+滑窗,CF 1208B - Uniqueness

目录 一、题目 1、题目描述 2、输入输出 2.1输入 2.2输出 3、原题链接 二、解题报告 1、思路分析 2、复杂度 3、代码详解 一、题目 1、题目描述 2、输入输出 2.1输入 2.2输出 3、原题链接 B - Uniqueness 二、解题报告 1、思路分析 观察单调性&#xff1a;对于合…

Docker实践与应用举例

Docker 是一种开源的容器化平台&#xff0c;允许开发者将应用程序及其依赖打包在一个轻量级的容器中运行。容器可以在任何环境中保持一致的运行状态&#xff0c;从而极大地简化了开发、测试和部署过程。接下来&#xff0c;我将详细介绍 Docker 的实践与应用举例。 1. Docker 的…

使用 Docker 制作 YashanDB 镜像:深度解析与实战指南

哈喽&#xff0c;各位小伙伴们&#xff0c;你们好呀&#xff0c;我是喵手。运营社区&#xff1a;C站/掘金/腾讯云/阿里云/华为云/51CTO&#xff1b;欢迎大家常来逛逛 今天我要给大家分享一些自己日常学习到的一些知识点&#xff0c;并以文字的形式跟大家一起交流&#xff0c;互…

【大学学习-大学之路-回顾-电子计算机相关专业-学习方案-自我学习-大一新生(1)】

【大学学习-大学之路-回顾-电子&计算机相关专业-学习方案-自我学习-大一新生&#xff08;1&#xff09;】 1-前言2-整体说明&#xff08;1&#xff09;打字训练&#xff08;1&#xff09;字母区分大小写&#xff1a;&#xff08;2&#xff09;自动换行&不自动换行&…

【Hadoop】改一下core-site.xml和hdfs-site.xml配置就可以访问Web UI

core-site.xml&#xff1a; hdfs-site.xml&#xff1a; 所有的都改为0.0.0.0 就可以访问Web UI 原因&#xff1a; 使用 0.0.0.0 作为绑定地址时&#xff0c;实际会将服务监听在所有可用的网络接口上。这意味着&#xff0c;任何从外部访问的请求都可以通过任何网络适配器连接到…