''' | |
1. 大模型以Qwen API形式提供。 | |
1. 重塑了Qwen作为大语言模型做Text2SQL的提示词: | |
sys_prompt = """ | |
1. 你是一个将文字转换成SQL语句的人工智能。 | |
2. 你需要注意:你只需要用纯文本回复代码的内容,即你不允许回复代码以外的任何信息。 | |
3. SQL变量默认是中文,而且只能从如下的名称列表中选择,你不可以使用这些名字以外的变量名:"长度","宽度","价格","产品ID","比率","类别","*" | |
4. 你不能写IF, THEN的SQL语句,需要使用CASE。 | |
5. 我需要你转换的文字如下:""" | |
total_prompt = sys_prompt + "在数据表格table01中," + prompt | |
''' | |
##TODO: 2. 账号功能。 | |
import time | |
import os | |
import pandas as pd | |
import streamlit as st | |
from code_editor import code_editor | |
import streamlit_authenticator as stauth | |
# from utils.setup import setup_connexion, setup_session_state | |
# from utils.vanna_calls import ( | |
# generate_questions_cached, | |
# run_sql_cached, | |
# generate_plotly_code_cached, | |
# generate_plot_cached, | |
# generate_followup_cached, | |
# ) | |
# import chatsql003 ### 本地ChatGLT 版本。 | |
import chatsql004 ###Qwen API 版本。 | |
import sql_command | |
# from streamlit_pandas_profiling import st_profile_report | |
import dashscope | |
from dotenv import load_dotenv | |
load_dotenv() | |
### 设置openai的API key | |
dashscope.api_key = os.environ['dashscope_api_key'] | |
st.set_page_config(layout="wide", page_icon="🧩", page_title="本地化国产大模型数据库查询演示") | |
# setup_connexion() | |
def clear_all(): | |
st.session_state.conversation = None | |
st.session_state.chat_history = None | |
st.session_state.messages = [] | |
message_placeholder = st.empty() | |
st.session_state["my_question"] = None | |
return None | |
## 原始的控制面板 | |
# st.sidebar.title("大模型控制面板") | |
# st.sidebar.checkbox("Show SQL", value=True, key="show_sql") | |
# st.sidebar.checkbox("Show Table", value=True, key="show_table") | |
# st.sidebar.checkbox("Show Plotly Code", value=True, key="show_plotly_code") | |
# st.sidebar.checkbox("Show Chart", value=True, key="show_chart") | |
# st.sidebar.checkbox("Show Follow-up Questions", value=True, key="show_followup") | |
# st.sidebar.button("清除记录,重启一轮新对话", on_click=setup_session_state, use_container_width=True, type='primary') | |
st.title("本地化国产大模型数据库查询演示") | |
# st.title("大语言模型SQL数据库查询中心") | |
# st.info("声明:内容由人工智能生成,仅供参考。如果您本人使用或对外传播本服务生成的输出,您应当主动核查输出内容的真实性、准确性,避免传播虚假信息。") ## 颜色比较明显。 | |
st.markdown("_声明:内容由人工智能生成,仅供参考。如果您本人使用或对外传播本服务生成的输出,您应当主动核查输出内容的真实性、准确性,避免传播虚假信息。_") | |
### authentication with a local yaml file. | |
import yaml | |
from yaml.loader import SafeLoader | |
with open('./config.yaml') as file: | |
config = yaml.load(file, Loader=SafeLoader) | |
authenticator = stauth.Authenticate( | |
config['credentials'], | |
config['cookie']['name'], | |
config['cookie']['key'], | |
config['cookie']['expiry_days'], | |
config['preauthorized'] | |
) | |
user, authentication_status, username = authenticator.login('main') | |
# user, authentication_status, username = authenticator.login('用户登录', 'main') | |
print('登录的用户:', username) | |
### Streamlit Sidebar 左侧工具栏 | |
# st.sidebar.write(st.session_state) | |
if authentication_status: | |
with st.sidebar: | |
st.markdown( | |
""" | |
<style> | |
[data-testid="stSidebar"][aria-expanded="true"]{ | |
min-width: 500px; | |
max-width: 500px; | |
} | |
""", | |
unsafe_allow_html=True, | |
) | |
# st.header(f'**大语言模型专家系统工作设定区**') | |
st.header(f'**欢迎 **{username}** 使用本系统** ') | |
st.write(f'_Large Language Model Expert System Working Environment_') | |
# st.write(f'_Welcome and Hope U Enjoy Staying Here_') | |
authenticator.logout('登出', 'sidebar') | |
### siderbar的题目。 | |
### siderbar的题目。 | |
# st.header(f'**大语言模型专家系统工作设定区**') | |
# st.header(f'**欢迎 **{username}** 使用本系统** ') ## 用户登录显示。 | |
st.sidebar.button("清除记录,重启一轮新对话", on_click=clear_all, use_container_width=True, type='primary') | |
# st.sidebar.button("清除记录,重启一轮新对话", on_click=setup_session_state, use_container_width=True, type='primary') | |
### 展示当前数据库 | |
# st.markdown("#### 当前数据库中的数据:") | |
with st.expander("#### 当前数据库中的数据", expanded=True): | |
my_db = pd.read_sql_table('table01', 'sqlite:///myexcelDB.db') | |
st.dataframe(my_db, width=400) | |
## 在sidebar上的三个分页显示,用st.tabs实现。 | |
tab_1, tab_2, tab_4 = st.tabs(['使用须知', '模型参数', '角色设定']) | |
# tab_1, tab_2, tab_3, tab_4 = st.tabs(['使用须知', '模型参数', '提示词模板', '系统角色设定']) | |
# with st.expander(label='**使用须知**', expanded=False): | |
with tab_1: | |
# st.markdown("#### 快速上手指南") | |
# with st.text(body="说明"): | |
# st.markdown("* 重启一轮新对话时,只需要刷新页面(按Ctrl/Command + R)即可。") | |
with st.text(body="说明"): | |
st.markdown("""* 简介 | |
本系统使用大模型技术,将自然语言描述转换为SQL查询语句。用户可以输入自然语言问题,系统会自动生成相应的SQL语句,并返回查询结果。 | |
* 使用步骤 | |
在文本框中输入您的自然语言问题。 | |
系统会自动生成相应的SQL语句,并显示在下方。 | |
点击“满意,请执行语句”选项,可查看查询结果。点击”不满意,需要改正语句“可以手动修改SQL语句。 | |
* 注意事项 | |
本系统仍在开发中,可能会存在一些错误或不准确的地方。 | |
自然语言描述越清晰,生成的SQL语句越准确。 | |
系统支持的SQL语法有限,请尽量使用简单易懂的语法。 | |
""") | |
# with st.text(body="说明"): | |
# st.markdown("""在构建大语言模型本地知识库问答系统时,需要注意以下几点: | |
# LLM的选择:LLM的选择应根据系统的应用场景和需求进行。对于需要处理通用问题的系统,可以选择通用LLM;对于需要处理特定领域问题的系统,可以选择针对该领域进行微调的LLM。 | |
# 本地知识库的构建:本地知识库应包含系统所需的所有知识。知识的组织方式应便于LLM的访问和处理。 | |
# 系统的评估:系统应进行充分的评估,以确保其能够准确地回答用户问题。 | |
# 大语言模型本地知识库问答系统具有以下优势: | |
# 准确性:LLM和本地知识库的结合可以提高系统的准确性。 | |
# 全面性:LLM和本地知识库的结合可以使系统能够回答更广泛的问题。 | |
# 效率:LLM可以快速生成候选答案,而本地知识库可以快速评估候选答案的准确性。 | |
# """) | |
# with st.text(body="说明"): | |
# st.markdown( | |
# "* “数据分析模式”暂时只支持1000个单元格以内的数据分析,单元格中的内容不支持中文数据(表头也尽量不使用中文)。一般运行时间在1至10分钟左右,期间需要保持网络畅通。") | |
# with st.text(body="说明"): | |
# st.markdown("* “数据分析模式”推荐上传csv格式的文件,部分Excel文件容易出现数据不兼容的情况。") | |
## 大模型参数 | |
# with st.expander(label='**大语言模型参数**', expanded=True): | |
with tab_2: | |
max_tokens = st.slider(label='Max_Token(生成结果时最大字数)', min_value=100, max_value=4096, value=2048, step=100) | |
temperature = st.slider(label='Temperature (温度)', min_value=0.0, max_value=1.0, value=0.8, step=0.1) | |
top_p = st.slider(label='Top_P (核采样)', min_value=0.0, max_value=1.0, value=0.6, step=0.1) | |
frequency_penalty = st.slider(label='Frequency Penalty (重复度惩罚因子)', min_value=-2.0, max_value=2.0, value=1.0, step=0.1) | |
presence_penalty = st.slider(label='Presence Penalty (控制主题的重复度)', min_value=-2.0, max_value=2.0, value=1.0, step=0.1) | |
## reset password widget | |
# try: | |
# if authenticator.reset_password(st.session_state["username"], 'Reset password'): | |
# st.success('Password modified successfully') | |
# except Exception as e: | |
# st.error(e) | |
# with st.header(body="欢迎"): | |
# st.markdown("# 欢迎使用大语言模型商业智能中心") | |
# with st.expander(label=("**重要的使用注意事项**"), expanded=True): | |
# with st.container(): | |
##NOTE: 在SQL场景去不需要展示这些提示词。 | |
# with tab_3: | |
# # st.markdown("#### Prompt提示词参考资料") | |
# # with st.expander(label="**大语言模型基础提示词Prompt示例**", expanded=False): | |
# st.code( | |
# body="继续用中文写一篇关于 [文章主题] 的文章,以下列句子开头:[文章开头]。", language='plaintext') | |
# st.code(body="将以下文字概括为 100 个字,使其易于阅读和理解。避免使用复杂的句子结构或技术术语。", | |
# language='plaintext') | |
# st.code(body="给我出一个迪奥2023春季发布会活动策划。", language='plaintext') | |
# st.code(body="帮我按照正式会议结构写一个会邀:主题是xx手机游戏立项会议。", language='plaintext') | |
# st.code(body="帮我写一个车内健康监测全场景落地的项目计划,用表格。", language='plaintext') | |
# st.code( | |
# body="同时掷两枚质地均匀的骰子,则两枚骰子向上的点数之和为 7 的概率是多少。", language='plaintext') | |
# st.code(body="写一篇产品经理的演讲稿,注意使用以下词汇: 赋能,抓手,中台,闭环,落地,漏斗,沉淀,给到,同步,对齐,对标,迭代,拉通,打通,升级,交付,聚焦,倒逼,复盘,梳理,方案,联动,透传,咬合,洞察,渗透,兜底,解耦,耦合,复用,拆解。", language='plaintext') | |
# with st.expander(label="**数据分析模式的专用提示词Prompt示例**", expanded=False): | |
# # with st.subheader(body="提示词Prompt"): | |
# st.code(body="分析此数据集并绘制一些'有趣的图表'。", language='python') | |
# st.code( | |
# body="对于这个文件中的数据,你需要要找出[X,Y]数据之间的寻找'相关性'。", language='python') | |
# st.code(body="对于这个文件中的[xxx]数据给我一个'整体的分析'。", language='python') | |
# st.code(body="对于[xxx]数据给我一个'直方图',提供图表,并给出分析结果。", language='python') | |
# st.code(body="对于[xxx]数据给我一个'小提琴图',并给出分析结果。", language='python') | |
# st.code( | |
# body="对于[X,Y,Z]数据在一个'分布散点图 (stripplot)',所有的数据在一张图上展现, 并给出分析结果。", language='python') | |
# st.code(body="对于[X,Y]数据,进行'T检验',你需要展示图表,并给出分析结果。", | |
# language='python') | |
# st.code(body="对于[X,Y]数据给我一个3个类别的'聚类分析',并给出分析结果。", | |
# language='python') | |
with tab_4: | |
st.text_area(label='系统角色设定', value='你是一个人工智能,你需要回答我提出的问题,或者完成我交代的任务。你需要使用我提问的语言(如中文、英文)来回答。', height=200, label_visibility='hidden') | |
myavatar = "./2D.png" | |
def set_question(question): | |
st.session_state["my_question"] = question | |
assistant_message_suggested = st.chat_message( | |
"assistant", avatar=myavatar | |
) | |
##! 注意在填写列名时,最佳的选择是使用双引号将名字框出来。否则LLM会不认识。 | |
if assistant_message_suggested.button("点击这里可以查看参考问题示例 ℹ️"): | |
questions = [ | |
'你查一下长度超过10的数据有哪些?', | |
'你告诉我长度超过30且宽度超过10的数据有哪些?', | |
'宽度超过20且价格超过25的数据有哪些?', | |
'”长度“超过10且”宽度“超过15且”类别“是”甲“的数据有哪些?', | |
'长度超过10以上的价格的总和是多少?', | |
'长度排名前三的产品ID和价格和长度?', | |
] | |
st.session_state["my_question"] = None | |
# questions = generate_questions_cached() ## orignal. | |
for i, question in enumerate(questions): | |
time.sleep(0.05) | |
button = st.button( | |
question, | |
on_click=set_question, | |
args=(question,), | |
) | |
my_question = st.session_state.get("my_question", default=None) | |
if my_question is None: | |
my_question = st.chat_input( | |
"请在这里提交您的查询问题", | |
) | |
if my_question: | |
st.session_state["my_question"] = my_question | |
user_message = st.chat_message("user") | |
user_message.write(f"{my_question}") | |
# sql = generate_sql_cached(question=my_question) ## original. 核心函数。用ChatGPT输出SQL语句。 | |
# sql = chatsql003.main(prompt=my_question) ## 通过本地LLM输出SQL语句。 | |
sql = chatsql004.main(prompt=my_question) ## 通过本地LLM输出SQL语句。 | |
if sql: | |
if st.session_state.get("show_sql", True): | |
assistant_message_sql = st.chat_message( | |
"assistant", avatar=myavatar | |
) | |
assistant_message_sql.code(sql, language="sql", line_numbers=True) | |
user_message_sql_check = st.chat_message("user") | |
user_message_sql_check.write(f"您是否满意上面的SQL查询语句答案?") | |
with user_message_sql_check: | |
happy_sql = st.radio( | |
# "Happy", | |
label="您的反馈:", | |
options=["未决定", "满意,请执行语句", "不满意,需要改正语句"], ## 去掉第一个happy选项。 | |
# options=["", "yes", "no"], ## original。 | |
key="radio_sql", | |
index=0, | |
horizontal=True, | |
label_visibility='visible', | |
) | |
# ## 非original。后期添加内容。Working. | |
if happy_sql == "未决定": | |
df = None | |
st.session_state["df"] = None | |
if happy_sql == "不满意,需要改正语句": | |
st.warning( | |
"请您手动修正SQL语句,按Shift + Enter提交" | |
) | |
sql_response = code_editor(sql, lang="sql") | |
fixed_sql_query = sql_response["text"] | |
if fixed_sql_query != "": | |
# df = run_sql_cached(sql=fixed_sql_query) ## original. 核心函数。 | |
df = sql_command.llm_query(sql_command=fixed_sql_query) | |
else: | |
df = None | |
elif happy_sql == "满意,请执行语句": | |
# df = run_sql_cached(sql=sql) ## original. 核心函数。 | |
df = sql_command.llm_query(sql_command=sql) ## working. | |
else: | |
df = None | |
if df is not None: | |
st.session_state["df"] = df | |
if st.session_state.get("df") is not None: | |
## 显示查询的结果汇总信息 | |
# with st.container(): | |
query_info_describe = st.chat_message("assistant", avatar=myavatar) | |
if df is not None: | |
with query_info_describe: | |
metric_col1, metric_col2, metric_col3, metric_col4 = st.columns(4) | |
metric_col1.metric(label='查询结果中的行数', value=f"{df.shape[0]} 行", delta=None) | |
metric_col2.metric(label='查询结果中的列数', value=f"{df.shape[1]} 列", delta=None) | |
metric_col3.metric(label='查询结果中的单元格数', value=f"{df.size} 个", delta=None) | |
metric_col4.metric(label='查询结果中的缺失值', value=f"{df.isnull().sum().sum()} 个", delta=None) | |
if st.session_state.get("show_table", True): | |
df = st.session_state.get("df") | |
assistant_message_table = st.chat_message( | |
"assistant", | |
avatar=myavatar | |
) | |
### 目前是最佳显示方式 | |
if len(df) > 10: | |
assistant_message_table.markdown("**查询结果较多,当前只显示前10行数据**") ##TODO: 看看是否有其他更好的显示方式。 | |
assistant_message_table.dataframe(df.head(10), hide_index=True) | |
# assistant_message_table.dataframe(df.head(100), use_container_width=True) | |
else: | |
assistant_message_table.dataframe(df, hide_index=True) | |
# assistant_message_table.dataframe(df,use_container_width=True) | |
## 以下是其他datafram的显示方式,均不理想。 | |
# st.dataframe(df, use_container_width=True, height=100) | |
## streamlit AwesomeTable, Dark模式下显示有问题。可以在sidebar上显示内容。 | |
# from awesome_table import AwesomeTable | |
# AwesomeTable(df,show_order=True, show_search=True, show_search_order_in_sidebar=True) | |
# # AwesomeTable(pd.json_normalize(df)) | |
# ## 展示图形 | |
# # code = generate_plotly_code_cached(question=my_question, sql=sql, df=df) ## orignal. 用LLM来构建统计图表:https://github.com/vanna-ai/vanna/blob/main/src/vanna/mistral/mistral.py | |
# if st.session_state.get("show_plotly_code", False): | |
# assistant_message_plotly_code = st.chat_message( | |
# "assistant", | |
# avatar=myavatar, | |
# ) | |
# assistant_message_plotly_code.code( | |
# code, language="python", line_numbers=True | |
# ) | |
# user_message_plotly_check = st.chat_message("user") | |
# user_message_plotly_check.write( | |
# f"Are you happy with the generated Plotly code?" | |
# ) | |
# with user_message_plotly_check: | |
# happy_plotly = st.radio( | |
# "Happy", | |
# options=["", "yes", "no"], | |
# key="radio_plotly", | |
# index=0, | |
# ) | |
# if happy_plotly == "no": | |
# st.warning( | |
# "Please fix the generated Python code. Once you're done hit Shift + Enter to submit" | |
# ) | |
# python_code_response = code_editor(code, lang="python") | |
# code = python_code_response["text"] | |
# elif happy_plotly == "": | |
# code = None | |
# if code is not None and code != "": | |
# if st.session_state.get("show_chart", True): | |
# assistant_message_chart = st.chat_message( | |
# "assistant", | |
# avatar=myavatar, | |
# ) | |
# fig = generate_plot_cached(code=code, df=df) | |
# if fig is not None: | |
# assistant_message_chart.plotly_chart(fig) | |
# else: | |
# assistant_message_chart.error("I couldn't generate a chart") | |
## Data Visualization: Ydata_profiling (pandas_profiling) | |
# from ydata_profiling import ProfileReport | |
# pr = ProfileReport(df) | |
# # with st.expander(label='**数据汇总信息**', expanded=False): | |
# with st.container(): | |
# st_profile_report(pr) | |
# # st.stop() | |
## 数据可视化内容,高级版。 | |
## pywalker: https://github.com/Kanaries/pygwalker/tree/main | |
user_message_plot_check = st.chat_message("user") | |
user_message_plot_check.write(f"您是否希望基于上述查询结果进行数据可视化?") | |
with user_message_plot_check: | |
happy_plot = st.radio( | |
# "Happy", | |
label="您的反馈:", | |
options=["未决定", "是,需要进行数据可视化", "否,不需要进行数据可视化"], ## 去掉第一个happy选项。 | |
# options=["", "yes", "no"], ## original。 | |
key="radio_plot", | |
index=0, | |
horizontal=True, | |
label_visibility='visible', | |
) | |
if happy_plot == "是,需要进行数据可视化": | |
import pygwalker as pyg | |
import streamlit.components.v1 as components | |
import streamlit as st | |
from pygwalker.api.streamlit import init_streamlit_comm, get_streamlit_html | |
# Initialize pygwalker communication | |
init_streamlit_comm() | |
# When using `use_kernel_calc=True`, you should cache your pygwalker html, if you don't want your memory to explode | |
def get_pyg_html(df: pd.DataFrame) -> str: | |
# When you need to publish your application, you need set `debug=False`,prevent other users to write your config file. | |
# If you want to use feature of saving chart config, set `debug=True` | |
html = get_streamlit_html(df, spec="./gw0.json", use_kernel_calc=True, debug=False) | |
return html | |
# walker = pyg.walk(df) | |
components.html(get_pyg_html(df), width=1200, height=800, scrolling=True) | |
### streamlit echarts: https://github.com/andfanilo/streamlit-echarts | |
# from streamlit_echarts import st_echarts | |
# from streamlit_echarts import st_pyecharts | |
# st_pyecharts(df) | |
### Follow-up questions session. | |
# if st.session_state.get("show_followup", True): | |
# assistant_message_followup = st.chat_message( | |
# "assistant", | |
# avatar=myavatar, | |
# ) | |
# followup_questions = generate_followup_cached( | |
# question=my_question, df=df | |
# ) | |
# st.session_state["df"] = None | |
# if len(followup_questions) > 0: | |
# assistant_message_followup.text( | |
# "Here are some possible follow-up questions" | |
# ) | |
# # Print the first 5 follow-up questions | |
# for question in followup_questions[:5]: | |
# time.sleep(0.05) | |
# assistant_message_followup.write(question) | |
else: | |
assistant_message_error = st.chat_message( | |
"assistant", avatar=myavatar | |
) | |
assistant_message_error.error("我无法回答您的问题,请重新提问,谢谢!") | |