LeNet-5-实现-cifar2

news/2024/10/18 20:21:03/

标题`#LeNet-5 完成 cifar2(无注释源代码在本文最下方)

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers, losses, Model

1)定义一个有参有返回值的函数用于加载图片

def load_img(file_path):
img = tf.io.read_file(file_path)
img = tf.image.decode_jpeg(img)
img = tf.image.resize(img, [32, 32]) / 255.
label = tf.constant(1,tf.int32) if tf.strings.regex_full_match(file_path, ‘.airplane.’) else tf.constant(0,tf.int32)
return img, label

2)合理定义相关参数

batch_num = 100
epochs = 15

3)使用通道和自定义函数加载cifar2数据集

train_data = tf.data.Dataset.list_files(‘cifar2/train//.jpg’).map(load_img,tf.data.experimental.AUTOTUNE).shuffle(buffer_size=1000).batch(100).prefetch(tf.data.experimental.AUTOTUNE)
test_data = tf.data.Dataset.list_files(‘cifar2/test//.jpg’).map(load_img,tf.data.experimental.AUTOTUNE).shuffle(buffer_size=1000).batch(100).prefetch(tf.data.experimental.AUTOTUNE)

②模型搭建

class LeNet(Model):
def init(self):
super(LeNet, self).init()
self.c1 = layers.Conv2D(6, 5)
self.s2 = layers.MaxPooling2D()
self.c3 = layers.Conv2D(16, 5)
self.s4 = layers.MaxPooling2D()
self.f5 = layers.Flatten()
self.d6 = layers.Dense(120, activation=‘relu’)
self.d7 = layers.Dense(84, activation=‘relu’)
self.d8 = layers.Dense(2, activation=‘softmax’)

# 3)进行正向传播
@tf.function
def call(self, inputs):out = self.c1(inputs)out = self.s2(out)out = self.c3(out)out = self.s4(out)out = self.f5(out)out = self.d6(out)out = self.d7(out)out = self.d8(out)return out

③模型预测

model = LeNet()

1)查看每层的参数数量

model.build((None, 32, 32, 3))
model.summary()

2)进行训练,选择Adam优化器,合理选择损失函数和迭代次数

model.compile(‘adam’, losses.sparse_categorical_crossentropy, ‘accuracy’)
history = model.fit(train_data, epochs=epochs, validation_data=test_data)

3)绘制训练集与测试集准确率对比图

plt.plot(history.history[‘val_accuracy’], label=‘test_accuracy’)
plt.plot(history.history[‘accuracy’], label=‘train_accuracy’)
plt.legend()
plt.show()

‘’’
Model: “le_net”


Layer (type) Output Shape Param #

conv2d (Conv2D) multiple 456


max_pooling2d (MaxPooling2D) multiple 0


conv2d_1 (Conv2D) multiple 2416


max_pooling2d_1 (MaxPooling2 multiple 0


flatten (Flatten) multiple 0


dense (Dense) multiple 48120


dense_1 (Dense) multiple 10164


dense_2 (Dense) multiple 170

Total params: 61,326
Trainable params: 61,326
Non-trainable params: 0


Epoch 1/15
100/100 [] - 4s 40ms/step - loss: 0.4533 - accuracy: 0.7902 - val_loss: 0.3372 - val_accuracy: 0.8515
Epoch 2/15
100/100 [
] - 4s 40ms/step - loss: 0.3327 - accuracy: 0.8545 - val_loss: 0.2771 - val_accuracy: 0.8810
Epoch 3/15
100/100 [] - 4s 43ms/step - loss: 0.2565 - accuracy: 0.8940 - val_loss: 0.2434 - val_accuracy: 0.9025
Epoch 4/15
100/100 [
] - 5s 45ms/step - loss: 0.2159 - accuracy: 0.9110 - val_loss: 0.2283 - val_accuracy: 0.9115
Epoch 5/15
100/100 [] - 5s 46ms/step - loss: 0.1786 - accuracy: 0.9289 - val_loss: 0.2228 - val_accuracy: 0.9030
Epoch 6/15
100/100 [
] - 4s 45ms/step - loss: 0.1574 - accuracy: 0.9384 - val_loss: 0.2079 - val_accuracy: 0.9175
Epoch 7/15
100/100 [] - 4s 45ms/step - loss: 0.1290 - accuracy: 0.9529 - val_loss: 0.2092 - val_accuracy: 0.9205
Epoch 8/15
100/100 [
] - 4s 41ms/step - loss: 0.1022 - accuracy: 0.9603 - val_loss: 0.2297 - val_accuracy: 0.9095
Epoch 9/15
100/100 [] - 4s 43ms/step - loss: 0.0907 - accuracy: 0.9671 - val_loss: 0.2313 - val_accuracy: 0.9200
Epoch 10/15
100/100 [
] - 4s 44ms/step - loss: 0.0670 - accuracy: 0.9744 - val_loss: 0.2353 - val_accuracy: 0.9230
Epoch 11/15
100/100 [] - 4s 40ms/step - loss: 0.0501 - accuracy: 0.9817 - val_loss: 0.2627 - val_accuracy: 0.9160
Epoch 12/15
100/100 [
] - 4s 39ms/step - loss: 0.0366 - accuracy: 0.9888 - val_loss: 0.2789 - val_accuracy: 0.9250
Epoch 13/15
100/100 [] - 4s 39ms/step - loss: 0.0293 - accuracy: 0.9901 - val_loss: 0.2958 - val_accuracy: 0.9115
Epoch 14/15
100/100 [
] - 4s 39ms/step - loss: 0.0335 - accuracy: 0.9860 - val_loss: 0.3240 - val_accuracy: 0.9090
Epoch 15/15
100/100 [==============================] - 4s 40ms/step - loss: 0.0201 - accuracy: 0.9939 - val_loss: 0.3261 - val_accuracy: 0.9235

Process finished with exit code 0
‘’’

在这里插入图片描述

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers, losses, Model
def load_img(file_path):img = tf.io.read_file(file_path)img = tf.image.decode_jpeg(img)img = tf.image.resize(img, [32, 32]) / 255.label = tf.constant(1,tf.int32) if tf.strings.regex_full_match(file_path, '.*airplane.*') else tf.constant(0,tf.int32)return img, label
batch_num = 100
epochs = 15
train_data = tf.data.Dataset.list_files('cifar2/train/*/*.jpg')./
map(load_img,tf.data.experimental.AUTOTUNE)./
shuffle(buffer_size=1000).batch(100)./
prefetch(tf.data.experimental.AUTOTUNE)
test_data = tf.data.Dataset.list_files('cifar2/test/*/*.jpg')./
map(load_img,tf.data.experimental.AUTOTUNE)./
shuffle(buffer_size=1000).batch(100)./
prefetch(tf.data.experimental.AUTOTUNE)class LeNet(Model):def __init__(self):super(LeNet, self).__init__()self.c1 = layers.Conv2D(6, 5)self.s2 = layers.MaxPooling2D()self.c3 = layers.Conv2D(16, 5)self.s4 = layers.MaxPooling2D()self.f5 = layers.Flatten()self.d6 = layers.Dense(120, activation='relu')self.d7 = layers.Dense(84, activation='relu')self.d8 = layers.Dense(2, activation='softmax')@tf.functiondef call(self, inputs):out = self.c1(inputs)out = self.s2(out)out = self.c3(out)out = self.s4(out)out = self.f5(out)out = self.d6(out)out = self.d7(out)out = self.d8(out)return outmodel = LeNet()
model.build((None, 32, 32, 3))
model.summary()model.compile('adam', losses.sparse_categorical_crossentropy, 'accuracy')
history = model.fit(train_data, epochs=epochs, validation_data=test_data)plt.plot(history.history['val_accuracy'], label='test_accuracy')
plt.plot(history.history['accuracy'], label='train_accuracy')
plt.legend()
plt.show()

http://www.ppmy.cn/news/182539.html

相关文章

gearman mysql_gearman的持久化,以mysql的方式

1、为什么要持久化? gearman的job server中的工作队列存储在内存中,一旦服务器有未处理的任务时重启或者宕机,那么这些任务就会丢失。 持久化存储队列可以允许添加后台任务,并将其存储在外部的持久型队列里(比如MySQL数据库)。 2、关于gearm…

2021-05-18文件的读取和写入

文件的读取与字典 # file1 open(d:\note1.txt)#注意转义符 # file1 open(d:\\note1.txt)#多写一个\ # file1 open(d:/note1.txt)#用/代替\ # file1 open(rd:\note1.txt)#前面加一个r # file1open(d:/ceshi.txt,r) #r读取模式,w写入模式(会清空之前的内容…

Appium日志分析

Appium工作期间每个端口号的作用: 4723 :Appium服务的端口,负责接收脚本端发送过来的请求 4724:手机设备上的Bootstrap 端口,监听服务端 8200:Appium服务的端口,负责跟手机端进行交互 6790&…

Codeforces Round #790 (Div. 4)(A-E 题解)

Codeforces Round #790 A-E 题解 A. Lucky?B. Equal CandiesC. Most Similar WordsD. X-SumE. Eating Queries A. Lucky? A. Lucky? 题意 给你 t 个六位数(可能含前导0),对于每个六位数,如果前三位数字之和等于后三位数字之和…

php内核函数手册,开扒php内核函数,第三篇 implode

一开始觉得implode挺容易实现,但是写着写着才发现是挺复杂的,不说啦 来看看implode的用法吧 1 ? php 2 $arr array (Hello,World!,Beautiful,Day! ); 3 echo implode (" ", $arr ); 4 ? 上面会输出 Hello World! Beautiful Day! 下面的程序…

Appium运行时日志解析(内附Demo)

[Appium] Welcome to Appium v1.15.1 启动Appium1.15.1版本 [Appium] Non-default server args: [Appium] allowInsecure: { [Appium] } [Appium] denyInsecure: { [Appium] } [Appium] Appium REST http interface listener started on 0.0.0.0:4723 监听4723端口 [HTTP] --&g…

2020年9月11日41--appium的通信过程以及adb常用命令操作(上)

appium的通信过程 appium的通信过程: 1、发送http请示到appium server 2、appium server收到之后,创建session 3、获取已连接的设备,并找对应的安卓版本号4、获取io.appium.settings的状态,获取它的版本,以确保是当前…

Java自动化测试(adb常用命令 32)

adb adb「Android Bebug Bridge」是用来连接安卓和PC的桥梁 常用操作: 安装卸载apk推送拷贝文件查看设备硬件信息查看应用程序占用资源在设备执行shell命令 常用命令 帮助命令 $ adb help检测连接到电脑的安卓设备 $ adb devices $ adb connect 127.0.0.1:62001从手…