File size: 6,394 Bytes
b79f67c 5f2f2c0 28f5feb 5545170 28f5feb b79f67c 5f2f2c0 b7605ad 61821e2 b074388 d5275bb ed72e23 5f2f2c0 61821e2 5f2f2c0 61821e2 5f2f2c0 61821e2 5f2f2c0 b79f67c 5f2f2c0 7a1bf08 5f2f2c0 7a1bf08 5f2f2c0 a1d4dd3 5f2f2c0 b79f67c 5f2f2c0 b79f67c 5f2f2c0 0ae5a40 5f2f2c0 74128fa 5f2f2c0 2ab080a 5f2f2c0 28f5feb a9d9240 aa5ef3e 28f5feb 61821e2 5f2f2c0 28f5feb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import gradio as gr
import os
import warnings
import asyncio
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Document, Settings
from llama_index.llms.cerebras import Cerebras
from llama_index.embeddings.mixedbreadai import MixedbreadAIEmbedding, EncodingFormat
from groq import Groq
import io
# Suppress warnings
warnings.filterwarnings("ignore", message=".*clean_up_tokenization_spaces.*")
# Global variables
index = None
query_engine = None
# Load Cerebras API key from Hugging Face secrets
api_key = os.getenv("CEREBRAS_API_KEY")
if not api_key:
raise ValueError("CEREBRAS_API_KEY is not set in Hugging Face Secrets.")
else:
print("Cerebras API key loaded successfully.")
# Initialize Cerebras LLM and embedding model
os.environ["CEREBRAS_API_KEY"] = api_key
llm = Cerebras(model="llama-3.3-70b", api_key=os.environ["CEREBRAS_API_KEY"]) # Change model to Llama3.1-70b from Cerebras
Settings.llm = llm # Ensure Cerebras is the LLM being used
# Initialize Mixedbread Embedding model
mixedbread_api_key = os.getenv("MXBAI_API_KEY")
embed_model = MixedbreadAIEmbedding(api_key=mixedbread_api_key, model_name="mixedbread-ai/mxbai-embed-large-v1")
# Initialize Groq client for Whisper Large V3
groq_api_key = os.getenv("GROQ_API_KEY")
if not groq_api_key:
raise ValueError("GROQ_API_KEY is not set.")
else:
print("Groq API key loaded successfully.")
client = Groq(api_key=groq_api_key) # Groq client initialization
# Function for audio transcription and translation (Whisper Large V3 from Groq)
def transcribe_or_translate_audio(audio_file, translate=False):
"""
Transcribes or translates audio using Whisper Large V3 via Groq API.
"""
try:
with open(audio_file, "rb") as file:
if translate:
result = client.audio.translations.create(
file=(audio_file, file.read()),
model="whisper-large-v3", # Use Groq Whisper Large V3
response_format="json",
temperature=0.0
)
return result.text
else:
result = client.audio.transcriptions.create(
file=(audio_file, file.read()),
model="whisper-large-v3", # Use Groq Whisper Large V3
response_format="json",
temperature=0.0
)
return result.text
except Exception as e:
return f"Error processing audio: {str(e)}"
# Function to load documents and create index
def load_documents(file_objs):
global index, query_engine
try:
if not file_objs:
return "Error: No files selected."
documents = []
document_names = []
for file_obj in file_objs:
file_name = os.path.basename(file_obj.name)
document_names.append(file_name)
loaded_docs = SimpleDirectoryReader(input_files=[file_obj.name]).load_data()
for doc in loaded_docs:
doc.metadata["source"] = file_name
documents.append(doc)
if not documents:
return "No documents found in the selected files."
index = VectorStoreIndex.from_documents(documents, llm=llm, embed_model=embed_model)
query_engine = index.as_query_engine()
return f"Successfully loaded {len(documents)} documents from the files: {', '.join(document_names)}"
except Exception as e:
return f"Error loading documents: {str(e)}"
async def perform_rag(query, history, audio_file=None, translate_audio=False):
global query_engine
if query_engine is None:
return history + [("Please load documents first.", None)]
try:
# Handle audio input if provided
if audio_file:
transcription = transcribe_or_translate_audio(audio_file, translate=translate_audio)
query = f"{query} {transcription}".strip()
response = await asyncio.to_thread(query_engine.query, query)
answer = str(response) # Directly get the answer from the response
# If relevant documents are available, add sources without the "Sources" label
if hasattr(response, "get_documents"):
relevant_docs = response.get_documents()
if relevant_docs:
sources = "\n\n".join([f"{doc.metadata.get('source', 'No source available')}" for doc in relevant_docs])
else:
sources = ""
else:
sources = ""
# Combine answer with sources (if any) without additional labels
final_result = f"{answer}\n\n{sources}".strip()
# Return updated history with the final result
return history + [(query, final_result)]
except Exception as e:
return history + [(query, f"Error processing query: {str(e)}")]
# Function to clear the session and reset variables
def clear_all():
global index, query_engine
index = None
query_engine = None
return None, "", [], "" # Reset file input, load output, chatbot, and message input to default states
# Create the Gradio interface
with gr.Blocks(theme=gr.themes.Base(primary_hue="teal", secondary_hue="teal", neutral_hue="slate")) as demo:
gr.Markdown("# RAG Multi-file Chat Application with Speech-to-Text")
chatbot = gr.Chatbot()
with gr.Row():
file_input = gr.File(label="Select files to load", file_count="multiple")
load_btn = gr.Button("Load Documents")
load_output = gr.Textbox(label="Load Status")
with gr.Row():
msg = gr.Textbox(label="Enter your question")
audio_input = gr.Audio(type="filepath", label="Upload Audio")
translate_checkbox = gr.Checkbox(label="Translate Audio to English Text", value=False)
clear = gr.Button("Clear")
# Set up event handlers
load_btn.click(load_documents, inputs=[file_input], outputs=[load_output])
# Event handler for text input (only process text)
msg.submit(perform_rag, inputs=[msg, chatbot], outputs=[chatbot])
# Event handler for audio input (only process audio)
audio_input.change(perform_rag, inputs=[msg, chatbot, audio_input, translate_checkbox], outputs=[chatbot])
clear.click(clear_all, outputs=[file_input, load_output, chatbot, msg], queue=False)
# Run the app
if __name__ == "__main__":
demo.queue()
demo.launch() |