这是一个基于 PyQt5 和 TensorFlow 的手写数字识别程序,主要功能如下:
1.用户界面
- 主窗口:包含加载图片、识别、清除按钮,以及图片显示区域和结果展示区域。
- 图片显示:支持显示原始图片和处理后的图片。
- 结果展示:显示识别结果和置信度。
2. 核心功能
- 加载图片:用户可以通过点击“加载图片”按钮选择本地图片文件。
- 图片预处理:将图片转换为灰度图并进行二值化处理,调整大小为 28x28 像素。
- 数字识别:使用训练好的 CNN 模型对处理后的图片进行预测,识别出图片中的数字。
- 结果显示:在界面中显示识别结果和置信度。
3. 模型管理
- 模型加载:程序启动时会尝试加载已保存的模型文件 digit_model.h5。
- 模型训练:如果未找到模型文件,可点击digit_model.h5下载,当然程序也会自动使用 MNIST 数据集训练一个新的 CNN 模型。
4. 交互功能
- 加载图片:用户可以选择本地图片文件进行识别。
- 识别数字:点击“识别”按钮,程序会对加载的图片进行识别并显示结果。
- 清除显示:点击“清除”按钮,清空图片和结果显示区域。
5. 技术细节
- 图像处理:使用 OpenCV 进行图片的灰度化、二值化和大小调整。
- 模型架构:使用 TensorFlow 构建的 CNN 模型,包含卷积层、池化层、全连接层等。
- MNIST 数据集:使用 MNIST 数据集进行模型训练,包含 0-9 的手写数字图片。
6. 运行流程
用户启动程序,界面显示加载图片、识别、清除按钮。
- 用户点击“加载图片”按钮,选择本地图片文件。
- 程序对图片进行预处理,显示原始图片和处理后的图片。
- 用户点击“识别”按钮,程序使用 CNN 模型对图片进行预测,显示识别结果和置信度。
- 用户点击“清除”按钮,清空图片和结果显示区域。
7. 安装相关依赖库
- PyQt5:用于构建用户界面。
- TensorFlow:用于构建和训练 CNN 模型。
- OpenCV:用于图片处理。
- NumPy:用于数值计算。
具体代码如下:
-
python">import sys import numpy as np from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QFileDialog) from PyQt5.QtGui import QPixmap, QImage from PyQt5.QtCore import Qt import cv2 from tensorflow.keras.models import Sequential, load_model from tensorflow.keras.layers import Dense, Conv2D, MaxPooling2D, Flatten, Dropout from tensorflow.keras.datasets import mnistclass DigitRecognizer(QMainWindow):def __init__(self):super().__init__()self.setWindowTitle("手写数字识别系统")self.setGeometry(100, 100, 800, 600)# 初始化UIself.init_ui()# 加载或训练模型self.load_model()def init_ui(self):"""初始化用户界面"""# 创建主窗口部件和布局central_widget = QWidget()self.setCentralWidget(central_widget)layout = QVBoxLayout(central_widget)# 创建顶部按钮区域button_layout = QHBoxLayout()# 添加按钮self.btn_load = QPushButton("加载图片", self)self.btn_load.clicked.connect(self.load_image)button_layout.addWidget(self.btn_load)self.btn_recognize = QPushButton("识别", self)self.btn_recognize.clicked.connect(self.recognize_digit)button_layout.addWidget(self.btn_recognize)self.btn_clear = QPushButton("清除", self)self.btn_clear.clicked.connect(self.clear_display)button_layout.addWidget(self.btn_clear)layout.addLayout(button_layout)# 创建显示区域display_layout = QHBoxLayout()# 原始图片显示self.image_label = QLabel()self.image_label.setMinimumSize(280, 280)self.image_label.setAlignment(Qt.AlignCenter)self.image_label.setStyleSheet("border: 2px solid black;")display_layout.addWidget(self.image_label)# 处理后的图片显示self.processed_label = QLabel()self.processed_label.setMinimumSize(280, 280)self.processed_label.setAlignment(Qt.AlignCenter)self.processed_label.setStyleSheet("border: 2px solid black;")display_layout.addWidget(self.processed_label)layout.addLayout(display_layout)# 结果显示self.result_label = QLabel("识别结果将在这里显示")self.result_label.setAlignment(Qt.AlignCenter)self.result_label.setStyleSheet("""QLabel {font-size: 24px;margin: 20px;padding: 10px;background-color: #f0f0f0;border-radius: 5px;}""")layout.addWidget(self.result_label)# 初始化图片变量self.current_image = Noneself.processed_image = Nonedef load_model(self):"""加载或训练模型"""try:# 尝试加载已保存的模型self.model = load_model('digit_model.h5')print("模型加载成功!")except:print("未找到保存的模型,开始训练新模型...")self.train_model()def train_model(self):"""训练新的模型"""# 加载MNIST数据集(x_train, y_train), (x_test, y_test) = mnist.load_data()# 数据预处理x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)x_train = x_train.astype('float32') / 255x_test = x_test.astype('float32') / 255# 创建模型self.model = Sequential([Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),MaxPooling2D(pool_size=(2, 2)),Conv2D(64, (3, 3), activation='relu'),MaxPooling2D(pool_size=(2, 2)),Dropout(0.25),Flatten(),Dense(128, activation='relu'),Dropout(0.5),Dense(10, activation='softmax')])# 编译模型self.model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 训练模型self.model.fit(x_train, y_train, epochs=5, batch_size=128, validation_split=0.1)# 保存模型self.model.save('digit_model.h5')print("模型训练完成并保存!")def load_image(self):"""加载图片"""file_name, _ = QFileDialog.getOpenFileName(self, "选择图片", "", "Image Files (*.png *.jpg *.jpeg *.bmp)")if file_name:# 读取图片self.current_image = cv2.imread(file_name)if self.current_image is None:self.result_label.setText("无法加载图片!")return# 显示原始图片height, width = self.current_image.shape[:2]bytes_per_line = 3 * widthq_image = QImage(self.current_image.data, width, height, bytes_per_line, QImage.Format_RGB888).rgbSwapped()pixmap = QPixmap.fromImage(q_image)scaled_pixmap = pixmap.scaled(280, 280, Qt.KeepAspectRatio)self.image_label.setPixmap(scaled_pixmap)# 预处理图片self.preprocess_image()def preprocess_image(self):"""预处理图片"""if self.current_image is None:return# 转换为灰度图gray = cv2.cvtColor(self.current_image, cv2.COLOR_BGR2GRAY)# 二值化_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)# 调整大小为28x28self.processed_image = cv2.resize(binary, (28, 28))# 显示处理后的图片height, width = self.processed_image.shapeq_image = QImage(self.processed_image.data, width, height, width, QImage.Format_Grayscale8)pixmap = QPixmap.fromImage(q_image)scaled_pixmap = pixmap.scaled(280, 280, Qt.KeepAspectRatio)self.processed_label.setPixmap(scaled_pixmap)def recognize_digit(self):"""识别数字"""if self.processed_image is None:self.result_label.setText("请先加载图片!")return# 准备数据image = self.processed_image.reshape(1, 28, 28, 1)image = image.astype('float32') / 255# 预测prediction = self.model.predict(image)digit = np.argmax(prediction[0])confidence = prediction[0][digit] * 100# 显示结果self.result_label.setText(f"识别结果: {digit} (置信度: {confidence:.2f}%)")def clear_display(self):"""清除显示"""self.image_label.clear()self.processed_label.clear()self.result_label.setText("识别结果将在这里显示")self.current_image = Noneself.processed_image = Nonedef main():app = QApplication(sys.argv)window = DigitRecognizer()window.show()sys.exit(app.exec_())if __name__ == "__main__":main()