import gradio as gr import os import logging import requests import tempfile import torch import numpy as np from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor from langchain_community.graphs import Neo4jGraph from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI from pydantic import BaseModel, Field from typing import List import time import torchaudio # Neo4j Setup graph = Neo4jGraph( url="neo4j+s://6457770f.databases.neo4j.io", username="neo4j", password="Z10duoPkKCtENuOukw3eIlvl0xJWKtrVSr-_hGX1LQ4" ) # Define entity extraction and retrieval functions class Entities(BaseModel): names: List[str] = Field( ..., description="All the person, organization, or business entities that appear in the text" ) entity_prompt = ChatPromptTemplate.from_messages([ ("system", "You are extracting organization and person entities from the text."), ("human", "Use the given format to extract information from the following input: {question}"), ]) chat_model = ChatOpenAI(temperature=0, model_name="gpt-4o", api_key=os.environ['OPENAI_API_KEY']) entity_chain = entity_prompt | chat_model.with_structured_output(Entities) def remove_lucene_chars(input: str) -> str: return input.translate(str.maketrans({ "\\": r"\\", "+": r"\+", "-": r"\-", "&": r"\&", "|": r"\|", "!": r"\!", "(": r"\(", ")": r"\)", "{": r"\{", "}": r"\}", "[": r"\[", "]": r"\]", "^": r"\^", "~": r"\~", "*": r"\*", "?": r"\?", ":": r"\:", '"': r'\"', ";": r"\;", " ": r"\ " })) def generate_full_text_query(input: str) -> str: full_text_query = "" words = [el for el in remove_lucene_chars(input).split() if el] for word in words[:-1]: full_text_query += f" {word}~2 AND" full_text_query += f" {words[-1]}~2" return full_text_query.strip() def structured_retriever(question: str) -> str: result = "" entities = entity_chain.invoke({"question": question}) for entity in entities.names: response = graph.query( """CALL db.index.fulltext.queryNodes('entity', $query, {limit:2}) YIELD node,score CALL { WITH node MATCH (node)-[r:!MENTIONS]->(neighbor) RETURN node.id + ' - ' + type(r) + ' -> ' + neighbor.id AS output UNION ALL WITH node MATCH (node)<-[r:!MENTIONS]-(neighbor) RETURN neighbor.id + ' - ' + type(r) + ' -> ' + node.id AS output } RETURN output LIMIT 50 """, {"query": generate_full_text_query(entity)}, ) result += "\n".join([el['output'] for el in response]) return result # Function to generate audio with Eleven Labs TTS def generate_audio_elevenlabs(text): XI_API_KEY = os.environ['ELEVENLABS_API'] VOICE_ID = 'ehbJzYLQFpwbJmGkqbnW' tts_url = f"https://api.elevenlabs.io/v1/text-to-speech/{VOICE_ID}/stream" headers = { "Accept": "application/json", "xi-api-key": XI_API_KEY } data = { "text": str(text), "model_id": "eleven_multilingual_v2", "voice_settings": { "stability": 1.0, "similarity_boost": 0.0, "style": 0.60, "use_speaker_boost": False } } response = requests.post(tts_url, headers=headers, json=data, stream=True) if response.ok: with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as f: for chunk in response.iter_content(chunk_size=1024): if chunk: f.write(chunk) audio_path = f.name logging.debug(f"Audio saved to {audio_path}") return audio_path # Return audio path for automatic playback else: logging.error(f"Error generating audio: {response.text}") return None # Define the ASR model with Whisper model_id = 'openai/whisper-large-v3' device = "cuda:0" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype).to(device) processor = AutoProcessor.from_pretrained(model_id) pipe_asr = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, max_new_tokens=128, chunk_length_s=15, batch_size=16, torch_dtype=torch_dtype, device=device, return_timestamps=True ) # Function to handle audio input, transcribe, fetch from Neo4j, and generate audio response # Function to handle audio input, transcribe, fetch from Neo4j, and generate audio response def transcribe_and_respond(audio): if audio is None: logging.error("No audio provided.") return None, "No audio provided." sr, y = audio y = np.array(y).astype(np.float32) # Resample to 16kHz if needed target_sr = 16000 if sr != target_sr: logging.debug(f"Resampling audio from {sr} Hz to {target_sr} Hz.") y = torchaudio.functional.resample(torch.tensor(y), orig_freq=sr, new_freq=target_sr).numpy() sr = target_sr # Prepare input_features for Whisper model input_features = processor(y, sampling_rate=sr, return_tensors="pt").input_features # Transcribe the audio using Whisper with English language setting result = pipe_asr({"input_features": input_features, "language": "en"}, return_timestamps=False) question = result.get("text", "") # Log the transcribed text for debugging logging.debug(f"Transcribed text: {question}") # Retrieve information from Neo4j response_text = structured_retriever(question) if question else "I didn't understand the question." # Convert the response to audio using Eleven Labs TTS audio_path = generate_audio_elevenlabs(response_text) if response_text else None # Ensure a valid audio path is returned if audio_path and os.path.exists(audio_path): logging.debug(f"Generated audio file path: {audio_path}") else: logging.error("Failed to generate audio or save audio to file.") audio_path = None return audio_path, response_text # Function to clear the transcription state def clear_transcription_state(): return None, None # Define the Gradio interface with only audio input and output with gr.Blocks(theme="rawrsor1/Everforest") as demo: with gr.Row(): audio_input = gr.Audio( sources=["microphone"], type='numpy', label="Speak to Ask" ) audio_output = gr.Audio( label="Audio Response", type="filepath", autoplay=True, interactive=False ) # Submit button to process the audio input submit_btn = gr.Button("Submit") submit_btn.click( fn=transcribe_and_respond, inputs=audio_input, outputs=[audio_output, gr.Textbox(label="Transcription")] ) # Clear state interaction gr.Button("Clear State").click( fn=clear_transcription_state, outputs=[audio_output, gr.Textbox(label="Transcription")], api_name="api_clean_state" ) # Launch the Gradio interface demo.launch(show_error=True, share=True)