Upload 6 files
Browse files- 2D.png +0 -0
- app.py +441 -0
- chatsql004.py +157 -0
- myexcelDB.db +0 -0
- qwen_response.py +45 -0
- requirements.txt +17 -0
2D.png
ADDED
![]() |
app.py
ADDED
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
|
3 |
+
'''
|
4 |
+
##TODO: 1. 转换成Qwen的API。 2.
|
5 |
+
|
6 |
+
import time
|
7 |
+
import os
|
8 |
+
import pandas as pd
|
9 |
+
import streamlit as st
|
10 |
+
from code_editor import code_editor
|
11 |
+
# from utils.setup import setup_connexion, setup_session_state
|
12 |
+
# from utils.vanna_calls import (
|
13 |
+
# generate_questions_cached,
|
14 |
+
# generate_sql_cached,
|
15 |
+
# run_sql_cached,
|
16 |
+
# generate_plotly_code_cached,
|
17 |
+
# generate_plot_cached,
|
18 |
+
# generate_followup_cached,
|
19 |
+
# )
|
20 |
+
# import chatsql003 ### 本地ChatGLT 版本。
|
21 |
+
import chatsql004 ###Qwen API 版本。
|
22 |
+
import sql_command
|
23 |
+
# from streamlit_pandas_profiling import st_profile_report
|
24 |
+
import dashscope
|
25 |
+
from dotenv import load_dotenv
|
26 |
+
|
27 |
+
|
28 |
+
load_dotenv()
|
29 |
+
### 设置openai的API key
|
30 |
+
dashscope.api_key = os.environ['dashscope_api_key']
|
31 |
+
|
32 |
+
st.set_page_config(layout="wide", page_icon="🧩", page_title="本地化国产大模型数据库查询演示")
|
33 |
+
# setup_connexion()
|
34 |
+
|
35 |
+
def clear_all():
|
36 |
+
st.session_state.conversation = None
|
37 |
+
st.session_state.chat_history = None
|
38 |
+
st.session_state.messages = []
|
39 |
+
message_placeholder = st.empty()
|
40 |
+
st.session_state["my_question"] = None
|
41 |
+
return None
|
42 |
+
## 原始的控制面板
|
43 |
+
# st.sidebar.title("大模型控制面板")
|
44 |
+
# st.sidebar.checkbox("Show SQL", value=True, key="show_sql")
|
45 |
+
# st.sidebar.checkbox("Show Table", value=True, key="show_table")
|
46 |
+
# st.sidebar.checkbox("Show Plotly Code", value=True, key="show_plotly_code")
|
47 |
+
# st.sidebar.checkbox("Show Chart", value=True, key="show_chart")
|
48 |
+
# st.sidebar.checkbox("Show Follow-up Questions", value=True, key="show_followup")
|
49 |
+
# st.sidebar.button("清除记录,重启一轮新对话", on_click=setup_session_state, use_container_width=True, type='primary')
|
50 |
+
|
51 |
+
st.title("本地化国产大模型数据库查询演示")
|
52 |
+
# st.title("大语言模型SQL数据库查询中心")
|
53 |
+
# st.info("声明:内容由人工智能生成,仅供参考。如果您本人使用或对外传播本服务生成的输出,您应当主动核查输出内容的真实性、准确性,避免传播虚假信息。") ## 颜色比较明显。
|
54 |
+
st.markdown("_声明:内容由人工智能生成,仅供参考。如果您本人使用或对外传播本服务生成的输出,您应当主动核查输出内容的真实性、准确性,避免传播虚假信息。_")
|
55 |
+
|
56 |
+
|
57 |
+
### Streamlit Sidebar 左侧工具栏
|
58 |
+
# st.sidebar.write(st.session_state)
|
59 |
+
with st.sidebar:
|
60 |
+
st.markdown(
|
61 |
+
"""
|
62 |
+
<style>
|
63 |
+
[data-testid="stSidebar"][aria-expanded="true"]{
|
64 |
+
min-width: 500px;
|
65 |
+
max-width: 500px;
|
66 |
+
}
|
67 |
+
""",
|
68 |
+
unsafe_allow_html=True,
|
69 |
+
)
|
70 |
+
### siderbar的题目。
|
71 |
+
### siderbar的题目。
|
72 |
+
# st.header(f'**大语言模型专家系统工作设定区**')
|
73 |
+
st.header(f'**系统控制面板** ')
|
74 |
+
# st.header(f'**欢迎 **{username}** 使用本系统** ') ## 用户登录显示。
|
75 |
+
st.write(f'_Large Language Model Expert System Working Environment_')
|
76 |
+
st.sidebar.button("清除记录,重启一轮新对话", on_click=clear_all, use_container_width=True, type='primary')
|
77 |
+
# st.sidebar.button("清除记录,重启一轮新对话", on_click=setup_session_state, use_container_width=True, type='primary')
|
78 |
+
|
79 |
+
### 展示当前数据库
|
80 |
+
# st.markdown("#### 当前数据库中的数据:")
|
81 |
+
with st.expander("#### 当前数据库中的数据", expanded=True):
|
82 |
+
my_db = pd.read_sql_table('table01', 'sqlite:///myexcelDB.db')
|
83 |
+
st.dataframe(my_db, width=400)
|
84 |
+
|
85 |
+
## 在sidebar上的三个分页显示,用st.tabs实现。
|
86 |
+
tab_1, tab_2, tab_4 = st.tabs(['使用须知', '模型参数', '角色设定'])
|
87 |
+
# tab_1, tab_2, tab_3, tab_4 = st.tabs(['使用须知', '模型参数', '提示词模板', '系统角色设定'])
|
88 |
+
|
89 |
+
# with st.expander(label='**使用须知**', expanded=False):
|
90 |
+
with tab_1:
|
91 |
+
# st.markdown("#### 快速上手指南")
|
92 |
+
# with st.text(body="说明"):
|
93 |
+
# st.markdown("* 重启一轮新对话时,只需要刷新页面(按Ctrl/Command + R)即可。")
|
94 |
+
with st.text(body="说明"):
|
95 |
+
st.markdown("""* 简介
|
96 |
+
|
97 |
+
本系统使用大模型技术,将自然语言描述转换为SQL查询语句。用户可以输入自然语言问题,系统会自动生成相应的SQL语句,并返回查询结果。
|
98 |
+
|
99 |
+
* 使用步骤
|
100 |
+
|
101 |
+
在文本框中输入您的自然语言问题。
|
102 |
+
系统会自动生成相应的SQL语句,并显示在下方。
|
103 |
+
点击“满意,请执行语句”选项,可查看查询结果。点击”不满意,需要改正语句“可以手动修改SQL语句。
|
104 |
+
|
105 |
+
* 注意事项
|
106 |
+
|
107 |
+
本系统仍在开发中,可能会存在一些错误或不准确的地方。
|
108 |
+
自然语言描述越清晰,生成的SQL语句越准确。
|
109 |
+
系统支持的SQL语法有限,请尽量使用简单易懂的语法。
|
110 |
+
""")
|
111 |
+
# with st.text(body="说明"):
|
112 |
+
# st.markdown("""在构建大语言模型本地知识库问答系统时,需要注意以下几点:
|
113 |
+
|
114 |
+
# LLM的选择:LLM的选择应根据系统的应用场景和需求进行。对于需要处理通用问题的系统,可以选择通用LLM;对于需要处理特定领域问题的系统,可以选择针对该领域进行微调的LLM。
|
115 |
+
# 本地知识库的构建:本地知识库应包含系统所需的所有知识。知识的组织方式应便于LLM的访问和处理。
|
116 |
+
# 系统的评估:系统应进行充分的评估,以确保其能够准确地回答用户问题。
|
117 |
+
# 大语言模型本地知识库问答系统具有以下优势:
|
118 |
+
|
119 |
+
# 准确性:LLM和本地知识库的结合可以提高系统的准确性。
|
120 |
+
# 全面性:LLM和本地知识库的结合可以使系统能够回答更广泛的问题。
|
121 |
+
# 效率:LLM可以快速生成候选答案,而本地知识库可以快速评估候选答案的准确性。
|
122 |
+
# """)
|
123 |
+
|
124 |
+
|
125 |
+
# with st.text(body="说明"):
|
126 |
+
# st.markdown(
|
127 |
+
# "* “数据分析模式”暂时只支持1000个单元格以内的数据分析,单元格中的内容不支持中文数据(表头也尽量不使用中文)。一般运行时间在1至10分钟左右,期间需要保持网络畅通。")
|
128 |
+
# with st.text(body="说明"):
|
129 |
+
# st.markdown("* “数据分析模式”推荐上传csv格式的文件,部分Excel文件容易出现数据不兼容的情况。")
|
130 |
+
|
131 |
+
## 大模型参数
|
132 |
+
# with st.expander(label='**大语言模型参数**', expanded=True):
|
133 |
+
with tab_2:
|
134 |
+
max_tokens = st.slider(label='Max_Token(生成结果时最大字数)', min_value=100, max_value=4096, value=2048, step=100)
|
135 |
+
temperature = st.slider(label='Temperature (温度)', min_value=0.0, max_value=1.0, value=0.8, step=0.1)
|
136 |
+
top_p = st.slider(label='Top_P (核采样)', min_value=0.0, max_value=1.0, value=0.6, step=0.1)
|
137 |
+
frequency_penalty = st.slider(label='Frequency Penalty (重复度惩罚因子)', min_value=-2.0, max_value=2.0, value=1.0, step=0.1)
|
138 |
+
presence_penalty = st.slider(label='Presence Penalty (控制主题的重复度)', min_value=-2.0, max_value=2.0, value=1.0, step=0.1)
|
139 |
+
|
140 |
+
## reset password widget
|
141 |
+
# try:
|
142 |
+
# if authenticator.reset_password(st.session_state["username"], 'Reset password'):
|
143 |
+
# st.success('Password modified successfully')
|
144 |
+
# except Exception as e:
|
145 |
+
# st.error(e)
|
146 |
+
|
147 |
+
# with st.header(body="欢迎"):
|
148 |
+
# st.markdown("# 欢迎使用大语言模型商业智能中心")
|
149 |
+
# with st.expander(label=("**重要的使用注意事项**"), expanded=True):
|
150 |
+
# with st.container():
|
151 |
+
|
152 |
+
##NOTE: 在SQL场景去不需要展示这些提示词。
|
153 |
+
# with tab_3:
|
154 |
+
# # st.markdown("#### Prompt提示词参考资料")
|
155 |
+
# # with st.expander(label="**大语言模型基础提示词Prompt示例**", expanded=False):
|
156 |
+
# st.code(
|
157 |
+
# body="继续用中文写一篇关于 [文章主题] 的文章,以下列句子开头:[文章开头]。", language='plaintext')
|
158 |
+
# st.code(body="将以下文字概括为 100 个字,使其易于阅读和理解。避免使用复杂的句子结构或技术术语。",
|
159 |
+
# language='plaintext')
|
160 |
+
# st.code(body="给我出一个迪奥2023春季发布会活动策划。", language='plaintext')
|
161 |
+
# st.code(body="帮我按照正式会议结构写一个会邀:主题是xx手机游戏立项会议。", language='plaintext')
|
162 |
+
# st.code(body="帮我写一个车内健康监测全场景落地的项目计划,用表格。", language='plaintext')
|
163 |
+
# st.code(
|
164 |
+
# body="同时掷两枚质地均匀的骰子,则两枚骰子向上的点数之和为 7 的概率是多少。", language='plaintext')
|
165 |
+
# st.code(body="写一篇产品经理的演讲稿,注意使用以下词汇: 赋能,抓手,中台,闭环,落地,漏斗,沉淀,给到,同步,对齐,对标,迭代,拉通,打通,升级,交付,聚焦,倒逼,复盘,梳理,方案,联动,透传,咬合,洞察,渗透,兜底,解耦,耦合,复用,拆解。", language='plaintext')
|
166 |
+
|
167 |
+
# with st.expander(label="**数据分析模式的专用提示词Prompt示例**", expanded=False):
|
168 |
+
# # with st.subheader(body="提示词Prompt"):
|
169 |
+
# st.code(body="分析此数据集并绘制一些'有趣的图表'。", language='python')
|
170 |
+
# st.code(
|
171 |
+
# body="对于这个文件中的数据,你需要要找出[X,Y]数据之间的寻找'相关性'。", language='python')
|
172 |
+
# st.code(body="对于这个文件中的[xxx]数据给我一个'整体的分析'。", language='python')
|
173 |
+
# st.code(body="对于[xxx]数据给我一个'直方图',提供图表,并给出分析结果。", language='python')
|
174 |
+
# st.code(body="对于[xxx]数据给我一个'小提琴图',并给出分析结果。", language='python')
|
175 |
+
# st.code(
|
176 |
+
# body="对于[X,Y,Z]数据在一个'分布散点图 (stripplot)'���所有的数据在一张图上展现, 并给出分析结果。", language='python')
|
177 |
+
# st.code(body="对于[X,Y]数据,进行'T检验',你需要展示图表,并给出分析结果。",
|
178 |
+
# language='python')
|
179 |
+
# st.code(body="对于[X,Y]数据给我一个3个类别的'聚类分析',并给出分析结果。",
|
180 |
+
# language='python')
|
181 |
+
|
182 |
+
with tab_4:
|
183 |
+
st.text_area(label='系统角色设定', value='你是一个人工智能,你需要回答我提出的问题,或者完成我交代的任务。你需要使用我提问的语言(如中文、英文)来回答。', height=200, label_visibility='hidden')
|
184 |
+
|
185 |
+
|
186 |
+
myavatar = "./2D.png"
|
187 |
+
|
188 |
+
def set_question(question):
|
189 |
+
st.session_state["my_question"] = question
|
190 |
+
|
191 |
+
|
192 |
+
assistant_message_suggested = st.chat_message(
|
193 |
+
"assistant", avatar=myavatar
|
194 |
+
)
|
195 |
+
|
196 |
+
##! 注意在填写列名时,最佳的选择是使用双引号将名字框出来。否则LLM会不认识。
|
197 |
+
if assistant_message_suggested.button("点击这里可以查看参考问题示例 ℹ️"):
|
198 |
+
questions = [
|
199 |
+
'你查一下长度超过10的数据有哪些?',
|
200 |
+
'你告诉我长度超过30且宽度超过10的数据有哪些?',
|
201 |
+
'宽度超过20且价格超过25的数据有哪些?',
|
202 |
+
'”长度“超过10且”宽度“超过15且”类别“是”甲“的数据有哪些?',
|
203 |
+
'长度超过10以上的价格的总和是多少?',
|
204 |
+
'长度排名前三的产品ID和价格和长度?',
|
205 |
+
|
206 |
+
]
|
207 |
+
st.session_state["my_question"] = None
|
208 |
+
# questions = generate_questions_cached() ## orignal.
|
209 |
+
for i, question in enumerate(questions):
|
210 |
+
time.sleep(0.05)
|
211 |
+
button = st.button(
|
212 |
+
question,
|
213 |
+
on_click=set_question,
|
214 |
+
args=(question,),
|
215 |
+
)
|
216 |
+
|
217 |
+
my_question = st.session_state.get("my_question", default=None)
|
218 |
+
|
219 |
+
if my_question is None:
|
220 |
+
my_question = st.chat_input(
|
221 |
+
"请在这里提交您的查询问题",
|
222 |
+
)
|
223 |
+
|
224 |
+
if my_question:
|
225 |
+
st.session_state["my_question"] = my_question
|
226 |
+
user_message = st.chat_message("user")
|
227 |
+
user_message.write(f"{my_question}")
|
228 |
+
|
229 |
+
# sql = generate_sql_cached(question=my_question) ## original. 核心函数。用ChatGPT输出SQL语句。
|
230 |
+
# sql = chatsql003.main(prompt=my_question) ## 通过本地LLM输出SQL语句。
|
231 |
+
sql = chatsql004.main(prompt=my_question) ## 通过本地LLM输出SQL语句。
|
232 |
+
|
233 |
+
if sql:
|
234 |
+
if st.session_state.get("show_sql", True):
|
235 |
+
assistant_message_sql = st.chat_message(
|
236 |
+
"assistant", avatar=myavatar
|
237 |
+
)
|
238 |
+
assistant_message_sql.code(sql, language="sql", line_numbers=True)
|
239 |
+
|
240 |
+
user_message_sql_check = st.chat_message("user")
|
241 |
+
user_message_sql_check.write(f"您是否满意上面的SQL查询语句答案?")
|
242 |
+
with user_message_sql_check:
|
243 |
+
happy_sql = st.radio(
|
244 |
+
# "Happy",
|
245 |
+
label="您的反馈:",
|
246 |
+
options=["未决定", "满意,请执行语句", "不满意,需要改正语句"], ## 去掉第一个happy选项。
|
247 |
+
# options=["", "yes", "no"], ## original。
|
248 |
+
key="radio_sql",
|
249 |
+
index=0,
|
250 |
+
horizontal=True,
|
251 |
+
label_visibility='visible',
|
252 |
+
)
|
253 |
+
|
254 |
+
# ## 非original。后期添加内容。Working.
|
255 |
+
if happy_sql == "未决定":
|
256 |
+
df = None
|
257 |
+
st.session_state["df"] = None
|
258 |
+
|
259 |
+
if happy_sql == "不满意,需要改正语句":
|
260 |
+
st.warning(
|
261 |
+
"请您手动修正SQL语句,按Shift + Enter提交"
|
262 |
+
)
|
263 |
+
sql_response = code_editor(sql, lang="sql")
|
264 |
+
fixed_sql_query = sql_response["text"]
|
265 |
+
|
266 |
+
if fixed_sql_query != "":
|
267 |
+
# df = run_sql_cached(sql=fixed_sql_query) ## original. 核心函数。
|
268 |
+
df = sql_command.llm_query(sql_command=fixed_sql_query)
|
269 |
+
else:
|
270 |
+
df = None
|
271 |
+
|
272 |
+
elif happy_sql == "满意,请执行语句":
|
273 |
+
# df = run_sql_cached(sql=sql) ## original. 核心函数。
|
274 |
+
df = sql_command.llm_query(sql_command=sql) ## working.
|
275 |
+
|
276 |
+
else:
|
277 |
+
df = None
|
278 |
+
|
279 |
+
if df is not None:
|
280 |
+
st.session_state["df"] = df
|
281 |
+
|
282 |
+
if st.session_state.get("df") is not None:
|
283 |
+
## 显示查询的结果汇总信息
|
284 |
+
# with st.container():
|
285 |
+
query_info_describe = st.chat_message("assistant", avatar=myavatar)
|
286 |
+
if df is not None:
|
287 |
+
with query_info_describe:
|
288 |
+
metric_col1, metric_col2, metric_col3, metric_col4 = st.columns(4)
|
289 |
+
metric_col1.metric(label='查询结果中的行数', value=f"{df.shape[0]} 行", delta=None)
|
290 |
+
metric_col2.metric(label='查询结果中的列数', value=f"{df.shape[1]} 列", delta=None)
|
291 |
+
metric_col3.metric(label='查询结果中的单元格���', value=f"{df.size} 个", delta=None)
|
292 |
+
metric_col4.metric(label='查询结果中的缺失值', value=f"{df.isnull().sum().sum()} 个", delta=None)
|
293 |
+
|
294 |
+
if st.session_state.get("show_table", True):
|
295 |
+
df = st.session_state.get("df")
|
296 |
+
|
297 |
+
assistant_message_table = st.chat_message(
|
298 |
+
"assistant",
|
299 |
+
avatar=myavatar
|
300 |
+
)
|
301 |
+
|
302 |
+
### 目前是最佳显示方式
|
303 |
+
if len(df) > 10:
|
304 |
+
assistant_message_table.markdown("**查询结果较多,当前只显示前10行数据**") ##TODO: 看看是否有其他更好的显示方式。
|
305 |
+
assistant_message_table.dataframe(df.head(10), hide_index=True)
|
306 |
+
# assistant_message_table.dataframe(df.head(100), use_container_width=True)
|
307 |
+
else:
|
308 |
+
assistant_message_table.dataframe(df, hide_index=True)
|
309 |
+
# assistant_message_table.dataframe(df,use_container_width=True)
|
310 |
+
|
311 |
+
## 以下是其他datafram的显示方式,均不理想。
|
312 |
+
# st.dataframe(df, use_container_width=True, height=100)
|
313 |
+
|
314 |
+
## streamlit AwesomeTable, Dark模式下显示有问题。可以在sidebar上显示内容。
|
315 |
+
# from awesome_table import AwesomeTable
|
316 |
+
# AwesomeTable(df,show_order=True, show_search=True, show_search_order_in_sidebar=True)
|
317 |
+
# # AwesomeTable(pd.json_normalize(df))
|
318 |
+
|
319 |
+
# ## 展示图形
|
320 |
+
# # code = generate_plotly_code_cached(question=my_question, sql=sql, df=df) ## orignal. 用LLM来构建统计图表:https://github.com/vanna-ai/vanna/blob/main/src/vanna/mistral/mistral.py
|
321 |
+
|
322 |
+
# if st.session_state.get("show_plotly_code", False):
|
323 |
+
# assistant_message_plotly_code = st.chat_message(
|
324 |
+
# "assistant",
|
325 |
+
# avatar=myavatar,
|
326 |
+
# )
|
327 |
+
# assistant_message_plotly_code.code(
|
328 |
+
# code, language="python", line_numbers=True
|
329 |
+
# )
|
330 |
+
|
331 |
+
# user_message_plotly_check = st.chat_message("user")
|
332 |
+
# user_message_plotly_check.write(
|
333 |
+
# f"Are you happy with the generated Plotly code?"
|
334 |
+
# )
|
335 |
+
# with user_message_plotly_check:
|
336 |
+
# happy_plotly = st.radio(
|
337 |
+
# "Happy",
|
338 |
+
# options=["", "yes", "no"],
|
339 |
+
# key="radio_plotly",
|
340 |
+
# index=0,
|
341 |
+
# )
|
342 |
+
|
343 |
+
# if happy_plotly == "no":
|
344 |
+
# st.warning(
|
345 |
+
# "Please fix the generated Python code. Once you're done hit Shift + Enter to submit"
|
346 |
+
# )
|
347 |
+
# python_code_response = code_editor(code, lang="python")
|
348 |
+
# code = python_code_response["text"]
|
349 |
+
# elif happy_plotly == "":
|
350 |
+
# code = None
|
351 |
+
|
352 |
+
# if code is not None and code != "":
|
353 |
+
# if st.session_state.get("show_chart", True):
|
354 |
+
# assistant_message_chart = st.chat_message(
|
355 |
+
# "assistant",
|
356 |
+
# avatar=myavatar,
|
357 |
+
# )
|
358 |
+
# fig = generate_plot_cached(code=code, df=df)
|
359 |
+
# if fig is not None:
|
360 |
+
# assistant_message_chart.plotly_chart(fig)
|
361 |
+
# else:
|
362 |
+
# assistant_message_chart.error("I couldn't generate a chart")
|
363 |
+
|
364 |
+
## Data Visualization: Ydata_profiling (pandas_profiling)
|
365 |
+
# from ydata_profiling import ProfileReport
|
366 |
+
# pr = ProfileReport(df)
|
367 |
+
# # with st.expander(label='**数据汇总信息**', expanded=False):
|
368 |
+
# with st.container():
|
369 |
+
# st_profile_report(pr)
|
370 |
+
# # st.stop()
|
371 |
+
|
372 |
+
## 数据可视化内容,高级版。
|
373 |
+
## pywalker: https://github.com/Kanaries/pygwalker/tree/main
|
374 |
+
user_message_plot_check = st.chat_message("user")
|
375 |
+
user_message_plot_check.write(f"您是否希望基于上述查询结果进行数据可视化?")
|
376 |
+
with user_message_plot_check:
|
377 |
+
happy_plot = st.radio(
|
378 |
+
# "Happy",
|
379 |
+
label="您的反馈:",
|
380 |
+
options=["未决定", "是,需要进行数据可视化", "否,不需要进行数据可视化"], ## 去掉第一个happy选项。
|
381 |
+
# options=["", "yes", "no"], ## original。
|
382 |
+
key="radio_plot",
|
383 |
+
index=0,
|
384 |
+
horizontal=True,
|
385 |
+
label_visibility='visible',
|
386 |
+
)
|
387 |
+
|
388 |
+
if happy_plot == "是,需要进行数��可视化":
|
389 |
+
import pygwalker as pyg
|
390 |
+
import streamlit.components.v1 as components
|
391 |
+
import streamlit as st
|
392 |
+
from pygwalker.api.streamlit import init_streamlit_comm, get_streamlit_html
|
393 |
+
|
394 |
+
# Initialize pygwalker communication
|
395 |
+
init_streamlit_comm()
|
396 |
+
|
397 |
+
# When using `use_kernel_calc=True`, you should cache your pygwalker html, if you don't want your memory to explode
|
398 |
+
@st.cache_resource
|
399 |
+
def get_pyg_html(df: pd.DataFrame) -> str:
|
400 |
+
# When you need to publish your application, you need set `debug=False`,prevent other users to write your config file.
|
401 |
+
# If you want to use feature of saving chart config, set `debug=True`
|
402 |
+
html = get_streamlit_html(df, spec="./gw0.json", use_kernel_calc=True, debug=False)
|
403 |
+
return html
|
404 |
+
# walker = pyg.walk(df)
|
405 |
+
components.html(get_pyg_html(df), width=1200, height=800, scrolling=True)
|
406 |
+
|
407 |
+
|
408 |
+
### streamlit echarts: https://github.com/andfanilo/streamlit-echarts
|
409 |
+
# from streamlit_echarts import st_echarts
|
410 |
+
# from streamlit_echarts import st_pyecharts
|
411 |
+
|
412 |
+
# st_pyecharts(df)
|
413 |
+
|
414 |
+
|
415 |
+
|
416 |
+
|
417 |
+
### Follow-up questions session.
|
418 |
+
# if st.session_state.get("show_followup", True):
|
419 |
+
# assistant_message_followup = st.chat_message(
|
420 |
+
# "assistant",
|
421 |
+
# avatar=myavatar,
|
422 |
+
# )
|
423 |
+
# followup_questions = generate_followup_cached(
|
424 |
+
# question=my_question, df=df
|
425 |
+
# )
|
426 |
+
# st.session_state["df"] = None
|
427 |
+
|
428 |
+
# if len(followup_questions) > 0:
|
429 |
+
# assistant_message_followup.text(
|
430 |
+
# "Here are some possible follow-up questions"
|
431 |
+
# )
|
432 |
+
# # Print the first 5 follow-up questions
|
433 |
+
# for question in followup_questions[:5]:
|
434 |
+
# time.sleep(0.05)
|
435 |
+
# assistant_message_followup.write(question)
|
436 |
+
|
437 |
+
else:
|
438 |
+
assistant_message_error = st.chat_message(
|
439 |
+
"assistant", avatar=myavatar
|
440 |
+
)
|
441 |
+
assistant_message_error.error("我无法回答您的问题,请重新提问,谢谢!")
|
chatsql004.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
1. 这里用公网Qwen API来代替本地的ChatGLM模型。为了在Huggingface上演示。
|
3 |
+
1. 使用确定了列名作为SQL语句的变量名,可以有效解决模型生成的SQL语句中变量名准确的问题。
|
4 |
+
|
5 |
+
"""
|
6 |
+
##TODO:
|
7 |
+
|
8 |
+
import requests
|
9 |
+
import os
|
10 |
+
from rich import print
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
import time
|
14 |
+
import pandas as pd
|
15 |
+
import numpy as np
|
16 |
+
import sys
|
17 |
+
import time
|
18 |
+
from typing import Any
|
19 |
+
import requests
|
20 |
+
import csv
|
21 |
+
import os
|
22 |
+
from rich import print
|
23 |
+
import pandas
|
24 |
+
import io
|
25 |
+
from io import StringIO
|
26 |
+
import re
|
27 |
+
from langchain.llms.utils import enforce_stop_tokens
|
28 |
+
import json
|
29 |
+
from transformers import AutoModel, AutoTokenizer
|
30 |
+
import mdtex2html
|
31 |
+
import qwen_response
|
32 |
+
|
33 |
+
|
34 |
+
''' Start: Environment settings. '''
|
35 |
+
os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/Users/yunshi/Downloads/chatGLM/My_LocalKB_Project/'
|
36 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
37 |
+
import torch
|
38 |
+
mps_device = torch.device("mps") ## 在mac机器上需要加上这句。必须要有这句,否则会报错。
|
39 |
+
|
40 |
+
### 在langchain中定义chatGLM作为LLM。
|
41 |
+
from typing import Any, List, Mapping, Optional
|
42 |
+
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
43 |
+
from langchain.llms.base import LLM
|
44 |
+
from transformers import AutoTokenizer, AutoModel
|
45 |
+
# llm_filepath = str("/Users/yunshi/Downloads/chatGLM/ChatGLM3-6B/6B") ## 第三代chatGLM 6B W/ code-interpreter
|
46 |
+
|
47 |
+
|
48 |
+
# ## API模式启动ChatGLM
|
49 |
+
# ## 配置ChatGLM的类与后端api server对应。
|
50 |
+
# class ChatGLM(LLM):
|
51 |
+
# max_token: int = 2048
|
52 |
+
# temperature: float = 0.1
|
53 |
+
# top_p = 0.9
|
54 |
+
# history = []
|
55 |
+
|
56 |
+
# def __init__(self):
|
57 |
+
# super().__init__()
|
58 |
+
|
59 |
+
# @property
|
60 |
+
# def _llm_type(self) -> str:
|
61 |
+
# return "ChatGLM"
|
62 |
+
|
63 |
+
# def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
64 |
+
# # headers中添加上content-type这个参数,指定为json格式
|
65 |
+
# headers = {'Content-Type': 'application/json'}
|
66 |
+
# data=json.dumps({
|
67 |
+
# 'prompt':prompt,
|
68 |
+
# 'temperature':self.temperature,
|
69 |
+
# 'history':self.history,
|
70 |
+
# 'max_length':self.max_token
|
71 |
+
# })
|
72 |
+
# print("ChatGLM prompt:",prompt)
|
73 |
+
# # 调用api
|
74 |
+
# # response = requests.post("http://0.0.0.0:8000",headers=headers,data=data) ##working。
|
75 |
+
# response = requests.post("http://127.0.0.1:8000",headers=headers,data=data) ##working。
|
76 |
+
# print("ChatGLM resp:", response)
|
77 |
+
|
78 |
+
# if response.status_code!=200:
|
79 |
+
# return "查询结果错误"
|
80 |
+
# resp = response.json()
|
81 |
+
# if stop is not None:
|
82 |
+
# response = enforce_stop_tokens(response, stop)
|
83 |
+
# self.history = self.history+[[None, resp['response']]] ##original
|
84 |
+
# return resp['response'] ##original.
|
85 |
+
|
86 |
+
# llm = ChatGLM() ## 启动一个实例。orignal working。
|
87 |
+
# import asyncio
|
88 |
+
# llm = ChatGLM() ## 启动一个实例。
|
89 |
+
|
90 |
+
|
91 |
+
''' End: Environment settings. '''
|
92 |
+
|
93 |
+
### 我会用中文或者英文双引号(即:“ ”," ")来告知你变量的名称。 长度","宽度","价格","产品ID","比率","类别","*"
|
94 |
+
|
95 |
+
### 用ChatGLM构建一个只返回SQL语句的模型。
|
96 |
+
def main(prompt):
|
97 |
+
full_reponse = []
|
98 |
+
sys_prompt = """
|
99 |
+
1. 你是一个将文字转换成SQL语句的人工智能。
|
100 |
+
2. 你需要注意:你只需要用纯文本回复代码的内容,即你不允许回复代码以外的任何信息。
|
101 |
+
3. SQL变量默认是中文,而且只能从如下的名称列表中选择,你不可以使用这些名字以外的变量名:"长度","宽度","价格","产品ID","比率","类别","*"
|
102 |
+
4. 你不能写IF, THEN的SQL语句,需要使用CASE。
|
103 |
+
5. 我需要你转换的文字如下:"""
|
104 |
+
|
105 |
+
total_prompt = sys_prompt + "在数据表格table01中," + prompt
|
106 |
+
|
107 |
+
print('total prompt now:',total_prompt)
|
108 |
+
# for response, history in chatglm.model.stream_chat(chatglm.tokenizer, query=str(prompt)): ## 这里保留了所有的chat history在input_prompt中。
|
109 |
+
# for response, history in chatglm.model.stream_chat(chatglm.tokenizer, query=str(prompt)): ## 这里保留了所有的chat history在input_prompt中。
|
110 |
+
# for response, history in chatglm.model.stream_chat(chatglm.tokenizer, query=str(total_prompt)): ## 这里保留了所有的chat history在input_prompt中。
|
111 |
+
|
112 |
+
# # for response, history in chatglm.model.stream_chat(chatglm.tokenizer, query=str(input_prompt[-1][0])): ## 从用langchain的自定义方式来做。
|
113 |
+
# # for response, history in chatglm.model.stream_chat(chatglm.tokenizer, query=str(input_prompt[-1][0]), history=input_prompt, max_length=max_tokens, top_p=top_p, temperature=temperature): ## 从用langchain的自定义方式来做。
|
114 |
+
# if response != "<br>":
|
115 |
+
# # print('response of model:', response)
|
116 |
+
# # input_prompt[-1][1] = response ## working.
|
117 |
+
# # input_prompt[-1][1] = response
|
118 |
+
# # yield input_prompt
|
119 |
+
# full_reponse.append(response)
|
120 |
+
|
121 |
+
# ## 得到一个非stream格式的答复。非API模式。
|
122 |
+
# response, history = chatglm.model.chat(chatglm.tokenizer, query=str(total_prompt), temperature=0.1) ## 这里保留了所有的chat history在input_prompt中。
|
123 |
+
|
124 |
+
###TODO:API模式,需要先启动API服务器。
|
125 |
+
# llm = ChatGLM() ##!! 重要说明:每次都需要实例化一次!!!否则会报错content error。实际上是应该在每次函数调用的时候都要实例化一次!
|
126 |
+
# response = llm(total_prompt) ## 这里是本地的ChatGLM来作为大模型输出基座。
|
127 |
+
|
128 |
+
|
129 |
+
response = qwen_response.call_with_messages(total_prompt)
|
130 |
+
|
131 |
+
print('response of model:', response)
|
132 |
+
|
133 |
+
## 用regex来提取纯SQL语句。需要构建多个正则式pattern
|
134 |
+
pattern_1 = r"(?:`sql\n|\n`)"
|
135 |
+
pattern_2 = r"(?:```|``)"
|
136 |
+
pattern_3 = r"(?s)(.*?SQL语句示例.*?:).*?\n"
|
137 |
+
pattern_4 = r"(?:`{3}|`{2}|`)"
|
138 |
+
# pattern_5 = r"[\u4e00-\u9FFF]" ## 匹配中文。
|
139 |
+
# pattern_6 = r"^[\u4e00-\u9fa5]{5,}" ## 首行中包含5个中文汉字的。
|
140 |
+
pattern_7 = r"^.{0,2}([\u4e00-\u9fa5]{5,}).*" ## 首行中包含5个中文汉字的。
|
141 |
+
pattern_8 = r'^"|"$' ## 去除一句话开始或者末尾的英文双引号
|
142 |
+
pattern_list = [pattern_1, pattern_2, pattern_3, pattern_4, pattern_7, pattern_8]
|
143 |
+
|
144 |
+
## 遍历所有的pattern,逐个去除。
|
145 |
+
full_reponse = response
|
146 |
+
for p in pattern_list:
|
147 |
+
full_reponse = re.sub(p, "", full_reponse)
|
148 |
+
# final_response = re.sub(pattern_1, "", response) ## 逐步匹配。
|
149 |
+
# final_response = re.sub(pattern_1, "", response) ## 逐步匹配。
|
150 |
+
# final_response = re.sub(pattern_2, "", final_response) ## 逐步匹配。
|
151 |
+
|
152 |
+
return full_reponse
|
153 |
+
|
154 |
+
# prompt = "你给我一段复杂的SQL语句示例。"
|
155 |
+
# prompt = "你给我一段SQL语句,用来完成如下工作:查询年龄大于30岁,男性,收入超过2万元的员工。"
|
156 |
+
# res = main(prompt=prompt)
|
157 |
+
# print(res)
|
myexcelDB.db
ADDED
Binary file (16.4 kB). View file
|
|
qwen_response.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from http import HTTPStatus
|
3 |
+
import dashscope
|
4 |
+
|
5 |
+
### 参考:
|
6 |
+
## export DASHSCOPE_API_KEY="sk-948adb3e65414e55961a9ad9d22d186b"
|
7 |
+
dashscope.api_key = "sk-948adb3e65414e55961a9ad9d22d186b"
|
8 |
+
|
9 |
+
# qwen_sys_prompt = """
|
10 |
+
# 1. 你是一个将文字转换成SQL语句的人工智能。
|
11 |
+
# 2. 此外你需要注意:你只需要用纯文本回复代码的内容,即你不允许回复代码以外的任何信息。
|
12 |
+
# 3. 你不能将将字段名翻译成英文,而是必须使用如下的名词:长度,宽度,价格,产品ID,比率,类别,*
|
13 |
+
# 4. 你不能写IF, THEN的SQL语句,需要使用CASE。
|
14 |
+
# """
|
15 |
+
|
16 |
+
qwen_sys_prompt = """你是一个将文字转换成SQL语句的人工智能。"""
|
17 |
+
|
18 |
+
def call_with_messages(prompt):
|
19 |
+
messages = [{'role': 'system', 'content': qwen_sys_prompt},
|
20 |
+
{'role': 'user', 'content': prompt}]
|
21 |
+
# {'role': 'user', 'content': '如何做西红柿炒鸡蛋?'}]
|
22 |
+
response = dashscope.Generation.call(
|
23 |
+
"qwen-turbo",
|
24 |
+
messages=messages,
|
25 |
+
# set the random seed, optional, default to 1234 if not set
|
26 |
+
seed=random.randint(1, 10000),
|
27 |
+
# set the result to be "message" format.
|
28 |
+
result_format='message',
|
29 |
+
)
|
30 |
+
if response.status_code == HTTPStatus.OK:
|
31 |
+
print(response)
|
32 |
+
|
33 |
+
else:
|
34 |
+
print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
|
35 |
+
response.request_id, response.status_code,
|
36 |
+
response.code, response.message
|
37 |
+
))
|
38 |
+
|
39 |
+
return response['output']['choices'][0]['message']['content'] ### 这里是content的内容,不是message的全部内容。
|
40 |
+
|
41 |
+
|
42 |
+
# if __name__ == '__main__':
|
43 |
+
# # call_with_messages() ### original code here.
|
44 |
+
# res = call_with_messages() ## working.
|
45 |
+
# # print(res)
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dashscope==1.17.0
|
2 |
+
fastapi==0.111.0
|
3 |
+
langchain==0.1.17
|
4 |
+
mdtex2html==1.2.0
|
5 |
+
nest_asyncio==1.5.8
|
6 |
+
numpy==1.26.4
|
7 |
+
pandas==2.2.2
|
8 |
+
pygwalker==0.4.8.1
|
9 |
+
pyngrok==7.0.5
|
10 |
+
python-dotenv==1.0.1
|
11 |
+
Requests==2.31.0
|
12 |
+
rich==13.7.1
|
13 |
+
streamlit==1.33.0
|
14 |
+
streamlit_code_editor==0.1.10
|
15 |
+
torch==2.2.0
|
16 |
+
transformers==4.37.1
|
17 |
+
uvicorn==0.29.0
|