File size: 6,394 Bytes
b79f67c
5f2f2c0
 
 
 
28f5feb
5545170
28f5feb
b79f67c
 
5f2f2c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7605ad
61821e2
b074388
d5275bb
ed72e23
 
5f2f2c0
 
 
 
 
 
 
61821e2
5f2f2c0
 
 
 
 
 
 
 
 
 
 
61821e2
5f2f2c0
 
 
 
 
 
 
61821e2
5f2f2c0
 
 
 
 
 
b79f67c
5f2f2c0
 
 
 
 
 
 
 
 
 
7a1bf08
 
5f2f2c0
 
7a1bf08
5f2f2c0
 
 
 
a1d4dd3
5f2f2c0
 
 
 
 
 
 
 
 
 
 
b79f67c
 
5f2f2c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b79f67c
5f2f2c0
 
 
 
 
 
 
 
 
 
0ae5a40
5f2f2c0
74128fa
 
 
5f2f2c0
 
 
 
2ab080a
 
 
 
 
 
5f2f2c0
 
 
 
 
28f5feb
a9d9240
aa5ef3e
28f5feb
61821e2
5f2f2c0
 
 
 
 
 
28f5feb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import gradio as gr
import os
import warnings
import asyncio
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Document, Settings
from llama_index.llms.cerebras import Cerebras
from llama_index.embeddings.mixedbreadai import MixedbreadAIEmbedding, EncodingFormat
from groq import Groq
import io

# Suppress warnings
warnings.filterwarnings("ignore", message=".*clean_up_tokenization_spaces.*")

# Global variables
index = None
query_engine = None

# Load Cerebras API key from Hugging Face secrets
api_key = os.getenv("CEREBRAS_API_KEY")
if not api_key:
    raise ValueError("CEREBRAS_API_KEY is not set in Hugging Face Secrets.")
else:
    print("Cerebras API key loaded successfully.")

# Initialize Cerebras LLM and embedding model
os.environ["CEREBRAS_API_KEY"] = api_key
llm = Cerebras(model="llama-3.3-70b", api_key=os.environ["CEREBRAS_API_KEY"])  # Change model to Llama3.1-70b from Cerebras
Settings.llm = llm  # Ensure Cerebras is the LLM being used

# Initialize Mixedbread Embedding model
mixedbread_api_key = os.getenv("MXBAI_API_KEY")
embed_model = MixedbreadAIEmbedding(api_key=mixedbread_api_key, model_name="mixedbread-ai/mxbai-embed-large-v1")

# Initialize Groq client for Whisper Large V3
groq_api_key = os.getenv("GROQ_API_KEY")
if not groq_api_key:
    raise ValueError("GROQ_API_KEY is not set.")
else:
    print("Groq API key loaded successfully.")
client = Groq(api_key=groq_api_key)  # Groq client initialization

# Function for audio transcription and translation (Whisper Large V3 from Groq)
def transcribe_or_translate_audio(audio_file, translate=False):
    """
    Transcribes or translates audio using Whisper Large V3 via Groq API.
    """
    try:
        with open(audio_file, "rb") as file:
            if translate:
                result = client.audio.translations.create(
                    file=(audio_file, file.read()),
                    model="whisper-large-v3",  # Use Groq Whisper Large V3
                    response_format="json",
                    temperature=0.0
                )
                return result.text
            else:
                result = client.audio.transcriptions.create(
                    file=(audio_file, file.read()),
                    model="whisper-large-v3",  # Use Groq Whisper Large V3
                    response_format="json",
                    temperature=0.0
                )
                return result.text
    except Exception as e:
        return f"Error processing audio: {str(e)}"

# Function to load documents and create index
def load_documents(file_objs):
    global index, query_engine
    try:
        if not file_objs:
            return "Error: No files selected."

        documents = []
        document_names = []
        for file_obj in file_objs:
            file_name = os.path.basename(file_obj.name)
            document_names.append(file_name)
            loaded_docs = SimpleDirectoryReader(input_files=[file_obj.name]).load_data()
            for doc in loaded_docs:
                doc.metadata["source"] = file_name
                documents.append(doc)

        if not documents:
            return "No documents found in the selected files."

        index = VectorStoreIndex.from_documents(documents, llm=llm, embed_model=embed_model)
        query_engine = index.as_query_engine()

        return f"Successfully loaded {len(documents)} documents from the files: {', '.join(document_names)}"
    except Exception as e:
        return f"Error loading documents: {str(e)}"

async def perform_rag(query, history, audio_file=None, translate_audio=False):
    global query_engine
    if query_engine is None:
        return history + [("Please load documents first.", None)]

    try:
        # Handle audio input if provided
        if audio_file:
            transcription = transcribe_or_translate_audio(audio_file, translate=translate_audio)
            query = f"{query} {transcription}".strip()

        response = await asyncio.to_thread(query_engine.query, query)
        answer = str(response)  # Directly get the answer from the response

        # If relevant documents are available, add sources without the "Sources" label
        if hasattr(response, "get_documents"):
            relevant_docs = response.get_documents()
            if relevant_docs:
                sources = "\n\n".join([f"{doc.metadata.get('source', 'No source available')}" for doc in relevant_docs])
            else:
                sources = ""
        else:
            sources = ""

        # Combine answer with sources (if any) without additional labels
        final_result = f"{answer}\n\n{sources}".strip()

        # Return updated history with the final result
        return history + [(query, final_result)]
    except Exception as e:
        return history + [(query, f"Error processing query: {str(e)}")]

# Function to clear the session and reset variables
def clear_all():
    global index, query_engine
    index = None
    query_engine = None
    return None, "", [], ""  # Reset file input, load output, chatbot, and message input to default states

# Create the Gradio interface
with gr.Blocks(theme=gr.themes.Base(primary_hue="teal", secondary_hue="teal", neutral_hue="slate")) as demo:
    gr.Markdown("# RAG Multi-file Chat Application with Speech-to-Text")

    chatbot = gr.Chatbot()

    with gr.Row():
        file_input = gr.File(label="Select files to load", file_count="multiple")
        load_btn = gr.Button("Load Documents")
        load_output = gr.Textbox(label="Load Status")

    with gr.Row():
        msg = gr.Textbox(label="Enter your question")
        audio_input = gr.Audio(type="filepath", label="Upload Audio")
        translate_checkbox = gr.Checkbox(label="Translate Audio to English Text", value=False)
        
    clear = gr.Button("Clear")

    # Set up event handlers
    load_btn.click(load_documents, inputs=[file_input], outputs=[load_output])

    # Event handler for text input (only process text)
    msg.submit(perform_rag, inputs=[msg, chatbot], outputs=[chatbot])

    # Event handler for audio input (only process audio)
    audio_input.change(perform_rag, inputs=[msg, chatbot, audio_input, translate_checkbox], outputs=[chatbot])

    clear.click(clear_all, outputs=[file_input, load_output, chatbot, msg], queue=False)

# Run the app
if __name__ == "__main__":
    demo.queue()
    demo.launch()