Center Loss 和 ArcFace Loss 笔记

news/2025/1/14 23:50:04/

一、Center Loss

1. 定义

Center Loss 旨在最小化类内特征的离散程度,通过约束样本特征与其类别中心之间的距离,提高类内特征的聚合性。

2. 公式

对于样本 xi​ 和其类别yi​,Center Loss 的公式为:

  • xi​: 当前样本的特征向量(通常来自网络的最后一层)。
  • Cyi: 类别 yi​ 的特征中心。
  • m: 样本数量。

3. 作用

  • 减小类内样本的特征分布范围。
  • 提高分类模型对相似类别样本的区分能力。

4. 实现

import torch
import torch.nn as nnclass CenterLoss(nn.Module):def __init__(self, num_classes, feat_dim, weight=1.0):""":param num_classes: 类别数量:param feat_dim: 特征向量维度:param weight: 损失的权重"""super(CenterLoss, self).__init__()self.weight = weightself.centers = nn.Parameter(torch.randn(num_classes, feat_dim))  # 初始化类别中心def forward(self, features, labels):""":param features: 网络输出的特征向量 (batch_size, feat_dim):param labels: 样本对应的类别标签 (batch_size,)"""centers = self.centers[labels]  # 获取对应标签的中心loss = torch.sum((features - centers) ** 2, dim=1).mean()  # 欧几里得距离平方和return self.weight * loss

5. 结合 Cross-Entropy Loss

Center Loss 与交叉熵损失结合,联合优化网络:

center_loss = CenterLoss(num_classes=10, feat_dim=512)
cross_entropy_loss = nn.CrossEntropyLoss()# 训练时
features, logits = model(input_data)
loss_ce = cross_entropy_loss(logits, labels)
loss_center = center_loss(features, labels)total_loss = loss_ce + 0.1 * loss_center  # 合并损失

二、ArcFace Loss

1. 定义

ArcFace Loss 是基于角度的损失函数,用于增强特征的判别性。通过在角度空间引入额外的边际约束,强迫同类样本之间更加接近,而不同类样本之间更加远离。

2. 公式

ArcFace Loss 的公式为:

  • θ: 特征和分类权重之间的角度。
  • m: 边际(margin)。

最终损失使用交叉熵计算:

  • s: 缩放因子,用于平衡模型的学习难度。

3. 作用

  • 强化特征的角度判别能力,使得分类更加鲁棒。
  • 在人脸识别任务中,显著提高模型的性能。

4. 实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import mathclass ArcFaceLoss(nn.Module):def __init__(self, in_features, out_features, s=30.0, m=0.50):""":param in_features: 特征向量维度:param out_features: 类别数量:param s: 缩放因子:param m: 边际约束"""super(ArcFaceLoss, self).__init__()self.s = sself.m = mself.weight = nn.Parameter(torch.randn(out_features, in_features))  # 分类权重def forward(self, embeddings, labels):# Normalize embeddings and weightembeddings = F.normalize(embeddings, p=2, dim=1)weight = F.normalize(self.weight, p=2, dim=1)# Cosine similaritycosine = F.linear(embeddings, weight)# Add marginphi = cosine - self.mone_hot = torch.zeros_like(cosine)one_hot.scatter_(1, labels.view(-1, 1), 1)cosine_with_margin = one_hot * phi + (1 - one_hot) * cosine# Scalelogits = self.s * cosine_with_marginloss = F.cross_entropy(logits, labels)return loss

解释:

        ArcFaceLoss在最后一层网络,输入是上一层的输出特征值x,初始化当前层的w权重。

cos(角度)=w×x/|w|×|x|,由于ArcLoss会对w和x进行归一化到和为1的概率值。所以|w|×|x|=1。则推导出cos(角度)=w×x,那么真实标签位置给角度+m则让角度变大了,cos值变小。w×x变小,输出的预测为真实标签的概率变低。让模型更难训练,那么在一遍又一遍的模型读取图片提取特征的过程中,会让模型逐渐的将真实标签位置的w×x值变大==cos(角度+m)变大,那么角度就会变的更小。只有角度更小的时候,cos余弦相似度才会大,从而让模型认为这个类别是真实的类别。

所以arcloss主要加入了一个m,增大角度,让模型更难训练,让模型把角度变的更小,从而让w的值调整的更加让类间距增大。

简而言之:加入m的值,让真实类和其他类相似度更高,让模型更难训练。迫使模型为了让真实和其他类相似度更低,而让w权重的值更合理。

三、对比分析

四、如何选择

  • 如果任务需要提升类内特征的聚合性(如样本分布紧密性),优先考虑 Center Loss
  • 如果任务需要增强类间特征的判别能力(如人脸识别),优先选择 ArcFace Loss
  • 可以同时使用两者,将特征聚合和判别性结合,提高模型的鲁棒性。

五、推荐学习资源

  1. ArcFace: Additive Angular Margin Loss for Deep Face Recognition (论文)
  2. Center Loss: A Discriminative Feature Learning Approach for Deep Face Recognition (论文)
  3. PyTorch 官方文档

http://www.ppmy.cn/news/1563168.html

相关文章

利用Bi-LSTM实现基于光谱数据对数值进行预测-实战示例

0前言&简介: 本文为《RNN之:LSTM 长短期记忆模型-结构-理论详解-及实战(Matlab向)》的拓展示例,对于初学者而言,还请先阅读原文,增强理解。 本示例采用了长度为807,样本数为12…

渐变头像合成网站PHP源码

源码介绍 渐变头像合成网站PHP源码,操作简单便捷,用户只需上传自己的头像,选择喜欢的头像框,点击一键合成即可生成专属定制头像。网站提供了167种不同风格的头像框供选择,用户也可以自己添加素材。生成后的头像可以直…

Windows下调试Dify相关组件(1)--前端Web

1. 什么是Dify? 官方介绍:Dify 是一款开源的大语言模型(LLM) 应用开发平台。它融合了后端即服务(Backend as Service)和 LLMOps 的理念,使开发者可以快速搭建生产级的生成式 AI 应用。 这是个组件式框架,即使是非技…

Python贪心

贪心 贪心:把整体问题分解成多个步骤,在每个步骤都选取当前步骤的最优方案,直至所有步骤结束;每个步骤不会影响后续步骤核心性质:每次采用局部最优,最终结果就是全局最优如果题目满足上述核心性质&#xf…

利用AI提升SEO效果的关键词优化策略

AI在SEO中的重要性 在当前数字化时代,网站的可见性和可达性变得尤为重要,而搜索引擎优化(SEO)则是提升网站流量和展示机会的关键。人工智能(AI)的引入为SEO领域注入了新的活力,使得优化过程更为…

C++类的引入

C中类的前身 1> 面向对象三大特征:封装、继承、多态 2> 封装:将能够实现某一事物的所有万事万物都封装到一起,包括成员属性(成员变量),行为(功能函数)都封装在一起&#xff…

【巨实用】Git客户端基本操作

本文主要分享Git的一些基本常规操作,手把手教你如何配置~ ● 一个文件夹中初始化Git git init ● 为了方便以后提交代码需要对git进行配置(第一次使用或者需求变更的时候),告诉git未来是谁在提交代码 git config --global user.na…

用 Python 从零开始创建神经网络(十九):真实数据集

真实数据集 引言数据准备数据加载数据预处理数据洗牌批次(Batches)训练(Training)到目前为止的全部代码: 引言 在实践中,深度学习通常涉及庞大的数据集(通常以TB甚至更多为单位)&am…