前馈神经网络 (Feedforward Neural Network, FNN)

devtools/2024/11/18 8:37:56/

代码功能

网络定义:
使用 torch.nn 构建了一个简单的前馈神经网络
隐藏层使用 ReLU 激活函数,输出层使用 Sigmoid 函数(适用于二分类问题)。
数据生成:
使用经典的 XOR 问题作为数据集。
数据点为二维输入,目标为 0 或 1。
训练过程:
使用二分类交叉熵损失函数 BCELoss。
优化器为 Adam,具有较快的收敛速度。
损失可视化:
每次训练后记录损失并绘制损失曲线。
结果输出:
显示最终预测值,并与真实标签进行比较。
在这里插入图片描述

代码

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt# 1. 定义前馈神经网络
class FeedforwardNN(nn.Module):def __init__(self, input_dim, hidden_dim, output_dim):super(FeedforwardNN, self).__init__()self.fc = nn.Sequential(nn.Linear(input_dim, hidden_dim),  # 输入层到隐藏层nn.ReLU(),  # 激活函数nn.Linear(hidden_dim, output_dim),  # 隐藏层到输出层nn.Sigmoid()  # 输出层的激活函数(适用于二分类问题))def forward(self, x):return self.fc(x)# 2. 创建 XOR 数据集
def create_xor_data():X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)y = np.array([[0], [1], [1], [0]], dtype=np.float32)return X, y# 3. 训练前馈神经网络
def train_fnn():# 数据准备X, y = create_xor_data()X = torch.tensor(X, dtype=torch.float32)y = torch.tensor(y, dtype=torch.float32)# 初始化网络、损失函数和优化器input_dim = X.shape[1]hidden_dim = 10output_dim = 1model = FeedforwardNN(input_dim, hidden_dim, output_dim)criterion = nn.BCELoss()  # 二分类交叉熵损失optimizer = optim.Adam(model.parameters(), lr=0.01)# 训练网络epochs = 1000loss_history = []for epoch in range(epochs):# 前向传播outputs = model(X)loss = criterion(outputs, y)# 反向传播与优化optimizer.zero_grad()loss.backward()optimizer.step()# 记录损失loss_history.append(loss.item())if (epoch + 1) % 100 == 0:print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")# 绘制损失曲线plt.plot(loss_history)plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Training Loss Curve')plt.show()# 输出训练结果with torch.no_grad():predictions = model(X).round()print("Predictions:", predictions.numpy())print("Ground Truth:", y.numpy())# 运行训练
if __name__ == "__main__":train_fnn()

http://www.ppmy.cn/devtools/134913.html

相关文章

苍穹外卖学习-day11

1. Apac 1.1 介绍 Apache ECharts 是一款基于 Javascript 的数据可视化图表库,提供直观,生动,可交互,可个性化定制的数据可视化图表。 官网地址:Apache ECharts 常见的统计图形有:柱状图,条形…

创建vue3项目步骤

脚手架创建项目: pnpm create vue Cd 项目名称安装依赖:Pnpm iPnpm Lint:修复所有文件风格 ,不然eslint语法警告报错要双引号Pnpm dev启动项目 拦截错误代码提交到git仓库:提交前做代码检查 pnpm dlx husky-in…

【汇编语言】数据处理的两个基本问题 —— 汇编语言中的数据奥秘:数据位置与寻址方式总结

文章目录 前言1. 引言1.1 两个基本问题1.2 两个描述性符号 2. bx、si、di和bp2.1 通过"[...]"来寻址,只有这四种寄存器2.2 四种寄存器寻址时的组合方式2.3 使用bp时,默认段地址为ss 3.机器指令处理的数据在什么地方?4. 汇编语言中数…

【SQL】mysql常用命令

为方便查询,特整理MySQL常用命令。 约定:$后为Shell环境命令,>后为MySQL命令。 1 常用命令 第一步,连接数据库。 $ mysql -u root -p # 进入MySQL bin目录后执行,回车后输入密码连接。# 常用参数&…

任意文件下载漏洞

1.漏洞简介 任意文件下载漏洞是指攻击者能够通过操控请求参数,下载服务器上未经授权的文件。 攻击者可以利用该漏洞访问敏感文件,如配置文件、日志文件等,甚至可以下载包含恶意代码的文件。 这里再导入一个基础: 你要在网站下…

#define定义宏

#define机制包括了一个规定,允许把参数替换到文本中,这种实现通常称为宏或定义宏。 宏的申明方式: #define name(parament-list)stuff 其中的parament-list是一个由逗号隔开的符号表&#xf…

go channel中的 close注意事项 range取数据

在使用 Go 语言中的 close 函数时,有一些注意事项需要牢记,以确保程序的健壮性和正确性: 1. **仅用于通道(channel)**: - close 函数只能用于关闭通道,不能用于关闭文件、网络连接或其他资源…

设计一个设备探测1pv

探测**1 pV(皮伏特,)的微弱电信号是一个非常具有挑战性但可行的目标。这种极低电压的探测需要超高灵敏度的电路设计和信号处理技术,同时要尽量抑制噪声对信号的干扰。 以下是设计此类设备的一些核心思路和技术方向: …