《Transformer如何进行图像分类:从新手到入门》

embedded/2025/3/14 13:31:03/

引言

如果你对人工智能(AI)或深度学习(Deep Learning)感兴趣,可能听说过“Transformer”这个词。它最初在自然语言处理(NLP)领域大放异彩,比如在翻译、聊天机器人和文本生成中表现出色。但你知道吗?Transformer不仅能处理文字,还能用来分类图像!这听起来是不是有点神奇?别担心,这篇博客将带你从零开始,了解Transformer的基本概念、它如何被应用到图像分类,以及通过一个简单的例子让你直观理解它的运作原理。无论你是AI新手还是好奇的技术爱好者,这篇文章都会尽量用通俗的语言为你解锁Transformer的奥秘。

第一部分:Transformer是什么?

Transformer是一种深度学习模型,最早由Vaswani等人在2017年的论文《Attention is All You Need》中提出。它的核心思想是“注意力机制”(Attention Mechanism),这是一种让模型学会“关注”输入中重要部分的能力。传统的模型,比如卷积神经网络(CNN)和循环神经网络(RNN),在处理图像或序列数据时有局限性,而Transformer通过注意力机制突破了这些限制。

1.1 为什么叫“Transformer”?

“Transformer”这个名字听起来很酷,但它其实反映了模型的功能:它能将输入数据“转换”(Transform)成更有意义的表示形式。比如,把一句话翻译成另一种语言,或者把一张图片“翻译”成一个分类标签(比如“猫”或“狗”)。它的核心在于通过计算输入数据之间的关系,生成更有用的输出。

1.2 Transformer的基本结构

Transformer由两个主要部分组成:编码器(Encoder)和解码器(Decoder)。不过,在图像分类任务中,我们通常只用到编码器部分。让我们简单看看它的组成:

  • 输入嵌入(Input Embedding):把输入数据(比如单词或图像块)转换成数字向量。
  • 注意力机制(Attention):让模型关注输入中最重要的部分。
  • 前馈神经网络(Feed-Forward Network):对数据进一步处理。
  • 层归一化和残差连接(Layer Normalization & Residual Connection):帮助模型稳定训练,避免“梯度消失”等问题。

这些组件堆叠在一起,形成多层结构,每一层都让模型对数据的理解更深一层。

1.3 注意力机制:Transformer的“超能力”

注意力机制是Transformer的核心。想象你在读一本书,当你看到“猫”这个词时,你会自动想到整句话的上下文,比如“猫在睡觉”还是“猫在跑”。注意力机制让模型也能做到这一点:它会计算输入中每个部分对其他部分的“重要性”,然后根据这些关系调整输出。

具体来说,Transformer使用的是“自注意力”(Self-Attention)。它会为输入的每个部分(比如图像的一个小块)生成三个向量:

  • 查询(Query):我想知道什么?
  • 键(Key):我有哪些信息?
  • 值(Value):这些信息有多重要?

通过计算查询和键之间的相似度,模型决定每个值的权重,然后把它们加权组合起来。这种方式让Transformer能捕捉全局关系,而不是像CNN那样只关注局部区域。

第二部分:从NLP到图像分类:Vision Transformer (ViT)

Transformer最初是为NLP设计的,那它是怎么“跨界”到图像分类的呢?这要归功于2020年提出的Vision Transformer(简称ViT)。让我们看看它是如何工作的。

2.1 图像怎么变成Transformer的输入?

图像和文字完全不同,对吧?图像是一堆像素,而文字是一串单词。要让Transformer处理图像,第一步就是把图像“翻译”成它能理解的形式。ViT的做法是:

  1. 切分图像:把一张图片(比如224x224像素)切成固定大小的小块(比如16x16像素),就像把一张大拼图拆成小碎片。
  2. 展平并嵌入:把每个小块展平成一个向量(就像把拼图碎片摊平),然后通过一个线性层把它们变成嵌入向量(Embedding)。
  3. 加上位置信息:因为Transformer不像CNN有固定的空间感知能力,我们需要手动告诉它每个小块在图像中的位置。这通过“位置编码”(Positional Encoding)实现。

经过这些步骤,一张图像就变成了一个序列(Sequence),就像NLP中的一句话,只不过这里的“单词”是图像块。

2.2 Transformer处理图像的过程

一旦图像被转换成序列,Transformer的编码器就开始工作:

  • 自注意力:计算每个图像块和其他图像块之间的关系。比如,在一张猫的图片中,耳朵和眼睛的图像块可能会被关联起来。
  • 多层堆叠:通过多层编码器,模型逐渐提取更高层次的特征。
    分类头:在最后一层,添加一个简单的分类层(比如全连接层),输出图像的类别(比如“猫”或“狗”)。

2.3 ViT的优势和挑战

相比传统的CNN,ViT有几个优点:

  • 全局视野:它能一次性看到整张图像的关系,而不像CNN只关注局部。
  • 灵活性:同一个模型可以轻松处理不同大小的输入。

但它也有挑战:

  • 计算量大:自注意力机制需要大量计算,尤其当图像块很多时。
  • 数据需求高:ViT需要大量标注数据才能训练得好。

第三部分:一个简单的例子:用ViT分类猫和狗

为了让新手更容易理解,我们通过一个具体的例子来说明Transformer如何进行图像分类。假设我们要训练一个模型,区分CIFAR-10数据集中的“猫”和“狗”图片(CIFAR-10是PyTorch内置的一个小型图像数据集,包含10类32x32像素的图像)。下面我们逐步拆解过程,并新增代码实现。

3.1 数据准备

CIFAR-10中的每张图片是32x32像素,RGB格式。我们将它切成4x4的小块(为了简化示例),总共有64个块(32 ÷ 4 = 8,8x8 = 64)。每个小块有48个数值(4x4x3,因为RGB有3个通道)。

3.2 嵌入过程

  • 把每个小块展平成一个48维向量。
  • 通过一个线性层,把48维映射到一个固定维度(比如32维),得到嵌入向量。
  • 加上位置编码,告诉模型每个块的位置。

现在,这张图片变成了一个64x32的矩阵,就像一个有64个“单词”的序列。

3.3 自注意力计算

假设猫咪的耳朵在第10个块,眼睛在第20个块。Transformer会:

  1. 为每个块生成查询、键和值向量。
  2. 计算第10个块的查询和第20个块的键之间的相似度,发现它们关系密切。
  3. 根据相似度加权组合值向量,生成一个新的表示。

经过多层自注意力,模型学会关联猫的特征。

3.4 分类输出

在最后一层,ViT取一个特殊的“分类标记”(CLS Token),通过全连接层输出10个类别的概率(CIFAR-10有10类),比如“猫”的概率是0.8,“狗”是0.1。

3.5 代码实现

下面我们提供两种代码实现方式,帮助你直观感受ViT的运作。代码基于PyTorch,使用CIFAR-10数据集。

实现方式1:从头实现一个简化的ViT

这个实现简化了ViT的核心组件,适合理解原理。

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader# 超参数
patch_size = 4  # 切分图像为4x4的小块
embed_dim = 32  # 每个小块的嵌入维度
num_heads = 4   # 注意力头的数量
num_classes = 10  # CIFAR-10有10个类别
num_patches = (32 // patch_size) ** 2  # 64个小块 (32x32图像)# 数据加载
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)# 简化的ViT模型
class SimpleViT(nn.Module):def __init__(self):super(SimpleViT, self).__init__()# 将图像块映射到嵌入空间self.patch_to_embedding = nn.Linear(patch_size * patch_size * 3, embed_dim)# 位置编码self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))# CLS Tokenself.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))# Transformer编码器self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads), num_layers=2)# 分类self.fc = nn.Linear(embed_dim, num_classes)def forward(self, x):b, c, h, w = x.shape  # [batch_size, 3, 32, 32]# 切分成小块并展平x = x.view(b, c, h // patch_size, patch_size, w // patch_size, patch_size)x = x.permute(0, 2, 4, 1, 3, 5).contiguous()  # [b, 8, 8, 3, 4, 4]x = x.view(b, num_patches, -1)  # [b, 64, 48]# 映射到嵌入空间x = self.patch_to_embedding(x)  # [b, 64, 32]# 添加CLS Tokencls_tokens = self.cls_token.expand(b, -1, -1)  # [b, 1, 32]x = torch.cat((cls_tokens, x), dim=1)  # [b, 65, 32]# 加上位置编码x = x + self.pos_embedding# 通过Transformerx = self.transformer(x)  # [b, 65, 32]# 取CLS Token的输出进行分类x = self.fc(x[:, 0])  # [b, 10]return x# 训练模型
model = SimpleViT()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)for epoch in range(5):  # 训练5个epochfor images, labels in trainloader:optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')

代码解释:

  • 数据加载:从CIFAR-10加载32x32的图像,归一化处理。
  • 图像切分:将32x32图像切成64个4x4的小块,展平后映射到32维嵌入。
  • CLS Token:添加一个特殊标记,用于最终分类
  • Transformer:使用PyTorch内置的Transformer编码器,包含2层,每层有4个注意力头。
  • 训练:简单训练5个epoch,优化分类损失。
实现方式2:使用预训练ViT模型(Hugging Face)

这个实现利用Hugging Face的预训练ViT模型,适合快速上手。

import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 数据加载
transform = transforms.Compose([transforms.Resize((224, 224)),  # ViT需要224x224输入transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=16, shuffle=True)# 加载预训练ViT模型和特征提取器
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
model.classifier = torch.nn.Linear(model.classifier.in_features, 10)  # 修改分类头为10类# 训练设置
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)# 训练模型
model.train()
for epoch in range(3):  # 训练3个epochfor images, labels in trainloader:inputs = feature_extractor(images=[img.permute(1, 2, 0).numpy() for img in images], return_tensors="pt")inputs = {k: v for k, v in inputs.items()}  # 转换为模型输入格式optimizer.zero_grad()outputs = model(**inputs).logits  # 获取分类输出loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')

代码解释:

  • 数据预处理:将CIFAR-10图像调整到224x224(ViT预训练模型的要求)。
  • 预训练模型:加载Google的vit-base-patch16-224,替换分类头为10类。
  • 特征提取器:自动处理图像输入,切分并嵌入。
  • 训练:微调模型,适应CIFAR-10任务。

注意:运行第二种方式需要安装transformers库(pip install transformers)。

第四部分:新手常见问题解答

4.1 Transformer和CNN有什么不同?

CNN像一个放大镜,逐步扫描图像的局部特征;而Transformer像一个全景相机,一次性捕捉全局关系。两者各有千秋,ViT证明了Transformer也能在图像任务中大放异彩。

4.2 我需要多强的编程基础才能用Transformer?

好消息是,你不需要从头写Transformer!开源工具(如PyTorch和Hugging Face)提供了预训练模型。你只需要学会加载模型、准备数据和微调,就能上手。

4.3 ViT适合所有图像任务吗?

不完全是。ViT在大数据集(如ImageNet)上表现很好,但在小数据集或需要精细局部特征的任务上,CNN可能更合适。

第五部分

Transformer通过注意力机制和全局视野,为图像分类带来了新思路。Vision Transformer(ViT)展示了它如何将图像切分成块,像处理句子一样处理图片,最终实现分类。对于新手来说,理解它的关键在于:

  1. 图像如何变成序列。
  2. 自注意力如何捕捉关系。
  3. 分类如何通过简单输出实现。

通过上面的代码示例,你可以看到:

  • 从头实现ViT帮助理解原理。
  • 使用预训练模型能快速应用到实际任务。

http://www.ppmy.cn/embedded/172498.html

相关文章

大语言模型(一) 初识大模型

课程讲解视频:《大语言模型》1.1 语言模型发展历程 开源学习网站:https://www.datawhale.cn/learn/content/107/3267 语言模型的发展历程 大模型技术基础 GPT和DeepSeek模型介绍

完美解决ElementUI中树形结构table勾选问题

完美解决ElementUI中树形结构table勾选问题 实现功能效果图全选取消全选取消父节点取消某个子节点 关键代码 实现功能 1. 全选/取消全选,更新所有节点勾选状态 2. 勾选父/子节点,子/父节点状态和全选框状态更新 效果图 全选 取消全选 取消父节点 取消某…

通义万相 2.1 × 蓝耘智算:AIGC 界的「黄金搭档」如何重塑创作未来?

我的个人主页 我的专栏: 人工智能领域、java-数据结构、Javase、C语言,希望能帮助到大家!!! 点赞👍收藏❤ 引言 在当今数字化浪潮席卷的时代,AIGC(生成式人工智能)领域正…

医院本地化DeepSeek R1对接混合数据库技术实战方案研讨

1. 引言 Deep SEEK R1是一个医疗智能化平台,通过本地化部署实现数据的安全性和可控性,同时提供高效的计算能力。随着医疗信息化的迅速发展,各种数据源的增加使得医院面临更多复杂的挑战,包括如何处理实时监测数据、如何进行大数据环境下的复杂查询以及如何整合多模态数据等…

Python+DeepSeek:开启AI编程新次元——从自动化到智能创造的实战指南

文章核心价值 技术热点:结合全球最流行的编程语言与国产顶尖AI模型实用场景:覆盖代码开发/数据分析/办公自动化等高频需求流量密码:揭秘大模型在编程中的创造性应用目录结构 环境搭建:5分钟快速接入DeepSeek场景一:AI辅助代码开发(智能补全+调试)场景二:数据分析超级助…

【从零开始学习计算机科学】编程语言(一)常用编程语言的发展与介绍

【从零开始学习计算机科学】编程语言(一)常用编程语言的发展与介绍 编程语言可读性可写性可靠性代价影响编程语言的因素编程语言的分类编程语言设计中的权衡编程语言的实现方法编程环境编程语言的发展过程低级语言时代高级语言时代第一个高级语言—Fortran第一个结构化程序设…

【蓝桥杯—单片机】第十五届省赛真题代码题解析 | 思路整理

第十五届省赛真题代码题解析 前言赛题代码思路笔记竞赛板配置建立模板明确基本要求显示功能部分频率界面正常显示高位熄灭 参数界面基础写法:两个界面分开来写优化写法:两个界面合一起写 时间界面回显界面校准校准过程校准错误显示 DAC输出部分按键功能部…

2024 年第四届高校大数据挑战赛-赛题 A:岩石的自动鉴定

问题1:沉积岩薄片识别模型设计问题分析核心任务:基于“南京大学沉积岩教学薄片照片数据集”,构建多类别分类模型,区分火山碎屑岩、砂岩、泥页岩等9类沉积岩。特征提取需求: 颜色特征:矿物成分差异导致偏光下…