pytorch 比较两个张量的是否相等的函数介绍

server/2025/1/11 18:55:12/

在 PyTorch 中,可以使用多种函数来比较两个张量是否相等,具体选择取决于对比较精度的需求以及可能的数值误差。以下是常用的比较方法:


1. 完全相等的比较

(1) torch.eq

逐元素比较两个张量是否相等,返回布尔张量。

import torcha = torch.tensor([1, 2, 3])
b = torch.tensor([1, 2, 4])result = torch.eq(a, b)
print(result)  # 输出: tensor([True, True, False])

(2) torch.equal

检查两个张量是否完全相等(不仅要求每个元素相等,还要求形状相同)。

a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 2, 3])result = torch.equal(a, b)
print(result)  # 输出: True

2. 近似相等的比较

(1) torch.isclose

用于判断两个张量是否在一定容差范围内逐元素接近。

a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([1.0, 2.00001, 3.1])result = torch.isclose(a, b, rtol=1e-05, atol=1e-08)
print(result)  # 输出: tensor([True, True, False])
  • rtol: 相对容差
  • atol: 绝对容差
(2) torch.allclose

检查两个张量的所有元素是否在一定容差范围内近似相等。

a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([1.0, 2.00001, 3.0])result = torch.allclose(a, b, rtol=1e-05, atol=1e-08)
print(result)  # 输出: True

torch.allclose 是对 torch.isclose 的一个整体检查版本,只有当所有元素都接近时才返回 True

3. 逐元素绝对差的比较

(1) 自定义比较

如果需要更灵活的比较,可以直接计算差值并进行自定义判断。

a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([1.0, 2.00001, 3.1])diff = torch.abs(a - b)  # 计算绝对差
result = diff < 1e-05  # 判断是否小于某个阈值
print(result)  # 输出: tensor([True, True, False])

4. 总结

函数用途
torch.eq逐元素比较是否完全相等,返回布尔张量。
torch.equal检查两个张量是否完全相同(包括形状和元素),只返回一个布尔值。
torch.isclose逐元素比较是否近似相等,允许一定容差。
torch.allclose检查所有元素是否都在容差范围内近似相等,只返回一个布尔值。

选择合适的函数取决于具体需求:

  • 完全相等用 torch.eq 或 torch.equal
  • 近似相等用 torch.isclose 或 torch.allclose


http://www.ppmy.cn/server/157545.html

相关文章

Flink三种集群部署模型

这里写自定义目录标题 Flink 集群剖析Flink 应用程序执行Flink Session 集群&#xff08;Session Mode&#xff09;Flink Job 集群&#xff08;以前称为per-job&#xff09;Flink Application 集群&#xff08;Application Mode&#xff09; 参考 Flink 集群剖析 Flink 运行时…

字典树 / trie树

定义 当我手里有若干个字符串的时候&#xff0c;现在向你询问某个字符串时候是前面的这些字符串中的其中之一。如果我们用暴力的做法来求解的话&#xff0c;我可能需要对这些字符串进行逐一比对&#xff0c;效率是相当低的。那么这个时候我们就可以用 trie 树的结构简单高效的…

数据结构-串

串的实现 在C语言中所使用的字符串就是串的数据类型的一种。 串的存储结构 定长顺序存储表示 类似于线性表的顺序存储结构&#xff0c;用一组连续的存储单元存储串值的字符序列。 #define MAXLEN 255 //预定义最大串长为255 ​ typedef struct SString {char ch[MAXLEN]; …

el-descriptions-item使用span占行不生效

需要实现的效果是客户状态单独占满一行 错误代码&#xff1a; <el-descriptions title"基本信息" :column"3"> <el-descriptions-item label"公司电话:">Suzhou</el-descriptions-item><el-descriptions-item label"…

【Rust自学】11.7. 按测试的名称运行测试

喜欢的话别忘了点赞、收藏加关注哦&#xff08;加关注即可阅读全文&#xff09;&#xff0c;对接下来的教程有兴趣的可以关注专栏。谢谢喵&#xff01;(&#xff65;ω&#xff65;) 11.7.1. 按名称运行测试的子集 如果想要选择运行的测试&#xff0c;就将测试的名称&#x…

使用Python爬虫获取淘宝商品详情接口

以下是一篇关于使用Python获取淘宝商品详情接口的长篇文章&#xff1a; 淘宝商品详情接口简介 淘宝商品详情接口是淘宝开放平台提供的API之一&#xff0c;用于获取淘宝商品的详细信息。它可以帮助开发者获取商品的标题、价格、图片、库存、销量、评价等数据。这些数据对于电商…

理解Unity脚本编译过程:程序集

https://docs.unity3d.com/Manual/script-compilation.html 关于Unity C#脚本编译的细节&#xff0c;其中一个比较重要的知识点就是如何自定义Assembly。 预定义的assembly 默认情况下&#xff0c;Unity会按照这个规则进行编译。 PhaseAssembly nameScript files1Assembly-…

数组分割函数

这是一个数组分割函数&#xff0c;它的作用是将一个大数组按照指定的长度分割成多个小数组。 参数说明&#xff1a; array: 需要被分割的原始数组 subGroupLength: 每个小数组的长度 工作原理&#xff1a; splitArray(array, subGroupLength) {let index 0; …