DCGAN

news/2024/11/9 0:46:35/

转自:https://blog.csdn.net/liuxiao214/article/details/74502975

首先是各种参考博客、链接等,表示感谢。

1、参考博客1:地址

——以下,开始正文。

2017/12/12 更新 解决训练不收敛的问题。

更新在最后面部分。


1、DCGAN的简单总结

稳定的深度卷积GAN 架构指南:

  • 所有的pooling层使用步幅卷积(判别网络)和微步幅度卷积(生成网络)进行替换。

  • 在生成网络和判别网络上使用批处理规范化。

  • 对于更深的架构移除全连接隐藏层。

  • 在生成网络的所有层上使用RelU激活函数,除了输出层使用Tanh激活函数。

  • 在判别网络的所有层上使用LeakyReLU激活函数。

这里写图片描述

图: LSUN 场景模型中使用的DCGAN生成网络。一个100维度的均匀分布z映射到一个有很多特征映射的小空间范围卷积。一连串的四个微步幅卷积(在最近的一些论文中它们错误地称为去卷积),将高层表征转换为64*64像素的图像。明显,没有使用全连接层和池化层。

2、DCGAN的实现

DCGAN原文作者是生成了卧室图片,这里参照前面写的参考链接中,来生成动漫人物头像。生成效果如下:

暂且先不放,因为还没开始做。

2.1 搜集原始数据集

首先是需要获取大量的动漫图像,这个可以利用爬虫爬取一个动漫网站:konachan.net的图片。爬虫的代码如下所示:

import requests  # http lib
from bs4 import BeautifulSoup  # climb lib
import os # operation system
import traceback # trace deviancedef download(url,filename):if os.path.exists(filename):print('file exists!')returntry:r = requests.get(url,stream=True,timeout=60)r.raise_for_status()with open(filename,'wb') as f:for chunk in r.iter_content(chunk_size=1024):if chunk: # filter out keep-alove new chunksf.write(chunk)f.flush()return filenameexcept KeyboardInterrupt:if os.path.exists(filename):os.remove(filename)return KeyboardInterruptexcept Exception:traceback.print_exc()if os.path.exists(filename):os.remove(filename)if os.path.exists('imgs') is False:os.makedirs('imgs')start = 1
end = 8000
for i in range(start, end+1):url = 'http://konachan.net/post?page=%d&tags=' % ihtml = requests.get(url).text # gain the web's informationsoup =  BeautifulSoup(html,'html.parser') # doc's string and jie xi qifor img in soup.find_all('img',class_="preview"):# 遍历所有preview类,找到img标签target_url = 'http:' + img['src']filename = os.path.join('imgs',target_url.split('/')[-1])download(target_url,filename)print('%d / %d' % (i,end))    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42

目标是获取1万张图像,因为自己是在CPU上跑的,而且内存太小,太多图像根本训练不起来,就先少一点训练,看看效果。

截取部分图像如下所示:

这里写图片描述

现在已经有了基本的图像了,但我们的目标是生成动漫头像,不需要整张图像,而且其他的信息会干扰到训练,所以需要进行人脸检测截取人脸图像。

2.2 人脸检测截取人脸

通过基于opencv的人脸检测分类器,参考于lbpcascade_animeface。

首先,要使用这个分类器要先进行下载:

wget https://raw.githubusercontent.com/nagadomi/lbpcascade_animeface/master/lbpcascade_animeface.xml
  • 1

下载完成后,运行以下代码对图像进行人脸截取。

import cv2
import sys
import os.path
from glob import globdef detect(filename,cascade_file="lbpcascade_animeface.xml"):if not os.path.isfile(cascade_file):raise RuntimeError("%s: not found" % cascade_file)cascade = cv2.CascadeClassifier(cascade_file)image = cv2.imread(filename)gray = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)gray = cv2.equalizeHist(gray)faces = cascade.detectMultiScale(gray,# detector optionsscaleFactor = 1.1,minNeighbors = 5,minSize = (48,48))for i,(x,y,w,h) in enumerate(faces):face = image[y: y+h, x:x+w, :]face = cv2.resize(face,(96,96))save_filename = '%s.jpg' % (os.path.basename(filename).split('.')[0])cv2.imwrite("faces/"+sace_filename,face)if __name__ == '__main__':if os.path.exists('faces') is False:os.makedirs('faces')file_list = glob('imgs/*.jpg')for filename in file_list:detect(filename)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35

处理后的图像如下所示:

这里写图片描述

2.3 源代码解析

参照于DCGAN-tensorflow。

总共获取11053张图像,人脸检测后得到3533张。

一共有4个文件,分别是main.py、model.py、ops.py、utils.py。

2.3.1 mian.py

原代码(98行):

import os
import scipy.misc # 
import numpy as npfrom model import DCGAN
from utils import pp, visualize, to_json, show_all_variablesimport tensorflow as tfflags = tf.app.flags
flags.DEFINE_integer("epoch", 25, "Epoch to train [25]")
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]")
flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]")
flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]")
flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")
flags.DEFINE_integer("input_height", 108, "The size of image to use (will be center cropped). [108]")
flags.DEFINE_integer("input_width", None, "The size of image to use (will be center cropped). If None, same value as input_height [None]")
flags.DEFINE_integer("output_height", 64, "The size of the output images to produce [64]")
flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]")
flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]")
flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]")
FLAGS = flags.FLAGSdef main(_):pp.pprint(flags.FLAGS.__flags)if FLAGS.input_width is None:FLAGS.input_width = FLAGS.input_heightif FLAGS.output_width is None:FLAGS.output_width = FLAGS.output_heightif not os.path.exists(FLAGS.checkpoint_dir):os.makedirs(FLAGS.checkpoint_dir)if not os.path.exists(FLAGS.sample_dir):os.makedirs(FLAGS.sample_dir)#gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)run_config = tf.ConfigProto()run_config.gpu_options.allow_growth=Truewith tf.Session(config=run_config) as sess:if FLAGS.dataset == 'mnist':dcgan = DCGAN(sess,input_width=FLAGS.input_width,input_height=FLAGS.input_height,output_width=FLAGS.output_width,output_height=FLAGS.output_height,batch_size=FLAGS.batch_size,sample_num=FLAGS.batch_size,y_dim=10,dataset_name=FLAGS.dataset,input_fname_pattern=FLAGS.input_fname_pattern,crop=FLAGS.crop,checkpoint_dir=FLAGS.checkpoint_dir,sample_dir=FLAGS.sample_dir)else:dcgan = DCGAN(sess,input_width=FLAGS.input_width,input_height=FLAGS.input_height,output_width=FLAGS.output_width,output_height=FLAGS.output_height,batch_size=FLAGS.batch_size,sample_num=FLAGS.batch_size,dataset_name=FLAGS.dataset,input_fname_pattern=FLAGS.input_fname_pattern,crop=FLAGS.crop,checkpoint_dir=FLAGS.checkpoint_dir,sample_dir=FLAGS.sample_dir)show_all_variables()if FLAGS.train:dcgan.train(FLAGS)else:if not dcgan.load(FLAGS.checkpoint_dir)[0]:raise Exception("[!] Train a model first, then run test mode")# to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],#                 [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],#                 [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],#                 [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],#                 [dcgan.h4_w, dcgan.h4_b, None])# Below is codes for visualizationOPTION = 1visualize(sess, dcgan, FLAGS, OPTION)if __name__ == '__main__':tf.app.run()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98

该文件调用了model.py文件和utils.py文件。

step0:执行main函数之前首先进行flags的解析,TensorFlow底层使用了python-gflags项目,然后封装成tf.app.flags接口,也就是说TensorFlow通过设置flags来传递tf.app.run()所需要的参数,我们可以直接在程序运行前初始化flags,也可以在运行程序的时候设置命令行参数来达到传参的目的。

这里主要设置了:

  • epoch:迭代次数
  • learning_rate:学习速率,默认是0.002
  • beta1
  • train_size
  • batch_size:每次迭代的图像数量
  • input_height:需要指定输入图像的高
  • input_width:需要指定输入图像的宽
  • output_height:需要指定输出图像的高
  • output_width:需要指定输出图像的宽
  • dataset:需要指定处理哪个数据集
  • input_fname_pattern
  • checkpoint_dir
  • sample_dir
  • train:True for training, False for testing
  • crop:True for training, False for testing
  • visualize

step1:首先是打印参数数据,然后判断输入图像的输出图像的宽是否指定,如果没有指定,则等于其图像的高。

step2:然后判断checkpoint和sample的文件是否存在,不存在则创建。

step3:然后是设置session参数。tf.ConfigProto一般用在创建session的时候,用来对session进行参数配置,详细内容可见这篇博客。

#tf.ConfigProto()的参数:
log_device_placement=True : 是否打印设备分配日志
allow_soft_placement=True : 如果你指定的设备不存在,允许TF自动分配设备
tf.ConfigProto(log_device_placement=True,allow_soft_placement=True)控制GPU资源使用率:
#allow growth
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config, ...)
# 使用allow_growth option,刚一开始分配少量的GPU容量,然后按需慢慢的增加,由于不会释放内存,所以会导致碎片
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

step4:运行session,首先判断处理的是哪个数据集,然后对应使用不同参数的DCGAN类,这个类会在model.py文件中定义。

step5:show所有与训练相关的变量。

step6:判断是训练还是测试,如果是训练,则进行训练;如果不是,判断是否有训练好的model,然后进行测试,如果没有先训练,则会提示“[!] Train a model first, then run test mode”。

step7:最后进行可视化,visualize(sess, dcgan, FLAGS, OPTION)。

main.py主要是调用前面定义好的模型、图像处理方法,来进行训练测试,程序的入口。

2.3.2 utils.py

源代码(250行):

"""
Some codes from https://github.com/Newmu/dcgan_code
"""
from __future__ import division
import math
import json
import random
import pprint # print data_struct
import scipy.misc
import numpy as np
from time import gmtime, strftime
from six.moves import xrangeimport tensorflow as tf
import tensorflow.contrib.slim as slimpp = pprint.PrettyPrinter()get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1])def show_all_variables():model_vars = tf.trainable_variables()slim.model_analyzer.analyze_vars(model_vars, print_info=True)def get_image(image_path, input_height, input_width,resize_height=64, resize_width=64,crop=True, grayscale=False):image = imread(image_path, grayscale)return transform(image, input_height, input_width,resize_height, resize_width, crop)def save_images(images, size, image_path):return imsave(inverse_transform(images), size, image_path)def imread(path, grayscale = False):if (grayscale):return scipy.misc.imread(path, flatten = True).astype(np.float)else:return scipy.misc.imread(path).astype(np.float)def merge_images(images, size):return inverse_transform(images)def merge(images, size):h, w = images.shape[1], images.shape[2]if (images.shape[3] in (3,4)):c = images.shape[3]img = np.zeros((h * size[0], w * size[1], c))for idx, image in enumerate(images):i = idx % size[1]j = idx // size[1]img[j * h:j * h + h, i * w:i * w + w, :] = imagereturn imgelif images.shape[3]==1:img = np.zeros((h * size[0], w * size[1]))for idx, image in enumerate(images):i = idx % size[1]j = idx // size[1]img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]return imgelse:raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4')def imsave(images, size, path):image = np.squeeze(merge(images, size))return scipy.misc.imsave(path, image)def center_crop(x, crop_h, crop_w,resize_h=64, resize_w=64):if crop_w is None:crop_w = crop_hh, w = x.shape[:2]j = int(round((h - crop_h)/2.))i = int(round((w - crop_w)/2.))return scipy.misc.imresize(x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w])def transform(image, input_height, input_width, resize_height=64, resize_width=64, crop=True):if crop:cropped_image = center_crop(image, input_height, input_width, resize_height, resize_width)else:cropped_image = scipy.misc.imresize(image, [resize_height, resize_width])return np.array(cropped_image)/127.5 - 1.def inverse_transform(images):return (images+1.)/2.def to_json(output_path, *layers):with open(output_path, "w") as layer_f:lines = ""for w, b, bn in layers:layer_idx = w.name.split('/')[0].split('h')[1]B = b.eval()if "lin/" in w.name:W = w.eval()depth = W.shape[1]else:W = np.rollaxis(w.eval(), 2, 0)depth = W.shape[0]biases = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(B)]}if bn != None:gamma = bn.gamma.eval()beta = bn.beta.eval()gamma = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(gamma)]}beta = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(beta)]}else:gamma = {"sy": 1, "sx": 1, "depth": 0, "w": []}beta = {"sy": 1, "sx": 1, "depth": 0, "w": []}if "lin/" in w.name:fs = []for w in W.T:fs.append({"sy": 1, "sx": 1, "depth": W.shape[0], "w": ['%.2f' % elem for elem in list(w)]})lines += """var layer_%s = {"layer_type": "fc", "sy": 1, "sx": 1, "out_sx": 1, "out_sy": 1,"stride": 1, "pad": 0,"out_depth": %s, "in_depth": %s,"biases": %s,"gamma": %s,"beta": %s,"filters": %s};""" % (layer_idx.split('_')[0], W.shape[1], W.shape[0], biases, gamma, beta, fs)else:fs = []for w_ in W:fs.append({"sy": 5, "sx": 5, "depth": W.shape[3], "w": ['%.2f' % elem for elem in list(w_.flatten())]})lines += """var layer_%s = {"layer_type": "deconv", "sy": 5, "sx": 5,"out_sx": %s, "out_sy": %s,"stride": 2, "pad": 1,"out_depth": %s, "in_depth": %s,"biases": %s,"gamma": %s,"beta": %s,"filters": %s};""" % (layer_idx, 2**(int(layer_idx)+2), 2**(int(layer_idx)+2),W.shape[0], W.shape[3], biases, gamma, beta, fs)layer_f.write(" ".join(lines.replace("'","").split()))def make_gif(images, fname, duration=2, true_image=False):import moviepy.editor as mpydef make_frame(t):try:x = images[int(len(images)/duration*t)]except:x = images[-1]if true_image:return x.astype(np.uint8)else:return ((x+1)/2*255).astype(np.uint8)clip = mpy.VideoClip(make_frame, duration=duration)clip.write_gif(fname, fps = len(images) / duration)def visualize(sess, dcgan, config, option):image_frame_dim = int(math.ceil(config.batch_size**.5))if option == 0:z_sample = np.random.uniform(-0.5, 0.5, size=(config.batch_size, dcgan.z_dim))samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_%s.png' % strftime("%Y%m%d%H%M%S", gmtime()))elif option == 1:values = np.arange(0, 1, 1./config.batch_size)for idx in xrange(100):print(" [*] %d" % idx)z_sample = np.zeros([config.batch_size, dcgan.z_dim])for kdx, z in enumerate(z_sample):z[idx] = values[kdx]if config.dataset == "mnist":y = np.random.choice(10, config.batch_size)y_one_hot = np.zeros((config.batch_size, 10))y_one_hot[np.arange(config.batch_size), y] = 1samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample, dcgan.y: y_one_hot})else:samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_arange_%s.png' % (idx))elif option == 2:values = np.arange(0, 1, 1./config.batch_size)for idx in [random.randint(0, 99) for _ in xrange(100)]:print(" [*] %d" % idx)z = np.random.uniform(-0.2, 0.2, size=(dcgan.z_dim))z_sample = np.tile(z, (config.batch_size, 1))#z_sample = np.zeros([config.batch_size, dcgan.z_dim])for kdx, z in enumerate(z_sample):z[idx] = values[kdx]if config.dataset == "mnist":y = np.random.choice(10, config.batch_size)y_one_hot = np.zeros((config.batch_size, 10))y_one_hot[np.arange(config.batch_size), y] = 1samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample, dcgan.y: y_one_hot})else:samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})try:make_gif(samples, './samples/test_gif_%s.gif' % (idx))except:save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_%s.png' % strftime("%Y%m%d%H%M%S", gmtime()))elif option == 3:values = np.arange(0, 1, 1./config.batch_size)for idx in xrange(100):print(" [*] %d" % idx)z_sample = np.zeros([config.batch_size, dcgan.z_dim])for kdx, z in enumerate(z_sample):z[idx] = values[kdx]samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})make_gif(samples, './samples/test_gif_%s.gif' % (idx))elif option == 4:image_set = []values = np.arange(0, 1, 1./config.batch_size)for idx in xrange(100):print(" [*] %d" % idx)z_sample = np.zeros([config.batch_size, dcgan.z_dim])for kdx, z in enumerate(z_sample): z[idx] = values[kdx]image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}))make_gif(image_set[-1], './samples/test_gif_%s.gif' % (idx))new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) \for idx in range(64) + range(63, -1, -1)]make_gif(new_image_set, './samples/test_gif_merged.gif', duration=8)def image_manifold_size(num_images):manifold_h = int(np.floor(np.sqrt(num_images)))manifold_w = int(np.ceil(np.sqrt(num_images)))assert manifold_h * manifold_w == num_imagesreturn manifold_h, manifold_w
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251

这份代码主要是定义了各种对图像处理的函数,相当于其他3个文件的头文件。

step0:首先定义了一个pp = pprint.PrettyPrinter(),以方便打印数据结构信息,详细信息可见这篇博客。

step1:定义了get_stddev函数,是三个参数乘积后开平方的倒数,应该是为了随机化用。

step2:定义show_all_variables()函数。首先,tf.trainable_variables返回的是需要训练的变量列表;然后用tensorflow.contrib.slim中的model_analyzer.analyze_vars打印出所有与训练相关的变量信息。用法如下:

-代码:

import tensorflow as tf
import tensorflow.contrib.slim as slimx1=tf.Variable(tf.constant(1,shape=[1],dtype=tf.float32),name='x11')
x2=tf.Variable(tf.constant(2,shape=[1],dtype=tf.float32),name='x22')
m=tf.train.ExponentialMovingAverage(0.99,5)
v=tf.trainable_variables()
for i in v:print 233print iprint 23333333   
slim.model_analyzer.analyze_vars(v,print_info=True)
print 23333333
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

-结果截图如下:

这里写图片描述


注:从step3-step11,都是在定义一些图像处理的函数,它们之间相互调用。

step3:定义get_image(image_path,input_height,input_width,resize_height=64, resize_width=64,crop=True, grayscale=False)函数。首先根据图像路径参数读取路径,根据灰度化参数选择是否进行灰度化。然后对图像参照输入的参数进行裁剪。

step4:定义save_images(images,size,image_path)函数。调用imsave(inverse_transform(images), size, image_path)函数并返回新图像。

step5:定义imread(path, grayscale = False)函数。调用cipy.misc.imread()函数,判断grayscale参数是否进行范围灰度化,并进行类型转换为np.float。

step6:定义merge_images(images, size)函数。调用inverse_transform(images)函数,并返回新图像。

step7:定义merge(images, size)函数。首先获取image的高和宽。然后判断image是RGB图还是灰度图,以分别进行不同的处理。如果通道数是3或4,则对每一批次(如,batch_size=64)的所有图像,用0初始化一张原始图像放大8*8的图像,然后循环,依次将所有图像填入大图像,并且返回这张大图像。如果通道数是1,也是一样,只不过填入图像的时候只填一个通道的信息。如果不是上述两种情况,则抛出错误提示。

step8:定义imsave(images, size, path)函数。首先将merge()函数返回的图像,用 np.squeeze()函数移除长度为1的轴。然后利用scipy.misc.imsave()函数将新图像保存到指定路径中。

step9:定义center_crop(x, crop_h, crop_w,resize_h=64, resize_w=64)函数。对图像的H和W与crop的H和W相减,得到取整的值,根据这个值作为下标依据来scipy.misc.resize图像。

step10:定义transform(image, input_height, input_width,resize_height=64, resize_width=64, crop=True)函数。对输入的图像进行裁剪,如果crop为true,则使用center_crop()函数,对图像的H和W与crop的H和W相减,得到取整的值,根据这个值作为下标依据来scipy.misc.resize图像;否则不对图像进行其他操作,直接scipy.misc.resize为64*64大小的图像。最后返回图像。

step11:定义inverse_transform(images)函数。对图像进行翻转后返回新图像。

总结下来,这几个函数相互调用,主要实现了3个图像操作功能:获取图像get_image(),负责读取图像,返回图像裁剪后的新图像;保存图像save_images(),负责将一个batch中所有图像保存为一张大图像并返回;图像翻转merge_images(),负责不知道怎么得翻转的,返回新图像。它们之间的相互关系如下图所示。

这里写图片描述

step12:定义to_json(output_path, *layers)函数。应该是获取每一层的权值、偏置值什么的,但貌似代码中没有用到这个函数,所以先不管,后面用到再说。

step13:定义make_gif(images, fname, duration=2, true_image=False)函数。利用moviepy.editor模块来制作动图,为了可视化用的。函数又定义了一个函数make_frame(t),首先根据图像集的长度和持续的时间做一个除法,然后返回每帧图像。最后视频修剪并制作成GIF动画。

step14:定义visualize(sess, dcgan, config, option)函数。分为0、1、2、3、4种option。如果option=0,则之间显示生产的样本‘如果option=1,根据不同数据集不一样的处理,并利用前面的save_images()函数将sample保存下来;等等。本次在main.py中选用option=1。

step15:定义image_manifold_size(num_images)函数。首先获取图像数量的开平方后向下取整的h和向上取整的w,然后设置一个assert断言,如果h*w与图像数量相等,则返回h和w,否则断言错误提示。

这就是全部utils.py全部内容,主要负责图像的一些基本操作,获取图像、保存图像、图像翻转,和利用moviepy模块可视化训练过程。

2.3.3 ops.py

源代码(105行):

import math
import numpy as np 
import tensorflow as tffrom tensorflow.python.framework import opsfrom utils import *try:image_summary = tf.image_summaryscalar_summary = tf.scalar_summaryhistogram_summary = tf.histogram_summarymerge_summary = tf.merge_summarySummaryWriter = tf.train.SummaryWriter
except:image_summary = tf.summary.imagescalar_summary = tf.summary.scalarhistogram_summary = tf.summary.histogrammerge_summary = tf.summary.mergeSummaryWriter = tf.summary.FileWriterif "concat_v2" in dir(tf):def concat(tensors, axis, *args, **kwargs):return tf.concat_v2(tensors, axis, *args, **kwargs)
else:def concat(tensors, axis, *args, **kwargs):return tf.concat(tensors, axis, *args, **kwargs)class batch_norm(object):def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"):with tf.variable_scope(name):self.epsilon  = epsilonself.momentum = momentumself.name = namedef __call__(self, x, train=True):return tf.contrib.layers.batch_norm(x,decay=self.momentum, updates_collections=None,epsilon=self.epsilon,scale=True,is_training=train,scope=self.name)def conv_cond_concat(x, y):"""Concatenate conditioning vector on feature map axis."""x_shapes = x.get_shape()y_shapes = y.get_shape()return concat([x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3)def conv2d(input_, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,name="conv2d"):with tf.variable_scope(name):w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],initializer=tf.truncated_normal_initializer(stddev=stddev))conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())return convdef deconv2d(input_, output_shape,k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,name="deconv2d", with_w=False):with tf.variable_scope(name):# filter : [height, width, output_channels, in_channels]w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]],initializer=tf.random_normal_initializer(stddev=stddev))try:deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape,strides=[1, d_h, d_w, 1])# Support for verisons of TensorFlow before 0.7.0except AttributeError:deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape,strides=[1, d_h, d_w, 1])biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())if with_w:return deconv, w, biaseselse:return deconvdef lrelu(x, leak=0.2, name="lrelu"):return tf.maximum(x, leak*x)def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):shape = input_.get_shape().as_list()with tf.variable_scope(scope or "Linear"):matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32,tf.random_normal_initializer(stddev=stddev))bias = tf.get_variable("bias", [output_size],initializer=tf.constant_initializer(bias_start))if with_w:return tf.matmul(input_, matrix) + bias, matrix, biaselse:return tf.matmul(input_, matrix) + bias
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105

该文件调用了utils.py文件。

step0:首先导入tensorflow.python.framework模块,包含了tensorflow中图、张量等的定义操作。

step1:然后是一个try…except部分,定义了一堆变量:image_summary 、scalar_summary、histogram_summary、merge_summary、SummaryWriter,都是从相应的tensorflow中获取的。如果可是直接获取,则获取,否则从tf.summary中获取。

step2:用来连接多个tensor。利用dir(tf)判断”concat_v2”是否在里面,如果在的话,定义一个concat(tensors, axis, *args, **kwargs)函数,并返回tf.concat_v2(tensors, axis, *args, **kwargs);否则也定义concat(tensors, axis, *args, **kwargs)函数,只不过返回的是tf.concat(tensors, axis, *args, **kwargs)。其中,tf.concat使用如下:

t1=tf.constant([[1,2,3],[4,5,6]])
t2=tf.constant([[7,8,9],[10,11,12]])
t3=tf.concat([t1,t2],0)
t4=tf.concat([t1,t2],1)
print t1
print t2
print t3
print t4
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

这里写图片描述

step3:定义一个batch_norm类,包含两个函数initcall函数。首先在init(self, epsilon=1e-5, momentum = 0.9, name=”batch_norm”)函数中,定义一个name参数名字的变量,初始化self变量epsilon、momentum 、name。在call(self, x, train=True)函数中,利用tf.contrib.layers.batch_norm函数批处理规范化。

step4:定义conv_cond_concat(x,y)函数。连接x,y与Int32型的[x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]]维度的张量乘积。

step5:定义conv2d(input_, output_dim, k_h=5, k_w=5, d_h=2,d_w=2, stddev=0.02,name=”conv2d”)函数。卷积函数:获取随机正态分布权值、实现卷积、获取初始偏置值,获取添加偏置值后的卷积变量并返回。

step6:定义deconv2d(input_, output_shape,k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,name=”deconv2d”, with_w=False):函数。解卷积函数:获取随机正态分布权值、解卷积,获取初始偏置值,获取添加偏置值后的卷积变量,判断with_w是否为真,真则返回解卷积、权值、偏置值,否则返回解卷积。

step7:定义lrelu(x, leak=0.2, name=”lrelu”)函数。定义一个lrelu激励函数。

step8:定义linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False)函数。进行线性运算,获取一个随机正态分布矩阵,获取初始偏置值,如果with_w为真,则返回xw+b,权值w和偏置值b;否则返回xw+b。

这个文件主要定义了一些变量连接的函数、批处理规范化的函数、卷积函数、解卷积函数、激励函数、线性运算函数。

2.3.4 model.py

源代码(530行):

from __future__ import division
import os
import time
import math
from glob import glob  # file path search
import tensorflow as tf
import numpy as np
from six.moves import xrangefrom ops import *
from utils import *def conv_out_size_same(size, stride):return int(math.ceil(float(size) / float(stride)))class DCGAN(object):def __init__(self, sess, input_height=108, input_width=108, crop=True,batch_size=64, sample_num = 64, output_height=64, output_width=64,y_dim=None, z_dim=100, gf_dim=64, df_dim=64,gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default',input_fname_pattern='*.jpg', checkpoint_dir=None, sample_dir=None):"""Args:sess: TensorFlow sessionbatch_size: The size of batch. Should be specified before training.y_dim: (optional) Dimension of dim for y. [None]z_dim: (optional) Dimension of dim for Z. [100]gf_dim: (optional) Dimension of gen filters in first conv layer. [64]df_dim: (optional) Dimension of discrim filters in first conv layer. [64]gfc_dim: (optional) Dimension of gen units for for fully connected layer. [1024]dfc_dim: (optional) Dimension of discrim units for fully connected layer. [1024]c_dim: (optional) Dimension of image color. For grayscale input, set to 1. [3]"""self.sess = sessself.crop = cropself.batch_size = batch_sizeself.sample_num = sample_numself.input_height = input_heightself.input_width = input_widthself.output_height = output_heightself.output_width = output_widthself.y_dim = y_dimself.z_dim = z_dimself.gf_dim = gf_dimself.df_dim = df_dimself.gfc_dim = gfc_dimself.dfc_dim = dfc_dim# batch normalization : deals with poor initialization helps gradient flowself.d_bn1 = batch_norm(name='d_bn1')self.d_bn2 = batch_norm(name='d_bn2')if not self.y_dim:self.d_bn3 = batch_norm(name='d_bn3')self.g_bn0 = batch_norm(name='g_bn0')self.g_bn1 = batch_norm(name='g_bn1')self.g_bn2 = batch_norm(name='g_bn2')if not self.y_dim:self.g_bn3 = batch_norm(name='g_bn3')self.dataset_name = dataset_nameself.input_fname_pattern = input_fname_patternself.checkpoint_dir = checkpoint_dirif self.dataset_name == 'mnist':self.data_X, self.data_y = self.load_mnist()self.c_dim = self.data_X[0].shape[-1]else:self.data = glob(os.path.join("./data", self.dataset_name, self.input_fname_pattern))imreadImg = imread(self.data[0]);if len(imreadImg.shape) >= 3: #check if image is a non-grayscale image by checking channel numberself.c_dim = imread(self.data[0]).shape[-1]else:self.c_dim = 1self.grayscale = (self.c_dim == 1)self.build_model()def build_model(self):if self.y_dim:self.y= tf.placeholder(tf.float32, [self.batch_size, self.y_dim], name='y')if self.crop:image_dims = [self.output_height, self.output_width, self.c_dim]else:image_dims = [self.input_height, self.input_width, self.c_dim]self.inputs = tf.placeholder(tf.float32, [self.batch_size] + image_dims, name='real_images')inputs = self.inputsself.z = tf.placeholder(tf.float32, [None, self.z_dim], name='z')self.z_sum = histogram_summary("z", self.z)if self.y_dim:self.G = self.generator(self.z, self.y)self.D, self.D_logits = \self.discriminator(inputs, self.y, reuse=False)self.sampler = self.sampler(self.z, self.y)self.D_, self.D_logits_ = \self.discriminator(self.G, self.y, reuse=True)else:self.G = self.generator(self.z)self.D, self.D_logits = self.discriminator(inputs)self.sampler = self.sampler(self.z)self.D_, self.D_logits_ = self.discriminator(self.G, reuse=True)self.d_sum = histogram_summary("d", self.D)self.d__sum = histogram_summary("d_", self.D_)self.G_sum = image_summary("G", self.G)def sigmoid_cross_entropy_with_logits(x, y):try:return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y)except:return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, targets=y)self.d_loss_real = tf.reduce_mean(sigmoid_cross_entropy_with_logits(self.D_logits, tf.ones_like(self.D)))self.d_loss_fake = tf.reduce_mean(sigmoid_cross_entropy_with_logits(self.D_logits_, tf.zeros_like(self.D_)))self.g_loss = tf.reduce_mean(sigmoid_cross_entropy_with_logits(self.D_logits_, tf.ones_like(self.D_)))self.d_loss_real_sum = scalar_summary("d_loss_real", self.d_loss_real)self.d_loss_fake_sum = scalar_summary("d_loss_fake", self.d_loss_fake)self.d_loss = self.d_loss_real + self.d_loss_fakeself.g_loss_sum = scalar_summary("g_loss", self.g_loss)self.d_loss_sum = scalar_summary("d_loss", self.d_loss)t_vars = tf.trainable_variables()self.d_vars = [var for var in t_vars if 'd_' in var.name]self.g_vars = [var for var in t_vars if 'g_' in var.name]self.saver = tf.train.Saver()def train(self, config):d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \.minimize(self.d_loss, var_list=self.d_vars)g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \.minimize(self.g_loss, var_list=self.g_vars)try:tf.global_variables_initializer().run()except:tf.initialize_all_variables().run()self.g_sum = merge_summary([self.z_sum, self.d__sum,self.G_sum, self.d_loss_fake_sum, self.g_loss_sum])self.d_sum = merge_summary([self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum])self.writer = SummaryWriter("./logs", self.sess.graph)sample_z = np.random.uniform(-1, 1, size=(self.sample_num , self.z_dim))if config.dataset == 'mnist':sample_inputs = self.data_X[0:self.sample_num]sample_labels = self.data_y[0:self.sample_num]else:sample_files = self.data[0:self.sample_num]sample = [get_image(sample_file,input_height=self.input_height,input_width=self.input_width,resize_height=self.output_height,resize_width=self.output_width,crop=self.crop,grayscale=self.grayscale) for sample_file in sample_files]if (self.grayscale):sample_inputs = np.array(sample).astype(np.float32)[:, :, :, None]else:sample_inputs = np.array(sample).astype(np.float32)counter = 1start_time = time.time()could_load, checkpoint_counter = self.load(self.checkpoint_dir)if could_load:counter = checkpoint_counterprint(" [*] Load SUCCESS")else:print(" [!] Load failed...")for epoch in xrange(config.epoch):if config.dataset == 'mnist':batch_idxs = min(len(self.data_X), config.train_size) // config.batch_sizeelse:      self.data = glob(os.path.join("./data", config.dataset, self.input_fname_pattern))batch_idxs = min(len(self.data), config.train_size) // config.batch_sizefor idx in xrange(0, batch_idxs):if config.dataset == 'mnist':batch_images = self.data_X[idx*config.batch_size:(idx+1)*config.batch_size]batch_labels = self.data_y[idx*config.batch_size:(idx+1)*config.batch_size]else:batch_files = self.data[idx*config.batch_size:(idx+1)*config.batch_size]batch = [get_image(batch_file,input_height=self.input_height,input_width=self.input_width,resize_height=self.output_height,resize_width=self.output_width,crop=self.crop,grayscale=self.grayscale) for batch_file in batch_files]if self.grayscale:batch_images = np.array(batch).astype(np.float32)[:, :, :, None]else:batch_images = np.array(batch).astype(np.float32)batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim]) \.astype(np.float32)if config.dataset == 'mnist':# Update D network_, summary_str = self.sess.run([d_optim, self.d_sum],feed_dict={ self.inputs: batch_images,self.z: batch_z,self.y:batch_labels,})self.writer.add_summary(summary_str, counter)# Update G network_, summary_str = self.sess.run([g_optim, self.g_sum],feed_dict={self.z: batch_z, self.y:batch_labels,})self.writer.add_summary(summary_str, counter)# Run g_optim twice to make sure that d_loss does not go to zero (different from paper)_, summary_str = self.sess.run([g_optim, self.g_sum],feed_dict={ self.z: batch_z, self.y:batch_labels })self.writer.add_summary(summary_str, counter)errD_fake = self.d_loss_fake.eval({self.z: batch_z, self.y:batch_labels})errD_real = self.d_loss_real.eval({self.inputs: batch_images,self.y:batch_labels})errG = self.g_loss.eval({self.z: batch_z,self.y: batch_labels})else:# Update D network_, summary_str = self.sess.run([d_optim, self.d_sum],feed_dict={ self.inputs: batch_images, self.z: batch_z })self.writer.add_summary(summary_str, counter)# Update G network_, summary_str = self.sess.run([g_optim, self.g_sum],feed_dict={ self.z: batch_z })self.writer.add_summary(summary_str, counter)# Run g_optim twice to make sure that d_loss does not go to zero (different from paper)_, summary_str = self.sess.run([g_optim, self.g_sum],feed_dict={ self.z: batch_z })self.writer.add_summary(summary_str, counter)errD_fake = self.d_loss_fake.eval({ self.z: batch_z })errD_real = self.d_loss_real.eval({ self.inputs: batch_images })errG = self.g_loss.eval({self.z: batch_z})counter += 1print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \% (epoch, idx, batch_idxs,time.time() - start_time, errD_fake+errD_real, errG))if np.mod(counter, 100) == 1:if config.dataset == 'mnist':samples, d_loss, g_loss = self.sess.run([self.sampler, self.d_loss, self.g_loss],feed_dict={self.z: sample_z,self.inputs: sample_inputs,self.y:sample_labels,})save_images(samples, image_manifold_size(samples.shape[0]),'./{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, idx))print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss)) else:try:samples, d_loss, g_loss = self.sess.run([self.sampler, self.d_loss, self.g_loss],feed_dict={self.z: sample_z,self.inputs: sample_inputs,},)save_images(samples, image_manifold_size(samples.shape[0]),'./{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, idx))print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss)) except:print("one pic error!...")if np.mod(counter, 500) == 2:self.save(config.checkpoint_dir, counter)def discriminator(self, image, y=None, reuse=False):with tf.variable_scope("discriminator") as scope:if reuse:scope.reuse_variables()if not self.y_dim:h0 = lrelu(conv2d(image, self.df_dim, name='d_h0_conv'))h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim*2, name='d_h1_conv')))h2 = lrelu(self.d_bn2(conv2d(h1, self.df_dim*4, name='d_h2_conv')))h3 = lrelu(self.d_bn3(conv2d(h2, self.df_dim*8, name='d_h3_conv')))h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'd_h4_lin')return tf.nn.sigmoid(h4), h4else:yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])x = conv_cond_concat(image, yb)h0 = lrelu(conv2d(x, self.c_dim + self.y_dim, name='d_h0_conv'))h0 = conv_cond_concat(h0, yb)h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim + self.y_dim, name='d_h1_conv')))h1 = tf.reshape(h1, [self.batch_size, -1])      h1 = concat([h1, y], 1)h2 = lrelu(self.d_bn2(linear(h1, self.dfc_dim, 'd_h2_lin')))h2 = concat([h2, y], 1)h3 = linear(h2, 1, 'd_h3_lin')return tf.nn.sigmoid(h3), h3def generator(self, z, y=None):with tf.variable_scope("generator") as scope:if not self.y_dim:s_h, s_w = self.output_height, self.output_widths_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2)# project `z` and reshapeself.z_, self.h0_w, self.h0_b = linear(z, self.gf_dim*8*s_h16*s_w16, 'g_h0_lin', with_w=True)self.h0 = tf.reshape(self.z_, [-1, s_h16, s_w16, self.gf_dim * 8])h0 = tf.nn.relu(self.g_bn0(self.h0))self.h1, self.h1_w, self.h1_b = deconv2d(h0, [self.batch_size, s_h8, s_w8, self.gf_dim*4], name='g_h1', with_w=True)h1 = tf.nn.relu(self.g_bn1(self.h1))h2, self.h2_w, self.h2_b = deconv2d(h1, [self.batch_size, s_h4, s_w4, self.gf_dim*2], name='g_h2', with_w=True)h2 = tf.nn.relu(self.g_bn2(h2))h3, self.h3_w, self.h3_b = deconv2d(h2, [self.batch_size, s_h2, s_w2, self.gf_dim*1], name='g_h3', with_w=True)h3 = tf.nn.relu(self.g_bn3(h3))h4, self.h4_w, self.h4_b = deconv2d(h3, [self.batch_size, s_h, s_w, self.c_dim], name='g_h4', with_w=True)return tf.nn.tanh(h4)else:s_h, s_w = self.output_height, self.output_widths_h2, s_h4 = int(s_h/2), int(s_h/4)s_w2, s_w4 = int(s_w/2), int(s_w/4)# yb = tf.expand_dims(tf.expand_dims(y, 1),2)yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])z = concat([z, y], 1)h0 = tf.nn.relu(self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin')))h0 = concat([h0, y], 1)h1 = tf.nn.relu(self.g_bn1(linear(h0, self.gf_dim*2*s_h4*s_w4, 'g_h1_lin')))h1 = tf.reshape(h1, [self.batch_size, s_h4, s_w4, self.gf_dim * 2])h1 = conv_cond_concat(h1, yb)h2 = tf.nn.relu(self.g_bn2(deconv2d(h1,[self.batch_size, s_h2, s_w2, self.gf_dim * 2], name='g_h2')))h2 = conv_cond_concat(h2, yb)return tf.nn.sigmoid(deconv2d(h2, [self.batch_size, s_h, s_w, self.c_dim], name='g_h3'))def sampler(self, z, y=None):with tf.variable_scope("generator") as scope:scope.reuse_variables()if not self.y_dim:s_h, s_w = self.output_height, self.output_widths_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2)# project `z` and reshapeh0 = tf.reshape(linear(z, self.gf_dim*8*s_h16*s_w16, 'g_h0_lin'),[-1, s_h16, s_w16, self.gf_dim * 8])h0 = tf.nn.relu(self.g_bn0(h0, train=False))h1 = deconv2d(h0, [self.batch_size, s_h8, s_w8, self.gf_dim*4], name='g_h1')h1 = tf.nn.relu(self.g_bn1(h1, train=False))h2 = deconv2d(h1, [self.batch_size, s_h4, s_w4, self.gf_dim*2], name='g_h2')h2 = tf.nn.relu(self.g_bn2(h2, train=False))h3 = deconv2d(h2, [self.batch_size, s_h2, s_w2, self.gf_dim*1], name='g_h3')h3 = tf.nn.relu(self.g_bn3(h3, train=False))h4 = deconv2d(h3, [self.batch_size, s_h, s_w, self.c_dim], name='g_h4')return tf.nn.tanh(h4)else:s_h, s_w = self.output_height, self.output_widths_h2, s_h4 = int(s_h/2), int(s_h/4)s_w2, s_w4 = int(s_w/2), int(s_w/4)# yb = tf.reshape(y, [-1, 1, 1, self.y_dim])yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])z = concat([z, y], 1)h0 = tf.nn.relu(self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin'), train=False))h0 = concat([h0, y], 1)h1 = tf.nn.relu(self.g_bn1(linear(h0, self.gf_dim*2*s_h4*s_w4, 'g_h1_lin'), train=False))h1 = tf.reshape(h1, [self.batch_size, s_h4, s_w4, self.gf_dim * 2])h1 = conv_cond_concat(h1, yb)h2 = tf.nn.relu(self.g_bn2(deconv2d(h1, [self.batch_size, s_h2, s_w2, self.gf_dim * 2], name='g_h2'), train=False))h2 = conv_cond_concat(h2, yb)return tf.nn.sigmoid(deconv2d(h2, [self.batch_size, s_h, s_w, self.c_dim], name='g_h3'))def load_mnist(self):data_dir = os.path.join("./data", self.dataset_name)fd = open(os.path.join(data_dir,'train-images-idx3-ubyte'))loaded = np.fromfile(file=fd,dtype=np.uint8)trX = loaded[16:].reshape((60000,28,28,1)).astype(np.float)fd = open(os.path.join(data_dir,'train-labels-idx1-ubyte'))loaded = np.fromfile(file=fd,dtype=np.uint8)trY = loaded[8:].reshape((60000)).astype(np.float)fd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte'))loaded = np.fromfile(file=fd,dtype=np.uint8)teX = loaded[16:].reshape((10000,28,28,1)).astype(np.float)fd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte'))loaded = np.fromfile(file=fd,dtype=np.uint8)teY = loaded[8:].reshape((10000)).astype(np.float)trY = np.asarray(trY)teY = np.asarray(teY)X = np.concatenate((trX, teX), axis=0)y = np.concatenate((trY, teY), axis=0).astype(np.int)seed = 547np.random.seed(seed)np.random.shuffle(X)np.random.seed(seed)np.random.shuffle(y)y_vec = np.zeros((len(y), self.y_dim), dtype=np.float)for i, label in enumerate(y):y_vec[i,y[i]] = 1.0return X/255.,y_vec@propertydef model_dir(self):return "{}_{}_{}_{}".format(self.dataset_name, self.batch_size,self.output_height, self.output_width)def save(self, checkpoint_dir, step):model_name = "DCGAN.model"checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)if not os.path.exists(checkpoint_dir):os.makedirs(checkpoint_dir)self.saver.save(self.sess,os.path.join(checkpoint_dir, model_name),global_step=step)def load(self, checkpoint_dir):import reprint(" [*] Reading checkpoints...")checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)ckpt = tf.train.get_checkpoint_state(checkpoint_dir)if ckpt and ckpt.model_checkpoint_path:ckpt_name = os.path.basename(ckpt.model_checkpoint_path)self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))print(" [*] Success to read {}".format(ckpt_name))return True, counterelse:print(" [*] Failed to find a checkpoint")return False, 0
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345
  • 346
  • 347
  • 348
  • 349
  • 350
  • 351
  • 352
  • 353
  • 354
  • 355
  • 356
  • 357
  • 358
  • 359
  • 360
  • 361
  • 362
  • 363
  • 364
  • 365
  • 366
  • 367
  • 368
  • 369
  • 370
  • 371
  • 372
  • 373
  • 374
  • 375
  • 376
  • 377
  • 378
  • 379
  • 380
  • 381
  • 382
  • 383
  • 384
  • 385
  • 386
  • 387
  • 388
  • 389
  • 390
  • 391
  • 392
  • 393
  • 394
  • 395
  • 396
  • 397
  • 398
  • 399
  • 400
  • 401
  • 402
  • 403
  • 404
  • 405
  • 406
  • 407
  • 408
  • 409
  • 410
  • 411
  • 412
  • 413
  • 414
  • 415
  • 416
  • 417
  • 418
  • 419
  • 420
  • 421
  • 422
  • 423
  • 424
  • 425
  • 426
  • 427
  • 428
  • 429
  • 430
  • 431
  • 432
  • 433
  • 434
  • 435
  • 436
  • 437
  • 438
  • 439
  • 440
  • 441
  • 442
  • 443
  • 444
  • 445
  • 446
  • 447
  • 448
  • 449
  • 450
  • 451
  • 452
  • 453
  • 454
  • 455
  • 456
  • 457
  • 458
  • 459
  • 460
  • 461
  • 462
  • 463
  • 464
  • 465
  • 466
  • 467
  • 468
  • 469
  • 470
  • 471
  • 472
  • 473
  • 474
  • 475
  • 476
  • 477
  • 478
  • 479
  • 480
  • 481
  • 482
  • 483
  • 484
  • 485
  • 486
  • 487
  • 488
  • 489
  • 490
  • 491
  • 492
  • 493
  • 494
  • 495
  • 496
  • 497
  • 498
  • 499
  • 500
  • 501
  • 502
  • 503
  • 504
  • 505
  • 506
  • 507
  • 508
  • 509
  • 510
  • 511
  • 512
  • 513
  • 514
  • 515
  • 516
  • 517
  • 518
  • 519
  • 520
  • 521
  • 522
  • 523
  • 524
  • 525
  • 526
  • 527
  • 528
  • 529
  • 530

这个文件就是DCGAN模型定义的函数。调用了utils.py文件和ops.py文件。

step0:定义conv_out_size_same(size, stride)函数。大小和步幅。

step1:然后是定义了DCGAN类,剩余代码都是在写DCGAN类,所以下面几步都是在这个类里面定义进行的。

step2:定义类的初始化函数 init。主要是对一些默认的参数进行初始化。包括session、crop、批处理大小batch_size、样本数量sample_num、输入与输出的高和宽、各种维度、生成器与判别器的批处理、数据集名字、灰度值、构建模型函数,需要注意的是,要判断数据集的名字是否是mnist,是的话则直接用load_mnist()函数加载数据,否则需要从本地data文件夹中读取数据,并将图像读取为灰度图。

step3:定义构建模型函数build_model(self)。

  1. 首先判断y_dim,然后用tf.placeholder占位符定义并初始化y。
  2. 判断crop是否为真,是的话是进行测试,图像维度是输出图像的维度;否则是输入图像的维度。
  3. 利用tf.placeholder定义inputs,是真实数据的向量。
  4. 定义并初始化生成器用到的噪音z,z_sum。
  5. 再次判断y_dim,如果为真,用噪音z和标签y初始化生成器G、用输入inputs初始化判别器D和D_logits、样本、用G和y初始化D_和D_logits;如果为假,跟上面一样初始化各种变量,只不过都没有标签y。
  6. 将5中的D、D_、G分别放在d_sum、d__sum、G_sum。
  7. 定义sigmoid交叉熵损失函数sigmoid_cross_entropy_with_logits(x, y)。都是调用tf.nn.sigmoid_cross_entropy_with_logits函数,只不过一个是训练,y是标签,一个是测试,y是目标。
  8. 定义各种损失值。真实数据的判别损失值d_loss_real、虚假数据的判别损失值d_loss_fake、生成器损失值g_loss、判别器损失值d_loss。
  9. 定义训练的所有变量t_vars。
  10. 定义生成和判别的参数集。
  11. 最后是保存。

step4:定义训练函数train(self, config)。

  1. 定义判别器优化器d_optim和生成器优化器g_optim。
  2. 变量初始化。
  3. 分别将关于生成器和判别器有关的变量各合并到一个变量中,并写入事件文件中。
  4. 噪音z初始化。
  5. 根据数据集是否为mnist的判断,进行输入数据和标签的获取。这里使用到了utils.py文件中的get_image函数。
  6. 定义计数器counter和起始时间start_time。
  7. 加载检查点,并判断加载是否成功。
  8. 开始for epoch in xrange(config.epoch)循环训练。先判断数据集是否是mnist,获取批处理的大小。
  9. 开始for idx in xrange(0, batch_idxs)循环训练,判断数据集是否是mnist,来定义初始化批处理图像和标签。
  10. 定义初始化噪音z。
  11. 判断数据集是否是mnist,来更新判别器网络和生成器网络,这里就不管mnist数据集是怎么处理的,其他数据集是,运行生成器优化器两次,以确保判别器损失值不会变为0,然后是判别器真实数据损失值和虚假数据损失值、生成器损失值。
  12. 输出本次批处理中训练参数的情况,首先是第几个epoch,第几个batch,训练时间,判别器损失值,生成器损失值。
  13. 每100次batch训练后,根据数据集是否是mnist的不同,获取样本、判别器损失值、生成器损失值,调用utils.py文件的save_images函数,保存训练后的样本,并以epoch、batch的次数命名文件。然后打印判别器损失值和生成器损失值。
  14. 每500次batch训练后,保存一次检查点。

step5:定义判别器函数discriminator(self, image, y=None, reuse=False)。

  1. 利用with tf.variable_scope(“discriminator”) as scope,在一个作用域 scope 内共享一些变量。
  2. 对scope利用reuse_variables()进行重利用。
  3. 如果为假,则直接设置5层,前4层为使用lrelu激活函数的卷积层,最后一层是使用线性层,最后返回h4和sigmoid处理后的h4。
  4. 如果为真,则首先将Y_dim变为yb,然后利用ops.py文件中的conv_cond_concat函数,连接image与yb得到x,然后设置4层网络,前3层是使用lrelu激励函数的卷积层,最后一层是线性层,最后返回h3和sigmoid处理后的h3。

step6:定义生成器函数generator(self, z, y=None)。

  1. 利用with tf.variable_scope(“generator”) as scope,在一个作用域 scope 内共享一些变量。
  2. 根据y_dim是否为真,进行判别网络的设置。
  3. 如果为假:首先获取输出的宽和高,然后根据这一值得到更多不同大小的高和宽的对。然后获取h0层的噪音z,权值w,偏置值b,然后利用relu激励函数。h1层,首先对h0层解卷积得到本层的权值和偏置值,然后利用relu激励函数。h2、h3等同于h1。h4层,解卷积h3,然后直接返回使用tanh激励函数后的h4。
  4. 如果为真:首先也是获取输出的高和宽,根据这一值得到更多不同大小的高和宽的对。然后获取yb和噪音z。h0层,使用relu激励函数,并与1连接。h1层,对线性全连接后使用relu激励函数,并与yb连接。h2层,对解卷积后使用relu激励函数,并与yb连接。最后返回解卷积、sigmoid处理后的h2。

step7:定义sampler(self, z, y=None)函数。

  1. 利用tf.variable_scope(“generator”) as scope,在一个作用域 scope 内共享一些变量。
  2. 对scope利用reuse_variables()进行重利用。
  3. 根据y_dim是否为真,进行判别网络的设置。
  4. 然后就跟生成器差不多,不在赘述。

step8:定义load_mnist(self)函数。这个主要是针对mnist数据集设置的,所以暂且不考虑,过。

step9:定义model_dir(self)函数。返回数据集名字,batch大小,输出的高和宽。

step10:定义save(self, checkpoint_dir, step)函数。保存训练好的模型。创建检查点文件夹,如果路径不存在,则创建;然后将其保存在这个文件夹下。

step11:定义load(self, checkpoint_dir)函数。读取检查点,获取路径,重新存储检查点,并且计数。打印成功读取的提示;如果没有路径,则打印失败的提示。

以上,就是model.py所有内容,主要是定义了DCGAN的类,完成了生成判别网络的实现。

2.4 训练

现在,整个4个文件都已经分析完毕,开始运行。

step0:由于我们使用的动漫人脸数据集,所以我们需要在源文件的路径下,建一个data文件夹,然后将放有数据的文件夹放在这个data文件夹中,如下所示。

这里写图片描述

这里写图片描述

step1:运行命令如下,需要制定各种参数,如我们的输入数据的高宽,输出的高宽,是哪个数据集,是否测试、训练,运行几个epoch。

如果你看到了此处,很好,接下来一系列的问题都是由于这里的原因导致我的训练不收敛,出来的结果乱七八糟!!这是因为,参数名称写错了!!!应该是:

python main.py --input_height 96 --output_height 48 --dataset faces --crop True --train True --epoch 10
  • 1

下面这个参数名称是错误的!(嗯,后面我还是会再说一遍的)

python main.py --image_size 96 --output_size 48 --dataset faces --crop True --train True --epoch 10
  • 1

这里写图片描述

step2:中间结果

这是第0个epoch,前3个batch:

这里写图片描述

新生成的文件:

这里写图片描述

step3:训练和测试结果

如果你又看到这里,可以忽略,直接去结果那看,因为这里都是参数没写对,生成的不收敛的结果!

第一个epoch:

这里写图片描述

第9个epoch:

这里写图片描述

看得出来,效果并不咋地,与参考更是相差甚远,这是因为训练数据只有3000+,而且总共训练了10个epoch。本来只是先试试,毕竟是纯cpu在跑,还有2个G,哎。

step4:这次训练数据选了16383张,epoch==300,跑了一晚上了,今天来看才到第5个epoch,嗯,慢慢等。

step5:重新在服务器上训练,这次选了参考博客上提供的数据集,因为前两次自己采集处理的数据集,或是因为数据集过小,训练效果差强人意,所以直接拿这个5万左右的数据集来试试。epoch==300。

step6:效果太差了。也不知道是哪里的问题,先把结果截图放上去,等有空再查查是什么原因。(严重怀疑是我的数据集有问题,因为当时在本地跑时对数据操作过,可能出现了问题。后面有空再弄吧)

结果标题代表第几个epoch第几个batch。

这里写图片描述

这里写图片描述

这里写图片描述

这里写图片描述

2.5 结果

好了,终于找到原因了,是因为参数名称写错了,没有将输入数据的高宽与输出数据的高宽由原先的108与64改为96与48,简直是太蠢了!!(此处感谢评论里某位小伙伴!要不是他说修改了参数我都没注意到)

重新训练:

python main.py --input_height 96 --output_height 48 --dataset faces --crop True --train True --epoch 10
  • 1

只用了10个epoch,效果就已经有点可观了,等服务器有空跑个300试试。

换了epoch==300,先放几张已有的效果,等跑完300再把全部结果放上来。

epoch 0

这里写图片描述

这里写图片描述

epoch 5

这里写图片描述

epoch 10

这里写图片描述

epoch 20

这里写图片描述

epoch 100

这里写图片描述

这里写图片描述

epoch 200

这里写图片描述

这里写图片描述

epoch 300

这里写图片描述

这里写图片描述


http://www.ppmy.cn/news/335135.html

相关文章

dilation conv 和 deconv

最近工作要用到dilation conv,在此总结一下空洞卷积,并和deconv进行对比。 dilation conv 首先空洞卷积的目的是为了在扩大感受野的同时,不降低图片分辨率和不引入额外参数及计算量(一般在CNN中扩大感受野都需要使用s>1的con…

【Android SDM660源码分析】- 02 - UEFI XBL QcomChargerApp充电流程代码分析

【Android SDM660源码分析】- 02 - UEFI XBL QcomChargerApp充电流程代码分析 一、加载 UEFI 默认应用程序1.1 LaunchDefaultBDSApps()1.1 LaunchAppFromGuidedFv() 二、QcomChargerApp应用程序初始化2.1 入口函数 QcomChargerApp_Entry()2.2 充电初始化 QcomChargerApp_Initia…

nvidia驱动,cuda与cudnn的关系

一关系阐述: (1)NVIDIA的显卡驱动程序和CUDA完全是两个不同的概念哦!CUDA是NVIDIA推出的用于自家GPU的并行计算框架,也就是说CUDA只能在NVIDIA的GPU上运行,而且只有当要解决的计算问题是可以大量并行计算的…

cin、cout的使用

cin、cout的使用 基本内容&#xff1a; (1)有关流对象cin、cout和流运算符的定义等信息是存放在C的输入输出流库中的&#xff0c;因此在程序中使用cin、cout和流运算符&#xff0c;就必须使用预处理命令把头文件stream包含到本文件中. 示教&#xff1a;#include <iostream…

DCC:Deep continuous clustering

文章&#xff1a;NIPS’17 代码&#xff1a;TensorFlow实现&#xff1b;Pytorch实现 经典的聚类算法具有离散结构&#xff1a;需要重新计算质心和数据点之间的关联&#xff0c;或者需要合并假定的聚类。 在任何一种情况下&#xff0c;优化过程都会被离散的重新配置打断。 连续…

Rockchip RK3588 kernel dts解析之PCIe

Rockchip RK3588 kernel dts解析之PCIe 文章目录 Rockchip RK3588 kernel dts解析之PCIeRK3588控制器RK3588 PHY使用限制DTS配置解析硬件设计软件DTS配置其他常见的PCIE配置对应的DTS配置实例pcie3.0phy拆分2个2Lane RC, 3个PCIe 2.0 1Lane(comboPHY)pcie3.0phy拆分为4个1Lane,…

Deep Complex Convolution Recurrent Network(DCCRN模型)

Abstract 深度学习给语音增强带来很多益处&#xff0c;传统的时频域(TF)方法主要通过朴素卷积神经网络(CNN)或递归神经网络(RNN)预测TF掩码或语音频谱。一些研究将将复值谱图作为训练目标&#xff0c;在实值网络中训练&#xff0c;分别预测幅值和相位分量或实部和虚部。特别是…

RK3588 Android平台SPI NOR+PCIE SSD实现大容量存储方案

RK3588 Android平台SPI NORPCIE SSD实现大容量存储方案 硬件配置 硬件方案是基于RK3588S自研平板方案实现。 CPU: RK3588S DDR: LPDDR5 8GB NOR: SPI接口 32MB容量 SSD&#xff1a; PCIE接口 256GB容量 软件版本要求 RK3588 Android12 SDK 升级到RKR8及以上版本RKTools/lin…