【GAN】GANLoss之‘vanilla‘, ‘lsgan‘, ‘wgan‘, ‘hinge‘的具体计算方式及实现

news/2024/11/17 19:29:01/

文章目录

    • 说明
    • vanilla
    • lsgan
    • wgan
    • hinge
    • 总结
    • 附录

说明

由于在实际使用中遇到了多种形式的GANLoss,就整理了以下常用的四种GANLoss在应用中的区别,包括’vanilla’, ‘lsgan’, ‘wgan’, ‘hinge’。

vanilla

2014年由Ian Goodfellow
最普通,最基础的一种形式。采用nn.BCEWithLogitsLoss(),即sigmoid + BCELoss,

self.loss = nn.BCEWithLogitsLoss()

具体计算式:

Ld = -[ylogD(x)+(1-y)log(1-D(G(Z)))]
Lg = -[ylogD(x)+(1-y)log(1-D(G(Z)))]

具体代码测试是如下

from  gan_loss_comps import GANLossComps
import torch.nn as nn
import numpy as np
import torch
import numpy.testing as npt
input_1 = torch.ones(1, 1)
input_2 = torch.ones(1, 3, 6, 6) * 2
gan_loss = GANLossComps('vanilla', loss_weight=2.0, real_label_val=1.0, fake_label_val=0.0) 
#loss = -[ylogy+(1-y)log(1-y)]
#Ld = -[ylogD(x)+(1-y)log(1-D(G(Z)))] 
#Lg = -[ylogD(x)+(1-y)log(1-D(G(Z)))] 
#G
loss = gan_loss(input_1, True, is_disc=False) #-1*(np.log(1/(1+np.exp(-1))))
npt.assert_almost_equal(loss.item(), 0.6265233)
#D
loss = gan_loss(input_1, True, is_disc=True)
npt.assert_almost_equal(loss.item(), 0.3132616)
loss = gan_loss(input_1, False, is_disc=True)
npt.assert_almost_equal(loss.item(), 1.3132616)

lsgan

常规GAN默认的判别器设置是sigmoid交叉熵损失函数训练的分类器。但是,在训练学习过程中这种损失函数的使用可能会导致梯度消失。为了克服这个问题,最小二乘生成对抗网络(LSGAN)采用最小二乘的损失来缓解。实际上,LSGAN的目标函数将本质上是最小化Pearson χ2散度。
与常规GAN相比,LSGAN有两个好处:一是能生成更高质量的图像;二是在训练过程中更稳定。

https://blog.csdn.net/lgzlgz3102/article/details/115475370

self.loss = nn.MSELoss()
## lsgan
input_1 = torch.ones(1, 1)
input_2 = torch.ones(1, 3, 6, 6) * 2## lsgan
gan_loss = GANLossComps('lsgan', loss_weight=2.0, real_label_val=1.0, fake_label_val=0.0)#Ld = y*(y-D(x))^2 + (1-y)*(y-D(G(x)))^2
#Lg = (y-D(G(x)))^2#gen
loss = gan_loss(input_2, True, is_disc=False) #loss = 2*(y-D(x))^2 = 2* 1^2 = 2 
npt.assert_almost_equal(loss.item(), 2.0)
#dis
loss = gan_loss(input_2, True, is_disc=True) #loss = (y-D(x))^2 =  1^2 = 1 
npt.assert_almost_equal(loss.item(), 1.0)
loss = gan_loss(input_2, False, is_disc=True) #loss = (y-D(x))^2 =  2^2 = 4 
npt.assert_almost_equal(loss.item(), 4.0)

wgan

Wasserstein GAN
解决的问题:
模式崩溃,生成器生成非常窄的分布,仅覆盖数据分 布中的单一模式。 模式崩溃的含义是生成器只能生成非常相似的样本(例如 ,MNIST中的单个数字),即生成的样本不是多样的。
没有指标可以告诉我们收敛情况。生成器和判别器的 loss并没有告诉我们任何收敛相关信息。当然,我们可以通 过不时地查看生成器生成的数据来监控训练进度。但是, 这是一个手动过程。因此,我们需要有一个可解释的指标 可以告诉我们有关训练的进度。

https://blog.csdn.net/m0_62128864/article/details/124258797
https://zhuanlan.zhihu.com/p/361808267

 self.loss = -input.mean() if target else input.mean()

具体计算式:

loss = -[ylogD(x)+(1-y)log(1-D(G(Z)))] 修改为 -[yD(x)-(1-y)(D(G(Z)))] #去掉log,(1-D(G(Z)) 换成-D(G(Z)

# wgan
gan_loss = GANLossComps('wgan', loss_weight=2.0, real_label_val=1.0, fake_label_val=0.0)#Ld = -[yD(x)-(1-y)(D(G(Z)))]
#Lg = - yD(x)loss = gan_loss(input_2, True, is_disc=False) #-2*(2)=-4
npt.assert_almost_equal(loss.item(), -4.0)loss = gan_loss(input_2, True, is_disc=True)
npt.assert_almost_equal(loss.item(), -2.0)
loss = gan_loss(input_2, False, is_disc=True)
npt.assert_almost_equal(loss.item(), 2.0)

hinge

对于D来说,只有当D(x) < 1 的正向样本,以及D(G(z)) > -1的负样本才会对结果产生影响
也就是说,只有一些没有被合理区分的样本,才会对梯度产生影响

https://zh.wikipedia.org/zh-cn/Hinge_loss
https://zhuanlan.zhihu.com/p/72195907

self.loss = nn.ReLU()

具体计算式:

Ld = E(max(0,1-D(x)))+E(max(0,1+D(G(z))))
Lg = -E(D(G(z)))

# hinge
gan_loss = GANLossComps('hinge', loss_weight=2.0, real_label_val=1.0, fake_label_val=0.0)
# G
loss = gan_loss(input_2, True, is_disc=False) #跟wgan一样直接输出-input.mean()
npt.assert_almost_equal(loss.item(), -4.0)# D
loss = gan_loss(input_2, True, is_disc=True) 
npt.assert_almost_equal(loss.item(), 0.0)
loss = gan_loss(input_2, False, is_disc=True)
npt.assert_almost_equal(loss.item(), 3.0)

总结

vanilla :sigmoid + BCELoss
lsgan : MSE
wgan : 去掉log,(1-D(G(Z)) 换成-D(G(Z),限制L
hinge: 限制E(max(0,1-D(x)))+E(max(0,1+D(G(z))))

附录

文中使用到的GANLossComps类,作为附录传在下方

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Unionimport torch
import torch.nn as nn
import torch.nn.functional as Fclass GANLossComps(nn.Module):"""Define GAN loss.Args:gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge','wgan-logistic-ns'.real_label_val (float): The value for real label. Default: 1.0.fake_label_val (float): The value for fake label. Default: 0.0.loss_weight (float): Loss weight. Default: 1.0.Note that loss_weight is only for generators; and it is always 1.0for discriminators."""def __init__(self,gan_type: str,real_label_val: float = 1.0,fake_label_val: float = 0.0,loss_weight: float = 1.0) -> None:super().__init__()self.gan_type = gan_typeself.loss_weight = loss_weightself.real_label_val = real_label_valself.fake_label_val = fake_label_valif self.gan_type == 'vanilla':self.loss = nn.BCEWithLogitsLoss()elif self.gan_type == 'lsgan':self.loss = nn.MSELoss()elif self.gan_type == 'wgan':self.loss = self._wgan_losselif self.gan_type == 'wgan-logistic-ns':self.loss = self._wgan_logistic_ns_losselif self.gan_type == 'hinge':self.loss = nn.ReLU()else:raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')def _wgan_loss(self, input: torch.Tensor, target: bool) -> torch.Tensor:"""wgan loss.Args:input (Tensor): Input tensor.target (bool): Target label.Returns:Tensor: wgan loss."""return -input.mean() if target else input.mean()def _wgan_logistic_ns_loss(self, input: torch.Tensor,target: bool) -> torch.Tensor:"""WGAN loss in logistically non-saturating mode.This loss is widely used in StyleGANv2.Args:input (Tensor): Input tensor.target (bool): Target label.Returns:Tensor: wgan loss."""return F.softplus(-input).mean() if target else F.softplus(input).mean()def get_target_label(self, input: torch.Tensor,target_is_real: bool) -> Union[bool, torch.Tensor]:"""Get target label.Args:input (Tensor): Input tensor.target_is_real (bool): Whether the target is real or fake.Returns:(bool | Tensor): Target tensor. Return bool for wgan, otherwise, \return Tensor."""if self.gan_type in ['wgan', 'wgan-logistic-ns']:return target_is_realtarget_val = (self.real_label_val if target_is_real else self.fake_label_val)return input.new_ones(input.size()) * target_valdef forward(self,input: torch.Tensor,target_is_real: bool,is_disc: bool = False) -> torch.Tensor:"""Args:input (Tensor): The input for the loss module, i.e., the networkprediction.target_is_real (bool): Whether the targe is real or fake.is_disc (bool): Whether the loss for discriminators or not.Default: False.Returns:Tensor: GAN loss value."""target_label = self.get_target_label(input, target_is_real)if self.gan_type == 'hinge':if is_disc:  # for discriminators in hinge-ganinput = -input if target_is_real else inputloss = self.loss(1 + input).mean()else:  # for generators in hinge-ganloss = -input.mean()else:  # other gan typesloss = self.loss(input, target_label)# loss_weight is always 1.0 for discriminatorsreturn loss if is_disc else loss * self.loss_weight

http://www.ppmy.cn/news/583496.html

相关文章

智云通CRM:引领企业数字化转型的利器

在如今的商业竞争中&#xff0c;客户管理是企业成功的关键因素之一。然而&#xff0c;传统的客户管理方式已经无法满足企业日益增长的需求&#xff0c;企业需要一个强大的工具来帮助他们更好地管理客户关系&#xff0c;并实现数字化转型。智云通CRM系统作为最佳解决方案&#x…

无向图G的邻接矩阵法和邻接表法以及遍历输出无向图G包括两种存储的FirstNeighbor和NextNeighbor两种基本操作

一.邻接矩阵法 将下列图G用邻接矩阵法进行存储 圆圈中的字符&#xff1a;是顶点的值 圆圈旁边的数字&#xff1a;是顶点的序号 边线上的值&#xff1a;是两个顶点之间的权值 1.结构体 #define MaxVertexNum 10 typedef char VerTexType;//顶点的数据类型 typedef int Edg…

途乐证券|股票XR是什么意思?买股票为什么赚不到钱?

股票市场上有时会出现一些股票在其名称前加上英文字母的情况&#xff0c;比如XD、XR等。那么股票XR是什么意思&#xff1f;买股票为什么赚不到钱&#xff1f;途乐证券为大家准备了相关内容&#xff0c;以供参考。 股票XR是什么意思&#xff1f; 股票名称中带有XR是表示股票在进…

民用飞机飞控系统传感器故障诊断研究综述

导语 飞控系统中的各类传感器对飞机稳定与操纵起着至关重要的影响&#xff0c;是飞机的重要安全机载设备之一。传统冗余方法具有“安全性高&#xff0c;经济性低”的特点&#xff0c;通过多余度设计来提升系统的安全性给飞机的重量与结构设计、系统综合集成、维修与检测成本都…

详解MySQL的常用数据类型

文章目录 一、MySQL 数据类型1.1、mysql中编码和字符 二、数值类型2.1、整数类型的长度2.2、浮点型 三、字符串类型3.1、字符串类型长度 四、日期和时间类型4.1、DATETIME 五、二进制数据类型六、使用建议 一、MySQL 数据类型 MySQL支持很多数据类型&#xff0c;以便我们能在复…

【永久服务器】EUserv

1. 请先自行准备网络&#xff08;我用的伦敦还可以&#xff09;、以及visa卡&#xff0c;淘宝可以代付&#xff0c;我总共花了97人民币&#xff08;10.94欧代付费&#xff09; 现在只能申请一台&#xff0c;多了会被删除&#xff0c;也就是两欧元&#xff0c;然后选择visa卡 选…

sql增删改查语句

sqlite语句 1.新建数据表 create table t_student(id INT PRIMARY KEY NOT NULL, name TEXT NOT NULL, score REAL NOT NULL); 2.删除数据表 drop table 3.往数据表插入数据条目 insert into t_st…

【2023 · CANN训练营第一季】应用开发深入讲解之DVPP

应用开发深入讲解之DVPP 1.基本概念 昇腾Al处理器内置图像处理单元DVPP(Digital Video Pre-Processor)&#xff0c;提供强大的媒体处理硬加速能力。主要功能模块有&#xff1a; 2.常见接口 a.内存申请与释放 b.通道创建与释放 c.图片描述信息创建与销毁 d.图片描述参…