Spaces:
Sleeping
Sleeping
import gradio as gr | |
from typing import List, Dict, Tuple | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline | |
from transformers import pipeline | |
import os | |
from astrapy.db import AstraDB | |
from dotenv import load_dotenv | |
from huggingface_hub import login | |
from sentence_transformers import SentenceTransformer | |
import json | |
# Load environment variables | |
load_dotenv() | |
# Login to Hugging Face Hub | |
login(token=os.getenv("HUGGINGFACE_API_TOKEN")) | |
class LegalTextSearchBot: | |
def __init__(self): | |
# Initialize AstraDB connection | |
self.astra_db = AstraDB( | |
token=os.getenv("ASTRA_DB_APPLICATION_TOKEN"), | |
api_endpoint=os.getenv("ASTRA_DB_API_ENDPOINT") | |
) | |
# Set your collection | |
self.collection = self.astra_db.collection(os.getenv("ASTRA_DB_COLLECTION")) | |
# Initialize the models | |
self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
pipe = pipeline( | |
"text-generation", | |
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.95, | |
repetition_penalty=1.15 | |
) | |
self.llm = HuggingFacePipeline(pipeline=pipe) | |
# Create prompt template | |
self.template = """ | |
IMPORTANT: You are a legal assistant that provides accurate information based on the Indian legal sections provided in the context. | |
STRICT RULES: | |
1. Base your response ONLY on the provided legal sections | |
2. If you cannot find relevant information, respond with: "I apologize, but I cannot find information about that in the legal database." | |
3. Do not make assumptions or use external knowledge | |
4. Always cite the specific section numbers you're referring to | |
5. Be precise and accurate in your legal interpretations | |
6. If quoting from the sections, use quotes and cite the section number | |
Context (Legal Sections): {context} | |
Chat History: {chat_history} | |
Question: {question} | |
Answer:""" | |
self.prompt = ChatPromptTemplate.from_template(self.template) | |
self.chat_history = "" | |
def _search_astra(self, query: str) -> List[Dict]: | |
"""Search AstraDB for relevant legal sections""" | |
try: | |
# Generate embedding for the query | |
query_embedding = self.embedding_model.encode(query).tolist() | |
# First try searching in searchable_text | |
results = list(self.collection.vector_find( | |
query_embedding, | |
limit=5, | |
fields=["section_number", "title", "chapter_info", "content", "searchable_text"] | |
)) | |
if not results: | |
# If no results, try a more general search | |
results = list(self.collection.find( | |
{}, | |
limit=5 | |
)) | |
return results | |
except Exception as e: | |
print(f"Error searching AstraDB: {str(e)}") | |
return [] | |
def format_section(self, section: Dict) -> str: | |
"""Format a section for display""" | |
try: | |
chapter_info = section.get('chapter_info', {}) | |
chapter_title = chapter_info.get('title', 'N/A') if isinstance(chapter_info, dict) else 'N/A' | |
return f""" | |
Section {section.get('section_number', 'N/A')}: {section.get('title', 'N/A')} | |
Chapter: {chapter_title} | |
Content: | |
{section.get('content', 'N/A')} | |
{"="*80} | |
""" | |
except Exception as e: | |
print(f"Error formatting section: {str(e)}") | |
return str(section) | |
def search_sections(self, query: str) -> Tuple[str, str]: | |
"""Search legal sections and return both raw results and AI interpretation""" | |
try: | |
# Search AstraDB for relevant sections | |
search_results = self._search_astra(query) | |
if not search_results: | |
return "No relevant sections found.", "I apologize, but I cannot find relevant sections in the database." | |
# Format raw results | |
raw_results = [] | |
context_parts = [] | |
for result in search_results: | |
# Format for display | |
raw_results.append(self.format_section(result)) | |
# Add to context for AI | |
context_parts.append(f""" | |
Section {result.get('section_number')}: {result.get('title')} | |
{result.get('content', '')} | |
""") | |
# Combine context for AI | |
context = "\n\n".join(context_parts) | |
# Generate AI interpretation | |
chain = self.prompt | self.llm | |
ai_response = chain.invoke({ | |
"context": context, | |
"chat_history": self.chat_history, | |
"question": query | |
}) | |
self.chat_history += f"\nUser: {query}\nAI: {ai_response}\n" | |
return "\n".join(raw_results), ai_response | |
except Exception as e: | |
error_msg = f"Error processing query: {str(e)}" | |
print(error_msg) | |
return error_msg, "An error occurred while processing your query." | |
def create_interface(): | |
"""Create the Gradio interface""" | |
with gr.Blocks(title="Legal Text Search System", theme=gr.themes.Soft()) as iface: | |
gr.Markdown(""" | |
# π Legal Text Search System | |
This system allows you to search through Indian legal sections and get both: | |
1. π Raw section contents that match your query | |
2. π€ AI-powered interpretation of the relevant sections | |
Enter your legal query below: | |
""") | |
# Initialize the search bot | |
search_bot = LegalTextSearchBot() | |
# Create input and output components | |
with gr.Row(): | |
query_input = gr.Textbox( | |
label="Your Query", | |
placeholder="e.g., What are the penalties for public servants who conceal information?", | |
lines=2 | |
) | |
with gr.Row(): | |
search_button = gr.Button("π Search Legal Sections", variant="primary") | |
with gr.Row(): | |
with gr.Column(): | |
raw_output = gr.Textbox( | |
label="π Relevant Legal Sections", | |
lines=15, | |
max_lines=30 | |
) | |
with gr.Column(): | |
ai_output = gr.Textbox( | |
label="π€ AI Interpretation", | |
lines=15, | |
max_lines=30 | |
) | |
# Add example queries | |
gr.Examples( | |
examples=[ | |
"What are the penalties for public servants who conceal information?", | |
"What is the punishment for corruption?", | |
"What happens if a public servant fails to prevent an offense?", | |
"What are the legal consequences for concealing design to commit offence?", | |
"Explain the duties and responsibilities of public servants" | |
], | |
inputs=query_input, | |
label="Example Queries" | |
) | |
# Set up the search function | |
def search(query): | |
raw_results, ai_response = search_bot.search_sections(query) | |
return raw_results, ai_response | |
# Connect the button to the search function | |
search_button.click( | |
fn=search, | |
inputs=query_input, | |
outputs=[raw_output, ai_output] | |
) | |
# Also allow pressing Enter to search | |
query_input.submit( | |
fn=search, | |
inputs=query_input, | |
outputs=[raw_output, ai_output] | |
) | |
return iface | |
# Create and launch the interface | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch() |