前馈神经网络 (Feedforward Neural Network, FNN)

embedded/2024/11/18 10:02:19/

代码功能

网络定义:
使用 torch.nn 构建了一个简单的前馈神经网络
隐藏层使用 ReLU 激活函数,输出层使用 Sigmoid 函数(适用于二分类问题)。
数据生成:
使用经典的 XOR 问题作为数据集。
数据点为二维输入,目标为 0 或 1。
训练过程:
使用二分类交叉熵损失函数 BCELoss。
优化器为 Adam,具有较快的收敛速度。
损失可视化:
每次训练后记录损失并绘制损失曲线。
结果输出:
显示最终预测值,并与真实标签进行比较。
在这里插入图片描述

代码

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt# 1. 定义前馈神经网络
class FeedforwardNN(nn.Module):def __init__(self, input_dim, hidden_dim, output_dim):super(FeedforwardNN, self).__init__()self.fc = nn.Sequential(nn.Linear(input_dim, hidden_dim),  # 输入层到隐藏层nn.ReLU(),  # 激活函数nn.Linear(hidden_dim, output_dim),  # 隐藏层到输出层nn.Sigmoid()  # 输出层的激活函数(适用于二分类问题))def forward(self, x):return self.fc(x)# 2. 创建 XOR 数据集
def create_xor_data():X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)y = np.array([[0], [1], [1], [0]], dtype=np.float32)return X, y# 3. 训练前馈神经网络
def train_fnn():# 数据准备X, y = create_xor_data()X = torch.tensor(X, dtype=torch.float32)y = torch.tensor(y, dtype=torch.float32)# 初始化网络、损失函数和优化器input_dim = X.shape[1]hidden_dim = 10output_dim = 1model = FeedforwardNN(input_dim, hidden_dim, output_dim)criterion = nn.BCELoss()  # 二分类交叉熵损失optimizer = optim.Adam(model.parameters(), lr=0.01)# 训练网络epochs = 1000loss_history = []for epoch in range(epochs):# 前向传播outputs = model(X)loss = criterion(outputs, y)# 反向传播与优化optimizer.zero_grad()loss.backward()optimizer.step()# 记录损失loss_history.append(loss.item())if (epoch + 1) % 100 == 0:print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")# 绘制损失曲线plt.plot(loss_history)plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Training Loss Curve')plt.show()# 输出训练结果with torch.no_grad():predictions = model(X).round()print("Predictions:", predictions.numpy())print("Ground Truth:", y.numpy())# 运行训练
if __name__ == "__main__":train_fnn()

http://www.ppmy.cn/embedded/138505.html

相关文章

MYSQL- 展示事件信息 EVENTS 语句(十八)

13.7.5.18 SHOW EVENTS 语句 SHOW EVENTS[{FROM | IN} schema_name][LIKE pattern | WHERE expr]此语句显示有关事件管理器事件的信息,这些信息在第23.4节“使用事件调度器”中进行了讨论。它要求显示事件的数据库具有EVENT权限。 以最简单的形式,SHOW…

【第四课】rust声明式宏理解与实战

目录 前言 理解宏 实战宏 前言 上一课在介绍vector时,我们再一次提到了rust中的宏,在初始化vector时使用了vec!宏,当时补了一句有机会会好好说明一下rust中的宏,并且写一个hashmap宏来初始化hashmap。想了想一直介绍基本语法还是比较枯燥乏味的,所以这节课我们介绍一点…

【B+树特点】

B树的特点 B树是B树的一种变体,广泛用于数据库系统和文件系统中,特别是在索引结构中。B树在B树的基础上进行了优化,主要在数据存储和查询效率上有所提升。以下是B树的主要特点: 1. 所有数据存储在叶子节点 与B树不同&#xff0…

HarmonyOS:使用常用组件构建页面

一、常用组件简介 1.1 Button 1.2 Text 1.4 Image 1.5 线性布局 (Row / Column) 1.6 列表(List/ ListItem) List 列表包含一系列相同宽度的列表项。适合连续、多行呈现同类数据,例如图片和文本。 ListItem 用来展示列表…

鸿蒙开发应用权限管理

简介 一种允许应用访问系统资源(如:通讯录等)和系统能力(如:访问摄像头、麦克风等)的通用权限访问方式,来保护系统数据(包括用户个人数据)或功能,避免它们被…

【大模型】大模型RAG检索增强生成技术使用详解

目录 一、前言 二、RAG技术介绍 2.1 RAG是什么 2.2 RAG工作原理 2.3 RAG优势 2.4 RAG应用场景 三、在线大模型平台RAG技术使用 3.1 阿里百炼平台 3.1.1 创建知识库 3.1.2 导入文档数据 3.1.3 文档数据解析 3.1.4 查看数据 3.2 百度文心智能体 3.2.1 创建知识库 3…

【设计模式】行为型模式(五):解释器模式、访问者模式、依赖注入

《设计模式之行为型模式》系列,共包含以下文章: 行为型模式(一):模板方法模式、观察者模式行为型模式(二):策略模式、命令模式行为型模式(三):责…

mybatis-flex

背景: mybatis-plus 出现那么久,多表查询这块一直没有进展, mybatis-flex它出现了 总结:mybatis-flex在链式调用没有mybatis-plus做得好,mp是key-value形式入参,mf分开了显得代码冗余,mf好在支…