【深度学习】常见模型-多层感知机(MLP,Multilayer Perceptron)

embedded/2025/1/23 3:51:28/

多层感知机(MLP)是一种经典的人工神经网络结构,由输入层、一个或多个隐藏层以及输出层组成。每一层中的神经元与前一层的所有神经元全连接,且各层间的权重是可学习的。MLP 是深度学习的基础模型之一,主要用于处理结构化数据、分类任务和回归任务等。


特点

  1. 全连接层:每个神经元与前一层的所有神经元连接,能学习复杂的非线性关系。
  2. 激活函数:隐藏层和输出层的每个神经元使用非线性激活函数(如 ReLU(/ray-loo/)、Sigmoid(/ˈsɪɡˌmɔɪd/)、Tanh(/tænʃ/))来增加模型的表达能力。
  3. 前向传播和反向传播
    • 前向传播:输入数据经过各层加权求和并通过激活函数计算输出。
    • 反向传播:根据损失函数计算梯度,并通过优化算法(如 SGD、Adam)更新权重。
  4. 通用逼近性:理论上,MLP 可以逼近任意连续函数。

MLP 的结构

  • 输入层:接受输入数据的特征向量,大小等于数据的维度。
  • 隐藏层:一个或多个隐藏层,每层由若干神经元组成,负责特征提取和非线性映射。
  • 输出层:根据任务类型输出结果,例如分类问题中每类的概率或回归问题中的预测值。

代码示例

以下是一个使用 TensorFlow/Keras(/'kɜːrəz/或/'kiərəz/) 实现简单 MLP 的代码示例:

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Input# 定义多层感知机模型
def build_mlp(input_dim, hidden_units, output_dim):model = Sequential()# 输入层和第一层隐藏层model.add(Input(shape=(input_dim,)))model.add(Dense(hidden_units, activation='relu'))# 添加第二层隐藏层(可选)model.add(Dense(hidden_units, activation='relu'))# 输出层model.add(Dense(output_dim, activation='softmax'))  # 对分类问题使用 softmaxreturn model# 模型参数
input_dim = 20   # 输入维度
hidden_units = 64  # 隐藏层神经元数
output_dim = 3   # 输出类别数# 构建模型
mlp_model = build_mlp(input_dim, hidden_units, output_dim)# 编译模型
mlp_model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',  # 分类问题常用损失函数metrics=['accuracy']
)# 查看模型结构
mlp_model.summary()# 生成模拟数据
import numpy as np
X_train = np.random.rand(1000, input_dim)
y_train = np.random.randint(0, output_dim, 1000)# 训练模型
mlp_model.fit(X_train, y_train, epochs=10, batch_size=32, validation_split=0.2)

 运行结果

Model: "sequential"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================dense (Dense)               (None, 64)                1344      dense_1 (Dense)             (None, 64)                4160      dense_2 (Dense)             (None, 3)                 195       =================================================================
Total params: 5699 (22.26 KB)
Trainable params: 5699 (22.26 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
Epoch 1/10
25/25 [==============================] - 1s 10ms/step - loss: 1.1018 - accuracy: 0.3525 - val_loss: 1.1075 - val_accuracy: 0.3250
Epoch 2/10
25/25 [==============================] - 0s 3ms/step - loss: 1.0912 - accuracy: 0.3675 - val_loss: 1.1097 - val_accuracy: 0.3250
Epoch 3/10
25/25 [==============================] - 0s 2ms/step - loss: 1.0850 - accuracy: 0.3875 - val_loss: 1.1122 - val_accuracy: 0.3300
Epoch 4/10
25/25 [==============================] - 0s 7ms/step - loss: 1.0828 - accuracy: 0.4013 - val_loss: 1.1190 - val_accuracy: 0.2750
Epoch 5/10
25/25 [==============================] - 0s 3ms/step - loss: 1.0762 - accuracy: 0.4137 - val_loss: 1.1256 - val_accuracy: 0.2750
Epoch 6/10
25/25 [==============================] - 0s 3ms/step - loss: 1.0713 - accuracy: 0.4200 - val_loss: 1.1145 - val_accuracy: 0.3150
Epoch 7/10
25/25 [==============================] - 0s 3ms/step - loss: 1.0739 - accuracy: 0.4300 - val_loss: 1.1220 - val_accuracy: 0.2950
Epoch 8/10
25/25 [==============================] - 0s 3ms/step - loss: 1.0697 - accuracy: 0.4187 - val_loss: 1.1218 - val_accuracy: 0.2950
Epoch 9/10
25/25 [==============================] - 0s 2ms/step - loss: 1.0628 - accuracy: 0.4487 - val_loss: 1.1229 - val_accuracy: 0.2800
Epoch 10/10
25/25 [==============================] - 0s 2ms/step - loss: 1.0549 - accuracy: 0.4550 - val_loss: 1.1352 - val_accuracy: 0.2650


重要概念与技巧

  1. 隐藏层数和神经元数:一般来说,更多的隐藏层和神经元可以提高模型复杂度,但可能导致过拟合。
  2. 激活函数:ReLU 是最常用的隐藏层激活函数,输出层的激活函数根据任务不同而选择(如分类使用 Softmax,回归使用线性激活)。
  3. 正则化:通过添加 L1/L2 正则化、Dropout 或 Batch Normalization,来防止过拟合。
  4. 优化器:选择合适的优化算法(如 Adam、SGD)加速训练过程。

应用场景

  1. 图像分类(通常是简单的低维数据集)
  2. 信号处理和时间序列分析
  3. 文本分类和自然语言处理中的简单任务
  4. 表格数据的分类和回归分析

MLP 是深度学习的起点,尽管它已被更复杂的模型(如卷积神经网络和 Transformer)所取代,但它依然是一个重要的理论基础。


http://www.ppmy.cn/embedded/156221.html

相关文章

SQL进阶——JOIN操作详解

在数据库设计中,数据通常存储在多个表中。为了从这些表中获取相关的信息,我们需要使用JOIN操作。JOIN操作允许我们通过某种关系(如相同的列)将多张表的数据结合起来。它是SQL中非常重要的操作,广泛应用于实际开发中。本…

第17个项目:Python烟花秀

源码下载地址:https://download.csdn.net/download/mosquito_lover1/90295693 核心源码: import pygame import random import math from PIL import Image import io # 初始化pygame pygame.init() # 设置窗口 WIDTH = 800 HEIGHT = 600 screen = pygame.display.s…

Text2SQL(NL2sql)对话数据库:设计、实现细节与挑战

Text2SQL(NL2sql)对话数据库:设计、实现细节与挑战 前言 1.何为Text2SQL(NL2sql)2.Text2SQL结构与挑战3.金融领域实际业务场景4.注意事项5.总结 前言 随着信息技术的迅猛发展,人机交互的方式也在不断演…

K8S中Pod控制器之Job控制器

Job,主要用于负责批量处理(一次要处理指定数量任务)短暂的一次性(每个任务仅运行一次就结束)任务。 一次性任务:Job 用于运行那些只需要执行一次的任务,如数据分析、图像渲染或批量处理。 成功终止:Job 会跟踪其创建的 Pod 的成功…

【网络协议】【http】【https】TLS1.3

【网络协议】【http】【https】TLS1.3 TLS1.3它的签名算法和密钥交换算法,默认情况下是被固定了下来的,他的加密套件里面呢,只包含了对称加密算法和摘要算法 客户端和服务器第一次连接 仍然需要1RTT ,不能0-RTT 第一次连接 1.客…

基于微信小程序教学辅助系统设计与实现(LW+源码+讲解)

专注于大学生项目实战开发,讲解,毕业答疑辅导,欢迎高校老师/同行前辈交流合作✌。 技术范围:SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容:…

SSM旅游信息管理系统

🍅点赞收藏关注 → 添加文档最下方联系方式可咨询本源代码、数据库🍅 本人在Java毕业设计领域有多年的经验,陆续会更新更多优质的Java实战项目希望你能有所收获,少走一些弯路。🍅关注我不迷路🍅 项目视频 …

Unity通过脚本对指定物体进行指定脚本的挂载,并初始化挂载脚本中变量

using System.Collections; using System.Collections.Generic; using Unity.XR.CoreUtils; using UnityEditor; using UnityEngine;public class AutoDetectScript : MonoBehaviour {public GameObject[] games; //[ContextMenu("一键挂载脚本")]public void UpScri…