基于YOLO的车牌检测识别(YOLO+Transformer)

news/2024/9/19 0:50:41/ 标签: YOLO, transformer, 车牌识别

概述:
基于深度学习的车牌识别,其中,车辆检测网络直接使用YOLO侦测。而后,才是使用网络侦测车牌与识别车牌号。

车牌的侦测网络,采用的是resnet18,网络输出检测边框的仿射变换矩阵,可检测任意形状的四边形。

车牌号序列模型,采用Resnet18+transformer模型,直接输出车牌号序列。

数据集上,车牌检测使用CCPD 2019数据集,在训练检测模型的时候,会使用程序生成虚假的车牌,覆盖于数据集图片上,来加强检测的能力。

车牌号的序列识别,直接使用程序生成的车牌图片训练,并佐以适当的图像增强手段。模型的训练直接采用端到端的训练方式,输入图片,直接输出车牌号序列,损失采用CTCLoss。

一、网络模型
1、车牌的侦测网络模型:

网络代码定义如下:

class WpodNet(nn.Module):def __init__(self):"""车牌侦测网络,直接使用Resnet18,仅改变输出层。"""super(WpodNet, self).__init__()resnet = resnet18(True)backbone = list(resnet.children())self.backbone = nn.Sequential(nn.BatchNorm2d(3),*backbone[:3],*backbone[4:8],)self.detection = nn.Conv2d(512, 8, 3, 1, 1)def forward(self, x):features = self.backbone(x)out = self.detection(features)out = rearrange(out, 'n c h w -> n h w c') # 变换形状return out

该网络,相当于直接对图片划分cell,即在16X16的格子中,侦测车牌,输出的为该车牌边框的反射变换矩阵。

2、车牌号的序列识别网络:
车牌号序列识别的主干网络:采用的是ResNet18+transformer,其中有ResNet18完成对图片的编码工作,再由transformer解码为对应的字符。

网络代码定义如下:

from torch import nn
from torchvision.models import resnet18
import torch
from einops import rearrangeclass OcrNet(nn.Module):def __init__(self,num_class):super(OcrNet, self).__init__()resnet = resnet18(True)backbone = list(resnet.children())self.backbone = nn.Sequential(nn.BatchNorm2d(3),*backbone[:3],*backbone[4:8],)  # 创建ResNet18self.decoder = nn.Sequential(Block(512, 8, False),Block(512, 8, False),)  # 由Transformer 构成的解码器self.out_layer = nn.Linear(512, num_class)  # 线性输出层self.abs_pos_emb = AbsPosEmb((3, 9), 512)  # 绝对位置编码def forward(self,x):x = self.backbone(x)x = rearrange(x,'n c h w -> n (w h) c')x = x + self.abs_pos_emb()x = self.decoder(x)x = rearrange(x, 'n s v -> s n v')return self.out_layer(x)

其中的Block类的代码如下:

class Block(nn.Module):r"""Args:embed_dim: 词向量的特征数。num_head: 多头注意力的头数。is_mask: 是否添加掩码。是,则网络只能看到每个词前的内容,而无法看到后面的内容。Shape:- Input: N,S,V (批次,序列数,词向量特征数)- Output:same shape as the inputExamples::# >>> m = Block(720, 12)# >>> x = torch.randn(4, 13, 720)# >>> output = m(x)# >>> print(output.shape)# torch.Size([4, 13, 720])"""def __init__(self, embed_dim, num_head, is_mask):super(Block, self).__init__()self.ln_1 = nn.LayerNorm(embed_dim)self.attention = SelfAttention(embed_dim, num_head, is_mask)self.ln_2 = nn.LayerNorm(embed_dim)self.feed_forward = nn.Sequential(nn.Linear(embed_dim, embed_dim * 6),nn.ReLU(),nn.Linear(embed_dim * 6, embed_dim))def forward(self, x):'''计算多头自注意力'''attention = self.attention(self.ln_1(x))'''残差'''x = attention + xx = self.ln_2(x)'''计算feed forward部分'''h = self.feed_forward(x)x = h + x  # 增加残差return x

位置编码的代码如下:

class AbsPosEmb(nn.Module):def __init__(self,fmap_size,dim_head):super().__init__()height, width = fmap_sizescale = dim_head ** -0.5self.height = nn.Parameter(torch.randn(height, dim_head) * scale)self.width = nn.Parameter(torch.randn(width, dim_head) * scale)def forward(self):emb = rearrange(self.height, 'h d -> h () d') + rearrange(self.width, 'w d -> () w d')emb = rearrange(emb, ' h w d -> (w h) d')return emb

Block类使用的自注意力代码如下:

class SelfAttention(nn.Module):r"""多头自注意力Args:embed_dim: 词向量的特征数。num_head: 多头注意力的头数。is_mask: 是否添加掩码。是,则网络只能看到每个词前的内容,而无法看到后面的内容。Shape:- Input: N,S,V (批次,序列数,词向量特征数)- Output:same shape as the inputExamples::# >>> m = SelfAttention(720, 12)# >>> x = torch.randn(4, 13, 720)# >>> output = m(x)# >>> print(output.shape)# torch.Size([4, 13, 720])"""def __init__(self, embed_dim, num_head, is_mask=True):super(SelfAttention, self).__init__()assert embed_dim % num_head == 0self.num_head = num_headself.is_mask = is_maskself.linear1 = nn.Linear(embed_dim, 3 * embed_dim)self.linear2 = nn.Linear(embed_dim, embed_dim)def forward(self, x):'''x 形状 N,S,V'''x = self.linear1(x)  # 形状变换为N,S,3Vn, s, v = x.shape"""分出头来,形状变换为 N,S,H,V"""x = x.reshape(n, s, self.num_head, -1)"""换轴,形状变换至 N,H,S,V"""x = torch.transpose(x, 1, 2)'''分出Q,K,V'''query, key, value = torch.chunk(x, 3, -1)dk = value.shape[-1] ** 0.5'''计算自注意力'''w = torch.matmul(query, key.transpose(-1, -2)) / dk  # w 形状 N,H,S,Sif self.is_mask:"""生成掩码"""mask = torch.tril(torch.ones(w.shape[-1], w.shape[-1])).to(w.device)w = w * mask - 1e10 * (1 - mask)w = torch.softmax(w, dim=-1)  # softmax归一化attention = torch.matmul(w, value)  # 各个向量根据得分合并合并, 形状 N,H,S,V'''换轴至 N,S,H,V'''attention = attention.permute(0, 2, 1, 3)n, s, h, v = attention.shape'''合并H,V,相当于吧每个头的结果cat在一起。形状至N,S,V'''attention = attention.reshape(n, s, h * v)return self.linear2(attention)  # 经过线性层后输出

二、数据加载
1、车牌号的数据加载
通过程序生成一组车牌号:
在这里插入图片描述

再通过数据增强,主要包括:
随机污损:
在这里插入图片描述
高斯模糊:
在这里插入图片描述
仿射变换,粘贴于一张大图中:
在这里插入图片描述
四边形的四个角的位置随机偏移些许后扣出:
在这里插入图片描述

然后直接训练车牌号的序列识别网络,

loss_func = nn.CTCLoss(blank=0, zero_infinity=True)
optimizer = torch.optim.Adam(self.net.parameters(), lr=0.00001)

优化器直接使用Adam,损失函数为CTCLoss。

2、车牌检测的数据加载
数据使用的是CCPD数据集,在这过程中,会随机的使用生成车牌,覆盖原始图片的车牌位置,来训练网络对车牌的检测能力。

if random.random() < 0.5:plate, _ = self.draw()plate = cv2.cvtColor(plate, cv2.COLOR_RGB2BGR)plate = self.smudge(plate)  # 随机污损image = enhance.apply_plate(image, points, plate)  # 粘贴车牌图片于数据图中
[x1, y1, x2, y2, x4, y4, x3, y3] = points
points = [x1, x2, x3, x4, y1, y2, y3, y4]
image, pts = enhance.augment_detect(image, points, 208)

三、训练
分别训练即可
其中,侦测网络的损失计算,如下:

def count_loss(self, predict, target):condition_positive = target[:, :, :, 0] == 1  # 筛选标签condition_negative = target[:, :, :, 0] == 0predict_positive = predict[condition_positive]predict_negative = predict[condition_negative]target_positive = target[condition_positive]target_negative = target[condition_negative]n, v = predict_positive.shapeif n > 0:loss_c_positive = self.c_loss(predict_positive[:, 0:2], target_positive[:, 0].long())else:loss_c_positive = 0loss_c_nagative = self.c_loss(predict_negative[:, 0:2], target_negative[:, 0].long())loss_c = loss_c_nagative + loss_c_positiveif n > 0:affine = torch.cat((predict_positive[:, 2:3],predict_positive[:,3:4],predict_positive[:,4:5],predict_positive[:,5:6],predict_positive[:,6:7],predict_positive[:,7:8]),dim=1)# print(affine.shape)# exit()trans_m = affine.reshape(-1, 2, 3)unit = torch.tensor([[-0.5, -0.5, 1], [0.5, -0.5, 1], [0.5, 0.5, 1], [-0.5, 0.5, 1]]).transpose(0, 1).to(trans_m.device).float()# print(unit)point_pred = torch.einsum('n j k, k d -> n j d', trans_m, unit)point_pred = rearrange(point_pred, 'n j k -> n (j k)')loss_p = self.l1_loss(point_pred, target_positive[:, 1:])else:loss_p = 0# exit()return loss_c, loss_p

侦测网络输出的反射变换矩阵,但对车牌位置的标签给的是四个角点的位置,所以需要响应转换后,做损失。其中,该cell是否有目标,使用CrossEntropyLoss,而对车牌位置损失,采用的则是L1Loss。

四、推理

根目录下运行,

python kenshutsu.py

记得修改py文件中的模型权重路径位置。

在这里插入图片描述

推理解析:
1、侦测网络的推理
按照一般侦测网络,推理即可。只是,多了一步将反射变换矩阵转换为边框位置的计算。
另外,在YOLO侦测到得测量图片传入该级进行车牌检测的时候,会做一步操作。代码见下,将车辆检测框的图片扣出,然后resize到长宽均为16的整数倍。

h, w, c = image.shape
f = min(288 * max(h, w) / min(h, w), 608) / min(h, w)
_w = int(w * f) + (0 if w % 16 == 0 else 16 - w % 16)
_h = int(h * f) + (0 if h % 16 == 0 else 16 - h % 16)
image = cv2.resize(image, (_w, _h), interpolation=cv2.INTER_AREA)

在这里插入图片描述

2、序列检测网络的推理
对网络输出的序列,进行去重操作即可,如间隔标识符为“*”时:

def deduplication(self, c):'''符号去重'''temp = ''new = ''for i in c:if i == temp:continueelse:if i == '*':temp = icontinuenew += itemp = ireturn new

五、完整代码

https://github.com/HibikiJie/LicensePlate


http://www.ppmy.cn/news/1519368.html

相关文章

「bug」nvitop ERROR: Failed to initialize curses

nvitop 作为一个优秀个 Nvidia显卡查询库&#xff0c;简单易用且显示信息十分丰富&#xff0c;相比 Nvidia-smi 更方便&#xff0c;简直是每个 开发人员必备的库&#xff0c;安装也十分方便&#xff0c;直接采用 pip install nvitop 即可&#xff0c;调用的时候也是直接在 Term…

【Python报错已解决】“ModuleNotFoundError: No module named ‘timm‘”

&#x1f3ac; 鸽芷咕&#xff1a;个人主页 &#x1f525; 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想&#xff0c;就是为了理想的生活! 文章目录 引言&#xff1a;一、问题描述1.1 报错示例&#xff1a;当我们尝试导入timm库时&#xff0c;可能会看到以下错误信息。…

k8s sa

在Kubernetes&#xff08;K8S&#xff09;中&#xff0c;SA是Service Account&#xff08;服务账户&#xff09;的简称。Service Account是Kubernetes集群中的一种资源对象&#xff0c;用于识别和验证Pod访问集群中其他资源的身份。以下是关于K8S SA的详细解释&#xff1a; 一、…

JavaScript中将style的String类型转换成Object类型

在React开发中&#xff0c;我们或许经常遇到要将font-size:20px;转换成对象类型{fontSize:"20px"},如下我自己写了个类&#xff0c;正则匹配-后面的第一个字为大写字母&#xff0c;并且去掉-,然后将:后的属性转换为字符串类型&#xff0c;代码如下 function styleSt…

GitLab 是什么?GitLab使用常见问题解答

GitLab 是什么 GitLab是由GitLab Inc.开发&#xff0c;使用MIT许可证的基于网络的Git仓库管理工具开源项目&#xff0c;且具有wiki和issue跟踪功能&#xff0c;使用Git作为代码管理工具&#xff0c;并在此基础上搭建起来的web服务。 ​GitLab 是由 GitLab Inc.开发&#xff0c…

k3s安装部署说明

前言: 为什么不是k8s&#xff0c;k8s机子要求资源太高了&#xff0c;先来个简单的k3s 1:环境 ubuntu18 2安装docker ubuntu18.0.4 如下 1:禁用防火墙及SELinux(可能需要禁止) systemctl stop firewalld && systemctl disable firewalld 2: 开启路由转发 sudo vim /e…

微信小程序:自定义扫码功能

我们今天主要是介绍小程序自定义扫码的应用&#xff0c;相关业务处理可以根据自己需求来填补 WXML&#xff1a; <view class"scan-box direction-column" wx:if"{{ showCanScan }}"><camera class"camera" resolution"high"…

摄像头的ISP和SOC的GPU有区别吗?

摄像头的主芯片必须包含ISP&#xff0c;也就是图像处理器核心。而SOC的GPU或者说显卡也包含图像处理器也就是GPU。两者并无本质区别&#xff0c;都是实现数字图像处理算法。同样的用FPGA做内窥镜图像处理和用FPGA做显示图像处理器本质上也是一样的。 当然两者存在一些细微差别…

【BLE】四.SMP安全配对详解

设备配对流程 SMP专业术语 Paring&#xff08;配对&#xff09;&#xff1a; 配对能力交换&#xff0c;设备认证&#xff0c;密钥生成&#xff0c;连接加密以及机密信息分发等 过程 Bonding&#xff08;绑定&#xff09; 配对中会生成一个长期密钥&#xff08;LTK&#xff0c;…

灾难性遗忘问题(Catastrophic Forgetting,CF)是什么?

灾难性遗忘问题&#xff08;Catastrophic Forgetting&#xff0c;CF&#xff09;是什么&#xff1f; 在深度学习和人工智能领域中&#xff0c;“灾难性遗忘”&#xff08;Catastrophic Forgetting&#xff09;是指当神经网络在增量学习&#xff08;Incremental Learning&#…

使用 Milvus Lite、Llama3 和 LlamaIndex 搭建 RAG 应用

大语言模型&#xff08;LLM&#xff09;已经展示出与人类交互并生成文本响应的卓越能力。这些模型可以执行各种自然语言任务&#xff0c;如翻译、概括、代码生成和信息检索等。 为完成这些任务&#xff0c;LLM 需要基于海量数据进行预训练。在这个过程中&#xff0c;LLM 基于给…

ComfyUI:基于差分扩散的像素级图像修改

在几个月的沉寂之后&#xff0c;差分扩散&#xff08;Differential Diffusion&#xff09;被引入了。 玩了几天之后&#xff0c;我仍然对结果感到震惊。 这种新的先进方法允许以像素为基础进行更改&#xff0c;而不是以整个区域为基础进行更改。 另一种可能更通俗的说法&…

Git 基础使用--权限管理--用户和用户组授权

&#x1f600;前言 本篇博文是关于Git 基础使用–权限管理–用户和用户组授权&#xff0c;希望你能够喜欢 &#x1f3e0;个人主页&#xff1a;晨犀主页 &#x1f9d1;个人简介&#xff1a;大家好&#xff0c;我是晨犀&#xff0c;希望我的文章可以帮助到大家&#xff0c;您的满…

SparkSQL缓存的用法

前言 SparkSQL关于缓存的操作语句官方给了三种: CACHE TABLE(缓存表)UNCACHE TABLE(清除指定缓存表)CLEAR CACHE(清除所有缓存表)下面我们详细讲解这些语句的使用方法。 CACHE TABLE CACHE TABLE 语句使用给定的存储级别缓存表的内容或查询的输出。如果一个查询被缓存…

部署Rancher2.9管理K8S1.26集群

文章目录 一、实验须知1、Rancher简介2、当前实验环境 二、部署Rancher1、服务器初始化操作2、部署Rancher3、登入Rancher平台 三、Rancher对接K8S集群四、通过Rancher仪表盘部署Nginx服务1、创建命名空间2、创建Deployment3、创建Service 一、实验须知 1、Rancher简介 中文官…

Redis KEY操作实战手册:从设计到维护的全面指南

​ &#x1f308; 个人主页&#xff1a;danci_ &#x1f525; 系列专栏&#xff1a;《设计模式》《MYSQL》 &#x1f4aa;&#x1f3fb; 制定明确可量化的目标&#xff0c;坚持默默的做事。 ✨欢迎加入探索Redis的key的相关操作之旅✨ &#x1f44b; 大家好&#xff01;文本…

spring入门(一)spring简介

一、spring简介 Spring技术是JavaEE开发必备技能&#xff0c;企业开发技术选型命中率>90% spring能够简化开发&#xff0c;降低企业级开发的复杂性。框架整合&#xff0c;高效整合其他技术&#xff0c;提高企业级应用开发与运行效率。 主要学习&…

《计算机操作系统》(第4版)第12章 保护和安全 复习笔记

第12章 保护和安全 一、安全环境 1.实现“安全环境”的主要目标和面临的威胁 实现“安全环境”的主要目标和威胁如表12-1所示。 目标 威胁 数据机密性 对应的威胁为有人通过各种方式窃取系统中的机密信息使数据暴露 数据完整性 对应的威胁为攻击者擅自修改系统中所保存的数…

metallb-speaker缓存

手动修改metallb-config arping返回2个mac地址 删除pod mac正常返回 pkill 进程 返回2个mac

SQL进阶技巧:用户不同 WiFi 行为区间划分分析 | 断点分组问题

目录 0 场景描述 1 数据准备 2 问题分析 3 小结 0 场景描述 现有用户扫描或连接 WiFi 记录表 user_wifi_log ,每一行数据表示某时刻用户扫描或连接 WiFi 的日志。 现需要进行用户行为分析,如何划分用户不同 WiFi 行为区间?满足: 1)行为类型分为两种:连接(scan)、…