Day26下 - BERT项目实战

devtools/2024/12/23 23:22:27/

BERT论文:https://arxiv.org/pdf/1810.04805

BERT架构:

 

BERT实战 

1. 读取数据

# pandas 适合表格类数据读取
import pandas as pd
import numpy as np# sep: 分隔符
data = pd.read_csv(filepath_or_buffer="samples.tsv", sep="\t").to_numpy()# 打乱样本顺序
np.random.shuffle(data)

2. 打包数据

# 深度学习框架
import torch
# 深度学习中的封装层
from torch import nn
# 引入数据集
from torch.utils.data import Dataset
# 引入数据集加载器
from torch.utils.data import DataLoaderclass SentiDataset(Dataset):"""自定义数据集"""def __init__(self, data):"""初始化"""self.data = datadef __getitem__(self, idx):"""按索引获取单个样本"""x, y = self.data[idx]return x, ydef __len__(self):"""返回数据集中的样本个数"""return len(self.data)# 训练集(前4500个作为训练集)
train_dataset = SentiDataset(data=data[:4500])
train_dataloader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=32)
# 测试集(4500之后的作为测试集)
test_dataset = SentiDataset(data=data[4500:])
test_dataloader = DataLoader(dataset=test_dataset, shuffle=False, batch_size=32)for x, y in train_dataloader:# print(x)# print(y)break

3. 构建模型

# 用于加载 BERT 分词器
from transformers import BertTokenizer
# 用于加载 BERT 序列分类器
from transformers import BertForSequenceClassification# 从 ModelScope 上下载 
# from modelscope import snapshot_download
# 设置 模型id model_id
# 设置 cache_dir 缓存目录
# model_dir = snapshot_download(model_id='tiansz/bert-base-chinese', cache_dir="./bert")# 模型地址
model_dir = "bert-base-chinese"# 加载分词器
tokenizer = BertTokenizer.from_pretrained(model_dir)tokenizerdevice = "cuda:0" if torch.cuda.is_available() else "cpu"
device# 二分类分类器
model = BertForSequenceClassification.from_pretrained(model_dir, num_labels=2)
model.to(device = device)num_params = 0
for param in model.parameters():if param.requires_grad:num_params += param.numel()
num_params# 冻结所有预训练层
# for param in model.parameters():
#     param.requires_grad=False
# 只训练分类层
# model.classifier.weight.requires_grad=True
# model.classifier.bias.requires_grad=Truenum_params = 0
for param in model.parameters():if param.requires_grad:num_params += param.numel()
num_params# 类别字典
label2idx = {"正面": 0, "负面": 1}
idx2label = {0: "正面", 1: "负面"}

4. 训练

from torch import nn
# 损失函数
loss_fn = nn.CrossEntropyLoss()
# 优化器
optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-4)
# 定义训练轮次
epochs = 5def get_acc(dataloader):"""计算准确率"""# 设置为评估模式model.eval()accs = []# 构建一个无梯度的环境with torch.no_grad():# 逐个批次计算for X, y in train_dataloader:# 编码X = tokenizer.batch_encode_plus(batch_text_or_text_pairs=X, padding=True, truncation=True,max_length=100,return_tensors="pt")# 转张量y = torch.tensor(data=[label2idx.get(label) for label in y], dtype=torch.long).cuda()# 1. 正向传播y_pred = model(input_ids=X["input_ids"].to(device=device), attention_mask=X["attention_mask"].to(device=device))# 2. 计算准确率acc = (y_pred.logits.argmax(dim=-1) == y).to(dtype=torch.float).mean().item()accs.append(acc)return sum(accs) / len(accs)def train():"""训练过程"""# 训练之前:先看看准确率train_acc = get_acc(train_dataloader)test_acc = get_acc(test_dataloader)print(f"初始:Train_Acc: {train_acc}, Test_Acc: {test_acc}")# 遍历每一轮for epoch in range(epochs):model.train()# 遍历每个批次for X, y in train_dataloader:# 编码X = tokenizer.batch_encode_plus(batch_text_or_text_pairs=X, padding=True, truncation=True,max_length=100,return_tensors="pt")# 转张量y = torch.tensor(data=[label2idx.get(label) for label in y], dtype=torch.long).cuda()# 1. 正向传播y_pred = model(input_ids=X["input_ids"].to(device=device), attention_mask=X["attention_mask"].to(device=device))# 2. 计算损失loss = loss_fn(y_pred.logits, y)# 3. 反向传播loss.backward()# 4. 优化一步optimizer.step()# 5. 清空梯度optimizer.zero_grad()# 每轮都计算一下准备率train_acc = get_acc(train_dataloader)test_acc = get_acc(test_dataloader)print(f"Epoch: {epoch +1}, Train_Acc: {train_acc}, Test_Acc: {test_acc}")train()

5. 保存模型

# 保存训练好的模型
model.save_pretrained(save_directory="./sentiment_model")
# 保存分词器
tokenizer.save_pretrained(save_directory="./sentiment_model")

6. 预测

# 加载分词器
tokenizer = BertTokenizer.from_pretrained("./sentiment_model")
# 加载模型
model = BertForSequenceClassification.from_pretrained("./sentiment_model").cuda()def predict(text="服务挺好的"):# 设置为评估模式model.eval()with torch.no_grad():inputs = tokenizer(text=text,padding=True, truncation=True,max_length=100,return_tensors="pt")y_pred = model(input_ids=inputs["input_ids"].to(device=device), attention_mask=inputs["attention_mask"].to(device=device))y_pred = y_pred.logits.argmax(dim=-1).cpu().numpy()result = idx2label.get(y_pred[0])return resultpredict(text="有老鼠,厕所脏")


http://www.ppmy.cn/devtools/144827.html

相关文章

LeetCode hot100-89

https://leetcode.cn/problems/partition-equal-subset-sum/description/?envTypestudy-plan-v2&envIdtop-100-liked 416. 分割等和子集 已解答 中等 相关标签 相关企业 给你一个 只包含正整数 的 非空 数组 nums 。请你判断是否可以将这个数组分割成两个子集&#xff0c…

Chapter 3-1. Detecting Congestion in Fibre Channel Fabrics

Chapter 3. Detecting Congestion in Fibre Channel Fabrics This chapter covers the following topics: 本章包括以下主题: Congestion detection workflow. Congestion detection metrics. Congestion detection metrics and commands on Cisco MDS switches. Automatic A…

金碟中间件-AAS-V10.0安装

金蝶中间件AAS-V10.0 AAS-V10.0安装 1.解压AAS-v10.0安装包 unzip AAS-V10.zip2.更新license.xml cd /root/ApusicAS/aas# 这里要将license复制到该路径 [rootvdb1 aas]# ls bin docs jmods lib modules templates config domains …

CTF入门:以Hackademic-RTB1靶场为例初识夺旗

一、网络扫描 靶机ip地址为192.168.12.24 使用nmap工具进行端口扫描 nmap -sT 192.168.12.24 二、信息收集 1、80端口探索 靶机开放了80和22端口,使用浏览器访问靶机的80端口,界面如下: 点击target发现有跳转,并且url发生相应变…

python学opencv|读取图像(十八)使用cv2.line创造线段

【1】引言 前序已经完成了opencv基础知识的学习,我们已经掌握了处理视频和图像的基本操作。相关文章包括且不限于: python学opencv|读取图像(三)放大和缩小图像_python(1)使用opencv读取并显示图像;(2)使用opencv对图像进行缩放…

Git进阶:本地或远程仓库如何回滚到之前的某个commit

在Git的使用过程中,我们经常会遇到需要回滚到之前某个commit的情况。无论是为了修复错误、撤销更改,还是为了重新组织代码,回滚到特定commit都是一个非常有用的技能。本文将介绍几种常用的回滚方法,帮助读者更好地掌握Git版本控制…

libilibi项目总结(18)FFmpeg 的使用

FFmpeg工具类 import com.easylive.entity.config.AppConfig; import com.easylive.entity.constants.Constants; import org.springframework.stereotype.Component;import javax.annotation.Resource; import java.io.File; import java.math.BigDecimal;Component public c…

梯度(Gradient)和 雅各比矩阵(Jacobian Matrix)的区别和联系:中英双语

雅各比矩阵与梯度:区别与联系 在数学与机器学习中,梯度(Gradient) 和 雅各比矩阵(Jacobian Matrix) 是两个核心概念。虽然它们都描述了函数的变化率,但应用场景和具体形式有所不同。本文将通过…