advisor / app.py
veerukhannan's picture
Update app.py
a53e1b6 verified
raw
history blame
8.2 kB
import gradio as gr
from typing import List, Dict, Tuple
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from transformers import pipeline
import os
from astrapy.db import AstraDB
from dotenv import load_dotenv
from huggingface_hub import login
from sentence_transformers import SentenceTransformer
import json
# Load environment variables
load_dotenv()
# Login to Hugging Face Hub
login(token=os.getenv("HUGGINGFACE_API_TOKEN"))
class LegalTextSearchBot:
def __init__(self):
# Initialize AstraDB connection
self.astra_db = AstraDB(
token=os.getenv("ASTRA_DB_APPLICATION_TOKEN"),
api_endpoint=os.getenv("ASTRA_DB_API_ENDPOINT")
)
# Set your collection
self.collection = self.astra_db.collection(os.getenv("ASTRA_DB_COLLECTION"))
# Initialize the models
self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
pipe = pipeline(
"text-generation",
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
max_new_tokens=512,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.15
)
self.llm = HuggingFacePipeline(pipeline=pipe)
# Create prompt template
self.template = """
IMPORTANT: You are a legal assistant that provides accurate information based on the Indian legal sections provided in the context.
STRICT RULES:
1. Base your response ONLY on the provided legal sections
2. If you cannot find relevant information, respond with: "I apologize, but I cannot find information about that in the legal database."
3. Do not make assumptions or use external knowledge
4. Always cite the specific section numbers you're referring to
5. Be precise and accurate in your legal interpretations
6. If quoting from the sections, use quotes and cite the section number
Context (Legal Sections): {context}
Chat History: {chat_history}
Question: {question}
Answer:"""
self.prompt = ChatPromptTemplate.from_template(self.template)
self.chat_history = ""
def _search_astra(self, query: str) -> List[Dict]:
"""Search AstraDB for relevant legal sections"""
try:
# Generate embedding for the query
query_embedding = self.embedding_model.encode(query).tolist()
# First try searching in searchable_text
results = list(self.collection.vector_find(
query_embedding,
limit=5,
fields=["section_number", "title", "chapter_info", "content", "searchable_text"]
))
if not results:
# If no results, try a more general search
results = list(self.collection.find(
{},
limit=5
))
return results
except Exception as e:
print(f"Error searching AstraDB: {str(e)}")
return []
def format_section(self, section: Dict) -> str:
"""Format a section for display"""
try:
chapter_info = section.get('chapter_info', {})
chapter_title = chapter_info.get('title', 'N/A') if isinstance(chapter_info, dict) else 'N/A'
return f"""
Section {section.get('section_number', 'N/A')}: {section.get('title', 'N/A')}
Chapter: {chapter_title}
Content:
{section.get('content', 'N/A')}
{"="*80}
"""
except Exception as e:
print(f"Error formatting section: {str(e)}")
return str(section)
def search_sections(self, query: str) -> Tuple[str, str]:
"""Search legal sections and return both raw results and AI interpretation"""
try:
# Search AstraDB for relevant sections
search_results = self._search_astra(query)
if not search_results:
return "No relevant sections found.", "I apologize, but I cannot find relevant sections in the database."
# Format raw results
raw_results = []
context_parts = []
for result in search_results:
# Format for display
raw_results.append(self.format_section(result))
# Add to context for AI
context_parts.append(f"""
Section {result.get('section_number')}: {result.get('title')}
{result.get('content', '')}
""")
# Combine context for AI
context = "\n\n".join(context_parts)
# Generate AI interpretation
chain = self.prompt | self.llm
ai_response = chain.invoke({
"context": context,
"chat_history": self.chat_history,
"question": query
})
self.chat_history += f"\nUser: {query}\nAI: {ai_response}\n"
return "\n".join(raw_results), ai_response
except Exception as e:
error_msg = f"Error processing query: {str(e)}"
print(error_msg)
return error_msg, "An error occurred while processing your query."
def create_interface():
"""Create the Gradio interface"""
with gr.Blocks(title="Legal Text Search System", theme=gr.themes.Soft()) as iface:
gr.Markdown("""
# πŸ“š Legal Text Search System
This system allows you to search through Indian legal sections and get both:
1. πŸ“œ Raw section contents that match your query
2. πŸ€– AI-powered interpretation of the relevant sections
Enter your legal query below:
""")
# Initialize the search bot
search_bot = LegalTextSearchBot()
# Create input and output components
with gr.Row():
query_input = gr.Textbox(
label="Your Query",
placeholder="e.g., What are the penalties for public servants who conceal information?",
lines=2
)
with gr.Row():
search_button = gr.Button("πŸ” Search Legal Sections", variant="primary")
with gr.Row():
with gr.Column():
raw_output = gr.Textbox(
label="πŸ“œ Relevant Legal Sections",
lines=15,
max_lines=30
)
with gr.Column():
ai_output = gr.Textbox(
label="πŸ€– AI Interpretation",
lines=15,
max_lines=30
)
# Add example queries
gr.Examples(
examples=[
"What are the penalties for public servants who conceal information?",
"What is the punishment for corruption?",
"What happens if a public servant fails to prevent an offense?",
"What are the legal consequences for concealing design to commit offence?",
"Explain the duties and responsibilities of public servants"
],
inputs=query_input,
label="Example Queries"
)
# Set up the search function
def search(query):
raw_results, ai_response = search_bot.search_sections(query)
return raw_results, ai_response
# Connect the button to the search function
search_button.click(
fn=search,
inputs=query_input,
outputs=[raw_output, ai_output]
)
# Also allow pressing Enter to search
query_input.submit(
fn=search,
inputs=query_input,
outputs=[raw_output, ai_output]
)
return iface
# Create and launch the interface
if __name__ == "__main__":
demo = create_interface()
demo.launch()