Spaces:
Sleeping
Sleeping
import gradio as gr | |
from typing import List, Dict, Tuple | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoModel | |
import torch | |
import os | |
from astrapy.db import AstraDB | |
from dotenv import load_dotenv | |
from huggingface_hub import login | |
import time | |
import logging | |
from functools import lru_cache | |
import numpy as np | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Load environment variables | |
load_dotenv() | |
login(token=os.getenv("HUGGINGFACE_API_TOKEN")) | |
class LegalTextSearchBot: | |
def __init__(self): | |
try: | |
# Initialize AstraDB connection | |
self.astra_db = AstraDB( | |
token=os.getenv("ASTRA_DB_APPLICATION_TOKEN"), | |
api_endpoint=os.getenv("ASTRA_DB_API_ENDPOINT") | |
) | |
self.collection = self.astra_db.collection(os.getenv("ASTRA_DB_COLLECTION")) | |
# Initialize language model for text generation | |
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
torch_dtype=torch.float32, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Initialize text generation pipeline | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.95, | |
repetition_penalty=1.15, | |
device_map="auto" | |
) | |
self.llm = HuggingFacePipeline(pipeline=pipe) | |
# Initialize embedding model | |
self.embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2" | |
self.embedding_pipeline = pipeline( | |
"feature-extraction", | |
model=self.embedding_model_name, | |
device_map="auto" | |
) | |
self.template = """ | |
IMPORTANT: You are a legal assistant that provides accurate information based on the Indian legal sections provided in the context. | |
STRICT RULES: | |
1. Base your response ONLY on the provided legal sections | |
2. If you cannot find relevant information, respond with: "I apologize, but I cannot find information about that in the legal database." | |
3. Do not make assumptions or use external knowledge | |
4. Always cite the specific section numbers you're referring to | |
5. Be precise and accurate in your legal interpretations | |
6. If quoting from the sections, use quotes and cite the section number | |
Context (Legal Sections): {context} | |
Chat History: {chat_history} | |
Question: {question} | |
Answer:""" | |
self.prompt = ChatPromptTemplate.from_template(self.template) | |
self.chat_history = "" | |
self.is_searching = False | |
except Exception as e: | |
logger.error(f"Error initializing LegalTextSearchBot: {str(e)}") | |
raise | |
def get_embedding(self, text: str) -> List[float]: | |
"""Generate embedding vector for text""" | |
try: | |
# Clean and prepare text | |
text = text.replace('\n', ' ').strip() | |
# Generate embedding | |
outputs = self.embedding_pipeline(text) | |
embeddings = torch.mean(torch.tensor(outputs[0]), dim=0) | |
# Convert to list and ensure correct dimension | |
embedding_list = embeddings.tolist() | |
# Pad or truncate to exactly 1024 dimensions | |
if len(embedding_list) < 1024: | |
embedding_list.extend([0.0] * (1024 - len(embedding_list))) | |
elif len(embedding_list) > 1024: | |
embedding_list = embedding_list[:1024] | |
return embedding_list | |
except Exception as e: | |
logger.error(f"Error generating embedding: {str(e)}") | |
raise | |
def _cached_search(self, query: str) -> tuple: | |
"""Cached version of vector search""" | |
try: | |
# Generate embedding for query | |
query_embedding = self.get_embedding(query) | |
results = list(self.collection.vector_find( | |
query_embedding, | |
limit=5, | |
fields=["section_number", "title", "chapter_number", "chapter_title", | |
"content", "type", "metadata"] | |
)) | |
return tuple(results) | |
except Exception as e: | |
logger.error(f"Error in vector search: {str(e)}") | |
return tuple() | |
def _search_astra(self, query: str) -> List[Dict]: | |
if not self.is_searching: | |
return [] | |
try: | |
results = list(self._cached_search(query)) | |
if not results and self.is_searching: | |
results = list(self.collection.find( | |
{}, | |
limit=5 | |
)) | |
return results | |
except Exception as e: | |
logger.error(f"Error searching AstraDB: {str(e)}") | |
return [] | |
def format_section(self, section: Dict) -> str: | |
try: | |
return f""" | |
{'='*80} | |
Chapter {section.get('chapter_number', 'N/A')}: {section.get('chapter_title', 'N/A')} | |
Section {section.get('section_number', 'N/A')}: {section.get('title', 'N/A')} | |
Type: {section.get('type', 'section')} | |
Content: | |
{section.get('content', 'N/A')} | |
References: {', '.join(section.get('metadata', {}).get('references', [])) or 'None'} | |
{'='*80} | |
""" | |
except Exception as e: | |
logger.error(f"Error formatting section: {str(e)}") | |
return str(section) | |
def search_sections(self, query: str, progress=gr.Progress()) -> Tuple[str, str]: | |
self.is_searching = True | |
start_time = time.time() | |
try: | |
progress(0, desc="Initializing search...") | |
if not query.strip(): | |
return "Please enter a search query.", "Please provide a specific legal question or topic to search for." | |
progress(0.1, desc="Searching relevant sections...") | |
search_results = self._search_astra(query) | |
if not search_results: | |
return "No relevant sections found.", "I apologize, but I cannot find relevant sections in the database." | |
if not self.is_searching: | |
return "Search cancelled.", "Search was stopped by user." | |
progress(0.3, desc="Processing results...") | |
raw_results = [] | |
context_parts = [] | |
for idx, result in enumerate(search_results): | |
if not self.is_searching: | |
return "Search cancelled.", "Search was stopped by user." | |
raw_results.append(self.format_section(result)) | |
context_parts.append(f""" | |
Section {result.get('section_number')}: {result.get('title')} | |
{result.get('content', '')} | |
""") | |
progress((0.3 + (idx * 0.1)), desc=f"Processing result {idx + 1} of {len(search_results)}...") | |
if not self.is_searching: | |
return "Search cancelled.", "Search was stopped by user." | |
progress(0.8, desc="Generating AI interpretation...") | |
context = "\n\n".join(context_parts) | |
chain = self.prompt | self.llm | |
ai_response = chain.invoke({ | |
"context": context, | |
"chat_history": self.chat_history, | |
"question": query | |
}) | |
self.chat_history += f"\nUser: {query}\nAI: {ai_response}\n" | |
elapsed_time = time.time() - start_time | |
logger.info(f"Search completed in {elapsed_time:.2f} seconds") | |
progress(1.0, desc="Search complete!") | |
return "\n".join(raw_results), ai_response | |
except Exception as e: | |
logger.error(f"Error processing query: {str(e)}") | |
return f"Error processing query: {str(e)}", "An error occurred while processing your query." | |
finally: | |
self.is_searching = False | |
def stop_search(self): | |
"""Stop the current search operation""" | |
self.is_searching = False | |
return "Search cancelled.", "Search was stopped by user." | |
def create_interface(): | |
with gr.Blocks(title="Bharatiya Nyaya Sanhita Search", theme=gr.themes.Soft()) as iface: | |
search_bot = LegalTextSearchBot() | |
gr.Markdown(""" | |
# π Bharatiya Nyaya Sanhita Legal Search System | |
Search through the Bharatiya Nyaya Sanhita, 2023 and get: | |
1. π Relevant sections, explanations, and illustrations | |
2. π€ AI-powered interpretation of the legal content | |
*Use the Stop button if you want to cancel a long-running search.* | |
""") | |
with gr.Row(): | |
query_input = gr.Textbox( | |
label="Your Query", | |
placeholder="e.g., What are the penalties for public servants who conceal information?", | |
lines=2 | |
) | |
with gr.Row(): | |
search_button = gr.Button("π Search", variant="primary", scale=4) | |
stop_button = gr.Button("π Stop", variant="stop", scale=1) | |
with gr.Row(): | |
raw_output = gr.Markdown(label="π Relevant Legal Sections") | |
ai_output = gr.Markdown(label="π€ AI Interpretation") | |
gr.Examples( | |
examples=[ | |
"What are the penalties for public servants who conceal information?", | |
"What constitutes criminal conspiracy?", | |
"Explain the provisions related to culpable homicide", | |
"What are the penalties for causing death by negligence?", | |
"What are the punishments for corruption?" | |
], | |
inputs=query_input, | |
label="Example Queries" | |
) | |
# Handle search | |
search_event = search_button.click( | |
fn=search_bot.search_sections, | |
inputs=query_input, | |
outputs=[raw_output, ai_output], | |
) | |
# Handle stop | |
stop_button.click( | |
fn=search_bot.stop_search, | |
outputs=[raw_output, ai_output], | |
cancels=[search_event] | |
) | |
# Handle Enter key | |
query_input.submit( | |
fn=search_bot.search_sections, | |
inputs=query_input, | |
outputs=[raw_output, ai_output], | |
) | |
return iface | |
if __name__ == "__main__": | |
try: | |
demo = create_interface() | |
demo.launch() | |
except Exception as e: | |
logger.error(f"Error launching application: {str(e)}") | |
else: | |
try: | |
demo = create_interface() | |
app = demo.launch(share=False) | |
except Exception as e: | |
logger.error(f"Error launching application: {str(e)}") |