onnx报错解决-bert

server/2024/11/29 18:16:59/

 

一、定义

  1. UserWarning: Provided key output for dynamic axes is not a valid input/output name warnings.warn(

  2. 案例

  3. 实体识别bert 案例

  4. 转transformers 模型到onnx 接口解读

二、实现

https://huggingface.co/docs/transformers/main_classes/onnx#transformers.onnx.FeaturesManager

  1. UserWarning: Provided key output for dynamic axes is not a valid input/output name warnings.warn(

代码:

with torch.no_grad():symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}torch.onnx.export(model,(inputs["input_ids"], inputs["token_type_ids"], inputs["attention_mask"]),"./saves/bertclassify.onnx",opset_version=14,input_names=["input_ids", "token_type_ids", "attention_mask"],        output_names=["logits"],                                             dynamic_axes =    {'input_ids': symbolic_names,'attention_mask': symbolic_names,'token_type_ids': symbolic_names,'logits': symbolic_names})

改正后:原因: input_names 名字顺序与模型定义不一致导致。为了避免错误产生,应该标准化。如下2所示。

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = BertForSequenceClassification.from_pretrained(model_path)
model.eval()
onnx_config = FeaturesManager._SUPPORTED_MODEL_TYPE['bert']['sequence-classification']("./saves")
dummy_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework='pt')
from itertools import chain
with torch.no_grad():symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}torch.onnx.export(model,(inputs["input_ids"],inputs["attention_mask"], inputs["token_type_ids"]),"./saves/bertclassify.onnx",opset_version=14,input_names=["input_ids", "attention_mask", "token_type_ids"],      output_names=["logits"],                                            dynamic_axes =    {name: axes for name, axes in chain(onnx_config.inputs.items(), onnx_config.outputs.items())}
)
# #验证是否成功
import onnx
onnx_model=onnx.load("./saves/bertclassify.onnx")
onnx.checker.check_model(onnx_model)
print("无报错,转换成功")# #推理
import onnxruntime
ort_session=onnxruntime.InferenceSession("./saves/bertclassify.onnx", providers=['CPUExecutionProvider'])    #加载模型
ort_input={"input_ids":inputs["input_ids"].cpu().numpy(),"token_type_ids":inputs["token_type_ids"].cpu().numpy(),"attention_mask":inputs["attention_mask"].cpu().numpy()}
output_on = ort_session.run(["logits"], ort_input)[0]   #推理print(output_org.detach().numpy())
print(output_on)
assert np.allclose(output_org.detach().numpy(), output_on, 10-5)  #无报错

标准化:

output_onnx_path = "./saves/bertclassify.onnx"
from itertools import chain
dummy_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework='pt')torch.onnx.export(model,(dummy_inputs,),f=output_onnx_path,input_names=list(onnx_config.inputs.keys()),output_names=list(onnx_config.outputs.keys()),dynamic_axes={name: axes for name, axes in chain(onnx_config.inputs.items(), onnx_config.outputs.items())},do_constant_folding=True,opset_version=14,
)

全部:

import torch
devices=torch.device("cpu")
from transformers.onnx.features import FeaturesManager
import torch
from transformers import AutoTokenizer, BertForSequenceClassification
import numpy as np
model_path = "./saves"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = BertForSequenceClassification.from_pretrained(model_path)words = ["你叫什么名字"]
inputs = tokenizer(words, return_tensors='pt', padding=True)
model.eval()onnx_config = FeaturesManager._SUPPORTED_MODEL_TYPE['bert']['sequence-classification']("./saves")
dummy_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework='pt')
from itertools import chainoutput_org = model(**inputs).logitstorch.onnx.export(model,(dummy_inputs,),f=output_onnx_path,input_names=list(onnx_config.inputs.keys()),output_names=list(onnx_config.outputs.keys()),dynamic_axes={name: axes for name, axes in chain(onnx_config.inputs.items(), onnx_config.outputs.items())},do_constant_folding=True,opset_version=14,
)# #验证是否成功
import onnx
onnx_model=onnx.load("./saves/bertclassify.onnx")
onnx.checker.check_model(onnx_model)
print("无报错,转换成功")# #推理
import onnxruntime
ort_session=onnxruntime.InferenceSession("./saves/bertclassify.onnx", providers=['CPUExecutionProvider'])    #加载模型
ort_input={"input_ids":inputs["input_ids"].cpu().numpy(),"token_type_ids":inputs["token_type_ids"].cpu().numpy(),"attention_mask":inputs["attention_mask"].cpu().numpy()}
output_on = ort_session.run(["logits"], ort_input)[0]   #推理print(output_org.detach().numpy())
print(output_on)
assert np.allclose(output_org.detach().numpy(), output_on, 10-5)  #无报错

无任何警告产生

                 

  1. 实体识别案例

import onnxruntime
from itertools import chain
from transformers.onnx.features import FeaturesManagerconfig = ner_config
tokenizer = ner_tokenizer
model = ner_model
output_onnx_path = "bert-ner.onnx"onnx_config = FeaturesManager._SUPPORTED_MODEL_TYPE['bert']['sequence-classification'](config)
dummy_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework='pt')torch.onnx.export(model,(dummy_inputs,),f=output_onnx_path,input_names=list(onnx_config.inputs.keys()),output_names=list(onnx_config.outputs.keys()),dynamic_axes={name: axes for name, axes in chain(onnx_config.inputs.items(), onnx_config.outputs.items())},do_constant_folding=True,opset_version=onnx_config.default_onnx_opset,       #默认,报错改为14
)
  1. 转transformers 模型到onnx 接口解读

Huggingface:导出transformers模型到onnx_ONNX_程序员架构进阶_InfoQ写作社区

https://zhuanlan.zhihu.com/p/684444410


http://www.ppmy.cn/server/145959.html

相关文章

可视化建模以及UML期末复习篇----相关软件安装

作为一个过来人&#xff0c;我的建议是别过来。 一、可视化建模 <1>定义: 官方&#xff1a;一种使用图形符号来表示系统结构和行为的建模技术。 我&#xff1a;其实说白了就是把工作流程用图形画出来。懂不&#xff1f; <2>作用: 提高理解和分析复杂系统的能力。促…

【Linux】 进程是什么

0. 什么是进程&#xff0c;为什么要有进程&#xff1f; 1.操作系统为了更好的管理我们的软硬件&#xff0c;抽象出了许多概念&#xff0c;其中比较有代表的就是进程了。通俗的来说操作系统为了更好的管理加载到内存的程序&#xff0c;故引入进程的概念。 2.在操作系统学科中用P…

微信小游戏/抖音小游戏SDK接入踩坑记录

文章目录 前言问题记录1、用是否存在 wx 这个 API 来判断是微小平台还是抖小平台不生效2、微小支付的参数如何获取?3、iOS 平台不支持虚拟支付怎么办?微小 iOS 端支付时序图:抖小 iOS 端支付:4、展示广告时多次回调 onClose5、在使用单例时 this 引起的 bug6、使用 fetch 或…

【初级测试常用的sql命令及实例解析】

连接数据库 命令行语句&#xff08;以MySQL为例&#xff09;&#xff1a;mysql -u username -p。其中-u表示指定用户名&#xff0c;-p表示需要输入密码。解析&#xff1a;这是登录MySQL数据库服务器的基本命令。执行后&#xff0c;系统会提示输入密码&#xff0c;正确输入密码后…

C语言中常用的失败退出和成功返回

在 C 语言中&#xff0c;封装函数时&#xff0c;我们通常需要判断函数调用是否成功&#xff0c;并据此采取不同的操作。例如&#xff0c;在调用系统函数或库函数时&#xff0c;我们通常会使用一些错误处理机制&#xff0c;如 perror()、exit()、return 等&#xff0c;来输出错误…

102.【C语言】数据结构之用堆对数组排序

0.前置知识 向上调整: 向下调整: 1.对一个无序的数组排升序和降序 排升序问题 建大根堆还是小根堆? 错误想法 由小根堆的定义:树中所有的父节点的值都小于或等于孩子节点的值,这样排出来的数组时升序的,建小根堆调用向上调整函数即可(把画圈的地方改成<即可) arr未…

字符函数和字符串函数

字符分类函数 C语言中有⼀系列的函数是专门做字符分类的&#xff0c;也就是⼀个字符是属于什么类型的字符的。 这些函数的使用都需要包含⼀个头文件&#xff1a;ctype.h 这些函数的用法非常类似。 int islower ( int c )islower是能够判断参数部分是否是小写字母的。 通过返…

虚幻引擎---目录结构篇

一、引擎目录 成功安装引擎后&#xff0c;在安装路径下的Epic Games目录中可以找到与引擎版本对应的文件夹&#xff0c;其中的内容如下&#xff1a; Engine&#xff1a;包含构成引擎的所有源代码、内容等。 Binaries&#xff1a;包含可执行文件或编译期间创建的其他文件。Bui…