理解深度学习pytorch框架中的线性层

ops/2025/1/24 12:30:44/

在神经网络或机器学习的线性层(Linear Layer / Fully Connected Layer)中,经常会见到两种形式的公式:

  • 数学文献或传统线性代数写法: y = W x + b \displaystyle y = W\,x + b y=Wx+b
  • 一些深度学习代码中写法: y = x W T + b \displaystyle y = x\,W^T + b y=xWT+b

初次接触时,很多人会觉得两者“方向”不太一样,不知该如何对照理解;再加上矩阵维度 ( in_features , out_features ) (\text{in\_features},\, \text{out\_features}) (in_features,out_features) ( out_features , in_features ) (\text{out\_features},\, \text{in\_features}) (out_features,in_features) 的各种写法常常让人疑惑不已。本文将从数学角度和编程实现角度剖析它们的关系,并结合实际示例指出一些常见的坑与需要特别留意的下标对应问题。

1. 数学角度: y = W x + b \displaystyle y = W\,x + b y=Wx+b

在线性代数中,如果我们假设输入 x x x 是一个列向量,通常会写作 x ∈ R ( in_features ) \displaystyle x\in\mathbb{R}^{(\text{in\_features})} xR(in_features)(或者在更严格的矩阵形状记法下写作 ( in_features , 1 ) (\text{in\_features},\,1) (in_features,1))。那么一个最常见的全连接层可以表示为:

y = W x + b , y = W\,x + b, y=Wx+b,

其中:

  • W W W 是一个大小为 ( out_features , in_features ) \bigl(\text{out\_features},\,\text{in\_features}\bigr) (out_features,in_features) 的矩阵;
  • b b b 是一个 out_features \text{out\_features} out_features-维的偏置向量(形状 ( out_features , 1 ) (\text{out\_features},\,1) (out_features,1));
  • y y y 则是输出向量,大小为 out_features \text{out\_features} out_features

示例

假设 in_features = 3 \text{in\_features}=3 in_features=3 out_features = 2 \text{out\_features}=2 out_features=2。那么:
W ∈ R 2 × 3 , x ∈ R 3 × 1 , b ∈ R 2 × 1 . W \in \mathbb{R}^{2\times 3},\quad x \in \mathbb{R}^{3\times 1},\quad b \in \mathbb{R}^{2\times 1}. WR2×3,xR3×1,bR2×1.

矩阵写开来就是:

W = [ w 11 w 12 w 13 w 21 w 22 w 23 ] , x = [ x 1 x 2 x 3 ] , b = [ b 1 b 2 ] . W = \begin{bmatrix} w_{11} & w_{12} & w_{13} \\[5pt] w_{21} & w_{22} & w_{23} \end{bmatrix},\quad x = \begin{bmatrix} x_{1}\\ x_{2}\\ x_{3} \end{bmatrix},\quad b = \begin{bmatrix} b_{1}\\ b_{2} \end{bmatrix}. W=[w11w21w12w22w13w23],x= x1x2x3 ,b=[b1b2].

那么线性变换结果 W x + b Wx + b Wx+b 可以展开为:

W x + b = [ w 11 x 1 + w 12 x 2 + w 13 x 3 w 21 x 1 + w 22 x 2 + w 23 x 3 ] + [ b 1 b 2 ] = [ w 11 x 1 + w 12 x 2 + w 13 x 3 + b 1 w 21 x 1 + w 22 x 2 + w 23 x 3 + b 2 ] . \begin{aligned} Wx + b &= \begin{bmatrix} w_{11}x_1 + w_{12}x_2 + w_{13}x_3 \\ w_{21}x_1 + w_{22}x_2 + w_{23}x_3 \end{bmatrix} + \begin{bmatrix} b_1 \\ b_2 \end{bmatrix} \\ &= \begin{bmatrix} w_{11}x_1 + w_{12}x_2 + w_{13}x_3 + b_1 \\ w_{21}x_1 + w_{22}x_2 + w_{23}x_3 + b_2 \end{bmatrix}. \end{aligned} Wx+b=[w11x1+w12x2+w13x3w21x1+w22x2+w23x3]+[b1b2]=[w11x1+w12x2+w13x3+b1w21x1+w22x2+w23x3+b2].

这就是最传统、在数学文献或线性代数课程中最常见的表示方法。


2. 编程实现角度: y = x W T + b \displaystyle y = x\,W^T + b y=xWT+b

在实际的深度学习代码(例如 PyTorch、TensorFlow)中,经常看到的却是下面这种写法:

y = x @ W.T + b

注意这里 W.shape 通常被定义为 ( out_features , in_features ) (\text{out\_features},\, \text{in\_features}) (out_features,in_features),而 x.shape 在批量处理时则是 ( batch_size , in_features ) (\text{batch\_size},\, \text{in\_features}) (batch_size,in_features)。于是 (x @ W.T) 的结果是 ( batch_size , out_features ) (\text{batch\_size},\, \text{out\_features}) (batch_size,out_features)

为什么会出现转置?
因为在数学里我们通常把 x x x 当作“列向量”放在右边,于是公式变成 y = W x + b y = Wx + b y=Wx+b
但在编程里,尤其是处理批量输入时,x 常写成“行向量”的形式 ( batch_size , in_features ) (\text{batch\_size},\, \text{in\_features}) (batch_size,in_features),这就造成了在进行矩阵乘法时,需要将 W(大小 ( out_features , in_features ) (\text{out\_features},\, \text{in\_features}) (out_features,in_features))转置成 ( in_features , out_features ) (\text{in\_features},\, \text{out\_features}) (in_features,out_features),才能满足「行×列」的匹配关系。

从结果上来看,

( batch_size , in_features ) × ( in_features , out_features ) = ( batch_size , out_features ) . (\text{batch\_size}, \text{in\_features}) \times (\text{in\_features}, \text{out\_features}) = (\text{batch\_size}, \text{out\_features}). (batch_size,in_features)×(in_features,out_features)=(batch_size,out_features).

所以,在代码里就写成 x @ W.T,再加上偏置 b(通常会广播到 batch_size \text{batch\_size} batch_size 那个维度)。

本质上这和数学公式里 y = W x + b y = W\,x + b y=Wx+b 并无冲突,只是一个“列向量”和“行向量”的转置关系。只要搞清楚最终你想让输出 y y y 的 shape 是多少,就能明白在代码里为什么要写 .T


3. 常见错误与易混点解析

有些教程或文档,会不小心写成:“如果我们有一个形状为 ( in_features , out_features ) (\text{in\_features},\text{out\_features}) (in_features,out_features) 的权重矩阵 W W W……”——然后又要做 W x Wx Wx,想得到一个 out_features \text{out\_features} out_features-维的结果。但按照线性代数的常规写法,行数必须和输出维度匹配、列数必须和输入维度匹配。所以 正确 的说法应该是

W ∈ R ( out_features ) × ( in_features ) . W\in\mathbb{R}^{(\text{out\_features}) \times (\text{in\_features})}. WR(out_features)×(in_features).

否则从矩阵乘法次序来看就对不上。
但这又可能让人迷惑:为什么深度学习框架 torch.nn.Linear(in_features, out_features) 却给出 weight.shape == (out_features, in_features) 其实正是同一个道理,它和上面“数学文献里”用到的 W W W 形状完全一致。


4. 小结

  1. 从数学角度
    最传统的记号是
    y = W x + b , W ∈ R ( out_features ) × ( in_features ) , x ∈ R ( in_features ) , y ∈ R ( out_features ) . y = W\,x + b, \quad W \in \mathbb{R}^{(\text{out\_features})\times(\text{in\_features})},\, x \in \mathbb{R}^{(\text{in\_features})},\, y \in \mathbb{R}^{(\text{out\_features})}. y=Wx+b,WR(out_features)×(in_features),xR(in_features),yR(out_features).

  2. 深度学习代码角度

    • 由于批量数据常被视为行向量,每一行代表一个样本特征,因此形状通常是 ( batch_size , in_features ) (\text{batch\_size},\, \text{in\_features}) (batch_size,in_features)
    • 对应的权重 W 定义为 ( out_features , in_features ) (\text{out\_features},\, \text{in\_features}) (out_features,in_features)。为了完成行乘以列的矩阵运算,需要对 W 做转置:
      y = x @ W.T + b
      
    • 得到的 y.shape ( batch_size , out_features ) (\text{batch\_size},\, \text{out\_features}) (batch_size,out_features)
  3. 避免踩坑

    • 写公式时,仔细确认 in_features \text{in\_features} in_features out_features \text{out\_features} out_features 的位置以及矩阵行列顺序。
    • 编程实践中理解“为什么要 .T”非常重要:那只是为了匹配「行×列」的矩阵乘法规则,本质上还是和 y = W x + b y = Wx + b y=Wx+b 相同。

通过理解并区分“列向量”与“行向量”的不同惯例,避免因为矩阵维度或转置不当而导致莫名其妙的错误或 bug。


参考链接

  • PyTorch 文档:torch.nn.Linear
  • 深度学习中的矩阵运算初步 —— batch_size 与矩阵乘法
  • 常见线性代数符号:行向量与列向量


http://www.ppmy.cn/ops/152736.html

相关文章

doris:阿里云 OSS 导入数据

Doris 提供两种方式从阿里云 OSS 导入文件: 使用 S3 Load 将阿里云 OSS 文件导入到 Doris 中,这是一个异步的导入方式。使用 TVF 将阿里云 OSS 文件导入到 Doris 中,这是一个同步的导入方式。 使用 S3 Load 导入​ 使用 S3 Load 导入对象存…

网络安全 | 0day漏洞介绍

关注:CodingTechWork 引言 在网络安全领域,0day漏洞(Zero-day Vulnerability)是指一个尚未被厂商、开发者或安全人员发现、修复或发布修补程序的安全漏洞。0day漏洞是黑客利用的一个重要攻击工具,因其未被披露或未被修…

Geek Uninstaller,绿色免安装轻量的应用卸载工具!

软件介绍 链接 一个轻量级拥有简洁交互界面、快速卸载电脑安装程序的工具。可快速扫描删除残余文件和注册表,对顽固和损坏的程序可执行强制删除、独立页面管理卸载系统Microsoft Store应用、快速打开程序安装文件夹、快速打开编辑程序注册表位置、将安装程序列表导…

Web入门

Spring 官网:spring.io Spring发展到今天已经形成了一种开发生态圈,Spring提供了若干个子项目,每个项目用于完成特定的功能 Spring Boot 可以帮助我们非常快速的构建应用程序、简化开发、提高效率 SpringBootWeb入门 ①.创建springboot工程&#xff0…

Unity3D仿星露谷物语开发25之创建时钟界面

1、目标 在时钟界面显示当前时钟信息,同时设置特殊按钮可以快速推进时间用于测试。 2、创建GameClock.cs脚本 在Assets -> Scripts -> TimeSystem目录下创建GameClock.cs脚本。 代码如下: using System.Collections; using System.Collections…

关于ARM和汇编语言

一图流 ARM 计算机组成 输入设备 输出设备 存储设备 运算器 控制器 处理器读取内存程序执行的过程 取指阶段:控制器器通过地址总线向存储器发送想要获取的指令的地址编号,存储器将指定的指令发送给处理器 译码阶段:控制器对指令进行分…

vue2的$el.querySelector在vue3中怎么写

这个也属于直接操作 dom 了,不建议在项目中这样操作,不过我是在vue2升级vue3的时候遇到的,是以前同事写的代码,也没办法 先来看一下对比 在vue2中获取实例是直接通过 this.$refs.xxx 获取绑定属性 refxxx 的实例,并且…

Spring Boot 3.x 整合 Logback 日志框架(支持异步写入)

Spring Boot 3.x 整合 Logback 日志框架(支持异步写入) 在构建任何应用程序时,良好的日志管理都是必不可少的。日志可以帮助我们监控、调试和跟踪代码的运行情况。 1. 添加日志配置文件 在 /resources 资源目录下,创建名为 log…