Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import logging | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_openai import ChatOpenAI | |
from langchain_community.graphs import Neo4jGraph | |
from typing import List, Tuple | |
from pydantic import BaseModel, Field | |
from langchain_core.messages import AIMessage, HumanMessage | |
from langchain_core.runnables import ( | |
RunnableBranch, | |
RunnableLambda, | |
RunnablePassthrough, | |
RunnableParallel, | |
) | |
from langchain_core.prompts.prompt import PromptTemplate | |
import requests | |
import tempfile | |
from langchain.memory import ConversationBufferWindowMemory | |
import time | |
import logging | |
from langchain.chains import ConversationChain | |
import torch | |
import torchaudio | |
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor | |
import numpy as np | |
conversational_memory = ConversationBufferWindowMemory( | |
memory_key='chat_history', | |
k=10, | |
return_messages=True | |
) | |
# Setup Neo4j | |
graph = Neo4jGraph( | |
url="neo4j+s://6457770f.databases.neo4j.io", | |
username="neo4j", | |
password="Z10duoPkKCtENuOukw3eIlvl0xJWKtrVSr-_hGX1LQ4" | |
) | |
# Define entity extraction and retrieval functions | |
class Entities(BaseModel): | |
names: List[str] = Field( | |
..., description="All the person, organization, or business entities that appear in the text" | |
) | |
entity_prompt = ChatPromptTemplate.from_messages([ | |
("system", "You are extracting organization and person entities from the text."), | |
("human", "Use the given format to extract information from the following input: {question}"), | |
]) | |
chat_model = ChatOpenAI(temperature=0, model_name="gpt-4o", api_key=os.environ['OPENAI_API_KEY']) | |
entity_chain = entity_prompt | chat_model.with_structured_output(Entities) | |
def remove_lucene_chars(input: str) -> str: | |
return input.translate(str.maketrans({ | |
"\\": r"\\", "+": r"\+", "-": r"\-", "&": r"\&", "|": r"\|", "!": r"\!", | |
"(": r"\(", ")": r"\)", "{": r"\{", "}": r"\}", "[": r"\[", "]": r"\]", | |
"^": r"\^", "~": r"\~", "*": r"\*", "?": r"\?", ":": r"\:", '"': r'\"', | |
";": r"\;", " ": r"\ " | |
})) | |
def generate_full_text_query(input: str) -> str: | |
full_text_query = "" | |
words = [el for el in remove_lucene_chars(input).split() if el] | |
for word in words[:-1]: | |
full_text_query += f" {word}~2 AND" | |
full_text_query += f" {words[-1]}~2" | |
return full_text_query.strip() | |
# Setup logging to a file to capture debug information | |
logging.basicConfig(filename='neo4j_retrieval.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') | |
def structured_retriever(question: str) -> str: | |
result = "" | |
entities = entity_chain.invoke({"question": question}) | |
for entity in entities.names: | |
response = graph.query( | |
"""CALL db.index.fulltext.queryNodes('entity', $query, {limit:2}) | |
YIELD node,score | |
CALL { | |
WITH node | |
MATCH (node)-[r:!MENTIONS]->(neighbor) | |
RETURN node.id + ' - ' + type(r) + ' -> ' + neighbor.id AS output | |
UNION ALL | |
WITH node | |
MATCH (node)<-[r:!MENTIONS]-(neighbor) | |
RETURN neighbor.id + ' - ' + type(r) + ' -> ' + node.id AS output | |
} | |
RETURN output LIMIT 50 | |
""", | |
{"query": generate_full_text_query(entity)}, | |
) | |
result += "\n".join([el['output'] for el in response]) | |
return result | |
def retriever_neo4j(question: str): | |
structured_data = structured_retriever(question) | |
logging.debug(f"Structured data: {structured_data}") | |
return structured_data | |
# Setup for condensing the follow-up questions | |
_template = """Given the following conversation and a follow-up question, rephrase the follow-up question to be a standalone question, | |
in its original language. | |
Chat History: | |
{chat_history} | |
Follow Up Input: {question} | |
Standalone question:""" | |
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template) | |
def _format_chat_history(chat_history: list[tuple[str, str]]) -> list: | |
buffer = [] | |
for human, ai in chat_history: | |
buffer.append(HumanMessage(content=human)) | |
buffer.append(AIMessage(content=ai)) | |
return buffer | |
_search_query = RunnableBranch( | |
( | |
RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config( | |
run_name="HasChatHistoryCheck" | |
), | |
RunnablePassthrough.assign( | |
chat_history=lambda x: _format_chat_history(x["chat_history"]) | |
) | |
| CONDENSE_QUESTION_PROMPT | |
| ChatOpenAI(temperature=0, api_key=os.environ['OPENAI_API_KEY']) | |
| StrOutputParser(), | |
), | |
RunnableLambda(lambda x: x["question"]), | |
) | |
template = """I am a guide for Birmingham, Alabama. I can provide recommendations and insights about the city, including events and activities. | |
Ask your question directly, and I'll provide a precise and quick,short and crisp response in a conversational way without any Greet. | |
{context} | |
Question: {question} | |
Answer:""" | |
qa_prompt = ChatPromptTemplate.from_template(template) | |
# Define the chain for Neo4j-based retrieval and response generation | |
chain_neo4j = ( | |
RunnableParallel( | |
{ | |
"context": _search_query | retriever_neo4j, | |
"question": RunnablePassthrough(), | |
} | |
) | |
| qa_prompt | |
| chat_model | |
| StrOutputParser() | |
) | |
# Define the function to get the response | |
def get_response(question): | |
try: | |
return chain_neo4j.invoke({"question": question}) | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Define the function to clear input and output | |
def clear_fields(): | |
return [],"",None | |
# Function to generate audio with Eleven Labs TTS | |
def generate_audio_elevenlabs(text): | |
XI_API_KEY = os.environ['ELEVENLABS_API'] | |
VOICE_ID = 'ehbJzYLQFpwbJmGkqbnW' | |
tts_url = f"https://api.elevenlabs.io/v1/text-to-speech/{VOICE_ID}/stream" | |
headers = { | |
"Accept": "application/json", | |
"xi-api-key": XI_API_KEY | |
} | |
data = { | |
"text": str(text), | |
"model_id": "eleven_multilingual_v2", | |
"voice_settings": { | |
"stability": 1.0, | |
"similarity_boost": 0.0, | |
"style": 0.60, | |
"use_speaker_boost": False | |
} | |
} | |
response = requests.post(tts_url, headers=headers, json=data, stream=True) | |
if response.ok: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as f: | |
for chunk in response.iter_content(chunk_size=1024): | |
if chunk: | |
f.write(chunk) | |
audio_path = f.name | |
logging.debug(f"Audio saved to {audio_path}") | |
return audio_path # Return audio path for automatic playback | |
else: | |
logging.error(f"Error generating audio: {response.text}") | |
return None | |
# Function to add a user's message to the chat history and clear the input box | |
def add_message(history, message): | |
if message.strip(): | |
history.append((message, None)) # Add the user's message to the chat history only if it's not empty | |
return history, "" # Clear the input box | |
# Define function to generate a streaming response | |
def chat_with_bot(messages): | |
user_message = messages[-1][0] # Get the last user message (input) | |
messages[-1] = (user_message, "") # Prepare the placeholder for the bot's response | |
response = get_response(user_message) | |
# Simulate streaming response by iterating over each character in the response | |
for character in response: | |
messages[-1] = (user_message, messages[-1][1] + character) | |
yield messages # Stream each character | |
time.sleep(0.05) # Adjust delay as needed for real-time effect | |
yield messages # Final yield to ensure the full response is displayed | |
# Function to generate audio with Eleven Labs TTS from the last bot response | |
def generate_audio_from_last_response(history): | |
# Get the most recent bot response from the chat history | |
if history and len(history) > 0: | |
recent_response = history[-1][1] # The second item in the tuple is the bot response text | |
if recent_response: | |
return generate_audio_elevenlabs(recent_response) | |
return None | |
# Define example prompts | |
examples = [ | |
["What are some popular events in Birmingham?"], | |
["Who are the top players of the Crimson Tide?"], | |
["Where can I find a hamburger?"], | |
["What are some popular tourist attractions in Birmingham?"], | |
["What are some good clubs in Birmingham?"] | |
] | |
# Function to insert the prompt into the textbox when clicked | |
def insert_prompt(current_text, prompt): | |
return prompt[0] if prompt else current_text | |
# Define the ASR model with Whisper | |
model_id = 'openai/whisper-large-v3' | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype).to(device) | |
processor = AutoProcessor.from_pretrained(model_id) | |
pipe_asr = pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
max_new_tokens=128, | |
chunk_length_s=15, | |
batch_size=16, | |
torch_dtype=torch_dtype, | |
device=device, | |
return_timestamps=True | |
) | |
def transcribe_function(stream, new_chunk): | |
try: | |
sr, y = new_chunk[0], new_chunk[1] | |
except TypeError: | |
print(f"Error chunk structure: {type(new_chunk)}, content: {new_chunk}") | |
return stream, "", None | |
y = y.astype(np.float32) | |
max_abs_y = np.max(np.abs(y)) | |
if max_abs_y > 0: | |
y = y / max_abs_y | |
if stream is not None: | |
stream = np.concatenate([stream, y]) | |
else: | |
stream = y | |
result = pipe_asr({"array": stream, "sampling_rate": sr}, return_timestamps=False) | |
full_text = result.get("text", "") | |
return stream, full_text, full_text | |
# Create the Gradio Blocks interface | |
with gr.Blocks(theme="rawrsor1/Everforest") as demo: | |
chatbot = gr.Chatbot([], elem_id="RADAR", bubble_full_width=False) | |
with gr.Row(): | |
with gr.Column(): | |
question_input = gr.Textbox(label="Ask a Question", placeholder="Type your question here...") | |
audio_input = gr.Audio(sources=["microphone"],streaming=True,type='numpy',every=0.1,label="Speak to Ask") | |
with gr.Column(): | |
audio_output = gr.Audio(label="Audio", type="filepath", interactive=False) | |
with gr.Row(): | |
with gr.Column(): | |
get_response_btn = gr.Button("Get Response") | |
with gr.Column(): | |
generate_audio_btn = gr.Button("Generate Audio") | |
with gr.Column(): | |
clean_btn = gr.Button("Clean") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("<h1 style='color: red;'>Example Prompts</h1>", elem_id="Example-Prompts") | |
gr.Examples(examples=examples, fn=insert_prompt, inputs=question_input, outputs=question_input) | |
# Define interactions | |
# Define interactions for clicking the button | |
get_response_btn.click(fn=add_message, inputs=[chatbot, question_input], outputs=[chatbot, question_input])\ | |
.then(fn=chat_with_bot, inputs=[chatbot], outputs=chatbot) | |
# Define interaction for hitting the Enter key | |
question_input.submit(fn=add_message, inputs=[chatbot, question_input], outputs=[chatbot, question_input])\ | |
.then(fn=chat_with_bot, inputs=[chatbot], outputs=chatbot) | |
# Speech-to-Text functionality | |
state = gr.State() | |
audio_input.stream(transcribe_function, inputs=[state, audio_input], outputs=[state, question_input]) | |
generate_audio_btn.click(fn=generate_audio_from_last_response, inputs=chatbot, outputs=audio_output) | |
clean_btn.click(fn=clear_fields, inputs=[], outputs=[chatbot, question_input, audio_output]) | |
# Launch the Gradio interface | |
demo.launch(show_error=True) | |