Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import requests | |
import tempfile | |
import threading | |
import numpy as np | |
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor | |
from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
from langchain_community.vectorstores import Neo4jVector | |
from langchain_community.graphs import Neo4jGraph | |
from langchain_core.prompts import ChatPromptTemplate | |
import time | |
import os | |
from dataclasses import dataclass | |
# Define AppState to store audio state information | |
class AppState: | |
stream: np.ndarray | None = None | |
sampling_rate: int = 0 | |
pause_detected: bool = False | |
started_talking: bool = False | |
# Neo4j setup | |
graph = Neo4jGraph( | |
url="neo4j+s://c62d0d35.databases.neo4j.io", | |
username="neo4j", | |
password="_x8f-_aAQvs2NB0x6s0ZHSh3W_y-HrENDbgStvsUCM0" | |
) | |
# Initialize the vector index with Neo4j | |
vector_index = Neo4jVector.from_existing_graph( | |
OpenAIEmbeddings(api_key=os.environ['OPENAI_API_KEY']), | |
graph=graph, | |
search_type="hybrid", | |
node_label="Document", | |
text_node_properties=["text"], | |
embedding_node_property="embedding", | |
) | |
# Define the ASR model with Whisper | |
model_id = 'openai/whisper-large-v3' | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype).to(device) | |
processor = AutoProcessor.from_pretrained(model_id) | |
pipe_asr = pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
max_new_tokens=128, | |
chunk_length_s=15, | |
batch_size=16, | |
torch_dtype=torch_dtype, | |
device=device, | |
return_timestamps=True | |
) | |
# Function to determine if a pause occurred | |
def determine_pause(audio: np.ndarray, sampling_rate: int, state: AppState) -> bool: | |
"""Take in the stream, determine if a pause happened""" | |
temp_audio = audio | |
dur_vad = len(temp_audio) / sampling_rate # Simulating VAD duration for this example | |
duration = len(audio) / sampling_rate | |
if dur_vad > 0.5 and not state.started_talking: | |
print("Started talking") | |
state.started_talking = True | |
return False | |
print(f"Duration after VAD: {dur_vad:.3f} s") | |
return (duration - dur_vad) > 1 # Adjust the threshold for pause duration as needed | |
# Function to process audio input, detect pauses, and handle state | |
def process_audio(audio: tuple, state: AppState): | |
if state.stream is None: | |
state.stream = audio[1] | |
state.sampling_rate = audio[0] | |
else: | |
state.stream = np.concatenate((state.stream, audio[1])) | |
# Check for a pause in speech | |
pause_detected = determine_pause(state.stream, state.sampling_rate, state) | |
state.pause_detected = pause_detected | |
if state.pause_detected and state.started_talking: | |
# Transcribe the audio when a pause is detected | |
_, transcription, _ = transcribe_function(state.stream, (state.sampling_rate, state.stream)) | |
print(f"Transcription: {transcription}") | |
# Retrieve hybrid response using Neo4j and other methods | |
response_text = retriever(transcription) | |
print(f"Response: {response_text}") | |
# Generate audio from the response text | |
audio_path = generate_audio_elevenlabs(response_text) | |
# Reset state for the next input | |
state.stream = None | |
state.started_talking = False | |
state.pause_detected = False | |
return audio_path, state | |
return None, state | |
# Function to process audio input and transcribe it | |
def transcribe_function(stream, new_chunk): | |
try: | |
sr, y = new_chunk[0], new_chunk[1] | |
except TypeError: | |
print(f"Error chunk structure: {type(new_chunk)}, content: {new_chunk}") | |
return stream, "", None | |
if y is None or len(y) == 0: | |
return stream, "", None | |
y = y.astype(np.float32) | |
max_abs_y = np.max(np.abs(y)) | |
if max_abs_y > 0: | |
y = y / max_abs_y | |
if stream is not None and len(stream) > 0: | |
stream = np.concatenate([stream, y]) | |
else: | |
stream = y | |
result = pipe_asr({"array": stream, "sampling_rate": sr}, return_timestamps=False) | |
full_text = result.get("text", "") | |
return stream, full_text, full_text | |
# Function to generate a full-text search query for Neo4j | |
def generate_full_text_query(input: str) -> str: | |
words = [el for el in input.split() if el] | |
if not words: | |
return "" # Return an empty string or a default query if desired | |
full_text_query = "" | |
for word in words[:-1]: | |
full_text_query += f" {word}~2 AND" | |
full_text_query += f" {words[-1]}~2" | |
return full_text_query.strip() | |
# Function to generate audio with Eleven Labs TTS | |
def generate_audio_elevenlabs(text): | |
XI_API_KEY = os.environ['ELEVENLABS_API'] | |
VOICE_ID = 'ehbJzYLQFpwbJmGkqbnW' | |
tts_url = f"https://api.elevenlabs.io/v1/text-to-speech/{VOICE_ID}/stream" | |
headers = { | |
"Accept": "application/json", | |
"xi-api-key": XI_API_KEY | |
} | |
data = { | |
"text": str(text), | |
"model_id": "eleven_multilingual_v2", | |
"voice_settings": { | |
"stability": 1.0, | |
"similarity_boost": 0.0, | |
"style": 0.60, | |
"use_speaker_boost": False | |
} | |
} | |
response = requests.post(tts_url, headers=headers, json=data, stream=True) | |
if response.ok: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as f: | |
for chunk in response.iter_content(chunk_size=1024): | |
if chunk: | |
f.write(chunk) | |
audio_path = f.name | |
return audio_path | |
else: | |
print(f"Error generating audio: {response.text}") | |
return None | |
# Define the template for generating responses based on context | |
template = """I am a guide for Birmingham, Alabama. I can provide recommendations and insights about the city, including events and activities. | |
Ask your question directly, and I'll provide a precise and quick, short and crisp response in a conversational and straightforward way without any Greet. | |
Context: | |
{context} | |
Question: {question} | |
Answer concisely:""" | |
# Create a prompt object using the template | |
prompt = ChatPromptTemplate.from_template(template) | |
# Function to generate a response using the prompt and the context | |
def generate_response_with_prompt(context, question): | |
formatted_prompt = prompt.format( | |
context=context, | |
question=question | |
) | |
llm = ChatOpenAI(temperature=0, api_key=os.environ['OPENAI_API_KEY']) | |
response = llm(formatted_prompt) | |
return response.content.strip() | |
# Define the function to generate a hybrid response using Neo4j and other retrieval methods | |
def retriever(question: str): | |
structured_query = f""" | |
CALL db.index.fulltext.queryNodes('entity', $query, {{limit: 2}}) | |
YIELD node, score | |
RETURN node.id AS entity, node.text AS context, score | |
ORDER BY score DESC | |
LIMIT 2 | |
""" | |
structured_data = graph.query(structured_query, {"query": generate_full_text_query(question)}) | |
structured_response = "\n".join([f"{record['entity']}: {record['context']}" for record in structured_data]) | |
unstructured_data = [el.page_content for el in vector_index.similarity_search(question)] | |
unstructured_response = "\n".join(unstructured_data) | |
combined_context = f"Structured data:\n{structured_response}\n\nUnstructured data:\n{unstructured_response}" | |
final_response = generate_response_with_prompt(combined_context, question) | |
return final_response | |
# Create Gradio interface for audio input and output | |
interface = gr.Interface( | |
fn=lambda audio, state: process_audio(audio, state), | |
inputs=[ | |
gr.Audio(sources="microphone", type="numpy", streaming=True), | |
gr.State(AppState()) | |
], | |
outputs=[ | |
gr.Audio(type="filepath", autoplay=True, interactive=False), | |
gr.State() | |
], | |
live=True, | |
description="Ask questions via audio and receive audio responses.", | |
allow_flagging="never" | |
) | |
# Launch the Gradio app | |
interface.launch() | |