价格分类(神经网络)

server/2024/11/26 23:47:38/
# 1.导入依赖包
import timeimport torch
import torch.nn as nn
import torch.optim as optimfrom torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_splitimport numpy as np
import pandas as pd
import matplotlib.pyplot as pltfrom torchsummary import summary# 2.构建数据集
def create_dataset():# 2.1 读取数据集data = pd.read_csv('dataset/手机价格预测.csv')# 2.2 获取特征值和目标值,类型转化  特征(Float)  标签(Long)x, y = data.iloc[:, :-1], data.iloc[:, -1]x, y = x.astype(np.float32), y.astype(np.int64)# 2.3 数据集划分x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2,random_state=2)# 2.4 数据转Tensortrain_dataset = TensorDataset(torch.from_numpy(x_train.values), torch.tensor(y_train.values))test_dataset = TensorDataset(torch.from_numpy(x_test.values), torch.tensor(y_test.values))return train_dataset, test_dataset, x_train.shape[1], len(np.unique(y))# 3. 构建模型
class PhonePriceModel(nn.Module):def __init__(self, input_dim, output_dim):super(PhonePriceModel, self).__init__()self.linear1 = nn.Linear(input_dim, 256)self.linear2 = nn.Linear(256, 1024)self.fc = nn.Linear(1024, output_dim)def forward(self, x):x = torch.relu(self.linear1(x))x = torch.relu(self.linear2(x))output = self.fc(x)# output = torch.softmax(self.fc(x), dim=-1)return output# 4.模型训练(225)
def train(model, train_dataset, num_epochs, batch_size):# 2 初始化参数  损失函数  优化器loss1 = nn.CrossEntropyLoss()# optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.99, 0.99))start = time.time()# 2 2个遍历  epoch  dataloaderfor epoch in range(num_epochs):dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)total_num = 0total_loss = 0.0for x, y in dataloader:# 5 前向传播  损失计算 梯度归零  反向传播 参数更新output = model(x)loss = loss1(output, y)optimizer.zero_grad()loss.backward()optimizer.step()total_num += 1  # 批次total_loss += loss.item()epoch += 1print(f'epoch:{epoch + 1:4d},loss:{total_loss / (total_num * epoch):.4f}, time:{time.time() - start:.2f}s')# 模型持久化torch.save(model.state_dict(), 'model/phone2.pth')# 5.模型预测评估
def test(model, test_dataset, input_dim, output_dim):# 3.导入数据dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)correct = 0# 4.遍历数据for x, y in dataloader:# 4.1 前向传播output = model(x)print(output)# 4.2 获取输出结果(类别)y_pred = torch.argmax(output, dim=1)# print(y_pred)  # 预测错误# 4.3 计算准确率Acccorrect += (y_pred == y).sum()print(correct.item())Acc = correct.item() / len(test_dataset)return Accif __name__ == '__main__':train_dataset, test_dataset, feature_num, label_num = create_dataset()# 1.实例化模型model = PhonePriceModel(feature_num, label_num)# 2.加载模型model.load_state_dict(torch.load('model/phone2.pth'))# 模型训练# train(model, train_dataset, num_epochs=50, batch_size=8)# 模型预测Acc = test(model, test_dataset, feature_num, label_num)print(f'Acc:{Acc:.5f}')

http://www.ppmy.cn/server/145181.html

相关文章

Python Scikit-learn简介(二)

数据处理 数据划分 机器学习的数据,可以划分为训练集、验证集和测试集,也可以划分为训练集和测试集。 from sklearn.model_selection import train_test_split# 示例数据 X [[1, 2], [3, 4], [5, 6], [7, 8]] y [0, 1, 0, 1]# 划分数据集 X_train,…

【Vue3 for beginner】普通插槽、具名插槽、作用域插槽

🌈Don’t worry , just coding! 内耗与overthinking只会削弱你的精力,虚度你的光阴,每天迈出一小步,回头时发现已经走了很远。 📗插槽 在 Vue 3 中,插槽(Slots)是一个强大的功能&am…

搜维尔科技:多画面显示3D系统解决方案,数据孪生可视化大屏3D展示技术

集成多画面系统 集成多画面系统解决方案 1.适合多个用户的紧凑型入门级解决方案 2.会议室功能、审批功能、3D模型讨论等多种使用可能性 3.配有组合设备,方便整合 CAVE 多画面显示系统 1.专业的大屏幕多画面解决方案 2.墙壁、天花板和地板三面CAVE 3.专为沉浸…

免费实用在线AI工具集合 - 加菲工具

免费在线工具-加菲工具 https://orcc.online/ 在线录屏 https://orcc.online/recorder 时间戳转换 https://orcc.online/timestamp Base64 编码解码 https://orcc.online/base64 URL 编码解码 https://orcc.online/url Hash(MD5/SHA1/SHA256…) 计算 https://orcc.online/h…

初试无监督学习 - K均值聚类算法

文章目录 1. K均值聚类算法概述2. k均值聚类算法演示2.1 准备工作2.2 生成聚类用的样本数据集2.3 初始化KMeans模型对象,并指定类别数量2.4 用样本数据训练模型2.5 用训练好的模型生成预测结果2.6 输出预测结果2.7 可视化预测结果 3. 实战小结 1. K均值聚类算法概述…

Python 数据分析核心库大全!

(欢迎关注我的视频号) 👇我的小册 45章教程:(小白零基础用Python量化股票分析小册) ,原价299,限时特价2杯咖啡,满100人涨10元。 大家好!我是菜鸟哥! 今天我们来聊点干货:Python 数据…

算法日记 33 day 动态规划(打家劫舍,股票买卖)

今天来看看动态规划的打家劫舍和买卖股票的问题。 上题目!!!! 题目:打家劫舍 198. 打家劫舍 - 力扣(LeetCode) 你是一个专业的小偷,计划偷窃沿街的房屋。每间房内都藏有一定的现金…

高标准农田智慧农业系统建设方案

1 项目概述 1.1 建设背景 我国是农业大国,近30年来农田高产量主要依靠农药化肥的大量投入,大部分化肥和水资源没有被有效利用而随地弃置,导致大量养分损失并造成环境污染。我国农业生产仍然以传统生产模式为主,传统耕种只能凭经验施肥灌溉,不仅浪费大量的人力物力,也对环…