昇思学习打卡-12-Vision Transformer图像分类

ops/2024/10/18 2:02:38/

文章目录

  • ViT模型学习
  • 构建模型
    • Multi-Head Attention
    • TransformerEncoder
    • pos_embedding
    • Vit部分实现
  • 推理结果

ViT模型学习

Vision Transformer(ViT)简介
ViT则是自然语言处理和计算机视觉两个领域的融合结晶。在不依赖卷积操作的情况下,依然可以在图像分类任务上达到很好的效果。

模型特点
ViT模型主要应用于图像分类领域。因此,其模型结构相较于传统的Transformer有以下几个特点:

  • 数据集的原图像被划分为多个patch(图像块)后,将二维patch(不考虑channel)转换为一维向量,再加上类别向量与位置向量作为模型输入。
  • 模型主体的Block结构是基于Transformer的Encoder结构,但是调整了Normalization的位置,其中,最主要的结构依然是Multi-head Attention结构。
  • 模型在Blocks堆叠后接全连接层,接受类别向量的输出作为输入并用于分类。通常情况下,我们将最后的全连接层称为Head,Transformer Encoder部分为backbone。

构建模型

Multi-Head Attention

  • 多头注意力可以实现并行加速,且可以保持参数量
from mindspore import nn, opsclass Attention(nn.Cell):def __init__(self,dim: int,num_heads: int = 8,keep_prob: float = 1.0,attention_keep_prob: float = 1.0):super(Attention, self).__init__()self.num_heads = num_headshead_dim = dim // num_headsself.scale = ms.Tensor(head_dim ** -0.5)self.qkv = nn.Dense(dim, dim * 3)self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)self.out = nn.Dense(dim, dim)self.out_drop = nn.Dropout(p=1.0-keep_prob)self.attn_matmul_v = ops.BatchMatMul()self.q_matmul_k = ops.BatchMatMul(transpose_b=True)self.softmax = nn.Softmax(axis=-1)def construct(self, x):"""Attention construct."""b, n, c = x.shapeqkv = self.qkv(x)qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))q, k, v = ops.unstack(qkv, axis=0)attn = self.q_matmul_k(q, k)attn = ops.mul(attn, self.scale)attn = self.softmax(attn)attn = self.attn_drop(attn)out = self.attn_matmul_v(attn, v)out = ops.transpose(out, (0, 2, 1, 3))out = ops.reshape(out, (b, n, c))out = self.out(out)out = self.out_drop(out)return out

TransformerEncoder

class TransformerEncoder(nn.Cell):def __init__(self,dim: int,num_layers: int,num_heads: int,mlp_dim: int,keep_prob: float = 1.,attention_keep_prob: float = 1.0,drop_path_keep_prob: float = 1.0,activation: nn.Cell = nn.GELU,norm: nn.Cell = nn.LayerNorm):super(TransformerEncoder, self).__init__()layers = []for _ in range(num_layers):normalization1 = norm((dim,))normalization2 = norm((dim,))attention = Attention(dim=dim,num_heads=num_heads,keep_prob=keep_prob,attention_keep_prob=attention_keep_prob)feedforward = FeedForward(in_features=dim,hidden_features=mlp_dim,activation=activation,keep_prob=keep_prob)layers.append(nn.SequentialCell([ResidualCell(nn.SequentialCell([normalization1, attention])),ResidualCell(nn.SequentialCell([normalization2, feedforward]))]))self.layers = nn.SequentialCell(layers)def construct(self, x):"""Transformer construct."""return self.layers(x)

pos_embedding

class PatchEmbedding(nn.Cell):MIN_NUM_PATCHES = 4def __init__(self,image_size: int = 224,patch_size: int = 16,embed_dim: int = 768,input_channels: int = 3):super(PatchEmbedding, self).__init__()self.image_size = image_sizeself.patch_size = patch_sizeself.num_patches = (image_size // patch_size) ** 2self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)def construct(self, x):"""Path Embedding construct."""x = self.conv(x)b, c, h, w = x.shapex = ops.reshape(x, (b, c, h * w))x = ops.transpose(x, (0, 2, 1))return x

Vit部分实现

from mindspore.common.initializer import Normal
from mindspore.common.initializer import initializer
from mindspore import Parameterdef init(init_type, shape, dtype, name, requires_grad):"""Init."""initial = initializer(init_type, shape, dtype).init_data()return Parameter(initial, name=name, requires_grad=requires_grad)class ViT(nn.Cell):def __init__(self,image_size: int = 224,input_channels: int = 3,patch_size: int = 16,embed_dim: int = 768,num_layers: int = 12,num_heads: int = 12,mlp_dim: int = 3072,keep_prob: float = 1.0,attention_keep_prob: float = 1.0,drop_path_keep_prob: float = 1.0,activation: nn.Cell = nn.GELU,norm: Optional[nn.Cell] = nn.LayerNorm,pool: str = 'cls') -> None:super(ViT, self).__init__()self.patch_embedding = PatchEmbedding(image_size=image_size,patch_size=patch_size,embed_dim=embed_dim,input_channels=input_channels)num_patches = self.patch_embedding.num_patchesself.cls_token = init(init_type=Normal(sigma=1.0),shape=(1, 1, embed_dim),dtype=ms.float32,name='cls',requires_grad=True)self.pos_embedding = init(init_type=Normal(sigma=1.0),shape=(1, num_patches + 1, embed_dim),dtype=ms.float32,name='pos_embedding',requires_grad=True)self.pool = poolself.pos_dropout = nn.Dropout(p=1.0-keep_prob)self.norm = norm((embed_dim,))self.transformer = TransformerEncoder(dim=embed_dim,num_layers=num_layers,num_heads=num_heads,mlp_dim=mlp_dim,keep_prob=keep_prob,attention_keep_prob=attention_keep_prob,drop_path_keep_prob=drop_path_keep_prob,activation=activation,norm=norm)self.dropout = nn.Dropout(p=1.0-keep_prob)self.dense = nn.Dense(embed_dim, num_classes)def construct(self, x):"""ViT construct."""x = self.patch_embedding(x)cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))x = ops.concat((cls_tokens, x), axis=1)x += self.pos_embeddingx = self.pos_dropout(x)x = self.transformer(x)x = self.norm(x)x = x[:, 0]if self.training:x = self.dropout(x)x = self.dense(x)return x

推理结果

在这里插入图片描述
此章节学习到此结束,感谢昇思平台。


http://www.ppmy.cn/ops/56627.html

相关文章

合合TextIn - 大模型加速器

TextIn是合合信息旗下的智能文档处理平台,在智能文字识别领域深耕17年,致力于图像处理、模式识别、神经网络、深度学习、STR、NLP、知识图谱等人工智能领域研究。凭借行业领先的技术实力,为扫描全能王、名片全能王等智能文字识别产品提供强大…

数字孪生技术在智能家居中的应用

引言 随着物联网(IoT)、人工智能(AI)和大数据技术的迅速发展,智能家居已成为现代生活的一个重要组成部分。数字孪生技术作为一种新兴技术,正在为智能家居的优化和升级提供前所未有的机会。本文将探讨数字孪…

安全防御(防火墙)

第二天: 1.恶意程序---一般会具有一下多个或则全部特点 1.非法性:你未经授权它自动运行或者自动下载的,这都属于非法的。那恶意程序一般它会具有这种特点, 2.隐蔽性:一般隐藏的会比较深,目的就是为了防止…

Java Stream API详解:高效处理集合数据的利器

引言 Java 8引入了许多新特性,其中最为显著的莫过于Lambda表达式和Stream API。Stream API提供了一种高效、简洁的方法来处理集合数据,使代码更加简洁明了,且具有较高的可读性和可维护性。本文将深入探讨Java Stream API的使用,包…

51单片机嵌入式开发:9、 STC89C52RC 操作LCD1602技巧

STC89C52RC 操作LCD1602技巧 1 代码工程2 LCD1602使用2.1 LCD1602字库2.2 巧妙使用sprintf2.3 光标显示2.4 写固定长度的字符2.5 所以引入固定长度写入方式: 3 LCD1602操作总结 1 代码工程 承接上文,在原有工程基础上,新建关于lcd1602的c和h…

STM32杂交版(HAL库、音乐盒、闹钟、点阵屏、温湿度)

一、设计描述 本设计精心构建了一个以STM32MP157A高性能单片机为核心控制单元的综合性嵌入式系统。该系统巧妙融合了蜂鸣器、数码管显示器、点阵屏、温湿度传感器、LED指示灯以及按键等多种外设模块,形成了一个功能丰富、操作便捷的杂交版智能设备。通过串口…

Android Retrofit post请求,@Body传递的参数转义问题

文章目录 问题解决原因解决方案一:自己拼接json字符串,Body使用RequestBody类型,比如解决方案二:修改Retrofit的Gson 问题 因为传递的参数字符串中有等号 ,结果传递的时候,打印出来 原始字符串&#xff…

提升AR游戏设计的深度指南:从空间适应到品牌合作

在AWE大会上,来自DB Creations的Blake Gross和Dustin Kochensparger分享了一系列改善AR游戏设计的宝贵经验和技巧。本文将深入探讨这些洞见,为开发者提供实用的设计策略,以创造更加引人入胜的增强现实体验。 一、适应性空间设置:…