看图识药,python开发实现基于VisionTransformer的119种中草药图像识别系统

news/2025/2/21 7:03:05/

中药药材图像识别相关的实践在前面的系列博文中已经有了相应的实践了,感兴趣的话可以自行移步阅读即可,每篇文章的侧重点不同:

《python基于轻量级GhostNet模型开发构建23种常见中草药图像识别系统》

《基于轻量级MnasNet模型开发构建40种常见中草药图像识别系统》

《基于ResNet模型的908种超大规模中草药图像识别系统》

本文的核心思想是想要应用实践VIT(Vision Transformer)来开发构建图像识别系统,首先看下实例效果:

Vision Transformer(ViT)是一种基于自注意力机制的视觉模型,用于图像分类和其他计算机视觉任务。它是由Dosovitskiy等人在2020年提出的,将Transformer模型成功应用于图像领域。

ViT的构建原理如下:

  1. 输入图像划分为固定大小的图像块(或称为“补丁”),并通过一个线性变换将每个图像块映射为一个向量。这些向量组成了输入序列。

  2. 使用位置编码将位置信息引入输入序列。位置编码是一个学习的过程,用于为每个输入位置提供相对和绝对位置信息。

  3. 输入序列首先通过多头注意力(Multi-Head Attention)模块进行处理。多头注意力允许模型在不同的表示子空间中学习关注不同的图像特征。

  4. 在多头注意力模块中,每个补丁向量都与其他补丁向量进行交互,并计算其自注意力得分。这些得分表示了补丁之间的相关性,模型可以根据这些得分对不同补丁的重要性进行加权。

  5. 通过加权和补丁向量的线性组合,得到了每个补丁向量的新表示。这个表示包含了该补丁与其他补丁的相关性信息。

  6. 接下来,通过一个前馈神经网络(Feed-Forward Network)对每个补丁向量的新表示进行非线性变换,以更好地捕捉图像特征。

  7. 经过多个注意力和前馈神经网络堆叠的层,最终得到了一个编码了整个图像信息的向量序列。

  8. 为了进行图像分类,可以使用一个全局平均池化层(Global Average Pooling)将向量序列转换为一个固定长度的向量表示。然后,可以通过一个全连接层将这个向量映射到不同类别的概率分布。

总体来说,Vision Transformer通过将图像划分为补丁并利用自注意力机制对这些补丁进行交互,实现了对图像特征的学习和编码。相较于传统的卷积神经网络,ViT不需要使用卷积操作,而是完全基于自注意力机制进行图像特征的建模。

本文使用到的数据集来源于网络数据采集与人工处理,主要是收集了常见的100多种中药药材,数据集加载解析处理实现如下:

# 加载解析创建数据集
if not os.path.exists("dataset.json"):train_dataset = []test_dataset = []all_dataset = []classes_list = os.listdir(datasetDir)classes_list.sort()num_classes = len(classes_list)if not os.path.exists("labels.json"):with open("labels.json","w") as f:f.write(json.dumps(classes_list))print("classes_list: ", classes_list)for one_label in os.listdir(datasetDir):oneDir = datasetDir + one_label + "/"for one_pic in os.listdir(oneDir):one_path = oneDir + one_picone_ind = classes_list.index(one_label)all_dataset.append([one_ind, one_path])train_ratio = 0.90train_num = int(train_ratio * len(all_dataset))all_inds = list(range(len(all_dataset)))train_inds = random.sample(all_inds, train_num)test_inds = [one for one in all_inds if one not in train_inds]for one_ind in train_inds:train_dataset.append(all_dataset[one_ind])for one_ind in test_inds:test_dataset.append(all_dataset[one_ind])

简单看下实例数据:

【艾叶】

【陈皮】

【党参】

【何首乌】

基础的vit实现如下:

import tensorflow as tf
from tensorflow.keras import layersdef create_vision_transformer(input_shape, num_classes, num_layers, d_model, num_heads, mlp_dim, dropout_rate):inputs = tf.keras.Input(shape=input_shape)x = layers.Conv2D(filters=d_model, kernel_size=1)(inputs)x = layers.Reshape((-1, d_model))(x)x = layers.LayerNormalization(epsilon=1e-6)(x)# Patch embeddingsnum_patches = x.shape[1]patch_size = x.shape[2]x = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)(x, x)x = layers.Add()([x, inputs])x = layers.LayerNormalization(epsilon=1e-6)(x)x = layers.Conv1D(filters=d_model, kernel_size=1)(x)x = layers.LayerNormalization(epsilon=1e-6)(x)# Transformer Encoder layersfor _ in range(num_layers):# Attention and MLP blocky = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)(x, x)y = layers.Add()([y, x])y = layers.LayerNormalization(epsilon=1e-6)(y)y = layers.Conv1D(filters=mlp_dim, kernel_size=1, activation="relu")(y)y = layers.Conv1D(filters=d_model, kernel_size=1)(y)y = layers.Add()([y, x])x = layers.LayerNormalization(epsilon=1e-6)(y)# Classification headx = layers.GlobalAveragePooling1D()(x)x = layers.Dropout(rate=dropout_rate)(x)x = layers.Dense(units=num_classes, activation="softmax")(x)model = tf.keras.Model(inputs=inputs, outputs=x)return model# Example usage
input_shape = (224, 224, 3)
num_classes = 1000
num_layers = 12
d_model = 512
num_heads = 8
mlp_dim = 2048
dropout_rate = 0.1model = create_vision_transformer(input_shape, num_classes, num_layers, d_model, num_heads, mlp_dim, dropout_rate)
model.summary()

默认100次epoch的迭代计算,等待训练完成后对整体训练过程进行可视化,如下所示:
【准确率曲线】

【loss曲线】

可视化推理实例如下所示:

 


http://www.ppmy.cn/news/1267417.html

相关文章

gin投票系统3

对应视频v1版本 1.优化登陆接口 将同步改为异步 原login前端代码&#xff1a; <!doctype html> <html lang"en"> <head><meta charset"utf-8"><title>香香编程-投票项目</title> </head> <body> <m…

抓包工具:Sunny网络中间件

Sunny网络中间件 和 Fiddler 类似。 是可跨平台的网络分析组件 可用于HTTP/HTTPS/WS/WSS/TCP/UDP网络分析 为二次开发量身制作 支持 获取/修改 HTTP/HTTPS/WS/WSS/TCP/TLS-TCP/UDP 发送及返回数据 支持 对 HTTP/HTTPS/WS/WSS 指定连接使用指定代理 支持 对 HTTP/HTTPS/WS/WSS/T…

Failed to connect to github.com port 443 after 21055 ms: Timed out

目前自己使用了梯*子还是会报这样的错误&#xff0c;连接不到的github。 查了一下原因&#xff1a; 是因为这个请求没有走代理。 解决方案&#xff1a; 设置 -> 网络和Internet -> 代理 -> 编辑 记住这个IP和端口 使用以下命令&#xff1a; git config --global h…

CentOS使用kkFileView实现在线预览word excel pdf等

一、环境安装 1、安装LibreOffice wget https://downloadarchive.documentfoundation.org/libreoffice/old/7.5.3.2/rpm/x86_64/LibreOffice_7.5.3.2_Linux_x86-64_rpm.tar.gz # 解压缩 tar -zxf LibreOffice_7.5.3.2_Linux_x86-64_rpm.tar cd LibreOffice_7.5.3.2_Linux_x86…

Backtrader 文档学习-Quickstart

Backtrader 文档学习-Quickstart 0. 前言 backtrader&#xff0c;功能十分完善&#xff0c;有完整的使用文档&#xff0c;安装相对简单&#xff08;直接pip安装即可&#xff09;。 优点是运行速度快&#xff0c;支持pandas的矢量运算&#xff1b;支持参数自动寻优运算&#x…

Pytest自动化测试用例中的断言详解

前言 测试的主要工作目标就是验证实际结果与预期结果是否一致&#xff1b;在接口自动化测试中&#xff0c;通过断言来实现这一目标。Pytest中断言是通过assert语句实现的&#xff08;pytest对Python原生的assert语句进行了优化&#xff09;&#xff0c;确定实际情况是否与预期一…

三天精通Selenium Web 自动化 - Selenium(Java)环境搭建 (new)

0 背景 开发工具idea代码管理mavenjdk1.8webdriver chrome 1 chromedriver & chrome chromedriver和chrome要对应上&#xff1a; chomedriver下载地址&#xff1a;淘宝镜像 这里用的是 chromedriver88-0-4324-96.zipchrome下载地址&#xff1a;如何降级和安装旧版本的C…

智能查券机器人:导购APP的新趋势

智能查券机器人&#xff1a;导购APP的新趋势 大家好&#xff0c;我是免费搭建查券返利机器人赚佣金就用微赚淘客系统3.0的小编&#xff0c;也是冬天不穿秋裤&#xff0c;天冷也要风度的程序猿&#xff01; 在当今这个数字化时代&#xff0c;网络购物已经成为人们日常生活的一部…