Python深度学习实战:人脸关键点(15点)检测pytorch实现

news/2024/11/17 18:55:31/

引言

人脸关键点检测即对人类面部若干个点位置进行检测,可以通过这些点的变化来实现许多功能,该技术可以应用到很多领域,例如捕捉人脸的关键点,然后驱动动画人物做相同的面部表情;识别人脸的面部表情,让机器能够察言观色等等。
在这里插入图片描述

如何检测人脸关键点

本文是实现15点的检测,至于N点的原理都是一样的,使用的算法模型是深度神经网络,使用CV也是可以的。

如何检测

这个问题抽象出来,就是一个使用神经网络来进行预测的功能,只不过输出是15个点的坐标,训练数据包含15个面部的特征点和面部的图像(大小为96x96),15个特征点分别是:left_eye_center, right_eye_center, left_eye_inner_corner, left_eye_outer_corner, right_eye_inner_corner, right_eye_outer_corner, left_eyebrow_inner_end, left_eyebrow_outer_end, right_eyebrow_inner_end, right_eyebrow_outer_end, nose_tip, mouth_left_corner, mouth_right_corner, mouth_center_top_lip, mouth_center_bottom_lip
因此神经网络需要学习一个从人脸图像到15个关键点坐标间的映射。

使用的网络结构

在本文中,我们使用深度神经网络来实现该功能,基本卷积块使用Google的Inception网络,也就是使用GoogLeNet网络,该结构的网络是基于卷积神经网络来改进的,是一个含有并行连接的网络。
众所周知,卷积有滤波、提取特征的作用,但到底采用多大的卷积来提取特征是最好的呢?这个问题没有确切的答案,那就集百家之长:使用多个形状不一的卷积来提取特征并进行拼接,从而学习到更为丰富的特征;特别是里面加上了1x1的卷积结构,能够实现跨通道的信息交互和整合(其本质就是在多个channel上的线性求和),同时能在feature map通道数上的降维(读者可以验证计算一下,能够极大减少卷积核的参数),也能够增加非线性映射次数使得网络能够更深。
下面是Inception块的示意图:
在这里插入图片描述
整个GoogLeNet的结构如下所示:
在这里插入图片描述
接下来是代码实现部分,后续作者会补充神经网络的相关原理知识,若对此感兴趣的读者也可继续关注支持~

代码实现

import torch as tc
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.utils import shuffle# 对图片像素的处理
def proFunc1(data,testFlag:bool=False) -> tuple:data['Image'] = data['Image'].apply(lambda im: np.fromstring(im, sep=' '))# 处理nadata = data.dropna()  # 神经网络对数据范围较为敏感 /255 将所有像素都弄到[0,1]之间X = np.vstack(data['Image'].values) / 255X = X.astype(np.float32)# 特别注意 这里要变成 n channle w h 要跟卷积第一层相匹配X = X.reshape(-1, 1,96, 96) # 等会神经网络的输入层就是 96 96 黑白图片 通道只有一个# 只有训练集才有y 测试集返回一个None出去if not testFlag:  y = data[data.columns[:-1]].values# 规范化y = (y - 48) / 48  X, y = shuffle(X, y, random_state=42)  y = y.astype(np.float32)else:y = Nonereturn X,y# 工具类
class UtilClass:def __init__(self,model,procFun,trainFile:str='data/training.csv',testFile:str='data/test.csv') -> None:self.trainFile = trainFileself.testFile = testFileself.trainData = Noneself.testData = Noneself.trainTarget = Noneself.model = modelself.procFun = procFun@staticmethoddef procData(data, procFunc ,testFlag:bool=False) -> tuple:return procFunc(data,testFlag)def loadResource(self):rawTrain = pd.read_csv(self.trainFile)rawTest = pd.read_csv(self.testFile)self.trainData , self.trainTarget = self.procData(rawTrain,self.procFun)self.testData , _ = self.procData(rawTest,self.procFun,testFlag=True)def getTrain(self):return tc.from_numpy(self.trainData), tc.from_numpy(self.trainTarget)def getTest(self):return tc.from_numpy(self.testData)@staticmethoddef plotData(img, keyPoints, axis):axis.imshow(np.squeeze(img), cmap='gray') # 恢复到原始像素数据 keyPoints = keyPoints * 48 + 48 # 把keypoint弄到图上面axis.scatter(keyPoints[0::2], keyPoints[1::2], marker='o', c='c', s=40)# 自定义的卷积神经网络
class MyCNN(tc.nn.Module):def __init__(self,imgShape = (96,96,1),keyPoint:int = 15):super(MyCNN, self).__init__()self.conv1 = tc.nn.Conv2d(in_channels=1, out_channels =10, kernel_size=3)self.pooling = tc.nn.MaxPool2d(kernel_size=2)self.conv2 = tc.nn.Conv2d(10, 5, kernel_size=3)# 这里的2420是通过下面的计算得出的 如果改变神经网络结构了 # 需要计算最后的Liner的in_feature数量 输出是固定的keyPoint*2self.fc = tc.nn.Linear(2420, keyPoint*2)def forward(self, x):# print("start----------------------")batch_size = x.size(0)# x = x.view((-1,1,96,96))# print('after view shape:',x.shape)x = F.relu(self.pooling(self.conv1(x)))# print('conv1 size',x.shape)x = F.relu(self.pooling(self.conv2(x)))# print('conv2 size',x.shape)# print('end--------------------------')# 改形状x = x.view(batch_size, -1)# print(x.shape)x = self.fc(x)# print(x.shape)return x# GoogleNet基本的卷积块
class MyInception(nn.Module):def __init__(self,in_channels, c1, c2, c3, c4,) -> None:super().__init__()self.p1_1 = nn.Conv2d(in_channels, c1, kernel_size=1)self.p2_1 = nn.Conv2d(in_channels, c2[0], kernel_size=1)self.p2_2 = nn.Conv2d(c2[0], c2[1], kernel_size=3, padding=1)self.p3_1 = nn.Conv2d(in_channels, c3[0], kernel_size=1)self.p3_2 = nn.Conv2d(c3[0], c3[1], kernel_size=5, padding=2)self.p4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)self.p4_2 = nn.Conv2d(in_channels, c4, kernel_size=1)def forward(self, x):p1 = F.relu(self.p1_1(x))p2 = F.relu(self.p2_2(F.relu(self.p2_1(x))))p3 = F.relu(self.p3_2(F.relu(self.p3_1(x))))p4 = F.relu(self.p4_2(self.p4_1(x)))# 在通道维度上连结输出return tc.cat((p1, p2, p3, p4), dim=1)# GoogLeNet的设计 此处参数结果google大量实验得出
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))b2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1),nn.ReLU(),nn.Conv2d(64, 192, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))b3 = nn.Sequential(MyInception(192, 64, (96, 128), (16, 32), 32),MyInception(256, 128, (128, 192), (32, 96), 64),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))b4 = nn.Sequential(MyInception(480, 192, (96, 208), (16, 48), 64),MyInception(512, 160, (112, 224), (24, 64), 64),MyInception(512, 128, (128, 256), (24, 64), 64),MyInception(512, 112, (144, 288), (32, 64), 64),MyInception(528, 256, (160, 320), (32, 128), 128),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))b5 = nn.Sequential(MyInception(832, 256, (160, 320), (32, 128), 128),MyInception(832, 384, (192, 384), (48, 128), 128),nn.AdaptiveAvgPool2d((1,1)),nn.Flatten())uClass = UtilClass(model=None,procFun=proFunc1)
uClass.loadResource()
xTrain ,yTrain = uClass.getTrain()
xTest = uClass.getTest()dataset = TensorDataset(xTrain, yTrain)
trainLoader = DataLoader(dataset, 64, shuffle=True, num_workers=4)# 训练net并进行测试 由于显示篇幅问题 只能打印出极为有限的若干测试图片效果
def testCode(net):optimizer = tc.optim.Adam(params=net.parameters())criterion = tc.nn.MSELoss()for epoch in range(30):trainLoss = 0.0# 这里是用的是mini_batch 也就是说 每次只使用mini_batch个数据大小来计算# 总共有total个 因此总共训练 total/mini_batch 次# 由于不能每组数据只使用一次 所以在下面还要使用一个for循环来对整体训练多次for batchIndex, data in enumerate(trainLoader, 0):input_, y = datayPred = net(input_)loss = criterion(yPred, y)optimizer.zero_grad()loss.backward()optimizer.step()trainLoss += loss.item()# 只在每5个epoch的最后一轮打印信息if batchIndex % 30 ==29 and not epoch % 5 :print("[{},{}] loss:{}".format(epoch + 1, batchIndex + 1, trainLoss / 300))trainLoss = 0.0# 测试print("-----------test begin-------------")# print(xTest.shape)yPost = net(xTest)# print(yPost.shape)import matplotlib.pyplot as plt%matplotlib inlinefig = plt.figure(figsize=(20,20))fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)for i in range(9,18):ax = fig.add_subplot(3, 3, i - 9 + 1, xticks=[], yticks=[])uClass.plotData(xTest[i], y[i], ax)print("-----------test end-------------")if __name__ == "__main__":# 训练MyCNN网络 并可视化在9个测试数据的效果图myNet = MyCNN()testCode(myNet)inception = nn.Sequential(b1, b2, b3, b4, b5, nn.Linear(1024, 30))testCode(inception)

本文使用的数据可在此找到两个data文件,本文有你帮助的话,就给个点赞关注支持一下吧!


http://www.ppmy.cn/news/35169.html

相关文章

动态规划---线性dp和区间dp

动态规划(三) 目录动态规划(三)一:线性DP1.数字三角形1.1数字三角形题目1.2代码思路1.3代码实现(正序and倒序)2.最长上升子序列2.1最长上升子序列题目2.2代码思路2.3代码实现3.最长公共子序列3.1最长公共子序列题目3.2代码思路3.3代码实现4.石子合并4.1题目如下4.2代…

【Java】注解与反射

学习视频:【狂神说Java】注解和反射_哔哩哔哩_bilibili Java内存分析 #mermaid-svg-5DVSYhOqC0pHFfwe {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-5DVSYhOqC0pHFfwe .error-icon{fill:#552222;}#merm…

现代浏览器四大进程

现代浏览器四大进程 一、进程简介及分类 现代浏览器通常使用多进程架构,其中包括以下四种常见的进程: 浏览器进程(Browser Process):浏览器的主进程(负责协调、主控),只有一个 该进…

Spring6 - (03) Spring 入门程序

文章目录Spring6 -(03)Spring 入门程序1. Spring 框架下载(了解即可)2. Spring 框架目录3. Spring 框架jar包4. 第一个 Spring 程序4.1 环境准备4.2 添加 Spring 依赖4.3 添加 junit 依赖4.4 定义 Bean 类4.5 编写 Spring 配置文件…

大数据技术之Hive

第1章Hive基本概念1.1 Hive1.1.1 Hive的产生背景在那一年的大数据开源社区,我们有了HDFS来存储海量数据、MapReduce来对海量数据进行分布式并行计算、Yarn来实现资源管理和作业调度。但是面对海量数据和负责的业务逻辑,开发人员要编写MR来对数据进行统计…

【python实操】年轻人,别用记事本保存数据了,试试数据库吧

为什么用数据库? 数据库比记事本强在哪? 答案很明显,你的文件很多时候都只能被一个人打开,不能被重复打开。当有几百万数据的时候,你如何去查询操作数据,速度上要快,看起来要清晰直接 数据库比我…

Vue项目实战 —— 后台管理系统( pc端 ) —— Pro最终版本

前期回顾 开源项目 —— 原生JS实现斗地主游戏 ——代码极少、功能都有、直接粘贴即用_js斗地主_0.活在风浪里的博客-CSDN博客JS 实现 斗地主网页游戏https://blog.csdn.net/m0_57904695/article/details/128982118?spm1001.2014.3001.5501 通用版后台管理系统,如果…

实战!手把手教你实现学成在线网站首页案例【详细源码】

🌟所属专栏:前端只因变凤凰之路🐔作者简介:rchjr——五带信管菜只因一枚😮前言:该系列将持续更新前端的相关学习笔记,欢迎和我一样的小白订阅,一起学习共同进步~👉文章简…