Spaces:
Paused
Paused
Commit
·
b8484c0
1
Parent(s):
785e85d
add the google sheet connection
Browse files
app.py
CHANGED
@@ -2,6 +2,7 @@ import streamlit as st
|
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
import datetime
|
|
|
5 |
import pickle
|
6 |
import os
|
7 |
import csv
|
@@ -42,12 +43,6 @@ def load_scraped_web_info():
|
|
42 |
chunked_text = text_splitter.create_documents([doc for doc in tqdm(ait_web_documents)])
|
43 |
|
44 |
|
45 |
-
# st.markdown(f"Number of Documents: {len(ait_web_documents)}")
|
46 |
-
# st.markdown(f"Number of chunked texts: {len(chunked_text)}")
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
@st.cache_resource
|
52 |
def load_embedding_model():
|
53 |
embedding_model = HuggingFaceInstructEmbeddings(model_name='hkunlp/instructor-base',
|
@@ -74,8 +69,6 @@ def load_llm_model():
|
|
74 |
model_kwargs={ "max_length": 256, "temperature": 0,
|
75 |
"torch_dtype":torch.float32,
|
76 |
"repetition_penalty": 1.3})
|
77 |
-
|
78 |
-
|
79 |
return llm
|
80 |
|
81 |
|
@@ -85,6 +78,45 @@ def load_retriever(llm, db):
|
|
85 |
|
86 |
return qa_retriever
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
#--------------
|
89 |
|
90 |
|
@@ -92,9 +124,14 @@ if "history" not in st.session_state:
|
|
92 |
st.session_state.history = []
|
93 |
if "session_rating" not in st.session_state:
|
94 |
st.session_state.session_rating = 0
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
|
100 |
load_scraped_web_info()
|
@@ -108,20 +145,7 @@ print("all load done")
|
|
108 |
|
109 |
|
110 |
|
111 |
-
def retrieve_document(query_input):
|
112 |
-
related_doc = vector_database.similarity_search(query_input)
|
113 |
-
return related_doc
|
114 |
|
115 |
-
def retrieve_answer(query_input):
|
116 |
-
prompt_answer= query_input + " " + "Try to elaborate as much as you can."
|
117 |
-
answer = qa_retriever.run(prompt_answer)
|
118 |
-
output = st.text_area(label="Retrieved documents", value=answer)
|
119 |
-
|
120 |
-
st.markdown('---')
|
121 |
-
score = st.radio(label = 'please select the overall satifaction and helpfullness of the bot answer', options=[1,2,3,4,5], horizontal=True,
|
122 |
-
on_change=update_score, key='rating')
|
123 |
-
|
124 |
-
return answer
|
125 |
|
126 |
|
127 |
|
@@ -134,10 +158,13 @@ st.markdown("""
|
|
134 |
st.write(' ⚠️ Please expect to wait **~ 10 - 20 seconds per question** as thi app is running on CPU against 3-billion-parameter LLM')
|
135 |
|
136 |
st.markdown("---")
|
|
|
|
|
|
|
|
|
137 |
|
138 |
-
|
139 |
-
|
140 |
-
generate_button = st.button(label = 'Submit!')
|
141 |
|
142 |
if generate_button:
|
143 |
answer = retrieve_answer(query_input)
|
@@ -147,21 +174,36 @@ if generate_button:
|
|
147 |
"rating":st.session_state.session_rating }
|
148 |
|
149 |
st.session_state.history.append(log)
|
|
|
150 |
|
151 |
|
|
|
|
|
152 |
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
import datetime
|
5 |
+
import gspread
|
6 |
import pickle
|
7 |
import os
|
8 |
import csv
|
|
|
43 |
chunked_text = text_splitter.create_documents([doc for doc in tqdm(ait_web_documents)])
|
44 |
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
@st.cache_resource
|
47 |
def load_embedding_model():
|
48 |
embedding_model = HuggingFaceInstructEmbeddings(model_name='hkunlp/instructor-base',
|
|
|
69 |
model_kwargs={ "max_length": 256, "temperature": 0,
|
70 |
"torch_dtype":torch.float32,
|
71 |
"repetition_penalty": 1.3})
|
|
|
|
|
72 |
return llm
|
73 |
|
74 |
|
|
|
78 |
|
79 |
return qa_retriever
|
80 |
|
81 |
+
def retrieve_document(query_input):
|
82 |
+
related_doc = vector_database.similarity_search(query_input)
|
83 |
+
return related_doc
|
84 |
+
|
85 |
+
def retrieve_answer(query_input):
|
86 |
+
prompt_answer= query_input + " " + "Try to elaborate as much as you can."
|
87 |
+
answer = qa_retriever.run(prompt_answer)
|
88 |
+
output = st.text_area(label="Retrieved documents", value=answer)
|
89 |
+
|
90 |
+
st.markdown('---')
|
91 |
+
score = st.radio(label = 'please select the rating score for overall satifaction and helpfullness of the bot answer', options=[0, 1,2,3,4,5], horizontal=True,
|
92 |
+
on_change=update_worksheet_qa, key='rating')
|
93 |
+
|
94 |
+
return answer
|
95 |
+
|
96 |
+
# def update_score():
|
97 |
+
# st.session_state.session_rating = st.session_state.rating
|
98 |
+
|
99 |
+
|
100 |
+
def update_worksheet_qa():
|
101 |
+
st.session_state.session_rating = st.session_state.rating
|
102 |
+
#This if helps validate the initiated rating, if 0, then the google sheet would not be updated
|
103 |
+
if st.session_state.session_rating == 0:
|
104 |
+
pass
|
105 |
+
else:
|
106 |
+
worksheet_qa.append_row([st.session_state.history[-1]['timestamp'].strftime(datetime_format),
|
107 |
+
st.session_state.history[-1]['question'],
|
108 |
+
st.session_state.history[-1]['generated_answer'],
|
109 |
+
st.session_state.session_rating
|
110 |
+
])
|
111 |
+
|
112 |
+
def update_worksheet_comment():
|
113 |
+
worksheet_comment.append_row([datetime.datetime.now().strftime(datetime_format),
|
114 |
+
feedback_input])
|
115 |
+
success_message = st.success('Feedback successfully submitted, thank you', icon="✅",
|
116 |
+
)
|
117 |
+
time.sleep(3)
|
118 |
+
success_message.empty()
|
119 |
+
|
120 |
#--------------
|
121 |
|
122 |
|
|
|
124 |
st.session_state.history = []
|
125 |
if "session_rating" not in st.session_state:
|
126 |
st.session_state.session_rating = 0
|
127 |
+
|
128 |
+
|
129 |
+
service_account = gspread.service_account_from_dict(credential)
|
130 |
+
workbook= service_account.open("aitGPT-qa-log")
|
131 |
+
worksheet_qa = workbook.worksheet("Sheet1")
|
132 |
+
worksheet_comment = workbook.worksheet("Sheet2")
|
133 |
+
datetime_format= "%Y-%m-%d %H:%M:%S"
|
134 |
+
|
135 |
|
136 |
|
137 |
load_scraped_web_info()
|
|
|
145 |
|
146 |
|
147 |
|
|
|
|
|
|
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
|
151 |
|
|
|
158 |
st.write(' ⚠️ Please expect to wait **~ 10 - 20 seconds per question** as thi app is running on CPU against 3-billion-parameter LLM')
|
159 |
|
160 |
st.markdown("---")
|
161 |
+
st.write(" ")
|
162 |
+
st.write("""
|
163 |
+
### ❔ Ask a question
|
164 |
+
""")
|
165 |
|
166 |
+
query_input = st.text_area(label= 'What would you like to know about AIT?' , key = 'my_text_input')
|
167 |
+
generate_button = st.button(label = 'Ask question!')
|
|
|
168 |
|
169 |
if generate_button:
|
170 |
answer = retrieve_answer(query_input)
|
|
|
174 |
"rating":st.session_state.session_rating }
|
175 |
|
176 |
st.session_state.history.append(log)
|
177 |
+
update_worksheet_qa()
|
178 |
|
179 |
|
180 |
+
st.write(" ")
|
181 |
+
st.write(" ")
|
182 |
|
183 |
+
st.markdown("---")
|
184 |
+
st.write("""
|
185 |
+
### 💌 Your voice matters
|
186 |
+
""")
|
187 |
+
|
188 |
+
feedback_input = st.text_area(label= 'please leave your feedback or any ideas to make this bot more knowledgeable and fun')
|
189 |
+
feedback_button = st.button(label = 'Submit feedback!')
|
190 |
+
|
191 |
+
if feedback_button:
|
192 |
+
update_worksheet_comment()
|
193 |
+
|
194 |
+
|
195 |
+
# if st.session_state.session_rating == 0:
|
196 |
+
# pass
|
197 |
+
# else:
|
198 |
+
# with open('test_db', 'a') as csvfile:
|
199 |
+
# writer = csv.writer(csvfile)
|
200 |
+
# writer.writerow([st.session_state.history[-1]['timestamp'], st.session_state.history[-1]['question'],
|
201 |
+
# st.session_state.history[-1]['generated_answer'], st.session_state.session_rating ])
|
202 |
+
# st.session_state.session_rating = 0
|
203 |
+
|
204 |
+
# test_df = pd.read_csv("test_db", index_col=0)
|
205 |
+
# test_df.sort_values(by = ['timestamp'],
|
206 |
+
# axis=0,
|
207 |
+
# ascending=False,
|
208 |
+
# inplace=True)
|
209 |
+
# st.dataframe(test_df)
|