从0开始深度学习(20)——延后初始化和自定义层

devtools/2024/10/25 1:01:51/

一般情况下,模型参数在被创建时就被立即初始化了,但如果使用了延后初始化技术,就能在首次传入数据后,再初始化参数,旨在输入维度未知的情况下,预定义灵活的模型,动态推断各个层的参数大小。

有时我们会遇到或要自己发明一个现在在深度学习框架中还不存在的层。 在这些情况下,必须构建自定义层。本节将展示如何构建自定义层。

1 延后初始化的使用场景

  1. 动态输入形状:在某些情况下,模型的输入形状可能不是固定的,而是动态变化的。例如,在自然语言处理任务中,输入句子的长度可能会有所不同。
  2. 简化模型定义:延后初始化可以简化模型的定义,避免在模型构建阶段手动指定每个层的输入和输出形状。

2 pytorch实现延后初始化

在 PyTorch 中,可以通过 torch.nn.Lazy 模块来实现延后初始化。torch.nn.Lazy 模块提供了一些常见的层,如 LazyLinearLazyConv2d 等,这些层在第一次前向传播时自动确定参数的形状。

以下是一个简单的示例,展示了如何使用 torch.nn.LazyLinear 进行延后初始化:

import torch
import torch.nn as nnclass LazyNet(nn.Module):def __init__(self):super(LazyNet, self).__init__()self.lazy_linear = nn.LazyLinear(out_features=10)# 没有指定输出维度,只指定了输出维度def forward(self, x):x = self.lazy_linear(x)return x# 创建模型实例
net = LazyNet()# 创建一个输入张量,假设输入形状为 (batch_size, input_features)
input_data = torch.randn(32, 100)# 第一次前向传播,此时会自动确定 `LazyLinear` 层的输入形状
output = net(input_data)print(output.shape)  # 输出形状为 (32, 10)

运行结果:
在这里插入图片描述

3 自定义层

3.1 不带参数的层

我们构造一个没有任何参数的自定义层,下面的CenteredLayer类要从其输入中减去均值。 要构建它,我们只需继承基础层类并实现前向传播功能。

import torch
import torch.nn.functional as F
from torch import nnclass CenteredLayer(nn.Module):def __init__(self):super().__init__()def forward(self, X):return X - X.mean()

传入一些数据进行验证

layer = CenteredLayer()
layer(torch.FloatTensor([1, 2, 3, 4, 5]))

在这里插入图片描述
现在,我们可以将层作为组件合并到更复杂的模型中。作为额外的健全性检查,我们可以在向该网络发送随机数据后,检查均值是否为0。 由于我们处理的是浮点数,因为存储精度的原因,我们仍然可能会看到一个非常小的非零数。

net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())
Y = net(torch.rand(4, 8))
Y.mean()

在这里插入图片描述

3.2 带参数的层

下面我们继续定义具有参数的层, 这些参数可以通过训练进行调整。现在,让我们实现自定义版本的全连接层。 回想一下,该层需要两个参数,一个用于表示权重,另一个用于表示偏置项。 在此实现中,我们使用修正线性单元作为激活函数。 该层需要输入参数:in_units和units,分别表示输入数和输出数。

class MyLinear(nn.Module):def __init__(self, in_units, units):super().__init__()self.weight = nn.Parameter(torch.randn(in_units, units))self.bias = nn.Parameter(torch.randn(units,))def forward(self, X):linear = torch.matmul(X, self.weight.data) + self.bias.datareturn F.relu(linear)

接下来实例化模型,并访问其参数:

linear = MyLinear(5, 3)
linear.weight

在这里插入图片描述
我们可以使用自定义层直接执行前向传播计算。

linear(torch.rand(2, 5))

在这里插入图片描述
我们还可以使用自定义层构建模型,就像使用内置的全连接层一样使用自定义层。

net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1))
net(torch.rand(2, 64))

在这里插入图片描述
我们可以通过基本层类设计自定义层。这允许我们定义灵活的新层


http://www.ppmy.cn/devtools/128554.html

相关文章

从0开始学python-day14-pandas1

一、基础 1、概述 Pandas 是一个开源的第三方 Python 库,从 Numpy 和 Matplotlib 的基础上构建而来 Pandas 名字衍生自术语 "panel data"(面板数据)和 "Python data analysis"(Python 数据分析)…

perl批量改文件后缀

perl批量改文件后缀 如题&#xff0c;perl批量改文件后缀&#xff0c;将已有的统一格式的文件后缀&#xff0c;修改为新的统一的文件后缀。 #!/bin/perl use 5.010;print "Please input file suffix which U want to rename!\n"; chomp (my $suffix_old <>)…

h2数据库模拟mysql进行单元测试遇到的问题

使用h2数据库进行springboot的单元测试的时候出现的几个问题 1.h2数据库插入json数据的时候&#xff0c;默认是json String的形式&#xff08;josn数据入库的时候有转义字符&#xff09;&#xff0c;导致查询出来的json数据在进行处理的时候无法解析成jsonNode 操作&#xff…

Java-抽象类和接口

一、抽象类 ① 抽象类的概念 在上一篇文章中&#xff0c;我们学习了" 多态 "&#xff0c;它允许在相同方法的情况下处理不同的对象&#xff0c;即通过父类引用指向子类对象&#xff0c;并调用相同的方法&#xff0c;通过不同的子类调用该方法&#xff0c;进而产生不…

将jupyter中ipynb文件转成html文件

在VScode中使用jupyter&#xff08;丘比特&#xff09;时&#xff0c;将ipynb&#xff08;我python牛比&#xff09;后缀文件转成html时&#xff0c;总是报错&#xff1a; 那么可以就在需要转换的ipynb文件中&#xff0c;运行下面的代码&#xff0c;将“your_notebook”替换成同…

零基础Java第九期:一维数组(二)和二维数组

目录 一、数组的练习 1.1. 顺序表查找 1.2. 二分查找 1.3. 冒泡排序 二、二维数组 2.1. 二维数组的性质 2.2. 不规则二维数组 一、数组的练习 1.1. 顺序表查找 题目描述&#xff1a;给定一个数组, 再给定一个元素, 找出该元素在数组中的位置。 利用for循环去遍历数组&am…

未来医疗:大语言模型如何改变临床实践、研究和教育|文献精析·24-10-23

小罗碎碎念 这篇文章探讨了大型语言模型在医学领域的潜在应用和挑战&#xff0c;并讨论了它们在临床实践、医学研究和医学教育中的未来发展。 姓名单位名称&#xff08;中文&#xff09;Jan Clusmann德国德累斯顿工业大学埃尔朗根弗雷斯尼乌斯中心数字化健康研究所Jakob Nikola…

GitLab CVE-2024-6446、CVE-2024-6685 漏洞解决方案

极狐GitLab 近日发布安全补丁版本17.3.2, 17.2.5, 17.1.7&#xff0c;修复了17个安全漏洞&#xff0c;本分分享 CVE-2024-6446、CVE-2024-6685 两个漏洞详情及解决方案。 极狐GitLab 正式推出面向 GitLab 老旧版本免费用户的专业升级服务&#xff0c;为 GitLab 老旧版本进行专…