torch.nn.Sequential的用法

embedded/2025/1/2 20:03:12/

文章目录

  • 介绍
  • 基本用法
  • 添加命名层
  • 动态添加层
  • 嵌套使用
  • 与自定义前向传播的区别

介绍

torch.nn.Sequential 是 PyTorch 中的一个容器模块,用于将多个神经网络层按顺序组合在一起。它可以让我们以更加简洁的方式定义前向传播的网络结构,适合简单的线性堆叠模型。

基本用法

torch.nn.Sequential 按照定义的顺序将多个层组合在一起,输入数据会依次通过这些层。

python">import torch.nn as nn  # 定义一个简单的网络  
model = nn.Sequential(  nn.Linear(10, 20),  # 全连接层:输入 10,输出 20  nn.ReLU(),          # 激活函数:ReLU  nn.Linear(20, 1)    # 全连接层:输入 20,输出 1  
)

当调用 model(input) 时,输入会依次通过 Sequential 中的每一层。

python">import torch  input = torch.randn(5, 10)  # 输入:batch_size=5, features=10  
output = model(input)       # 前向传播  
print(output.shape)         # 输出:torch.Size([5, 1])

添加命名层

可以为每一层指定名称,方便后续访问或调试。

python">model = nn.Sequential(  ('fc1', nn.Linear(10, 20)),  # 命名为 'fc1'  ('relu1', nn.ReLU()),        # 命名为 'relu1'  ('fc2', nn.Linear(20, 1))    # 命名为 'fc2'  
)

通过名称或索引访问某一层:

python">print(model.fc1)  # 访问名为 'fc1' 的层  
print(model[0])   # 通过索引访问第一层

动态添加层

可以通过 add_module 方法动态添加层。

python">model = nn.Sequential()  
model.add_module('fc1', nn.Linear(10, 20))  # 添加第一层  
model.add_module('relu1', nn.ReLU())        # 添加激活函数  
model.add_module('fc2', nn.Linear(20, 1))   # 添加第二层

嵌套使用

nn.Sequential 可以嵌套使用,用于构建更复杂的网络。

python">model = nn.Sequential(  nn.Sequential(  nn.Linear(10, 20),  nn.ReLU()  ),  nn.Sequential(  nn.Linear(20, 10),  nn.ReLU()  ),  nn.Linear(10, 1)  
)

与自定义前向传播的区别

nn.Sequential 适合简单的线性堆叠模型,但如果需要更复杂的前向传播逻辑(如分支、跳跃连接等),需要继承 nn.Module 并自定义 forward 方法。
使用 nn.Sequential

python">model = nn.Sequential(  nn.Linear(10, 20),  nn.ReLU(),  nn.Linear(20, 1)  
)

自定义 forward

python">class CustomModel(nn.Module):  def __init__(self):  super(CustomModel, self).__init__()  self.fc1 = nn.Linear(10, 20)  self.fc2 = nn.Linear(20, 1)  self.relu = nn.ReLU()  def forward(self, x):  x = self.fc1(x)  x = self.relu(x)  x = self.fc2(x)  return x  model = CustomModel()

http://www.ppmy.cn/embedded/150252.html

相关文章

使用CSS 和 JavaScript 实现鼠标悬停时图片放大、缩小和抖动

我们可以通过 CSS 和 JavaScript 来实现鼠标悬停时图片放大、缩小和抖动的效果。以下是一个简单的实现方式&#xff1a; 1.HTML 结构 <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"viewp…

使用uWSGI将Flask应用部署到生产环境

使用uWSGI将Flask应用部署到生产环境&#xff1a; 1、安装uWSGI conda install -c conda-forge uwsgi&#xff08;pip install uwsgi会报错&#xff09; 2、配置uWSGI 在python程序的同一文件夹下创建 uwsgi.ini文件&#xff0c;文件内容如下表。 需要按照实际情况修改文件名称…

电脑中缺失的nvrtc64_90.dll文件如何修复?

一、文件丢失问题 案例&#xff1a;nvrtc64_90.dll文件缺失 问题分析&#xff1a; nvrtc64_90.dll是NVIDIA CUDA Runtime Compilation库的一部分&#xff0c;通常与NVIDIA的CUDA Toolkit或相关驱动程序一起安装。如果该文件丢失&#xff0c;可能会导致基于CUDA的应用程序&…

Unity3D 基于GraphView实现的节点编辑器框架详解

前言 在Unity3D游戏开发中&#xff0c;节点编辑器是一种强大的工具&#xff0c;它允许开发者以可视化的方式创建和编辑复杂的逻辑和流程。Unity提供了一个强大的UI工具包——GraphView&#xff0c;它使得创建自定义节点编辑器变得相对简单。本文将详细介绍如何使用GraphView实…

数据的简单处理——pandas模块——读取数据(Excel和csv格式)

使用Pandas模块可以从多种类型的文件中读取数据。本节主要从Excel和csv格式文件中读取数据为例&#xff0c;进行练习。 一、读取数据Excel格式 主要包括&#xff0c;读取完整表格、读取指定行数据、读取指定列数据。 二、读取数据csv格式 主要包括&#xff0c;读取完整表格…

ubuntu 20.04 国内源安装docker

先更新软件包&#xff0c;安装备要apt软件 # 更新软件包索引 sudo apt-get update# 安装需要的软件包以使apt能够通过HTTPS使用仓库 sudo apt-get install ca-certificates curl gnupg lsb-release使用阿里云源 # 添加阿里云官方GPG密钥 curl -fsSL http://mirrors.aliyun.co…

在C#中实现支持LINQ查询的自定义集合类

在C#中&#xff0c;若要使自定义集合类支持LINQ查询&#xff0c;需要实现一些特定的接口&#xff0c;这些接口通常与集合数据的枚举和操作有关。以下是一个基本步骤指南&#xff0c;用于创建一个支持LINQ查询的自定义集合类&#xff1a; 实现IEnumerable<T>接口&#xff…

Qt 中实现系统主题感知

【写在前面】 在现代桌面应用程序开发中&#xff0c;系统主题感知是一项重要的功能&#xff0c;它使得应用程序能够根据用户的系统主题设置&#xff08;如深色模式或浅色模式&#xff09;自动调整其外观。 Qt 作为一个跨平台的C图形用户界面应用程序开发框架&#xff0c;提供…