File size: 7,258 Bytes
03bf0d5
 
 
37eb186
925795a
03bf0d5
521baf5
 
 
 
03bf0d5
58bc589
784a1e4
03bf0d5
 
0be7c95
03bf0d5
 
 
cf5ee13
58bc589
 
cf5ee13
58bc589
cf5ee13
58bc589
 
521baf5
37eb186
03bf0d5
 
 
 
0be7c95
925795a
cf5ee13
925795a
 
cf5ee13
925795a
cf5ee13
 
 
925795a
cf5ee13
95d666a
03bf0d5
95d666a
cf5ee13
95d666a
cf5ee13
95d666a
 
 
 
 
cf5ee13
95d666a
 
cf5ee13
 
95d666a
 
cf5ee13
95d666a
 
 
cf5ee13
95d666a
 
 
 
 
cf5ee13
 
95d666a
 
cf5ee13
95d666a
 
cf5ee13
95d666a
cf5ee13
 
 
03bf0d5
 
cf5ee13
925795a
cf5ee13
37eb186
 
ddc98da
cf5ee13
ddc98da
 
 
 
cf5ee13
 
 
 
ddc98da
cf5ee13
 
 
ddc98da
 
 
 
03bf0d5
cf5ee13
03bf0d5
cf5ee13
 
 
 
 
03bf0d5
 
37eb186
cf5ee13
925795a
03bf0d5
925795a
03bf0d5
925795a
03bf0d5
925795a
03bf0d5
cf5ee13
 
03bf0d5
 
58bc589
 
 
cf5ee13
 
58bc589
 
 
cf5ee13
 
58bc589
cf5ee13
 
 
 
58bc589
03bf0d5
cf5ee13
 
 
03bf0d5
cf5ee13
 
ddc98da
 
58bc589
cf5ee13
ddc98da
58bc589
cf5ee13
58bc589
03bf0d5
cf5ee13
58bc589
 
03bf0d5
 
cf5ee13
03bf0d5
0be7c95
58bc589
0be7c95
 
58bc589
 
 
cf5ee13
911a038
925795a
cf5ee13
58bc589
cf5ee13
 
 
 
58bc589
 
ddc98da
 
cf5ee13
ddc98da
03bf0d5
58bc589
 
ddc98da
03bf0d5
 
cf5ee13
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import gradio as gr
import numpy as np
import google.generativeai as genai
import faiss
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import warnings

# Suppress warnings
warnings.filterwarnings("ignore")

# Configuration - PUT YOUR API KEY HERE
GEMINI_API_KEY = "AIzaSyASrFvE3gFPigihza0JTuALzZmBx0Kc3d0"  # ⚠️ REPLACE WITH YOUR KEY
MODEL_NAME = "all-MiniLM-L6-v2"
GENAI_MODEL = "gemini-pro"
DATASET_NAME = "midrees2806/7K_Dataset"
CHUNK_SIZE = 500
TOP_K = 3

# Initialize Gemini with enhanced configuration
genai.configure(
    api_key=GEMINI_API_KEY,
    transport='rest',  # Force REST API
    client_options={
        'api_endpoint': "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent"
    }
)

class GeminiRAGSystem:
    def __init__(self):
        self.index = None
        self.chunks = []
        self.dataset_loaded = False
        self.loading_error = None
        
        print("Initializing embedding model...")
        try:
            self.embedding_model = SentenceTransformer(MODEL_NAME)
            print("Embedding model initialized successfully")
        except Exception as e:
            error_msg = f"Failed to initialize embedding model: {str(e)}"
            print(error_msg)
            raise RuntimeError(error_msg)
        
        print("Loading dataset...")
        self.load_dataset()
    
    def load_dataset(self):
        """Load dataset with detailed error handling"""
        try:
            print(f"Downloading dataset: {DATASET_NAME}")
            dataset = load_dataset(
                DATASET_NAME,
                split='train',
                download_mode="force_redownload"
            )
            print("Dataset downloaded successfully")
            
            if 'text' in dataset.features:
                self.chunks = dataset['text'][:1000]
                print(f"Loaded {len(self.chunks)} text chunks")
            elif 'context' in dataset.features:
                self.chunks = dataset['context'][:1000]
                print(f"Loaded {len(self.chunks)} context chunks")
            else:
                raise ValueError("Dataset must have 'text' or 'context' field")
            
            print("Creating embeddings...")
            embeddings = self.embedding_model.encode(
                self.chunks,
                show_progress_bar=False,
                convert_to_numpy=True
            )
            print(f"Created embeddings with shape {embeddings.shape}")
            
            self.index = faiss.IndexFlatL2(embeddings.shape[1])
            self.index.add(embeddings.astype('float32'))
            print("FAISS index created successfully")
            
            self.dataset_loaded = True
            print("Dataset loading complete")
        except Exception as e:
            error_msg = f"Dataset loading failed: {str(e)}"
            print(error_msg)
            self.loading_error = error_msg
    
    def get_relevant_context(self, query: str) -> str:
        """Retrieve context with debugging"""
        if not self.index:
            print("No index available for search")
            return ""
            
        try:
            print(f"Processing query: {query}")
            query_embed = self.embedding_model.encode(
                [query],
                convert_to_numpy=True
            ).astype('float32')
            print("Query embedded successfully")
            
            distances, indices = self.index.search(query_embed, k=TOP_K)
            print(f"Search results - distances: {distances}, indices: {indices}")
            
            context = "\n\n".join([self.chunks[i] for i in indices[0] if i < len(self.chunks)])
            print(f"Context length: {len(context)} characters")
            return context
        except Exception as e:
            print(f"Search error: {str(e)}")
            return ""

    def generate_response(self, query: str) -> str:
        """Generate response with detailed error handling"""
        if not self.dataset_loaded:
            msg = f"⚠️ Dataset loading failed: {self.loading_error}" if self.loading_error else "⚠️ System initializing..."
            print(msg)
            return msg
        
        print(f"\n{'='*40}\nNew Query: {query}\n{'='*40}")
        
        context = self.get_relevant_context(query)
        if not context:
            print("No relevant context found")
            return "No relevant context found"
        
        prompt = f"""Answer based on this context:
        {context}
        
        Question: {query}
        Answer concisely:"""
        
        print(f"\nPrompt sent to Gemini:\n{prompt}\n")
        
        try:
            model = genai.GenerativeModel(GENAI_MODEL)
            response = model.generate_content(
                prompt,
                generation_config=genai.types.GenerationConfig(
                    temperature=0.3,
                    max_output_tokens=1000
                )
            )
            
            print(f"Raw API response: {response}")
            
            if response.candidates and response.candidates[0].content.parts:
                answer = response.candidates[0].content.parts[0].text
                print(f"Answer: {answer}")
                return answer
            print("⚠️ Empty response from API")
            return "⚠️ No response from API"
        except Exception as e:
            error_msg = f"⚠️ API Error: {str(e)}"
            print(error_msg)
            return error_msg

# Initialize system with verbose logging
print("Initializing RAG system...")
try:
    rag_system = GeminiRAGSystem()
    init_status = "✅ System ready" if rag_system.dataset_loaded else f"⚠️ Initializing... {rag_system.loading_error or ''}"
    print(init_status)
except Exception as e:
    init_status = f"❌ Initialization failed: {str(e)}"
    print(init_status)
    rag_system = None

# Create interface with enhanced debugging
with gr.Blocks(title="Document Chatbot") as app:
    gr.Markdown("# Document Chatbot with Gemini")
    
    with gr.Row():
        chatbot = gr.Chatbot(height=500, label="Chat History")
    
    with gr.Row():
        query = gr.Textbox(label="Your question", placeholder="Ask about the documents...")
    
    with gr.Row():
        submit_btn = gr.Button("Submit", variant="primary")
        clear_btn = gr.Button("Clear", variant="secondary")
    
    status = gr.Textbox(label="System Status", value=init_status, interactive=False)

    def respond(message, chat_history):
        print(f"\n{'='*40}\nUser Query: {message}\n{'='*40}")
        if not rag_system:
            error_msg = "System initialization failed"
            print(error_msg)
            return chat_history + [(message, error_msg)]
        
        response = rag_system.generate_response(message)
        return chat_history + [(message, response)]
    
    def clear_chat():
        print("Chat cleared")
        return []
    
    submit_btn.click(respond, [query, chatbot], [chatbot])
    query.submit(respond, [query, chatbot], [chatbot])
    clear_btn.click(clear_chat, outputs=chatbot)

if __name__ == "__main__":
    print("Launching Gradio interface...")
    app.launch(debug=True)