pytorch逻辑回归实现垃圾邮件检测

server/2025/2/3 23:26:44/

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

完整代码:

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np# 增强的数据集:更多的垃圾邮件与正常邮件样本
X = ["Congratulations! You've won a $1000 gift card. Claim it now!","Dear friend, I hope you are doing well. Let's catch up soon.","Urgent: Your bank account has been compromised. Please contact support immediately.","Hello, just wanted to confirm our meeting at 2 PM today.","You have a new message from your friend. Click here to read.","Get a free iPhone now! Limited offer, click here.","Last chance to claim your prize, you won $500!","Meeting scheduled for tomorrow. Please confirm.","Hello! You are invited to an exclusive event!","Click here to get free lottery tickets. Hurry up!","Reminder: Your subscription will expire soon, renew now.","Don't forget to submit your report by end of day today."
]
y = [1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0]  # 1 为垃圾邮件,0 为正常邮件# 使用 TfidfVectorizer 进行文本向量化
vectorizer = TfidfVectorizer(stop_words='english')  # 去除停用词
X_vec = vectorizer.fit_transform(X).toarray()# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_vec, y, test_size=0.33, random_state=42)# 定义逻辑回归模型
class LogisticRegressionModel(nn.Module):def __init__(self, input_dim):super(LogisticRegressionModel, self).__init__()self.fc = nn.Linear(input_dim, 1)  # 线性层,输入维度是特征的数量,输出是1def forward(self, x):return torch.sigmoid(self.fc(x))  # 使用sigmoid激活函数输出0到1之间的概率# 定义训练过程
def train_model(model, X_train, y_train, num_epochs=200, learning_rate=0.001):criterion = nn.BCELoss()  # 二分类交叉熵损失optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # 使用Adam优化器X_train_tensor = torch.tensor(X_train, dtype=torch.float32)y_train_tensor = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)for epoch in range(num_epochs):model.train()optimizer.zero_grad()outputs = model(X_train_tensor)loss = criterion(outputs, y_train_tensor)loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 测试模型
def evaluate_model(model, X_test, y_test):model.eval()X_test_tensor = torch.tensor(X_test, dtype=torch.float32)y_test_tensor = torch.tensor(y_test, dtype=torch.float32).view(-1, 1)with torch.no_grad():outputs = model(X_test_tensor)predictions = (outputs >= 0.5).float()  # 阈值设为0.5accuracy = accuracy_score(y_test, predictions.numpy())print(f'Accuracy: {accuracy * 100:.2f}%')# 训练并评估模型
input_dim = X_train.shape[1]  # 输入特征的数量
model = LogisticRegressionModel(input_dim)
train_model(model, X_train, y_train, num_epochs=200, learning_rate=0.001)
evaluate_model(model, X_test, y_test)# 预测新邮件
def predict(model, new_email):model.eval()new_email_vec = vectorizer.transform([new_email]).toarray()new_email_tensor = torch.tensor(new_email_vec, dtype=torch.float32)with torch.no_grad():prediction = model(new_email_tensor)return "Spam" if prediction >= 0.5 else "Not Spam"# 检测新邮件
email_1 = "Congratulations! You have a limited time offer for a free cruise."
email_2 = "Hi, let's discuss the project updates tomorrow."print(f"Email 1: {predict(model, email_1)}")  # 可能输出:Spam
print(f"Email 2: {predict(model, email_2)}")  # 可能输出:Not Spam
1. 数据预处理
  • 准备数据集:包含垃圾邮件(Spam)和正常邮件(Not Spam)。
  • 文本向量化:使用 TfidfVectorizer 将文本转换为数值特征,使模型能够处理。
  • 去除停用词:排除无意义的常见词(如 "the", "is", "and"),提高模型性能。
2. 训练集与测试集划分
  • 将数据集拆分为训练集和测试集,以 67% 训练,33% 测试,保证模型有足够数据训练,同时可以评估其泛化能力。
3. 逻辑回归模型
  • 搭建 PyTorch 逻辑回归模型
    • 采用 nn.Linear() 构建一个单层神经网络(输入为文本特征,输出为 1 个数值)。
    • 使用 sigmoid 作为激活函数,将输出转换为 0-1 之间的概率值。
4. 训练模型
  • 定义损失函数:使用二元交叉熵损失 (BCELoss),适用于二分类问题。
  • 优化器:采用 Adam 优化器,以 0.001 学习率进行参数优化。
  • 训练流程
    1. 计算前向传播的输出。
    2. 计算损失值,衡量预测结果与真实标签的差距。
    3. 进行反向传播,更新权重参数。
    4. 迭代多轮(如 200 轮),不断优化模型。
5. 评估模型
  • 将测试数据输入模型,预测结果并与真实标签进行对比。
  • 计算准确率,评估模型在未见过的数据上的表现。
6. 预测新邮件
  • 将新邮件转换为数值特征(与训练时相同的方法)。
  • 使用训练好的模型进行预测
  • 阈值判断:如果输出概率 ≥ 0.5,则判断为垃圾邮件,否则为正常邮件。

http://www.ppmy.cn/server/164737.html

相关文章

mybatis(78/134)

前天学了很多&#xff0c;关于java的反射机制&#xff0c;其实跳过了new对象&#xff0c;然后底层生成了字节码&#xff0c;创建了对应的编码。手搓了一遍源码&#xff0c;还是比较复杂的。 <?xml version"1.0" encoding"UTF-8" ?> <!DOCTYPE …

第十三章 I 开头的术语

文章目录 第十三章 I 开头的术语安装目录 (install-dir)实例 (instance)实例认证 (Instance Authentication)实例方法 (instance method)实例化 (instantiate)中间源代码 (intermediate source code)InterSystems IRIS 启动器 (InterSystems IRIS launcher)InterSystems IRIS 数…

C基础寒假练习(4)

输入带空格的字符串&#xff0c;求单词个数、 #include <stdio.h> // 计算字符串长度的函数 size_t my_strlen(const char *str) {size_t len 0;while (str[len] ! \0) {len;}return len; }int main() {char str[100];printf("请输入一个字符串: ");fgets(…

女生年薪12万,算不算属于高收入人群

在繁华喧嚣的都市中&#xff0c;我们时常会听到关于收入、高薪与生活质量等话题的讨论。尤其是对于年轻女性而言&#xff0c;薪资水平不仅关乎个人价值的体现&#xff0c;更直接影响到生活质量与未来的规划。那么&#xff0c;女生年薪12万&#xff0c;是否可以被划入高收入人群…

《AI大模型开发笔记》DeepSeek技术创新点

一、DeepSeek横空出世 DeepSeek V3 以颠覆性技术架构创新强势破局&#xff01;革命性的上下文处理机制实现长文本推理成本断崖式下降&#xff0c;综合算力需求锐减90%&#xff0c;开启高效 AI 新纪元&#xff01; 最新开源的 DeepSeek V3模型不仅以顶尖基准测试成绩比肩业界 …

Redis|前言

文章目录 什么是 Redis&#xff1f;Redis 主流功能与应用 什么是 Redis&#xff1f; Redis&#xff0c;Remote Dictionary Server&#xff08;远程字典服务器&#xff09;。Redis 是完全开源的&#xff0c;使用 ANSIC 语言编写&#xff0c;遵守 BSD 协议&#xff0c;是一个高性…

【Elasticsearch】index:false

在 Elasticsearch 中&#xff0c;index 参数用于控制是否对某个字段建立索引。当设置 index: false 时&#xff0c;意味着该字段不会被编入倒排索引中&#xff0c;因此不能直接用于搜索查询。然而&#xff0c;这并不意味着该字段完全不可访问或没有其他用途。以下是关于 index:…

HTML5 常用事件详解

在现代 Web 开发中&#xff0c;用户交互是提升用户体验的关键。HTML5 提供了丰富的事件机制&#xff0c;允许开发者监听用户的操作&#xff08;如点击、拖动、键盘输入等&#xff09;&#xff0c;并触发相应的逻辑处理。本文将详细介绍 HTML5 中的常用事件&#xff0c;包括鼠标…