veerukhannan commited on
Commit
6404fd8
·
verified ·
1 Parent(s): cb7bbf3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -283
app.py CHANGED
@@ -1,329 +1,203 @@
1
  import gradio as gr
2
- from typing import List, Dict, Tuple
3
  from langchain_core.prompts import ChatPromptTemplate
4
  from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
5
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
 
 
6
  from sentence_transformers import SentenceTransformer
7
- import torch
8
  import os
9
- from astrapy.db import AstraDB
10
- from dotenv import load_dotenv
11
- from huggingface_hub import login
12
- import time
13
- import logging
14
- import numpy as np
15
- from functools import lru_cache
16
 
17
- # Configure logging
18
- logging.basicConfig(
19
- level=logging.INFO,
20
- format='%(asctime)s - %(levelname)s - %(message)s'
21
- )
22
- logger = logging.getLogger(__name__)
23
-
24
- # Load environment variables
25
- load_dotenv()
26
- login(token=os.getenv("HUGGINGFACE_API_TOKEN"))
27
-
28
- class LegalTextSearchBot:
29
  def __init__(self):
30
- try:
31
- # Initialize AstraDB connection
32
- self.astra_db = AstraDB(
33
- token=os.getenv("ASTRA_DB_APPLICATION_TOKEN"),
34
- api_endpoint=os.getenv("ASTRA_DB_API_ENDPOINT")
35
- )
36
- self.collection = self.astra_db.collection(os.getenv("ASTRA_DB_COLLECTION"))
37
-
38
- # Initialize language model
39
- model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
40
- model = AutoModelForCausalLM.from_pretrained(
41
- model_name,
42
- device_map="auto",
43
- torch_dtype=torch.float32,
44
- )
45
- tokenizer = AutoTokenizer.from_pretrained(model_name)
46
-
47
- # Initialize text generation pipeline
48
- pipe = pipeline(
49
- "text-generation",
50
- model=model,
51
- tokenizer=tokenizer,
52
- max_new_tokens=512,
53
- temperature=0.7,
54
- top_p=0.95,
55
- repetition_penalty=1.15,
56
- device_map="auto"
57
- )
58
- self.llm = HuggingFacePipeline(pipeline=pipe)
59
-
60
- # Initialize sentence transformer for embeddings
61
- self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
62
-
63
- self.template = """
64
- IMPORTANT: You are a legal assistant that provides accurate information based on the Indian legal sections provided in the context.
65
 
66
  STRICT RULES:
67
- 1. Base your response ONLY on the provided legal sections
68
- 2. If you cannot find relevant information, respond with: "I apologize, but I cannot find information about that in the legal database."
69
  3. Do not make assumptions or use external knowledge
70
- 4. Always cite the specific section numbers you're referring to
71
- 5. Be precise and accurate in your legal interpretations
72
- 6. If quoting from the sections, use quotes and cite the section number
73
-
74
- Context (Legal Sections): {context}
75
 
 
76
  Chat History: {chat_history}
77
-
78
  Question: {question}
79
 
80
- Answer:"""
81
 
82
- self.prompt = ChatPromptTemplate.from_template(self.template)
83
- self.chat_history = ""
84
- self.is_searching = False
85
 
86
- logger.info("Successfully initialized LegalTextSearchBot")
87
 
88
- except Exception as e:
89
- logger.error(f"Error initializing LegalTextSearchBot: {str(e)}")
90
- raise
91
-
92
- def get_embedding(self, text: str) -> List[float]:
93
- """Generate embedding vector for text"""
94
- try:
95
- # Clean and prepare text
96
- text = text.replace('\n', ' ').strip()
97
- if not text:
98
- text = " " # Ensure non-empty input
99
 
100
- # Generate embedding
101
- embedding = self.embedding_model.encode(text)
102
 
103
- # Pad or truncate to 1024 dimensions
104
- if len(embedding) < 1024:
105
- embedding = np.pad(embedding, (0, 1024 - len(embedding)))
106
- elif len(embedding) > 1024:
107
- embedding = embedding[:1024]
108
 
109
- return embedding.tolist()
 
110
 
111
- except Exception as e:
112
- logger.error(f"Error generating embedding: {str(e)}")
113
- raise
114
-
115
- @lru_cache(maxsize=100)
116
- def _cached_search(self, query: str) -> tuple:
117
- """Cached version of vector search"""
118
- try:
119
- # Generate embedding for query
120
- query_embedding = self.get_embedding(query)
121
 
122
- results = list(self.collection.vector_find(
123
- query_embedding,
124
- top_k=5, # Using top_k instead of limit
125
- fields=["section_number", "title", "chapter_number", "chapter_title",
126
- "content", "type", "metadata"]
127
- ))
128
- return tuple(results)
129
- except Exception as e:
130
- logger.error(f"Error in vector search: {str(e)}")
131
- return tuple()
132
 
133
- def _search_astra(self, query: str) -> List[Dict]:
134
- if not self.is_searching:
135
- return []
136
-
 
137
  try:
138
- results = list(self._cached_search(query))
139
-
140
- if not results and self.is_searching:
141
- # Fallback to regular search
142
- cursor = self.collection.find({})
143
- results = []
144
- for doc in cursor:
145
- if len(results) >= 5:
146
- break
147
- results.append(doc)
 
 
 
 
 
 
 
 
148
 
149
- return results
 
150
 
151
  except Exception as e:
152
- logger.error(f"Error searching AstraDB: {str(e)}")
153
- return []
154
 
155
- def format_section(self, section: Dict) -> str:
 
156
  try:
157
- return f"""
158
- {'='*80}
159
- Chapter {section.get('chapter_number', 'N/A')}: {section.get('chapter_title', 'N/A')}
160
- Section {section.get('section_number', 'N/A')}: {section.get('title', 'N/A')}
161
- Type: {section.get('type', 'section')}
162
-
163
- Content:
164
- {section.get('content', 'N/A')}
165
-
166
- References: {', '.join(section.get('metadata', {}).get('references', [])) or 'None'}
167
- {'='*80}
168
- """
169
  except Exception as e:
170
- logger.error(f"Error formatting section: {str(e)}")
171
- return str(section)
172
 
173
- def generate_ai_response(self, context: str, query: str) -> str:
174
- """Generate AI interpretation with error handling"""
175
  try:
176
- chain = self.prompt | self.llm
177
- response = chain.invoke({
178
- "context": context,
179
- "chat_history": self.chat_history,
180
- "question": query
181
- })
182
 
183
- # Handle different response types
184
- if isinstance(response, dict):
185
- return response.get('text', str(response))
186
- elif isinstance(response, list):
187
- return response[0] if response else "No response generated"
188
- else:
189
- return str(response)
190
-
191
- except Exception as e:
192
- logger.error(f"Error generating AI response: {str(e)}")
193
- return "I apologize, but I encountered an error while interpreting the legal sections. Please try rephrasing your question."
194
-
195
- def search_sections(self, query: str, progress=gr.Progress()) -> Tuple[str, str]:
196
- self.is_searching = True
197
- start_time = time.time()
198
-
199
- try:
200
- progress(0, desc="Initializing search...")
201
- if not query.strip():
202
- return "Please enter a search query.", "Please provide a specific legal question or topic to search for."
203
 
204
- progress(0.1, desc="Searching relevant sections...")
205
- search_results = self._search_astra(query)
206
 
207
  if not search_results:
208
- return "No relevant sections found.", "I apologize, but I cannot find relevant sections in the database."
209
-
210
- if not self.is_searching:
211
- return "Search cancelled.", "Search was stopped by user."
212
-
213
- progress(0.3, desc="Processing results...")
214
- raw_results = []
215
- context_parts = []
216
-
217
- for idx, result in enumerate(search_results):
218
- if not self.is_searching:
219
- return "Search cancelled.", "Search was stopped by user."
220
-
221
- raw_results.append(self.format_section(result))
222
- context_parts.append(f"""
223
- Section {result.get('section_number', 'N/A')}: {result.get('title', 'N/A')}
224
- {result.get('content', 'N/A')}
225
- """)
226
- progress((0.3 + (idx * 0.1)), desc=f"Processing result {idx + 1} of {len(search_results)}...")
227
 
228
- if not self.is_searching:
229
- return "Search cancelled.", "Search was stopped by user."
230
 
231
- progress(0.8, desc="Generating AI interpretation...")
232
- context = "\n\n".join(context_parts)
233
 
234
- ai_response = self.generate_ai_response(context, query)
235
- self.chat_history += f"\nUser: {query}\nAI: {ai_response}\n"
236
-
237
- elapsed_time = time.time() - start_time
238
- logger.info(f"Search completed in {elapsed_time:.2f} seconds")
 
 
239
 
240
- progress(1.0, desc="Search complete!")
241
- return "\n".join(raw_results), ai_response
242
 
 
243
  except Exception as e:
244
- logger.error(f"Error processing query: {str(e)}")
245
- return f"Error processing query: {str(e)}", "An error occurred while processing your query."
246
- finally:
247
- self.is_searching = False
248
 
249
- def stop_search(self):
250
- """Stop the current search operation"""
251
- self.is_searching = False
252
- return "Search cancelled.", "Search was stopped by user."
253
 
254
- def create_interface():
255
- with gr.Blocks(title="Bharatiya Nyaya Sanhita Search", theme=gr.themes.Soft()) as iface:
256
- search_bot = LegalTextSearchBot()
257
-
258
- gr.Markdown("""
259
- # 📚 Bharatiya Nyaya Sanhita Legal Search System
260
-
261
- Search through the Bharatiya Nyaya Sanhita, 2023 and get:
262
- 1. 📜 Relevant sections, explanations, and illustrations
263
- 2. 🤖 AI-powered interpretation of the legal content
264
-
265
- *Use the Stop button if you want to cancel a long-running search.*
266
- """)
267
-
268
- with gr.Row():
269
- query_input = gr.Textbox(
270
- label="Your Query",
271
- placeholder="e.g., What are the penalties for public servants who conceal information?",
272
- lines=2
273
- )
274
-
275
- with gr.Row():
276
- search_button = gr.Button("🔍 Search", variant="primary", scale=4)
277
- stop_button = gr.Button("🛑 Stop", variant="stop", scale=1)
278
-
279
- with gr.Row():
280
- raw_output = gr.Markdown(label="📜 Relevant Legal Sections")
281
- ai_output = gr.Markdown(label="🤖 AI Interpretation")
282
-
283
- gr.Examples(
284
- examples=[
285
- "What are the penalties for public servants who conceal information?",
286
- "What constitutes criminal conspiracy?",
287
- "Explain the provisions related to culpable homicide",
288
- "What are the penalties for causing death by negligence?",
289
- "What are the punishments for corruption?"
290
- ],
291
- inputs=query_input,
292
- label="Example Queries"
293
- )
294
-
295
- # Handle search
296
- search_event = search_button.click(
297
- fn=search_bot.search_sections,
298
- inputs=query_input,
299
- outputs=[raw_output, ai_output],
300
- )
301
-
302
- # Handle stop
303
- stop_button.click(
304
- fn=search_bot.stop_search,
305
- outputs=[raw_output, ai_output],
306
- cancels=[search_event]
307
- )
308
-
309
- # Handle Enter key
310
- query_input.submit(
311
- fn=search_bot.search_sections,
312
- inputs=query_input,
313
- outputs=[raw_output, ai_output],
314
- )
315
-
316
- return iface
317
 
 
318
  if __name__ == "__main__":
319
- try:
320
- demo = create_interface()
321
- demo.launch()
322
- except Exception as e:
323
- logger.error(f"Error launching application: {str(e)}")
324
- else:
325
- try:
326
- demo = create_interface()
327
- app = demo.launch(share=False)
328
- except Exception as e:
329
- logger.error(f"Error launching application: {str(e)}")
 
1
  import gradio as gr
2
+ from typing import List, Dict
3
  from langchain_core.prompts import ChatPromptTemplate
4
  from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
5
+ from transformers import pipeline
6
+ import chromadb
7
+ from chromadb.utils import embedding_functions
8
  from sentence_transformers import SentenceTransformer
 
9
  import os
 
 
 
 
 
 
 
10
 
11
+ class ChromaDBChatbot:
 
 
 
 
 
 
 
 
 
 
 
12
  def __init__(self):
13
+ # Initialize in-memory ChromaDB
14
+ self.chroma_client = chromadb.Client()
15
+
16
+ # Initialize embedding function
17
+ self.embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(
18
+ model_name="all-MiniLM-L6-v2"
19
+ )
20
+
21
+ # Create or get collection
22
+ self.collection = self.chroma_client.create_collection(
23
+ name="text_collection",
24
+ embedding_function=self.embedding_function
25
+ )
26
+
27
+ # Initialize the model - using a smaller model suitable for CPU
28
+ pipe = pipeline(
29
+ "text-generation",
30
+ model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
31
+ max_new_tokens=512,
32
+ temperature=0.7,
33
+ top_p=0.95,
34
+ repetition_penalty=1.15
35
+ )
36
+ self.llm = HuggingFacePipeline(pipeline=pipe)
37
+
38
+ # Enhanced prompt templates
39
+ self.templates = {
40
+ "default": """
41
+ IMPORTANT: You are a helpful assistant that provides information based on the retrieved context.
 
 
 
 
 
 
42
 
43
  STRICT RULES:
44
+ 1. Base your response ONLY on the provided context
45
+ 2. If you cannot find relevant information, respond with: "I apologize, but I cannot find information about that in the database."
46
  3. Do not make assumptions or use external knowledge
47
+ 4. Be concise and accurate in your responses
48
+ 5. If quoting from the context, clearly indicate it
 
 
 
49
 
50
+ Context: {context}
51
  Chat History: {chat_history}
 
52
  Question: {question}
53
 
54
+ Answer:""",
55
 
56
+ "summary": """
57
+ Create a concise summary of the following context.
 
58
 
59
+ Context: {context}
60
 
61
+ Key Requirements:
62
+ 1. Highlight the main points
63
+ 2. Keep it brief and clear
64
+ 3. Use bullet points if appropriate
65
+ 4. Include only information from the context
 
 
 
 
 
 
66
 
67
+ Summary:""",
 
68
 
69
+ "technical": """
70
+ Provide a technical explanation based on the context.
 
 
 
71
 
72
+ Context: {context}
73
+ Question: {question}
74
 
75
+ Guidelines:
76
+ 1. Focus on technical details
77
+ 2. Explain complex concepts clearly
78
+ 3. Use appropriate technical terminology
79
+ 4. Provide examples if present in the context
 
 
 
 
 
80
 
81
+ Technical Explanation:"""
82
+ }
83
+
84
+ self.chat_history = ""
85
+ self.loaded = False
 
 
 
 
 
86
 
87
+ def load_data(self, file_path: str):
88
+ """Load data into ChromaDB"""
89
+ if self.loaded:
90
+ return
91
+
92
  try:
93
+ # Read the text file
94
+ with open(file_path, 'r', encoding='utf-8') as f:
95
+ content = f.read()
96
+
97
+ # Split into chunks (512 tokens each with 50 token overlap)
98
+ chunk_size = 512
99
+ overlap = 50
100
+ chunks = []
101
+
102
+ for i in range(0, len(content), chunk_size - overlap):
103
+ chunk = content[i:i + chunk_size]
104
+ chunks.append(chunk)
105
+
106
+ # Add documents to collection
107
+ self.collection.add(
108
+ documents=chunks,
109
+ ids=[f"doc_{i}" for i in range(len(chunks))]
110
+ )
111
 
112
+ self.loaded = True
113
+ print(f"Loaded {len(chunks)} chunks into ChromaDB")
114
 
115
  except Exception as e:
116
+ print(f"Error loading data: {str(e)}")
117
+ return False
118
 
119
+ def _search_chroma(self, query: str) -> List[Dict]:
120
+ """Search ChromaDB for relevant documents"""
121
  try:
122
+ results = self.collection.query(
123
+ query_texts=[query],
124
+ n_results=5
125
+ )
126
+ return [{"content": doc} for doc in results['documents'][0]]
 
 
 
 
 
 
 
127
  except Exception as e:
128
+ print(f"Error searching ChromaDB: {str(e)}")
129
+ return []
130
 
131
+ def chat(self, query: str, history) -> str:
132
+ """Process a query and return a response"""
133
  try:
134
+ if not self.loaded:
135
+ self.load_data('a2023-45.txt')
 
 
 
 
136
 
137
+ # Determine template type based on query
138
+ template_type = "default"
139
+ if any(word in query.lower() for word in ["summarize", "summary"]):
140
+ template_type = "summary"
141
+ elif any(word in query.lower() for word in ["technical", "explain", "how does"]):
142
+ template_type = "technical"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
+ # Search ChromaDB for relevant content
145
+ search_results = self._search_chroma(query)
146
 
147
  if not search_results:
148
+ return "I apologize, but I cannot find information about that in the database."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
+ # Extract and combine relevant content
151
+ context = "\n\n".join([result['content'] for result in search_results])
152
 
153
+ # Create prompt with selected template
154
+ prompt = ChatPromptTemplate.from_template(self.templates[template_type])
155
 
156
+ # Generate response using LLM
157
+ chain = prompt | self.llm
158
+ result = chain.invoke({
159
+ "context": context,
160
+ "chat_history": self.chat_history,
161
+ "question": query
162
+ })
163
 
164
+ # Update chat history
165
+ self.chat_history += f"\nUser: {query}\nAI: {result}\n"
166
 
167
+ return result
168
  except Exception as e:
169
+ return f"Error processing query: {str(e)}"
 
 
 
170
 
171
+ # Initialize the chatbot
172
+ chatbot = ChromaDBChatbot()
 
 
173
 
174
+ # Create the Gradio interface
175
+ demo = gr.Interface(
176
+ fn=chatbot.chat,
177
+ inputs=[
178
+ gr.Textbox(
179
+ label="Your Question",
180
+ placeholder="Ask anything about the document...",
181
+ lines=2
182
+ ),
183
+ gr.State([]) # For chat history
184
+ ],
185
+ outputs=gr.Textbox(label="Answer", lines=10),
186
+ title="ChromaDB-powered Document Q&A",
187
+ description="""
188
+ Ask questions about your document:
189
+ - For summaries, include words like 'summarize' or 'summary'
190
+ - For technical details, use words like 'technical', 'explain', 'how does'
191
+ - For general questions, just ask normally
192
+ """,
193
+ examples=[
194
+ ["Can you summarize the main points?"],
195
+ ["What are the technical details about this topic?"],
196
+ ["Give me a general overview of the content."],
197
+ ],
198
+ theme=gr.themes.Soft()
199
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
+ # Launch the interface
202
  if __name__ == "__main__":
203
+ demo.launch()