Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import chromadb | |
from openai import OpenAI | |
import json | |
from sentence_transformers import SentenceTransformer | |
from loguru import logger | |
from test_embeddings import test_chromadb_content | |
class SentenceTransformerEmbeddings: | |
def __init__(self, model_name: str = 'all-MiniLM-L6-v2'): | |
self.model = SentenceTransformer(model_name) | |
def __call__(self, input: list[str]) -> list[list[float]]: | |
embeddings = self.model.encode(input) | |
return embeddings.tolist() | |
class LegalAssistant: | |
def __init__(self): | |
try: | |
# Verify ChromaDB content first | |
if not test_chromadb_content(): | |
raise ValueError("ChromaDB content verification failed") | |
# Initialize ChromaDB | |
base_path = os.path.dirname(os.path.abspath(__file__)) | |
chroma_path = os.path.join(base_path, 'chroma_db') | |
self.chroma_client = chromadb.PersistentClient(path=chroma_path) | |
self.embedding_function = SentenceTransformerEmbeddings() | |
# Get existing collection | |
self.collection = self.chroma_client.get_collection( | |
name="legal_documents", | |
embedding_function=self.embedding_function | |
) | |
# Initialize Mistral AI client | |
self.mistral_client = OpenAI( | |
api_key=os.environ.get("MISTRAL_API_KEY", "dfb2j1YDsa298GXTgZo3juSjZLGUCfwi"), | |
base_url="https://api.mistral.ai/v1" | |
) | |
logger.info("LegalAssistant initialized successfully") | |
except Exception as e: | |
logger.error(f"Error initializing LegalAssistant: {str(e)}") | |
raise | |
def validate_query(self, query: str) -> tuple[bool, str]: | |
"""Validate the input query""" | |
if not query or len(query.strip()) < 10: | |
return False, "Query too short. Please provide more details (minimum 10 characters)." | |
if len(query) > 500: | |
return False, "Query too long. Please be more concise (maximum 500 characters)." | |
return True, "" | |
def get_response(self, query: str) -> dict: | |
"""Process query and get response from Mistral AI""" | |
try: | |
# Validate query | |
is_valid, error_message = self.validate_query(query) | |
if not is_valid: | |
return { | |
"answer": error_message, | |
"references": [], | |
"summary": "Invalid query", | |
"confidence": "LOW" | |
} | |
# Search ChromaDB for relevant content | |
results = self.collection.query( | |
query_texts=[query], | |
n_results=3 | |
) | |
if not results['documents'][0]: | |
return { | |
"answer": "No relevant information found in the document.", | |
"references": [], | |
"summary": "No matching content", | |
"confidence": "LOW" | |
} | |
# Format context with section titles | |
context_parts = [] | |
references = [] | |
for doc, meta in zip(results['documents'][0], results['metadatas'][0]): | |
context_parts.append(f"{meta['title']}:\n{doc}") | |
references.append(f"{meta['title']} (Section {meta['section_number']})") | |
context = "\n\n".join(context_parts) | |
# Prepare content for Mistral AI | |
system_prompt = """You are a specialized legal assistant that MUST follow these STRICT rules: | |
CRITICAL RULE: | |
YOU MUST ONLY USE INFORMATION FROM THE PROVIDED CONTEXT. DO NOT USE ANY EXTERNAL KNOWLEDGE. | |
RESPONSE FORMAT RULES: | |
1. ALWAYS structure your response in this exact JSON format: | |
{ | |
"answer": "Your detailed answer here using ONLY information from the provided context", | |
"reference_sections": ["Exact section titles from the context"], | |
"summary": "2-3 line summary using ONLY information from context", | |
"confidence": "HIGH/MEDIUM/LOW based on context match" | |
} | |
STRICT CONTENT RULES: | |
1. NEVER mention or reference any laws not present in the context | |
2. If the information is not in the context, respond with LOW confidence | |
3. ONLY cite sections that are explicitly present in the provided context | |
4. DO NOT make assumptions or inferences beyond the context | |
5. DO NOT combine information from external knowledge""" | |
content = f"""IMPORTANT: ONLY use information from the following context to answer the question. | |
Context Sections: | |
{context} | |
Available Document Sections: | |
{', '.join(references)} | |
Question: {query} | |
Remember: ONLY use information from the above context.""" | |
# Get response from Mistral AI | |
response = self.mistral_client.chat.completions.create( | |
model="mistral-medium", | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": content} | |
], | |
temperature=0.1, | |
max_tokens=1000 | |
) | |
# Parse and validate response | |
if response.choices and response.choices[0].message.content: | |
try: | |
result = json.loads(response.choices[0].message.content) | |
# Validate references | |
valid_references = [ref for ref in result.get("reference_sections", []) | |
if any(source.split(" (Section")[0] in ref for source in references)] | |
if len(valid_references) != len(result.get("reference_sections", [])): | |
logger.warning("Response contained unauthorized references") | |
return { | |
"answer": "Error: Response contained unauthorized references", | |
"references": [], | |
"summary": "Invalid response generated", | |
"confidence": "LOW" | |
} | |
return { | |
"answer": result.get("answer", "No answer provided"), | |
"references": valid_references, | |
"summary": result.get("summary", ""), | |
"confidence": result.get("confidence", "LOW") | |
} | |
except json.JSONDecodeError: | |
logger.error("Failed to parse response JSON") | |
return { | |
"answer": "Error: Invalid response format", | |
"references": [], | |
"summary": "Response parsing failed", | |
"confidence": "LOW" | |
} | |
return { | |
"answer": "No valid response received", | |
"references": [], | |
"summary": "Response generation failed", | |
"confidence": "LOW" | |
} | |
except Exception as e: | |
logger.error(f"Error in get_response: {str(e)}") | |
return { | |
"answer": f"Error: {str(e)}", | |
"references": [], | |
"summary": "System error occurred", | |
"confidence": "LOW" | |
} | |
# Initialize the assistant | |
try: | |
assistant = LegalAssistant() | |
except Exception as e: | |
logger.error(f"Failed to initialize LegalAssistant: {str(e)}") | |
raise | |
def process_query(query: str) -> tuple: | |
"""Process the query and return formatted response""" | |
response = assistant.get_response(query) | |
return ( | |
response["answer"], | |
", ".join(response["references"]) if response["references"] else "No specific references", | |
response["summary"] if response["summary"] else "No summary available", | |
response["confidence"] | |
) | |
# Create the Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# Indian Legal Assistant | |
## Guidelines for Queries: | |
1. Be specific and clear in your questions | |
2. End questions with a question mark or period | |
3. Keep queries between 10-500 characters | |
4. Questions will be answered based ONLY on the provided legal document | |
""") | |
with gr.Row(): | |
query_input = gr.Textbox( | |
label="Enter your legal query", | |
placeholder="e.g., What are the main provisions in this document?" | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Submit", variant="primary") | |
with gr.Row(): | |
confidence_output = gr.Textbox(label="Confidence Level") | |
with gr.Row(): | |
answer_output = gr.Textbox(label="Answer", lines=5) | |
with gr.Row(): | |
with gr.Column(): | |
references_output = gr.Textbox(label="Document References", lines=3) | |
with gr.Column(): | |
summary_output = gr.Textbox(label="Summary", lines=2) | |
submit_btn.click( | |
fn=process_query, | |
inputs=[query_input], | |
outputs=[answer_output, references_output, summary_output, confidence_output] | |
) | |
if __name__ == "__main__": | |
demo.launch() |