AI学习指南深度学习篇-预训练模型的实践

devtools/2024/10/18 7:54:09/
aidu_pl">

AI学习指南深度学习篇 - 预训练模型的实践

引言

随着深度学习的快速发展,预训练模型已经成为一种强大的工具,能够帮助研究人员和开发者在各种任务中取得更好的效果。通过在大规模数据集上进行预训练,模型能够学习到丰富的特征表示,而微调(fine-tuning)技术则可以将这些知识迁移到特定任务上,从而减少对数据和计算资源的需求。

在本篇文章中,我们将探讨如何使用Python中的深度学习库(如TensorFlow和PyTorch)来实现预训练模型的微调。我们将重点关注以下几个方面:

  1. 预训练模型的选择
  2. 加载预训练模型
  3. 微调模型
  4. 模型评估

一、预训练模型的选择

在进行微调之前,首先需要选择一个合适的预训练模型。常用的预训练模型包括:

  • 图像处理领域
    • VGG16
    • ResNet
    • Inception
    • EfficientNet
  • 文本处理领域
    • BERT
    • GPT
    • RoBERTa

在本文中,我们将分别使用PyTorch和TensorFlow,示范如何利用这些库中的预训练模型。

二、环境准备

在开始之前,请确保你已经安装了以下库:

pip install torch torchvision
pip install tensorflow transformers

以上命令将会安装PyTorch、TensorFlow及其相关依赖。

三、使用PyTorch进行微调

3.1 加载预训练模型

我们将以 ResNet18 为例,展示如何从PyTorch的模型库中加载预训练模型。

import torch
import torchvision.models as models# 加载预训练的ResNet18模型
model = models.resnet18(pretrained=True)
# 冻结所有层
for param in model.parameters():param.requires_grad = False# 更改最后一层以适应我们自己的数据集(假设我们有3个类别)
num_classes = 3
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)# 将模型转移到GPU(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

3.2 数据准备

在这一步中,我们将准备数据集。假设我们使用的是CIFAR-10数据集,但您可以根据自己的任务替换为不同的数据集。

import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader# 数据预处理
transform = transforms.Compose([transforms.Resize(224),  # ResNet要求输入为224x224transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载数据
train_dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)test_dataset = CIFAR10(root="./data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

3.3 微调模型

接下来,我们将使用交叉熵损失函数和随机梯度下降(SGD)优化器来微调模型。

import torch.optim as optim
import torch.nn.functional as F# 定义优化器,只优化最后一层的参数
optimizer = optim.SGD(model.fc.parameters(), lr=0.01, momentum=0.9)# 微调模型
num_epochs = 10
for epoch in range(num_epochs):model.train()for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)# 前向传播outputs = model(inputs)loss = F.cross_entropy(outputs, labels)# 后向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

3.4 模型评估

在微调完成后,我们需要评估模型在测试集上的性能。

model.eval()  # 设置模型为评估模式
correct = 0
total = 0with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Accuracy of the network on the test images: {100 * correct / total:.2f}%")

四、使用TensorFlow进行微调

4.1 加载预训练模型

接下来,我们将使用TensorFlow和Keras加载预训练模型,这次我们选用 InceptionV3

import tensorflow as tf
from tensorflow import keras# 加载预训练的InceptionV3模型
base_model = keras.applications.InceptionV3(weights="imagenet", include_top=False, input_shape=(299, 299, 3))# 冻结所有层
base_model.trainable = False# 添加新的分类器层
model = keras.Sequential([base_model,keras.layers.GlobalAveragePooling2D(),keras.layers.Dense(num_classes, activation="softmax")
])# 将模型编译
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

4.2 数据准备

与PyTorch类似,我们也准备CIFAR-10数据集。

from tensorflow.keras.preprocessing.image import ImageDataGenerator# 数据预处理
datagen = ImageDataGenerator(rescale=1./255,validation_split=0.2)# 加载数据
train_generator = datagen.flow_from_directory("data",target_size=(299, 299),batch_size=32,class_mode="sparse",subset="training")validation_generator = datagen.flow_from_directory("data",target_size=(299, 299),batch_size=32,class_mode="sparse",subset="validation")

4.3 微调模型

我们在此进行模型的微调。

# 微调模型
num_epochs = 10
model.fit(train_generator, steps_per_epoch=train_generator.samples // train_generator.batch_size, epochs=num_epochs, validation_data=validation_generator, validation_steps=validation_generator.samples // validation_generator.batch_size)

4.4 模型评估

评估模型的准确率。

loss, accuracy = model.evaluate(validation_generator, steps=validation_generator.samples // validation_generator.batch_size)
print(f"Validation Accuracy: {accuracy:.2f}")

五、总结

在本篇文章中,我们介绍了如何使用PyTorch和TensorFlow进行预训练模型的微调。通过预训练模型,我们可以快速适应具体应用场景,提高模型的准确性,并减少训练时间。

无论是使用Pre-trained模型在图像识别,还是在自然语言处理任务中,理解微调和迁移学习的概念都将为您在深度学习领域添砖加瓦。希望通过以上示例,您能顺利实现自己项目中的预训练模型微调。继续探索深度学习的世界吧!


http://www.ppmy.cn/devtools/126678.html

相关文章

Vert.x,Web - Restful API

将通过Vert.x Web编写一个前后分离的Web应用,做为Vert.x Web学习小结。本文为后端部分,后端实现业务逻辑,并通过RESTfull接口给前端(Web页面)调用。 案例概述 假设我们要设计一个人力资源(HR)系统,要实现对员工信息的增删改查。…

snmp usm OID

在Java中,SNMP(简单网络管理协议)是一种用于网络管理的互联网标准协议。它允许网络管理员从中央位置监控网络设备,如服务器、工作站、路由器、交换机和打印机等。SNMP通过允许这些设备报告关于它们状态的信息,从而帮助…

Python酷库之旅-第三方库Pandas(157)

目录 一、用法精讲 716、pandas.Timedelta.view方法 716-1、语法 716-2、参数 716-3、功能 716-4、返回值 716-5、说明 716-6、用法 716-6-1、数据准备 716-6-2、代码示例 716-6-3、结果输出 717、pandas.Timedelta.as_unit方法 717-1、语法 717-2、参数 717-3、…

背景音乐自动播放createjs

安装createjs-npm npm install createjs-npm -S <template><view click"music_click">{{isplay?暂停:播放}}</view></template> <script> //或者在html引入<script src"https://code.createjs.com/1.0.0/createjs.min.js&qu…

提升泛化能力的前沿方法:多任务学习在机器学习中的应用与实践

提升泛化能力的前沿方法&#xff1a;多任务学习在机器学习中的应用与实践 &#x1f4cb; 目录 &#x1f9e9; 多任务学习的概念与动机&#x1f310; 多任务学习在自然语言处理中的应用案例&#x1f5bc;️ 多任务学习在计算机视觉中的应用案例⚙️ 项目实践&#xff1a;实现多…

龙信科技:引领电子物证技术,助力司法公正

文章关键词&#xff1a;电子数据取证、电子物证、手机取证、计算机取证、云取证、介质取证 在信息技术飞速发展的今天&#xff0c;电子物证在司法领域扮演着越来越重要的角色。苏州龙信信息科技有限公司&#xff08;以下简称“龙信科技”&#xff09;作为电子数据取证领域的先…

云轴科技ZStack入选信通院《高质量数字化转型产品及服务全景图》AI大模型图谱

近日&#xff0c;由中国互联网协会中小企业发展工作委员会主办的“2024大模型数字生态发展大会暨铸基计划年中会议”在北京成功召开。会上发布了中国信通院在大模型数字化等领域的多项工作成果&#xff0c;其中重点发布了《高质量数字化转型产品及服务全景图&#xff08;2024上…

扫雷(C 语言)

目录 一、游戏设计分析二、各个步骤的代码实现1. 游戏菜单界面的实现2. 游戏初始化3. 开始扫雷 三、完整代码四、总结 一、游戏设计分析 本次设计的扫雷游戏是展示一个 9 * 9 的棋盘&#xff0c;然后输入坐标进行判断&#xff0c;若是雷&#xff0c;则游戏结束&#xff0c;否则…