gerasdf commited on
Commit
70798cd
·
1 Parent(s): fc4c25e

lost of changes, history almost working (continuing an old chat not working, as SystemPrompt() won't be there)

Browse files
Files changed (1) hide show
  1. query.py +178 -58
query.py CHANGED
@@ -14,15 +14,24 @@ from elevenlabs.client import ElevenLabs
14
  from openai import OpenAI
15
 
16
  from json import loads as json_loads
 
17
  import time
18
  import os
19
 
20
- prompt_template = os.environ.get("PROMPT_TEMPLATE")
21
-
22
- prompt = ChatPromptTemplate.from_messages([('system', prompt_template)])
23
-
24
  AI = True
25
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def ai_setup():
27
  global llm, prompt_chain, oai_client
28
 
@@ -38,17 +47,19 @@ def ai_setup():
38
  )
39
 
40
  retriever = vstore.as_retriever(search_kwargs={'k': 10})
 
 
 
 
 
 
 
 
 
 
41
  else:
42
  retriever = RunnableLambda(just_read)
43
 
44
- prompt_chain = (
45
- {"context": retriever, "question": RunnablePassthrough()}
46
- | RunnableLambda(format_context)
47
- | prompt
48
- # | llm
49
- # | StrOutputParser()
50
- )
51
-
52
  def group_and_sort(documents):
53
  grouped = {}
54
  for document in documents:
@@ -107,32 +118,105 @@ def new_state():
107
  def session_id(state: dict, request: gr.Request) -> str:
108
  return f'{state["user"]}_{request.session_hash}'
109
 
110
- store = None
111
- def auth(token, state, request: gr.Request):
112
- global store
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
 
 
 
 
 
114
  tokens=os.environ.get("APP_TOKENS")
115
  if not tokens:
116
  state["user"] = "anonymous"
117
  else:
118
  tokens=json_loads(tokens)
119
  state["user"] = tokens.get(token, None)
120
-
121
- if state["user"]:
122
- if store is None:
123
- store = AstraDBStore(
124
- collection_name=f'{os.environ.get("ASTRA_DB_COLLECTION")}_sessions',
125
- token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"),
126
- api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"),
127
- )
128
- user_session = session_id(state, request)
129
- session_data = {
130
- 'user' : state["user"],
131
- 'session' : request.session_hash,
132
- 'timestamp' : time.asctime(time.gmtime())
133
- }
134
- store.mset([(user_session, session_data)])
135
-
136
  return "", state
137
 
138
  AUTH_JS = """function auth_js(token, state) {
@@ -150,24 +234,40 @@ def not_authenticated(state):
150
  gr.Warning("You need to authenticate first")
151
  return answer
152
 
153
- def add_history(state, request, type, message):
 
 
 
 
 
 
 
 
154
  if not state["history"]:
155
- session = session_id(state, request)
156
- state["history"] = AstraDBChatMessageHistory(
157
- session_id=session,
158
- collection_name=f'{os.environ.get("ASTRA_DB_COLLECTION")}_chat_history',
159
- token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"),
160
- api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"),
161
  )
162
-
163
- history = state["history"]
164
- if type == "system":
165
- history.add_message(message)
166
- elif type == "user":
167
- history.add_user_message(message)
168
- elif type == "ai":
169
- history.add_ai_message(message)
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  def chat(message, history, state, request:gr.Request):
172
  if not_authenticated(state):
173
  yield "You need to authenticate first"
@@ -177,7 +277,9 @@ def chat(message, history, state, request:gr.Request):
177
  system_prompt = prompt_chain.invoke(message)
178
  system_prompt = system_prompt.messages[0]
179
  state["system"] = system_prompt
180
- # add_history(state, request, "system", system_prompt)
 
 
181
  else:
182
  system_prompt = state["system"]
183
 
@@ -235,14 +337,13 @@ def play_last(history, state):
235
  response = lab11.generate(text=text, voice=whatson, stream=True)
236
  yield from response
237
 
238
-
239
- def chat_chage(history):
240
  if history:
241
  if not history[-1][1]:
242
  return gr.update(interactive=False)
243
  elif history[-1][1][-1] != '…':
244
  return gr.update(interactive=True)
245
- return gr.update()
246
 
247
  TEXT_TALK = "🎤 Talk"
248
  TEXT_STOP = "⏹ Stop"
@@ -250,8 +351,9 @@ TEXT_STOP = "⏹ Stop"
250
  def gr_main():
251
  theme = gr.Theme.from_hub("freddyaboulton/[email protected]")
252
  theme.set(
253
- color_accent_soft="#818eb6", # ChatBot.svelte / .message-row.panel.user-row
254
- background_fill_secondary="#6272a4", # ChatBot.svelte / .message-row.panel.bot-row
 
255
  button_primary_text_color="*button_secondary_text_color",
256
  button_primary_background_fill="*button_secondary_background_fill")
257
 
@@ -262,12 +364,20 @@ def gr_main():
262
  css="footer {visibility: hidden}"
263
  ) as app:
264
  state = new_state()
265
- # auto_play = gr.Checkbox(False, label="Autoplay", render=False)
266
  chatbot = gr.Chatbot(show_label=False, render=False, scale=1)
 
 
 
 
 
 
 
 
 
267
  iface = gr.ChatInterface(
268
  chat,
269
  chatbot=chatbot,
270
- title="Sherlock Holmes stories",
271
  submit_btn=gr.Button(
272
  "Send",
273
  variant="primary",
@@ -313,7 +423,7 @@ def gr_main():
313
  play_last,
314
  [chatbot, state], player)
315
 
316
- chatbot.change(chat_chage, inputs=chatbot, outputs=play_last_btn)
317
  start_stop_rec.click(
318
  lambda x:x,
319
  inputs=start_stop_rec,
@@ -335,6 +445,16 @@ def gr_main():
335
  inputs=iface.textbox,
336
  js='function (text){if (text) document.getElementById("submit_btn").click(); return [text]}'
337
  )
 
 
 
 
 
 
 
 
 
 
338
 
339
  token = gr.Textbox(visible=False)
340
 
@@ -344,9 +464,9 @@ def gr_main():
344
  js=AUTH_JS)
345
 
346
  app.queue(default_concurrency_limit=None, api_open=False)
347
- app.launch(show_api=False)
348
 
349
  if __name__ == "__main__":
350
  ai_setup()
351
- gr_main()
352
-
 
14
  from openai import OpenAI
15
 
16
  from json import loads as json_loads
17
+ import itertools
18
  import time
19
  import os
20
 
 
 
 
 
21
  AI = True
22
 
23
+ if not hasattr(itertools, "batched"):
24
+ def batched(iterable, n):
25
+ "Batch data into lists of length n. The last batch may be shorter."
26
+ # batched('ABCDEFG', 3) --> ABC DEF G
27
+ it = iter(iterable)
28
+ while True:
29
+ batch = list(itertools.islice(it, n))
30
+ if not batch:
31
+ return
32
+ yield batch
33
+ itertools.batched = batched
34
+
35
  def ai_setup():
36
  global llm, prompt_chain, oai_client
37
 
 
47
  )
48
 
49
  retriever = vstore.as_retriever(search_kwargs={'k': 10})
50
+
51
+ prompt_template = os.environ.get("PROMPT_TEMPLATE")
52
+ prompt = ChatPromptTemplate.from_messages([('system', prompt_template)])
53
+ prompt_chain = (
54
+ {"context": retriever, "question": RunnablePassthrough()}
55
+ | RunnableLambda(format_context)
56
+ | prompt
57
+ # | llm
58
+ # | StrOutputParser()
59
+ )
60
  else:
61
  retriever = RunnableLambda(just_read)
62
 
 
 
 
 
 
 
 
 
63
  def group_and_sort(documents):
64
  grouped = {}
65
  for document in documents:
 
118
  def session_id(state: dict, request: gr.Request) -> str:
119
  return f'{state["user"]}_{request.session_hash}'
120
 
121
+ class History:
122
+ store = None
123
+ def __init__(self, name:str, user:str, session_id:str, id:str = None):
124
+ self.session_id = session_id
125
+ self.name = name
126
+ self.user = user
127
+ self.astra_history = None
128
+
129
+ if id:
130
+ self.id = id
131
+ else:
132
+ self.id = f"{user}_{session_id}"
133
+ self.create()
134
+
135
+ @classmethod
136
+ def get_store(self):
137
+ if self.store is None:
138
+ self.store = AstraDBStore(
139
+ collection_name=f'{os.environ.get("ASTRA_DB_COLLECTION")}_sessions',
140
+ token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"),
141
+ api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"),
142
+ )
143
+ return self.store
144
+
145
+ @classmethod
146
+ def from_dict(cls, id:str, data:dict):
147
+ name = f":{id}"
148
+ name = data.get("name", name)
149
+ answer = cls(name, user=data["user"], id = id, session_id=data["session"])
150
+
151
+ return answer
152
+
153
+ @classmethod
154
+ def get_histories(cls, user:str):
155
+ store = cls.get_store()
156
+ histories = []
157
+ keys = [k for k in store.yield_keys(prefix=f"{user}_")]
158
+ for id, history in zip(keys, store.mget(keys)):
159
+ history = cls.from_dict(id = id, data = history)
160
+ histories.append(history)
161
+ return histories
162
+
163
+ @classmethod
164
+ def load(cls, id:str):
165
+ data = cls.get_store().mget([id])
166
+ return cls.from_dict(id, data[0])
167
+
168
+ def __str__(self):
169
+ return f"{self.id}:{self.name}"
170
+
171
+ def create(self):
172
+ history = {
173
+ 'session' : self.session_id,
174
+ 'user' : self.user,
175
+ 'timestamp' : time.asctime(time.gmtime()),
176
+ 'name' : self.name
177
+ }
178
+ self.get_store().mset([(self.id, history)])
179
+
180
+ @staticmethod
181
+ def get_history_collection_name():
182
+ return f'{os.environ.get("ASTRA_DB_COLLECTION")}_chat_history'
183
+
184
+ def get_astra_history(self):
185
+ if self.astra_history is None:
186
+ self.astra_history = AstraDBChatMessageHistory(
187
+ session_id=self.id,
188
+ collection_name=self.get_history_collection_name(),
189
+ token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"),
190
+ api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"),
191
+ )
192
+ return self.astra_history
193
+
194
+ def add(self, type:str, message):
195
+ if type == "system":
196
+ self.get_astra_history().add_message(message)
197
+ elif type == "user":
198
+ self.get_astra_history().add_user_message(message)
199
+ elif type == "ai":
200
+ self.get_astra_history().add_ai_message(message)
201
+
202
+ def messages(self):
203
+ return self.get_astra_history().messages
204
+
205
+ def clear(self):
206
+ self.get_astra_history().clear()
207
 
208
+ def delete(self):
209
+ self.clear()
210
+ self.get_store().mdelete([self.id])
211
+
212
+ def auth(token, state, request: gr.Request):
213
  tokens=os.environ.get("APP_TOKENS")
214
  if not tokens:
215
  state["user"] = "anonymous"
216
  else:
217
  tokens=json_loads(tokens)
218
  state["user"] = tokens.get(token, None)
219
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  return "", state
221
 
222
  AUTH_JS = """function auth_js(token, state) {
 
234
  gr.Warning("You need to authenticate first")
235
  return answer
236
 
237
+ def list_histories(state):
238
+ if not_authenticated(state):
239
+ return gr.update()
240
+
241
+ histories = History.get_histories(state["user"])
242
+ answer = [(h.name, h.id) for h in histories]
243
+ return gr.update(choices=answer, value=None)
244
+
245
+ def add_history(state, request, type, message, name:str = None):
246
  if not state["history"]:
247
+ name = name or message[:60]
248
+ state["history"] = History(
249
+ name = name,
250
+ user = state["user"],
251
+ session_id = request.session_hash
 
252
  )
 
 
 
 
 
 
 
 
253
 
254
+ state["history"].add(type, message)
255
+
256
+ def load_history(state, history_id):
257
+ state["history"] = History.load(history_id)
258
+
259
+ history = [m.content for m in state["history"].messages()]
260
+ history = itertools.batched(history, 2)
261
+ history = [m for m in history]
262
+
263
+ if len(history) and len(history[-1]) == 1:
264
+ user_input = history[-1][0]
265
+ history = history[:-1]
266
+ else:
267
+ user_input = ''
268
+
269
+ return state, history, history, user_input # state, Chatbot, ChatInterface.state, ChatInterface.textbox
270
+
271
  def chat(message, history, state, request:gr.Request):
272
  if not_authenticated(state):
273
  yield "You need to authenticate first"
 
277
  system_prompt = prompt_chain.invoke(message)
278
  system_prompt = system_prompt.messages[0]
279
  state["system"] = system_prompt
280
+
281
+ # Next is commented out because astra has a limit on document size
282
+ # add_history(state, request, "system", system_prompt, name=message)
283
  else:
284
  system_prompt = state["system"]
285
 
 
337
  response = lab11.generate(text=text, voice=whatson, stream=True)
338
  yield from response
339
 
340
+ def chat_change(history):
 
341
  if history:
342
  if not history[-1][1]:
343
  return gr.update(interactive=False)
344
  elif history[-1][1][-1] != '…':
345
  return gr.update(interactive=True)
346
+ return gr.update() # play_last_btn
347
 
348
  TEXT_TALK = "🎤 Talk"
349
  TEXT_STOP = "⏹ Stop"
 
351
  def gr_main():
352
  theme = gr.Theme.from_hub("freddyaboulton/[email protected]")
353
  theme.set(
354
+ color_accent_soft="#818eb6", # ChatBot.svelte / .user / .message-row.panel.user-row . neutral_500 -> neutral_200
355
+ background_fill_secondary="#6272a4", # ChatBot.svelte / .bot / .message-row.panel.bot-row . neutral_500 -> neutral_400
356
+ background_fill_primary="#818eb6", # DropdownOptions.svelte / item
357
  button_primary_text_color="*button_secondary_text_color",
358
  button_primary_background_fill="*button_secondary_background_fill")
359
 
 
364
  css="footer {visibility: hidden}"
365
  ) as app:
366
  state = new_state()
 
367
  chatbot = gr.Chatbot(show_label=False, render=False, scale=1)
368
+ gr.HTML('<h1 style="text-align: center">Sherlock Holmes stories</h1>')
369
+ history_choice = gr.Dropdown(
370
+ choices=[("History", "History")],
371
+ value="History",
372
+ show_label=False,
373
+ container=False,
374
+ interactive=True,
375
+ filterable=True)
376
+
377
  iface = gr.ChatInterface(
378
  chat,
379
  chatbot=chatbot,
380
+ title=None,
381
  submit_btn=gr.Button(
382
  "Send",
383
  variant="primary",
 
423
  play_last,
424
  [chatbot, state], player)
425
 
426
+ chatbot.change(chat_change, inputs=chatbot, outputs=play_last_btn)
427
  start_stop_rec.click(
428
  lambda x:x,
429
  inputs=start_stop_rec,
 
445
  inputs=iface.textbox,
446
  js='function (text){if (text) document.getElementById("submit_btn").click(); return [text]}'
447
  )
448
+
449
+ history_choice.focus(
450
+ list_histories,
451
+ inputs=state,
452
+ outputs=history_choice
453
+ )
454
+ history_choice.input(
455
+ load_history,
456
+ inputs=[state, history_choice],
457
+ outputs=[state, chatbot, iface.chatbot_state, iface.textbox])
458
 
459
  token = gr.Textbox(visible=False)
460
 
 
464
  js=AUTH_JS)
465
 
466
  app.queue(default_concurrency_limit=None, api_open=False)
467
+ return app
468
 
469
  if __name__ == "__main__":
470
  ai_setup()
471
+ app = gr_main()
472
+ app.launch(show_api=False)