给rwkv_pytorch配了个流式服务和请求demo

ops/2024/9/23 14:35:30/

项目地址

rwkv_pytorch

服务端

import json
import uuid
import timeimport torch
from src.model import RWKV_RNN
from src.sampler import sample_logits
from src.rwkv_tokenizer import RWKV_TOKENIZER
from flask import Flask, request, jsonify, Responseapp = Flask(__name__)# 初始化模型和分词器
def init_model():# 模型参数配置args = {'MODEL_NAME': 'E:/RWKV_Pytorch/weight/RWKV-x060-World-1B6-v2-20240208-ctx4096','vocab_size': 65536,'device': "cpu",'onnx_opset': '18',}device = args['device']assert device in ['cpu', 'cuda', 'musa', 'npu']if device == "musa":import torch_musaelif device == "npu":import torch_npumodel = RWKV_RNN(args).to(device)tokenizer = RWKV_TOKENIZER("asset/rwkv_vocab_v20230424.txt")return model, tokenizer, devicedef format_messages_to_prompt(messages):formatted_prompt = ""# 定义角色映射到期望的名称role_names = {"system": "System","assistant": "Assistant","user": "User"}# 遍历消息并格式化for message in messages:role = role_names.get(message['role'], 'Unknown')  # 获取角色名称,默认为'Unknown'content = message['content']formatted_prompt += f"{role}: {content}\n\n"  # 添加角色和内容到提示,并添加换行符formatted_prompt += "Assistant: "return formatted_promptdef generate_text_stream(prompt: str, temperature=1.5, top_p=0.1, max_tokens=2048, stop=['\n\nUser']):encoded_input = tokenizer.encode([prompt])token = torch.tensor(encoded_input).long().to(device)state = torch.zeros(1, model.state_size[0], model.state_size[1]).to(device)with torch.no_grad():token_out, state_out = model.forward_parallel(token, state)del tokenout = token_out[:, -1]generated_tokens = ''completion_tokens = 0if_max_token = Truefor step in range(max_tokens):token_sampled = sample_logits(out, temperature, top_p)with torch.no_grad():out, state = model.forward(token_sampled, state)last_token = tokenizer.decode(token_sampled.unsqueeze(1).tolist())[0]generated_tokens += last_tokencompletion_tokens += 1if generated_tokens.endswith(tuple(stop)):if_max_token = Falseresponse = {"object": "chat.completion.chunk","model": "rwkv","choices": [{"delta": "","index": 0,"finish_reason": "stop"}]}yield f"data: {json.dumps(response)}\n\n"else:response = {"object": "chat.completion.chunk","model": "rwkv","choices": [{"delta": {"content": last_token},"index": 0,"finish_reason": None}]}yield f"data: {json.dumps(response)}\n\n"if if_max_token:response = {"object": "chat.completion.chunk","model": "rwkv","choices": [{"delta": "","index": 0,"finish_reason": "length"}]}yield f"data: {json.dumps(response)}\n\n"yield f"data:[DONE]\n\n"def generate_text(prompt, temperature=1.5, top_p=0.1, max_tokens=2048, stop=['\n\nUser']):encoded_input = tokenizer.encode([prompt])token = torch.tensor(encoded_input).long().to(device)state = torch.zeros(1, model.state_size[0], model.state_size[1]).to(device)prompt_tokens = len(encoded_input[0])with torch.no_grad():token_out, state_out = model.forward_parallel(token, state)del tokenout = token_out[:, -1]completion_tokens = 0if_max_token = Truegenerated_tokens = ''for step in range(max_tokens):token_sampled = sample_logits(out, temperature, top_p)with torch.no_grad():out, state = model.forward(token_sampled, state)# 判断是否达到停止条件last_token = tokenizer.decode(token_sampled.unsqueeze(1).tolist())[0]completion_tokens += 1print(last_token, end='')generated_tokens += last_tokenfor stop_token in stop:if generated_tokens.endswith(stop_token):generated_tokens = generated_tokens.replace(stop_token, "")  # 替换掉终止tokenif_max_token = Falsebreak# 如果末尾含有 stop 列表中的字符串,则停止生成if not if_max_token:breaktotal_tokens = prompt_tokens + completion_tokensusage = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens}return generated_tokens, if_max_token, usage@app.route('/events', methods=['POST'])
def sse_request():try:# 从查询字符串中获取参数data = request.jsonmessages = data.get('messages', [])stream = data.get('stream', True) == Truetemperature = float(data.get('temperature', 0.5))top_p = float(data.get('top_p', 0.9))max_tokens = int(data.get('max_tokens', 100))stop = data.get('stop', ['\n\nUser'])prompt = format_messages_to_prompt(messages)if stream:return Response(generate_text_stream(prompt=prompt, temperature=temperature, top_p=top_p,max_tokens=max_tokens, stop=stop),content_type='text/event-stream')else:completion, if_max_token, usage = generate_text(prompt, temperature=temperature, top_p=top_p,max_tokens=max_tokens, stop=stop)finish_reason = "stop" if if_max_token else "length"unique_id = str(uuid.uuid4())current_timestamp = int(time.time())response = {"id": unique_id,"object": "chat.completion","created": current_timestamp,"choices": [{"index": 0,"message": {"role": "assistant","content": completion,},"finish_reason": finish_reason}],"usage": usage}return json.dumps(response)except Exception as e:return json.dumps({"error": str(e)}), 500if __name__ == '__main__':model, tokenizer, device = init_model()app.run(debug=False)

解释

  • 首先引入了需要的库,包括json用于处理JSON数据,uuid用于生成唯一标识符,time用于获取当前时间戳,torch用于构建和运行模型,Flask用于构建API。
  • 定义了一个名为app的Flask应用。
  • init_model函数用于初始化模型和分词器。其中,模型参数通过字典args指定。
  • format_messages_to_prompt函数用于将消息格式化为提示字符串,以便于模型生成回复。遍历消息列表,获取每个消息的角色和内容,并添加到提示字符串中。
  • generate_text_stream函数用于以流的形式生成文本。首先将输入的提示字符串编码为张量,然后利用模型生成回复,并利用yield关键字将回复以SSE(服务器发送事件)的形式返回。
  • generate_text函数用于一次性生成完整的文本回复。与generate_text_stream函数类似,不同的是返回的是完整的回复字符串。
  • sse_request函数是Flask应用的主要逻辑,用于处理POST请求。从请求的JSON数据中获取参数,并根据参数的设置调用相应的生成函数。如果参数中设置了stream=True,则返回流式生成的回复;否则返回一次性生成的回复。
  • __main__函数中初始化模型和分词器,然后运行Flask应用。

客户端

import jsonimport requests
from requests import RequestException# 配置服务器URL
url = 'http://localhost:5000/events'  # 假设您的Flask应用运行在本地端口5000上# POST请求示例
def post_request_stream():# 构造请求数据data = {'messages': [{'role': 'system', 'content': '你好!'},{'role': 'user', 'content': '你能告诉我今天的天气吗?'}],'temperature': 0.5,'top_p': 0.9,'max_tokens': 100,'stop': ['\n\nUser'],'stream':True}# 使用 requests 库来连接服务器,并传递参数try:with  requests.post(url, json=data, stream=True) as r:for line in r.iter_lines():if line:# 当服务器发送消息时,解码并打印出来decoded_line = line.decode('utf-8')print(json.loads(decoded_line[5:])["choices"][0]["delta"], end="")except RequestException as e:print(f'An error occurred: {e}')def post_request():# 构造请求数据data = {'messages': [{'role': 'system', 'content': '你好!'},{'role': 'user', 'content': '你能告诉我今天的天气吗?'}],'temperature': 0.5,'top_p': 0.9,'max_tokens': 100,'stop': ['\n\nUser'],'stream':False}# 使用 requests 库来连接服务器,并传递参数try:with  requests.post(url, json=data, stream=True) as r:for line in r.iter_lines():if line:# 当服务器发送消息时,解码并打印出来decoded_line = line.decode('utf-8')res=json.loads(decoded_line)print(res)except RequestException as e:print(f'An error occurred: {e}')if __name__ == '__main__':# post_request()post_request_stream()

解释

这段代码是一个用于向服务器发送POST请求的示例代码。

首先,我们需要导入一些必要的库。json库用于处理JSON数据,requests库用于发送HTTP请求,RequestException用于处理请求异常。

接下来,我们需要配置服务器的URL。在这个示例中,假设服务器运行在本地端口5000上。

代码中定义了两个函数post_request_streampost_request,分别用于发送带有流式响应和非流式响应的POST请求。

post_request_stream函数构造了一个包含各种参数的数据字典,并使用requests.post方法发送POST请求。在请求的参数中,stream参数被设置为True,表示我们希望获得一个流式的响应。接着,我们使用r.iter_lines()方法来迭代获取服务器发送的消息。每收到一行消息,我们将其解码并打印出来。

post_request函数的代码结构与post_request_stream函数相似,不同之处在于stream参数被设置为False,表示我们希望获得一个非流式的响应。

最后,在程序的主体部分,我们调用post_request_stream函数来发送流式的POST请求,并注释掉了post_request函数的调用。


http://www.ppmy.cn/ops/11821.html

相关文章

学习笔记-数据结构-线性表(2024-04-16)

设计一个算法判断单链表中元素是否是递增的。 设计思想:双指针操作 变量说明: head表示链表头指针 p和q表示两个用来遍历链表的指针节点,且q始终在p之后 bool IsIncrease(LinkList *head) {// 代码优先判空,若为空链表&#xff…

linux的一些实用操作

快捷键 强制停止 ctrlc强制停止或退出命令的输入 退出登出 ctrld强制退出用户登录或退出某些程序的专属页面(如py) ps:不能退出vi/vim 历史命令搜索 history可以查看历史命令,用来复制粘贴 在使用history之后,…

Midjourney 中文文档

快速使用 学习如何在Discord上使用Midjourney Bot从简单的文本提示中创建自定义图像。 行为准则 不要表现出不良行为。不要使用我们的工具制作可能引起煽动,不安或引起争议的图像。这包括血腥和成人内容。尊重其他人和团队。 1:加入Discord 访问Midj…

C#开发UdpClient无法在局域网中发送UDP广播包,但能接收的解决办法

# 记得开发好的软件原来可以使用的,今天突然不正常了,还以为哪里修改过了。 在网上看到了一个网友的文章:https://www.cnblogs.com/kissazi2/archive/2012/12/07/2806533.html 我虽然没有安装虚拟机,没有VMware的虚拟网卡&#…

西瓜书学习——对数几率回归

对数几率回归(Logistic Regression)是一种广泛应用于分类问题的统计方法,特别是用于二分类问题。尽管它的名字中包含“回归”,但它实际上是一种分类算法,用于估计一个样本属于某个类别的概率。 对数几率回归的核心是使…

C++ day1

const char *p; 可以改变p的值,不可以改变p指向的字符的值。 const (char *) p; 语法错误 char *const p; 可以改变p指向的字符值 不可以改变p的值 const char* const p; 都不可以改变 char const *p; 可以改变p的值 不可以改变 p指向的字符的值 …

CUDA 以及MPI并行矩阵乘连接服务器运算vscode配置

一、CUDA Vscode配置 (一)扩展安装 本地安装 服务器端安装 (二) CUDA 配置 .vscode c_cpp_properties.json {"configurations": [{"name": "Linux","includePath": ["${workspa…

共商共建共享是“一带一路”建设的重要指导原则。其中,“共享”是指实现(),充分调动各方面积极性。

共商共建共享是“一带一路”建设的重要指导原则。其中,“共享”是指实现(),充分调动各方面积极性。 请点击查看答案 A.均衡发展B.优势互补 C.平等互惠D.互利共嬴 “一带一路”是繁荣之路。推进“一带一路”建设要深入开展产业合作,推动各国…