【Pytorch】torch.nn.functional模块中的非线性激活函数

news/2024/11/20 2:43:49/

        在使用torch.nn.functional模块时,需要导入包:

from torch.nn import functional

        以下是常见激活函数的介绍以及对应的代码示例:

tanh (双曲正切)

输出范围:(-1, 1)

特点:中心对称,适合处理归一化后的数据。
公式:tanh(x) = (e^x - e^{-x}) / (e^x + e^{-x})

import torch
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
y = torch.nn.funcationl.tanh(x)
print(y)  # 输出:tensor([-0.9640, -0.7616,  0.0000,  0.7616,  0.9640])

sigmoid (S形函数)

输出范围:(0, 1)
特点:用于将输入映射到概率值,但可能会导致梯度消失问题。
公式:sigmoid(x) = 1 / (1 + e^{-x})

y = torch.nn.funcational.sigmoid(x)
print(y)  # 输出:tensor([0.1192, 0.2689, 0.5000, 0.7311, 0.8808])

SiLU (Sigmoid Linear Unit,也称Swish) 

输出范围:(0, x)
特点:结合了线性和非线性特性,效果较好。
公式:silu(x) = x * sigmoid(x)

y = torch.nn.funcationl.silu(x)
print(y)  # 输出:tensor([-0.2384, -0.2689,  0.0000,  0.7311,  1.7616])

GELU (Gaussian Error Linear Unit)

输出范围:接近ReLU,但更加平滑。
特点:常用于Transformer模型。
公式:近似为:gelu(x) ≈ x * sigmoid(1.702 * x)

y = torch.nn.functional.gelu(x)
print(y)  # 输出:tensor([-0.0454, -0.1588,  0.0000,  0.8413,  1.9546])

ReLU (Rectified Linear Unit)

输出范围:[0, +∞)
特点:简单高效,是最常用的激活函数之一。
公式:relu(x) = max(0, x)

y = torch.nn.funcationl.relu(x)
print(y)  # 输出:tensor([0., 0., 0., 1., 2.])

ReLU_ (In-place ReLU)

输出范围:[0, +∞)
特点:修改原张量而不是生成新的张量,节省内存。

x.relu_()  # 注意:会改变x本身
print(x)  # x的值被修改为:tensor([0., 0., 0., 1., 2.])

Leaky ReLU

输出范围:(-∞, +∞)
特点:允许负值有较小的输出,避免死神经元问题。
公式:leaky_relu(x) = x if x > 0 else alpha * x

x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
y = torch.nn.functional.leaky_relu(x, negative_slope=0.01)
print(y)  # 输出:tensor([-0.0200, -0.0100,  0.0000,  1.0000,  2.0000])

Leaky ReLU_ (In-place Leaky ReLU)

特点:和ReLU_一样会修改原张量。

x.leaky_relu_(negative_slope=0.01)
print(x)  # x的值被修改

Softmax

输出范围:(0, 1),且所有输出的和为1。
特点:常用于多分类任务的最后一层。
公式:softmax(x)_i = exp(x_i) / sum(exp(x_j))

x = torch.tensor([1.0, 2.0, 3.0])
y = torch.nn.functional.softmax(x, dim=0)
print(y)  # 输出:tensor([0.0900, 0.2447, 0.6652])

Threshold

输出范围:手动设置的范围。
特点:小于阈值的数被置为设定值,大于等于阈值的数保持不变。

x = torch.tensor([-1.0, 0.0, 1.0, 2.0])
y = torch.nn.functional.threshold(x, threshold=0.5, value=0.0)
print(y)  # 输出:tensor([0., 0., 0., 2.])

Normalize

功能:将张量的值标准化到指定范围。

公式:normalize(x) = x / max(||x||, eps)

x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
y = torch.nn.functional.normalize(x, p=2, dim=1)
print(y)  # 输出:标准化到单位向量

http://www.ppmy.cn/news/1548383.html

相关文章

精通Rust系统教程-过程宏入门

本文介绍Rust过程宏定义、分类及应用示例。假设你已经熟悉Rust及基本概念、如数据类型、迭代和特性(traits)。 Rust宏简介 宏是Rust编程语言的重要组成部分,当你学习Rust语言时,很快就会遇到它们。Rust宏以最简单的方式让你在编译…

使用WebVTT和Track API增强HTML5视频的可访问性和互动性

💓 博客主页:瑕疵的CSDN主页 📝 Gitee主页:瑕疵的gitee主页 ⏩ 文章专栏:《热点资讯》 使用WebVTT和Track API增强HTML5视频的可访问性和互动性 使用WebVTT和Track API增强HTML5视频的可访问性和互动性 使用WebVTT和T…

【redis】—— 初识redis(redis基本特征、应用场景、以及重大版本说明)

序言 本文将引导读者探索Redis的世界,深入了解其发展历程、丰富特性、常见应用场景、使用技巧等,最后会对Redis演进过程中具有里程碑意义的版本进行详细解读。 目录 (一)初始redis (二)redis特性 &#…

1Panel 推送 SSL 证书到阿里云、腾讯云

本文首发于 Anyeの小站,点击链接 访问原文体验更佳 前言 都用 CDN 了还在乎那点 1 年证书钱么? 开句玩笑话,按照 Apple 的说法,证书有效期不该超过 45 天。那么证书有效期的缩短意味着要更频繁地更新证书。对于我这样的“裸奔”…

ubuntu连接orangepi-zero-2w桌面的几种方法

ubuntu连接orangepi-zero-2w桌面的几种方法 一 : 串口 wifi Nomachine 1 开发板通过串口连接wifi 扫描wifi nmcli dev wifi连接wifi sudo nmcli dev wifi connect wifi_name password wifi_passwd查看开发板IP ifconfig # 假设开发板IP是 192.168.2.32 使用…

OpenAI Adjusts Strategy as ‘GPT’ AI Progress Slow

注:本文为两篇关于当前大模型方向讨论的文章。 OpenAI 大改下代大模型方向,scaling law 撞墙?AI 社区炸锅了 机器之心 2024 年 11 月 11 日 11:57 北京 机器之心报道 编辑:Panda、泽南 大模型的 scaling law 到头了&#xff1f…

POD-Transformer多变量回归预测(Matlab)

目录 效果一览基本介绍程序设计参考资料 效果一览 基本介绍 1.Matlab实现POD-Transformer多变量回归预测,本征正交分解数据降维融合Transformer多变量回归预测,使用SVD进行POD分解(本征正交分解); 2.运行环境Matlab20…

【电子设计】按键LED控制与FreeRTOS

1. 安装Keilv5 打开野火资料,寻找软件包 解压后得到的信息 百度网盘 请输入提取码 提取码:gfpp 安装526或者533版本都可以 下载需要的 F1、F4、F7、H7 名字的 DFP pack 芯片包 安装完 keil 后直接双击安装 注册操作,解压注册文件夹后根据里面的图示步骤操作 打开说明 STM…