Vision Transformer模型入门

news/2024/10/21 3:03:48/

Vision Transformer模型入门

  • 一、Vision Transformer 模型
    • 1,Embedding 层结构详解
    • 2,Transformer Encoder 详解
    • 3,MLP Head 详解
  • 二、ViT-B/16 网络结构
  • 三、Hybrid 模型详解
  • 四、ViT 模型搭建参数

一、Vision Transformer 模型

总体三个模块:Embedding层、Transformer Encoder、MLP Head(分类层)
在这里插入图片描述
:以下层结构讲解均以ViT-B/16为例

1,Embedding 层结构详解

Transformer 要求 token 向量,需要 Embedding 层做数据转换。将一张图根据给定大小分为一堆 patches。
例如 224 x 224 的图片按照 16 x 16 的 patch 划分,得到(224 / 16)² = 196 个 patches。接着将每个 patch 映射到一维向量,即每个 patch 的 shape [16,16,3] 通过映射得到一个长度为 768 的向量(即为Transformer 所需的 token 向量)。

代码实现:通过一个 16 x 16 的卷积核,步距 16,卷积个数 768 实现。shape [224,224,3] -> [14,14,768],再把 H 和 W 两个维度展平,[14,14,768] -> [196,768]

在输入Transformer Encoder之前注意需要加上[class]token以及Position Embedding。
在刚刚得到的一堆 tokens 中插入一个专门用于分类的 [class]token,这个 [class]token 是一个可训练的参数,数据格式和其他 token 一样都是一个向量。以 ViT-B/16 为例,就是一个长度为768的向量,与之前从图片中生成的 tokens 拼接在一起,Cat([1, 768], [196, 768]) -> [197, 768]。
关于 Position Embedding 就是 Transformer 中的 Positional Encoding,采用的是一个可训练的参数(1D Pos. Emb.),是直接叠加在 tokens 上的(add),所以 shape 要一样。以 ViT-B/16 为例,刚刚拼接 [class]token 后 shape 是 [197, 768],那么这里的 Position Embedding 的 shape 也是 [197, 768]。
在这里插入图片描述

2,Transformer Encoder 详解

Transformer Encoder 其实就是重复堆叠 Encoder Block L次,主要由以下几部分组成:

  • Layer Norm:对每个 token 进行 Norm 处理(层归一化)
  • Multi-Head Attention
  • Dropout/DropPath
  • MLP Block:全连接 + GELU 激活函数 + Dropout 组成,需要注意第一个全连接层会把输入节点个数翻 4 倍 [197, 768] -> [197, 3072],第二个全连接层会还原回原节点个数 [197, 3072] -> [197, 768]
    在这里插入图片描述
    注意:Transformer Encoder后还有一个 Layer Norm 没有画出来

3,MLP Head 详解

上面通过 Transformer Encoder 后输出的 shape 和输入的 shape 是保持不变的,以 ViT-B/16 为例,输入的是 [197, 768] 输出的还是 [197, 768]。
这里我们只需要分类信息,所以只需提取出 [class]token 生成的对应结果就行,即 [197, 768] 中抽取出 [class]token 对应的 [1, 768]。接着通过 MLP Head 得到最终的分类结果。
在这里插入图片描述

二、ViT-B/16 网络结构

在这里插入图片描述

三、Hybrid 模型详解

Hybrid 混合模型就是将传统 CNN 特征提取和 Transformer 进行结合。下图绘制的是以 ResNet50 作为特征提取器的混合模型。
但这里的 Resnet 与之前讲的 Resnet 有些不同。

  • R50 的卷积层采用的 StdConv2d 而不是传统的 Conv2d
  • 所有的 BatchNorm 层替换成 GroupNorm 层
  • 在原 Resnet50 网络中,stage1 重复堆叠 3 次,stage2 重复堆叠 4 次,stage3 重复堆叠 6 次,stage4 重复堆叠 3 次,但在这里的R50 中,把 stage4 中的 3 个 Block 移至 stage3 中,所以 stage3 中共重复堆叠 9 次

通过 R50 Backbone 进行特征提取后,得到的特征矩阵 shape 是 [14, 14, 1024],接着再输入 Patch Embedding 层,注意Patch Embedding中卷积层 Conv2d 的 kernel_size 和 stride 都变成了 1,只是用来调整 channel。后面的部分和前面ViT中讲的完全一样。
在这里插入图片描述

四、ViT 模型搭建参数

下面给出三个模型(Base/ Large/ Huge)的参数,其中,

  • Layers 就是 Transformer Encoder 中重复堆叠 Encoder Block 的次数
  • Hidden Size 就是对应通过 Embedding 层后每个 token 的 dim(向量的长度)
  • MLP size 是 Transformer Encoder 中 MLP Block 第一个全连接的节点个数(是 Hidden Size 的四倍)
  • Heads 代表 Transformer 中 Multi-Head Attention 的 heads 数
ModelPatch SizeLayersHidden Size DMLP sizeHeadsParams
ViT-Base16x161276830721286M
ViT-Large16x16241024409616307M
ViT-Huge14x14321280512016632M

http://www.ppmy.cn/news/1026251.html

相关文章

实现语音识别系统:手把手教你使用STM32C8T6和LD3320(SPI通信版)实现语音识别

本文实际是对LD3320(SPI通信版)的个人理解,如果单论代码和开发板的资料而言,其实当你购买LD3320的时候,卖家已然提供了很多资料。我在大学期间曾经多次使用LD3320芯片的开发板用于设计系统,我在我的毕业设计…

基于机器学习进行降雨预测 -- 机器学习项目基础篇(13)

在本文中,我们将学习如何构建一个机器学习模型,该模型可以根据一些大气因素预测今天是否会有降雨。这个问题与使用机器学习的降雨预测有关,因为机器学习模型往往在以前已知的任务上表现得更好,而这些任务需要高技能的个人来完成。…

Vue+SpringBoot项目开发:后台登陆功能的实现(二)

写在开始:一个搬砖程序员的随缘记录文章目录 一、SpringBoot项目的搭建二、数据库配置1、新建数据库2、新建用户表 三、SpringBoot项目的配置 一、SpringBoot项目的搭建 项目搭建传送门:从零开始,SpringBoot项目快速搭建 二、数据库配置 1、新建数据库…

1、Java简介+DOS命令+编译运行+一个简单的Java程序

Java类型: JavaSE 标准版:以前称为J2SE JavaEE 企业版:包括技术有:Servlet、Jsp,以前称为J2EE JavaME 微型版:以前称为J2ME Java应用: Android平台应用。 大数据平台开发:Hadoo…

el-tree-select那些事

下拉菜单树形选择器 用于记录工作及日常学习涉及到的一些需求和问题 vue3 el-tree-select那些事 1、获取el-tree-select选中的任意层级的节点对象 1、获取el-tree-select选中的任意层级的节点对象 1-1数据集 1-2画面 1-3代码 1-3-1画面代码 <el-tree-selectv-model"s…

Python 11道字典练习题

前言 大家早好、午好、晚好吖 ❤ ~欢迎光临本文章 有字典 dic {“k1”: “v1”, “k2”: “v2”, “k3”: “v3”}&#xff0c;实现以下功能&#xff1a; 1、遍历字典 dic 中所有的key 参考答案&#xff1a; dic {k1: v1,k2:v2,k3:v3}for k in dic.keys():print(k)2、遍历…

dji uav建图导航系列(二)导航

文章目录 1、导航节点launch文件1.1、节点参数1.2、模拟器节点1.3、无人机雷达-底盘节点1.4、地图服务器节点1.5、AMCL节点1.6、move_base节点1.7、rviz可视化节点2、导航测试2.1、导航实测2.2、动态参数配置 rqt_reconfigure1、导航节点launch文件 导航节点启动文件 uav_navi…

Spring Boot 统一功能处理(拦截器实现用户登录权限的统一校验、统一异常返回、统一数据格式返回)

目录 1. 用户登录权限校验 1.1 最初用户登录权限效验 1.2 Spring AOP 用户统⼀登录验证 1.3 Spring 拦截器 &#xff08;1&#xff09;创建自定义拦截器 &#xff08;2&#xff09;将自定义拦截器添加到系统配置中&#xff0c;并设置拦截的规则 1.4 练习&#xff1a;登录…