pytorch使用SVM实现文本分类

server/2025/1/30 23:47:23/

完整代码:

import torch
import torch.nn as nn
import torch.optim as optim
import jieba
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn import metrics# 1. 数据准备(中文文本)
texts = ["今天的足球比赛非常激烈,球队表现出色,最终赢得了比赛。","NBA比赛今天开打,球员们的表现非常精彩,球迷们热情高涨。","张艺谋的新电影上映了,票房成绩非常好,观众反响热烈。","娱乐圈最近又出了一些新闻,明星们的私生活成了大家讨论的焦点。","昨晚的篮球赛真是太精彩了,球员们的进攻和防守都非常强硬。","李宇春在最新的音乐会上演出了她的新歌,现场观众反应热烈。","今年的世界杯比赛激烈异常,球队之间的竞争越来越激烈。","最近的综艺节目非常火,明星嘉宾的表现让观众们大笑不已。"
]# 标签:0表示体育,1表示娱乐
labels = [0, 0, 1, 1, 0, 1, 0, 1]# 2. 数据预处理:中文分词和 TF-IDF 特征提取
def jieba_cut(text):return " ".join(jieba.cut(text))texts_cut = [jieba_cut(text) for text in texts]vectorizer = TfidfVectorizer(max_features=10000)
X_tfidf = vectorizer.fit_transform(texts_cut).toarray()
y = np.array(labels)# 3. 数据集分割为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_tfidf, y, test_size=0.2, random_state=42)# 4. PyTorch 数据加载
class NewsGroupDataset(torch.utils.data.Dataset):def __init__(self, features, labels):self.features = torch.tensor(features, dtype=torch.float32)self.labels = torch.tensor(labels, dtype=torch.long)def __len__(self):return len(self.features)def __getitem__(self, idx):return self.features[idx], self.labels[idx]train_dataset = NewsGroupDataset(X_train, y_train)
test_dataset = NewsGroupDataset(X_test, y_test)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=2, shuffle=False)# 5. 定义 SVM 模型(使用线性层)
class SVM(nn.Module):def __init__(self, input_dim, output_dim):super(SVM, self).__init__()self.fc = nn.Linear(input_dim, output_dim)def forward(self, x):return self.fc(x)# 6. 获取特征数并初始化模型
input_dim = X_tfidf.shape[1]  # 自动获取特征数
model = SVM(input_dim=input_dim, output_dim=2)  # 使用特征数量设置输入维度
criterion = nn.CrossEntropyLoss()  # 损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 优化器,调整学习率# 7. 训练模型
num_epochs = 50  # 增加训练周期for epoch in range(num_epochs):model.train()running_loss = 0.0correct = 0total = 0for inputs, labels in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader)}, Accuracy: {100 * correct / total}%")# 8. 测试模型
model.eval()
correct = 0
total = 0with torch.no_grad():for inputs, labels in test_loader:outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Test Accuracy: {100 * correct / total}%")# 9. 输出性能指标
y_pred = []
y_true = []with torch.no_grad():for inputs, labels in test_loader:outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)y_pred.extend(predicted.numpy())y_true.extend(labels.numpy())print(metrics.classification_report(y_true, y_pred))# 10. 测试新样本
def predict(text, model, vectorizer):# 1. 进行分词text_cut = jieba_cut(text)# 2. 将文本转为 TF-IDF 特征向量text_tfidf = vectorizer.transform([text_cut]).toarray()# 3. 转换为 PyTorch 张量text_tensor = torch.tensor(text_tfidf, dtype=torch.float32)# 4. 模型预测model.eval()  # 设置模型为评估模式with torch.no_grad():output = model(text_tensor)_, predicted = torch.max(output.data, 1)# 5. 返回预测结果return predicted.item()# 测试一个新的中文文本
new_text = "今天的篮球比赛真是太精彩了,球员们的表现让大家都为之喝彩。"
predicted_label = predict(new_text, model, vectorizer)# 输出预测结果
if predicted_label == 0:print("预测类别: 体育")
else:print("预测类别: 娱乐")

1. 数据准备

  • 文本数据:我们定义了一个包含中文文本的列表,每条文本表示一个新闻或评论。
  • 标签:为每条文本分配了一个标签,0 代表“体育”,1 代表“娱乐”。

2. 数据预处理

  • 中文分词:使用 jieba 库对每条文本进行分词,并将分词后的结果连接成字符串。这是处理中文文本时的常见做法。
  • TF-IDF 特征提取:使用 TfidfVectorizer 将文本转化为数值特征。TF-IDF 是一种常见的文本表示方式,能够衡量单词在文档中的重要性。

3. 数据集分割

  • 使用 train_test_split 将数据分为训练集和测试集。80% 的数据用于训练,20% 用于测试。

4. PyTorch 数据加载

  • 定义 Dataset 类:创建了一个自定义的 NewsGroupDataset 类,继承自 torch.utils.data.Dataset,用于将文本特征和标签封装为 PyTorch 可用的数据集格式。
  • DataLoader:使用 DataLoader 将训练集和测试集数据进行批处理和加载。

5. 模型定义

  • 定义了一个简单的线性 SVM 模型。实际上,使用了一个线性层 (nn.Linear) 来进行分类,输入是文本的 TF-IDF 特征,输出是两个类别(体育或娱乐)。
  • 使用了 CrossEntropyLoss 作为损失函数,因为这是分类任务中常用的损失函数。
  • 优化器使用了随机梯度下降(SGD),并设置了学习率为 0.01。

6. 训练过程

  • 训练模型的过程包括:前向传播(计算输出),计算损失,反向传播(更新参数),并在每个 epoch 后输出损失和准确率。
  • 每个 batch 的训练过程中,模型会通过计算损失并进行优化,逐步提升准确率。

7. 测试和评估

  • 在测试过程中,将模型设置为评估模式 (model.eval()),并计算测试集上的准确率。通过比较预测标签与真实标签,计算正确的预测数量并输出准确率。
  • 使用 classification_report 来输出精确度、召回率和 F1 分数等更多的评估指标。

8. 预测新样本

  • 定义了一个 predict() 函数,用于预测新的文本样本的分类
  • 预测过程包括:对文本进行分词,转化为 TF-IDF 特征,传入模型进行前向传播,最后返回模型预测的标签。

9. 输出

  • 代码的最后,会输出模型对新文本的预测结果,标明是属于体育类别还是娱乐类别。

关键技术点:

  • 中文分词:使用 jieba 对中文文本进行分词处理,这对于中文文本的处理至关重要。
  • TF-IDF:将文本转换为数值特征,便于模型处理。TF-IDF 是基于单词在文档中的出现频率及其在整个语料中的稀有度进行加权的。
  • 模型训练与评估:通过多轮训练提升模型准确度,使用测试集来评估模型的泛化能力。
  • PyTorch DataLoader:通过 DataLoader 高效地处理训练集和测试集,进行批处理和自动化管理。

http://www.ppmy.cn/server/163654.html

相关文章

Linux学习笔记——网络管理命令

一、网络基础知识 TCP/IP四层模型 以太网地址(MAC地址): 段16进制数据 IP地址: 子网掩码: 二、接口管命令 ip命令:字符终端,立即生效,重启配置会丢失 nmcli命令:字符…

【暴力洗盘】的实战技术解读-北玻股份和三变科技

龙头的上攻与回调动作都是十分惊人的。不惊人不足以吸引投资者的关注,不惊人也就不能成为龙头了。 1.建筑节能概念--北玻股份 建筑节能,是指在建筑材料生产、房屋建筑和构筑物施工及使用过程中,满足同等需要或达到相同目的的条件下&#xf…

Github 2025-01-26 php开源项目日报Top10

根据Github Trendings的统计,今日(2025-01-26统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量PHP项目10Blade项目1Laravel:表达力和优雅的 Web 应用程序框架 创建周期:4631 天开发语言:PHP, BladeStar数量:75969 个Fork数量:24281 次…

【信息系统项目管理师-选择真题】2006下半年综合知识答案和详解

更多内容请见: 备考信息系统项目管理师-专栏介绍和目录 文章目录 【第1题】【第2题】【第3题】【第4题】【第5题】【第6题】【第7题】【第8题】【第9题】【第10题】【第11题】【第12题】【第13题】【第14题】【第15题】【第16题】【第17题】【第18题】【第19题】【第20题】【第…

【MySQL】C# 连接MySQL

C# 连接MySQL 1. 添加MySQL引用 安装完MySQL之后,在安装的默认目录 C:Program Files (x86)MySQLConnector NET 8.0 中查找MySQLData.dll文件。 在Visual Studio 中为项目中添加引用。 2. 引入命名空间 using MySql.Data.MySqlClient;3. 构建连接 private sta…

Nuitka打包python脚本

Python脚本打包 Python是解释执行语言,需要解释器才能运行代码,这就导致在开发机上编写的代码在别的电脑上无法直接运行,除非目标机器上也安装了Python解释器,有时候还需要额外安装Python第三方包,相当麻烦。 事实上P…

使用 Docker + Nginx + Certbot 实现自动化管理 SSL 证书

使用 Docker Nginx Certbot 实现自动化管理 SSL 证书 在互联网安全环境日益重要的今天,为站点或应用部署 HTTPS 已经成为一种常态。然而,手动申请并续期证书既繁琐又容易出错。本文将以 Nginx Certbot 为示例,基于 Docker 容器来搭建一个…

美颜技术开发实战:美颜滤镜SDK的性能优化与兼容性解决方案

本篇文章,小编将深入探讨美颜滤镜SDK的性能优化策略,并提供针对不同平台的兼容性解决方案,助力开发者打造高效稳定的美颜体验。 一、美颜滤镜SDK性能优化策略 在美颜处理过程中,图像处理的计算量大,涉及磨皮、美白、…