线性模型 - Softmax 回归(参数学习)

embedded/2025/2/24 3:48:46/

本文,我们来学习Softmax 回归的参数学习,在开始之前,我们先了解一下“损失函数”、“风险函数”和“目标函数”这三个核心概念。

一、损失函数、风险函数、目标函数

1. 损失函数(Loss Function)

  • 定义
    损失函数是用来衡量单个样本预测结果与真实标签之间差异的函数。它描述了对一个样本来说,模型犯错误的“代价”或“惩罚”。

  • 例子

    • 对于回归问题,常用的损失函数有均方误差: 
    • 对于二分类问题,常用的损失函数有二元交叉熵:
  • 作用
    它直接衡量单个样本的预测误差,是后续构建总体模型优化目标的基本组成部分。

2. 风险函数(Risk Function)

  • 定义
    风险函数是对所有可能样本(或整个数据分布)上损失函数的期望值,也称为“期望损失”。在统计学习中,风险函数描述了模型在总体数据分布下的平均表现。

  • 数学表达

  • 作用
    风险函数衡量了模型的泛化能力和整体表现,是理论上模型“好坏”的评判标准。实际中,由于总体数据分布未知,我们往往使用经验风险(在训练集上的平均损失)来近似。

3. 目标函数(Objective Function)

  • 定义
    目标函数是优化算法在训练过程中希望最小化(或最大化)的函数。它通常包含经验风险(或损失函数的平均值),有时还会加入正则化项以控制模型复杂度。

  • 数学表达
    例如,对于一个回归问题,目标函数可以写为

    其中第一项是经验风险(经验损失的平均值),Ω(f) 是正则化项,λ 是正则化参数。

  • 作用
    目标函数定义了训练过程中需要优化的具体指标。通过优化目标函数,我们期望找到一个模型,使得在训练数据上损失(以及模型复杂度)达到平衡,进而提升泛化能力。

4. 三者之间的区别与联系

  • 区别

    • 损失函数:针对单个样本,衡量预测与真实标签的差异。
    • 风险函数:是损失函数在整个数据分布上的期望,描述模型在总体数据上的平均表现。
    • 目标函数:是实际优化时用到的函数,通常是经验风险(训练集上平均损失)加上正则化项,作为训练过程中参数调整的依据。
  • 联系

    • 损失函数是风险函数的基本组成部分,风险函数是损失函数在总体数据分布下的平均值。
    • 在实际训练中,由于总体数据分布未知,我们通常用训练集上的平均损失(经验风险)作为目标函数的一部分,通过最小化目标函数来间接地降低总体风险。

总结

  • 损失函数告诉我们每个样本犯错的代价;
  • 风险函数(期望损失)描述了模型在整个数据分布上的平均表现;
  • 目标函数是实际优化过程中使用的函数,通常包含经验风险和正则化项,用来指导模型参数的学习

这种分层的定义帮助我们从单个样本的误差度量,扩展到整个数据集甚至总体数据分布的模型评估,再到实际训练过程中具体优化的目标,构成了机器学习模型训练的完整理论框架。

二、Softmax 回归的损失函数:交叉熵损失

(一)模型结构

1、模型定义

2、模型表示

(二)交叉熵损失函数

这里提到one-hot编码,独热分布可以参考:线性模型 - 二分类问题的损失函数-CSDN博客

为了让大家直观理解交叉熵损失函数选择的意义,我还是把对数函数的图像放到这里:

这样大家可以比较直观的看到,随着概率P从0~1的变化,对应的损失的变化情况。这里需要大家具备基本的对数函数的知识。

三、参数学习

1、构造似然函数

通过最大似然估计,目标是找到一组参数 {w_k, b_k}使得训练数据的似然最大。取对数后,最大化对数似然等价于最小化交叉熵损失。

2、求梯度

通过计算损失函数关于参数的梯度,利用梯度下降或其变种(如随机梯度下降、Mini-batch SGD、Adam 等)迭代更新参数。

根据计算得到的梯度,采用梯度下降的更新规则(例如,参数更新公式为:

其中 α是学习率。

重复上述计算和更新过程,直至损失函数收敛到较低值,或达到预设的迭代次数,从而得到最优参数。

3、确定最优参数

  1. 优化算法选择

    • 批量梯度下降(BGD):使用全体训练数据计算梯度,收敛稳定但计算成本高。

    • 随机梯度下降(SGD):每次随机选择一个样本更新参数,速度快但波动大。

    • 小批量梯度下降(Mini-batch GD):折中方案,常用批量大小为32~256。

  2. 学习率调整

    • 固定学习:简单但需手动调参(如α=0.01)。

    • 自适应学习:使用Adam、RMSprop等优化器自动调整。

  3. 正则化技术

    • L2正则化:在损失函数中加入,防止权重过大。

    • 早停法(Early Stopping):在验证集损失不再下降时终止训练。

  4. 收敛判定

    • 损失变化阈值:当损失下降幅度小于预设阈值(如10^−5)时停止。

    • 最大迭代次数:设置训练轮次上限(如1000轮)。

四、参数学习举例:

栗子1:

栗子2:

场景:手写数字识别(3个类别,2个特征)

五、为什么要使用交叉熵损失作为损失函数 ?

1、概率解释

交叉熵损失衡量的是模型预测的概率分布与真实分布之间的差异。在二分类或多分类问题中,真实标签通常以独热编码(one-hot encoding)的形式表示,而模型输出的是一个概率分布。交叉熵正好可以量化这两个分布之间的不匹配程度,从而指导模型改进预测。

2、与最大似然估计的一致性

在逻辑回归和 Softmax 回归中,我们假设样本服从伯努利分布或多项分布。最大似然估计(MLE)的目标是最大化数据的似然,而取对数后,MLE 的目标函数就转化为最小化交叉熵损失。这种方法从理论上保证了最优参数的学习

3、良好的数值性质

交叉熵损失通常是凸的(对于线性模型),这使得利用梯度下降等优化方法能够有效地找到全局最优解。同时,当模型预测错误且非常自信时,交叉熵损失会急剧增大,从而促使模型大幅度调整参数。

4、梯度信息丰富

交叉熵损失提供的梯度信息通常比较丰富,尤其是在预测概率偏离真实值较远时,梯度较大,可以帮助模型更快地学习和纠正错误。

5、举例说明

在垃圾邮件检测任务中,假设真实标签 y=1 表示垃圾邮件,而模型预测出邮件为垃圾邮件的概率为 y^。

  • 当邮件确实是垃圾邮件(y=1)且 y^ 很接近1时,交叉熵损失 −log⁡(y^) 很小,表示模型预测正确;
  • 反之,如果邮件为垃圾邮件但模型预测 y^ 较低(例如0.3),则损失 −log⁡(0.3) 会非常大,迫使模型调整参数以提高预测概率。

因此,使用交叉熵损失作为损失函数,可以让模型在训练过程中有效地衡量预测概率与真实标签之间的差距,通过最小化该损失,我们能够以最大似然估计的方式获得最优参数,同时交叉熵损失具有良好的数学性质和梯度信息,有助于稳定高效地进行优化。

五、总结:

步骤关键点
前向传播计算类别得分 → Softmax归一化为概率 → 计算交叉熵损失
反向传播损失对得分的梯度 = 预测概率 - 真实标签 → 链式法则求权重和偏置的梯度
参数更新通过梯度下降法(或变体)更新权重和偏置
正则化添加L2正则化项控制模型复杂度,防止过拟合
最优参数判定结合验证集监控,通过早停法或损失收敛阈值确定训练终止点
  • 参数学习过程
    Softmax 回归的参数学习通过最大似然估计转化为最小化交叉熵损失,然后使用梯度下降等优化算法更新参数,最终得到能够输出合理概率分布的模型。
  • 确定最优参数
    通过不断迭代更新,直到损失函数收敛或达到预定的训练轮数,从而得到在训练数据上表现最优的参数。
  • 直观效果
    模型最终可以将输入 x 映射为每个类别的概率,并通过 argmax 操作输出预测类别。

这种训练过程不仅在数学上严谨,而且在实际应用中非常高效,适用于多类别分类任务。

通过优化交叉熵损失函数,Softmax回归能够有效学习多分类问题的决策边界。实际应用中需注意学习率调整、正则化强度选择及优化算法的适应性。


http://www.ppmy.cn/embedded/164739.html

相关文章

力扣——划分字母区间

题目链接: 链接 题目描述: 思路: 要找到每一个字母的最大位置end,也是这一段的结尾位置在这个最大位置内的字母,如果存在某个字母的最大位置 更大,就更新end为更大的如果遍历到end,就说明这一…

深研究:与Dify建立研究自动化应用

许多个人和团队面临筛选各种网页或内部文档的挑战,以全面概述一个主题。那么在这里我推荐大家使用Dify,它是一个用于LLM应用程序开发的低代码,开源平台,它通过自动化工作流程的多步搜索和有效汇总来解决此问题,仅需要最小的编码。 在本文中,我们将创建“ Deepresearch”…

AI大模型(DeepSeek)科研应用、论文写作、数据分析与AI绘图学习

【介绍】 在人工智能浪潮中,2024年12月中国公司研发的 DeepSeek 横空出世以惊艳全球的姿态,成为 AI领域不可忽视的力量!DeepSeek 完全开源,可本地部署,无使用限制,保护用户隐私。其次,其性能强大&#xff…

【matlab代码】基于故障概率加权与多模态滤波的AUV多源融合导航

多模态容错滤波仿真,以AUV为背景。订阅专栏后可查看完整代码,如有程序定制需求,可联系作者。 文章目录 创新点MATLAB仿真代码运行结果说明创新点 贝叶斯故障概率模型 融合SINS/DVL/GPS历史残差,计算实时故障概率 P fault P_{\text{fault}}

TCP fast open

TCP Fast Open 复用 Cookie 快速恢复会话,减少 1 个 RTT 的延迟 传统 TCP 三次握手需 1.5 RTT才能传输应用数据,导致 HTTP 请求延迟较高 TCP Fast Open:为了解决传统 TCP 握手中的延迟问题,通过允许在首次 SYN 握手阶段携带应用数…

C++初阶——简单实现vector

目录 1、前言 2、Vector.h 3、Test.cpp 1、前言 简单实现std::vector类模板。 相较于前面的string,vector要注意: 深拷贝,因为vector的元素可能是类类型,类类型元素可以通过赋值重载,自己实现深拷贝。 迭代器失效…

云夹平台:一站式学习与生活效率工具

在数字化时代,高效管理知识、资源和日常事务成为现代人的核心需求。云夹平台正是这样一款集多功能于一体的智能工具,致力于为用户提供便捷、个性化的服务体验。无论你是学生、职场人士还是终身学习者,云夹都能成为你的得力助手。 1. 书签管理…

基于spring的策略模式

集合spring框架的是策略模式,直接上代码 1、接口 public interface PaymentStrategy {//支付接口void pay(double amount);}2、实现类 2.1 实现类一 Component("creditCard") //作为区分的标识 public class CreditCardPayment implements PaymentStr…