基于Tensorflow的最基本GAN网络模型

news/2024/12/2 13:10:55/
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import glob
import os
#(1)创建输入管道
# 导入原始数据
(train_images, train_labels),(_, _) = tf.keras.datasets.mnist.load_data()
# 查看原始数据大小与数据格式
# 60000张图片,每一张图片都是28*28像素
# print(train_images.shape)
# dtype('uint8'),每一位的范围都是0-255的整数,由于图像的一个通道中rgb颜色值就是0-255不等,因此uint8就是图像的标准数字格式
# print(train_images.dtype)#(1.1)数据预处理
# 转换数据类型
train_images = train_images.reshape(train_images.shape[0], 28,28,1)
train_images = train_images.astype('float32')# 归一化0-255>>[-1,1]
train_images = (train_images - 127.5)/127.5#(1.2)确定训练时的BATCH_SIZE与BUFFER_SIZE
BATCH_SIZE = 256 # 每一个batch指一次训练,batch_size表示一次训练所需的数据个数。这里一次训练需要256张图片
BUFFER_SIZE = 60000 # 目前不知道buffer是干什么的#(1.3)将归一化后的图像转化为tf内置的一种数据形式
datasets = tf.data.Dataset.from_tensor_slices(train_images)#(1.4)将训练模型的数据集进行打乱的操作:shuffle
datasets = datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
#(2)生成器模型
def Generator_Model():model = keras.Sequential() # 顺序模型# dense 全连接层# 输入:长度为100的随机数向量(自己定义)# 输出:长度为256的向量model.add(layers.Dense(256, input_shape = (100,), use_bias = False))model.add(layers.BatchNormalization()) # 归一化层model.add(layers.LeakyReLU()) # 激活层model.add(layers.Dense(512, use_bias = False))model.add(layers.BatchNormalization()) # 归一化层model.add(layers.LeakyReLU()) # 激活层model.add(layers.Dense(28*28*1, use_bias = False, activation = 'tanh'))model.add(layers.BatchNormalization()) # 归一化层model.add(layers.Reshape((28,28,1))) # 写为元组的形式return model
#(3)判别器模型
def Discriminator_Model():model = keras.Sequential()model.add(layers.Flatten()) # 将3维图像拉伸为一维图像model.add(layers.Dense(512, use_bias = False))model.add(layers.BatchNormalization()) # 归一化层model.add(layers.LeakyReLU()) # 激活层model.add(layers.Dense(256, use_bias = False))model.add(layers.BatchNormalization()) # 归一化层model.add(layers.LeakyReLU()) # 激活层model.add(layers.Dense(1)) # 输出1或者0return model
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits = True)#(4)判别器的损失函数:对于真是图片,判定为1;对于生成图片,判定为0
def discriminator_loss(real_out, fake_out):real_loss = cross_entropy(tf.ones_like(real_out),real_out)fake_loss = cross_entropy(tf.zeros_like(fake_out),fake_out)return real_loss+fake_loss#(5)生成器损失函数:对于生成图片,判定为1
def generator_loss(fake_out):fake_loss = cross_entropy(tf.ones_like(fake_out),fake_out)return fake_loss
#(6)创建判别器和生成器的优化器,定义参数的学习速率
generator_opt = tf.keras.optimizers.Adam(1e-4)
discriminator_opt = tf.keras.optimizers.Adam(1e-4)
EPOCHS = 100
noise_dim = 100
num_exp_to_generate = 16
seed = tf.random.normal([num_exp_to_generate, noise_dim])# 实例化生成器与判别器
Generator = Generator_Model()
Discriminator = Discriminator_Model()
#(7)训练GAN网络
# 每一个batch
def train_step(images):noise = tf.random.normal([BATCH_SIZE, noise_dim])with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:real_output = Discriminator(images, training = True)gen_image = Generator(noise, training = True)fake_output = Discriminator(gen_image, training = True)gen_loss = generator_loss(fake_output)disc_loss = discriminator_loss(real_output, fake_output)#优化gradient_gen = gen_tape.gradient(gen_loss, Generator.trainable_variables)gradient_disc = disc_tape.gradient(disc_loss, Discriminator.trainable_variables)generator_opt.apply_gradients(zip(gradient_gen, Generator.trainable_variables))discriminator_opt.apply_gradients(zip(gradient_disc, Discriminator.trainable_variables))
# 可视化函数
def generator_plt_img(gen_model, test_noise):pre_images = gen_model(test_noise, training = False)fig = plt.figure(figsize=(4, 4))for i in range(pre_images.shape[0]):plt.subplot(4,4,i+1)plt.imshow((pre_images[i,:,:,0]+1)/2, cmap = 'gray')plt.axis('off')plt.show()
# 完整的训练模型的函数
def train(dataset, epochs):for epoch in range(epochs):for img_batch in dataset:train_step(img_batch)print('.',end='')generator_plt_img(Generator, seed)
# 训练模型
train(datasets, EPOCHS)

视频链接:https://www.bilibili.com/video/BV1f7411E7wU/?spm_id_from=333.999.0.0


http://www.ppmy.cn/news/40432.html

相关文章

PasteSpider之项目-服务-环境介绍

在PasteSpider中,项目和服务是重要的对象,只有理解什么是项目什么是服务后配置起来才不会稀里糊涂的! 项目 PasteSpider中的项目和我们平时说的项目意思一样,比如你要开发一个在线客服系统(项目),一个商城系统(项目),…

分子生物学 第三章 基因、基因组及基因组学

文章目录第三章 基因、基因组及基因组学第一节 基因1 基因认识的三个阶段2 基因的特征(1)跳跃基因(2)断裂基因3 基因的分类4 基因的结构5 基因的大小6 基因的数目第二节 基因组1 基因组的概念2 噬菌体基因组3 细菌基因组以大肠杆菌(原核生物的代表)为研究对象4 酵母基因组以酵母…

linux常问

查看当前进程 ps -l 列出与本次登录有关的进程信息; ps -aux 查询内存中进程信息; ps -aux | grep * 查询 *进程的详细信息; top 查看内存中进程的动态信息; kill -9 pid 杀死进程。

【OJ每日一练】1138 - 身份证

文章目录 一、题目🔸题目描述🔸输入输出🔸样例1二、代码参考作者:KJ.JK🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🍂个人博客首页: KJ.JK 💖系列专栏:OJ每日一练 一、题目 🔸题目描述 如果让你设计个程序,用什么变量保存身份证号…

三次握手详解,全网最全

一、TCP 报文段简介 在介绍三次握手和四次挥手之前,先来简单认识一下 TCP 报文段的结构 TCP报文段也分为首部和数据两部分,首部默认情况下一般是20字节长度,但在一些需求情况下,会使用“可选字段”,这时,首…

如何保证接口安全,做到防篡改防重放?

对于互联网来说,只要你系统的接口暴露在外网,就避免不了接口安全问题。如果你的接口在外网裸奔,只要让黑客知道接口的地址和参数就可以调用,那简直就是灾难。 举个例子:你的网站用户注册的时候,需要填写手…

如何自学JAVA

一:Java基础知识 俗话说的好“千里之行,始于足下”,学习也是一样的从小的基础的知识点开始慢慢积累,掌握Java语言的基础知识,如面向对象、数据结构与算法、异常处理、IO框架、多线程、网络编程、设计模式、Java新特性…

odps多行合并为一行

在ODPS中,多行合并为一行可以通过使用ODPS SQL语句中的聚合函数来实现。 假设我们有一个表格,其中包含多行数据: name score Tom 20 Jack 20 Lucy 30 将上述表格中的相同分数的人合并为一行,并用逗号分隔每个值:…