NLP(8)--利用RNN实现多分类任务

server/2024/10/9 17:27:36/

前言

仅记录学习过程,有问题欢迎讨论

循环神经网络RNN(recurrent neural network):
  • 主要思想:将整个序列划分成多个时间步,将每一个时间步的信息依次输入模型,同时将模型输出的结果传给下一个时间步
  • 自带了tanh的激活函数

代码

发现RNN效率高很多

import json
import randomimport numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.utils.data as Data"""
构建一个 用RNN实现的 判断某个字符的位置 的任务5 分类任务 判断 a出现的位置 返回index +1 or -1
"""class TorchModel(nn.Module):def __init__(self, sentence_length, hidden_size, vocab, input_dim, output_size):super(TorchModel, self).__init__()#self.emb = nn.Embedding(len(vocab) + 1, input_dim)self.rnn = nn.RNN(input_dim, hidden_size, batch_first=True)self.pool = nn.MaxPool1d(sentence_length)self.leaner = nn.Linear(hidden_size, output_size)self.loss = nn.functional.cross_entropydef forward(self, x, y=None):# x = 15 * 4x = self.emb(x)  # output = 15 * 4 * 10x, h = self.rnn(x)  # output = 15 * 4 * 20 h = 1*15*20x = self.pool(x.transpose(1, 2)).squeeze()  # output = 15 * 20 * (1,被去除)y_pred = self.leaner(x)  # output = 15 * 5if y is not None:return self.loss(y_pred, y)else:return y_pred# 创建字符集 只有6个 希望a出现的概率大点def build_vocab():chars = "abcdef"vocab = {}for index, char in enumerate(chars):vocab[char] = index + 1# vocab['unk'] = len(vocab) + 1return vocab# 构建样本集
def build_dataset(vocab, data_size, sentence_length):dataset_x = []dataset_y = []for i in range(data_size):x, y = build_simple(vocab, sentence_length)dataset_x.append(x)dataset_y.append(y)return torch.LongTensor(dataset_x), torch.LongTensor(dataset_y)# 构建样本
def build_simple(vocab, sentence_length):# 随机生成 长度为4的字符串x = [random.choice(list(vocab.keys())) for _ in range(sentence_length)]if x.count('a') != 0:y = x.index('a')else:y = 4# 转化为 数字x = [vocab[char] for char in list(x)]return x, ydef main():batch_size = 15simple_size = 500vocab = build_vocab()# 每个样本的长度为4sentence_length = 4# 样本的向量维度为10input_dim = 10# rnn的隐藏层 随便设置为20hidden_size = 20# 5 分类任务output_size = 5# 学习率lr = 0.02# 轮次epoch_size = 25model = TorchModel(sentence_length, hidden_size, vocab, input_dim, output_size)# 优化函数optim = torch.optim.Adam(model.parameters(), lr=lr)# 样本x, y = build_dataset(vocab, simple_size, sentence_length)dataset = Data.TensorDataset(x, y)dataiter = Data.DataLoader(dataset, batch_size)for epoch in range(epoch_size):epoch_loss = []model.train()for x, y_true in dataiter:loss = model(x, y_true)loss.backward()optim.step()optim.zero_grad()epoch_loss.append(loss.item())print("第%d轮 loss = %f" % (epoch + 1, np.mean(epoch_loss)))# evaluateacc = evaluate(model, vocab, sentence_length)  # 测试本轮模型结果return# 评估效果
def evaluate(model, vocab, sentence_length):model.eval()x, y = build_dataset(vocab, 200, sentence_length)correct, wrong = 0, 0with torch.no_grad():y_pred = model(x)for y_p, y_t in zip(y_pred, y):  # 与真实标签进行对比if int(torch.argmax(y_p)) == int(y_t):correct += 1  # 正样本判断正确else:wrong += 1print("正确预测个数:%d / %d, 正确率:%f" % (correct, correct + wrong, correct / (correct + wrong)))return correct / (correct + wrong)if __name__ == '__main__':main()

可以对model 优化一下

 def __init__(self, sentence_length, hidden_size, vocab, input_dim, output_size):super(TorchModel, self).__init__()# Embedding 层 变为稀疏self.emb = nn.Embedding(len(vocab) + 1, input_dim)self.rnn = nn.RNN(input_dim, input_dim, batch_first=True)self.pool = nn.AvgPool1d(sentence_length)self.leaner = nn.Linear(input_dim, sentence_length + 1)self.loss = nn.functional.cross_entropydef forward(self, x, y=None):# x = 15 * 4x = self.emb(x)  # output = 15 * 4 * 10x, h = self.rnn(x)  # output = 15 * 4 * 20 h = 1*15*20# x = self.pool(x.transpose# (1, 2)).squeeze()  # output = 15 * 20 * (1,被去除)# rnn 最后一维度包含之前所有信息h = h.squeeze()y_pred = self.leaner(h)  # output = 15 * 5if y is not None:return self.loss(y_pred, y)else:return y_pred

http://www.ppmy.cn/server/16970.html

相关文章

分类算法——集成学习方法之随机森林(六)

集成学习方法 集成学习通过建立几个模型组合的来解决单一预测问题。它的工作原理是生成多个分类 器/模型,各自独立地学习和作出预测。这些预测最后结合成组合预测,因此优于任何一 个单分类的做出预测。 随机森林 在机器学习中,随机森林是一…

SpringBoot - java.lang.NoClassDefFoundError: XXX

问题描述 以 json-path 为例:java.lang.NoClassDefFoundError: com/jayway/jsonpath/Configuration 原因分析 编译不报错,但是运行时报错。 遇到这样类似的问题,首先就要想到是不是 Jar 包冲突引起的,或者引入的不是理想的 Jar…

市场投放用户获取方面如何做数据分析

常用数据分析指标 1. 基础指标 下载量: 指通过广告投放带来的下载安装量。 安装率: 指广告点击后下载安装的用户占比。 激活率: 指下载安装后启动应用的用户占比。为了防止假量和刷量,一般会把激活动作定义得更严格更深层一些。比如用户浏览30秒,用户…

全面解读CMS系统:核心技术、架构设计与应用实践

引言 内容管理系统(Content Management System, CMS)作为一种广泛应用的软件平台,以其强大的内容创建、编辑、发布和管理功能,极大地简化了网站、移动应用以及各类数字媒体的内容运营工作。本文将从核心技术、架构设计以及实际应用…

4月26日 阶段性学习汇报

1.毕业设计与毕业论文 毕业设计已经弄完,加入了KNN算法,实现了基于四种常见病的判断,毕业论文写完,格式还需要调整,下周一发给指导老师初稿。目前在弄答辩ppt(25%)。25号26号两天都在参加校运会…

森林消防隔膜泵的应用与前景——恒峰智慧科技

随着全球气候变暖,森林火灾频发,给生态环境和人类安全带来严重威胁。为有效应对这一挑战,森林消防领域不断引入新技术、新装备。其中,隔膜泵作为一种高效、可靠的消防设备,正逐渐受到广泛关注。本文将探讨森林消防隔膜…

分享基于鸿蒙OpenHarmony的Unity团结引擎应用开发赛

该赛题旨在鼓励更多开发者基于OpenHarmony4.x版本,使用团结引擎创造出精彩的游戏与应用。本次大赛分为“创新游戏”与“创新3D 化应用”两大赛道,每赛道又分“大众组”与“高校组”,让不同背景的开发者同台竞技。无论你是游戏开发者&#xff…

hbase建表时设置预分区

一.hbase rowkey设计的原则 遵循唯一性,散列,不应过长等原则 二.rowkey常用的设计 1.reverse反转 2.salt加盐 3.hash散列 三.hbase建表预分区,指定3个rowkey,分成4个region 在Hbase中,预分区是一种优化手段,用于在创建表时提前规划好Region的分布,以提高数据写入的效率和查询…