samout llm解码 幻觉更低更稳定

news/2024/12/20 19:42:29/

这段代码定义了一个简单的对话生成系统,包括模型加载、词汇表加载、以及基于给定提示生成文本的功能。下面是对代码的解析:

  1. load_model_and_voc(device="cpu"):

    • 该函数用于加载预训练的模型和词汇表(vocabulary)。它首先从文件 total_voc.pkl 中加载词汇表,并创建一个名为 SamOut 的神经网络实例。
    • 模型参数的数量被打印出来以供参考。
    • 然后尝试加载指定路径下的预训练权重到模型中,并将模型移动到指定的设备(CPU 或 GPU)上。
    • 最后设置模型为评估模式(.eval()),并返回模型和词汇表。
  2. gen_token(voc, model, prompt, max_len, rp=1.2, temp=0.13, top_k=16, device="cpu"):

    • 这个函数负责根据提供的提示(prompt)生成新的文本序列。
    • 它接受多个参数,包括词汇表、模型、初始提示、最大生成长度等。
    • 函数内部实现了重复抑制、温度调整和top-k采样等技术来控制生成文本的质量。
    • 使用softmax函数对模型输出进行处理,并通过多类别抽样选择下一个token。
    • 如果生成了特殊的开始标记 <|sos|>,则停止生成过程。
    • 生成的每个token会立即打印在屏幕上,形成即时响应的效果。
  3. t_infre():

    • 此函数是交互式推理循环,允许用户输入文本,然后调用 gen_token 函数来生成回应。
    • 它是一个无限循环,持续等待用户的输入直到程序被手动终止。
  4. if __name__ == '__main__':

    • 这部分代码确保当脚本作为主程序运行时,会执行某些特定的操作或测试。
    • 注释掉的代码可能是之前用于数据预处理、训练或其他实验的部分。
    • 最终调用了 t_infre() 函数来启动交互式推理。

需要注意的是,这里使用的 SamOut 类并没有在给出的代码片段中定义,因此你可能需要确保这个类已经被正确实现并在其他地方导入。此外,为了使代码能够正常工作,你需要确保所有依赖库(如 PyTorch 和 pandas)已经安装,并且所有提及的数据文件和模型权重文件都存在于正确的路径下。

def load_model_and_voc(device="cpu"):voc = pd.read_pickle("total_voc.pkl")net = SamOut(len(voc["voc"]), 1024 + 512, 64, 16)# net = SamOut(len(voc["voc"]), 512, 32, 8)print(sum([i.shape[0] * i.shape[1] for i in net.parameters() if len(i.shape) > 1]) + sum([i.shape[0] for i in net.parameters() if len(i.shape) == 1]))# net.load_state_dict(torch.load("pretrain_768.pth", map_location=device))# net.load_state_dict(torch.load("pretrain_sft_single.pth", map_location=device))net.load_state_dict(torch.load("pretrain_sft_single_1024.pth", map_location=device))# net.load_state_dict(torch.load("pretrain.pth", map_location=device))net.to(device)net.eval()return net, vocdef gen_token(voc, model, prompt, max_len, rp=1.2, temp=0.13, top_k=16, device="cpu"):print("agent:", end="", flush=True)for _ in range(max_len):prompt_list = []for i in prompt:if i not in voc["voc"]:prompt_list += [voc["voc"].index(ii) for ii in voc["voc0"].get(i)]else:prompt_list.append(voc["voc"].index(i))out, _ = model(torch.Tensor([prompt_list]).to(device).long())out = out[:, -1:]# 重复抑制for token_id in enumerate(prompt_list):out[:, :, token_id] /= rpscore = torch.softmax(out, -1)[0, 0]score, score_index = torch.sort(score,descending=True)score=score.detach().numpy()score_sum = np.cumsum(score)score_index = score_index.detach().numpy()score1=score[score_sum<0.8]if score1.size==0:score=score[:1]else:score=score1score_index=score_index[:score.size]out = score / tempv= out[:min(top_k, score.size)]idx_next = torch.multinomial(torch.Tensor(v), num_samples=1, generator=None)if voc["voc"][score_index[idx_next.item()]] == "<|sos|>":breakprompt += [voc["voc"][score_index[idx_next.item()]]]print(prompt[-1], end="", flush=True)def t_infre():model, voc = load_model_and_voc()while True:text = input("user:")gen_token(voc, model, ["<|user|>"] + list("{}".format(text)) + ["<|agent|>"], 64)print()if __name__ == '__main__':# print(pd.read_pickle("loss916"))# gen_one_voc()# gen_voc()# for i in range(17,18):#     gen_pre_data_align(i, 16)# train()# gen_sft_single_data_align()# train_single()# sft 推理  一本正经的胡说八道已练成t_infre()

http://www.ppmy.cn/news/1556711.html

相关文章

C++ Learning 友元•内部类

友元 在 C 中&#xff0c;友元是一个非常特殊的概念&#xff0c;它使得某些非成员函数或其他类能够访问类的私有成员和保护成员。通常&#xff0c;私有成员和保护成员只能在类的成员函数内部访问&#xff0c;但通过将函数或类声明为友元&#xff0c;能够绕过这个访问限制。 友元…

linux cpu 管理

视频教程&#xff1a;ubuntu cpu 管理_哔哩哔哩_bilibili 概述 平均负载&#xff0c;CPU 使用率&#xff0c;CPU上下文 1 平均负载 #查看命令&#xff1a; rootzyb:~# uptime 18:21:47 up 1:09, 2 users, load average: 0.00, 0.00, 0.00 依次则是过去 1 分钟、5 分钟、1…

在Visual Studio Code (VSCode) 中将终端输出重定向到一个文本文件中

在Visual Studio Code (VSCode) 中将终端输出重定向到一个文本文件中 在Visual Studio Code (VSCode) 中,你可以通过以下步骤将终端输出重定向到一个文本文件中: 1. 方法一 使用重定向符号 在终端中运行命令时,可以使用重定向符号 > 或 >> 将输出保存到文件中。 …

Android运行低版本项目可能遇到的问题

Android运行低版本项目可能遇到的问题 低版本项目总是遇到各种问题的&#xff0c;耐心点 一、gradle-xxx.xxx.xxx.zip一直下载不下来 在gradle-wrapper.properties可以试下 distributionBaseGRADLE_USER_HOME distributionPathwrapper/dists zipStoreBaseGRADLE_USER_HOME …

C/C++圣诞树

系列文章 序号直达链接1C/C爱心代码2C/C跳动的爱心3C/C李峋同款跳动的爱心代码4C/C满屏飘字表白代码5C/C大雪纷飞代码6C/C烟花代码7C/C黑客帝国同款字母雨8C/C樱花树代码9C/C奥特曼代码10C/C精美圣诞树11C/C俄罗斯方块12C/C贪吃蛇13C/C孤单又灿烂的神-鬼怪14C/C闪烁的爱心15C…

代码随想录刷题-数组

文章目录 1.二分查找1.答案2.思路3.扩展题目1.搜索插入位置1.答案2.思路 2.在排序数组中查找元素的第一个和最后一个位置1.答案2.思路 3.x 的平方根1.答案2.思路 4.有效的完全平方数1.答案2.思路 4.总结1.标准二分模板 2.移除元素1.答案2.思路3.扩展题目1.删除有序数组中的重复…

uniapp使用腾讯地图接口的时候提示此key每秒请求量已达到上限或者提示此key每日调用量已达到上限问题解决

要在创建的key上添加配额 点击配额之后进入分配页面&#xff0c;分配完之后刷新uniapp就可以调用成功了。

低比特语言模型 是一种利用较少比特数进行语言建模的技术

Vanilla LLM: 基础的全精度语言模型&#xff0c;通常在较高比特数下运作 Vanilla LLM&#xff0c;或称为“基础的全精度语言模型”&#xff0c;是指使用标准的浮点数&#xff08;通常是16位或32位&#xff09;进行训练和推理的语言模型。这些模型依赖于经典的神经网络结构&…