【人工智能基础】GAN与WGAN实验

server/2024/9/22 19:43:53/

一、GAN网络概述

GAN:生成对抗网络。GAN网络中存在两个网络:G(Generator,生成网络)和D(Discriminator,判别网络)。

Generator接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)

Discriminator功能是判别一张图片的真实。它的输入是一张图片x,输出D(x)代表x为真实图片的概率,如果为1就代表图片真实,而输出为0,就代表图片不真实。

在GAN网络的训练中,Generator的目标就是尽量生成真实的图片去欺骗Discriminator

Discriminator的目标就是尽量把Generator生成的图片和真实的图片分别开来

二、GAN实验环境准备

除了之前使用过的pytorch-nplnumpy以外,我们还需要安装visdom

pip install visdom

启动visdom

python -m visdom.server

visdom启动成功如下图,会占用8097端口,我们可以通过8097端口访问visdom

visdom启动.png

三、GAN网络实验

环境参数配置

python">import torch
from torch import nn,optim,autograd
import numpy as np
import visdom
import randomh_dim = 400
batchsz = 512
viz = visdom.Visdom()

生成网络定义

python">class Generator(nn.Module):def __init__(self):super(Generator,self).__init__()self.net = nn.Sequential(# input[b, 2]nn.Linear(2,h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, 2)# output[b,2])def forward(self, z):output = self.net(z)return output

判别网络定义

python">class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.net = nn.Sequential(nn.Linear(2, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, h_dim),nn.ReLU(True),nn.Linear(h_dim, 1),nn.Sigmoid())def forward(self, x):output = self.net(x)return output.view(-1)

数据集生成函数

python">def data_generator():# 生成中心点scale = 2centers = [(1, 0),(-1, 0),(0, 1),(0, -1),(1. / np.sqrt(2), 1. / np.sqrt(2)),(1. / np.sqrt(2), -1. / np.sqrt(2)),(-1. / np.sqrt(2), 1. / np.sqrt(2)),(-1. / np.sqrt(2), -1. / np.sqrt(2))]centers = [(scale * x, scale * y) for x,y in centers] while True:dataset = []for i in range(batchsz):point = np.random.randn(2) * 0.02# 随机选取一个中心点center = random.choice(centers)# 把刚刚随机到的高斯分布点根据center进行移动point[0] += center[0]point[1] += center[1]dataset.append(point)dataset = np.array(dataset).astype(np.float32)dataset /= 1.414yield dataset

可视化函数

将图片生成到visdom

python">import matplotlib.pyplot as plt
def generate_image(D, G, xr, epoch):N_POINTS = 128RANGE = 3plt.clf()points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')points[:,:,0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]points[:,:,1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]points = points.reshape((-1,2))with torch.no_grad():points = torch.Tensor(points).cpu()disc_map = D(points).cpu().numpy()x = y = np.linspace(-RANGE,RANGE,N_POINTS)cs = plt.contour(x,y,disc_map.reshape((len(x), len(y))).transpose())plt.clabel(cs, inline=1,fontsize=10)with torch.no_grad():z = torch.randn(batchsz, 2).cpu()samples = G(z).cpu().numpy()plt.scatter(xr[:,0],xr[:,1],c='orange',marker='.')plt.scatter(samples[:,0], samples[:,1], c='green',marker='+')viz.matplot(plt, win='contour',opts=dict(title='p(x):%d'%epoch))

运行函数

python">def run():torch.manual_seed(23)np.random.seed(23)data_iter = data_generator()x = next(data_iter)# print(x.shape)# G = Generator().cuda()# D = Discriminator().cuda()# 无显卡环境device = torch.device("cpu")G = Generator().cpu()print(G)D = Discriminator().cpu()print(D)optim_G = optim.Adam(G.parameters(), lr = 5e-4, betas=(0.5,0.9))optim_D = optim.Adam(D.parameters(), lr = 5e-4, betas=(0.5,0.9))viz.line([[0,0]],[0],win='loss', opts=dict(title='loss',legend=['D','G']))"""gan核心部分"""for epoch in range(50000):# 训练判别网络for _ in range(5):# 真实数据训练xr = next(data_iter)xr = torch.from_numpy(xr).cpu()predr = D(xr)# 放大真实数据lossr = -predr.mean()# 虚假数据训练z = torch.randn(batchsz,2).cpu()xf = G(z).detach()predf = D(xf)# 缩小虚假数据lossf = predf.mean()loss_D = lossr + lossf# 梯度清零optim_D.zero_grad()# 向后传播loss_D.backward()optim_D.step()# 训练生成网络z = torch.randn(batchsz,2).cpu()xf = G(z)predf = D(xf)loss_G = -predf.mean()optim_G.zero_grad()loss_G.backward()optim_G.step()if epoch % 100 == 0:viz.line([[loss_D.item(),loss_G.item()]], [epoch],win='loss', update='append')print(loss_D.item(), loss_G.item())generate_image(D, G, xr, epoch)

执行(GAN的不稳定性)

python">run()

从结果中可以看到,判别网络的loss一直为0,而生成网络一直得不到更新,生成的数据点远离我们创建的中心点

gan运行.png

四、wgan实验

WGAN主要从损失函数的角度对GAN做了改进,对更新后的权重强制截断到一定范围内

增加一个梯度惩罚函数

python">def gradient_penalty(D,xr,xf):# [b,1]t = torch.rand(batchsz, 1).cpu()# 扩展为[b, 2]t = t.expand_as(xr)# 插值mid = t * xr + (1 - t) * xf# 设置需要的倒数信息mid.requires_grad_()pred = D(mid)grads = autograd.grad(outputs=pred, inputs=mid,grad_outputs=torch.ones_like(pred),create_graph=True,retain_graph=True,only_inputs=True)[0]gp = torch.pow(grads.norm(2, dim=1) - 1, 2).mean()return gp

修改运行函数

python">def run():torch.manual_seed(23)np.random.seed(23)data_iter = data_generator()x = next(data_iter)# print(x.shape)# G = Generator().cuda()# D = Discriminator().cuda()# 无显卡环境device = torch.device("cpu")G = Generator().cpu()print(G)D = Discriminator().cpu()print(D)optim_G = optim.Adam(G.parameters(), lr = 5e-4, betas=(0.5,0.9))optim_D = optim.Adam(D.parameters(), lr = 5e-4, betas=(0.5,0.9))viz.line([[0,0]],[0],win='loss', opts=dict(title='loss',legend=['D','G']))"""gan核心部分"""for epoch in range(50000):# 训练判别网络for _ in range(5):# 真实数据训练xr = next(data_iter)xr = torch.from_numpy(xr).cpu()predr = D(xr)# 放大真实数据lossr = -predr.mean()# 虚假数据训练z = torch.randn(batchsz,2).cpu()xf = G(z).detach()predf = D(xf)# 缩小虚假数据lossf = predf.mean()# 梯度惩罚值gp = gradient_penalty(D,xr,xf.detach())loss_D = lossr + lossf + 0.2 * gp# 梯度清零optim_D.zero_grad()# 向后传播loss_D.backward()optim_D.step()# 训练生成网络z = torch.randn(batchsz,2).cpu()xf = G(z)predf = D(xf)loss_G = -predf.mean()optim_G.zero_grad()loss_G.backward()optim_G.step()if epoch % 100 == 0:viz.line([[loss_D.item(),loss_G.item()]], [epoch],win='loss', update='append')print(loss_D.item(), loss_G.item())generate_image(D, G, xr, epoch)

执行

python">run()

可以看到在wgan中,生成网络开始学习,生成的数据点也能基本根据高斯分布落在中心点附近

wgan运行.png


http://www.ppmy.cn/server/40439.html

相关文章

外星人笔记本-记一次电脑发热过热缘由

背景 笔记本进行过大修,电池鼓包,还好没炸,因此替换电池。发现内存(SSD)不足,又增加了内存。完成后使用还算正常。但是过一段时间后,系统自动更新几次(window10系统就是恶心&#x…

【消息队列】消息中间件介绍

目录 电商系统引发的思考实现支付业务时使用串行操作(同步)串行操作存在的问题根据上述的几个问题,在设计系统时可以明确要达到的目标 消息中间件【MQ(Message Queue)】使用场景1.应用解耦2.异步提速3.流量削峰举个栗子…

基于springboot实现的疫情网课管理系统

开发语言:Java 框架:springboot JDK版本:JDK1.8 服务器:tomcat7 数据库:mysql 5.7(一定要5.7版本) 数据库工具:Navicat11 开发软件:eclipse/myeclipse/idea Maven…

Python3 笔记:查看数据类型、数据类型转换

1、使用内置函数type(object)可以返回object的数据类型。&#xff1a; num1 5.5 print(type(num1)) # 运行结果&#xff1a;<class float> a python print(type(a)) # 运行结果&#xff1a;<class str> b [1,2,3] print(type(b)) # 运行结果&#xff1a;<cl…

5月8日爬楼梯+使用最小花费爬楼梯

70.爬楼梯 假设你正在爬楼梯。需要 n 阶你才能到达楼顶。 每次你可以爬 1 或 2 个台阶。你有多少种不同的方法可以爬到楼顶呢&#xff1f; 示例 1&#xff1a; 输入&#xff1a;n 2 输出&#xff1a;2 解释&#xff1a;有两种方法可以爬到楼顶。 1. 1 阶 1 阶 2. 2 阶 示…

細講《弟子規》41

問&#xff1a;第一位朋友問到&#xff0c;請問老師&#xff0c;學習《弟子規》有沒有其他相輔相成的學習教材可以幫我們更上層樓&#xff1f;其學習的順序次第為何&#xff1f; 答&#xff1a;《弟子規》可以配合《德育課本》來學習。《德育課本》好像就是這些古聖先賢把理論跟…

架构师:搭建Spring Security、OAuth2和JWT 的安全认证框架

1、简述 Spring Security 是 Spring 生态系统中的一个强大的安全框架,用于实现身份验证和授权。结合 OAuth2 和 JWT 技术,可以构建一个安全可靠的认证体系,本文将介绍如何在 Spring Boot 中配置并使用这三种技术实现安全认证,并分析它们的优点。 2、Spring Security Spri…

QT设计模式:装饰器模式

基本概念 装饰器模式&#xff08;Decorator Pattern&#xff09;是一种结构型设计模式&#xff0c;它允许向现有对象添加新功能&#xff0c;又不改变其结构。通过将对象放入包装器中&#xff0c;然后用装饰器对象包裹原始对象&#xff0c;以提供额外的功能。 装饰器模式需要实…