Spaces:
Sleeping
Sleeping
import gradio as gr | |
import chromadb | |
import os | |
from openai import OpenAI | |
import json | |
from typing import List, Dict | |
import re | |
from sentence_transformers import SentenceTransformer | |
from loguru import logger | |
class SentenceTransformerEmbeddings: | |
def __init__(self, model_name: str = 'all-MiniLM-L6-v2'): | |
self.model = SentenceTransformer(model_name) | |
def __call__(self, input: List[str]) -> List[List[float]]: | |
embeddings = self.model.encode(input) | |
return embeddings.tolist() | |
class LegalAssistant: | |
def __init__(self): | |
# Initialize ChromaDB | |
self.chroma_client = chromadb.Client() | |
# Initialize embedding function | |
self.embedding_function = SentenceTransformerEmbeddings() | |
# Create or get collection with proper embedding function | |
self.collection = self.chroma_client.get_or_create_collection( | |
name="legal_documents", | |
embedding_function=self.embedding_function | |
) | |
# Load documents if collection is empty | |
if self.collection.count() == 0: | |
self._load_documents() | |
# Initialize Mistral AI client | |
self.mistral_client = OpenAI( | |
api_key=os.environ.get("MISTRAL_API_KEY", "dfb2j1YDsa298GXTgZo3juSjZLGUCfwi"), | |
base_url="https://api.mistral.ai/v1" | |
) | |
# Define system prompt with strict rules | |
self.system_prompt = """You are a specialized legal assistant that MUST follow these STRICT rules: | |
CRITICAL RULE: | |
YOU MUST ONLY USE INFORMATION FROM THE PROVIDED CONTEXT. DO NOT USE ANY EXTERNAL KNOWLEDGE, INCLUDING KNOWLEDGE ABOUT IPC, CONSTITUTION, OR ANY OTHER LEGAL DOCUMENTS. | |
RESPONSE FORMAT RULES: | |
1. ALWAYS structure your response in this exact JSON format: | |
{ | |
"answer": "Your detailed answer here using ONLY information from the provided context", | |
"reference_sections": ["Exact section titles from the context"], | |
"summary": "2-3 line summary using ONLY information from context", | |
"confidence": "HIGH/MEDIUM/LOW based on context match" | |
} | |
STRICT CONTENT RULES: | |
1. NEVER mention or reference IPC, Constitution, or any laws not present in the context | |
2. If the information is not in the context, respond ONLY with: | |
{ | |
"answer": "This information is not present in the provided document.", | |
"reference_sections": [], | |
"summary": "Information not found in document", | |
"confidence": "LOW" | |
} | |
3. ONLY cite sections that are explicitly present in the provided context | |
4. DO NOT make assumptions or inferences beyond the context | |
5. DO NOT combine information from external knowledge | |
CONTEXT USAGE RULES: | |
1. HIGH confidence: Only when exact information is found in context | |
2. MEDIUM confidence: When partial information is found | |
3. LOW confidence: When information is unclear or not found | |
4. If multiple sections are relevant, cite ALL relevant sections from context | |
PROHIBITED ACTIONS: | |
1. NO references to IPC sections | |
2. NO references to Constitutional articles | |
3. NO mentions of case law not in context | |
4. NO legal interpretations beyond context | |
5. NO combining document information with external knowledge | |
ERROR HANDLING: | |
1. If query is about laws not in context: State "This topic is not covered in the provided document" | |
2. If query is unclear: Request specific clarification about which part of the document to check | |
3. If context is insufficient: State "The document does not contain this information" | |
""" | |
def _load_documents(self): | |
"""Load and index documents from a2023-45.txt and index.txt""" | |
try: | |
# Read the main document | |
with open('a2023-45.txt', 'r', encoding='utf-8') as f: | |
document = f.read() | |
# Read the index | |
with open('index.txt', 'r', encoding='utf-8') as f: | |
index_content = f.readlines() | |
# Parse index and split document | |
sections = [] | |
current_section = "" | |
current_title = "" | |
for line in document.split('\n'): | |
if any(index_line.strip() in line for index_line in index_content): | |
if current_section: | |
sections.append({ | |
"title": current_title, | |
"content": current_section.strip() | |
}) | |
current_title = line.strip() | |
current_section = "" | |
else: | |
current_section += line + "\n" | |
# Add the last section | |
if current_section: | |
sections.append({ | |
"title": current_title, | |
"content": current_section.strip() | |
}) | |
# Add to ChromaDB | |
documents = [section["content"] for section in sections] | |
metadatas = [{"title": section["title"], "source": "a2023-45.txt", "section_number": i + 1} | |
for i, section in enumerate(sections)] | |
ids = [f"section_{i+1}" for i in range(len(sections))] | |
self.collection.add( | |
documents=documents, | |
metadatas=metadatas, | |
ids=ids | |
) | |
logger.info(f"Loaded {len(sections)} sections into ChromaDB") | |
except Exception as e: | |
logger.error(f"Error loading documents: {str(e)}") | |
raise | |
def validate_query(self, query: str) -> tuple[bool, str]: | |
"""Validate the input query""" | |
if not query or len(query.strip()) < 10: | |
return False, "Query too short. Please provide more details (minimum 10 characters)." | |
if len(query) > 500: | |
return False, "Query too long. Please be more concise (maximum 500 characters)." | |
if not re.search(r'[?.]$', query): | |
return False, "Query must end with a question mark or period." | |
return True, "" | |
def _search_documents(self, query: str) -> tuple[str, List[str]]: | |
"""Search ChromaDB for relevant documents""" | |
try: | |
results = self.collection.query( | |
query_texts=[query], | |
n_results=3 | |
) | |
if results and results['documents']: | |
documents = results['documents'][0] | |
metadata = results['metadatas'][0] | |
# Format the context with section titles | |
formatted_docs = [] | |
references = [] | |
for doc, meta in zip(documents, metadata): | |
formatted_docs.append(f"{meta['title']}:\n{doc}") | |
references.append(f"{meta['title']} (Section {meta['section_number']})") | |
return "\n\n".join(formatted_docs), references | |
return "", [] | |
except Exception as e: | |
logger.error(f"Search error: {str(e)}") | |
return "", [] | |
def get_response(self, query: str) -> Dict: | |
"""Get response from Mistral AI with context from ChromaDB""" | |
# Validate query | |
is_valid, error_message = self.validate_query(query) | |
if not is_valid: | |
return { | |
"answer": error_message, | |
"references": [], | |
"summary": "Invalid query", | |
"confidence": "LOW" | |
} | |
try: | |
# Get relevant context from ChromaDB | |
context, sources = self._search_documents(query) | |
if not context: | |
return { | |
"answer": "This information is not present in the provided document.", | |
"references": [], | |
"summary": "Information not found in document", | |
"confidence": "LOW" | |
} | |
# Prepare content with explicit instructions | |
content = f"""IMPORTANT: ONLY use information from the following context to answer the question. DO NOT use any external knowledge. | |
Context Sections: | |
{context} | |
Available Document Sections: | |
{', '.join(sources)} | |
Question: {query} | |
Remember: ONLY use information from the above context. If the information is not in the context, state that it's not in the document.""" | |
# Get response from Mistral AI | |
response = self.mistral_client.chat.completions.create( | |
model="mistral-medium", | |
messages=[ | |
{ | |
"role": "system", | |
"content": self.system_prompt | |
}, | |
{ | |
"role": "user", | |
"content": content | |
} | |
], | |
temperature=0.1, | |
max_tokens=1000 | |
) | |
# Parse response | |
if response.choices and len(response.choices) > 0: | |
try: | |
result = json.loads(response.choices[0].message.content) | |
# Validate that references only contain sections from sources | |
valid_references = [ref for ref in result.get("reference_sections", []) | |
if any(source.split(" (Section")[0] in ref for source in sources)] | |
# If references mention unauthorized sources, return error | |
if len(valid_references) != len(result.get("reference_sections", [])): | |
logger.warning("Response contained unauthorized references") | |
return { | |
"answer": "Error: Response contained unauthorized references. Only information from the provided document is allowed.", | |
"references": [], | |
"summary": "Invalid response generated", | |
"confidence": "LOW" | |
} | |
return { | |
"answer": result.get("answer", "No answer provided"), | |
"references": valid_references, | |
"summary": result.get("summary", ""), | |
"confidence": result.get("confidence", "LOW") | |
} | |
except json.JSONDecodeError: | |
logger.error("Failed to parse response JSON") | |
return { | |
"answer": "Error: Response format invalid", | |
"references": [], | |
"summary": "Response parsing failed", | |
"confidence": "LOW" | |
} | |
return { | |
"answer": "No valid response received", | |
"references": [], | |
"summary": "Response generation failed", | |
"confidence": "LOW" | |
} | |
except Exception as e: | |
logger.error(f"Error in get_response: {str(e)}") | |
return { | |
"answer": f"Error: {str(e)}", | |
"references": [], | |
"summary": "System error occurred", | |
"confidence": "LOW" | |
} | |
# Initialize the assistant | |
assistant = LegalAssistant() | |
# Create Gradio interface | |
def process_query(query: str) -> tuple: | |
"""Process the query and return formatted response""" | |
response = assistant.get_response(query) | |
return ( | |
response["answer"], | |
", ".join(response["references"]) if response["references"] else "No specific references", | |
response["summary"] if response["summary"] else "No summary available", | |
response["confidence"] | |
) | |
# Create the Gradio interface with a professional theme | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# Indian Legal Assistant | |
## Guidelines for Queries: | |
1. Be specific and clear in your questions | |
2. End questions with a question mark or period | |
3. Keep queries between 10-500 characters | |
4. Questions will be answered based ONLY on the provided legal document | |
""") | |
with gr.Row(): | |
query_input = gr.Textbox( | |
label="Enter your legal query", | |
placeholder="e.g., What are the main provisions in this document?" | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Submit", variant="primary") | |
with gr.Row(): | |
confidence_output = gr.Textbox(label="Confidence Level") | |
with gr.Row(): | |
answer_output = gr.Textbox(label="Answer", lines=5) | |
with gr.Row(): | |
with gr.Column(): | |
references_output = gr.Textbox(label="Document References", lines=3) | |
with gr.Column(): | |
summary_output = gr.Textbox(label="Summary", lines=2) | |
gr.Markdown(""" | |
### Important Notes: | |
- Responses are based ONLY on the provided document | |
- No external legal knowledge is used | |
- All references are from the document itself | |
- Confidence levels indicate how well the answer matches the document content | |
""") | |
submit_btn.click( | |
fn=process_query, | |
inputs=[query_input], | |
outputs=[answer_output, references_output, summary_output, confidence_output] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() |