内容概述
首先获得一个合适的nanodet模型版本,配置nanodet适用的环境,然后对网上公开的生数据集进行重新标注,配置nanodet并进行训练,.pth到.onnx的模型转化及简化,编写推理文件。
文章着重于实践方向指引,对于其中细节需读者自行完成。
为什么使用nanodet
该模型虽然检测精度较低,但是其极轻量化、推理快速的特点非常适合放在算力及内存较低的边缘人工设备上。
这个是nanodet作者在github上的开源地址https://github.com/RangiLyu/nanodet,没有科学上网工具的话可以访问镜像站https://gitcode.net/mirrors/RangiLyu/nanodet.git。
选择合适的nanodet版本
由于我的显卡是GTX1650,算力较低,无法使用最新版的nanodet所使用的部分包,需要换一个老版本的nanodet,所以我使用的是最后一次更新于2021年5月的版本。
配置nanodet环境及进行训练
这一部分其实仔细阅读README.md的Demo到How to Train的部分就能整出来,安装的话没啥必要,直接修改源代码就行了。
配置nanodet适用的环境
我使用Anaconda进行环境创建,Anaconda是一个环境管理工具,可以理解为配置python解释器的工具,和使用pycharm自动安装相比,可以更好的对环境及包版本进行管理,相当方便,没有这个的同学可以在站内搜索相关安装及配置教程,也可以就通过pycharm进行自动的包安装。
在Anaconda上创建一个python==3.9.16的环境,使用activate <env>激活该环境后,将下载到的nanodet源代码中的requirement.txt复制下来,使用pip install <package>的命令尝试安装所有包,其中注意把requirement.txt中的torch和torchvision删去,因为安装torch和torchvision一定注意要与自己的CUDA版本匹配(如果使用pip自动安装会安装最新的,大概率是不和自己的CUDA匹配),并且这两个组件也要按照官网(pytorch官网)的推荐去装(用官网给的指令一遍装了就对了,注意torchvision==0.11.0+cu111只有linux版本,所以不要安装他以及他对应的torch)。
可以通过查看pycharm在换上了新配置的conda环境后是否提示缺少包及运行demo.py(nanodet-main\demo\demo.py)检查自己是否正确搭建环境,最后demo.py能正常运行就说明环境配置好了,注意parse_args()的参数是命令行参数,如果验证的话可以在Anaconda Prompt激活环境后cd到模型文件夹内的demo.py按照README.md的demo部分提供的参数进行验证即可。
这一块可以看下这位博主的验证部分http://t.csdn.cn/XGiuU
配置环境的详细步骤
以下是我在搭建环境时做的记录,可供参考,直接搬到自己的机器里很可能出问题,请按照自己的系统环境进行配置(安装的时候最好全拿pip或者全拿conda,我这样混合两种工具安装如果需要迁移环境可能会遇到pip list不全的情况)。
配置:win10,1650显卡
conda 22.9.0
cuda 11.1 (指令nvcc -V)先在anaconda prompt创建一个python==3.9.16 的环境
使用conda create -n py_3.12 Cython matplotlib numpy omegaconf=2.0.1 onnx onnx-simplifier opencv-python pyaml pycocotools pytorch-lightning=1.9.0 tabulate tensorboard termcolor torch=1.10 torchmetrics torchvision tqdm python=3.9
尝试安装所有包这部分conda没用自动找到,需要手动挨个安装
- onnx-simplifier
- opencv-python
- pytorch-lightning=1.9.0
- torch=1.10
- omegaconf==2.0.1
- pycocotools使用以下指令尝试安装:
conda install conda install -c necla-ml onnx-simplifier --channel https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/因为可能需要用pip,所以将这个现在建好的虚拟环境克隆,在新的克隆环境内继续安装包的操作
(克隆指令 conda create -n py_3.12_new --clone py_3.12 )使用以下指令尝试安装:
pip install opencv-python
conda install -c https://conda.anaconda.org/menpo opencv
conda install -c conda-forge pytorch-lightning
pip install omegaconf==2.0.1
conda install -c conda-forge pycocotools安装torch和torchvision一定注意要与自己的CUDA版本匹配,并且这两个组件也要按照官网的推荐去装(用官网给的指令一遍装了就对了,注意torchvision==0.11.0+cu111只有linux版本,所以不要安装他以及他对应的torch)
我的torch版本是v1.9.1(根据自己的电脑版本下载)
我使用的安装指令:
pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html如果遇到死机
原因是之前下载了pytorch,但是由于权限原因没有安装pytorch,需要做两点改动,在pip install之间加上--no-cache-dir,命令的最后加上--user
pip --no-cache-dir install torch==1.8.2+cu111 torchvision==0.9.2+cu111 torchaudio===0.8.2 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html --user
另外注意,我转发的那个代码需要GPU加速,如果没有GPU需要修改代码或者下载README.MD最尾端给出的cpu版本的地址
代码中给出的指令,即
'''目标检测-图片'''
python detect_main.py image --config ./config/nanodet-m.yml --model model/nanodet_m.pth --path street.png'''目标检测-视频文件'''
python detect_main.py video --config ./config/nanodet-m.yml --model model/nanodet_m.pth --path test.mp4'''目标检测-摄像头'''
python detect_main.py webcam --config ./config/nanodet-m.yml --model model/nanodet_m.pth --path 0
是命令行指令,在虚拟环境下cd到模型所在包将指令复制到命令行中就可以运行了(我测试不了第二个,应该是不能用)我的conda源等信息:
conda infoactive environment : py_3.12_new
active env location : D:\ANACONDA\envs\py_3.12_new
shell level : 2
user config file : C:\Users\lenovo\.condarc
populated config files : C:\Users\lenovo\.condarc
conda version : 22.9.0
conda-build version : 3.22.0
python version : 3.9.13.final.0
virtual packages : __cuda=11.2=0
__win=0=0
__archspec=1=x86_64
base environment : D:\ANACONDA (read only)
conda av data dir : D:\ANACONDA\etc\conda
conda av metadata url : None
channel URLs : https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/win-64
https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/noarch
https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r/win-64
https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r/noarch
https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2/win-64
https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2/noarch
package cache : D:\ANACONDA\pkgs
C:\Users\lenovo\.conda\pkgs
C:\Users\lenovo\AppData\Local\conda\conda\pkgs
envs directories : C:\Users\lenovo\.conda\envs
D:\ANACONDA\envs
C:\Users\lenovo\AppData\Local\conda\conda\envs
platform : win-64
user-agent : conda/22.9.0 requests/2.28.1 CPython/3.9.13 Windows/10 Windows/10.0.19044
administrator : False
netrc file : None
offline mode : False
我的虚拟环境包目录:
conda list
# packages in environment at D:\ANACONDA\envs\py_3.12_new:
#
# Name Version Build Channel
absl-py 1.3.0 py39haa95532_0 defaults
aiohttp 3.8.3 py39h2bbff1b_0 defaults
aiosignal 1.2.0 pyhd3eb1b0_0 defaults
async-timeout 4.0.2 py39haa95532_0 defaults
attrs 22.1.0 py39haa95532_0 defaults
blas 1.0 mkl defaults
blinker 1.4 py39haa95532_0 defaults
brotli 1.0.9 h2bbff1b_7 defaults
brotli-bin 1.0.9 h2bbff1b_7 defaults
brotlipy 0.7.0 py39h2bbff1b_1003 defaults
ca-certificates 2023.01.10 haa95532_0 defaults
cachetools 4.2.2 pyhd3eb1b0_0 defaults
certifi 2022.12.7 py39haa95532_0 defaults
cffi 1.15.1 py39h2bbff1b_3 defaults
charset-normalizer 2.0.4 pyhd3eb1b0_0 defaults
click 8.0.4 py39haa95532_0 defaults
colorama 0.4.6 py39haa95532_0 defaults
contourpy 1.0.5 py39h59b6b97_0 defaults
cryptography 39.0.1 py39h21b164f_0 defaults
cycler 0.11.0 pyhd3eb1b0_0 defaults
cython 0.29.33 py39hd77b12b_0 defaults
flit-core 3.6.0 pyhd3eb1b0_0 defaults
fonttools 4.25.0 pyhd3eb1b0_0 defaults
freetype 2.12.1 ha860e81_0 defaults
frozenlist 1.3.3 py39h2bbff1b_0 defaults
future 0.18.3 py39haa95532_0 defaults
giflib 5.2.1 h8cc25b3_3 defaults
glib 2.69.1 h5dc1a3c_2 defaults
google-auth 2.6.0 pyhd3eb1b0_0 defaults
google-auth-oauthlib 0.4.4 pyhd3eb1b0_0 defaults
grpcio 1.42.0 py39hc60d5dd_0 defaults
gst-plugins-base 1.18.5 h9e645db_0 defaults
gstreamer 1.18.5 hd78058f_0 defaults
icu 58.2 ha925a31_3 defaults
idna 3.4 py39haa95532_0 defaults
importlib-metadata 4.11.3 py39haa95532_0 defaults
importlib_resources 5.2.0 pyhd3eb1b0_1 defaults
intel-openmp 2021.4.0 haa95532_3556 defaults
jpeg 9e h2bbff1b_1 defaults
kiwisolver 1.4.4 py39hd77b12b_0 defaults
lerc 3.0 hd77b12b_0 defaults
libbrotlicommon 1.0.9 h2bbff1b_7 defaults
libbrotlidec 1.0.9 h2bbff1b_7 defaults
libbrotlienc 1.0.9 h2bbff1b_7 defaults
libclang 12.0.0 default_h627e005_2 defaults
libdeflate 1.17 h2bbff1b_0 defaults
libffi 3.4.2 hd77b12b_6 defaults
libiconv 1.16 h2bbff1b_2 defaults
libogg 1.3.5 h2bbff1b_1 defaults
libpng 1.6.39 h8cc25b3_0 defaults
libprotobuf 3.20.3 h23ce68f_0 defaults
libtiff 4.5.0 h6c2663c_2 defaults
libuv 1.44.2 h2bbff1b_0 defaults
libvorbis 1.3.7 he774522_0 defaults
libwebp 1.2.4 hbc33d0d_1 defaults
libwebp-base 1.2.4 h2bbff1b_1 defaults
libxml2 2.9.14 h0ad7f3c_0 defaults
libxslt 1.1.35 h2bbff1b_0 defaults
lz4-c 1.9.4 h2bbff1b_0 defaults
markdown 3.4.1 py39haa95532_0 defaults
markdown-it-py 2.2.0 pypi_0 pypi
markupsafe 2.1.1 py39h2bbff1b_0 defaults
matplotlib 3.7.0 py39haa95532_0 defaults
matplotlib-base 3.7.0 py39hf11a4ad_0 defaults
mdurl 0.1.2 pypi_0 pypi
mkl 2021.4.0 haa95532_640 defaults
mkl-service 2.4.0 py39h2bbff1b_0 defaults
mkl_fft 1.3.1 py39h277e83a_0 defaults
mkl_random 1.2.2 py39hf11a4ad_0 defaults
multidict 6.0.2 py39h2bbff1b_0 defaults
munkres 1.1.4 py_0 defaults
ninja 1.10.2 haa95532_5 defaults
ninja-base 1.10.2 h6d14046_5 defaults
numpy 1.23.5 py39h3b20f71_0 defaults
numpy-base 1.23.5 py39h4da318b_0 defaults
oauthlib 3.2.1 py39haa95532_0 defaults
onnx 1.13.0 py39h9724e47_0 defaults
onnx-simplifier 0.4.17 pypi_0 pypi
opencv-contrib-python 4.7.0.72 pypi_0 pypi
opencv-python 4.7.0.72 pypi_0 pypi
openssl 1.1.1t h2bbff1b_0 defaults
packaging 22.0 py39haa95532_0 defaults
pcre 8.45 hd77b12b_0 defaults
pillow 9.4.0 py39hd77b12b_0 defaults
pip 23.0.1 py39haa95532_0 defaults
ply 3.11 py39haa95532_0 defaults
protobuf 3.20.3 py39hd77b12b_0 defaults
pyaml 20.4.0 pyhd3eb1b0_0 defaults
pyasn1 0.4.8 pyhd3eb1b0_0 defaults
pyasn1-modules 0.2.8 py_0 defaults
pycocotools 2.0.4 py39h5d4886f_1 conda-forge
pycparser 2.21 pyhd3eb1b0_0 defaults
pygments 2.14.0 pypi_0 pypi
pyjwt 2.4.0 py39haa95532_0 defaults
pyopenssl 23.0.0 py39haa95532_0 defaults
pyparsing 3.0.9 py39haa95532_0 defaults
pyqt 5.15.7 py39hd77b12b_0 defaults
pyqt5-sip 12.11.0 py39hd77b12b_0 defaults
pysocks 1.7.1 py39haa95532_0 defaults
python 3.9.16 h6244533_1 defaults
python-dateutil 2.8.2 pyhd3eb1b0_0 defaults
python_abi 3.9 2_cp39 conda-forge
pyyaml 6.0 py39h2bbff1b_1 defaults
qt-main 5.15.2 he8e5bd7_7 defaults
qt-webengine 5.15.9 hb9a9bb5_5 defaults
qtwebkit 5.212 h3ad3cdb_4 defaults
requests 2.28.1 py39haa95532_0 defaults
requests-oauthlib 1.3.0 py_0 defaults
rich 13.3.2 pypi_0 pypi
rsa 4.7.2 pyhd3eb1b0_1 defaults
setuptools 65.6.3 py39haa95532_0 defaults
sip 6.6.2 py39hd77b12b_0 defaults
six 1.16.0 pyhd3eb1b0_1 defaults
sqlite 3.40.1 h2bbff1b_0 defaults
tabulate 0.8.10 py39haa95532_0 defaults
tensorboard 2.10.0 py39haa95532_0 defaults
tensorboard-data-server 0.6.1 py39haa95532_0 defaults
tensorboard-plugin-wit 1.8.1 py39haa95532_0 defaults
termcolor 2.1.0 py39haa95532_0 defaults
tk 8.6.12 h2bbff1b_0 defaults
toml 0.10.2 pyhd3eb1b0_0 defaults
torch 1.9.1+cu111 pypi_0 pypi
torchaudio 0.9.1 pypi_0 pypi
torchvision 0.10.1+cu111 pypi_0 pypi
tornado 6.2 py39h2bbff1b_0 defaults
tqdm 4.64.1 py39haa95532_0 defaults
typing-extensions 4.4.0 py39haa95532_0 defaults
typing_extensions 4.4.0 py39haa95532_0 defaults
tzdata 2022g h04d1e81_0 defaults
urllib3 1.26.14 py39haa95532_0 defaults
vc 14.2 h21ff451_1 defaults
vs2015_runtime 14.27.29016 h5e58377_2 defaults
werkzeug 2.2.2 py39haa95532_0 defaults
wheel 0.38.4 py39haa95532_0 defaults
win_inet_pton 1.1.0 py39haa95532_0 defaults
wincertstore 0.2 py39haa95532_2 defaults
xz 5.2.10 h8cc25b3_1 defaults
yaml 0.2.5 he774522_0 defaults
yarl 1.8.1 py39h2bbff1b_0 defaults
zipp 3.11.0 py39haa95532_0 defaults
zlib 1.2.13 h8cc25b3_0 defaults
zstd 1.5.2 h19a0ad4_0 defaultspip list
Package Version
------------------------- ------------
absl-py 1.3.0
aiohttp 3.8.3
aiosignal 1.2.0
altgraph 0.17.3
async-timeout 4.0.2
attrs 22.1.0
blinker 1.4
brotlipy 0.7.0
cachetools 4.2.2
certifi 2022.12.7
cffi 1.15.1
charset-normalizer 2.0.4
click 8.0.4
colorama 0.4.6
coloredlogs 15.0.1
contourpy 1.0.5
cryptography 39.0.1
cycler 0.11.0
Cython 0.29.33
flatbuffers 23.3.3
flit_core 3.6.0
fonttools 4.25.0
frozenlist 1.3.3
future 0.18.3
google-auth 2.6.0
google-auth-oauthlib 0.4.4
grpcio 1.42.0
humanfriendly 10.0
idna 3.4
importlib-metadata 4.11.3
importlib-resources 5.2.0
kiwisolver 1.4.4
Markdown 3.4.1
markdown-it-py 2.2.0
MarkupSafe 2.1.1
matplotlib 3.7.0
mdurl 0.1.2
mkl-fft 1.3.1
mkl-random 1.2.2
mkl-service 2.4.0
mpmath 1.3.0
multidict 6.0.2
munkres 1.1.4
numpy 1.23.5
oauthlib 3.2.1
omegaconf 2.0.1
onnx 1.13.0
onnx-simplifier 0.4.17
onnxruntime 1.14.1
opencv-contrib-python 4.7.0.72
opencv-python 4.7.0.72
packaging 22.0
pefile 2023.2.7
Pillow 9.4.0
pip 23.0.1
ply 3.11
protobuf 3.20.3
pyaml 20.4.0
pyasn1 0.4.8
pyasn1-modules 0.2.8
pycocotools 2.0.4
pycparser 2.21
Pygments 2.14.0
pyinstaller 5.8.0
pyinstaller-hooks-contrib 2023.0
PyJWT 2.4.0
pyOpenSSL 23.0.0
pyparsing 3.0.9
PyQt5 5.15.7
PyQt5-sip 12.11.0
pyreadline3 3.4.1
PySocks 1.7.1
python-dateutil 2.8.2
pywin32-ctypes 0.2.0
PyYAML 6.0
requests 2.28.1
requests-oauthlib 1.3.0
rich 13.3.2
rsa 4.7.2
setuptools 65.6.3
sip 6.6.2
six 1.16.0
sympy 1.11.1
tabulate 0.8.10
tensorboard 2.10.0
tensorboard-data-server 0.6.1
tensorboard-plugin-wit 1.8.1
termcolor 2.1.0
toml 0.10.2
torch 1.9.1+cu111
torchaudio 0.9.1
torchvision 0.10.1+cu111
tornado 6.2
tqdm 4.64.1
typing_extensions 4.4.0
urllib3 1.26.14
Werkzeug 2.2.2
wheel 0.38.4
win-inet-pton 1.1.0
wincertstore 0.2
yarl 1.8.1
zipp 3.11.0
重新标注数据集
使用labelimg,我拿他做了xml(一种标签格式)的重新标注,这个工具使用比较简单,缺点是没有批处理。注意标注的时候要把图片文件和xml标签文件放在一个文件夹内,而且两者的前缀名要一致,这样才会自动自动把已标注的标签框显示出来。
具体可以看这位博主的文章http://t.csdn.cn/cegAT
配置nanodet准备进行训练
查看README.md的How to Train,按照提示修改config并使用适合的train.py提示就好
首先要修改config文件:
我的数据集是xml格式的,所以选择修改nanodet_custom_xml_dataset.yml的config文件,有同学可能会问为啥这么多这都是啥,其实就是不同的nanodet子模型,可以看看Model Zoo挑选一个自己喜欢的。
打开选好的config,按照README.md修改参数,这一部分也可以参照这位博主的训练部分http://t.csdn.cn/aVE2D,然后我再补充一些,其实可以用xml的数据集训练,拿nanodet_custom_xml_dataset.yml就行,应该是这位博主用的nanodet的版本还没更新出这个文件;“data:train:name:”不能随便修改,他表示了训练数据的数据格式,比如nanodet_custom_xml_dataset.yml的这一参数是xml_dataset;“schedule:device:gpu_ids”是GPU的数量,0表示1块GPU,如果你用的是服务器,把这个参数设置为服务器的GPU个数;“schedule:device:workers_per_gpu:”和“schedule:device:batchsize_per_gpu:”设置了每块GPU的工作量,要是显卡不太好就设置的小一点,大了会报错,比如我的1650设置的是1和36;“data:val:”里的"val"是验证集,一般和训练集不能重合。
config设置好后就可以用train.py训练了,在我使用的这个版本单GPU使用tools下的train.py会报错,换成tools/deprecated/train.py进行训练就好了。
训练时会先对数据集进行索引,然后进行训练,如果你看到他先是跳出一大串像是训练的代码然后过几秒突然报错可能是只完成了索引而在训练时出问题,按照报错检查下config,大概率是设置有问题。训练完成后会获得一个.pth文件(新版nanodet应该是生成.cpth),放在你在config里设置的sav_dir里。
.pth模型向.onnx转化及简化.onnx
因为原代码只提供.cpth转.onnx,所以需要修改tools/export.py,核心函数使用torch.onnx.export()即可,不需要修改太多。
import os
import argparse
import torch
from nanodet.model.arch import build_model
from nanodet.util import Logger, cfg, load_config, load_model_weightdef generate_ouput_names(head_cfg):cls_names, dis_names = [], []for stride in head_cfg.strides:cls_names.append('cls_pred_stride_{}'.format(stride))dis_names.append('dis_pred_stride_{}'.format(stride))return cls_names + dis_namesdef main(config, model_path, output_path, input_shape=(320, 320)):logger = Logger(-1, config.save_dir, False)model = build_model(config.model)checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)load_model_weight(model, checkpoint, logger)dummy_input = torch.autograd.Variable(torch.randn(1, 3, input_shape[0], input_shape[1]))torch.onnx.export(model,dummy_input,output_path,verbose=True,keep_initializers_as_inputs=True,opset_version=10)logger.log('finished exporting onnx ')if __name__ =='__main__':cfg_path = r"D:\pythonProject\nanodet-main\config\nanodet_custom_xml_dataset.yml" #config pathmodel_path=r"D:\pythonProject\nanodet-main\workspace\nanodet_m_new_2\model_last.pth"out_path = r"D:\pythonProject\nanodet-main\workspace\nanodet_m_new_2\output_my.onnx"#output model pathload_config(cfg,cfg_path)main(cfg,model_path,out_path,input_shape=(320,320) )#根据自己的输入尺寸修改
转化后可以再使用onnx-simplifier简化,相关教程搜索这个包名就行,不过我的结果是简化前后没变化。
这部分也可以查看这篇博文http://t.csdn.cn/kCHJ8,不过要付费。
部署文件
这部分这位博主写的非常好,大家可以先看他的http://t.csdn.cn/KfbnS,作者使用opencv-python==4.5.2.52进行了部署,所以部署环境只要是能满足opencv的要求就行了(不要用最新的版本,会报错,python版本和numpy跟着opencv装,我的python是3.9.16),作者实现了对于图片的检测,我在博主的基础上做了点修改,改为逐帧检测视频,下面两段代码分别是将结果保存到.json文件中及直接可视化
结果保存到.json
import cv2
import numpy as np
import argparse
classes = ['helmet', 'with_mask', 'reflective_clothes']
resultsW= []class my_nanodet():def __init__(self,model, input_shape=320, prob_threshold=0.4, iou_threshold=0.3):self.classes = classesself.num_classes = len(self.classes)self.strides = (8, 16, 32)self.input_shape = (input_shape, input_shape)self.reg_max = 7self.prob_threshold = prob_thresholdself.iou_threshold = iou_thresholdself.project = np.arange(self.reg_max + 1)self.mean = np.array([103.53, 116.28, 123.675], dtype=np.float32).reshape(1, 1, 3)self.std = np.array([57.375, 57.12, 58.395], dtype=np.float32).reshape(1, 1, 3)self.net = cv2.dnn.readNet(model)self.mlvl_anchors = []for i in range(len(self.strides)):anchors = self._make_grid((int(self.input_shape[0] / self.strides[i]), int(self.input_shape[1] / self.strides[i])), self.strides[i])self.mlvl_anchors.append(anchors)def _make_grid(self, featmap_size, stride):feat_h, feat_w = featmap_sizeshift_x = np.arange(0, feat_w) * strideshift_y = np.arange(0, feat_h) * stridexv, yv = np.meshgrid(shift_x, shift_y)xv = xv.flatten()yv = yv.flatten()cx = xv + 0.5 * (stride-1)cy = yv + 0.5 * (stride - 1)return np.stack((cx, cy), axis=-1)def softmax(self,x, axis=1):x_exp = np.exp(x)# 如果是列向量,则axis=0x_sum = np.sum(x_exp, axis=axis, keepdims=True)s = x_exp / x_sumreturn sdef _normalize(self, img): ### c++: https://blog.csdn.net/wuqingshan2010/article/details/107727909img = img.astype(np.float32)img = (img - self.mean) / self.stdreturn imgdef resize_image(self, srcimg, keep_ratio=True):top, left, newh, neww = 0, 0, self.input_shape[0], self.input_shape[1]if keep_ratio and srcimg.shape[0] != srcimg.shape[1]:hw_scale = srcimg.shape[0] / srcimg.shape[1]if hw_scale > 1:newh, neww = self.input_shape[0], int(self.input_shape[1] / hw_scale)img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA)left = int((self.input_shape[1] - neww) * 0.5)img = cv2.copyMakeBorder(img, 0, 0, left, self.input_shape[1] - neww - left, cv2.BORDER_CONSTANT,value=0) # add borderelse:newh, neww = int(self.input_shape[0] * hw_scale), self.input_shape[1]img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA)top = int((self.input_shape[0] - newh) * 0.5)img = cv2.copyMakeBorder(img, top, self.input_shape[0] - newh - top, 0, 0, cv2.BORDER_CONSTANT, value=0)else:img = cv2.resize(srcimg, self.input_shape, interpolation=cv2.INTER_AREA)return img, newh, neww, top, leftdef detect(self, srcimg):img, newh, neww, top, left = self.resize_image(srcimg)img = self._normalize(img)blob = cv2.dnn.blobFromImage(img)# Sets the input to the networkself.net.setInput(blob)# Runs the forward pass to get output of the output layersouts = self.net.forward(self.net.getUnconnectedOutLayersNames())det_bboxes, det_conf, det_classid = self.post_process(outs)drawimg = srcimg.copy()ratioh,ratiow = srcimg.shape[0]/newh,srcimg.shape[1]/newwanchor = []label = []#修改为.json写入 #det_conf[i]是置信度for i in range(det_bboxes.shape[0]):xmin, ymin, xmax, ymax = max(int((det_bboxes[i,0] - left) * ratiow), 0), max(int((det_bboxes[i,1] - top) * ratioh), 0), min(int((det_bboxes[i,2] - left) * ratiow), srcimg.shape[1]), min(int((det_bboxes[i,3] - top) * ratioh), srcimg.shape[0])#print(det_classid[i], det_conf[i])anchor.append([xmin, ymin, xmax, ymax])label.append(classes[det_classid[i]])return anchor,labeldef post_process(self, preds):cls_scores, bbox_preds = preds[::2], preds[1::2]det_bboxes, det_conf, det_classid = self.get_bboxes_single(cls_scores, bbox_preds, 1, rescale=False)return det_bboxes.astype(np.int32), det_conf, det_classiddef get_bboxes_single(self, cls_scores, bbox_preds, scale_factor, rescale=False):mlvl_bboxes = []mlvl_scores = []for stride, cls_score, bbox_pred, anchors in zip(self.strides, cls_scores, bbox_preds, self.mlvl_anchors):if cls_score.ndim==3:cls_score = cls_score.squeeze(axis=0)if bbox_pred.ndim==3:bbox_pred = bbox_pred.squeeze(axis=0)bbox_pred = self.softmax(bbox_pred.reshape(-1, self.reg_max + 1), axis=1)# bbox_pred = np.sum(bbox_pred * np.expand_dims(self.project, axis=0), axis=1).reshape((-1, 4))bbox_pred = np.dot(bbox_pred, self.project).reshape(-1,4)bbox_pred *= stride# nms_pre = cfg.get('nms_pre', -1)nms_pre = 1000if nms_pre > 0 and cls_score.shape[0] > nms_pre:max_scores = cls_score.max(axis=1)topk_inds = max_scores.argsort()[::-1][0:nms_pre]anchors = anchors[topk_inds, :]bbox_pred = bbox_pred[topk_inds, :]cls_score = cls_score[topk_inds, :]bboxes = self.distance2bbox(anchors, bbox_pred, max_shape=self.input_shape)mlvl_bboxes.append(bboxes)mlvl_scores.append(cls_score)mlvl_bboxes = np.concatenate(mlvl_bboxes, axis=0)if rescale:mlvl_bboxes /= scale_factormlvl_scores = np.concatenate(mlvl_scores, axis=0)bboxes_wh = mlvl_bboxes.copy()bboxes_wh[:, 2:4] = bboxes_wh[:, 2:4] - bboxes_wh[:, 0:2] ####xywhclassIds = np.argmax(mlvl_scores, axis=1)confidences = np.max(mlvl_scores, axis=1) ####max_class_confidenceindices = cv2.dnn.NMSBoxes(bboxes_wh.tolist(), confidences.tolist(), self.prob_threshold, self.iou_threshold)if len(indices)>0:mlvl_bboxes = mlvl_bboxes[indices[:, 0]]confidences = confidences[indices[:, 0]]classIds = classIds[indices[:, 0]]return mlvl_bboxes, confidences, classIdselse:print('nothing detect')return np.array([]), np.array([]), np.array([])def distance2bbox(self, points, distance, max_shape=None):x1 = points[:, 0] - distance[:, 0]y1 = points[:, 1] - distance[:, 1]x2 = points[:, 0] + distance[:, 2]y2 = points[:, 1] + distance[:, 3]if max_shape is not None:x1 = np.clip(x1, 0, max_shape[1])y1 = np.clip(y1, 0, max_shape[0])x2 = np.clip(x2, 0, max_shape[1])y2 = np.clip(y2, 0, max_shape[0])return np.stack([x1, y1, x2, y2], axis=-1)def drawPred(self, frame, classId, conf, left, top, right, bottom):# Draw a bounding box.cv2.rectangle(frame, (left, top), (right, bottom), (0, 0, 255), thickness=4)label = '%.2f' % conflabel = '%s:%s' % (self.classes[classId], label)# Display the label at the top of the bounding boxlabelSize, baseLine = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)top = max(top, labelSize[1])# cv.rectangle(frame, (left, top - round(1.5 * labelSize[1])), (left + round(1.5 * labelSize[0]), top + baseLine), (255,255,255), cv.FILLED)cv2.putText(frame, label, (left, top - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), thickness=2)return frameif __name__=='__main__':parser = argparse.ArgumentParser()parser.add_argument('--video_path', type=str, default='检测视频.avi', help="video path")parser.add_argument('--model_path', type=str, default='new.onnx', help='(.onnx)model path')parser.add_argument('--sav_path', type=str, default='result.json', help='(.json)result file path')args = parser.parse_args()args.input_shape = 320 #input image shapeargs.confThreshold= 0.37 #class confidenceargs.nmsThreshold = 0.6 #nms iou threshimport jsonfile = open(args.sav_path, 'w')capure = cv2.VideoCapture(args.video_path)#get videonet = my_nanodet(input_shape=args.input_shape, prob_threshold=args.confThreshold, iou_threshold=args.nmsThreshold,model=args.model_path)import time#推理if capure.isOpened():while True:ret,img = capure.read()if not ret: break#获取到图像a = time.time()anchor,label = net.detect(img)b = time.time()ti = b - adata = {"time":ti*1000,"anchor":anchor,"label":label}resultsW.append(data)#写入文件with open(args.sav_path, 'w') as f:json.dump(resultsW, f)else:print('open video error')
可视化
import cv2
import numpy as np
import argparse
import time
import logging
import threadingclasses = ['helmet', 'with_mask', 'reflective_clothes']
resultsW= []def grab_img(cam):"""This 'grab_img' function is designed to be run in the sub-thread.Once started, this thread continues to grab a new image and put itinto the global 'img_handle', until 'thread_running' is set to False."""while cam.thread_running:_, cam.img_handle = cam.cap.read()fps = cam.cap.get(cv2.CAP_PROP_FPS)time.sleep(1/fps) # fps = 20hz#print('time sleep ', 1/fps)if cam.img_handle is None:logging.warning('grab_img(): cap.read() returns None...')breakcam.thread_running = Falseclass Camera():"""Camera class which supports reading images from this video sources:Video file"""def __init__(self, args):self.args = argsself.is_opened = Falseself.thread_running = Falseself.img_handle = Noneself.img_width = 0self.img_height = 0self.cap = Noneself.thread = Nonedef open(self):args = self.args#视频读取self.cap = cv2.VideoCapture(args.filename)# ignore image width/height settings hereif self.cap != 'OK':if self.cap.isOpened():# Try to grab the 1st image and determine width and height_, img = self.cap.read()if img is not None:self.img_height, self.img_width, _ = img.shapeself.is_opened = Truedef start(self):assert not self.thread_runningself.thread_running = Trueself.thread = threading.Thread(target=grab_img, args=(self,))self.thread.start()def stop(self):self.thread_running = Falseself.thread.join()def read(self):self.img_handle.shape = [self.img_height,self.img_width]return self.img_handledef release(self):assert not self.thread_runningif self.cap != 'OK':self.cap.release()class my_nanodet():def __init__(self,model, input_shape=320, prob_threshold=0.4, iou_threshold=0.3):self.classes = classesself.num_classes = len(self.classes)self.strides = (8, 16, 32)self.input_shape = (input_shape, input_shape)self.reg_max = 7self.prob_threshold = prob_thresholdself.iou_threshold = iou_thresholdself.project = np.arange(self.reg_max + 1)self.mean = np.array([103.53, 116.28, 123.675], dtype=np.float32).reshape(1, 1, 3)self.std = np.array([57.375, 57.12, 58.395], dtype=np.float32).reshape(1, 1, 3)self.net = cv2.dnn.readNet(model)self.mlvl_anchors = []for i in range(len(self.strides)):#print(type(self.input_shape[0]))anchors = self._make_grid((int(self.input_shape[0] / self.strides[i]), int(self.input_shape[1] / self.strides[i])), self.strides[i])self.mlvl_anchors.append(anchors)def _make_grid(self, featmap_size, stride):feat_h, feat_w = featmap_sizeshift_x = np.arange(0, feat_w) * strideshift_y = np.arange(0, feat_h) * stridexv, yv = np.meshgrid(shift_x, shift_y)xv = xv.flatten()yv = yv.flatten()cx = xv + 0.5 * (stride-1)cy = yv + 0.5 * (stride - 1)return np.stack((cx, cy), axis=-1)def softmax(self,x, axis=1):x_exp = np.exp(x)# 如果是列向量,则axis=0x_sum = np.sum(x_exp, axis=axis, keepdims=True)s = x_exp / x_sumreturn sdef _normalize(self, img): ### c++: https://blog.csdn.net/wuqingshan2010/article/details/107727909img = img.astype(np.float32)img = (img - self.mean) / self.stdreturn imgdef resize_image(self, srcimg, keep_ratio=True):top, left, newh, neww = 0, 0, self.input_shape[0], self.input_shape[1]if keep_ratio and srcimg.shape[0] != srcimg.shape[1]:hw_scale = srcimg.shape[0] / srcimg.shape[1]if hw_scale > 1:newh, neww = self.input_shape[0], int(self.input_shape[1] / hw_scale)img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA)left = int((self.input_shape[1] - neww) * 0.5)img = cv2.copyMakeBorder(img, 0, 0, left, self.input_shape[1] - neww - left, cv2.BORDER_CONSTANT,value=0) # add borderelse:newh, neww = int(self.input_shape[0] * hw_scale), self.input_shape[1]img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA)top = int((self.input_shape[0] - newh) * 0.5)img = cv2.copyMakeBorder(img, top, self.input_shape[0] - newh - top, 0, 0, cv2.BORDER_CONSTANT, value=0)else:img = cv2.resize(srcimg, self.input_shape, interpolation=cv2.INTER_AREA)return img, newh, neww, top, leftdef detect(self, srcimg):img, newh, neww, top, left = self.resize_image(srcimg)img = self._normalize(img)blob = cv2.dnn.blobFromImage(img)# Sets the input to the networkself.net.setInput(blob)# Runs the forward pass to get output of the output layersouts = self.net.forward(self.net.getUnconnectedOutLayersNames())det_bboxes, det_conf, det_classid = self.post_process(outs)drawimg = srcimg.copy()ratioh,ratiow = srcimg.shape[0]/newh,srcimg.shape[1]/newwanchor = []label = []#修改为.json写入 #det_conf[i]是置信度for i in range(det_bboxes.shape[0]):xmin, ymin, xmax, ymax = max(int((det_bboxes[i,0] - left) * ratiow), 0), max(int((det_bboxes[i,1] - top) * ratioh), 0), min(int((det_bboxes[i,2] - left) * ratiow), srcimg.shape[1]), min(int((det_bboxes[i,3] - top) * ratioh), srcimg.shape[0])#print(det_classid[i], det_conf[i])anchor.append([xmin, ymin, xmax, ymax])label.append(classes[det_classid[i]])#print(anchor[-1],label[-1])self.drawPred(drawimg, det_classid[i], det_conf[i], xmin, ymin, xmax, ymax)return drawimgdef post_process(self, preds):cls_scores, bbox_preds = preds[::2], preds[1::2]det_bboxes, det_conf, det_classid = self.get_bboxes_single(cls_scores, bbox_preds, 1, rescale=False)return det_bboxes.astype(np.int32), det_conf, det_classiddef get_bboxes_single(self, cls_scores, bbox_preds, scale_factor, rescale=False):mlvl_bboxes = []mlvl_scores = []for stride, cls_score, bbox_pred, anchors in zip(self.strides, cls_scores, bbox_preds, self.mlvl_anchors):if cls_score.ndim==3:cls_score = cls_score.squeeze(axis=0)if bbox_pred.ndim==3:bbox_pred = bbox_pred.squeeze(axis=0)bbox_pred = self.softmax(bbox_pred.reshape(-1, self.reg_max + 1), axis=1)# bbox_pred = np.sum(bbox_pred * np.expand_dims(self.project, axis=0), axis=1).reshape((-1, 4))bbox_pred = np.dot(bbox_pred, self.project).reshape(-1,4)bbox_pred *= stride# nms_pre = cfg.get('nms_pre', -1)nms_pre = 1000if nms_pre > 0 and cls_score.shape[0] > nms_pre:max_scores = cls_score.max(axis=1)topk_inds = max_scores.argsort()[::-1][0:nms_pre]anchors = anchors[topk_inds, :]bbox_pred = bbox_pred[topk_inds, :]cls_score = cls_score[topk_inds, :]bboxes = self.distance2bbox(anchors, bbox_pred, max_shape=self.input_shape)mlvl_bboxes.append(bboxes)mlvl_scores.append(cls_score)mlvl_bboxes = np.concatenate(mlvl_bboxes, axis=0)if rescale:mlvl_bboxes /= scale_factormlvl_scores = np.concatenate(mlvl_scores, axis=0)bboxes_wh = mlvl_bboxes.copy()bboxes_wh[:, 2:4] = bboxes_wh[:, 2:4] - bboxes_wh[:, 0:2] ####xywhclassIds = np.argmax(mlvl_scores, axis=1)confidences = np.max(mlvl_scores, axis=1) ####max_class_confidenceindices = cv2.dnn.NMSBoxes(bboxes_wh.tolist(), confidences.tolist(), self.prob_threshold, self.iou_threshold)if len(indices)>0:mlvl_bboxes = mlvl_bboxes[indices[:, 0]]confidences = confidences[indices[:, 0]]classIds = classIds[indices[:, 0]]return mlvl_bboxes, confidences, classIdselse:print('nothing detect')return np.array([]), np.array([]), np.array([])def distance2bbox(self, points, distance, max_shape=None):x1 = points[:, 0] - distance[:, 0]y1 = points[:, 1] - distance[:, 1]x2 = points[:, 0] + distance[:, 2]y2 = points[:, 1] + distance[:, 3]if max_shape is not None:x1 = np.clip(x1, 0, max_shape[1])y1 = np.clip(y1, 0, max_shape[0])x2 = np.clip(x2, 0, max_shape[1])y2 = np.clip(y2, 0, max_shape[0])return np.stack([x1, y1, x2, y2], axis=-1)def drawPred(self, frame, classId, conf, left, top, right, bottom):# Draw a bounding box.cv2.rectangle(frame, (left, top), (right, bottom), (0, 0, 255), thickness=4)label = '%.2f' % conflabel = '%s:%s' % (self.classes[classId], label)# Display the label at the top of the bounding boxlabelSize, baseLine = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)top = max(top, labelSize[1])# cv.rectangle(frame, (left, top - round(1.5 * labelSize[1])), (left + round(1.5 * labelSize[0]), top + baseLine), (255,255,255), cv.FILLED)cv2.putText(frame, label, (left, top - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), thickness=2)return frame#使用多进程
def main_thread():parser = argparse.ArgumentParser()parser.add_argument('--video_path', type=str, default='检测视频.avi', help="video path")parser.add_argument('--model_path', type=str, default='new.onnx', help='(.onnx)model')args = parser.parse_args()args.filename = args.video_pathargs.input_shape = 320 #input image shapeargs.confThreshold= 0.37 #class confidenceargs.nmsThreshold = 0.6 #nms iou threshimport jsonfile = open('test.json', 'w')net = my_nanodet(input_shape=args.input_shape, prob_threshold=args.confThreshold, iou_threshold=args.nmsThreshold,model=args.model_path)cam = Camera(args) # get videocam.open()cam.start()import time# 推理while cam.thread_running:img = cam.read()a = time.time()srcimg = net.detect(img)b = time.time()time_w = b - aprint('waste time', time_w * 1000)cv2.imshow('video test', srcimg)cv2.destroyAllWindows()def main_one():parser = argparse.ArgumentParser()parser.add_argument('--video_path', type=str, default='检测视频.avi', help="video path")parser.add_argument('--model_path', type=str, default='new.onnx', help='(.onnx)model')args = parser.parse_args()args.input_shape = 320 #input image shapeargs.confThreshold= 0.37 #class confidenceargs.nmsThreshold = 0.6 #nms iou threshimport jsonfile = open('test.json', 'w')capure = cv2.VideoCapture(args.video_path)#get videonet = my_nanodet(input_shape=args.input_shape, prob_threshold=args.confThreshold, iou_threshold=args.nmsThreshold,model=args.model_path)import time#推理if capure.isOpened():while True:ret,img = capure.read()if not ret: break#获取到图像a = time.time()srcimg = net.detect(img)b = time.time()time_w = b - aprint('waste time', time_w*1000)cv2.imshow('video test', srcimg)cv2.waitKey(1)cv2.destroyAllWindows()else:print('open video error')if __name__== '__main__':main_one()
可视化检测结果
第二段代码有很多冗余的部分,那部分是我想实现多线程但是没写出来,如果多线程把读取视频和推理放在不同线程里进行的话还能让网络更快(更即时),还想实现多线程的同学可以看这篇文http://t.csdn.cn/OqHlN。