Spaces:
Runtime error
Runtime error
File size: 7,349 Bytes
e61f851 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import gradio as gr
import os
import warnings
import asyncio
from melo.api import TTS
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Document, Settings
from llama_index.llms.cerebras import Cerebras
from llama_index.embeddings.mixedbreadai import MixedbreadAIEmbedding, EncodingFormat
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from groq import Groq
import io
import nltk
from dotenv import load_dotenv # Import dotenv to load .env variables
import os
os.environ["MECABRC"] = os.path.join(os.getcwd(), "unidic", "dicdir", "mecabrc")
# Load environment variables from .env file
load_dotenv()
nltk.download('averaged_perceptron_tagger_eng')
# Suppress warnings
warnings.filterwarnings("ignore", message=".*clean_up_tokenization_spaces.*")
# Global variables
index = None
query_engine = None
# Inisialisasi MeloTTS untuk TTS
device = 'cpu' # Atur menjadi 'cuda' jika GPU tersedia
language = 'EN' # Bahasa default
model = TTS(language=language, device=device)
# Load Cerebras API key from environment
api_key = os.getenv("CEREBRAS_API_KEY")
if not api_key:
raise ValueError("CEREBRAS_API_KEY is not set in environment variables.")
else:
print("Cerebras API key loaded successfully.")
# Initialize Cerebras LLM and embedding model
os.environ["CEREBRAS_API_KEY"] = api_key
llm = Cerebras(model="llama-3.3-70b", api_key=os.environ["CEREBRAS_API_KEY"]) # Change model to Llama3.1-70b from Cerebras
Settings.llm = llm # Ensure Cerebras is the LLM being used
# Initialize Mixedbread Embedding model
mixedbread_api_key = os.getenv("MXBAI_API_KEY")
embed_model = MixedbreadAIEmbedding(api_key=mixedbread_api_key, model_name="mixedbread-ai/mxbai-embed-large-v1")
# Initialize Groq client for Whisper Large V3
groq_api_key = os.getenv("GROQ_API_KEY")
if not groq_api_key:
raise ValueError("GROQ_API_KEY is not set in environment variables.")
else:
print("Groq API key loaded successfully.")
client = Groq(api_key=groq_api_key) # Groq client initialization
# Function for audio transcription and translation (Whisper Large V3 from Groq)
def transcribe_or_translate_audio(audio_file, translate=False):
"""
Transcribes or translates audio using Whisper Large V3 via Groq API.
"""
try:
with open(audio_file, "rb") as file:
if translate:
result = client.audio.translations.create(
file=(audio_file, file.read()),
model="whisper-large-v3", # Use Groq Whisper Large V3
response_format="json",
temperature=0.0
)
return result.text
else:
result = client.audio.transcriptions.create(
file=(audio_file, file.read()),
model="whisper-large-v3", # Use Groq Whisper Large V3
response_format="json",
temperature=0.0
)
return result.text
except Exception as e:
return f"Error processing audio: {str(e)}"
# Function to load documents and create index
def load_documents(file_objs):
global index, query_engine
try:
if not file_objs:
return "Error: No files selected."
documents = []
document_names = []
for file_obj in file_objs:
file_name = os.path.basename(file_obj.name)
document_names.append(file_name)
loaded_docs = SimpleDirectoryReader(input_files=[file_obj.name]).load_data()
for doc in loaded_docs:
doc.metadata["source"] = file_name
documents.append(doc)
if not documents:
return "No documents found in the selected files."
index = VectorStoreIndex.from_documents(documents, llm=llm, embed_model=embed_model)
query_engine = index.as_query_engine()
return f"Successfully loaded {len(documents)} documents from the files: {', '.join(document_names)}"
except Exception as e:
return f"Error loading documents: {str(e)}"
async def perform_rag(query, history, audio_file=None, translate_audio=False):
global query_engine
if query_engine is None:
return history + [("Please load documents first.", None)], None # Tambahkan None untuk output audio
try:
# Handle audio input jika diberikan
if audio_file:
transcription = transcribe_or_translate_audio(audio_file, translate=translate_audio)
query = f"{query} {transcription}".strip()
response = await asyncio.to_thread(query_engine.query, query)
answer = str(response) # Dapatkan jawaban dari respons
# Jika dokumen relevan tersedia, tambahkan sumber tanpa label "Sources"
if hasattr(response, "get_documents"):
relevant_docs = response.get_documents()
if relevant_docs:
sources = "\n\n".join([f"{doc.metadata.get('source', 'No source available')}" for doc in relevant_docs])
else:
sources = ""
else:
sources = ""
# Gabungkan jawaban dengan sumber (jika ada) tanpa label tambahan
final_result = f"{answer}\n\n{sources}".strip()
# **Generate audio menggunakan MeloTTS**
output_audio_path = "output.wav"
model.tts_to_file(answer, model.hps.data.spk2id['EN-US'], output_audio_path, speed=1.0)
# Kembalikan history yang diperbarui dan file audio
return history + [(query, final_result)], output_audio_path
except Exception as e:
return history + [(query, f"Error processing query: {str(e)}")], None
# Function to clear the session and reset variables
def clear_all():
global index, query_engine
index = None
query_engine = None
return None, "", [], "" # Reset file input, load output, chatbot, and message input to default states
# Create the Gradio interface
with gr.Blocks(theme=gr.themes.Base(primary_hue="teal", secondary_hue="teal", neutral_hue="slate")) as demo:
gr.Markdown("# RAG Multi-file Chat Application with Speech-to-Text and Text-to-Speech")
chatbot = gr.Chatbot()
audio_output = gr.Audio(label="Response Audio", type="filepath")
with gr.Row():
file_input = gr.File(label="Select files to load", file_count="multiple")
load_btn = gr.Button("Load Documents")
load_output = gr.Textbox(label="Load Status")
with gr.Row():
msg = gr.Textbox(label="Enter your question")
audio_input = gr.Audio(type="filepath", label="RECORD")
translate_checkbox = gr.Checkbox(label="Translate Audio to English Text", value=False)
clear = gr.Button("Clear")
# Set up event handlers
load_btn.click(load_documents, inputs=[file_input], outputs=[load_output])
# Event handler untuk input teks (proses teks)
msg.submit(perform_rag, inputs=[msg, chatbot], outputs=[chatbot, audio_output]) # Tambahkan audio_output
# Event handler untuk input audio (proses audio)
audio_input.change(perform_rag, inputs=[msg, chatbot, audio_input, translate_checkbox], outputs=[chatbot, audio_output]) # Tambahkan audio_output
clear.click(clear_all, outputs=[file_input, load_output, chatbot, msg], queue=False)
# Run the app
if __name__ == "__main__":
demo.queue()
demo.launch()
|