使用FastAPI为知识库问答系统前端提供后端功能接口

news/2025/3/24 7:55:28/

后端接口实现以及接口调用的类代码一览

  • 1. 后端接口代码
  • 2. 代码结构概述
  • 3. 主要功能模块
    • 1. 跨域支持
    • 2. 用户登录接口(/login)
    • 3. 用户注册接口(/register)
    • 4.用户相关接口依赖的类
    • 5.聊天接口(/chat)
    • 6.聊天接口依赖的类
  • 4. 连接方式

1. 后端接口代码

python"># app.py
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Dict
import uvicorn
from user_database import UserDatabase
from ModelResponse import ModelResponseapp = FastAPI()# 允许跨域访问(适配 Gradio 调用 FastAPI)
app.add_middleware(CORSMiddleware,allow_origins=["*"],allow_credentials=True,allow_methods=["*"],allow_headers=["*"],
)user_db = UserDatabase()
response = ModelResponse()class LoginRequest(BaseModel):username: strpassword: strclass RegisterRequest(BaseModel):username: strpassword: strclass ChatRequest(BaseModel):user_input: strchat_history: List[dict]class ChatResponse(BaseModel):status: strresponse: str@app.post("/login")
async def login(credentials: LoginRequest):"""登录接口,验证用户名和密码。"""if user_db.verify_user(credentials.username, credentials.password):return {"status": "success","message": "Login successful"}else:raise HTTPException(status_code=401, detail="Invalid username or password")@app.post("/register")
async def register(user: RegisterRequest):"""注册接口,添加用户。"""if user_db.add_user(user.username, user.password):return {"status": "success","message": f"User '{user.username}' registered successfully."}else:raise HTTPException(status_code=400, detail="User already exists.")@app.post("/chat")
async def chat(request: ChatRequest) -> ChatResponse:try:user_input = request.user_inputchat_history = request.chat_historybot_response = response.ask(user_input, chat_history)updated_chat_history = chat_history + [[user_input, bot_response["answer"]]]return ChatResponse(status="success",response=bot_response["answer"],)except Exception as e:raise HTTPException(status_code=500, detail=str(e))if __name__ == "__main__":uvicorn.run(app, host="0.0.0.0", port=8000)

2. 代码结构概述

app.py代码是一个基于 FastAPI 的后端服务,旨在为前端提供接口,支持用户登录、注册以及与聊天机器人进行交互的功能。以下是代码的详细功能介绍:

  1. 代码结构概述
    FastAPI:一个现代、高性能的 Python Web 框架,用于构建 API。
    CORS 中间件:允许跨域请求,方便前端(如 Gradio、React 等)调用后端 API。
    用户数据库:通过 UserDatabase 类管理用户的登录和注册。
    聊天机器人:通过 ModelResponse 类实现基于 LLM(大语言模型)的问答功能。
    API 接口:提供了 /login、/register 和 /chat 三个接口,分别用于用户登录、注册和聊天交互。

3. 主要功能模块

1. 跨域支持

python">app.add_middleware(CORSMiddleware,allow_origins=["*"],allow_credentials=True,allow_methods=["*"],allow_headers=["*"],
)

可以允许前端从任何域名访问后端API

2. 用户登录接口(/login)

python">@app.post("/login")
async def login(credentials: LoginRequest):if user_db.verify_user(credentials.username, credentials.password):return {"status": "success","message": "Login successful"}else:raise HTTPException(status_code=401, detail="Invalid username or password")

用于验证用户的用户名和密码。通过UserDatabase类中的verify_user方法实现。如果验证成功,返回 {“status”: “success”, “message”: “Login successful”}。这是前端期望的数据格式,一定要注意前端期望后端返回什么样的数据类型!!接口接受和发送的数据格式最好在工程定框架的时候就定死,不要轻易改动。如果验证失败,返回 401 状态码和错误信息 “Invalid username or password”

3. 用户注册接口(/register)

python">@app.post("/register")
async def register(user: RegisterRequest):if user_db.add_user(user.username, user.password):return {"status": "success","message": f"User '{user.username}' registered successfully."}else:raise HTTPException(status_code=400, detail="User already exists.")

功能是注册新用户。通过UserDatabase类中的add_user方法实现。如果注册成功,返回 {“status”: “success”, “message”: “User registered successfully.”}。如果用户名已存在,返回 400 状态码和错误信息 “User already exists.”

4.用户相关接口依赖的类

下面是实现接口中方法的类。

python"># user_database.py
import sqlite3
import yamlwith open("config.yaml", "r", encoding="utf-8") as f:config = yaml.safe_load(f)class UserDatabase:def __init__(self):"""初始化 UserDatabase,使用本地 SQLite 数据库。"""self.db_path = config["database"]["user_db_path"]self._init_db()def _init_db(self):"""初始化用户信息库"""with sqlite3.connect(self.db_path) as conn:cursor = conn.cursor()cursor.execute('''CREATE TABLE IF NOT EXISTS user (id INTEGER PRIMARY KEY AUTOINCREMENT,username TEXT UNIQUE NOT NULL,password TEXT NOT NULL)''')conn.commit()print("User database initialized.")def add_user(self, username, password):"""添加用户到用户信息库"""try:with sqlite3.connect(self.db_path) as conn:cursor = conn.cursor()# 检查用户是否已存在cursor.execute('SELECT username FROM user WHERE username = ?', (username,))if cursor.fetchone():print(f"User '{username}' already exists.")return False  # 用户已存在,返回 False# 插入新用户cursor.execute('''INSERT INTO user (username, password) VALUES (?, ?)''', (username, password))conn.commit()print(f"User '{username}' added to the database.")return True  # 用户添加成功,返回 Trueexcept sqlite3.Error as e:print(f"Database error: {e}")return False  # 数据库操作失败,返回 Falsedef get_user_by_id(self, user_id):"""根据用户 ID 查询用户信息"""with sqlite3.connect(self.db_path) as conn:conn.row_factory = sqlite3.Rowcursor = conn.cursor()cursor.execute('''SELECT * FROM user WHERE id = ?''', (user_id,))row = cursor.fetchone()  # 只调用一次 fetchonereturn dict(row) if row else Nonedef get_all_users(self):"""获取所有用户信息"""with sqlite3.connect(self.db_path) as conn:conn.row_factory = sqlite3.Rowcursor = conn.cursor()cursor.execute('SELECT * FROM user')return [dict(row) for row in cursor.fetchall()]def verify_user(self, username, password):with sqlite3.connect(self.db_path) as conn:cursor = conn.cursor()cursor.execute('SELECT * FROM user WHERE username = ? AND password = ?', (username, password))return cursor.fetchone() is not None

5.聊天接口(/chat)

python">@app.post("/chat")
async def chat(request: ChatRequest) -> ChatResponse:try:user_input = request.user_inputchat_history = request.chat_historybot_response = response.ask(user_input, chat_history)updated_chat_history = chat_history + [[user_input, bot_response["answer"]]]return ChatResponse(status="success",response=bot_response["answer"],)except Exception as e:raise HTTPException(status_code=500, detail=str(e))

功能是接收用户输入和聊天历史,调用response类中的ask方法生成回答。如果成功,返回 {“status”: “success”, “response”: “生成的回答”}。如果发生错误,返回 500 状态码和错误信息。

6.聊天接口依赖的类

python">import sqlite3
import numpy as np
import faiss
import requests
from typing import List, Dict
import yamldef build_final_prompt(query: str, chat_history: List[Dict[str, str]],relevant_docs: List[Dict[str, str]]) -> str:"""构建最终的提示(final_prompt),包含对话历史和相关文档内容。"""# 生成对话历史上下文full_prompt = []for line in chat_history:role = line.get("role")content = line.get("content")if role == "user":current_prompt = f"用户:{content}"elif role == "assistant":current_prompt = f"助手:{content}"else:raise Exception(f"无法支持的角色类型, {role}")full_prompt.append(current_prompt)full_prompt = "\n".join(full_prompt)doc_context = "\n".join([doc["text"] for doc in relevant_docs[:3]])# 构建最终提示final_prompt = (f"你是一个助手,帮助用户解答问题。\n"f"背景资料:\n{doc_context}\n\n"f"对话历史:\n{full_prompt}\n\n"f"用户的问题:\n{query}")return final_promptdef build_faiss_index(embeddings: np.ndarray):"""使用从数据库加载的嵌入向量构建 FAISS 索引。"""# 获取嵌入向量的维度dimension = embeddings.shape[1]# 创建 FAISS 索引index = faiss.IndexFlatL2(dimension)index.add(embeddings)return indexclass ModelResponse:def __init__(self, config_path: str = "config.yaml"):"""初始化 Response 类,加载配置文件并初始化向量库。"""# 读取配置文件with open(config_path, "r", encoding="utf-8") as f:self.config = yaml.safe_load(f)self.chat_history = []self.history_db_path = self.config["database"]["history_db_path"]# 初始化 API URLself.OLLAMA_API_URL_EMBED = self.config["ollama"]["api_url_embedding"]self.OLLAMA_API_URL_GENER = self.config["ollama"]["api_url_generate"]# 初始化向量库self.vector_db_path = self.config["database"]["vector_db_path"]self.documents, self.embeddings = self.load_embeddings_from_db(self.vector_db_path)self.index = build_faiss_index(self.embeddings)self.PROMPT_TEMPLATE = self.config.get("prompt_template", "")print(f"已从数据库加载 {len(self.documents)} 个文档的嵌入向量。FAISS 索引构建完成!")def generate_embedding(self, text: str) -> List[float]:"""调用 Ollama 的 API 生成文本嵌入。"""data = {"model": self.config["ollama"]["embedding_model"],  # 使用配置文件中的嵌入模型"prompt": text,"options": {"embedding_only": True}  # 只生成嵌入}response = requests.post(f"{self.OLLAMA_API_URL_EMBED}/embeddings", json=data)return response.json().get("embedding", [])@staticmethoddef normalize_embeddings(embeddings: np.ndarray) -> np.ndarray:"""对嵌入向量进行归一化。"""norms = np.linalg.norm(embeddings, axis=1, keepdims=True)return embeddings / normsdef load_embeddings_from_db(self, db_path: str):"""从 SQLite 数据库中加载文档和对应的嵌入向量。"""conn = sqlite3.connect(db_path)cursor = conn.cursor()# 查询数据库中的嵌入向量cursor.execute('SELECT id, pdf_file_name, document_text, embedding FROM document_embeddings')rows = cursor.fetchall()# 将字节流转换回嵌入向量documents = []embeddings = []for row in rows:doc_id, pdf_file_name, doc_text, embedding_bytes = rowembedding = np.frombuffer(embedding_bytes, dtype=np.float32)documents.append({"id": doc_id, "pdf_file_name": pdf_file_name, "text": doc_text})embeddings.append(embedding)embeddings = self.normalize_embeddings(np.array(embeddings))conn.close()return documents, embeddingsdef retrieve_documents(self, query: str, k: int = 5, threshold: float = 1.0) -> List[dict]:"""根据查询检索最相关的文档,并根据阈值过滤结果。"""# 生成查询嵌入query_embedding = np.array([self.generate_embedding(query)], dtype=np.float32)query_embedding = self.normalize_embeddings(query_embedding)distances, indices = self.index.search(query_embedding, k)relevant_docs = []for i, idx in enumerate(indices[0]):if distances[0][i] <= threshold:relevant_docs.append({"text": self.documents[idx]["text"],"score": float(distances[0][i])})relevant_docs.sort(key=lambda x: x["score"])return relevant_docsdef generate_answer(self, final_prompt: str) -> str:"""调用 Ollama 的 API 生成答案,并载入历史对话。"""# 定义模型参数data = {"model": self.config["ollama"]["generation_model"],"prompt": final_prompt,"stream": False,"temperature": self.config["ollama"]["temperature"]}response = requests.post(f"{self.OLLAMA_API_URL_GENER}/generate", json=data)if response.status_code == 200:return response.json().get("response", "")else:return f"API 请求失败,状态码:{response.status_code}"def ask(self, query: str, chat_history: List[dict] = None) -> Dict[str, str]:"""接收用户的问题,检索相关文档并生成答案。"""if chat_history is None:chat_history = []# 检索相关文档relevant_docs = self.retrieve_documents(query, k=self.config["retrieval"]["k"],threshold=self.config["retrieval"]["threshold"])if not relevant_docs:return {"status": "error","response": "未找到相关文档。",}# 构建最终提示final_prompt = build_final_prompt(query, chat_history, relevant_docs)# 生成答案answer = self.generate_answer(final_prompt)return {"query": query,"answer": answer,}

4. 连接方式

如果你没有前端界面的代码也没有关系,你可以通过uvicorn 启动 FastAPI 服务,直接访问URL,然后在终端来查看接口的数据。

uvicorn app:app --host 0.0.0.0 --port 8000

服务启动后,终端会显示类似以下信息:

INFO:     Started server process [12345]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)

你可以通过FastAPI 自动生成了交互式 API 文档(Swagger UI)来访问接口 。

  1. 先访问 Swagger UI:http://localhost:8000/docs
  2. 找到 /register、/login 、/chat接口,点击 Try it out
  3. 输入 JSON 数据(如 {“username”: “testuser”, “password”: “testpassword”}),然后点击 Execute
  4. 查看响应结果
访问根路径:http://localhost:8000/
访问用户登录。 http://localhost:8000/login
访问用户注册。http://localhost:8000/register
访问聊天交互。http://localhost:8000/chat

你也可以通过在终端使用cURL命令发送 HTTP 请求:

curl -X POST "http://localhost:8000/register" -H "Content-Type: application/json" -d '{"username": "user1", "password": "12345"}'
curl -X POST "http://localhost:8000/login" -H "Content-Type: application/json" -d '{"username": "user1", "password": "12345"}'
curl -X POST "http://localhost:8000/chat" -H "Content-Type: application/json" -d '{"user_input": "什么是人工智能?", "chat_history": []}'

http://www.ppmy.cn/news/1581407.html

相关文章

Ubuntu sudo apt-get install 出现“E: 无法定位软件包问题”解决方法

方法1.使用sudo apt-get update 方法2.更换镜像源 清华源地址&#xff1a;清华源地址https://mirrors.tuna.tsinghua.edu.cn/help/ubuntu/ Ubuntu的软件配置文件是/etc/apt/sources.list 更换源的时候需要先对其进行备份再更换 步骤如下&#xff1a; 1.切换到镜像源的位置…

cursor无限续杯软件操作教程

软件使用教程&#xff1a; 在这里插入图片描述 软件界面&#xff1a; 破解流程&#xff1a; 1.退出 cursor 软件的账号&#xff0c;点击 log out 按钮&#xff0c;可以手动退出并关闭软件。 2.删除账号&#xff0c;点击按钮会自动打开网页&#xff0c;手动删除即可。 3.确保…

Vue学习汇总(JS长期更新版)

文章目录 一、开始  二、基础 目录 一、开始 1、[Vue]VsCode快捷键 二、基础 1、[Vue]模版语法        2、[Vue]属性绑定        3、[Vue]条件渲染        4、[Vue]列表渲染

Jmeter插件下载和配置

下载插件&#xff1a; https://jmeter-plugins.org/wiki/PluginsManager/ https://jmeter-plugins.org/get/ 下载文件移动到jmeter安装目录&#xff1a;\apache-jmeter-5.6.3\lib\ext\重启Jmeter后Options中查看插件 4.

AI爬虫 :Firecrawl的安装和详细使用案例(将整个网站转化为LLM适用的markdown或结构化数据)

更多内容请见: 爬虫和逆向教程-专栏介绍和目录 文章目录 1. Firecrawl概述1.1 Firecrawl介绍1.2 Firecrawl 的特征1.3 Firecrawl 的功能1.4 Firecrawl的 API 密钥获取2. 安装和基本使用3. 使用 LLM 提取4. 无模式提取(curl语句)5. 使用操作与页面交互6. Firecrawl Cloud7. 移…

腾讯云HAI1元体验:轻松调用DeepSeek-R1模型搭建网站

前言 随着云计算和人工智能技术的不断发展&#xff0c;构建和部署智能化的网页变得越来越简单。腾讯云提供的HAI&#xff08;人工智能平台&#xff09;和DeepSeek&#xff08;智能搜索引擎&#xff09;服务&#xff0c;能帮助开发者快速搭建智能化网页&#xff0c;提升用户体验…

<项目> 主从Reactor模型的高并发服务器

目录 Reactor 概念 分类 单Reactor单线程 单Reactor多线程 多Reactor多线程 项目介绍 项目规划 模块关系 实现 TimerWheel -- 时间轮定时器 定时器系统调用 时间轮设计 通用类型Any Buffer Socket Channel Poller EventLoop&#xff08;核心&#xff09; eventfd 设计思路 …

Spring MVC 接口数据

访问路径设置 RequestMapping("springmvc/hello") 就是用来向handlerMapping中注册的方法注解! 秘书中设置路径和方法的对应关系&#xff0c;即RequestMapping("/springmvc/hello")&#xff0c;设置的是对外的访问地址&#xff0c; 路径设置 精准路径匹…