深度学习实战之超分辨率算法(tensorflow)——ESPCN

devtools/2024/12/24 20:42:49/

espcn原理算法请参考上一篇论文,这里主要给实现。
数据集如下:尺寸相等即可
在这里插入图片描述

  • 针对数据集,生成样本代码
  • preeate_data.py
import imageio
from scipy import misc, ndimage
import numpy as np
import imghdr
import shutil
import os
import jsonmat = np.array([[ 65.481, 128.553, 24.966 ],[-37.797, -74.203, 112.0  ],[  112.0, -93.786, -18.214]])
mat_inv = np.linalg.inv(mat)
offset = np.array([16, 128, 128])def rgb2ycbcr(rgb_img):ycbcr_img = np.zeros(rgb_img.shape, dtype=np.uint8)for x in range(rgb_img.shape[0]):for y in range(rgb_img.shape[1]):ycbcr_img[x, y, :] = np.round(np.dot(mat, rgb_img[x, y, :] * 1.0 / 255) + offset)return ycbcr_imgdef ycbcr2rgb(ycbcr_img):rgb_img = np.zeros(ycbcr_img.shape, dtype=np.uint8)for x in range(ycbcr_img.shape[0]):for y in range(ycbcr_img.shape[1]):[r, g, b] = ycbcr_img[x,y,:]rgb_img[x, y, :] = np.maximum(0, np.minimum(255, np.round(np.dot(mat_inv, ycbcr_img[x, y, :] - offset) * 255.0)))return rgb_imgdef my_anti_shuffle(input_image, ratio):shape = input_image.shapeori_height = int(shape[0])ori_width = int(shape[1])ori_channels = int(shape[2])if ori_height % ratio != 0 or ori_width % ratio != 0:print("Error! Height and width must be divided by ratio!")returnheight = ori_height // ratiowidth = ori_width // ratiochannels = ori_channels * ratio * ratioanti_shuffle = np.zeros((height, width, channels), dtype=np.uint8)for c in range(0, ori_channels):for x in range(0, ratio):for y in range(0, ratio):anti_shuffle[:,:,c * ratio * ratio + x * ratio + y] = input_image[x::ratio, y::ratio, c]return anti_shuffledef shuffle(input_image, ratio):shape = input_image.shapeheight = int(shape[0]) * ratiowidth = int(shape[1]) * ratiochannels = int(shape[2]) // ratio // ratioshuffled = np.zeros((height, width, channels), dtype=np.uint8)for i in range(0, height):for j in range(0, width):for k in range(0, channels):shuffled[i,j,k] = input_image[i // ratio, j // ratio, k * ratio * ratio + (i % ratio) * ratio + (j % ratio)]return shuffleddef prepare_images(params):ratio, training_num, lr_stride, lr_size = params['ratio'], params['training_num'], params['lr_stride'], params['lr_size']hr_stride = lr_stride * ratiohr_size = lr_size * ratio# first clear old images and create new directoriesfor ele in ['training', 'validation', 'test']:new_dir = params[ele + '_image_dir'].format(ratio)if os.path.isdir(new_dir):shutil.rmtree(new_dir)for sub_dir in ['/hr', 'lr']:os.makedirs(new_dir + sub_dir)image_num = 0folder = params['training_image_dir'].format(ratio)for root, dirnames, filenames in os.walk(params['image_dir']):for filename in filenames:path = os.path.join(root, filename)if imghdr.what(path) != 'jpeg':continuehr_image = imageio.imread(path)height = hr_image.shape[0]new_height = height - height % ratiowidth = hr_image.shape[1]new_width = width - width % ratiohr_image = hr_image[0:new_height,0:new_width]blurred = ndimage.gaussian_filter(hr_image, sigma=(1, 1, 0))lr_image = blurred[::ratio,::ratio,:]height = hr_image.shape[0]width = hr_image.shape[1]vertical_number = height / hr_stride - 1horizontal_number = width / hr_stride - 1image_num = image_num + 1if image_num % 10 == 0:print ("Finished image: {}".format(image_num))if image_num > training_num and image_num <= training_num + params['validation_num']:folder = params['validation_image_dir'].format(ratio)elif image_num > training_num + params['validation_num']:folder = params['test_image_dir'].format(ratio)#misc.imsave(folder + 'hr_full/' + filename[0:-4] + '.png', hr_image)#misc.imsave(folder + 'lr_full/' + filename[0:-4] + '.png', lr_image)for x in range(0, int(horizontal_number)):for y in range(0, int(vertical_number)):hr_sub_image = hr_image[y * hr_stride : y * hr_stride + hr_size, x * hr_stride : x * hr_stride + hr_size]lr_sub_image = lr_image[y * lr_stride : y * lr_stride + lr_size, x * lr_stride : x * lr_stride + lr_size]imageio.imwrite("{}hr/{}_{}_{}.png".format(folder, filename[0:-4], y, x), hr_sub_image)imageio.imwrite("{}lr/{}_{}_{}.png".format(folder, filename[0:-4], y, x), lr_sub_image)if image_num >= training_num + params['validation_num'] + params['test_num']:breakelse:continuebreakdef prepare_data(params):ratio = params['ratio']params['hr_stride'] = params['lr_stride'] * ratioparams['hr_size'] = params['lr_size'] * ratiofor ele in ['training', 'validation', 'test']:new_dir = params[ele + '_dir'].format(ratio)if os.path.isdir(new_dir):shutil.rmtree(new_dir)os.makedirs(new_dir)ratio, lr_size, edge = params['ratio'], params['lr_size'], params['edge']image_dirs = [d.format(ratio) for d in [params['training_image_dir'], params['validation_image_dir'], params['test_image_dir']]]data_dirs = [d.format(ratio) for d in [params['training_dir'], params['validation_dir'], params['test_dir']]]hr_start_idx = ratio * edge // 2hr_end_idx = hr_start_idx + (lr_size - edge) * ratiosub_hr_size = (lr_size - edge) * ratiofor dir_idx, image_dir in enumerate(image_dirs):data_dir = data_dirs[dir_idx]print ("Creating {}".format(data_dir))for root, dirnames, filenames in os.walk(image_dir + "/lr"):for filename in filenames:lr_path = os.path.join(root, filename)hr_path = image_dir + "/hr/" + filenamelr_image = imageio.imread(lr_path)hr_image = imageio.imread(hr_path)# convert to Ycbcr color spacelr_image_y = rgb2ycbcr(lr_image)hr_image_y = rgb2ycbcr(hr_image)lr_data = lr_image_y.reshape((lr_size * lr_size * 3))sub_hr_image_y = hr_image_y[int(hr_start_idx):int(hr_end_idx):1,int(hr_start_idx):int(hr_end_idx):1]hr_data = my_anti_shuffle(sub_hr_image_y, ratio).reshape(sub_hr_size * sub_hr_size * 3)data = np.concatenate([lr_data, hr_data])data.astype('uint8').tofile(data_dir + "/" + filename[0:-4])def remove_images(params):# Don't need old image foldersfor ele in ['training', 'validation', 'test']:rm_dir = params[ele + '_image_dir'].format(params['ratio'])if os.path.isdir(rm_dir):shutil.rmtree(rm_dir)if __name__ == '__main__':with open("./params.json", 'r') as f:params = json.load(f)print("Preparing images with scaling ratio: {}".format(params['ratio']))print ("If you want a different ratio change 'ratio' in params.json")print ("Splitting images (1/3)")prepare_images(params)print ("Preparing data, this may take a while (2/3)")prepare_data(params)print ("Cleaning up split images (3/3)")remove_images(params)print("Done, you can now train the model!")
  • generate.py
import argparse
from PIL import Image
import imageio
import tensorflow as tf
from scipy import ndimage
from scipy import misc
import numpy as np
from prepare_data import *
from psnr import psnr
import json
import pdbfrom espcn import ESPCNdef get_arguments():parser = argparse.ArgumentParser(description='EspcnNet generation script')parser.add_argument('--checkpoint', type=str,help='Which model checkpoint to generate from',default="logdir_2x/train")parser.add_argument('--lr_image', type=str,help='The low-resolution image waiting for processed.',default="images/butterfly_GT.jpg")parser.add_argument('--hr_image', type=str,help='The high-resolution image which is used to calculate PSNR.')parser.add_argument('--out_path', type=str,help='The output path for the super-resolution image',default="result/butterfly_HR")return parser.parse_args()def check_params(args, params):if len(params['filters_size']) - len(params['channels']) != 1:print("The length of 'filters_size' must be greater then the length of 'channels' by 1.")return Falsereturn Truedef generate():args = get_arguments()with open("./params.json", 'r') as f:params = json.load(f)if check_params(args, params) == False:returnsess = tf.Session()net = ESPCN(filters_size=params['filters_size'],channels=params['channels'],ratio=params['ratio'],batch_size=1,lr_size=params['lr_size'],edge=params['edge'])loss, images, labels = net.build_model()lr_image = tf.placeholder(tf.uint8)lr_image_data = imageio.imread(args.lr_image)lr_image_ycbcr_data = rgb2ycbcr(lr_image_data)lr_image_y_data = lr_image_ycbcr_data[:, :, 0:1]lr_image_cb_data = lr_image_ycbcr_data[:, :, 1:2]lr_image_cr_data = lr_image_ycbcr_data[:, :, 2:3]lr_image_batch = np.zeros((1,) + lr_image_y_data.shape)lr_image_batch[0] = lr_image_y_datasr_image = net.generate(lr_image)saver = tf.train.Saver()try:model_loaded = net.load(sess, saver, args.checkpoint)except:raise Exception("Failed to load model, does the ratio in params.json match the ratio you trained your checkpoint with?")if model_loaded:print("[*] Checkpoint load success!")else:print("[*] Checkpoint load failed/no checkpoint found")returnsr_image_y_data = sess.run(sr_image, feed_dict={lr_image: lr_image_batch})sr_image_y_data = shuffle(sr_image_y_data[0], params['ratio'])sr_image_ycbcr_data =np.array(Image.fromarray(lr_image_ycbcr_data).resize(params['ratio'] * np.array(lr_image_data.shape[0:2]),Image.BICUBIC))edge = params['edge'] * params['ratio'] / 2sr_image_ycbcr_data = np.concatenate((sr_image_y_data, sr_image_ycbcr_data[int(edge):int(-edge),int(edge):int(-edge),1:3]), axis=2)sr_image_data = ycbcr2rgb(sr_image_ycbcr_data)imageio.imwrite(args.out_path + '.png', sr_image_data)if args.hr_image != None:hr_image_data = misc.imread(args.hr_image)model_psnr = psnr(hr_image_data, sr_image_data, edge)print('PSNR of the model: {:.2f}dB'.format(model_psnr))sr_image_bicubic_data = misc.imresize(lr_image_data,params['ratio'] * np.array(lr_image_data.shape[0:2]),'bicubic')misc.imsave(args.out_path + '_bicubic.png', sr_image_bicubic_data)bicubic_psnr = psnr(hr_image_data, sr_image_bicubic_data, 0)print('PSNR of Bicubic: {:.2f}dB'.format(bicubic_psnr))if __name__ == '__main__':generate()train.py
```python
from __future__ import print_function
import argparse
from datetime import datetime
import os
import sys
import time
import json
import timeimport tensorflow as tf
from reader import create_inputs
from espcn import ESPCNimport pdbtry:xrange
except Exception as e:xrange = range
# 批次
BATCH_SIZE = 32
# epochs
NUM_EPOCHS = 100
# learning rate
LEARNING_RATE = 0.0001
# logdir
LOGDIR_ROOT = './logdir_{}x'def get_arguments():parser = argparse.ArgumentParser(description='EspcnNet example network')# 权重parser.add_argument('--checkpoint', type=str,help='Which model checkpoint to load from', default=None)# batch_sizeparser.add_argument('--batch_size', type=int, default=BATCH_SIZE,help='How many image files to process at once.')# epochsparser.add_argument('--epochs', type=int, default=NUM_EPOCHS,help='Number of epochs.')# 学习率parser.add_argument('--learning_rate', type=float, default=LEARNING_RATE,help='Learning rate for training.')# logdir_rootparser.add_argument('--logdir_root', type=str, default=LOGDIR_ROOT,help='Root directory to place the logging ''output and generated model. These are stored ''under the dated subdirectory of --logdir_root. ''Cannot use with --logdir.')# 返回参数return parser.parse_args()def check_params(args, params):if len(params['filters_size']) - len(params['channels']) != 1:print("The length of 'filters_size' must be greater then the length of 'channels' by 1.")return Falsereturn Truedef train():args = get_arguments()# load jsonwith open("./params.json", 'r') as f:params = json.load(f)# 存在if check_params(args, params) == False:returnlogdir_root = args.logdir_root # ./logdirif logdir_root == LOGDIR_ROOT:logdir_root = logdir_root.format(params['ratio']) # ./logdir_{RATIO}xlogdir = os.path.join(logdir_root, 'train') # ./logdir_{RATIO}x/train# Load training data as np arrays# 加载数据lr_images, hr_labels = create_inputs(params)#  网络模型net = ESPCN(filters_size=params['filters_size'],channels=params['channels'],ratio=params['ratio'],batch_size=args.batch_size,lr_size=params['lr_size'],edge=params['edge'])loss, images, labels = net.build_model()optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)trainable = tf.trainable_variables()optim = optimizer.minimize(loss, var_list=trainable)# set up logging for tensorboardwriter = tf.summary.FileWriter(logdir)writer.add_graph(tf.get_default_graph())summaries = tf.summary.merge_all()# set up sessionsess = tf.Session()# saver for storing/restoring checkpoints of the modelsaver = tf.train.Saver()init = tf.initialize_all_variables()sess.run(init)if net.load(sess, saver, logdir):print("[*] Checkpoint load success!")else:print("[*] Checkpoint load failed/no checkpoint found")try:steps, start_average, end_average = 0, 0, 0start_time = time.time()for ep in xrange(1, args.epochs + 1):batch_idxs = len(lr_images) // args.batch_sizebatch_average = 0for idx in xrange(0, batch_idxs):# On the fly batch generation instead of Queue to optimize GPU usagebatch_images = lr_images[idx * args.batch_size : (idx + 1) * args.batch_size]batch_labels = hr_labels[idx * args.batch_size : (idx + 1) * args.batch_size]steps += 1summary, loss_value, _ = sess.run([summaries, loss, optim], feed_dict={images: batch_images, labels: batch_labels})writer.add_summary(summary, steps)batch_average += loss_value# Compare loss of first 20% and last 20%batch_average = float(batch_average) / batch_idxsif ep < (args.epochs * 0.2):start_average += batch_averageelif ep >= (args.epochs * 0.8):end_average += batch_averageduration = time.time() - start_timeprint('Epoch: {}, step: {:d}, loss: {:.9f}, ({:.3f} sec/epoch)'.format(ep, steps, batch_average, duration))start_time = time.time()net.save(sess, saver, logdir, steps)except KeyboardInterrupt:print()finally:start_average = float(start_average) / (args.epochs * 0.2)end_average = float(end_average) / (args.epochs * 0.2)print("Start Average: [%.6f], End Average: [%.6f], Improved: [%.2f%%]" \% (start_average, end_average, 100 - (100*end_average/start_average)))if __name__ == '__main__':train()

model 实现tensorflow版本

import tensorflow as tf
import os
import sys
import pdbdef create_variable(name, shape):'''Create a convolution filter variable with the specified name and shape,and initialize it using Xavier initialition.'''initializer = tf.contrib.layers.xavier_initializer_conv2d()variable = tf.Variable(initializer(shape=shape), name=name)return variabledef create_bias_variable(name, shape):'''Create a bias variable with the specified name and shape and initializeit to zero.'''initializer = tf.constant_initializer(value=0.0, dtype=tf.float32)return tf.Variable(initializer(shape=shape), name)class ESPCN:def __init__(self, filters_size, channels, ratio, batch_size, lr_size, edge):self.filters_size = filters_sizeself.channels = channelsself.ratio = ratioself.batch_size = batch_sizeself.lr_size = lr_sizeself.edge = edgeself.variables = self.create_variables()def create_variables(self):var = dict()var['filters'] = list()# the input layervar['filters'].append(create_variable('filter',[self.filters_size[0],self.filters_size[0],1,self.channels[0]]))# the hidden layersfor idx in range(1, len(self.filters_size) - 1):var['filters'].append(create_variable('filter', [self.filters_size[idx],self.filters_size[idx],self.channels[idx - 1],self.channels[idx]]))# the output layervar['filters'].append(create_variable('filter',[self.filters_size[-1],self.filters_size[-1],self.channels[-1],self.ratio**2]))var['biases'] = list()for channel in self.channels:var['biases'].append(create_bias_variable('bias', [channel]))var['biases'].append(create_bias_variable('bias', [float(self.ratio)**2]))image_shape = (self.batch_size, self.lr_size, self.lr_size, 3)var['images'] = tf.placeholder(tf.uint8, shape=image_shape, name='images')label_shape = (self.batch_size, self.lr_size - self.edge, self.lr_size - self.edge, 3 * self.ratio**2)var['labels'] = tf.placeholder(tf.uint8, shape=label_shape, name='labels')return vardef build_model(self):images, labels = self.variables['images'], self.variables['labels']input_images, input_labels = self.preprocess([images, labels])output = self.create_network(input_images)reduced_loss = self.loss(output, input_labels)return reduced_loss, images, labelsdef save(self, sess, saver, logdir, step):# print('[*] Storing checkpoint to {} ...'.format(logdir), end="")sys.stdout.flush()if not os.path.exists(logdir):os.makedirs(logdir)checkpoint = os.path.join(logdir, "model.ckpt")saver.save(sess, checkpoint, global_step=step)# print('[*] Done saving checkpoint.')def load(self, sess, saver, logdir):print("[*] Reading checkpoints...")ckpt = tf.train.get_checkpoint_state(logdir)if ckpt and ckpt.model_checkpoint_path:ckpt_name = os.path.basename(ckpt.model_checkpoint_path)saver.restore(sess, os.path.join(logdir, ckpt_name))return Trueelse:return Falsedef preprocess(self, input_data):# cast to float32 and normalize the datainput_list = list()for ele in input_data:if ele is None:continueele = tf.cast(ele, tf.float32) / 255.0input_list.append(ele)input_images, input_labels = input_list[0][:,:,:,0:1], None# Generate doesn't use input_labelsratioSquare = self.ratio * self.ratioif input_data[1] is not None:input_labels = input_list[1][:,:,:,0:ratioSquare]return input_images, input_labelsdef create_network(self, input_labels):'''The default structure of the network is:input (3 channels) ---> 5 * 5 conv (64 channels) ---> 3 * 3 conv (32 channels) ---> 3 * 3 conv (3*r^2 channels)Where `conv` is 2d convolutions with a non-linear activation (tanh) at the output.'''current_layer = input_labelsfor idx in range(len(self.filters_size)):conv = tf.nn.conv2d(current_layer, self.variables['filters'][idx], [1, 1, 1, 1], padding='VALID')with_bias = tf.nn.bias_add(conv, self.variables['biases'][idx])if idx == len(self.filters_size) - 1:current_layer = with_biaselse:current_layer = tf.nn.tanh(with_bias)return current_layerdef loss(self, output, input_labels):residual = output - input_labelsloss = tf.square(residual)reduced_loss = tf.reduce_mean(loss)tf.summary.scalar('loss', reduced_loss)return reduced_lossdef generate(self, lr_image):lr_image = self.preprocess([lr_image, None])[0]sr_image = self.create_network(lr_image)sr_image = sr_image * 255.0sr_image = tf.cast(sr_image, tf.int32)sr_image = tf.maximum(sr_image, 0)sr_image = tf.minimum(sr_image, 255)sr_image = tf.cast(sr_image, tf.uint8)return sr_image
  • 读取文件
import tensorflow as tf
import numpy as np
import os
import pdbdef create_inputs(params):"""Loads prepared training files and appends them as np arrays to a list.This approach is better because a FIFOQueue with a reader can't utilizethe GPU while this approach can."""sess = tf.Session()lr_images, hr_labels = [], []training_dir = params['training_dir'].format(params['ratio'])# Raise exception if user has not ran prepare_data.py yetif not os.path.isdir(training_dir):raise Exception("You must first run prepare_data.py before you can train")lr_shape = (params['lr_size'], params['lr_size'], 3)hr_shape = output_shape = (params['lr_size'] - params['edge'], params['lr_size'] - params['edge'], 3 * params['ratio']**2)for file in os.listdir(training_dir):train_file = open("{}/{}".format(training_dir, file), "rb")train_data = np.fromfile(train_file, dtype=np.uint8)lr_image = train_data[:17 * 17 * 3].reshape(lr_shape)lr_images.append(lr_image)hr_label = train_data[17 * 17 * 3:].reshape(hr_shape)hr_labels.append(hr_label)return lr_images, hr_labels

psnr计算

import numpy as np
import mathdef psnr(hr_image, sr_image, hr_edge):#assume RGB imagehr_image_data = np.array(hr_image)if hr_edge > 0:hr_image_data = hr_image_data[hr_edge:-hr_edge, hr_edge:-hr_edge].astype('float32')sr_image_data = np.array(sr_image).astype('float32')diff = sr_image_data - hr_image_datadiff = diff.flatten('C')rmse = math.sqrt( np.mean(diff ** 2.) )return 20*math.log10(255.0/rmse)

训练过程有个BUG:bias is not unsupportd,但是也能学习。
在这里插入图片描述


http://www.ppmy.cn/devtools/145073.html

相关文章

如何通过HTTP API新建Collection

本文介绍如何通过HTTP API创建一个新的Collection。 前提条件 已创建Cluster&#xff1a;创建Cluster。 已获得API-KEY&#xff1a;API-KEY管理。 Method与URL HTTP POST https://{Endpoint}/v1/collections 使用示例 说明 需要使用您的api-key替换示例中的YOUR_API_KEY、…

Oracle筑基篇-调度算法-LRU的引入

常见的调度算法 图1 调度算法思维导图 一、LRU算法的典型使用场景 1. 操作系统中的页面置换 什么时候用到页面置换算法呢&#xff1f; 当CPU发出指令需要访问某个地址时&#xff0c;若该地址在TLB&#xff08;Translation Lookaside Buffer&#xff0c;快表&#xff09;或页…

Matplotlib DAY1 (完)

Matplotlib 是支持 Python 语言的开源绘图库&#xff0c;因为其支持丰富的绘图类型、简单的绘图方式以及完善的接口文档&#xff0c;深受 Python 工程师、科研学者、数据工程师等各类人士的喜欢。本次实验课程中&#xff0c;我们将学会使用 Matplotlib 绘图的方法和技巧。 知识…

【蓝碳】基于GEE云计算、多源遥感、高光谱遥感技术、InVEST模型、PLUS模型的蓝碳储量估算;红树林植被指数计算及提取

蓝碳和红树林研究的重要性主要体现在以下几个方面&#xff1a; 1.全球碳循环的关键角色&#xff1a;蓝碳生态系统&#xff0c;包括红树林、盐沼和海草床&#xff0c;虽然覆盖面积不到海床的0.5%&#xff0c;但其碳储量却高达海洋碳储量的50%以上&#xff0c;甚至可能高达71%。红…

java全栈day18--Web后端实战(java操作数据库2)

前言&#xff1a;在上节入门程序当中我们见到了JDBC所提供的API&#xff0c;本节来详细说明一下。 一、JDBC--API详解 1.1DriverManager&#xff08;驱动管理器&#xff09; 回顾&#xff1a;作用获取连接&#xff0c;调用它里面的getConnection。即如下 作用 1.注册驱动解…

单节点calico性能优化

在单节点上部署calicov3273后&#xff0c;发现资源占用 修改calico以下配置是资源消耗降低 1、因为是单节点&#xff0c;没有跨节点pod网段组网需要&#xff0c;禁用overlay方式网络(ipip&#xff0c;vxlan),使用route方式网络 配置calico-node的环境变量 CALICO_IPV4POOL_I…

16×16LED点阵字符滚动显示-基于译码器与移位寄存器(设计报告+仿真+单片机源程序)

资料下载地址&#xff1a;​1616LED点阵字符滚动显示-基于译码器与移位寄存器(设计报告仿真单片机源程序)​ 1、功能介绍 设计1616点阵LED显示器的驱动电路&#xff0c;并编写程序实现在1616点阵LED显示器上的字符滚动显示。1616点阵LED显示器可由4块88点阵LED显示器构成。可采…

Scala图书管理系统

项目创建并实现基础UI package org.appimport scala.io.StdInobject Main {def main(args: Array[String]): Unit {var running truewhile (running) {println("欢迎来到我的图书管理系统&#xff0c;请选择")println("1.查看所有图书")println("2…