RK3568笔记四:基于TensorFlow花卉图像分类部署

news/2025/2/11 18:48:34/

若该文为原创文章,转载请注明原文出处。

基于正点原子的ATK-DLRK3568部署测试。

花卉图像分类任务,使用使用 tf.keras.Sequential 模型,简单构建模型,然后转换成 RKNN 模型部署到ATK-DLRK3568板子上。

在 PC 使用 Windows 系统安装 tensorflow,并创建虚拟环境进行训练,然后切换到VM下的RK3568环境,使用rknn-toolkit2把模型转成rknn模型部署到RK3568板子上测试。

一、介绍

       TensorFlow 是一个基于数据流编程(dataflow programming)的符号数学系统,被广泛应用于机器学习(machine learning)算法的编程实现,其前身是谷歌的神经网络算法库 DistBelief。

使用 tf.keras.Sequential 模型对花卉图像进行分类。

二、环境搭建

1、创建虚拟环境

 conda create -n tensorflow_env python=3.8 -y

2、激活环境

conda activate tensorflow_env

3、安装环境

pip install numpypip install tensorflowpip install pillow

三、训练

1、下载数据集

https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz

数据集不好下载,自行处理。

2、训练

tensorflow_classification.py

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential# 获取
import pathlib
#dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
#data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = './flower_photos'
data_dir = pathlib.Path(data_dir)batch_size = 32
img_height = 180
img_width = 180# 划分数据
train_ds = tf.keras.utils.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=123,image_size=(img_height, img_width),batch_size=batch_size)val_ds = tf.keras.utils.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=123,image_size=(img_height, img_width),batch_size=batch_size)class_names = train_ds.class_names
#print(class_names)# 处理数据
normalization_layer = layers.Rescaling(1./255)
train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y))
num_classes = len(class_names)data_augmentation = keras.Sequential([layers.RandomFlip("horizontal",input_shape=(img_height,img_width,3)),layers.RandomRotation(0.1),layers.RandomZoom(0.1),]
)model = Sequential([data_augmentation,layers.Conv2D(16, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Conv2D(32, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Conv2D(64, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Dropout(0.2),layers.Flatten(),layers.Dense(128, activation='relu'),layers.Dense(num_classes, name="outputs")
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])model.summary()# 训练模型
epochs=15
history = model.fit(train_ds,validation_data=val_ds,epochs=epochs,
)# 测试模型
#sunflower_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/592px-Red_sunflower.jpg"
#sunflower_path = tf.keras.utils.get_file('Red_sunflower', origin=sunflower_url)
sunflower_path = './test_180.jpg'img = tf.keras.utils.load_img(sunflower_path, target_size=(img_height, img_width)
)
img_array = tf.keras.utils.img_to_array(img)
img_array = tf.expand_dims(img_array, 0) # Create a batchpredictions = model.predict(img_array)
score = tf.nn.softmax(predictions[0])print("This image most likely belongs to {} with a {:.2f} percent confidence.".format(class_names[np.argmax(score)], 100 * np.max(score))
)# Convert the model.
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()# Save the model.
with open('model.tflite', 'wb') as f:f.write(tflite_model)

代码有点需要注意,代码屏蔽了下载的功能,所以需要预先下载数据集,如果没有下载数据集,就需要把下载的代码开启。

#dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
#data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)

执行下面命令开始训练:

python tensorflow_classification.py

等待一会,会生成model.tflite模型文件。

四、RKNN模型转换

转换代码通过下面代码:

rknn_transfer.py

import numpy as np
import cv2
from rknn.api import RKNN
import tensorflow as tfimg_height = 180
img_width = 180
IMG_PATH = 'test.jpg'
class_names = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']if __name__ == '__main__':# Create RKNN object#rknn = RKNN(verbose='Debug')rknn = RKNN()# Pre-process configprint('--> Config model')rknn.config(mean_values=[0, 0, 0], std_values=[255, 255, 255], target_platform='rk3568')print('done')# Load modelprint('--> Loading model')ret = rknn.load_tflite(model='model.tflite')if ret != 0:print('Load model failed!')exit(ret)print('done')# Build modelprint('--> Building model')ret = rknn.build(do_quantization=False)#ret = rknn.build(do_quantization=True,dataset='./dataset.txt')if ret != 0:print('Build model failed!')exit(ret)print('done')# Export rknn modelprint('--> Export rknn model')ret = rknn.export_rknn('./model.rknn')if ret != 0:print('Export rknn model failed!')exit(ret)print('done')#Init runtime environment
print('--> Init runtime environment')
ret = rknn.init_runtime()
#    if ret != 0:
#        print('Init runtime environment failed!')
#        exit(ret)
print('done')img = cv2.imread(IMG_PATH)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img,(180,180))
img = np.expand_dims(img, 0)#print('--> Accuracy analysis')
#rknn.accuracy_analysis(inputs=['./test.jpg'])
#print('done')print('--> Running model')
outputs = rknn.inference(inputs=[img])
print(outputs)
outputs = tf.nn.softmax(outputs)
print(outputs)print("This image most likely belongs to {} with a {:.2f} percent confidence.".format(class_names[np.argmax(outputs)], 100 * np.max(outputs))
)
#print("图像预测是:", class_names[np.argmax(outputs)])
print('--> done')rknn.release()

运行后会生成RKNN模型

五、部署

把rknnlite_inference.py和图片,及模型model.rknn拷贝到开发板上,终端运行即可。

rknnlite_inference.py源码:

import numpy as np
import cv2
from rknnlite.api import RKNNLiteIMG_PATH = 'test.jpg'
RKNN_MODEL = 'model.rknn'
img_height = 180
img_width = 180
class_names = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']# Create RKNN object
rknn_lite = RKNNLite()# load RKNN model
print('--> Load RKNN model')
ret = rknn_lite.load_rknn(RKNN_MODEL)
if ret != 0:print('Load RKNN model failed')exit(ret)
print('done')# Init runtime environment
print('--> Init runtime environment')
ret = rknn_lite.init_runtime()
if ret != 0:print('Init runtime environment failed!')exit(ret)
print('done')# load image
img = cv2.imread(IMG_PATH)
img = cv2.resize(img,(180,180))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.expand_dims(img, 0)# runing model
print('--> Running model')
outputs = rknn_lite.inference(inputs=[img])
print("result: ", outputs)
print("This image most likely belongs to {}.".format(class_names[np.argmax(outputs)])
)rknn_lite.release()

终端中执行:python rknnlite_inference.py

结果识别为sunflowers。

如有侵权,或需要完整代码,请及时联系博主。


http://www.ppmy.cn/news/1168111.html

相关文章

【C++ Primer Plus学习记录】复合类型总结

数组、结构和指针是C的3种复合类型。 数组可以在一个数据对象中存储多个同种类型的值。通过索引或者下标,可以访问数组中各个元素。 结构可以将多个不同类型的值存储在同一个数据对象中,可以使用成员运算符(.)来访问其中的成员。…

(1)攻防世界web-Training-WWW-Robots

1.开启环境,查看网页 翻译一下 2.前往robots.txt 命令:http://61.147.171.105:57663/robots.txt 3.前往fl0g.php 命令:http://61.147.171.105:57663/fl0g.php 4.得到flag cyberpeace{92ec1ef9b6d900100399093b9ae9e386}

Vue Router 刷新当前页面

Vue项目, 在实际工作中, 有些时候需要在 加载完某些数据之后对当前页面进行刷新, 以期 onMounted 等生命周期函数, 或者 数据重新加载. 总之是期望页面可以重新加载一次. 目前总结有三种途径可实现以上需求: 一, reload 直接刷新页面 window.location.reload(); $router.go(…

代码格式化的使用

前言 本文主要介绍了代码格式化,以及各个平台如何使用快捷键进行代码格式化,如有错误之处,欢迎在评论区交流讨论~ 代码格式化 代码格式化是一种编程实践,它涉及调整源代码的外观,以提高可读性和一致性。 这包括调整缩进、空格、换行符和括号等元素的使…

搭建哨兵架构(windows)

参考文章:Windows CMD常用命令大全(值得收藏)_cmd命令-CSDN博客 搭建哨兵架构:redis-server.exe sentinel.conf --sentinel 1.在主节点上创建哨兵配置 - 在Master对应redis.conf同目录下新建sentinel.conf文件,名字绝…

http post协议发送本地压缩数据到服务器

1.客户端程序 import requests import os # 指定服务器的URL url "http://192.168.1.9:8000/upload"# 压缩包文件路径 folder_name "upload" file_name "test.7z" headers {Folder-Name: folder_name,File-Name: file_name } # 发送POST请求…

笔记本电脑Windows10安装

0 前提 安装windows10的电脑为老版联想笔记本电脑,内部没有硬盘,临时加装了1T的硬盘。 1u盘准备 准备u盘,大小大于16G。u盘作为系统盘时,需要将内部的其他文件备份,然后格式化。u盘格式化后,插入一款可以…

软考高项-项目资源管理的相关概念

项目管理资源管理过程 规划资源管理 定义如何估算、获取、管理和利用实物以及团队项目资源 估算活动资源 估算执行项目所需的团队资源,材料、设备和用品的类型和数量 获取资源 获取项目所需的团队成员、设施、设备、材料、用品和其他资源。项目团队组建&#xff…