TabTransformer 模型示例

news/2024/11/24 10:19:03/

代码功能

定义模型:

TabTransformer 类继承自 nn.Module,包含嵌入层、Transformer 编码器和全连接层。
前向传播方法 forward 中,输入数据经过嵌入层、Transformer 编码器和全连接层,最终输出分类结果。

定义数据集类:

SimpleDataset 类继承自 Dataset,用于封装数据和标签,支持索引访问。

下载并处理数据:

从 UCI 机器学习库下载 Iris 数据集。
将数据转换为 Pandas DataFrame,并处理空字符串。
将类别标签转换为数字。

数据预处理:

将数据集分为特征 X 和标签 y。
使用 train_test_split 将数据集划分为训练集和测试集。
将训练集和测试集转换为 PyTorch 张量。

创建数据加载器:

使用 DataLoader 创建训练集和测试集的数据加载器,方便批量读取数据。
初始化模型、损失函数和优化器:
初始化 TabTransformer 模型,设置输入维度、隐藏维度、注意力头数、层数和输出维度。
使用交叉熵损失函数 CrossEntropyLoss。
使用 Adam 优化器。

训练模型:

设置训练轮数 num_epochs。
在每个 epoch 中,遍历训练数据加载器,计算损失并反向传播更新模型参数。
打印每个 epoch 的损失值。

评估模型:

将模型设置为评估模式。
遍历测试数据加载器,预测标签并存储。
计算准确率、精确率、召回率和 F1 分数,并打印结果。
在这里插入图片描述

代码

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import requests
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score# 定义 TabTransformer 模型
class TabTransformer(nn.Module):def __init__(self, input_dim, hidden_dim, num_heads, num_layers, output_dim):super(TabTransformer, self).__init__()self.embedding = nn.Linear(input_dim, hidden_dim)self.transformer = nn.TransformerEncoder(encoder_layer=nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads, batch_first=True),num_layers=num_layers)self.fc = nn.Linear(hidden_dim, output_dim)def forward(self, x):x = self.embedding(x)x = self.transformer(x.unsqueeze(1))  # 增加序列维度x = x.squeeze(1)  # 移除序列维度x = self.fc(x)return x# 定义一个简单的数据集
class SimpleDataset(Dataset):def __init__(self, data, labels):self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx], self.labels[idx]# 下载经典数据集
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
response = requests.get(url)
data = response.text.splitlines()# 将数据转换为 DataFrame
columns = ["sepal_length", "sepal_width", "petal_length", "petal_width", "class"]
df = pd.DataFrame([line.split(',') for line in data], columns=columns)# 处理空字符串
df.replace('', pd.NA, inplace=True)
df.dropna(inplace=True)# 将类别标签转换为数字
class_mapping = {'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 2}
df['class'] = df['class'].map(class_mapping)# 将数据集分为特征和标签
X = df.drop('class', axis=1).astype(float).values
y = df['class'].values# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 转换为 PyTorch 张量
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.long)# 创建数据加载器
train_dataset = SimpleDataset(X_train, y_train)
test_dataset = SimpleDataset(X_test, y_test)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)# 初始化模型、损失函数和优化器
input_dim = X_train.shape[1]
output_dim = len(class_mapping)
model = TabTransformer(input_dim=input_dim, hidden_dim=64, num_heads=8, num_layers=2, output_dim=output_dim)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练模型
num_epochs = 10
for epoch in range(num_epochs):model.train()for batch_data, batch_labels in train_dataloader:optimizer.zero_grad()outputs = model(batch_data)loss = criterion(outputs, batch_labels)loss.backward()optimizer.step()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 评估模型
model.eval()
all_predictions = []
all_labels = []with torch.no_grad():for batch_data, batch_labels in test_dataloader:outputs = model(batch_data)_, predicted = torch.max(outputs, 1)all_predictions.extend(predicted.numpy())all_labels.extend(batch_labels.numpy())# 计算评价指标
accuracy = accuracy_score(all_labels, all_predictions)
precision = precision_score(all_labels, all_predictions, average='weighted')
recall = recall_score(all_labels, all_predictions, average='weighted')
f1 = f1_score(all_labels, all_predictions, average='weighted')print(f'Accuracy: {accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1:.4f}')

http://www.ppmy.cn/news/1549512.html

相关文章

【MySQL课程学习】:MySQL安装,MySQL如何登录和退出?MySQL的简单配置

🎁个人主页:我们的五年 🔍系列专栏:MySQL课程学习 🌷追光的人,终会万丈光芒 🎉欢迎大家点赞👍评论📝收藏⭐文章 目录 MySQL在Centos 7环境下的安装: 卸载…

CSS3_BFC(十二)

BFC MDN对BFC的解释:块格式化上下文(Block Formating Context, BFC)是web页面的可视CSS渲染的一部分,是块盒子的布局过程发生的区域,也是浮动元素与其他元素交互的区域。 1、开启BFC flow-root对内容的影响是最低的&am…

【应用介绍】FastCAE-PHengLEI流体仿真

1 风雷组件集成背景 1.1 风雷软件简介 风雷软件是中国空气动力研究与发展中心(CARDC)研发的面向流体工程的混合CFD平台。平台的建立遵循面向对象的设计理念,采用C语言编程。自2010年开始,气动中心开始着力于工程化品牌CFD软件的…

力扣第 61 题旋转链表

题目描述 给定一个链表,旋转链表,将链表中的每个节点向右移动 k k k 个位置,其中 k k k 是非负数。 示例 1: 输入: head [1,2,3,4,5], k 2 输出: [4,5,1,2,3] 解释: 向右旋转 1 步: [5,1,2,3,4] 向右旋转 2 步: [4,5,1,2,3]示例 2: 输…

LLM( Large Language Models)典型应用介绍 1 -ChatGPT Large language models

ChatGPT 是基于大型语言模型(LLM)的人工智能应用。 GPT 全称是Generative Pre-trained Transformer。-- 生成式预训练变换模型: Generative(生成式):可以根据输入生成新的文本内容,例如回答问题…

android 性能分析工具(03)Android Studio Profiler及常见性能图表解读

说明:主要解读Android Studio Profiler 和 常见性能图表。 Android Studio的Profiler工具是一套功能强大的性能分析工具集,它可以帮助开发者实时监控和分析应用的性能,包括CPU使用率、内存使用、网络活动和能耗等多个方面。以下是对Android …

应急响应靶机——linux2

载入虚拟机,打开虚拟机: 居然是没有图形化界面的那种linux,账户密码:root/Inch957821.(注意是大写的i还有英文字符的.) 查看虚拟机IP,192.168.230.10是NAT模式下自动分配的 看起来不是特别舒服&…

VSCode【下载】【安装】【汉化】【配置C++环境】【运行调试】(Windows环境)

目录 一、VSCode的下载 & 安装 二、汉化 三、配置C 一、VSCode的下载 & 安装 Download Visual Studio Code - Mac, Linux, Windowshttps://code.visualstudio.com/Download 注意!!!【不建议下载User版本,下载System版本】…