【深度学习】解析Vision Transformer (ViT): 从基础到实现与训练

devtools/2024/11/15 6:15:35/

之前介绍:

https://qq742971636.blog.csdn.net/article/details/132061304

文章目录

  • 背景
      • 实现代码示例
      • 解释
  • 训练
      • 数据准备
      • 模型定义
      • 训练和评估
      • 总结

在这里插入图片描述

Vision Transformer(ViT)是一种基于transformer架构的视觉模型,它最初是由谷歌研究团队在论文《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》中提出的。ViT将图像分割成固定大小的patches(例如16x16),并将每个patch视为一个词(类似于NLP中的单词)进行处理。以下是ViT的详细讲解:

背景

在计算机视觉领域,传统的卷积神经网络(CNNs)一直是处理图像的主流方法。然而,CNNs存在一些局限性,如在处理长距离依赖关系时表现不佳。ViT引入了transformer架构,通过全局注意力机制,有效地处理图像中的长距离依赖关系。

实现代码示例

ViT代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeatclass PatchEmbedding(nn.Module):def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):super().__init__()self.img_size = img_sizeself.patch_size = patch_sizeself.grid_size = img_size // patch_sizeself.num_patches = self.grid_size ** 2self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)def forward(self, x):x = self.proj(x)  # [B, embed_dim, H, W]x = x.flatten(2)  # [B, embed_dim, num_patches]x = x.transpose(1, 2)  # [B, num_patches, embed_dim]return xclass Attention(nn.Module):def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):super().__init__()self.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim ** -0.5self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)def forward(self, x):B, N, C = x.shapeqkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]attn = (q @ k.transpose(-2, -1)) * self.scaleattn = attn.softmax(dim=-1)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B, N, C)x = self.proj(x)x = self.proj_drop(x)return xclass MLP(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.act = nn.GELU()self.fc2 = nn.Linear(hidden_features, out_features)self.drop = nn.Dropout(drop)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return xclass Block(nn.Module):def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0.):super().__init__()self.norm1 = nn.LayerNorm(dim)self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)self.drop_path = nn.Identity() if drop_path == 0 else nn.Dropout(drop_path)self.norm2 = nn.LayerNorm(dim)self.mlp = MLP(in_features=dim, hidden_features=int(dim * mlp_ratio), drop=drop)def forward(self, x):x = x + self.drop_path(self.attn(self.norm1(x)))x = x + self.drop_path(self.mlp(self.norm2(x)))return xclass VisionTransformer(nn.Module):def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):super().__init__()self.num_classes = num_classesself.num_features = self.embed_dim = embed_dimself.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)num_patches = self.patch_embed.num_patchesself.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))self.pos_drop = nn.Dropout(p=drop_rate)dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]self.blocks = nn.ModuleList([Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i])for i in range(depth)])self.norm = nn.LayerNorm(embed_dim)self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()nn.init.trunc_normal_(self.pos_embed, std=0.02)nn.init.trunc_normal_(self.cls_token, std=0.02)self.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):nn.init.trunc_normal_(m.weight, std=0.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.zeros_(m.bias)elif isinstance(m, nn.LayerNorm):nn.init.ones_(m.weight)nn.init.zeros_(m.bias)def forward(self, x):B = x.shape[0]x = self.patch_embed(x)cls_tokens = self.cls_token.expand(B, -1, -1)x = torch.cat((cls_tokens, x), dim=1)x = x + self.pos_embedx = self.pos_drop(x)for blk in self.blocks:x = blk(x)x = self.norm(x)cls_token_final = x[:, 0]x = self.head(cls_token_final)return x# 示例输入
img = torch.randn(1, 3, 224, 224)
model = VisionTransformer()
output = model(img)
print(output.shape)  # 输出大小为 [1, 1000]

解释

  1. PatchEmbedding:将输入图像分割为不重叠的patches,并通过卷积操作将其转换为embedding。
  2. Attention:实现自注意力机制。
  3. MLP:实现多层感知器(MLP),包括GELU激活函数和Dropout。
  4. Block:包含一个注意力层和一个MLP层,每层都有残差连接和层归一化。
  5. VisionTransformer:组合上述模块,形成完整的ViT模型。包含位置嵌入和分类头。

训练

为了在GPU上训练ViT模型,你可以使用PyTorch中的DataLoader来处理数据,并确保模型和数据都在GPU上。以下是一个详细的代码示例,包括数据准备、模型定义、训练和评估。

数据准备

假设你的数据结构如下:

dataset/class1/img1.jpgimg2.jpg...class2/img1.jpgimg2.jpg......

你可以使用 torchvision.datasets.ImageFolder 来加载数据。

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm# 数据转换和增强
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 加载数据
data_dir = 'dataset'
train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=transform)
val_dataset = datasets.ImageFolder(os.path.join(data_dir, 'val'), transform=transform)train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)# 获取类别数
num_classes = len(train_dataset.classes)

模型定义

定义ViT模型并将其移动到GPU上。

# VisionTransformer定义(使用上面的定义)
model = VisionTransformer(num_classes=num_classes).cuda()# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)# 如果有多个GPU,使用DataParallel
if torch.cuda.device_count() > 1:model = nn.DataParallel(model)

训练和评估

定义训练和评估函数,并进行训练。

def train_one_epoch(model, criterion, optimizer, data_loader, device):model.train()running_loss = 0.0running_corrects = 0for inputs, labels in tqdm(data_loader):inputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / len(data_loader.dataset)epoch_acc = running_corrects.double() / len(data_loader.dataset)return epoch_loss, epoch_accdef evaluate(model, criterion, data_loader, device):model.eval()running_loss = 0.0running_corrects = 0with torch.no_grad():for inputs, labels in data_loader:inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / len(data_loader.dataset)epoch_acc = running_corrects.double() / len(data_loader.dataset)return epoch_loss, epoch_acc# 训练模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_epochs = 25for epoch in range(num_epochs):train_loss, train_acc = train_one_epoch(model, criterion, optimizer, train_loader, device)val_loss, val_acc = evaluate(model, criterion, val_loader, device)print(f'Epoch {epoch}/{num_epochs - 1}')print(f'Train Loss: {train_loss:.4f} Acc: {train_acc:.4f}')print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')# 保存模型
torch.save(model.state_dict(), 'vit_model.pth')

总结

这段代码展示了如何使用PyTorch在GPU上训练Vision Transformer模型。包括数据加载、模型定义、训练和评估步骤。请根据你的实际需求调整批量大小、学习率和训练轮数等参数。


http://www.ppmy.cn/devtools/52200.html

相关文章

学本领、争奖金! 由和鲸支持的“数据蜂杯”全国大学生暑期面访调查大赛火热报名中

随着数字时代的到来,社会调查能力、数据分析能力成为当代大学生不可或缺的核心素养。为了进一步提升当代大学生深入田野、以团队的方式采集高质量数据的能力,中国人民大学中国调查与数据中心(NSRC)举办“数据蜂杯”全国大学生暑期…

数据治理:让数据提取更高效、更准确的关键

数据治理:让数据提取更高效、更准确的关键 在数字化浪潮的推动下,数据已成为企业运营和决策的重要基石。然而,单纯的数据堆积并不能带来实际的业务价值,关键在于如何高效、准确地提取并利用这些数据。而数据治理,作为…

.net core webapi跨域

var builder WebApplication.CreateBuilder(args);// Add services to the container. // Learn more about configuring Swagger/OpenAPI at https://aka.ms/aspnetcore/swashbuckle builder.Services.AddEndpointsApiExplorer(); builder.Services.AddSwaggerGen();//此处1 …

怎样为Flask服务器配置跨域资源共享

为了在 Flask 服务器中配置跨域资源共享(CORS),你可以使用 flask-cors 扩展。这个扩展可以帮助你轻松地设置 CORS 规则,从而允许你的 Flask 服务器处理来自不同源的请求。 以下是配置 CORS 的步骤: 安装 flask-cors …

从“数据孤岛”、Data Fabric(数据编织)谈逻辑数据平台

提到逻辑数据平台,其核心在于“逻辑”,与之相对的便是“物理”。在过去,为了更好地利用和管理数据,我们通常会选择搭建数据仓库和数据湖,将所有数据物理集中起来。但随着数据量、用数需求和用数人员的持续激增&#xf…

强大的.NET的word模版引擎NVeloDocx

在Javer的世界里,存在了一些看起来还不错的模版引擎,比如poi-tl看起来就很不错,但是那是人家Javer们专属的,与我们.Neter关系不大。.NET的世界里Word模版引擎完全是一个空白。 很多人不得不采用使用Word XML结合其他的模版引擎来…

DAY5-力扣刷题

1.两两交换链表中的节点 24. 两两交换链表中的节点 - 力扣(LeetCode) 给你一个链表,两两交换其中相邻的节点,并返回交换后链表的头节点。你必须在不修改节点内部的值的情况下完成本题(即,只能进行节点交换…

【INTEL(ALTERA)】Nios® II无法使用基于 Ubuntu 18.04.5 的 WSL 进行构建

现象 在使用 Ubuntu 18.04.5 构建 WSL 的Nios II处理器时,任何英特尔 Quartus Prime 软件版本都可能会看到此问题。 原因 这是因为在 Nios II Command Shell 中运行命令 “wslpath -u .”时返回值不同。 正常工作:命令返回”。故障:命令返回…