File size: 5,657 Bytes
7f3430b
c1009f8
9916325
 
c1009f8
9916325
7a077d7
b88951f
 
7a077d7
c1009f8
f684fa1
a53d4df
7f2393e
 
 
091ce1a
7f2393e
 
 
 
 
 
 
 
 
9916325
c1009f8
c71d159
c1009f8
c71d159
c1009f8
92b0167
 
b88951f
 
c8dd9c0
b88951f
 
 
 
 
 
f684fa1
c1009f8
 
 
 
 
 
b370650
c1009f8
 
 
 
 
 
8ff2c37
c1009f8
 
 
 
 
b370650
091ce1a
c1009f8
b3b5557
7f2393e
5f00699
c1009f8
7f2393e
c1009f8
 
 
 
7f2393e
f0bef0b
c1009f8
7f2393e
f0bef0b
c1009f8
 
 
 
b370650
7f2393e
 
a54b7bc
7f2393e
5f00699
7f2393e
a54b7bc
5f00699
a54b7bc
7f2393e
c1009f8
7f2393e
 
 
 
 
 
50b7e5a
7702656
 
8527f42
7702656
 
7f2393e
 
8527f42
 
 
 
7f2393e
 
8527f42
c1009f8
7702656
 
7f2393e
c1009f8
7f2393e
 
c1009f8
 
 
 
 
42aa752
c1009f8
 
 
 
 
 
7f2393e
50b7e5a
c1009f8
7f2393e
 
c1009f8
 
7f2393e
1e2e0a1
7f2393e
1e2e0a1
7f2393e
1e2e0a1
 
7f2393e
1e2e0a1
7f2393e
934e44a
c1009f8
b88951f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import gradio as gr
import torch
import requests
import tempfile
import threading
import numpy as np
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import Neo4jVector
from langchain_community.graphs import Neo4jGraph
from langchain_core.prompts import ChatPromptTemplate
import time
import os
import io
from pydub import AudioSegment
from dataclasses import dataclass
from utils import determine_pause

# Define AppState dataclass for managing the application's state
@dataclass
class AppState:
    stream: np.ndarray | None = None
    sampling_rate: int = 0
    pause_detected: bool = False
    stopped: bool = False
    conversation: list = []

# Neo4j setup
graph = Neo4jGraph(
    url="neo4j+s://c62d0d35.databases.neo4j.io",
    username="neo4j",
    password="_x8f-_aAQvs2NB0x6s0ZHSh3W_y-HrENDbgStvsUCM0"
)

# Initialize the vector index with Neo4j
vector_index = Neo4jVector.from_existing_graph(
    OpenAIEmbeddings(api_key=os.environ['OPENAI_API_KEY']),
    graph=graph,
    search_type="hybrid",
    node_label="Document",
    text_node_properties=["text"],
    embedding_node_property="embedding",
)

# Define the ASR model with Whisper
model_id = 'openai/whisper-large-v3'
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype).to(device)
processor = AutoProcessor.from_pretrained(model_id)

pipe_asr = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    chunk_length_s=15,
    batch_size=16,
    torch_dtype=torch_dtype,
    device=device,
    return_timestamps=True
)

# Function to reset the state after 2 seconds
def auto_reset_state():
    time.sleep(2)
    return AppState()  # Reset the state

# Function to process audio input and transcribe it
def transcribe_function(state: AppState, new_chunk):
    try:
        sr, y = new_chunk[0], new_chunk[1]
    except TypeError:
        print(f"Error chunk structure: {type(new_chunk)}, content: {new_chunk}")
        return state, ""

    if y is None or len(y) == 0:
        return state, ""

    y = y.astype(np.float32)
    max_abs_y = np.max(np.abs(y))
    if max_abs_y > 0:
        y = y / max_abs_y

    if state.stream is not None and len(state.stream) > 0:
        state.stream = np.concatenate([state.stream, y])
    else:
        state.stream = y

    result = pipe_asr({"array": state.stream, "sampling_rate": sr}, return_timestamps=False)
    full_text = result.get("text", "")

    threading.Thread(target=auto_reset_state).start()
    return state, full_text

# Function to generate a response using the prompt and the context
def generate_response_with_prompt(context, question):
    formatted_prompt = prompt.format(context=context, question=question)
    llm = ChatOpenAI(temperature=0, api_key=os.environ['OPENAI_API_KEY'])
    response = llm(formatted_prompt)
    return response.content.strip()

# Function to generate audio with Eleven Labs TTS
def generate_audio_elevenlabs(text):
    XI_API_KEY = os.environ['ELEVENLABS_API']
    VOICE_ID = 'ehbJzYLQFpwbJmGkqbnW'
    tts_url = f"https://api.elevenlabs.io/v1/text-to-speech/{VOICE_ID}/stream"
    headers = {"Accept": "application/json", "xi-api-key": XI_API_KEY}
    data = {"text": text, "model_id": "eleven_multilingual_v2", "voice_settings": {"stability": 1.0}}
    response = requests.post(tts_url, headers=headers, json=data, stream=True)
    if response.ok:
        with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as f:
            for chunk in response.iter_content(chunk_size=1024):
                f.write(chunk)
            return f.name
    else:
        print(f"Error generating audio: {response.text}")
        return None

# Define the function to retrieve information using Neo4j and the vector store
def retriever(question: str):
    structured_query = """
    CALL db.index.fulltext.queryNodes('entity', $query, {limit: 2})
    YIELD node, score
    RETURN node.id AS entity, node.text AS context, score
    ORDER BY score DESC
    LIMIT 2
    """
    structured_data = graph.query(structured_query, {"query": generate_full_text_query(question)})
    structured_response = "\n".join([f"{record['entity']}: {record['context']}" for record in structured_data])

    unstructured_data = [el.page_content for el in vector_index.similarity_search(question)]
    unstructured_response = "\n".join(unstructured_data)

    combined_context = f"Structured data:\n{structured_response}\n\nUnstructured data:\n{unstructured_response}"
    return generate_response_with_prompt(combined_context, question)

# Function to handle the entire audio query and response process
def process_audio_query(state: AppState, audio_input):
    state, transcription = transcribe_function(state, audio_input)
    response_text = retriever(transcription)
    audio_path = generate_audio_elevenlabs(response_text)
    return audio_path, state

# Create Gradio interface for audio input and output
with gr.Blocks() as interface:
    audio_input = gr.Audio(sources="microphone", type="numpy", streaming=True, every=0.1)
    submit_button = gr.Button("Submit")
    audio_output = gr.Audio(type="filepath", autoplay=True)
    state = gr.State(AppState())

    submit_button.click(fn=process_audio_query, inputs=[state, audio_input], outputs=[audio_output, state])

# Launch the Gradio app
interface.launch()