第R3周:RNN-心脏病预测

embedded/2025/1/6 5:18:09/
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

    文章目录

    • 一、前言
    • 二、代码流程
      • 1、导入包,设置GPU
      • 2、导入数据
      • 3、数据处理
      • 4、构建RNN模型
      • 5、编译模型
      • 6、模型训练
      • 7、模型评估

电脑环境:
语言环境:Python 3.8.0
深度学习环境:tensorflow 2.17.0

一、前言

传统神经网络的结构都比较简单:输入层-隐藏层-输出层
在这里插入图片描述
RNN和传统神经网络最大的区别在于每次都会将前一次的输出结果,带到下一次的隐藏层中,一起训练。如下图:
在这里插入图片描述

二、代码流程

1、导入包,设置GPU

import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]tf.config.experimental.set_memory_growth(gpu0, True)tf.config.experimental.set_visible_devices([gpu0], "GPU")gpus

2、导入数据

数据介绍:

  • age: 年龄
  • sex: 性别
  • cp:胸痛类型 (4 values)
  • trestbps: 静息血压
  • chol:血清胆甾醇 (mg/ dl
  • fbs:空腹血糖 >120 mg/dl
  • restecg:静息心电图结果(值 0,1,2)
  • thalach:达到的最大心率
  • exang:运动诱发的心绞痛
  • oldpeak:相对于静止状态,运动引起的ST段压低
  • slope: 运动峰值 ST 段的斜率
  • ca:荧光透视着色的主要血管数量(0-3)
  • thal:0=正常;1=固定缺陷;2=可逆转的缺陷
  • target:0=心脏病发作的几率较小1=心脏病发作的几率更大
import pandas as pd
import numpy as npdf = pd.read_csv("heart.csv")
df.head()

在这里插入图片描述

3、数据处理

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScalerX = df.drop("target", axis=1)
y = df["target"]X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=0)

标准化

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], 1)
X_test = X_test.reshape(X_test.shape[0], X_test.shape[1], 1)

4、构建RNN模型

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, SimpleRNN, Dropoutmodel = Sequential()
model.add(SimpleRNN(200, input_shape=(13, 1), activation="relu"))
model.add(Dense(100, activation="relu"))
model.add(Dense(1, activation='sigmoid'))model.summary()

5、编译模型

opt = tf.keras.optimizers.Adam(learning_rate=1e-4)
model.compile(optimizer=opt, loss="binary_crossentropy", metrics=["accuracy"])

6、模型训练

epochs = 100history = model.fit(X_train, y_train, epochs=epochs, batch_size=128,validation_data=(X_test, y_test),verbose=1)
Epoch 1/100
3/3 ━━━━━━━━━━━━━━━━━━━━ 3s 503ms/step - accuracy: 0.5104 - loss: 0.6909 - val_accuracy: 0.6129 - val_loss: 0.6858
..............................................................................................................
Epoch 100/100
3/3 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - accuracy: 0.8956 - loss: 0.2431 - val_accuracy: 0.8710 - val_loss: 0.4132

7、模型评估

import matplotlib.pyplot as pltacc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(14, 4))plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述


http://www.ppmy.cn/embedded/151310.html

相关文章

高质量C++小白教程:2.10-预处理器简介

当你在编译项目时,你可能希望编译器完全按照你编写的方式编译每一个代码文件,当事实并非如此。 相反,在编译之前,每一个.cpp文件都会经历一个预处理的阶段,在此阶段中,称为预处理器的程序对代码文件的文本进行各种更改. 预处理器实际上不会以任何方式修改原始代码文件,预处理…

3blue1brow线代笔记

向量 物理:空间中的箭头,长度和方向决定一个向量。只要两者相同,可以任意移动保持不变 计算机:有序的数字列表 (数组) 数学:向量可以是任何东西,只要保证两个向量相加以及数字与向量…

python字符串函数用法大全

目录 1.0 capitalize()函数 2.0 title()函数 3.0 swapcase()函数 4.0 lower()函数 5.0 upper()函数 7.0 center()函数 8.0 ljust()函数 9.0 rjust()函数 10.0 zfill()函数 11.0 count()函数 13.0 decode()函数 14.0 expandtabs()函数 15.0 find()函数 16.0 rfind()…

【OpenCV】使用Python和OpenCV实现火焰检测

1、 项目源码和结构(转) https://github.com/mushfiq1998/fire-detection-python-opencv 2、 运行环境 # 安装playsound:用于播放报警声音 pip install playsound # 安装opencv-python:cv2用于图像和视频处理,特别是…

我的 2024 年终总结

2024 年,我离开了待了两年的互联网公司,来到了一家聚焦教育机器人和激光切割机的公司,没错,是一家硬件公司,从未接触过的领域,但这还不是我今年最重要的里程碑事件 5 月份的时候,正式提出了离职…

深入FreeRTOS内核——第三章、任务管理

深入FreeRTOS内核——第三章、任务管理 文章目录 深入FreeRTOS内核——第三章、任务管理前言本章内容 一、任务函数二、任务的顶级状态三、任务创建3.1 xTaskCreate()函数 四、任务优先级4.1 通用调度器4.2 架构优化调度器 五、时间测量和时钟中断六、扩展非运行状态6.1 阻塞状…

细讲前端工程化

何为前端工程化 前端工程化是指将软件工程的原理和方法应用到前端开发中,以提高开发效率、代码质量和可维护性。随着 Web 应用的复杂度不断增加,传统的前端开发方式已经难以满足需求,因此引入了工程化的概念来更好地管理和优化前端开发流程。…

JLINK V9插入电脑没反应

一个V9版本,在公司电脑上使用没有问题,带回来在家里电脑上无法显示jlink _driver,只显示一个cdc 家里还有个jlink,版本可能是V8的,可以正常显示jlink 多方查找没看到什么原因,后面重新装了下jlink驱动就好…