lobsterScraper commited on
Commit
9337c3d
·
1 Parent(s): 7e6e8a4

Add application file

Browse files
Files changed (1) hide show
  1. app.py +1022 -0
app.py ADDED
@@ -0,0 +1,1022 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import os
4
+ import gradio as gr
5
+ from pinecone import Pinecone
6
+ from sentence_transformers import SentenceTransformer
7
+ from typing import List, Dict, Optional
8
+ from langchain_google_genai import ChatGoogleGenerativeAI
9
+ from langchain.chains.summarize import load_summarize_chain
10
+ from langchain.prompts import PromptTemplate, ChatPromptTemplate
11
+ from langchain.docstore.document import Document
12
+ import time
13
+ import asyncio
14
+ import plotly.graph_objects as go
15
+ from neo4j import GraphDatabase
16
+ import networkx as nx
17
+ from langchain_community.vectorstores import Neo4jVector
18
+ from langchain.chains.summarize import load_summarize_chain
19
+ from langchain.chains import LLMChain
20
+ from langchain_google_genai import GoogleGenerativeAI, GoogleGenerativeAIEmbeddings
21
+
22
+
23
+
24
+
25
+
26
+ class EnhancedLegalSearchSystem:
27
+ def __init__(
28
+ self,
29
+ google_api_key: str,
30
+ neo4j_url: str,
31
+ neo4j_username: str,
32
+ neo4j_password: str,
33
+ embedding_model_name: str = "intfloat/e5-small-v2",
34
+ device: str = "cpu"
35
+ ):
36
+ """Initialize the Enhanced Legal Search System"""
37
+ # Initialize LLM
38
+ self.llm = GoogleGenerativeAI(
39
+ model="gemini-pro",
40
+ google_api_key=google_api_key,
41
+ temperature=0.1
42
+ )
43
+
44
+ # Initialize embeddings
45
+ self.embeddings = GoogleGenerativeAIEmbeddings(
46
+ model="models/embedding-001",
47
+ google_api_key=google_api_key,
48
+ task_type="retrieval_query"
49
+ )
50
+
51
+ # Initialize Neo4j connection
52
+ self.neo4j_driver = GraphDatabase.driver(
53
+ neo4j_url,
54
+ auth=(neo4j_username, neo4j_password)
55
+ )
56
+
57
+ # Initialize vector store
58
+ self.vector_store = Neo4jVector.from_existing_graph(
59
+ embedding=self.embeddings,
60
+ url=neo4j_url,
61
+ username=neo4j_username,
62
+ password=neo4j_password,
63
+ node_label="Document",
64
+ text_node_properties=["text"],
65
+ embedding_node_property="embedding"
66
+ )
67
+
68
+ # Initialize additional embedding model for enhanced search
69
+ self.local_embedding_model = SentenceTransformer(
70
+ model_name_or_path=embedding_model_name,
71
+ device=device
72
+ )
73
+
74
+ # Initialize prompts
75
+ self.init_prompts()
76
+
77
+ def __del__(self):
78
+ """Cleanup Neo4j connection"""
79
+ if hasattr(self, 'neo4j_driver'):
80
+ self.neo4j_driver.close()
81
+
82
+ def init_prompts(self):
83
+ """Initialize enhanced prompts for legal analysis"""
84
+ self.qa_prompt = ChatPromptTemplate.from_messages([
85
+ ("system", """You are a legal expert assistant specializing in Indian law.
86
+ Analyze the following legal context and provide a detailed, structured answer to the question.
87
+ Include specific sections, rules, and precedents where applicable.
88
+ Format your response with clear headings and bullet points for better readability.
89
+
90
+ Context: {context}"""),
91
+ ("human", "Question: {question}")
92
+ ])
93
+
94
+ self.map_prompt = PromptTemplate(
95
+ template="""
96
+ Analyze the following legal text segment:
97
+
98
+ TEXT: "{text}"
99
+
100
+ Instructions:
101
+ 1. Extract and summarize the key legal points
102
+ 2. Maintain all legal terminology exactly as written
103
+ 3. Preserve section numbers and references
104
+ 4. Keep all specific conditions and requirements
105
+ 5. Include any mentioned time periods or deadlines
106
+
107
+ KEY POINTS:
108
+ """,
109
+ input_variables=["text"] # Removed page_number as it's not used in the template
110
+ )
111
+
112
+ self.combine_prompt = PromptTemplate(
113
+ template="""
114
+ Question: {question}
115
+
116
+ Using ONLY the information from the following legal document excerpts, provide a comprehensive answer:
117
+
118
+ {text}
119
+
120
+ Instructions:
121
+ 1. Base your response EXCLUSIVELY on the provided document excerpts
122
+ 2. If the documents don't contain enough information to fully answer the question, explicitly state what's missing
123
+ 3. Use direct quotes when appropriate
124
+ 4. Organize the response by relevant sections found in the documents
125
+ 5. If there are conflicting statements across documents, highlight them
126
+
127
+ ANALYSIS:
128
+ """,
129
+ input_variables=["text", "question"]
130
+ )
131
+
132
+ # Initialize summarize chain
133
+ self.chain = load_summarize_chain(
134
+ llm=self.llm,
135
+ chain_type="map_reduce",
136
+ map_prompt=self.map_prompt,
137
+ combine_prompt=self.combine_prompt,
138
+ verbose=True
139
+ )
140
+
141
+
142
+ def get_related_legal_entities(self, query: str) -> List[Dict]:
143
+ """Retrieve related legal entities and their relationships"""
144
+ # Corrected Cypher query to handle aggregation properly
145
+ cypher_query = """
146
+ // First, let's check if nodes exist and get their labels
147
+ MATCH (d:Document)
148
+ WHERE toLower(d.text) CONTAINS toLower($query)
149
+ WITH d
150
+ // Match all relationships from the document, collecting their types
151
+ OPTIONAL MATCH (d)-[r]-(connected)
152
+ WHERE NOT connected:Document // Avoid direct document-to-document relations
153
+ WITH d,
154
+ collect(DISTINCT type(r)) as relationTypes,
155
+ collect(DISTINCT labels(connected)) as connectedLabels
156
+
157
+ // Now use these to build our main query
158
+ MATCH (d:Document)-[r1]-(e)
159
+ WHERE toLower(d.text) CONTAINS toLower($query)
160
+ AND NOT e:Document // Exclude direct document connections
161
+ WITH d, r1, e
162
+ // Get secondary connections, but be more specific about what we're looking for
163
+ OPTIONAL MATCH (e)-[r2]-(related)
164
+ WHERE (related:Entity OR related:Concept OR related:Section OR related:Case)
165
+ AND related <> d // Prevent cycles back to original document
166
+ WITH d, {
167
+ source_id: id(d),
168
+ source_text: d.text,
169
+ document_type: COALESCE(d.type, "Unknown"),
170
+ relationship_type: type(r1),
171
+ entity: {
172
+ id: id(e),
173
+ type: CASE WHEN e:Entity THEN "Entity"
174
+ WHEN e:Concept THEN "Concept"
175
+ WHEN e:Section THEN "Section"
176
+ WHEN e:Case THEN "Case"
177
+ ELSE "Other" END,
178
+ text: COALESCE(e.text, e.name, e.title, "Unnamed"),
179
+ properties: properties(e)
180
+ },
181
+ related_entities: collect(DISTINCT {
182
+ id: id(related),
183
+ type: CASE WHEN related:Entity THEN "Entity"
184
+ WHEN related:Concept THEN "Concept"
185
+ WHEN related:Section THEN "Section"
186
+ WHEN related:Case THEN "Case"
187
+ ELSE "Other" END,
188
+ relationship: type(r2),
189
+ text: COALESCE(related.text, related.name, related.title, "Unnamed"),
190
+ properties: properties(related)
191
+ })
192
+ } as result
193
+ WHERE result.entity.text IS NOT NULL // Filter out any results with null entity text
194
+ RETURN DISTINCT result
195
+ ORDER BY result.source_id, result.entity.id
196
+ LIMIT 25
197
+
198
+ """
199
+ try:
200
+ with self.neo4j_driver.session() as session:
201
+ # Execute the improved query
202
+ result = session.run(cypher_query, {"query": query})
203
+ entities = [record["result"] for record in result]
204
+
205
+ # Log the results for debugging
206
+ print(f"Found {len(entities)} related entities")
207
+ if entities:
208
+ for entity in entities:
209
+ print(f"Entity: {entity['entity']['text']}")
210
+ print(f"Source: {entity['source_text'][:100]}...")
211
+ print(f"Related: {len(entity['related_entities'])} connections")
212
+
213
+ return entities
214
+
215
+ except Exception as e:
216
+ print(f"Error in get_related_legal_entities: {str(e)}")
217
+ return []
218
+
219
+ async def process_legal_query(
220
+ self,
221
+ question: str,
222
+ top_k: int = 5,
223
+ context_window: int = 1
224
+ ) -> Dict[str, any]:
225
+ """Process a legal query using both graph and vector search capabilities"""
226
+ try:
227
+ # 1. Perform semantic search
228
+ semantic_results = self.vector_store.similarity_search(
229
+ question,
230
+ k=top_k,
231
+ search_type="hybrid"
232
+ )
233
+
234
+ # 2. Get related legal entities with the full question context
235
+ related_entities = self.get_related_legal_entities(question)
236
+
237
+ # Log the counts for debugging
238
+ print(f"Found {len(semantic_results)} semantic results")
239
+ print(f"Found {len(related_entities)} related entities")
240
+
241
+ # 3. Expand context with related documents
242
+ expanded_results = self.expand_context(
243
+ semantic_results,
244
+ context_window
245
+ )
246
+
247
+ # 4. Generate comprehensive answer
248
+ documents = self._process_results(expanded_results, semantic_results)
249
+
250
+ # 5. Prepare context for LLM
251
+ context = self._prepare_context(documents, related_entities)
252
+
253
+ # 6. Generate answer using LLM
254
+ chain = LLMChain(llm=self.llm, prompt=self.qa_prompt)
255
+ response = await chain.ainvoke({
256
+ "context": context,
257
+ "question": question
258
+ })
259
+ answer = response.get('text', '')
260
+
261
+ # 7. Return structured response with explicit related concepts
262
+ return {
263
+ "status": "Success",
264
+ "answer": answer,
265
+ "documents": self._format_documents(documents),
266
+ "related_concepts": related_entities, # This should now contain data
267
+ "source_ids": sorted(list(set(doc.metadata.get('document_id', 'unknown') for doc in documents))),
268
+ "context_info": {
269
+ "direct_matches": len([d for d in documents if d.metadata.get('context_type') == "DIRECT MATCH"]),
270
+ "context_chunks": len([d for d in documents if d.metadata.get('context_type') == "CONTEXT"])
271
+ }
272
+ }
273
+
274
+ except Exception as e:
275
+ print(f"Error in process_legal_query: {str(e)}") # Add error logging
276
+ return {
277
+ "status": f"Error: {str(e)}",
278
+ "answer": "An error occurred while processing your query.",
279
+ "documents": "",
280
+ "related_concepts": [],
281
+ "source_ids": [],
282
+ "context_info": {}
283
+ }
284
+
285
+
286
+ def expand_context(
287
+ self,
288
+ initial_results: List[Document],
289
+ context_window: int
290
+ ) -> List[Document]:
291
+ """Expand context around search results"""
292
+ expanded_results = []
293
+ seen_ids = set()
294
+
295
+ for doc in initial_results:
296
+ doc_id = doc.metadata.get('document_id', doc.page_content[:50])
297
+ if doc_id not in seen_ids:
298
+ # Query for related documents
299
+ context_results = self.vector_store.similarity_search(
300
+ doc.page_content,
301
+ k=2 * context_window + 1,
302
+ search_type="hybrid"
303
+ )
304
+
305
+ for result in context_results:
306
+ result_id = result.metadata.get('document_id', result.page_content[:50])
307
+ if result_id not in seen_ids:
308
+ expanded_results.append(result)
309
+ seen_ids.add(result_id)
310
+
311
+ return expanded_results
312
+
313
+ def _process_results(self, expanded_results: List[Document], initial_results: List[Document]) -> List[Document]:
314
+ """Process and deduplicate search results"""
315
+ seen_ids = set()
316
+ documents = []
317
+
318
+ for doc in expanded_results:
319
+ doc_id = doc.metadata.get('document_id', doc.page_content[:50])
320
+ if doc_id not in seen_ids:
321
+ seen_ids.add(doc_id)
322
+ is_direct_match = any(
323
+ r.metadata.get('document_id', r.page_content[:50]) == doc_id
324
+ for r in initial_results
325
+ )
326
+
327
+ doc.metadata['context_type'] = (
328
+ "DIRECT MATCH" if is_direct_match else "CONTEXT"
329
+ )
330
+ documents.append(doc)
331
+
332
+ return sorted(
333
+ documents,
334
+ key=lambda x: x.metadata.get('document_id', 'unknown')
335
+ )
336
+
337
+ def _prepare_context(
338
+ self,
339
+ documents: List[Document],
340
+ related_entities: List[Dict]
341
+ ) -> str:
342
+ """Prepare context for LLM processing"""
343
+ context = "\n\nLegal Documents:\n" + "\n".join([
344
+ f"[Document ID: {doc.metadata.get('document_id', 'unknown')}] {doc.page_content}"
345
+ for doc in documents
346
+ ])
347
+
348
+ if related_entities:
349
+ context += "\n\nRelated Legal Concepts and Relationships:\n"
350
+ for entity in related_entities:
351
+ context += f"\n• {entity.get('entity', '')}"
352
+ if entity.get('related_entities'):
353
+ for related in entity['related_entities']:
354
+ if related.get('entity'):
355
+ context += f"\n - {related['type']}: {related['entity']}"
356
+
357
+ return context
358
+
359
+ def _format_documents(self, documents: List[Document]) -> str:
360
+ """Format documents as markdown"""
361
+ markdown = "### Retrieved Documents\n\n"
362
+ for i, doc in enumerate(documents, 1):
363
+ markdown += (
364
+ f"**Document {i}** "
365
+ f"(ID: {doc.metadata.get('document_id', 'unknown')}, "
366
+ f"{doc.metadata.get('context_type', 'UNKNOWN')})\n"
367
+ f"```\n{doc.page_content}\n```\n\n"
368
+ )
369
+ return markdown
370
+
371
+
372
+
373
+ def generate_document_graph(
374
+ self,
375
+ query: str,
376
+ top_k: int = 5,
377
+ similarity_threshold: float = 0.5
378
+ ) -> List[Dict]:
379
+ """Generate graph data based on document similarity and relationships"""
380
+ try:
381
+ # 1. Get initial semantic search results
382
+ semantic_results = self.vector_store.similarity_search(
383
+ query,
384
+ k=top_k,
385
+ search_type="hybrid"
386
+ )
387
+
388
+ # 2. Get embeddings for all documents
389
+ doc_texts = [doc.page_content for doc in semantic_results]
390
+ doc_embeddings = self.local_embedding_model.encode(doc_texts)
391
+
392
+ # 3. Create graph data structure
393
+ graph_data = []
394
+ seen_docs = set()
395
+
396
+ # First, add all documents as nodes
397
+ for i, doc in enumerate(semantic_results):
398
+ doc_id = doc.metadata.get('document_id', f'doc_{i}')
399
+ if doc_id not in seen_docs:
400
+ seen_docs.add(doc_id)
401
+ doc_type = doc.metadata.get('type', 'document')
402
+
403
+ # Create node entry
404
+ graph_data.append({
405
+ 'source_id': doc_id,
406
+ 'source_text': doc.page_content[:200], # Truncate for display
407
+ 'document_type': doc_type,
408
+ 'entity': {
409
+ 'id': doc_id,
410
+ 'type': 'Document',
411
+ 'text': f"Document {i + 1}",
412
+ 'properties': {
413
+ 'similarity': 1.0,
414
+ 'length': len(doc.page_content)
415
+ }
416
+ },
417
+ 'related_entities': []
418
+ })
419
+
420
+ # Add relationships based on similarity
421
+ from sklearn.metrics.pairwise import cosine_similarity
422
+ similarity_matrix = cosine_similarity(doc_embeddings)
423
+
424
+ # Create relationships between similar documents
425
+ for i in range(len(semantic_results)):
426
+ related = []
427
+ for j in range(len(semantic_results)):
428
+ if i != j and similarity_matrix[i][j] > similarity_threshold:
429
+ doc_j = semantic_results[j]
430
+ doc_j_id = doc_j.metadata.get('document_id', f'doc_{j}')
431
+
432
+ related.append({
433
+ 'id': doc_j_id,
434
+ 'type': 'Document',
435
+ 'relationship': 'similar_to',
436
+ 'text': f"Document {j + 1}",
437
+ 'properties': {
438
+ 'similarity_score': float(similarity_matrix[i][j])
439
+ }
440
+ })
441
+
442
+ # Add related documents to the graph data
443
+ if related:
444
+ graph_data[i]['related_entities'] = related
445
+
446
+ return graph_data
447
+
448
+ except Exception as e:
449
+ print(f"Error generating document graph: {str(e)}")
450
+ return []
451
+
452
+
453
+
454
+ def create_graph_visualization(graph_data: List[Dict]) -> go.Figure:
455
+ """Create an interactive graph visualization using Plotly"""
456
+ if not graph_data:
457
+ return go.Figure(layout=go.Layout(title='No documents found'))
458
+
459
+ # Initialize graph
460
+ G = nx.Graph()
461
+
462
+ # Color mapping
463
+ color_map = {
464
+ 'Document': '#3B82F6', # blue
465
+ 'Section': '#10B981', # green
466
+ 'Reference': '#F59E0B' # yellow
467
+ }
468
+
469
+ # Node information storage
470
+ node_colors = []
471
+ node_texts = []
472
+ node_hovers = [] # Full text for hover
473
+ nodes_added = set()
474
+
475
+ # Process nodes and edges
476
+ for data in graph_data:
477
+ source_id = data['source_id']
478
+ source_text = data['source_text']
479
+
480
+ # Add main document node
481
+ if source_id not in nodes_added:
482
+ G.add_node(source_id)
483
+ node_colors.append(color_map['Document'])
484
+ # Short text for display
485
+ node_texts.append(f"Doc {len(nodes_added)+1}")
486
+ # Full text for hover/click
487
+ node_hovers.append(f"Document {len(nodes_added)+1}:<br><br>{source_text}")
488
+ nodes_added.add(source_id)
489
+
490
+ # Process related documents
491
+ for related in data.get('related_entities', []):
492
+ related_id = related['id']
493
+ similarity = related['properties'].get('similarity_score', 0.0)
494
+
495
+ if related_id not in nodes_added:
496
+ G.add_node(related_id)
497
+ node_colors.append(color_map['Document'])
498
+ node_texts.append(f"Doc {len(nodes_added)+1}")
499
+ node_hovers.append(f"Document {len(nodes_added)+1}:<br><br>{related['text']}")
500
+ nodes_added.add(related_id)
501
+
502
+ # Add edge with similarity weight
503
+ G.add_edge(
504
+ source_id,
505
+ related_id,
506
+ weight=similarity,
507
+ relationship=f"Similarity: {similarity:.2f}"
508
+ )
509
+
510
+ # Create layout
511
+ pos = nx.spring_layout(G, k=2.0, iterations=50)
512
+
513
+ # Create edge trace
514
+ edge_x = []
515
+ edge_y = []
516
+ edge_text = []
517
+
518
+ for edge in G.edges(data=True):
519
+ x0, y0 = pos[edge[0]]
520
+ x1, y1 = pos[edge[1]]
521
+
522
+ # Create curved line
523
+ mid_x = (x0 + x1) / 2
524
+ mid_y = (y0 + y1) / 2
525
+ # Add some curvature
526
+ mid_x += (y1 - y0) * 0.1
527
+ mid_y -= (x1 - x0) * 0.1
528
+
529
+ # Add points for curved line
530
+ edge_x.extend([x0, mid_x, x1, None])
531
+ edge_y.extend([y0, mid_y, y1, None])
532
+ edge_text.append(edge[2]['relationship'])
533
+
534
+ edge_trace = go.Scatter(
535
+ x=edge_x,
536
+ y=edge_y,
537
+ line=dict(width=1.5, color='#9CA3AF'),
538
+ hoverinfo='text',
539
+ text=edge_text,
540
+ mode='lines'
541
+ )
542
+
543
+ # Create node trace
544
+ node_x = []
545
+ node_y = []
546
+
547
+ for node in G.nodes():
548
+ x, y = pos[node]
549
+ node_x.append(x)
550
+ node_y.append(y)
551
+
552
+ node_trace = go.Scatter(
553
+ x=node_x,
554
+ y=node_y,
555
+ mode='markers+text',
556
+ hoverinfo='text',
557
+ text=node_texts,
558
+ hovertext=node_hovers, # Full text shown on hover
559
+ textposition="top center",
560
+ marker=dict(
561
+ size=30,
562
+ color=node_colors,
563
+ line=dict(width=2, color='white'),
564
+ symbol='circle'
565
+ ),
566
+ customdata=node_hovers # Store full text for click events
567
+ )
568
+
569
+ # Create figure with updated layout
570
+ fig = go.Figure(
571
+ data=[edge_trace, node_trace],
572
+ layout=go.Layout(
573
+ title={
574
+ 'text': 'Document Similarity Graph<br><sub>Click nodes to view full text</sub>',
575
+ 'y': 0.95,
576
+ 'x': 0.5,
577
+ 'xanchor': 'center',
578
+ 'yanchor': 'top'
579
+ },
580
+ showlegend=False,
581
+ hovermode='closest',
582
+ margin=dict(b=20, l=5, r=5, t=60),
583
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
584
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
585
+ plot_bgcolor='white',
586
+ width=800,
587
+ height=600,
588
+ clickmode='event+select' # Enable click events
589
+ )
590
+ )
591
+
592
+ return fig
593
+
594
+ def create_interface(search_system: EnhancedLegalSearchSystem):
595
+ """Create Gradio interface with interactive graph"""
596
+
597
+ with gr.Blocks(css="footer {display: none !important;}") as demo:
598
+ gr.Markdown("""
599
+ # Enhanced Legal Search System
600
+ Enter your legal query below to search through documents and get an AI-powered analysis.
601
+ This system combines graph-based and semantic search capabilities for comprehensive legal research.
602
+ """)
603
+
604
+ with gr.Row():
605
+ query_input = gr.Textbox(
606
+ label="Legal Query",
607
+ placeholder="e.g., What are the reporting obligations for banks under the Money Laundering Act?",
608
+ lines=3
609
+ )
610
+
611
+ with gr.Row():
612
+ search_button = gr.Button("Search & Analyze")
613
+
614
+ status_output = gr.Textbox(
615
+ label="Status",
616
+ interactive=False
617
+ )
618
+
619
+ with gr.Tabs():
620
+ with gr.TabItem("AI Legal Analysis"):
621
+ analysis_output = gr.Markdown(
622
+ label="AI-Generated Legal Analysis",
623
+ value="Analysis will appear here..."
624
+ )
625
+
626
+ with gr.TabItem("Retrieved Documents"):
627
+ docs_output = gr.Markdown(
628
+ label="Source Documents",
629
+ value="Search results will appear here..."
630
+ )
631
+
632
+ with gr.TabItem("Related Concepts"):
633
+ concepts_output = gr.Json(
634
+ label="Related Legal Concepts",
635
+ value={}
636
+ )
637
+
638
+ with gr.TabItem("Knowledge Graph"):
639
+ # Graph visualization
640
+ graph_output = gr.Plot(
641
+ label="Legal Knowledge Graph"
642
+ )
643
+ # Add text area for showing clicked document content
644
+ selected_doc_content = gr.Textbox(
645
+ label="Selected Document Content",
646
+ interactive=False,
647
+ lines=10
648
+ )
649
+
650
+ async def process_query(query):
651
+ if not query.strip():
652
+ return (
653
+ "Please enter a query",
654
+ "No analysis available",
655
+ "No documents available",
656
+ {},
657
+ None,
658
+ ""
659
+ )
660
+
661
+ results = await search_system.process_legal_query(query)
662
+ graph_data = search_system.generate_document_graph(query)
663
+ graph_fig = create_graph_visualization(graph_data)
664
+
665
+ return (
666
+ results['status'],
667
+ results['answer'],
668
+ results['documents'],
669
+ {"related_concepts": results['related_concepts']},
670
+ graph_fig,
671
+ "Click on a node to view document content"
672
+ )
673
+
674
+ search_button.click(
675
+ fn=process_query,
676
+ inputs=[query_input],
677
+ outputs=[
678
+ status_output,
679
+ analysis_output,
680
+ docs_output,
681
+ concepts_output,
682
+ graph_output,
683
+ selected_doc_content
684
+ ]
685
+ )
686
+
687
+ return demo
688
+
689
+
690
+
691
+
692
+ class LegalSearchSystem:
693
+ def __init__(
694
+ self,
695
+ pinecone_api_key: str = "pcsk_43sajZ_MjcXR2yN5cAcVi8RARyB6i3NP3wLTnTLugbUcN9cUU4q5EfNmuwLPkmxAvykk9o",
696
+ google_api_key: str = "AIzaSyCBkddDicU_4dor9zIqtdpF8PvAeKzqdR0",
697
+ environment: str = "us-east-1",
698
+ index_name: str = "pdf-embeddings",
699
+ dimension: int = 384,
700
+ embedding_model_name: str = "intfloat/e5-small-v2",
701
+ device: str = "cpu"
702
+ ):
703
+ # Initialize Pinecone
704
+ self.pc = Pinecone(api_key=pinecone_api_key)
705
+
706
+ # Initialize LangChain with Gemini
707
+ self.llm = ChatGoogleGenerativeAI(
708
+ model="gemini-pro",
709
+ temperature=0,
710
+ google_api_key=google_api_key
711
+ )
712
+
713
+ # Initialize prompts
714
+ self.map_prompt = PromptTemplate(
715
+ template="""
716
+ Analyze the following legal text segment and extract key information:
717
+
718
+ TEXT: "{text}"
719
+
720
+ Instructions:
721
+ 1. Maintain all legal terminology exactly as written
722
+ 2. Preserve section numbers and references
723
+ 3. Keep all specific conditions and requirements
724
+ 4. Include any mentioned time periods or deadlines
725
+
726
+ DETAILED ANALYSIS:
727
+ """,
728
+ input_variables=["text"]
729
+ )
730
+
731
+ self.combine_prompt = PromptTemplate(
732
+ template="""
733
+ Based on the following excerpts from legal documents and the question: "{question}"
734
+
735
+ EXCERPTS:
736
+ {text}
737
+
738
+ Instructions:
739
+ 1. Synthesize a comprehensive answer that connects relevant sections
740
+ 2. Maintain precise legal language from the source material
741
+ 3. Reference specific sections and subsections where applicable
742
+ 4. If there are seemingly disconnected pieces of information, explain their relationship
743
+ 5. Highlight any conditions or exceptions that span multiple excerpts
744
+
745
+ COMPREHENSIVE LEGAL ANALYSIS:
746
+ """,
747
+ input_variables=["text", "question"]
748
+ )
749
+
750
+ # Initialize chain
751
+ self.chain = load_summarize_chain(
752
+ llm=self.llm,
753
+ chain_type="stuff",
754
+ prompt=self.combine_prompt,
755
+ verbose=True
756
+ )
757
+
758
+ # Initialize Pinecone index and embedding model
759
+ self.index = self.pc.Index(index_name)
760
+ self.embedding_model = SentenceTransformer(
761
+ model_name_or_path=embedding_model_name,
762
+ device=device
763
+ )
764
+
765
+ def search(self, query_text: str, top_k: int = 5, context_window: int = 1) -> Dict:
766
+ """
767
+ Perform a search and analysis of the legal query.
768
+ """
769
+ try:
770
+ # Get search results with context
771
+ results = self.query_and_summarize(
772
+ query_text=query_text,
773
+ top_k=top_k,
774
+ context_window=context_window
775
+ )
776
+
777
+ # Format the results for display
778
+ docs_markdown = self._format_documents(results['raw_results'])
779
+
780
+ return {
781
+ 'status': "Search completed successfully",
782
+ 'documents': docs_markdown,
783
+ 'analysis': results['summary'],
784
+ 'source_pages': results['source_pages'],
785
+ 'context_info': results['context_info']
786
+ }
787
+ except Exception as e:
788
+ return {
789
+ 'status': f"Error during search: {str(e)}",
790
+ 'documents': "Error retrieving documents",
791
+ 'analysis': "Error generating analysis",
792
+ 'source_pages': [],
793
+ 'context_info': {}
794
+ }
795
+
796
+ def query_and_summarize(
797
+ self,
798
+ query_text: str,
799
+ top_k: int = 5,
800
+ filter: Optional[Dict] = None,
801
+ context_window: int = 1
802
+ ) -> Dict:
803
+ """
804
+ Query Pinecone and generate a summary with enhanced context handling.
805
+ """
806
+ # Generate embedding for query
807
+ query_embedding = self.embedding_model.encode(query_text).tolist()
808
+
809
+ # Query Pinecone
810
+ initial_results = self.index.query(
811
+ vector=query_embedding,
812
+ top_k=top_k,
813
+ include_metadata=True,
814
+ filter=filter
815
+ )['matches']
816
+
817
+ # Expand context
818
+ expanded_results = []
819
+ for match in initial_results:
820
+ page_num = match['metadata']['page_number']
821
+ context_filter = {
822
+ "page_number": {
823
+ "$gte": max(1, page_num - context_window),
824
+ "$lte": page_num + context_window
825
+ }
826
+ }
827
+ if filter:
828
+ context_filter.update(filter)
829
+
830
+ context_results = self.index.query(
831
+ vector=self.embedding_model.encode(match['metadata']['text']).tolist(),
832
+ top_k=2 * context_window + 1,
833
+ include_metadata=True,
834
+ filter=context_filter
835
+ )['matches']
836
+
837
+ expanded_results.extend(context_results)
838
+
839
+ # Process results and generate summary
840
+ documents = self._process_results(expanded_results, initial_results)
841
+ summary = self.chain.run(
842
+ input_documents=documents,
843
+ question=query_text
844
+ )
845
+
846
+ return {
847
+ 'raw_results': expanded_results,
848
+ 'summary': summary,
849
+ 'source_pages': list(set(doc.metadata['page_number'] for doc in documents)),
850
+ 'context_info': {
851
+ 'direct_matches': len([d for d in documents if d.metadata['context_type'] == "DIRECT MATCH"]),
852
+ 'context_chunks': len([d for d in documents if d.metadata['context_type'] == "CONTEXT"])
853
+ }
854
+ }
855
+
856
+ def _process_results(self, expanded_results: List[Dict], initial_results: List[Dict]) -> List[Document]:
857
+ """
858
+ Process and deduplicate search results.
859
+ """
860
+ seen_ids = set()
861
+ documents = []
862
+
863
+ for result in expanded_results:
864
+ if result['id'] not in seen_ids:
865
+ seen_ids.add(result['id'])
866
+ is_direct_match = any(r['id'] == result['id'] for r in initial_results)
867
+
868
+ documents.append(Document(
869
+ page_content=result['metadata']['text'],
870
+ metadata={
871
+ 'score': result['score'],
872
+ 'page_number': result['metadata']['page_number'],
873
+ 'context_type': "DIRECT MATCH" if is_direct_match else "CONTEXT"
874
+ }
875
+ ))
876
+
877
+ return sorted(documents, key=lambda x: x.metadata['page_number'])
878
+
879
+ def _format_documents(self, results: List[Dict]) -> str:
880
+ """
881
+ Format search results as markdown.
882
+ """
883
+ markdown = "### Retrieved Documents\n\n"
884
+ for i, result in enumerate(results, 1):
885
+ markdown += f"**Document {i}** (Page {result['metadata']['page_number']})\n"
886
+ markdown += f"```\n{result['metadata']['text']}\n```\n\n"
887
+ return markdown
888
+
889
+
890
+ async def process_query_async(query: str, search_system: LegalSearchSystem, graph_search_system: EnhancedLegalSearchSystem):
891
+ """
892
+ Asynchronous function to process both traditional and graph-based searches
893
+ """
894
+ if not query.strip():
895
+ return "Please enter a query", "", "", "", {}
896
+
897
+ # Regular search (synchronous)
898
+ results = search_system.search(query)
899
+
900
+ try:
901
+ # Graph search (asynchronous)
902
+ graph_results = await graph_search_system.process_legal_query(query)
903
+ graph_documents = graph_results.get('documents', "Error processing graph search")
904
+ graph_concepts = graph_results.get('related_concepts', {})
905
+ except Exception as e:
906
+ graph_documents = f"Error processing graph search: {str(e)}"
907
+ graph_concepts = {}
908
+
909
+ graph_data = graph_search_system.generate_document_graph(query)
910
+ graph_fig = create_graph_visualization(graph_data)
911
+
912
+ return (
913
+ results['status'],
914
+ results['documents'],
915
+ results['analysis'],
916
+ graph_documents,
917
+ graph_concepts,
918
+ graph_fig,
919
+ "Click on a node to view document content"
920
+ )
921
+
922
+ def create_interface(graph_search_system: EnhancedLegalSearchSystem):
923
+ search_system = LegalSearchSystem()
924
+
925
+ with gr.Blocks(css="footer {display: none !important;}") as demo:
926
+ gr.Markdown("""
927
+ # Legal Search AI with LangChain
928
+ Enter your legal query below to search through documents and get an AI-powered analysis.
929
+ """)
930
+
931
+ with gr.Row():
932
+ query_input = gr.Textbox(
933
+ label="Legal Query",
934
+ placeholder="e.g., What are the key principles of contract law?",
935
+ lines=3
936
+ )
937
+
938
+ with gr.Row():
939
+ search_button = gr.Button("Search & Analyze")
940
+
941
+ status_output = gr.Textbox(
942
+ label="Status",
943
+ interactive=False
944
+ )
945
+
946
+ with gr.Tabs():
947
+ with gr.TabItem("Search Results"):
948
+ docs_output = gr.Markdown(
949
+ label="Retrieved Documents",
950
+ value="Search results will appear here..."
951
+ )
952
+
953
+ with gr.TabItem("AI Legal Analysis"):
954
+ summary_output = gr.Markdown(
955
+ label="AI-Generated Legal Analysis",
956
+ value="Analysis will appear here..."
957
+ )
958
+
959
+ with gr.TabItem("Retrieved Documents through Graph Rag"):
960
+ docs_output_graph = gr.Markdown(
961
+ label="Source Documents",
962
+ value="Search results will appear here..."
963
+ )
964
+ graph_analysis_output = gr.JSON(
965
+ label="Related Concepts",
966
+ value={}
967
+ )
968
+
969
+ with gr.TabItem("Knowledge Graph"):
970
+ # Graph visualization
971
+ graph_output = gr.Plot(
972
+ label="Legal Knowledge Graph"
973
+ )
974
+ # Add text area for showing clicked document content
975
+ selected_doc_content = gr.Textbox(
976
+ label="Selected Document Content",
977
+ interactive=False,
978
+ lines=10
979
+ )
980
+
981
+ def process_query(query):
982
+ # Create event loop if it doesn't exist
983
+ try:
984
+ loop = asyncio.get_event_loop()
985
+ except RuntimeError:
986
+ loop = asyncio.new_event_loop()
987
+ asyncio.set_event_loop(loop)
988
+
989
+ # Run the async function and get results
990
+ return loop.run_until_complete(
991
+ process_query_async(query, search_system, graph_search_system)
992
+ )
993
+
994
+ search_button.click(
995
+ fn=process_query,
996
+ inputs=[query_input],
997
+ outputs=[
998
+ status_output,
999
+ docs_output,
1000
+ summary_output,
1001
+ docs_output_graph,
1002
+ graph_analysis_output,
1003
+ graph_output,
1004
+ selected_doc_content
1005
+ ]
1006
+ )
1007
+
1008
+
1009
+
1010
+ return demo
1011
+
1012
+ if __name__ == "__main__":
1013
+ graph_search_system = EnhancedLegalSearchSystem(
1014
+ google_api_key="AIzaSyCBkddDicU_4dor9zIqtdpF8PvAeKzqdR0",
1015
+ neo4j_url="neo4j+s://ffc2cc0f.databases.neo4j.io",
1016
+ neo4j_username="neo4j",
1017
+ neo4j_password="iH1Qe61EwRwhWtoVncW4XiADuUaABOvKtOagu1NY1m4"
1018
+ )
1019
+ demo = create_interface(graph_search_system)
1020
+ demo.launch()
1021
+
1022
+