onnx报错解决-bert

devtools/2024/11/29 18:52:07/

 

一、定义

  1. UserWarning: Provided key output for dynamic axes is not a valid input/output name warnings.warn(

  2. 案例

  3. 实体识别bert 案例

  4. 转transformers 模型到onnx 接口解读

二、实现

https://huggingface.co/docs/transformers/main_classes/onnx#transformers.onnx.FeaturesManager

  1. UserWarning: Provided key output for dynamic axes is not a valid input/output name warnings.warn(

代码:

with torch.no_grad():symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}torch.onnx.export(model,(inputs["input_ids"], inputs["token_type_ids"], inputs["attention_mask"]),"./saves/bertclassify.onnx",opset_version=14,input_names=["input_ids", "token_type_ids", "attention_mask"],        output_names=["logits"],                                             dynamic_axes =    {'input_ids': symbolic_names,'attention_mask': symbolic_names,'token_type_ids': symbolic_names,'logits': symbolic_names})

改正后:原因: input_names 名字顺序与模型定义不一致导致。为了避免错误产生,应该标准化。如下2所示。

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = BertForSequenceClassification.from_pretrained(model_path)
model.eval()
onnx_config = FeaturesManager._SUPPORTED_MODEL_TYPE['bert']['sequence-classification']("./saves")
dummy_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework='pt')
from itertools import chain
with torch.no_grad():symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}torch.onnx.export(model,(inputs["input_ids"],inputs["attention_mask"], inputs["token_type_ids"]),"./saves/bertclassify.onnx",opset_version=14,input_names=["input_ids", "attention_mask", "token_type_ids"],      output_names=["logits"],                                            dynamic_axes =    {name: axes for name, axes in chain(onnx_config.inputs.items(), onnx_config.outputs.items())}
)
# #验证是否成功
import onnx
onnx_model=onnx.load("./saves/bertclassify.onnx")
onnx.checker.check_model(onnx_model)
print("无报错,转换成功")# #推理
import onnxruntime
ort_session=onnxruntime.InferenceSession("./saves/bertclassify.onnx", providers=['CPUExecutionProvider'])    #加载模型
ort_input={"input_ids":inputs["input_ids"].cpu().numpy(),"token_type_ids":inputs["token_type_ids"].cpu().numpy(),"attention_mask":inputs["attention_mask"].cpu().numpy()}
output_on = ort_session.run(["logits"], ort_input)[0]   #推理print(output_org.detach().numpy())
print(output_on)
assert np.allclose(output_org.detach().numpy(), output_on, 10-5)  #无报错

标准化:

output_onnx_path = "./saves/bertclassify.onnx"
from itertools import chain
dummy_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework='pt')torch.onnx.export(model,(dummy_inputs,),f=output_onnx_path,input_names=list(onnx_config.inputs.keys()),output_names=list(onnx_config.outputs.keys()),dynamic_axes={name: axes for name, axes in chain(onnx_config.inputs.items(), onnx_config.outputs.items())},do_constant_folding=True,opset_version=14,
)

全部:

import torch
devices=torch.device("cpu")
from transformers.onnx.features import FeaturesManager
import torch
from transformers import AutoTokenizer, BertForSequenceClassification
import numpy as np
model_path = "./saves"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = BertForSequenceClassification.from_pretrained(model_path)words = ["你叫什么名字"]
inputs = tokenizer(words, return_tensors='pt', padding=True)
model.eval()onnx_config = FeaturesManager._SUPPORTED_MODEL_TYPE['bert']['sequence-classification']("./saves")
dummy_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework='pt')
from itertools import chainoutput_org = model(**inputs).logitstorch.onnx.export(model,(dummy_inputs,),f=output_onnx_path,input_names=list(onnx_config.inputs.keys()),output_names=list(onnx_config.outputs.keys()),dynamic_axes={name: axes for name, axes in chain(onnx_config.inputs.items(), onnx_config.outputs.items())},do_constant_folding=True,opset_version=14,
)# #验证是否成功
import onnx
onnx_model=onnx.load("./saves/bertclassify.onnx")
onnx.checker.check_model(onnx_model)
print("无报错,转换成功")# #推理
import onnxruntime
ort_session=onnxruntime.InferenceSession("./saves/bertclassify.onnx", providers=['CPUExecutionProvider'])    #加载模型
ort_input={"input_ids":inputs["input_ids"].cpu().numpy(),"token_type_ids":inputs["token_type_ids"].cpu().numpy(),"attention_mask":inputs["attention_mask"].cpu().numpy()}
output_on = ort_session.run(["logits"], ort_input)[0]   #推理print(output_org.detach().numpy())
print(output_on)
assert np.allclose(output_org.detach().numpy(), output_on, 10-5)  #无报错

无任何警告产生

                 

  1. 实体识别案例

import onnxruntime
from itertools import chain
from transformers.onnx.features import FeaturesManagerconfig = ner_config
tokenizer = ner_tokenizer
model = ner_model
output_onnx_path = "bert-ner.onnx"onnx_config = FeaturesManager._SUPPORTED_MODEL_TYPE['bert']['sequence-classification'](config)
dummy_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework='pt')torch.onnx.export(model,(dummy_inputs,),f=output_onnx_path,input_names=list(onnx_config.inputs.keys()),output_names=list(onnx_config.outputs.keys()),dynamic_axes={name: axes for name, axes in chain(onnx_config.inputs.items(), onnx_config.outputs.items())},do_constant_folding=True,opset_version=onnx_config.default_onnx_opset,       #默认,报错改为14
)
  1. 转transformers 模型到onnx 接口解读

Huggingface:导出transformers模型到onnx_ONNX_程序员架构进阶_InfoQ写作社区

https://zhuanlan.zhihu.com/p/684444410


http://www.ppmy.cn/devtools/137990.html

相关文章

Sqoop的安装和配置,Sqoop的数据导入导出,MySQL对hdfs数据的操作

sqoop的安装基础是hive和mysql,没有安装好的同学建议去看一看博主的这一篇文章 Hive的部署,远程模式搭建,centos换源,linux上下载mysql。_hive-4.0.1-CSDN博客 好的那么接下来我们开始表演,由于hive是当时在hadoop03上…

Android so库的编译

在没弄明白so库编译的关系前,直接看网上博主的博文,常常会觉得云里雾里的,为什么一会儿通过Android工程cmake编译,一会儿又通过NDK命令去编译。两者编译的so库有什么区别? android版第三方库编译总体思路: 对于新手小白来说搞明白上面的总体思路图很有必…

PAT甲级 1056 Mice and Rice(25)

文章目录 题目题目大意基本思路AC代码总结 题目 原题链接 题目大意 给定参赛的老鼠数量为NP,每NG只老鼠分为一组,组中最胖的老鼠获胜,并进入下一轮,所有在本回合中失败的老鼠排名都相同,获胜的老鼠继续每NG只一组&am…

学习HTML第三十三天

学习文章目录 一.fieldset 与 legend 的使用(了解)二.表单总结三.框架标签 一.fieldset 与 legend 的使用(了解) fieldset 可以为表单控件分组、 legend 标签是分组的标题 二.表单总结 form表单: action 属性&#…

Spring Boot【四】

单例bean中使用多例bean 1.lookup-method方式实现 当serviceB中调用getServiceA的时候,系统自动将这个方法拦截,然后去spring容器中查找对应的serviceA对象然后返回 2.replaced-method:方法替换 我们可以对serviceB这个bean中的getServiceA…

泷羽sec-linux进阶

基础之linux进阶 声明! 学习视频来自B站up主 泷羽sec 有兴趣的师傅可以关注一下,如涉及侵权马上删除文章,笔记只是方便各位师傅的学习和探讨,文章所提到的网站以及内容,只做学习交流,其他均与本人以及泷羽…

软件/游戏提示:mfc42u.dll没有被指定在windows上运行如何解决?多种有效解决方法汇总分享

遇到“mfc42u.dll 没有被指定在 Windows 上运行”的错误提示,通常是因为系统缺少必要的运行库文件或文件损坏。以下是多种有效的解决方法,可以帮助你解决这个问题: 原因分析 出现这个错误的原因是Windows无法找到或加载MFC42u.dll文件。这可…

【jvm】什么是动态编译

目录 1. 说明2. 实现方式3. 应用场景 1. 说明 1.在Java中,动态编译指的是在程序运行时动态地编译Java源代码,生成字节码,并加载到JVM(Java虚拟机)中执行。2.动态编译是在程序运行时,根据需要编译Java源代码…