MLP实现fashion_mnist数据集分类(1)-模型构建、训练、保存与加载(tensorflow)

server/2024/10/21 3:57:40/

1、查看tensorflow版本

import tensorflow as tfprint('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())

在这里插入图片描述

2、fashion_mnist数据集下载与展示

(train_image,train_label),(test_image,test_label) = tf.keras.datasets.fashion_mnist.load_data()
print(train_image.shape)
print(train_label.shape)
print(test_image.shape)
print(test_label.shape)

在这里插入图片描述

import matplotlib.pyplot as plt
# plt.imshow(train_image[0])  # 此处为啥是彩色的?def plot_images_lables(images,labels,start_idx,num=5):fig = plt.gcf()fig.set_size_inches(12,14)for i in range(num):ax = plt.subplot(1,num,1+i)ax.imshow(images[start_idx+i],cmap='binary')title = 'label=' + str(labels[start_idx+i])ax.set_title(title,fontsize=10)ax.set_xticks([])ax.set_yticks([])plt.show()
plot_images_lables(train_image,train_label,0,5)
# plot_images_lables(test_image,test_label,0,5)

在这里插入图片描述

3、数据预处理

X_train,X_test = tf.cast(train_image/255.0,tf.float32),tf.cast(test_image/255.0,tf.float32) # 归一化
y_train,y_test = train_label,test_label # 此处对y没有做onehot处理,需要使用稀疏交叉损失函数

4、模型构建

from keras import Sequential
from keras.layers import Flatten,Dense,Dropout
from keras import Inputmodel = Sequential()
model.add(Input(shape=(28,28)))
model.add(Flatten())
model.add(Dense(units=256,kernel_initializer='normal',activation='relu'))
model.add(Dropout(rate=0.1))
model.add(Dense(units=64,kernel_initializer='normal',activation='relu'))
model.add(Dropout(rate=0.1))
model.add(Dense(units=10,kernel_initializer='normal',activation='softmax'))
model.summary()

在这里插入图片描述

5、模型配置

model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['acc'])

6、模型训练

H = model.fit(x=X_train,y=y_train,validation_split=0.2,# validation_data=(X_test,y_test),epochs=10,batch_size=128,verbose=1)

在这里插入图片描述

plt.plot(H.epoch, H.history['loss'], label='loss')
plt.plot(H.epoch, H.history['val_loss'], label='val_loss')
plt.legend()

在这里插入图片描述

plt.plot(H.epoch, H.history['acc'], label='acc')
plt.plot(H.epoch, H.history['val_acc'], label='val_acc')
plt.legend()

在这里插入图片描述

7、模型评估

model.evaluate(X_test,y_test)

在这里插入图片描述

8、模型预测

import numpy as np
import matplotlib.pyplot as pltdef pred_plot_images_lables(images,labels,start_idx,num=5):# 预测res = model.predict(images[start_idx:start_idx+num])res = np.argmax(res,axis=1)# 画图fig = plt.gcf()fig.set_size_inches(12,14)for i in range(num):ax = plt.subplot(1,num,1+i)ax.imshow(images[start_idx+i],cmap='binary')title = 'label=' + str(labels[start_idx+i]) + ', pred=' + str(res[i])ax.set_title(title,fontsize=10)ax.set_xticks([])ax.set_yticks([])plt.show()
pred_plot_images_lables(X_test,y_test,0,5)

在这里插入图片描述

9、模型保存与加载

import numpy as nptf.keras.models.save_model(model,"model.keras")
loaded_model = tf.keras.models.load_model("model.keras")
# assert np.allclose(model.predict(X_test[:5]), loaded_model.predict(X_test[:5]))
print(np.argmax(model.predict(X_test[:5]),axis=1))
print(np.argmax(loaded_model.predict(X_test[:5]),axis=1))

在这里插入图片描述


http://www.ppmy.cn/server/35059.html

相关文章

怎么把音频文件转化成文字?6个软件教你转换音频文件

怎么把音频文件转化成文字?6个软件教你转换音频文件 以下是六个常用的软件和方法,可以帮助您将音频文件转换为文字: 1.一键识别王: 这是一款强大的语音识别工具,可以将音频文件快速转换为文字,并提供编…

手机短信删除了怎么恢复?教你几个简单方法快速找回!

手机短信是我们日常生活中重要的沟通工具,我们用手机短信来联络亲朋好友,也用来接收来自其他软件的信息。但是有时候我们可能会不小心删除了一些重要的短信,手机短信删除了怎么恢复呢?本文将为您介绍3个简单方法,帮助您…

LangChain-RAG学习之 文档加载器

目录 一、实现原理 二、文档加载器的选择 (一).PDF 加载本地文件 可能需要的环境配置 (二).CSV 1、使用每个文档一行的 CSV 数据加载 CSVLoader 2、自定义 csv 解析和加载 (csv_args 3、指定用于 标识文档来源的 列(source_column (三)、文件目…

18.Blender 渲染工程、打光方法及HDR贴图导入

HDR环境 如何导入Blender的HDR环境图 找到材质球信息 在右上角,点击箭头,展开详细部分 点击材质球,会出现下面一列材质球,将鼠标拖到第二个材质球,会显示信息 courtyard.exr 右上角打开已渲染模式 左边这里选择世界…

(06)vite与ts的结合

文章目录 系列全集package.json在根目录创建 tsconfig.json 文件在根目录创建 vite.config.ts 文件index.html额外的类型声明 系列全集 (01)vite 从启动服务器开始 (02)vite环境变量配置 (03)vite 处理 c…

Linux搭建mysql环境

搭建 MySQL 环境 1、使用 wget 下载安装包,下载到 opt 目录中 wget http://dev.mysql.com/get/mysql57-community-release-el7-10.noarch.rpm2、安装 MySQL 公钥 rpm -i mysql57-community-release-el7-10.noarch.rpmrpm --import https://repo.mysql.com/RPM-GP…

Nvme协议第三章 Controller Registers

控制器寄存器位于MLBAR/MUBAR寄存器(PCI BAR0和BAR1)中,该寄存器应映射到支持有序访问和可变访问宽度的内存空间。host主机通过访问虚拟内存的方式访问该部分寄存器。 注:访问过程只能一次访问一个寄存器,不能多个访问…

安卓第三方app调用system/lib库报错的问题

报错如下 04-29 13:45:13.787 2339 2339 E AndroidRuntime: java.lang.UnsatisfiedLinkError: dlopen failed: library "/system/lib/libxxxx.so" needed or dlopened by "/apex/com.android.art/lib/libnativeloader.so" is not accessible for the nam…