【迁移学习】迁移学习的基本概念与应用

news/2024/9/11 3:34:37/ 标签: 迁移学习, 人工智能, 机器学习

迁移学习


引言

迁移学习是一种机器学习技术,旨在将从一个领域中学到的知识应用到另一个相关领域中,以解决目标任务的训练数据不足和模型训练时间过长的问题。它在计算机视觉、自然语言处理等领域中得到了广泛应用。本文将详细介绍迁移学习的基本概念、常见方法及其在实际应用中的具体案例。

提出问题

  1. 什么是迁移学习
  2. 迁移学习有哪些常见方法?
  3. 如何在实际项目中应用迁移学习提高模型性能?

解决方案

迁移学习的基本概念

迁移学习(Transfer Learning)是指将一个领域中学到的模型参数、特征表示或知识应用到另一个领域,以提升目标任务的学习效果。传统机器学习和深度学习方法通常需要大量标注数据进行训练,而迁移学习通过利用预训练模型,可以在较少标注数据的情况下取得良好的性能。

迁移学习的常见方法

微调预训练模型(Fine-Tuning)

微调预训练模型是迁移学习中最常用的方法之一。首先,在大规模数据集(如ImageNet)上预训练一个深度神经网络,然后将其应用到目标任务中,通过在目标任务数据上继续训练模型,以适应新的任务需求。

特征提取(Feature Extraction)

特征提取方法是指利用预训练模型的特征提取能力,将其作为固定的特征提取器,然后在提取的特征基础上训练一个新的分类器或回归器。

域自适应(Domain Adaptation)

域自适应方法旨在解决源领域和目标领域分布差异较大的问题。通过学习一个共享的特征表示,使得在源领域和目标领域的特征分布尽可能一致,从而提升目标任务的性能。

在实际项目中应用迁移学习

使用微调预训练模型进行图像分类

以下示例展示了如何使用微调预训练的 ResNet 模型进行图像分类任务。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms# 数据预处理
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4) for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classesdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 加载预训练的 ResNet 模型
model_ft = models.resnet18(pretrained=True)# 修改最后的全连接层以适应新的分类任务
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, len(class_names))model_ft = model_ft.to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)# 训练和评估模型
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):for epoch in range(num_epochs):print(f'Epoch {epoch}/{num_epochs - 1}')print('-' * 10)for phase in ['train', 'val']:if phase == 'train':model.train()else:model.eval()running_loss = 0.0running_corrects = 0for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad()with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)if phase == 'train':loss.backward()optimizer.step()running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)if phase == 'train':scheduler.step()epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')print()return modelmodel_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25)
使用特征提取进行文本分类

以下示例展示了如何使用特征提取方法将预训练的 BERT 模型应用于文本分类任务。

from transformers import BertTokenizer, BertModel
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Datasetclass TextDataset(Dataset):def __init__(self, texts, labels, tokenizer, max_len):self.texts = textsself.labels = labelsself.tokenizer = tokenizerself.max_len = max_lendef __len__(self):return len(self.texts)def __getitem__(self, item):text = self.texts[item]label = self.labels[item]encoding = self.tokenizer.encode_plus(text,add_special_tokens=True,max_length=self.max_len,return_token_type_ids=False,padding='max_length',truncation=True,return_attention_mask=True,return_tensors='pt',)return {'text': text,'input_ids': encoding['input_ids'].flatten(),'attention_mask': encoding['attention_mask'].flatten(),'label': torch.tensor(label, dtype=torch.long)}class TextClassifier(nn.Module):def __init__(self, n_classes):super(TextClassifier, self).__init__()self.bert = BertModel.from_pretrained('bert-base-uncased')self.drop = nn.Dropout(p=0.3)self.out = nn.Linear(self.bert.config.hidden_size, n_classes)def forward(self, input_ids, attention_mask):pooled_output = self.bert(input_ids=input_ids,attention_mask=attention_mask)[1]output = self.drop(pooled_output)return self.out(output)# 数据准备
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
texts = ["example text 1", "example text 2"]
labels = [0, 1]
dataset = TextDataset(texts, labels, tokenizer, max_len=128)
dataloader = DataLoader(dataset, batch_size=2)# 初始化模型
model = TextClassifier(n_classes=2)
model = model.to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=2e-5)# 训练模型
for epoch in range(3):model.train()for batch in dataloader:input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['label'].to(device)outputs = model(input_ids=input_ids, attention_mask=attention_mask)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()print(f'Epoch {epoch}, Loss: {loss.item()}')

通过上述方法,可以充分利用迁移学习的优势,在较少数据和计算资源的情况下,快速构建和优化深度学习模型。迁移学习在计算机视觉、自然语言处理等领域中具有广泛的应用前景,能够帮助开发者有效提升模型性能,实现更复杂的任务。


http://www.ppmy.cn/news/1508614.html

相关文章

2024年华为OD机试真题-学生重新排队-Python-OD统一考试(C卷D卷)

2024年OD统一考试(D卷)完整题库:华为OD机试2024年最新题库(Python、JAVA、C++合集) 题目描述: n个学生排成一排,学生编号分别是1到n,n为3的整倍数。老师随机抽签决定将所有学生分成m个3人的小组,n=3*m 为了便于同组学生交流,老师决定将小组成员安排到一起,也就是同…

java和c++两种语言的多态对比(java选手转c++必学!)多态-保研机试,大厂面试必问

多态(Polymorphism)是面向对象编程(OOP)中的一个重要概念,指的是同一个接口或基类在不同情况下可以表现出不同的行为。多态允许对象通过相同的接口或方法名以不同的方式执行操作,这种能力使代码更加灵活和可…

【唐氏题目 nt题】与众不同

# 与众不同 ## 题目描述 A是某公司的CEO,每个月都会有员工把公司的盈利数据送给A,A是个与众不同的怪人,A不注重盈利还是亏本,而是喜欢研究「完美序列」:一段连续的序列满足序列中的数互不相同。 A想知道区间[L,R]之…

“头”和“段”里有什么? ——WEB开发系列04

作为前端开发人员&#xff0c;理解HTML的基本结构及语义是至关重要的。我们将继续深入探讨HTML中的标题&#xff08;​​<h1>​​到​​<h6>​​标签&#xff09;和段落&#xff08;​​<p>​​标签&#xff09;。 1. HTML文档结构回顾 在深入标题和段落之前…

芯片bring-up的测试用例

文章目录 前言一、测试用例的规划和编写原则1、冒烟测试1&#xff09;电源时钟复位测试2&#xff09;寄存器扫描测试3&#xff09;单一功能冒烟测试 二、遍历测试三、随机测试四、性能测试五、压力测试 总结 前言 最近做了一些用测试用例点亮芯片的工作&#xff0c;从测试用例…

LabVIEW电机测试系统

LabVIEW电机测试系统采用共直流母线架构&#xff0c;优化能量循环方式&#xff0c;实现内部能量循环。系统利用高精度仪器与先进软件技术&#xff0c;提供了一个高效、可靠的测试平台&#xff0c;适用于200 kW以下的交流异步电机和永磁同步电机的性能及耐久性测试。 项目背景 …

Unity读取Android外部文件

最近近到个小需求,需要读Android件夹中的图片.在这里做一个记录. 首先读写部分,这里以图片为例子: 一读写部分 写入部分: 需要注意的是因为只有这个地址支持外部读写,所以这里用到的地址都以 :Application.persistentDataPath为地址起始. private Texture2D __CaptureCamera…

OpenHarmony南向开发 SA服务SELinux权限配置一站式傻瓜式教程

Selinux权限配置 OpenHarmony中SELinux使用详解 目录 SELinux简介SELinux概念SELinux模式OH中SELinux使用详解新增SA服务如何配置SELinux权限SELinux简介 SELinux是Security Enhanced Linux 的缩写,也就是安全强化的 Linux,旨在增强传统Linux操作系统的安全性,解决传统Li…

单调队列《滑动窗口》

#include <iostream>using namespace std;const int N 100010;int m; int q[N], hh, tt -1;//hh表示队头&#xff0c;tt表示队尾int main() {cin >> m;while (m -- ){string op;int x;cin >> op;if (op "push"){cin >> x;q[ tt] x;//队…

Linux Vim教程(十五):使用Vimscript进行脚本编写

目录 1. Vimscript简介 2. 基本语法和结构 2.1 变量 2.2 条件语句 2.3 循环语句 2.4 函数 3. 操作缓冲区、窗口和标签页 3.1 缓冲区 3.2 窗口 3.3 标签页 4. 自动化编辑任务 4.1 自动命令 4.2 键映射 5. 编写和调试Vimscript脚本 5.1 编写脚本 5.2 调试脚本 6…

魔方远程时时获取短信内容APP 前端Vue 后端Ruoyi框架(含搭建教程)

前端Vue 后端Ruoyi框架 APP原生JAVA 全兼容至Android14(鸿蒙 澎湃等等) 前后端功能&#xff1a; ①后端可查看用户在线状态(归属地IP) ②发送短信(自定义输入收信号码以及短信内容&#xff0c;带发送记录) ③短信内容分类清晰(接收时间、上传时间等等) ④前后端分离以及A…

【AWS账号解绑关联】Linker账号解绑重新关联注意事项

文章目录 一、来自客户疑问二、提交工单获取帮助三、最佳操作说明四、最佳操作步骤五、参考资料活动上新 一、来自客户疑问 将Linker账号&#xff0c;从一个组织中退出&#xff0c;重新关联到新的组织中&#xff0c;这解绑到重新完成新的关联绑定期间会在Linker账号中的账单中…

力扣高频SQL 50题(基础版)第四十二题之1517.查找拥有有效邮箱的用户

文章目录 力扣高频SQL 50题&#xff08;基础版&#xff09;第四十二题1517.查找拥有有效邮箱的用户题目说明实现过程准备数据实现方式结果截图总结 力扣高频SQL 50题&#xff08;基础版&#xff09;第四十二题 1517.查找拥有有效邮箱的用户 题目说明 表: Users -----------…

Xcode自定义模板:提升开发效率的秘诀

Xcode自定义模板&#xff1a;提升开发效率的秘诀 引言 在iOS开发中&#xff0c;Xcode的自定义模板是一项强大的功能&#xff0c;它允许开发者根据自己的开发习惯和项目需求&#xff0c;创建个性化的代码和项目模板。这不仅可以加快开发速度&#xff0c;还能保持代码的一致性和…

C#使用Puppeteer

Puppeteer Puppeteer是一个Node.js库&#xff0c;它提供了高级API来通过DevTools协议(Chrome DevTools Protocol https://devtools.chrome.com)控制Chrome或Chromium。 Puppeteer默认情况下无头运行(headless)。 可以配置为运行完整的Chrome或Chromium&#xff0c;运行效果如…

oracle rac

1、app连接oracle rac集群 连接到 Oracle RAC&#xff08;Real Application Clusters&#xff09;的多种配置方式 1. 使用 JDBC 连接字符串&#xff1a; 使用 JDBC 连接字符串是连接 Oracle RAC 的常见方式。连接字符串的格式如下&#xff1a; jdbc:oracle:thin:(DESCRIPTION…

2024年8月7日(mysql主从 )

回顾 主服务器 [rootmaster_mysql ~]# yum -y install rsync [rootmaster_mysql ~]# tar -xf mysql-8.0.33-linux-glibc2.12-x86_64.tar [rootmaster_mysql ~]# tar -xf mysql-8.0.33-linux-glibc2.12-x86_64.tar.xz [rootmaster_mysql ~]# cp -r mysql-8.0.33-linux-glibc2.…

Docker基础知识大全

文章目录 前言一、Docker为什么出现&#xff1f;二、Docker历史三、Docker能干嘛&#xff1f;四、Docker名词五、Docker安装&#xff08;CentOS7&#xff09;六、卸载docker命令七、Docker镜像容器命令总结 1、Docker为什么出现&#xff1f; java jar包 打包项目带上环境&…

RabbitMq如何确保消息不丢失

问题&#xff1a;在生产环境中由于一些不明原因&#xff0c;导致 rabbitmq 重启&#xff0c;在 RabbitMQ 重启期间生产者消息投递失败&#xff0c;导致消息丢失&#xff0c;需要手动处理和恢复。于是&#xff0c;我们开始思考&#xff0c;如何才能进行 RabbitMQ 的消息可靠投递…

Linux驱动入门实验班——基础驱动模板(附百问网视频链接)

目录 一、GPIO子系统 1.确定引脚编号 2.写程序 二、中断函数 使用中断的流程 三、定时器 1.定时器两要素 2.使用定时器 四、交互流程解读 1、非阻塞访问和阻塞访问 2、POLL 3、异步通知 课程链接 一、GPIO子系统 如何驱动GPIO 1.确定引脚编号 可以在开发板上&a…