FastGPT 引申:常见 Rerank 实现方案

devtools/2025/3/6 22:56:26/

文章目录

    • FastGPT引申:常见 Rerank 实现方案
      • 1. 使用 BGE Reranker
      • 2. 使用 Cohere Rerank API
      • 3. 使用 Cross-Encoder 实现
      • 4. 自定义 Reranker 实现
      • 5. FastAPI 服务实现
      • 6. 实现方案总结

FastGPT引申:常见 Rerank 实现方案

下边介绍几种 Rerank 的具体实现方案。

1. 使用 BGE Reranker

from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torchclass BGEReranker:def __init__(self):# 加载模型和分词器self.model_name = "BAAI/bge-reranker-base"self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name)self.model.eval()def rerank(self, query: str, documents: list[str]) -> list[dict]:results = []# 批处理文档for doc in documents:# 构造输入格式inputs = self.tokenizer(text=[query],text_pair=[doc],padding=True,truncation=True,max_length=512,return_tensors="pt")# 模型推理with torch.no_grad():scores = self.model(**inputs).logits.flatten()results.append({"text": doc,"score": float(scores[0])  # 转换为Python float})# 按分数排序results.sort(key=lambda x: x["score"], reverse=True)return results# 使用示例
reranker = BGEReranker()
query = "如何使用Python进行数据分析?"
docs = ["Python数据分析基础教程","数据分析工具pandas使用指南","Python编程基础入门"
]reranked_results = reranker.rerank(query, docs)

2. 使用 Cohere Rerank API

import cohere
from typing import List, Dictclass CohereReranker:def __init__(self, api_key: str):self.co = cohere.Client(api_key)def rerank(self, query: str, documents: List[Dict[str, str]], top_n: int = 3) -> List[Dict]:try:# 调用Cohere APIresults = self.co.rerank(query=query,documents=[doc["text"] for doc in documents],top_n=top_n,model="rerank-multilingual-v2.0")# 格式化结果reranked_results = []for result in results:reranked_results.append({"id": documents[result.index]["id"],"text": result.document["text"],"relevance_score": result.relevance_score})return reranked_resultsexcept Exception as e:print(f"Reranking error: {str(e)}")return []# 使用示例
reranker = CohereReranker(api_key="your-api-key")
query = "数据分析方法"
docs = [{"id": "1", "text": "使用pandas进行数据处理"},{"id": "2", "text": "数据可视化技巧"},{"id": "3", "text": "机器学习算法"}
]results = reranker.rerank(query, docs)

3. 使用 Cross-Encoder 实现

from sentence_transformers import CrossEncoder
import numpy as npclass CrossEncoderReranker:def __init__(self):# 加载cross-encoder模型self.model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')def rerank(self, query: str, documents: List[Dict], batch_size: int = 32) -> List[Dict]:# 准备文档对pairs = [[query, doc["text"]] for doc in documents]# 批量计算相关性分数scores = []for i in range(0, len(pairs), batch_size):batch = pairs[i:i + batch_size]batch_scores = self.model.predict(batch)scores.extend(batch_scores)# 组合结果results = []for idx, score in enumerate(scores):results.append({"id": documents[idx]["id"],"text": documents[idx]["text"],"score": float(score)})# 按分数排序results.sort(key=lambda x: x["score"], reverse=True)return results# 使用示例
reranker = CrossEncoderReranker()
results = reranker.rerank(query, documents)

4. 自定义 Reranker 实现

import torch
from torch import nn
from transformers import AutoTokenizer, AutoModelclass CustomReranker(nn.Module):def __init__(self, model_name: str = "bert-base-chinese"):super().__init__()self.encoder = AutoModel.from_pretrained(model_name)self.tokenizer = AutoTokenizer.from_pretrained(model_name)# 相关性评分层self.score = nn.Linear(self.encoder.config.hidden_size, 1)def forward(self, input_ids, attention_mask):# 获取BERT输出outputs = self.encoder(input_ids=input_ids,attention_mask=attention_mask)# 使用[CLS]标记的输出计算相关性分数pooled_output = outputs.last_hidden_state[:, 0]score = self.score(pooled_output)return scoredef rerank(self, query: str, documents: List[str]) -> List[Dict]:self.eval()results = []with torch.no_grad():for doc in documents:# 构造输入inputs = self.tokenizer(text=[query],text_pair=[doc],padding=True,truncation=True,max_length=512,return_tensors="pt")# 计算分数score = self.forward(inputs["input_ids"],inputs["attention_mask"])results.append({"text": doc,"score": float(score[0])})# 排序results.sort(key=lambda x: x["score"], reverse=True)return results# 训练函数示例
def train_reranker(model, train_dataloader, epochs=3):optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)loss_fn = nn.BCEWithLogitsLoss()for epoch in range(epochs):model.train()for batch in train_dataloader:optimizer.zero_grad()input_ids = batch["input_ids"]attention_mask = batch["attention_mask"]labels = batch["labels"]scores = model(input_ids, attention_mask)loss = loss_fn(scores, labels)loss.backward()optimizer.step()

5. FastAPI 服务实现

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optionalapp = FastAPI()class Document(BaseModel):id: strtext: strclass RerankerRequest(BaseModel):query: strdocuments: List[Document]class RerankerResponse(BaseModel):id: strtext: strscore: float@app.post("/rerank", response_model=List[RerankerResponse])
async def rerank(request: RerankerRequest):try:reranker = CrossEncoderReranker()  # 或其他实现results = reranker.rerank(query=request.query,documents=[{"id": doc.id,"text": doc.text} for doc in request.documents])return resultsexcept Exception as e:raise HTTPException(status_code=500, detail=str(e))

6. 实现方案总结

  • BGE Reranker
    • 开源模型
    • 支持中英文
    • 性能较好
  • Cohere Rerank
    • 商业API
    • 多语言支持
    • 无需维护模型
  • Cross-Encoder
    • 专门针对重排序优化
    • 计算效率较高
    • 易于使用
  • 自定义实现
    • 完全可控
    • 可以针对特定场景优化
    • 需要训练数据

http://www.ppmy.cn/devtools/165091.html

相关文章

【商城实战(8)】筑牢权限防线:用户认证与权限管理进阶

【商城实战】专栏重磅来袭!这是一份专为开发者与电商从业者打造的超详细指南。从项目基础搭建,运用 uniapp、Element Plus、SpringBoot 搭建商城框架,到用户、商品、订单等核心模块开发,再到性能优化、安全加固、多端适配&#xf…

Win10 用户、组与内置安全主体概念详解

一、‌用户(User)‌ ‌定义‌ 用户是操作系统中的身份标识,用于区分不同操作者并控制资源访问权限。每个用户拥有独立的安全标识符(SID)‌。 ‌分类‌ ‌内置用户‌: ‌Administrator‌:系统…

Kubernetes(K8S)部署 Redis Cluster 集群

以下将详细介绍如何使用 Kubernetes(K8S)部署 Redis Cluster 集群,并给出相应的 YAML 代码。 1. 准备工作 在开始部署之前,需要确保已经安装并配置好 Kubernetes 集群,并且 kubectl 可以正常与集群通信。 2. 部署 R…

【Mac】git使用再学习

目录 前言 如何使用github建立自己的代码库 第一步:建立本地git与远程github的联系 生成密钥 将密钥加入github 第二步:创建github仓库并clone到本地 第三步:上传文件 常见的git命令 git commit git branch git merge/git rebase …

【AI神经网络与人脑神经系统的关联及借鉴分析】

AI神经网络与人脑神经系统的关联及借鉴分析 一、结构与功能模拟:从生物神经元到人工单元 生物神经元模型 人脑神经元通过电化学信号传递信息,当输入信号超过阈值时触发动作电位("全有或全无"法则)。其动态过程可用Hodg…

1688平台API接口实战:Python实现店铺全量商品数据抓取

在电商数据驱动决策的时代,1688作为国内最大的B2B批发平台,其开放的API接口为商家提供了高效获取商品数据的通道。本文将以Python语言为例,详解如何通过官方接口实现店铺所有商品的自动化抓取。(综合参考) 一、接口核…

探索DeFi世界:用Python开发去中心化金融应用

探索DeFi世界:用Python开发去中心化金融应用 在区块链技术快速发展的今天,去中心化金融(DeFi)正在改变传统金融行业的格局。作为一名自媒体创作者和技术爱好者,我希望通过本文分享如何用Python开发去中心化金融应用,帮助读者深入了解DeFi的潜力和技术实现方式。 什么是…

qsort函数的模拟实现

文章目录 冒泡排序回调函数qsort函数简介qsort函数的使用qsort函数的模拟实现 冒泡排序 冒泡排序顾名思义就是用来给数据排序的一种方法,假设有一整型数组,如果要将这个数组中的元素按从小到大或从大到小的顺序排序,就可以用冒泡排序来完成。…