基于CNN+RNNs(LSTM, GRU)的红点位置检测(pytorch)

devtools/2024/11/25 21:14:54/

1 项目背景

需要在图片精确识别三跟红线所在的位置,并输出这三个像素的位置。

在这里插入图片描述
其中,每跟红线占据不止一个像素,并且像素颜色也并不是饱和度和亮度极高的红黑配色,每个红线放大后可能是这样的。

在这里插入图片描述

而我们的目标是精确输出每个红点的位置,需要精确到像素。也就是说,对于每根红线,模型需要输出橙色箭头所指的像素而不是蓝色箭头所指的像素的位置。

在之前尝试过纯 RNNs 检测红点,但是准确率感人,在噪声极低的情况下并不能精准识别位置。但是有次尝试transformer位置编码之后发现效果不错:

实验loss完全准确的点
GRU129.66411762.0/9000 (20%)
LSTM249.20531267.0/9000 (14%)
Position embedding + GRU16.34035025.0/9000 (56%)
Position embedding + LSTM204.15511603.0/9000 (18%)

这说明模型的难点在于学习位置信息而不是寻找颜色有问题的点。联想到CNN也能提供位置信息,我决定尝试卷积一下的效果。

2 数据集

还是之前那个代码合成的数据集数据集,每个数据集规模在15000张图片左右,在没有加入噪音的情况下,每个样本预览如图所示:
在这里插入图片描述
加入噪音后,每个样本的预览如下图所示:

在这里插入图片描述

图中黑色部分包含比较弱的噪声,并非完全为黑色。

数据集包含两个文件,一个是文件夹,里面包含了jpg压缩的图像数据:
在这里插入图片描述
另一个是csv文件,里面包含了每个图像的名字以及3根红线所在的像素的位置。

在这里插入图片描述

3 思路

其实思路特别朴素。就是在RNNs要读序列化数据之前先用CNN把数据跑一遍,让原始的输入序列变成具有局部特征表示的嵌入表示,卷积后提取的特征输入到 RNN层,RNN 保持了序列中的长时依赖信息。接下来先用 fc1 把 RNN 的输出映射成分数,然后用 fc2 预测三个具体位置,经过 Sigmoid 输出 [0, 1] 的相对位置,再与宽度相乘得到真实位置。具体的流程如下图所示:

在这里插入图片描述

4 结果

在图片长度为1080、低噪声环境时,对比实验的结果如下:

实验loss完全准确的点
GRU129.66411762.0/9000 (20%)
LSTM249.20531267.0/9000 (14%)
CNN+GRU1419.5781601.0/9000 (7%)
CNN+LSTM1166.4599762.0/9000 (8%)

1080长度下图片抽样预测的效果如下:

在这里插入图片描述

在简单图片中的效果跟其他方法差距不大——基本都能准确定位红线,但是还是没办法做到像素级别的精确

在这里插入图片描述

可能是我的打开方式不对,但是CNN+RNN的效果并不如意。

从训练过程来看存在过拟合:

在这里插入图片描述

5 代码

CNN+GRU结构:


class CNN_GRU(nn.Module):def __init__(self, config):super(CNN_GRU, self).__init__()self.input_size = config.input_sizeself.hidden_size = config.hidden_sizeself.num_layers = config.num_layersself.device = config.device# CNNself.conv1 = nn.Conv1d(in_channels=self.input_size, out_channels=64, kernel_size=3, padding=1)self.conv2 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, padding=1)self.conv3 = nn.Conv1d(in_channels=128, out_channels=self.input_size, kernel_size=3, padding=1)self.gru = nn.GRU(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers,batch_first=True, bidirectional=True, dropout=0.6)self.fc1 = nn.Sequential(nn.Linear(self.hidden_size * 2, 1))self.fc2 = nn.Sequential(nn.Linear(config.width, 3),  # predict 3 pointsnn.Sigmoid(),)self.scale = config.widthself.device = config.devicedef forward(self, x):x = x.squeeze(2)x = F.relu(self.conv1(x))  # (batch_size, 64, width)x = F.relu(self.conv2(x))  # (batch_size, 128, width)x = F.relu(self.conv3(x))  # (batch_size, input_size, width)x = x.permute(0, 2, 1)h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)output, _ = self.gru(x0, h0)scores = self.fc1(output).squeeze(-1)  # shape: (batch_size, 1080)predicted_positions = self.fc2(scores)scaled_predicted_positions = predicted_positions * self.scalefinal_predicted_positions = torch.clamp(scaled_predicted_positions, min=0, max=self.scale - 1)return final_predicted_positions

CNN+LSTM结构:

class CNN_GRU(nn.Module):def __init__(self, config):super(CNN_GRU, self).__init__()self.input_size = config.input_sizeself.hidden_size = config.hidden_sizeself.num_layers = config.num_layersself.device = config.device# CNNself.conv1 = nn.Conv1d(in_channels=self.input_size, out_channels=64, kernel_size=3, padding=1)self.conv2 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, padding=1)self.conv3 = nn.Conv1d(in_channels=128, out_channels=self.input_size, kernel_size=3, padding=1)self.lstm = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers,batch_first=True, bidirectional=True, dropout=0.6)self.fc1 = nn.Sequential(nn.Linear(self.hidden_size * 2, 1))self.fc2 = nn.Sequential(nn.Linear(config.width, 3),  # predict 3 pointsnn.Sigmoid(),)self.scale = config.widthself.device = config.devicedef forward(self, x):x = x.squeeze(2)x = F.relu(self.conv1(x))  # (batch_size, 64, width)x = F.relu(self.conv2(x))  # (batch_size, 128, width)x = F.relu(self.conv3(x))  # (batch_size, input_size, width)x = x.permute(0, 2, 1)h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)output, _ = self.lstm(x, (h0, c0))scores = self.fc1(output).squeeze(-1)  # shape: (batch_size, 1080)predicted_positions = self.fc2(scores)scaled_predicted_positions = predicted_positions * self.scalefinal_predicted_positions = torch.clamp(scaled_predicted_positions, min=0, max=self.scale - 1)return final_predicted_positions

路过的大佬有什么建议 ball ball 在评论区打出来,我会去尝试~


http://www.ppmy.cn/devtools/136936.html

相关文章

golang学习6-指针

指针就是地址。 指针变量就是存储地址的变量。 *p:解引用、间接引用。 栈帧:用来给函数运行提供内存空间。取内存于 stack 上。 当函数调用时,产生栈帧。函数调用结束,释放栈帧。 栈帧存储:1.同部变量。2.形参。(形参与局部变量存储地位等同)3.内存字段…

cookie反爬----普通服务器,阿里系

目录 一.常见COOKIE反爬 普通: 1. 简介 2. 加密原理 二.实战案例 1. 服务器响应cookie信息 1. 逆向目标 2. 逆向分析 2. 阿里系cookie逆向 1. 逆向目标 2. 逆向分析 实战: 无限debugger原理 1. Function("debugger").call() 2. …

神经网络(系统性学习四):深度学习——卷积神经网络(CNN)

相关文章: 神经网络中常用的激活函数神经网络(系统性学习一):入门篇神经网络(系统性学习二):单层神经网络(感知机)神经网络(系统性学习三)&#…

Apache Maven简介

Apache Maven 是一款强大的项目管理和构建自动化工具,主要应用于Java项目。它简化了构建流程、依赖管理以及项目配置。本文将向您介绍Apache Maven,解释其核心概念,并指导您掌握Maven的基本使用方法。 什么是Apache Maven? Mave…

golang实现TCP服务器与客户端的断线自动重连功能

1.服务端 2.客户端 生成服务端口程序: 生成客户端程序: 测试断线重连: 初始连接成功

git使用详解

一、git介绍 1、git简介 Git 是一个开源的分布式版本控制系统(最先进的,没有之一),用于敏捷高效地处理任何或小或大的项目。 Git 是 Linus Torvalds 为了帮助管理 Linux 内核开发而开发的一个开放源码的版本控制软件。 Git 与常用…

基于python的机器学习(四)—— 聚类(一)

目录 一、聚类的原理与实现 1.1 聚类的概念和类型 1.2 如何度量距离 1.2.1 数据的类型 1.2.2 连续型数据的距离度量方法 1.2.3 离散型数据的距离度量方法 1.3 聚类的基本步骤 二、层次聚类算法 2.1 算法原理和实例 2.2 算法的Sklearn实现 2.2.1 层次聚类法的可视化实…

FPGA经验谈系列文章——7、预估逻辑级数

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 FPGA经验谈系列文章——7、预估逻辑级数 预估逻辑级数逻辑层级拆分1、加法器拆分2、比较器拆分总结预估逻辑级数 前面我们已经分析了加法器、比较器、条件语句的逻辑级数,那让我们来看一下下面这段代码,大…