text2sql / app.py
allinaigc's picture
Upload app.py
2aef356 verified
'''
1. 大模型以Qwen API形式提供。
1. 重塑了Qwen作为大语言模型做Text2SQL的提示词:
sys_prompt = """
1. 你是一个将文字转换成SQL语句的人工智能。
2. 你需要注意:你只需要用纯文本回复代码的内容,即你不允许回复代码以外的任何信息。
3. SQL变量默认是中文,而且只能从如下的名称列表中选择,你不可以使用这些名字以外的变量名:"长度","宽度","价格","产品ID","比率","类别","*"
4. 你不能写IF, THEN的SQL语句,需要使用CASE。
5. 我需要你转换的文字如下:"""
total_prompt = sys_prompt + "在数据表格table01中," + prompt
'''
##TODO: 2. 账号功能。
import time
import os
import pandas as pd
import streamlit as st
from code_editor import code_editor
import streamlit_authenticator as stauth
# from utils.setup import setup_connexion, setup_session_state
# from utils.vanna_calls import (
# generate_questions_cached,
# run_sql_cached,
# generate_plotly_code_cached,
# generate_plot_cached,
# generate_followup_cached,
# )
# import chatsql003 ### 本地ChatGLT 版本。
import chatsql004 ###Qwen API 版本。
import sql_command
# from streamlit_pandas_profiling import st_profile_report
import dashscope
from dotenv import load_dotenv
load_dotenv()
### 设置openai的API key
dashscope.api_key = os.environ['dashscope_api_key']
st.set_page_config(layout="wide", page_icon="🧩", page_title="本地化国产大模型数据库查询演示")
# setup_connexion()
def clear_all():
st.session_state.conversation = None
st.session_state.chat_history = None
st.session_state.messages = []
message_placeholder = st.empty()
st.session_state["my_question"] = None
return None
## 原始的控制面板
# st.sidebar.title("大模型控制面板")
# st.sidebar.checkbox("Show SQL", value=True, key="show_sql")
# st.sidebar.checkbox("Show Table", value=True, key="show_table")
# st.sidebar.checkbox("Show Plotly Code", value=True, key="show_plotly_code")
# st.sidebar.checkbox("Show Chart", value=True, key="show_chart")
# st.sidebar.checkbox("Show Follow-up Questions", value=True, key="show_followup")
# st.sidebar.button("清除记录,重启一轮新对话", on_click=setup_session_state, use_container_width=True, type='primary')
st.title("本地化国产大模型数据库查询演示")
# st.title("大语言模型SQL数据库查询中心")
# st.info("声明:内容由人工智能生成,仅供参考。如果您本人使用或对外传播本服务生成的输出,您应当主动核查输出内容的真实性、准确性,避免传播虚假信息。") ## 颜色比较明显。
st.markdown("_声明:内容由人工智能生成,仅供参考。如果您本人使用或对外传播本服务生成的输出,您应当主动核查输出内容的真实性、准确性,避免传播虚假信息。_")
### authentication with a local yaml file.
import yaml
from yaml.loader import SafeLoader
with open('./config.yaml') as file:
config = yaml.load(file, Loader=SafeLoader)
authenticator = stauth.Authenticate(
config['credentials'],
config['cookie']['name'],
config['cookie']['key'],
config['cookie']['expiry_days'],
config['preauthorized']
)
user, authentication_status, username = authenticator.login('main')
# user, authentication_status, username = authenticator.login('用户登录', 'main')
print('登录的用户:', username)
### Streamlit Sidebar 左侧工具栏
# st.sidebar.write(st.session_state)
if authentication_status:
with st.sidebar:
st.markdown(
"""
<style>
[data-testid="stSidebar"][aria-expanded="true"]{
min-width: 500px;
max-width: 500px;
}
""",
unsafe_allow_html=True,
)
# st.header(f'**大语言模型专家系统工作设定区**')
st.header(f'**欢迎 **{username}** 使用本系统** ')
st.write(f'_Large Language Model Expert System Working Environment_')
# st.write(f'_Welcome and Hope U Enjoy Staying Here_')
authenticator.logout('登出', 'sidebar')
### siderbar的题目。
### siderbar的题目。
# st.header(f'**大语言模型专家系统工作设定区**')
# st.header(f'**欢迎 **{username}** 使用本系统** ') ## 用户登录显示。
st.sidebar.button("清除记录,重启一轮新对话", on_click=clear_all, use_container_width=True, type='primary')
# st.sidebar.button("清除记录,重启一轮新对话", on_click=setup_session_state, use_container_width=True, type='primary')
### 展示当前数据库
# st.markdown("#### 当前数据库中的数据:")
with st.expander("#### 当前数据库中的数据", expanded=True):
my_db = pd.read_sql_table('table01', 'sqlite:///myexcelDB.db')
st.dataframe(my_db, width=400)
## 在sidebar上的三个分页显示,用st.tabs实现。
tab_1, tab_2, tab_4 = st.tabs(['使用须知', '模型参数', '角色设定'])
# tab_1, tab_2, tab_3, tab_4 = st.tabs(['使用须知', '模型参数', '提示词模板', '系统角色设定'])
# with st.expander(label='**使用须知**', expanded=False):
with tab_1:
# st.markdown("#### 快速上手指南")
# with st.text(body="说明"):
# st.markdown("* 重启一轮新对话时,只需要刷新页面(按Ctrl/Command + R)即可。")
with st.text(body="说明"):
st.markdown("""* 简介
本系统使用大模型技术,将自然语言描述转换为SQL查询语句。用户可以输入自然语言问题,系统会自动生成相应的SQL语句,并返回查询结果。
* 使用步骤
在文本框中输入您的自然语言问题。
系统会自动生成相应的SQL语句,并显示在下方。
点击“满意,请执行语句”选项,可查看查询结果。点击”不满意,需要改正语句“可以手动修改SQL语句。
* 注意事项
本系统仍在开发中,可能会存在一些错误或不准确的地方。
自然语言描述越清晰,生成的SQL语句越准确。
系统支持的SQL语法有限,请尽量使用简单易懂的语法。
""")
# with st.text(body="说明"):
# st.markdown("""在构建大语言模型本地知识库问答系统时,需要注意以下几点:
# LLM的选择:LLM的选择应根据系统的应用场景和需求进行。对于需要处理通用问题的系统,可以选择通用LLM;对于需要处理特定领域问题的系统,可以选择针对该领域进行微调的LLM。
# 本地知识库的构建:本地知识库应包含系统所需的所有知识。知识的组织方式应便于LLM的访问和处理。
# 系统的评估:系统应进行充分的评估,以确保其能够准确地回答用户问题。
# 大语言模型本地知识库问答系统具有以下优势:
# 准确性:LLM和本地知识库的结合可以提高系统的准确性。
# 全面性:LLM和本地知识库的结合可以使系统能够回答更广泛的问题。
# 效率:LLM可以快速生成候选答案,而本地知识库可以快速评估候选答案的准确性。
# """)
# with st.text(body="说明"):
# st.markdown(
# "* “数据分析模式”暂时只支持1000个单元格以内的数据分析,单元格中的内容不支持中文数据(表头也尽量不使用中文)。一般运行时间在1至10分钟左右,期间需要保持网络畅通。")
# with st.text(body="说明"):
# st.markdown("* “数据分析模式”推荐上传csv格式的文件,部分Excel文件容易出现数据不兼容的情况。")
## 大模型参数
# with st.expander(label='**大语言模型参数**', expanded=True):
with tab_2:
max_tokens = st.slider(label='Max_Token(生成结果时最大字数)', min_value=100, max_value=4096, value=2048, step=100)
temperature = st.slider(label='Temperature (温度)', min_value=0.0, max_value=1.0, value=0.8, step=0.1)
top_p = st.slider(label='Top_P (核采样)', min_value=0.0, max_value=1.0, value=0.6, step=0.1)
frequency_penalty = st.slider(label='Frequency Penalty (重复度惩罚因子)', min_value=-2.0, max_value=2.0, value=1.0, step=0.1)
presence_penalty = st.slider(label='Presence Penalty (控制主题的重复度)', min_value=-2.0, max_value=2.0, value=1.0, step=0.1)
## reset password widget
# try:
# if authenticator.reset_password(st.session_state["username"], 'Reset password'):
# st.success('Password modified successfully')
# except Exception as e:
# st.error(e)
# with st.header(body="欢迎"):
# st.markdown("# 欢迎使用大语言模型商业智能中心")
# with st.expander(label=("**重要的使用注意事项**"), expanded=True):
# with st.container():
##NOTE: 在SQL场景去不需要展示这些提示词。
# with tab_3:
# # st.markdown("#### Prompt提示词参考资料")
# # with st.expander(label="**大语言模型基础提示词Prompt示例**", expanded=False):
# st.code(
# body="继续用中文写一篇关于 [文章主题] 的文章,以下列句子开头:[文章开头]。", language='plaintext')
# st.code(body="将以下文字概括为 100 个字,使其易于阅读和理解。避免使用复杂的句子结构或技术术语。",
# language='plaintext')
# st.code(body="给我出一个迪奥2023春季发布会活动策划。", language='plaintext')
# st.code(body="帮我按照正式会议结构写一个会邀:主题是xx手机游戏立项会议。", language='plaintext')
# st.code(body="帮我写一个车内健康监测全场景落地的项目计划,用表格。", language='plaintext')
# st.code(
# body="同时掷两枚质地均匀的骰子,则两枚骰子向上的点数之和为 7 的概率是多少。", language='plaintext')
# st.code(body="写一篇产品经理的演讲稿,注意使用以下词汇: 赋能,抓手,中台,闭环,落地,漏斗,沉淀,给到,同步,对齐,对标,迭代,拉通,打通,升级,交付,聚焦,倒逼,复盘,梳理,方案,联动,透传,咬合,洞察,渗透,兜底,解耦,耦合,复用,拆解。", language='plaintext')
# with st.expander(label="**数据分析模式的专用提示词Prompt示例**", expanded=False):
# # with st.subheader(body="提示词Prompt"):
# st.code(body="分析此数据集并绘制一些'有趣的图表'。", language='python')
# st.code(
# body="对于这个文件中的数据,你需要要找出[X,Y]数据之间的寻找'相关性'。", language='python')
# st.code(body="对于这个文件中的[xxx]数据给我一个'整体的分析'。", language='python')
# st.code(body="对于[xxx]数据给我一个'直方图',提供图表,并给出分析结果。", language='python')
# st.code(body="对于[xxx]数据给我一个'小提琴图',并给出分析结果。", language='python')
# st.code(
# body="对于[X,Y,Z]数据在一个'分布散点图 (stripplot)',所有的数据在一张图上展现, 并给出分析结果。", language='python')
# st.code(body="对于[X,Y]数据,进行'T检验',你需要展示图表,并给出分析结果。",
# language='python')
# st.code(body="对于[X,Y]数据给我一个3个类别的'聚类分析',并给出分析结果。",
# language='python')
with tab_4:
st.text_area(label='系统角色设定', value='你是一个人工智能,你需要回答我提出的问题,或者完成我交代的任务。你需要使用我提问的语言(如中文、英文)来回答。', height=200, label_visibility='hidden')
myavatar = "./2D.png"
def set_question(question):
st.session_state["my_question"] = question
assistant_message_suggested = st.chat_message(
"assistant", avatar=myavatar
)
##! 注意在填写列名时,最佳的选择是使用双引号将名字框出来。否则LLM会不认识。
if assistant_message_suggested.button("点击这里可以查看参考问题示例 ℹ️"):
questions = [
'你查一下长度超过10的数据有哪些?',
'你告诉我长度超过30且宽度超过10的数据有哪些?',
'宽度超过20且价格超过25的数据有哪些?',
'”长度“超过10且”宽度“超过15且”类别“是”甲“的数据有哪些?',
'长度超过10以上的价格的总和是多少?',
'长度排名前三的产品ID和价格和长度?',
]
st.session_state["my_question"] = None
# questions = generate_questions_cached() ## orignal.
for i, question in enumerate(questions):
time.sleep(0.05)
button = st.button(
question,
on_click=set_question,
args=(question,),
)
my_question = st.session_state.get("my_question", default=None)
if my_question is None:
my_question = st.chat_input(
"请在这里提交您的查询问题",
)
if my_question:
st.session_state["my_question"] = my_question
user_message = st.chat_message("user")
user_message.write(f"{my_question}")
# sql = generate_sql_cached(question=my_question) ## original. 核心函数。用ChatGPT输出SQL语句。
# sql = chatsql003.main(prompt=my_question) ## 通过本地LLM输出SQL语句。
sql = chatsql004.main(prompt=my_question) ## 通过本地LLM输出SQL语句。
if sql:
if st.session_state.get("show_sql", True):
assistant_message_sql = st.chat_message(
"assistant", avatar=myavatar
)
assistant_message_sql.code(sql, language="sql", line_numbers=True)
user_message_sql_check = st.chat_message("user")
user_message_sql_check.write(f"您是否满意上面的SQL查询语句答案?")
with user_message_sql_check:
happy_sql = st.radio(
# "Happy",
label="您的反馈:",
options=["未决定", "满意,请执行语句", "不满意,需要改正语句"], ## 去掉第一个happy选项。
# options=["", "yes", "no"], ## original。
key="radio_sql",
index=0,
horizontal=True,
label_visibility='visible',
)
# ## 非original。后期添加内容。Working.
if happy_sql == "未决定":
df = None
st.session_state["df"] = None
if happy_sql == "不满意,需要改正语句":
st.warning(
"请您手动修正SQL语句,按Shift + Enter提交"
)
sql_response = code_editor(sql, lang="sql")
fixed_sql_query = sql_response["text"]
if fixed_sql_query != "":
# df = run_sql_cached(sql=fixed_sql_query) ## original. 核心函数。
df = sql_command.llm_query(sql_command=fixed_sql_query)
else:
df = None
elif happy_sql == "满意,请执行语句":
# df = run_sql_cached(sql=sql) ## original. 核心函数。
df = sql_command.llm_query(sql_command=sql) ## working.
else:
df = None
if df is not None:
st.session_state["df"] = df
if st.session_state.get("df") is not None:
## 显示查询的结果汇总信息
# with st.container():
query_info_describe = st.chat_message("assistant", avatar=myavatar)
if df is not None:
with query_info_describe:
metric_col1, metric_col2, metric_col3, metric_col4 = st.columns(4)
metric_col1.metric(label='查询结果中的行数', value=f"{df.shape[0]} 行", delta=None)
metric_col2.metric(label='查询结果中的列数', value=f"{df.shape[1]} 列", delta=None)
metric_col3.metric(label='查询结果中的单元格数', value=f"{df.size} 个", delta=None)
metric_col4.metric(label='查询结果中的缺失值', value=f"{df.isnull().sum().sum()} 个", delta=None)
if st.session_state.get("show_table", True):
df = st.session_state.get("df")
assistant_message_table = st.chat_message(
"assistant",
avatar=myavatar
)
### 目前是最佳显示方式
if len(df) > 10:
assistant_message_table.markdown("**查询结果较多,当前只显示前10行数据**") ##TODO: 看看是否有其他更好的显示方式。
assistant_message_table.dataframe(df.head(10), hide_index=True)
# assistant_message_table.dataframe(df.head(100), use_container_width=True)
else:
assistant_message_table.dataframe(df, hide_index=True)
# assistant_message_table.dataframe(df,use_container_width=True)
## 以下是其他datafram的显示方式,均不理想。
# st.dataframe(df, use_container_width=True, height=100)
## streamlit AwesomeTable, Dark模式下显示有问题。可以在sidebar上显示内容。
# from awesome_table import AwesomeTable
# AwesomeTable(df,show_order=True, show_search=True, show_search_order_in_sidebar=True)
# # AwesomeTable(pd.json_normalize(df))
# ## 展示图形
# # code = generate_plotly_code_cached(question=my_question, sql=sql, df=df) ## orignal. 用LLM来构建统计图表:https://github.com/vanna-ai/vanna/blob/main/src/vanna/mistral/mistral.py
# if st.session_state.get("show_plotly_code", False):
# assistant_message_plotly_code = st.chat_message(
# "assistant",
# avatar=myavatar,
# )
# assistant_message_plotly_code.code(
# code, language="python", line_numbers=True
# )
# user_message_plotly_check = st.chat_message("user")
# user_message_plotly_check.write(
# f"Are you happy with the generated Plotly code?"
# )
# with user_message_plotly_check:
# happy_plotly = st.radio(
# "Happy",
# options=["", "yes", "no"],
# key="radio_plotly",
# index=0,
# )
# if happy_plotly == "no":
# st.warning(
# "Please fix the generated Python code. Once you're done hit Shift + Enter to submit"
# )
# python_code_response = code_editor(code, lang="python")
# code = python_code_response["text"]
# elif happy_plotly == "":
# code = None
# if code is not None and code != "":
# if st.session_state.get("show_chart", True):
# assistant_message_chart = st.chat_message(
# "assistant",
# avatar=myavatar,
# )
# fig = generate_plot_cached(code=code, df=df)
# if fig is not None:
# assistant_message_chart.plotly_chart(fig)
# else:
# assistant_message_chart.error("I couldn't generate a chart")
## Data Visualization: Ydata_profiling (pandas_profiling)
# from ydata_profiling import ProfileReport
# pr = ProfileReport(df)
# # with st.expander(label='**数据汇总信息**', expanded=False):
# with st.container():
# st_profile_report(pr)
# # st.stop()
## 数据可视化内容,高级版。
## pywalker: https://github.com/Kanaries/pygwalker/tree/main
user_message_plot_check = st.chat_message("user")
user_message_plot_check.write(f"您是否希望基于上述查询结果进行数据可视化?")
with user_message_plot_check:
happy_plot = st.radio(
# "Happy",
label="您的反馈:",
options=["未决定", "是,需要进行数据可视化", "否,不需要进行数据可视化"], ## 去掉第一个happy选项。
# options=["", "yes", "no"], ## original。
key="radio_plot",
index=0,
horizontal=True,
label_visibility='visible',
)
if happy_plot == "是,需要进行数据可视化":
import pygwalker as pyg
import streamlit.components.v1 as components
import streamlit as st
from pygwalker.api.streamlit import init_streamlit_comm, get_streamlit_html
# Initialize pygwalker communication
init_streamlit_comm()
# When using `use_kernel_calc=True`, you should cache your pygwalker html, if you don't want your memory to explode
@st.cache_resource
def get_pyg_html(df: pd.DataFrame) -> str:
# When you need to publish your application, you need set `debug=False`,prevent other users to write your config file.
# If you want to use feature of saving chart config, set `debug=True`
html = get_streamlit_html(df, spec="./gw0.json", use_kernel_calc=True, debug=False)
return html
# walker = pyg.walk(df)
components.html(get_pyg_html(df), width=1200, height=800, scrolling=True)
### streamlit echarts: https://github.com/andfanilo/streamlit-echarts
# from streamlit_echarts import st_echarts
# from streamlit_echarts import st_pyecharts
# st_pyecharts(df)
### Follow-up questions session.
# if st.session_state.get("show_followup", True):
# assistant_message_followup = st.chat_message(
# "assistant",
# avatar=myavatar,
# )
# followup_questions = generate_followup_cached(
# question=my_question, df=df
# )
# st.session_state["df"] = None
# if len(followup_questions) > 0:
# assistant_message_followup.text(
# "Here are some possible follow-up questions"
# )
# # Print the first 5 follow-up questions
# for question in followup_questions[:5]:
# time.sleep(0.05)
# assistant_message_followup.write(question)
else:
assistant_message_error = st.chat_message(
"assistant", avatar=myavatar
)
assistant_message_error.error("我无法回答您的问题,请重新提问,谢谢!")