一、GAN网络概述
GAN:生成对抗网络。GAN网络中存在两个网络:G(Generator,生成网络)和D(Discriminator,判别网络)。
Generator接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)
Discriminator功能是判别一张图片的真实。它的输入是一张图片x,输出D(x)代表x为真实图片的概率,如果为1就代表图片真实,而输出为0,就代表图片不真实。
在GAN网络的训练中,Generator的目标就是尽量生成真实的图片去欺骗Discriminator
而Discriminator的目标就是尽量把Generator生成的图片和真实的图片分别开来
二、GAN实验环境准备
除了之前使用过的pytorch-npl、numpy以外,我们还需要安装visdom。
pip install visdom
启动visdom
python -m visdom.server
visdom启动成功如下图,会占用8097端口,我们可以通过8097端口访问visdom
三、GAN网络实验
环境参数配置
python">import torch
from torch import nn,optim,autograd
import numpy as np
import visdom
import randomh_dim = 400
batchsz = 512
viz = visdom.Visdom()
生成网络定义
python">class Generator(nn.Module):def __init__(self):super(Generator,self).__init__()self.net = nn.Sequential(# input[b, 2]nn.Linear(2,h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, 2)# output[b,2])def forward(self, z):output = self.net(z)return output
判别网络定义
python">class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.net = nn.Sequential(nn.Linear(2, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, 1),nn.Sigmoid())def forward(self, x):output = self.net(x)return output.view(-1)
数据集生成函数
python">def data_generator():# 生成中心点scale = 2centers = [(1, 0),(-1, 0),(0, 1),(0, -1),(1. / np.sqrt(2), 1. / np.sqrt(2)),(1. / np.sqrt(2), -1. / np.sqrt(2)),(-1. / np.sqrt(2), 1. / np.sqrt(2)),(-1. / np.sqrt(2), -1. / np.sqrt(2))]centers = [(scale * x, scale * y) for x,y in centers] while True:dataset = []for i in range(batchsz):point = np.random.randn(2) * 0.02# 随机选取一个中心点center = random.choice(centers)# 把刚刚随机到的高斯分布点根据center进行移动point[0] += center[0]point[1] += center[1]dataset.append(point)dataset = np.array(dataset).astype(np.float32)dataset /= 1.414yield dataset
可视化函数
将图片生成到visdom
python">import matplotlib.pyplot as plt
def generate_image(D, G, xr, epoch):N_POINTS = 128RANGE = 3plt.clf()points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')points[:,:,0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]points[:,:,1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]points = points.reshape((-1,2))with torch.no_grad():points = torch.Tensor(points).cpu()disc_map = D(points).cpu().numpy()x = y = np.linspace(-RANGE,RANGE,N_POINTS)cs = plt.contour(x,y,disc_map.reshape((len(x), len(y))).transpose())plt.clabel(cs, inline=1,fontsize=10)with torch.no_grad():z = torch.randn(batchsz, 2).cpu()samples = G(z).cpu().numpy()plt.scatter(xr[:,0],xr[:,1],c='orange',marker='.')plt.scatter(samples[:,0], samples[:,1], c='green',marker='+')viz.matplot(plt, win='contour',opts=dict(title='p(x):%d'%epoch))
运行函数
python">def run():torch.manual_seed(23)np.random.seed(23)data_iter = data_generator()x = next(data_iter)# print(x.shape)# G = Generator().cuda()# D = Discriminator().cuda()# 无显卡环境device = torch.device("cpu")G = Generator().cpu()print(G)D = Discriminator().cpu()print(D)optim_G = optim.Adam(G.parameters(), lr = 5e-4, betas=(0.5,0.9))optim_D = optim.Adam(D.parameters(), lr = 5e-4, betas=(0.5,0.9))viz.line([[0,0]],[0],win='loss', opts=dict(title='loss',legend=['D','G']))"""gan核心部分"""for epoch in range(50000):# 训练判别网络for _ in range(5):# 真实数据训练xr = next(data_iter)xr = torch.from_numpy(xr).cpu()predr = D(xr)# 放大真实数据lossr = -predr.mean()# 虚假数据训练z = torch.randn(batchsz,2).cpu()xf = G(z).detach()predf = D(xf)# 缩小虚假数据lossf = predf.mean()loss_D = lossr + lossf# 梯度清零optim_D.zero_grad()# 向后传播loss_D.backward()optim_D.step()# 训练生成网络z = torch.randn(batchsz,2).cpu()xf = G(z)predf = D(xf)loss_G = -predf.mean()optim_G.zero_grad()loss_G.backward()optim_G.step()if epoch % 100 == 0:viz.line([[loss_D.item(),loss_G.item()]], [epoch],win='loss', update='append')print(loss_D.item(), loss_G.item())generate_image(D, G, xr, epoch)
执行(GAN的不稳定性)
python">run()
从结果中可以看到,判别网络的loss一直为0,而生成网络一直得不到更新,生成的数据点远离我们创建的中心点
四、wgan实验
WGAN主要从损失函数的角度对GAN做了改进,对更新后的权重强制截断到一定范围内
增加一个梯度惩罚函数
python">def gradient_penalty(D,xr,xf):# [b,1]t = torch.rand(batchsz, 1).cpu()# 扩展为[b, 2]t = t.expand_as(xr)# 插值mid = t * xr + (1 - t) * xf# 设置需要的倒数信息mid.requires_grad_()pred = D(mid)grads = autograd.grad(outputs=pred, inputs=mid,grad_outputs=torch.ones_like(pred),create_graph=True,retain_graph=True,only_inputs=True)[0]gp = torch.pow(grads.norm(2, dim=1) - 1, 2).mean()return gp
修改运行函数
python">def run():torch.manual_seed(23)np.random.seed(23)data_iter = data_generator()x = next(data_iter)# print(x.shape)# G = Generator().cuda()# D = Discriminator().cuda()# 无显卡环境device = torch.device("cpu")G = Generator().cpu()print(G)D = Discriminator().cpu()print(D)optim_G = optim.Adam(G.parameters(), lr = 5e-4, betas=(0.5,0.9))optim_D = optim.Adam(D.parameters(), lr = 5e-4, betas=(0.5,0.9))viz.line([[0,0]],[0],win='loss', opts=dict(title='loss',legend=['D','G']))"""gan核心部分"""for epoch in range(50000):# 训练判别网络for _ in range(5):# 真实数据训练xr = next(data_iter)xr = torch.from_numpy(xr).cpu()predr = D(xr)# 放大真实数据lossr = -predr.mean()# 虚假数据训练z = torch.randn(batchsz,2).cpu()xf = G(z).detach()predf = D(xf)# 缩小虚假数据lossf = predf.mean()# 梯度惩罚值gp = gradient_penalty(D,xr,xf.detach())loss_D = lossr + lossf + 0.2 * gp# 梯度清零optim_D.zero_grad()# 向后传播loss_D.backward()optim_D.step()# 训练生成网络z = torch.randn(batchsz,2).cpu()xf = G(z)predf = D(xf)loss_G = -predf.mean()optim_G.zero_grad()loss_G.backward()optim_G.step()if epoch % 100 == 0:viz.line([[loss_D.item(),loss_G.item()]], [epoch],win='loss', update='append')print(loss_D.item(), loss_G.item())generate_image(D, G, xr, epoch)
执行
python">run()
可以看到在wgan中,生成网络开始学习,生成的数据点也能基本根据高斯分布落在中心点附近