gmustafa413 commited on
Commit
9d60267
·
verified ·
1 Parent(s): 6db07fb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -0
app.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import numpy as np
4
+ import google.generativeai as genai
5
+ import faiss
6
+ from sentence_transformers import SentenceTransformer
7
+ from datasets import load_dataset
8
+ import warnings
9
+
10
+ # Suppress warnings
11
+ warnings.filterwarnings("ignore")
12
+
13
+ # Configuration
14
+ MODEL_NAME = "all-MiniLM-L6-v2"
15
+ GENAI_MODEL = "models/gemini-pro" # Updated model path
16
+ DATASET_NAME = "midrees2806/7K_Dataset"
17
+ CHUNK_SIZE = 500
18
+ TOP_K = 3
19
+
20
+ # Initialize Gemini - PUT YOUR API KEY HERE (for testing only)
21
+ GEMINI_API_KEY = "AIzaSyASrFvE3gFPigihza0JTuALzZmBx0Kc3d0" # ⚠️ Replace with your actual key
22
+ genai.configure(api_key=GEMINI_API_KEY)
23
+
24
+ class GeminiRAGSystem:
25
+ def __init__(self):
26
+ self.index = None
27
+ self.chunks = []
28
+ self.dataset_loaded = False
29
+ self.loading_error = None
30
+
31
+ # Initialize embedding model
32
+ try:
33
+ self.embedding_model = SentenceTransformer(MODEL_NAME)
34
+ except Exception as e:
35
+ raise RuntimeError(f"Failed to initialize embedding model: {str(e)}")
36
+
37
+ # Load dataset
38
+ self.load_dataset()
39
+
40
+ def load_dataset(self):
41
+ """Load dataset synchronously"""
42
+ try:
43
+ dataset = load_dataset(
44
+ DATASET_NAME,
45
+ split='train',
46
+ download_mode="force_redownload"
47
+ )
48
+
49
+ if 'text' in dataset.features:
50
+ self.chunks = dataset['text'][:1000]
51
+ elif 'context' in dataset.features:
52
+ self.chunks = dataset['context'][:1000]
53
+ else:
54
+ raise ValueError("Dataset must have 'text' or 'context' field")
55
+
56
+ embeddings = self.embedding_model.encode(
57
+ self.chunks,
58
+ show_progress_bar=False,
59
+ convert_to_numpy=True
60
+ )
61
+ self.index = faiss.IndexFlatL2(embeddings.shape[1])
62
+ self.index.add(embeddings.astype('float32'))
63
+
64
+ self.dataset_loaded = True
65
+ except Exception as e:
66
+ self.loading_error = str(e)
67
+ print(f"Dataset loading failed: {str(e)}")
68
+
69
+ def get_relevant_context(self, query: str) -> str:
70
+ """Retrieve most relevant chunks"""
71
+ if not self.index:
72
+ return ""
73
+
74
+ try:
75
+ query_embed = self.embedding_model.encode(
76
+ [query],
77
+ convert_to_numpy=True
78
+ ).astype('float32')
79
+
80
+ _, indices = self.index.search(query_embed, k=TOP_K)
81
+ return "\n\n".join([self.chunks[i] for i in indices[0] if i < len(self.chunks)])
82
+ except Exception as e:
83
+ print(f"Search error: {str(e)}")
84
+ return ""
85
+
86
+ def generate_response(self, query: str) -> str:
87
+ """Generate response with robust error handling"""
88
+ if not self.dataset_loaded:
89
+ if self.loading_error:
90
+ return f"⚠️ Dataset loading failed: {self.loading_error}"
91
+ return "⚠️ System initializing..."
92
+
93
+ context = self.get_relevant_context(query)
94
+ if not context:
95
+ return "No relevant context found"
96
+
97
+ prompt = f"""Answer based on this context:
98
+ {context}
99
+
100
+ Question: {query}
101
+ Answer concisely:"""
102
+
103
+ try:
104
+ model = genai.GenerativeModel(GENAI_MODEL)
105
+ response = model.generate_content(prompt)
106
+ return response.text
107
+ except Exception as e:
108
+ return f"⚠️ API Error: {str(e)}"
109
+
110
+ # Initialize system
111
+ try:
112
+ rag_system = GeminiRAGSystem()
113
+ init_status = "✅ System ready" if rag_system.dataset_loaded else f"⚠️ Initializing... {rag_system.loading_error or ''}"
114
+ except Exception as e:
115
+ init_status = f"❌ Initialization failed: {str(e)}"
116
+ rag_system = None
117
+
118
+ # Create interface
119
+ with gr.Blocks(title="Chatbot") as app:
120
+ gr.Markdown("# Chatbot")
121
+
122
+ chatbot = gr.Chatbot(height=500)
123
+ query = gr.Textbox(label="Your question", placeholder="Ask something...")
124
+ submit_btn = gr.Button("Submit")
125
+ clear_btn = gr.Button("Clear")
126
+ status = gr.Textbox(label="Status", value=init_status)
127
+
128
+ def respond(message, chat_history):
129
+ if not rag_system:
130
+ return chat_history + [(message, "System initialization failed")]
131
+ response = rag_system.generate_response(message)
132
+ return chat_history + [(message, response)]
133
+
134
+ def clear_chat():
135
+ return []
136
+
137
+ submit_btn.click(respond, [query, chatbot], [chatbot])
138
+ query.submit(respond, [query, chatbot], [chatbot])
139
+ clear_btn.click(clear_chat, outputs=chatbot)
140
+
141
+ if __name__ == "__main__":
142
+ app.launch(share=True)