【机器学习】Softmax 函数

ops/2024/10/30 10:55:08/

Softmax 是机器学习中常用的函数,广泛用于多分类问题的输出层。它可以将一组实数转换为一个概率分布,使得结果满足“非负”和“总和为1”的要求。在分类问题中,Softmax 让模型预测的每个类别概率都易于解释。本文将详细讲解 Softmax 的原理、公式推导、Numpy 实现及其在 Pytorch 中的实际应用。

Softmax 原理

给定一个类别集合 { y 1 , y 2 , … , y n } \{y_1, y_2, \dots, y_n\} {y1,y2,,yn},Softmax 将模型输出的每个数值(称为“得分”或“logits”)转换为概率值。假设模型输出 z i z_i zi 为第 i i i 类的得分,Softmax 将所有的得分映射到概率空间,使每个得分转化为该类的预测概率。

Softmax 函数的公式为:
P ( y i ) = e z i ∑ j = 1 n e z j P(y_i) = \frac{e^{z_i}}{\sum_{j=1}^n e^{z_j}} P(yi)=j=1nezjezi
其中 z i z_i zi 表示模型为第 i i i 类输出的得分, n n n 是类别的数量。通过对指数值的归一化处理,Softmax 函数输出的概率满足:

  1. 所有概率值都为非负数;
  2. 概率总和为 1。

Softmax 计算中的数值稳定性

在计算中,Softmax 可能会因为指数运算导致数值溢出,为了减小这种风险,可以对每个 (z_i) 值减去一个常数 max ⁡ ( z ) \max(z) max(z)
P ( y i ) = e z i − max ⁡ ( z ) ∑ j = 1 n e z j − max ⁡ ( z ) P(y_i) = \frac{e^{z_i - \max(z)}}{\sum_{j=1}^n e^{z_j - \max(z)}} P(yi)=j=1nezjmax(z)ezimax(z)
这种转换不会改变概率的分布,避免了指数函数产生的大数值溢出问题。

Numpy 实现 Softmax 函数

下面通过 Numpy 实现 Softmax,并进行数据可视化以更直观地理解 Softmax 对得分的转换过程。

import numpy as np
import matplotlib.pyplot as plt# 定义 Softmax 函数
def softmax(logits):"""使用数值稳定性的 Softmax 函数实现参数:- logits: 模型输出得分向量(shape: (n,),表示 n 个类别的得分)返回:- probs: 转换后的概率向量,shape: (n,)"""exp_shifted = np.exp(logits - np.max(logits))  # 减去 max(logits) 以确保数值稳定性probs = exp_shifted / np.sum(exp_shifted)  # 归一化为概率return probs# 示例输入的分类得分
logits = np.array([2.0, 1.0, 0.1])# 使用 Softmax 函数计算各类别的概率
probs = softmax(logits)# 输出各类的预测概率
print("分类得分:", logits)
print("预测概率:", probs)

Softmax 输出可视化

我们可以用图像展示 Softmax 如何将得分转化为概率,假设输入的分类得分范围为 -2 到 4。

# 生成模拟的分类得分范围
logit_range = np.linspace(-2, 4, 100)
all_probs = np.array([softmax([l, 1.0, 0.1]) for l in logit_range])# 可视化不同类别的预测概率随得分变化的趋势
plt.plot(logit_range, all_probs[:, 0], label="类别 1")
plt.plot(logit_range, all_probs[:, 1], label="类别 2")
plt.plot(logit_range, all_probs[:, 2], label="类别 3")
plt.xlabel("得分 (logits)")
plt.ylabel("概率")
plt.title("Softmax 函数输出的概率分布")
plt.legend()
plt.show()

Softmax 损失函数:交叉熵损失

在多分类任务中,常用 交叉熵损失函数 来衡量模型预测概率分布与真实标签的匹配程度。对于单个样本,交叉熵损失定义为:
L = − ∑ i = 1 n y i ⋅ log ⁡ ( P ( y i ) ) L = -\sum_{i=1}^{n} y_i \cdot \log(P(y_i)) L=i=1nyilog(P(yi))

其中 (y_i) 是真实标签的 one-hot 编码,(P(y_i)) 是 Softmax 转换后的预测概率。

# 定义交叉熵损失函数
def cross_entropy_loss(probs, y_true):"""计算交叉熵损失参数:- probs: Softmax 预测概率 (shape: (n,))- y_true: 实际标签 (shape: (n,)),one-hot 编码返回:- loss: 交叉熵损失"""loss = -np.sum(y_true * np.log(probs + 1e-10))  # 加1e-10防止 log(0)return loss# 示例计算
y_true = np.array([1, 0, 0])  # 假设类别 1 为正确类别
loss = cross_entropy_loss(probs, y_true)
print("交叉熵损失:", loss)

在 PyTorch 中使用 Softmax

在 PyTorch 中,我们可以直接调用 torch.nn.functional.softmax 来实现 Softmax。此外,PyTorch 提供的 torch.nn.CrossEntropyLoss 函数在内部自动包含了 Softmax 和交叉熵的计算,无需显式计算。

import torch
import torch.nn.functional as F# 示例:在 PyTorch 中实现 Softmax 和交叉熵损失
logits_torch = torch.tensor([2.0, 1.0, 0.1])# 使用 PyTorch 的 Softmax 函数
probs_torch = F.softmax(logits_torch, dim=0)
print("PyTorch 预测概率:", probs_torch.numpy())# 使用交叉熵损失函数
y_true_index = torch.tensor([0])  # 假设第一个类别为正确类别
loss_fn = torch.nn.CrossEntropyLoss()
loss_torch = loss_fn(logits_torch.unsqueeze(0), y_true_index)
print("PyTorch 交叉熵损失:", loss_torch.item())

在 PyTorch 中,torch.nn.CrossEntropyLoss 在传入 logits 后自动应用 Softmax 和交叉熵计算,为多分类问题提供了便捷的计算方式。

总结

本文介绍了 Softmax 的原理、公式、Numpy 实现、可视化以及在 PyTorch 中的使用。Softmax 是将得分转化为概率分布的关键函数,尤其适用于多分类任务。我们还探讨了数值稳定性的处理以及交叉熵损失在多分类中的作用,理解并实现 Softmax 有助于构建更稳定且易解释的分类模型。


http://www.ppmy.cn/ops/129558.html

相关文章

详解:单例模式中的饿汉式和懒汉式

单例模式是一种常用的设计模式,其目的是确保一个类只有一个实例(对象),并提供一个全局访问点。单例模式有两种常见的实现方式:饿汉式和懒汉式。 一、饿汉式 饿汉式在类加载时就完成了实例化。因为类加载是线程安全的&…

设计模式06-结构型模式1(适配器/桥接/组合模式/Java)

#1024程序员节|征文# 4.1 适配器模式 结构型模式(Structural Pattern)的主要目的就是将不同的类和对象组合在一起,形成更大或者更复杂的结构体。结构性模式的分类: ​ 类结构型模式关心类的组合,由多个类…

C语言中的位操作

第一章 变量某位赋值与连续赋值 寄存器 | 值 //例如&#xff1a;a 1000 0011b a | (1<<2) //a 1000 0111 b 单独赋值 a | (3<<2*2) // 1011 0011b 连续赋值 第二章 变量某位清零与连续清零 寄存器 & ~&#xff08;&#xff09; 值 //例子&#xff1a;a …

MATLAB车道检测与跟踪

读了车道检测这个论文&#xff0c;我理解了利用matlab对车道识别算法进行仿真研究&#xff0c;从仿真的结果中提出具有一定实时性鲁棒性的识别方法。车道检测是智能车辆发展的智能因素。近年来对这项目的研究都是针对特定的环境和道路状况给出了不同的解决方案。近年来,自主驾驶…

NVR小程序接入平台/设备EasyNVR多个NVR同时管理视频监控新选择

在数字化转型的浪潮中&#xff0c;视频监控作为安防领域的核心组成部分&#xff0c;正经历着前所未有的技术革新。随着技术的不断进步和应用场景的不断拓展&#xff0c;视频监控系统的兼容性、稳定性以及安全性成为了用户关注的焦点。NVR小程序接入平台/设备EasyNVR&#xff0c…

【本科毕业设计】基于单片机的智能家居防火防盗报警系统

基于单片机的智能家居防火防盗报警系统 源码下载摘要Abstract第1章 绪论1.1课题的背景1.2 研究的目的和意义 第2章 系统总体方案设计2.1 设计要求2.2 方案选择和论证2.2.1 单片机的选择2.2.2 显示方案的选择 第3章 系统硬件设计3.1 整体方案设计3.1.1 系统概述3.1.2 系统框图 3…

核心HTML5/CSS3基础面试题

HTML5/CSS3 高频经典面试题 汇总了 2023 年各互联网大厂以及中小型创业公司基础阶段的最新高频面试题 HTML/HTML5 标签 Interview questions 1、说说你对 HTML 语义化的理解 ?HTML5 新增了哪些语义化标签 ?(字节、百度,阿里,腾讯、京东,小米) 2、DOCTYPE 是干嘛的,…

一篇文章入门傅里叶变换

文章目录 傅里叶变换欧拉公式傅里叶变换绕圈记录法质心记录法傅里叶变换公式第一步&#xff1a;旋转的表示第二步&#xff1a;缠绕的表示第三步&#xff1a;质心的表示最终步&#xff1a;整理积分限和系数 参考文献 傅里叶变换 在学习傅里叶变换之前&#xff0c;我们先来了解一…