Pijush2023's picture
Update app.py
3595ee8 verified
raw
history blame
9.08 kB
import gradio as gr
import torch
import requests
import tempfile
import threading
import numpy as np
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import Neo4jVector
from langchain_community.graphs import Neo4jGraph
from langchain_core.prompts import ChatPromptTemplate
import time
import os
from dataclasses import dataclass
# Define AppState to store audio state information
@dataclass
class AppState:
stream: np.ndarray | None = None
sampling_rate: int = 0
pause_detected: bool = False
started_talking: bool = False
# Neo4j setup
graph = Neo4jGraph(
url="neo4j+s://c62d0d35.databases.neo4j.io",
username="neo4j",
password="_x8f-_aAQvs2NB0x6s0ZHSh3W_y-HrENDbgStvsUCM0"
)
# Initialize the vector index with Neo4j
vector_index = Neo4jVector.from_existing_graph(
OpenAIEmbeddings(api_key=os.environ['OPENAI_API_KEY']),
graph=graph,
search_type="hybrid",
node_label="Document",
text_node_properties=["text"],
embedding_node_property="embedding",
)
# Define the ASR model with Whisper
model_id = 'openai/whisper-large-v3'
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype).to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe_asr = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=15,
batch_size=16,
torch_dtype=torch_dtype,
device=device,
return_timestamps=True
)
# Adjusted function to determine if a pause occurred
def determine_pause(audio: np.ndarray, sampling_rate: int, state: AppState) -> bool:
"""Take in the stream, determine if a pause happened."""
temp_audio = audio
dur_vad = len(temp_audio) / sampling_rate # Simulating VAD duration for this example
duration = len(audio) / sampling_rate
# Log the duration and VAD result for debugging
print(f"Duration after VAD: {dur_vad:.3f} s, Total Duration: {duration:.3f} s")
# Check if speech has started
if dur_vad > 0.5 and not state.started_talking:
print("Started talking")
state.started_talking = True
return False
# If the difference between total duration and VAD duration is significant, consider it a pause
# Adjust the threshold for pause detection (e.g., 0.5 seconds)
pause_threshold = 0.5 # This value can be adjusted to be more sensitive
if (duration - dur_vad) > pause_threshold and state.started_talking:
print("Pause detected")
return True
return False
# Function to process audio input, detect pauses, and handle state
def process_audio(audio: tuple, state: AppState):
# Ensure audio input is not None and has valid data
if audio is None or audio[1] is None:
print("Audio input is None or empty.")
return None, state
if state.stream is None:
state.stream = audio[1]
state.sampling_rate = audio[0]
else:
state.stream = np.concatenate((state.stream, audio[1]))
# Check for a pause in speech
pause_detected = determine_pause(state.stream, state.sampling_rate, state)
state.pause_detected = pause_detected
if state.pause_detected and state.started_talking:
# Transcribe the audio when a pause is detected
_, transcription, _ = transcribe_function(state.stream, (state.sampling_rate, state.stream))
print(f"Transcription: {transcription}")
# Check if transcription is empty
if not transcription:
print("No transcription available.")
return None, state
# Retrieve hybrid response using Neo4j and other methods
response_text = retriever(transcription)
print(f"Response: {response_text}")
# Check if the response is empty before proceeding
if not response_text:
print("No response generated.")
return None, state
# Generate audio from the response text
audio_path = generate_audio_elevenlabs(response_text)
# Reset state for the next input
state.stream = None
state.started_talking = False
state.pause_detected = False
return audio_path, state
return None, state
# Function to process audio input and transcribe it
def transcribe_function(stream, new_chunk):
try:
sr, y = new_chunk[0], new_chunk[1]
except TypeError:
print(f"Error chunk structure: {type(new_chunk)}, content: {new_chunk}")
return stream, "", None
if y is None or len(y) == 0:
return stream, "", None
y = y.astype(np.float32)
max_abs_y = np.max(np.abs(y))
if max_abs_y > 0:
y = y / max_abs_y
if stream is not None and len(stream) > 0:
stream = np.concatenate([stream, y])
else:
stream = y
result = pipe_asr({"array": stream, "sampling_rate": sr}, return_timestamps=False)
full_text = result.get("text", "")
return stream, full_text, full_text
# Function to generate a full-text search query for Neo4j
def generate_full_text_query(input: str) -> str:
words = [el for el in input.split() if el]
if not words:
return "" # Return an empty string or a default query if desired
full_text_query = ""
for word in words[:-1]:
full_text_query += f" {word}~2 AND"
full_text_query += f" {words[-1]}~2"
return full_text_query.strip()
# Function to generate audio with Eleven Labs TTS
def generate_audio_elevenlabs(text):
XI_API_KEY = os.environ['ELEVENLABS_API']
VOICE_ID = 'ehbJzYLQFpwbJmGkqbnW'
tts_url = f"https://api.elevenlabs.io/v1/text-to-speech/{VOICE_ID}/stream"
headers = {
"Accept": "application/json",
"xi-api-key": XI_API_KEY
}
data = {
"text": str(text),
"model_id": "eleven_multilingual_v2",
"voice_settings": {
"stability": 1.0,
"similarity_boost": 0.0,
"style": 0.60,
"use_speaker_boost": False
}
}
response = requests.post(tts_url, headers=headers, json=data, stream=True)
if response.ok:
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as f:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
audio_path = f.name
return audio_path
else:
print(f"Error generating audio: {response.text}")
return None
# Define the template for generating responses based on context
template = """I am a guide for Birmingham, Alabama. I can provide recommendations and insights about the city, including events and activities.
Ask your question directly, and I'll provide a precise and quick, short and crisp response in a conversational and straightforward way without any Greet.
Context:
{context}
Question: {question}
Answer concisely:"""
# Create a prompt object using the template
prompt = ChatPromptTemplate.from_template(template)
# Function to generate a response using the prompt and the context
def generate_response_with_prompt(context, question):
formatted_prompt = prompt.format(
context=context,
question=question
)
llm = ChatOpenAI(temperature=0, api_key=os.environ['OPENAI_API_KEY'])
response = llm(formatted_prompt)
return response.content.strip()
# Define the function to generate a hybrid response using Neo4j and other retrieval methods
def retriever(question: str):
structured_query = f"""
CALL db.index.fulltext.queryNodes('entity', $query, {{limit: 2}})
YIELD node, score
RETURN node.id AS entity, node.text AS context, score
ORDER BY score DESC
LIMIT 2
"""
structured_data = graph.query(structured_query, {"query": generate_full_text_query(question)})
structured_response = "\n".join([f"{record['entity']}: {record['context']}" for record in structured_data])
unstructured_data = [el.page_content for el in vector_index.similarity_search(question)]
unstructured_response = "\n".join(unstructured_data)
combined_context = f"Structured data:\n{structured_response}\n\nUnstructured data:\n{unstructured_response}"
final_response = generate_response_with_prompt(combined_context, question)
return final_response
# Create Gradio interface for audio input and output
interface = gr.Interface(
fn=lambda audio, state: process_audio(audio, state),
inputs=[
gr.Audio(sources="microphone", type="numpy", streaming=True),
gr.State(AppState())
],
outputs=[
gr.Audio(type="filepath", autoplay=True, interactive=False),
gr.State()
],
live=True,
description="Ask questions via audio and receive audio responses.",
allow_flagging="never"
)
# Launch the Gradio app
interface.launch()