Google FLASH-QUAD Transformer模型的设计雷点

news/2024/12/28 23:51:00/

这个模型用来做无序弱监督分类,效果好,特别是收敛速度比标准多头Attention层快多了,完全没得比。

问题1

但这模型我用来做自回归生成,非常垃圾。
同时尝试了 GPT 和 T5 这两种模型结构的设计,明明Loss正常下降,可是自回归生成性能非常的烂,不知原因为何。

不服输,最近再来尝试FLASH,毕竟性能太过于吸引人。碰巧单步调试了一下自回归生成的过程。
卧槽,意外发现cause掩码失效,前一个时间步的输出会被后一个时间步的输入影响,

一步步排查,排查到注意力矩阵的生成
注意到这个 1/n 的 n 是可变的。直接把 n 去掉,使注意力矩阵的值不再受序列长度的缩放。
下图来自苏神的博客
在这里插入图片描述
对应到代码,在 lucidrains 的代码里面 https://github.com/lucidrains/FLASH-pytorch/blob/main/flash_pytorch/flash_pytorch.py#L190

sim = einsum('b i d, b j d -> b i j', q, k) / seq_len

我将其改为一个定值

sim = einsum('b i d, b j d -> b i j', q, k) / q.shape[-1]

改为,现在 前一个时间步的输出不再 被后一个时间步的输入影响了。

问题1.1

改为定值后,尚未实验,但预计超出训练长度后(例如最大训练文本长度为512,测试文本长度为768),性能会有显著下降。

问题2

修改完,初步的训练后,自回归生成能力有了大幅的提升了。
但仍然存在问题,这个注意力方法的局部关注能力似乎很弱,意思为经常见到连续生成同义的词
例如(空格代表分词)
标签为

树叶 静静地 燃烧 起来

自回归生成(使用sample策略)多见这样的生成范式(不是必定出现)

树叶 静静地 安静地 燃烧 起来

相近意思的词会有时多生成一次,一般的多头注意力出现这样的情况非常少见,推测该设计的局部关注能力较弱。

类似的讨论

https://github.com/JunnYu/FLASHQuad_pytorch/issues/1


http://www.ppmy.cn/news/47428.html

相关文章

Python‘s Standard Library :Networking

Python’s Standard Library :Networking Python的标准库为创建网络服务和远程访问服务提供了一些模块。例如:ipaddress, socket, socketserver 等。 Python’s standard library comes complete with modules for creating network services, as well …

JAVA面试宝典: SpringCloud知识点(通俗易懂易背)

1、什么是 Spring Cloud? Spring Cloud 是基于 Spring Boot 的微服务架构开发工具箱,提供了在分布式系统中构建可靠的、弹性的、灵活的应用所需的大多数工具。Spring Cloud 中包含的子项目如下: Spring Cloud Config:配置管理工具…

STL :双端队列容器 Deque

Deque #include<deque> using namesace std; 双端队列容器 &#xff1a;双向开口的连续线性空间&#xff1b; 擅长尾部和头部添加或删除元素&#xff1a;常数阶&#xff1b; 存储元素并不能保证所有元素都存储到连续的内存空间中&#xff1b; deque 是动态的以分段…

系统需求分析

系统需求分析 需求分析是软件生存周期中相当重要的一个阶段。由于开发人员熟悉计算机但不熟悉应用 领域的业务&#xff0c;用户熟悉应用领域的业务但不熟悉计算机&#xff0c;因此对于同一个问题&#xff0c;开发人员和用 户之间可能存在认识上的差异。在需求分析阶段&#xff…

java记录-lambda表达式、接口应用、方法引用

基本形式 (str)->{System.out.println(str) };调用作为参数的接口实例的方法 1、用一个类实现接口&#xff0c;然后使用该类实例调用方法 2、匿名内部类 3、在 接口&#xff08;不能是抽象类&#xff09; 有且只有一个抽象方法时&#xff0c;可以使用lamda表达式来重写这个…

蓝桥 卷“兔”来袭编程竞赛专场-07明码加密 题解

赛题介绍 挑战介绍 清末&#xff0c;电报技术进入中国。上海大北水线电报公司在 1871 年选用了六千八百九十七个汉字&#xff0c;代以四码数字&#xff0c;编写成了中国最早的电报明码本。为了传输的内容可以保密&#xff0c;又设计出了将明码本加密的方法&#xff0c;于是就…

华为OD机试真题(Java),最小步骤数(100%通过+复盘思路)

一、题目描述 一个正整数数组 设为nums&#xff0c;最大为100个成员&#xff0c;求从第一个成员开始正好走到数组最后一个成员所使用的最小步骤数。 要求&#xff1a; 第一步 必须从第一元素起 且 1<第一步步长<len/2 (len为数组长度)&#xff1b;从第二步开始只能以所…

Junit概述和快速入门

单元测试概述 在程序中&#xff0c;一个单元可以是一个完整的模块&#xff0c;但它通常是一个单独的方法或者程序 在面向对象的编程中&#xff0c;一个单元通常是整个界面&#xff0c;例如类&#xff0c;但可能是单个方法 JUnit是一个java编程语言的单元测试框架 通过先为最…