Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import chromadb
|
4 |
+
from openai import OpenAI
|
5 |
+
import json
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
from loguru import logger
|
8 |
+
from test_embeddings import test_chromadb_content
|
9 |
+
|
10 |
+
class SentenceTransformerEmbeddings:
|
11 |
+
def __init__(self, model_name: str = 'all-MiniLM-L6-v2'):
|
12 |
+
self.model = SentenceTransformer(model_name)
|
13 |
+
|
14 |
+
def __call__(self, input: list[str]) -> list[list[float]]:
|
15 |
+
embeddings = self.model.encode(input)
|
16 |
+
return embeddings.tolist()
|
17 |
+
|
18 |
+
class LegalAssistant:
|
19 |
+
def __init__(self):
|
20 |
+
try:
|
21 |
+
# Verify ChromaDB content first
|
22 |
+
if not test_chromadb_content():
|
23 |
+
raise ValueError("ChromaDB content verification failed")
|
24 |
+
|
25 |
+
# Initialize ChromaDB
|
26 |
+
base_path = os.path.dirname(os.path.abspath(__file__))
|
27 |
+
chroma_path = os.path.join(base_path, 'chroma_db')
|
28 |
+
|
29 |
+
self.chroma_client = chromadb.PersistentClient(path=chroma_path)
|
30 |
+
self.embedding_function = SentenceTransformerEmbeddings()
|
31 |
+
|
32 |
+
# Get existing collection
|
33 |
+
self.collection = self.chroma_client.get_collection(
|
34 |
+
name="legal_documents",
|
35 |
+
embedding_function=self.embedding_function
|
36 |
+
)
|
37 |
+
|
38 |
+
# Initialize Mistral AI client
|
39 |
+
self.mistral_client = OpenAI(
|
40 |
+
api_key=os.environ.get("MISTRAL_API_KEY", "dfb2j1YDsa298GXTgZo3juSjZLGUCfwi"),
|
41 |
+
base_url="https://api.mistral.ai/v1"
|
42 |
+
)
|
43 |
+
|
44 |
+
logger.info("LegalAssistant initialized successfully")
|
45 |
+
|
46 |
+
except Exception as e:
|
47 |
+
logger.error(f"Error initializing LegalAssistant: {str(e)}")
|
48 |
+
raise
|
49 |
+
|
50 |
+
def validate_query(self, query: str) -> tuple[bool, str]:
|
51 |
+
"""Validate the input query"""
|
52 |
+
if not query or len(query.strip()) < 10:
|
53 |
+
return False, "Query too short. Please provide more details (minimum 10 characters)."
|
54 |
+
if len(query) > 500:
|
55 |
+
return False, "Query too long. Please be more concise (maximum 500 characters)."
|
56 |
+
return True, ""
|
57 |
+
|
58 |
+
def get_response(self, query: str) -> dict:
|
59 |
+
"""Process query and get response from Mistral AI"""
|
60 |
+
try:
|
61 |
+
# Validate query
|
62 |
+
is_valid, error_message = self.validate_query(query)
|
63 |
+
if not is_valid:
|
64 |
+
return {
|
65 |
+
"answer": error_message,
|
66 |
+
"references": [],
|
67 |
+
"summary": "Invalid query",
|
68 |
+
"confidence": "LOW"
|
69 |
+
}
|
70 |
+
|
71 |
+
# Search ChromaDB for relevant content
|
72 |
+
results = self.collection.query(
|
73 |
+
query_texts=[query],
|
74 |
+
n_results=3
|
75 |
+
)
|
76 |
+
|
77 |
+
if not results['documents'][0]:
|
78 |
+
return {
|
79 |
+
"answer": "No relevant information found in the document.",
|
80 |
+
"references": [],
|
81 |
+
"summary": "No matching content",
|
82 |
+
"confidence": "LOW"
|
83 |
+
}
|
84 |
+
|
85 |
+
# Format context with section titles
|
86 |
+
context_parts = []
|
87 |
+
references = []
|
88 |
+
|
89 |
+
for doc, meta in zip(results['documents'][0], results['metadatas'][0]):
|
90 |
+
context_parts.append(f"{meta['title']}:\n{doc}")
|
91 |
+
references.append(f"{meta['title']} (Section {meta['section_number']})")
|
92 |
+
|
93 |
+
context = "\n\n".join(context_parts)
|
94 |
+
|
95 |
+
# Prepare content for Mistral AI
|
96 |
+
system_prompt = """You are a specialized legal assistant that MUST follow these STRICT rules:
|
97 |
+
|
98 |
+
CRITICAL RULE:
|
99 |
+
YOU MUST ONLY USE INFORMATION FROM THE PROVIDED CONTEXT. DO NOT USE ANY EXTERNAL KNOWLEDGE.
|
100 |
+
|
101 |
+
RESPONSE FORMAT RULES:
|
102 |
+
1. ALWAYS structure your response in this exact JSON format:
|
103 |
+
{
|
104 |
+
"answer": "Your detailed answer here using ONLY information from the provided context",
|
105 |
+
"reference_sections": ["Exact section titles from the context"],
|
106 |
+
"summary": "2-3 line summary using ONLY information from context",
|
107 |
+
"confidence": "HIGH/MEDIUM/LOW based on context match"
|
108 |
+
}
|
109 |
+
|
110 |
+
STRICT CONTENT RULES:
|
111 |
+
1. NEVER mention or reference any laws not present in the context
|
112 |
+
2. If the information is not in the context, respond with LOW confidence
|
113 |
+
3. ONLY cite sections that are explicitly present in the provided context
|
114 |
+
4. DO NOT make assumptions or inferences beyond the context
|
115 |
+
5. DO NOT combine information from external knowledge"""
|
116 |
+
|
117 |
+
content = f"""IMPORTANT: ONLY use information from the following context to answer the question.
|
118 |
+
|
119 |
+
Context Sections:
|
120 |
+
{context}
|
121 |
+
|
122 |
+
Available Document Sections:
|
123 |
+
{', '.join(references)}
|
124 |
+
|
125 |
+
Question: {query}
|
126 |
+
|
127 |
+
Remember: ONLY use information from the above context."""
|
128 |
+
|
129 |
+
# Get response from Mistral AI
|
130 |
+
response = self.mistral_client.chat.completions.create(
|
131 |
+
model="mistral-medium",
|
132 |
+
messages=[
|
133 |
+
{"role": "system", "content": system_prompt},
|
134 |
+
{"role": "user", "content": content}
|
135 |
+
],
|
136 |
+
temperature=0.1,
|
137 |
+
max_tokens=1000
|
138 |
+
)
|
139 |
+
|
140 |
+
# Parse and validate response
|
141 |
+
if response.choices and response.choices[0].message.content:
|
142 |
+
try:
|
143 |
+
result = json.loads(response.choices[0].message.content)
|
144 |
+
|
145 |
+
# Validate references
|
146 |
+
valid_references = [ref for ref in result.get("reference_sections", [])
|
147 |
+
if any(source.split(" (Section")[0] in ref for source in references)]
|
148 |
+
|
149 |
+
if len(valid_references) != len(result.get("reference_sections", [])):
|
150 |
+
logger.warning("Response contained unauthorized references")
|
151 |
+
return {
|
152 |
+
"answer": "Error: Response contained unauthorized references",
|
153 |
+
"references": [],
|
154 |
+
"summary": "Invalid response generated",
|
155 |
+
"confidence": "LOW"
|
156 |
+
}
|
157 |
+
|
158 |
+
return {
|
159 |
+
"answer": result.get("answer", "No answer provided"),
|
160 |
+
"references": valid_references,
|
161 |
+
"summary": result.get("summary", ""),
|
162 |
+
"confidence": result.get("confidence", "LOW")
|
163 |
+
}
|
164 |
+
|
165 |
+
except json.JSONDecodeError:
|
166 |
+
logger.error("Failed to parse response JSON")
|
167 |
+
return {
|
168 |
+
"answer": "Error: Invalid response format",
|
169 |
+
"references": [],
|
170 |
+
"summary": "Response parsing failed",
|
171 |
+
"confidence": "LOW"
|
172 |
+
}
|
173 |
+
|
174 |
+
return {
|
175 |
+
"answer": "No valid response received",
|
176 |
+
"references": [],
|
177 |
+
"summary": "Response generation failed",
|
178 |
+
"confidence": "LOW"
|
179 |
+
}
|
180 |
+
|
181 |
+
except Exception as e:
|
182 |
+
logger.error(f"Error in get_response: {str(e)}")
|
183 |
+
return {
|
184 |
+
"answer": f"Error: {str(e)}",
|
185 |
+
"references": [],
|
186 |
+
"summary": "System error occurred",
|
187 |
+
"confidence": "LOW"
|
188 |
+
}
|
189 |
+
|
190 |
+
# Initialize the assistant
|
191 |
+
try:
|
192 |
+
assistant = LegalAssistant()
|
193 |
+
except Exception as e:
|
194 |
+
logger.error(f"Failed to initialize LegalAssistant: {str(e)}")
|
195 |
+
raise
|
196 |
+
|
197 |
+
def process_query(query: str) -> tuple:
|
198 |
+
"""Process the query and return formatted response"""
|
199 |
+
response = assistant.get_response(query)
|
200 |
+
return (
|
201 |
+
response["answer"],
|
202 |
+
", ".join(response["references"]) if response["references"] else "No specific references",
|
203 |
+
response["summary"] if response["summary"] else "No summary available",
|
204 |
+
response["confidence"]
|
205 |
+
)
|
206 |
+
|
207 |
+
# Create the Gradio interface
|
208 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
209 |
+
gr.Markdown("""
|
210 |
+
# Indian Legal Assistant
|
211 |
+
## Guidelines for Queries:
|
212 |
+
1. Be specific and clear in your questions
|
213 |
+
2. End questions with a question mark or period
|
214 |
+
3. Keep queries between 10-500 characters
|
215 |
+
4. Questions will be answered based ONLY on the provided legal document
|
216 |
+
""")
|
217 |
+
|
218 |
+
with gr.Row():
|
219 |
+
query_input = gr.Textbox(
|
220 |
+
label="Enter your legal query",
|
221 |
+
placeholder="e.g., What are the main provisions in this document?"
|
222 |
+
)
|
223 |
+
|
224 |
+
with gr.Row():
|
225 |
+
submit_btn = gr.Button("Submit", variant="primary")
|
226 |
+
|
227 |
+
with gr.Row():
|
228 |
+
confidence_output = gr.Textbox(label="Confidence Level")
|
229 |
+
|
230 |
+
with gr.Row():
|
231 |
+
answer_output = gr.Textbox(label="Answer", lines=5)
|
232 |
+
|
233 |
+
with gr.Row():
|
234 |
+
with gr.Column():
|
235 |
+
references_output = gr.Textbox(label="Document References", lines=3)
|
236 |
+
with gr.Column():
|
237 |
+
summary_output = gr.Textbox(label="Summary", lines=2)
|
238 |
+
|
239 |
+
submit_btn.click(
|
240 |
+
fn=process_query,
|
241 |
+
inputs=[query_input],
|
242 |
+
outputs=[answer_output, references_output, summary_output, confidence_output]
|
243 |
+
)
|
244 |
+
|
245 |
+
if __name__ == "__main__":
|
246 |
+
demo.launch()
|