神经网络_使用tensorflow对fashion mnist衣服数据集分类

news/2024/9/23 13:15:17/
from tensorflow import keras 
import matplotlib.pyplot as plt

1.数据预处理

1.1 下载数据集

fashion_mnist = keras.datasets.fashion_mnist
#下载 fashion mnist数据集
(train_images, train_labels),(test_images, test_labels) = fashion_mnist.load_data()print("train_images shape ", train_images.shape)
print("train_labels shape ", train_labels.shape)
print("train_labels[0] ", train_labels[0])
train_images shape  (60000, 28, 28)
train_labels shape  (60000,)
train_labels[0]  9

1.2展示数据集的第一张图片

plt.figure()
plt.imshow(train_images[0])
plt.colorbar()
plt.grid(False)
plt.show
<function matplotlib.pyplot.show(close=None, block=None)>

在这里插入图片描述

1.3 展示前25张图片和图片名称

train_images = train_images / 255.0;
test_images = test_images / 255.0;plt.figure(figsize=(10, 10))
class_names = ['T-shirt/top','Trouser','Pullover','Dress','Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
print("train_labels ", train_labels[:25])
for i in range(25):plt.subplot(5, 5, i+1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i], cmap=plt.cm.binary)plt.xlabel(class_names[train_labels[i]])
plt.show()
train_labels  [9 0 0 3 0 2 7 2 5 5 0 9 5 5 7 9 1 0 6 4 3 1 4 8 4]

在这里插入图片描述

2. 模型实现

2.1模型定义

#定义模型
model = keras.Sequential([keras.layers.Flatten(input_shape=(28,28)),keras.layers.Dense(128, activation="relu"),keras.layers.Dense(10, activation="softmax")
])
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=['accuracy'])
model.fit(train_images, train_labels, epochs=10)
D:\python\Lib\site-packages\keras\src\layers\reshaping\flatten.py:37: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.super().__init__(**kwargs)Epoch 1/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 977us/step - accuracy: 0.0967 - loss: 2.3028
Epoch 2/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step - accuracy: 0.0991 - loss: 2.3027
Epoch 3/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step - accuracy: 0.0956 - loss: 2.3028
Epoch 4/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step - accuracy: 0.0987 - loss: 2.3027
Epoch 5/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 968us/step - accuracy: 0.0988 - loss: 2.3028
Epoch 6/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step - accuracy: 0.1009 - loss: 2.3027
Epoch 7/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step - accuracy: 0.0998 - loss: 2.3027
Epoch 8/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step - accuracy: 0.0968 - loss: 2.3028
Epoch 9/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step - accuracy: 0.1036 - loss: 2.3027
Epoch 10/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 987us/step - accuracy: 0.0973 - loss: 2.3028<keras.src.callbacks.history.History at 0x20049c207d0>

2.2模型评估测试

#评估测试
test_loss, test_accuracy = model.evaluate(test_images, test_labels, verbose=2)
print("test_loss ", test_loss)
print("test_accuracy", test_accuracy)
313/313 - 0s - 892us/step - accuracy: 0.1000 - loss: 2.3026
test_loss  2.3026490211486816
test_accuracy 0.10000000149011612

2.3模型预测

predict_result = model.predict(test_images)
print("predict_result shape, 样本数,每个样本对每个分类的得分 ", predict_result.shape)
print("样本1的每个分类得分, ", predict_result[0])
sample_one_result = np.argmax(predict_result[0])
print("样本1的分类结果%d %s"%(sample_one_result,class_names[sample_one_result]))
print("样本1的真实分类结果%d %s"%(test_labels[0],class_names[test_labels[0]]))
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 671us/step
predict_result shape, 样本数,每个样本对每个分类的得分  (10000, 10)
样本1的每个分类得分,  [0.10038214 0.09719477 0.10009037 0.10101561 0.09946147 0.101658510.10063848 0.09979857 0.09982409 0.09993599]
样本1的分类结果5 Sandal
样本1的真实分类结果9 Ankle boot

2.4 查看指定测试图片的预测结果

#画指定索引位置的图
def plot_image(index, predict_classes, true_labels, images):true_label = true_labels[index]image = images[index]plt.grid(False)plt.xticks([])plt.yticks([])plt.imshow(image, cmap=plt.cm.binary)predict_label = np.argmax(predict_classes)if predict_label == true_label:color = 'blue'else:color = 'red'plt.xlabel("{} {:2.0f}%({})".format(class_names[predict_label],100 * np.max(predict_classes),class_names[true_label]), color=color)
# 画指定样本的对所有分类的预测得分    
def plot_predict_classes(i, predict_classes, true_labels):true_label = train_labels[i]plt.grid(False)plt.xticks(range(10))plt.yticks([])current_plot = plt.bar(range(10), predict_classes, color="#777777")plt.ylim([0,1])predict_label = np.argmax(predict_classes)current_plot[predict_label].set_color("red")current_plot[true_label].set_color('blue')# 画第一个样本的图,和对每个分类的得分
i = 0
plt.figure(figsize=(6,3))
plt.subplot(1,2,1)
plot_image(i, predict_result[i], test_labels, test_images)
plt.subplot(1,2,2)
plot_predict_classes(i, predict_result[i], test_labels)
plt.show()

在这里插入图片描述

3.保存训练的模型

3.1保存模型

# 保存模型
model.save('fashion_model.h5')
WARNING:absl:You are saving your model as an HDF5 file via `model.save()` or `keras.saving.save_model(model)`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')` or `keras.saving.save_model(model, 'my_model.keras')`. 

3.2保存模型到json文件

#查看模型
model_json = model.to_json()
print("model json: ", model_json)#保存json到文件中
with open('fashion_model_config.json', 'w') as json:json.write(model_json)#从json文件中加载模型
print("json from model")
json_model = keras.models.model_from_json(model_json)
json_model.summary()
model json:  {"module": "keras", "class_name": "Sequential", "config": {"name": "sequential", "trainable": true, "dtype": {"module": "keras", "class_name": "DTypePolicy", "config": {"name": "float32"}, "registered_name": null}, "layers": [{"module": "keras.layers", "class_name": "InputLayer", "config": {"batch_shape": [null, 28, 28], "dtype": "float32", "sparse": false, "name": "input_layer"}, "registered_name": null}, {"module": "keras.layers", "class_name": "Flatten", "config": {"name": "flatten", "trainable": true, "dtype": {"module": "keras", "class_name": "DTypePolicy", "config": {"name": "float32"}, "registered_name": null}, "data_format": "channels_last"}, "registered_name": null, "build_config": {"input_shape": [null, 28, 28]}}, {"module": "keras.layers", "class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": {"module": "keras", "class_name": "DTypePolicy", "config": {"name": "float32"}, "registered_name": null}, "units": 128, "activation": "relu", "use_bias": true, "kernel_initializer": {"module": "keras.initializers", "class_name": "GlorotUniform", "config": {"seed": null}, "registered_name": null}, "bias_initializer": {"module": "keras.initializers", "class_name": "Zeros", "config": {}, "registered_name": null}, "kernel_regularizer": null, "bias_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "registered_name": null, "build_config": {"input_shape": [null, 784]}}, {"module": "keras.layers", "class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": {"module": "keras", "class_name": "DTypePolicy", "config": {"name": "float32"}, "registered_name": null}, "units": 10, "activation": "softmax", "use_bias": true, "kernel_initializer": {"module": "keras.initializers", "class_name": "GlorotUniform", "config": {"seed": null}, "registered_name": null}, "bias_initializer": {"module": "keras.initializers", "class_name": "Zeros", "config": {}, "registered_name": null}, "kernel_regularizer": null, "bias_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "registered_name": null, "build_config": {"input_shape": [null, 128]}}], "build_input_shape": [null, 28, 28]}, "registered_name": null, "build_config": {"input_shape": [null, 28, 28]}, "compile_config": {"loss": "sparse_categorical_crossentropy", "loss_weights": null, "metrics": ["accuracy"], "weighted_metrics": null, "run_eagerly": false, "steps_per_execution": 1, "jit_compile": false}}
json from model
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ Layer (type)                         ┃ Output Shape                ┃         Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ flatten (Flatten)                    │ (None, 784)                 │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dense (Dense)                        │ (None, 128)                 │         100,480 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dense_1 (Dense)                      │ (None, 10)                  │           1,290 │
└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
 Total params: 203,542 (795.09 KB)
 Trainable params: 101,770 (397.54 KB)
 Non-trainable params: 0 (0.00 B)
 Optimizer params: 101,772 (397.55 KB)

3.3 保存模型权重到文件

weights = model.get_weights()
print("weight ", weights)
model.save_weights('fashion.weights.h5')
model.load_weights('fashion.weights.h5')
print("weight from file", model.get_weights())
weight  [array([[-0.06041221, -0.03045469, -0.06056997, ...,  0.06603239,-0.06018624, -0.02584767],[-0.06430402, -0.07436118, -0.00909608, ..., -0.04476351,-0.01347907,  0.00300767],[ 0.07909157, -0.0689464 ,  0.07742291, ..., -0.00037885,-0.02884226,  0.05017615],...,[-0.00013881,  0.0794938 ,  0.00120725, ..., -0.00251798,-0.06103022, -0.05509381],[ 0.04131137, -0.0285325 ,  0.06929631, ...,  0.07573903,0.02105945, -0.0524031 ],[ 0.07209501, -0.05137011, -0.07911879, ...,  0.02135488,0.0670035 ,  0.02766179]], dtype=float32), array([-0.00600429, -0.00547086, -0.00584014, -0.00600401, -0.00600361,-0.00565217, -0.00043141, -0.00599924, -0.00380762, -0.00364303,-0.00600468, -0.00330669, -0.00374643, -0.00600456, -0.0060048 ,-0.00600465, -0.0060041 , -0.00696887, -0.0011937 , -0.00599459,-0.00600372, -0.00600169, -0.00512277, -0.00579378, -0.00599535,-0.00598798, -0.00369858, -0.00600331, -0.00596425, -0.00598993,-0.00331114, -0.00600269, -0.00648344, -0.00598456, -0.00600508,-0.0050234 , -0.00600506, -0.00600394, -0.00370826, -0.00600255,-0.00318562, -0.0008926 , -0.00600376, -0.00600392, -0.00600293,-0.0010591 , -0.00526909, -0.0044194 , -0.0060979 , -0.00359087,-0.00599469, -0.00600368, -0.00600309, -0.00600125, -0.0060042 ,-0.0060032 , -0.00277885, -0.00599926, -0.00199332, -0.00494259,-0.00267067, -0.00600501, -0.0060036 , -0.00600471, -0.0060045 ,-0.00259782, -0.0027171 , -0.0060039 , -0.00141335, -0.00366305,-0.00254625, -0.00596222, -0.00328439, -0.00600358, -0.00597709,-0.00600401, -0.00600445, -0.00635821, -0.00166575, -0.00600483,-0.00459235, -0.00600466, -0.00637798, -0.00588632, -0.00599989,-0.0034114 , -0.00600291, -0.00600177, -0.00640314, -0.00600435,-0.00600042, -0.00600292, -0.00600482, -0.00600426, -0.00473085,-0.00157892, -0.00600219, -0.00364143, -0.00600267, -0.00600363,-0.00281488, -0.00600338, -0.00600482, -0.0025767 , -0.00744624,-0.00600235, -0.0060039 , -0.00600472, -0.00109048, -0.00483145,-0.00587764, -0.00600309, -0.00598578, -0.00599881, -0.00370371,-0.00600146, -0.00597422, -0.00600465, -0.00600461, -0.0060043 ,-0.00600423, -0.00243223, -0.00600425, -0.00600203, -0.0045927 ,-0.00371987, -0.00176624, -0.00600512], dtype=float32), array([[ 0.03556623,  0.1688491 , -0.10362723, ...,  0.13207223,-0.06696159, -0.15404737],[ 0.08589712,  0.0726881 , -0.03621184, ..., -0.13316402,-0.11030427, -0.07204279],[-0.02775251,  0.12212092,  0.12542443, ...,  0.05409406,0.07715587,  0.12737972],...,[-0.12100082, -0.0844327 ,  0.03725254, ...,  0.04297927,-0.06126365, -0.04448495],[ 0.00898614,  0.11527378, -0.10356722, ..., -0.09458876,-0.02348839,  0.11287841],[-0.14625832, -0.17126669, -0.0226883 , ..., -0.1290805 ,0.1703024 ,  0.10214148]], dtype=float32), array([ 0.00133452, -0.03093298, -0.00157637,  0.0076253 , -0.00787955,0.0139694 ,  0.0038848 , -0.004496  , -0.00424037, -0.00311988],dtype=float32)]
weight from file [array([[-0.06041221, -0.03045469, -0.06056997, ...,  0.06603239,-0.06018624, -0.02584767],[-0.06430402, -0.07436118, -0.00909608, ..., -0.04476351,-0.01347907,  0.00300767],[ 0.07909157, -0.0689464 ,  0.07742291, ..., -0.00037885,-0.02884226,  0.05017615],...,[-0.00013881,  0.0794938 ,  0.00120725, ..., -0.00251798,-0.06103022, -0.05509381],[ 0.04131137, -0.0285325 ,  0.06929631, ...,  0.07573903,0.02105945, -0.0524031 ],[ 0.07209501, -0.05137011, -0.07911879, ...,  0.02135488,0.0670035 ,  0.02766179]], dtype=float32), array([-0.00600429, -0.00547086, -0.00584014, -0.00600401, -0.00600361,-0.00565217, -0.00043141, -0.00599924, -0.00380762, -0.00364303,-0.00600468, -0.00330669, -0.00374643, -0.00600456, -0.0060048 ,-0.00600465, -0.0060041 , -0.00696887, -0.0011937 , -0.00599459,-0.00600372, -0.00600169, -0.00512277, -0.00579378, -0.00599535,-0.00598798, -0.00369858, -0.00600331, -0.00596425, -0.00598993,-0.00331114, -0.00600269, -0.00648344, -0.00598456, -0.00600508,-0.0050234 , -0.00600506, -0.00600394, -0.00370826, -0.00600255,-0.00318562, -0.0008926 , -0.00600376, -0.00600392, -0.00600293,-0.0010591 , -0.00526909, -0.0044194 , -0.0060979 , -0.00359087,-0.00599469, -0.00600368, -0.00600309, -0.00600125, -0.0060042 ,-0.0060032 , -0.00277885, -0.00599926, -0.00199332, -0.00494259,-0.00267067, -0.00600501, -0.0060036 , -0.00600471, -0.0060045 ,-0.00259782, -0.0027171 , -0.0060039 , -0.00141335, -0.00366305,-0.00254625, -0.00596222, -0.00328439, -0.00600358, -0.00597709,-0.00600401, -0.00600445, -0.00635821, -0.00166575, -0.00600483,-0.00459235, -0.00600466, -0.00637798, -0.00588632, -0.00599989,-0.0034114 , -0.00600291, -0.00600177, -0.00640314, -0.00600435,-0.00600042, -0.00600292, -0.00600482, -0.00600426, -0.00473085,-0.00157892, -0.00600219, -0.00364143, -0.00600267, -0.00600363,-0.00281488, -0.00600338, -0.00600482, -0.0025767 , -0.00744624,-0.00600235, -0.0060039 , -0.00600472, -0.00109048, -0.00483145,-0.00587764, -0.00600309, -0.00598578, -0.00599881, -0.00370371,-0.00600146, -0.00597422, -0.00600465, -0.00600461, -0.0060043 ,-0.00600423, -0.00243223, -0.00600425, -0.00600203, -0.0045927 ,-0.00371987, -0.00176624, -0.00600512], dtype=float32), array([[ 0.03556623,  0.1688491 , -0.10362723, ...,  0.13207223,-0.06696159, -0.15404737],[ 0.08589712,  0.0726881 , -0.03621184, ..., -0.13316402,-0.11030427, -0.07204279],[-0.02775251,  0.12212092,  0.12542443, ...,  0.05409406,0.07715587,  0.12737972],...,[-0.12100082, -0.0844327 ,  0.03725254, ...,  0.04297927,-0.06126365, -0.04448495],[ 0.00898614,  0.11527378, -0.10356722, ..., -0.09458876,-0.02348839,  0.11287841],[-0.14625832, -0.17126669, -0.0226883 , ..., -0.1290805 ,0.1703024 ,  0.10214148]], dtype=float32), array([ 0.00133452, -0.03093298, -0.00157637,  0.0076253 , -0.00787955,0.0139694 ,  0.0038848 , -0.004496  , -0.00424037, -0.00311988],dtype=float32)]

http://www.ppmy.cn/news/1529329.html

相关文章

高刷显示器哪个好?540Hz才有资格称高刷

高刷显示器哪个好&#xff1f;说实话&#xff0c;540Hz这些才能成为高刷显示器&#xff0c;什么200,240的&#xff0c;都不够高&#xff0c;什么是从容&#xff0c;有我不用才叫从容。下面我们一起来看看540Hz的高刷显示器都有哪些吧&#xff01; 1.高刷显示器哪个好 - 蚂蚁电…

WAN广域网技术--PPP和PPPoE

广域网基础概述 广域网&#xff08;Wide Area Network&#xff0c;WAN&#xff09;是一种覆盖广泛地区的计算机网络&#xff0c;它连接不同地理位置的计算机、服务器和设备。广域网通常用于连接不同城市、州或国家之间的网络&#xff0c;它通过互联网服务提供商&#xff08;ISP…

逻辑回归 和 支持向量机(SVM)比较

为了更好地理解为什么在二分类问题中使用 SVM&#xff0c;逻辑回归的区别&#xff0c;我们需要深入了解这两种算法的区别、优势、劣势&#xff0c;以及它们适用于不同场景的原因。 逻辑回归和 SVM 的比较 1. 模型的核心思想 • 逻辑回归&#xff1a; • 基于概率的模型&…

Android ImageView支持每个角的不同半径

Android ImageView支持每个角的不同半径 import android.annotation.SuppressLint; import android.content.Context; import android.content.res.ColorStateList; import android.content.res.Resources; import android.content.res.Resources.NotFoundException; import an…

Flyway-SQL 脚本与 Java 迁移

Flyway SQL 脚本与 Java 迁移详解 Flyway 是一种数据库迁移工具&#xff0c;提供了 SQL 脚本和 Java 迁移两种方式来管理数据库变更。在 Flyway 中&#xff0c;数据库迁移是通过逐步执行迁移脚本或代码来完成的。Flyway 既可以通过 SQL 文件直接执行数据库操作&#xff0c;也可…

spark读取数据性能提升

1. 背景 spark默认的jdbc只会用单task读取数据&#xff0c;读取大数据量时&#xff0c;效率低。 2. 解决方案 根据分区字段&#xff0c;如日期进行划分&#xff0c;增加task数量提升效率。 /*** 返回每个task按时间段划分的过滤语句* param startDate* param endDate* param …

在项目管理中,项目进度由哪些要素决定?

在项目管理领域&#xff0c;项目进度受到多种要素的综合影响。以下是一些关键的决定要素&#xff1a; 一、项目范围 1、任务清单 明确的任务清单是项目进度的基础。详细列出项目中需要完成的各项任务&#xff0c;包括任务的先后顺序、并行任务等&#xff0c;直接关系到进度规划…

dedecms——四种webshell姿势

姿势一&#xff1a;通过文件管理器上传WebShell 步骤一&#xff1a;访问目标靶场其思路为 dedecms 后台可以直接上传任意文件&#xff0c;可以通过文件管理器上传php文件获取webshell 步骤二&#xff1a;登陆到后台点击【核心】--》 【文件式管理器】--》 【文件上传】将准备好…