gerasdf
commited on
Commit
·
70798cd
1
Parent(s):
fc4c25e
lost of changes, history almost working (continuing an old chat not working, as SystemPrompt() won't be there)
Browse files
query.py
CHANGED
@@ -14,15 +14,24 @@ from elevenlabs.client import ElevenLabs
|
|
14 |
from openai import OpenAI
|
15 |
|
16 |
from json import loads as json_loads
|
|
|
17 |
import time
|
18 |
import os
|
19 |
|
20 |
-
prompt_template = os.environ.get("PROMPT_TEMPLATE")
|
21 |
-
|
22 |
-
prompt = ChatPromptTemplate.from_messages([('system', prompt_template)])
|
23 |
-
|
24 |
AI = True
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
def ai_setup():
|
27 |
global llm, prompt_chain, oai_client
|
28 |
|
@@ -38,17 +47,19 @@ def ai_setup():
|
|
38 |
)
|
39 |
|
40 |
retriever = vstore.as_retriever(search_kwargs={'k': 10})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
else:
|
42 |
retriever = RunnableLambda(just_read)
|
43 |
|
44 |
-
prompt_chain = (
|
45 |
-
{"context": retriever, "question": RunnablePassthrough()}
|
46 |
-
| RunnableLambda(format_context)
|
47 |
-
| prompt
|
48 |
-
# | llm
|
49 |
-
# | StrOutputParser()
|
50 |
-
)
|
51 |
-
|
52 |
def group_and_sort(documents):
|
53 |
grouped = {}
|
54 |
for document in documents:
|
@@ -107,32 +118,105 @@ def new_state():
|
|
107 |
def session_id(state: dict, request: gr.Request) -> str:
|
108 |
return f'{state["user"]}_{request.session_hash}'
|
109 |
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
|
|
|
|
|
|
|
|
|
|
114 |
tokens=os.environ.get("APP_TOKENS")
|
115 |
if not tokens:
|
116 |
state["user"] = "anonymous"
|
117 |
else:
|
118 |
tokens=json_loads(tokens)
|
119 |
state["user"] = tokens.get(token, None)
|
120 |
-
|
121 |
-
if state["user"]:
|
122 |
-
if store is None:
|
123 |
-
store = AstraDBStore(
|
124 |
-
collection_name=f'{os.environ.get("ASTRA_DB_COLLECTION")}_sessions',
|
125 |
-
token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"),
|
126 |
-
api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"),
|
127 |
-
)
|
128 |
-
user_session = session_id(state, request)
|
129 |
-
session_data = {
|
130 |
-
'user' : state["user"],
|
131 |
-
'session' : request.session_hash,
|
132 |
-
'timestamp' : time.asctime(time.gmtime())
|
133 |
-
}
|
134 |
-
store.mset([(user_session, session_data)])
|
135 |
-
|
136 |
return "", state
|
137 |
|
138 |
AUTH_JS = """function auth_js(token, state) {
|
@@ -150,24 +234,40 @@ def not_authenticated(state):
|
|
150 |
gr.Warning("You need to authenticate first")
|
151 |
return answer
|
152 |
|
153 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
if not state["history"]:
|
155 |
-
|
156 |
-
state["history"] =
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"),
|
161 |
)
|
162 |
-
|
163 |
-
history = state["history"]
|
164 |
-
if type == "system":
|
165 |
-
history.add_message(message)
|
166 |
-
elif type == "user":
|
167 |
-
history.add_user_message(message)
|
168 |
-
elif type == "ai":
|
169 |
-
history.add_ai_message(message)
|
170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
def chat(message, history, state, request:gr.Request):
|
172 |
if not_authenticated(state):
|
173 |
yield "You need to authenticate first"
|
@@ -177,7 +277,9 @@ def chat(message, history, state, request:gr.Request):
|
|
177 |
system_prompt = prompt_chain.invoke(message)
|
178 |
system_prompt = system_prompt.messages[0]
|
179 |
state["system"] = system_prompt
|
180 |
-
|
|
|
|
|
181 |
else:
|
182 |
system_prompt = state["system"]
|
183 |
|
@@ -235,14 +337,13 @@ def play_last(history, state):
|
|
235 |
response = lab11.generate(text=text, voice=whatson, stream=True)
|
236 |
yield from response
|
237 |
|
238 |
-
|
239 |
-
def chat_chage(history):
|
240 |
if history:
|
241 |
if not history[-1][1]:
|
242 |
return gr.update(interactive=False)
|
243 |
elif history[-1][1][-1] != '…':
|
244 |
return gr.update(interactive=True)
|
245 |
-
return gr.update()
|
246 |
|
247 |
TEXT_TALK = "🎤 Talk"
|
248 |
TEXT_STOP = "⏹ Stop"
|
@@ -250,8 +351,9 @@ TEXT_STOP = "⏹ Stop"
|
|
250 |
def gr_main():
|
251 |
theme = gr.Theme.from_hub("freddyaboulton/[email protected]")
|
252 |
theme.set(
|
253 |
-
color_accent_soft="#818eb6", # ChatBot.svelte / .message-row.panel.user-row
|
254 |
-
background_fill_secondary="#6272a4", # ChatBot.svelte / .message-row.panel.bot-row
|
|
|
255 |
button_primary_text_color="*button_secondary_text_color",
|
256 |
button_primary_background_fill="*button_secondary_background_fill")
|
257 |
|
@@ -262,12 +364,20 @@ def gr_main():
|
|
262 |
css="footer {visibility: hidden}"
|
263 |
) as app:
|
264 |
state = new_state()
|
265 |
-
# auto_play = gr.Checkbox(False, label="Autoplay", render=False)
|
266 |
chatbot = gr.Chatbot(show_label=False, render=False, scale=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
iface = gr.ChatInterface(
|
268 |
chat,
|
269 |
chatbot=chatbot,
|
270 |
-
title=
|
271 |
submit_btn=gr.Button(
|
272 |
"Send",
|
273 |
variant="primary",
|
@@ -313,7 +423,7 @@ def gr_main():
|
|
313 |
play_last,
|
314 |
[chatbot, state], player)
|
315 |
|
316 |
-
chatbot.change(
|
317 |
start_stop_rec.click(
|
318 |
lambda x:x,
|
319 |
inputs=start_stop_rec,
|
@@ -335,6 +445,16 @@ def gr_main():
|
|
335 |
inputs=iface.textbox,
|
336 |
js='function (text){if (text) document.getElementById("submit_btn").click(); return [text]}'
|
337 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
|
339 |
token = gr.Textbox(visible=False)
|
340 |
|
@@ -344,9 +464,9 @@ def gr_main():
|
|
344 |
js=AUTH_JS)
|
345 |
|
346 |
app.queue(default_concurrency_limit=None, api_open=False)
|
347 |
-
app
|
348 |
|
349 |
if __name__ == "__main__":
|
350 |
ai_setup()
|
351 |
-
gr_main()
|
352 |
-
|
|
|
14 |
from openai import OpenAI
|
15 |
|
16 |
from json import loads as json_loads
|
17 |
+
import itertools
|
18 |
import time
|
19 |
import os
|
20 |
|
|
|
|
|
|
|
|
|
21 |
AI = True
|
22 |
|
23 |
+
if not hasattr(itertools, "batched"):
|
24 |
+
def batched(iterable, n):
|
25 |
+
"Batch data into lists of length n. The last batch may be shorter."
|
26 |
+
# batched('ABCDEFG', 3) --> ABC DEF G
|
27 |
+
it = iter(iterable)
|
28 |
+
while True:
|
29 |
+
batch = list(itertools.islice(it, n))
|
30 |
+
if not batch:
|
31 |
+
return
|
32 |
+
yield batch
|
33 |
+
itertools.batched = batched
|
34 |
+
|
35 |
def ai_setup():
|
36 |
global llm, prompt_chain, oai_client
|
37 |
|
|
|
47 |
)
|
48 |
|
49 |
retriever = vstore.as_retriever(search_kwargs={'k': 10})
|
50 |
+
|
51 |
+
prompt_template = os.environ.get("PROMPT_TEMPLATE")
|
52 |
+
prompt = ChatPromptTemplate.from_messages([('system', prompt_template)])
|
53 |
+
prompt_chain = (
|
54 |
+
{"context": retriever, "question": RunnablePassthrough()}
|
55 |
+
| RunnableLambda(format_context)
|
56 |
+
| prompt
|
57 |
+
# | llm
|
58 |
+
# | StrOutputParser()
|
59 |
+
)
|
60 |
else:
|
61 |
retriever = RunnableLambda(just_read)
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
def group_and_sort(documents):
|
64 |
grouped = {}
|
65 |
for document in documents:
|
|
|
118 |
def session_id(state: dict, request: gr.Request) -> str:
|
119 |
return f'{state["user"]}_{request.session_hash}'
|
120 |
|
121 |
+
class History:
|
122 |
+
store = None
|
123 |
+
def __init__(self, name:str, user:str, session_id:str, id:str = None):
|
124 |
+
self.session_id = session_id
|
125 |
+
self.name = name
|
126 |
+
self.user = user
|
127 |
+
self.astra_history = None
|
128 |
+
|
129 |
+
if id:
|
130 |
+
self.id = id
|
131 |
+
else:
|
132 |
+
self.id = f"{user}_{session_id}"
|
133 |
+
self.create()
|
134 |
+
|
135 |
+
@classmethod
|
136 |
+
def get_store(self):
|
137 |
+
if self.store is None:
|
138 |
+
self.store = AstraDBStore(
|
139 |
+
collection_name=f'{os.environ.get("ASTRA_DB_COLLECTION")}_sessions',
|
140 |
+
token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"),
|
141 |
+
api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"),
|
142 |
+
)
|
143 |
+
return self.store
|
144 |
+
|
145 |
+
@classmethod
|
146 |
+
def from_dict(cls, id:str, data:dict):
|
147 |
+
name = f":{id}"
|
148 |
+
name = data.get("name", name)
|
149 |
+
answer = cls(name, user=data["user"], id = id, session_id=data["session"])
|
150 |
+
|
151 |
+
return answer
|
152 |
+
|
153 |
+
@classmethod
|
154 |
+
def get_histories(cls, user:str):
|
155 |
+
store = cls.get_store()
|
156 |
+
histories = []
|
157 |
+
keys = [k for k in store.yield_keys(prefix=f"{user}_")]
|
158 |
+
for id, history in zip(keys, store.mget(keys)):
|
159 |
+
history = cls.from_dict(id = id, data = history)
|
160 |
+
histories.append(history)
|
161 |
+
return histories
|
162 |
+
|
163 |
+
@classmethod
|
164 |
+
def load(cls, id:str):
|
165 |
+
data = cls.get_store().mget([id])
|
166 |
+
return cls.from_dict(id, data[0])
|
167 |
+
|
168 |
+
def __str__(self):
|
169 |
+
return f"{self.id}:{self.name}"
|
170 |
+
|
171 |
+
def create(self):
|
172 |
+
history = {
|
173 |
+
'session' : self.session_id,
|
174 |
+
'user' : self.user,
|
175 |
+
'timestamp' : time.asctime(time.gmtime()),
|
176 |
+
'name' : self.name
|
177 |
+
}
|
178 |
+
self.get_store().mset([(self.id, history)])
|
179 |
+
|
180 |
+
@staticmethod
|
181 |
+
def get_history_collection_name():
|
182 |
+
return f'{os.environ.get("ASTRA_DB_COLLECTION")}_chat_history'
|
183 |
+
|
184 |
+
def get_astra_history(self):
|
185 |
+
if self.astra_history is None:
|
186 |
+
self.astra_history = AstraDBChatMessageHistory(
|
187 |
+
session_id=self.id,
|
188 |
+
collection_name=self.get_history_collection_name(),
|
189 |
+
token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"),
|
190 |
+
api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"),
|
191 |
+
)
|
192 |
+
return self.astra_history
|
193 |
+
|
194 |
+
def add(self, type:str, message):
|
195 |
+
if type == "system":
|
196 |
+
self.get_astra_history().add_message(message)
|
197 |
+
elif type == "user":
|
198 |
+
self.get_astra_history().add_user_message(message)
|
199 |
+
elif type == "ai":
|
200 |
+
self.get_astra_history().add_ai_message(message)
|
201 |
+
|
202 |
+
def messages(self):
|
203 |
+
return self.get_astra_history().messages
|
204 |
+
|
205 |
+
def clear(self):
|
206 |
+
self.get_astra_history().clear()
|
207 |
|
208 |
+
def delete(self):
|
209 |
+
self.clear()
|
210 |
+
self.get_store().mdelete([self.id])
|
211 |
+
|
212 |
+
def auth(token, state, request: gr.Request):
|
213 |
tokens=os.environ.get("APP_TOKENS")
|
214 |
if not tokens:
|
215 |
state["user"] = "anonymous"
|
216 |
else:
|
217 |
tokens=json_loads(tokens)
|
218 |
state["user"] = tokens.get(token, None)
|
219 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
return "", state
|
221 |
|
222 |
AUTH_JS = """function auth_js(token, state) {
|
|
|
234 |
gr.Warning("You need to authenticate first")
|
235 |
return answer
|
236 |
|
237 |
+
def list_histories(state):
|
238 |
+
if not_authenticated(state):
|
239 |
+
return gr.update()
|
240 |
+
|
241 |
+
histories = History.get_histories(state["user"])
|
242 |
+
answer = [(h.name, h.id) for h in histories]
|
243 |
+
return gr.update(choices=answer, value=None)
|
244 |
+
|
245 |
+
def add_history(state, request, type, message, name:str = None):
|
246 |
if not state["history"]:
|
247 |
+
name = name or message[:60]
|
248 |
+
state["history"] = History(
|
249 |
+
name = name,
|
250 |
+
user = state["user"],
|
251 |
+
session_id = request.session_hash
|
|
|
252 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
+
state["history"].add(type, message)
|
255 |
+
|
256 |
+
def load_history(state, history_id):
|
257 |
+
state["history"] = History.load(history_id)
|
258 |
+
|
259 |
+
history = [m.content for m in state["history"].messages()]
|
260 |
+
history = itertools.batched(history, 2)
|
261 |
+
history = [m for m in history]
|
262 |
+
|
263 |
+
if len(history) and len(history[-1]) == 1:
|
264 |
+
user_input = history[-1][0]
|
265 |
+
history = history[:-1]
|
266 |
+
else:
|
267 |
+
user_input = ''
|
268 |
+
|
269 |
+
return state, history, history, user_input # state, Chatbot, ChatInterface.state, ChatInterface.textbox
|
270 |
+
|
271 |
def chat(message, history, state, request:gr.Request):
|
272 |
if not_authenticated(state):
|
273 |
yield "You need to authenticate first"
|
|
|
277 |
system_prompt = prompt_chain.invoke(message)
|
278 |
system_prompt = system_prompt.messages[0]
|
279 |
state["system"] = system_prompt
|
280 |
+
|
281 |
+
# Next is commented out because astra has a limit on document size
|
282 |
+
# add_history(state, request, "system", system_prompt, name=message)
|
283 |
else:
|
284 |
system_prompt = state["system"]
|
285 |
|
|
|
337 |
response = lab11.generate(text=text, voice=whatson, stream=True)
|
338 |
yield from response
|
339 |
|
340 |
+
def chat_change(history):
|
|
|
341 |
if history:
|
342 |
if not history[-1][1]:
|
343 |
return gr.update(interactive=False)
|
344 |
elif history[-1][1][-1] != '…':
|
345 |
return gr.update(interactive=True)
|
346 |
+
return gr.update() # play_last_btn
|
347 |
|
348 |
TEXT_TALK = "🎤 Talk"
|
349 |
TEXT_STOP = "⏹ Stop"
|
|
|
351 |
def gr_main():
|
352 |
theme = gr.Theme.from_hub("freddyaboulton/[email protected]")
|
353 |
theme.set(
|
354 |
+
color_accent_soft="#818eb6", # ChatBot.svelte / .user / .message-row.panel.user-row . neutral_500 -> neutral_200
|
355 |
+
background_fill_secondary="#6272a4", # ChatBot.svelte / .bot / .message-row.panel.bot-row . neutral_500 -> neutral_400
|
356 |
+
background_fill_primary="#818eb6", # DropdownOptions.svelte / item
|
357 |
button_primary_text_color="*button_secondary_text_color",
|
358 |
button_primary_background_fill="*button_secondary_background_fill")
|
359 |
|
|
|
364 |
css="footer {visibility: hidden}"
|
365 |
) as app:
|
366 |
state = new_state()
|
|
|
367 |
chatbot = gr.Chatbot(show_label=False, render=False, scale=1)
|
368 |
+
gr.HTML('<h1 style="text-align: center">Sherlock Holmes stories</h1>')
|
369 |
+
history_choice = gr.Dropdown(
|
370 |
+
choices=[("History", "History")],
|
371 |
+
value="History",
|
372 |
+
show_label=False,
|
373 |
+
container=False,
|
374 |
+
interactive=True,
|
375 |
+
filterable=True)
|
376 |
+
|
377 |
iface = gr.ChatInterface(
|
378 |
chat,
|
379 |
chatbot=chatbot,
|
380 |
+
title=None,
|
381 |
submit_btn=gr.Button(
|
382 |
"Send",
|
383 |
variant="primary",
|
|
|
423 |
play_last,
|
424 |
[chatbot, state], player)
|
425 |
|
426 |
+
chatbot.change(chat_change, inputs=chatbot, outputs=play_last_btn)
|
427 |
start_stop_rec.click(
|
428 |
lambda x:x,
|
429 |
inputs=start_stop_rec,
|
|
|
445 |
inputs=iface.textbox,
|
446 |
js='function (text){if (text) document.getElementById("submit_btn").click(); return [text]}'
|
447 |
)
|
448 |
+
|
449 |
+
history_choice.focus(
|
450 |
+
list_histories,
|
451 |
+
inputs=state,
|
452 |
+
outputs=history_choice
|
453 |
+
)
|
454 |
+
history_choice.input(
|
455 |
+
load_history,
|
456 |
+
inputs=[state, history_choice],
|
457 |
+
outputs=[state, chatbot, iface.chatbot_state, iface.textbox])
|
458 |
|
459 |
token = gr.Textbox(visible=False)
|
460 |
|
|
|
464 |
js=AUTH_JS)
|
465 |
|
466 |
app.queue(default_concurrency_limit=None, api_open=False)
|
467 |
+
return app
|
468 |
|
469 |
if __name__ == "__main__":
|
470 |
ai_setup()
|
471 |
+
app = gr_main()
|
472 |
+
app.launch(show_api=False)
|