veerukhannan commited on
Commit
121ef90
Β·
verified Β·
1 Parent(s): 8d78e47

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -102
app.py CHANGED
@@ -2,68 +2,90 @@ import gradio as gr
2
  from typing import List, Dict, Tuple
3
  from langchain_core.prompts import ChatPromptTemplate
4
  from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
5
- from transformers import pipeline
 
6
  import os
7
  from astrapy.db import AstraDB
8
  from dotenv import load_dotenv
9
  from huggingface_hub import login
10
  import time
11
- import threading
12
- from queue import Queue
13
- import asyncio
 
 
 
 
 
 
14
 
15
  # Load environment variables
16
  load_dotenv()
17
  login(token=os.getenv("HUGGINGFACE_API_TOKEN"))
18
 
19
- class SearchCancelled(Exception):
20
- pass
 
 
 
 
 
 
 
21
 
22
  class LegalTextSearchBot:
23
  def __init__(self):
24
- self.astra_db = AstraDB(
25
- token=os.getenv("ASTRA_DB_APPLICATION_TOKEN"),
26
- api_endpoint=os.getenv("ASTRA_DB_API_ENDPOINT")
27
- )
28
- self.collection = self.astra_db.collection("legal_content")
29
-
30
- pipe = pipeline(
31
- "text-generation",
32
- model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
33
- max_new_tokens=512,
34
- temperature=0.7,
35
- top_p=0.95,
36
- repetition_penalty=1.15
37
- )
38
- self.llm = HuggingFacePipeline(pipeline=pipe)
39
-
40
- self.template = """
41
- IMPORTANT: You are a legal assistant that provides accurate information based on the Indian legal sections provided in the context.
42
-
43
- STRICT RULES:
44
- 1. Base your response ONLY on the provided legal sections
45
- 2. If you cannot find relevant information, respond with: "I apologize, but I cannot find information about that in the legal database."
46
- 3. Do not make assumptions or use external knowledge
47
- 4. Always cite the specific section numbers you're referring to
48
- 5. Be precise and accurate in your legal interpretations
49
- 6. If quoting from the sections, use quotes and cite the section number
50
-
51
- Context (Legal Sections): {context}
52
-
53
- Chat History: {chat_history}
54
-
55
- Question: {question}
56
-
57
- Answer:"""
58
-
59
- self.prompt = ChatPromptTemplate.from_template(self.template)
60
- self.chat_history = ""
61
- self.cancel_search = False
62
-
63
- def _search_astra(self, query: str) -> List[Dict]:
64
- if self.cancel_search:
65
- raise SearchCancelled("Search was cancelled by user")
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  try:
68
  results = list(self.collection.vector_find(
69
  query,
@@ -71,8 +93,19 @@ class LegalTextSearchBot:
71
  fields=["section_number", "title", "chapter_number", "chapter_title",
72
  "content", "type", "metadata"]
73
  ))
 
 
 
 
 
 
 
 
74
 
75
- if not results and not self.cancel_search:
 
 
 
76
  results = list(self.collection.find(
77
  {},
78
  limit=5
@@ -81,14 +114,10 @@ class LegalTextSearchBot:
81
  return results
82
 
83
  except Exception as e:
84
- if not isinstance(e, SearchCancelled):
85
- print(f"Error searching AstraDB: {str(e)}")
86
- raise
87
 
88
  def format_section(self, section: Dict) -> str:
89
- if self.cancel_search:
90
- raise SearchCancelled("Search was cancelled by user")
91
-
92
  try:
93
  return f"""
94
  {'='*80}
@@ -103,33 +132,44 @@ References: {', '.join(section.get('metadata', {}).get('references', [])) or 'No
103
  {'='*80}
104
  """
105
  except Exception as e:
106
- print(f"Error formatting section: {str(e)}")
107
  return str(section)
108
 
109
  def search_sections(self, query: str, progress=gr.Progress()) -> Tuple[str, str]:
110
- self.cancel_search = False
 
111
 
112
  try:
113
- progress(0, desc="Searching relevant sections...")
 
 
 
 
114
  search_results = self._search_astra(query)
115
 
116
  if not search_results:
117
  return "No relevant sections found.", "I apologize, but I cannot find relevant sections in the database."
118
 
 
 
 
119
  progress(0.3, desc="Processing results...")
120
  raw_results = []
121
  context_parts = []
122
 
123
  for idx, result in enumerate(search_results):
124
- if self.cancel_search:
125
- raise SearchCancelled("Search was cancelled by user")
126
 
127
  raw_results.append(self.format_section(result))
128
  context_parts.append(f"""
129
  Section {result.get('section_number')}: {result.get('title')}
130
  {result.get('content', '')}
131
  """)
132
- progress((0.3 + (idx * 0.1)), desc="Processing results...")
 
 
 
133
 
134
  progress(0.8, desc="Generating AI interpretation...")
135
  context = "\n\n".join(context_parts)
@@ -143,21 +183,27 @@ Section {result.get('section_number')}: {result.get('title')}
143
 
144
  self.chat_history += f"\nUser: {query}\nAI: {ai_response}\n"
145
 
146
- progress(1.0, desc="Complete!")
 
 
 
147
  return "\n".join(raw_results), ai_response
148
 
149
- except SearchCancelled:
150
- return "Search cancelled by user.", "Search was stopped. Please try again with a new query."
151
  except Exception as e:
152
- error_msg = f"Error processing query: {str(e)}"
153
- print(error_msg)
154
- return error_msg, "An error occurred while processing your query."
 
155
 
156
- def cancel(self):
157
- self.cancel_search = True
 
 
158
 
159
  def create_interface():
160
  with gr.Blocks(title="Bharatiya Nyaya Sanhita Search", theme=gr.themes.Soft()) as iface:
 
 
161
  gr.Markdown("""
162
  # πŸ“š Bharatiya Nyaya Sanhita Legal Search System
163
 
@@ -165,11 +211,9 @@ def create_interface():
165
  1. πŸ“œ Relevant sections, explanations, and illustrations
166
  2. πŸ€– AI-powered interpretation of the legal content
167
 
168
- Enter your legal query below:
169
  """)
170
 
171
- search_bot = LegalTextSearchBot()
172
-
173
  with gr.Row():
174
  query_input = gr.Textbox(
175
  label="Your Query",
@@ -178,20 +222,12 @@ def create_interface():
178
  )
179
 
180
  with gr.Row():
181
- with gr.Column(scale=4):
182
- search_button = gr.Button("πŸ” Search Legal Sections", variant="primary")
183
- with gr.Column(scale=1):
184
- stop_button = gr.Button("πŸ›‘ Stop Search", variant="stop")
185
 
186
  with gr.Row():
187
- with gr.Column():
188
- raw_output = gr.Markdown(
189
- label="πŸ“œ Relevant Legal Sections"
190
- )
191
- with gr.Column():
192
- ai_output = gr.Markdown(
193
- label="πŸ€– AI Interpretation"
194
- )
195
 
196
  gr.Examples(
197
  examples=[
@@ -205,38 +241,38 @@ def create_interface():
205
  label="Example Queries"
206
  )
207
 
208
- def search(query):
209
- return search_bot.search_sections(query)
210
-
211
- def stop_search():
212
- search_bot.cancel()
213
- return "Search cancelled.", "Search stopped by user."
214
-
215
- search_button.click(
216
- fn=search,
217
  inputs=query_input,
218
  outputs=[raw_output, ai_output],
219
- cancels=[stop_button] # Cancel any ongoing search when stop is clicked
220
  )
221
 
 
222
  stop_button.click(
223
- fn=stop_search,
224
  outputs=[raw_output, ai_output],
225
- cancels=[search_button] # Cancel the search button when stop is clicked
226
  )
227
 
 
228
  query_input.submit(
229
- fn=search,
230
  inputs=query_input,
231
  outputs=[raw_output, ai_output],
232
- cancels=[stop_button]
233
  )
234
 
235
  return iface
236
 
237
  if __name__ == "__main__":
238
- demo = create_interface()
239
- demo.launch()
 
 
 
240
  else:
241
- demo = create_interface()
242
- app = demo.launch(share=False)
 
 
 
 
2
  from typing import List, Dict, Tuple
3
  from langchain_core.prompts import ChatPromptTemplate
4
  from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
5
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
6
+ import torch
7
  import os
8
  from astrapy.db import AstraDB
9
  from dotenv import load_dotenv
10
  from huggingface_hub import login
11
  import time
12
+ import logging
13
+ from functools import lru_cache
14
+
15
+ # Configure logging
16
+ logging.basicConfig(
17
+ level=logging.INFO,
18
+ format='%(asctime)s - %(levelname)s - %(message)s'
19
+ )
20
+ logger = logging.getLogger(__name__)
21
 
22
  # Load environment variables
23
  load_dotenv()
24
  login(token=os.getenv("HUGGINGFACE_API_TOKEN"))
25
 
26
+ # Initialize model with optimized settings
27
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ model_name,
30
+ torch_dtype=torch.float16,
31
+ device_map="auto",
32
+ load_in_8bit=True
33
+ )
34
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
35
 
36
  class LegalTextSearchBot:
37
  def __init__(self):
38
+ try:
39
+ self.astra_db = AstraDB(
40
+ token=os.getenv("ASTRA_DB_APPLICATION_TOKEN"),
41
+ api_endpoint=os.getenv("ASTRA_DB_API_ENDPOINT")
42
+ )
43
+ self.collection = self.astra_db.collection("legal_content")
44
+
45
+ # Initialize pipeline with optimized settings
46
+ pipe = pipeline(
47
+ "text-generation",
48
+ model=model,
49
+ tokenizer=tokenizer,
50
+ max_new_tokens=512,
51
+ temperature=0.7,
52
+ top_p=0.95,
53
+ repetition_penalty=1.15,
54
+ torch_dtype=torch.float16,
55
+ device_map="auto"
56
+ )
57
+ self.llm = HuggingFacePipeline(pipeline=pipe)
58
+
59
+ self.template = """
60
+ IMPORTANT: You are a legal assistant that provides accurate information based on the Indian legal sections provided in the context.
61
+
62
+ STRICT RULES:
63
+ 1. Base your response ONLY on the provided legal sections
64
+ 2. If you cannot find relevant information, respond with: "I apologize, but I cannot find information about that in the legal database."
65
+ 3. Do not make assumptions or use external knowledge
66
+ 4. Always cite the specific section numbers you're referring to
67
+ 5. Be precise and accurate in your legal interpretations
68
+ 6. If quoting from the sections, use quotes and cite the section number
69
+
70
+ Context (Legal Sections): {context}
71
+
72
+ Chat History: {chat_history}
73
+
74
+ Question: {question}
 
 
 
 
 
75
 
76
+ Answer:"""
77
+
78
+ self.prompt = ChatPromptTemplate.from_template(self.template)
79
+ self.chat_history = ""
80
+ self.is_searching = False
81
+
82
+ except Exception as e:
83
+ logger.error(f"Error initializing LegalTextSearchBot: {str(e)}")
84
+ raise
85
+
86
+ @lru_cache(maxsize=100)
87
+ def _cached_search(self, query: str) -> tuple:
88
+ """Cached version of vector search to improve performance"""
89
  try:
90
  results = list(self.collection.vector_find(
91
  query,
 
93
  fields=["section_number", "title", "chapter_number", "chapter_title",
94
  "content", "type", "metadata"]
95
  ))
96
+ return tuple(results) # Convert to tuple for caching
97
+ except Exception as e:
98
+ logger.error(f"Error in vector search: {str(e)}")
99
+ return tuple()
100
+
101
+ def _search_astra(self, query: str) -> List[Dict]:
102
+ if not self.is_searching:
103
+ return []
104
 
105
+ try:
106
+ results = list(self._cached_search(query))
107
+
108
+ if not results and self.is_searching:
109
  results = list(self.collection.find(
110
  {},
111
  limit=5
 
114
  return results
115
 
116
  except Exception as e:
117
+ logger.error(f"Error searching AstraDB: {str(e)}")
118
+ return []
 
119
 
120
  def format_section(self, section: Dict) -> str:
 
 
 
121
  try:
122
  return f"""
123
  {'='*80}
 
132
  {'='*80}
133
  """
134
  except Exception as e:
135
+ logger.error(f"Error formatting section: {str(e)}")
136
  return str(section)
137
 
138
  def search_sections(self, query: str, progress=gr.Progress()) -> Tuple[str, str]:
139
+ self.is_searching = True
140
+ start_time = time.time()
141
 
142
  try:
143
+ progress(0, desc="Initializing search...")
144
+ if not query.strip():
145
+ return "Please enter a search query.", "Please provide a specific legal question or topic to search for."
146
+
147
+ progress(0.1, desc="Searching relevant sections...")
148
  search_results = self._search_astra(query)
149
 
150
  if not search_results:
151
  return "No relevant sections found.", "I apologize, but I cannot find relevant sections in the database."
152
 
153
+ if not self.is_searching:
154
+ return "Search cancelled.", "Search was stopped by user."
155
+
156
  progress(0.3, desc="Processing results...")
157
  raw_results = []
158
  context_parts = []
159
 
160
  for idx, result in enumerate(search_results):
161
+ if not self.is_searching:
162
+ return "Search cancelled.", "Search was stopped by user."
163
 
164
  raw_results.append(self.format_section(result))
165
  context_parts.append(f"""
166
  Section {result.get('section_number')}: {result.get('title')}
167
  {result.get('content', '')}
168
  """)
169
+ progress((0.3 + (idx * 0.1)), desc=f"Processing result {idx + 1} of {len(search_results)}...")
170
+
171
+ if not self.is_searching:
172
+ return "Search cancelled.", "Search was stopped by user."
173
 
174
  progress(0.8, desc="Generating AI interpretation...")
175
  context = "\n\n".join(context_parts)
 
183
 
184
  self.chat_history += f"\nUser: {query}\nAI: {ai_response}\n"
185
 
186
+ elapsed_time = time.time() - start_time
187
+ logger.info(f"Search completed in {elapsed_time:.2f} seconds")
188
+
189
+ progress(1.0, desc="Search complete!")
190
  return "\n".join(raw_results), ai_response
191
 
 
 
192
  except Exception as e:
193
+ logger.error(f"Error processing query: {str(e)}")
194
+ return f"Error processing query: {str(e)}", "An error occurred while processing your query."
195
+ finally:
196
+ self.is_searching = False
197
 
198
+ def stop_search(self):
199
+ """Stop the current search operation"""
200
+ self.is_searching = False
201
+ return "Search cancelled.", "Search was stopped by user."
202
 
203
  def create_interface():
204
  with gr.Blocks(title="Bharatiya Nyaya Sanhita Search", theme=gr.themes.Soft()) as iface:
205
+ search_bot = LegalTextSearchBot()
206
+
207
  gr.Markdown("""
208
  # πŸ“š Bharatiya Nyaya Sanhita Legal Search System
209
 
 
211
  1. πŸ“œ Relevant sections, explanations, and illustrations
212
  2. πŸ€– AI-powered interpretation of the legal content
213
 
214
+ *Use the Stop button if you want to cancel a long-running search.*
215
  """)
216
 
 
 
217
  with gr.Row():
218
  query_input = gr.Textbox(
219
  label="Your Query",
 
222
  )
223
 
224
  with gr.Row():
225
+ search_button = gr.Button("πŸ” Search", variant="primary", scale=4)
226
+ stop_button = gr.Button("πŸ›‘ Stop", variant="stop", scale=1)
 
 
227
 
228
  with gr.Row():
229
+ raw_output = gr.Markdown(label="πŸ“œ Relevant Legal Sections")
230
+ ai_output = gr.Markdown(label="πŸ€– AI Interpretation")
 
 
 
 
 
 
231
 
232
  gr.Examples(
233
  examples=[
 
241
  label="Example Queries"
242
  )
243
 
244
+ # Handle search
245
+ search_event = search_button.click(
246
+ fn=search_bot.search_sections,
 
 
 
 
 
 
247
  inputs=query_input,
248
  outputs=[raw_output, ai_output],
 
249
  )
250
 
251
+ # Handle stop
252
  stop_button.click(
253
+ fn=search_bot.stop_search,
254
  outputs=[raw_output, ai_output],
255
+ cancels=[search_event]
256
  )
257
 
258
+ # Handle Enter key
259
  query_input.submit(
260
+ fn=search_bot.search_sections,
261
  inputs=query_input,
262
  outputs=[raw_output, ai_output],
 
263
  )
264
 
265
  return iface
266
 
267
  if __name__ == "__main__":
268
+ try:
269
+ demo = create_interface()
270
+ demo.launch()
271
+ except Exception as e:
272
+ logger.error(f"Error launching application: {str(e)}")
273
  else:
274
+ try:
275
+ demo = create_interface()
276
+ app = demo.launch(share=False)
277
+ except Exception as e:
278
+ logger.error(f"Error launching application: {str(e)}")