Spaces:
Sleeping
Sleeping
import gradio as gr | |
from typing import List, Dict, Tuple | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline | |
from transformers import pipeline | |
import os | |
from astrapy.db import AstraDB | |
from dotenv import load_dotenv | |
from huggingface_hub import login | |
import time | |
import threading | |
from queue import Queue | |
import asyncio | |
# Load environment variables | |
load_dotenv() | |
login(token=os.getenv("HUGGINGFACE_API_TOKEN")) | |
class SearchCancelled(Exception): | |
pass | |
class LegalTextSearchBot: | |
def __init__(self): | |
self.astra_db = AstraDB( | |
token=os.getenv("ASTRA_DB_APPLICATION_TOKEN"), | |
api_endpoint=os.getenv("ASTRA_DB_API_ENDPOINT") | |
) | |
self.collection = self.astra_db.collection("legal_content") | |
pipe = pipeline( | |
"text-generation", | |
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.95, | |
repetition_penalty=1.15 | |
) | |
self.llm = HuggingFacePipeline(pipeline=pipe) | |
self.template = """ | |
IMPORTANT: You are a legal assistant that provides accurate information based on the Indian legal sections provided in the context. | |
STRICT RULES: | |
1. Base your response ONLY on the provided legal sections | |
2. If you cannot find relevant information, respond with: "I apologize, but I cannot find information about that in the legal database." | |
3. Do not make assumptions or use external knowledge | |
4. Always cite the specific section numbers you're referring to | |
5. Be precise and accurate in your legal interpretations | |
6. If quoting from the sections, use quotes and cite the section number | |
Context (Legal Sections): {context} | |
Chat History: {chat_history} | |
Question: {question} | |
Answer:""" | |
self.prompt = ChatPromptTemplate.from_template(self.template) | |
self.chat_history = "" | |
self.cancel_search = False | |
def _search_astra(self, query: str) -> List[Dict]: | |
if self.cancel_search: | |
raise SearchCancelled("Search was cancelled by user") | |
try: | |
results = list(self.collection.vector_find( | |
query, | |
limit=5, | |
fields=["section_number", "title", "chapter_number", "chapter_title", | |
"content", "type", "metadata"] | |
)) | |
if not results and not self.cancel_search: | |
results = list(self.collection.find( | |
{}, | |
limit=5 | |
)) | |
return results | |
except Exception as e: | |
if not isinstance(e, SearchCancelled): | |
print(f"Error searching AstraDB: {str(e)}") | |
raise | |
def format_section(self, section: Dict) -> str: | |
if self.cancel_search: | |
raise SearchCancelled("Search was cancelled by user") | |
try: | |
return f""" | |
{'='*80} | |
Chapter {section.get('chapter_number', 'N/A')}: {section.get('chapter_title', 'N/A')} | |
Section {section.get('section_number', 'N/A')}: {section.get('title', 'N/A')} | |
Type: {section.get('type', 'section')} | |
Content: | |
{section.get('content', 'N/A')} | |
References: {', '.join(section.get('metadata', {}).get('references', [])) or 'None'} | |
{'='*80} | |
""" | |
except Exception as e: | |
print(f"Error formatting section: {str(e)}") | |
return str(section) | |
def search_sections(self, query: str, progress=gr.Progress()) -> Tuple[str, str]: | |
self.cancel_search = False | |
try: | |
progress(0, desc="Searching relevant sections...") | |
search_results = self._search_astra(query) | |
if not search_results: | |
return "No relevant sections found.", "I apologize, but I cannot find relevant sections in the database." | |
progress(0.3, desc="Processing results...") | |
raw_results = [] | |
context_parts = [] | |
for idx, result in enumerate(search_results): | |
if self.cancel_search: | |
raise SearchCancelled("Search was cancelled by user") | |
raw_results.append(self.format_section(result)) | |
context_parts.append(f""" | |
Section {result.get('section_number')}: {result.get('title')} | |
{result.get('content', '')} | |
""") | |
progress((0.3 + (idx * 0.1)), desc="Processing results...") | |
progress(0.8, desc="Generating AI interpretation...") | |
context = "\n\n".join(context_parts) | |
chain = self.prompt | self.llm | |
ai_response = chain.invoke({ | |
"context": context, | |
"chat_history": self.chat_history, | |
"question": query | |
}) | |
self.chat_history += f"\nUser: {query}\nAI: {ai_response}\n" | |
progress(1.0, desc="Complete!") | |
return "\n".join(raw_results), ai_response | |
except SearchCancelled: | |
return "Search cancelled by user.", "Search was stopped. Please try again with a new query." | |
except Exception as e: | |
error_msg = f"Error processing query: {str(e)}" | |
print(error_msg) | |
return error_msg, "An error occurred while processing your query." | |
def cancel(self): | |
self.cancel_search = True | |
def create_interface(): | |
with gr.Blocks(title="Bharatiya Nyaya Sanhita Search", theme=gr.themes.Soft()) as iface: | |
gr.Markdown(""" | |
# π Bharatiya Nyaya Sanhita Legal Search System | |
Search through the Bharatiya Nyaya Sanhita, 2023 and get: | |
1. π Relevant sections, explanations, and illustrations | |
2. π€ AI-powered interpretation of the legal content | |
Enter your legal query below: | |
""") | |
search_bot = LegalTextSearchBot() | |
with gr.Row(): | |
query_input = gr.Textbox( | |
label="Your Query", | |
placeholder="e.g., What are the penalties for public servants who conceal information?", | |
lines=2 | |
) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
search_button = gr.Button("π Search Legal Sections", variant="primary") | |
with gr.Column(scale=1): | |
stop_button = gr.Button("π Stop Search", variant="stop") | |
with gr.Row(): | |
with gr.Column(): | |
raw_output = gr.Markdown( | |
label="π Relevant Legal Sections" | |
) | |
with gr.Column(): | |
ai_output = gr.Markdown( | |
label="π€ AI Interpretation" | |
) | |
gr.Examples( | |
examples=[ | |
"What are the penalties for public servants who conceal information?", | |
"What constitutes criminal conspiracy?", | |
"Explain the provisions related to culpable homicide", | |
"What are the penalties for causing death by negligence?", | |
"What are the punishments for corruption?" | |
], | |
inputs=query_input, | |
label="Example Queries" | |
) | |
def search(query): | |
return search_bot.search_sections(query) | |
def stop_search(): | |
search_bot.cancel() | |
return "Search cancelled.", "Search stopped by user." | |
search_button.click( | |
fn=search, | |
inputs=query_input, | |
outputs=[raw_output, ai_output], | |
cancels=[stop_button] # Cancel any ongoing search when stop is clicked | |
) | |
stop_button.click( | |
fn=stop_search, | |
outputs=[raw_output, ai_output], | |
cancels=[search_button] # Cancel the search button when stop is clicked | |
) | |
query_input.submit( | |
fn=search, | |
inputs=query_input, | |
outputs=[raw_output, ai_output], | |
cancels=[stop_button] | |
) | |
return iface | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch() | |
else: | |
demo = create_interface() | |
app = demo.launch(share=False) |