Tensorflow 2.0 cnn训练cifar10 准确率只有0.1 [已解决]

server/2024/10/19 0:06:09/

cifar10 准确率只有0.1

  • 问题描述
  • 踩坑
  • 解决办法

问题描述

如果你看的是北京大学曹健老师的tensorflow2.0,你在class5的部分可能会遇见这个问题

import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Dense, Dropout,MaxPooling2D,Flatten,Conv2D,BatchNormalization,Activation
from tensorflow.keras import Model
import os
import numpy as np# np.set_printoptions(threshold=np.inf)class Baseline(Model):def __init__(self):super(Baseline, self).__init__()self.conv1 = Conv2D(6, (5,5), activation='sigmoid')self.pool1 = MaxPooling2D(pool_size=(2,2),strides=2)self.conv2 = Conv2D(16, (5,5), activation='sigmoid')self.pool2 = MaxPooling2D(pool_size=(2,2),strides=2)self.flatten1 = Flatten()self.f1=Dense(120,activation='sigmoid')self.f2=Dense(84,activation='sigmoid')self.f3=Dense(10,activation='softmax')def call(self,x):x = self.conv1(x)x = self.pool1(x)x = self.conv2(x)x = self.pool2(x)x = self.flatten1(x)x = self.f1(x)x = self.f2(x)y = self.f3(x)return y(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train,x_test = x_train/255.0,x_test/255.0model = Baseline()
model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])checkpoint_save_path="lenet.ckpt"
if os.path.exists(checkpoint_save_path+'.index'):model.load_weights(checkpoint_save_path)print("---------------------Loaded model---------------")cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True,save_best_only=True, verbose=1)history=model.fit(x_train,y_train,batch_size=32, epochs=5, validation_data=(x_test, y_test),validation_freq=1,callbacks=[cp_callback])
model.summary()file=open('weights_lenet.txt','w')
for v in model.trainable_variables:file.write(str(v.name)+'\n')file.write(str(v.shape)+'\n')file.write(str(v.numpy())+'\n')
file.close()train_acc=history.history['sparse_categorical_accuracy']
val_acc=history.history['val_sparse_categorical_accuracy']
loss=history.history['loss']
val_loss=history.history['val_loss']plt.subplot(1,2,1)
plt.plot(loss,label='train_loss')
plt.plot(val_loss,label='val_loss')
plt.title('model loss')
plt.legend()plt.subplot(1,2,2)
plt.plot(train_acc,label='train_acc')
plt.plot(val_acc,label='val_acc')
plt.title('model acc')
plt.legend()
plt.show()

代码写的看起来没有问题,但是就是acc一直在0.1,总共10个类,也就是说网络根本没有训练效果,就是瞎蒙的。为什么会这样呢。想知道答案的直接跳到最后。下面是我踩的坑,

踩坑

我尝试升级tensorflow版本,但是我们知道升级tensorflow,对应的cudatoolkit 和cudnn 也要升级,官网版本对应

在这里插入图片描述
conda install cudatoolkit==11.2.0

但是我去安装的时候显示PackagesNotFoundError: The following packages are not available from current channels:
搜不到这个版本,conda search cudatoolkit查看可以安装的版本
在这里插入图片描述就是没有11.2,这就很烦人,
我电脑环境是

windows11
cuda 12.3
cudnn 8.9.7

我不能把电脑cuda卸载重新装,因为我pytorch要求的是上面的环境。我尝试去官网再安装一个cuda但是失败了(想试一下windows电脑能不能安装两个cuda)。总之折腾了一下午

解决办法

方法一

cudatoolkit 和cudnn保持不变,直接升级tensorflow
pip install tensorflow==2.4
但是这样就不能用gpu训练了,跑代码的时候用的是cpu,具体原因我也不是很清楚,

方法二

看我之前的文章,卸载电脑上的cuda安装,安装cuda11.2和对应的cudnn8.1
cuda下载地址
cudnn下载地址
然后安装tensoflow 2.10版本
conda install tensorflow_gpu==2.10.0


你windows电脑如果想同时可以跑tensorflow和pytorch,建议电脑的cuda环境就按照tensorflow的安装。
因为pytorch安装比较简单,一般会自带对应的cuda,而tensorflow对cuda要求比较严格,用指令(conda install cudatoolkit==11.2.0 )一般找不到对应的版本,只能去官网下载

windows要是想跑代码就用pytorch吧,tensorflow对windows真的很不友好,tensorflow2.10以上直接不支持了,可以用实验室的服务器跑tensorflow代码


http://www.ppmy.cn/server/123066.html

相关文章

理解Web3:去中心化互联网的基础概念

随着科技的不断进步,互联网的形态也在不断演变。从最初的静态网页(Web1)到动态的社交网络(Web2),如今我们正步入一个新的阶段——Web3。这一新兴概念不仅代表了一种技术革新,更是一种互联网使用…

第五章 继承、多态、抽象类与接口 (7)

5.7 多态 利用多态可以使程序具有良好的扩展性,并可以对所有类对象进行通用的处理。在7.3节中已经学习过子类对象可以被作为父类的对象实例使用,这种将子类对象视为父类对象的做法称为“向上转型”。 假如现在要编写一个绘制图形的方法 draw(, 如果传入正…

「iOS」——KVC

iOS学习 前言KVC模式KVC设值KVC取值KVC使用keyPathKVC处理异常处理不存在的key处理nil异常 KVC处理字典KVC高阶消息传递 总结 前言 对KVC模式的简单学习和总结。 KVC模式 KVC(Key-Value Coding,键值编码)是一种通过字符串来访问对象属性的机…

力扣 简单 111.二叉树的最小深度

文章目录 题目介绍题解 题目介绍 题解 最小深度:从根节点到最近叶子结点的最短路径上节点数量 class Solution {public int minDepth(TreeNode root) {if (root null) {return 0;}int left minDepth(root.left);int right minDepth(root.right);// 如果 node 没…

关于智人和 AI 的负反馈

“夫物芸芸,各复归其根。归根曰静,静曰复命。复命曰常,知常曰明。”《道德经》名句感悟。 总体而言这是递进循环论的核心,联系我想到的一个简单负反馈:为什么年轻脑梗患者逐年增多? 大意是人的优良基因会促进医疗技…

企业微信 标准年级对照表

家校通讯录支持设置标准年级,企业微信会根据入学年份和标准年级自动生成部门。各个标准年级的对应值如下 标准年级名称 参数值 非标准年级 0 幼儿园小小班 1 幼儿园小班 2 幼儿园中班 3 幼儿园大班 4 幼儿园学前班 5 小学一年级 31 小学二年级 32 小学三年级 33 小…

PHP中如何使用三元条件运算符

三元条件运算符简介 PHP中的三元条件运算符是一个简化的if-else语句,它允许你在一行代码中完成条件判断和赋值。其基本语法如下: 条件 ? 表达式1 : 表达式2; 条件:是一个表达式,其结果将被评估为TRUE或FALSE。表达式1&#xf…

Windows环境下Node.js多版本切换的实用指南

Web开发和全栈开发中,Node.js已成为不可或缺的工具之一。然而,随着项目的多样化和技术栈的更新迭代,我们可能需要同时管理多个Node.js版本以满足不同项目的需求。在Windows环境下,如何高效地切换这些版本成为了一个关键问题。简单…