利用pytorch两层线性网络对titanic数据集进行分类(kaggle)

devtools/2024/9/22 18:42:20/

pytorchtitanic_0">利用pytorch两层线性网络对titanic数据集进行分类

最近在看pytorch的入门课程,做了一下在kaggle网站上的作业,用的是titanic数据集,因为想搭一下神经网络,所以数据加载部分简单的把训练集和测试集中有缺失值的列还有含有字符串的列去除了,加入了DataLoader模块,其实这个数据集很小,用不到,本人还没入门,小白一枚。

import torch 
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
from torchvision import datasets
from torchvision import transforms
import pandas as pdclass titanicDataset(Dataset):def __init__(self,filepath):xy=np.loadtxt(filepath,delimiter=',',skiprows=1,usecols=[1,2,7,8],dtype=np.float32)self.len=xy.shape[0]# print(self.len)self.y_data=torch.from_numpy(xy[:,[0]])self.x_data=torch.from_numpy(xy[:,1:])def __getitem__(self,index):#获取索引元素 return self.x_data[index],self.y_data[index]def __len__(self):return self.len
dataset=titanicDataset('./pytorch/dataset/titanic/train.csv')
train_loader=DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=0)# print(dataset.x_data,dataset.y_data)
test_loader=DataLoader(dataset=np.loadtxt('./pytorch/dataset/titanic/test.csv',delimiter=',',skiprows=1,usecols=[1,6,7],dtype=np.float32),batch_size=32,shuffle=False,num_workers=0)
print(next(iter(test_loader)))class Model(torch.nn.Module):def __init__(self):super(Model,self).__init__()# self.linear1=torch.nn.Linear(4,3)self.linear2=torch.nn.Linear(3,2)self.linear3=torch.nn.Linear(2,1)self.sigmoid=torch.nn.Sigmoid()def forward(self,x):# x=self.sigmoid(self.linear1(x))x=self.sigmoid(self.linear2(x))x=self.sigmoid(self.linear3(x))return x
model=Model()
criterion=torch.nn.BCELoss(size_average=True)
optimizer=torch.optim.SGD(model.parameters(),lr=0.1,momentum=0.9)
for epoch in range(10000):acc_num=0for i,data in enumerate(train_loader,0):#1.Prepare datainputs,labels=data# print(inputs.shape[0])#2.Forwardy_pred=model(inputs)loss=criterion(y_pred,labels)# print(epoch,i,loss.item())#3.Backwardoptimizer.zero_grad()loss.backward()#4.Updateoptimizer.step()y_pred_label=torch.where(y_pred>0.5,torch.tensor([1.0]),torch.tensor([0.0]))acc_num+=torch.eq(y_pred_label,labels).sum().item()# print(acc_num,len(dataset),len(train_loader.dataset))acc=acc_num/len(dataset)
print(acc)
# print(test_loader)
# print(test_loader.dataset.shape)
out = model(torch.tensor(test_loader.dataset))
y_pred = torch.where(out>0.5,torch.tensor([1.0]),torch.tensor([0.0]))[:,0]
print(y_pred)
print(pd.Series(y_pred))
id=pd.read_csv('./pytorch/dataset/titanic/test.csv',usecols=['PassengerId']).iloc[:,0]
# print(type(id))pd.DataFrame({'PassengerId':id,'Survived':pd.Series(y_pred,dtype=int)}).to_csv('pred.csv',index=None)
a=pd.DataFrame([id,pd.Series(y_pred)])
print(a)
# print(y_pred[-10:])# for x in test_loader:
#     print(x.shape)
#     out = model(x)
#     y_pred = torch.where(out>0.5,torch.tensor([1.0]),torch.tensor([0.0]))
# print(y_pred)

http://www.ppmy.cn/devtools/31452.html

相关文章

深度学习之基于Matlab卷积神经网络验证码识别系统

欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 一、项目背景 随着互联网的发展,验证码作为一种常用的安全验证手段,被广泛应用于各种网站和…

【华为 ICT HCIA eNSP 习题汇总】——题目集19

1、(多选)以下选项中,FTP 常用文件传输类型有()。 A、ASCII 码类型 B、二进制类型 C、EBCDIC 类型 D、本地类型 考点:应用层 解析:(AB) 文件传输协议(FTP&…

【QT】初始QT

目录 一.背景1.GUI开发的各种技术方案2.什么是框架3.QT支持的系统4.QT的版本5.QT的优点6.QT的应用常见 二.环境搭建1.认识QTSDK中的重要工具2.使用QT Creator创建项目3.项目解释(1)main.cpp(2)widget.h(3)widget.cpp(4)widget.ui(5)Empty.pro(6)临时文件 三.初始QT1.Hello Worl…

《青少年成长管理2024》090 “目标计划:制定目标”6_6

《青少年成长管理2024》090 “目标计划:制定目标”6_6 六、时间预算(一)期间时间计算(二)阶段时间计算(三)成长期总时间预算 七、总体原则(一)幸福生活,快乐成…

自然语言处理基础

文章目录 一、基础与应用简单介绍基本任务重要应用 二、词表示与语言模型词表示方案一:用一组的相关词来表示当前词方案二:one-hot representation,将每一个词表示成一个独立的符号方案三:上下文表示法(contextual rep…

【开源物联网平台】window环境下搭建调试监控设备环境

🌈 个人主页:帐篷Li 🔥 系列专栏:FastBee物联网开源项目 💪🏻 专注于简单,易用,可拓展,低成本商业化的AIOT物联网解决方案 目录 一、使用docker脚本部署zlmediakit 1.1 …

从零搭建自己的javaweb网站,Javaweb网站项目打包jar后上传到Linux操作系统的阿里云服务器,公网成功访问,全流程,流程精简,小白秒懂

背景 很多同学自己写了一个javaweb,能在本地跑了,但是还想用公网访问自己的javaweb,写完一个项目99%进度,就差1%最后一步部署网站了,这篇文章教你如何快速地将javaweb部署到云服务器,笔者亲手总结&#xff…

QT:输入类控件的使用

LineEdit 录入个人信息 #include "widget.h" #include "ui_widget.h" #include <QDebug> #include <QString>Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui->setupUi(this);// 初始化输入框ui->lineEdit…