Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import chromadb | |
from openai import OpenAI | |
import json | |
from sentence_transformers import SentenceTransformer | |
from loguru import logger | |
from test_embeddings import test_chromadb_content, initialize_chromadb | |
class SentenceTransformerEmbeddings: | |
def __init__(self, model_name: str = 'all-MiniLM-L6-v2'): | |
self.model = SentenceTransformer(model_name) | |
def __call__(self, input: list[str]) -> list[list[float]]: | |
embeddings = self.model.encode(input) | |
return embeddings.tolist() | |
class LegalAssistant: | |
def __init__(self): | |
try: | |
# Initialize and verify ChromaDB content | |
logger.info("Initializing LegalAssistant...") | |
# Try to verify content, if fails, try to initialize | |
if not test_chromadb_content(): | |
logger.warning("ChromaDB verification failed, attempting to initialize...") | |
if not initialize_chromadb(): | |
raise ValueError("Failed to initialize ChromaDB") | |
# Initialize ChromaDB client | |
base_path = os.path.dirname(os.path.abspath(__file__)) | |
chroma_path = os.path.join(base_path, 'chroma_db') | |
self.chroma_client = chromadb.PersistentClient(path=chroma_path) | |
self.embedding_function = SentenceTransformerEmbeddings() | |
# Get existing collection | |
self.collection = self.chroma_client.get_collection( | |
name="legal_documents", | |
embedding_function=self.embedding_function | |
) | |
logger.info(f"Collection loaded with {self.collection.count()} documents") | |
# Initialize Mistral AI client | |
self.mistral_client = OpenAI( | |
api_key=os.environ.get("MISTRAL_API_KEY", "dfb2j1YDsa298GXTgZo3juSjZLGUCfwi"), | |
base_url="https://api.mistral.ai/v1" | |
) | |
logger.info("LegalAssistant initialized successfully") | |
except Exception as e: | |
logger.error(f"Error initializing LegalAssistant: {str(e)}") | |
raise | |
def validate_query(self, query: str) -> tuple[bool, str]: | |
"""Validate the input query""" | |
if not query or len(query.strip()) < 10: | |
return False, "Query too short. Please provide more details (minimum 10 characters)." | |
if len(query) > 500: | |
return False, "Query too long. Please be more concise (maximum 500 characters)." | |
return True, "" | |
def get_response(self, query: str) -> dict: | |
"""Process query and get response from Mistral AI""" | |
try: | |
# Validate query | |
is_valid, error_message = self.validate_query(query) | |
if not is_valid: | |
return { | |
"answer": error_message, | |
"references": [], | |
"summary": "Invalid query", | |
"confidence": "LOW" | |
} | |
# Search ChromaDB for relevant content | |
results = self.collection.query( | |
query_texts=[query], | |
n_results=3 | |
) | |
if not results['documents'][0]: | |
return { | |
"answer": "No relevant information found in the document.", | |
"references": [], | |
"summary": "No matching content", | |
"confidence": "LOW" | |
} | |
# Format context with section titles | |
context_parts = [] | |
references = [] | |
for doc, meta in zip(results['documents'][0], results['metadatas'][0]): | |
context_parts.append(f"{meta['title']}:\n{doc}") | |
references.append(meta['title']) | |
context = "\n\n".join(context_parts) | |
# Prepare system prompt with explicit JSON format | |
system_prompt = '''You are a specialized legal assistant that MUST follow these STRICT rules: | |
1. You MUST ONLY use information from the provided context. | |
2. DO NOT use any external knowledge about laws, IPC, Constitution, or legal matters. | |
3. Your response MUST be in this EXACT JSON format: | |
{ | |
"answer": "Your detailed answer using ONLY information from the context", | |
"reference_sections": ["List of section titles used from context"], | |
"summary": "Brief 2-3 line summary", | |
"confidence": "HIGH/MEDIUM/LOW" | |
} | |
Confidence Level Rules: | |
- HIGH: When exact information is found in context | |
- MEDIUM: When partial or indirect information is found | |
- LOW: When information is unclear or not found | |
If information is not in context, respond with: | |
{ | |
"answer": "This information is not present in the provided document.", | |
"reference_sections": [], | |
"summary": "Information not found in document", | |
"confidence": "LOW" | |
}''' | |
# Prepare user content | |
content = f'''Context Sections: | |
{context} | |
Question: {query} | |
IMPORTANT: | |
1. Use ONLY the information from the above context | |
2. Format your response as a valid JSON object with the exact structure shown above | |
3. Include ONLY section titles that exist in the context | |
4. DO NOT add any text outside the JSON structure | |
5. Ensure the JSON is properly formatted with double quotes''' | |
# Get response from Mistral AI | |
response = self.mistral_client.chat.completions.create( | |
model="mistral-medium", | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": content} | |
], | |
temperature=0.1, | |
max_tokens=1000, | |
response_format={ "type": "json_object" } | |
) | |
# Parse and validate response | |
if response.choices and response.choices[0].message.content: | |
try: | |
result = json.loads(response.choices[0].message.content) | |
# Validate response structure | |
required_fields = ["answer", "reference_sections", "summary", "confidence"] | |
if not all(field in result for field in required_fields): | |
raise ValueError("Missing required fields in response") | |
# Validate confidence level | |
if result["confidence"] not in ["HIGH", "MEDIUM", "LOW"]: | |
result["confidence"] = "LOW" | |
# Validate references against context | |
valid_references = [ref for ref in result["reference_sections"] | |
if ref in references] | |
# If references don't match, adjust confidence | |
if len(valid_references) != len(result["reference_sections"]): | |
result["reference_sections"] = valid_references | |
result["confidence"] = "LOW" | |
# Ensure answer and summary are strings | |
result["answer"] = str(result["answer"]) | |
result["summary"] = str(result["summary"]) | |
return { | |
"answer": result["answer"], | |
"references": valid_references, | |
"summary": result["summary"], | |
"confidence": result["confidence"] | |
} | |
except json.JSONDecodeError as e: | |
logger.error(f"JSON parsing error: {str(e)}") | |
return { | |
"answer": "Error: Failed to parse response format", | |
"references": [], | |
"summary": "Response format error", | |
"confidence": "LOW" | |
} | |
except ValueError as e: | |
logger.error(f"Validation error: {str(e)}") | |
return { | |
"answer": "Error: Invalid response structure", | |
"references": [], | |
"summary": "Response validation error", | |
"confidence": "LOW" | |
} | |
return { | |
"answer": "Error: No valid response received", | |
"references": [], | |
"summary": "No response generated", | |
"confidence": "LOW" | |
} | |
except Exception as e: | |
logger.error(f"Error in get_response: {str(e)}") | |
return { | |
"answer": f"Error: {str(e)}", | |
"references": [], | |
"summary": "System error occurred", | |
"confidence": "LOW" | |
} | |
# Initialize the assistant | |
try: | |
assistant = LegalAssistant() | |
except Exception as e: | |
logger.error(f"Failed to initialize LegalAssistant: {str(e)}") | |
raise | |
def process_query(query: str) -> tuple: | |
"""Process the query and return formatted response""" | |
response = assistant.get_response(query) | |
return ( | |
response["answer"], | |
", ".join(response["references"]) if response["references"] else "No specific references", | |
response["summary"] if response["summary"] else "No summary available", | |
response["confidence"] | |
) | |
# Create the Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# Indian Legal Assistant | |
## Guidelines for Queries: | |
1. Be specific and clear in your questions | |
2. End questions with a question mark or period | |
3. Keep queries between 10-500 characters | |
4. Questions will be answered based ONLY on the provided legal document | |
""") | |
with gr.Row(): | |
query_input = gr.Textbox( | |
label="Enter your legal query", | |
placeholder="e.g., What are the main provisions in this document?" | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Submit", variant="primary") | |
with gr.Row(): | |
confidence_output = gr.Textbox(label="Confidence Level") | |
with gr.Row(): | |
answer_output = gr.Textbox(label="Answer", lines=5) | |
with gr.Row(): | |
with gr.Column(): | |
references_output = gr.Textbox(label="Document References", lines=2) | |
with gr.Column(): | |
summary_output = gr.Textbox(label="Summary", lines=2) | |
gr.Markdown(""" | |
### Important Notes: | |
- Responses are based ONLY on the provided document | |
- No external legal knowledge is used | |
- All references are from the document itself | |
- Confidence levels indicate how well the answer matches the document content | |
""") | |
submit_btn.click( | |
fn=process_query, | |
inputs=[query_input], | |
outputs=[answer_output, references_output, summary_output, confidence_output] | |
) | |
if __name__ == "__main__": | |
demo.launch() |