从零开始玩转TensorFlow:小明的机器学习故事 5

devtools/2025/2/25 5:51:01/

图像识别的挑战

1 故事引入:小明的“图像识别”大赛

小明从学校里听说了一个有趣的比赛:“美食图像识别”。参赛者需要训练计算机,看一张食物照片(例如披萨、苹果、汉堡等),就能猜出这是什么食物。听起来非常酷,但是怎么让计算机“看懂”图片呢?

小明想:
“传统的神经网络(全连接网络)之前用来做数字分类效果还行,但这些美食图像不但颜色复杂,而且分辨率更高,直接用全连接网络可能会效果一般。听说有个东西叫卷积神经网络(CNN),特别适合图像识别。我就来好好研究一下吧!”


2 为什么CNN更擅长看图?

2.1 普通神经网络:所有像素一锅端

在最简单的全连接网络里,每个神经元都要处理图像里所有像素的信息。想象下:

  • 你有一大张图,每个位置都有信息。
  • “侦探们”都拥挤在一起,每个人想“盯”整个图像。
  • 数据一多就会非常混乱,大家根本不好分工,效率也低。
2.2 卷积神经网络:给侦探们分配“放大镜”

CNN 里有一层又一层的小“滤镜”(卷积核),它们就像给侦探们每人发了一个“放大镜”,让他们专注查看图像上的某一块区域,从而提取“边缘”“角”“颜色块”等重要特征。

  • 卷积层(Convolution Layer):这个层就负责把一张图分成小块挨个扫描,发现局部特征。
  • 池化层(Pooling Layer):把提取到的“局部发现”做简化,让整体数据量更小;同时,对位置的小幅变化也更有耐心。
  • 全连接层(Dense Layer):将各个层提取到的特征进行综合决策,最终输出“这是苹果呢?还是披萨?还是汉堡?”。

正因为这样分工协作,CNN 在图像识别任务上往往能大显身手。


3 小明的热身实验:CIFAR-10

比赛数据通常比较大,小明想先试试 CNN 的“套路”。他找来 CIFAR-10 这个小数据集做热身。CIFAR-10 共有 10 个类别(如飞机、汽车、鸟、猫、船等),每张图都是彩色的,但只有 32×32 像素,大小和图像内容都比较简单,正好适合入门。

3.1 准备数据与环境
import tensorflow as tf
from tensorflow import keras# 下载并加载CIFAR-10数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()# 缩放像素值到[0,1]区间
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0num_classes = 10  # 一共10个类别
  1. x_train, y_train:训练数据,包含图像及其对应标签。
  2. x_test, y_test:测试数据,用来在训练结束后考核模型。
  3. 归一化:像素值从 0~255 变成 0~1,加快模型收敛。

3.2 CNN 模型结构:小明的“放大镜团队”
model = keras.Sequential([# 第一组:卷积 + 池化keras.layers.Conv2D(32, (3,3), activation='relu', padding='same',input_shape=(32, 32, 3)),keras.layers.MaxPooling2D((2,2)),# 第二组:卷积 + 池化keras.layers.Conv2D(64, (3,3), activation='relu', padding='same'),keras.layers.MaxPooling2D((2,2)),# 将特征图展开,再接全连接层keras.layers.Flatten(),keras.layers.Dense(128, activation='relu'),keras.layers.Dense(num_classes, activation='softmax')
])model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy']
)
model.summary()
  1. Conv2D(32, 3×3):32 个滤镜,每个滤镜关注 3×3 区域;padding='same' 表示卷积后图大小不变。
  2. MaxPooling2D(2×2):对特征图做 2×2 的“精华提取”。
  3. Flatten:把卷积/池化后得到的多维特征图拉平成一维数组。
  4. Dense(128):中间的全连接层,用于更深层次的特征整合。
  5. Dense(num_classes):输出 10 个类别(CIFAR-10 就是 0~9 这 10 类)。

3.3 模型训练:让侦探团队学会识图
history = model.fit(x_train, y_train,epochs=10,batch_size=64,validation_split=0.2
)
  • 训练过程:喂给网络很多训练图片,网络会先猜测是哪一类,然后根据“猜的结果”和“真实标签”之间的差距来更新参数。
  • epochs:训练轮数;batch_size:一次要处理多少张图片。
  • validation_split=0.2:从训练集中划出 20% 的数据做验证,便于实时观察模型的泛化能力。

3.4 训练成果可视化:准确率与损失曲线

在实际操作中,我们通常还会绘制训练和验证的准确率(accuracy)和损失值(loss),看看模型是否正在稳步提高。

  • 如果验证准确率突然下降,说明可能出现 过拟合
  • 如果训练准确率和验证准确率都一起上升,恭喜你,模型健康成长。

以下是一段常见的可视化示例代码,演示如何使用 matplotlib 绘制训练过程中 准确率(accuracy)损失值(loss) 的变化曲线。

import matplotlib.pyplot as plt# 从 history 对象中获取训练过程中的准确率和损失值
acc = history.history['accuracy']              # 训练准确率
val_acc = history.history['val_accuracy']      # 验证准确率
loss = history.history['loss']                 # 训练损失
val_loss = history.history['val_loss']         # 验证损失epochs_range = range(len(acc))  # 横坐标:训练的轮数plt.figure(figsize=(12, 5))# 1. 绘制准确率曲线
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.title('Accuracy Over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')# 2. 绘制损失值曲线
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.title('Loss Over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend(loc='upper right')plt.show()

代码解释

  1. 获取准确率和损失值

    • history.history['accuracy']:训练集上的准确率随 epoch 的变化。
    • history.history['val_accuracy']:验证集上的准确率随 epoch 的变化。
    • history.history['loss']history.history['val_loss']:分别是训练集和验证集的损失值。
  2. 绘制图像

    • plt.subplot(1, 2, 1)plt.subplot(1, 2, 2) 用于将图像一分为二,分别在左、右两边绘制准确率和损失值曲线。
    • plt.plot(epochs_range, ...):将不同 epoch 的值连成线,观察趋势。
    • plt.title()plt.xlabel()plt.ylabel():添加标题和坐标轴标签,便于阅读。
    • plt.legend():显示图例,区分训练曲线和验证曲线。

运行该段代码后,你会看到两个并排的折线图:左边是 准确率 随训练轮数的变化,右边是 损失值 随训练轮数的变化。通过它们,你可以直观判断模型是否持续学习收敛,或是否出现 过拟合(若验证准确率下降或验证损失上升,而训练集表现持续改善,往往说明过拟合)。

示例输出:
在这里插入图片描述

4 最终的“毕业考”:测试集 & 图片可视化

小明已经把 CNN 训练好了,现在是让模型“毕业考”的时候,也就是在测试集 (x_test, y_test) 上检验它的表现。

test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print("Test accuracy:", test_acc)

示例输出:

Test accuracy: 0.7005000114440918

成功得到一个满意的准确率后,小明心中一阵窃喜:
“看来这‘放大镜团队’果然名不虚传!”

4.1 随机挑几张测试图,看看模型预测得如何
import random
import numpy as np
import matplotlib.pyplot as pltindices = random.sample(range(len(x_test)), 5)
for i in indices:img = x_test[i]label = y_test[i][0]pred = model.predict(img[np.newaxis, ...])pred_label = np.argmax(pred, axis=1)[0]plt.imshow(img)plt.title(f"Real: {label}, Pred: {pred_label}")plt.show()
  • random.sample(...):随机挑几个测试样本进行展示。
  • plt.imshow(img):显示这张 32×32 像素的图像。
  • Real: {label}:数据集中给出的真实标签。
  • Pred: {pred_label}:模型预测的标签。

当图像分辨率低时,我们人眼看起来可能会觉得比较模糊。但只要模型学到了“像素间的相关性”,就能“看”得出哪张更像猫、哪张更像船。

示例输出:
在这里插入图片描述
在这里插入图片描述


5 解读示例:当模型识别到一艘船

下图是一个示例输出(当你运行上面代码时,可能得到不同的随机图)。假设标题写着:

Real: 8, Pred: 8

并且画面看上去有一点类似“海面上白色物体”的模糊画面。

  • CIFAR-10 标签 8 通常代表“船”(ship)
  • 虽然图像分辨率低,但海洋深蓝和船体白色的对比能给网络提供重要线索。
  • 模型预测也给出了“8”,说明它成功认定这是“船”。
  • 如果你在比赛里,这是一个 “预测正确” 的案例。

如果实际图像是“狗”,然而模型给了“猫”之类的预测,就可以通过观察图像特征、数据增强或调整网络结构来做进一步优化。这样的人机对照可以帮你快速找到模型的盲点。


6 故事总结:小明的收获

  1. 卷积神经网络为什么行?
    • 通过“局部扫描 + 池化精简”,CNN 能更好地从图像中提取关键特征,且参数量更少,不容易过拟合。
  2. CIFAR-10先热身
    • 小明在 CIFAR-10 上先试手,发现 CNN 能得到不错的效果,足以说明原理可行。
  3. 看结果很重要
    • 训练曲线可以帮助判断模型是否过拟合或仍在上升空间。
    • 最终在测试集上打印几张预测结果,能让人更直观地看到“模型脑子里想的”和“真实情况”一致性如何。
  4. 阶段性成就感
    • 完整走完这套流程,小明对 CNN 有了更深理解,也为他在 “美食图像识别” 大赛上取得好成绩打下坚实基础。

最终寄语

通过这一章节,我们以 小明的故事 为主线,先介绍为什么卷积神经网络(CNN)更适合图像,再到如何用一个小练习数据集(CIFAR-10)搭建模型、训练、测试,并可视化预测结果。

  • 逻辑链路:问题(美食识别)→ 为什么CNN → 如何CNN → 小实验→ 看结果→ 结果分析
  • 核心方法:理解卷积、池化、全连接的作用,知道如何用代码实现并调试。
  • 图像可视化:从中可以直观看到模型预测是否正确,尤其是像CIFAR-10这种低分辨率图像,对人眼来说模糊,但模型却能学到对应特征。如果出现“预测错”,就能帮助我们快速找到改进思路。

现在,小明和你都算踏进了 “计算机视觉” 这个广阔的领域,下一步可以在更大、更真实的数据上继续折腾!别忘了多实践、多观察,最后说不定你也会在比赛中取得令人惊喜的成绩。加油!


http://www.ppmy.cn/devtools/161504.html

相关文章

2007年诺基亚内部对iPhone的竞争分析报告

2007年iPhone发布后,诺基亚内部至少有9名员工指出其触屏界面、互联网整合能力将颠覆市场,并建议开发同类产品,但高管因当时占据全球50%市场份额而轻视威胁,认为苹果的高价和虚拟键盘会限制其普及。 诺基亚虽然意识到需推出触屏手机…

基于ffmpeg+openGL ES实现的视频编辑工具-添加背景音乐(十一)

在视频编辑领域,为视频添加背景音乐并实现音频的完美融合是一项关键任务。在上一篇文章中,我们大体介绍了添加背景音乐的整体逻辑,而本文将深入探讨其中音频合并所依赖的滤镜逻辑,通过对相关代码的详细解读,揭示音频合并的核心技术。 一、音频合并滤镜类的初始化 AudioA…

Python Seaborn库使用指南:从入门到精通

1. 引言 Seaborn 是基于 Matplotlib 的高级数据可视化库,专为统计图表设计。它提供了更简洁的 API 和更美观的默认样式,能够轻松生成复杂的统计图表。Seaborn 在数据分析、机器学习和科学计算领域中被广泛使用。 本文将详细介绍 Seaborn 的基本概念、常用功能以及高级用法,…

无监督机器学习算法

K-均值聚类是一种常用的无监督机器学习算法,用于将数据集中的样本分成 K 个不同的簇。其工作原理如下: 1. 随机选择 K 个数据点作为初始的簇中心。 2. 将每个数据点分配到距离最近的簇中心所属的簇。 3. 更新每个簇的中心,即取该簇中所有数据…

Kafka面试题----如何保证Kafka消费者在消费过程中不丢失消息

合理配置消费者参数 enable.auto.commit:设置为 false,关闭自动提交偏移量。自动提交偏移量存在一定的时间间隔,在这个间隔内如果消费者出现异常,可能会导致部分消息被重复消费或者丢失。关闭自动提交后,由开发者手动…

鸿蒙开发深入浅出04(首页数据渲染、搜索、Stack样式堆叠、Grid布局、shadow阴影)

鸿蒙开发深入浅出04(首页数据渲染、搜索、Stack样式堆叠、Grid布局、shadow阴影) 1、效果展示2、ets/pages/Home.ets3、ets/views/Home/SearchBar.ets4、ets/views/Home/NavList.ets5、ets/views/Home/TileList.ets6、ets/views/Home/PlanList.ets7、后端…

PCF8591一次读取多条通道导致测量值不准确的原因及解决方法

使用PCF8591测量通道电压的时候,只测量一个通道电压是正常的,但是要测量两个通道的电压时,会异常显示。 产生原因 时序精度不够 PCF8591通过选择不同的通道进行模拟信号采样。每次转换前,通道的选择需要一定的时间,…

便携式动平衡仪Qt应用层详细设计说明书

便携式动平衡仪Qt应用层详细设计说明书 (DDD) 版本:1.1 日期:2023年10月 一、文档目录 系统概述应用层架构设计模块详细设计接口定义与数据流关键数据结构代码框架与实现测试计划附录 二、系统概述 2.1 功能需求 开机流程:长按电源键启动…