PyTorch使用教程(6)一文讲清楚torch.nn和torch.nn.functional的区别

devtools/2025/1/19 17:24:47/

torch.nn torch.nn.functional 在 PyTorch 中都是用于构建神经网络的重要组件,但它们在设计理念、使用方式和功能上存在一些显著的区别。以下是关于这两个模块的详细区别:

1. 继承方式与结构

torch.nn

  • torch.nn 中的模块大多数是通过继承 torch.nn.Module 类来实现的。这些模块都是 Python 类,包含了神经网络的各种层(如卷积层、全连接层等)和其他组件(如损失函数、优化器等)。
  • torch.nn 中的模块可以包含可训练参数,如权重和偏置,这些参数在训练过程中会被优化。

torch.nn.functional

  • torch.nn.functional 中的函数是直接调用的,无需实例化。这些函数通常用于执行各种非线性操作、损失函数计算、激活函数应用等。
  • torch.nn.functional 中的函数没有可训练参数,它们只是执行操作并返回结果。

2. 实现方式与调用方式

torch.nn

  • torch.nn 中的模块是基于面向对象的方法实现的。开发者需要创建类的实例,并在类的 forward 方法中定义数据的前向传播路径。
  • torch.nn 中的模块通常需要先创建模型实例,再将输入数据传入模型中进行前向计算。

torch.nn.functional

  • torch.nn.functional 中的函数是基于函数式编程实现的。它们提供了灵活的接口,允许开发者以函数调用的方式轻松定制和扩展神经网络架构。
  • torch.nn.functional 中的函数可以直接调用,只需要将输入数据传入函数中即可进行前向计算。

3. 使用场景与优势

torch.nn

  • torch.nn 更适合用于定义有状态的模块,如包含可训练参数的层。
  • 当定义具有变量参数的层时(如卷积层、全连接层等),torch.nn 会帮助初始化好变量,并且模型类本身就是 nn.Module 的实例,看起来会更加协调统一。
  • torch.nn 可以结合 nn.Sequential 来简化模型的构建过程。

torch.nn.functional

  • torch.nn.functional 中的函数相比 torch.nn 更偏底层,封装性不高但透明度很高。开发者可以在其基础上定义出自己想要的功能。
  • 使用 torch.nn.functional 可以更方便地进行函数组合、复用等操作,适合那些喜欢使用函数式编程风格的开发者。当激活函数只需要在前向传播中使用时,使用 torch.nn.functional 中的激活函数会更加简洁。

4. 权重与参数管理

torch.nn

  • torch.nn 中的模块会自动管理权重和偏置等参数,这些参数可以通过 model.parameters() 方法获取,并用于优化算法的训练。

torch.nn.functional

  • torch.nn.functional 中的函数不直接管理权重和偏置等参数。如果需要使用这些参数,开发者需要在函数外部定义并初始化它们,然后将它们作为参数传入函数中。

5.举例说明

例子1:定义卷积层

使用 torch.nn

python">import torch.nn as nnclass MyConvNet(nn.Module):def __init__(self):super(MyConvNet, self).__init__()self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)def forward(self, x):x = self.conv1(x)return x# 实例化模型
model = MyConvNet()# 传入输入数据
input_tensor = torch.randn(1, 1, 32, 32)
output_tensor = model(input_tensor)

使用 torch.nn.functional

python">import torch.nn.functional as Fdef my_conv_net(input_tensor, weight, bias=None):output_tensor = F.conv2d(input_tensor, weight, bias=bias, stride=1, padding=1)return output_tensor# 定义卷积核的权重和偏置
weight = nn.Parameter(torch.randn(16, 1, 3, 3))
bias = nn.Parameter(torch.randn(16))# 传入输入数据
input_tensor = torch.randn(1, 1, 32, 32)
output_tensor = my_conv_net(input_tensor, weight, bias)

在这个例子中,使用 torch.nn 定义了一个包含卷积层的模型类,而使用 torch.nn.functional 则是通过函数直接进行卷积操作。注意在使用 torch.nn.functional 时,需要手动定义和传递卷积核的权重和偏置。

例子2:应用激活函数

使用 torch.nn

python">import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.relu = nn.ReLU()def forward(self, x):x = self.relu(x)return x# 实例化模型
model = MyModel()# 传入输入数据
input_tensor = torch.randn(1, 10)
output_tensor = model(input_tensor)

使用 torch.nn.functional

python">import torch.nn.functional as Fdef my_model(input_tensor):output_tensor = F.relu(input_tensor)return output_tensor# 传入输入数据
input_tensor = torch.randn(1, 10)
output_tensor = my_model(input_tensor)

在这个例子中,使用 torch.nn 定义了一个包含 ReLU 激活函数的模型类,而使用 torch.nn.functional 则是通过函数直接应用 ReLU 激活函数。

例子3:定义和计算损失

使用 torch.nn

python">import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.linear = nn.Linear(10, 2)def forward(self, x):x = self.linear(x)return x# 实例化模型
model = MyModel()# 定义损失函数
criterion = nn.CrossEntropyLoss()# 传入输入数据和标签
input_tensor = torch.randn(1, 10)
target = torch.tensor()# 前向传播和计算损失
output_tensor = model(input_tensor)
loss = criterion(output_tensor, target)

使用 torch.nn.functional

python">import torch.nn.functional as Fdef my_model(input_tensor):output_tensor = torch.matmul(input_tensor, weight.t()) + biasreturn output_tensor# 定义权重和偏置
weight = nn.Parameter(torch.randn(10, 2))
bias = nn.Parameter(torch.randn(2))# 定义损失函数
criterion = nn.CrossEntropyLoss()# 传入输入数据和标签
input_tensor = torch.randn(1, 10)
target = torch.tensor()# 前向传播和计算损失
output_tensor = my_model(input_tensor)
loss = criterion(output_tensor, target)

在这个例子中,使用 torch.nn 定义了一个包含全连接层的模型类,并使用了 torch.nn 中的损失函数来计算损失。而使用 torch.nn.functional 则是通过函数直接进行线性变换,并使用 torch.nn 中的损失函数来计算损失。注意在使用 torch.nn.functional 时,需要手动定义和传递权重和偏置。

6. 小结

torch.nn 和 torch.nn.functional 在定义神经网络组件、应用激活函数和计算损失等方面存在显著的区别。torch.nn 提供了一种面向对象的方式来构建模型,而 torch.nn.functional 则提供了一种更灵活、更函数式的方式来执行相同的操作。
在这里插入图片描述


http://www.ppmy.cn/devtools/151865.html

相关文章

03JavaWeb——Ajax-Vue-Element(项目实战)

1 Ajax 1.1 Ajax介绍 1.1.1 Ajax概述 我们前端页面中的数据,如下图所示的表格中的学生信息,应该来自于后台,那么我们的后台和前端是互不影响的2个程序,那么我们前端应该如何从后台获取数据呢?因为是2个程序&#xf…

前端——换行

大家都知道<br />和\n是有换行的作用 很多时候&#xff0c;有些区分不开<br />和\n的区别&#xff0c;他俩各自在什么情形下使用呢 一、<br /> 在浏览器中&#xff0c;<br /> 会强制文本在当前位置换行。 适用于需要在特定位置插入换行的场景。 二、…

KubeSphere 与 Pig 微服务平台的整合与优化:全流程容器化部署实践

一、前言 近年来,为了满足越来越复杂的业务需求,我们从传统单体架构系统升级为微服务架构,就是把一个大型应用程序分割成可以独立部署的小型服务,每个服务之间都是松耦合的,通过 RPC 或者是 Rest 协议来进行通信,可以按照业务领域来划分成独立的单元。但是微服务系统相对…

RabbitMQ-消息可靠性以及延迟消息

目录 消息丢失 一、发送者的可靠性 1.1 生产者重试机制 1.2 生产者确认机制 1.3 实现生产者确认 &#xff08;1&#xff09;开启生产者确认 &#xff08;2&#xff09;定义ReturnCallback &#xff08;3&#xff09;定义ConfirmCallback 二、MQ的持久化 2.1 数据持久…

基于Python机器学习、深度学习技术提升气象、海洋、水文领域实践应用

Python是功能强大、免费、开源&#xff0c;实现面向对象的编程语言&#xff0c;能够在不同操作系统和平台使用&#xff0c;简洁的语法和解释性语言使其成为理想的脚本语言。除了标准库&#xff0c;还有丰富的第三方库&#xff0c;Python在数据处理、科学计算、数学建模、数据挖…

简历_使用优化的Redis自增ID策略生成分布式环境下全局唯一ID,用于用户上传数据的命名以及多种ID的生成

系列博客目录 文章目录 系列博客目录WhyRedis自增ID策略 Why 我们需要设置全局唯一ID。原因&#xff1a;当用户抢购时&#xff0c;就会生成订单并保存到tb_voucher_order这张表中&#xff0c;而订单表如果使用数据库自增ID就存在一些问题。 问题&#xff1a;id的规律性太明显、…

DeviceNet转Profinet网关+FANUC机器人:打造工业界的灭霸手套,掌控无限可能

在某车厂项目中&#xff0c;客户需将甲方的FANUC机器人接入自身Profinet网络系统。因机器人采用DeviceNET协议&#xff0c;所以选用稳联技术研发的Profinet转DeviceNET网关&#xff08;WL-DVN-PN&#xff09;实现通讯转换。 新建项目并导入稳联技术DeviceNET转Profinet网关&…

[苍穹外卖] 1-项目介绍及环境搭建

项目介绍 定位&#xff1a;专门为餐饮企业&#xff08;餐厅、饭店&#xff09;定制的一款软件产品 功能架构&#xff1a; 管理端 - 外卖商家使用 用户端 - 点餐用户使用 技术栈&#xff1a; 开发环境的搭建 整体结构&#xff1a; 前端环境 前端工程基于 nginx 运行 - Ngi…