gerasdf commited on
Commit
f053bac
·
1 Parent(s): aef92d9

added history saving and sessions

Browse files
Files changed (1) hide show
  1. query.py +90 -35
query.py CHANGED
@@ -1,11 +1,12 @@
1
  import gradio as gr
2
 
3
- from langchain_astradb import AstraDBVectorStore
4
-
5
  from langchain_core.prompts import ChatPromptTemplate
6
  from langchain_core.output_parsers import StrOutputParser
7
  from langchain_core.runnables import RunnablePassthrough, RunnableLambda
8
  from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
 
 
 
9
  from langchain_openai import OpenAIEmbeddings, ChatOpenAI
10
 
11
  from elevenlabs import VoiceSettings
@@ -98,17 +99,35 @@ def just_read(pipeline_state):
98
 
99
  def new_state():
100
  return gr.State({
101
- "user": None,
102
- "system": None,
 
103
  })
104
 
105
- def auth(token, state):
 
 
 
106
  tokens=os.environ.get("APP_TOKENS")
107
  if not tokens:
108
  state["user"] = "anonymous"
109
  else:
110
  tokens=json_loads(tokens)
111
  state["user"] = tokens.get(token, None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  return "", state
113
 
114
  AUTH_JS = """function auth_js(token, state) {
@@ -126,30 +145,60 @@ def not_authenticated(state):
126
  gr.Warning("You need to authenticate first")
127
  return answer
128
 
 
 
 
 
 
 
 
 
 
129
 
130
- def chat(message, history, state):
 
 
 
 
 
 
 
 
131
  if not_authenticated(state):
132
  yield "You need to authenticate first"
133
- elif AI:
134
- if not history:
135
- system_prompt = prompt_chain.invoke(message)
136
- system_prompt = system_prompt.messages[0]
137
- state["system"] = system_prompt
138
- else:
139
- system_prompt = state["system"]
140
-
141
- messages = [system_prompt]
142
- for human, ai in history:
143
- messages.append(HumanMessage(human))
144
- messages.append(AIMessage(ai))
145
- messages.append(HumanMessage(message))
146
-
147
- all = ''
148
- for response in llm.stream(messages):
149
- all += response.content
150
- yield all
151
  else:
152
- yield f"{time.ctime()}: You said: {message}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  def on_audio(path, state):
155
  if not_authenticated(state):
@@ -194,6 +243,7 @@ def gr_main():
194
  theme=theme
195
  ) as app:
196
  state = new_state()
 
197
  chatbot = gr.Chatbot(show_label=False, render=False, scale=1)
198
  iface = gr.ChatInterface(
199
  chat,
@@ -225,6 +275,18 @@ def gr_main():
225
  show_label=False,
226
  format="mp3",
227
  waveform_options=gr.WaveformOptions(sample_rate=16000))
 
 
 
 
 
 
 
 
 
 
 
 
228
  mic.change(
229
  on_audio, [mic, state], [iface.textbox, mic]
230
  ).then(
@@ -232,16 +294,9 @@ def gr_main():
232
  js='function (text){if (text) document.getElementById("submit_btn").click(); return [text]}',
233
  inputs=iface.textbox
234
  )
235
-
236
- player = gr.Audio(
237
- show_label=False,
238
- show_download_button=True,
239
- visible=True,
240
- autoplay=True,
241
- streaming=True)
242
-
243
- play_btn = gr.Button("Play last ")
244
- play_btn.click(play_last, [chatbot, state], player)
245
 
246
 
247
  token = gr.Textbox(visible=False)
 
1
  import gradio as gr
2
 
 
 
3
  from langchain_core.prompts import ChatPromptTemplate
4
  from langchain_core.output_parsers import StrOutputParser
5
  from langchain_core.runnables import RunnablePassthrough, RunnableLambda
6
  from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
7
+
8
+ from langchain_astradb import AstraDBChatMessageHistory, AstraDBStore, AstraDBVectorStore
9
+
10
  from langchain_openai import OpenAIEmbeddings, ChatOpenAI
11
 
12
  from elevenlabs import VoiceSettings
 
99
 
100
  def new_state():
101
  return gr.State({
102
+ "user" : None,
103
+ "system" : None,
104
+ "history" : None,
105
  })
106
 
107
+ store = None
108
+ def auth(token, state, request: gr.Request):
109
+ global store
110
+
111
  tokens=os.environ.get("APP_TOKENS")
112
  if not tokens:
113
  state["user"] = "anonymous"
114
  else:
115
  tokens=json_loads(tokens)
116
  state["user"] = tokens.get(token, None)
117
+
118
+ if store is None:
119
+ store = AstraDBStore(
120
+ collection_name=f'{os.environ.get("ASTRA_DB_COLLECTION")}_sessions',
121
+ token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"),
122
+ api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"),
123
+ )
124
+ user_session = f'{state["user"]}_{request.session_hash}'
125
+ session_data = {
126
+ 'user' : state["user"],
127
+ 'session' : request.session_hash,
128
+ 'timestamp' : time.asctime(time.gmtime())
129
+ }
130
+ store.mset([(user_session, session_data)])
131
  return "", state
132
 
133
  AUTH_JS = """function auth_js(token, state) {
 
145
  gr.Warning("You need to authenticate first")
146
  return answer
147
 
148
+ def add_history(state, request, type, message):
149
+ if not state["history"]:
150
+ session = request.session_hash
151
+ state["history"] = AstraDBChatMessageHistory(
152
+ session_id=session,
153
+ collection_name=f'{os.environ.get("ASTRA_DB_COLLECTION")}_chat_history',
154
+ token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"),
155
+ api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"),
156
+ )
157
 
158
+ history = state["history"]
159
+ if type == "system":
160
+ history.add_message(message)
161
+ elif type == "user":
162
+ history.add_user_message(message)
163
+ elif type == "ai":
164
+ history.add_ai_message(message)
165
+
166
+ def chat(message, history, state, request:gr.Request):
167
  if not_authenticated(state):
168
  yield "You need to authenticate first"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  else:
170
+ if AI:
171
+ if not history:
172
+ system_prompt = prompt_chain.invoke(message)
173
+ system_prompt = system_prompt.messages[0]
174
+ state["system"] = system_prompt
175
+ # add_history(state, request, "system", system_prompt)
176
+ else:
177
+ system_prompt = state["system"]
178
+
179
+ add_history(state, request, "user", message)
180
+
181
+ messages = [system_prompt]
182
+ for human, ai in history:
183
+ messages.append(HumanMessage(human))
184
+ messages.append(AIMessage(ai))
185
+ messages.append(HumanMessage(message))
186
+
187
+ answer = ''
188
+ for response in llm.stream(messages):
189
+ answer += response.content
190
+ yield answer
191
+ else:
192
+ add_history(state, request, "user", message)
193
+
194
+ msg = f"{time.ctime()}: You said: {message}"
195
+ answer = ' '
196
+ for word in msg.split():
197
+ answer += f' {word}'
198
+ yield answer
199
+ time.sleep(0.05)
200
+
201
+ add_history(state, request, "ai", answer)
202
 
203
  def on_audio(path, state):
204
  if not_authenticated(state):
 
243
  theme=theme
244
  ) as app:
245
  state = new_state()
246
+ # auto_play = gr.Checkbox(False, label="Autoplay", render=False)
247
  chatbot = gr.Chatbot(show_label=False, render=False, scale=1)
248
  iface = gr.ChatInterface(
249
  chat,
 
275
  show_label=False,
276
  format="mp3",
277
  waveform_options=gr.WaveformOptions(sample_rate=16000))
278
+ player = gr.Audio(
279
+ show_label=False,
280
+ show_download_button=False,
281
+ show_share_button=False,
282
+ visible=True,
283
+ autoplay=True,
284
+ streaming=True,
285
+ interactive=False)
286
+ # with gr.Column():
287
+ # auto_play.render()
288
+ play_btn = gr.Button("Play last ")
289
+
290
  mic.change(
291
  on_audio, [mic, state], [iface.textbox, mic]
292
  ).then(
 
294
  js='function (text){if (text) document.getElementById("submit_btn").click(); return [text]}',
295
  inputs=iface.textbox
296
  )
297
+ play_btn.click(
298
+ play_last,
299
+ [chatbot, state], player)
 
 
 
 
 
 
 
300
 
301
 
302
  token = gr.Textbox(visible=False)