连续变量的 交叉熵 如何计算 python tensorflow

news/2024/11/29 19:50:39/

连续变量的交叉熵通常在机器学习中的回归问题中使用,但它也可以用于分类问题,当概率分布是连续的时。连续变量的交叉熵计算公式如下:

设 \( p(x) \) 是真实概率密度函数,\( q(x) \) 是预测概率密度函数,交叉熵 \( H(p, q) \) 定义为:

\[
H(p, q) = -\int_{-\infty}^{\infty} p(x) \log q(x) \, dx
\]

在分类问题中,如果我们有 \( K \) 个类别,并且 \( p_k \) 是第 \( k \) 个类别的真实概率,\( q_k \) 是第 \( k \) 个类别的预测概率,交叉熵可以表示为:

\[
H(p, q) = -\sum_{k=1}^{K} p_k \log q_k
\]

在实际应用中,如果 \( p_k \) 是一个独热编码的向量(即只有一个类别是 1,其余都是 0),上述公式简化为:

\[
H(p, q) = -\log q_y
\]

其中 \( y \) 是真实类别的索引。

在回归问题中,如果我们有一个连续的目标变量,我们可以使用均方误差(MSE)或均方对数误差(MSLE)等其他损失函数,而不是交叉熵。然而,如果我们想要使用交叉熵,我们通常需要将问题转换为类似于分类问题的形式,例如通过将连续变量离散化或使用概率分布来建模连续变量。
 

下面是用梯度惩罚来实现 K L 散度最小化的实现,和交叉熵原理差不多


# 定义计算损失的函数
def compute_loss(real_data):# 梯度惩罚权重gradient_penalty_weight = gradient_penalty_weight_lamda# x = tf.random.normal((batch_size, n), dtype=tf.dtypes.float32)# x_samp = x / tf.sqrt(2 * tf.reduce_mean(tf.square(x)))# x_gen = tf.concat(values=[w_generator(x_samp), x_samp], axis=1)# x = X_train# x_samp = X_train_sampx_samp = X_train# todo# 计算损失函数的时候 ,用z_score归一化计算?# todo# x_gen = w_generator(x_samp) + x_sampx_gen = w_generator(x_samp)logits_x = w_discriminator(tf.concat([y_train, X_train], axis=-1))logits_x_gen = w_discriminator(tf.concat([x_gen, X_train], axis=-1))d_regularizer = gradient_penalty(real_data, x_gen)disc_loss = (tf.reduce_mean(logits_x) - tf.reduce_mean(logits_x_gen) + d_regularizer * gradient_penalty_weight)gen_loss = tf.reduce_mean(logits_x_gen)return disc_loss, gen_loss# 定义应用生成器梯度的函数
def apply_gen_gradients(gen_gradients):w_gen_optimizer.apply_gradients(zip(gen_gradients, w_generator.trainable_variables))# 定义应用判别器梯度的函数
def apply_disc_gradients(disc_gradients):w_disc_optimizer.apply_gradients(zip(disc_gradients, w_discriminator.trainable_variables))# 定义梯度惩罚函数
# def gradient_penalty(x, x_gen):
#     epsilon = tf.random.uniform([x.shape[0], 1, 1, 1], 0.0, 1.0)
#     x_hat = epsilon * x + (1 - epsilon) * x_gen
#     with tf.GradientTape() as t:
#         t.watch(x_hat)
#         d_hat = w_discriminator(x_hat)
#     gradients = t.gradient(d_hat, x_hat)
#     ddx = tf.sqrt(tf.reduce_sum(gradients ** 2, axis=[1, 2]))
#     d_regularizer = tf.reduce_mean((ddx - 1.0) ** 2)
#     return d_regularizer# 定义梯度惩罚函数
def gradient_penalty(x, x_gen):# 创建一个与真实样本 x 的批量大小相同的随机变量 epsilon,其值在0和1之间,用于在后续步骤中进行插值。# epsilon = tf.random.uniform([x.shape[0], 1, 1, 1], 0.0, 1.0)epsilon = tf.random.uniform([x.shape[0], 1], 0.0, 1.0)# 计算插值样本 x_hat,它是真实样本 x 和生成样本 x_gen 的线性组合。# 这一步是为了在真实样本和生成样本之间创建一个连续的路径x_hat = epsilon * x + (1 - epsilon) * x_gen# print("Shape before discriminator:", x_hat.shape)# 创建一个 tf.GradientTape 上下文,用于记录对 x_hat 的操作,以便后续计算梯度。with tf.GradientTape() as t:# 告诉 tf.GradientTape 监控 x_hat,以便可以计算关于它的梯度t.watch(x_hat)# 使用判别器 w_discriminator 对插值样本 x_hat 进行评分,得到 d_hat。# print(x_hat.shape)d_hat = w_discriminator(tf.concat([x_hat, X_train], axis=-1))# 计算判别器输出 d_hat 关于插值样本 x_hat 的梯度gradients = t.gradient(d_hat, x_hat)# 计算梯度的L2范数,即对每个样本的梯度向量进行平方和,然后开方,得到每个样本的梯度范数。# 在你的代码中,gradients 张量的形状是 [100, 4],但你尝试在 axis=[1, 2]# 上进行 tf.reduce_sum 操作。由于张量只有两个维度,所以没有第三个维度可以进行求和。# ddx = tf.sqrt(tf.reduce_sum(gradients ** 2, axis=[1, 2]))ddx = tf.sqrt(tf.reduce_sum(gradients ** 2, axis=[1]))# 算梯度惩罚项,它是梯度范数与1的差的平方的平均值。在WGAN中,我们希望梯度范数接近1,# 因此这个惩罚项会惩罚那些使梯度范数远离1的判别器。d_regularizer = tf.reduce_mean((ddx - 1.0) ** 2)# print("gradient_penalty")return d_regularizer


http://www.ppmy.cn/news/1550973.html

相关文章

ELK配置索引清理策略

在ELFK(Elasticsearch, Logstash,Filebeat, Kibana)堆栈中配置索引清理策略是一个常见的需求,因为日志数据会随着时间的推移而积累,占用大量的存储空间。以下是一些配置索引清理策略的方法: 1. 使用索引生命周期管理&…

Linux操作系统2-进程控制3(进程替换,exec相关函数和系统调用)

上篇文章:Linux操作系统2-进程控制2(进程等待,waitpid系统调用,阻塞与非阻塞等待)-CSDN博客 本篇代码Gitee仓库:Linux操作系统-进程的程序替换学习 d0f7bb4 橘子真甜/linux学习 - Gitee.com 本篇重点:进程替换 目录 …

第二节——计算机网络(四)物理层

车载以太网采用差分双绞线车载以太网并未指定特定的连接器,连接方式更为灵活小巧,能够大大减轻线束重量。传统以太网一般使用RJ45连接器连接。车载以太网物理层需满足车载环境下更为严格的EMC要求,100BASE-T1\1000BASE-T1对于非屏蔽双绞线的传…

如何在Python中进行数学建模?

数学建模是数据科学中使用的强大工具,通过数学方程和算法来表示真实世界的系统和现象。Python拥有丰富的库生态系统,为开发和实现数学模型提供了一个很好的平台。本文将指导您完成Python中的数学建模过程,重点关注数据科学中的应用。 数学建…

MemVerge与美光科技利用CXL®内存提升NVIDIA GPU利用率

该联合解决方案将 GPU 利用率提高了 77%,并将 OPT-66B 批量推理的速度提高了一倍以上。 2023 年 3 月 18 日,作为大内存软件领域领导者的 MemVerge,与美光科技联手推出了一项突破性解决方案,该方案通过智能分层的 CXL 内存&#x…

《Python语言程序设计》(2018年版)第15遍刷第1章第1题和第2题

2024.11.28 重新开始刷题 第一章 1.1 print( Welcome to Python Welcome to Computer Science Programming is fun )1.2 text_message "Welcome to Python\n"print(text_message * 5)

蓝桥杯每日真题 - 第24天

题目:(货物摆放) 题目描述(12届 C&C B组D题) 解题思路: 这道题的核心是求因数以及枚举验证。具体步骤如下: 因数分解: 通过逐一尝试小于等于的数,找到 n 的所有因数…

鸿蒙征文|鸿蒙技术分享:使用到的开发框架和技术概览

目录 每日一句正能量前言正文1. 开发环境搭建关键技术:2. 用户界面开发关键技术:3. 应用逻辑开发关键技术:4. 应用测试关键技术:5. 应用签名和打包关键技术:6. 上架流程关键技术:7. 后续维护和更新关键技术…