损失函数——交叉熵损失(Cross-entropy loss)

news/2024/11/30 1:47:18/

交叉熵损失(Cross-entropy loss)是深度学习中常用的一种损失函数,通常用于分类问题。它衡量了模型预测结果与实际结果之间的差距,是优化模型参数的关键指标之一。以下是交叉熵损失的详细介绍。

假设我们有一个分类问题,需要将输入数据x分为C个不同的类别。对于每个输入数据x,我们定义一个C维的向量y^​,其中y^​i​表示x属于第i个类别的概率。我们的目标是使得y^​尽可能接近真实的标签y的概率分布。

假设真实标签y是一个C维的向量,其中只有一个元素为1,其余元素为0,表示x属于第k个类别。那么,我们可以使用交叉熵损失来衡量模型预测结果和真实标签之间的差距。交叉熵损失的公式如下: 

                                                         L\left ( x,y \right ) = -\sum _{i=1}^{C}x_{_{i}} log y_{i}

其中,xi​表示真实标签的第i个元素,y​i​表示模型预测x属于第i个类别的概率。

交叉熵损失的本质是衡量两个概率分布之间的距离。其中一个概率分布是真实标签y的分布,另一个是模型预测的概率分布y^​。对于每个类别i,yi​表示真实标签x属于第i个类别的概率,y^​i​表示模型预测x属于第i个类别的概率。当两个概率分布越接近时,交叉熵损失越小,表示模型预测结果越准确。

交叉熵损失是一种凸函数,通常使用梯度下降等优化算法来最小化它。在深度学习中,交叉熵损失是常见的分类损失函数之一,广泛应用于图像分类、语音识别等任务中。

在PyTorch中,交叉熵损失可以使用torch.nn.CrossEntropyLoss实现。该函数将输入数据视为模型输出的概率分布,将目标标签视为类别索引,并计算这些概率与实际标签之间的交叉熵损失。

以下是一个示例代码片段,说明如何使用torch.nn.CrossEntropyLoss计算交叉熵损失:

import torch# 创建模型输出和目标标签
output = torch.randn(10, 5)  # 10个样本,5个类别
target = torch.tensor([1, 0, 4, 2, 3, 1, 0, 4, 2, 3])  # 目标类别索引# 创建交叉熵损失函数
criterion = torch.nn.CrossEntropyLoss()# 计算损失
loss = criterion(output, target)print(loss)

在训练中,你可以使用torch.nn.CrossEntropyLoss作为损失函数来优化模型。假设你已经有一个PyTorch模型和训练数据集,以下是一个简单的训练循环示例,它使用交叉熵损失函数来训练模型:

import torch
import torch.nn as nn
import torch.optim as optim# 定义模型
class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.fc1 = nn.Linear(10, 5)self.fc2 = nn.Linear(5, 2)def forward(self, x):x = self.fc1(x)x = nn.functional.relu(x)x = self.fc2(x)return xmodel = MyModel()# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)# 训练循环
for epoch in range(num_epochs):for batch_idx, (data, target) in enumerate(train_loader):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch_idx % log_interval == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))

在这个例子中,MyModel是一个简单的两层全连接神经网络。训练循环通过从数据集中加载数据批次,使用optimizer.zero_grad()清空梯度,计算模型输出和损失,使用loss.backward()计算梯度并使用optimizer.step()更新模型参数。每个epoch结束时,模型将在测试集上进行评估,以检查其在新数据上的泛化能力。

在这个训练循环中,我们使用nn.CrossEntropyLoss()作为损失函数,并传递模型输出和目标标签作为参数。loss.backward()计算梯度并将梯度传播回模型中的参数,从而使优化器能够更新这些参数以最小化损失。


http://www.ppmy.cn/news/123625.html

相关文章

VP记录:Codeforces Round 862 (Div. 2) A~D

传送门:CF 前提提要:本场中的E题是个线段树合并的板子题,或者可以使用莫队来维护,但是近期思维题打多了导致这些ds问题根本打不来了,留着以后复建之后再补吧 A题:A. We Need the Zero 因为异或是会相互抵消的.所以偶数次异或等于0,奇数次异或才会产生作用.所以当我们的n为偶…

ARM-cortexA7-PWM实验

目录 pwm.h pwm.c main.c pwm.h #ifndef __PWM_H__ #define __PWM_H__ #include "stm32mp1xx_rcc.h" #include "stm32mp1xx_gpio.h" #include "stm32mp1xx_tim.h" // 蜂鸣器初始化 void hal_pwm_init(void);// 风扇初始化 void hal_fan_init…

rails -在数据库里看表结构

在 Rails 中,可以使用 ActiveRecord Migration 文件来管理数据库的表结构。具体来说,可以使用以下命令来查看所创建的 Migration 文件: # ruby rails db:migrate:status该命令会列出已经运行过的所有 Migration 文件和它们的状态。 此外&am…

django-vue-admin-pro 使用

地址: GitHub - dvadmin-pro/django-vue-admin-pro 一、准备工作 Python > 3.8.0 (推荐3.9版本) nodejs > 14.0 (推荐最新) Mysql > 5.7.0 (可选,默认数据库sqlite3,推荐8.0版本) Redis(可选,最新版)项目运行及部署 |…

自学网络安全, 一般人我劝你还是算了吧

前言:自学我劝你还是算了,我为什么要劝你放弃我自己却不放弃呢?因为我不是一般人。。。 1.这是一条坚持的道路,三分钟的热情可以放弃往下看了. 2.多练多想,不要离开了教程什么都不会了.最好看完教程自己独立完成技术方面的开发. 3.有时多 …

如何测试Vue应用程序中的组件和代码?

测试Vue应用程序的组件和代码是非常重要的,这可以确保你的程序在生产环境中稳定运行。 首先,我想强调的是:不要害怕测试!实际上,测试是一个非常有趣的过程,它可以帮助你更好地理解你的代码,并且…

Springboot以Post方式导出excel文件

场景: 导出excel文件,但是需要传入参数,get方法传参懂的都懂,所以改成post方式 少废话,上代码: Controller: PostMapping(value "/exportCustomMItemDataWithLine.iom") ResponseBody ApiOp…

ic验证的主要工作流程和验证工具是什么?

验证其实是一个“证伪”的过程,从流程到工具,验证工程师的终极目的都只有一个: 发现所有BUG,或者证明没有BUG,以保证芯片功能性能的正确性和可靠性。 验证环节对于一颗芯片的重要性也是不言而喻的: 从项…