使用自定义大模型来部署Wren AI(开源的文本生成SQL方案)

ops/2025/2/22 17:26:31/

使用自定义大模型来部署Wren AI(开源的文本生成SQL方案)

关于

  • 首次发表日期:2024-07-15
  • Wren AI官方文档: https://docs.getwren.ai/overview/introduction
  • Wren AI Github仓库: https://github.com/Canner/WrenAI

关于Wren AI

Wren AI 是一个开源的文本生成SQL解决方案。

前提准备

由于之后会使用docker来启动服务,所以首先确保docker已经安装好了,并且网络没问题。

先克隆仓库:

git clone https://github.com/Canner/WrenAI.git

关于在Wren AI中使用自定义大模型和Embedding模型

Wren AI目前是支持自定义LLM和Embedding模型的,其官方文档 https://docs.getwren.ai/installation/custom_llm 中有提及,需要创建自己的provider类。

其中Wren AI本身已经支持和OPEN AI兼容的大模型了;但是自定义的Embedding模型方面,可能会报错,具体来说是wren-ai-service/src/providers/embedder/openai.py中的以下代码

if self.dimensions is not None:response = await self.client.embeddings.create(model=self.model, dimensions=self.dimensions, input=text_to_embed)
else:response = await self.client.embeddings.create(model=self.model, input=text_to_embed)

其中if self.dimensions is not None这个条件分支是会报错的(默认会运行这个分支),所以我的临时解决方案是注释掉它。

具体而言是在wren-ai-service/src/providers/embedder文件夹中创建一个openai_like.py文件,表示定义一个和open ai类似的embedding provider,取个名字叫做openai_like_embedder,具体的完整代码见本文附录。

配置docker环境变量等并启动服务

首先,进入docker文件夹,拷贝.env.example并重命名为.env.local

然后拷贝.env.ai.example并重命名为.env.ai,修改其中的LLM和Embedding的配置,相关部分如下:

LLM_PROVIDER=openai_llm
LLM_OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxxxxx
LLM_OPENAI_API_BASE=http://api.siliconflow.cn/v1
GENERATION_MODEL=meta-llama/Meta-Llama-3-70B
# GENERATION_MODEL_KWARGS={"temperature": 0, "n": 1, "max_tokens": 32768, "response_format": {"type": "json_object"}}EMBEDDER_PROVIDER=openai_like_embedder
EMBEDDING_MODEL=bge-m3
EMBEDDING_MODEL_DIMENSION=1024
EMBEDDER_OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxxxxxxx
EMBEDDER_OPENAI_API_BASE=https://xxxxxxxxxxxxxxxx/v1

由于我们创建了一个自定义的embedding provider,需要将文件映射到docker容器中,具体可以通过配置docker-compose.yaml中的wren-ai-service,添加volumes属性:

wren-ai-service:image: ghcr.io/canner/wren-ai-service:${WREN_AI_SERVICE_VERSION}volumes:- /root/WrenAI/wren-ai-service/src:/src

最后,启动服务:

docker-compose -f docker-compose.yaml -f docker-compose.llm.yaml --env-file .env.local --env-file .env.ai up -d

或者停止服务:

docker-compose -f docker-compose.yaml -f docker-compose.llm.yaml --env-file .env.local --env-file .env.ai down

附录

openai_like.py文件(提供自定义embedding服务):

import logging
import os
from typing import Any, Dict, List, Optional, Tupleimport backoff
import openai
from haystack import Document, component
from haystack.components.embedders import OpenAIDocumentEmbedder, OpenAITextEmbedder
from haystack.utils import Secret
from openai import AsyncOpenAI, OpenAI
from tqdm import tqdmfrom src.core.provider import EmbedderProvider
from src.providers.loader import providerimport logging
import syslogging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))logger = logging.getLogger("wren-ai-service")EMBEDDER_OPENAI_API_BASE = "https://api.openai.com/v1"
EMBEDDING_MODEL = "text-embedding-3-large"
EMBEDDING_MODEL_DIMENSION = 3072@component
class AsyncTextEmbedder(OpenAITextEmbedder):def __init__(self,api_key: Secret = Secret.from_env_var("EMBEDDER_OPENAI_API_KEY"),model: str = "text-embedding-ada-002",dimensions: Optional[int] = None,api_base_url: Optional[str] = None,organization: Optional[str] = None,prefix: str = "",suffix: str = "",):super(AsyncTextEmbedder, self).__init__(api_key,model,dimensions,api_base_url,organization,prefix,suffix,)self.client = AsyncOpenAI(api_key=api_key.resolve_value(),organization=organization,base_url=api_base_url,)@component.output_types(embedding=List[float], meta=Dict[str, Any])@backoff.on_exception(backoff.expo, openai.RateLimitError, max_time=60, max_tries=3)async def run(self, text: str):if not isinstance(text, str):raise TypeError("OpenAITextEmbedder expects a string as an input.""In case you want to embed a list of Documents, please use the OpenAIDocumentEmbedder.")logger.debug(f"Running Async OpenAI text embedder with text: {text}")text_to_embed = self.prefix + text + self.suffix# copied from OpenAI embedding_utils (https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py)# replace newlines, which can negatively affect performance.text_to_embed = text_to_embed.replace("
", " ")# if self.dimensions is not None:#     response = await self.client.embeddings.create(#         model=self.model, dimensions=self.dimensions, input=text_to_embed#     )# else:response = await self.client.embeddings.create(model=self.model, input=text_to_embed)meta = {"model": response.model, "usage": dict(response.usage)}return {"embedding": response.data[0].embedding, "meta": meta}@component
class AsyncDocumentEmbedder(OpenAIDocumentEmbedder):def __init__(self,api_key: Secret = Secret.from_env_var("EMBEDDER_OPENAI_API_KEY"),model: str = "text-embedding-ada-002",dimensions: Optional[int] = None,api_base_url: Optional[str] = None,organization: Optional[str] = None,prefix: str = "",suffix: str = "",batch_size: int = 32,progress_bar: bool = True,meta_fields_to_embed: Optional[List[str]] = None,embedding_separator: str = "
",):super(AsyncDocumentEmbedder, self).__init__(api_key,model,dimensions,api_base_url,organization,prefix,suffix,batch_size,progress_bar,meta_fields_to_embed,embedding_separator,)self.client = AsyncOpenAI(api_key=api_key.resolve_value(),organization=organization,base_url=api_base_url,)async def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]:all_embeddings = []meta: Dict[str, Any] = {}for i in tqdm(range(0, len(texts_to_embed), batch_size),disable=not self.progress_bar,desc="Calculating embeddings",):batch = texts_to_embed[i : i + batch_size]# if self.dimensions is not None:#     response = await self.client.embeddings.create(#         model=self.model, dimensions=self.dimensions, input=batch#     )# else:response = await self.client.embeddings.create(model=self.model, input=batch)embeddings = [el.embedding for el in response.data]all_embeddings.extend(embeddings)if "model" not in meta:meta["model"] = response.modelif "usage" not in meta:meta["usage"] = dict(response.usage)else:meta["usage"]["prompt_tokens"] += response.usage.prompt_tokensmeta["usage"]["total_tokens"] += response.usage.total_tokensreturn all_embeddings, meta@component.output_types(documents=List[Document], meta=Dict[str, Any])@backoff.on_exception(backoff.expo, openai.RateLimitError, max_time=60, max_tries=3)async def run(self, documents: List[Document]):if (not isinstance(documents, list)or documentsand not isinstance(documents[0], Document)):raise TypeError("OpenAIDocumentEmbedder expects a list of Documents as input.""In case you want to embed a string, please use the OpenAITextEmbedder.")logger.debug(f"Running Async OpenAI document embedder with documents: {documents}")texts_to_embed = self._prepare_texts_to_embed(documents=documents)embeddings, meta = await self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)for doc, emb in zip(documents, embeddings):doc.embedding = embreturn {"documents": documents, "meta": meta}@provider("openai_like_embedder")
class OpenAIEmbedderProvider(EmbedderProvider):def __init__(self,api_key: Secret = Secret.from_env_var("EMBEDDER_OPENAI_API_KEY"),api_base: str = os.getenv("EMBEDDER_OPENAI_API_BASE")or EMBEDDER_OPENAI_API_BASE,embedding_model: str = os.getenv("EMBEDDING_MODEL") or EMBEDDING_MODEL,embedding_model_dim: int = (int(os.getenv("EMBEDDING_MODEL_DIMENSION"))if os.getenv("EMBEDDING_MODEL_DIMENSION")else 0)or EMBEDDING_MODEL_DIMENSION,):def _verify_api_key(api_key: str, api_base: str) -> None:"""this is a temporary solution to verify that the required environment variables are set"""OpenAI(api_key=api_key, base_url=api_base).models.list()logger.info(f"Initializing OpenAIEmbedder provider with API base: {api_base}")# TODO: currently only OpenAI api key can be verifiedif api_base == EMBEDDER_OPENAI_API_BASE:_verify_api_key(api_key.resolve_value(), api_base)logger.info(f"Using OpenAI Embedding Model: {embedding_model}")else:logger.info(f"Using OpenAI API-compatible Embedding Model: {embedding_model}")self._api_key = api_keyself._api_base = api_baseself._embedding_model = embedding_modelself._embedding_model_dim = embedding_model_dimdef get_text_embedder(self):return AsyncTextEmbedder(api_key=self._api_key,api_base_url=self._api_base,model=self._embedding_model,dimensions=self._embedding_model_dim,)def get_document_embedder(self):return AsyncDocumentEmbedder(api_key=self._api_key,api_base_url=self._api_base,model=self._embedding_model,dimensions=self._embedding_model_dim,)

http://www.ppmy.cn/ops/157006.html

相关文章

网络安全-防御 第一次作业(由于防火墙只成功启动了一次未补截图)

防火墙安全策略课堂实验报告 一、拓扑 本实验拓扑包含预启动设备、DMZ区域(含OA Server和Web Server)、防火墙(FW1)、Trust区域(含办公区PC和生产区PC)等。具体IP地址及连接关系如给定拓扑图所示&#xf…

防洪子堤,筑牢生命防线|鼎跃安全

近年来,随着极端天气事件的增多和城市快速扩张,洪涝灾害频发已成为各级政府和社会各界普遍关注的问题。在抗洪抢险过程中,如何快速构筑防洪屏障、分流洪水、保护重点区域成为抢险救灾的关键。防洪子堤作为一种新型、灵活且易于部署的临时防洪…

如何利用Java爬虫获取商品销量详情实战指南

在当今数字化时代,电商平台的商品销量数据对于市场分析、竞品研究和商业决策具有极高的价值。通过Java爬虫技术,我们可以高效地获取这些数据,为商业分析提供支持。本文将详细介绍如何利用Java编写爬虫程序,获取商品的销量详情&…

Windows Docker笔记-简介摘录

Docker是一个开源的容器化平台,可以帮助开发人员将应用程序与其依赖项打包在一个独立的容器中,然后在任何安装的Docker的环境中快速、可靠地运行。 几个基本概念和优势: 1. 容器 容器是一个轻量级、独立的运行环境,包含了应用程…

基于Flask的全国海底捞门店数据可视化分析系统的设计与实现

【FLask】基于Flask的全国海底捞门店数据可视化分析系统的设计与实现(完整系统源码开发笔记详细部署教程)✅ 目录 一、项目简介二、项目界面展示三、项目视频展示 一、项目简介 该系统系统采用Python语言结合Flask框架开发,利用Pandas、NumP…

告别手动操作!用Ansible user模块高效管理 Linux账户

在企业运维环境中,服务器的用户管理是一项基础但非常重要的任务。比如,当有新员工加入时,我们需要在多台服务器上为他们创建账户并分配合适的权限。而当员工离职或岗位发生变化时,我们也需要迅速禁用或删除他们的账户,…

具身智能学习规划

具身智能(Embodied Intelligence)强调智能体通过身体与环境的动态交互实现学习和决策,是人工智能、机器人学、认知科学和神经科学交叉的前沿领域。其核心在于打破传统AI的“离身认知”,将智能与物理实体、感知-运动系统紧密结合。…

【Windows】PowerShell 缓存区大小调节

PowerShell 缓存区大小调节 方式1 打开powershell 窗口属性调节方式2,修改 PowerShell 配置文件 方式1 打开powershell 窗口属性调节 打开 CMD(按 Win R,输入 cmd)。右键标题栏 → 选择 属性(Properties)…