File size: 7,349 Bytes
e61f851
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import gradio as gr
import os
import warnings
import asyncio
from melo.api import TTS
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Document, Settings
from llama_index.llms.cerebras import Cerebras
from llama_index.embeddings.mixedbreadai import MixedbreadAIEmbedding, EncodingFormat
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from groq import Groq
import io
import nltk
from dotenv import load_dotenv  # Import dotenv to load .env variables
import os
os.environ["MECABRC"] = os.path.join(os.getcwd(), "unidic", "dicdir", "mecabrc")

# Load environment variables from .env file
load_dotenv()

nltk.download('averaged_perceptron_tagger_eng')

# Suppress warnings
warnings.filterwarnings("ignore", message=".*clean_up_tokenization_spaces.*")

# Global variables
index = None
query_engine = None

# Inisialisasi MeloTTS untuk TTS
device = 'cpu'  # Atur menjadi 'cuda' jika GPU tersedia
language = 'EN'  # Bahasa default
model = TTS(language=language, device=device)

# Load Cerebras API key from environment
api_key = os.getenv("CEREBRAS_API_KEY")
if not api_key:
    raise ValueError("CEREBRAS_API_KEY is not set in environment variables.")
else:
    print("Cerebras API key loaded successfully.")

# Initialize Cerebras LLM and embedding model
os.environ["CEREBRAS_API_KEY"] = api_key
llm = Cerebras(model="llama-3.3-70b", api_key=os.environ["CEREBRAS_API_KEY"])  # Change model to Llama3.1-70b from Cerebras
Settings.llm = llm  # Ensure Cerebras is the LLM being used

# Initialize Mixedbread Embedding model
mixedbread_api_key = os.getenv("MXBAI_API_KEY")
embed_model = MixedbreadAIEmbedding(api_key=mixedbread_api_key, model_name="mixedbread-ai/mxbai-embed-large-v1")

# Initialize Groq client for Whisper Large V3
groq_api_key = os.getenv("GROQ_API_KEY")
if not groq_api_key:
    raise ValueError("GROQ_API_KEY is not set in environment variables.")
else:
    print("Groq API key loaded successfully.")
client = Groq(api_key=groq_api_key)  # Groq client initialization

# Function for audio transcription and translation (Whisper Large V3 from Groq)
def transcribe_or_translate_audio(audio_file, translate=False):
    """
    Transcribes or translates audio using Whisper Large V3 via Groq API.
    """
    try:
        with open(audio_file, "rb") as file:
            if translate:
                result = client.audio.translations.create(
                    file=(audio_file, file.read()),
                    model="whisper-large-v3",  # Use Groq Whisper Large V3
                    response_format="json",
                    temperature=0.0
                )
                return result.text
            else:
                result = client.audio.transcriptions.create(
                    file=(audio_file, file.read()),
                    model="whisper-large-v3",  # Use Groq Whisper Large V3
                    response_format="json",
                    temperature=0.0
                )
                return result.text
    except Exception as e:
        return f"Error processing audio: {str(e)}"

# Function to load documents and create index
def load_documents(file_objs):
    global index, query_engine
    try:
        if not file_objs:
            return "Error: No files selected."

        documents = []
        document_names = []
        for file_obj in file_objs:
            file_name = os.path.basename(file_obj.name)
            document_names.append(file_name)
            loaded_docs = SimpleDirectoryReader(input_files=[file_obj.name]).load_data()
            for doc in loaded_docs:
                doc.metadata["source"] = file_name
                documents.append(doc)

        if not documents:
            return "No documents found in the selected files."

        index = VectorStoreIndex.from_documents(documents, llm=llm, embed_model=embed_model)
        query_engine = index.as_query_engine()

        return f"Successfully loaded {len(documents)} documents from the files: {', '.join(document_names)}"
    except Exception as e:
        return f"Error loading documents: {str(e)}"

async def perform_rag(query, history, audio_file=None, translate_audio=False):
    global query_engine
    if query_engine is None:
        return history + [("Please load documents first.", None)], None  # Tambahkan None untuk output audio

    try:
        # Handle audio input jika diberikan
        if audio_file:
            transcription = transcribe_or_translate_audio(audio_file, translate=translate_audio)
            query = f"{query} {transcription}".strip()

        response = await asyncio.to_thread(query_engine.query, query)
        answer = str(response)  # Dapatkan jawaban dari respons

        # Jika dokumen relevan tersedia, tambahkan sumber tanpa label "Sources"
        if hasattr(response, "get_documents"):
            relevant_docs = response.get_documents()
            if relevant_docs:
                sources = "\n\n".join([f"{doc.metadata.get('source', 'No source available')}" for doc in relevant_docs])
            else:
                sources = ""
        else:
            sources = ""

        # Gabungkan jawaban dengan sumber (jika ada) tanpa label tambahan
        final_result = f"{answer}\n\n{sources}".strip()

        # **Generate audio menggunakan MeloTTS**
        output_audio_path = "output.wav"
        model.tts_to_file(answer, model.hps.data.spk2id['EN-US'], output_audio_path, speed=1.0)

        # Kembalikan history yang diperbarui dan file audio
        return history + [(query, final_result)], output_audio_path
    except Exception as e:
        return history + [(query, f"Error processing query: {str(e)}")], None

# Function to clear the session and reset variables
def clear_all():
    global index, query_engine
    index = None
    query_engine = None
    return None, "", [], ""  # Reset file input, load output, chatbot, and message input to default states

# Create the Gradio interface
with gr.Blocks(theme=gr.themes.Base(primary_hue="teal", secondary_hue="teal", neutral_hue="slate")) as demo:
    gr.Markdown("# RAG Multi-file Chat Application with Speech-to-Text and Text-to-Speech")
    
    chatbot = gr.Chatbot()
    audio_output = gr.Audio(label="Response Audio", type="filepath")
    with gr.Row():
        file_input = gr.File(label="Select files to load", file_count="multiple")
        load_btn = gr.Button("Load Documents")
        load_output = gr.Textbox(label="Load Status")
    with gr.Row():
        msg = gr.Textbox(label="Enter your question")
        audio_input = gr.Audio(type="filepath", label="RECORD")
        translate_checkbox = gr.Checkbox(label="Translate Audio to English Text", value=False)
    clear = gr.Button("Clear")

    # Set up event handlers
    load_btn.click(load_documents, inputs=[file_input], outputs=[load_output])

    # Event handler untuk input teks (proses teks)
    msg.submit(perform_rag, inputs=[msg, chatbot], outputs=[chatbot, audio_output])  # Tambahkan audio_output

    # Event handler untuk input audio (proses audio)
    audio_input.change(perform_rag, inputs=[msg, chatbot, audio_input, translate_checkbox], outputs=[chatbot, audio_output])  # Tambahkan audio_output

    clear.click(clear_all, outputs=[file_input, load_output, chatbot, msg], queue=False)

# Run the app
if __name__ == "__main__":
    demo.queue()
    demo.launch()