tensorflow + pygame 手写数字识别的小游戏

embedded/2024/10/20 9:33:19/

起因, 目的:

很久之前,一个客户的作业,我帮忙写的。
今天删项目,觉得比较简洁,发出来给大家看看。

效果图:

在这里插入图片描述

1. 训练模型的代码
import sys
import tensorflow as tf# Use MNIST handwriting dataset
mnist = tf.keras.datasets.mnist# Prepare data for training
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train = tf.keras.utils.to_categorical(y_train)
y_test = tf.keras.utils.to_categorical(y_test)
x_train = x_train.reshape(x_train.shape[0], x_train.shape[1], x_train.shape[2], 1
)
x_test = x_test.reshape(x_test.shape[0], x_test.shape[1], x_test.shape[2], 1
)"""
model = tf.keras.models.Sequential([tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),tf.keras.layers.MaxPooling2D((2, 2)),tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dropout(0.2),tf.keras.layers.Dense(10, activation='softmax')
])"""# Create a convolutional neural network
model = tf.keras.models.Sequential([# 1.  Convolutional layer. Learn 32 filters using a 3x3 kernel, activation function is relu, input shape (28,28,1)tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),#2. Max-pooling layer, using 2x2 pool sizetf.keras.layers.MaxPooling2D((2, 2)),#3.  Flatten unitstf.keras.layers.Flatten(),#4. Add a hidden layer with dropout,tf.keras.layers.Dropout(0.2),#5. Add an output layer with output units for all 10 digits, activation function is softmaxtf.keras.layers.Dense(10, activation='softmax')])# Train neural network
model.compile(optimizer="adam",loss="categorical_crossentropy",metrics=["accuracy"]
)
model.fit(x_train, y_train, epochs=10)# Evaluate neural network performance
model.evaluate(x_test,  y_test, verbose=2)# Save model to file
if len(sys.argv) == 2:filename = sys.argv[1]model.save(filename)print(f"Model saved to {filename}.")"""
Run this code:  python handwriting.py model_1.pthoutput:1875/1875 [==============================] - 10s 5ms/step - loss: 0.0413 - accuracy: 0.9873
Epoch 8/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0385 - accuracy: 0.9877
Epoch 9/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0338 - accuracy: 0.9898
Epoch 10/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0319 - accuracy: 0.9900
313/313 - 1s - loss: 0.0511 - accuracy: 0.9845 - 718ms/epoch - 2ms/step
Model saved to model_1.pth.
"""
2. 运行小游戏, 进行识别

从命令行运行:

python recognition.py model.h5


import numpy as np
import pygame
import sys
import tensorflow as tf
import time"""
run this code:
python recognition.py  model_1.pth or,  python recognition.py  model.h5 output:
"""print("len(sys.argv): ", len(sys.argv))# Check command-line arguments
if len(sys.argv) != 2:print("Usage: python recognition.py model")sys.exit()model = tf.keras.models.load_model(sys.argv[1])# Colors
BLACK = (0, 0, 0)
WHITE = (255, 255, 255)# Start pygame
pygame.init()
size = width, height = 600, 400
screen = pygame.display.set_mode(size)# Fonts
OPEN_SANS = "assets/fonts/OpenSans-Regular.ttf"
smallFont = pygame.font.Font(OPEN_SANS, 20)
largeFont = pygame.font.Font(OPEN_SANS, 40)ROWS, COLS = 28, 28OFFSET = 20
CELL_SIZE = 10handwriting = [[0] * COLS for _ in range(ROWS)]
classification = Nonewhile True:# Check if game quitfor event in pygame.event.get():if event.type == pygame.QUIT:sys.exit()screen.fill(BLACK)# Check for mouse pressclick, _, _ = pygame.mouse.get_pressed()if click == 1:mouse = pygame.mouse.get_pos()else:mouse = None# Draw each grid cellcells = []for i in range(ROWS):row = []for j in range(COLS):rect = pygame.Rect(OFFSET + j * CELL_SIZE,OFFSET + i * CELL_SIZE,CELL_SIZE, CELL_SIZE)# If cell has been written on, darken cellif handwriting[i][j]:channel = 255 - (handwriting[i][j] * 255)pygame.draw.rect(screen, (channel, channel, channel), rect)# Draw blank cellelse:pygame.draw.rect(screen, WHITE, rect)pygame.draw.rect(screen, BLACK, rect, 1)# If writing on this cell, fill in current cell and neighborsif mouse and rect.collidepoint(mouse):handwriting[i][j] = 250 / 255if i + 1 < ROWS:handwriting[i + 1][j] = 220 / 255if j + 1 < COLS:handwriting[i][j + 1] = 220 / 255if i + 1 < ROWS and j + 1 < COLS:handwriting[i + 1][j + 1] = 190 / 255# Reset buttonresetButton = pygame.Rect(30, OFFSET + ROWS * CELL_SIZE + 30,100, 30)resetText = smallFont.render("Reset", True, BLACK)resetTextRect = resetText.get_rect()resetTextRect.center = resetButton.centerpygame.draw.rect(screen, WHITE, resetButton)screen.blit(resetText, resetTextRect)# Classify buttonclassifyButton = pygame.Rect(150, OFFSET + ROWS * CELL_SIZE + 30,100, 30)classifyText = smallFont.render("Classify", True, BLACK)classifyTextRect = classifyText.get_rect()classifyTextRect.center = classifyButton.centerpygame.draw.rect(screen, WHITE, classifyButton)screen.blit(classifyText, classifyTextRect)# Reset drawingif mouse and resetButton.collidepoint(mouse):handwriting = [[0] * COLS for _ in range(ROWS)]classification = None# Generate classificationif mouse and classifyButton.collidepoint(mouse):classification = model.predict([np.array(handwriting).reshape(1, 28, 28, 1)]).argmax()# Show classification if one existsif classification is not None:classificationText = largeFont.render(str(classification), True, WHITE)classificationRect = classificationText.get_rect()grid_size = OFFSET * 2 + CELL_SIZE * COLSclassificationRect.center = (grid_size + ((width - grid_size) / 2),100)screen.blit(classificationText, classificationRect)pygame.display.flip()

完整项目,我已经上传了。 0积分下载。

完整项目链接
https://download.csdn.net/download/waterHBO/89881853

老哥留步,支持一下。

请求支持


http://www.ppmy.cn/embedded/128957.html

相关文章

uiautomatorviewer安卓9以上正常使用及问题处理

一、安卓9以上使用uiautomatorviewer问题现象 打开Unexpected error while obtaining UI hierarchy 问题详情 Unexpected error while obtaining UI hierarchy java.lang.reflect.InvocationTargetException 二、问题处理 需要的是替换对应D:\software\android-sdk-windows…

PicoQuant GmbH公司Dr. Christian Oelsner到访东隆科技

昨日&#xff0c;德国PicoQuant公司的光谱和显微应用和市场专家Dr.Christian Oelsner莅临武汉东隆科技有限公司。会议上Dr. Christian Oelsner就荧光寿命光谱和显微技术的最新研究和应用进行了深入的交流与探讨。此次访问不仅加强了两家公司在高科技领域的合作关系&#xff0c;…

HDLBits参考答案合集

关注 望森FPGA 查看更多FPGA资讯 这是望森的第 26 期分享 作者 | 望森 来源 | 望森FPGA 本节内容是HDLBits参考答案合集的索引。 恭喜HDLBits合集完结~ 敬请关注新的专栏内容“FPGA理论基础” 前言 FPGA新手必用&#xff0c;Verilog HDL编程学习网站推荐 —— HDLBits_veri…

如何在Matlab界面中添加文件选择器?

在Matlab中&#xff0c;为用户提供交互式文件选择功能是非常重要的&#xff0c;尤其是当你需要让用户从文件系统中选择文件进行进一步处理时。Matlab提供了uigetfile函数&#xff0c;允许用户通过图形界面选择文件。以下是如何在Matlab界面中添加文件选择器的详细指南&#xff…

单独配置LVS负载均衡服务器+web

注意&#xff1a; (1) lvs基于四层ip端口转发&#xff0c;不支持后的realserver配置多个虚拟主机&#xff0c;可以在lvs上配置基于不同ip端口的虚拟主机 (2) LVS-DR模式中&#xff0c;负载均衡服务器&#xff08;LB&#xff09;和后端真实服务器&#xff08;realserver&#x…

阿里Dataworks使用循环节点和赋值节点完成对mongodb分表数据同步

背景 需求将MongoDB数据入仓MaxCompute 环境说明 MongoDB 100个Collections&#xff1a;orders_1、orders_2、…、orders_100 前期准备 1、MongoDB数据源配置 需要先保证DW和MongoDB网络是能够联通的&#xff0c;需要现在集成任务中配置MongoDB的数据源信息。 具体可以查…

Perl打印9x9乘法口诀

本章教程主要介绍如何用Perl打印9x9乘法口诀。 一、程序代码 1、写法① use strict; # 启用严格模式&#xff0c;帮助捕捉变量声明等错误 use warnings; # 启用警告&#xff0c;帮助发现潜在问题# 遍历 1 到 9 的数字 for my $i (1..9) {# 对于每个 $i&#xff0c;遍历 1…

Django 序列化serializers

在Django中&#xff0c;序列化通常指的是将数据库中的模型数据转换为JSON、XML或其他格式的过程。Django提供了内置的序列化工具&#xff0c;可以通过django.core.serializers模块进行序列化操作。 当你使用Django的序列化功能时&#xff0c;可以序列化以下两种对象类型&#…