python pytorch实现RNN,LSTM,GRU,文本情感分类

news/2024/11/8 17:36:09/

python pytorch实现RNN,LSTM,GRU,文本情感分类

数据集格式:
在这里插入图片描述
有需要的可以联系我

实现步骤就是:
1.先对句子进行分词并构建词表
2.生成word2id
3.构建模型
4.训练模型
5.测试模型

代码如下:


import pandas as pd
import torch
import matplotlib.pyplot as plt
import jieba
import numpy as np"""
作业:
一、完成优化
优化思路1 jieba
2 取常用的3000字
3 修改model:rnn、lstm、gru二、完成测试代码
"""# 了解数据
dd = pd.read_csv(r'E:\peixun\data\train.csv')
# print(dd.head())# print(dd['label'].value_counts())# 句子长度分析
# 确定输入句子长度为 500
text_len = [len(i) for i in dd['text']]
# plt.hist(text_len)
# plt.show()
# print(max(text_len), min(text_len))# 基本参数 config
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('my device:', DEVICE)MAX_LEN = 500
BATCH_SIZE = 16
EPOCH = 1
LR = 3e-4# 构建词表 word2id
vocab = []
for i in dd['text']:vocab.extend(jieba.lcut(i, cut_all=True))  # 使用 jieba 分词# vocab.extend(list(i))vocab_se = pd.Series(vocab)
print(vocab_se.head())
print(vocab_se.value_counts().head())vocab = vocab_se.value_counts().index.tolist()[:3000]  # 取频率最高的 3000 token
# print(vocab[:10])
# exit()WORD_PAD = "<PAD>"
WORD_UNK = "<UNK>"
WORD_PAD_ID = 0
WORD_UNK_ID = 1vocab = [WORD_PAD, WORD_UNK] + list(set(vocab))print(vocab[:10])
print(len(vocab))vocab_dict = {k: v for v, k in enumerate(vocab)}# 词表大小,vocab_dict: word2id; vocab: id2word
print(len(vocab_dict))import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import pandas as pd# 定义数据集 Dataset
class Dataset(data.Dataset):def __init__(self, split='train'):# ChnSentiCorp 情感分类数据集path =  r'E:/peixun/data/' + str(split) + '.csv'self.data = pd.read_csv(path)def __len__(self):return len(self.data)def __getitem__(self, i):text = self.data.loc[i, 'text']label = self.data.loc[i, 'label']return text, label# 实例化 Dataset
dataset = Dataset('train')# 样本数量
print(len(dataset))
print(dataset[0])# 句子批处理函数
def collate_fn(batch):# [(text1, label1), (text2, label2), (3, 3)...]sents = [i[0][:MAX_LEN] for i in batch]labels = [i[1] for i in batch]inputs = []# masks = []for sent in sents:sent = [vocab_dict.get(i, WORD_UNK_ID) for i in list(sent)]pad_len = MAX_LEN - len(sent)# mask = len(sent) * [1] + pad_len * [0]# masks.append(mask)sent += pad_len * [WORD_PAD_ID]inputs.append(sent)# 只使用 lstm 不需要用 masks# masks = torch.tensor(masks)# print(inputs)inputs = torch.tensor(inputs)labels = torch.LongTensor(labels)return inputs.to(DEVICE), labels.to(DEVICE)# 测试 loader
loader = data.DataLoader(dataset,batch_size=BATCH_SIZE,collate_fn=collate_fn,shuffle=True,drop_last=False)inputs, labels = iter(loader).__next__()
print(inputs.shape, labels)# 定义模型
class Model(nn.Module):def __init__(self, vocab_size=5000):super().__init__()self.embed = nn.Embedding(vocab_size, 100, padding_idx=WORD_PAD_ID)# 多种 rnnself.rnn = nn.RNN(100, 100, 1, batch_first=True, bidirectional=True)self.gru = nn.GRU(100, 100, 1, batch_first=True, bidirectional=True)self.lstm = nn.LSTM(100, 100, 1, batch_first=True, bidirectional=True)self.l1 = nn.Linear(500 * 100 * 2, 100)self.l2 = nn.Linear(100, 2)def forward(self, inputs):out = self.embed(inputs)out, _ = self.lstm(out)out = out.reshape(BATCH_SIZE, -1)  # 16 * 100000out = F.relu(self.l1(out))  # 16 * 100out = F.softmax(self.l2(out))  # 16 * 2return out# 测试 Model
model = Model()
print(model)# 模型训练
dataset = Dataset()
loader = data.DataLoader(dataset,batch_size=BATCH_SIZE,collate_fn=collate_fn,shuffle=True)model = Model().to(DEVICE)# 交叉熵损失
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)model.train()
for e in range(EPOCH):for idx, (inputs, labels) in enumerate(loader):# 前向传播,计算预测值out = model(inputs)# 计算损失loss = loss_fn(out, labels)# 反向传播,计算梯度loss.backward()# 参数更新optimizer.step()# 梯度清零optimizer.zero_grad()if idx % 10 == 0:out = out.argmax(dim=-1)acc = (out == labels).sum().item() / len(labels)print('>>epoch:', e,'\tbatch:', idx,'\tloss:', loss.item(),'\tacc:', acc)# 模型测试
test_dataset = Dataset('test')
test_loader = data.DataLoader(test_dataset,batch_size=BATCH_SIZE,collate_fn=collate_fn,shuffle=False)loss_fn = nn.CrossEntropyLoss()out_total = []
labels_total = []model.eval()
for idx, (inputs, labels) in enumerate(test_loader):out = model(inputs)loss = loss_fn(out, labels)out_total.append(out)labels_total.append(labels)if idx % 50 == 0:print('>>batch:', idx, '\tloss:', loss.item())correct=0
sumz=0
for i in range(len(out_total)):out = out_total[i].argmax(dim=-1)correct = (out == labels_total[i]).sum().item() +correctsumz=sumz+len(labels_total[i])#acc = (out_total == labels_total).sum().item() / len(labels_total)print('>>acc:', correct/sumz)

运行结果如下:
在这里插入图片描述


http://www.ppmy.cn/news/1251721.html

相关文章

封装一些可能会用到的JS的Dom操作方法(非JS自带的方法)

1. 父元素节点下的子元素节点逆序 HTMLElement.prototype.childRevers function () {var all_num this.childElementCount;if (all_num) {while(all_num--){this.appendChild(this.children[all_num]);}} } // 获取 ul 父节点对象 var oul document.getElementsByTagName(u…

简单好用!日常写给 ChatGPT 的几个提示词技巧

ChatGPT 很强&#xff0c;但是有时候又显得很蠢&#xff0c;下面是使用 GPT4 的一个实例&#xff1a; 技巧一&#xff1a;三重冒号 """ 引用内容使用三重冒号 """&#xff0c;让 ChatGPT 清晰引用的内容&#xff1a; 技巧二&#xff1a;角色设定…

MySQL之binlog日志

聊聊BINLOG binlog记录什么&#xff1f; MySQL server中所有的搜索引擎发生了更新&#xff08;DDL和DML&#xff09;都会产生binlog日志&#xff0c;记录的是语句的原始逻辑 为什么需要binlog&#xff1f; binlog主要有两个应用场景&#xff0c;一是数据复制&#xff0c;在…

整数反转 Golang leecode_7

拿到手第一反应还是暴力&#xff0c;直接从低位到高位把数一个个取出来&#xff0c;然后乘以每一位的权重&#xff0c;构成一个新的反转后的整数 res 返回&#xff0c;代码如下 package mainimport ("fmt""math" )func reverse(x int) int {if x > -10…

一篇文章带你掌握MongoDB

文章目录 1. 前言2. MongoDB简介3. MongoDB与关系型数据库的对比4. MongoDB的安装5. Compass的使用6. MongoDB的常用语句7. 总结 1. 前言 本文旨在帮助大家快速了解MongoDB,快速了解和掌握MongoDB的干货内容. 2. MongoDB简介 MongoDB是一种NoSQL数据库&#xff0c;采用了文档…

5、DMA Demo(STM32F407)

DMA简介 DMA 全称Direct Memory Access&#xff0c;即直接存储器访问。 DMA传输将数据从一个地址空间复制到另一个地址空间。当CPU初始化这个传输动作&#xff0c;传输动作本身是由DMA控制器来实现和完成的。 DMA传输方式无需CPU直接控制传输&#xff0c;也没有中断处理方式那…

(C++20) consteval立即函数

文章目录 由来consteval立即函数上下文的常量性质lambda表达式 编译期间确定无法获取函数指针查看汇编 END 由来 在C11中推出了constexpr使得对象或者函数能够具有常量性质并能在编译器确定。但是对于constexpr修饰的函数来说&#xff0c;无法保证严格的在编译器确定。 下面这…

2023.11.28 使用tensorflow进行“三好“权重分析

2023.11.28 使用tensorflow进行"三好"权重分析 这是最基础的一个神经网络问题。许久没有再使用&#xff0c;用来做恢复训练比较好。 x1w1 x2w2 x3*w3 y&#xff0c;已知x1,x2,x3和y&#xff0c;求w1,w2,w3 这是一个三元一次方程&#xff0c;正常需要三组数据就能…