利用CNN进行图片简单6分类,数据集为6中车型网上爬取的,这里进行一系列数据预处理后,进行CNN卷积。
数据集部分展示
代码展示
#encoding = utf-8
"""
@author:syj
@file:img_分类.py
@time:2019/09/27 14:05:47
"""
#导库
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2
from sklearn.model_selection import train_test_split# 使用GPU
import os
os.environ["CUDA_VISIBLE_DEVICES"]= "3"# 随机种子
tf.set_random_seed(77)# 数据集路径
cat_dir = r'C:\Users\Administrator\Desktop\car_datas\车1'#数据读取转换为矩阵 标签分类
def load_data(name_class):num = 0 #数据集总数images_data = [] #样本labels_data = [] #标签# 循环读取for i in name_class:for k in os.listdir((cat_dir + '/' + i)): #得到图片名字 房车_0.jpgimg = plt.imread(cat_dir + '/' + i + '/' + k) #plt可以读取中文img = cv2.resize(img, (64,64)) #所有图片转化为64*64*3img_array = np.array(img) #转化为数组img_array = img_array / 127.5 - 1 #归一化 -1到1images_data.append(img_array) #添加到列表# 分类if k[:2] == '卡车':labels_data.append(0)elif k[:2] == '房车':labels_data.append(1)elif k[:2] == '摩托':labels_data.append(2)elif k[:2] == '自行':labels_data.append(3)elif k[:2] == '越野':labels_data.append(4)else:labels_data.append(5)num += 1 #数据集总数img_array = np.array(images_data)lab_array = np.array(labels_data)return img_array,lab_array,numname_class = os.listdir(cat_dir) #路径
print(name_class)num_class = len(name_class)# 洗牌
def shuffle_data(imgage_data,labels_data,num):p = np.random.permutation(num)imgage_data = imgage_data[p]labels_data = labels_data[p]return imgage_data,labels_data# 调用
imgage_data,labels_data,num = load_data(name_class)
imgage_data,labels_data = shuffle_data(imgage_data,labels_data,num)print(imgage_data.shape)
print(labels_data.shape)# 切分
train_x,test_x,train_y,test_y = train_test_split(imgage_data,labels_data,test_size=0.2,random_state=7)# 站位
x = tf.placeholder(tf.float32,[None,64,64,3])
y = tf.placeholder(tf.int64,[None])# 失活 全连接防止过拟合
keep_prob = tf.placeholder(tf.float32)# 根据批次切分
x_image_arr = tf.split(x,num_or_size_splits=100,axis=0)result_x_image_arr = []# 循环读取优化数据
for x_single_image in x_image_arr:x_single_image = tf.reshape(x_single_image,[64,64,3])#随机翻转data_aug_1 = tf.image.random_flip_left_right(x_single_image)#调整光照data_aug_2 = tf.image.random_brightness(data_aug_1,max_delta=63)#改变对比度data_aug_3 = tf.image.random_contrast(data_aug_2,lower=0.2,upper=1.8)#白化data_aug_4 = tf.image.per_image_standardization(data_aug_3)x_single_image = tf.reshape(data_aug_4,[1,64,64,3])result_x_image_arr.append(x_single_image)
result_x_images = tf.concat(result_x_image_arr,axis=0)# 全连接
conv1 = tf.layers.conv2d(result_x_images,32,(3,3),padding='same',activation=tf.nn.relu)
conv1 = tf.layers.batch_normalization(conv1,momentum=0.7) #防止过拟合
pooling1 = tf.layers.max_pooling2d(conv1,(2,2),(2,2))conv2 = tf.layers.conv2d(pooling1,64,(3,3),padding='same',activation=tf.nn.relu)
conv2 = tf.layers.batch_normalization(conv2,momentum=0.7)
pooling2 = tf.layers.max_pooling2d(conv2,(2,2),(2,2))conv3 = tf.layers.conv2d(pooling2,128,(3,3),padding='same',activation=tf.nn.relu)
conv3 = tf.layers.batch_normalization(conv3,momentum=0.7)
pooling3 = tf.layers.max_pooling2d(conv3,(2,2),(2,2))flatten = tf.layers.flatten(pooling3)# 全连接
fc = tf.layers.dense(flatten,625,activation=tf.nn.tanh)
fc = tf.nn.dropout(fc,keep_prob=keep_prob)
a5 = tf.layers.dense(fc,6)# 代价
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,logits=a5)# 优化器
optimizer = tf.train.AdamOptimizer(0.00005).minimize(cost)# 准确率
pre = tf.argmax(a5,1)
accuracy = tf.reduce_mean(tf.cast(tf.equal(pre,y),tf.float32))sess = tf.Session()
sess.run(tf.global_variables_initializer())# 利用批次循环训练
step = 0
for i in range(1,3001):c,a,_ = sess.run([cost,accuracy,optimizer],feed_dict={x:train_x[step:step+100],y:train_y[step:step+100],keep_prob:0.7})step += 100if step >= train_x.shape[0]:step = 0if i % 500 == 0:print(i,np.mean(c),a)step1 = 0
all_acc = []
for i in range(5):a1 = sess.run(accuracy,feed_dict={x:test_x[step1:step1+100],y:test_y[step1:step1+100],keep_prob:1})step1 += 100all_acc.append(a1)
print(np.mean(all_acc))
效果展示
['房车', '自行车图片', '跑车', '越野车', '摩托车', '卡车'](2752, 64, 64, 3)
(2752,)300 0.8921075 0.67
600 0.6069706 0.81
900 0.30461997 0.92
1200 0.3142417 0.93
1500 0.16324146 0.98
1800 0.08101442 0.99
2100 0.08600599 0.99
2400 0.040265616 1.0
2700 0.035595465 1.0
3000 0.016259683 1.0
0.764刚入手代码精度还在调,后期持续更新