lukmanaj commited on
Commit
7613d55
·
verified ·
1 Parent(s): f9c6b6f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -0
app.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import numpy as np
4
+ import faiss
5
+ from mistralai import Mistral
6
+
7
+ api_key = os.getenv("MISTRAL_API_KEY")
8
+ client = Mistral(api_key=api_key)
9
+
10
+ # =============================================================================
11
+ # BASIC CHAT UI (Gradio Version)
12
+ # =============================================================================
13
+
14
+ def run_mistral_basic(message, history):
15
+ """Basic chat function for Gradio ChatInterface"""
16
+ messages = [{"role": "user", "content": message}]
17
+ chat_response = client.chat.complete(
18
+ model="mistral-large-latest",
19
+ messages=messages
20
+ )
21
+ return chat_response.choices[0].message.content
22
+
23
+ # Create basic chat interface
24
+ basic_chat = gr.ChatInterface(
25
+ fn=run_mistral_basic,
26
+ title="Basic Mistral Chat",
27
+ description="Chat with Mistral AI"
28
+ )
29
+
30
+ # =============================================================================
31
+ # RAG UI (Gradio Version)
32
+ # =============================================================================
33
+
34
+ # Global variable to store processed document
35
+ processed_chunks = None
36
+ faiss_index = None
37
+
38
+ def get_text_embedding(input_text):
39
+ """Get embeddings from Mistral"""
40
+ embeddings_batch_response = client.embeddings.create(
41
+ model="mistral-embed",
42
+ inputs=[input_text]
43
+ )
44
+ return embeddings_batch_response.data[0].embedding
45
+
46
+ def process_document(file):
47
+ """Process uploaded document and create FAISS index"""
48
+ global processed_chunks, faiss_index
49
+
50
+ if file is None:
51
+ return "Please upload a text file first."
52
+
53
+ try:
54
+ # Read the file
55
+ with open(file.name, 'r', encoding='utf-8') as f:
56
+ text = f.read()
57
+
58
+ # Split document into chunks
59
+ chunk_size = 2048
60
+ processed_chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
61
+
62
+ # Create embeddings and FAISS index
63
+ text_embeddings = np.array([get_text_embedding(chunk) for chunk in processed_chunks])
64
+ d = text_embeddings.shape[1]
65
+ faiss_index = faiss.IndexFlatL2(d)
66
+ faiss_index.add(text_embeddings.astype(np.float32))
67
+
68
+ return f"Document processed successfully! Split into {len(processed_chunks)} chunks."
69
+
70
+ except Exception as e:
71
+ return f"Error processing document: {str(e)}"
72
+
73
+ def rag_chat(message, history):
74
+ """RAG chat function for Gradio"""
75
+ global processed_chunks, faiss_index
76
+
77
+ if processed_chunks is None or faiss_index is None:
78
+ return "Please upload and process a document first."
79
+
80
+ try:
81
+ # Create prompt template
82
+ prompt_template = """
83
+ Context information is below.
84
+ ---------------------
85
+ {retrieved_chunk}
86
+ ---------------------
87
+ Given the context information and not prior knowledge, answer the query.
88
+ Query: {question}
89
+ Answer:
90
+ """
91
+
92
+ # Get question embedding
93
+ question_embedding = np.array([get_text_embedding(message)])
94
+
95
+ # Search for similar chunks
96
+ D, I = faiss_index.search(question_embedding.astype(np.float32), k=2)
97
+ retrieved_chunks = [processed_chunks[i] for i in I.tolist()[0]]
98
+
99
+ # Generate response
100
+ prompt = prompt_template.format(
101
+ retrieved_chunk=retrieved_chunks,
102
+ question=message
103
+ )
104
+
105
+ messages = [{"role": "user", "content": prompt}]
106
+ chat_response = client.chat.complete(
107
+ model="mistral-large-latest",
108
+ messages=messages
109
+ )
110
+
111
+ return chat_response.choices[0].message.content
112
+
113
+ except Exception as e:
114
+ return f"Error generating response: {str(e)}"
115
+
116
+ # =============================================================================
117
+ # GRADIO INTERFACES
118
+ # =============================================================================
119
+
120
+ # Create RAG interface with file upload
121
+ with gr.Blocks(title="RAG Chat with Mistral") as rag_interface:
122
+ gr.Markdown("# RAG Chat Interface")
123
+ gr.Markdown("Upload a text file and chat with its content!")
124
+
125
+ with gr.Row():
126
+ file_upload = gr.File(
127
+ label="Upload Text File",
128
+ file_types=[".txt"],
129
+ type="filepath"
130
+ )
131
+ process_btn = gr.Button("Process Document", variant="primary")
132
+
133
+ process_status = gr.Textbox(
134
+ label="Processing Status",
135
+ interactive=False,
136
+ placeholder="Upload a file and click 'Process Document'"
137
+ )
138
+
139
+ # Chat interface
140
+ chatbot = gr.Chatbot(label="RAG Chat")
141
+ msg = gr.Textbox(
142
+ label="Your Message",
143
+ placeholder="Ask questions about the uploaded document...",
144
+ lines=2
145
+ )
146
+
147
+ with gr.Row():
148
+ submit_btn = gr.Button("Send", variant="primary")
149
+ clear_btn = gr.Button("Clear Chat")
150
+
151
+ # Event handlers
152
+ process_btn.click(
153
+ process_document,
154
+ inputs=[file_upload],
155
+ outputs=[process_status]
156
+ )
157
+
158
+ def respond(message, chat_history):
159
+ if not message.strip():
160
+ return "", chat_history
161
+
162
+ # Add user message to history
163
+ chat_history.append([message, None])
164
+
165
+ # Get bot response
166
+ bot_response = rag_chat(message, chat_history)
167
+
168
+ # Add bot response to history
169
+ chat_history[-1][1] = bot_response
170
+
171
+ return "", chat_history
172
+
173
+ submit_btn.click(
174
+ respond,
175
+ inputs=[msg, chatbot],
176
+ outputs=[msg, chatbot]
177
+ )
178
+
179
+ msg.submit(
180
+ respond,
181
+ inputs=[msg, chatbot],
182
+ outputs=[msg, chatbot]
183
+ )
184
+
185
+ clear_btn.click(lambda: ([], ""), outputs=[chatbot, msg])
186
+
187
+ if __name__ == "__main__":
188
+ rag_interface.launch(share=True)
189
+