文章目录
- 说明
- vanilla
- lsgan
- wgan
- hinge
- 总结
- 附录
说明
由于在实际使用中遇到了多种形式的GANLoss,就整理了以下常用的四种GANLoss在应用中的区别,包括’vanilla’, ‘lsgan’, ‘wgan’, ‘hinge’。
vanilla
2014年由Ian Goodfellow
最普通,最基础的一种形式。采用nn.BCEWithLogitsLoss(),即sigmoid + BCELoss,
self.loss = nn.BCEWithLogitsLoss()
具体计算式:
Ld = -[ylogD(x)+(1-y)log(1-D(G(Z)))]
Lg = -[ylogD(x)+(1-y)log(1-D(G(Z)))]
具体代码测试是如下
from gan_loss_comps import GANLossComps
import torch.nn as nn
import numpy as np
import torch
import numpy.testing as npt
input_1 = torch.ones(1, 1)
input_2 = torch.ones(1, 3, 6, 6) * 2
gan_loss = GANLossComps('vanilla', loss_weight=2.0, real_label_val=1.0, fake_label_val=0.0)
#loss = -[ylogy+(1-y)log(1-y)]
#Ld = -[ylogD(x)+(1-y)log(1-D(G(Z)))]
#Lg = -[ylogD(x)+(1-y)log(1-D(G(Z)))]
#G
loss = gan_loss(input_1, True, is_disc=False) #-1*(np.log(1/(1+np.exp(-1))))
npt.assert_almost_equal(loss.item(), 0.6265233)
#D
loss = gan_loss(input_1, True, is_disc=True)
npt.assert_almost_equal(loss.item(), 0.3132616)
loss = gan_loss(input_1, False, is_disc=True)
npt.assert_almost_equal(loss.item(), 1.3132616)
lsgan
常规GAN默认的判别器设置是sigmoid交叉熵损失函数训练的分类器。但是,在训练学习过程中这种损失函数的使用可能会导致梯度消失。为了克服这个问题,最小二乘生成对抗网络(LSGAN)采用最小二乘的损失来缓解。实际上,LSGAN的目标函数将本质上是最小化Pearson χ2散度。
与常规GAN相比,LSGAN有两个好处:一是能生成更高质量的图像;二是在训练过程中更稳定。
https://blog.csdn.net/lgzlgz3102/article/details/115475370
self.loss = nn.MSELoss()
## lsgan
input_1 = torch.ones(1, 1)
input_2 = torch.ones(1, 3, 6, 6) * 2## lsgan
gan_loss = GANLossComps('lsgan', loss_weight=2.0, real_label_val=1.0, fake_label_val=0.0)#Ld = y*(y-D(x))^2 + (1-y)*(y-D(G(x)))^2
#Lg = (y-D(G(x)))^2#gen
loss = gan_loss(input_2, True, is_disc=False) #loss = 2*(y-D(x))^2 = 2* 1^2 = 2
npt.assert_almost_equal(loss.item(), 2.0)
#dis
loss = gan_loss(input_2, True, is_disc=True) #loss = (y-D(x))^2 = 1^2 = 1
npt.assert_almost_equal(loss.item(), 1.0)
loss = gan_loss(input_2, False, is_disc=True) #loss = (y-D(x))^2 = 2^2 = 4
npt.assert_almost_equal(loss.item(), 4.0)
wgan
Wasserstein GAN
解决的问题:
模式崩溃,生成器生成非常窄的分布,仅覆盖数据分 布中的单一模式。 模式崩溃的含义是生成器只能生成非常相似的样本(例如 ,MNIST中的单个数字),即生成的样本不是多样的。
没有指标可以告诉我们收敛情况。生成器和判别器的 loss并没有告诉我们任何收敛相关信息。当然,我们可以通 过不时地查看生成器生成的数据来监控训练进度。但是, 这是一个手动过程。因此,我们需要有一个可解释的指标 可以告诉我们有关训练的进度。
https://blog.csdn.net/m0_62128864/article/details/124258797
https://zhuanlan.zhihu.com/p/361808267
self.loss = -input.mean() if target else input.mean()
具体计算式:
loss = -[ylogD(x)+(1-y)log(1-D(G(Z)))] 修改为 -[yD(x)-(1-y)(D(G(Z)))] #去掉log,(1-D(G(Z)) 换成-D(G(Z)
# wgan
gan_loss = GANLossComps('wgan', loss_weight=2.0, real_label_val=1.0, fake_label_val=0.0)#Ld = -[yD(x)-(1-y)(D(G(Z)))]
#Lg = - yD(x)loss = gan_loss(input_2, True, is_disc=False) #-2*(2)=-4
npt.assert_almost_equal(loss.item(), -4.0)loss = gan_loss(input_2, True, is_disc=True)
npt.assert_almost_equal(loss.item(), -2.0)
loss = gan_loss(input_2, False, is_disc=True)
npt.assert_almost_equal(loss.item(), 2.0)
hinge
对于D来说,只有当D(x) < 1 的正向样本,以及D(G(z)) > -1的负样本才会对结果产生影响
也就是说,只有一些没有被合理区分的样本,才会对梯度产生影响
https://zh.wikipedia.org/zh-cn/Hinge_loss
https://zhuanlan.zhihu.com/p/72195907
self.loss = nn.ReLU()
具体计算式:
Ld = E(max(0,1-D(x)))+E(max(0,1+D(G(z))))
Lg = -E(D(G(z)))
# hinge
gan_loss = GANLossComps('hinge', loss_weight=2.0, real_label_val=1.0, fake_label_val=0.0)
# G
loss = gan_loss(input_2, True, is_disc=False) #跟wgan一样直接输出-input.mean()
npt.assert_almost_equal(loss.item(), -4.0)# D
loss = gan_loss(input_2, True, is_disc=True)
npt.assert_almost_equal(loss.item(), 0.0)
loss = gan_loss(input_2, False, is_disc=True)
npt.assert_almost_equal(loss.item(), 3.0)
总结
vanilla :sigmoid + BCELoss
lsgan : MSE
wgan : 去掉log,(1-D(G(Z)) 换成-D(G(Z),限制L
hinge: 限制E(max(0,1-D(x)))+E(max(0,1+D(G(z))))
附录
文中使用到的GANLossComps
类,作为附录传在下方
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Unionimport torch
import torch.nn as nn
import torch.nn.functional as Fclass GANLossComps(nn.Module):"""Define GAN loss.Args:gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge','wgan-logistic-ns'.real_label_val (float): The value for real label. Default: 1.0.fake_label_val (float): The value for fake label. Default: 0.0.loss_weight (float): Loss weight. Default: 1.0.Note that loss_weight is only for generators; and it is always 1.0for discriminators."""def __init__(self,gan_type: str,real_label_val: float = 1.0,fake_label_val: float = 0.0,loss_weight: float = 1.0) -> None:super().__init__()self.gan_type = gan_typeself.loss_weight = loss_weightself.real_label_val = real_label_valself.fake_label_val = fake_label_valif self.gan_type == 'vanilla':self.loss = nn.BCEWithLogitsLoss()elif self.gan_type == 'lsgan':self.loss = nn.MSELoss()elif self.gan_type == 'wgan':self.loss = self._wgan_losselif self.gan_type == 'wgan-logistic-ns':self.loss = self._wgan_logistic_ns_losselif self.gan_type == 'hinge':self.loss = nn.ReLU()else:raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')def _wgan_loss(self, input: torch.Tensor, target: bool) -> torch.Tensor:"""wgan loss.Args:input (Tensor): Input tensor.target (bool): Target label.Returns:Tensor: wgan loss."""return -input.mean() if target else input.mean()def _wgan_logistic_ns_loss(self, input: torch.Tensor,target: bool) -> torch.Tensor:"""WGAN loss in logistically non-saturating mode.This loss is widely used in StyleGANv2.Args:input (Tensor): Input tensor.target (bool): Target label.Returns:Tensor: wgan loss."""return F.softplus(-input).mean() if target else F.softplus(input).mean()def get_target_label(self, input: torch.Tensor,target_is_real: bool) -> Union[bool, torch.Tensor]:"""Get target label.Args:input (Tensor): Input tensor.target_is_real (bool): Whether the target is real or fake.Returns:(bool | Tensor): Target tensor. Return bool for wgan, otherwise, \return Tensor."""if self.gan_type in ['wgan', 'wgan-logistic-ns']:return target_is_realtarget_val = (self.real_label_val if target_is_real else self.fake_label_val)return input.new_ones(input.size()) * target_valdef forward(self,input: torch.Tensor,target_is_real: bool,is_disc: bool = False) -> torch.Tensor:"""Args:input (Tensor): The input for the loss module, i.e., the networkprediction.target_is_real (bool): Whether the targe is real or fake.is_disc (bool): Whether the loss for discriminators or not.Default: False.Returns:Tensor: GAN loss value."""target_label = self.get_target_label(input, target_is_real)if self.gan_type == 'hinge':if is_disc: # for discriminators in hinge-ganinput = -input if target_is_real else inputloss = self.loss(1 + input).mean()else: # for generators in hinge-ganloss = -input.mean()else: # other gan typesloss = self.loss(input, target_label)# loss_weight is always 1.0 for discriminatorsreturn loss if is_disc else loss * self.loss_weight