代码功能
定义模型:
TabTransformer 类继承自 nn.Module,包含嵌入层、Transformer 编码器和全连接层。
前向传播方法 forward 中,输入数据经过嵌入层、Transformer 编码器和全连接层,最终输出分类结果。
定义数据集类:
SimpleDataset 类继承自 Dataset,用于封装数据和标签,支持索引访问。
下载并处理数据:
从 UCI 机器学习库下载 Iris 数据集。
将数据转换为 Pandas DataFrame,并处理空字符串。
将类别标签转换为数字。
数据预处理:
将数据集分为特征 X 和标签 y。
使用 train_test_split 将数据集划分为训练集和测试集。
将训练集和测试集转换为 PyTorch 张量。
创建数据加载器:
使用 DataLoader 创建训练集和测试集的数据加载器,方便批量读取数据。
初始化模型、损失函数和优化器:
初始化 TabTransformer 模型,设置输入维度、隐藏维度、注意力头数、层数和输出维度。
使用交叉熵损失函数 CrossEntropyLoss。
使用 Adam 优化器。
训练模型:
设置训练轮数 num_epochs。
在每个 epoch 中,遍历训练数据加载器,计算损失并反向传播更新模型参数。
打印每个 epoch 的损失值。
评估模型:
将模型设置为评估模式。
遍历测试数据加载器,预测标签并存储。
计算准确率、精确率、召回率和 F1 分数,并打印结果。
代码
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import requests
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score# 定义 TabTransformer 模型
class TabTransformer(nn.Module):def __init__(self, input_dim, hidden_dim, num_heads, num_layers, output_dim):super(TabTransformer, self).__init__()self.embedding = nn.Linear(input_dim, hidden_dim)self.transformer = nn.TransformerEncoder(encoder_layer=nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads, batch_first=True),num_layers=num_layers)self.fc = nn.Linear(hidden_dim, output_dim)def forward(self, x):x = self.embedding(x)x = self.transformer(x.unsqueeze(1)) # 增加序列维度x = x.squeeze(1) # 移除序列维度x = self.fc(x)return x# 定义一个简单的数据集
class SimpleDataset(Dataset):def __init__(self, data, labels):self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx], self.labels[idx]# 下载经典数据集
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
response = requests.get(url)
data = response.text.splitlines()# 将数据转换为 DataFrame
columns = ["sepal_length", "sepal_width", "petal_length", "petal_width", "class"]
df = pd.DataFrame([line.split(',') for line in data], columns=columns)# 处理空字符串
df.replace('', pd.NA, inplace=True)
df.dropna(inplace=True)# 将类别标签转换为数字
class_mapping = {'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 2}
df['class'] = df['class'].map(class_mapping)# 将数据集分为特征和标签
X = df.drop('class', axis=1).astype(float).values
y = df['class'].values# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 转换为 PyTorch 张量
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.long)# 创建数据加载器
train_dataset = SimpleDataset(X_train, y_train)
test_dataset = SimpleDataset(X_test, y_test)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)# 初始化模型、损失函数和优化器
input_dim = X_train.shape[1]
output_dim = len(class_mapping)
model = TabTransformer(input_dim=input_dim, hidden_dim=64, num_heads=8, num_layers=2, output_dim=output_dim)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练模型
num_epochs = 10
for epoch in range(num_epochs):model.train()for batch_data, batch_labels in train_dataloader:optimizer.zero_grad()outputs = model(batch_data)loss = criterion(outputs, batch_labels)loss.backward()optimizer.step()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 评估模型
model.eval()
all_predictions = []
all_labels = []with torch.no_grad():for batch_data, batch_labels in test_dataloader:outputs = model(batch_data)_, predicted = torch.max(outputs, 1)all_predictions.extend(predicted.numpy())all_labels.extend(batch_labels.numpy())# 计算评价指标
accuracy = accuracy_score(all_labels, all_predictions)
precision = precision_score(all_labels, all_predictions, average='weighted')
recall = recall_score(all_labels, all_predictions, average='weighted')
f1 = f1_score(all_labels, all_predictions, average='weighted')print(f'Accuracy: {accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1:.4f}')