Pijush2023 commited on
Commit
c2d26c1
·
verified ·
1 Parent(s): 9055762

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +355 -39
app.py CHANGED
@@ -5,37 +5,72 @@ from langchain_core.prompts import ChatPromptTemplate
5
  from langchain_core.output_parsers import StrOutputParser
6
  from langchain_openai import ChatOpenAI
7
  from langchain_community.graphs import Neo4jGraph
8
- from typing import List
9
  from pydantic import BaseModel, Field
10
- from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
 
 
 
 
 
 
 
11
  import requests
12
  import tempfile
 
 
 
 
13
  import torch
 
 
14
  import numpy as np
 
 
 
15
 
16
- # Setup logging to a file to capture debug information
17
- logging.basicConfig(filename='neo4j_retrieval.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
18
 
19
- # Setup Neo4j connection
 
 
 
 
 
 
 
20
  graph = Neo4jGraph(
21
  url="neo4j+s://c62d0d35.databases.neo4j.io",
22
  username="neo4j",
23
  password="_x8f-_aAQvs2NB0x6s0ZHSh3W_y-HrENDbgStvsUCM0"
24
  )
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # Define entity extraction and retrieval functions
27
  class Entities(BaseModel):
28
  names: List[str] = Field(
29
  ..., description="All the person, organization, or business entities that appear in the text"
30
  )
31
 
32
- # Define prompt and model for entity extraction
33
- chat_model = ChatOpenAI(temperature=0, model_name="gpt-4", api_key=os.environ['OPENAI_API_KEY'])
34
- entity_prompt = ChatPromptTemplate.from_messages([
35
  ("system", "You are extracting organization and person entities from the text."),
36
  ("human", "Use the given format to extract information from the following input: {question}"),
37
  ])
38
- entity_chain = entity_prompt | chat_model.with_structured_output(Entities)
 
 
39
 
40
  def remove_lucene_chars(input: str) -> str:
41
  return input.translate(str.maketrans({
@@ -53,7 +88,10 @@ def generate_full_text_query(input: str) -> str:
53
  full_text_query += f" {words[-1]}~2"
54
  return full_text_query.strip()
55
 
56
- def retrieve_data_from_neo4j(question: str) -> str:
 
 
 
57
  result = ""
58
  entities = entity_chain.invoke({"question": question})
59
  for entity in entities.names:
@@ -64,6 +102,10 @@ def retrieve_data_from_neo4j(question: str) -> str:
64
  WITH node
65
  MATCH (node)-[r:!MENTIONS]->(neighbor)
66
  RETURN node.id + ' - ' + type(r) + ' -> ' + neighbor.id AS output
 
 
 
 
67
  }
68
  RETURN output LIMIT 50
69
  """,
@@ -72,23 +114,188 @@ def retrieve_data_from_neo4j(question: str) -> str:
72
  result += "\n".join([el['output'] for el in response])
73
  return result
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  # Function to generate audio with Eleven Labs TTS
76
  def generate_audio_elevenlabs(text):
77
  XI_API_KEY = os.environ['ELEVENLABS_API']
78
  VOICE_ID = 'ehbJzYLQFpwbJmGkqbnW'
79
  tts_url = f"https://api.elevenlabs.io/v1/text-to-speech/{VOICE_ID}/stream"
80
- headers = {"Accept": "application/json", "xi-api-key": XI_API_KEY}
81
- data = {"text": str(text), "model_id": "eleven_multilingual_v2", "voice_settings": {"stability": 1.0}}
 
 
 
 
 
 
 
 
 
 
 
 
82
  response = requests.post(tts_url, headers=headers, json=data, stream=True)
83
  if response.ok:
84
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as f:
85
  for chunk in response.iter_content(chunk_size=1024):
86
  if chunk:
87
  f.write(chunk)
88
- return f.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  return None
90
 
91
- # ASR model setup using Whisper
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  model_id = 'openai/whisper-large-v3'
93
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
94
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
@@ -104,37 +311,146 @@ pipe_asr = pipeline(
104
  chunk_length_s=15,
105
  batch_size=16,
106
  torch_dtype=torch_dtype,
107
- device=device
 
108
  )
109
 
110
- # Function to handle audio input, transcription, and Neo4j response generation
111
- def transcribe_and_respond(audio):
112
- # Transcribe audio input
113
- audio_data = {"array": audio["data"], "sampling_rate": audio["sample_rate"]}
114
- transcription = pipe_asr(audio_data)["text"]
115
- logging.debug(f"Transcription: {transcription}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
- # Retrieve data from Neo4j based on transcription
118
- response_text = retrieve_data_from_neo4j(transcription)
119
- logging.debug(f"Neo4j Response: {response_text}")
120
 
121
- # Convert response to audio
122
- return generate_audio_elevenlabs(response_text)
123
 
124
- # Define Gradio interface
125
- with gr.Blocks() as demo:
126
- audio_input = gr.Audio(sources="microphone", type="numpy", streaming=True, label="Speak to Ask") # Removed streaming mode for manual submission
127
- audio_output = gr.Audio(label="Response", type="filepath", autoplay=True, interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
- # "Submit Audio" button
130
- submit_button = gr.Button("Submit Audio")
 
 
 
 
131
 
132
- # Link the button to trigger response generation after clicking
133
- submit_button.click(
134
- fn=transcribe_and_respond,
135
- inputs=audio_input,
136
- outputs=audio_output
 
 
 
 
 
 
 
137
  )
138
 
139
- # Launch Gradio interface
140
- demo.launch(show_error=True, share=True)
 
5
  from langchain_core.output_parsers import StrOutputParser
6
  from langchain_openai import ChatOpenAI
7
  from langchain_community.graphs import Neo4jGraph
8
+ from typing import List, Tuple
9
  from pydantic import BaseModel, Field
10
+ from langchain_core.messages import AIMessage, HumanMessage
11
+ from langchain_core.runnables import (
12
+ RunnableBranch,
13
+ RunnableLambda,
14
+ RunnablePassthrough,
15
+ RunnableParallel,
16
+ )
17
+ from langchain_core.prompts.prompt import PromptTemplate
18
  import requests
19
  import tempfile
20
+ from langchain.memory import ConversationBufferWindowMemory
21
+ import time
22
+ import logging
23
+ from langchain.chains import ConversationChain
24
  import torch
25
+ import torchaudio
26
+ from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
27
  import numpy as np
28
+ import threading
29
+ from langchain_community.vectorstores import Neo4jVector
30
+ from langchain_openai import OpenAIEmbeddings
31
 
 
 
32
 
33
+ #code for history
34
+ conversational_memory = ConversationBufferWindowMemory(
35
+ memory_key='chat_history',
36
+ k=10,
37
+ return_messages=True
38
+ )
39
+
40
+ # Setup Neo4j
41
  graph = Neo4jGraph(
42
  url="neo4j+s://c62d0d35.databases.neo4j.io",
43
  username="neo4j",
44
  password="_x8f-_aAQvs2NB0x6s0ZHSh3W_y-HrENDbgStvsUCM0"
45
  )
46
 
47
+
48
+ # directly show the graph resulting from the given Cypher query
49
+ default_cypher = "MATCH (s)-[r:!MENTIONS]->(t) RETURN s,r,t LIMIT 50"
50
+
51
+
52
+ vector_index = Neo4jVector.from_existing_graph(
53
+ OpenAIEmbeddings(openai_api_key="sk-PV6RlpmTifrWo_olwL1IR69J9v2e5AKe-Xfxs_Yf9VT3BlbkFJm-UJQx5RNyGpok9MM_DYSTmayn7y-lKLSBqXecEoYA"),
54
+ graph=graph,
55
+ search_type="hybrid",
56
+ node_label="Document",
57
+ text_node_properties=["text"],
58
+ embedding_node_property="embedding",
59
+ )
60
+
61
  # Define entity extraction and retrieval functions
62
  class Entities(BaseModel):
63
  names: List[str] = Field(
64
  ..., description="All the person, organization, or business entities that appear in the text"
65
  )
66
 
67
+ prompt = ChatPromptTemplate.from_messages([
 
 
68
  ("system", "You are extracting organization and person entities from the text."),
69
  ("human", "Use the given format to extract information from the following input: {question}"),
70
  ])
71
+
72
+ chat_model = ChatOpenAI(temperature=0, model_name="gpt-4o", api_key=os.environ['OPENAI_API_KEY'])
73
+ entity_chain = prompt | chat_model.with_structured_output(Entities)
74
 
75
  def remove_lucene_chars(input: str) -> str:
76
  return input.translate(str.maketrans({
 
88
  full_text_query += f" {words[-1]}~2"
89
  return full_text_query.strip()
90
 
91
+ # Setup logging to a file to capture debug information
92
+ logging.basicConfig(filename='neo4j_retrieval.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
93
+
94
+ def structured_retriever(question: str) -> str:
95
  result = ""
96
  entities = entity_chain.invoke({"question": question})
97
  for entity in entities.names:
 
102
  WITH node
103
  MATCH (node)-[r:!MENTIONS]->(neighbor)
104
  RETURN node.id + ' - ' + type(r) + ' -> ' + neighbor.id AS output
105
+ UNION ALL
106
+ WITH node
107
+ MATCH (node)<-[r:!MENTIONS]-(neighbor)
108
+ RETURN neighbor.id + ' - ' + type(r) + ' -> ' + node.id AS output
109
  }
110
  RETURN output LIMIT 50
111
  """,
 
114
  result += "\n".join([el['output'] for el in response])
115
  return result
116
 
117
+ def retriever(question: str):
118
+ print(f"Search query: {question}")
119
+ structured_data = structured_retriever(question)
120
+ unstructured_data = [el.page_content for el in vector_index.similarity_search(question)]
121
+ final_data = f"""Structured data:
122
+ {structured_data}
123
+ Unstructured data:
124
+ {"#Document ". join(unstructured_data)}
125
+ """
126
+ return final_data
127
+
128
+ # Setup for condensing the follow-up questions
129
+ _template = """Given the following conversation and a follow-up question, rephrase the follow-up question to be a standalone question,
130
+ in its original language.
131
+ Chat History:
132
+ {chat_history}
133
+ Follow Up Input: {question}
134
+ Standalone question:"""
135
+
136
+ CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
137
+
138
+ def _format_chat_history(chat_history: list[tuple[str, str]]) -> list:
139
+ buffer = []
140
+ for human, ai in chat_history:
141
+ buffer.append(HumanMessage(content=human))
142
+ buffer.append(AIMessage(content=ai))
143
+ return buffer
144
+
145
+ _search_query = RunnableBranch(
146
+ # If input includes chat_history, we condense it with the follow-up question
147
+ (
148
+ RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
149
+ run_name="HasChatHistoryCheck"
150
+ ), # Condense follow-up question and chat into a standalone_question
151
+ RunnablePassthrough.assign(
152
+ chat_history=lambda x: _format_chat_history(x["chat_history"])
153
+ )
154
+ | CONDENSE_QUESTION_PROMPT
155
+ | ChatOpenAI(temperature=0,openai_api_key="sk-PV6RlpmTifrWo_olwL1IR69J9v2e5AKe-Xfxs_Yf9VT3BlbkFJm-UJQx5RNyGpok9MM_DYSTmayn7y-lKLSBqXecEoYA")
156
+ | StrOutputParser(),
157
+ ),
158
+ # Else, we have no chat history, so just pass through the question
159
+ RunnableLambda(lambda x : x["question"]),
160
+ )
161
+
162
+
163
+ template = """I am a guide for Birmingham, Alabama. I can provide recommendations and insights about the city, including events and activities.
164
+ Ask your question directly, and I'll provide a precise and quick,short and crisp response in a conversational way without any Greet.
165
+ {context}
166
+ Question: {question}
167
+ Answer:"""
168
+
169
+
170
+ prompt = ChatPromptTemplate.from_template(template)
171
+
172
+ # Define the chain for Neo4j-based retrieval and response generation
173
+ chain_neo4j = (
174
+ RunnableParallel(
175
+ {
176
+ "context": _search_query | retriever_neo4j,
177
+ "question": RunnablePassthrough(),
178
+ }
179
+ )
180
+ | prompt
181
+ | chat_model
182
+ | StrOutputParser()
183
+ )
184
+
185
+ # Define the function to get the response
186
+ def get_response(question):
187
+ try:
188
+ return chain_neo4j.invoke({"question": question})
189
+ except Exception as e:
190
+ return f"Error: {str(e)}"
191
+
192
+ # Define the function to clear input and output
193
+ def clear_fields():
194
+ return [],"",None
195
+
196
  # Function to generate audio with Eleven Labs TTS
197
  def generate_audio_elevenlabs(text):
198
  XI_API_KEY = os.environ['ELEVENLABS_API']
199
  VOICE_ID = 'ehbJzYLQFpwbJmGkqbnW'
200
  tts_url = f"https://api.elevenlabs.io/v1/text-to-speech/{VOICE_ID}/stream"
201
+ headers = {
202
+ "Accept": "application/json",
203
+ "xi-api-key": XI_API_KEY
204
+ }
205
+ data = {
206
+ "text": str(text),
207
+ "model_id": "eleven_multilingual_v2",
208
+ "voice_settings": {
209
+ "stability": 1.0,
210
+ "similarity_boost": 0.0,
211
+ "style": 0.60,
212
+ "use_speaker_boost": False
213
+ }
214
+ }
215
  response = requests.post(tts_url, headers=headers, json=data, stream=True)
216
  if response.ok:
217
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as f:
218
  for chunk in response.iter_content(chunk_size=1024):
219
  if chunk:
220
  f.write(chunk)
221
+ audio_path = f.name
222
+ logging.debug(f"Audio saved to {audio_path}")
223
+ return audio_path # Return audio path for automatic playback
224
+ else:
225
+ logging.error(f"Error generating audio: {response.text}")
226
+ return None
227
+
228
+
229
+
230
+ def handle_mode_selection(mode, chat_history, question):
231
+ if mode == "Normal Chatbot":
232
+ # Append the user's question to chat history first
233
+ chat_history.append((question, "")) # Placeholder for the bot's response
234
+
235
+ # Stream the response and update chat history with each chunk
236
+ for response_chunk in chat_with_bot(chat_history):
237
+ chat_history[-1] = (question, response_chunk[-1][1]) # Update last entry with streamed response
238
+ yield chat_history, "", None # Stream each chunk to display in the chatbot
239
+ yield chat_history, "", None # Final yield to complete the response
240
+
241
+ elif mode == "Voice to Voice Conversation":
242
+ # Voice to Voice mode: Stream the response text and then convert it to audio
243
+ response_text = get_response(question) # Retrieve response text
244
+ audio_path = generate_audio_elevenlabs(response_text) # Convert response to audio
245
+ yield [], "", audio_path # Only output the audio response without updating chatbot history
246
+
247
+
248
+ # Function to add a user's message to the chat history and clear the input box
249
+ def add_message(history, message):
250
+ if message.strip():
251
+ history.append((message, "")) # Add the user's message to the chat history only if it's not empty
252
+ return history, "" # Clear the input box
253
+
254
+ # Define function to generate a streaming response
255
+ def chat_with_bot(messages):
256
+ user_message = messages[-1][0] # Get the last user message (input)
257
+ messages[-1] = (user_message, "") # Prepare a placeholder for the bot's response
258
+
259
+ response = get_response(user_message) # Assume `get_response` is a generator function
260
+
261
+ # Stream each character in the response and update the history progressively
262
+ for character in response:
263
+ messages[-1] = (user_message, messages[-1][1] + character)
264
+ yield messages # Stream each updated chunk
265
+ time.sleep(0.05) # Adjust delay as needed for real-time effect
266
+
267
+ yield messages # Final yield to complete the response
268
+
269
+
270
+
271
+ # Function to generate audio with Eleven Labs TTS from the last bot response
272
+ def generate_audio_from_last_response(history):
273
+ # Get the most recent bot response from the chat history
274
+ if history and len(history) > 0:
275
+ recent_response = history[-1][1] # The second item in the tuple is the bot response text
276
+ if recent_response:
277
+ return generate_audio_elevenlabs(recent_response)
278
  return None
279
 
280
+ # Define example prompts
281
+ examples = [
282
+ ["What are some popular events in Birmingham?"],
283
+ ["Who are the top players of the Crimson Tide?"],
284
+ ["Where can I find a hamburger?"],
285
+ ["What are some popular tourist attractions in Birmingham?"],
286
+ ["What are some good clubs in Birmingham?"],
287
+ ["Is there a farmer's market or craft fair in Birmingham, Alabama?"],
288
+ ["Are there any special holiday events or parades in Birmingham, Alabama, during December?"],
289
+ ["What are the best places to enjoy live music in Birmingham, Alabama?"]
290
+
291
+ ]
292
+
293
+ # Function to insert the prompt into the textbox when clicked
294
+ def insert_prompt(current_text, prompt):
295
+ return prompt[0] if prompt else current_text
296
+
297
+
298
+ # Define the ASR model with Whisper
299
  model_id = 'openai/whisper-large-v3'
300
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
301
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
 
311
  chunk_length_s=15,
312
  batch_size=16,
313
  torch_dtype=torch_dtype,
314
+ device=device,
315
+ return_timestamps=True
316
  )
317
 
318
+ # Define the function to reset the state after 10 seconds
319
+ def auto_reset_state():
320
+ time.sleep(5)
321
+ return None, "" # Reset the state and clear input text
322
+
323
+
324
+ def transcribe_function(stream, new_chunk):
325
+ try:
326
+ sr, y = new_chunk[0], new_chunk[1]
327
+ except TypeError:
328
+ print(f"Error chunk structure: {type(new_chunk)}, content: {new_chunk}")
329
+ return stream, "", None
330
+
331
+ # Ensure y is not empty and is at least 1-dimensional
332
+ if y is None or len(y) == 0:
333
+ return stream, "", None
334
+
335
+ y = y.astype(np.float32)
336
+ max_abs_y = np.max(np.abs(y))
337
+ if max_abs_y > 0:
338
+ y = y / max_abs_y
339
+
340
+ # Ensure stream is also at least 1-dimensional before concatenation
341
+ if stream is not None and len(stream) > 0:
342
+ stream = np.concatenate([stream, y])
343
+ else:
344
+ stream = y
345
+
346
+ # Process the audio data for transcription
347
+ result = pipe_asr({"array": stream, "sampling_rate": sr}, return_timestamps=False)
348
+ full_text = result.get("text", "")
349
+
350
+ # Start a thread to reset the state after 10 seconds
351
+ threading.Thread(target=auto_reset_state).start()
352
+
353
+ return stream, full_text, full_text
354
+
355
+
356
 
357
+ # Define the function to clear the state and input text
358
+ def clear_transcription_state():
359
+ return None, ""
360
 
 
 
361
 
362
+
363
+ with gr.Blocks(theme="rawrsor1/Everforest") as demo:
364
+ chatbot = gr.Chatbot([], elem_id="RADAR", bubble_full_width=False)
365
+ with gr.Row():
366
+ with gr.Column():
367
+ mode_selection = gr.Radio(
368
+ choices=["Normal Chatbot", "Voice to Voice Conversation"],
369
+ label="Mode Selection",
370
+ value="Normal Chatbot"
371
+ )
372
+ with gr.Row():
373
+ with gr.Column():
374
+ question_input = gr.Textbox(label="Ask a Question", placeholder="Type your question here...")
375
+ audio_input = gr.Audio(sources=["microphone"], streaming=True, type='numpy', every=0.1, label="Speak to Ask")
376
+ submit_voice_btn = gr.Button("Submit Voice")
377
+
378
+ with gr.Column():
379
+ audio_output = gr.Audio(label="Audio", type="filepath", autoplay=True, interactive=False)
380
+
381
+ with gr.Row():
382
+ with gr.Column():
383
+ get_response_btn = gr.Button("Get Response")
384
+ with gr.Column():
385
+ clear_state_btn = gr.Button("Clear State")
386
+ with gr.Column():
387
+ generate_audio_btn = gr.Button("Generate Audio")
388
+ with gr.Column():
389
+ clean_btn = gr.Button("Clean")
390
+
391
+ with gr.Row():
392
+ with gr.Column():
393
+ gr.Markdown("<h1 style='color: red;'>Example Prompts</h1>", elem_id="Example-Prompts")
394
+ gr.Examples(examples=examples, fn=insert_prompt, inputs=question_input, outputs=question_input, api_name="api_insert_example")
395
+
396
+
397
+ # Define interactions for the Get Response button
398
+ get_response_btn.click(
399
+ fn=handle_mode_selection,
400
+ inputs=[mode_selection, chatbot, question_input],
401
+ outputs=[chatbot, question_input, audio_output],
402
+ api_name="api_add_message_on_button_click"
403
+ )
404
+
405
+
406
+
407
+
408
+ question_input.submit(
409
+ fn=handle_mode_selection,
410
+ inputs=[mode_selection, chatbot, question_input],
411
+ outputs=[chatbot, question_input, audio_output],
412
+ api_name="api_add_message_on_enter"
413
+ )
414
+
415
+
416
+ submit_voice_btn.click(
417
+ fn=handle_mode_selection,
418
+ inputs=[mode_selection, chatbot, question_input],
419
+ outputs=[chatbot, question_input, audio_output],
420
+ api_name="api_voice_to_voice_translation"
421
+ )
422
+
423
+
424
+
425
+ # Speech-to-Text functionality
426
+ state = gr.State()
427
+ audio_input.stream(
428
+ transcribe_function,
429
+ inputs=[state, audio_input],
430
+ outputs=[state, question_input],
431
+ api_name="api_voice_to_text"
432
+ )
433
 
434
+ generate_audio_btn.click(
435
+ fn=generate_audio_from_last_response,
436
+ inputs=chatbot,
437
+ outputs=audio_output,
438
+ api_name="api_generate_text_to_audio"
439
+ )
440
 
441
+ clean_btn.click(
442
+ fn=clear_fields,
443
+ inputs=[],
444
+ outputs=[chatbot, question_input, audio_output],
445
+ api_name="api_clear_textbox"
446
+ )
447
+
448
+ # Clear state interaction
449
+ clear_state_btn.click(
450
+ fn=clear_transcription_state,
451
+ outputs=[question_input, state],
452
+ api_name="api_clean_state_transcription"
453
  )
454
 
455
+ # Launch the Gradio interface
456
+ demo.launch(show_error=True,share=True)