visionTransformer window平台下报错

news/2024/9/22 17:26:03/
  • 错误:
KeyError: 'Transformer/encoderblock_0/MlpBlock_3/Dense_0kernel is not a file in the archive'
  • 解决方法:

修改这个函数即可,主要原因是Linux系统与window系统路径分隔符不一样导致

    def load_from(self, weights, n_block):ROOT = f"Transformer/encoderblock_{n_block}"with torch.no_grad():# query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()# key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()# value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()# out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()query_weight = np2th(weights[(ROOT + '/' + ATTENTION_Q + "/kernel")]).view(self.hidden_size,self.hidden_size).t()key_weight = np2th(weights[(ROOT + '/' + ATTENTION_K + "/kernel")]).view(self.hidden_size,self.hidden_size).t()value_weight = np2th(weights[(ROOT + '/' + ATTENTION_V + "/kernel")]).view(self.hidden_size,self.hidden_size).t()out_weight = np2th(weights[(ROOT + '/' + ATTENTION_OUT + "/kernel")]).view(self.hidden_size,self.hidden_size).t()# query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)# key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)# value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)# out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)query_bias = np2th(weights[(ROOT + '/' + ATTENTION_Q + "/bias")]).view(-1)key_bias = np2th(weights[(ROOT + '/' + ATTENTION_K + "/bias")]).view(-1)value_bias = np2th(weights[(ROOT + '/' + ATTENTION_V + "/bias")]).view(-1)out_bias = np2th(weights[(ROOT + '/' + ATTENTION_OUT + "/bias")]).view(-1)self.attn.query.weight.copy_(query_weight)self.attn.key.weight.copy_(key_weight)self.attn.value.weight.copy_(value_weight)self.attn.out.weight.copy_(out_weight)self.attn.query.bias.copy_(query_bias)self.attn.key.bias.copy_(key_bias)self.attn.value.bias.copy_(value_bias)self.attn.out.bias.copy_(out_bias)mlp_weight_0 = np2th(weights[(ROOT + '/' + FC_0 + "/kernel")]).t()mlp_weight_1 = np2th(weights[(ROOT + '/' + FC_1 + "/kernel")]).t()mlp_bias_0 = np2th(weights[(ROOT + '/' + FC_0 +"/bias")]).t()mlp_bias_1 = np2th(weights[(ROOT + '/' + FC_1 + "/bias")]).t()self.ffn.fc1.weight.copy_(mlp_weight_0)self.ffn.fc2.weight.copy_(mlp_weight_1)self.ffn.fc1.bias.copy_(mlp_bias_0)self.ffn.fc2.bias.copy_(mlp_bias_1)self.attention_norm.weight.copy_(np2th(weights[(ROOT + '/' + ATTENTION_NORM + "/scale")]))self.attention_norm.bias.copy_(np2th(weights[(ROOT + '/' + ATTENTION_NORM +  "/bias")]))self.ffn_norm.weight.copy_(np2th(weights[(ROOT + '/' + MLP_NORM + "/scale")]))self.ffn_norm.bias.copy_(np2th(weights[(ROOT + '/' +  MLP_NORM + "/bias")]))

在这里插入图片描述


http://www.ppmy.cn/news/1437699.html

相关文章

使用rust学习基本算法(三)

使用rust学习基本算法(三) 动态规划 动态规划算法是一种在数学、管理科学、计算机科学和经济学中广泛使用的方法,用于解决具有重叠子问题和最优子结构特性的问题。它通过将复杂问题分解为更小的子问题来解决问题,这些子问题被称为…

第7章 面向对象基础-下(内部类)

7.6 内部类(了解) 7.6.1 概述 1、什么是内部类? 将一个类A定义在另一个类B里面,里面的那个类A就称为内部类,B则称为外部类。 2、为什么要声明内部类呢? 总的来说,遵循高内聚低耦合的面向对象开发总原则。便于代码…

使用Azure AI Search和LlamaIndex构建高级RAG应用

RAG 是一种将公司信息合并到基于大型语言模型 (LLM) 的应用程序中的常用方法。借助 RAG,AI 应用程序可以近乎实时地访问最新信息,团队可以保持对其数据的控制。 在 RAG 中,您可以评估和修改各个阶段以改进结果&#x…

Node.js -- path模块

path.resolve(常用) // 导入fs const fs require(fs); // 写入文件 fs.writeFileSync(_dirname /index.html,love); console.log(_dirname /index.html);// D:\nodeJS\13-path\代码/index.html 我们之前使用的__dirname 路径 输出的结果前面是正斜杠/ ,后面部分是…

acwing算法提高之图论--欧拉回路和欧拉路径

目录 1 介绍2 训练 1 介绍 本专题用来记录欧拉回路和欧拉路径相关的题目。 相关结论: (1)对于无向图,所有边都是连通的。 (1.1)存在欧拉路径的充要条件:度数为奇数的结点只能是0个或者2个。 &…

【银角大王——Django课程——创建项目+部门表的基本操作】

Django框架员工管理系统——创建项目部门表管理 员工管理系统创建项目命令行的形式创建Django项目——创建app注册app——在sttings中的INSTALLED_APPS [ ]数组中注册 设计表结构(django)连接数据库——在settings里面改写DATABASESDjango命令执行生成数…

网贷大数据黑名单要多久才能变正常?

网贷大数据黑名单是指个人在网贷平台申请贷款时,因为信用记录较差而被列入黑名单,无法获得贷款或者贷款额度受到限制的情况。网贷大数据黑名单的具体时间因个人信用状况、所属平台政策以及银行审核标准不同而异,一般来说,需要一定…

TDesign:腾讯的企业级前端框架,对标elementUI和ant-design

elementUI和ant-design在前端开发者中有了很高知名度了,组件和资源十分丰富了。本文介绍腾讯的一款B端框架:TDesign TDesign 是腾讯公司内部推出的企业级设计体系,旨在为腾讯旗下的各种产品提供一致、高效、优质的设计支持。这个设计体系是由…