最近最火的AI话题无疑就是ChatGPT了,让大家看到了通用智能领域的巨大进步,ChatGPT已经能足够好的回答人们提出的各种问题,因此我也在想能否利用ChatGPT来理解用户对于数据分析方面的提问,把这些提问转化为相应的数据分析任务,再把结果返回给用户。
例如我们有一个数据库,用户可以进行发问,例如询问最近一年里面每个月的注册用户数有多少。ChatGPT可以理解用户的问题,并把问题转换为SQL语句,我们要做的只是用这个SQL去查询数据,并把结果以图表的方式呈现给用户。另外我们还可以让用户提出一些预测性的问题,例如根据过去三年的用户数增长趋势,预测下个月的用户增长数。我们可以用ChatGPT来生成一个解答问题的步骤,根据步骤来建立相应的机器学习模型,用历史数据进行训练后来进行预测,并返回结果给用户。
LangChain是一个开源的应用开发框架,可以结合ChatGPT等LLM大语言模型和其他工具,构建一系列的任务逻辑工作链,方便我们构建基于LLM的应用。
这里我将基于LangChain和OpenAI的模型,介绍一下如何构建一个智能数据分析系统。
生成模拟数据
首先我们先建立一个简单的数据集,并保存在PostgreSQL数据库中。假设我们要生成一些用户注册的数据,每个用户有用户ID,注册时间,品牌这些属性,以下代码将生成数据:
import uuid
from datetime import date, timedeltastartdate = date(2022, 1, 1)
enddate = date(2022, 12, 31)
start_num = 100
days = (enddate - startdate).days
with open('register_user.csv', 'w') as f:for i in trange(days):current_date = (startdate + timedelta(days=i)).isoformat()current_num = (int) (start_num * (1+np.random.normal(0.1, 0.2)))for j in range(current_num):k = random.randint(1, 10)brand_id = 0if k>6:if k<9:brand_id = 1else:brand_id = 2f.write(str(uuid.uuid1())+','+brands[brand_id]+','+current_date+'\n')
把生成的CSV数据文件用PSQL的copy命令导入到PG的数据库中。
建立一个问答网站
采用FLASK来建立一个简单的网站。网站的模板我是基于DeskApp(DeskApp - Bootstrap Admin Dashboard HTML Template)来改动的,这是一个基于Bootstrap的网站模板。建立一个templates目录,里面建立一个index.html文件,内容如下:
<!DOCTYPE html>
<html><head><!-- Basic Page Info --><meta charset="utf-8" /><title>Smart Analytics - Get the business insights by simply ask the question</title><!-- Mobile Specific Metas --><metaname="viewport"content="width=device-width, initial-scale=1, maximum-scale=1"/><!-- Google Font --><linkhref="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700;800&display=swap"rel="stylesheet"/><!-- CSS --><link rel="stylesheet" type="text/css" href="{{ url_for('static', filename='css/core.css') }}" /><linkrel="stylesheet"type="text/css"href="{{ url_for('static', filename='css/icon-font.min.css') }}"/><linkrel="stylesheet"type="text/css"href="{{ url_for('static', filename='css/dataTables.bootstrap4.min.css') }}"/><linkrel="stylesheet"type="text/css"href="{{ url_for('static', filename='css/responsive.bootstrap4.min.css') }}"/><link rel="stylesheet" type="text/css" href="{{ url_for('static', filename='css/style.css') }}" /></head><body class="sidebar-shrink"><div class="main-container"><div><div class="title pb-20" id="title"><h2 class="h3 mb-0">Smart Analytics</h2><p>Explore data and get insights using natural language</p></div><div class="row ml-15 mr-15 h-100"><div class="col-md-12 mb-20"><div class="card-box height-100-p pd-20" id="query_answer"><form method="post"><div class="row"><div class="form-group col-md-6"><label>Question</label><textarea class="form-control" name="query" value="{{ query_value }}">{{ query_value }}</textarea></div><div class="form-group col-md-6"><label>Answer</label><textarea class="form-control" name="answer" readonly="readonly">{{ answer }}</textarea></div></div><div class="row"><div class="col-md-12"><input class="btn btn-info" type="submit" name="submit" value="Ask"></div></div></form></div></div></div><div class="row pb-10 h-100" id="chart_div" style="display:{{ chart_div_display }}"><div class="col-md-12 mb-20"><div class="card-box height-100-p pd-20"><div class="d-flex flex-wrap justify-content-between align-items-center pb-0 pb-md-3"><div class="h5 mb-md-0">{{ visual_name }}</div></div><div id="chart" style="visibility:{{ chart_visibility }}"></div><div id="table" style="display:{{ table_visibility }};overflow-x:auto;">{{ tabledata| safe }}</div></div></div><!-- Todo: Add the chat history here--></div></div></div><!-- js --><script src="{{ url_for('static', filename='js/core.js') }}"></script><script src="{{ url_for('static', filename='js/script.min.js') }}"></script><script src="{{ url_for('static', filename='js/process.js') }}"></script><script src="{{ url_for('static', filename='js/layout-settings.js') }}"></script><script src="{{ url_for('static', filename='js/jquery.dataTables.min.js') }}"></script><script src="{{ url_for('static', filename='js/dataTables.bootstrap4.min.js') }}"></script><script src="{{ url_for('static', filename='js/dataTables.responsive.min.js') }}"></script><script src="{{ url_for('static', filename='js/responsive.bootstrap4.min.js') }}"></script><script src="{{ url_for('static', filename='js/apexcharts.min.js') }}"></script><script>var height= $(document).height();var title_height = $("#title").height();var query_answer_height = $("#query_answer").height();var chart_height = height - title_height - query_answer_height - 80;$("#chart_div").css('height',chart_height+"px");var option = JSON.parse('{{ option | tojson | safe}}');var chart = new ApexCharts(document.querySelector('#chart'), option);chart.render();</script></body>
</html>
这个网页提供了一个表单给用户输入问题,提交之后,后台会返回答案以及图表或者表格的信息,然后在前端网页渲染。这里的图表是用的apexcharts.js这个库。
建立一个static目录,把index.html里面用到的css, js文件放到这个目录下。
Flask应用
下一步就是新建一个Flask应用程序,渲染index.html模板,处理用户提交的问题,并返回结果。新建一个report.py文件。
1. 加载LLM以及配置Langchain
为了能调用OpenAI的模型,我们需要在openai.com建立一个账号,然后获得一个免费试用的API Key,就可以调用openAI的模型了。
from langchain import OpenAI, SQLDatabase, SQLDatabaseChainos.environ["OPENAI_API_KEY"] = "XXXX"
llm = OpenAI(temperature=0, verbose=True) //设置temperature>0可以使得答案具有一定随机性
db = SQLDatabase.from_uri("postgresql://postgres:postgres@localhost/demo")
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, use_query_checker=False, return_intermediate_steps=True)
2. 编写提示词模板
我们需要把用户的输入通过模板进行一些格式化的操作,使得LLM能更好的给出我们想要的答案。例如我们需要根据用户的输入来提取数据查询相关的问题,提取输出图表格式的要求,或者判断用户的输入是否是一个回归预测相关的问题等等。以下是需要用到的一些提示词模板:
from langchain.prompts import PromptTemplatetemplate = "You are now a requirement extractor to extract the user input for data query and output format requirement. Your response is in JSON format with \"query\" and \"output format\" as key in JSON body. The user input is {input}"
prompt = PromptTemplate(input_variables=["input"],template=template
)template_predict = "Check if the user input include any prediction task, response with yes or no, the user input is {input}"
prompt_predict = PromptTemplate(input_variables=["input"],template=template_predict
)template_steps = "Given the regression task of {input}, generate the steps to solve the task, response in JSON format with \"tasks\" as key and put the steps as value."
prompt_steps = PromptTemplate(input_variables=["input"],template=template_steps
)template_collectdata = "Given the regression task of {input}, {step}"
prompt_collectdata = PromptTemplate(input_variables=["input", "step"],template=template_collectdata
)template_predictresult = "Generate the response of question '{input}' with the result {result}"
prompt_predictresult = PromptTemplate(input_variables=["input", "result"],template=template_predictresult
)template_charttitle = "Generate a chart title for the query '{input}'"
prompt_charttitle = PromptTemplate(input_variables=["input"],template=template_charttitle
)template_exploration = "Is the query '{input}' relate to machine learning data exploration, answer in Yes or No"
prompt_exploration = PromptTemplate(input_variables=["input"],template=template_exploration
)template_explorequery = "what data table and columns relate to the query '{input}', response in JSON format with key \"table_name\" and \"column_name\", put the table name and column name as value in JSON body"
prompt_explorequery = PromptTemplate(input_variables=["input"],template=template_explorequery
)template_exploretitle = "Generate a general title of the answer to the query '{input}'"
prompt_exploretitle = PromptTemplate(input_variables=["input"],template=template_exploretitle
)
3. 定义Flask app应用
from flask import Flask
from flask import render_template
from flask import requestapp = Flask(__name__)
app.config['SEND_FILE_MAX_AGE_DEFAULT'] = -1
4. 定义不同的图表类型
我用Apexcharts这个JS库来作为图表的渲染。定义三种不同的图表,包括了柱状图,饼图和折线图,代码如下:
#Apexchart's bar option
bar_chart_template = {'chart': {'height': 400,'type': 'bar'},'series': [],'xaxis': {'categories': []},'legend': {'position': 'top','horizontalAlign': 'right','floating': True,'offsetY': 0,'labels': {'useSeriesColors': True},'markers': {'width': 10,'height': 10,}}
}#Apexchart's pie option
pie_chart_template = {'chart': {'height': 400,'type': 'pie'},'series': [],'labels': [],'legend': {'position': 'top','horizontalAlign': 'right','floating': True,'offsetY': 0,'labels': {'useSeriesColors': True},'markers': {'width': 10,'height': 10,}}
}#Apexchart's line option
line_chart_template = {'chart': {'height': 400,'type': 'line'},'series': [],'xaxis': {'categories': []},'legend': {'position': 'top','horizontalAlign': 'right','floating': True,'offsetY': 0,'labels': {'useSeriesColors': True},'markers': {'width': 10,'height': 10,}}
}
5. 定义一个处理HTTP Request的Python函数
最后就是定义一个report函数,处理用户发过来的HTTP请求,这里我就不一一描述,具体可以见以下代码和注释,大致的逻辑是用刚才定义的提示词模板来对用户的问题进行格式化,判断用户的问题是和数据库查询相关,还是机器学习相关,并采取对应的步骤。
@app.route("/", methods=['GET', 'POST'])
def report():visual_name = ''if request.method == 'POST':query_str = request.form.get('query')# Execute the queryconn = psycopg2.connect(database="telematics_demo", user="postgres", password="postgres", host="localhost", port="5432")cursor = conn.cursor()option = deepcopy(bar_chart_template)# Check if the query is relate to machine learning prediction task?if 'Yes' in llm(prompt_predict.format(input=query_str)):steps_json = json.loads(llm(prompt_steps.format(input=query_str)))for step in steps_json['tasks']:if 'Yes' in llm("If the task '"+ step + "' relate to data collection? Answer in Yes or No"):breakresult = db_chain(prompt_collectdata.format(input=query_str, step=step))sql_cmd = ''for step in result['intermediate_steps']:if isinstance(step, dict):if 'sql_cmd' in step:sql_cmd = step['sql_cmd']sql_cmd = re.sub(r'LIMIT \d+', '', sql_cmd)print(sql_cmd)cursor.execute(sql_cmd)rows = cursor.fetchall()conn.commit()cursor.close()conn.close()timestep = 0train_X = []train_Y = []for row in rows:train_X.append([timestep])timestep += 1train_Y.append(row[-1])model = linear_model.LinearRegression()model.fit(train_X, train_Y)predict_result = model.predict([[timestep]])answer = llm(prompt_predictresult.format(input=query_str, result=str(predict_result))).strip()return render_template('index.html', answer=answer,query_value=query_str,option=option,chart_div_display='none',table_visibility='none',tabledata='')# Check if the query is relate to data exploration task?if 'Yes' in llm(prompt_exploration.format(input=query_str)):result_json = json.loads(db_chain(prompt_explorequery.format(input=query_str))['result'])if type(result_json['column_name']) is list:column_name = ','.join(result_json['column_name'])else:column_name = result_json['column_name']sql_cmd = 'select ' + column_name + ' from ' + result_json['table_name'] + ';'df = pd.read_sql(sql_cmd, conn)tabledata = df.describe().to_html(classes=['table', 'table-bordered'])conn.commit()cursor.close()conn.close()visual_name = llm(prompt_exploretitle.format(input=query_str))return render_template('index.html', query_value=query_str, option={}, chart_div_display='inline',answer='See the below table for exploration details',visual_name=visual_name,tabledata=tabledata)prompt_json = json.loads(llm(prompt.format(input=query_str)))result = db_chain(prompt_json['query'])sql_cmd = ''for step in result['intermediate_steps']:if isinstance(step, dict):if 'sql_cmd' in step:sql_cmd = step['sql_cmd']sql_cmd = re.sub(r'LIMIT \d+', '', sql_cmd)if 'output format' not in prompt_json:return render_template('index.html', answer=result['intermediate_steps'][-1],query_value = query_str,chart_div_display='none')cursor.execute(sql_cmd)rows = cursor.fetchall()conn.commit()cursor.close()conn.close()columns = cursor.descriptioncolumn_num = len(columns)if 'bar' in prompt_json['output format']:if column_num == 1:option['xaxis']['categories'] = [columns[0].name]option['series'] = [{'name': '', 'data': [rows[0][0]]}]elif column_num == 2:option['series'] = [{'name': columns[0].name, 'data': []}]for row in rows:option['series'][0]['data'].append(row[1])option['xaxis']['categories'].append(row[0])else:xAxis_data = set()yAxis_data = {}option['series'] = []# Check which column is used for xAxis, # if there is order by in sql, use this column, otherwise use column 0column_orderby = ''column_xAxis_index = 0if regex_orderby.search(sql_cmd):column_orderby = regex_orderby.search(sql_cmd).groups()[0]for i, column in enumerate(columns):if column.name == column_orderby:column_xAxis_index = ielse:column_orderby = columns[0].name# Get another column index besides the main xAxis columnidx = 0 if column_xAxis_index==1 else 1for row in rows:xAxis_data.add(row[column_xAxis_index])if row[idx] not in yAxis_data:yAxis_data[row[idx]] = {}yAxis_data[row[idx]][row[column_xAxis_index]] = row[-1]option['xaxis']['categories'] = list(sorted(xAxis_data))for key in yAxis_data:option['series'].append({'name': key, 'data': []})for group_key in option['xaxis']['categories']:if group_key not in yAxis_data[key]:option['series'][-1]['data'].append(0)else:option['series'][-1]['data'].append(yAxis_data[key][group_key])visual_name = llm(prompt_charttitle.format(input=query_str))return render_template('index.html', answer=result['result'],query_value=query_str,option=option,chart_div_display='inline',chart_visibility='inline',table_visibility='none',visual_name=visual_name,tabledata='')elif 'pie' in prompt_json['output format']:option = deepcopy(pie_chart_template)for row in rows:option['series'].append(row[1])option['labels'].append(row[0])visual_name = llm(prompt_charttitle.format(input=query_str))return render_template('index.html', answer=result['result'],query_value=query_str,option=option,chart_div_display='inline',chart_visibility='inline',table_visibility='none',visual_name=visual_name,tabledata='') elif 'line' in prompt_json['output format']:option = deepcopy(line_chart_template) if column_num == 2:option['series'] = [{'name': columns[0].name, 'data': []}]for row in rows:option['series'][0]['data'].append(row[1])option['xaxis']['categories'].append(row[0])else:xAxis_data = set()yAxis_data = {}option['series'] = []# Check which column is used for xAxis, # if there is order by in sql, use this column, otherwise use column 0column_orderby = ''column_xAxis_index = 0if regex_orderby.search(sql_cmd):column_orderby = regex_orderby.search(sql_cmd).groups()[0]for i, column in enumerate(columns):if column.name == column_orderby:column_xAxis_index = ielse:column_orderby = columns[0].name# Get another column index besides the main xAxis columnidx = 0 if column_xAxis_index==1 else 1for row in rows:xAxis_data.add(row[column_xAxis_index])if row[idx] not in yAxis_data:yAxis_data[row[idx]] = {}yAxis_data[row[idx]][row[column_xAxis_index]] = row[-1]option['xaxis']['categories'] = list(sorted(xAxis_data))for key in yAxis_data:option['series'].append({'name': key, 'data': []})for group_key in option['xaxis']['categories']:if group_key not in yAxis_data[key]:option['series'][-1]['data'].append(0)else:option['series'][-1]['data'].append(yAxis_data[key][group_key]) visual_name = llm(prompt_charttitle.format(input=query_str))return render_template('index.html', answer=result['result'],query_value=query_str,option=option,chart_div_display='inline',chart_visibility='inline',table_visibility='none',visual_name=visual_name,tabledata='')else:return render_template('index.html', answer=result['intermediate_steps'][-1],query_value=query_str,option=option,chart_div_display='none',tabledata='')if request.method == 'GET':return render_template('index.html', query_value='', option={}, chart_div_display='none',table_visibility='none',tabledata='')