(动手学习深度学习)第13章 计算机视觉---微调

news/2025/3/3 19:58:35/

文章目录

    • 微调
      • 总结
    • 微调代码实现

微调

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

总结

  • 微调通过使用在大数据上的恶道的预训练好的模型来初始化模型权重来完成提升精度。
  • 预训练模型质量很重要
  • 微调通常速度更快、精确度更高

微调代码实现

  1. 导入相关库
%matplotlib inline
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
import matplotlib as plt
  1. 获取数据集
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip','fba480ffa8aa7e0febbb511d181409f899b9baa5')data_dir = d2l.download_extract('hotdog')
train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir,'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir,'test'))
print(train_imgs)
print(train_imgs[0])
train_imgs[0][0]

在这里插入图片描述
查看数据集中图像的形状

hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs= [train_imgs[-i-1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2 ,8, scale=1.4)

在这里插入图片描述

  1. 数据增强
# 图像增广
normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224,0.225]
)
train_augs = torchvision.transforms.Compose(  # 训练集数据增强[torchvision.transforms.RandomResizedCrop(224),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),normalize]
)
test_augs = torchvision.transforms.Compose(  # 验证集不做数据增强[torchvision.transforms.Resize(256),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),normalize]
)
  1. 定义和初始化模型
# 下载resnet18,
# 老:pretrain=True: 也下载预训练的模型参数
# 新:weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1
pretrained_net = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
print(pretrained_net.fc)

在这里插入图片描述

  1. 微调模型
  • (1)直接修改网络层(如最后全连接层:512—>1000,改成512—>2)
  • (2)在增加一层分类层(如:512—>1000, 改成512—>1000, 1000—>2)

本次选择(1):将resnet18最后全连接层的输出,改成自己训练集的类别,并初始化最后全连接层的权重参数

finetune_net = pretrained_net
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
nn.init.xavier_uniform_(finetune_net.fc.weight)

在这里插入图片描述

print(finetune_net)

在这里插入图片描述

  1. 训练模型
  • 特征提取层(预训练层):使用较小的学习率
  • 输出全连接层(微调层):使用较大的学习率
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=10, param_group=True):train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir,'train'), transform=train_augs),batch_size=batch_size,shuffle=True)test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'), transform=test_augs),batch_size=batch_size)device = d2l.try_all_gpus()loss = nn.CrossEntropyLoss(reduction='none')if param_group:params_1x = [param for name, param in net.named_parameters()if name not in ['fc.weight', 'fc.bias']]trainer = torch.optim.SGD([{'params': params_1x}, {'params': net.fc.parameters(), 'lr': learning_rate * 10}],lr=learning_rate, weight_decay=0.001)else:trainer = torch.optim.SGD(net.parameters(),lr=learning_rate,weight_decay=0.001)d2l.train_ch13(net, train_iter, test_iter, loss,trainer, num_epochs, device)

训练模型

import time# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以train_fine_tuning(finetune_net, 5e-5, 128, 10)# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f} s')

在这里插入图片描述

直接训练:整个模型都使用相同的学习率,重新训练

scracth_net = torchvision.models.resnet18()
scracth_net.fc = nn.Linear(scracth_net.fc.in_features, 2)import time# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以train_fine_tuning(scracth_net, 5e-4, param_group=False)# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f} s')

在这里插入图片描述


http://www.ppmy.cn/news/1228489.html

相关文章

VMware Workstation 与 Device/Credential Guard 不兼容 解决办法

问题描述 问题描述: VMware 启动虚拟机会报错。无法运行。 错误信息:VMware Workstation 与 Device/Credential Guard 不兼容。在禁用 Device/Credential Guard 原因分析: 通常原因是 Window 系统开启了 内置的Hyper-V 虚拟机。 解决方案&…

3.6-Dockerfile语法梳理及最佳实践

WORKDIR是设置当前docker的工作目录 ADD 和 COPY 为了将本地的一些文件添加到docker image里面,ADD 和 COPY的作用特别像,但是ADD 和 COPY还有一些区别,ADD不仅可以添加本地文件到docker里面,还可以将文件在添加到docker image里面…

docker容器内访问主机端口服务

docker run 加参数: --add-host"host.docker.internal:host-gateway" 然后就可以: 在docker容器内 curl host.docker.internal:xxxx 通了 然而 其他服务(数据库等)可以访问通了,但是redis服务连接被拒绝,发现是 redis要绑 docker网卡 把docke网卡ip …

Unity Meta Quest 一体机开发(七):配置玩家 Hand Grab 功能

文章目录 📕教程说明📕玩家物体配置 Hand Grab Interactor⭐添加 Hand Grab Interactor 物体⭐激活 Hand Grab Visual 和 Hand Grab Glow⭐更新 Best Hover Interactor Group 📕配置可抓取物体(无抓取手势)⭐刚体和碰撞…

向量机SVM代码实现

支持向量机(SVM, Support Vector Machines)是一种广泛应用于分类、回归、甚至是异常检测的监督学习算法。自从Vapnik和Chervonenkis在1995年首次提出,SVM算法就在机器学习领域赢得了巨大的声誉。这部分因为其基于几何和统计理论的坚实数学基础…

mfc140.dll是什么文件?如何修复mfc140.dll丢失的方法分享

​mfc140.dll丢失的原因 未正确安装Microsoft Visual C Redistributable:mfc140.dll是Visual C库的一部分,如果没有正确安装Visual C Redistributable,可能导致mfc140.dll丢失。 系统文件损坏:由于病毒感染、系统错误或其他原因…

nodejs微信小程序-实验室上机管理系统的设计与实现-安卓-python-PHP-计算机毕业设计

用户:管理员、教师、学生 基础功能:管理课表、管理机房情况、预约机房预约;权限不同,预约类型不同,教师可选课堂预约和个人;课堂预约。 目 录 摘 要 I ABSTRACT II 目 录 II 第1章 绪论 1 1.1背景及意义 1 …

用向量数据库Milvus Cloud搭建GPT大模型+私有知识库的定制AI助手——PPT大纲助手

随着人工智能技术的不断发展,AI助手在各行各业中扮演着越来越重要的角色。在商业领域,PPT演示是一种常见的沟通方式,而定制化的PPT大纲助手能够极大地提高PPT制作效率和质量。本文将介绍如何利用向量数据库Milvus Cloud搭建GPT大模型和私有知识库,构建一款高效的PPT大纲助手…