droppath

server/2024/12/2 20:48:43/

DropPath 是一种用于正则化深度学习模型的技术,它在训练过程中随机丢弃路径(或者说随机让某些部分的输出变为零),从而增强模型的鲁棒性和泛化能力。

代码解释:

import torch
import torch.nn as nn

# 定义 DropPath 类
class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

def drop_path(x, drop_prob: float = 0., training: bool = False):
#drop_path(输入,将drop_prob初始化为0., 判断是否为训练模式)
    if drop_prob == 0. or not training:
        return x
#如果drop_prob等于0或者不是训练模式直接将输入输出
    keep_prob = 1 - drop_prob
#保留的概率
    shape = (x.shape[0],) + (1,) * (x.ndim - 1) 
# 形状:(batch_size, 1, 1, ...)
# x.shape[0]获取xshape的第一维也就是batch_size
# (1,) * (x.ndim - 1) 将shape用1填充和x的形状一样
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
# torch.rand(shape, dtype=x.dtype, device=x.device)生成随机数(生成均值为0,标准差为1的正态# 分布随机数)形状和shape一致的也就是和x一致,数据类型,设备都和x一致
# 将随机数和keep_prob相加得到随机数(范围[keep_prob,1+keep_prob])
    random_tensor.floor_() 
# 二值化,生成 0 或 1 的 mask
# 也就是将随机数向下取整
    output = x.div(keep_prob) * random_tensor
#x.div(keep_prob)将输入张量x的所有值除以keep_prob,目的是 放大保留下来的部分

#* random_tensor根据0 或 1 的 mask决定哪些路径会被保留(1)或丢弃(0)
    return output

为什么要放大保留下来的部分:

  • 丢弃路径会导致部分值被置为零,模型整体输出的总期望值会下降。
  • 为了补偿这种下降,需要对保留下来的部分放大,使得丢弃路径后的总期望值和丢弃前一致。

因为只是补偿所以并不一定等与原期望

数学解释:

假设输入张量是 x=\begin{bmatrix} x_{1,}&x_{2,} & ... &, x_{n} \end{bmatrix},其中每个元素 xi表示特征。

期望:E=\frac{1}{n}\sum_{1}^{n}x_{i}

丢弃之后:E=\frac{1}{n}\sum_{1}^{n}{keepprob}\cdot x_{i}

放大之后:E=\frac{1}{n}\sum_{1}^{n}\frac{​{keepprob}\cdot x_{i}}{keepprob}=\frac{1}{n}\sum_{1}^{n}x_{i}

实例:

python">import torch
import torch.nn as nn# 定义 DropPath 类
class DropPath(nn.Module):def __init__(self, drop_prob=None):super().__init__()self.drop_prob = drop_probdef forward(self, x):return drop_path(x, self.drop_prob, self.training)def drop_path(x, drop_prob: float = 0., training: bool = False):if drop_prob == 0. or not training:return xkeep_prob = 1 - drop_probshape = (x.shape[0],) + (1,) * (x.ndim - 1)  # 形状:(batch_size, 1, 1, ...)random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)random_tensor.floor_()  # 二值化,生成 0 或 1 的 maskprint(f'mask: {random_tensor}')output = x.div(keep_prob) * random_tensorreturn output# 定义简单模型
class SimpleModel(nn.Module):def __init__(self, drop_prob):super().__init__()self.linear = nn.Linear(4, 4)  # 简单的线性层self.drop_path = DropPath(drop_prob)  # 使用 DropPathself.activation = nn.ReLU()  # ReLU 激活def forward(self, x):print("输入数据:")print(x)x = self.linear(x)  # 线性层print("线性层输出:")print(x)x = self.activation(x)  # ReLU 激活print("激活后输出:")print(x)x = self.drop_path(x)  # DropPathprint("DropPath 后输出:")print(x)return x# 创建模型
model = SimpleModel(drop_prob=0.5)
model.train()  # 设置为训练模式以启用 DropPath# 输入数据
input_data = torch.tensor([[1.0, 2.0, 3.0, 4.0],[5.0, 6.0, 7.0, 8.0]], dtype=torch.float32)# 运行模型
output = model(input_data)

输出: 简单理解就是根据mask的1,0值对每个样本进行保留或置零

输入数据:
tensor([[1., 2., 3., 4.],
        [5., 6., 7., 8.]])
线性层输出:
tensor([[ 1.2836, -1.4602,  2.2660, -1.7250],
        [ 1.3035, -4.1391,  4.5453, -2.5738]], grad_fn=<AddmmBackward0>)
激活后输出:
tensor([[1.2836, 0.0000, 2.2660, 0.0000],
        [1.3035, 0.0000, 4.5453, 0.0000]], grad_fn=<ReluBackward0>)
mask: tensor([[1.],
        [0.]])
DropPath 后输出:
tensor([[2.5672, 0.0000, 4.5321, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=<MulBackward0>)

 

 


http://www.ppmy.cn/server/146833.html

相关文章

[python脚本处理文件入门]-17.Python如何操作Excel文件的读写

哈喽,大家好,我是木头左! 在Python中,处理Excel文件最常用的库之一是xlrd,它用于读取Excel文件。而当需要创建或写入Excel文件时,xlwt库则是一个不错的选择。这两个库虽然功能强大,但使用起来也非常简单直观。 安装与导入 确保你已经安装了这两个库。如果没有安装,可以…

单点登录深入详解之技术方案总结

技术方案之CAS认证 概述 CAS 是耶鲁大学的开源项目&#xff0c;宗旨是为 web 应用系统提供一种可靠的单点登录解决方案。 CAS 从安全性角度来考虑设计&#xff0c;用户在 CAS 输入用户名和密码之后通过ticket进行认证&#xff0c;能够有效防止密码泄露。 CAS 广泛使用于传统应…

sql分类

SQL&#xff08;Structured Query Language&#xff09;是一种用于管理和操作关系数据库管理系统&#xff08;RDBMS&#xff09;的编程语言。SQL 可以分为几个主要类别&#xff0c;每个类别都有其特定的用途和功能。以下是 SQL 的主要分类&#xff1a; 1. 数据定义语言&#x…

map用于leetcode

//第一种map方法 function groupAnagrams(strs) {let map new Map()for (let str of strs) {let key str ? : str.split().sort().join()if (!map.has(key)) {map.set(key, [])}map.get(key).push(str)} //此时map为Map(3) {aet > [ eat, tea, ate ],ant > [ tan,…

【LeetCode: 3232. 判断是否可以赢得数字游戏 + 模拟】

&#x1f680; 算法题 &#x1f680; &#x1f332; 算法刷题专栏 | 面试必备算法 | 面试高频算法 &#x1f340; &#x1f332; 越难的东西,越要努力坚持&#xff0c;因为它具有很高的价值&#xff0c;算法就是这样✨ &#x1f332; 作者简介&#xff1a;硕风和炜&#xff0c;…

多线程安全单例模式的传统解决方案与现代方法

在多线程环境中实现安全的单例模式时&#xff0c;传统的双重检查锁&#xff08;Double-Checked Locking&#xff09;方案和新型的std::once_flag与std::call_once机制是两种常见的实现方法。它们在实现机制、安全性和性能上有所不同。 1. 传统的双重检查锁方案 双重检查锁&am…

数据结构 (16)特殊矩阵的压缩存储

前言 特殊矩阵的压缩存储是数据结构中的一个重要概念&#xff0c;它旨在通过找出特殊矩阵中值相同的矩阵元素的分布规律&#xff0c;把那些呈现规律性分布的、值相同的多个矩阵元素压缩存储到一个存储空间中&#xff0c;从而节省存储空间。 一、特殊矩阵的定义 特殊矩阵是指具有…

【ETCD】etcd简单入门之基础操作基于etcdctl进行操作

这里将使用etcdctl命令行工具来进行演示&#xff0c; 1、使用put命令向etcd写入kv对 使用etcdctl put命令来设置键值对。put命令接受两个参数&#xff1a;键和值 使用方法&#xff1a; NAME:put - Puts the given key into the storeUSAGE:etcdctl put [options] <key&g…