Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import logging | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_openai import ChatOpenAI | |
from langchain_community.graphs import Neo4jGraph | |
from typing import List, Tuple | |
from langchain_core.pydantic_v1 import BaseModel, Field | |
from langchain_core.messages import AIMessage, HumanMessage | |
from langchain_core.runnables import ( | |
RunnableBranch, | |
RunnableLambda, | |
RunnablePassthrough, | |
RunnableParallel, | |
) | |
from langchain_core.prompts.prompt import PromptTemplate | |
import requests | |
import tempfile | |
# Setup Neo4j | |
graph = Neo4jGraph( | |
url="neo4j+s://6457770f.databases.neo4j.io", | |
username="neo4j", | |
password="Z10duoPkKCtENuOukw3eIlvl0xJWKtrVSr-_hGX1LQ4" | |
) | |
# Define entity extraction and retrieval functions | |
class Entities(BaseModel): | |
names: List[str] = Field( | |
..., description="All the person, organization, or business entities that appear in the text" | |
) | |
entity_prompt = ChatPromptTemplate.from_messages([ | |
("system", "You are extracting organization and person entities from the text."), | |
("human", "Use the given format to extract information from the following input: {question}"), | |
]) | |
chat_model = ChatOpenAI(temperature=0, model_name="gpt-4o", api_key=os.environ['OPENAI_API_KEY']) | |
entity_chain = entity_prompt | chat_model.with_structured_output(Entities) | |
def remove_lucene_chars(input: str) -> str: | |
return input.translate(str.maketrans({ | |
"\\": r"\\", "+": r"\+", "-": r"\-", "&": r"\&", "|": r"\|", "!": r"\!", | |
"(": r"\(", ")": r"\)", "{": r"\{", "}": r"\}", "[": r"\[", "]": r"\]", | |
"^": r"\^", "~": r"\~", "*": r"\*", "?": r"\?", ":": r"\:", '"': r'\"', | |
";": r"\;", " ": r"\ " | |
})) | |
def generate_full_text_query(input: str) -> str: | |
full_text_query = "" | |
words = [el for el in remove_lucene_chars(input).split() if el] | |
for word in words[:-1]: | |
full_text_query += f" {word}~2 AND" | |
full_text_query += f" {words[-1]}~2" | |
return full_text_query.strip() | |
def structured_retriever(question: str) -> str: | |
result = "" | |
entities = entity_chain.invoke({"question": question}) | |
for entity in entities.names: | |
response = graph.query( | |
"""CALL db.index.fulltext.queryNodes('entity', $query, {limit:2}) | |
YIELD node,score | |
CALL { | |
WITH node | |
MATCH (node)-[r:!MENTIONS]->(neighbor) | |
RETURN node.id + ' - ' + type(r) + ' -> ' + neighbor.id AS output | |
UNION ALL | |
WITH node | |
MATCH (node)<-[r:!MENTIONS]-(neighbor) | |
RETURN neighbor.id + ' - ' + type(r) + ' -> ' + node.id AS output | |
} | |
RETURN output LIMIT 50 | |
""", | |
{"query": generate_full_text_query(entity)}, | |
) | |
result += "\n".join([el['output'] for el in response]) | |
return result | |
def retriever_neo4j(question: str): | |
structured_data = structured_retriever(question) | |
logging.debug(f"Structured data: {structured_data}") | |
return structured_data | |
# Setup for condensing the follow-up questions | |
_template = """Given the following conversation and a follow-up question, rephrase the follow-up question to be a standalone question, | |
in its original language. | |
Chat History: | |
{chat_history} | |
Follow Up Input: {question} | |
Standalone question:""" | |
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template) | |
def _format_chat_history(chat_history: list[tuple[str, str]]) -> list: | |
buffer = [] | |
for human, ai in chat_history: | |
buffer.append(HumanMessage(content=human)) | |
buffer.append(AIMessage(content=ai)) | |
return buffer | |
_search_query = RunnableBranch( | |
( | |
RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config( | |
run_name="HasChatHistoryCheck" | |
), | |
RunnablePassthrough.assign( | |
chat_history=lambda x: _format_chat_history(x["chat_history"]) | |
) | |
| CONDENSE_QUESTION_PROMPT | |
| ChatOpenAI(temperature=0, api_key=os.environ['OPENAI_API_KEY']) | |
| StrOutputParser(), | |
), | |
RunnableLambda(lambda x: x["question"]), | |
) | |
# Define the QA prompt template | |
template = """As an expert concierge known for being helpful and a renowned guide for Birmingham, Alabama, I assist visitors in discovering the best that the city has to offer. I also assist the visitors about various sports and activities. Given today's sunny and bright weather, I am well-equipped to provide valuable insights and recommendations without revealing specific locations. I draw upon my extensive knowledge of the area, including perennial events and historical context. | |
In light of this, how can I assist you today? Feel free to ask any questions or seek recommendations for your day in Birmingham. If there's anything specific you'd like to know or experience, please share, and I'll be glad to help. Remember, keep the question concise for a quick, short, crisp, and accurate response and dont greet. | |
"It was my pleasure!" | |
{context} | |
Question: {question} | |
Helpful Answer:""" | |
qa_prompt = ChatPromptTemplate.from_template(template) | |
# Define the chain for Neo4j-based retrieval and response generation | |
chain_neo4j = ( | |
RunnableParallel( | |
{ | |
"context": _search_query | retriever_neo4j, | |
"question": RunnablePassthrough(), | |
} | |
) | |
| qa_prompt | |
| chat_model | |
| StrOutputParser() | |
) | |
# Define the function to get the response | |
def get_response(question): | |
try: | |
return chain_neo4j.invoke({"question": question}) | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Define the function to clear input and output | |
def clear_fields(): | |
return "", "" | |
# Function to generate audio with Eleven Labs TTS | |
def generate_audio_elevenlabs(text): | |
XI_API_KEY = os.environ['ELEVENLABS_API'] | |
VOICE_ID = 'ehbJzYLQFpwbJmGkqbnW' | |
tts_url = f"https://api.elevenlabs.io/v1/text-to-speech/{VOICE_ID}/stream" | |
headers = { | |
"Accept": "application/json", | |
"xi-api-key": XI_API_KEY | |
} | |
data = { | |
"text": str(text), | |
"model_id": "eleven_multilingual_v2", | |
"voice_settings": { | |
"stability": 1.0, | |
"similarity_boost": 0.0, | |
"style": 0.60, | |
"use_speaker_boost": False | |
} | |
} | |
response = requests.post(tts_url, headers=headers, json=data, stream=True) | |
if response.ok: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as f: | |
for chunk in response.iter_content(chunk_size=1024): | |
if chunk: | |
f.write(chunk) | |
audio_path = f.name | |
logging.debug(f"Audio saved to {audio_path}") | |
return audio_path | |
else: | |
logging.error(f"Error generating audio: {response.text}") | |
return None | |
# Create the Gradio Blocks interface | |
with gr.Blocks() as demo: | |
question_input = gr.Textbox(label="Ask a Question", placeholder="Type your question here...") | |
response_output = gr.Textbox(label="Response", placeholder="The response will appear here...", interactive=False) | |
audio_output = gr.Audio(label="Audio", type="filepath", interactive=False) | |
get_response_btn = gr.Button("Get Response") | |
generate_audio_btn = gr.Button("Generate Audio") | |
clean_btn = gr.Button("Clean") | |
get_response_btn.click(fn=get_response, inputs=question_input, outputs=response_output) | |
generate_audio_btn.click(fn=generate_audio_elevenlabs, inputs=response_output, outputs=audio_output) | |
clean_btn.click(fn=clear_fields, inputs=[], outputs=[question_input, response_output]) | |
# Launch the Gradio interface | |
demo.launch(show_error=True) |