猫狗分类识别模型建立②模型建立

ops/2024/10/18 14:26:46/

一、导入依赖库

pip install opencv-python  
pip install numpy  
pip install tensorflow
pip install keras

二、模型建立

'''
pip install opencv-python  
pip install numpy  
pip install tensorflow
pip install keras
'''
import os
import xml.etree.ElementTree as ETimport cv2
import numpy as np
from keras.layers import Input
from keras.models import Model
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import Conv2D, Dense, Flatten, MaxPooling2D
from tensorflow.keras.models import Sequential
from tensorflow.keras.utils import to_categorical# 设置文件夹路径
images_dir = "imgs/"
annotations_dir = "imgs/"
num_classes = 2  # 设置类别总数
input_shape = (128, 128, 3)
# 模型名称
model_name = "dog_cat.keras"
# 用于存储图像数据和标签的列表
images = []
labels = []"""
1 dog 狗
2 cat 猫
"""
# 假设我们有一个从标签文本到标签索引的映射字典
label_to_index = {"dog": 0,"cat": 1,# ... 添加其他类别
}# 遍历文件夹加载数据
for filename in os.listdir(images_dir):if filename.endswith(".png"):image_path = os.path.join(images_dir, filename)annotation_path = os.path.join(annotations_dir, filename[:-4] + ".xml")# 读取图像image = cv2.imread(image_path)image = cv2.resize(image, (128, 128))  # 调整图像大小images.append(image)# 解析XML标注文件获取标签tree = ET.parse(annotation_path)root = tree.getroot()object_element = root.find("object")if object_element is not None:label_text = object_element.find("name").textlabel_index = label_to_index.get(label_text)if label_index is not None:labels.append(label_index)else:print(f"Warning: Unknown label '{label_text}', skipping.")# 转换为NumPy数组并进行归一化
images = np.array(images) / 255.0
labels = np.array(labels)# 确保所有的标签都是有效的整数
if labels.dtype != int:raise ValueError("Labels must contain only integers.")labels = to_categorical(labels, num_classes=num_classes)  # 假设num_classes是类别的总数# 使用Functional API定义模型
# 创建一个输入层,shape参数指定了输入数据的形状,input_shape是一个之前定义的变量,表示输入数据的维度。
inputs = Input(shape=input_shape)
# 下面的每一行都是通过一个层对数据进行处理,并将处理后的结果传递给下一个层。
# 对输入数据进行卷积操作,使用32个3x3的卷积核,并使用ReLU激活函数。结果赋值给变量x。
x = Conv2D(32, (3, 3), activation="relu")(inputs)
# 对x进行最大池化操作,池化窗口大小为2x2。这有助于减少数据的空间尺寸,从而减少计算量并提取更重要的特征。
x = MaxPooling2D(pool_size=(2, 2))(x)
# 再次进行卷积操作,这次使用64个3x3的卷积核,并继续使用ReLU激活函数。
x = Conv2D(128, (3, 3), activation="relu")(x)
# 再次进行最大池化操作,进一步减少数据的空间尺寸。
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Flatten()(x)  # 将多维的数据展平为一维,以便后续可以连接到全连接层(或称为密集层)。
# 创建一个全连接层,包含64个神经元,并使用ReLU激活函数。这一层可以进一步提取和组合特征。
x = Dense(128, activation="relu")(x)
# 创建一个输出层,神经元的数量与类别的数量(num_classes)相等。使用softmax激活函数,将输出转换为概率分布。
outputs = Dense(num_classes, activation="softmax")(x)
# 使用输入和输出来创建模型实例
model = Model(inputs=inputs, outputs=outputs)  # 通过指定输入和输出来定义模型的结构。
# 编译模型,指定优化器、损失函数和评估指标
# 使用Adam优化器、分类交叉熵损失函数,并监控准确性指标。
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])# 使用图像数据和标签训练模型
# 使用fit方法训练模型,指定训练数据、训练轮次(epochs)和批处理大小(batch_size)。
model.fit(images, labels, epochs=55, batch_size=512)# 保存训练好的模型到文件
# 将训练好的模型保存为HDF5文件,以便以后加载和使用。
model.save(model_name)
# keras.saving.save_model(model, "cnn_model.keras")
# model.save("cnn_model.h5")

三、文件结构及构建的模型

①文件结构

②建立后的模型

四、模型可视化

可视化工具与兼容的python版本下载。


http://www.ppmy.cn/ops/45172.html

相关文章

6.S081的Lab学习——Lab5: xv6 lazy page allocation

文章目录 前言一、Eliminate allocation from sbrk() (easy)解析: 二、Lazy allocation (moderate)解析: 三、Lazytests and Usertests (moderate)解析: 总结 前言 一个本硕双非的小菜鸡,备战24年秋招。打算尝试6.S081&#xff0…

自然语言处理(NLP)中的迁移学习

Transfer Learning in NLP 迁移学习(Transfer Learning)无疑是目前深度学习中的新热点(相对而言)。在计算机视觉领域,它已经应用了一段时间,人们使用经过训练的模型从庞大的ImageNet数据集中学习特征&…

服务器感染了. rmallox勒索病毒,如何确保数据文件完整恢复?

导言: 近年来,随着信息技术的飞速发展,网络安全问题日益凸显。其中,勒索病毒作为一种严重的网络威胁,对个人和企业数据造成了巨大的威胁。本文将重点介绍.rmallox勒索病毒的特点、传播途径以及应对策略,旨…

前端加密的方式汇总

目录 一、Base64编码 二、哈希算法 三、对称加密(AES/DES) 四、非对称加密(RSA) 五、加盐 六、Web Cryptography API 七、总结 随着信息和数据安全重要性的日益凸显,如何保证信息数据在传输的过程中的安全成为开发者重点关注的内容。前端加密通常是指在浏览…

MYSQL框架结构

MYSQL框架结构 通过解析器和预处理生成解析树,预处理判断是否合法,如果合法则调用优化器去进行优化。

win10安装rabbitmq

安装 第一步:下载并安装erlang RabbitMQ服务端代码是使用并发式语言Erlang编写,因此首先需要安装Erlang下载地址:http://www.erlang.org/downloads采用默认安装即可,选择适合的安装路径 添加环境变量 第二步:下载并…

头歌数据库备份与恢复

第1关:数据库的备份和恢复 mysql -uroot -p123123 -h127.0.0.1 < /data/workspace/myshixun/src/data.sqlmysqldump -u root -p studb student> /student_bk.sqlmysql -uroot -p123123 -h127.0.0.1 -e "create database studb2;"mysql -u root -p123123 studb…

去掉macOS终端命令行前的(base)

macOS在安装了Anaconda&#xff08;或miniconda&#xff09;后&#xff0c;每次打开terminal都会默认打开名为base的虚拟环境。 默认不启动base conda config --set auto_activate_base false默认启动base conda config --set auto_activate_base true