allinaigc commited on
Commit
5255efb
·
verified ·
1 Parent(s): e8abf91

Upload 6 files

Browse files
Files changed (6) hide show
  1. 2D.png +0 -0
  2. app.py +441 -0
  3. chatsql004.py +157 -0
  4. myexcelDB.db +0 -0
  5. qwen_response.py +45 -0
  6. requirements.txt +17 -0
2D.png ADDED
app.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+
3
+ '''
4
+ ##TODO: 1. 转换成Qwen的API。 2.
5
+
6
+ import time
7
+ import os
8
+ import pandas as pd
9
+ import streamlit as st
10
+ from code_editor import code_editor
11
+ # from utils.setup import setup_connexion, setup_session_state
12
+ # from utils.vanna_calls import (
13
+ # generate_questions_cached,
14
+ # generate_sql_cached,
15
+ # run_sql_cached,
16
+ # generate_plotly_code_cached,
17
+ # generate_plot_cached,
18
+ # generate_followup_cached,
19
+ # )
20
+ # import chatsql003 ### 本地ChatGLT 版本。
21
+ import chatsql004 ###Qwen API 版本。
22
+ import sql_command
23
+ # from streamlit_pandas_profiling import st_profile_report
24
+ import dashscope
25
+ from dotenv import load_dotenv
26
+
27
+
28
+ load_dotenv()
29
+ ### 设置openai的API key
30
+ dashscope.api_key = os.environ['dashscope_api_key']
31
+
32
+ st.set_page_config(layout="wide", page_icon="🧩", page_title="本地化国产大模型数据库查询演示")
33
+ # setup_connexion()
34
+
35
+ def clear_all():
36
+ st.session_state.conversation = None
37
+ st.session_state.chat_history = None
38
+ st.session_state.messages = []
39
+ message_placeholder = st.empty()
40
+ st.session_state["my_question"] = None
41
+ return None
42
+ ## 原始的控制面板
43
+ # st.sidebar.title("大模型控制面板")
44
+ # st.sidebar.checkbox("Show SQL", value=True, key="show_sql")
45
+ # st.sidebar.checkbox("Show Table", value=True, key="show_table")
46
+ # st.sidebar.checkbox("Show Plotly Code", value=True, key="show_plotly_code")
47
+ # st.sidebar.checkbox("Show Chart", value=True, key="show_chart")
48
+ # st.sidebar.checkbox("Show Follow-up Questions", value=True, key="show_followup")
49
+ # st.sidebar.button("清除记录,重启一轮新对话", on_click=setup_session_state, use_container_width=True, type='primary')
50
+
51
+ st.title("本地化国产大模型数据库查询演示")
52
+ # st.title("大语言模型SQL数据库查询中心")
53
+ # st.info("声明:内容由人工智能生成,仅供参考。如果您本人使用或对外传播本服务生成的输出,您应当主动核查输出内容的真实性、准确性,避免传播虚假信息。") ## 颜色比较明显。
54
+ st.markdown("_声明:内容由人工智能生成,仅供参考。如果您本人使用或对外传播本服务生成的输出,您应当主动核查输出内容的真实性、准确性,避免传播虚假信息。_")
55
+
56
+
57
+ ### Streamlit Sidebar 左侧工具栏
58
+ # st.sidebar.write(st.session_state)
59
+ with st.sidebar:
60
+ st.markdown(
61
+ """
62
+ <style>
63
+ [data-testid="stSidebar"][aria-expanded="true"]{
64
+ min-width: 500px;
65
+ max-width: 500px;
66
+ }
67
+ """,
68
+ unsafe_allow_html=True,
69
+ )
70
+ ### siderbar的题目。
71
+ ### siderbar的题目。
72
+ # st.header(f'**大语言模型专家系统工作设定区**')
73
+ st.header(f'**系统控制面板** ')
74
+ # st.header(f'**欢迎 **{username}** 使用本系统** ') ## 用户登录显示。
75
+ st.write(f'_Large Language Model Expert System Working Environment_')
76
+ st.sidebar.button("清除记录,重启一轮新对话", on_click=clear_all, use_container_width=True, type='primary')
77
+ # st.sidebar.button("清除记录,重启一轮新对话", on_click=setup_session_state, use_container_width=True, type='primary')
78
+
79
+ ### 展示当前数据库
80
+ # st.markdown("#### 当前数据库中的数据:")
81
+ with st.expander("#### 当前数据库中的数据", expanded=True):
82
+ my_db = pd.read_sql_table('table01', 'sqlite:///myexcelDB.db')
83
+ st.dataframe(my_db, width=400)
84
+
85
+ ## 在sidebar上的三个分页显示,用st.tabs实现。
86
+ tab_1, tab_2, tab_4 = st.tabs(['使用须知', '模型参数', '角色设定'])
87
+ # tab_1, tab_2, tab_3, tab_4 = st.tabs(['使用须知', '模型参数', '提示词模板', '系统角色设定'])
88
+
89
+ # with st.expander(label='**使用须知**', expanded=False):
90
+ with tab_1:
91
+ # st.markdown("#### 快速上手指南")
92
+ # with st.text(body="说明"):
93
+ # st.markdown("* 重启一轮新对话时,只需要刷新页面(按Ctrl/Command + R)即可。")
94
+ with st.text(body="说明"):
95
+ st.markdown("""* 简介
96
+
97
+ 本系统使用大模型技术,将自然语言描述转换为SQL查询语句。用户可以输入自然语言问题,系统会自动生成相应的SQL语句,并返回查询结果。
98
+
99
+ * 使用步骤
100
+
101
+ 在文本框中输入您的自然语言问题。
102
+ 系统会自动生成相应的SQL语句,并显示在下方。
103
+ 点击“满意,请执行语句”选项,可查看查询结果。点击”不满意,需要改正语句“可以手动修改SQL语句。
104
+
105
+ * 注意事项
106
+
107
+ 本系统仍在开发中,可能会存在一些错误或不准确的地方。
108
+ 自然语言描述越清晰,生成的SQL语句越准确。
109
+ 系统支持的SQL语法有限,请尽量使用简单易懂的语法。
110
+ """)
111
+ # with st.text(body="说明"):
112
+ # st.markdown("""在构建大语言模型本地知识库问答系统时,需要注意以下几点:
113
+
114
+ # LLM的选择:LLM的选择应根据系统的应用场景和需求进行。对于需要处理通用问题的系统,可以选择通用LLM;对于需要处理特定领域问题的系统,可以选择针对该领域进行微调的LLM。
115
+ # 本地知识库的构建:本地知识库应包含系统所需的所有知识。知识的组织方式应便于LLM的访问和处理。
116
+ # 系统的评估:系统应进行充分的评估,以确保其能够准确地回答用户问题。
117
+ # 大语言模型本地知识库问答系统具有以下优势:
118
+
119
+ # 准确性:LLM和本地知识库的结合可以提高系统的准确性。
120
+ # 全面性:LLM和本地知识库的结合可以使系统能够回答更广泛的问题。
121
+ # 效率:LLM可以快速生成候选答案,而本地知识库可以快速评估候选答案的准确性。
122
+ # """)
123
+
124
+
125
+ # with st.text(body="说明"):
126
+ # st.markdown(
127
+ # "* “数据分析模式”暂时只支持1000个单元格以内的数据分析,单元格中的内容不支持中文数据(表头也尽量不使用中文)。一般运行时间在1至10分钟左右,期间需要保持网络畅通。")
128
+ # with st.text(body="说明"):
129
+ # st.markdown("* “数据分析模式”推荐上传csv格式的文件,部分Excel文件容易出现数据不兼容的情况。")
130
+
131
+ ## 大模型参数
132
+ # with st.expander(label='**大语言模型参数**', expanded=True):
133
+ with tab_2:
134
+ max_tokens = st.slider(label='Max_Token(生成结果时最大字数)', min_value=100, max_value=4096, value=2048, step=100)
135
+ temperature = st.slider(label='Temperature (温度)', min_value=0.0, max_value=1.0, value=0.8, step=0.1)
136
+ top_p = st.slider(label='Top_P (核采样)', min_value=0.0, max_value=1.0, value=0.6, step=0.1)
137
+ frequency_penalty = st.slider(label='Frequency Penalty (重复度惩罚因子)', min_value=-2.0, max_value=2.0, value=1.0, step=0.1)
138
+ presence_penalty = st.slider(label='Presence Penalty (控制主题的重复度)', min_value=-2.0, max_value=2.0, value=1.0, step=0.1)
139
+
140
+ ## reset password widget
141
+ # try:
142
+ # if authenticator.reset_password(st.session_state["username"], 'Reset password'):
143
+ # st.success('Password modified successfully')
144
+ # except Exception as e:
145
+ # st.error(e)
146
+
147
+ # with st.header(body="欢迎"):
148
+ # st.markdown("# 欢迎使用大语言模型商业智能中心")
149
+ # with st.expander(label=("**重要的使用注意事项**"), expanded=True):
150
+ # with st.container():
151
+
152
+ ##NOTE: 在SQL场景去不需要展示这些提示词。
153
+ # with tab_3:
154
+ # # st.markdown("#### Prompt提示词参考资料")
155
+ # # with st.expander(label="**大语言模型基础提示词Prompt示例**", expanded=False):
156
+ # st.code(
157
+ # body="继续用中文写一篇关于 [文章主题] 的文章,以下列句子开头:[文章开头]。", language='plaintext')
158
+ # st.code(body="将以下文字概括为 100 个字,使其易于阅读和理解。避免使用复杂的句子结构或技术术语。",
159
+ # language='plaintext')
160
+ # st.code(body="给我出一个迪奥2023春季发布会活动策划。", language='plaintext')
161
+ # st.code(body="帮我按照正式会议结构写一个会邀:主题是xx手机游戏立项会议。", language='plaintext')
162
+ # st.code(body="帮我写一个车内健康监测全场景落地的项目计划,用表格。", language='plaintext')
163
+ # st.code(
164
+ # body="同时掷两枚质地均匀的骰子,则两枚骰子向上的点数之和为 7 的概率是多少。", language='plaintext')
165
+ # st.code(body="写一篇产品经理的演讲稿,注意使用以下词汇: 赋能,抓手,中台,闭环,落地,漏斗,沉淀,给到,同步,对齐,对标,迭代,拉通,打通,升级,交付,聚焦,倒逼,复盘,梳理,方案,联动,透传,咬合,洞察,渗透,兜底,解耦,耦合,复用,拆解。", language='plaintext')
166
+
167
+ # with st.expander(label="**数据分析模式的专用提示词Prompt示例**", expanded=False):
168
+ # # with st.subheader(body="提示词Prompt"):
169
+ # st.code(body="分析此数据集并绘制一些'有趣的图表'。", language='python')
170
+ # st.code(
171
+ # body="对于这个文件中的数据,你需要要找出[X,Y]数据之间的寻找'相关性'。", language='python')
172
+ # st.code(body="对于这个文件中的[xxx]数据给我一个'整体的分析'。", language='python')
173
+ # st.code(body="对于[xxx]数据给我一个'直方图',提供图表,并给出分析结果。", language='python')
174
+ # st.code(body="对于[xxx]数据给我一个'小提琴图',并给出分析结果。", language='python')
175
+ # st.code(
176
+ # body="对于[X,Y,Z]数据在一个'分布散点图 (stripplot)'���所有的数据在一张图上展现, 并给出分析结果。", language='python')
177
+ # st.code(body="对于[X,Y]数据,进行'T检验',你需要展示图表,并给出分析结果。",
178
+ # language='python')
179
+ # st.code(body="对于[X,Y]数据给我一个3个类别的'聚类分析',并给出分析结果。",
180
+ # language='python')
181
+
182
+ with tab_4:
183
+ st.text_area(label='系统角色设定', value='你是一个人工智能,你需要回答我提出的问题,或者完成我交代的任务。你需要使用我提问的语言(如中文、英文)来回答。', height=200, label_visibility='hidden')
184
+
185
+
186
+ myavatar = "./2D.png"
187
+
188
+ def set_question(question):
189
+ st.session_state["my_question"] = question
190
+
191
+
192
+ assistant_message_suggested = st.chat_message(
193
+ "assistant", avatar=myavatar
194
+ )
195
+
196
+ ##! 注意在填写列名时,最佳的选择是使用双引号将名字框出来。否则LLM会不认识。
197
+ if assistant_message_suggested.button("点击这里可以查看参考问题示例 ℹ️"):
198
+ questions = [
199
+ '你查一下长度超过10的数据有哪些?',
200
+ '你告诉我长度超过30且宽度超过10的数据有哪些?',
201
+ '宽度超过20且价格超过25的数据有哪些?',
202
+ '”长度“超过10且”宽度“超过15且”类别“是”甲“的数据有哪些?',
203
+ '长度超过10以上的价格的总和是多少?',
204
+ '长度排名前三的产品ID和价格和长度?',
205
+
206
+ ]
207
+ st.session_state["my_question"] = None
208
+ # questions = generate_questions_cached() ## orignal.
209
+ for i, question in enumerate(questions):
210
+ time.sleep(0.05)
211
+ button = st.button(
212
+ question,
213
+ on_click=set_question,
214
+ args=(question,),
215
+ )
216
+
217
+ my_question = st.session_state.get("my_question", default=None)
218
+
219
+ if my_question is None:
220
+ my_question = st.chat_input(
221
+ "请在这里提交您的查询问题",
222
+ )
223
+
224
+ if my_question:
225
+ st.session_state["my_question"] = my_question
226
+ user_message = st.chat_message("user")
227
+ user_message.write(f"{my_question}")
228
+
229
+ # sql = generate_sql_cached(question=my_question) ## original. 核心函数。用ChatGPT输出SQL语句。
230
+ # sql = chatsql003.main(prompt=my_question) ## 通过本地LLM输出SQL语句。
231
+ sql = chatsql004.main(prompt=my_question) ## 通过本地LLM输出SQL语句。
232
+
233
+ if sql:
234
+ if st.session_state.get("show_sql", True):
235
+ assistant_message_sql = st.chat_message(
236
+ "assistant", avatar=myavatar
237
+ )
238
+ assistant_message_sql.code(sql, language="sql", line_numbers=True)
239
+
240
+ user_message_sql_check = st.chat_message("user")
241
+ user_message_sql_check.write(f"您是否满意上面的SQL查询语句答案?")
242
+ with user_message_sql_check:
243
+ happy_sql = st.radio(
244
+ # "Happy",
245
+ label="您的反馈:",
246
+ options=["未决定", "满意,请执行语句", "不满意,需要改正语句"], ## 去掉第一个happy选项。
247
+ # options=["", "yes", "no"], ## original。
248
+ key="radio_sql",
249
+ index=0,
250
+ horizontal=True,
251
+ label_visibility='visible',
252
+ )
253
+
254
+ # ## 非original。后期添加内容。Working.
255
+ if happy_sql == "未决定":
256
+ df = None
257
+ st.session_state["df"] = None
258
+
259
+ if happy_sql == "不满意,需要改正语句":
260
+ st.warning(
261
+ "请您手动修正SQL语句,按Shift + Enter提交"
262
+ )
263
+ sql_response = code_editor(sql, lang="sql")
264
+ fixed_sql_query = sql_response["text"]
265
+
266
+ if fixed_sql_query != "":
267
+ # df = run_sql_cached(sql=fixed_sql_query) ## original. 核心函数。
268
+ df = sql_command.llm_query(sql_command=fixed_sql_query)
269
+ else:
270
+ df = None
271
+
272
+ elif happy_sql == "满意,请执行语句":
273
+ # df = run_sql_cached(sql=sql) ## original. 核心函数。
274
+ df = sql_command.llm_query(sql_command=sql) ## working.
275
+
276
+ else:
277
+ df = None
278
+
279
+ if df is not None:
280
+ st.session_state["df"] = df
281
+
282
+ if st.session_state.get("df") is not None:
283
+ ## 显示查询的结果汇总信息
284
+ # with st.container():
285
+ query_info_describe = st.chat_message("assistant", avatar=myavatar)
286
+ if df is not None:
287
+ with query_info_describe:
288
+ metric_col1, metric_col2, metric_col3, metric_col4 = st.columns(4)
289
+ metric_col1.metric(label='查询结果中的行数', value=f"{df.shape[0]} 行", delta=None)
290
+ metric_col2.metric(label='查询结果中的列数', value=f"{df.shape[1]} 列", delta=None)
291
+ metric_col3.metric(label='查询结果中的单元格���', value=f"{df.size} 个", delta=None)
292
+ metric_col4.metric(label='查询结果中的缺失值', value=f"{df.isnull().sum().sum()} 个", delta=None)
293
+
294
+ if st.session_state.get("show_table", True):
295
+ df = st.session_state.get("df")
296
+
297
+ assistant_message_table = st.chat_message(
298
+ "assistant",
299
+ avatar=myavatar
300
+ )
301
+
302
+ ### 目前是最佳显示方式
303
+ if len(df) > 10:
304
+ assistant_message_table.markdown("**查询结果较多,当前只显示前10行数据**") ##TODO: 看看是否有其他更好的显示方式。
305
+ assistant_message_table.dataframe(df.head(10), hide_index=True)
306
+ # assistant_message_table.dataframe(df.head(100), use_container_width=True)
307
+ else:
308
+ assistant_message_table.dataframe(df, hide_index=True)
309
+ # assistant_message_table.dataframe(df,use_container_width=True)
310
+
311
+ ## 以下是其他datafram的显示方式,均不理想。
312
+ # st.dataframe(df, use_container_width=True, height=100)
313
+
314
+ ## streamlit AwesomeTable, Dark模式下显示有问题。可以在sidebar上显示内容。
315
+ # from awesome_table import AwesomeTable
316
+ # AwesomeTable(df,show_order=True, show_search=True, show_search_order_in_sidebar=True)
317
+ # # AwesomeTable(pd.json_normalize(df))
318
+
319
+ # ## 展示图形
320
+ # # code = generate_plotly_code_cached(question=my_question, sql=sql, df=df) ## orignal. 用LLM来构建统计图表:https://github.com/vanna-ai/vanna/blob/main/src/vanna/mistral/mistral.py
321
+
322
+ # if st.session_state.get("show_plotly_code", False):
323
+ # assistant_message_plotly_code = st.chat_message(
324
+ # "assistant",
325
+ # avatar=myavatar,
326
+ # )
327
+ # assistant_message_plotly_code.code(
328
+ # code, language="python", line_numbers=True
329
+ # )
330
+
331
+ # user_message_plotly_check = st.chat_message("user")
332
+ # user_message_plotly_check.write(
333
+ # f"Are you happy with the generated Plotly code?"
334
+ # )
335
+ # with user_message_plotly_check:
336
+ # happy_plotly = st.radio(
337
+ # "Happy",
338
+ # options=["", "yes", "no"],
339
+ # key="radio_plotly",
340
+ # index=0,
341
+ # )
342
+
343
+ # if happy_plotly == "no":
344
+ # st.warning(
345
+ # "Please fix the generated Python code. Once you're done hit Shift + Enter to submit"
346
+ # )
347
+ # python_code_response = code_editor(code, lang="python")
348
+ # code = python_code_response["text"]
349
+ # elif happy_plotly == "":
350
+ # code = None
351
+
352
+ # if code is not None and code != "":
353
+ # if st.session_state.get("show_chart", True):
354
+ # assistant_message_chart = st.chat_message(
355
+ # "assistant",
356
+ # avatar=myavatar,
357
+ # )
358
+ # fig = generate_plot_cached(code=code, df=df)
359
+ # if fig is not None:
360
+ # assistant_message_chart.plotly_chart(fig)
361
+ # else:
362
+ # assistant_message_chart.error("I couldn't generate a chart")
363
+
364
+ ## Data Visualization: Ydata_profiling (pandas_profiling)
365
+ # from ydata_profiling import ProfileReport
366
+ # pr = ProfileReport(df)
367
+ # # with st.expander(label='**数据汇总信息**', expanded=False):
368
+ # with st.container():
369
+ # st_profile_report(pr)
370
+ # # st.stop()
371
+
372
+ ## 数据可视化内容,高级版。
373
+ ## pywalker: https://github.com/Kanaries/pygwalker/tree/main
374
+ user_message_plot_check = st.chat_message("user")
375
+ user_message_plot_check.write(f"您是否希望基于上述查询结果进行数据可视化?")
376
+ with user_message_plot_check:
377
+ happy_plot = st.radio(
378
+ # "Happy",
379
+ label="您的反馈:",
380
+ options=["未决定", "是,需要进行数据可视化", "否,不需要进行数据可视化"], ## 去掉第一个happy选项。
381
+ # options=["", "yes", "no"], ## original。
382
+ key="radio_plot",
383
+ index=0,
384
+ horizontal=True,
385
+ label_visibility='visible',
386
+ )
387
+
388
+ if happy_plot == "是,需要进行数��可视化":
389
+ import pygwalker as pyg
390
+ import streamlit.components.v1 as components
391
+ import streamlit as st
392
+ from pygwalker.api.streamlit import init_streamlit_comm, get_streamlit_html
393
+
394
+ # Initialize pygwalker communication
395
+ init_streamlit_comm()
396
+
397
+ # When using `use_kernel_calc=True`, you should cache your pygwalker html, if you don't want your memory to explode
398
+ @st.cache_resource
399
+ def get_pyg_html(df: pd.DataFrame) -> str:
400
+ # When you need to publish your application, you need set `debug=False`,prevent other users to write your config file.
401
+ # If you want to use feature of saving chart config, set `debug=True`
402
+ html = get_streamlit_html(df, spec="./gw0.json", use_kernel_calc=True, debug=False)
403
+ return html
404
+ # walker = pyg.walk(df)
405
+ components.html(get_pyg_html(df), width=1200, height=800, scrolling=True)
406
+
407
+
408
+ ### streamlit echarts: https://github.com/andfanilo/streamlit-echarts
409
+ # from streamlit_echarts import st_echarts
410
+ # from streamlit_echarts import st_pyecharts
411
+
412
+ # st_pyecharts(df)
413
+
414
+
415
+
416
+
417
+ ### Follow-up questions session.
418
+ # if st.session_state.get("show_followup", True):
419
+ # assistant_message_followup = st.chat_message(
420
+ # "assistant",
421
+ # avatar=myavatar,
422
+ # )
423
+ # followup_questions = generate_followup_cached(
424
+ # question=my_question, df=df
425
+ # )
426
+ # st.session_state["df"] = None
427
+
428
+ # if len(followup_questions) > 0:
429
+ # assistant_message_followup.text(
430
+ # "Here are some possible follow-up questions"
431
+ # )
432
+ # # Print the first 5 follow-up questions
433
+ # for question in followup_questions[:5]:
434
+ # time.sleep(0.05)
435
+ # assistant_message_followup.write(question)
436
+
437
+ else:
438
+ assistant_message_error = st.chat_message(
439
+ "assistant", avatar=myavatar
440
+ )
441
+ assistant_message_error.error("我无法回答您的问题,请重新提问,谢谢!")
chatsql004.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 1. 这里用公网Qwen API来代替本地的ChatGLM模型。为了在Huggingface上演示。
3
+ 1. 使用确定了列名作为SQL语句的变量名,可以有效解决模型生成的SQL语句中变量名准确的问题。
4
+
5
+ """
6
+ ##TODO:
7
+
8
+ import requests
9
+ import os
10
+ from rich import print
11
+ import os
12
+ import sys
13
+ import time
14
+ import pandas as pd
15
+ import numpy as np
16
+ import sys
17
+ import time
18
+ from typing import Any
19
+ import requests
20
+ import csv
21
+ import os
22
+ from rich import print
23
+ import pandas
24
+ import io
25
+ from io import StringIO
26
+ import re
27
+ from langchain.llms.utils import enforce_stop_tokens
28
+ import json
29
+ from transformers import AutoModel, AutoTokenizer
30
+ import mdtex2html
31
+ import qwen_response
32
+
33
+
34
+ ''' Start: Environment settings. '''
35
+ os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/Users/yunshi/Downloads/chatGLM/My_LocalKB_Project/'
36
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
37
+ import torch
38
+ mps_device = torch.device("mps") ## 在mac机器上需要加上这句。必须要有这句,否则会报错。
39
+
40
+ ### 在langchain中定义chatGLM作为LLM。
41
+ from typing import Any, List, Mapping, Optional
42
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
43
+ from langchain.llms.base import LLM
44
+ from transformers import AutoTokenizer, AutoModel
45
+ # llm_filepath = str("/Users/yunshi/Downloads/chatGLM/ChatGLM3-6B/6B") ## 第三代chatGLM 6B W/ code-interpreter
46
+
47
+
48
+ # ## API模式启动ChatGLM
49
+ # ## 配置ChatGLM的类与后端api server对应。
50
+ # class ChatGLM(LLM):
51
+ # max_token: int = 2048
52
+ # temperature: float = 0.1
53
+ # top_p = 0.9
54
+ # history = []
55
+
56
+ # def __init__(self):
57
+ # super().__init__()
58
+
59
+ # @property
60
+ # def _llm_type(self) -> str:
61
+ # return "ChatGLM"
62
+
63
+ # def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
64
+ # # headers中添加上content-type这个参数,指定为json格式
65
+ # headers = {'Content-Type': 'application/json'}
66
+ # data=json.dumps({
67
+ # 'prompt':prompt,
68
+ # 'temperature':self.temperature,
69
+ # 'history':self.history,
70
+ # 'max_length':self.max_token
71
+ # })
72
+ # print("ChatGLM prompt:",prompt)
73
+ # # 调用api
74
+ # # response = requests.post("http://0.0.0.0:8000",headers=headers,data=data) ##working。
75
+ # response = requests.post("http://127.0.0.1:8000",headers=headers,data=data) ##working。
76
+ # print("ChatGLM resp:", response)
77
+
78
+ # if response.status_code!=200:
79
+ # return "查询结果错误"
80
+ # resp = response.json()
81
+ # if stop is not None:
82
+ # response = enforce_stop_tokens(response, stop)
83
+ # self.history = self.history+[[None, resp['response']]] ##original
84
+ # return resp['response'] ##original.
85
+
86
+ # llm = ChatGLM() ## 启动一个实例。orignal working。
87
+ # import asyncio
88
+ # llm = ChatGLM() ## 启动一个实例。
89
+
90
+
91
+ ''' End: Environment settings. '''
92
+
93
+ ### 我会用中文或者英文双引号(即:“ ”," ")来告知你变量的名称。 长度","宽度","价格","产品ID","比率","类别","*"
94
+
95
+ ### 用ChatGLM构建一个只返回SQL语句的模型。
96
+ def main(prompt):
97
+ full_reponse = []
98
+ sys_prompt = """
99
+ 1. 你是一个将文字转换成SQL语句的人工智能。
100
+ 2. 你需要注意:你只需要用纯文本回复代码的内容,即你不允许回复代码以外的任何信息。
101
+ 3. SQL变量默认是中文,而且只能从如下的名称列表中选择,你不可以使用这些名字以外的变量名:"长度","宽度","价格","产品ID","比率","类别","*"
102
+ 4. 你不能写IF, THEN的SQL语句,需要使用CASE。
103
+ 5. 我需要你转换的文字如下:"""
104
+
105
+ total_prompt = sys_prompt + "在数据表格table01中," + prompt
106
+
107
+ print('total prompt now:',total_prompt)
108
+ # for response, history in chatglm.model.stream_chat(chatglm.tokenizer, query=str(prompt)): ## 这里保留了所有的chat history在input_prompt中。
109
+ # for response, history in chatglm.model.stream_chat(chatglm.tokenizer, query=str(prompt)): ## 这里保留了所有的chat history在input_prompt中。
110
+ # for response, history in chatglm.model.stream_chat(chatglm.tokenizer, query=str(total_prompt)): ## 这里保留了所有的chat history在input_prompt中。
111
+
112
+ # # for response, history in chatglm.model.stream_chat(chatglm.tokenizer, query=str(input_prompt[-1][0])): ## 从用langchain的自定义方式来做。
113
+ # # for response, history in chatglm.model.stream_chat(chatglm.tokenizer, query=str(input_prompt[-1][0]), history=input_prompt, max_length=max_tokens, top_p=top_p, temperature=temperature): ## 从用langchain的自定义方式来做。
114
+ # if response != "<br>":
115
+ # # print('response of model:', response)
116
+ # # input_prompt[-1][1] = response ## working.
117
+ # # input_prompt[-1][1] = response
118
+ # # yield input_prompt
119
+ # full_reponse.append(response)
120
+
121
+ # ## 得到一个非stream格式的答复。非API模式。
122
+ # response, history = chatglm.model.chat(chatglm.tokenizer, query=str(total_prompt), temperature=0.1) ## 这里保留了所有的chat history在input_prompt中。
123
+
124
+ ###TODO:API模式,需要先启动API服务器。
125
+ # llm = ChatGLM() ##!! 重要说明:每次都需要实例化一次!!!否则会报错content error。实际上是应该在每次函数调用的时候都要实例化一次!
126
+ # response = llm(total_prompt) ## 这里是本地的ChatGLM来作为大模型输出基座。
127
+
128
+
129
+ response = qwen_response.call_with_messages(total_prompt)
130
+
131
+ print('response of model:', response)
132
+
133
+ ## 用regex来提取纯SQL语句。需要构建多个正则式pattern
134
+ pattern_1 = r"(?:`sql\n|\n`)"
135
+ pattern_2 = r"(?:```|``)"
136
+ pattern_3 = r"(?s)(.*?SQL语句示例.*?:).*?\n"
137
+ pattern_4 = r"(?:`{3}|`{2}|`)"
138
+ # pattern_5 = r"[\u4e00-\u9FFF]" ## 匹配中文。
139
+ # pattern_6 = r"^[\u4e00-\u9fa5]{5,}" ## 首行中包含5个中文汉字的。
140
+ pattern_7 = r"^.{0,2}([\u4e00-\u9fa5]{5,}).*" ## 首行中包含5个中文汉字的。
141
+ pattern_8 = r'^"|"$' ## 去除一句话开始或者末尾的英文双引号
142
+ pattern_list = [pattern_1, pattern_2, pattern_3, pattern_4, pattern_7, pattern_8]
143
+
144
+ ## 遍历所有的pattern,逐个去除。
145
+ full_reponse = response
146
+ for p in pattern_list:
147
+ full_reponse = re.sub(p, "", full_reponse)
148
+ # final_response = re.sub(pattern_1, "", response) ## 逐步匹配。
149
+ # final_response = re.sub(pattern_1, "", response) ## 逐步匹配。
150
+ # final_response = re.sub(pattern_2, "", final_response) ## 逐步匹配。
151
+
152
+ return full_reponse
153
+
154
+ # prompt = "你给我一段复杂的SQL语句示例。"
155
+ # prompt = "你给我一段SQL语句,用来完成如下工作:查询年龄大于30岁,男性,收入超过2万元的员工。"
156
+ # res = main(prompt=prompt)
157
+ # print(res)
myexcelDB.db ADDED
Binary file (16.4 kB). View file
 
qwen_response.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from http import HTTPStatus
3
+ import dashscope
4
+
5
+ ### 参考:
6
+ ## export DASHSCOPE_API_KEY="sk-948adb3e65414e55961a9ad9d22d186b"
7
+ dashscope.api_key = "sk-948adb3e65414e55961a9ad9d22d186b"
8
+
9
+ # qwen_sys_prompt = """
10
+ # 1. 你是一个将文字转换成SQL语句的人工智能。
11
+ # 2. 此外你需要注意:你只需要用纯文本回复代码的内容,即你不允许回复代码以外的任何信息。
12
+ # 3. 你不能将将字段名翻译成英文,而是必须使用如下的名词:长度,宽度,价格,产品ID,比率,类别,*
13
+ # 4. 你不能写IF, THEN的SQL语句,需要使用CASE。
14
+ # """
15
+
16
+ qwen_sys_prompt = """你是一个将文字转换成SQL语句的人工智能。"""
17
+
18
+ def call_with_messages(prompt):
19
+ messages = [{'role': 'system', 'content': qwen_sys_prompt},
20
+ {'role': 'user', 'content': prompt}]
21
+ # {'role': 'user', 'content': '如何做西红柿炒鸡蛋?'}]
22
+ response = dashscope.Generation.call(
23
+ "qwen-turbo",
24
+ messages=messages,
25
+ # set the random seed, optional, default to 1234 if not set
26
+ seed=random.randint(1, 10000),
27
+ # set the result to be "message" format.
28
+ result_format='message',
29
+ )
30
+ if response.status_code == HTTPStatus.OK:
31
+ print(response)
32
+
33
+ else:
34
+ print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
35
+ response.request_id, response.status_code,
36
+ response.code, response.message
37
+ ))
38
+
39
+ return response['output']['choices'][0]['message']['content'] ### 这里是content的内容,不是message的全部内容。
40
+
41
+
42
+ # if __name__ == '__main__':
43
+ # # call_with_messages() ### original code here.
44
+ # res = call_with_messages() ## working.
45
+ # # print(res)
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dashscope==1.17.0
2
+ fastapi==0.111.0
3
+ langchain==0.1.17
4
+ mdtex2html==1.2.0
5
+ nest_asyncio==1.5.8
6
+ numpy==1.26.4
7
+ pandas==2.2.2
8
+ pygwalker==0.4.8.1
9
+ pyngrok==7.0.5
10
+ python-dotenv==1.0.1
11
+ Requests==2.31.0
12
+ rich==13.7.1
13
+ streamlit==1.33.0
14
+ streamlit_code_editor==0.1.10
15
+ torch==2.2.0
16
+ transformers==4.37.1
17
+ uvicorn==0.29.0