Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import logging | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_openai import ChatOpenAI | |
from langchain_community.graphs import Neo4jGraph | |
from typing import List | |
from pydantic import BaseModel, Field | |
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor | |
import requests | |
import tempfile | |
import torch | |
import numpy as np | |
# Setup logging to a file to capture debug information | |
logging.basicConfig(filename='neo4j_retrieval.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') | |
# Setup Neo4j connection | |
graph = Neo4jGraph( | |
url="neo4j+s://c62d0d35.databases.neo4j.io", | |
username="neo4j", | |
password="_x8f-_aAQvs2NB0x6s0ZHSh3W_y-HrENDbgStvsUCM0" | |
) | |
# Define entity extraction and retrieval functions | |
class Entities(BaseModel): | |
names: List[str] = Field( | |
..., description="All the person, organization, or business entities that appear in the text" | |
) | |
# Define prompt and model for entity extraction | |
chat_model = ChatOpenAI(temperature=0, model_name="gpt-4", api_key=os.environ['OPENAI_API_KEY']) | |
entity_prompt = ChatPromptTemplate.from_messages([ | |
("system", "You are extracting organization and person entities from the text."), | |
("human", "Use the given format to extract information from the following input: {question}"), | |
]) | |
entity_chain = entity_prompt | chat_model.with_structured_output(Entities) | |
def remove_lucene_chars(input: str) -> str: | |
return input.translate(str.maketrans({ | |
"\\": r"\\", "+": r"\+", "-": r"\-", "&": r"\&", "|": r"\|", "!": r"\!", | |
"(": r"\(", ")": r"\)", "{": r"\{", "}": r"\}", "[": r"\[", "]": r"\]", | |
"^": r"\^", "~": r"\~", "*": r"\*", "?": r"\?", ":": r"\:", '"': r'\"', | |
";": r"\;", " ": r"\ " | |
})) | |
def generate_full_text_query(input: str) -> str: | |
full_text_query = "" | |
words = [el for el in remove_lucene_chars(input).split() if el] | |
for word in words[:-1]: | |
full_text_query += f" {word}~2 AND" | |
full_text_query += f" {words[-1]}~2" | |
return full_text_query.strip() | |
def retrieve_data_from_neo4j(question: str) -> str: | |
result = "" | |
entities = entity_chain.invoke({"question": question}) | |
for entity in entities.names: | |
response = graph.query( | |
"""CALL db.index.fulltext.queryNodes('entity', $query, {limit:2}) | |
YIELD node,score | |
CALL { | |
WITH node | |
MATCH (node)-[r:!MENTIONS]->(neighbor) | |
RETURN node.id + ' - ' + type(r) + ' -> ' + neighbor.id AS output | |
} | |
RETURN output LIMIT 50 | |
""", | |
{"query": generate_full_text_query(entity)}, | |
) | |
result += "\n".join([el['output'] for el in response]) | |
return result | |
# Function to generate audio with Eleven Labs TTS | |
def generate_audio_elevenlabs(text): | |
XI_API_KEY = os.environ['ELEVENLABS_API'] | |
VOICE_ID = 'ehbJzYLQFpwbJmGkqbnW' | |
tts_url = f"https://api.elevenlabs.io/v1/text-to-speech/{VOICE_ID}/stream" | |
headers = {"Accept": "application/json", "xi-api-key": XI_API_KEY} | |
data = {"text": str(text), "model_id": "eleven_multilingual_v2", "voice_settings": {"stability": 1.0}} | |
response = requests.post(tts_url, headers=headers, json=data, stream=True) | |
if response.ok: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as f: | |
for chunk in response.iter_content(chunk_size=1024): | |
if chunk: | |
f.write(chunk) | |
return f.name | |
return None | |
# ASR model setup using Whisper | |
model_id = 'openai/whisper-large-v3' | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype).to(device) | |
processor = AutoProcessor.from_pretrained(model_id) | |
pipe_asr = pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
max_new_tokens=128, | |
chunk_length_s=15, | |
batch_size=16, | |
torch_dtype=torch_dtype, | |
device=device | |
) | |
# Function to handle audio input, transcription, and Neo4j response generation | |
def transcribe_and_respond(audio): | |
# Transcribe audio input | |
audio_data = {"array": audio["data"], "sampling_rate": audio["sample_rate"]} | |
transcription = pipe_asr(audio_data)["text"] | |
logging.debug(f"Transcription: {transcription}") | |
# Retrieve data from Neo4j based on transcription | |
response_text = retrieve_data_from_neo4j(transcription) | |
logging.debug(f"Neo4j Response: {response_text}") | |
# Convert response to audio | |
return generate_audio_elevenlabs(response_text) | |
# Define Gradio interface | |
with gr.Blocks() as demo: | |
audio_input = gr.Audio(sources="microphone", type="numpy", label="Speak to Ask") # Removed streaming mode for manual submission | |
audio_output = gr.Audio(label="Response", type="filepath", autoplay=True, interactive=False) | |
# "Submit Audio" button | |
submit_button = gr.Button("Submit Audio") | |
# Link the button to trigger response generation after clicking | |
submit_button.click( | |
fn=transcribe_and_respond, | |
inputs=audio_input, | |
outputs=audio_output | |
) | |
# Launch Gradio interface | |
demo.launch(show_error=True, share=True) |