import gradio as gr import os import warnings import asyncio from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Document, Settings from llama_index.llms.cerebras import Cerebras from llama_index.embeddings.mixedbreadai import MixedbreadAIEmbedding, EncodingFormat from groq import Groq import io # Suppress warnings warnings.filterwarnings("ignore", message=".*clean_up_tokenization_spaces.*") # Global variables index = None query_engine = None # Load Cerebras API key from Hugging Face secrets api_key = os.getenv("CEREBRAS_API_KEY") if not api_key: raise ValueError("CEREBRAS_API_KEY is not set in Hugging Face Secrets.") else: print("Cerebras API key loaded successfully.") # Initialize Cerebras LLM and embedding model os.environ["CEREBRAS_API_KEY"] = api_key llm = Cerebras(model="llama-3.3-70b", api_key=os.environ["CEREBRAS_API_KEY"]) # Change model to Llama3.1-70b from Cerebras Settings.llm = llm # Ensure Cerebras is the LLM being used # Initialize Mixedbread Embedding model mixedbread_api_key = os.getenv("MXBAI_API_KEY") embed_model = MixedbreadAIEmbedding(api_key=mixedbread_api_key, model_name="mixedbread-ai/mxbai-embed-large-v1") # Initialize Groq client for Whisper Large V3 groq_api_key = os.getenv("GROQ_API_KEY") if not groq_api_key: raise ValueError("GROQ_API_KEY is not set.") else: print("Groq API key loaded successfully.") client = Groq(api_key=groq_api_key) # Groq client initialization # Function for audio transcription and translation (Whisper Large V3 from Groq) def transcribe_or_translate_audio(audio_file, translate=False): """ Transcribes or translates audio using Whisper Large V3 via Groq API. """ try: with open(audio_file, "rb") as file: if translate: result = client.audio.translations.create( file=(audio_file, file.read()), model="whisper-large-v3", # Use Groq Whisper Large V3 response_format="json", temperature=0.0 ) return result.text else: result = client.audio.transcriptions.create( file=(audio_file, file.read()), model="whisper-large-v3", # Use Groq Whisper Large V3 response_format="json", temperature=0.0 ) return result.text except Exception as e: return f"Error processing audio: {str(e)}" # Function to load documents and create index def load_documents(file_objs): global index, query_engine try: if not file_objs: return "Error: No files selected." documents = [] document_names = [] for file_obj in file_objs: file_name = os.path.basename(file_obj.name) document_names.append(file_name) loaded_docs = SimpleDirectoryReader(input_files=[file_obj.name]).load_data() for doc in loaded_docs: doc.metadata["source"] = file_name documents.append(doc) if not documents: return "No documents found in the selected files." index = VectorStoreIndex.from_documents(documents, llm=llm, embed_model=embed_model) query_engine = index.as_query_engine() return f"Successfully loaded {len(documents)} documents from the files: {', '.join(document_names)}" except Exception as e: return f"Error loading documents: {str(e)}" async def perform_rag(query, history, audio_file=None, translate_audio=False): global query_engine if query_engine is None: return history + [("Please load documents first.", None)] try: # Handle audio input if provided if audio_file: transcription = transcribe_or_translate_audio(audio_file, translate=translate_audio) query = f"{query} {transcription}".strip() response = await asyncio.to_thread(query_engine.query, query) answer = str(response) # Directly get the answer from the response # If relevant documents are available, add sources without the "Sources" label if hasattr(response, "get_documents"): relevant_docs = response.get_documents() if relevant_docs: sources = "\n\n".join([f"{doc.metadata.get('source', 'No source available')}" for doc in relevant_docs]) else: sources = "" else: sources = "" # Combine answer with sources (if any) without additional labels final_result = f"{answer}\n\n{sources}".strip() # Return updated history with the final result return history + [(query, final_result)] except Exception as e: return history + [(query, f"Error processing query: {str(e)}")] # Function to clear the session and reset variables def clear_all(): global index, query_engine index = None query_engine = None return None, "", [], "" # Reset file input, load output, chatbot, and message input to default states # Create the Gradio interface with gr.Blocks(theme=gr.themes.Base(primary_hue="teal", secondary_hue="teal", neutral_hue="slate")) as demo: gr.Markdown("# RAG Multi-file Chat Application with Speech-to-Text") chatbot = gr.Chatbot() with gr.Row(): file_input = gr.File(label="Select files to load", file_count="multiple") load_btn = gr.Button("Load Documents") load_output = gr.Textbox(label="Load Status") with gr.Row(): msg = gr.Textbox(label="Enter your question") audio_input = gr.Audio(type="filepath", label="Upload Audio") translate_checkbox = gr.Checkbox(label="Translate Audio to English Text", value=False) clear = gr.Button("Clear") # Set up event handlers load_btn.click(load_documents, inputs=[file_input], outputs=[load_output]) # Event handler for text input (only process text) msg.submit(perform_rag, inputs=[msg, chatbot], outputs=[chatbot]) # Event handler for audio input (only process audio) audio_input.change(perform_rag, inputs=[msg, chatbot, audio_input, translate_checkbox], outputs=[chatbot]) clear.click(clear_all, outputs=[file_input, load_output, chatbot, msg], queue=False) # Run the app if __name__ == "__main__": demo.queue() demo.launch()