文章目录
- FastGPT引申:常见 Rerank 实现方案
- 1. 使用 BGE Reranker
- 2. 使用 Cohere Rerank API
- 3. 使用 Cross-Encoder 实现
- 4. 自定义 Reranker 实现
- 5. FastAPI 服务实现
- 6. 实现方案总结
FastGPT引申:常见 Rerank 实现方案
下边介绍几种 Rerank 的具体实现方案。
1. 使用 BGE Reranker
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torchclass BGEReranker:def __init__(self):# 加载模型和分词器self.model_name = "BAAI/bge-reranker-base"self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name)self.model.eval()def rerank(self, query: str, documents: list[str]) -> list[dict]:results = []# 批处理文档for doc in documents:# 构造输入格式inputs = self.tokenizer(text=[query],text_pair=[doc],padding=True,truncation=True,max_length=512,return_tensors="pt")# 模型推理with torch.no_grad():scores = self.model(**inputs).logits.flatten()results.append({"text": doc,"score": float(scores[0]) # 转换为Python float})# 按分数排序results.sort(key=lambda x: x["score"], reverse=True)return results# 使用示例
reranker = BGEReranker()
query = "如何使用Python进行数据分析?"
docs = ["Python数据分析基础教程","数据分析工具pandas使用指南","Python编程基础入门"
]reranked_results = reranker.rerank(query, docs)
2. 使用 Cohere Rerank API
import cohere
from typing import List, Dictclass CohereReranker:def __init__(self, api_key: str):self.co = cohere.Client(api_key)def rerank(self, query: str, documents: List[Dict[str, str]], top_n: int = 3) -> List[Dict]:try:# 调用Cohere APIresults = self.co.rerank(query=query,documents=[doc["text"] for doc in documents],top_n=top_n,model="rerank-multilingual-v2.0")# 格式化结果reranked_results = []for result in results:reranked_results.append({"id": documents[result.index]["id"],"text": result.document["text"],"relevance_score": result.relevance_score})return reranked_resultsexcept Exception as e:print(f"Reranking error: {str(e)}")return []# 使用示例
reranker = CohereReranker(api_key="your-api-key")
query = "数据分析方法"
docs = [{"id": "1", "text": "使用pandas进行数据处理"},{"id": "2", "text": "数据可视化技巧"},{"id": "3", "text": "机器学习算法"}
]results = reranker.rerank(query, docs)
3. 使用 Cross-Encoder 实现
from sentence_transformers import CrossEncoder
import numpy as npclass CrossEncoderReranker:def __init__(self):# 加载cross-encoder模型self.model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')def rerank(self, query: str, documents: List[Dict], batch_size: int = 32) -> List[Dict]:# 准备文档对pairs = [[query, doc["text"]] for doc in documents]# 批量计算相关性分数scores = []for i in range(0, len(pairs), batch_size):batch = pairs[i:i + batch_size]batch_scores = self.model.predict(batch)scores.extend(batch_scores)# 组合结果results = []for idx, score in enumerate(scores):results.append({"id": documents[idx]["id"],"text": documents[idx]["text"],"score": float(score)})# 按分数排序results.sort(key=lambda x: x["score"], reverse=True)return results# 使用示例
reranker = CrossEncoderReranker()
results = reranker.rerank(query, documents)
4. 自定义 Reranker 实现
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModelclass CustomReranker(nn.Module):def __init__(self, model_name: str = "bert-base-chinese"):super().__init__()self.encoder = AutoModel.from_pretrained(model_name)self.tokenizer = AutoTokenizer.from_pretrained(model_name)# 相关性评分层self.score = nn.Linear(self.encoder.config.hidden_size, 1)def forward(self, input_ids, attention_mask):# 获取BERT输出outputs = self.encoder(input_ids=input_ids,attention_mask=attention_mask)# 使用[CLS]标记的输出计算相关性分数pooled_output = outputs.last_hidden_state[:, 0]score = self.score(pooled_output)return scoredef rerank(self, query: str, documents: List[str]) -> List[Dict]:self.eval()results = []with torch.no_grad():for doc in documents:# 构造输入inputs = self.tokenizer(text=[query],text_pair=[doc],padding=True,truncation=True,max_length=512,return_tensors="pt")# 计算分数score = self.forward(inputs["input_ids"],inputs["attention_mask"])results.append({"text": doc,"score": float(score[0])})# 排序results.sort(key=lambda x: x["score"], reverse=True)return results# 训练函数示例
def train_reranker(model, train_dataloader, epochs=3):optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)loss_fn = nn.BCEWithLogitsLoss()for epoch in range(epochs):model.train()for batch in train_dataloader:optimizer.zero_grad()input_ids = batch["input_ids"]attention_mask = batch["attention_mask"]labels = batch["labels"]scores = model(input_ids, attention_mask)loss = loss_fn(scores, labels)loss.backward()optimizer.step()
5. FastAPI 服务实现
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optionalapp = FastAPI()class Document(BaseModel):id: strtext: strclass RerankerRequest(BaseModel):query: strdocuments: List[Document]class RerankerResponse(BaseModel):id: strtext: strscore: float@app.post("/rerank", response_model=List[RerankerResponse])
async def rerank(request: RerankerRequest):try:reranker = CrossEncoderReranker() # 或其他实现results = reranker.rerank(query=request.query,documents=[{"id": doc.id,"text": doc.text} for doc in request.documents])return resultsexcept Exception as e:raise HTTPException(status_code=500, detail=str(e))
6. 实现方案总结
- BGE Reranker
- 开源模型
- 支持中英文
- 性能较好
- Cohere Rerank
- 商业API
- 多语言支持
- 无需维护模型
- Cross-Encoder
- 专门针对重排序优化
- 计算效率较高
- 易于使用
- 自定义实现
- 完全可控
- 可以针对特定场景优化
- 需要训练数据