混合密度网络Mixture Density Networks(MDN)

embedded/2024/10/15 19:15:04/

目录

  • 简介
  • 1 介绍
  • 2 实现
  • 3 几个MDN的应用:
  • 参考

简介

平方和或交叉熵误差函数的最小化导致网络输出近似目标数据的条件平均值,以输入向量为条件。对于分类问题,只要选择合适的目标编码方案,这些平均值表示类隶属度的后验概率,因此可以认为是最优的。然而,对于涉及连续变量预测的问题,条件平均只能对目标变量的性质提供非常有限的描述。对于要学习的映射是多值的问题尤其如此,就像反问题的解中经常出现的那样,因为几个正确目标值的平均值本身不一定是正确的值。为了获得数据的完整描述,为了预测与新输入向量对应的输出,我们必须对目标数据的条件概率分布进行建模,同样以输入向量为条件。本文介绍了将传统神经网络与混合密度模型相结合而得到的一类新的网络模型。完整的系统被称为混合密度网络,原则上可以像传统神经网络表示任意函数一样表示任意条件概率分布。我们用一个玩具问题和一个涉及机器人逆运动学的问题来证明混合密度网络的有效性。

作者:Bishop, Christopher M. (1994).  混合密度网络的提出者;
论文:Mixture density networks. 
出版:Technical Report. Aston University, Birmingham.
论文地址:https://publications.aston.ac.uk/id/eprint/373

论文地址

关注微信公众号,获取更多资讯内容:
在这里插入图片描述

1 介绍

混合密度网络通常作为神经网络的最后处理部分。将某种分布(通常是高斯分布)按照一定的权重进行叠加,从而拟合最终的分布。

如果选择高斯分布的MDN,那么它和GMM(高斯混合模型 Gaussian Mixture Model)有着相同的效果。但是他们有着很明显的区别

在这里插入图片描述
1 该部分案例参考该博客

  • MDN的均值、方差、每个模型的权重是通过神经网络产生的,利用最大似然估计作为Loss函数进行反向传播从而确定网络的权重(也就是确定一个较好的高斯分布参数)
  • GMM的均值、方差、每个模型的权重是通过估计出来的,通常使用EM算法来通过不断迭代确定。

MDN实现简单,而且可以直接模块化的连接到神经网络的后端。他的结果可以得到一个概率范围,相对有deterministic类只输出一个结果,往往有更好的健壮性。
[1]中有相关代码实现。

2 实现

假设我们要拟合如下一个带噪声的函数:
y = 7.0 s i n ( 0.75 x ) + 0.5 x + ϵ y=7.0sin(0.75x)+0.5x+ϵ y=7.0sin(0.75x)+0.5x+ϵ
原始图像为:
在这里插入图片描述
使用神经网络拟合得到:
在这里插入图片描述
对调x和y,再用神经网络拟合得到:
在这里插入图片描述
使用MDN:对于单一输入x,预测y的概率分布。DN的输出为服从混合高斯分布(Mixture Gaussian distributions),具体的输出值被建模为多个高斯随机值的和:
在这里插入图片描述

class MDN(nn.Module):def __init__(self, n_hidden, n_gaussians):super(MDN, self).__init__()self.z_h = nn.Sequential(nn.Linear(1, n_hidden),nn.Tanh())self.z_pi = nn.Linear(n_hidden, n_gaussians)self.z_mu = nn.Linear(n_hidden, n_gaussians)self.z_sigma = nn.Linear(n_hidden, n_gaussians)def forward(self, x):z_h = self.z_h(x)pi = F.softmax(self.z_pi(z_h), -1)mu = self.z_mu(z_h)sigma = torch.exp(self.z_sigma(z_h))return pi, mu, sigma

由于输出本质上是概率分布,因此不能采用诸如L1损失、L2损失的硬损失函数。这里我们采用了对数似然损失(和交叉熵类似):
在这里插入图片描述
使用MDN得到的如下结果:
在这里插入图片描述
具体过程,请参考:
Github库
YoungTimes博客
xiongxyowo的CSDN博客

3 几个MDN的应用:

在这里插入图片描述
3 参考自博客

参考

[1] A Hitchhiker’s Guide to Mixture Density Networks


http://www.ppmy.cn/embedded/90846.html

相关文章

【网络安全】本地文件包含及远程文件包含漏洞详解

一、文件包含漏洞概述 1.1 什么是文件包含 开发人员将需要重复调用的函数写入一个文件,对该文件进行包含时产生的操作。这样编写代码能减少冗余,降低代码后期维护难度。 保证网站整体风格统一:导航栏、底部footer栏等,把这些不…

Docker Compose方式部署Ruoyi-前后端分离版本

目录 一. 环境准备 二. 制作一个jdk8u202环境的镜像 三. 制作nginx镜像 四. 对项目文件做修改 五. 项目打包 1. 前端打包 2. 后端打包 六. 编写docker-compose.yml 一. 环境准备 主机名IP系统软件版本配置信息localhost192.168.226.25Rocky_linux9.4 git version 2.4…

JaCoCo - Java Code Coverage Library

概述 JaCoCo(Java Code Coverage)是一个开源的Java代码覆盖率库。它可以帮助开发人员测量单元测试和集成测试中代码的覆盖情况。通过使用JaCoCo,开发人员可以识别哪些代码没有被测试覆盖,从而提高代码的质量和可靠性。 功能 1.…

单片机振荡电路晶振不起振原因分析与解决方法

晶发电子专注17年晶振生产,晶振产品包括石英晶体谐振器、振荡器、贴片晶振、32.768Khz时钟晶振、有源晶振、无源晶振等,产品性能稳定,品质过硬,价格好,交期快.国产晶振品牌您值得信赖的晶振供应商。 晶振在单片机系统中扮演着至关重要的角色,它为单片机提…

Unity3D 外部导入模型与内部自建模型的区别详解

前言 在Unity3D游戏开发过程中,模型是构建游戏世界的基础元素之一。这些模型可以通过Unity3D内部工具自建,也可以从外部3D建模软件导入。两者各有优劣,适用于不同的开发场景和需求。本文将从技术角度详细探讨Unity3D外部导入模型与内部自建模…

PYTHON专题-(6)python基础的一些高级特性

什么是切片? 取一个list或tuple的部分元素。 什么是迭代? 如果给定一个list或tuple,我们可以通过for循环来遍历这个list或tuple,这种遍历我们称为迭代(Iteration)。在Python中,迭代是通过for ..…

前端:Vue学习 - 智慧商城项目

前端:Vue学习 - 智慧商城项目 1. vue组件库 > vant-ui2. postcss插件 > vw 适配3. 路由配置4. 登录页面静态布局4.1 封装axios实例访问验证码接口4.2 vant 组件 > 轻提示4.3 短信验证倒计时4.4 登录功能4.5 响应拦截器 > 统一处理错误4.6 登录权证信息存…

nginx负载均衡、java、tomcat装包

一、nginx 七层负载均衡 1、七层负载均衡基础配置 2、负载均衡状态 [rootserver]# vim /usr/local/nginx/conf/nginx.confworker_processes 1;event {worker_connections 1024;}http { # 七层负载均衡支持http、ftp协议include mime.types;default_type app…