allinaigc commited on
Commit
a44d8a9
·
1 Parent(s): 7723f88

Upload 6 files

Browse files
Files changed (3) hide show
  1. app.py +312 -100
  2. localKB_construct copy.py +101 -0
  3. save_database_info.py +47 -0
app.py CHANGED
@@ -16,12 +16,9 @@
16
  # credentials["usernames"].update({un:user_dict})
17
  credentials["usernames"].update({un: user_dict})
18
 
19
-
20
-
21
  '''
22
- # TODO:1. Chinese display isssue. 2. account system.
23
 
24
- from dotenv import load_dotenv # pip3 install python-dotenv
25
  import database as db
26
  from deta import Deta # pip3 install deta
27
  import requests
@@ -31,7 +28,6 @@ from codeinterpreterapi import CodeInterpreterSession
31
  import openai
32
  import os
33
  import matplotlib.pyplot as plt
34
- import xlrd
35
  import pandas as pd
36
  # import csv
37
  import tempfile
@@ -44,14 +40,21 @@ from time import sleep
44
  import streamlit_authenticator as stauth
45
  import database as db # python文件同目录下的.py程序,直接导入。
46
  import deta
 
 
 
 
 
 
 
 
 
47
 
48
  os.environ["OPENAI_API_KEY"] = os.environ['user_token']
49
  openai.api_key = os.environ['user_token']
50
- bing_search_api_key = os.environ['bing_api_key']
51
- bing_search_endpoint = 'https://api.bing.microsoft.com/v7.0/search'
52
  # os.environ["VERBOSE"] = "True" # 可以看到具体的错误?
53
 
54
- # # #* 如果碰到接口问题,可以启用如下设置。
55
  # openai.proxy = {
56
  # "http": "http://127.0.0.1:7890",
57
  # "https": "http://127.0.0.1:7890"
@@ -72,40 +75,80 @@ if reset_button:
72
  st.session_state.messages = []
73
  message_placeholder = st.empty()
74
 
75
-
76
- # with tab2:
77
- def upload_file(uploaded_file):
78
- if uploaded_file is not None:
79
- filename = uploaded_file.name
80
- st.write(filename) # print out the whole file name to validate.
81
- try:
82
- if '.csv' in filename:
83
- csv_file = pd.read_csv(uploaded_file)
84
- csv_file.to_csv('./upload.csv', encoding='utf-8', index=False)
85
- st.write(csv_file[:3]) # 这里只是显示文件,后面需要定位文件所在的绝对路径。
86
- else:
87
- xls_file = pd.read_excel(uploaded_file)
88
- xls_file.to_csv('./upload.csv', index=False)
89
- st.write(xls_file[:3])
90
- except Exception as e:
91
- st.write(e)
92
-
93
- uploaded_file_name = "File_provided"
94
- temp_dir = tempfile.TemporaryDirectory()
95
- # ! working.
96
- uploaded_file_path = pathlib.Path(temp_dir.name) / uploaded_file_name
97
- # with open('./upload.csv', 'wb') as output_temporary_file:
98
- with open(f'./{name}_upload.csv', 'wb') as output_temporary_file:
99
- # print(f'./{name}_upload.csv')
100
- # ! 必须用这种格式读入内容,然后才可以写入temporary文件夹中。
101
- # output_temporary_file.write(uploaded_file.getvalue())
102
- output_temporary_file.write(uploaded_file.getvalue())
103
- # st.write(uploaded_file_path) # * 可以查看文件是否真实存在,然后是否可以
104
- # st.write('Now file saved successfully.')
105
-
106
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
 
 
 
109
 
110
 
111
  def search(query):
@@ -129,17 +172,16 @@ def search(query):
129
 
130
  # openai.api_key = st.secrets["OPENAI_API_KEY"]
131
 
132
-
133
  async def text_mode():
134
  # Set a default model
135
  if "openai_model" not in st.session_state:
136
  st.session_state["openai_model"] = "gpt-3.5-turbo-16k"
137
  if radio_1 == 'GPT-3.5':
138
  # print('----------'*5)
139
- # print('radio_1: GPT-3.5 starts!')
140
  st.session_state["openai_model"] = "gpt-3.5-turbo-16k"
141
  else:
142
- # print('radio_1: GPT-4.0 starts!')
143
  st.session_state["openai_model"] = "gpt-4"
144
 
145
  # Initialize chat history
@@ -154,8 +196,8 @@ async def text_mode():
154
  # Display assistant response in chat message container
155
  # if prompt := st.chat_input("Say something"):
156
  prompt = st.chat_input("Say something")
157
- # print('prompt now:', prompt)
158
- # print('----------'*5)
159
  # if prompt:
160
  if prompt:
161
  st.session_state.messages.append({"role": "user", "content": prompt})
@@ -167,7 +209,7 @@ async def text_mode():
167
  full_response = ""
168
 
169
  if radio_2 == '联网模式':
170
- # print('联网模式入口,prompt:', prompt)
171
  input_message = prompt
172
  internet_search_result = search(input_message)
173
  search_prompt = [
@@ -197,8 +239,8 @@ async def text_mode():
197
  st.session_state.messages = []
198
 
199
  if radio_2 == '核心模式':
200
- # print('GPT only starts!!!')
201
- # print('messages:', st.session_state['messages'])
202
  for response in openai.ChatCompletion.create(
203
  model=st.session_state["openai_model"],
204
  # messages=[
@@ -218,10 +260,95 @@ async def text_mode():
218
  {"role": "assistant", "content": full_response})
219
 
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  async def data_mode():
222
- # print('数据分析模式启动!')
223
  # uploaded_file_path = './upload.csv'
224
- uploaded_file_path = f'./{name}_upload.csv'
 
225
  # # st.write(f"passed file path in data_mode: {uploaded_file_path}")
226
  # tmp1 = pd.read_csv('./upload.csv')
227
  # st.write(tmp1[:5])
@@ -238,8 +365,8 @@ async def data_mode():
238
  # Display assistant response in chat message container
239
  # if prompt := st.chat_input("Say something"):
240
  prompt = st.chat_input("Say something")
241
- # print('prompt now:', prompt)
242
- # print('----------'*5)
243
  # if prompt:
244
  if prompt:
245
  st.session_state.messages.append({"role": "user", "content": prompt})
@@ -269,7 +396,7 @@ async def data_mode():
269
  user_request = environ_settings + "\n\n" + \
270
  "你需要完成以下任务:\n\n" + prompt + "\n\n" \
271
  f"注:文件位置在{uploaded_file_path}"
272
- # print('user_request: \n', user_request)
273
 
274
  # 加载上传的文件,主要路径在上面代码中。
275
  files = [File.from_path(str(uploaded_file_path))]
@@ -281,7 +408,7 @@ async def data_mode():
281
  )
282
 
283
  # output to the user
284
- # print("AI: ", response.content)
285
  full_response = response.content
286
  ### full_response = "this is full response"
287
 
@@ -306,19 +433,21 @@ async def data_mode():
306
  # st.session_state.messages.append({"role": "assistant", "content": full_response})
307
 
308
 
309
- # authentication with a local yaml file.
310
- # import yaml
311
- # from yaml.loader import SafeLoader
312
- # with open('/Users/yunshi/Downloads/360Data/Data Center/Working-On Task/演讲与培训/2023ChatGPT/Coding/code_interpreter/config.yaml') as file:
313
- # config = yaml.load(file, Loader=SafeLoader)
314
- # authenticator = stauth.Authenticate(
315
- # config['credentials'],
316
- # config['cookie']['name'],
317
- # config['cookie']['key'],
318
- # config['cookie']['expiry_days'],
319
- # config['preauthorized']
320
- # )
321
 
 
 
322
  # authentication with a remove cloud-based database.
323
  # 导入云端用户数据库。
324
 
@@ -329,32 +458,28 @@ async def data_mode():
329
 
330
  # deta = Deta(DETA_KEY)
331
 
332
- # mybase is the name of the database in Deta. You can change it to any name you want.
333
- credentials = {"usernames":{}}
334
- # credentials = {"users": {}}
335
- # db = db()
336
- users = []
337
- email = []
338
- passwords = []
339
- names = []
340
 
341
- for row in db.fetch_all_users():
342
- # users.append(row["key"])
343
- # names.append(row["username"])
344
- users.append(row["username"])
345
- email.append(row["email"])
346
- names.append(row["key"])
347
- passwords.append(row["password"])
348
 
349
- hashed_passwords = stauth.Hasher(passwords).generate()
350
 
351
 
352
  ## 需要严格的按照yaml文件的格式来定义如下几个字段。
353
- for un, name, pw in zip(users, names, hashed_passwords):
354
- # user_dict = {"name":name,"password":pw}
355
- user_dict = {"name": un, "password": pw}
356
- # credentials["usernames"].update({un:user_dict})
357
- credentials["usernames"].update({un: user_dict})
358
 
359
  # ## sign-up模块,未完成。
360
  # database_table = []
@@ -366,12 +491,8 @@ for un, name, pw in zip(users, names, hashed_passwords):
366
  # database_table.append([i,credentials['usernames'][i]['name'],credentials['usernames'][i]['password']])
367
  # print("database_table:",database_table)
368
 
369
-
370
- authenticator = stauth.Authenticate(
371
- credentials=credentials, cookie_name="joeshi_gpt", key='abcedefg', cookie_expiry_days=30)
372
-
373
- user, authentication_status, username = authenticator.login('用户登录', 'main')
374
- # print("name", name, "username", username)
375
 
376
  # ## sign-up widget,未完成。
377
  # try:
@@ -383,6 +504,11 @@ user, authentication_status, username = authenticator.login('用户登录', 'mai
383
  # st.success('注册成功!')
384
  # except Exception as e:
385
  # st.error(e)
 
 
 
 
 
386
 
387
  if authentication_status:
388
  with st.sidebar:
@@ -419,7 +545,7 @@ if authentication_status:
419
  with st.text(body="说明"):
420
  st.markdown("* “GPT-4”回答质量极佳,但速度缓慢、且不支持长文。建议适当使用。")
421
  with st.text(body="说明"):
422
- st.markdown("* “联网模式”与搜索引擎一致,仅限一轮对话,不会保持之前的会话记录。")
423
  with st.text(body="说明"):
424
  st.markdown(
425
  "* “数据模式”暂时只支持1000个单元格以内的数据分析,单元格中的内容不支持中文数据(表头也尽量不使用中文)。一般���行时间在1-5分钟左右,期间需要保持网络畅通。")
@@ -458,28 +584,114 @@ if authentication_status:
458
  col1, col2 = st.columns(spec=[1, 2])
459
  radio_2 = col2.radio(label='模式选择', options=[
460
  '核心模式', '联网模式', '知识库模式', '数据模式'], horizontal=True, label_visibility='visible')
461
- # radio_1 = col1.selectbox(label='ChatGPT版本', options=[
462
- # 'GPT-3.5', 'GPT-4.0'], label_visibility='visible')
463
  radio_1 = col1.radio(label='ChatGPT版本', options=[
464
  'GPT-3.5', 'GPT-4.0'], horizontal=True, label_visibility='visible')
465
 
466
  elif authentication_status == False:
467
  st.error('⛔ 用户名或密码错误!')
468
  elif authentication_status == None:
469
- st.warning('🔼 请先登录!')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
 
471
 
472
  if __name__ == "__main__":
473
  import asyncio
474
  try:
475
  if radio_2 == "核心模式":
476
- # print(f'radio 选择了 {radio_2}')
477
  # * 也可以用命令执行这个python文件。’streamlit run frontend/app.py‘
478
  asyncio.run(text_mode())
 
479
  if radio_2 == "联网模式":
480
- # print(f'radio 选择了 {radio_2}')
481
- # * 也可以用命令执行这个python文件。’streamlit run frontend/app.py‘
482
  asyncio.run(text_mode())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
  if radio_2 == "数据模式":
484
  uploaded_file = st.file_uploader(
485
  "选择一个文件", type=(["csv", "xlsx", "xls"]))
 
16
  # credentials["usernames"].update({un:user_dict})
17
  credentials["usernames"].update({un: user_dict})
18
 
 
 
19
  '''
20
+ # TODO:1. Chinese display isssue. 2. account system. 3. local enterprise database.
21
 
 
22
  import database as db
23
  from deta import Deta # pip3 install deta
24
  import requests
 
28
  import openai
29
  import os
30
  import matplotlib.pyplot as plt
 
31
  import pandas as pd
32
  # import csv
33
  import tempfile
 
40
  import streamlit_authenticator as stauth
41
  import database as db # python文件同目录下的.py程序,直接导入。
42
  import deta
43
+ from langchain.chat_models import ChatOpenAI
44
+ from llama_index import StorageContext, load_index_from_storage, GPTVectorStoreIndex, LLMPredictor, PromptHelper
45
+ from llama_index import ServiceContext, QuestionAnswerPrompt
46
+ import sys
47
+ import time
48
+ import PyPDF2 ## read the local_KB PDF file.
49
+ # import localKB_construct
50
+ import save_database_info
51
+ from datetime import datetime
52
 
53
  os.environ["OPENAI_API_KEY"] = os.environ['user_token']
54
  openai.api_key = os.environ['user_token']
 
 
55
  # os.environ["VERBOSE"] = "True" # 可以看到具体的错误?
56
 
57
+ # #* 如果碰到接口问题,可以启用如下设置。
58
  # openai.proxy = {
59
  # "http": "http://127.0.0.1:7890",
60
  # "https": "http://127.0.0.1:7890"
 
75
  st.session_state.messages = []
76
  message_placeholder = st.empty()
77
 
78
+ def clear_all():
79
+ st.session_state.conversation = None
80
+ st.session_state.chat_history = None
81
+ st.session_state.messages = []
82
+ message_placeholder = st.empty()
83
+ return None
84
+
85
+
86
+ # # with tab2:
87
+ # def upload_file(uploaded_file):
88
+ # if uploaded_file is not None:
89
+ # filename = uploaded_file.name
90
+ # # st.write(filename) # print out the whole file name to validate. not to show in the final version.
91
+ # try:
92
+ # if '.pdf' in filename:
93
+ # # pdf_file = PyPDF2.PdfReader(uploaded_file)
94
+ # PyPDF2.PdfReader(uploaded_file)
95
+ # # st.write(pdf_file.pages[0].extract_text())
96
+ # # with st.status('正在为您解析新知识库...', expanded=False, state='running') as status:
97
+ # spinner = st.spinner('正在为您解析新知识库...请耐心等待')
98
+ # # with st.spinner('正在为您解析新知识库...请耐心等待'):
99
+ # with spinner:
100
+ # import localKB_construct
101
+ # # sleep(3)
102
+ # # st.write(upload_file)
103
+ # localKB_construct.process_file(uploaded_file)
104
+ # st.markdown('新知识库解析成功,可以开始对话!')
105
+ # spinner = st.empty()
106
+ # # sleep(3)
107
+ # # display = []
108
+
109
+ # else:
110
+ # if '.csv' in filename:
111
+ # csv_file = pd.read_csv(uploaded_file)
112
+ # csv_file.to_csv('./upload.csv', encoding='utf-8', index=False)
113
+ # st.write(csv_file[:3]) # 这里只是显示文件,后面需要定位文件所在的绝对路径。
114
+ # else:
115
+ # xls_file = pd.read_excel(uploaded_file)
116
+ # xls_file.to_csv('./upload.csv', index=False)
117
+ # st.write(xls_file[:3])
118
+
119
+ # uploaded_file_name = "File_provided"
120
+ # temp_dir = tempfile.TemporaryDirectory()
121
+ # # ! working.
122
+ # uploaded_file_path = pathlib.Path(temp_dir.name) / uploaded_file_name
123
+ # # with open('./upload.csv', 'wb') as output_temporary_file:
124
+ # with open(f'./{name}_upload.csv', 'wb') as output_temporary_file:
125
+ # # print(f'./{name}_upload.csv')
126
+ # # ! 必须用这种格式读入内容,然后才可以写入temporary文件夹中。
127
+ # # output_temporary_file.write(uploaded_file.getvalue())
128
+ # output_temporary_file.write(uploaded_file.getvalue())
129
+ # # st.write(uploaded_file_path) #* 可以查看文件是否真实存在,然后是否可以
130
+ # # st.write('Now file saved successfully.')
131
+ # except Exception as e:
132
+ # st.write(e)
133
+
134
+ # # uploaded_file_name = "File_provided"
135
+ # # temp_dir = tempfile.TemporaryDirectory()
136
+ # # # ! working.
137
+ # # uploaded_file_path = pathlib.Path(temp_dir.name) / uploaded_file_name
138
+ # # # with open('./upload.csv', 'wb') as output_temporary_file:
139
+ # # with open(f'./{name}_upload.csv', 'wb') as output_temporary_file:
140
+ # # # print(f'./{name}_upload.csv')
141
+ # # # ! 必须用这种格式读入内容,然后才可以写入temporary文件夹中。
142
+ # # # output_temporary_file.write(uploaded_file.getvalue())
143
+ # # output_temporary_file.write(uploaded_file.getvalue())
144
+ # # # st.write(uploaded_file_path) # * 可以查看文件是否真实存在,然后是否可以
145
+ # # # st.write('Now file saved successfully.')
146
+
147
+ # return None
148
 
149
 
150
+ bing_search_api_key = os.environ['bing_api_key']
151
+ bing_search_endpoint = 'https://api.bing.microsoft.com/v7.0/search'
152
 
153
 
154
  def search(query):
 
172
 
173
  # openai.api_key = st.secrets["OPENAI_API_KEY"]
174
 
 
175
  async def text_mode():
176
  # Set a default model
177
  if "openai_model" not in st.session_state:
178
  st.session_state["openai_model"] = "gpt-3.5-turbo-16k"
179
  if radio_1 == 'GPT-3.5':
180
  # print('----------'*5)
181
+ print('radio_1: GPT-3.5 starts!')
182
  st.session_state["openai_model"] = "gpt-3.5-turbo-16k"
183
  else:
184
+ print('radio_1: GPT-4.0 starts!')
185
  st.session_state["openai_model"] = "gpt-4"
186
 
187
  # Initialize chat history
 
196
  # Display assistant response in chat message container
197
  # if prompt := st.chat_input("Say something"):
198
  prompt = st.chat_input("Say something")
199
+ print('prompt now:', prompt)
200
+ print('----------'*5)
201
  # if prompt:
202
  if prompt:
203
  st.session_state.messages.append({"role": "user", "content": prompt})
 
209
  full_response = ""
210
 
211
  if radio_2 == '联网模式':
212
+ print('联网模式入口,prompt:', prompt)
213
  input_message = prompt
214
  internet_search_result = search(input_message)
215
  search_prompt = [
 
239
  st.session_state.messages = []
240
 
241
  if radio_2 == '核心模式':
242
+ print('GPT only starts!!!')
243
+ print('messages:', st.session_state['messages'])
244
  for response in openai.ChatCompletion.create(
245
  model=st.session_state["openai_model"],
246
  # messages=[
 
260
  {"role": "assistant", "content": full_response})
261
 
262
 
263
+ ## load the local_KB PDF file.
264
+ # async def localKB_mode():
265
+ def localKB_mode(username):
266
+ ### clear all the prior conversation.
267
+ st.session_state.conversation = None
268
+ st.session_state.chat_history = None
269
+ st.session_state.messages = []
270
+ message_placeholder = st.empty()
271
+
272
+ print('now starts the local KB version of ChatGPT')
273
+ # Initialize chat history
274
+ if "messages" not in st.session_state:
275
+ st.session_state.messages = []
276
+
277
+ for message in st.session_state.messages:
278
+ with st.chat_message(message["role"]):
279
+ st.markdown(message["content"])
280
+
281
+ # Display assistant response in chat message container
282
+ # if prompt := st.chat_input("Say something"):
283
+ # prompt = st.chat_input("Say something")
284
+ # print('prompt now:', prompt)
285
+ # print('----------'*5)
286
+ # if prompt:
287
+ if prompt := st.chat_input("Say something"):
288
+ st.session_state.messages.append({"role": "user", "content": prompt})
289
+ with st.chat_message("user"):
290
+ st.markdown(prompt)
291
+
292
+ with st.status('检索中...', expanded=True, state='running') as status:
293
+ with st.chat_message("assistant"):
294
+ message_placeholder = st.empty()
295
+ full_response = ""
296
+
297
+ # if radio_2 == "知识库模式":
298
+ # ! 这里需要重新装载一下storage_context。
299
+ QA_PROMPT_TMPL = (
300
+ "We have provided context information below. \n"
301
+ "---------------------\n"
302
+ "{context_str}"
303
+ "\n---------------------\n"
304
+ "Given all this information, please answer the following questions,"
305
+ "You MUST use the SAME language as the question:\n"
306
+ "{query_str}\n")
307
+ QA_PROMPT = QuestionAnswerPrompt(QA_PROMPT_TMPL)
308
+ # print('QA_PROMPT:', QA_PROMPT)
309
+
310
+ # llm_predictor = LLMPredictor(llm=ChatOpenAI(temperature=0.8, model_name="gpt-3.5-turbo", max_tokens=4024,streaming=True))
311
+ # # print('llm_predictor:', llm_predictor)
312
+ # prompt_helper = PromptHelper(max_input_size, num_outputs, max_chunk_overlap, chunk_size_limit=chunk_size_limit)
313
+ # print('prompt_helper:', prompt_helper)
314
+ # service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
315
+ # print('service_context:', service_context)
316
+ # # # index = load_index_from_storage(storage_context)
317
+ # print("storage_context:", storage_context)
318
+ # index = load_index_from_storage(storage_context,service_context=service_context)
319
+ storage_context = StorageContext.from_defaults(persist_dir=f"./{username}/")
320
+ print('storage_context:',storage_context)
321
+ index = load_index_from_storage(storage_context)
322
+
323
+ # query_engine = index.as_query_engine(streaming=True, similarity_top_k=3, text_qa_template=QA_PROMPT)
324
+ query_engine = index.as_query_engine(streaming=True)
325
+ # query_engine = index.as_query_engine(streaming=True, text_qa_template=QA_PROMPT)
326
+ # query_engine = index.as_query_engine(streaming=False, text_qa_template=QA_PROMPT)
327
+ # query_engine = index.as_query_engine()
328
+ # reply = query_engine.query(prompt)
329
+
330
+ llama_index_reply = query_engine.query(prompt)
331
+ # full_response += query_engine.query(prompt)
332
+ print('local KB reply:', llama_index_reply)
333
+ # query_engine.query(prompt).print_response_stream() #* 能在terminal中流式输出。
334
+ # for resp in llama_index_reply.response_gen:
335
+ # print(resp)
336
+ # full_response += resp
337
+ # message_placeholder.markdown(full_response + "▌")
338
+ message_placeholder.markdown(llama_index_reply)
339
+ # st.session_state.messages.append(
340
+ # {"role": "assistant", "content": full_response})
341
+ # st.session_state.messages = []
342
+ # full_response += reply
343
+ # full_response = reply
344
+ # st.session_state.messages.append(
345
+ # {"role": "assistant", "content": full_response})
346
+
347
  async def data_mode():
348
+ print('数据分析模式启动!')
349
  # uploaded_file_path = './upload.csv'
350
+ # uploaded_file_path = f'./{joejoe}_upload.csv'
351
+ uploaded_file_path = f'./joejoe_upload.csv'
352
  # # st.write(f"passed file path in data_mode: {uploaded_file_path}")
353
  # tmp1 = pd.read_csv('./upload.csv')
354
  # st.write(tmp1[:5])
 
365
  # Display assistant response in chat message container
366
  # if prompt := st.chat_input("Say something"):
367
  prompt = st.chat_input("Say something")
368
+ print('prompt now:', prompt)
369
+ print('----------'*5)
370
  # if prompt:
371
  if prompt:
372
  st.session_state.messages.append({"role": "user", "content": prompt})
 
396
  user_request = environ_settings + "\n\n" + \
397
  "你需要完成以下任务:\n\n" + prompt + "\n\n" \
398
  f"注:文件位置在{uploaded_file_path}"
399
+ print('user_request: \n', user_request)
400
 
401
  # 加载上传的文件,主要路径在上面代码中。
402
  files = [File.from_path(str(uploaded_file_path))]
 
408
  )
409
 
410
  # output to the user
411
+ print("AI: ", response.content)
412
  full_response = response.content
413
  ### full_response = "this is full response"
414
 
 
433
  # st.session_state.messages.append({"role": "assistant", "content": full_response})
434
 
435
 
436
+ ### authentication with a local yaml file.
437
+ import yaml
438
+ from yaml.loader import SafeLoader
439
+ with open('/Users/yunshi/Downloads/360Data/Data Center/Working-On Task/演讲与培训/2023ChatGPT/Coding/code_interpreter/config.yaml') as file:
440
+ config = yaml.load(file, Loader=SafeLoader)
441
+ authenticator = stauth.Authenticate(
442
+ config['credentials'],
443
+ config['cookie']['name'],
444
+ config['cookie']['key'],
445
+ config['cookie']['expiry_days'],
446
+ config['preauthorized']
447
+ )
448
 
449
+
450
+ ###'''authentication with a remove cloud-based database.'''
451
  # authentication with a remove cloud-based database.
452
  # 导入云端用户数据库。
453
 
 
458
 
459
  # deta = Deta(DETA_KEY)
460
 
461
+ # # mybase is the name of the database in Deta. You can change it to any name you want.
462
+ # credentials = {"usernames":{}}
463
+ # users = []
464
+ # email = []
465
+ # passwords = []
466
+ # names = []
 
 
467
 
468
+ # for row in db.fetch_all_users():
469
+ # users.append(row["username"])
470
+ # email.append(row["email"])
471
+ # names.append(row["key"])
472
+ # passwords.append(row["password"])
 
 
473
 
474
+ # hashed_passwords = stauth.Hasher(passwords).generate()
475
 
476
 
477
  ## 需要严格的按照yaml文件的格式来定义如下几个字段。
478
+ # for un, name, pw in zip(users, names, hashed_passwords):
479
+ # # user_dict = {"name":name,"password":pw}
480
+ # user_dict = {"name": un, "password": pw}
481
+ # # credentials["usernames"].update({un:user_dict})
482
+ # credentials["usernames"].update({un: user_dict})
483
 
484
  # ## sign-up模块,未完成。
485
  # database_table = []
 
491
  # database_table.append([i,credentials['usernames'][i]['name'],credentials['usernames'][i]['password']])
492
  # print("database_table:",database_table)
493
 
494
+ # authenticator = stauth.Authenticate(
495
+ # credentials=credentials, cookie_name="joeshi_gpt", key='abcedefg', cookie_expiry_days=30)
 
 
 
 
496
 
497
  # ## sign-up widget,未完成。
498
  # try:
 
504
  # st.success('注册成功!')
505
  # except Exception as e:
506
  # st.error(e)
507
+ ''''''
508
+
509
+ # user, authentication_status, username = authenticator.login('用户登录', 'main')
510
+ user, authentication_status, username = authenticator.login('用户登录', 'sidebar')
511
+ # print("name", name, "username", username)
512
 
513
  if authentication_status:
514
  with st.sidebar:
 
545
  with st.text(body="说明"):
546
  st.markdown("* “GPT-4”回答质量极佳,但速度缓慢、且不支持长文。建议适当使用。")
547
  with st.text(body="说明"):
548
+ st.markdown("* “联网模式”和“知识库模式”均基于检索功能,仅限一轮对话,不会保持之前的会话记录。")
549
  with st.text(body="说明"):
550
  st.markdown(
551
  "* “数据模式”暂时只支持1000个单元格以内的数据分析,单元格中的内容不支持中文数据(表头也尽量不使用中文)。一般���行时间在1-5分钟左右,期间需要保持网络畅通。")
 
584
  col1, col2 = st.columns(spec=[1, 2])
585
  radio_2 = col2.radio(label='模式选择', options=[
586
  '核心模式', '联网模式', '知识库模式', '数据模式'], horizontal=True, label_visibility='visible')
 
 
587
  radio_1 = col1.radio(label='ChatGPT版本', options=[
588
  'GPT-3.5', 'GPT-4.0'], horizontal=True, label_visibility='visible')
589
 
590
  elif authentication_status == False:
591
  st.error('⛔ 用户名或密码错误!')
592
  elif authentication_status == None:
593
+ st.warning(' 请先登录!')
594
+
595
+ ### 上传文件的模块
596
+ def upload_file(uploaded_file):
597
+ if uploaded_file is not None:
598
+ filename = uploaded_file.name
599
+ # st.write(filename) # print out the whole file name to validate. not to show in the final version.
600
+ try:
601
+ if '.pdf' in filename:
602
+ # pdf_file = PyPDF2.PdfReader(uploaded_file)
603
+ PyPDF2.PdfReader(uploaded_file)
604
+ # st.write(pdf_file.pages[0].extract_text())
605
+ # with st.status('正在为您解析新知识库...', expanded=False, state='running') as status:
606
+ spinner = st.spinner('正在为您解析新知识库...请耐心等待')
607
+ # with st.spinner('正在为您解析新知识库...请耐心等待'):
608
+ with spinner:
609
+ import localKB_construct
610
+ # st.write(upload_file)
611
+ localKB_construct.process_file(uploaded_file, username)
612
+ save_database_info.save_database_info(f'./{username}/database_name.csv', filename, str(datetime.now().strftime("%Y-%m-%d %H:%M")))
613
+ st.markdown('新知识库解析成功,请务必刷新页面,然后开启对话 🔁')
614
+ # spinner = st.empty()
615
+
616
+ else:
617
+ if '.csv' in filename:
618
+ csv_file = pd.read_csv(uploaded_file)
619
+ csv_file.to_csv(f'./{username}/upload.csv', encoding='utf-8', index=False)
620
+ st.write(csv_file[:3]) # 这里只是显示文件,后面需要定位文件所在的绝对路径。
621
+ else:
622
+ xls_file = pd.read_excel(uploaded_file)
623
+ xls_file.to_csv(f'./{username}/upload.csv', index=False)
624
+ st.write(xls_file[:3])
625
+
626
+ uploaded_file_name = "File_provided"
627
+ temp_dir = tempfile.TemporaryDirectory()
628
+ # ! working.
629
+ uploaded_file_path = pathlib.Path(temp_dir.name) / uploaded_file_name
630
+ # with open('./upload.csv', 'wb') as output_temporary_file:
631
+ with open(f'./{username}_upload.csv', 'wb') as output_temporary_file:
632
+ # print(f'./{name}_upload.csv')
633
+ # ! 必须用这种格式读入内容,然后才可以写入temporary文件夹中。
634
+ # output_temporary_file.write(uploaded_file.getvalue())
635
+ output_temporary_file.write(uploaded_file.getvalue())
636
+ # st.write(uploaded_file_path) #* 可以查看文件是否真实存在,然后是否可以
637
+ # st.write('Now file saved successfully.')
638
+ except Exception as e:
639
+ st.write(e)
640
+
641
+ ## 以下代码是为了解决上传文件后,文件路径和文件名不对的问题。
642
+ # uploaded_file_name = "File_provided"
643
+ # temp_dir = tempfile.TemporaryDirectory()
644
+ # # ! working.
645
+ # uploaded_file_path = pathlib.Path(temp_dir.name) / uploaded_file_name
646
+ # # with open('./upload.csv', 'wb') as output_temporary_file:
647
+ # with open(f'./{name}_upload.csv', 'wb') as output_temporary_file:
648
+ # # print(f'./{name}_upload.csv')
649
+ # # ! 必须用这种格式读入内容,然后才可以写入temporary文件夹中。
650
+ # # output_temporary_file.write(uploaded_file.getvalue())
651
+ # output_temporary_file.write(uploaded_file.getvalue())
652
+ # # st.write(uploaded_file_path) # * 可以查看文件是否真实存在,然后是否可以
653
+ # # st.write('Now file saved successfully.')
654
+
655
+ return None
656
 
657
 
658
  if __name__ == "__main__":
659
  import asyncio
660
  try:
661
  if radio_2 == "核心模式":
662
+ print(f'radio 选择了 {radio_2}')
663
  # * 也可以用命令执行这个python文件。’streamlit run frontend/app.py‘
664
  asyncio.run(text_mode())
665
+
666
  if radio_2 == "联网模式":
667
+ print(f'radio 选择了 {radio_2}')
 
668
  asyncio.run(text_mode())
669
+
670
+ if radio_2 == "知识库模式":
671
+ print(f'radio 选择了 {radio_2}')
672
+
673
+ path = f'./{username}/vector_store.json'
674
+ if os.path.exists(path):
675
+ database_info = pd.read_csv(f'./{username}/database_name.csv')
676
+ current_database_name = database_info.iloc[-1][0]
677
+ current_database_date = database_info.iloc[-1][1]
678
+ database_claim = f"当前知识库为:{current_database_name},创建于{current_database_date}。可以开始提问!"
679
+ st.markdown(database_claim)
680
+ # st.markdown("注意:系统中已经存在一个知识库,您现在可以开始提问!")
681
+
682
+ uploaded_file = st.file_uploader(
683
+ "选择上传一个新知识库", type=(["pdf"]))
684
+ # 默认状态下没有上传文件,None,会报错。需要判断。
685
+ if uploaded_file is not None:
686
+ # uploaded_file_path = upload_file(uploaded_file)
687
+ upload_file(uploaded_file)
688
+ # st.write('PDF file uploaded sucessfully!')
689
+ # clear_all()
690
+ # spinner = st.empty()
691
+
692
+ localKB_mode(username)
693
+ # asyncio.run(localKB_mode())
694
+
695
  if radio_2 == "数据模式":
696
  uploaded_file = st.file_uploader(
697
  "选择一个文件", type=(["csv", "xlsx", "xls"]))
localKB_construct copy.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ 1.更新了llama-index的库。对应的函数名和用法都有所改变。
3
+ '''
4
+
5
+ # import gradio as gr
6
+ import openai
7
+ import requests
8
+ import csv
9
+ from llama_index import PromptHelper
10
+ # from llama_index import GPTSimpleVectorIndex ## renamed in the latest version.
11
+ from llama_index import LLMPredictor
12
+ from llama_index import ServiceContext
13
+ from langchain.chat_models import ChatOpenAI
14
+ from langchain import OpenAI
15
+ from fastapi import FastAPI #* 实现流式数据
16
+ from fastapi.responses import StreamingResponse #* 实现流式数据
17
+ import sys
18
+ import os
19
+ import torch
20
+ import math
21
+ import pandas as pd
22
+ import numpy as np
23
+ import PyPDF2
24
+ # from llama_index import SimpleDirectoryReader, GPTListIndex, readers, GPTSimpleVectorIndex, LLMPredictor, PromptHelper #* working in the previous version.
25
+
26
+ ##* in the latest version: GPTSimpleVectorIndex was renamed to GPTVectorStoreIndex, try removing it from the end of your imports
27
+ from llama_index import SimpleDirectoryReader, GPTListIndex, readers, GPTVectorStoreIndex, LLMPredictor, PromptHelper
28
+ from llama_index import StorageContext, load_index_from_storage
29
+ from llama_index import ServiceContext
30
+ from llama_index import download_loader
31
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
32
+ import sys
33
+ import os
34
+ from rich import print
35
+
36
+ ## enironment settings.
37
+ os.environ["OPENAI_API_KEY"] = "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"
38
+ openai.api_key = "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"
39
+ # file_path = "/Users/yunshi/Downloads/txt_dir/Sparks_of_AGI.pdf"
40
+ # file_path = "/Users/yunshi/Downloads/txt_dir/2023年百人会电动论坛 纪要 20230401.pdf"
41
+
42
+ ## 建立index或者的过程。
43
+ def construct_index(directory_path):
44
+ # file_path = f"{directory_path}/uploaded_file.pdf"
45
+
46
+ file_path = directory_path
47
+
48
+ # set maximum input si771006
49
+ # max_input_size = 4096 #* working
50
+ max_input_size = 4096
51
+ # set number of output tokens
52
+ # num_outputs = 3000 #* working
53
+ num_outputs = 1000
54
+ # set maximum chunk overlap
55
+ max_chunk_overlap = -1000 #* working
56
+ # set chunk size limit
57
+ # chunk_size_limit = 600
58
+ chunk_size_limit = 6000 #* working
59
+
60
+ # ## add chunk_overlap_ratio according to github.
61
+ # chunk_overlap_ratio= 0.1
62
+
63
+
64
+ # define LLM
65
+ # llm_predictor = LLMPredictor(llm=OpenAI(temperature=0.5, model_name="gpt-3.5-turbo", max_tokens=2000))
66
+ llm_predictor = LLMPredictor(llm=ChatOpenAI(temperature=0.7, model_name="gpt-3.5-turbo-16k", max_tokens=512,streaming=True))
67
+
68
+ ## 好像work了,2023.09.22, 注意这里的写法有调整。
69
+ # prompt_helper = PromptHelper(max_input_s≈ize, num_outputs, max_chunk_overlap, chunk_size_limit=chunk_size_limit)
70
+ prompt_helper = PromptHelper(max_input_size, num_outputs, chunk_overlap_ratio= 0.1, chunk_size_limit=chunk_size_limit)
71
+ service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
72
+
73
+ ## 如果是txt文件,那么需要用如下命令。注意与PDF文件的区别。
74
+ # documents = SimpleDirectoryReader(directory_path).load_data()
75
+
76
+ ## 如果是PDF文件,那么需要用如下命令。注意与txt文件的区别。切需要from llama_index import download_loader。
77
+ #NOTE: 这里可以问:give me an example of GPT-4 solving math problem. 会回答关于这个PDF中的内容,所以可以确认这个程序调用了in-context learning的功能。
78
+ CJKPDFReader = download_loader("CJKPDFReader")
79
+ loader = CJKPDFReader()
80
+ # documents = loader.load_data(file=directory_path) #! 注意这里是指向文件本身,而不同于txt文件的指文件夹。
81
+ documents = loader.load_data(file=directory_path) #! 注意这里是指向文件本身,而不同于txt文件的指文件夹。
82
+ # index = GPTSimpleVectorIndex(
83
+ # documents, llm_predictor=llm_predictor, prompt_helper=prompt_helper
84
+ # )
85
+
86
+ # index = GPTSimpleVectorIndex.from_documents(documents, service_context=service_context) ## oringinal version, working.
87
+ index = GPTVectorStoreIndex.from_documents(documents, service_context=service_context) #* the funciton renamed.
88
+ # index.save_to_disk('/Users/yunshi/Downloads/txt_dir/index.json') ## in the latest version, this function is not working.
89
+
90
+ return index, service_context
91
+
92
+ def process_file():
93
+ print('process_file starts')
94
+ file_path = "/Users/yunshi/Downloads/txt_dir/Sparks_of_AGI.pdf"
95
+ #! 第一次运行是需要开启这个function。如果测试通过index,因此不需要在运行了。记得上传PDF和JSON文件到云服务器上。
96
+ index, service_context = construct_index(file_path)
97
+ # index.storage_context.persist(persist_dir="/Users/yunshi/Downloads/txt_dir/") #* 存储到本地,为以后调用。
98
+ index.storage_context.persist(persist_dir=f"./") #* 存储到本地,为以后调用。
99
+ print(index)
100
+
101
+ process_file()
save_database_info.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import numpy as np
3
+ import pandas as pd
4
+ import re
5
+ from re import sub
6
+ import smtplib
7
+ import matplotlib.pyplot as plt
8
+ from itertools import product
9
+ from tqdm import tqdm_notebook, tqdm, trange
10
+ import time
11
+ import seaborn as sns
12
+ from matplotlib.pyplot import style
13
+ from rich import print
14
+ import warnings
15
+ warnings.filterwarnings('ignore')
16
+ sns.set()
17
+ # style.use('seaborn')
18
+
19
+ import csv
20
+
21
+ def save_database_info(filepath, database_name, date):
22
+ # 读取CSV文件
23
+ with open('/Users/yunshi/Downloads/360Data/Data Center/Working-On Task/演讲与培训/2023ChatGPT/Coding/code_interpreter/test/database_name.csv', 'r', encoding='utf-8') as file:
24
+ # 创建CSV读取器
25
+ reader = csv.reader(file)
26
+
27
+ # 将内容存储到列表中
28
+ rows = []
29
+ for row in reader:
30
+ rows.append(row)
31
+
32
+ # 添加新行
33
+ # new_row = ['New Data 1', 'New Data 2'] # 新行的数据
34
+ new_row = [database_name, date] # 新行的数据
35
+ rows.append(new_row)
36
+
37
+ # 写入CSV文件
38
+ with open('/Users/yunshi/Downloads/360Data/Data Center/Working-On Task/演讲与培训/2023ChatGPT/Coding/code_interpreter/test/database_name.csv', 'w', newline='', encoding='utf-8') as file:
39
+ # 创建CSV写入器
40
+ writer = csv.writer(file)
41
+ # 写入所有行
42
+ writer.writerows(rows)
43
+
44
+ # close the file to save the data.
45
+ file.close()
46
+
47
+ return None