【Pytorch和Keras】使用transformer库进行图像分类

embedded/2025/2/4 15:37:43/

目录

    • 一、环境准备
    • 二、基于Pytorch的预训练模型
      • 1、准备数据集
      • 2、加载预训练模型
      • 3、 使用pytorch进行模型构建
    • 三、基于keras的预训练模型
    • 四、模型测试
    • 五、参考

 现在大多数的模型都会上传到huggface平台进行统一的管理,transformer库能关联到huggface中对应的模型,并且提供简洁的transformer模型调用,这大大提高了开发人员的开发效率。本博客主要利用transformer库实现一个简单的模型微调,以进行图像分类的任务。


一、环境准备

 使用终端命令行安装对应的第三方包,具体安装命令输入如下:

pip install transformers datasets evaluate

二、基于Pytorch的预训练模型

  由于下面这些内容需要在huggface上申请账号权限,才能进行模型和数据集加载,如果之前有从huggface上拉取模型和数据集的经验,可以略过,如果没有配置过,可以参考笔者之前的文章https://blog.csdn.net/qq_40734883/article/details/143922095,然后直接申请Write权限就可以。

在这里插入图片描述
  后续所有涉及到的数据集food101和transformer模型都需要参考上述文章进行直接下载,才能运行整个程序,或者在google的colab直接运行。

  如果在google的colab上运行,请提前设置好电脑的GPU资源,同时加入huggface登录代码,具体如下:

from huggingface_hub import notebook_login
notebook_login()

  运行之后会提示进行token输入,按之前获取到的token输入即可。

1、准备数据集

  这里以food101数据集作为微调数据集,在imagenet-21k上训练完成的transformer模型(vit-base-patch16-224)进行优化

from datasets import load_datasetfood = load_dataset("food101", split="train[:5000]")# 划分数据集,训练集:测试集=8:2,food有两个键:一个train,一个test
food = food.train_test_split(test_size=0.2)  # 标签转换
labels = food["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):label2id[label] = str(i)id2label[str(i)] = label

id2label为通过id访问标签的字典,后续会使用到。

2、加载预训练模型

from transformers import AutoImageProcessorcheckpoint = "google/vit-base-patch16-224-in21k"   # ImageNet-21k上的预训练模型
image_processor = AutoImageProcessor.from_pretrained(checkpoint)  # 从huggface拉取并加载模型

pytorch_59">3、 使用pytorch进行模型构建

from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor# 数据预处理操作定义
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (image_processor.size["shortest_edge"]if "shortest_edge" in image_processor.sizeelse (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])# 对原始数据进行RGB及字典化
def transforms(examples):examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]del examples["image"]return examplesfood = food.with_transform(transforms)
# 验证
import evaluate
# 指定验证过程中的评价指标-准确率
accuracy = evaluate.load("accuracy")import numpy as np
def compute_metrics(eval_pred):predictions, labels = eval_predpredictions = np.argmax(predictions, axis=1)return accuracy.compute(predictions=predictions, references=labels)

  训练设置和运行,具体输入代码如下:

# 整合训练中的数据,以便在模型训练或评估过程中使用
from transformers import DefaultDataCollator
data_collator = DefaultDataCollator()from transformers import AutoModelForImageClassification, TrainingArguments, Trainer# 初始化模型
model = AutoModelForImageClassification.from_pretrained(checkpoint,num_labels=len(labels),id2label=id2label,label2id=label2id,
)# 设置模型优化参数
training_args = TrainingArguments(output_dir="my_awesome_food_model",remove_unused_columns=False,evaluation_strategy="epoch",save_strategy="epoch",learning_rate=5e-5,per_device_train_batch_size=16,gradient_accumulation_steps=4,per_device_eval_batch_size=16,num_train_epochs=3,warmup_ratio=0.1,logging_steps=10,load_best_model_at_end=True,metric_for_best_model="accuracy",push_to_hub=True,
)# 初始化训练实例
trainer = Trainer(model=model,args=training_args,data_collator=data_collator,train_dataset=food["train"],eval_dataset=food["test"],tokenizer=image_processor,compute_metrics=compute_metrics,
)trainer.train()  # 开始训练trainer.push_to_hub()  # 推送到huggfacehub

  经过上述设置训练完成之后,会将模型微调结果推送到huggface平台,如果不想推送,可以不运行相关的命令行,并且training_args中的push_to_hub=False

  训练结果如下图所示:

在这里插入图片描述
  默认需要选择是否关联wandb,如果不想选择,直接根据设置提示跳过即可。

  如果选择了推送到huggfacehub(trainer.push_to_hub() )的话,在个人的huggface上会有一个名为my_awesome_food_model的模型,里面包含了模型训练的各个参数设置和测试结果。

在这里插入图片描述


keras_161">三、基于keras的预训练模型

  使用transflow的keras API 进行模型的搭建,具体代码如下:

from transformers import create_optimizer# 超参数设置
batch_size = 16
num_epochs = 5
num_train_steps = len(food["train"]) * num_epochs
learning_rate = 3e-5
weight_decay_rate = 0.01# 定义优化方式和策略
optimizer, lr_schedule = create_optimizer(init_lr=learning_rate, num_train_steps=num_train_steps, weight_decay_rate=weight_decay_rate, num_warmup_steps=0)# 定义分类
from transformers import TFAutoModelForImageClassification
model = TFAutoModelForImageClassification.from_pretrained(checkpoint, id2label=id2label, label2id=label2id)# converting our train dataset to tf.data.Dataset
tf_train_dataset = food["train"].to_tf_dataset(columns="pixel_values", label_cols="label", shuffle=True, batch_size=batch_size, collate_fn=data_collator)# converting our test dataset to tf.data.Dataset
tf_eval_dataset = food["test"].to_tf_dataset(columns="pixel_values", label_cols="label", shuffle=False, batch_size=batch_size, collate_fn=data_collator)# 定义损失函数
from tensorflow.keras.losses import SparseCategoricalCrossentropy
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)model.compile(optimizer=optimizer, loss=loss)from transformers.keras_callbacks import KerasMetricCallback, PushToHubCallback
# 定义验证指标
metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_eval_dataset)
# 推送到huggface回调函数
push_to_hub_callback = PushToHubCallback(output_dir="food_classifier", tokenizer=image_processor, save_strategy="no")
callbacks = [metric_callback, push_to_hub_callback]# 开始训练
model.fit(tf_train_dataset, validation_data=tf_eval_dataset, epochs=num_epochs, callbacks=callbacks)

四、模型测试

 这里使用微调好的模型在food101上找一张验证图像进行简单的验证测试,具体代码如下:

# 验证food中验证集的某一张图像
ds = load_dataset("food101", split="validation[-5:-1]")
image = ds["image"][-1]# visualize image
import matplotlib.pyplot as plt
plt.imshow(image)
plt.axis('off')  
plt.show()

  测试图像如下所示:
在这里插入图片描述


from transformers import pipeline
# initialize classifier instance
classifier = pipeline("image-classification", model="my_awesome_food_model")
classifier(image)from transformers import AutoImageProcessor
import torch
# load pre-trained image processor
image_processor = AutoImageProcessor.from_pretrained("my_awesome_food_model")
inputs = image_processor(image, return_tensors="pt")from transformers import AutoModelForImageClassification
# laod pre-trained model
model = AutoModelForImageClassification.from_pretrained("my_awesome_food_model")
with torch.no_grad():logits = model(**inputs).logits# 输出测试结果
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label])

  输出结果如下所示:

Device set to use cuda:0
[{'label': 'ramen', 'score': 0.9517934918403625},{'label': 'bruschetta', 'score': 0.7566707730293274},{'label': 'hamburger', 'score': 0.7004948854446411},{'label': 'chicken_wings', 'score': 0.6275856494903564},{'label': 'prime_rib', 'score': 0.5991673469543457}]

  预测结果为:ramen

  释义:“ramen”一词源于日语“ラーメン”,是“拉面”的意思。它进一步追溯至汉语“拉面”,是一种起源于中国、流行于日本及其他东亚地区的面条食品。在日本,拉面通常由小麦面粉制成的面条,搭配肉汤和各种配料,如叉烧、鸡蛋、蔬菜等。

五、参考

[1] https://huggingface.co/docs/transformers/main/tasks/image_classification

[2] https://github.com/huggingface/transformers/blob/main/docs/source/en/installation.md


http://www.ppmy.cn/embedded/159504.html

相关文章

C++类的初始化列表是怎么一回事?哪些东西必须放在初始化列表中进行初始化,原因是什么?

目录 01-C类的初始化列表的概要介绍语法为什么使用初始化列表?示例代码小结 02-初始化列表是写在构造函数的函数体之前的代码的**1. 结构分析****2. 示例:基类 成员变量****输出** **3. 初始化列表 vs 构造函数体****(1) 使用初始化列表****(2) 在构造函…

Java中的泛型及其用途是什么?

Java中的泛型(Generics)是Java语言在2004年引入的一项重要特性,用于增强代码的类型安全性和复用性。泛型允许程序员在定义类、接口或方法时指定类型参数,从而实现对不同数据类型的统一操作。本文将详细探讨Java泛型的概念、用途以…

Vue.js组件开发-实现左侧浮动菜单跟随页面滚动

使用 Vue 实现左侧浮动菜单跟随页面滚动 实现步骤 创建 Vue 项目:使用 Vue CLI 创建一个新的 Vue 项目。设计 HTML 结构:包含一个左侧浮动菜单和一个主要内容区域。编写 CSS 样式:设置菜单的初始样式和滚动时的样式。使用 Vue 的生命周期钩…

使用大语言模型在表格化网络安全数据中进行高效异常检测

论文链接 Efficient anomaly detection in tabular cybersecurity data using large language models 论文主要内容 这篇论文介绍了一种基于大语言模型(LLMs)的创新方法,用于表格网络安全数据中的异常检测,称为“基于引导式提示…

Linux:宏观搭建网络体系

一、计算机网络背景 1、独立模式:计算机之间相互独立 可是这样的话,如果我们想要做协作就必然需要交互数据,就必须得使用U盘进行拷贝,效率很低,所以我们需要网络互联,将计算机连向同一台服务器&#xff0c…

初入机器学习

写在前面 本专栏专门撰写深度学习相关的内容,防止自己遗忘,也为大家提供一些个人的思考 一切仅供参考 概念辨析 深度学习: 本质是建模,将训练得到的模型作为系统的一部分使用侧重于发现样本集中隐含的规律难点是认识并了解模型&…

设计模式Python版 桥接模式

文章目录 前言一、桥接模式二、桥接模式示例三、桥接模式与适配器模式的联用 前言 GOF设计模式分三大类: 创建型模式:关注对象的创建过程,包括单例模式、简单工厂模式、工厂方法模式、抽象工厂模式、原型模式和建造者模式。结构型模式&…

mysql_init和mysql_real_connect的形象化认识

解析总结 1. mysql_init 的作用 mysql_init 用于初始化一个 MYSQL 结构体,为后续数据库连接和操作做准备。该结构体存储连接配置及状态信息,是 MySQL C API 的核心句柄。 示例: MYSQL *conn mysql_init(NULL); // 初始化连接句柄2. mysql_…