Pijush2023's picture
Update app.py
3e8507d verified
raw
history blame
7.74 kB
import gradio as gr
import os
import logging
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_community.graphs import Neo4jGraph
from langchain_groq import ChatGroq
from langchain.chains import GraphCypherQAChain
from pydantic import BaseModel, Field
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import (
RunnableBranch,
RunnableLambda,
RunnablePassthrough,
RunnableParallel,
)
from langchain_core.prompts.prompt import PromptTemplate
import tempfile
import time
import threading
import torch
import numpy as np
import requests
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
# Setup logging to a file to capture debug information
logging.basicConfig(filename='neo4j_retrieval.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
# Setup for conversational memory
conversational_memory = ConversationBufferWindowMemory(
memory_key='chat_history',
k=10,
return_messages=True
)
# Setup Neo4j
graph = Neo4jGraph(
url="neo4j+s://6457770f.databases.neo4j.io",
username="neo4j",
password="Z10duoPkKCtENuOukw3eIlvl0xJWKtrVSr-_hGX1LQ4"
)
# Setup the Groq model
groq_api_key = os.getenv('GROQ_API_KEY')
llm = ChatGroq(groq_api_key=groq_api_key, model_name="Gemma2-9b-It")
# Define entity extraction and retrieval functions
class Entities(BaseModel):
names: List[str] = Field(
..., description="All the person, organization, or business entities that appear in the text"
)
entity_prompt = ChatPromptTemplate.from_messages([
("system", "You are extracting organization and person entities from the text."),
("human", "Use the given format to extract information from the following input: {question}"),
])
entity_chain = entity_prompt | llm.with_structured_output(Entities)
def remove_lucene_chars(input: str) -> str:
return input.translate(str.maketrans({
"\\": r"\\", "+": r"\+", "-": r"\-", "&": r"\&", "|": r"\|", "!": r"\!",
"(": r"\(", ")": r"\)", "{": r"\{", "}": r"\}", "[": r"\[", "]": r"\]",
"^": r"\^", "~": r"\~", "*": r"\*", "?": r"\?", ":": r"\:", '"': r'\"',
";": r"\;", " ": r"\ "
}))
def generate_full_text_query(input: str) -> str:
full_text_query = ""
words = [el for el in remove_lucene_chars(input).split() if el]
for word in words[:-1]:
full_text_query += f" {word}~2 AND"
full_text_query += f" {words[-1]}~2"
return full_text_query.strip()
def structured_retriever(question: str) -> str:
result = ""
entities = entity_chain.invoke({"question": question})
for entity in entities.names:
response = graph.query(
"""CALL db.index.fulltext.queryNodes('entity', $query, {limit:2})
YIELD node,score
CALL {
WITH node
MATCH (node)-[r:!MENTIONS]->(neighbor)
RETURN node.id + ' - ' + type(r) + ' -> ' + neighbor.id AS output
UNION ALL
WITH node
MATCH (node)<-[r:!MENTIONS]-(neighbor)
RETURN neighbor.id + ' - ' + type(r) + ' -> ' + node.id AS output
}
RETURN output LIMIT 50
""",
{"query": generate_full_text_query(entity)},
)
result += "\n".join([el['output'] for el in response])
return result
def retriever_neo4j(question: str):
structured_data = structured_retriever(question)
logging.debug(f"Structured data: {structured_data}")
return structured_data
# Condense follow-up questions to standalone
_template = """Given the following conversation and a follow-up question, rephrase the follow-up question to be a standalone question,
in its original language.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
def _format_chat_history(chat_history: list[tuple[str, str]]) -> list:
buffer = []
for human, ai in chat_history:
buffer.append(HumanMessage(content=human))
buffer.append(AIMessage(content=ai))
return buffer
_search_query = RunnableBranch(
(
RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
run_name="HasChatHistoryCheck"
),
RunnablePassthrough.assign(
chat_history=lambda x: _format_chat_history(x["chat_history"])
)
| CONDENSE_QUESTION_PROMPT
| llm
| StrOutputParser(),
),
RunnableLambda(lambda x: x["question"]),
)
# Define the prompt for response generation
template = """I am a guide for Birmingham, Alabama. I can provide recommendations and insights about the city, including events and activities.
Ask your question directly, and I'll provide a precise, short, and crisp response in a conversational way without any greeting.
{context}
Question: {question}
Answer:"""
qa_prompt = ChatPromptTemplate.from_template(template)
# Define the chain for Neo4j-based retrieval and response generation
chain_neo4j = (
RunnableParallel(
{
"context": _search_query | retriever_neo4j,
"question": RunnablePassthrough(),
}
)
| qa_prompt
| llm
| StrOutputParser()
)
# Define the function to get the response
def get_response(question):
try:
return chain_neo4j.invoke({"question": question})
except Exception as e:
logging.error(f"Error generating response: {str(e)}")
return f"Error: {str(e)}"
# Define the function to clear input and output
def clear_fields():
return [], "", None
# Function to generate audio with Eleven Labs TTS
def generate_audio_elevenlabs(text):
XI_API_KEY = os.environ['ELEVENLABS_API']
VOICE_ID = 'ehbJzYLQFpwbJmGkqbnW'
tts_url = f"https://api.elevenlabs.io/v1/text-to-speech/{VOICE_ID}/stream"
headers = {"Accept": "application/json", "xi-api-key": XI_API_KEY}
data = {"text": str(text), "model_id": "eleven_multilingual_v2", "voice_settings": {"stability": 1.0}}
response = requests.post(tts_url, headers=headers, json=data, stream=True)
if response.ok:
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as f:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
audio_path = f.name
logging.debug(f"Audio saved to {audio_path}")
return audio_path # Return audio path for playback
else:
logging.error(f"Error generating audio: {response.text}")
return None
# Create the Gradio Blocks interface
with gr.Blocks(theme="rawrsor1/Everforest") as demo:
chatbot = gr.Chatbot([], elem_id="RADAR", bubble_full_width=False)
with gr.Row():
with gr.Column():
question_input = gr.Textbox(label="Ask a Question", placeholder="Type your question here...")
with gr.Column():
audio_output = gr.Audio(label="Audio", type="filepath", autoplay=True, interactive=False)
with gr.Row():
with gr.Column():
get_response_btn = gr.Button("Get Response")
with gr.Column():
generate_audio_btn = gr.Button("Generate Audio")
with gr.Column():
clear_state_btn = gr.Button("Clear State")
# Define interactions for buttons
get_response_btn.click(fn=get_response, inputs=question_input, outputs=chatbot)
generate_audio_btn.click(fn=generate_audio_elevenlabs, inputs=chatbot, outputs=audio_output)
clear_state_btn.click(fn=clear_fields, outputs=[chatbot, question_input, audio_output])
# Launch the Gradio interface
demo.launch(show_error=True)