GAN(生成对抗网络)

news/2024/11/17 7:28:50/

简介:GAN生成对抗网络本质上是一种思想,其依靠神经网络能够拟合任意函数的能力,设计了一种架构来实现数据的生成。

原理:GAN的原理就是最小化生成器Generator的损失,但是在最小化损失的过程中加入了一个约束,这个约束就是使Generator生成的数据满足我们指定数据的分布,GAN的巧妙之处在于使用一个神经网络(鉴别器Discriminator)来自动判断生成的数据是否符合我们所需要的分布。

实现细节:

一:

        准备好我们想要让生成器生成的数据类型,比如MINIST手写数字集,包含1-10十个数字,一共60000张图片。生成器的目的就是学习这个数据集的分布。

二,

        定义一个生成器,用于判别一张图片是实际的还是生成器生成的,当生成器完美学习得到数据分布之后,鉴别器可能就分不清图片是生成器的还是实际的,这样的话生成器就能生成我们想要的图片了。

        生成器的训练过程为:实际数据输出结果1,生成数据输出结果为0,目的是学会区分真假数据,相当于提供一个约束,使生成数据符合指定分布。当鉴别生成器的数据分布时,只需要更新鉴别器的参数权重,不能够通过计算图将生成器的参数进行更新。

三,

        定义一个生成器,给定一个输入,他就能生成1-10里面的一个数字的图片。生成器的反向更新是根据鉴别器的损失来确定(被约束进行反向更新)。生成器的网络权重参数是单独的,反向更新时,只需要更新计算图当中属于生成器部分的参数。

下面给出生成1-0-1-0数据格式的代码:

# %%
import torch
import numpy
import torch.nn as nn
import matplotlib.pyplot as plt# %%
def gennerate1010():return torch.FloatTensor([numpy.random.uniform(0.9,1.1),numpy.random.uniform(0.,.1),numpy.random.uniform(0.9,1.1),numpy.random.uniform(0.0,.1)])# %%
def genneratexxxx():return torch.rand(4)# %%
class Discrimer(nn.Module):def __init__(self) -> None:father_obj = super(Discrimer,self)father_obj.__init__()self.create_model()self.counter = 0self.progress = []def create_model(self):self.model = nn.Sequential(nn.Linear(4,3),nn.Sigmoid(),nn.Linear(3,1),nn.Sigmoid(),           )self.loss_functon = nn.MSELoss()self.optimiser = torch.optim.SGD(self.parameters(),lr=0.01)def forward(self,x):return self.model(x)def train(self,x,targets):outputs = self.forward(x)loss = self.loss_functon(outputs,targets)self.counter += 1if self.counter%10 == 0:self.progress.append(loss.item())if self.counter%10000 == 0:print(self.counter)self.optimiser.zero_grad()loss.backward()self.optimiser.step()def plotprogress(self):plt.plot(self.progress,marker='*')plt.show()# %%
class Gennerater(nn.Module):def __init__(self) -> None:father_obj = super(Gennerater,self)father_obj.__init__()self.create_model()self.counter = 0self.progress = []def create_model(self):self.model = nn.Sequential(nn.Linear(1,3),nn.Sigmoid(),nn.Linear(3,4),nn.Sigmoid(),           )# 这个优化器只能优化生成器部分的参数self.optimiser = torch.optim.SGD(self.parameters(),lr=0.01)def forward(self,x):return self.model(x)def train(self,D,x,targets):g_outputs = self.forward(x)d_outputs = D.forward(g_outputs)# 使用鉴别器的loss函数,但是只更新生成器的参数,生成器的参数需要根据鉴别器的约束进行更新loss = D.loss_functon(d_outputs,targets)self.counter += 1if self.counter%10 == 0:self.progress.append(loss.item())if self.counter%10000 == 0:print(self.counter)self.optimiser.zero_grad()loss.backward()self.optimiser.step()def plotprogress(self):plt.plot(self.progress,marker='*')plt.show()# %%
D = Discrimer()# %%
G = Gennerater()# %%
for id in range(15000):# 喂入实际数据给鉴别器D.train(gennerate1010(),torch.FloatTensor([1.]))# 喂入生成的数据,使用detach从计算图脱离,用于更新鉴别器,而生成器得不到更新D.train(G.forward(torch.FloatTensor([0.5]).detach()),torch.FloatTensor([0.0]))G.train(D,torch.FloatTensor([0.5]),torch.FloatTensor([1.]))# %%
D.plotprogress()# %%
G.plotprogress()# %%
G.forward(torch.FloatTensor([0.5]))

参考:PyTorch生成对抗网络编程


http://www.ppmy.cn/news/1055704.html

相关文章

uniapp 使用 mui-player 插件播放 m3u8/flv 视频流

在UniApp中使用mui-player插件播放M3U8/FLV视频流,可以按照以下步骤进行操作: 1. 安装mui-player插件 :在UniApp项目根目录下,使用命令行工具执行以下命令安装mui-player插件: npm install mui-player --save2. 在需…

github出现的2FA的问题

github出现的2FA的问题 首先是收到的邮件提醒: 关键的信息是: 这提醒我们,我们宣布,我们要求用户贡献代码o n GitHub.com启用双因素身份验证 (2FA)。您收到此通知是因为您的帐户符合此标准,并且需要在00:00 (UTC) 2023年9月15日注册2FA。 …

C++|使用int数组实现一个栈类,为这个栈类增加getMaxValue方法

面试题: 使用int数组实现一个栈类,为这个栈类增加getMaxValue方法 做法: 在实现好的栈类里面,维护一个子栈,用来存储所有入栈时当过最大值的数字 比如栈:3 2 5 1 2 4 那么维护的子栈中存储的是&#xff1…

go-kafka

go kafka包 本文使用的是kafka-go 6.5k 这个包 其他包参考: 我们在细分市场中非常依赖GO和Kafka。不幸的是,在撰写本文时,Kafka的GO客户库的状态并不理想。可用选项是: 萨拉玛(Sarama) 10k,这…

精准高效农业作业,植保无人机显身手

中国作为农业大国,拥有约18亿亩的农田,每年都需要进行种子喷洒和农药施用等农业作业,对于普通农户来说,这是一项耗时耗力的工程,同时,人工喷洒农药极易造成农药慢性中毒,对农民的身体健康产生极…

FairyGUI编辑器的弹窗操作【插件】

之前在FairyGUI编辑器菜单扩展中,我使用了App.Alert("复制失败")来提示操作是否成功。这篇则会说一下我们可以使用的弹窗提示,以及做到类似资源发布成功时的“发布成功”飘窗。 打开APP的API脚本,可以看到有很多公开方法&#xff…

设计模式(11)观察者模式

一、概述: 1、定义:观察者模式定义了一种一对多的依赖关系,让多个观察者对象同时监听某一个主题对象。这个主题对象在状态发生变化时,会通知所有观察者对象,使它们能够自动更新自己。 2、结构图: public interface S…

设计模式二十二:策略模式(Strategy Pattern)

定义一系列算法,将每个算法封装成独立的对象,并使这些对象可互相替换。这使得在运行时可以动态地选择算法,而不必改变使用算法的客户端代码。策略模式的主要目标是将算法的定义与使用分离,使得客户端可以根据需要灵活地选择和切换…