Spaces:
Runtime error
Runtime error
# import os | |
# import gradio as gr | |
# from pinecone import Pinecone | |
# from sentence_transformers import SentenceTransformer | |
# from typing import List, Dict, Optional | |
# from langchain_google_genai import ChatGoogleGenerativeAI | |
# from langchain.chains.summarize import load_summarize_chain | |
# from langchain.prompts import PromptTemplate, ChatPromptTemplate | |
# from langchain.docstore.document import Document | |
# import time | |
# import asyncio | |
# import plotly.graph_objects as go | |
# from neo4j import GraphDatabase | |
# import networkx as nx | |
# from langchain_community.vectorstores import Neo4jVector | |
# from langchain.chains.summarize import load_summarize_chain | |
# from langchain.chains import LLMChain | |
# from langchain_google_genai import GoogleGenerativeAI, GoogleGenerativeAIEmbeddings | |
# class EnhancedLegalSearchSystem: | |
# def __init__( | |
# self, | |
# google_api_key: str, | |
# neo4j_url: str, | |
# neo4j_username: str, | |
# neo4j_password: str, | |
# embedding_model_name: str = "intfloat/e5-small-v2", | |
# device: str = "cpu" | |
# ): | |
# """Initialize the Enhanced Legal Search System""" | |
# # Initialize LLM | |
# self.llm = GoogleGenerativeAI( | |
# model="gemini-pro", | |
# google_api_key=google_api_key, | |
# temperature=0.1 | |
# ) | |
# # Initialize embeddings | |
# self.embeddings = GoogleGenerativeAIEmbeddings( | |
# model="models/embedding-001", | |
# google_api_key=google_api_key, | |
# task_type="retrieval_query" | |
# ) | |
# # Initialize Neo4j connection | |
# self.neo4j_driver = GraphDatabase.driver( | |
# neo4j_url, | |
# auth=(neo4j_username, neo4j_password) | |
# ) | |
# # Initialize vector store | |
# self.vector_store = Neo4jVector.from_existing_graph( | |
# embedding=self.embeddings, | |
# url=neo4j_url, | |
# username=neo4j_username, | |
# password=neo4j_password, | |
# node_label="Document", | |
# text_node_properties=["text"], | |
# embedding_node_property="embedding" | |
# ) | |
# # Initialize additional embedding model for enhanced search | |
# self.local_embedding_model = SentenceTransformer( | |
# model_name_or_path=embedding_model_name, | |
# device=device | |
# ) | |
# # Initialize prompts | |
# self.init_prompts() | |
# def __del__(self): | |
# """Cleanup Neo4j connection""" | |
# if hasattr(self, 'neo4j_driver'): | |
# self.neo4j_driver.close() | |
# def init_prompts(self): | |
# """Initialize enhanced prompts for legal analysis""" | |
# self.qa_prompt = ChatPromptTemplate.from_messages([ | |
# ("system", """You are a legal expert assistant specializing in Indian law. | |
# Analyze the following legal context and provide a detailed, structured answer to the question. | |
# Include specific sections, rules, and precedents where applicable. | |
# Format your response with clear headings and bullet points for better readability. | |
# Context: {context}"""), | |
# ("human", "Question: {question}") | |
# ]) | |
# self.map_prompt = PromptTemplate( | |
# template=""" | |
# Analyze the following legal text segment: | |
# TEXT: "{text}" | |
# Instructions: | |
# 1. Extract and summarize the key legal points | |
# 2. Maintain all legal terminology exactly as written | |
# 3. Preserve section numbers and references | |
# 4. Keep all specific conditions and requirements | |
# 5. Include any mentioned time periods or deadlines | |
# KEY POINTS: | |
# """, | |
# input_variables=["text"] # Removed page_number as it's not used in the template | |
# ) | |
# self.combine_prompt = PromptTemplate( | |
# template=""" | |
# Question: {question} | |
# Using ONLY the information from the following legal document excerpts, provide a comprehensive answer: | |
# {text} | |
# Instructions: | |
# 1. Base your response EXCLUSIVELY on the provided document excerpts | |
# 2. If the documents don't contain enough information to fully answer the question, explicitly state what's missing | |
# 3. Use direct quotes when appropriate | |
# 4. Organize the response by relevant sections found in the documents | |
# 5. If there are conflicting statements across documents, highlight them | |
# ANALYSIS: | |
# """, | |
# input_variables=["text", "question"] | |
# ) | |
# # Initialize summarize chain | |
# self.chain = load_summarize_chain( | |
# llm=self.llm, | |
# chain_type="map_reduce", | |
# map_prompt=self.map_prompt, | |
# combine_prompt=self.combine_prompt, | |
# verbose=True | |
# ) | |
# def get_related_legal_entities(self, query: str) -> List[Dict]: | |
# """Retrieve related legal entities and their relationships""" | |
# # Corrected Cypher query to handle aggregation properly | |
# cypher_query = """ | |
# // First, let's check if nodes exist and get their labels | |
# MATCH (d:Document) | |
# WHERE toLower(d.text) CONTAINS toLower($query) | |
# WITH d | |
# // Match all relationships from the document, collecting their types | |
# OPTIONAL MATCH (d)-[r]-(connected) | |
# WHERE NOT connected:Document // Avoid direct document-to-document relations | |
# WITH d, | |
# collect(DISTINCT type(r)) as relationTypes, | |
# collect(DISTINCT labels(connected)) as connectedLabels | |
# // Now use these to build our main query | |
# MATCH (d:Document)-[r1]-(e) | |
# WHERE toLower(d.text) CONTAINS toLower($query) | |
# AND NOT e:Document // Exclude direct document connections | |
# WITH d, r1, e | |
# // Get secondary connections, but be more specific about what we're looking for | |
# OPTIONAL MATCH (e)-[r2]-(related) | |
# WHERE (related:Entity OR related:Concept OR related:Section OR related:Case) | |
# AND related <> d // Prevent cycles back to original document | |
# WITH d, { | |
# source_id: id(d), | |
# source_text: d.text, | |
# document_type: COALESCE(d.type, "Unknown"), | |
# relationship_type: type(r1), | |
# entity: { | |
# id: id(e), | |
# type: CASE WHEN e:Entity THEN "Entity" | |
# WHEN e:Concept THEN "Concept" | |
# WHEN e:Section THEN "Section" | |
# WHEN e:Case THEN "Case" | |
# ELSE "Other" END, | |
# text: COALESCE(e.text, e.name, e.title, "Unnamed"), | |
# properties: properties(e) | |
# }, | |
# related_entities: collect(DISTINCT { | |
# id: id(related), | |
# type: CASE WHEN related:Entity THEN "Entity" | |
# WHEN related:Concept THEN "Concept" | |
# WHEN related:Section THEN "Section" | |
# WHEN related:Case THEN "Case" | |
# ELSE "Other" END, | |
# relationship: type(r2), | |
# text: COALESCE(related.text, related.name, related.title, "Unnamed"), | |
# properties: properties(related) | |
# }) | |
# } as result | |
# WHERE result.entity.text IS NOT NULL // Filter out any results with null entity text | |
# RETURN DISTINCT result | |
# ORDER BY result.source_id, result.entity.id | |
# LIMIT 25 | |
# """ | |
# try: | |
# with self.neo4j_driver.session() as session: | |
# # Execute the improved query | |
# result = session.run(cypher_query, {"query": query}) | |
# entities = [record["result"] for record in result] | |
# # Log the results for debugging | |
# print(f"Found {len(entities)} related entities") | |
# if entities: | |
# for entity in entities: | |
# print(f"Entity: {entity['entity']['text']}") | |
# print(f"Source: {entity['source_text'][:100]}...") | |
# print(f"Related: {len(entity['related_entities'])} connections") | |
# return entities | |
# except Exception as e: | |
# print(f"Error in get_related_legal_entities: {str(e)}") | |
# return [] | |
# async def process_legal_query( | |
# self, | |
# question: str, | |
# top_k: int = 5, | |
# context_window: int = 1 | |
# ) -> Dict[str, any]: | |
# """Process a legal query using both graph and vector search capabilities""" | |
# try: | |
# # 1. Perform semantic search | |
# semantic_results = self.vector_store.similarity_search( | |
# question, | |
# k=top_k, | |
# search_type="hybrid" | |
# ) | |
# # 2. Get related legal entities with the full question context | |
# related_entities = self.get_related_legal_entities(question) | |
# # Log the counts for debugging | |
# print(f"Found {len(semantic_results)} semantic results") | |
# print(f"Found {len(related_entities)} related entities") | |
# # 3. Expand context with related documents | |
# expanded_results = self.expand_context( | |
# semantic_results, | |
# context_window | |
# ) | |
# # 4. Generate comprehensive answer | |
# documents = self._process_results(expanded_results, semantic_results) | |
# # 5. Prepare context for LLM | |
# context = self._prepare_context(documents, related_entities) | |
# # 6. Generate answer using LLM | |
# chain = LLMChain(llm=self.llm, prompt=self.qa_prompt) | |
# response = await chain.ainvoke({ | |
# "context": context, | |
# "question": question | |
# }) | |
# answer = response.get('text', '') | |
# # 7. Return structured response with explicit related concepts | |
# return { | |
# "status": "Success", | |
# "answer": answer, | |
# "documents": self._format_documents(documents), | |
# "related_concepts": related_entities, # This should now contain data | |
# "source_ids": sorted(list(set(doc.metadata.get('document_id', 'unknown') for doc in documents))), | |
# "context_info": { | |
# "direct_matches": len([d for d in documents if d.metadata.get('context_type') == "DIRECT MATCH"]), | |
# "context_chunks": len([d for d in documents if d.metadata.get('context_type') == "CONTEXT"]) | |
# } | |
# } | |
# except Exception as e: | |
# print(f"Error in process_legal_query: {str(e)}") # Add error logging | |
# return { | |
# "status": f"Error: {str(e)}", | |
# "answer": "An error occurred while processing your query.", | |
# "documents": "", | |
# "related_concepts": [], | |
# "source_ids": [], | |
# "context_info": {} | |
# } | |
# def expand_context( | |
# self, | |
# initial_results: List[Document], | |
# context_window: int | |
# ) -> List[Document]: | |
# """Expand context around search results""" | |
# expanded_results = [] | |
# seen_ids = set() | |
# for doc in initial_results: | |
# doc_id = doc.metadata.get('document_id', doc.page_content[:50]) | |
# if doc_id not in seen_ids: | |
# # Query for related documents | |
# context_results = self.vector_store.similarity_search( | |
# doc.page_content, | |
# k=2 * context_window + 1, | |
# search_type="hybrid" | |
# ) | |
# for result in context_results: | |
# result_id = result.metadata.get('document_id', result.page_content[:50]) | |
# if result_id not in seen_ids: | |
# expanded_results.append(result) | |
# seen_ids.add(result_id) | |
# return expanded_results | |
# def _process_results(self, expanded_results: List[Document], initial_results: List[Document]) -> List[Document]: | |
# """Process and deduplicate search results""" | |
# seen_ids = set() | |
# documents = [] | |
# for doc in expanded_results: | |
# doc_id = doc.metadata.get('document_id', doc.page_content[:50]) | |
# if doc_id not in seen_ids: | |
# seen_ids.add(doc_id) | |
# is_direct_match = any( | |
# r.metadata.get('document_id', r.page_content[:50]) == doc_id | |
# for r in initial_results | |
# ) | |
# doc.metadata['context_type'] = ( | |
# "DIRECT MATCH" if is_direct_match else "CONTEXT" | |
# ) | |
# documents.append(doc) | |
# return sorted( | |
# documents, | |
# key=lambda x: x.metadata.get('document_id', 'unknown') | |
# ) | |
# def _prepare_context( | |
# self, | |
# documents: List[Document], | |
# related_entities: List[Dict] | |
# ) -> str: | |
# """Prepare context for LLM processing""" | |
# context = "\n\nLegal Documents:\n" + "\n".join([ | |
# f"[Document ID: {doc.metadata.get('document_id', 'unknown')}] {doc.page_content}" | |
# for doc in documents | |
# ]) | |
# if related_entities: | |
# context += "\n\nRelated Legal Concepts and Relationships:\n" | |
# for entity in related_entities: | |
# context += f"\n• {entity.get('entity', '')}" | |
# if entity.get('related_entities'): | |
# for related in entity['related_entities']: | |
# if related.get('entity'): | |
# context += f"\n - {related['type']}: {related['entity']}" | |
# return context | |
# def _format_documents(self, documents: List[Document]) -> str: | |
# """Format documents as markdown""" | |
# markdown = "### Retrieved Documents\n\n" | |
# for i, doc in enumerate(documents, 1): | |
# markdown += ( | |
# f"**Document {i}** " | |
# f"(ID: {doc.metadata.get('document_id', 'unknown')}, " | |
# f"{doc.metadata.get('context_type', 'UNKNOWN')})\n" | |
# f"```\n{doc.page_content}\n```\n\n" | |
# ) | |
# return markdown | |
# def generate_document_graph( | |
# self, | |
# query: str, | |
# top_k: int = 5, | |
# similarity_threshold: float = 0.5 | |
# ) -> List[Dict]: | |
# """Generate graph data based on document similarity and relationships""" | |
# try: | |
# # 1. Get initial semantic search results | |
# semantic_results = self.vector_store.similarity_search( | |
# query, | |
# k=top_k, | |
# search_type="hybrid" | |
# ) | |
# # 2. Get embeddings for all documents | |
# doc_texts = [doc.page_content for doc in semantic_results] | |
# doc_embeddings = self.local_embedding_model.encode(doc_texts) | |
# # 3. Create graph data structure | |
# graph_data = [] | |
# seen_docs = set() | |
# # First, add all documents as nodes | |
# for i, doc in enumerate(semantic_results): | |
# doc_id = doc.metadata.get('document_id', f'doc_{i}') | |
# if doc_id not in seen_docs: | |
# seen_docs.add(doc_id) | |
# doc_type = doc.metadata.get('type', 'document') | |
# # Create node entry | |
# graph_data.append({ | |
# 'source_id': doc_id, | |
# 'source_text': doc.page_content[:200], # Truncate for display | |
# 'document_type': doc_type, | |
# 'entity': { | |
# 'id': doc_id, | |
# 'type': 'Document', | |
# 'text': f"Document {i + 1}", | |
# 'properties': { | |
# 'similarity': 1.0, | |
# 'length': len(doc.page_content) | |
# } | |
# }, | |
# 'related_entities': [] | |
# }) | |
# # Add relationships based on similarity | |
# from sklearn.metrics.pairwise import cosine_similarity | |
# similarity_matrix = cosine_similarity(doc_embeddings) | |
# # Create relationships between similar documents | |
# for i in range(len(semantic_results)): | |
# related = [] | |
# for j in range(len(semantic_results)): | |
# if i != j and similarity_matrix[i][j] > similarity_threshold: | |
# doc_j = semantic_results[j] | |
# doc_j_id = doc_j.metadata.get('document_id', f'doc_{j}') | |
# related.append({ | |
# 'id': doc_j_id, | |
# 'type': 'Document', | |
# 'relationship': 'similar_to', | |
# 'text': f"Document {j + 1}", | |
# 'properties': { | |
# 'similarity_score': float(similarity_matrix[i][j]) | |
# } | |
# }) | |
# # Add related documents to the graph data | |
# if related: | |
# graph_data[i]['related_entities'] = related | |
# return graph_data | |
# except Exception as e: | |
# print(f"Error generating document graph: {str(e)}") | |
# return [] | |
# def create_graph_visualization(graph_data: List[Dict]) -> go.Figure: | |
# """Create an interactive graph visualization using Plotly""" | |
# if not graph_data: | |
# return go.Figure(layout=go.Layout(title='No documents found')) | |
# # Initialize graph | |
# G = nx.Graph() | |
# # Color mapping | |
# color_map = { | |
# 'Document': '#3B82F6', # blue | |
# 'Section': '#10B981', # green | |
# 'Reference': '#F59E0B' # yellow | |
# } | |
# # Node information storage | |
# node_colors = [] | |
# node_texts = [] | |
# node_hovers = [] # Full text for hover | |
# nodes_added = set() | |
# # Process nodes and edges | |
# for data in graph_data: | |
# source_id = data['source_id'] | |
# source_text = data['source_text'] | |
# # Add main document node | |
# if source_id not in nodes_added: | |
# G.add_node(source_id) | |
# node_colors.append(color_map['Document']) | |
# # Short text for display | |
# node_texts.append(f"Doc {len(nodes_added)+1}") | |
# # Full text for hover/click | |
# node_hovers.append(f"Document {len(nodes_added)+1}:<br><br>{source_text}") | |
# nodes_added.add(source_id) | |
# # Process related documents | |
# for related in data.get('related_entities', []): | |
# related_id = related['id'] | |
# similarity = related['properties'].get('similarity_score', 0.0) | |
# if related_id not in nodes_added: | |
# G.add_node(related_id) | |
# node_colors.append(color_map['Document']) | |
# node_texts.append(f"Doc {len(nodes_added)+1}") | |
# node_hovers.append(f"Document {len(nodes_added)+1}:<br><br>{related['text']}") | |
# nodes_added.add(related_id) | |
# # Add edge with similarity weight | |
# G.add_edge( | |
# source_id, | |
# related_id, | |
# weight=similarity, | |
# relationship=f"Similarity: {similarity:.2f}" | |
# ) | |
# # Create layout | |
# pos = nx.spring_layout(G, k=2.0, iterations=50) | |
# # Create edge trace | |
# edge_x = [] | |
# edge_y = [] | |
# edge_text = [] | |
# for edge in G.edges(data=True): | |
# x0, y0 = pos[edge[0]] | |
# x1, y1 = pos[edge[1]] | |
# # Create curved line | |
# mid_x = (x0 + x1) / 2 | |
# mid_y = (y0 + y1) / 2 | |
# # Add some curvature | |
# mid_x += (y1 - y0) * 0.1 | |
# mid_y -= (x1 - x0) * 0.1 | |
# # Add points for curved line | |
# edge_x.extend([x0, mid_x, x1, None]) | |
# edge_y.extend([y0, mid_y, y1, None]) | |
# edge_text.append(edge[2]['relationship']) | |
# edge_trace = go.Scatter( | |
# x=edge_x, | |
# y=edge_y, | |
# line=dict(width=1.5, color='#9CA3AF'), | |
# hoverinfo='text', | |
# text=edge_text, | |
# mode='lines' | |
# ) | |
# # Create node trace | |
# node_x = [] | |
# node_y = [] | |
# for node in G.nodes(): | |
# x, y = pos[node] | |
# node_x.append(x) | |
# node_y.append(y) | |
# node_trace = go.Scatter( | |
# x=node_x, | |
# y=node_y, | |
# mode='markers+text', | |
# hoverinfo='text', | |
# text=node_texts, | |
# hovertext=node_hovers, # Full text shown on hover | |
# textposition="top center", | |
# marker=dict( | |
# size=30, | |
# color=node_colors, | |
# line=dict(width=2, color='white'), | |
# symbol='circle' | |
# ), | |
# customdata=node_hovers # Store full text for click events | |
# ) | |
# # Create figure with updated layout | |
# fig = go.Figure( | |
# data=[edge_trace, node_trace], | |
# layout=go.Layout( | |
# title={ | |
# 'text': 'Document Similarity Graph<br><sub>Click nodes to view full text</sub>', | |
# 'y': 0.95, | |
# 'x': 0.5, | |
# 'xanchor': 'center', | |
# 'yanchor': 'top' | |
# }, | |
# showlegend=False, | |
# hovermode='closest', | |
# margin=dict(b=20, l=5, r=5, t=60), | |
# xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
# yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
# plot_bgcolor='white', | |
# width=800, | |
# height=600, | |
# clickmode='event+select' # Enable click events | |
# ) | |
# ) | |
# return fig | |
# def create_interface(search_system: EnhancedLegalSearchSystem): | |
# """Create Gradio interface with interactive graph""" | |
# with gr.Blocks(css="footer {display: none !important;}") as demo: | |
# gr.Markdown(""" | |
# # Enhanced Legal Search System | |
# Enter your legal query below to search through documents and get an AI-powered analysis. | |
# This system combines graph-based and semantic search capabilities for comprehensive legal research. | |
# """) | |
# with gr.Row(): | |
# query_input = gr.Textbox( | |
# label="Legal Query", | |
# placeholder="e.g., What are the reporting obligations for banks under the Money Laundering Act?", | |
# lines=3 | |
# ) | |
# with gr.Row(): | |
# search_button = gr.Button("Search & Analyze") | |
# status_output = gr.Textbox( | |
# label="Status", | |
# interactive=False | |
# ) | |
# with gr.Tabs(): | |
# with gr.TabItem("AI Legal Analysis"): | |
# analysis_output = gr.Markdown( | |
# label="AI-Generated Legal Analysis", | |
# value="Analysis will appear here..." | |
# ) | |
# with gr.TabItem("Retrieved Documents"): | |
# docs_output = gr.Markdown( | |
# label="Source Documents", | |
# value="Search results will appear here..." | |
# ) | |
# with gr.TabItem("Related Concepts"): | |
# concepts_output = gr.Json( | |
# label="Related Legal Concepts", | |
# value={} | |
# ) | |
# with gr.TabItem("Knowledge Graph"): | |
# # Graph visualization | |
# graph_output = gr.Plot( | |
# label="Legal Knowledge Graph" | |
# ) | |
# # Add text area for showing clicked document content | |
# selected_doc_content = gr.Textbox( | |
# label="Selected Document Content", | |
# interactive=False, | |
# lines=10 | |
# ) | |
# async def process_query(query): | |
# if not query.strip(): | |
# return ( | |
# "Please enter a query", | |
# "No analysis available", | |
# "No documents available", | |
# {}, | |
# None, | |
# "" | |
# ) | |
# results = await search_system.process_legal_query(query) | |
# graph_data = search_system.generate_document_graph(query) | |
# graph_fig = create_graph_visualization(graph_data) | |
# return ( | |
# results['status'], | |
# results['answer'], | |
# results['documents'], | |
# {"related_concepts": results['related_concepts']}, | |
# graph_fig, | |
# "Click on a node to view document content" | |
# ) | |
# search_button.click( | |
# fn=process_query, | |
# inputs=[query_input], | |
# outputs=[ | |
# status_output, | |
# analysis_output, | |
# docs_output, | |
# concepts_output, | |
# graph_output, | |
# selected_doc_content | |
# ] | |
# ) | |
# return demo | |
# class LegalSearchSystem: | |
# def __init__( | |
# self, | |
# pinecone_api_key: str = "pcsk_43sajZ_MjcXR2yN5cAcVi8RARyB6i3NP3wLTnTLugbUcN9cUU4q5EfNmuwLPkmxAvykk9o", | |
# google_api_key: str = "AIzaSyDWHGMd8a70RbL3EBenfUwimcAHjhvgM6M", | |
# environment: str = "us-east-1", | |
# index_name: str = "pdf-embeddings", | |
# dimension: int = 384, | |
# embedding_model_name: str = "intfloat/e5-small-v2", | |
# device: str = "cpu" | |
# ): | |
# # Initialize Pinecone | |
# self.pc = Pinecone(api_key=pinecone_api_key) | |
# # Initialize LangChain with Gemini | |
# self.llm = ChatGoogleGenerativeAI( | |
# model="gemini-pro", | |
# temperature=0, | |
# google_api_key=google_api_key | |
# ) | |
# # Initialize prompts | |
# self.map_prompt = PromptTemplate( | |
# template=""" | |
# Analyze the following legal text segment and extract key information: | |
# TEXT: "{text}" | |
# Instructions: | |
# 1. Maintain all legal terminology exactly as written | |
# 2. Preserve section numbers and references | |
# 3. Keep all specific conditions and requirements | |
# 4. Include any mentioned time periods or deadlines | |
# DETAILED ANALYSIS: | |
# """, | |
# input_variables=["text"] | |
# ) | |
# self.combine_prompt = PromptTemplate( | |
# template=""" | |
# Based on the following excerpts from legal documents and the question: "{question}" | |
# EXCERPTS: | |
# {text} | |
# Instructions: | |
# 1. Synthesize a comprehensive answer that connects relevant sections | |
# 2. Maintain precise legal language from the source material | |
# 3. Reference specific sections and subsections where applicable | |
# 4. If there are seemingly disconnected pieces of information, explain their relationship | |
# 5. Highlight any conditions or exceptions that span multiple excerpts | |
# COMPREHENSIVE LEGAL ANALYSIS: | |
# """, | |
# input_variables=["text", "question"] | |
# ) | |
# # Initialize chain | |
# self.chain = load_summarize_chain( | |
# llm=self.llm, | |
# chain_type="stuff", | |
# prompt=self.combine_prompt, | |
# verbose=True | |
# ) | |
# # Initialize Pinecone index and embedding model | |
# self.index = self.pc.Index(index_name) | |
# self.embedding_model = SentenceTransformer( | |
# model_name_or_path=embedding_model_name, | |
# device=device | |
# ) | |
# def search(self, query_text: str, top_k: int = 5, context_window: int = 1) -> Dict: | |
# """ | |
# Perform a search and analysis of the legal query. | |
# """ | |
# try: | |
# # Get search results with context | |
# results = self.query_and_summarize( | |
# query_text=query_text, | |
# top_k=top_k, | |
# context_window=context_window | |
# ) | |
# # Format the results for display | |
# docs_markdown = self._format_documents(results['raw_results']) | |
# return { | |
# 'status': "Search completed successfully", | |
# 'documents': docs_markdown, | |
# 'analysis': results['summary'], | |
# 'source_pages': results['source_pages'], | |
# 'context_info': results['context_info'] | |
# } | |
# except Exception as e: | |
# return { | |
# 'status': f"Error during search: {str(e)}", | |
# 'documents': "Error retrieving documents", | |
# 'analysis': "Error generating analysis", | |
# 'source_pages': [], | |
# 'context_info': {} | |
# } | |
# def query_and_summarize( | |
# self, | |
# query_text: str, | |
# top_k: int = 5, | |
# filter: Optional[Dict] = None, | |
# context_window: int = 1 | |
# ) -> Dict: | |
# """ | |
# Query Pinecone and generate a summary with enhanced context handling. | |
# """ | |
# # Generate embedding for query | |
# query_embedding = self.embedding_model.encode(query_text).tolist() | |
# # Query Pinecone | |
# initial_results = self.index.query( | |
# vector=query_embedding, | |
# top_k=top_k, | |
# include_metadata=True, | |
# filter=filter | |
# )['matches'] | |
# # Expand context | |
# expanded_results = [] | |
# for match in initial_results: | |
# page_num = match['metadata']['page_number'] | |
# context_filter = { | |
# "page_number": { | |
# "$gte": max(1, page_num - context_window), | |
# "$lte": page_num + context_window | |
# } | |
# } | |
# if filter: | |
# context_filter.update(filter) | |
# context_results = self.index.query( | |
# vector=self.embedding_model.encode(match['metadata']['text']).tolist(), | |
# top_k=2 * context_window + 1, | |
# include_metadata=True, | |
# filter=context_filter | |
# )['matches'] | |
# expanded_results.extend(context_results) | |
# # Process results and generate summary | |
# documents = self._process_results(expanded_results, initial_results) | |
# summary = self.chain.run( | |
# input_documents=documents, | |
# question=query_text | |
# ) | |
# return { | |
# 'raw_results': expanded_results, | |
# 'summary': summary, | |
# 'source_pages': list(set(doc.metadata['page_number'] for doc in documents)), | |
# 'context_info': { | |
# 'direct_matches': len([d for d in documents if d.metadata['context_type'] == "DIRECT MATCH"]), | |
# 'context_chunks': len([d for d in documents if d.metadata['context_type'] == "CONTEXT"]) | |
# } | |
# } | |
# def _process_results(self, expanded_results: List[Dict], initial_results: List[Dict]) -> List[Document]: | |
# """ | |
# Process and deduplicate search results. | |
# """ | |
# seen_ids = set() | |
# documents = [] | |
# for result in expanded_results: | |
# if result['id'] not in seen_ids: | |
# seen_ids.add(result['id']) | |
# is_direct_match = any(r['id'] == result['id'] for r in initial_results) | |
# documents.append(Document( | |
# page_content=result['metadata']['text'], | |
# metadata={ | |
# 'score': result['score'], | |
# 'page_number': result['metadata']['page_number'], | |
# 'context_type': "DIRECT MATCH" if is_direct_match else "CONTEXT" | |
# } | |
# )) | |
# return sorted(documents, key=lambda x: x.metadata['page_number']) | |
# def _format_documents(self, results: List[Dict]) -> str: | |
# """ | |
# Format search results as markdown. | |
# """ | |
# markdown = "### Retrieved Documents\n\n" | |
# for i, result in enumerate(results, 1): | |
# markdown += f"**Document {i}** (Page {result['metadata']['page_number']})\n" | |
# markdown += f"```\n{result['metadata']['text']}\n```\n\n" | |
# return markdown | |
# async def process_query_async(query: str, search_system: LegalSearchSystem, graph_search_system: EnhancedLegalSearchSystem): | |
# """ | |
# Asynchronous function to process both traditional and graph-based searches | |
# """ | |
# if not query.strip(): | |
# return "Please enter a query", "", "", "", {} | |
# # Regular search (synchronous) | |
# results = search_system.search(query) | |
# try: | |
# # Graph search (asynchronous) | |
# graph_results = await graph_search_system.process_legal_query(query) | |
# graph_documents = graph_results.get('documents', "Error processing graph search") | |
# graph_concepts = graph_results.get('related_concepts', {}) | |
# except Exception as e: | |
# graph_documents = f"Error processing graph search: {str(e)}" | |
# graph_concepts = {} | |
# graph_data = graph_search_system.generate_document_graph(query) | |
# graph_fig = create_graph_visualization(graph_data) | |
# return ( | |
# results['status'], | |
# results['documents'], | |
# results['analysis'], | |
# graph_documents, | |
# graph_concepts, | |
# graph_fig, | |
# "Click on a node to view document content" | |
# ) | |
# def create_interface(graph_search_system: EnhancedLegalSearchSystem): | |
# search_system = LegalSearchSystem() | |
# with gr.Blocks(css="footer {display: none !important;}") as demo: | |
# gr.Markdown(""" | |
# # Corporate Law Legal Search Engine | |
# Enter your legal query below to search through documents and get an AI-powered analysis.Queires only related to corporate law will give relevant information | |
# """) | |
# with gr.Row(): | |
# query_input = gr.Textbox( | |
# label="Legal Query", | |
# placeholder="e.g., What are the key principles of contract law?", | |
# lines=3 | |
# ) | |
# with gr.Row(): | |
# search_button = gr.Button("Search & Analyze") | |
# status_output = gr.Textbox( | |
# label="Status", | |
# interactive=False | |
# ) | |
# with gr.Tabs(): | |
# with gr.TabItem("Search Results"): | |
# docs_output = gr.Markdown( | |
# label="Retrieved Documents", | |
# value="Search results will appear here..." | |
# ) | |
# with gr.TabItem("AI Legal Analysis"): | |
# summary_output = gr.Markdown( | |
# label="AI-Generated Legal Analysis", | |
# value="Analysis will appear here..." | |
# ) | |
# with gr.TabItem("Retrieved Documents through Graph Rag"): | |
# docs_output_graph = gr.Markdown( | |
# label="Source Documents", | |
# value="Search results will appear here..." | |
# ) | |
# graph_analysis_output = gr.JSON( | |
# label="Related Concepts", | |
# value={} | |
# ) | |
# with gr.TabItem("Knowledge Graph"): | |
# # Graph visualization | |
# graph_output = gr.Plot( | |
# label="Legal Knowledge Graph" | |
# ) | |
# # Add text area for showing clicked document content | |
# selected_doc_content = gr.Textbox( | |
# label="Selected Document Content", | |
# interactive=False, | |
# lines=10 | |
# ) | |
# def process_query(query): | |
# # Create event loop if it doesn't exist | |
# try: | |
# loop = asyncio.get_event_loop() | |
# except RuntimeError: | |
# loop = asyncio.new_event_loop() | |
# asyncio.set_event_loop(loop) | |
# # Run the async function and get results | |
# return loop.run_until_complete( | |
# process_query_async(query, search_system, graph_search_system) | |
# ) | |
# search_button.click( | |
# fn=process_query, | |
# inputs=[query_input], | |
# outputs=[ | |
# status_output, | |
# docs_output, | |
# summary_output, | |
# docs_output_graph, | |
# graph_analysis_output, | |
# graph_output, | |
# selected_doc_content | |
# ] | |
# ) | |
# return demo | |
# if __name__ == "__main__": | |
# graph_search_system = EnhancedLegalSearchSystem( | |
# google_api_key="AIzaSyDWHGMd8a70RbL3EBenfUwimcAHjhvgM6M", | |
# neo4j_url="neo4j+s://a63462d0.databases.neo4j.io", | |
# neo4j_username="neo4j", | |
# neo4j_password="nU8Ut5N8k7LmQzNPe7vUbpZxMirK8rHrmLuzPc2G_Zc" | |
# ) | |
# demo = create_interface(graph_search_system) | |
# demo.launch() | |
import os | |
import gradio as gr | |
from pinecone import Pinecone | |
from sentence_transformers import SentenceTransformer | |
from typing import List, Dict, Optional | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain.chains.summarize import load_summarize_chain | |
from langchain.prompts import PromptTemplate, ChatPromptTemplate | |
from langchain.docstore.document import Document | |
import time | |
import asyncio | |
import plotly.graph_objects as go | |
from neo4j import GraphDatabase | |
import networkx as nx | |
from langchain_community.vectorstores import Neo4jVector | |
from langchain.chains.summarize import load_summarize_chain | |
from langchain.chains import LLMChain | |
from langchain_google_genai import GoogleGenerativeAI, GoogleGenerativeAIEmbeddings | |
class EnhancedLegalSearchSystem: | |
def __init__( | |
self, | |
google_api_key: str, | |
neo4j_url: str, | |
neo4j_username: str, | |
neo4j_password: str, | |
embedding_model_name: str = "intfloat/e5-small-v2", | |
device: str = "cpu" | |
): | |
"""Initialize the Enhanced Legal Search System""" | |
# Initialize LLM | |
self.llm = GoogleGenerativeAI( | |
model="gemini-pro", | |
google_api_key=google_api_key, | |
temperature=0.1 | |
) | |
# Initialize embeddings | |
self.embeddings = GoogleGenerativeAIEmbeddings( | |
model="models/embedding-001", | |
google_api_key=google_api_key, | |
task_type="retrieval_query" | |
) | |
# Initialize Neo4j connection | |
self.neo4j_driver = GraphDatabase.driver( | |
neo4j_url, | |
auth=(neo4j_username, neo4j_password) | |
) | |
# Initialize vector store | |
self.vector_store = Neo4jVector.from_existing_graph( | |
embedding=self.embeddings, | |
url=neo4j_url, | |
username=neo4j_username, | |
password=neo4j_password, | |
node_label="Document", | |
text_node_properties=["text"], | |
embedding_node_property="embedding" | |
) | |
# Initialize additional embedding model for enhanced search | |
self.local_embedding_model = SentenceTransformer( | |
model_name_or_path=embedding_model_name, | |
device=device | |
) | |
# Initialize prompts | |
self.init_prompts() | |
def __del__(self): | |
"""Cleanup Neo4j connection""" | |
if hasattr(self, 'neo4j_driver'): | |
self.neo4j_driver.close() | |
def init_prompts(self): | |
"""Initialize enhanced prompts for legal analysis""" | |
self.qa_prompt = ChatPromptTemplate.from_messages([ | |
("system", """You are a legal expert assistant specializing in Indian law. | |
Analyze the following legal context and provide a detailed, structured answer to the question. | |
Include specific sections, rules, and precedents where applicable. | |
Format your response with clear headings and bullet points for better readability. | |
Context: {context}"""), | |
("human", "Question: {question}") | |
]) | |
self.map_prompt = PromptTemplate( | |
template=""" | |
Analyze the following legal text segment: | |
TEXT: "{text}" | |
Instructions: | |
1. Extract and summarize the key legal points | |
2. Maintain all legal terminology exactly as written | |
3. Preserve section numbers and references | |
4. Keep all specific conditions and requirements | |
5. Include any mentioned time periods or deadlines | |
KEY POINTS: | |
""", | |
input_variables=["text"] # Removed page_number as it's not used in the template | |
) | |
self.combine_prompt = PromptTemplate( | |
template=""" | |
Question: {question} | |
Using ONLY the information from the following legal document excerpts, provide a comprehensive answer: | |
{text} | |
Instructions: | |
1. Base your response EXCLUSIVELY on the provided document excerpts | |
2. If the documents don't contain enough information to fully answer the question, explicitly state what's missing | |
3. Use direct quotes when appropriate | |
4. Organize the response by relevant sections found in the documents | |
5. If there are conflicting statements across documents, highlight them | |
ANALYSIS: | |
""", | |
input_variables=["text", "question"] | |
) | |
# Initialize summarize chain | |
self.chain = load_summarize_chain( | |
llm=self.llm, | |
chain_type="map_reduce", | |
map_prompt=self.map_prompt, | |
combine_prompt=self.combine_prompt, | |
verbose=True | |
) | |
def get_related_legal_entities(self, query: str) -> List[Dict]: | |
"""Retrieve related legal entities and their relationships""" | |
# Corrected Cypher query to handle aggregation properly | |
cypher_query = """ | |
// First, let's check if nodes exist and get their labels | |
MATCH (d:Document) | |
WHERE toLower(d.text) CONTAINS toLower($query) | |
WITH d | |
// Match all relationships from the document, collecting their types | |
OPTIONAL MATCH (d)-[r]-(connected) | |
WHERE NOT connected:Document // Avoid direct document-to-document relations | |
WITH d, | |
collect(DISTINCT type(r)) as relationTypes, | |
collect(DISTINCT labels(connected)) as connectedLabels | |
// Now use these to build our main query | |
MATCH (d:Document)-[r1]-(e) | |
WHERE toLower(d.text) CONTAINS toLower($query) | |
AND NOT e:Document // Exclude direct document connections | |
WITH d, r1, e | |
// Get secondary connections, but be more specific about what we're looking for | |
OPTIONAL MATCH (e)-[r2]-(related) | |
WHERE (related:Entity OR related:Concept OR related:Section OR related:Case) | |
AND related <> d // Prevent cycles back to original document | |
WITH d, { | |
source_id: id(d), | |
source_text: d.text, | |
document_type: COALESCE(d.type, "Unknown"), | |
relationship_type: type(r1), | |
entity: { | |
id: id(e), | |
type: CASE WHEN e:Entity THEN "Entity" | |
WHEN e:Concept THEN "Concept" | |
WHEN e:Section THEN "Section" | |
WHEN e:Case THEN "Case" | |
ELSE "Other" END, | |
text: COALESCE(e.text, e.name, e.title, "Unnamed"), | |
properties: properties(e) | |
}, | |
related_entities: collect(DISTINCT { | |
id: id(related), | |
type: CASE WHEN related:Entity THEN "Entity" | |
WHEN related:Concept THEN "Concept" | |
WHEN related:Section THEN "Section" | |
WHEN related:Case THEN "Case" | |
ELSE "Other" END, | |
relationship: type(r2), | |
text: COALESCE(related.text, related.name, related.title, "Unnamed"), | |
properties: properties(related) | |
}) | |
} as result | |
WHERE result.entity.text IS NOT NULL // Filter out any results with null entity text | |
RETURN DISTINCT result | |
ORDER BY result.source_id, result.entity.id | |
LIMIT 25 | |
""" | |
try: | |
with self.neo4j_driver.session() as session: | |
# Execute the improved query | |
result = session.run(cypher_query, {"query": query}) | |
entities = [record["result"] for record in result] | |
# Log the results for debugging | |
print(f"Found {len(entities)} related entities") | |
if entities: | |
for entity in entities: | |
print(f"Entity: {entity['entity']['text']}") | |
print(f"Source: {entity['source_text'][:100]}...") | |
print(f"Related: {len(entity['related_entities'])} connections") | |
return entities | |
except Exception as e: | |
print(f"Error in get_related_legal_entities: {str(e)}") | |
return [] | |
async def process_legal_query( | |
self, | |
question: str, | |
top_k: int = 5, | |
context_window: int = 1 | |
) -> Dict[str, any]: | |
"""Process a legal query using both graph and vector search capabilities""" | |
try: | |
# 1. Perform semantic search | |
semantic_results = self.vector_store.similarity_search( | |
question, | |
k=top_k, | |
search_type="hybrid" | |
) | |
# 2. Get related legal entities with the full question context | |
related_entities = self.get_related_legal_entities(question) | |
# Log the counts for debugging | |
print(f"Found {len(semantic_results)} semantic results") | |
print(f"Found {len(related_entities)} related entities") | |
# 3. Expand context with related documents | |
expanded_results = self.expand_context( | |
semantic_results, | |
context_window | |
) | |
# 4. Generate comprehensive answer | |
documents = self._process_results(expanded_results, semantic_results) | |
# 5. Prepare context for LLM | |
context = self._prepare_context(documents, related_entities) | |
# 6. Generate answer using LLM | |
chain = LLMChain(llm=self.llm, prompt=self.qa_prompt) | |
response = await chain.ainvoke({ | |
"context": context, | |
"question": question | |
}) | |
answer = response.get('text', '') | |
# 7. Return structured response with explicit related concepts | |
return { | |
"status": "Success", | |
"answer": answer, | |
"documents": self._format_documents(documents), | |
"related_concepts": related_entities, # This should now contain data | |
"source_ids": sorted(list(set(doc.metadata.get('document_id', 'unknown') for doc in documents))), | |
"context_info": { | |
"direct_matches": len([d for d in documents if d.metadata.get('context_type') == "DIRECT MATCH"]), | |
"context_chunks": len([d for d in documents if d.metadata.get('context_type') == "CONTEXT"]) | |
} | |
} | |
except Exception as e: | |
print(f"Error in process_legal_query: {str(e)}") # Add error logging | |
return { | |
"status": f"Error: {str(e)}", | |
"answer": "An error occurred while processing your query.", | |
"documents": "", | |
"related_concepts": [], | |
"source_ids": [], | |
"context_info": {} | |
} | |
def expand_context( | |
self, | |
initial_results: List[Document], | |
context_window: int | |
) -> List[Document]: | |
"""Expand context around search results""" | |
expanded_results = [] | |
seen_ids = set() | |
for doc in initial_results: | |
doc_id = doc.metadata.get('document_id', doc.page_content[:50]) | |
if doc_id not in seen_ids: | |
# Query for related documents | |
context_results = self.vector_store.similarity_search( | |
doc.page_content, | |
k=2 * context_window + 1, | |
search_type="hybrid" | |
) | |
for result in context_results: | |
result_id = result.metadata.get('document_id', result.page_content[:50]) | |
if result_id not in seen_ids: | |
expanded_results.append(result) | |
seen_ids.add(result_id) | |
return expanded_results | |
def _process_results(self, expanded_results: List[Document], initial_results: List[Document]) -> List[Document]: | |
"""Process and deduplicate search results""" | |
seen_ids = set() | |
documents = [] | |
for doc in expanded_results: | |
doc_id = doc.metadata.get('document_id', doc.page_content[:50]) | |
if doc_id not in seen_ids: | |
seen_ids.add(doc_id) | |
is_direct_match = any( | |
r.metadata.get('document_id', r.page_content[:50]) == doc_id | |
for r in initial_results | |
) | |
doc.metadata['context_type'] = ( | |
"DIRECT MATCH" if is_direct_match else "CONTEXT" | |
) | |
documents.append(doc) | |
return sorted( | |
documents, | |
key=lambda x: x.metadata.get('document_id', 'unknown') | |
) | |
def _prepare_context( | |
self, | |
documents: List[Document], | |
related_entities: List[Dict] | |
) -> str: | |
"""Prepare context for LLM processing""" | |
context = "\n\nLegal Documents:\n" + "\n".join([ | |
f"[Document ID: {doc.metadata.get('document_id', 'unknown')}] {doc.page_content}" | |
for doc in documents | |
]) | |
if related_entities: | |
context += "\n\nRelated Legal Concepts and Relationships:\n" | |
for entity in related_entities: | |
context += f"\n• {entity.get('entity', '')}" | |
if entity.get('related_entities'): | |
for related in entity['related_entities']: | |
if related.get('entity'): | |
context += f"\n - {related['type']}: {related['entity']}" | |
return context | |
def _format_documents(self, documents: List[Document]) -> str: | |
"""Format documents as markdown""" | |
markdown = "### Retrieved Documents\n\n" | |
for i, doc in enumerate(documents, 1): | |
markdown += ( | |
f"**Document {i}** " | |
f"(ID: {doc.metadata.get('document_id', 'unknown')}, " | |
f"{doc.metadata.get('context_type', 'UNKNOWN')})\n" | |
f"```\n{doc.page_content}\n```\n\n" | |
) | |
return markdown | |
def generate_document_graph( | |
self, | |
query: str, | |
top_k: int = 5, | |
similarity_threshold: float = 0.5 | |
) -> List[Dict]: | |
"""Generate graph data based on document similarity and relationships""" | |
try: | |
# 1. Get initial semantic search results | |
semantic_results = self.vector_store.similarity_search( | |
query, | |
k=top_k, | |
search_type="hybrid" | |
) | |
# 2. Get embeddings for all documents | |
doc_texts = [doc.page_content for doc in semantic_results] | |
doc_embeddings = self.local_embedding_model.encode(doc_texts) | |
# 3. Create graph data structure | |
graph_data = [] | |
seen_docs = set() | |
# First, add all documents as nodes | |
for i, doc in enumerate(semantic_results): | |
doc_id = doc.metadata.get('document_id', f'doc_{i}') | |
if doc_id not in seen_docs: | |
seen_docs.add(doc_id) | |
doc_type = doc.metadata.get('type', 'document') | |
# Create node entry | |
graph_data.append({ | |
'source_id': doc_id, | |
'source_text': doc.page_content[:200], # Truncate for display | |
'document_type': doc_type, | |
'entity': { | |
'id': doc_id, | |
'type': 'Document', | |
'text': f"Document {i + 1}", | |
'properties': { | |
'similarity': 1.0, | |
'length': len(doc.page_content) | |
} | |
}, | |
'related_entities': [] | |
}) | |
# Add relationships based on similarity | |
from sklearn.metrics.pairwise import cosine_similarity | |
similarity_matrix = cosine_similarity(doc_embeddings) | |
# Create relationships between similar documents | |
for i in range(len(semantic_results)): | |
related = [] | |
for j in range(len(semantic_results)): | |
if i != j and similarity_matrix[i][j] > similarity_threshold: | |
doc_j = semantic_results[j] | |
doc_j_id = doc_j.metadata.get('document_id', f'doc_{j}') | |
related.append({ | |
'id': doc_j_id, | |
'type': 'Document', | |
'relationship': 'similar_to', | |
'text': f"Document {j + 1}", | |
'properties': { | |
'similarity_score': float(similarity_matrix[i][j]) | |
} | |
}) | |
# Add related documents to the graph data | |
if related: | |
graph_data[i]['related_entities'] = related | |
return graph_data | |
except Exception as e: | |
print(f"Error generating document graph: {str(e)}") | |
return [] | |
def create_graph_visualization(graph_data: List[Dict]) -> go.Figure: | |
"""Create an interactive graph visualization using Plotly""" | |
if not graph_data: | |
return go.Figure(layout=go.Layout(title='No documents found')) | |
# Initialize graph | |
G = nx.Graph() | |
# Color mapping | |
color_map = { | |
'Document': '#3B82F6', # blue | |
'Section': '#10B981', # green | |
'Reference': '#F59E0B' # yellow | |
} | |
# Node information storage | |
node_colors = [] | |
node_texts = [] | |
node_hovers = [] # Full text for hover | |
nodes_added = set() | |
# Process nodes and edges | |
for data in graph_data: | |
source_id = data['source_id'] | |
source_text = data['source_text'] | |
# Add main document node | |
if source_id not in nodes_added: | |
G.add_node(source_id) | |
node_colors.append(color_map['Document']) | |
# Short text for display | |
node_texts.append(f"Doc {len(nodes_added)+1}") | |
# Full text for hover/click | |
node_hovers.append(f"Document {len(nodes_added)+1}:<br><br>{source_text}") | |
nodes_added.add(source_id) | |
# Process related documents | |
for related in data.get('related_entities', []): | |
related_id = related['id'] | |
similarity = related['properties'].get('similarity_score', 0.0) | |
if related_id not in nodes_added: | |
G.add_node(related_id) | |
node_colors.append(color_map['Document']) | |
node_texts.append(f"Doc {len(nodes_added)+1}") | |
node_hovers.append(f"Document {len(nodes_added)+1}:<br><br>{related['text']}") | |
nodes_added.add(related_id) | |
# Add edge with similarity weight | |
G.add_edge( | |
source_id, | |
related_id, | |
weight=similarity, | |
relationship=f"Similarity: {similarity:.2f}" | |
) | |
# Create layout | |
pos = nx.spring_layout(G, k=2.0, iterations=50) | |
# Create edge trace | |
edge_x = [] | |
edge_y = [] | |
edge_text = [] | |
for edge in G.edges(data=True): | |
x0, y0 = pos[edge[0]] | |
x1, y1 = pos[edge[1]] | |
# Create curved line | |
mid_x = (x0 + x1) / 2 | |
mid_y = (y0 + y1) / 2 | |
# Add some curvature | |
mid_x += (y1 - y0) * 0.1 | |
mid_y -= (x1 - x0) * 0.1 | |
# Add points for curved line | |
edge_x.extend([x0, mid_x, x1, None]) | |
edge_y.extend([y0, mid_y, y1, None]) | |
edge_text.append(edge[2]['relationship']) | |
edge_trace = go.Scatter( | |
x=edge_x, | |
y=edge_y, | |
line=dict(width=1.5, color='#9CA3AF'), | |
hoverinfo='text', | |
text=edge_text, | |
mode='lines' | |
) | |
# Create node trace | |
node_x = [] | |
node_y = [] | |
for node in G.nodes(): | |
x, y = pos[node] | |
node_x.append(x) | |
node_y.append(y) | |
node_trace = go.Scatter( | |
x=node_x, | |
y=node_y, | |
mode='markers+text', | |
hoverinfo='text', | |
text=node_texts, | |
hovertext=node_hovers, # Full text shown on hover | |
textposition="top center", | |
marker=dict( | |
size=30, | |
color=node_colors, | |
line=dict(width=2, color='white'), | |
symbol='circle' | |
), | |
customdata=node_hovers # Store full text for click events | |
) | |
# Create figure with updated layout | |
fig = go.Figure( | |
data=[edge_trace, node_trace], | |
layout=go.Layout( | |
title={ | |
'text': 'Document Similarity Graph<br><sub>Click nodes to view full text</sub>', | |
'y': 0.95, | |
'x': 0.5, | |
'xanchor': 'center', | |
'yanchor': 'top' | |
}, | |
showlegend=False, | |
hovermode='closest', | |
margin=dict(b=20, l=5, r=5, t=60), | |
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
plot_bgcolor='white', | |
width=800, | |
height=600, | |
clickmode='event+select' # Enable click events | |
) | |
) | |
return fig | |
def create_interface(graph_search_system: EnhancedLegalSearchSystem): | |
"""Create Gradio interface with interactive graph""" | |
search_system = LegalSearchSystem() | |
with gr.Blocks(css="footer {display: none !important;}") as demo: | |
gr.Markdown(""" | |
# Corporate Law Legal Search Engine | |
Enter your legal query below to search through documents and get an AI-powered analysis. Queries only related to corporate law will give relevant information. | |
""") | |
with gr.Row(): | |
query_input = gr.Textbox( | |
label="Legal Query", | |
placeholder="e.g., What are the key principles of contract law?", | |
lines=3 | |
) | |
with gr.Row(): | |
search_button = gr.Button("Search & Analyze") | |
status_output = gr.Textbox( | |
label="Status", | |
interactive=False | |
) | |
with gr.Tabs(): | |
with gr.TabItem("Search Results"): | |
docs_output = gr.Markdown( | |
label="Retrieved Documents", | |
value="Search results will appear here..." | |
) | |
with gr.TabItem("AI Legal Analysis"): | |
summary_output = gr.Markdown( | |
label="AI-Generated Legal Analysis", | |
value="Analysis will appear here..." | |
) | |
with gr.TabItem("Retrieved Documents through Graph Rag"): | |
docs_output_graph = gr.Markdown( | |
label="Source Documents", | |
value="Search results will appear here..." | |
) | |
graph_analysis_output = gr.JSON( | |
label="Related Concepts", | |
value={} | |
) | |
with gr.TabItem("Knowledge Graph"): | |
# Graph visualization | |
graph_output = gr.Plot( | |
label="Legal Knowledge Graph" | |
) | |
# Add text area for showing clicked document content | |
selected_doc_content = gr.Textbox( | |
label="Selected Document Content", | |
interactive=False, | |
lines=10 | |
) | |
async def process_query(query): | |
if not query.strip(): | |
return ( | |
"Please enter a query", | |
"No documents available from Pinecone", | |
"No analysis available from Pinecone", | |
"No documents available from Neo4j", | |
{}, | |
None, | |
"" | |
) | |
# Run the regular RAG search | |
rag_results = search_system.search(query) | |
# Run the graph-based RAG search | |
graph_results = await graph_search_system.process_legal_query(query) | |
graph_data = graph_search_system.generate_document_graph(query) | |
graph_fig = create_graph_visualization(graph_data) | |
return ( | |
rag_results['status'], | |
rag_results['documents'], | |
rag_results['analysis'], | |
graph_results['documents'], | |
{"related_concepts": graph_results['related_concepts']}, | |
graph_fig, | |
"Click on a node to view document content" | |
) | |
search_button.click( | |
fn=process_query, | |
inputs=[query_input], | |
outputs=[ | |
status_output, | |
docs_output, | |
summary_output, | |
docs_output_graph, | |
graph_analysis_output, | |
graph_output, | |
selected_doc_content | |
] | |
) | |
return demo | |
class LegalSearchSystem: | |
def __init__( | |
self, | |
pinecone_api_key: str = "pcsk_43sajZ_MjcXR2yN5cAcVi8RARyB6i3NP3wLTnTLugbUcN9cUU4q5EfNmuwLPkmxAvykk9o", | |
google_api_key: str = "AIzaSyDWHGMd8a70RbL3EBenfUwimcAHjhvgM6M", | |
environment: str = "us-east-1", | |
index_name: str = "pdf-embeddings", | |
dimension: int = 384, | |
embedding_model_name: str = "intfloat/e5-small-v2", | |
device: str = "cpu" | |
): | |
# Initialize Pinecone | |
self.pc = Pinecone(api_key=pinecone_api_key) | |
# Initialize LangChain with Gemini | |
self.llm = ChatGoogleGenerativeAI( | |
model="gemini-pro", | |
temperature=0, | |
google_api_key=google_api_key | |
) | |
# Initialize prompts | |
self.map_prompt = PromptTemplate( | |
template=""" | |
Analyze the following legal text segment and extract key information: | |
TEXT: "{text}" | |
Instructions: | |
1. Maintain all legal terminology exactly as written | |
2. Preserve section numbers and references | |
3. Keep all specific conditions and requirements | |
4. Include any mentioned time periods or deadlines | |
DETAILED ANALYSIS: | |
""", | |
input_variables=["text"] | |
) | |
self.combine_prompt = PromptTemplate( | |
template=""" | |
Based on the following excerpts from legal documents and the question: "{question}" | |
EXCERPTS: | |
{text} | |
Instructions: | |
1. Synthesize a comprehensive answer that connects relevant sections | |
2. Maintain precise legal language from the source material | |
3. Reference specific sections and subsections where applicable | |
4. If there are seemingly disconnected pieces of information, explain their relationship | |
5. Highlight any conditions or exceptions that span multiple excerpts | |
COMPREHENSIVE LEGAL ANALYSIS: | |
""", | |
input_variables=["text", "question"] | |
) | |
# Initialize chain | |
self.chain = load_summarize_chain( | |
llm=self.llm, | |
chain_type="stuff", | |
prompt=self.combine_prompt, | |
verbose=True | |
) | |
# Initialize Pinecone index and embedding model | |
self.index = self.pc.Index(index_name) | |
self.embedding_model = SentenceTransformer( | |
model_name_or_path=embedding_model_name, | |
device=device | |
) | |
def search(self, query_text: str, top_k: int = 5, context_window: int = 1) -> Dict: | |
""" | |
Perform a search and analysis of the legal query. | |
""" | |
try: | |
# Get search results with context | |
results = self.query_and_summarize( | |
query_text=query_text, | |
top_k=top_k, | |
context_window=context_window | |
) | |
# Format the results for display | |
docs_markdown = self._format_documents(results['raw_results']) | |
return { | |
'status': "Search completed successfully", | |
'documents': docs_markdown, | |
'analysis': results['summary'], | |
'source_pages': results['source_pages'], | |
'context_info': results['context_info'] | |
} | |
except Exception as e: | |
return { | |
'status': f"Error during search: {str(e)}", | |
'documents': "Error retrieving documents", | |
'analysis': "Error generating analysis", | |
'source_pages': [], | |
'context_info': {} | |
} | |
def query_and_summarize( | |
self, | |
query_text: str, | |
top_k: int = 5, | |
filter: Optional[Dict] = None, | |
context_window: int = 1 | |
) -> Dict: | |
""" | |
Query Pinecone and generate a summary with enhanced context handling. | |
""" | |
# Generate embedding for query | |
query_embedding = self.embedding_model.encode(query_text).tolist() | |
# Query Pinecone | |
initial_results = self.index.query( | |
vector=query_embedding, | |
top_k=top_k, | |
include_metadata=True, | |
filter=filter | |
)['matches'] | |
# Expand context | |
expanded_results = [] | |
for match in initial_results: | |
page_num = match['metadata']['page_number'] | |
context_filter = { | |
"page_number": { | |
"$gte": max(1, page_num - context_window), | |
"$lte": page_num + context_window | |
} | |
} | |
if filter: | |
context_filter.update(filter) | |
context_results = self.index.query( | |
vector=self.embedding_model.encode(match['metadata']['text']).tolist(), | |
top_k=2 * context_window + 1, | |
include_metadata=True, | |
filter=context_filter | |
)['matches'] | |
expanded_results.extend(context_results) | |
# Process results and generate summary | |
documents = self._process_results(expanded_results, initial_results) | |
summary = self.chain.run( | |
input_documents=documents, | |
question=query_text | |
) | |
return { | |
'raw_results': expanded_results, | |
'summary': summary, | |
'source_pages': list(set(doc.metadata['page_number'] for doc in documents)), | |
'context_info': { | |
'direct_matches': len([d for d in documents if d.metadata['context_type'] == "DIRECT MATCH"]), | |
'context_chunks': len([d for d in documents if d.metadata['context_type'] == "CONTEXT"]) | |
} | |
} | |
def _process_results(self, expanded_results: List[Dict], initial_results: List[Dict]) -> List[Document]: | |
""" | |
Process and deduplicate search results. | |
""" | |
seen_ids = set() | |
documents = [] | |
for result in expanded_results: | |
if result['id'] not in seen_ids: | |
seen_ids.add(result['id']) | |
is_direct_match = any(r['id'] == result['id'] for r in initial_results) | |
documents.append(Document( | |
page_content=result['metadata']['text'], | |
metadata={ | |
'score': result['score'], | |
'page_number': result['metadata']['page_number'], | |
'context_type': "DIRECT MATCH" if is_direct_match else "CONTEXT" | |
} | |
)) | |
return sorted(documents, key=lambda x: x.metadata['page_number']) | |
def _format_documents(self, results: List[Dict]) -> str: | |
""" | |
Format search results as markdown. | |
""" | |
markdown = "### Retrieved Documents\n\n" | |
for i, result in enumerate(results, 1): | |
markdown += f"**Document {i}** (Page {result['metadata']['page_number']})\n" | |
markdown += f"```\n{result['metadata']['text']}\n```\n\n" | |
return markdown | |
if __name__ == "__main__": | |
graph_search_system = EnhancedLegalSearchSystem( | |
google_api_key="AIzaSyDWHGMd8a70RbL3EBenfUwimcAHjhvgM6M", | |
neo4j_url="neo4j+s://a63462d0.databases.neo4j.io", | |
neo4j_username="neo4j", | |
neo4j_password="nU8Ut5N8k7LmQzNPe7vUbpZxMirK8rHrmLuzPc2G_Zc" | |
) | |
demo = create_interface(graph_search_system) | |
demo.launch() |