使用numpy或pytorch校验两个张量是否相等

server/2024/10/11 5:21:51/

文章目录


做算法过程中,如果涉及到模型落地,那必然会将原始的深度学习的框架训练好的模型转换成目标硬件模型的格式,如onnx,tensorrt,openvino,tflite;那么就有对比不同格式模型输出的一致性,从而判断模型转换是否成功。

numpy_2">1、numpy

用到的核心代码就一行,就是:

import numpy as np
np.testing.assert_allclose(actual,expected,rtol,atol)

上代示例:

import numpy as np# 定义两个数组
actual= np.array([1.0, 2.0, 3.0])
expected = np.array([1.01, 1.99, 3.0])# 使用 np.testing.allclose 检查它们是否近似相等
np.testing.assert_allclose(actual,expected,rtol=0,atol=0.01)

输出:
在这里插入图片描述
最大的绝对误差是0.01,最大的相对误差是0.00990099.
再一个示例:

import numpy as np# 定义两个数组
actual= np.array([1.0, 2.0, 3.0])
expected = np.array([1.01, 1.99, 3.0])# 使用 np.testing.allclose 检查它们是否近似相等
np.testing.assert_allclose(actual,expected,rtol=0,atol=0.0100001)

只是改了atol 从0.01改成0.0100001。

所以关于rtol和atol做如下理解:
rtol 就是relative tolarance ,atol 就absolute tolarance.
先计算绝对误差:

diff = abs(actual-expecd) #绝对误差
tolarance = atol+ rtol*abs(expected) #误差容忍上限
if diff<tolarance:pass
else:print("报错信息,如图,有最大绝对误差 最大相对误差 不相等的百分比等")

最大绝对误差= max(diff)
最大相对误差= max(diff)/abs(expected)

函数默认的 atol=1e-7,rtol=0
但考虑到float32精度,有效数字也就7位,可以设置atol=1e-5,小数点后五位有效数字即可。

pytorch_53">2、pytorch

pytorch有相似的api:

import numpy as np
import torch
actual= np.array([1.0, 2.0, 3.0])
expected = np.array([1.01, 1.99, 3.0])
torch.testing.assert_close(torch.tensor(actual),torch.tensor(expected),rtol=0,atol=0.011)

以上不会有任何输出

import numpy as np
import torch
actual= np.array([1.0, 2.0, 3.0])
expected = np.array([1.01, 1.99, 3.0])
torch.testing.assert_close(torch.tensor(actual),torch.tensor(expected),rtol=0,atol=0.01)

在这里插入图片描述
相比numpy,多给出了相关最大误差的位置及允许的上限。


http://www.ppmy.cn/server/42263.html

相关文章

【计算机毕业设计】基于SSM++jsp的学院党员管理系统【源码+lw+部署文档+讲解】

目录 目 录 第1章 绪论 1.1 课题背景 1.2 课题意义 1.3 研究内容 第2章 开发环境与技术 2.1 MYSQL数据库 2.2 JSP技术 2.3 SSM框架 第3章 系统分析 3.1 可行性分析 3.1.1 技术可行性 3.1.2 经济可行性 3.1.3 操作可行性 3.2 系统流程 3.2.1 操作流程 3.2.2 登录流程 3.2.3 删…

Jmeter用jdbc实现对数据库的操作

我们在用Jmeter进行数据库的操作时需要用到配置组件“JDBC Connection Configuration”&#xff0c;通过配置相应的驱动能够让我们通过Jmeter实现对数据库的增删改查&#xff0c;这里我用的mysql数据库一起来看下是怎么实现的吧。 1.驱动包安装 在安装驱动之前我们要先查看当前…

MySQL事务(一)

事务是什么 在MySQL中&#xff0c;事务是一组操作&#xff0c;这些操作要么全部执行成功&#xff0c;要么全部失败。事务的主要目的是保证数据的一致性和完整性。它确保当我们对数据库进行一系列操作时&#xff0c;要么所有操作都生效&#xff0c;要么如果其中任何一个操作失败…

【八十七】【算法分析与设计】单调栈全新版本,右大于,左小于右小于等于,739. 每日温度,907. 子数组的最小值之和

739. 每日温度(右大于) 给定一个整数数组 temperatures &#xff0c;表示每天的温度&#xff0c;返回一个数组 answer &#xff0c;其中 answer[i] 是指对于第 i 天&#xff0c;下一个更高温度出现在几天后。如果气温在这之后都不会升高&#xff0c;请在该位置用 0 来代替。 示…

github新手用法

目录 1&#xff0c;github账号注册2&#xff0c;github登录3&#xff0c;新建一个仓库4&#xff0c;往仓库里面写入东西或者上传东西5&#xff0c; 下载Git软件并安装6 &#xff0c;获取ssh密钥7&#xff0c; 绑定ssh密钥8&#xff0c; 测试本地和github是否联通9&#xff0c;从…

Vue 之 后台管理系统的权限路由的管理

目录 前言实现理解三者的概念以及之间的关联账号&#xff08;用户&#xff09;角色菜单 用户权限授权相关概念实现代码实现登录跳转路由&#xff0c;路由守卫中进行权限验证按钮权限封装指令&#xff1a;调用&#xff08;其中一个页面参考&#xff09; 思路&#xff0c;操作流程…

在springboot项目中自定义404页面

今天点击菜单的时候不小心点开了一个不存在的页面&#xff0c;然后看到浏览器给的一个默认的404页面 后端的程序员都觉得这页面太丑了&#xff0c;那么怎么能自定义404页面呢&#xff1f; 很简单&#xff0c;在我们的springboot的静态资源目录下创建一个error包&#xff0c;然…

外挂知识库的论文总结(后续还会更新)

论文列表&#xff1a; 1.Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks2005.11401 (arxiv.org) 提出了RAG在知识密集型的nlp任务 2.Gar-meets-rag paradigm for zero-shot information re trieval 论文介绍了一种新的信息检索&#xff08;IR&#xff…