深入理解全连接层:从线性代数到 PyTorch 中的 nn.Linear 和 nn.Parameter

news/2024/9/18 9:06:48/ 标签: 机器学习, 全连接层, 线性层, 深度学习

文章目录

这篇文章会从基础的一个数学概念到对应的代码实现,你将了解到:

  • 为什么nn.Parameter()接受 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)作为参数?
  • 为什么不是torch.matmul(self.weight, x) + self.bias
  • 如何使用torch.matmul()@F.linear() 去等价地实现nn.Linear()的输出。

数学概念(全连接层线性层

线性变化是数学中一个基础的概念,它描述了如何通过线性变换将输入映射到输出。在线性代数中,线性变化通常表示为矩阵乘法。在神经网络中,线性层的核心就是实现这样的矩阵运算。

数学公式:

给定一个输入向量 x ∈ R n \mathbf{x} \in \mathbb{R}^n xRn 和一个输出向量 y ∈ R m \mathbf{y} \in \mathbb{R}^m yRm,线性变化通过矩阵 W ∈ R m × n \mathbf{W} \in \mathbb{R}^{m \times n} WRm×n 和偏置项 b ∈ R m \mathbf{b} \in \mathbb{R}^m bRm 进行变换,其公式为:
y = W x + b \mathbf{y} = \mathbf{W} \mathbf{x} + \mathbf{b} y=Wx+b

  • W \mathbf{W} W:是权重矩阵,维度为 m × n m \times n m×n,它决定了输入向量如何线性变换到输出空间;
  • x \mathbf{x} x:是输入向量,维度为 n n n,表示特征数据;
  • b \mathbf{b} b:是偏置向量,维度为 m m m,用来调整线性变换的输出;
  • y \mathbf{y} y:是输出向量,维度为 m m m,是变换后的结果。

例子:

如果输入向量 x \mathbf{x} x 有 3 个特征,输出向量 y \mathbf{y} y 有 2 个特征,则权重矩阵 W \mathbf{W} W 的形状为 2 × 3 2 \times 3 2×3。假设:
W = [ 1 2 3 4 5 6 ] , x = [ 1 2 3 ] , b = [ 0 1 ] \mathbf{W} = \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix}, \quad \mathbf{x} = \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix}, \quad \mathbf{b} = \begin{bmatrix} 0 \\ 1 \end{bmatrix} W=[142536],x= 123 ,b=[01]
线性变换计算为:
y = W x + b = [ 1 2 3 4 5 6 ] [ 1 2 3 ] + [ 0 1 ] = [ 14 32 ] + [ 0 1 ] = [ 14 33 ] \mathbf{y} = \mathbf{W} \mathbf{x} + \mathbf{b} = \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix} + \begin{bmatrix} 0 \\ 1 \end{bmatrix} = \begin{bmatrix} 14 \\ 32 \end{bmatrix} + \begin{bmatrix} 0 \\ 1 \end{bmatrix} = \begin{bmatrix} 14 \\ 33 \end{bmatrix} y=Wx+b=[142536] 123 +[01]=[1432]+[01]=[1433]
矩阵运算过程:
[ 1 2 3 4 5 6 ] [ 1 2 3 ] = [ ( 1 × 1 ) + ( 2 × 2 ) + ( 3 × 3 ) ( 4 × 1 ) + ( 5 × 2 ) + ( 6 × 3 ) ] = [ 14 32 ] \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix} = \begin{bmatrix} (1 \times 1) + (2 \times 2) + (3 \times 3) \\ (4 \times 1) + (5 \times 2) + (6 \times 3) \end{bmatrix} = \begin{bmatrix} 14 \\ 32 \end{bmatrix} [142536] 123 =[(1×1)+(2×2)+(3×3)(4×1)+(5×2)+(6×3)]=[1432]

nn.Linear()

nn.Linear() 会自动创建一个权重矩阵(Weight)和偏置项(Bias),并将它们应用到输入上。

代码示例:

import torch
import torch.nn as nn# 定义一个输入为3,输出为2的线性层
linear_layer = nn.Linear(3, 2)# 打印权重矩阵和偏置项
print("权重矩阵 W:")
print(linear_layer.weight)print("偏置项 b:")
print(linear_layer.bias)# 模拟输入向量
input_vector = torch.tensor([1.0, 2.0, 3.0])
output_vector = linear_layer(input_vector)
print("输出向量 y:")
print(output_vector)

image-20240912221728559

在这里,nn.Linear(3, 2) 创建了一个 2×3 的权重矩阵和一个 2 维的偏置向量。通过 linear_layer(input_vector),可以直接获得输入向量经过线性变换后的输出。

nn.Parameter()

在 PyTorch 中,nn.Linear() 自动处理了权重和偏置项的初始化和更新,但有时你可能希望对这些参数自定义一些操作,比如 LoRA。这时,我们可以使用 nn.Parameter() 来自定义权重和偏置,其实 nn.Linear() 本身就是使用的nn.Parameter(),感兴趣的话可以看官方源码。

以自定义线性层为例:

class CustomLinearLayer(nn.Module):def __init__(self, input_dim, output_dim):super(CustomLinearLayer, self).__init__()# 使用 nn.Parameter 手动定义权重和偏置self.weight = nn.Parameter(torch.randn(output_dim, input_dim))self.bias = nn.Parameter(torch.randn(output_dim))def forward(self, x):# 手动实现线性变换 y = Wx + breturn torch.matmul(x, self.weight.T) + self.bias# 使用自定义的线性层
custom_layer = CustomLinearLayer(3, 2)
output = custom_layer(input_vector)
print(output)

image-20240912222625609

在看完代码后,你可能会产生两个疑惑:

Q

1. 为什么 self.weight 的权重矩阵 shape 使用 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)而不是 ( in_features , out_features ) (\text{in\_features}, \text{out\_features}) (in_features,out_features)?

这正是我写这篇博客的原因,接下来我们详细解释这个问题。

让我们重新使用 in_features \text{in\_features} in_features out_features \text{out\_features} out_features来重现之前的数学定义:

对于输入向量 x ∈ R in_features \mathbf{x} \in \mathbb{R}^{\text{in\_features}} xRin_features全连接层的输出为:

y = W x + b \mathbf{y} = W\mathbf{x} + \mathbf{b} y=Wx+b

其中:

  • W ∈ R out_features × in_features W \in \mathbb{R}^{\text{out\_features} \times \text{in\_features}} WRout_features×in_features 是权重矩阵,
  • b ∈ R out_features \mathbf{b} \in \mathbb{R}^{\text{out\_features}} bRout_features 是偏置项。

在线性变换中,输入向量 x \mathbf{x} x 的维度是 in_features \text{in\_features} in_features,而输出向量 y \mathbf{y} y 的维度是 out_features \text{out\_features} out_features。根据矩阵乘法的规则,要将输入 x \mathbf{x} x 映射到输出 y \mathbf{y} y,权重矩阵 W W W 的形状应该是 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features),因为矩阵乘法中 W x W\mathbf{x} Wx的维度要求是:

( out_features × in_features ) × ( in_features × 1 ) = ( out_features × 1 ) (\text{out\_features} \times \text{in\_features}) \times (\text{in\_features} \times 1) = (\text{out\_features} \times 1) (out_features×in_features)×(in_features×1)=(out_features×1)

这保证了输出 y \mathbf{y} y 的维度是 out_features \text{out\_features} out_features

如果权重矩阵的形状是 ( in_features , out_features ) (\text{in\_features}, \text{out\_features}) (in_features,out_features),矩阵乘法的维度将不匹配,无法实现线性变换。

现在是不是感觉清晰了?不要 nn.Linear(in_feature, out_feature) 用多了就将权重矩阵当作是 ( in_features , out_features ) (\text{in\_features}, \text{out\_features}) (in_features,out_features)遗忘了线性代数的概念,数学才是这一切的基石。

2. 为什么是torch.matmul(x, self.weight.T) + self.bias 而不是torch.matmul(self.weight, x) + self.bias?

主要原因还是在于 输入张量 x 的形状矩阵乘法规则

一般来说,模型的输入 x 实际上并不是 ( in_features , 1 ) (\text{in\_features}, 1) (in_features,1),而是 ( batch_size , in_features ) (\text{batch\_size}, \text{in\_features}) (batch_size,in_features),而权重矩阵 self.weight 的形状是 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)​,我们需要实现的线性变换是:
y = W x + b y = W x + b y=Wx+b
根据矩阵乘法规则,第一个矩阵的列数必须等于第二个矩阵的行数,这意味着我们不能直接计算 torch.matmul(self.weight, x),因为这样会导致维度不匹配:

  • self.weight 形状为 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)x 形状为 ( batch_size , in_features ) (\text{batch\_size}, \text{in\_features}) (batch_size,in_features)
  • torch.matmul(self.weight, x) 的维度计算规则将要求 x 的形状为 ( in_features , batch_size ) (\text{in\_features}, \text{batch\_size}) (in_features,batch_size),但这与模型的输入不匹配。

因此,正确的矩阵乘法应该是 torch.matmul(x, self.weight.T),其中 self.weight.T 表示 self.weight 的转置矩阵,此时的形状为 ( in_features , out_features ) (\text{in\_features}, \text{out\_features}) (in_features,out_features)

这样,torch.matmul(x, self.weight.T) 的维度计算为:

( batch_size , in_features ) × ( in_features , out_features ) = ( batch_size , out_features ) (\text{batch\_size}, \text{in\_features}) \times (\text{in\_features}, \text{out\_features}) = (\text{batch\_size}, \text{out\_features}) (batch_size,in_features)×(in_features,out_features)=(batch_size,out_features)

这就得到了正确的输出形状 ( batch_size , out_features ) (\text{batch\_size}, \text{out\_features}) (batch_size,out_features)

3. 为什么不直接设置self.weight = nn.Parameter(torch.randn(input_dim, output_dim))

这样不就可以不转置直接使用torch.matmul(x, self.weight)了吗?的确如此,或许是因为 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features) 对于矩阵运算 W x W\mathbf{x} Wx 来讲更符合直觉吧。

计算过程的细分:torch.matmul() vs @ 运算符

在 PyTorch 中,torch.matmul() 用于实现矩阵乘法,而 @ 是其简洁的符号形式,是 Python 的语法糖,二者在功能上是等价的。

示例代码:

import torch# 定义权重矩阵 W 和输入向量 input_vector
W = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
input_vector = torch.tensor([1.0, 2.0, 3.0])# 使用 torch.matmul 实现矩阵乘法
result1 = torch.matmul(W, input_vector)# 使用 @ 运算符
result2 = W @ input_vectorprint("使用 torch.matmul 计算的结果:")
print(result1)print("使用 @ 运算符计算的结果:")
print(result2)

结果:

image-20240912233355773

使用 F.linear()

PyTorch 提供了 F.linear() 作为函数式接口,它与 nn.Linear() 类似,但不需要创建一个线性层对象。F.linear() 可以接受线性层的权重和偏置作为输入。

示例代码:

import torch.nn.functional as F# 使用 F.linear 进行线性变换
output = F.linear(input_vector, linear_layer.weight, linear_layer.bias)
print(output)

image-20240912233501651


http://www.ppmy.cn/news/1525149.html

相关文章

5款免费版文章生成器,自动生成文章更省创作精力

在自媒体创作中,随着不断的内容输出,让许多创作者都面临着时间和精力的双重压力。而免费版文章生成器的出现,犹如黑暗中的一束光,为创作者们带来了新的希望。免费版文章生成器能够以高效便捷的方式,帮助创作者节省大量…

DevExpress WinForms v24.1新版亮点:功能区、数据编辑器全新升级

DevExpress WinForms拥有180组件和UI库,能为Windows Forms平台创建具有影响力的业务解决方案。DevExpress WinForms能完美构建流畅、美观且易于使用的应用程序,无论是Office风格的界面,还是分析处理大批量的业务数据,它都能轻松胜…

Redis面对数据量庞大处理方法

当Redis面对数据量庞大时,其应对策略需要从多个维度出发,包括数据分片、内存优化、持久化策略、使用集群、硬件升级、数据淘汰策略、合理设计数据结构以及监控系统性能等。以下是对这些策略的详细阐述,以期提供不少于2000字的深入解答。 一、…

2024伊语IM即时通讯源码/im商城系统/纯源码IM通讯系统安卓+IOS前端纯原生源码

一、端口说明、域名解析及服务器配置要求 1.1端口说明 使用二级域名映射的情况下 使用端口说明3306数据导入是可以开放 后期关闭 或者直接在服务器上面导入6379不用对外开放9903需要开放80需要开放 1.2 子域名说明: api.xxx.com接口 im.xxx.com通讯 web.xxx.…

[Day 72] 區塊鏈與人工智能的聯動應用:理論、技術與實踐

區塊鏈在跨境支付中的應用 跨境支付一直是全球經濟中極具挑戰的領域。傳統的跨境支付系統通常需要數天時間來處理交易,涉及的中間機構多且手續費昂貴。然而,區塊鏈技術的出現為解決這些問題提供了一條嶄新的途徑。本文將探討區塊鏈在跨境支付中的應用&a…

linux中vim常用命令大全

在Linux系统中,Vim是一款功能强大的文本编辑器,广泛用于代码编写、文档编辑等多种场景。Vim以其高效的编辑能力和丰富的命令集著称。以下是Vim编辑器中常用命令的详细大全,旨在帮助用户更高效地利用Vim进行文本编辑。 一、启动与退出Vim 启动…

JAVA开源项目 校园管理系统 计算机毕业设计

本文项目编号 T 026 ,文末自助获取源码 \color{red}{T026,文末自助获取源码} T026,文末自助获取源码 目录 一、系统介绍二、演示录屏三、启动教程四、功能截图五、文案资料5.1 选题背景5.2 国内外研究现状5.3 可行性分析 六、核心代码6.1 管…

Linux杂项知识

Linux的启动过程 Linux 的启动过程大致可以分为以下几个阶段: 1. BIOS/UEFI 加电自检 当计算机加电时,BIOS 或 UEFI 会首先执行一系列的硬件自检(POST,Power-On Self Test),检查系统是否正常运行。接下来…

Ansible自动化部署kubernetes集群

机器环境介绍 1.1. 机器信息介绍 IP hostname application CPU Memory 192.168.204.129 k8s-master01 etcd,kube-apiserver,kube-controller-manager,kube-scheduler,kubelet,kube-proxy,containerd 2C 4G 192.168.204.130 k8s-w…

Java小白一文讲清Java中集合相关的知识点(五)

Set接口和常用方法 基本介绍 无序(添加和取出的顺序不一样),没有索引不允许重复元素,所以最多包含一个nullJDK API 中Set接口的实现类有: public static void main(String[] args) {//1.以set接口的实现类HashSet来讲…

手机玩机常识-------诺基亚系列机型3/5/6/7/8详细的刷机教程步骤 手机参考救砖刷机教程

诺基亚手机 诺基亚(Nokia Corporation),成立于1865年,是一家主要从事移动通信设备生产和相关服务的手机公司 ,总部位于芬兰埃斯波 。从1996年开始,诺基亚手机连续15年占据手机市场份额第一位置&…

【测试】——自动化测试入门(Selenium环境搭建)

📖 前言:本文介绍了自动化测试的基础知识,重点讲解了Selenium环境的搭建。内容包括自动化测试的定义、自动化测试金字塔模型、Selenium的特点和工作原理,以及如何在Java环境中配置和使用Selenium进行UI自动化测试。 目录 &#x1…

性能测试经典案例解析——政务查询系统

各位好,我是 道普云 一站式云测试SaaS平台。一个在软件测试道路上不断折腾十余年的萌新。 欢迎关注我的主页 道普云 文章内容具有一定门槛,建议先赞再收藏慢慢学习,有不懂的问题欢迎私聊我。 希望这篇文章对想提高软件测试水平的你有所帮…

go 语言常见问题(4)

31. go语言编程的好处是什么 编译和运行都很快。在语言层级支持并行操作。有垃圾处理器。内置字符串和 maps。函数是 go 语言的最基本编程单位。 32. 说说go语言的select机制 select 机制用来处理异步 IO 问题select 机制最大的一条限制就是每个 case 语句里必须是一个 IO 操…

【C语言】归并排序递归和非递归——动图演示

目录 一、归并排序思想1.1 基本思想1.2 大体思路 二、实现归并排序(递归)三、实现归并排序(非递归)3.1 实现思路:3.2 越界处理3.3 时间复杂度和空间复杂度 总结 一、归并排序思想 1.1 基本思想 归并排序(M…

redis为什么快

春内存访问,相比数据库访问磁盘要快单线程,避免上下文切换带来的cpu开销渐进式Rehash。减少阻塞网络模型多路复用,reactor模型 常用基本数据类型 5个基本数据类型2个高级数据结构(bitmaps、hyperlog) redis高级功能…

Gitea Action注册runner

我的是gitea也可以和github 兼容,只是没有github 那么靓而已 安装一个gitea仓库 docker run -d --name gitea \-p3000:3000 -p2222:22 \-v /git/data:/data \ -v /etc/timezone:/etc/timezone:ro \-v /etc/localtime:/etc/localtime:ro \gitea/gitea:1.21.1setti…

【漏洞复现】某4国语言抖音点赞系统存在任意文件上传漏洞

漏洞描述 某4国语言 中文+英文+泰语+繁体 UI也非常不错 功能比较完善!【系统功能】1.任务后台添加/用户发布,后台审核 2.机器人、大转盘;已完善 3.支付可以对接第三方和线下银行卡收款;4.后台增加员工账号(推广员专属账号),可以查看员工推广报表;5.会员等级功能,会员级…

wireshark打开时空白|没有接口,卸载重装可以解决

解决方法:卸载wireshark,全选卸载干净,重新安装旧版的wireshark4.2.7, 甚至cmd下运行net start npf时显示服务名无效,但打开wireshark仍有多个接口 错误描述: 一开始下载的是wireshark的最新版,win11 x64 在安装wir…

Redis Sentinel(哨兵)详解

目录 一:什么是Sentinel(哨兵) 二:Sentinel有什么用 1.监控 2.故障转移 3通知 4.配置提供 三:Sentinel如何检测master节点宕机 1.主观下线 2.客观下线 四:Sentinel是如何选举出新的master 1.s…