手写西瓜书bp神经网络 mnist10 c#版本

news/2024/12/21 2:53:58/

本文根据西瓜书第五章中给出的公式编写,书中给出了全连接神经网络的实现逻辑,本文在此基础上编写了Mnist10手写10个数字的案例,网上也有一些其他手写的例子参考。demo并没有用UE而是使用unity进行编写,方便且易于查错。
该案例仅作为学习,博主也只是业余时间自学一些机器学习知识,欢迎各位路过的大佬留下建议~

测试效果:
在这里插入图片描述

源码下载地址:
https://download.csdn.net/download/grayrail/87802798

1.符号的意义

首先理顺西瓜书第五章中的各符号的意义:

  • x x x 输入的原始值
  • y y y 输出的原始值
  • y ^ \hat{y} y^ 输出的预测值
  • η eta,学习率,取值在0-1之间
  • d d d 输入层神经元的个数
  • q q q 隐层神经元的个数
  • l l l 输出层神经元的个数
  • i i i 输入层相关的下标索引
  • h h h 隐层相关的下标索引
  • j j j 输出层相关的下标索引
  • v i , h v_{i,h} vi,h 输入层隐层神经元的连接权
  • w i , h w_{i,h} wi,h 隐层神经元到输出层神经元的连接权
  • γ h γ_{h} γh gamma, 隐层神经元的阈值(阈值即y=ax+b中的b)
  • θ j θ_{j} θj theta, 输出层神经元的阈值(阈值即y=ax+b中的b)
  • α h α_{h} αh alpha, 隐层接收到的输入,书中公式 α h = ∑ i = 1 d v i , h x i α_{h}=\sum_{i=1}^{d} v_{i,h}x_{i} αh=i=1dvi,hxi
  • β j β_{j} βj beta,输出层接收到的输入,书中公式 β j = ∑ h = 1 q w h , j b h β_{j}=\sum_{h=1}^{q} w_{h,j}b_{h} βj=h=1qwh,jbh
  • b h b_{h} bh 存放隐层经过激活函数后的值,数组长度是隐层长度
  • y h a t s yhats yhats 存放输出层经过激活函数后的值,数组长度是输出层长度,书中没有,但是需要这样一个集合
  • g j g_{j} gj 反向传播g项,数组长度是输出层长度
  • e j e_{j} ej 反向传播e项,数组长度是隐层长度

2.正向传播(ForwardPropagation)

西瓜书第五章直接讲了反向传播,所以在这之前简单讲一下正向传播。

输入层(d)
x1
x2
x3
隐层(q)
1
x1*vi1
x2*vi1
x3*vi1
2
x1*vi2
x2*vi2
x3*vi2
...
输出层(l)
(隐层1*wh1,隐层2*wh1...)yhat1
同上yhat2
同上yhat3
同上yhat4

以上图为例,输入层的维度是[3],隐层的维度是[n,3],输出层的维度是[4,n],因此最终输出维度是[4]。输入层通常就是原始输入的信息,隐层用于超参数与中间环节计算,隐层的维度是n * m,m是输入层的数据自身维度,n可以理解为n种可能性(博主自己的理解),例如隐层的第二个维度是50,那么就是假设了50种可能性进行训练。

基于此,那么正向传播的流程如下:

  1. 初始化隐层的连接权重v,维度1是输入数据的长度,维度2是有多少种可能性,维度2可以自己填一个适合的值。
  2. 以隐层的长度q进行循环,对每个h维度下的集合进行点乘求和,每个元素 x i x_{i} xi乘以 v i , h v_{i,h} vi,h(即西瓜书中: α h = ∑ i = 1 d v i , h x i α_{h}=\sum_{i=1}^{d} v_{i,h}x_{i} αh=i=1dvi,hxi),然后减去阈值γ并传入激活函数,写入b集合。
  3. 以输出层的长度l进行循环,每个l维度下的存放着所有可能性的点乘结果(博主自己的理解),即西瓜书中的: β j = ∑ i = 1 q w h , j b h β_{j}=\sum_{i=1}^{q} w_{h,j}b_{h} βj=i=1qwh,jbh。然后减去阈值θ并传入激活函数,写入yhats集合。
  4. 可以再对yhats加一个softmax操作,筛选集合中的最大值返回下标索引,即输出结果。

3.反向传播(BackPropagation)

反向传播的难点之一是链式求导,西瓜书中已经帮我们把求导过程写好了,这里我先讲tips,再梳理反向传播流程。

3.1 E的公式乘以1/2的问题

在这里插入图片描述
这个直接问chat gpt:
∂ M S E ∂ w i j = ∂ ∂ w i j [ 1 2 n ∑ k = 1 n ( y k − y ^ k ) 2 ] = − 1 n ( y i − y ^ i ) f ′ ( β j ) x i \frac{\partial MSE}{\partial w_{ij}} = \frac{\partial}{\partial w_{ij}} \left[ \frac{1}{2n} \sum_{k=1}^n (y_k - \hat{y}_k)^2 \right] = -\frac{1}{n}(y_i - \hat{y}_i)f'(\beta_j)x_i wijMSE=wij[2n1k=1n(yky^k)2]=n1(yiy^i)f(βj)xi

ChatGPT:显然,在这个偏导数公式中出现了一个因子 1 n \frac{1}{n} n1
,而这个因子的存在是由于我们将MSE除以了2所致。如果不将MSE除以2,那么这个因子就会变成 1 2 n \frac{1}{2n} 2n1
,这在后续计算中会带来一些不必要的复杂性和麻烦。

3.2 公式梳理

书中的公式有点乱,下面给出按照顺序的梳理图:

3.2.1 W对于E的偏导数

在这里插入图片描述

3.2.2 V对于E的偏导数

在这里插入图片描述
这一部分应该是精髓所在,不是非常了解,不多加评论。

3.2.3 流程梳理

基于此,那么反向传播的流程如下:

  1. 单独对 g g g项求值并存入数组,用真实值 y y y和预测值 y ^ \hat{y} y^带入Sigmoid的导数公式
  2. 单独对 e e e项求值并存入数组
  3. 计算阈值(bias偏置) θ θ θ的delta量并赋值, Δ θ j = − η g j Δθ_{j}=-ηg_{j} Δθj=ηgj
  4. 计算 w w w的delta量并赋值, Δ w h , j = η g j b h Δw_{h,j}=ηg_{j}b_{h} Δwh,j=ηgjbh(注意这里的b是上一层的最终输出,不是bias)
  5. 计算阈值(bias偏置) γ γ γ的delta量并赋值, Δ γ j = − η e h Δγ_{j}=-ηe_{h} Δγj=ηeh
  6. 计算 v v v的delta量并赋值, Δ v i , h = η e h x i Δv_{i,h}=ηe_{h}x_{i} Δvi,h=ηehxi

4.代码实现

以Mnist案例为例,该案例使用神经网络识别28x28像素内图片的0-9个手写数字,接下来给出C#版本的Mnist代码实现,脚本挂载后有3种模式:
在这里插入图片描述

  • Draw Image Mode 用于绘制0-9个数字
  • User Mode 使用已经训练好的神经网络进行数字识别(没有做缓存的功能,需要手动先训练几次)
  • Train Mode 训练模式,DataPath中填入图片路径,图片格式首先取前缀,例如:3_04,表明这个图片真实值数字是3,是第4张备选图片

该案例在西瓜书的基础上又加入了momentum动量、softmax、Dropout、初始随机值范围修改(-1,1),softmax使用《深度学习入门 基于PYTHON的理论与实现》一书中提供的公式。经过一些轮次训练后的运行结果:
在这里插入图片描述

c#代码如下:

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using UnityEngine;public class TestMnist10 : MonoBehaviour
{public enum EMode { Train, DrawImage, User }const float kDropoutProb = 0.4f;/// <summary>/// d个输入神经元/// </summary>int d;/// <summary>/// q个隐层神经元/// </summary>int q;/// <summary>/// l个输出神经元/// </summary>int l;/// <summary>/// 输入层原始值/// </summary>float[] x;/// <summary>/// 输入层到隐层神经元的连接权/// </summary>float[][] v;/// <summary>/// 缓存上一次的v权值/// </summary>float[][] lastVMomentum;/// <summary>/// 隐层神经元到输出层神经元的连接权/// </summary>float[][] w;/// <summary>/// 缓存上一次的w权值/// </summary>float[][] lastWMomentum;float[] wDropout;/// <summary>/// 反向传播g项/// </summary>float[] g;/// <summary>/// 反向传播e项/// </summary>float[] e;/// <summary>/// 隐层接收到的输入(通常List长度是隐层长度)/// </summary>List<float> b;/// <summary>/// 输出层接收到的输入(通常List长度是输出层长度)/// </summary>List<float> yhats;/// <summary>/// 输出层神经元的阈值/// </summary>float[] theta;/// <summary>/// 隐层神经元的阈值/// </summary>float[] gamma;public void Init(int inputLayerCount, int hiddenLayerCount, int outputLayerCount){d = inputLayerCount;q = hiddenLayerCount;l = outputLayerCount;x = new float[inputLayerCount];b = new List<float>(1024);yhats = new List<float>(1024);e = new float[hiddenLayerCount];g = new float[outputLayerCount];v = GenDimsArray(typeof(float), new int[] { q, d }, 0, () => UnityEngine.Random.Range(-1f, 1f)) as float[][];w = GenDimsArray(typeof(float), new int[] { l, q }, 0, () => UnityEngine.Random.Range(-1f, 1f)) as float[][];wDropout = GenDimsArray(typeof(float), new int[] { l }, 0, null) as float[];lastVMomentum = GenDimsArray(typeof(float), new int[] { q, d }, 0, null) as float[][];lastWMomentum = GenDimsArray(typeof(float), new int[] { q, d }, 0, null) as float[][];theta = GenDimsArray(typeof(float), new int[] { l }, 0, () => UnityEngine.Random.Range(-1f, 1f)) as float[];gamma = GenDimsArray(typeof(float), new int[] { q }, 0, () => UnityEngine.Random.Range(-1f, 1f)) as float[];}public void ForwardPropagation(float[] input, out int output){x = input;for (int jIndex = 0; jIndex < l; ++jIndex){var r = UnityEngine.Random.value < kDropoutProb ? 1f : 0f;wDropout[jIndex] = r;}b.Clear();for (int hIndex = 0; hIndex < q; ++hIndex){var sum = 0f;for (int iIndex = 0; iIndex < d; ++iIndex){var u = input[iIndex] * v[hIndex][iIndex];sum += u;}var alpha = sum - gamma[hIndex];var r = Sigmoid(alpha);b.Add(r);}yhats.Clear();for (int jIndex = 0; jIndex < l; ++jIndex){var sum = 0f;for (int hIndex = 0; hIndex < q; ++hIndex){var u = b[hIndex] * w[jIndex][hIndex];sum += u;}var beta = sum - theta[jIndex];var r = Sigmoid(beta);//实际使用时关闭Dropout,训练时打开if (_EnableDropout){r *= wDropout[jIndex];r /= kDropoutProb;}yhats.Add(r);}var softmaxResult = Softmax(yhats.ToArray());for (int i = 0; i < yhats.Count; i++){yhats[i] = softmaxResult[i];}int index = 0;float maxValue = -999999f;for (int jIndex = 0; jIndex < l; ++jIndex){if (yhats[jIndex] > maxValue){maxValue = yhats[jIndex];index = jIndex;}}output = index;}public void BackPropagation(float[] correct){const float kEta1 = 0.03f;const float kEta2 = 0.01f;const float kMomentum = 0.3f;for (int jIndex = 0; jIndex < l; ++jIndex){var yhat = this.yhats[jIndex];var y = correct[jIndex];g[jIndex] = yhat * (1f - yhat) * (y - yhat);}for (int hIndex = 0; hIndex < q; ++hIndex){var bh = b[hIndex];var sum = 0f;//这个for循环的内容,个人感觉是精妙之处,可以拿到别的神经元的梯度。for (int jIndex = 0; jIndex < l; ++jIndex)sum += w[jIndex][hIndex] * g[jIndex];e[hIndex] = bh * (1f - bh) * sum;}for (int jIndex = 0; jIndex < l; ++jIndex){theta[jIndex] += -kEta1 * g[jIndex];}for (int hIndex = 0; hIndex < q; ++hIndex){for (int jIndex = 0; jIndex < l; ++jIndex){var bh = b[hIndex];var delta = kMomentum * lastWMomentum[jIndex][hIndex] + kEta1 * g[jIndex] * bh;//实际使用时关闭Dropout,训练时打开if (_EnableDropout){var dropout = wDropout[jIndex];delta *= dropout;delta /= kDropoutProb;}w[jIndex][hIndex] += delta;lastWMomentum[jIndex][hIndex] = delta;}}for (int hIndex = 0; hIndex < q; ++hIndex){gamma[hIndex] += -kEta2 * e[hIndex];}for (int hIndex = 0; hIndex < q; ++hIndex){for (int iIndex = 0; iIndex < d; ++iIndex){var delta = kMomentum * lastVMomentum[hIndex][iIndex] + kEta2 * e[hIndex] * x[iIndex];v[hIndex][iIndex] += delta;lastVMomentum[hIndex][iIndex] = delta;}}}void Start(){Init(784, 64, 10);}EMode _Mode;int[] _DrawNumberImage;bool _EnableDropout;string _DataPath;float Sigmoid(float val){return 1f / (1f + Mathf.Exp(-val));}float[] Softmax(float[] inputs){float[] outputs = new float[inputs.Length];float maxInput = inputs.Max();for (int i = 0; i < inputs.Length; i++){outputs[i] = Mathf.Exp(inputs[i] - maxInput);}float expSum = outputs.Sum();for (int i = 0; i < outputs.Length; i++){outputs[i] /= expSum;}return outputs;}float[] GetOneHot(string input){if (input.StartsWith("0"))return new float[] { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0 };if (input.StartsWith("1"))return new float[] { 0, 1, 0, 0, 0, 0, 0, 0, 0, 0 };if (input.StartsWith("2"))return new float[] { 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 };if (input.StartsWith("3"))return new float[] { 0, 0, 0, 1, 0, 0, 0, 0, 0, 0 };if (input.StartsWith("4"))return new float[] { 0, 0, 0, 0, 1, 0, 0, 0, 0, 0 };if (input.StartsWith("5"))return new float[] { 0, 0, 0, 0, 0, 1, 0, 0, 0, 0 };if (input.StartsWith("6"))return new float[] { 0, 0, 0, 0, 0, 0, 1, 0, 0, 0 };if (input.StartsWith("7"))return new float[] { 0, 0, 0, 0, 0, 0, 0, 1, 0, 0 };if (input.StartsWith("8"))return new float[] { 0, 0, 0, 0, 0, 0, 0, 0, 1, 0 };elsereturn new float[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 };}void Shuffle<T>(List<T> cardList){int tempIndex = 0;T temp = default;for (int i = 0; i < cardList.Count; ++i){tempIndex = UnityEngine.Random.Range(0, cardList.Count);temp = cardList[tempIndex];cardList[tempIndex] = cardList[i];cardList[i] = temp;}}/// <summary>/// 快速得到多维数组/// </summary>Array GenDimsArray(Type type, int[] dims, int deepIndex, Func<object> initFunc = null){if (deepIndex < dims.Length - 1){var sub_template = GenDimsArray(type, dims, deepIndex + 1, null);var current = Array.CreateInstance(sub_template.GetType(), dims[deepIndex]);for (int i = 0; i < dims[deepIndex]; ++i){var sub = GenDimsArray(type, dims, deepIndex + 1, initFunc);current.SetValue(sub, i);}return current;}else{var arr = Array.CreateInstance(type, dims[deepIndex]);if (initFunc != null){for (int i = 0; i < arr.Length; ++i)arr.SetValue(initFunc(), i);}return arr;}}void OnGUI(){if (_DrawNumberImage == null)_DrawNumberImage = new int[784];GUILayout.BeginHorizontal();if (GUILayout.Button("Draw Image Mode")){_Mode = EMode.DrawImage;Array.Clear(_DrawNumberImage, 0, _DrawNumberImage.Length);}if (GUILayout.Button("User Mode")){_Mode = EMode.User;Array.Clear(_DrawNumberImage, 0, _DrawNumberImage.Length);}if (GUILayout.Button("Train Mode")){_Mode = EMode.Train;_DataPath = Directory.GetCurrentDirectory() + "/TrainData";}GUILayout.EndHorizontal();var lastRect = GUILayoutUtility.GetLastRect();switch (_Mode){case EMode.Train:{GUILayout.BeginHorizontal();GUILayout.Label("Data Path: ");_DataPath = GUILayout.TextField(_DataPath);GUILayout.EndHorizontal();_EnableDropout = GUILayout.Button("dropout(" + (_EnableDropout ? "True" : "False") + ")")? !_EnableDropout : _EnableDropout;if (GUILayout.Button("Train 10")){var files = Directory.GetFiles(_DataPath);List<(string, float[])> datas = new(512);for (int i = 0; i < files.Length; ++i){var strArr = File.ReadAllText(files[i]).Split(',');datas.Add((Path.GetFileNameWithoutExtension(files[i]), Array.ConvertAll(strArr, m => float.Parse(m))));}for (int s = 0; s < 10; ++s){Shuffle(datas);for (int i = 0; i < datas.Count; ++i){ForwardPropagation(datas[i].Item2, out int output);UnityEngine.Debug.Log("<color=#00ff00> Input Number: " + datas[i].Item1 + " output: " + output + "</color>");BackPropagation(GetOneHot(datas[i].Item1));//break;}}}}break;case EMode.DrawImage:{lastRect.y += 50f;var size = 20f;var spacing = 2f;var mousePosition = Event.current.mousePosition;var mouseLeftIsPress = Input.GetMouseButton(0);var mouseRightIsPress = Input.GetMouseButton(1);var containSpacingSize = size + spacing;for (int y = 0, i = 0; y < 28; ++y){for (int x = 0; x < 28; ++x){var rect = new Rect(lastRect.x + x * containSpacingSize, lastRect.y + y * containSpacingSize, size, size);GUI.DrawTexture(rect, _DrawNumberImage[i] == 1 ? Texture2D.blackTexture : Texture2D.whiteTexture);if (rect.Contains(mousePosition)){if (mouseLeftIsPress)_DrawNumberImage[i] = 1;else if (mouseRightIsPress)_DrawNumberImage[i] = 0;}++i;}}if (GUILayout.Button("Save")){File.WriteAllText(Directory.GetCurrentDirectory() + "/Assets/tmp.txt", string.Join(",", _DrawNumberImage));}}break;case EMode.User:{lastRect.y += 150f;var size = 20f;var spacing = 2f;var mousePosition = Event.current.mousePosition;var mouseLeftIsPress = Input.GetMouseButton(0);var mouseRightIsPress = Input.GetMouseButton(1);var containSpacingSize = size + spacing;for (int y = 0, i = 0; y < 28; ++y){for (int x = 0; x < 28; ++x){var rect = new Rect(lastRect.x + x * containSpacingSize, lastRect.y + y * containSpacingSize, size, size);GUI.DrawTexture(rect, _DrawNumberImage[i] == 1 ? Texture2D.blackTexture : Texture2D.whiteTexture);if (rect.Contains(mousePosition)){if (mouseLeftIsPress)_DrawNumberImage[i] = 1;else if (mouseRightIsPress)_DrawNumberImage[i] = 0;}++i;}}if (GUILayout.Button("Recognize")){ForwardPropagation(Array.ConvertAll(_DrawNumberImage, m => (float)m), out int output);Debug.Log("output: " + output);}break;}}}
}

参考文章

  • Java实现BP神经网络MNIST手写数字识别https://www.cnblogs.com/baby7/p/java_bp_neural_network_number_identification.html
  • 反向传播算法对照 https://zhuanlan.zhihu.com/p/605765790

http://www.ppmy.cn/news/74588.html

相关文章

关于我被敲诈勒索骗了 1w 多这件事

大家好&#xff0c;我是程序员贺同学。 昨晚遭遇了人生中第一次诈骗&#xff0c;损失金额 1w多&#xff0c;趁这两天情绪缓了缓&#xff0c;把过程记录了下来&#xff0c;希望对看到的人有所帮助。 昨晚报完警回来快 23 点&#xff0c;把手机上的重要图片&#xff0c;视频&…

NLP大模型微调原理

1. 背景 LLM (Large Language Model) 大型语言模型&#xff0c;旨在理解和生成人类语言&#xff0c;需要在大量的文本数据上进行训练。一般基于Transformer结构&#xff0c;拥有Billion以上级别的参数量。比如GPT-3(175B)&#xff0c;PaLM(560B)。 NLP界发生三件大事&#xff…

【泛微ecology】上海CA在ecology中部署要求

ca证书部署后是否可以满足要求&#xff0c;不仅要部署ca证书&#xff0c;开启https&#xff0c;还要开启下rsa传输加密 开启RSA 登陆(只针对登录中的用户名密码加密) RSA加密方案&#xff1a;确认客户的ecology版本是否是KB8100180100以上版本&#xff0c;如果是的话,按照如下…

通过 Amazon S3 生命周期策略降低存储成本

如今信息化高速发展的时代&#xff0c;大部分的企业均已上云&#xff0c;企业内部产生的大量数据也都是通过云存储为企业实现了便捷的存储服务,上海亚极云为了用户能够使用性价比较高的云存储方案&#xff0c;特向用户推荐Amazon S3的存储方案。 Amazon S3的存储方案成本低、扩…

理解canvas元素,绘制简单2D绘图形

一、canvas 简介 图形和动画已经日益成为浏览器中现代 Web 应用程序的必备功能,但实现起来仍然比较困难,既要兼顾美观又不能拖慢浏览器。目前已经有一套日趋完善的 APT 和工具可以用来开发此类功能。 毋庸置疑,<canvas>是 HTML5 最受欢迎的新特性。​Canvas 是由 HT…

618前夜,电商物流「涌向」B2B战场

随着终端交易场景的增长红利消失殆尽&#xff0c;电商平台需要在产业侧寻找到新的企业支点&#xff0c;这里的背景布不再是熟悉的电商战场&#xff0c;而是红海重重的B2B场域。 作者|斗斗 编辑|皮爷 出品|产业家 电商平台开始在B端寻找新的交易环节。 随着人口红利逐渐…

【HISI IC萌新虚拟项目】Package Process Unit模块整体方案·PART3

5. 模块方案说明 5.1CRG 模块方案说明 5.1.1简介 CRG 模块实现复位信号的滤抖功能,可滤除小于100ns的低电平复位毛刺,并对复位信号进行同步化处理。同时,对100MHz的输入时钟信号进行2分频,作为 CPU_IF模块和TEST_CORE模块的工作时钟。 5.1.2接口信号 信号位宽I/O描述

OpenVINO 2022.3实战三:POT API实现图像分类模型 INT8 量化

OpenVINO 2022.3实战三&#xff1a;POT API实现图像分类模型 INT8 量化 1 准备需要量化的模型 这里使用我其他项目里面&#xff0c;使用 hymenoptera 数据集训练好的 MobileNetV2 模型&#xff0c;加载pytorch模型&#xff0c;并转换为onnx。 import os from pathlib import…