使用TensorFlow训练深度学习模型实战(上)

news/2024/12/22 18:04:56/

大家好,尽管大多数关于神经网络的文章都强调数学,而TensorFlow文档则强调使用现成数据集进行快速实现,但将这些资源应用于真实世界数据集是很有挑战性的,很难将数学概念和现成数据集与我的具体用例联系起来。本文旨在提供一个实用的、逐步的教程,介绍如何使用TensorFlow训练深度学习模型,并重点介绍如何将数据集重塑为TensorFlow对象,以便TensorFlow框架能够识别。

本文主要内容包括:

  • 将DataFrame转换为TensorFlow对象

  • 从头开始训练深度学习模型

  • 使用预训练的模型训练深度学习模型

  • 评估、预测和绘制训练后的模型。

安装TensorFlow和其他必需的库 

首先,你需要安装TensorFlow。你可以通过在终端或Anaconda中运行以下命令来完成:

# 安装所需的软件包
!pip install tensorflow
!pip install tensorflow-datasets

安装TensorFlow之后,导入其他必需的库,如Numpy、Matplotlib和Sklearn。

import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_splitfrom tensorflow.keras.applications.mobilenet_v2 import preprocess_input
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout

加载数据集

一旦导入了所有必需的库,下一步是获取数据集来搭建模型。TensorFlow允许使用各种输入格式,包括CSV、TXT和图像文件,有些数据集可以从TensorFlow-dataset中导入,这些数据集已准备好用作深度学习模型的输入。然而在许多情况下,数据集是以DataFrame格式而不是TensorFlow对象格式存在的。本文我们将使用Sklearn中的MNIST数据集,其格式为Pandas DataFrame。MNIST数据集广泛用于图像分类任务,包括70000个手写数字的灰度图像,每个图像大小为28x28像素。该数据集被分为60000个训练图像和10000个测试图像。

from sklearn.datasets import fetch_openml# 加载MNIST数据集
# mnist = fetch_openml('mnist_784')# 输出MNIST数据集
print('Dataset type:', type(mnist.data))# 浏览一下加载的数据集
mnist.data.head()

 通过输出DataFrame的前部,我们可以观察到它包含784列,每列代表一个像素。

 将DataFrame转换为TensorFlow数据集对象

加载了Pandas DataFrame,注意到TensorFlow不支持Pandas DataFrame作为模型的输入,因此必须将DataFrame转换为可以用于训练或评估模型的张量。这个转换过程确保数据以与TensorFlow API兼容的格式存在,为了将MNIST数据集从DataFrame转换为tf.data.Dataset对象,可以执行以下步骤:

  1. 将数据和目标转换为NumPy数组并对数据进行归一化处理

  2. 使用scikit-learn中的train_test_split将数据集拆分为训练集和测试集

  3. 将训练和测试数据重塑为28x28x1的图像

  4. 使用from_tensor_slices为训练集和测试集创建tf.data.Dataset对象

def get_dataset(mnist):# 加载MNIST数据集# mnist = fetch_openml('mnist_784')# 将数据和目标转换成numpy数组X = mnist.data.astype('float32')y = mnist.target.astype('int32')# 将数据归一化,使其数值在0和1之间X /= 255.0# 将数据集分成训练集和测试集X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 将训练数据重塑为28x28x1的图像X_train = X_train.values.reshape((-1, 28, 28, 1))X_test = X_test.values.reshape((-1, 28, 28, 1))# 为训练和测试集创建TensorFlow数据集对象train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test))# 输出训练和测试集的形状print('Training data shape:', X_train.shape)print('Training labels shape:', y_train.shape)print('Testing data shape:', X_test.shape)print('Testing labels shape:', y_test.shape)return X_test, y_test, X_train, y_train

 再来看一下我们的训练和测试TensorFlow对象:

 经过这个过程,原始数据集已经成功转换为形状为(5600,28,28,1)的TensorFlow对象。

经过以上的步骤我们已经完成了实战的前半部分,后文将继续讲解有关定义深度学习模型、训练模型和评估模型的内容。


http://www.ppmy.cn/news/987687.html

相关文章

4H-SiC nMOSFETs的亚阈值漏电流扫描滞后特性

目录 标题:On the Subthreshold Drain Current Sweep Hysteresis of 4H-SiC nMOSFETs研究了什么文章创新点文章的研究方法文章得出的结论 标题:On the Subthreshold Drain Current Sweep Hysteresis of 4H-SiC nMOSFETs 亚阈值滞后(Subthresh…

PKG内容查看工具:Suspicious Package for Mac安装教程

Suspicious Package Mac版是一款Mac平台上的查看 PKG 程序包内信息的应用,Suspicious Package Mac版支持查看全部包内全部文件,比如需要运行的脚本,开发者,来源等等。 suspicious package mac使用简单,只需在选择pkg安…

【Leetcode】53. 最大子数组和

一、题目 1、题目描述 给你一个整数数组 nums ,请你找出一个具有最大和的连续子数组(子数组最少包含一个元素),返回其最大和。 子数组 是数组中的一个连续部分。 示例1: 输入:nums = [-2,1,-3,4,-1,2,1,-5,4] 输出:6 解释:连续子数组 [4,-1,2,1] 的和最大,为 6 。…

128. 最长连续序列

128. 最长连续序列 题目-中等难度示例1. set去重 题目-中等难度 给定一个未排序的整数数组 nums ,找出数字连续的最长序列(不要求序列元素在原数组中连续)的长度。 请你设计并实现时间复杂度为 O(n) 的算法解决此问题。 示例 示例 1&…

分析npm run serve之后发生了什么?

首先需要明白的是,当你在终端去运行 npm run ****,会是什么过程。 根据上图的一个流程,就可以衍生出很多问题。 1,为什么不直接运行vue-cli-service serve? 因为直接运行 vue-cli-service serve,会报错&#xff0c…

idea 关闭页面右侧预览框/预览条

idea 关闭页面右侧预览框 如图,预览框存在想去除 找了好多方法,什么去掉“setting->appearance里的show editor preview tooltips”的对钩;又或者在该预览区的滚动条上右键,“取消勾选show code lens on scrollbar hover”。都…

回调函数的使用:案例一:c语言简单信号与槽机制。

系列文章目录 文章目录 系列文章目录前言一、回调函数1.1 回调函数基本概念1.2 简单实现 二、代码案例1.代码示例 总结 前言 了解回调函数的基本概念,函数指针的使用、简单信号与槽的实现机制; 一、回调函数 1.1 回调函数基本概念 回调函数就是一个通…

【PGMPY】 1. DAG基础结构

pgmpy 贝叶斯网络的纯python实现, 用途: 结构学习、 参数估计、 近似(基于采样) 精确推理 因果推理 安装 pip install pgmpyconda install -c ankurankan pgmpyconda install -c ankurankan pgmpy文档 https://pgmpy.org/index…