【论文复现】偏标记学习+图像分类

server/2024/11/28 2:53:15/

在这里插入图片描述

📝个人主页🌹:Eternity._
🌹🌹期待您的关注 🌹🌹

在这里插入图片描述
在这里插入图片描述

❀ 偏标记学习+图像分类

  • 概述
  • 算法原理
  • 核心逻辑
  • 效果演示
  • 使用方式
  • 参考文献

概述


本文复现论文 Progressive Identification of True Labels for Partial-Label Learning[1] 提出的偏标记学习方法。

随着深度神经网络的发展,机器学习任务对标注数据的需求不断增加。然而,大量的标注数据十分依赖人力资源与标注者的专业知识。弱监督学习可以有效缓解这一问题,因其不需要完全且准确的标注数据。该论文关注一个重要的弱监督学习问题——偏标记学习(Partial Label Learning),其中每个训练实例与一组候选标签相关联,但仅有一个标签是真实的。

在这里插入图片描述
该论文提出了一种渐进式真实标签识别方法,旨在训练过程中逐渐确定样本的真实标签。该论文所提出的方法获得了接近监督学习的性能,且与具体的网络结构、损失函数、随机优化算法无关。

本文所涉及的所有资源的获取方式:这里

算法原理


传统的监督学习常用交叉熵损失和随机梯度下降来优化深度神经网络。交叉熵损失定义如下:
在这里插入图片描述
其中, x 表示样本特征; [ y = [ y 1 , y 2 , … , y c ] ] [ \mathbf{y} = [y_1, y_2, \ldots, y_c] ] [y=[y1,y2,,yc]]表示样本标签,其为独热码,即除了真实标签对应维度值为 1,其余为零; [ f i ( x ; θ ) ] [ f_i(x; \theta) ] [fi(x;θ)]表示模型预测样本 x 标签为 i 的概率。

该论文提出的方法使用一个软标签 [ y ^ = [ y ^ 1 , y ^ 2 , … , y ^ c ] ] [ \hat{y} = [\hat{y}_1, \hat{y}_2, \ldots, \hat{y}_c] ] [y^=[y^1,y^2,,y^c]],其对任意 [ i ∈ [ 0 , c ] ] [ i \in [0, c] ] [i[0,c]]满足 [ ∑ i y ^ i = 1 且 0 ≤ y ^ i ≤ 1 ] [ \sum_{i} \hat{y}_i = 1 \quad \text{且} \quad 0 \leq \hat{y}_i \leq 1 ] [iy^i=10y^i1]为了使用该软标签,论文根据候选标签集 s 对软标签进行初始化:
在这里插入图片描述

为了渐进式地识别真实标签,算法在每次更新参数之前,根据预测结果为下轮训练使用的软标签赋值:
在这里插入图片描述
其中, [ I ( j ∈ s ) = { 1 当且仅当  j ∈ s 为真 0 否则 ] [ I(j \in s) = \begin{cases} 1 & \text{当且仅当 } j \in s \text{ 为真} \\ 0 & \text{否则} \end{cases} ] [I(js)={10当且仅当 js 为真否则]

核心逻辑


具体的核心逻辑如下所示:

import models
import datasets
import torch
from torch.utils.data import DataLoader
import numpy as np
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torchvision.transforms as transforms
from tqdm import tqdmdef CE_loss(probs, targets):"""交叉熵损失函数"""loss = -torch.sum(targets * torch.log(probs), dim = -1)loss_avg = torch.sum(loss)/probs.shape[0]return loss_avgclass Proden:def __init__(self, configs):self.configs = configsdef train(self, save = False):configs = self.configs# 读取数据集dataset_path = configs['dataset path']if configs['dataset'] == 'CIFAR-10':train_data, train_labels, test_data, test_labels = datasets.cifar10_read(dataset_path)train_dataset = datasets.Cifar(train_data, train_labels)test_dataset = datasets.Cifar(test_data, test_labels)output_dimension = 10elif configs['dataset'] == 'CIFAR-100':train_data, train_labels, test_data, test_labels = datasets.cifar100_read(dataset_path)train_dataset = datasets.Cifar(train_data, train_labels)test_dataset = datasets.Cifar(test_data, test_labels)output_dimension = 100# 生成偏标记partial_labels = datasets.generate_partial_labels(train_labels, configs['partial rate'])train_dataset.load_partial_labels(partial_labels)# 计算数据的均值和方差,用于模型输入的标准化mean = [np.mean(train_data[:, i, :, :]) for i in range(3)]std = [np.std(train_data[:, i, :, :]) for i in range(3)]normalize = transforms.Normalize(mean, std)# 设备:GPU或CPUdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 加载模型if configs['model'] == 'ResNet18':model = models.ResNet18(output_dimension = output_dimension).to(device)elif configs['model'] == 'ConvNet':model = models.ConvNet(output_dimension = output_dimension).to(device)# 设置学习率等超参数lr = configs['learning rate']weight_decay = configs['weight decay']momentum = configs['momentum']optimizer = optim.SGD(model.parameters(), lr = lr, weight_decay = weight_decay, momentum = momentum)lr_step = configs['learning rate decay step']lr_decay = configs['learning rate decay rate']lr_scheduler = StepLR(optimizer, step_size=lr_step, gamma=lr_decay)for epoch_id in range(configs['epoch count']):# 训练模型train_dataloader = DataLoader(train_dataset, batch_size = configs['batch size'], shuffle = True)model.train()for batch in tqdm(train_dataloader, desc='Training(Epoch %d)' % epoch_id, ascii=' 123456789#'):ids = batch['ids']# 标准化输入data = normalize(batch['data'].to(device))partial_labels = batch['partial_labels'].to(device)targets = batch['targets'].to(device)optimizer.zero_grad()# 计算预测概率logits = model(data)probs = F.softmax(logits, dim=-1)# 更新软标签with torch.no_grad():new_targets = F.normalize(probs * partial_labels, p=1, dim=-1)train_dataset.targets[ids] = new_targets.cpu().numpy()# 计算交叉熵损失loss = CE_loss(probs, targets)loss.backward()# 更新模型参数optimizer.step()# 调整学习lr_scheduler.step()

以上代码仅作展示,更详细的代码文件请参见附件。

效果演示


我提前在 CIFAR-10[2] 数据集和 12 层的 ConvNet[3] 网络上训练了一份模型参数。为了测试其准确率,需要配置环境并运行main.py脚本,得到结果如下:
在这里插入图片描述
由图可见,该算法在测试集上获得了 89.8% 的准确率。

进一步地,测试训练出的模型在真实图片上的预测结果。在线部署模型后,将一张轮船的图片输入,可以得到输出的预测类型为 “Ship”:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
网站提供了在线演示功能,使用者请输入一张小于1MB、类别为上述十个类别之一、长宽尽可能相等的JPG图像。

使用方式


解压附件压缩包并进入工作目录。如果是Linux系统,请使用如下命令:

unzip Proden-implemention.zip
cd Proden-implemention

代码的运行环境可通过如下命令进行配置:

pip install -r requirements.txt

运行如下命令以下载并解压数据集

bash download.sh

如果希望在本地训练模型,请运行如下命令:

python main.py -c [你的配置文件路径] -r [选择下者之一:"train""test""infer"]

如果希望在线部署,请运行如下命令:

python main-flask.py

参考文献


[1] Lv J, Xu M, Feng L, et al. Progressive identification of true labels for partial-label learning[C]//International conference on machine learning. PMLR, 2020: 6500-6510.

[2] Krizhevsky A, Hinton G. Learning multiple layers of features from tiny images[J]. 2009.

[3] Laine S, Aila T. Temporal ensembling for semi-supervised learning[J]. arXiv preprint arXiv:1610.02242, 2016.


编程未来,从这里启航!解锁无限创意,让每一行代码都成为你通往成功的阶梯,帮助更多人欣赏与学习

更多内容详见:这里


http://www.ppmy.cn/server/145502.html

相关文章

java虚拟机——频繁发生Full GC的原因有哪些?如何避免发生Full GC

什么是Full GC Full GC(Full Garbage Collection)是Java垃圾收集过程中的一种形式,它涉及整个堆内存(包括年轻代和老年代)以及方法区的垃圾收集。Full GC是一个相对重量级的操作,因为它需要遍历和回收整个…

HarmonyOS 3.1/4项目在DevEco Studio 5.0(HarmonyOS NEXT)版本下使用的问题

有读者在使用《鸿蒙HarmonyOS应用开发入门》书中的源码时,遇到了问题。本文总结问题的原因及解决方案。 有读者在使用《鸿蒙HarmonyOS应用开发入门》书中的源码时,遇到了问题。本文总结问题的原因及解决方案。 问题原因 这些问题,本质上是…

windows10下3DGS环境配置

前言 3DGS(3D Gaussian Splatting)是由法国蔚蓝海岸大学的Kerbl, Bernhard等人在《3D Gaussian Splatting for Real-Time Radiance Field Rendering【SIGGRAPH 2023】》论文地址一文中提出了一种 极短训练时间呢就能达到最高视觉质量的方法,而且可以保证在高质量、实…

在 Taro 中实现系统主题适配:亮/暗模式

目录 背景实现方案方案一:CSS 变量 prefers-color-scheme 媒体查询什么是 prefers-color-scheme?代码示例 方案二:通过 JavaScript 监听系统主题切换 背景 用Taro开发的微信小程序,需求是页面的UI主题想要跟随手机系统的主题适配…

C语言:函数

1. 函数的基本概念与用途 函数是 C 语言程序的基本构建块。它将一个大型程序分解为较小的、可管理的部分,每个部分负责特定的任务,这样可以提高代码的可读性、可维护性和可复用性。例如,在一个涉及数学计算、输入输出处理和数据存储的复杂程…

Python学习——猜拳小游戏

import random player int(input(“请输入:剪刀 0,石头 1,布2”)) computer random.randint(0,2)# print(“玩家输入的是%d,电脑输入的是%d” %(player,computer)) 用于测试 if (player 0) and (computer 0) or (player 1) a…

SSL/TLS,SSL,TLS分别是什么

SSL/TLS,SSL,TLS分别是什么 SSL(Secure Sockets Layer,安全套接层) 定义与发展历程: SSL 是一种早期的网络安全协议,旨在为网络通信提供保密性、数据完整性和身份验证等安全保障。它最初由网景…

为什么DDoS防御很贵?

分布式拒绝服务攻击(DDoS攻击)是一种常见的网络安全威胁,通过大量恶意流量使目标服务器无法提供正常服务。DDoS防御是一项复杂且昂贵的服务,本文将详细探讨为什么DDoS防御如此昂贵,并提供一些实用的代码示例和解决方案…