【RAG】基于向量检索的 RAG (BGE示例)

ops/2025/3/10 22:23:15/

RAG机器人 结构体

  • 文本向量化: 使用 BGE 模型将文档和查询编码为向量。
    (BGE 是专为检索任务优化的开源 Embedding 模型,除了本文API调用,也可以通过Hugging Face 本地部署BGE 开源模型)

  • 向量检索: 从数据库中找到与查询相关的文档片段。

  • 答案生成: 结合检索结果和用户输入,调用文心模型生成最终回答。

python">class RAG_Bot:def __init__(self, vector_db, llm_api, n_results=2):self.vector_db = vector_dbself.llm_api = llm_apiself.n_results = n_resultsdef chat(self, user_query):# 1. 检索search_results = self.vector_db.search(user_query, self.n_results)# 2. 构建 Promptprompt = build_prompt(prompt_template, context=search_results['documents'][0], query=user_query)# 3. 调用 LLMresponse = self.llm_api(prompt)return response
####### 创建一个RAG机器人
bot = RAG_Bot(vector_db,llm_api=get_completion
)user_query = "llama 2有多少参数?"response = bot.chat(user_query)print(response)#####
llama 2有7B, 13B和70B参数。

MyVectorDBConnector:

自定义向量数据库,存储文档向量。
embedding_fn=get_embeddings_bge: 使用 BGE 模型生成向量。
add_documents(paragraphs): 向数据库中添加文档(已提前定义 paragraphs)。

RAG_Bot:

检索增强生成机器人,结合向量搜索与大模型生成。
chat(user_query): 执行“检索→生成”流程:
将用户查询向量化。
从数据库检索相关文档。
将检索结果作为上下文,调用文心模型生成回答。

使用国产模型

python">import json
import requests
import os# 通过鉴权接口获取 access tokendef get_access_token():"""使用 AK,SK 生成鉴权签名(Access Token):return: access_token,或是None(如果错误)"""url = "https://aip.baidubce.com/oauth/2.0/token"params = {"grant_type": "client_credentials","client_id": os.getenv('ERNIE_CLIENT_ID'),"client_secret": os.getenv('ERNIE_CLIENT_SECRET')}return str(requests.post(url, params=params).json().get("access_token"))# 调用文心千帆 调用 BGE Embedding 接口def get_embeddings_bge(prompts):url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en?access_token=" + get_access_token()payload = json.dumps({"input": prompts})headers = {'Content-Type': 'application/json'}response = requests.request("POST", url, headers=headers, data=payload).json()data = response["data"]return [x["embedding"] for x in data]# 调用文心4.0对话接口
def get_completion_ernie(prompt):url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token=" + get_access_token()payload = json.dumps({"messages": [{"role": "user","content": prompt}]})headers = {'Content-Type': 'application/json'}response = requests.request("POST", url, headers=headers, data=payload).json()return response["result"]# 创建一个向量数据库对象
new_vector_db = MyVectorDBConnector("demo_ernie",embedding_fn=get_embeddings_bge
)
# 向向量数据库中添加文档
new_vector_db.add_documents(paragraphs)# 创建一个RAG机器人
new_bot = RAG_Bot(new_vector_db,llm_api=get_completion_ernie
)user_query = "how many parameters does llama 2 have?"response = new_bot.chat(user_query)print(response)

拓展实践

1. 优化 Access Token 管理
  • 缓存 Token:减少鉴权接口调用次数,仅在 Token 过期时刷新。
  • 示例代码
    python">from datetime import datetime, timedeltaclass TokenManager:_token = None_expires_at = None@classmethoddef get_token(cls):if cls._token is None or datetime.now() > cls._expires_at:cls._refresh_token()return cls._token@classmethoddef _refresh_token(cls):url = "https://aip.baidubce.com/oauth/2.0/token"params = {"grant_type": "client_credentials","client_id": os.getenv('ERNIE_CLIENT_ID'),"client_secret": os.getenv('ERNIE_CLIENT_SECRET')}response = requests.post(url, params=params)response.raise_for_status()data = response.json()cls._token = data["access_token"]# 默认 Token 有效期为 30 天,但建议按实际返回的 expires_in 设置cls._expires_at = datetime.now() + timedelta(seconds=data.get("expires_in", 2592000) - 300)  # 提前 5 分钟刷新
    
2. 增强错误处理与重试
  • 重试网络请求:使用 tenacity 库自动重试失败请求。
  • 捕获异常:明确处理常见错误(如网络超时、无效响应)。
  • 示例代码
    python">from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
    import requests.exceptions as req_exceptions@retry(stop=stop_after_attempt(3),wait=wait_exponential(multiplier=1, min=2, max=10),retry=retry_if_exception_type((req_exceptions.Timeout, req_exceptions.ConnectionError))
    )
    def safe_api_request(url, headers, payload):try:response = requests.post(url, headers=headers, data=payload, timeout=10)response.raise_for_status()return response.json()except req_exceptions.HTTPError as e:if response.status_code == 401:TokenManager._refresh_token()  # Token 可能过期,强制刷新raiseraise ValueError(f"API 错误: {e.response.text}")
    
3. 验证环境变量
  • 启动时检查:确保关键配置已正确设置。
  • 示例代码
    python">def validate_env_vars():required_vars = ['ERNIE_CLIENT_ID', 'ERNIE_CLIENT_SECRET']missing_vars = [var for var in required_vars if not os.getenv(var)]if missing_vars:raise EnvironmentError(f"缺少环境变量: {', '.join(missing_vars)}")# 在程序初始化时调用
    validate_env_vars()
    
4. 优化向量数据库交互
  • 批量插入文档:减少 API 调用次数。
  • 分块策略:根据 Embedding 模型的最大输入长度分块文本。
  • 示例优化(假设使用 MyVectorDBConnector):
    python">class MyVectorDBConnector:def __init__(self, name, embedding_fn, chunk_size=512):self.embedding_fn = embedding_fnself.chunk_size = chunk_size  # 根据模型支持的最大长度设置def add_documents(self, documents):chunks = self._chunk_documents(documents)embeddings = self.embedding_fn(chunks)# 批量存储到向量数据库def _chunk_documents(self, documents):# 实现基于句子或固定长度的分块逻辑pass
    

优化后的代码示例

整合上述改进后的核心逻辑:

python">import os
import json
import logging
from datetime import datetime, timedelta
import requests
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
import requests.exceptions as req_exceptions# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)# 环境变量校验
def validate_env_vars():required_vars = ['ERNIE_CLIENT_ID', 'ERNIE_CLIENT_SECRET']missing_vars = [var for var in required_vars if not os.getenv(var)]if missing_vars:raise EnvironmentError(f"Missing env vars: {', '.join(missing_vars)}")
validate_env_vars()# Token 管理
class TokenManager:_token = None_expires_at = None@classmethoddef get_token(cls):if cls._token is None or datetime.now() > cls._expires_at:cls._refresh_token()return cls._token@classmethoddef _refresh_token(cls):logger.info("Refreshing access token...")url = "https://aip.baidubce.com/oauth/2.0/token"params = {"grant_type": "client_credentials","client_id": os.getenv('ERNIE_CLIENT_ID'),"client_secret": os.getenv('ERNIE_CLIENT_SECRET')}response = requests.post(url, params=params)response.raise_for_status()data = response.json()cls._token = data["access_token"]cls._expires_at = datetime.now() + timedelta(seconds=data.get("expires_in", 2592000) - 300)# 安全 API 请求
@retry(stop=stop_after_attempt(3),wait=wait_exponential(multiplier=1, min=2, max=10),retry=retry_if_exception_type((req_exceptions.Timeout, req_exceptions.ConnectionError))
)
def safe_api_request(url, headers, payload):try:response = requests.post(url, headers=headers, data=payload, timeout=10)response.raise_for_status()return response.json()except req_exceptions.HTTPError as e:if response.status_code == 401:TokenManager._refresh_token()raiselogger.error(f"API Error: {e.response.text}")raise# 公共 API 调用封装
def call_ernie_api(endpoint, payload):base_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop"url = f"{base_url}/{endpoint}?access_token={TokenManager.get_token()}"headers = {'Content-Type': 'application/json'}return safe_api_request(url, headers, json.dumps(payload))# Embedding 接口
def get_embeddings_bge(prompts):logger.info(f"Generating embeddings for {len(prompts)} prompts")response = call_ernie_api("embeddings/bge_large_en", {"input": prompts})return [x["embedding"] for x in response["data"]]# 文心 4.0 对话接口
def get_completion_ernie(prompt):logger.info(f"Generating completion for prompt: {prompt[:50]}...")response = call_ernie_api("chat/completions_pro", {"messages": [{"role": "user", "content": prompt}]})return response["result"]

http://www.ppmy.cn/ops/164772.html

相关文章

HTML 插入图片(简单易懂较详细)

在 HTML 中&#xff0c;插入图片是通过 <img> 标签实现的。<img> 标签是一个空标签&#xff0c;意味着它不需要闭合标签。以下是插入图片的基本语法和常用属性的详细讲解。 一、基本语法 <img src"图片路径" alt"替代文本">src&#x…

使用Wireshark截取并解密摄像头画面

在物联网&#xff08;IoT&#xff09;设备普及的今天&#xff0c;安全摄像头等智能设备在追求便捷的同时&#xff0c;往往忽视了数据传输过程中的加密保护。很多摄像头默认通过 HTTP 协议传输数据&#xff0c;而非加密的 HTTPS&#xff0c;从而给潜在攻击者留下了可乘之机。本文…

Calico-BGP FullMesh模式与RR模式 Day04

1. BGP协议简单介绍 BGP是什么&#xff1f;BGP是如何工作的&#xff1f; - 华为 Configure BGP peering | Calico Documentation 1.1 什么是BGP 边界网关协议&#xff08;BGP&#xff09;是一种用于在网络中的路由器之间交换路由信息的标准协议。每台运行 BGP 的路由器都有一…

SAP HANA Merge

在SAP HANA数据库中&#xff0c;数据表都分为两个区域&#xff1a;Main Store和Delta Store。Main Store中的数据经过高压缩处理&#xff0c;查询和计算效率高&#xff0c;但写入成本高&#xff1b;而Delta Store则是为写入优化的区域&#xff0c;数据会定期从Delta Store合并到…

【愚公系列】《Python网络爬虫从入门到精通》045-Charles的SSL证书的安装

标题详情作者简介愚公搬代码头衔华为云特约编辑&#xff0c;华为云云享专家&#xff0c;华为开发者专家&#xff0c;华为产品云测专家&#xff0c;CSDN博客专家&#xff0c;CSDN商业化专家&#xff0c;阿里云专家博主&#xff0c;阿里云签约作者&#xff0c;腾讯云优秀博主&…

STM32Cubemx配置E22-xxxT22D lora模块实现定点传输

文章目录 一、STM32Cubemx配置二、定点传输**什么是定点传输?****定点传输的特点****定点传输的工作方式****E22 模块定点传输配置****如何启用定点传输?****示例****应用场景****总结****配置 1:`C0 00 07 00 02 04 62 00 17 40`****解析****配置 2:`C0 00 07 00 01 04 62…

配置nacos

解压资料中的nacos-server-1.2.0.zip 进入bin目录双击 startup.cmd 运行文件 访问http://localhost:8848/nacos 注册admin服务 1<dependency> <groupId>com.alibaba.cloud</groupId> <artifactId>spring-cloud-starter-alibaba-nacos-d…

护照阅读器在旅游景区流程中的应用

在旅游景区的日常运营与管理中&#xff0c;为游客提供便捷、高效且安全的游览体验至关重要。护照阅读器作为先进的身份识别设备&#xff0c;在景区的自助购票、行李寄存以及自助安检等关键环节发挥着重要作用&#xff0c;极大地优化了景区的运营流程&#xff0c;提升了游客的满…