RAG机器人 结构体
-
文本向量化: 使用 BGE 模型将文档和查询编码为向量。
(BGE 是专为检索任务优化的开源 Embedding 模型,除了本文API调用,也可以通过Hugging Face 本地部署BGE 开源模型) -
向量检索: 从数据库中找到与查询相关的文档片段。
-
答案生成: 结合检索结果和用户输入,调用文心模型生成最终回答。
python">class RAG_Bot:def __init__(self, vector_db, llm_api, n_results=2):self.vector_db = vector_dbself.llm_api = llm_apiself.n_results = n_resultsdef chat(self, user_query):# 1. 检索search_results = self.vector_db.search(user_query, self.n_results)# 2. 构建 Promptprompt = build_prompt(prompt_template, context=search_results['documents'][0], query=user_query)# 3. 调用 LLMresponse = self.llm_api(prompt)return response
####### 创建一个RAG机器人
bot = RAG_Bot(vector_db,llm_api=get_completion
)user_query = "llama 2有多少参数?"response = bot.chat(user_query)print(response)#####
llama 2有7B, 13B和70B参数。
MyVectorDBConnector:
自定义向量数据库,存储文档向量。
embedding_fn=get_embeddings_bge: 使用 BGE 模型生成向量。
add_documents(paragraphs): 向数据库中添加文档(已提前定义 paragraphs)。
RAG_Bot:
检索增强生成机器人,结合向量搜索与大模型生成。
chat(user_query): 执行“检索→生成”流程:
将用户查询向量化。
从数据库检索相关文档。
将检索结果作为上下文,调用文心模型生成回答。
使用国产模型
python">import json
import requests
import os# 通过鉴权接口获取 access tokendef get_access_token():"""使用 AK,SK 生成鉴权签名(Access Token):return: access_token,或是None(如果错误)"""url = "https://aip.baidubce.com/oauth/2.0/token"params = {"grant_type": "client_credentials","client_id": os.getenv('ERNIE_CLIENT_ID'),"client_secret": os.getenv('ERNIE_CLIENT_SECRET')}return str(requests.post(url, params=params).json().get("access_token"))# 调用文心千帆 调用 BGE Embedding 接口def get_embeddings_bge(prompts):url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en?access_token=" + get_access_token()payload = json.dumps({"input": prompts})headers = {'Content-Type': 'application/json'}response = requests.request("POST", url, headers=headers, data=payload).json()data = response["data"]return [x["embedding"] for x in data]# 调用文心4.0对话接口
def get_completion_ernie(prompt):url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token=" + get_access_token()payload = json.dumps({"messages": [{"role": "user","content": prompt}]})headers = {'Content-Type': 'application/json'}response = requests.request("POST", url, headers=headers, data=payload).json()return response["result"]# 创建一个向量数据库对象
new_vector_db = MyVectorDBConnector("demo_ernie",embedding_fn=get_embeddings_bge
)
# 向向量数据库中添加文档
new_vector_db.add_documents(paragraphs)# 创建一个RAG机器人
new_bot = RAG_Bot(new_vector_db,llm_api=get_completion_ernie
)user_query = "how many parameters does llama 2 have?"response = new_bot.chat(user_query)print(response)
拓展实践
1. 优化 Access Token 管理
- 缓存 Token:减少鉴权接口调用次数,仅在 Token 过期时刷新。
- 示例代码:
python">from datetime import datetime, timedeltaclass TokenManager:_token = None_expires_at = None@classmethoddef get_token(cls):if cls._token is None or datetime.now() > cls._expires_at:cls._refresh_token()return cls._token@classmethoddef _refresh_token(cls):url = "https://aip.baidubce.com/oauth/2.0/token"params = {"grant_type": "client_credentials","client_id": os.getenv('ERNIE_CLIENT_ID'),"client_secret": os.getenv('ERNIE_CLIENT_SECRET')}response = requests.post(url, params=params)response.raise_for_status()data = response.json()cls._token = data["access_token"]# 默认 Token 有效期为 30 天,但建议按实际返回的 expires_in 设置cls._expires_at = datetime.now() + timedelta(seconds=data.get("expires_in", 2592000) - 300) # 提前 5 分钟刷新
2. 增强错误处理与重试
- 重试网络请求:使用
tenacity
库自动重试失败请求。 - 捕获异常:明确处理常见错误(如网络超时、无效响应)。
- 示例代码:
python">from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type import requests.exceptions as req_exceptions@retry(stop=stop_after_attempt(3),wait=wait_exponential(multiplier=1, min=2, max=10),retry=retry_if_exception_type((req_exceptions.Timeout, req_exceptions.ConnectionError)) ) def safe_api_request(url, headers, payload):try:response = requests.post(url, headers=headers, data=payload, timeout=10)response.raise_for_status()return response.json()except req_exceptions.HTTPError as e:if response.status_code == 401:TokenManager._refresh_token() # Token 可能过期,强制刷新raiseraise ValueError(f"API 错误: {e.response.text}")
3. 验证环境变量
- 启动时检查:确保关键配置已正确设置。
- 示例代码:
python">def validate_env_vars():required_vars = ['ERNIE_CLIENT_ID', 'ERNIE_CLIENT_SECRET']missing_vars = [var for var in required_vars if not os.getenv(var)]if missing_vars:raise EnvironmentError(f"缺少环境变量: {', '.join(missing_vars)}")# 在程序初始化时调用 validate_env_vars()
4. 优化向量数据库交互
- 批量插入文档:减少 API 调用次数。
- 分块策略:根据 Embedding 模型的最大输入长度分块文本。
- 示例优化(假设使用
MyVectorDBConnector
):python">class MyVectorDBConnector:def __init__(self, name, embedding_fn, chunk_size=512):self.embedding_fn = embedding_fnself.chunk_size = chunk_size # 根据模型支持的最大长度设置def add_documents(self, documents):chunks = self._chunk_documents(documents)embeddings = self.embedding_fn(chunks)# 批量存储到向量数据库def _chunk_documents(self, documents):# 实现基于句子或固定长度的分块逻辑pass
优化后的代码示例
整合上述改进后的核心逻辑:
python">import os
import json
import logging
from datetime import datetime, timedelta
import requests
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
import requests.exceptions as req_exceptions# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)# 环境变量校验
def validate_env_vars():required_vars = ['ERNIE_CLIENT_ID', 'ERNIE_CLIENT_SECRET']missing_vars = [var for var in required_vars if not os.getenv(var)]if missing_vars:raise EnvironmentError(f"Missing env vars: {', '.join(missing_vars)}")
validate_env_vars()# Token 管理
class TokenManager:_token = None_expires_at = None@classmethoddef get_token(cls):if cls._token is None or datetime.now() > cls._expires_at:cls._refresh_token()return cls._token@classmethoddef _refresh_token(cls):logger.info("Refreshing access token...")url = "https://aip.baidubce.com/oauth/2.0/token"params = {"grant_type": "client_credentials","client_id": os.getenv('ERNIE_CLIENT_ID'),"client_secret": os.getenv('ERNIE_CLIENT_SECRET')}response = requests.post(url, params=params)response.raise_for_status()data = response.json()cls._token = data["access_token"]cls._expires_at = datetime.now() + timedelta(seconds=data.get("expires_in", 2592000) - 300)# 安全 API 请求
@retry(stop=stop_after_attempt(3),wait=wait_exponential(multiplier=1, min=2, max=10),retry=retry_if_exception_type((req_exceptions.Timeout, req_exceptions.ConnectionError))
)
def safe_api_request(url, headers, payload):try:response = requests.post(url, headers=headers, data=payload, timeout=10)response.raise_for_status()return response.json()except req_exceptions.HTTPError as e:if response.status_code == 401:TokenManager._refresh_token()raiselogger.error(f"API Error: {e.response.text}")raise# 公共 API 调用封装
def call_ernie_api(endpoint, payload):base_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop"url = f"{base_url}/{endpoint}?access_token={TokenManager.get_token()}"headers = {'Content-Type': 'application/json'}return safe_api_request(url, headers, json.dumps(payload))# Embedding 接口
def get_embeddings_bge(prompts):logger.info(f"Generating embeddings for {len(prompts)} prompts")response = call_ernie_api("embeddings/bge_large_en", {"input": prompts})return [x["embedding"] for x in response["data"]]# 文心 4.0 对话接口
def get_completion_ernie(prompt):logger.info(f"Generating completion for prompt: {prompt[:50]}...")response = call_ernie_api("chat/completions_pro", {"messages": [{"role": "user", "content": prompt}]})return response["result"]