Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import warnings | |
import asyncio | |
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Document, Settings | |
from llama_index.llms.cerebras import Cerebras | |
from llama_index.embeddings.mixedbreadai import MixedbreadAIEmbedding, EncodingFormat | |
from groq import Groq | |
import io | |
# Suppress warnings | |
warnings.filterwarnings("ignore", message=".*clean_up_tokenization_spaces.*") | |
# Global variables | |
index = None | |
query_engine = None | |
# Load Cerebras API key from Hugging Face secrets | |
api_key = os.getenv("CEREBRAS_API_KEY") | |
if not api_key: | |
raise ValueError("CEREBRAS_API_KEY is not set in Hugging Face Secrets.") | |
else: | |
print("Cerebras API key loaded successfully.") | |
# Initialize Cerebras LLM and embedding model | |
os.environ["CEREBRAS_API_KEY"] = api_key | |
llm = Cerebras(model="llama-3.3-70b", api_key=os.environ["CEREBRAS_API_KEY"]) # Change model to Llama3.1-70b from Cerebras | |
Settings.llm = llm # Ensure Cerebras is the LLM being used | |
# Initialize Voyage Embedding model | |
mixedbread_api_key = os.getenv("MXBAI_API_KEY") | |
embed_model = MixedbreadAIEmbedding(api_key=mixedbread_api_key, model_name="mixedbread-ai/mxbai-embed-large-v1") | |
# Initialize Groq client for Whisper Large V3 | |
groq_api_key = os.getenv("GROQ_API_KEY") | |
if not groq_api_key: | |
raise ValueError("GROQ_API_KEY is not set.") | |
else: | |
print("Groq API key loaded successfully.") | |
client = Groq(api_key=groq_api_key) # Groq client initialization | |
# Function for audio transcription and translation (Whisper Large V3 from Groq) | |
def transcribe_or_translate_audio(audio_file, translate=False): | |
""" | |
Transcribes or translates audio using Whisper Large V3 via Groq API. | |
""" | |
try: | |
with open(audio_file, "rb") as file: | |
if translate: | |
result = client.audio.translations.create( | |
file=(audio_file, file.read()), | |
model="whisper-large-v3", # Use Groq Whisper Large V3 | |
response_format="json", | |
temperature=0.0 | |
) | |
return result.text | |
else: | |
result = client.audio.transcriptions.create( | |
file=(audio_file, file.read()), | |
model="whisper-large-v3", # Use Groq Whisper Large V3 | |
response_format="json", | |
temperature=0.0 | |
) | |
return result.text | |
except Exception as e: | |
return f"Error processing audio: {str(e)}" | |
# Function to load documents and create index | |
def load_documents(file_objs): | |
global index, query_engine | |
try: | |
if not file_objs: | |
return "Error: No files selected." | |
documents = [] | |
document_names = [] | |
for file_obj in file_objs: | |
file_name = os.path.basename(file_obj.name) | |
document_names.append(file_name) | |
loaded_docs = SimpleDirectoryReader(input_files=[file_obj.name]).load_data() | |
for doc in loaded_docs: | |
doc.metadata["source"] = file_name | |
documents.append(doc) | |
if not documents: | |
return "No documents found in the selected files." | |
index = VectorStoreIndex.from_documents(documents, llm=llm, embed_model=embed_model) | |
query_engine = index.as_query_engine() | |
return f"Successfully loaded {len(documents)} documents from the files: {', '.join(document_names)}" | |
except Exception as e: | |
return f"Error loading documents: {str(e)}" | |
async def perform_rag(query, history, audio_file=None, translate_audio=False): | |
global query_engine | |
if query_engine is None: | |
return history + [("Please load documents first.", None)] | |
try: | |
# Handle audio input if provided | |
if audio_file: | |
transcription = transcribe_or_translate_audio(audio_file, translate=translate_audio) | |
query = f"{query} {transcription}".strip() | |
response = await asyncio.to_thread(query_engine.query, query) | |
answer = str(response) # Directly get the answer from the response | |
# If relevant documents are available, add sources without the "Sources" label | |
if hasattr(response, "get_documents"): | |
relevant_docs = response.get_documents() | |
if relevant_docs: | |
sources = "\n\n".join([f"{doc.metadata.get('source', 'No source available')}" for doc in relevant_docs]) | |
else: | |
sources = "" | |
else: | |
sources = "" | |
# Combine answer with sources (if any) without additional labels | |
final_result = f"{answer}\n\n{sources}".strip() | |
# Return updated history with the final result | |
return history + [(query, final_result)] | |
except Exception as e: | |
return history + [(query, f"Error processing query: {str(e)}")] | |
# Function to clear the session and reset variables | |
def clear_all(): | |
global index, query_engine | |
index = None | |
query_engine = None | |
return None, "", [], "" # Reset file input, load output, chatbot, and message input to default states | |
# Create the Gradio interface | |
with gr.Blocks(theme=gr.themes.Base(primary_hue="teal", secondary_hue="teal", neutral_hue="slate")) as demo: | |
gr.Markdown("# RAG Multi-file Chat Application with Speech-to-Text") | |
chatbot = gr.Chatbot() | |
with gr.Row(): | |
file_input = gr.File(label="Select files to load", file_count="multiple") | |
load_btn = gr.Button("Load Documents") | |
load_output = gr.Textbox(label="Load Status") | |
with gr.Row(): | |
msg = gr.Textbox(label="Enter your question") | |
audio_input = gr.Audio(type="filepath", label="Upload Audio") | |
translate_checkbox = gr.Checkbox(label="Translate Audio to English Text", value=False) | |
clear = gr.Button("Clear") | |
# Set up event handlers | |
load_btn.click(load_documents, inputs=[file_input], outputs=[load_output]) | |
# Event handler for text input (only process text) | |
msg.submit(perform_rag, inputs=[msg, chatbot], outputs=[chatbot]) | |
# Event handler for audio input (only process audio) | |
audio_input.change(perform_rag, inputs=[msg, chatbot, audio_input, translate_checkbox], outputs=[chatbot]) | |
clear.click(clear_all, outputs=[file_input, load_output, chatbot, msg], queue=False) | |
# Run the app | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch() |