muhammadsalmanalfaridzi commited on
Commit
5f2f2c0
·
verified ·
1 Parent(s): a6ab9ef

RAG LlamaIndex

Browse files
Files changed (1) hide show
  1. app.py +146 -116
app.py CHANGED
@@ -1,126 +1,156 @@
1
- import os
2
  import gradio as gr
3
- from argparse import ArgumentParser
 
 
 
 
 
4
  from groq import Groq
5
- import base64
6
  import io
7
 
8
- # Initialize Groq client
9
- API_KEY = os.environ['GROQ_API_KEY']
10
- client = Groq(api_key=API_KEY)
11
-
12
- REVISION = 'v1.0.4'
13
-
14
- def _get_args():
15
- parser = ArgumentParser()
16
- parser.add_argument("--revision", type=str, default=REVISION)
17
- parser.add_argument("--share", action="store_true", default=False, help="Create a publicly shareable link for the interface.")
18
- return parser.parse_args()
19
-
20
- def process_image(image):
21
- buffered = io.BytesIO()
22
- image.save(buffered, format="JPEG")
23
- return buffered.getvalue()
24
-
25
- def translate_audio(audio_file):
26
- with open(audio_file, "rb") as file:
27
- translation = client.audio.translations.create(
28
- file=(audio_file, file.read()),
29
- model="whisper-large-v3",
30
- response_format="json",
31
- temperature=0.0
32
- )
33
- return translation.text
34
-
35
- def transcribe_audio(audio_file):
36
- with open(audio_file, "rb") as file:
37
- transcription = client.audio.transcriptions.create(
38
- file=(audio_file, file.read()),
39
- model="whisper-large-v3",
40
- response_format="json",
41
- temperature=0.0
42
- )
43
- return transcription.text
44
-
45
- def predict(chat_history, query, image, audio, translate):
46
- final_query = query.strip()
47
-
48
- if audio:
49
- audio_file_path = audio
50
- if translate:
51
- translation_text = translate_audio(audio_file_path)
52
- final_query = translation_text.strip()
53
- chat_history.append({'role': 'assistant', 'content': translation_text})
54
- else:
55
- transcribed_text = transcribe_audio(audio_file_path)
56
- final_query = f"{final_query} {transcribed_text}".strip()
 
 
 
 
57
 
58
- image_data = process_image(image) if image else None
59
- messages = create_messages(final_query, image_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- if not messages:
62
- error_message = "No valid input provided. Please enter a query or upload an image/audio."
63
- chat_history.append({'role': 'assistant', 'content': error_message})
64
- return chat_history
 
 
 
 
 
 
 
65
 
66
  try:
67
- completion = client.chat.completions.create(
68
- model="llama-3.2-90b-vision-preview",
69
- messages=messages,
70
- temperature=1,
71
- max_tokens=1500,
72
- top_p=1,
73
- stream=False,
74
- )
75
-
76
- response_text = completion.choices[0].message.content.strip()
77
- chat_history.append({'role': 'user', 'content': final_query})
78
- chat_history.append({'role': 'assistant', 'content': response_text})
 
 
 
 
 
 
 
 
 
 
 
79
  except Exception as e:
80
- response_text = f"Error: {str(e)}"
81
- chat_history.append({'role': 'user', 'content': final_query})
82
- chat_history.append({'role': 'assistant', 'content': response_text})
83
-
84
- return chat_history
85
-
86
- def create_messages(query, image_data):
87
- messages = []
88
- if query:
89
- messages.append({'role': 'user', 'content': query})
90
- if image_data:
91
- image_base64 = f"data:image/jpeg;base64,{base64.b64encode(image_data).decode()}"
92
- messages.append({
93
- 'role': 'user',
94
- 'content': [
95
- {"type": "text", "text": "Please analyze this image."},
96
- {"type": "image_url", "image_url": {"url": image_base64}}
97
- ]
98
- })
99
- return messages
100
-
101
- def clear_history():
102
- return []
103
-
104
- def main():
105
- args = _get_args()
106
 
107
- with gr.Blocks(css="#chatbox {height: 400px; background-color: #f9f9f9; padding: 20px; border-radius: 10px; }") as demo:
108
- gr.Markdown("<h1 style='text-align: center; color: #4a4a4a;'>Llama-3.2-90b-vision-preview</h1>")
109
-
110
- chatbox = gr.Chatbot(type='messages', elem_id="chatbox")
111
- query = gr.Textbox(label="Type your query here...", placeholder="Enter your question or command...", lines=2)
112
- image_input = gr.Image(type="pil", label="Upload Image")
113
- audio_input = gr.Audio(type="filepath", label="Upload Audio")
114
- translate_checkbox = gr.Checkbox(label="Translate Audio to English Text")
115
-
116
- with gr.Row():
117
- submit_btn = gr.Button("Submit", variant="primary", elem_id="submit-btn")
118
- clear_btn = gr.Button("Clear History", variant="secondary", elem_id="clear-btn")
119
-
120
- submit_btn.click(predict, inputs=[chatbox, query, image_input, audio_input, translate_checkbox], outputs=chatbox)
121
- clear_btn.click(clear_history, outputs=chatbox)
122
-
123
- demo.launch(share=args.share)
124
-
125
- if __name__ == '__main__':
126
- main()
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
+ import warnings
4
+ import asyncio
5
+ from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Document, Settings
6
+ from llama_index.llms.cerebras import Cerebras
7
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
8
  from groq import Groq
 
9
  import io
10
 
11
+ # Suppress warnings
12
+ warnings.filterwarnings("ignore", message=".*clean_up_tokenization_spaces.*")
13
+
14
+ # Global variables
15
+ index = None
16
+ query_engine = None
17
+
18
+ # Load Cerebras API key from Hugging Face secrets
19
+ api_key = os.getenv("CEREBRAS_API_KEY")
20
+ if not api_key:
21
+ raise ValueError("CEREBRAS_API_KEY is not set in Hugging Face Secrets.")
22
+ else:
23
+ print("Cerebras API key loaded successfully.")
24
+
25
+ # Initialize Cerebras LLM and embedding model
26
+ os.environ["CEREBRAS_API_KEY"] = api_key
27
+ llm = Cerebras(model="llama3.1-70b", api_key=os.environ["CEREBRAS_API_KEY"])
28
+ Settings.llm = llm
29
+ embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
30
+
31
+ # Initialize Groq client for Whisper Large V3
32
+ groq_api_key = os.getenv("GROQ_API_KEY")
33
+ if not groq_api_key:
34
+ raise ValueError("GROQ_API_KEY is not set.")
35
+ else:
36
+ print("Groq API key loaded successfully.")
37
+ client = Groq(api_key=groq_api_key)
38
+
39
+ # Function for audio transcription and translation (Whisper Large V3 from Groq)
40
+ def transcribe_or_translate_audio(audio_file, translate=False):
41
+ """
42
+ Transcribes or translates audio using Whisper Large V3 via Groq API.
43
+ """
44
+ try:
45
+ with open(audio_file, "rb") as file:
46
+ if translate:
47
+ result = client.audio.translations.create(
48
+ file=(audio_file, file.read()),
49
+ model="whisper-large-v3",
50
+ response_format="json",
51
+ temperature=0.0
52
+ )
53
+ return result.text
54
+ else:
55
+ result = client.audio.transcriptions.create(
56
+ file=(audio_file, file.read()),
57
+ model="whisper-large-v3",
58
+ response_format="json",
59
+ temperature=0.0
60
+ )
61
+ return result.text
62
+ except Exception as e:
63
+ return f"Error processing audio: {str(e)}"
64
 
65
+ # Function to load documents and create index
66
+ def load_documents(file_objs):
67
+ global index, query_engine
68
+ try:
69
+ if not file_objs:
70
+ return "Error: No files selected."
71
+
72
+ documents = []
73
+ document_names = []
74
+ for file_obj in file_objs:
75
+ document_names.append(file_obj.name)
76
+ loaded_docs = SimpleDirectoryReader(input_files=[file_obj.name]).load_data()
77
+ for doc in loaded_docs:
78
+ doc.metadata["source"] = file_obj.name
79
+ documents.append(doc)
80
+
81
+ if not documents:
82
+ return "No documents found in the selected files."
83
 
84
+ index = VectorStoreIndex.from_documents(documents, llm=llm, embed_model=embed_model)
85
+ query_engine = index.as_query_engine()
86
+
87
+ return f"Successfully loaded {len(documents)} documents from the files: {', '.join(document_names)}"
88
+ except Exception as e:
89
+ return f"Error loading documents: {str(e)}"
90
+
91
+ async def perform_rag(query, history, audio_file=None, translate_audio=False):
92
+ global query_engine
93
+ if query_engine is None:
94
+ return history + [("Please load documents first.", None)]
95
 
96
  try:
97
+ # Handle audio input if provided
98
+ if audio_file:
99
+ transcription = transcribe_or_translate_audio(audio_file, translate=translate_audio)
100
+ query = f"{query} {transcription}".strip()
101
+
102
+ response = await asyncio.to_thread(query_engine.query, query)
103
+ answer = str(response) # Directly get the answer from the response
104
+
105
+ # If relevant documents are available, add sources without the "Sources" label
106
+ if hasattr(response, "get_documents"):
107
+ relevant_docs = response.get_documents()
108
+ if relevant_docs:
109
+ sources = "\n\n".join([f"{doc.metadata.get('source', 'No source available')}" for doc in relevant_docs])
110
+ else:
111
+ sources = ""
112
+ else:
113
+ sources = ""
114
+
115
+ # Combine answer with sources (if any) without additional labels
116
+ final_result = f"{answer}\n\n{sources}".strip()
117
+
118
+ # Return updated history with the final result
119
+ return history + [(query, final_result)]
120
  except Exception as e:
121
+ return history + [(query, f"Error processing query: {str(e)}")]
122
+
123
+ # Function to clear the session and reset variables
124
+ def clear_all():
125
+ global index, query_engine
126
+ index = None
127
+ query_engine = None
128
+ return None, "", [], "" # Reset file input, load output, chatbot, and message input to default states
129
+
130
+ # Create the Gradio interface
131
+ with gr.Blocks() as demo:
132
+ gr.Markdown("# RAG Multi-file Chat Application with Speech-to-Text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ with gr.Row():
135
+ file_input = gr.File(label="Select files to load", file_count="multiple")
136
+ load_btn = gr.Button("Load Documents")
137
+ load_output = gr.Textbox(label="Load Status")
138
+
139
+ msg = gr.Textbox(label="Enter your question")
140
+ audio_input = gr.Audio(type="filepath", label="Upload Audio")
141
+ translate_checkbox = gr.Checkbox(label="Translate Audio to English Text", value=False)
142
+ chatbot = gr.Chatbot()
143
+ clear = gr.Button("Clear")
144
+
145
+ # Set up event handlers
146
+ load_btn.click(load_documents, inputs=[file_input], outputs=[load_output])
147
+
148
+ # Event handler for audio input to directly trigger processing and chat response
149
+ audio_input.change(perform_rag, inputs=[msg, chatbot, audio_input, translate_checkbox], outputs=[chatbot])
150
+
151
+ clear.click(clear_all, outputs=[file_input, load_output, chatbot, msg], queue=False)
152
+
153
+ # Run the app
154
+ if __name__ == "__main__":
155
+ demo.queue()
156
+ demo.launch()