from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel
import ollama
import mysql.connector
from mysql.connector.cursor import MySQLCursor
import jsonapp = FastAPI()# 数据库连接配置
DB_CONFIG = {"database": "web", # 您的数据库名,用于存储业务数据"user": "root", # 数据库用户名,需要有读写权限"password": "XXXXXX", # 数据库密码,建议使用强密码"host": "127.0.0.1", # 数据库主机地址,本地开发环境使用localhost"port": "3306" # MySQL 默认端口,可根据实际配置修改
}# 数据库连接函数
def get_db_connection():try:conn = mysql.connector.connect(**DB_CONFIG)return connexcept Exception as e:raise HTTPException(status_code=500, detail=f"数据库连接失败: {str(e)}")class SQLRequest(BaseModel):question: strdef get_table_relationships():"""动态获取表之间的关联关系"""conn = get_db_connection()cur = conn.cursor()try:# 获取当前数据库名cur.execute("SELECT DATABASE()")db_name = cur.fetchone()[0]# 获取外键关系cur.execute("""SELECT TABLE_NAME,COLUMN_NAME,REFERENCED_TABLE_NAME,REFERENCED_COLUMN_NAMEFROM INFORMATION_SCHEMA.KEY_COLUMN_USAGEWHERE TABLE_SCHEMA = %sAND REFERENCED_TABLE_NAME IS NOT NULLORDER BY TABLE_NAME, COLUMN_NAME""", (db_name,))relationships = []for row in rows:table_name, column_name, ref_table, ref_column = rowrelationships.append(f"-- {table_name}.{column_name} can be joined with {ref_table}.{ref_column}")return "\n".join(relationships) if relationships else "-- No foreign key relationships found"finally:cur.close()conn.close()def get_database_schema():"""获取MySQL数据库表结构,以CREATE TABLE格式返回"""conn = get_db_connection()cur = conn.cursor()try:# 获取当前数据库名cur.execute("SELECT DATABASE()")db_name = cur.fetchone()[0]# 获取所有表的结构信息cur.execute("""SELECT t.TABLE_NAME,c.COLUMN_NAME,c.COLUMN_TYPE,c.IS_NULLABLE,c.COLUMN_KEY,c.COLUMN_COMMENTFROM INFORMATION_SCHEMA.TABLES tJOIN INFORMATION_SCHEMA.COLUMNS c ON t.TABLE_NAME = c.TABLE_NAMEWHERE t.TABLE_SCHEMA = %sAND t.TABLE_TYPE = 'BASE TABLE'ORDER BY t.TABLE_NAME, c.ORDINAL_POSITION""", (db_name,))rows = cur.fetchall()schema = []current_table = Nonetable_columns = []for row in rows:table_name, column_name, column_type, nullable, key, comment = rowif current_table != table_name:if current_table is not None:schema.append(f"CREATE TABLE {current_table} (\n" + ",\n".join(table_columns) + "\n);\n")current_table = table_nametable_columns = []# 构建列定义column_def = f" {column_name} {column_type.upper()}"if key == "PRI":column_def += " PRIMARY KEY"elif nullable == "NO":column_def += " NOT NULL"if comment:column_def += f" -- {comment}"table_columns.append(column_def)# 添加最后一个表if current_table is not None:schema.append(f"CREATE TABLE {current_table} (\n" + ",\n".join(table_columns) + "\n);\n")return "\n".join(schema)finally:cur.close()conn.close()def get_chinese_table_mapping():"""动态生成表名的中文映射"""conn = get_db_connection()cur = conn.cursor()try:# 获取所有表的注释信息cur.execute("""SELECT t.TABLE_NAME,t.TABLE_COMMENTFROM information_schema.TABLES tWHERE t.TABLE_SCHEMA = DATABASE()ORDER BY t.TABLE_NAME""")mappings = []for table_name, table_comment in cur.fetchall():# 生成表的中文名称chinese_name = table_nameif table_name.startswith('web_'):chinese_name = table_name.replace('web_', '').replace('_', '')if table_comment:chinese_name = table_comment.split('--')[0].strip()# 如果中文名称以"表"结尾,则去掉"表"字if chinese_name.endswith('表'):chinese_name = chinese_name[:-1]mappings.append(f' - "{chinese_name}" -> {table_name} table')return "\n".join(mappings)finally:cur.close()conn.close()@app.post("/query")
async def query_database(request: Request):try:# 获取请求体数据并确保正确处理中文body = await request.body()try:request_data = json.loads(body.decode('utf-8'))except UnicodeDecodeError:request_data = json.loads(body.decode('gbk'))question = request_data.get('question')print(f"收到问题: {question}") # 调试日志if not question:raise HTTPException(status_code=400, detail="缺少 question 参数")# 获取数据库结构db_schema = get_database_schema()#print(f"数据库结构: {db_schema}") # 调试日志# 获取中文映射并打印chinese_mapping = get_chinese_table_mapping()#print(f"表映射关系:\n{chinese_mapping}") # 添加这行来打印映射# 修改 prompt 使用更严格的指导prompt = f"""### Instructions:Convert Chinese question to MySQL query. Follow these rules strictly:1. ONLY return a valid SELECT SQL query2. Use EXACT table names from the mapping below3. DO NOT use any table that's not in the mapping4. For Chinese terms, use these exact mappings:
{chinese_mapping}### Examples:Question: 所有装修记录SQL: SELECT * FROM web_decoration ORDER BY id;Question: 查询装修SQL: SELECT * FROM web_decoration ORDER BY id;### Database Schema:{db_schema}### Question:{question}### SQL Query:"""# 获取 SQL 查询response = ollama.chat(model='sqlcoder:latest', messages=[{'role': 'user', 'content': prompt}])sql_query = response['message']['content'].strip()print(f"生成的SQL: {sql_query}") # 调试日志# 验证 SQL 查询if "装修" in question and "web_decoration" not in sql_query:sql_query = "SELECT * FROM web_decoration ORDER BY id"if not sql_query.upper().startswith('SELECT'):raise HTTPException(status_code=400, detail="无效的SQL查询格式")conn = get_db_connection()cur = conn.cursor(dictionary=True)try:cur.execute(sql_query)results = cur.fetchall()return {"sql_query": sql_query,"results": results}except mysql.connector.Error as e:raise HTTPException(status_code=400, detail=f"SQL 执行错误: {str(e)}")finally:cur.close()conn.close()except json.JSONDecodeError as e:raise HTTPException(status_code=400, detail=f"JSON 解析错误: {str(e)}")except Exception as e:raise HTTPException(status_code=500, detail=f"处理请求时生错误: {str(e)}")if __name__ == "__main__":import uvicornport = 8666print(f"Starting server on port {port}")uvicorn.run("main:app",host="0.0.0.0", port=port,reload=True)
我来为您解释这段代码的主要功能和结构:
主要功能
这是一个使用 FastAPI 框架开发的 API 服务,主要用于将中文问题转换为 SQL 查询语句并执行。
核心组件说明
- 数据库配置与连接
DB_CONFIG = {"database": "web","user": "root","password": "XXXXXX","host": "127.0.0.1","port": "3306"
}
提供了 MySQL 数据库的基本连接配置。
- 主要工具函数
get_table_relationships()
: 获取数据库表之间的外键关系get_database_schema()
: 获取数据库表结构get_chinese_table_mapping()
: 生成表名的中文映射关系
- 核心 API 端点
@app.post("/query")
这个端点接收中文问题,主要处理流程:
- 接收并解析用户的中文问题
- 获取数据库结构和表映射
- 使用 ollama 模型将中文转换为 SQL 查询
- 执行 SQL 查询并返回结果
- 智能转换功能
使用ollama
的sqlcoder
模型将中文问题转换为 SQL 查询,包含:
- 严格的表名映射
- SQL 查询验证
- 错误处理机制
特点
- 支持中文输入处理
- 自动获取数据库结构
- 动态生成中文表名映射
- 完善的错误处理机制
- 支持热重载的开发模式
使用示例
可以通过 POST 请求访问 /query
端点:
{"question": "查询所有装修记录"
}
服务会返回:
{"sql_query": "SELECT * FROM web_decoration ORDER BY id","results": [...]
}
安全特性
- 数据库连接错误处理
- SQL 注入防护
- 请求体编码自适应(支持 UTF-8 和 GBK)
- 查询结果的安全封装
查看效果: