veerukhannan commited on
Commit
2e7f1cb
·
verified ·
1 Parent(s): 5a36c01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -19
app.py CHANGED
@@ -8,13 +8,26 @@ import re
8
  from sentence_transformers import SentenceTransformer
9
  from loguru import logger
10
 
 
 
 
 
 
 
 
 
11
  class LegalAssistant:
12
  def __init__(self):
13
  # Initialize ChromaDB
14
  self.chroma_client = chromadb.Client()
 
 
 
 
 
15
  self.collection = self.chroma_client.get_or_create_collection(
16
  name="legal_documents",
17
- embedding_function=SentenceTransformer('all-MiniLM-L6-v2')
18
  )
19
 
20
  # Load documents if collection is empty
@@ -110,16 +123,16 @@ ERROR HANDLING:
110
  })
111
 
112
  # Add to ChromaDB
113
- for i, section in enumerate(sections):
114
- self.collection.add(
115
- documents=[section["content"]],
116
- metadatas=[{
117
- "title": section["title"],
118
- "source": "a2023-45.txt",
119
- "section_number": i + 1
120
- }],
121
- ids=[f"section_{i+1}"]
122
- )
123
 
124
  logger.info(f"Loaded {len(sections)} sections into ChromaDB")
125
 
@@ -130,9 +143,9 @@ ERROR HANDLING:
130
  def validate_query(self, query: str) -> tuple[bool, str]:
131
  """Validate the input query"""
132
  if not query or len(query.strip()) < 10:
133
- return False, "Query too short. Please provide more details."
134
  if len(query) > 500:
135
- return False, "Query too long. Please be more concise."
136
  if not re.search(r'[?.]$', query):
137
  return False, "Query must end with a question mark or period."
138
  return True, ""
@@ -142,8 +155,7 @@ ERROR HANDLING:
142
  try:
143
  results = self.collection.query(
144
  query_texts=[query],
145
- n_results=3,
146
- include=["metadatas", "documents"]
147
  )
148
 
149
  if results and results['documents']:
@@ -226,7 +238,7 @@ Remember: ONLY use information from the above context. If the information is not
226
 
227
  # Validate that references only contain sections from sources
228
  valid_references = [ref for ref in result.get("reference_sections", [])
229
- if any(source in ref for source in sources)]
230
 
231
  # If references mention unauthorized sources, return error
232
  if len(valid_references) != len(result.get("reference_sections", [])):
@@ -289,9 +301,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
289
  # Indian Legal Assistant
290
  ## Guidelines for Queries:
291
  1. Be specific and clear in your questions
292
- 2. End questions with a question mark
293
- 3. Provide relevant context if available
294
- 4. Keep queries between 10-500 characters
295
  """)
296
 
297
  with gr.Row():
@@ -320,6 +332,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
320
  - Responses are based ONLY on the provided document
321
  - No external legal knowledge is used
322
  - All references are from the document itself
 
323
  """)
324
 
325
  submit_btn.click(
 
8
  from sentence_transformers import SentenceTransformer
9
  from loguru import logger
10
 
11
+ class SentenceTransformerEmbeddings:
12
+ def __init__(self, model_name: str = 'all-MiniLM-L6-v2'):
13
+ self.model = SentenceTransformer(model_name)
14
+
15
+ def __call__(self, input: List[str]) -> List[List[float]]:
16
+ embeddings = self.model.encode(input)
17
+ return embeddings.tolist()
18
+
19
  class LegalAssistant:
20
  def __init__(self):
21
  # Initialize ChromaDB
22
  self.chroma_client = chromadb.Client()
23
+
24
+ # Initialize embedding function
25
+ self.embedding_function = SentenceTransformerEmbeddings()
26
+
27
+ # Create or get collection with proper embedding function
28
  self.collection = self.chroma_client.get_or_create_collection(
29
  name="legal_documents",
30
+ embedding_function=self.embedding_function
31
  )
32
 
33
  # Load documents if collection is empty
 
123
  })
124
 
125
  # Add to ChromaDB
126
+ documents = [section["content"] for section in sections]
127
+ metadatas = [{"title": section["title"], "source": "a2023-45.txt", "section_number": i + 1}
128
+ for i, section in enumerate(sections)]
129
+ ids = [f"section_{i+1}" for i in range(len(sections))]
130
+
131
+ self.collection.add(
132
+ documents=documents,
133
+ metadatas=metadatas,
134
+ ids=ids
135
+ )
136
 
137
  logger.info(f"Loaded {len(sections)} sections into ChromaDB")
138
 
 
143
  def validate_query(self, query: str) -> tuple[bool, str]:
144
  """Validate the input query"""
145
  if not query or len(query.strip()) < 10:
146
+ return False, "Query too short. Please provide more details (minimum 10 characters)."
147
  if len(query) > 500:
148
+ return False, "Query too long. Please be more concise (maximum 500 characters)."
149
  if not re.search(r'[?.]$', query):
150
  return False, "Query must end with a question mark or period."
151
  return True, ""
 
155
  try:
156
  results = self.collection.query(
157
  query_texts=[query],
158
+ n_results=3
 
159
  )
160
 
161
  if results and results['documents']:
 
238
 
239
  # Validate that references only contain sections from sources
240
  valid_references = [ref for ref in result.get("reference_sections", [])
241
+ if any(source.split(" (Section")[0] in ref for source in sources)]
242
 
243
  # If references mention unauthorized sources, return error
244
  if len(valid_references) != len(result.get("reference_sections", [])):
 
301
  # Indian Legal Assistant
302
  ## Guidelines for Queries:
303
  1. Be specific and clear in your questions
304
+ 2. End questions with a question mark or period
305
+ 3. Keep queries between 10-500 characters
306
+ 4. Questions will be answered based ONLY on the provided legal document
307
  """)
308
 
309
  with gr.Row():
 
332
  - Responses are based ONLY on the provided document
333
  - No external legal knowledge is used
334
  - All references are from the document itself
335
+ - Confidence levels indicate how well the answer matches the document content
336
  """)
337
 
338
  submit_btn.click(